security: check host validity before anything
This commit is contained in:
+1
-1
@@ -16,4 +16,4 @@ dev = [
|
|||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["ALL"]
|
select = ["ALL"]
|
||||||
ignore = ["D", "E501", "S104", "PLR2004", "ANN401", "BLE001", "COM812", "S603"]
|
ignore = ["D", "E501", "S104", "PLR2004", "ANN401", "BLE001", "COM812", "S603", "PLR0911"]
|
||||||
+16
-2
@@ -19,6 +19,8 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||||||
server_version = "StaplerServer/" + project.get_version()
|
server_version = "StaplerServer/" + project.get_version()
|
||||||
CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge"
|
CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge"
|
||||||
PATH_REGEX = re.compile(r"^\/([\w-]+)\/")
|
PATH_REGEX = re.compile(r"^\/([\w-]+)\/")
|
||||||
|
HOST_PART_REGEX = re.compile(r"^[a-zA-Z0-9]*[a-zA-Z0-9][a-zA-Z0-9]$")
|
||||||
|
HOST_PORT_REGEX = re.compile(r"^\d{2,5}$")
|
||||||
|
|
||||||
@typing.override
|
@typing.override
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -55,6 +57,11 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||||||
self.__pre_log_request()
|
self.__pre_log_request()
|
||||||
if (sub_path := self.__check_update_request()) is None:
|
if (sub_path := self.__check_update_request()) is None:
|
||||||
return None
|
return None
|
||||||
|
host: str | None = self.headers["X-Host"]
|
||||||
|
if host is not None and not self.__valid_host(host):
|
||||||
|
return self.send_error(
|
||||||
|
http.HTTPStatus.BAD_REQUEST, "Invalid requested host"
|
||||||
|
)
|
||||||
if (content_length := self.__get_length()) == 0:
|
if (content_length := self.__get_length()) == 0:
|
||||||
return self.send_error(http.HTTPStatus.LENGTH_REQUIRED, "No body found")
|
return self.send_error(http.HTTPStatus.LENGTH_REQUIRED, "No body found")
|
||||||
if content_length > self.max_size_bytes:
|
if content_length > self.max_size_bytes:
|
||||||
@@ -73,8 +80,8 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||||||
http.HTTPStatus.CREATED,
|
http.HTTPStatus.CREATED,
|
||||||
f"Resource /{sub_path}/ updated",
|
f"Resource /{sub_path}/ updated",
|
||||||
)
|
)
|
||||||
if self.headers["X-Host"] is not None:
|
if host is not None:
|
||||||
self.registry.set_host(sub_path, self.headers["X-Host"])
|
self.registry.set_host(sub_path, host)
|
||||||
self.registry.add(sub_path)
|
self.registry.add(sub_path)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -204,6 +211,13 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||||||
return 0
|
return 0
|
||||||
return int(self.headers["Content-Length"])
|
return int(self.headers["Content-Length"])
|
||||||
|
|
||||||
|
def __valid_host(self, host: str) -> bool:
|
||||||
|
hostname, port = host.split(":", maxsplit=2)
|
||||||
|
for part in hostname.split("."):
|
||||||
|
if not self.HOST_PART_REGEX.fullmatch(part):
|
||||||
|
return False
|
||||||
|
return not (port and len(port) and not self.HOST_PORT_REGEX.fullmatch(port))
|
||||||
|
|
||||||
def __server_index(self) -> None:
|
def __server_index(self) -> None:
|
||||||
self.__send_basic_body(self.server_version + "\n")
|
self.__send_basic_body(self.server_version + "\n")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user