diff --git a/src/handlers.py b/src/handlers.py index 4158516..4fcb9e2 100644 --- a/src/handlers.py +++ b/src/handlers.py @@ -15,7 +15,7 @@ if typing.TYPE_CHECKING: from . import cert_manager, params, registry, token_manager -class _BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): +class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): @typing.override def __init__( self, @@ -37,14 +37,12 @@ class _BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): ) -> None: shortmsg, longmsg = self.responses[code] if message is None: - message = shortmsg + message = shortmsg # pragma: no cover if explain is None: explain = longmsg - if hasattr(self, "headers") and ( - "Accept" not in self.headers["Accept"] or "text/" in self.headers["Accept"] - ): + if "text/" in self._get_header("Accept"): self.send_basic_body( - f"{code} {message}\n{explain}\n\n{self._server_signature()}", + f"{code} {message}\n{explain}\n\n{self.server_signature()}", code=code, message=message, ) @@ -52,17 +50,17 @@ class _BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): self.send_status_only(code, message) @typing.override - def log_message(self, format: str, *args: typing.Any) -> None: + def log_message(self, format: str, *args: typing.Any) -> None: # pragma: no cover fmt = "%s - " + format self.logger.info(fmt, self.address_string(), *args) @typing.override - def log_error(self, format: str, *args: typing.Any) -> None: + def log_error(self, format: str, *args: typing.Any) -> None: # pragma: no cover fmt = "%s - " + format self.logger.error(fmt, self.address_string(), *args) @typing.override - def log_request(self, code: str = "?", size: str = "-") -> None: # ty:ignore[invalid-method-override] + def log_request(self, code: str = "?", size: str = "-") -> None: # ty:ignore[invalid-method-override] # pragma: no cover if isinstance(code, http.HTTPStatus): color = logs.TermColor.RED if 100 <= code < 200: @@ -91,14 +89,15 @@ class _BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): message: str | None = None, headers: dict[str, str] | None = None, ) -> None: + if headers is None: + headers = {} encoded: bytes = body.encode() self.out_size = len(encoded) self.send_response(code, message) self.send_header("Content-type", f"{content_type}; charset=UTF-8") self.send_header("Content-Length", str(len(encoded))) - if headers is not None: - for header, value in headers.items(): - self.send_header(header, value) + for header, value in headers.items(): + self.send_header(header, value) # pragma: no cover self.end_headers() self.wfile.write(encoded) self.close_connection = True @@ -109,25 +108,35 @@ class _BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): message: str | None = None, headers: dict[str, str] | None = None, ) -> None: + if headers is None: + headers = {} self.send_response(code, message) self.send_header("Content-Length", "0") - if headers is not None: - for header, value in headers.items(): - self.send_header(header, value) + for header, value in headers.items(): + self.send_header(header, value) self.end_headers() self.close_connection = True def _get_host(self) -> str: - if not hasattr(self, "headers") or self.headers["Host"] is None: - return self.default_host - return self.headers["Host"].split(":", maxsplit=2)[0] + host = self._get_header("Host", self.default_host) + return host.split(":", maxsplit=2)[0] def _get_length(self) -> int: - if not hasattr(self, "headers") or not self.headers["Content-Length"]: - return 0 - return int(self.headers["Content-Length"]) + return int(self._get_header("Content-Length", "0")) - def _pre_log_request(self) -> None: + def _get_header(self, key: str, default_value: str = "") -> str: + if self._has_header(key): + return self.headers[key] + return default_value + + def _has_header(self, key: str) -> bool: + return ( + hasattr(self, "headers") + and key in self.headers + and len(self.headers[key]) > 0 + ) + + def _pre_log_request(self) -> None: # pragma: no cover args = ("...", self.address_string(), self._get_host(), self.requestline) fmt = "← %s - %s - %s - %s" if (size := self._get_length()) > 0: @@ -135,17 +144,20 @@ class _BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): fmt += " - %s" self.logger.debug(fmt, *args) - def _server_signature(self) -> str: + def server_signature(self) -> str: return self.server_version + "\n\n" + STAPLER_ASCII + "\n" -class RequestHandler(http.server.SimpleHTTPRequestHandler, _BaseHandler): +class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): protocol_version = "HTTP/2.0" server_version = "StaplerServer/" + project.get_version() CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge" - PATH_REGEX = re.compile(r"^\/([\w-]+)\/") - HOST_PART_REGEX = re.compile(r"^([a-zA-Z0-9]|[a-zA-Z0-9]*[a-zA-Z0-9][a-zA-Z0-9])$") + UPDATE_PATH_REGEX = re.compile(r"^\/([\w-]+)\/?$") + GET_PATH_REGEX = re.compile(r"^\/([\w-]+)\/") + HOST_PART_REGEX = re.compile(r"^([a-z0-9]|[a-z0-9][a-z0-9-]{,61}[a-z0-9])$") AUTHORIZED_PATHS: typing.ClassVar[list[str]] = ["/favicon.ico"] + TOKEN_HEADER = "X-Token" # noqa: S105 + HOST_HEADER = "X-Host" @typing.override def __init__( @@ -175,7 +187,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, _BaseHandler): def do_GET(self) -> None: self._pre_log_request() if self.path == "/" and self._get_host() == self.default_host: - return self.send_basic_body(self._server_signature()) + return self.send_basic_body(self.server_signature()) super().do_GET() return None @@ -183,14 +195,20 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, _BaseHandler): self._pre_log_request() if (sub_path := self.__check_update_request()) is None: return None - host: str | None = self.headers["X-Host"] + host: str | None = ( + self._get_header(self.HOST_HEADER).lower() + if self._has_header(self.HOST_HEADER) + else None + ) if host is not None and not self.__valid_host(host): return self.send_error( http.HTTPStatus.BAD_REQUEST, "Invalid requested host" ) if ( - page := self.registry.get_from_host(host) - ) is not None and page.path != sub_path: + host is not None + and (page := self.registry.get_from_host(host)) is not None + and page.path != sub_path + ): return self.send_error(http.HTTPStatus.FORBIDDEN, "Host already taken") if (content_length := self._get_length()) == 0: return self.send_error(http.HTTPStatus.LENGTH_REQUIRED, "No body found") @@ -211,7 +229,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, _BaseHandler): f"Resource /{sub_path}/ updated", ) self.registry.add(sub_path) - self.token_manager.set_token(self.headers["X-Token"], sub_path) + self.token_manager.set_token(self._get_header(self.TOKEN_HEADER), sub_path) if host is not None and self.cert_manager.create_or_update(host): self.registry.set_host(sub_path, host) return None @@ -243,12 +261,18 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, _BaseHandler): def translate_path(self, path: str) -> str: if path.startswith(self.CERTBOT_CHALLENGE_PATH): return self.certbot_www + path - if (page := self.registry.get_from_host(host := self._get_host())) is not None: + host = self._get_host() + if ( + host != self.default_host + and (page := self.registry.get_from_host(host := self._get_host())) + is not None + ): path = f"/{page.path}" + path elif host != self.default_host: return "" elif ( - path not in self.AUTHORIZED_PATHS and self.__get_subpath(path) is None + path not in self.AUTHORIZED_PATHS + and self.__get_subpath(path, self.GET_PATH_REGEX) is None ): # not a valid path return "" if pathlib.Path(path).name.startswith("."): # hidden files @@ -256,13 +280,14 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, _BaseHandler): return super().translate_path(path) def __check_update_request(self) -> str | None: - if (token := self.headers["X-Token"]) is None: + if not self._has_header(self.TOKEN_HEADER): self.send_error(http.HTTPStatus.BAD_REQUEST, "No X-Token header in request") return None + token = self._get_header(self.TOKEN_HEADER) if not self.token_manager.is_valid(token): self.send_error(http.HTTPStatus.UNAUTHORIZED, "Invalid token") return None - if (sub_path := self.__get_subpath_full(self.path)) is None: + if (sub_path := self.__get_subpath(self.path, self.UPDATE_PATH_REGEX)) is None: self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid path") return None if not self.token_manager.is_valid_for_path(token, sub_path): @@ -270,28 +295,26 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, _BaseHandler): return None return sub_path - def __get_subpath(self, path: str) -> str | None: - if (match := self.PATH_REGEX.match(path)) is not None: - return match.group(1) - return None - - def __get_subpath_full(self, path: str) -> str | None: - if (match := self.PATH_REGEX.fullmatch(path)) is not None: + def __get_subpath(self, path: str, regex: re.Pattern) -> str | None: + if (match := regex.match(path.lower())) is not None: return match.group(1) return None def __valid_host(self, host: str) -> bool: - return all(self.HOST_PART_REGEX.fullmatch(part) for part in host.split(".")) + return ( + all(self.HOST_PART_REGEX.fullmatch(part) for part in host.split(".")) + and len(host) < 256 + ) -class UpgradeHandler(_BaseHandler): +class UpgradeHandler(BaseHandler): server_version = "StaplerUpgradeServer/" + project.get_version() def do_HEAD(self) -> None: self._pre_log_request() self.send_status_only( http.HTTPStatus.MOVED_PERMANENTLY, - headers={"Location": f"https://{self._get_host()}{self.path}"}, + headers={"Location": f"https://{self._get_host()}{self.path.lower()}"}, ) def do_GET(self) -> None: diff --git a/tests/test_handlers.py b/tests/test_handlers.py new file mode 100644 index 0000000..c017552 --- /dev/null +++ b/tests/test_handlers.py @@ -0,0 +1,664 @@ +import abc +import collections +import contextlib +import http +import http.server +import io +import logging +import tarfile +import typing +import unittest.mock + +from src.cert_manager import CertManager +from src.data_dir import DataDir +from src.handlers import BaseHandler, RequestHandler, UpgradeHandler +from src.page import Page +from src.params import Parameters +from src.registry import Registry +from src.token_manager import TokenManager + +from . import BaseTestCase + + +class BaseHandlerTestCase(BaseTestCase, abc.ABC): + @abc.abstractmethod + def _get_handler( + self, + path: str = "/", + headers: dict[str, str | None] | None = None, + rfile: io.BufferedIOBase | None = None, + ) -> BaseHandler: + pass + + @contextlib.contextmanager + def expects_status_only( + self, + handler: BaseHandler, + code: int, + message: str | None = None, + headers: dict[str, str] | None = None, + ) -> typing.Iterator[None]: + if headers is None: + headers = {} + send_response_mock = handler.send_response = unittest.mock.Mock() # ty:ignore[invalid-assignment] + send_header_mock = handler.send_header = unittest.mock.Mock() # ty:ignore[invalid-assignment] + end_headers_mock = handler.end_headers = unittest.mock.Mock() # ty:ignore[invalid-assignment] + yield + send_response_mock.assert_called_once_with(code, message) + send_header_mock.assert_has_calls( + [ + unittest.mock.call("Content-Length", "0"), + ] + + [unittest.mock.call(header, value) for header, value in headers.items()], + any_order=True, + ) + end_headers_mock.assert_called_once() + + @contextlib.contextmanager + def expects_basic_body( # noqa: PLR0913 + self, + handler: BaseHandler, + body: str, + content_type: str = "text/plain", + code: int = http.HTTPStatus.OK, + message: str | None = None, + headers: dict[str, str] | None = None, + ) -> typing.Iterator[None]: + if headers is None: + headers = {} + send_response_mock = handler.send_response = unittest.mock.Mock() # ty:ignore[invalid-assignment] + send_header_mock = handler.send_header = unittest.mock.Mock() # ty:ignore[invalid-assignment] + end_headers_mock = handler.end_headers = unittest.mock.Mock() # ty:ignore[invalid-assignment] + yield + send_response_mock.assert_called_once_with(code, message) + send_header_mock.assert_has_calls( + [ + unittest.mock.call("Content-Length", str(len(body.encode()))), + unittest.mock.call("Content-type", f"{content_type}; charset=UTF-8"), + ] + + [unittest.mock.call(header, value) for header, value in headers.items()], + any_order=True, + ) + end_headers_mock.assert_called_once() + handler.wfile.seek(0) + self.assertEqual(handler.wfile.read(), body.encode()) + + @contextlib.contextmanager + def expects_error( + self, + handler: BaseHandler, + code: int, + message: str | None = None, + ) -> typing.Iterator[None]: + shortmsg, _ = RequestHandler.responses[code] + if message is None: + message = shortmsg + with self.expects_status_only(handler, code, message): + yield + + @contextlib.contextmanager + def expects_error_full( + self, + handler: BaseHandler, + code: int, + message: str | None = None, + explain: str | None = None, + ) -> typing.Iterator[None]: + shortmsg, longmsg = http.server.BaseHTTPRequestHandler.responses[code] + if message is None: + message = shortmsg + if explain is None: + explain = longmsg + with self.expects_basic_body( + handler, + body=f"{code} {message}\n{explain}\n\n{handler.server_signature()}", + code=code, + message=message, + ): + yield + + +class TestRequestHandler(BaseHandlerTestCase): + @typing.override + def setUp(self) -> None: + self.get_tmp_dir() + self.registry = self.mock(Registry) + self.cert_manager = self.mock(CertManager) + self.token_manager = self.mock(TokenManager) + self.certbot_www = self.tmp_path / "certbot_www" + self.data_dir = self.mock(DataDir) + super().setUp() + + def _get_handler( + self, + path: str = "/", + headers: dict[str, str | None] | None = None, + rfile: io.BufferedIOBase | None = None, + ) -> RequestHandler: + if headers is None: + headers = {} + with self.patch("http.server.BaseHTTPRequestHandler.__init__"): + handler = RequestHandler( + unittest.mock.MagicMock(), + "127.0.0.1", + unittest.mock.MagicMock(), + params=Parameters( + data_dir=self.get_tmp_dir(), certbot_www=str(self.certbot_www) + ), + registry=self.registry, + cert_manager=self.cert_manager, + token_manager=self.token_manager, + ) + handler.address_string = lambda: "127.0.0.1" # ty:ignore[invalid-assignment] + handler.requestline = "GET /" + handler.path = path + handler.request_version = "HTTP/0.9" + handler.headers = collections.defaultdict(lambda: None, headers) # ty:ignore[invalid-assignment] + handler.rfile = rfile if rfile is not None else io.BytesIO() + handler.wfile = io.BytesIO() + handler.logger = unittest.mock.Mock(logging.Logger) + handler.data_dir = self.data_dir + return handler + + def test_do_head_proxy(self) -> None: + handler = self._get_handler() + with ( + self.patch("http.server.SimpleHTTPRequestHandler.do_HEAD"), + self.seal_mocks(), + ): + handler.do_HEAD() + + def test_do_get_index(self) -> None: + handler = self._get_handler("/") + with ( + self.expects_basic_body(handler, handler.server_signature()), + self.seal_mocks(), + ): + handler.do_GET() + + def test_do_get_proxy_on_other_path(self) -> None: + handler = self._get_handler("/file") + with ( + self.patch("http.server.SimpleHTTPRequestHandler.do_GET"), + self.seal_mocks(), + ): + handler.do_GET() + + def test_do_get_proxy_on_other_host(self) -> None: + handler = self._get_handler("/", {"Host": "other_host"}) + with ( + self.patch("http.server.SimpleHTTPRequestHandler.do_GET"), + self.seal_mocks(), + ): + handler.do_GET() + + def test_do_put_no_token(self) -> None: + handler = self._get_handler("/path") + with ( + self.expects_error( + handler, http.HTTPStatus.BAD_REQUEST, "No X-Token header in request" + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_invalid_token(self) -> None: + handler = self._get_handler("/path", {"X-Token": "secret"}) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], False), # noqa: FBT003 + self.expects_error(handler, http.HTTPStatus.UNAUTHORIZED, "Invalid token"), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_invalid_path(self) -> None: + handler = self._get_handler("/pa.th", {"X-Token": "secret"}) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.expects_error(handler, http.HTTPStatus.BAD_REQUEST, "Invalid path"), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_invalid_token_for_path(self) -> None: + handler = self._get_handler("/path", {"X-Token": "secret"}) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + False, # noqa: FBT003 + ), + self.expects_error( + handler, http.HTTPStatus.FORBIDDEN, "Path forbidden for this token" + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_invalid_host(self) -> None: + handler = self._get_handler( + "/path", {"X-Token": "secret", "X-Host": "invalid_host"} + ) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.expects_error( + handler, http.HTTPStatus.BAD_REQUEST, "Invalid requested host" + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_invalid_host_for_path(self) -> None: + handler = self._get_handler( + "/path", {"X-Token": "secret", "X-Host": "example.com"} + ) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.mock_call( + self.registry.get_from_host, ["example.com"], Page("other_path") + ), + self.expects_error( + handler, http.HTTPStatus.FORBIDDEN, "Host already taken" + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_no_content(self) -> None: + handler = self._get_handler("/path", {"X-Token": "secret"}) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.expects_error( + handler, http.HTTPStatus.LENGTH_REQUIRED, "No body found" + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_content_too_large(self) -> None: + handler = self._get_handler( + "/path", {"X-Token": "secret", "Content-Length": "999999999"} + ) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.expects_error( + handler, + http.HTTPStatus.CONTENT_TOO_LARGE, + "Archive too large", + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_tar_error(self) -> None: + handler = self._get_handler( + "/path", {"X-Token": "secret", "Content-Length": "1"} + ) + handler.rfile.write(b"\0") + self.data_dir.extract_tar_bytes.side_effect = tarfile.TarError + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.expects_error( + handler, http.HTTPStatus.BAD_REQUEST, "Invalid tar archive" + ), + self.seal_mocks(), + ): + handler.do_PUT() + self.data_dir.extract_tar_bytes.assert_called_once() + + def test_do_put_extract_error(self) -> None: + handler = self._get_handler( + "/path", {"X-Token": "secret", "Content-Length": "1"} + ) + handler.rfile.write(b"\0") + self.data_dir.extract_tar_bytes.side_effect = Exception + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.expects_error(handler, http.HTTPStatus.INTERNAL_SERVER_ERROR, ""), + self.seal_mocks(), + ): + handler.do_PUT() + self.data_dir.extract_tar_bytes.assert_called_once() + + def test_do_put_ok(self) -> None: + handler = self._get_handler( + "/path", {"X-Token": "secret", "Content-Length": "1"} + ) + handler.rfile.write(b"\0") + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.mock_call_unchecked(self.data_dir.extract_tar_bytes), + self.mock_call(self.registry.add, ["path"]), + self.mock_call(self.token_manager.set_token, ["secret", "path"]), + self.expects_status_only( + handler, http.HTTPStatus.CREATED, "Resource /path/ updated" + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_ok_with_host_fail_init(self) -> None: + handler = self._get_handler( + "/path", + {"X-Token": "secret", "Content-Length": "1", "X-Host": "example.com"}, + ) + handler.rfile.write(b"\0") + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.mock_call(self.registry.get_from_host, ["example.com"], Page("path")), + self.mock_call_unchecked(self.data_dir.extract_tar_bytes), + self.mock_call(self.registry.add, ["path"]), + self.mock_call(self.token_manager.set_token, ["secret", "path"]), + self.mock_call(self.cert_manager.create_or_update, ["example.com"], False), # noqa: FBT003 + self.expects_status_only( + handler, http.HTTPStatus.CREATED, "Resource /path/ updated" + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_ok_with_host(self) -> None: + handler = self._get_handler( + "/path", + {"X-Token": "secret", "Content-Length": "1", "X-Host": "example.com"}, + ) + handler.rfile.write(b"\0") + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.mock_call(self.registry.get_from_host, ["example.com"], Page("path")), + self.mock_call_unchecked(self.data_dir.extract_tar_bytes), + self.mock_call(self.registry.add, ["path"]), + self.mock_call(self.token_manager.set_token, ["secret", "path"]), + self.mock_call(self.cert_manager.create_or_update, ["example.com"], True), # noqa: FBT003 + self.mock_call(self.registry.set_host, ["path", "example.com"]), + self.expects_status_only( + handler, http.HTTPStatus.CREATED, "Resource /path/ updated" + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_delete_no_token(self) -> None: + handler = self._get_handler("/path") + with ( + self.expects_error( + handler, http.HTTPStatus.BAD_REQUEST, "No X-Token header in request" + ), + self.seal_mocks(), + ): + handler.do_DELETE() + + def test_do_delete_invalid_token(self) -> None: + handler = self._get_handler("/path", {"X-Token": "secret"}) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], False), # noqa: FBT003 + self.expects_error(handler, http.HTTPStatus.UNAUTHORIZED, "Invalid token"), + self.seal_mocks(), + ): + handler.do_DELETE() + + def test_do_delete_invalid_path(self) -> None: + handler = self._get_handler("/pa.th", {"X-Token": "secret"}) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.expects_error(handler, http.HTTPStatus.BAD_REQUEST, "Invalid path"), + self.seal_mocks(), + ): + handler.do_DELETE() + + def test_do_delete_invalid_token_for_path(self) -> None: + handler = self._get_handler("/path", {"X-Token": "secret"}) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + False, # noqa: FBT003 + ), + self.expects_error( + handler, http.HTTPStatus.FORBIDDEN, "Path forbidden for this token" + ), + self.seal_mocks(), + ): + handler.do_DELETE() + + def test_do_delete_not_found(self) -> None: + handler = self._get_handler("/path", {"X-Token": "secret"}) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.mock_call(self.data_dir.exists, ["path"], False), # noqa: FBT003 + self.expects_error(handler, http.HTTPStatus.NOT_FOUND, "Not found"), + self.seal_mocks(), + ): + handler.do_DELETE() + + def test_do_delete_remove_error(self) -> None: + handler = self._get_handler("/path", {"X-Token": "secret"}) + self.data_dir.remove.side_effect = Exception + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.mock_call(self.data_dir.exists, ["path"], True), # noqa: FBT003 + self.mock_call(self.data_dir.exists, ["path"], True), # noqa: FBT003 + self.expects_error(handler, http.HTTPStatus.INTERNAL_SERVER_ERROR, ""), + self.seal_mocks(), + ): + handler.do_DELETE() + self.data_dir.remove.assert_called_once_with("path") + + def test_do_delete_ok(self) -> None: + handler = self._get_handler("/path", {"X-Token": "secret"}) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.mock_call(self.data_dir.exists, ["path"], True), # noqa: FBT003 + self.mock_call(self.data_dir.remove, ["path"]), + self.mock_call(self.registry.remove, ["path"]), + self.expects_error( + handler, http.HTTPStatus.NO_CONTENT, "Resource /path/ removed" + ), + self.seal_mocks(), + ): + handler.do_DELETE() + + def test_list_directory(self) -> None: + handler = self._get_handler("/path/", {"Accept": "text/html"}) + with ( + self.expects_error_full( + handler, http.HTTPStatus.NOT_FOUND, "File not found" + ), + self.seal_mocks(), + ): + handler.list_directory() + + def test_translate_path_certbot(self) -> None: + handler = self._get_handler() + with ( + self.patch("http.server.SimpleHTTPRequestHandler.translate_path", count=0), + self.seal_mocks(), + ): + self.assertEqual( + handler.translate_path("/.well-known/acme-challenge/abcde"), + str(self.certbot_www / ".well-known" / "acme-challenge" / "abcde"), + ) + + def test_translate_path_host_not_found(self) -> None: + handler = self._get_handler(headers={"Host": "example.com"}) + with ( + self.mock_call(self.registry.get_from_host, ["example.com"]), + self.patch("http.server.SimpleHTTPRequestHandler.translate_path", count=0), + self.seal_mocks(), + ): + self.assertEqual( + handler.translate_path("/"), + "", + ) + + def test_translate_path_invalid(self) -> None: + handler = self._get_handler() + with ( + self.patch("http.server.SimpleHTTPRequestHandler.translate_path", count=0), + self.seal_mocks(), + ): + self.assertEqual( + handler.translate_path("/invalid.path"), + "", + ) + + def test_translate_path_favicon(self) -> None: + handler = self._get_handler() + with ( + self.patch_call( + "http.server.SimpleHTTPRequestHandler.translate_path", + ["/favicon.ico"], + ), + self.seal_mocks(), + ): + self.assertEqual( + handler.translate_path("/favicon.ico"), + None, + ) + + def test_translate_path_dotfile(self) -> None: + handler = self._get_handler() + with ( + self.patch("http.server.SimpleHTTPRequestHandler.translate_path", count=0), + self.seal_mocks(), + ): + self.assertEqual( + handler.translate_path("/path/.token"), + "", + ) + + def test_translate_path_with_host(self) -> None: + handler = self._get_handler(headers={"Host": "example.com"}) + with ( + self.mock_call(self.registry.get_from_host, ["example.com"], Page("path")), + self.patch_call( + "http.server.SimpleHTTPRequestHandler.translate_path", + ["/path/index.html"], + ), + self.seal_mocks(), + ): + self.assertEqual( + handler.translate_path("/index.html"), + None, + ) + + def test_translate_path_default_host(self) -> None: + handler = self._get_handler() + with ( + self.patch_call( + "http.server.SimpleHTTPRequestHandler.translate_path", + ["/path/index.html"], + ), + self.seal_mocks(), + ): + self.assertEqual( + handler.translate_path("/path/index.html"), + None, + ) + + +class TestUpgradeHandler(BaseHandlerTestCase): + def _get_handler( + self, + path: str = "/", + headers: dict[str, str | None] | None = None, + rfile: io.BufferedIOBase | None = None, + ) -> UpgradeHandler: + if headers is None: + headers = {} + with self.patch("http.server.BaseHTTPRequestHandler.__init__"): + handler = UpgradeHandler( + unittest.mock.MagicMock(), + "127.0.0.1", + unittest.mock.MagicMock(), + params=Parameters(), + ) + handler.address_string = lambda: "127.0.0.1" # ty:ignore[invalid-assignment] + handler.requestline = "GET /" + handler.path = path + handler.request_version = "HTTP/0.9" + handler.headers = collections.defaultdict(lambda: None, headers) # ty:ignore[invalid-assignment] + handler.rfile = rfile if rfile is not None else io.BytesIO() + handler.wfile = io.BytesIO() + handler.logger = unittest.mock.Mock(logging.Logger) + return handler + + def test_do_get(self) -> None: + handler = self._get_handler("/file") + with self.expects_status_only( + handler, + http.HTTPStatus.MOVED_PERMANENTLY, + headers={"Location": "https://localhost/file"}, + ): + handler.do_GET() + + def test_do_head(self) -> None: + handler = self._get_handler("/file") + with self.expects_status_only( + handler, + http.HTTPStatus.MOVED_PERMANENTLY, + headers={"Location": "https://localhost/file"}, + ): + handler.do_HEAD()