security: check host validity before anything

This commit is contained in:
2026-04-12 18:17:12 +02:00
parent ef92d95115
commit 213a311fd9
2 changed files with 17 additions and 3 deletions
+1 -1
View File
@@ -16,4 +16,4 @@ dev = [
[tool.ruff.lint]
select = ["ALL"]
ignore = ["D", "E501", "S104", "PLR2004", "ANN401", "BLE001", "COM812", "S603"]
ignore = ["D", "E501", "S104", "PLR2004", "ANN401", "BLE001", "COM812", "S603", "PLR0911"]
+16 -2
View File
@@ -19,6 +19,8 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
server_version = "StaplerServer/" + project.get_version()
CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge"
PATH_REGEX = re.compile(r"^\/([\w-]+)\/")
HOST_PART_REGEX = re.compile(r"^[a-zA-Z0-9]*[a-zA-Z0-9][a-zA-Z0-9]$")
HOST_PORT_REGEX = re.compile(r"^\d{2,5}$")
@typing.override
def __init__(
@@ -55,6 +57,11 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
self.__pre_log_request()
if (sub_path := self.__check_update_request()) is None:
return None
host: str | None = self.headers["X-Host"]
if host is not None and not self.__valid_host(host):
return self.send_error(
http.HTTPStatus.BAD_REQUEST, "Invalid requested host"
)
if (content_length := self.__get_length()) == 0:
return self.send_error(http.HTTPStatus.LENGTH_REQUIRED, "No body found")
if content_length > self.max_size_bytes:
@@ -73,8 +80,8 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
http.HTTPStatus.CREATED,
f"Resource /{sub_path}/ updated",
)
if self.headers["X-Host"] is not None:
self.registry.set_host(sub_path, self.headers["X-Host"])
if host is not None:
self.registry.set_host(sub_path, host)
self.registry.add(sub_path)
return None
@@ -204,6 +211,13 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
return 0
return int(self.headers["Content-Length"])
def __valid_host(self, host: str) -> bool:
hostname, port = host.split(":", maxsplit=2)
for part in hostname.split("."):
if not self.HOST_PART_REGEX.fullmatch(part):
return False
return not (port and len(port) and not self.HOST_PORT_REGEX.fullmatch(port))
def __server_index(self) -> None:
self.__send_basic_body(self.server_version + "\n")