From 213a311fd9d4760b5e0766e397816e1bd4663e6e Mon Sep 17 00:00:00 2001 From: klemek Date: Sun, 12 Apr 2026 18:17:12 +0200 Subject: [PATCH] security: check host validity before anything --- pyproject.toml | 2 +- src/handler.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3ba2863..926b2fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,4 +16,4 @@ dev = [ [tool.ruff.lint] select = ["ALL"] -ignore = ["D", "E501", "S104", "PLR2004", "ANN401", "BLE001", "COM812", "S603"] \ No newline at end of file +ignore = ["D", "E501", "S104", "PLR2004", "ANN401", "BLE001", "COM812", "S603", "PLR0911"] \ No newline at end of file diff --git a/src/handler.py b/src/handler.py index 5ee9080..35e4730 100644 --- a/src/handler.py +++ b/src/handler.py @@ -19,6 +19,8 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler): server_version = "StaplerServer/" + project.get_version() CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge" PATH_REGEX = re.compile(r"^\/([\w-]+)\/") + HOST_PART_REGEX = re.compile(r"^[a-zA-Z0-9]*[a-zA-Z0-9][a-zA-Z0-9]$") + HOST_PORT_REGEX = re.compile(r"^\d{2,5}$") @typing.override def __init__( @@ -55,6 +57,11 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler): self.__pre_log_request() if (sub_path := self.__check_update_request()) is None: return None + host: str | None = self.headers["X-Host"] + if host is not None and not self.__valid_host(host): + return self.send_error( + http.HTTPStatus.BAD_REQUEST, "Invalid requested host" + ) if (content_length := self.__get_length()) == 0: return self.send_error(http.HTTPStatus.LENGTH_REQUIRED, "No body found") if content_length > self.max_size_bytes: @@ -73,8 +80,8 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler): http.HTTPStatus.CREATED, f"Resource /{sub_path}/ updated", ) - if self.headers["X-Host"] is not None: - self.registry.set_host(sub_path, self.headers["X-Host"]) + if host is not None: + self.registry.set_host(sub_path, host) self.registry.add(sub_path) return None @@ -204,6 +211,13 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler): return 0 return int(self.headers["Content-Length"]) + def __valid_host(self, host: str) -> bool: + hostname, port = host.split(":", maxsplit=2) + for part in hostname.split("."): + if not self.HOST_PART_REGEX.fullmatch(part): + return False + return not (port and len(port) and not self.HOST_PORT_REGEX.fullmatch(port)) + def __server_index(self) -> None: self.__send_basic_body(self.server_version + "\n")