security: check host validity before anything
This commit is contained in:
+1
-1
@@ -16,4 +16,4 @@ dev = [
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["ALL"]
|
||||
ignore = ["D", "E501", "S104", "PLR2004", "ANN401", "BLE001", "COM812", "S603"]
|
||||
ignore = ["D", "E501", "S104", "PLR2004", "ANN401", "BLE001", "COM812", "S603", "PLR0911"]
|
||||
+16
-2
@@ -19,6 +19,8 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
server_version = "StaplerServer/" + project.get_version()
|
||||
CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge"
|
||||
PATH_REGEX = re.compile(r"^\/([\w-]+)\/")
|
||||
HOST_PART_REGEX = re.compile(r"^[a-zA-Z0-9]*[a-zA-Z0-9][a-zA-Z0-9]$")
|
||||
HOST_PORT_REGEX = re.compile(r"^\d{2,5}$")
|
||||
|
||||
@typing.override
|
||||
def __init__(
|
||||
@@ -55,6 +57,11 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
self.__pre_log_request()
|
||||
if (sub_path := self.__check_update_request()) is None:
|
||||
return None
|
||||
host: str | None = self.headers["X-Host"]
|
||||
if host is not None and not self.__valid_host(host):
|
||||
return self.send_error(
|
||||
http.HTTPStatus.BAD_REQUEST, "Invalid requested host"
|
||||
)
|
||||
if (content_length := self.__get_length()) == 0:
|
||||
return self.send_error(http.HTTPStatus.LENGTH_REQUIRED, "No body found")
|
||||
if content_length > self.max_size_bytes:
|
||||
@@ -73,8 +80,8 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
http.HTTPStatus.CREATED,
|
||||
f"Resource /{sub_path}/ updated",
|
||||
)
|
||||
if self.headers["X-Host"] is not None:
|
||||
self.registry.set_host(sub_path, self.headers["X-Host"])
|
||||
if host is not None:
|
||||
self.registry.set_host(sub_path, host)
|
||||
self.registry.add(sub_path)
|
||||
return None
|
||||
|
||||
@@ -204,6 +211,13 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
return 0
|
||||
return int(self.headers["Content-Length"])
|
||||
|
||||
def __valid_host(self, host: str) -> bool:
|
||||
hostname, port = host.split(":", maxsplit=2)
|
||||
for part in hostname.split("."):
|
||||
if not self.HOST_PART_REGEX.fullmatch(part):
|
||||
return False
|
||||
return not (port and len(port) and not self.HOST_PORT_REGEX.fullmatch(port))
|
||||
|
||||
def __server_index(self) -> None:
|
||||
self.__send_basic_body(self.server_version + "\n")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user