From 33cfd350a5e5e283829cc2d9cd50c99c6fb1d6d8 Mon Sep 17 00:00:00 2001 From: Klemek Date: Mon, 20 Apr 2026 14:41:27 +0200 Subject: [PATCH] feat: X-Proxy --- README.md | 2 +- src/handlers.py | 93 +++++++-- src/page.py | 3 + src/registry.py | 18 +- tests/__init__.py | 40 +++- tests/test_handlers.py | 432 ++++++++++++++++++++++++++++++++++++----- tests/test_page.py | 6 + tests/test_registry.py | 100 ++++++++++ 8 files changed, 609 insertions(+), 85 deletions(-) diff --git a/README.md b/README.md index 7889247..f205e7d 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ curl -X DELETE \ - [x] unit tests - [x] github actions - [x] X-Redirect -- [ ] X-Proxy +- [x] X-Proxy - [ ] proper doc ### Makefile targets diff --git a/src/handlers.py b/src/handlers.py index 582b21c..2d082b3 100644 --- a/src/handlers.py +++ b/src/handlers.py @@ -47,7 +47,7 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): ) -> None: shortmsg, longmsg = self.responses[code] if message is None: - message = shortmsg # pragma: no cover + message = shortmsg if explain is None: explain = longmsg if "text/" in self._get_header("Accept"): @@ -104,7 +104,7 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): encoded: bytes = body.encode() self.out_size = len(encoded) self.send_response(code, message) - self.send_header("Content-type", f"{content_type}; charset=UTF-8") + self.send_header("Content-Type", f"{content_type}; charset=UTF-8") self.send_header("Content-Length", str(len(encoded))) for header, value in headers.items(): self.send_header(header, value) # pragma: no cover @@ -134,14 +134,14 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): ) def send_proxy(self, url: str) -> None: - body: bytes | None = None - if self.in_size > 0: - body = self.rfile.read(self.in_size) headers = dict(self.headers) headers["Host"] = urllib.parse.urlparse(url).netloc headers["X-Forwarded-For"] = self.client_address[0] headers["X-Real-IP"] = self.client_address[0] try: + body: bytes | None = None + if self.in_size > 0: + body = self.rfile.read(self.in_size) response: requests.Response = requests.request( self.command, url, data=body, headers=headers, timeout=240 ) @@ -220,6 +220,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): TOKEN_HEADER = "X-Token" # noqa: S105 HOST_HEADER = "X-Host" REDIRECT_HEADER = "X-Redirect" + PROXY_HEADER = "X-Proxy" @typing.override def __init__( @@ -241,6 +242,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): self.__token: str | None = None self.__target_host: str | None = None self.__target_redirect: str | None = None + self.__target_proxy: str | None = None super().__init__(*args, directory=params.data_dir, **kwargs, params=params) # ty:ignore[unknown-argument] @property @@ -249,6 +251,10 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): self.__token = self._get_header(self.TOKEN_HEADER) return self.__token + @property + def has_token(self) -> bool: + return len(self.token) > 0 + @property def target_host(self) -> str: if self.__target_host is None: @@ -269,28 +275,35 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): def has_target_redirect(self) -> bool: return len(self.target_redirect) > 0 + @property + def target_proxy(self) -> str: + if self.__target_proxy is None: + self.__target_proxy = self._get_header(self.PROXY_HEADER).lower() + return self.__target_proxy + + @property + def has_target_proxy(self) -> bool: + return len(self.target_proxy) > 0 + @typing.override def do_HEAD(self) -> None: self._pre_log_request() - if ( - page := self.__get_page(self.path) - ) is not None and page.redirect is not None: - return self.send_redirect(page.redirect) - return super().do_HEAD() + if not self._proxy_or_redirect(): + super().do_HEAD() @typing.override def do_GET(self) -> None: self._pre_log_request() + if self._proxy_or_redirect(): + return None if self.path == "/" and self.host == self.default_host: return self.send_basic_body(self.server_signature()) - if ( - page := self.__get_page(self.path) - ) is not None and page.redirect is not None: - return self.send_redirect(page.redirect) return super().do_GET() def do_PUT(self) -> None: self._pre_log_request() + if self._proxy_or_redirect(): + return None if (path := self.__check_update_request()) is None: return None if self.has_target_host and not self.__valid_host(self.target_host): @@ -303,8 +316,15 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): and page.path != path ): return self.send_error(http.HTTPStatus.FORBIDDEN, "Host already taken") + if self.has_target_proxy and self.has_target_redirect: + return self.send_error( + http.HTTPStatus.BAD_REQUEST, + f"Cannot use {self.PROXY_HEADER} with {self.REDIRECT_HEADER}", + ) if self.has_target_redirect: self._update_redirect(path) + elif self.has_target_proxy: + self._update_proxy(path) else: self._update_extract(path) if self.has_target_host and self.cert_manager.create_or_update( @@ -321,21 +341,26 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): def do_DELETE(self) -> None: self._pre_log_request() + if self._proxy_or_redirect(): + return None if (path := self.__check_update_request()) is None: return None return self._update_remove(path) def do_CONNECT(self) -> None: self._pre_log_request() - self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) + if not self._proxy_or_redirect(): + self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) def do_OPTIONS(self) -> None: self._pre_log_request() - self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) + if not self._proxy_or_redirect(): + self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) def do_TRACE(self) -> None: self._pre_log_request() - self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) + if not self._proxy_or_redirect(): + self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) def _update_extract(self, path: str) -> None: if self.in_size == 0: @@ -366,10 +391,22 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): http.HTTPStatus.BAD_REQUEST, f"No content must be sent with {self.REDIRECT_HEADER}", ) - self.data_dir.empty(path) - self.registry.add(path) - self.token_manager.set_token(path, self.token) self.registry.set_redirect(path, self.target_redirect) + self.token_manager.set_token(path, self.token) + self.send_status_only( + http.HTTPStatus.CREATED, + f"Resource /{path}/ updated", + ) + return None + + def _update_proxy(self, path: str) -> None: + if self.in_size > 0: + return self.send_error( + http.HTTPStatus.BAD_REQUEST, + f"No content must be sent with {self.PROXY_HEADER}", + ) + self.registry.set_proxy(path, self.target_proxy) + self.token_manager.set_token(path, self.token) self.send_status_only( http.HTTPStatus.CREATED, f"Resource /{path}/ updated", @@ -391,6 +428,22 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): self.registry.remove(path) return None + def _proxy_or_redirect(self) -> bool: + if self.has_token: + return False + if (page := self.__get_page(self.path)) is None: + return False + if page.redirect is not None: + self.send_redirect(page.redirect) + return True + if page.proxy is not None: + if self.host == self.default_host: + self.send_proxy(page.proxy + self.path.removeprefix(f"/{page.path}")) + else: + self.send_proxy(page.proxy + self.path) + return True + return False + @typing.override def list_directory(self, *_: typing.Any, **__: typing.Any) -> None: """Disable default directory listing.""" diff --git a/src/page.py b/src/page.py index 4fa7d33..43a73ae 100644 --- a/src/page.py +++ b/src/page.py @@ -8,6 +8,7 @@ class Page: host: str | None = None token_hash: str | None = None redirect: str | None = None + proxy: str | None = None def __repr__(self) -> str: out = f"/{self.path}/" @@ -15,6 +16,8 @@ class Page: out += f" [{self.host}]" if self.redirect is not None: out += f" (redirect: {self.redirect})" + elif self.proxy is not None: + out += f" (proxy: {self.proxy})" elif not self.with_index: out += " (no index)" return out diff --git a/src/registry.py b/src/registry.py index 7a3b64e..855b621 100644 --- a/src/registry.py +++ b/src/registry.py @@ -18,6 +18,7 @@ class Registry: HOST_FILE = ".host" TOKEN_FILE = ".token" # noqa: S105 REDIRECT_FILE = ".redirect" + PROXY_FILE = ".proxy" def __init__(self, params: Parameters) -> None: self.logger: logging.Logger = logging.getLogger(self.__class__.__name__) @@ -39,6 +40,7 @@ class Registry: self.data_dir.get_file(path, self.HOST_FILE), self.data_dir.get_file(path, self.TOKEN_FILE), self.data_dir.get_file(path, self.REDIRECT_FILE), + self.data_dir.get_file(path, self.PROXY_FILE), ) self.logger.info("Updated %s", self.pages[path]) @@ -52,14 +54,26 @@ class Registry: if path in self.pages and self.pages[path].token_hash != token_hash: self.data_dir.set_file(path, self.TOKEN_FILE, token_hash, 0o600) self.pages[path].token_hash = token_hash - self.logger.debug("Updated %s", self.pages[path]) + self.logger.debug("Updated %s (token)", self.pages[path]) def set_redirect(self, path: str, redirect: str) -> None: - if path in self.pages and self.pages[path].redirect != redirect: + if path not in self.pages or self.pages[path].redirect != redirect: + self.data_dir.empty(path) self.data_dir.set_file(path, self.REDIRECT_FILE, redirect) + if path not in self.pages: + self.pages[path] = Page(path) self.pages[path].redirect = redirect self.logger.debug("Updated %s", self.pages[path]) + def set_proxy(self, path: str, proxy: str) -> None: + if path not in self.pages or self.pages[path].proxy != proxy: + self.data_dir.empty(path) + self.data_dir.set_file(path, self.PROXY_FILE, proxy) + if path not in self.pages: + self.pages[path] = Page(path) + self.pages[path].proxy = proxy + self.logger.debug("Updated %s", self.pages[path]) + def remove(self, path: str) -> None: if path in self.pages: page = self.pages[path] diff --git a/tests/__init__.py b/tests/__init__.py index 6aaa51f..37789d2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -5,6 +5,8 @@ import typing import unittest import unittest.mock +__import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 # ty:ignore[unresolved-attribute] # noqa: SLF001 + class BaseTestCase(unittest.TestCase): @typing.override @@ -47,16 +49,19 @@ class BaseTestCase(unittest.TestCase): target: str, args: list[typing.Iterable[typing.Any]] | None = None, return_values: list[typing.Any] | None = None, + kwargs: list[dict[str, typing.Any]] | None = None, ) -> typing.Iterator[unittest.mock.Mock]: if args is None: args = [[]] if return_values is None: return_values = [None] * len(args) + if kwargs is None: + kwargs = [{}] * len(args) with unittest.mock.patch( target, side_effect=return_values, create=True ) as mock: yield mock - self.__check_calls(mock, args) + self.__check_calls(mock, args, kwargs) @contextlib.contextmanager def patch_call( @@ -64,10 +69,13 @@ class BaseTestCase(unittest.TestCase): target: str, args: typing.Iterable[typing.Any] | None = None, return_value: typing.Any = None, + kwargs: dict[str, typing.Any] | None = None, ) -> typing.Iterator[unittest.mock.Mock]: if args is None: args = [] - with self.patch_calls(target, [args], [return_value]) as mock: + if kwargs is None: + kwargs = {} + with self.patch_calls(target, [args], [return_value], [kwargs]) as mock: yield mock @contextlib.contextmanager @@ -84,15 +92,18 @@ class BaseTestCase(unittest.TestCase): mock: unittest.mock.Mock, args: list[typing.Iterable[typing.Any]] | None = None, return_values: list[typing.Any] | None = None, + kwargs: list[dict[str, typing.Any]] | None = None, ) -> typing.Iterator[None]: if args is None: args = [[]] if return_values is None: return_values = [None] * len(args) + if kwargs is None: + kwargs = [{}] * len(args) mock.side_effect = return_values mock.reset_mock() yield - self.__check_calls(mock, args) + self.__check_calls(mock, args, kwargs) @contextlib.contextmanager def mock_call( @@ -100,10 +111,13 @@ class BaseTestCase(unittest.TestCase): mock: unittest.mock.Mock, args: typing.Iterable[typing.Any] | None = None, return_value: typing.Any = None, + kwargs: dict[str, typing.Any] | None = None, ) -> typing.Iterator[None]: if args is None: args = [] - with self.mock_calls(mock, [args], [return_value]): + if kwargs is None: + kwargs = {} + with self.mock_calls(mock, [args], [return_value], [kwargs]): yield @contextlib.contextmanager @@ -140,17 +154,23 @@ class BaseTestCase(unittest.TestCase): self, mock: unittest.mock.Mock, args: list[typing.Iterable[typing.Any]], + kwargs: list[dict[str, typing.Any]], ) -> None: + total_rows = max(len(args), len(mock.method_calls), len(kwargs)) + missing_calls = max(0, total_rows - len(mock.mock_calls)) + missing_args = max(0, total_rows - len(args)) + missing_kwargs = max(0, total_rows - len(kwargs)) for i, values in enumerate( zip( - mock.mock_calls - + [None] - * (max(len(args), len(mock.method_calls)) - len(mock.mock_calls)), - args + [[]] * (max(len(args), len(mock.method_calls)) - len(args)), + mock.mock_calls + [None] * missing_calls, + args + [[]] * missing_args, + kwargs + [{}] * missing_kwargs, strict=False, ) ): - real_call, expected_args = values + real_call, expected_args, expected_kwargs = values self.assertEqual( - real_call, unittest.mock.call(*expected_args), f"{i + 1}: {mock}" + real_call, + unittest.mock.call(*expected_args, **expected_kwargs), + f"{i + 1}: {mock}", ) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 0373226..c91be42 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -9,6 +9,8 @@ import tarfile import typing import unittest.mock +import requests + from src.handlers import BaseHandler, RequestHandler, UpgradeHandler from src.page import Page from src.params import Parameters @@ -22,6 +24,7 @@ class BaseHandlerTestCase(BaseTestCase, abc.ABC): self, path: str = "/", headers: dict[str, str | None] | None = None, + method: str = "GET", rfile: io.BufferedIOBase | None = None, ) -> BaseHandler: pass @@ -70,7 +73,7 @@ class BaseHandlerTestCase(BaseTestCase, abc.ABC): send_header_mock.assert_has_calls( [ unittest.mock.call("Content-Length", str(len(body.encode()))), - unittest.mock.call("Content-type", f"{content_type}; charset=UTF-8"), + unittest.mock.call("Content-Type", f"{content_type}; charset=UTF-8"), ] + [unittest.mock.call(header, value) for header, value in headers.items()], any_order=True, @@ -129,6 +132,7 @@ class TestRequestHandler(BaseHandlerTestCase): self, path: str = "/", headers: dict[str, str | None] | None = None, + method: str = "GET", rfile: io.BufferedIOBase | None = None, ) -> RequestHandler: if headers is None: @@ -146,9 +150,11 @@ class TestRequestHandler(BaseHandlerTestCase): token_manager=self.token_manager, ) handler.address_string = lambda: "127.0.0.1" # ty:ignore[invalid-assignment] - handler.requestline = "GET /" + handler.requestline = f"{method} {path}" handler.path = path + handler.command = method handler.request_version = "HTTP/0.9" + handler.client_address = ("127.0.0.1", 12345) handler.headers = collections.defaultdict(lambda: None, headers) # ty:ignore[invalid-assignment] handler.rfile = rfile if rfile is not None else io.BytesIO() handler.wfile = io.BytesIO() @@ -156,25 +162,7 @@ class TestRequestHandler(BaseHandlerTestCase): handler.data_dir = self.data_dir return handler - def test_do_head_redirect(self) -> None: - handler = self._get_handler("/path") - with ( - self.mock_call( - self.registry.get_from_path, - ["path"], - Page("path", redirect="https://example.com"), - ), - self.expects_status_only( - handler, - http.HTTPStatus.MOVED_PERMANENTLY, - headers={"Location": "https://example.com"}, - ), - self.patch("http.server.SimpleHTTPRequestHandler.do_HEAD", count=0), - self.seal_mocks(), - ): - handler.do_HEAD() - - def test_do_head_proxy(self) -> None: + def test_do_head_forward(self) -> None: handler = self._get_handler() with ( self.patch("http.server.SimpleHTTPRequestHandler.do_HEAD"), @@ -191,37 +179,16 @@ class TestRequestHandler(BaseHandlerTestCase): ): handler.do_GET() - def test_do_get_redirect(self) -> None: - handler = self._get_handler("/path") - with ( - self.mock_call( - self.registry.get_from_path, - ["path"], - Page("path", redirect="https://example.com"), - ), - self.expects_status_only( - handler, - http.HTTPStatus.MOVED_PERMANENTLY, - headers={"Location": "https://example.com"}, - ), - self.patch("http.server.SimpleHTTPRequestHandler.do_GET", count=0), - self.seal_mocks(), - ): - handler.do_GET() - - def test_do_get_proxy_on_other_path(self) -> None: + def test_do_get_forward_on_other_path(self) -> None: handler = self._get_handler("/file") with ( - self.mock_call( - self.registry.get_from_path, - ["file"], - ), + self.mock_call(self.registry.get_from_path, ["file"], Page("file")), self.patch("http.server.SimpleHTTPRequestHandler.do_GET"), self.seal_mocks(), ): handler.do_GET() - def test_do_get_proxy_on_other_host(self) -> None: + def test_do_get_forward_on_other_host(self) -> None: handler = self._get_handler("/", {"Host": "other_host"}) with ( self.mock_call( @@ -236,6 +203,7 @@ class TestRequestHandler(BaseHandlerTestCase): def test_do_put_no_token(self) -> None: handler = self._get_handler("/path") with ( + self.mock_call(self.registry.get_from_path, ["path"]), self.expects_error( handler, http.HTTPStatus.BAD_REQUEST, "No X-Token header in request" ), @@ -246,6 +214,7 @@ class TestRequestHandler(BaseHandlerTestCase): def test_do_post_is_do_put(self) -> None: handler = self._get_handler("/path") with ( + self.mock_call(self.registry.get_from_path, ["path"]), self.expects_error( handler, http.HTTPStatus.BAD_REQUEST, "No X-Token header in request" ), @@ -256,6 +225,7 @@ class TestRequestHandler(BaseHandlerTestCase): def test_do_patch_is_do_put(self) -> None: handler = self._get_handler("/path") with ( + self.mock_call(self.registry.get_from_path, ["path"]), self.expects_error( handler, http.HTTPStatus.BAD_REQUEST, "No X-Token header in request" ), @@ -525,10 +495,8 @@ class TestRequestHandler(BaseHandlerTestCase): ["secret", "path"], True, # noqa: FBT003 ), - self.mock_call(self.data_dir.empty, ["path"]), - self.mock_call(self.registry.add, ["path"]), - self.mock_call(self.token_manager.set_token, ["path", "secret"]), self.mock_call(self.registry.set_redirect, ["path", "https://example.com"]), + self.mock_call(self.token_manager.set_token, ["path", "secret"]), self.expects_status_only( handler, http.HTTPStatus.CREATED, "Resource /path/ updated" ), @@ -553,10 +521,8 @@ class TestRequestHandler(BaseHandlerTestCase): True, # noqa: FBT003 ), self.mock_call(self.registry.get_from_host, ["example.com"], Page("path")), - self.mock_call(self.data_dir.empty, ["path"]), - self.mock_call(self.registry.add, ["path"]), - self.mock_call(self.token_manager.set_token, ["path", "secret"]), self.mock_call(self.registry.set_redirect, ["path", "https://example.com"]), + self.mock_call(self.token_manager.set_token, ["path", "secret"]), self.mock_call(self.cert_manager.create_or_update, ["example.com"], True), # noqa: FBT003 self.mock_call(self.registry.set_host, ["path", "example.com"]), self.expects_status_only( @@ -566,9 +532,114 @@ class TestRequestHandler(BaseHandlerTestCase): ): handler.do_PUT() + def test_do_put_proxy_with_content(self) -> None: + handler = self._get_handler( + "/path", + { + "X-Token": "secret", + "X-Proxy": "https://example.com", + "Content-Length": "1", + }, + ) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.expects_error( + handler, + http.HTTPStatus.BAD_REQUEST, + "No content must be sent with X-Proxy", + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_proxy_ok(self) -> None: + handler = self._get_handler( + "/path", + { + "X-Token": "secret", + "X-Proxy": "https://example.com", + }, + ) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.mock_call(self.registry.set_proxy, ["path", "https://example.com"]), + self.mock_call(self.token_manager.set_token, ["path", "secret"]), + self.expects_status_only( + handler, http.HTTPStatus.CREATED, "Resource /path/ updated" + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_proxy_with_host(self) -> None: + handler = self._get_handler( + "/path", + { + "X-Token": "secret", + "X-Proxy": "https://example.com", + "X-Host": "example.com", + }, + ) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.mock_call(self.registry.get_from_host, ["example.com"], Page("path")), + self.mock_call(self.registry.set_proxy, ["path", "https://example.com"]), + self.mock_call(self.token_manager.set_token, ["path", "secret"]), + self.mock_call(self.cert_manager.create_or_update, ["example.com"], True), # noqa: FBT003 + self.mock_call(self.registry.set_host, ["path", "example.com"]), + self.expects_status_only( + handler, http.HTTPStatus.CREATED, "Resource /path/ updated" + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_proxy_and_redirect(self) -> None: + handler = self._get_handler( + "/path", + { + "X-Token": "secret", + "X-Proxy": "https://example.com", + "X-Redirect": "https://example.com", + "X-Host": "example.com", + }, + ) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.mock_call(self.registry.get_from_host, ["example.com"], Page("path")), + self.expects_status_only( + handler, + http.HTTPStatus.BAD_REQUEST, + "Cannot use X-Proxy with X-Redirect", + ), + self.seal_mocks(), + ): + handler.do_PUT() + def test_do_delete_no_token(self) -> None: handler = self._get_handler("/path") with ( + self.mock_call(self.registry.get_from_path, ["path"]), self.expects_error( handler, http.HTTPStatus.BAD_REQUEST, "No X-Token header in request" ), @@ -662,6 +733,261 @@ class TestRequestHandler(BaseHandlerTestCase): ): handler.do_DELETE() + def test_do_post_proxy_no_body(self) -> None: + handler = self._get_handler("/path", method="POST") + response = requests.Response() + response.status_code = 200 + response.reason = "OK" + response.raw = io.BytesIO() + with ( + self.mock_call( + self.registry.get_from_path, + ["path"], + Page("path", proxy="https://example.com"), + ), + self.patch_call( + "requests.request", + [ + "POST", + "https://example.com", + ], + response, + { + "data": None, + "headers": { + "Host": "example.com", + "X-Forwarded-For": "127.0.0.1", + "X-Real-IP": "127.0.0.1", + }, + "timeout": 240, + }, + ), + self.expects_status_only(handler, 200, "OK"), + self.seal_mocks(), + ): + handler.do_POST() + + def test_do_post_proxy_with_request_body(self) -> None: + handler = self._get_handler( + "/path", + method="POST", + headers={"Content-Length": "5"}, + rfile=io.BytesIO(b"hello"), + ) + response = requests.Response() + response.status_code = 200 + response.reason = "OK" + response.raw = io.BytesIO() + with ( + self.mock_call( + self.registry.get_from_path, + ["path"], + Page("path", proxy="https://example.com"), + ), + self.patch_call( + "requests.request", + [ + "POST", + "https://example.com", + ], + response, + { + "data": b"hello", + "headers": { + "Host": "example.com", + "X-Forwarded-For": "127.0.0.1", + "X-Real-IP": "127.0.0.1", + "Content-Length": "5", + }, + "timeout": 240, + }, + ), + self.expects_status_only(handler, 200, "OK"), + self.seal_mocks(), + ): + handler.do_POST() + + def test_do_post_proxy_with_response_body(self) -> None: + handler = self._get_handler( + "/path", + method="POST", + ) + response = requests.Response() + response.status_code = 200 + response.reason = "OK" + response.headers["Content-Type"] = "text/plain; charset=UTF-8" + response.raw = io.BytesIO(b"hello") + with ( + self.mock_call( + self.registry.get_from_path, + ["path"], + Page("path", proxy="https://example.com"), + ), + self.patch_call( + "requests.request", + [ + "POST", + "https://example.com", + ], + response, + { + "data": None, + "headers": { + "Host": "example.com", + "X-Forwarded-For": "127.0.0.1", + "X-Real-IP": "127.0.0.1", + }, + "timeout": 240, + }, + ), + self.expects_basic_body(handler, "hello", message="OK"), + self.seal_mocks(), + ): + handler.do_POST() + + def test_do_post_proxy_fail(self) -> None: + handler = self._get_handler("/path", method="POST") + with ( + self.mock_call( + self.registry.get_from_path, + ["path"], + Page("path", proxy="https://example.com"), + ), + self.patch_call( + "requests.request", + [ + "POST", + "https://example.com", + ], + None, + { + "data": None, + "headers": { + "Host": "example.com", + "X-Forwarded-For": "127.0.0.1", + "X-Real-IP": "127.0.0.1", + }, + "timeout": 240, + }, + ) as request_mock, + self.expects_status_only( + handler, + http.HTTPStatus.BAD_GATEWAY, + "Could not reach https://example.com", + ), + self.seal_mocks(), + ): + request_mock.side_effect = Exception + handler.do_POST() + + def test_do_post_proxy_sub_path(self) -> None: + handler = self._get_handler("/path/index.html", method="POST") + response = requests.Response() + response.status_code = 200 + response.reason = "OK" + response.raw = io.BytesIO() + with ( + self.mock_call( + self.registry.get_from_path, + ["path"], + Page("path", proxy="https://example.com"), + ), + self.patch_call( + "requests.request", + [ + "POST", + "https://example.com/index.html", + ], + response, + { + "data": None, + "headers": { + "Host": "example.com", + "X-Forwarded-For": "127.0.0.1", + "X-Real-IP": "127.0.0.1", + }, + "timeout": 240, + }, + ), + self.expects_status_only(handler, 200, "OK"), + self.seal_mocks(), + ): + handler.do_POST() + + def test_do_post_proxy_sub_path_for_host(self) -> None: + handler = self._get_handler( + "/path/index.html", method="POST", headers={"Host": "host"} + ) + response = requests.Response() + response.status_code = 200 + response.reason = "OK" + response.raw = io.BytesIO() + with ( + self.mock_call( + self.registry.get_from_host, + ["host"], + Page("path", proxy="https://example.com"), + ), + self.patch_call( + "requests.request", + [ + "POST", + "https://example.com/path/index.html", + ], + response, + { + "data": None, + "headers": { + "Host": "example.com", + "X-Forwarded-For": "127.0.0.1", + "X-Real-IP": "127.0.0.1", + }, + "timeout": 240, + }, + ), + self.expects_status_only(handler, 200, "OK"), + self.seal_mocks(), + ): + handler.do_POST() + + def test_do_method_not_supported(self) -> None: + handler = self._get_handler("/path") + for method in ["CONNECT", "TRACE", "OPTIONS"]: + with ( + self.subTest(None, method=method), + self.mock_call( + self.registry.get_from_path, + ["path"], + ), + self.expects_status_only( + handler, http.HTTPStatus.METHOD_NOT_ALLOWED, "Method Not Allowed" + ), + self.seal_mocks(), + ): + getattr(handler, f"do_{method}")() + + def test_do_redirect(self) -> None: + handler = self._get_handler("/path") + for method in [method.value for method in http.HTTPMethod]: + with ( + self.subTest(None, method=method), + self.mock_call( + self.registry.get_from_path, + ["path"], + Page("path", redirect="https://example.com"), + ), + self.expects_status_only( + handler, + http.HTTPStatus.MOVED_PERMANENTLY, + headers={"Location": "https://example.com"}, + ), + self.patch( + f"http.server.SimpleHTTPRequestHandler.do_{method}", count=0 + ), + self.seal_mocks(), + ): + getattr(handler, f"do_{method}")() + def test_list_directory(self) -> None: handler = self._get_handler("/path/", {"Accept": "text/html"}) with ( @@ -780,6 +1106,7 @@ class TestUpgradeHandler(BaseHandlerTestCase): self, path: str = "/", headers: dict[str, str | None] | None = None, + method: str = "GET", rfile: io.BufferedIOBase | None = None, ) -> UpgradeHandler: if headers is None: @@ -792,8 +1119,9 @@ class TestUpgradeHandler(BaseHandlerTestCase): params=Parameters(), ) handler.address_string = lambda: "127.0.0.1" # ty:ignore[invalid-assignment] - handler.requestline = "GET /" + handler.requestline = f"{method} {path}" handler.path = path + handler.command = method handler.request_version = "HTTP/0.9" handler.headers = collections.defaultdict(lambda: None, headers) # ty:ignore[invalid-assignment] handler.rfile = rfile if rfile is not None else io.BytesIO() diff --git a/tests/test_page.py b/tests/test_page.py index 14b381c..f1452d8 100644 --- a/tests/test_page.py +++ b/tests/test_page.py @@ -21,3 +21,9 @@ class TestPage(BaseTestCase): str(Page("test_1", redirect="https://example.com")), "/test_1/ (redirect: https://example.com)", ) + + def test_repr_with_proxy(self) -> None: + self.assertEqual( + str(Page("test_1", proxy="https://example.com")), + "/test_1/ (proxy: https://example.com)", + ) diff --git a/tests/test_registry.py b/tests/test_registry.py index 96d63de..26a440f 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -30,17 +30,21 @@ class TestRegistry(BaseTestCase): ["test_1", Registry.HOST_FILE], ["test_1", Registry.TOKEN_FILE], ["test_1", Registry.REDIRECT_FILE], + ["test_1", Registry.PROXY_FILE], ["test_2", Registry.HOST_FILE], ["test_2", Registry.TOKEN_FILE], ["test_2", Registry.REDIRECT_FILE], + ["test_2", Registry.PROXY_FILE], ], [ "test_1_host", "test_1_token", None, None, + None, "test_2_token", "test_2_redirect", + None, ], ), self.seal_mocks(), @@ -114,12 +118,33 @@ class TestRegistry(BaseTestCase): self.registry.set_token_hash("test_1", "new_value") self.assertEqual(self.registry.pages["test_1"].token_hash, "new_value") + def test_set_token_hash_no_change(self) -> None: + self.registry.pages["test_1"] = Page( + "test_1", + token_hash="secret", # noqa: S106 + ) + with ( + self.seal_mocks(), + ): + self.registry.set_token_hash("test_1", "secret") + self.assertEqual(self.registry.pages["test_1"].token_hash, "secret") + + def test_set_token_hash_not_found(self) -> None: + with ( + self.seal_mocks(), + ): + self.registry.set_token_hash("test_1", "secret") + def test_set_redirect(self) -> None: self.registry.pages["test_1"] = Page( "test_1", redirect="https://example.com", ) with ( + self.mock_call( + self.data_dir.empty, + ["test_1"], + ), self.mock_call( self.data_dir.set_file, ["test_1", Registry.REDIRECT_FILE, "https://new-example.com"], @@ -131,6 +156,81 @@ class TestRegistry(BaseTestCase): self.registry.pages["test_1"].redirect, "https://new-example.com" ) + def test_set_redirect_no_change(self) -> None: + self.registry.pages["test_1"] = Page( + "test_1", + redirect="https://example.com", + ) + with ( + self.seal_mocks(), + ): + self.registry.set_redirect("test_1", "https://example.com") + self.assertEqual(self.registry.pages["test_1"].redirect, "https://example.com") + + def test_set_redirect_not_found(self) -> None: + with ( + self.mock_call( + self.data_dir.empty, + ["test_1"], + ), + self.mock_call( + self.data_dir.set_file, + ["test_1", Registry.REDIRECT_FILE, "https://new-example.com"], + ), + self.seal_mocks(), + ): + self.registry.set_redirect("test_1", "https://new-example.com") + self.assertIn("test_1", self.registry.pages) + self.assertEqual( + self.registry.pages["test_1"].redirect, "https://new-example.com" + ) + + def test_set_proxy(self) -> None: + self.registry.pages["test_1"] = Page( + "test_1", + proxy="https://example.com", + ) + with ( + self.mock_call( + self.data_dir.empty, + ["test_1"], + ), + self.mock_call( + self.data_dir.set_file, + ["test_1", Registry.PROXY_FILE, "https://new-example.com"], + ), + self.seal_mocks(), + ): + self.registry.set_proxy("test_1", "https://new-example.com") + self.assertEqual(self.registry.pages["test_1"].proxy, "https://new-example.com") + + def test_set_proxy_no_change(self) -> None: + self.registry.pages["test_1"] = Page( + "test_1", + proxy="https://example.com", + ) + with ( + self.seal_mocks(), + ): + self.registry.set_proxy("test_1", "https://example.com") + self.assertEqual(self.registry.pages["test_1"].proxy, "https://example.com") + + def test_set_proxy_not_found(self) -> None: + with ( + self.mock_call( + self.data_dir.empty, + ["test_1"], + ), + self.mock_call( + self.data_dir.set_file, + ["test_1", Registry.PROXY_FILE, "https://new-example.com"], + ), + self.seal_mocks(), + ): + self.registry.set_proxy("test_1", "https://new-example.com") + self.assertIn("test_1", self.registry.pages) + self.assertEqual(self.registry.pages["test_1"].proxy, "https://new-example.com") + def test_remove(self) -> None: self.registry.pages["test_1"] = Page( "test_1",