fix: better path detection

This commit is contained in:
2026-04-12 13:16:55 +02:00
parent 379defc3c0
commit f3b0b40251
2 changed files with 43 additions and 25 deletions
+34 -23
View File
@@ -2,10 +2,12 @@ import os
import io import io
import tarfile import tarfile
import shutil import shutil
import re
class DataDir: class DataDir:
HOST_FILE = ".host" HOST_FILE = ".host"
PATH_REGEX = re.compile(r"^[\w-]+$")
def __init__(self, root_path: str): def __init__(self, root_path: str):
self.root_path = root_path self.root_path = root_path
@@ -13,40 +15,49 @@ class DataDir:
def list_paths(self) -> list[str]: def list_paths(self) -> list[str]:
paths: list[str] = [] paths: list[str] = []
for path in os.listdir(self.root_path): for path in os.listdir(self.root_path):
if os.path.isdir(os.path.join(self.root_path, path)): if self.__valid_path(path):
paths += [path] paths += [path]
return paths return paths
def __valid_path(self, path: str, exists: bool = True) -> bool:
return (
not exists or os.path.isdir(os.path.join(self.root_path, path))
) and self.PATH_REGEX.match(path) is not None
def set_host(self, path: str, host: str): def set_host(self, path: str, host: str):
path_host = os.path.join(self.root_path, path, self.HOST_FILE) if self.__valid_path(path):
with open(path_host, mode="w") as host_file: path_host = os.path.join(self.root_path, path, self.HOST_FILE)
host_file.write(host) with open(path_host, mode="w") as host_file:
host_file.write(host)
def has_index(self, path: str): def has_index(self, path: str):
path_index = os.path.join(self.root_path, path, "index.html") if self.__valid_path(path):
return os.path.exists(path_index) and os.path.isfile(path_index) path_index = os.path.join(self.root_path, path, "index.html")
return os.path.exists(path_index) and os.path.isfile(path_index)
def get_host(self, path: str): def get_host(self, path: str):
path_host = os.path.join(self.root_path, path, self.HOST_FILE) if self.__valid_path(path):
if os.path.exists(path_host) and os.path.isfile(path_host): path_host = os.path.join(self.root_path, path, self.HOST_FILE)
try: if os.path.exists(path_host) and os.path.isfile(path_host):
with open(path_host) as host_file: try:
return host_file.read().split("\n")[0].strip() with open(path_host) as host_file:
except Exception: return host_file.read().split("\n")[0].strip()
pass except Exception:
return None pass
return None
def extract_tar_bytes(self, path: str, tar_bytes: io.BytesIO): def extract_tar_bytes(self, path: str, tar_bytes: io.BytesIO):
target_path = os.path.join(self.root_path, path) if self.__valid_path(path, exists=False):
with tarfile.open(fileobj=tar_bytes) as tar_file: target_path = os.path.join(self.root_path, path)
if os.path.exists(target_path): with tarfile.open(fileobj=tar_bytes) as tar_file:
shutil.rmtree(target_path) if os.path.exists(target_path):
tar_file.extractall(target_path) shutil.rmtree(target_path)
tar_file.extractall(target_path)
def remove(self, path: str): def remove(self, path: str):
target_path = os.path.join(self.root_path, path) if self.__valid_path(path):
shutil.rmtree(target_path) target_path = os.path.join(self.root_path, path)
shutil.rmtree(target_path)
def exists(self, path: str): def exists(self, path: str):
target_path = os.path.join(self.root_path, path) return self.__valid_path(path)
return os.path.exists(target_path) and os.path.isdir(target_path)
+9 -2
View File
@@ -12,6 +12,7 @@ class StaplerRequestHandler(http.server.SimpleHTTPRequestHandler):
protocol_version = "HTTP/2.0" protocol_version = "HTTP/2.0"
server_version = "StaplerServer/" + project.get_version() server_version = "StaplerServer/" + project.get_version()
CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge" CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge"
PATH_REGEX = re.compile(r"^\/([\w-]+)\/")
def __init__( def __init__(
self, *args, params: params.Parameters, registry: registry.Registry, **kwargs self, *args, params: params.Parameters, registry: registry.Registry, **kwargs
@@ -34,6 +35,8 @@ class StaplerRequestHandler(http.server.SimpleHTTPRequestHandler):
if (page := self.registry.get_from_host(self.get_host())) is not None: if (page := self.registry.get_from_host(self.get_host())) is not None:
path = f"/{page.path}" + path path = f"/{page.path}" + path
path = super().translate_path(path) path = super().translate_path(path)
if self.get_subpath(match_full=False) is None: # not a valid path
return ""
if os.path.basename(path).startswith("."): # hidden files if os.path.basename(path).startswith("."): # hidden files
return "" return ""
return path return path
@@ -83,8 +86,12 @@ class StaplerRequestHandler(http.server.SimpleHTTPRequestHandler):
) )
self.registry.remove(sub_path) self.registry.remove(sub_path)
def get_subpath(self) -> str | None: def get_subpath(self, match_full: bool = True) -> str | None:
if (match := re.match(r"^\/([\w-]+)\/$", self.path)) is not None: if match_full:
match = self.PATH_REGEX.fullmatch(self.path)
else:
match = self.PATH_REGEX.match(self.path)
if match is not None:
return match.group(1) return match.group(1)
return None return None