diff --git a/src/data_dir.py b/src/data_dir.py index 36ebeb4..c5b76e7 100644 --- a/src/data_dir.py +++ b/src/data_dir.py @@ -2,10 +2,12 @@ import os import io import tarfile import shutil +import re class DataDir: HOST_FILE = ".host" + PATH_REGEX = re.compile(r"^[\w-]+$") def __init__(self, root_path: str): self.root_path = root_path @@ -13,40 +15,49 @@ class DataDir: def list_paths(self) -> list[str]: paths: list[str] = [] for path in os.listdir(self.root_path): - if os.path.isdir(os.path.join(self.root_path, path)): + if self.__valid_path(path): paths += [path] return paths + def __valid_path(self, path: str, exists: bool = True) -> bool: + return ( + not exists or os.path.isdir(os.path.join(self.root_path, path)) + ) and self.PATH_REGEX.match(path) is not None + def set_host(self, path: str, host: str): - path_host = os.path.join(self.root_path, path, self.HOST_FILE) - with open(path_host, mode="w") as host_file: - host_file.write(host) + if self.__valid_path(path): + path_host = os.path.join(self.root_path, path, self.HOST_FILE) + with open(path_host, mode="w") as host_file: + host_file.write(host) def has_index(self, path: str): - path_index = os.path.join(self.root_path, path, "index.html") - return os.path.exists(path_index) and os.path.isfile(path_index) + if self.__valid_path(path): + path_index = os.path.join(self.root_path, path, "index.html") + return os.path.exists(path_index) and os.path.isfile(path_index) def get_host(self, path: str): - path_host = os.path.join(self.root_path, path, self.HOST_FILE) - if os.path.exists(path_host) and os.path.isfile(path_host): - try: - with open(path_host) as host_file: - return host_file.read().split("\n")[0].strip() - except Exception: - pass - return None + if self.__valid_path(path): + path_host = os.path.join(self.root_path, path, self.HOST_FILE) + if os.path.exists(path_host) and os.path.isfile(path_host): + try: + with open(path_host) as host_file: + return host_file.read().split("\n")[0].strip() + except Exception: + pass + return None def extract_tar_bytes(self, path: str, tar_bytes: io.BytesIO): - target_path = os.path.join(self.root_path, path) - with tarfile.open(fileobj=tar_bytes) as tar_file: - if os.path.exists(target_path): - shutil.rmtree(target_path) - tar_file.extractall(target_path) + if self.__valid_path(path, exists=False): + target_path = os.path.join(self.root_path, path) + with tarfile.open(fileobj=tar_bytes) as tar_file: + if os.path.exists(target_path): + shutil.rmtree(target_path) + tar_file.extractall(target_path) def remove(self, path: str): - target_path = os.path.join(self.root_path, path) - shutil.rmtree(target_path) + if self.__valid_path(path): + target_path = os.path.join(self.root_path, path) + shutil.rmtree(target_path) def exists(self, path: str): - target_path = os.path.join(self.root_path, path) - return os.path.exists(target_path) and os.path.isdir(target_path) + return self.__valid_path(path) diff --git a/src/handler.py b/src/handler.py index 18e3000..cc5ae61 100644 --- a/src/handler.py +++ b/src/handler.py @@ -12,6 +12,7 @@ class StaplerRequestHandler(http.server.SimpleHTTPRequestHandler): protocol_version = "HTTP/2.0" server_version = "StaplerServer/" + project.get_version() CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge" + PATH_REGEX = re.compile(r"^\/([\w-]+)\/") def __init__( self, *args, params: params.Parameters, registry: registry.Registry, **kwargs @@ -34,6 +35,8 @@ class StaplerRequestHandler(http.server.SimpleHTTPRequestHandler): if (page := self.registry.get_from_host(self.get_host())) is not None: path = f"/{page.path}" + path path = super().translate_path(path) + if self.get_subpath(match_full=False) is None: # not a valid path + return "" if os.path.basename(path).startswith("."): # hidden files return "" return path @@ -83,8 +86,12 @@ class StaplerRequestHandler(http.server.SimpleHTTPRequestHandler): ) self.registry.remove(sub_path) - def get_subpath(self) -> str | None: - if (match := re.match(r"^\/([\w-]+)\/$", self.path)) is not None: + def get_subpath(self, match_full: bool = True) -> str | None: + if match_full: + match = self.PATH_REGEX.fullmatch(self.path) + else: + match = self.PATH_REGEX.match(self.path) + if match is not None: return match.group(1) return None