diff --git a/README.md b/README.md index e65b785..7889247 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,7 @@ curl -X DELETE \ - [x] remove dot files after file extract - [x] unit tests - [x] github actions -- [ ] X-Redirect +- [x] X-Redirect - [ ] X-Proxy - [ ] proper doc diff --git a/src/data_dir.py b/src/data_dir.py index 37a28d0..b6720c6 100644 --- a/src/data_dir.py +++ b/src/data_dir.py @@ -90,6 +90,12 @@ class DataDir: shutil.rmtree(target_path) self.logger.debug("Deleted %s", target_path) + def empty(self, path: str) -> None: + self.remove(path) + target_path = self.root_path / path + target_path.mkdir() + self.logger.debug("Created empty %s", target_path) + def exists(self, path: str) -> bool: return ( self.PATH_REGEX.match(path) is not None and (self.root_path / path).is_dir() diff --git a/src/handlers.py b/src/handlers.py index f989942..75cce12 100644 --- a/src/handlers.py +++ b/src/handlers.py @@ -14,6 +14,7 @@ from .data_dir import DataDir if typing.TYPE_CHECKING: from .cert_manager import CertManager + from .page import Page from .params import Parameters from .registry import Registry from .token_manager import TokenManager @@ -30,8 +31,8 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): self.logger: logging.Logger = logging.getLogger(self.__class__.__name__) self.default_host: str = params.host.split(":", maxsplit=2)[0] self.out_size: int = 0 - self._host: str | None = None - self._in_size: int | None = None + self.__host: str | None = None + self.__in_size: int | None = None super().__init__(*args, **kwargs) @typing.override @@ -123,11 +124,17 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): self.end_headers() self.close_connection = True + def send_redirect(self, location: str) -> None: + self.send_status_only( + http.HTTPStatus.MOVED_PERMANENTLY, + headers={"Location": location}, + ) + @property def host(self) -> str: - if self._host is None: - self._host = self._get_host() - return self._host + if self.__host is None: + self.__host = self._get_host() + return self.__host def _get_host(self) -> str: host = self._get_header("Host", self.default_host) @@ -135,9 +142,9 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): @property def in_size(self) -> int: - if self._in_size is None: - self._in_size = self._get_length() - return self._in_size + if self.__in_size is None: + self.__in_size = self._get_length() + return self.__in_size def _get_length(self) -> int: return int(self._get_header("Content-Length", "0")) @@ -167,15 +174,16 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): - protocol_version = "HTTP/2.0" + protocol_version = "HTTP/1.1" server_version = "StaplerServer/" + project.get_version() CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge" UPDATE_PATH_REGEX = re.compile(r"^\/([\w-]+)\/?$") - GET_PATH_REGEX = re.compile(r"^\/([\w-]+)\/") + GET_PATH_REGEX = re.compile(r"^\/([\w-]+)($|\/)") HOST_PART_REGEX = re.compile(r"^([a-z0-9]|[a-z0-9][a-z0-9-]{,61}[a-z0-9])$") AUTHORIZED_PATHS: typing.ClassVar[list[str]] = ["/favicon.ico"] TOKEN_HEADER = "X-Token" # noqa: S105 HOST_HEADER = "X-Host" + REDIRECT_HEADER = "X-Redirect" @typing.override def __init__( @@ -194,42 +202,60 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): self.registry: Registry = registry self.cert_manager: CertManager = cert_manager self.certbot_www: str = os.path.realpath(params.certbot_www) - self._token: str | None = None - self._target_host: str | None = None + self.__token: str | None = None + self.__target_host: str | None = None + self.__target_redirect: str | None = None super().__init__(*args, directory=params.data_dir, **kwargs, params=params) # ty:ignore[unknown-argument] @property def token(self) -> str: - if self._token is None: - self._token = self._get_header(self.TOKEN_HEADER) - return self._token + if self.__token is None: + self.__token = self._get_header(self.TOKEN_HEADER) + return self.__token @property def target_host(self) -> str: - if self._target_host is None: - self._target_host = self._get_header(self.HOST_HEADER).lower() - return self._target_host + if self.__target_host is None: + self.__target_host = self._get_header(self.HOST_HEADER).lower() + return self.__target_host @property def has_target_host(self) -> bool: return len(self.target_host) > 0 + @property + def target_redirect(self) -> str: + if self.__target_redirect is None: + self.__target_redirect = self._get_header(self.REDIRECT_HEADER).lower() + return self.__target_redirect + + @property + def has_target_redirect(self) -> bool: + return len(self.target_redirect) > 0 + @typing.override def do_HEAD(self) -> None: self._pre_log_request() - super().do_HEAD() + if ( + page := self.__get_page(self.path) + ) is not None and page.redirect is not None: + return self.send_redirect(page.redirect) + return super().do_HEAD() @typing.override def do_GET(self) -> None: self._pre_log_request() if self.path == "/" and self.host == self.default_host: return self.send_basic_body(self.server_signature()) - super().do_GET() - return None + if ( + page := self.__get_page(self.path) + ) is not None and page.redirect is not None: + return self.send_redirect(page.redirect) + return super().do_GET() def do_PUT(self) -> None: self._pre_log_request() - if (sub_path := self.__check_update_request()) is None: + if (path := self.__check_update_request()) is None: return None if self.has_target_host and not self.__valid_host(self.target_host): return self.send_error( @@ -238,9 +264,26 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): if ( self.has_target_host and (page := self.registry.get_from_host(self.target_host)) is not None - and page.path != sub_path + and page.path != path ): return self.send_error(http.HTTPStatus.FORBIDDEN, "Host already taken") + if self.has_target_redirect: + self._update_redirect(path) + else: + self._update_extract(path) + if self.has_target_host and self.cert_manager.create_or_update( + self.target_host + ): + self.registry.set_host(path, self.target_host) + return None + + def do_DELETE(self) -> None: + self._pre_log_request() + if (path := self.__check_update_request()) is None: + return None + return self._update_remove(path) + + def _update_extract(self, path: str) -> None: if self.in_size == 0: return self.send_error(http.HTTPStatus.LENGTH_REQUIRED, "No body found") if self.in_size > self.max_size_bytes: @@ -250,39 +293,48 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): ) try: file_bytes = io.BytesIO(self.rfile.read(self.in_size)) - self.data_dir.extract_tar_bytes(sub_path, file_bytes) + self.data_dir.extract_tar_bytes(path, file_bytes) except tarfile.TarError: return self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid tar archive") except Exception as e: return self.send_error(http.HTTPStatus.INTERNAL_SERVER_ERROR, str(e)) + self.registry.add(path) + self.token_manager.set_token(path, self.token) self.send_status_only( http.HTTPStatus.CREATED, - f"Resource /{sub_path}/ updated", + f"Resource /{path}/ updated", ) - self.registry.add(sub_path) - self.token_manager.set_token(self.token, sub_path) - if self.has_target_host and self.cert_manager.create_or_update( - self.target_host - ): - self.registry.set_host(sub_path, self.target_host) return None - def do_DELETE(self) -> None: - self._pre_log_request() - if (sub_path := self.__check_update_request()) is None: - return None - if not self.data_dir.exists(sub_path): + def _update_redirect(self, path: str) -> None: + if self.in_size > 0: + return self.send_error( + http.HTTPStatus.BAD_REQUEST, + f"No content must be sent with {self.REDIRECT_HEADER}", + ) + self.data_dir.empty(path) + self.registry.add(path) + self.token_manager.set_token(path, self.token) + self.registry.set_redirect(path, self.target_redirect) + self.send_status_only( + http.HTTPStatus.CREATED, + f"Resource /{path}/ updated", + ) + return None + + def _update_remove(self, path: str) -> None: + if not self.data_dir.exists(path): self.send_error(http.HTTPStatus.NOT_FOUND, "Not found") return None try: - self.data_dir.remove(sub_path) + self.data_dir.remove(path) except Exception as e: return self.send_error(http.HTTPStatus.INTERNAL_SERVER_ERROR, str(e)) self.send_status_only( http.HTTPStatus.NO_CONTENT, - f"Resource /{sub_path}/ removed", + f"Resource /{path}/ removed", ) - self.registry.remove(sub_path) + self.registry.remove(path) return None @typing.override @@ -294,30 +346,27 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): def translate_path(self, path: str) -> str: if path.startswith(self.CERTBOT_CHALLENGE_PATH): return self.certbot_www + path - if ( - self.host != self.default_host - and (page := self.registry.get_from_host(self.host)) is not None - ): + page = self.__get_page(path) + if page is None: + if path in self.AUTHORIZED_PATHS: + return super().translate_path(path) + return "" + if self.host != self.default_host: path = f"/{page.path}" + path - elif self.host != self.default_host: - return "" - elif ( - path not in self.AUTHORIZED_PATHS - and self.__get_subpath(path, self.GET_PATH_REGEX) is None - ): # not a valid path - return "" if pathlib.Path(path).name.startswith("."): # hidden files return "" return super().translate_path(path) def __check_update_request(self) -> str | None: if not self._has_header(self.TOKEN_HEADER): - self.send_error(http.HTTPStatus.BAD_REQUEST, "No X-Token header in request") + self.send_error( + http.HTTPStatus.BAD_REQUEST, f"No {self.TOKEN_HEADER} header in request" + ) return None if not self.token_manager.is_valid(self.token): self.send_error(http.HTTPStatus.UNAUTHORIZED, "Invalid token") return None - if (sub_path := self.__get_subpath(self.path, self.UPDATE_PATH_REGEX)) is None: + if (sub_path := self.__get_path(self.path, self.UPDATE_PATH_REGEX)) is None: self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid path") return None if not self.token_manager.is_valid_for_path(self.token, sub_path): @@ -325,7 +374,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): return None return sub_path - def __get_subpath(self, path: str, regex: re.Pattern) -> str | None: + def __get_path(self, path: str, regex: re.Pattern) -> str | None: if (match := regex.match(path.lower())) is not None: return match.group(1) return None @@ -336,16 +385,20 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): and len(host) < 256 ) + def __get_page(self, src_path: str) -> Page | None: + if self.host == self.default_host: + if path := self.__get_path(src_path, self.GET_PATH_REGEX): + return self.registry.get_from_path(path) + return None + return self.registry.get_from_host(self.host) + class UpgradeHandler(BaseHandler): server_version = "StaplerUpgradeServer/" + project.get_version() def do_HEAD(self) -> None: self._pre_log_request() - self.send_status_only( - http.HTTPStatus.MOVED_PERMANENTLY, - headers={"Location": f"https://{self.host}{self.path.lower()}"}, - ) + self.send_redirect(f"https://{self.host}{self.path}") def do_GET(self) -> None: self.do_HEAD() diff --git a/src/page.py b/src/page.py index d648229..4fa7d33 100644 --- a/src/page.py +++ b/src/page.py @@ -7,11 +7,14 @@ class Page: with_index: bool = False host: str | None = None token_hash: str | None = None + redirect: str | None = None def __repr__(self) -> str: out = f"/{self.path}/" if self.host is not None: out += f" [{self.host}]" - if not self.with_index: + if self.redirect is not None: + out += f" (redirect: {self.redirect})" + elif not self.with_index: out += " (no index)" return out diff --git a/src/registry.py b/src/registry.py index 2ae04b9..7a3b64e 100644 --- a/src/registry.py +++ b/src/registry.py @@ -17,6 +17,7 @@ class Registry: HOST_FILE = ".host" TOKEN_FILE = ".token" # noqa: S105 + REDIRECT_FILE = ".redirect" def __init__(self, params: Parameters) -> None: self.logger: logging.Logger = logging.getLogger(self.__class__.__name__) @@ -37,6 +38,7 @@ class Registry: self.data_dir.has_index(path), self.data_dir.get_file(path, self.HOST_FILE), self.data_dir.get_file(path, self.TOKEN_FILE), + self.data_dir.get_file(path, self.REDIRECT_FILE), ) self.logger.info("Updated %s", self.pages[path]) @@ -52,6 +54,12 @@ class Registry: self.pages[path].token_hash = token_hash self.logger.debug("Updated %s", self.pages[path]) + def set_redirect(self, path: str, redirect: str) -> None: + if path in self.pages and self.pages[path].redirect != redirect: + self.data_dir.set_file(path, self.REDIRECT_FILE, redirect) + self.pages[path].redirect = redirect + self.logger.debug("Updated %s", self.pages[path]) + def remove(self, path: str) -> None: if path in self.pages: page = self.pages[path] diff --git a/src/token_manager.py b/src/token_manager.py index 9c52671..eb40a5e 100644 --- a/src/token_manager.py +++ b/src/token_manager.py @@ -48,7 +48,7 @@ class TokenManager: page.token_hash is None or page.token_hash == self.__hash_token(token) ) - def set_token(self, token: str, path: str) -> None: + def set_token(self, path: str, token: str) -> None: self.registry.set_token_hash(path, self.__hash_token(token)) def new_token(self) -> None: diff --git a/tests/test_data_dir.py b/tests/test_data_dir.py index e572f13..ef79df7 100644 --- a/tests/test_data_dir.py +++ b/tests/test_data_dir.py @@ -109,7 +109,7 @@ class TestDataDir(BaseTestCase): self.assert_file_content(self.tmp_path / "test_1" / "value", "value") assert not (self.tmp_path / "test_1" / ".value").exists() assert not (self.tmp_path / "test_1" / ".git").exists() - assert (self.tmp_path / "test_1" / "dir").exists() + assert (self.tmp_path / "test_1" / "dir").is_dir() assert not (self.tmp_path / "test_1" / "dir" / ".invalid").exists() assert not (self.tmp_path / "test_1" / "dir" / ".test").exists() @@ -128,6 +128,28 @@ class TestDataDir(BaseTestCase): self.data_dir.extract_tar_bytes("~test", tar_bytes) assert not (self.tmp_path / "~test").exists() + def test_empty_create(self) -> None: + self.data_dir.empty("test_1") + assert (self.tmp_path / "test_1").is_dir() + self.assertListEqual(list((self.tmp_path / "test_1").iterdir()), []) + + def test_empty_existing(self) -> None: + self.__create_path("test_1", {".host": "value"}) + self.data_dir.empty("test_1") + assert (self.tmp_path / "test_1").is_dir() + self.assertListEqual(list((self.tmp_path / "test_1").iterdir()), []) + + def test_exists_invalid_dir(self) -> None: + self.__create_path(".certbot") + assert not self.data_dir.exists(".certbot") + + def test_exists_ok(self) -> None: + self.__create_path("test_1") + assert self.data_dir.exists("test_1") + + def test_exists_fail(self) -> None: + assert not self.data_dir.exists("test_1") + def __create_path(self, path: str, files: dict[str, str] | None = None) -> None: (self.tmp_path / path).mkdir() if files is not None: diff --git a/tests/test_handlers.py b/tests/test_handlers.py index d5486bb..5a2e01d 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -156,6 +156,24 @@ class TestRequestHandler(BaseHandlerTestCase): handler.data_dir = self.data_dir return handler + def test_do_head_redirect(self) -> None: + handler = self._get_handler("/path") + with ( + self.mock_call( + self.registry.get_from_path, + ["path"], + Page("path", redirect="https://example.com"), + ), + self.expects_status_only( + handler, + http.HTTPStatus.MOVED_PERMANENTLY, + headers={"Location": "https://example.com"}, + ), + self.patch("http.server.SimpleHTTPRequestHandler.do_HEAD", count=0), + self.seal_mocks(), + ): + handler.do_HEAD() + def test_do_head_proxy(self) -> None: handler = self._get_handler() with ( @@ -165,9 +183,28 @@ class TestRequestHandler(BaseHandlerTestCase): handler.do_HEAD() def test_do_get_index(self) -> None: - handler = self._get_handler("/") + handler = self._get_handler() with ( self.expects_basic_body(handler, handler.server_signature()), + self.patch("http.server.SimpleHTTPRequestHandler.do_GET", count=0), + self.seal_mocks(), + ): + handler.do_GET() + + def test_do_get_redirect(self) -> None: + handler = self._get_handler("/path") + with ( + self.mock_call( + self.registry.get_from_path, + ["path"], + Page("path", redirect="https://example.com"), + ), + self.expects_status_only( + handler, + http.HTTPStatus.MOVED_PERMANENTLY, + headers={"Location": "https://example.com"}, + ), + self.patch("http.server.SimpleHTTPRequestHandler.do_GET", count=0), self.seal_mocks(), ): handler.do_GET() @@ -175,6 +212,10 @@ class TestRequestHandler(BaseHandlerTestCase): def test_do_get_proxy_on_other_path(self) -> None: handler = self._get_handler("/file") with ( + self.mock_call( + self.registry.get_from_path, + ["file"], + ), self.patch("http.server.SimpleHTTPRequestHandler.do_GET"), self.seal_mocks(), ): @@ -183,6 +224,10 @@ class TestRequestHandler(BaseHandlerTestCase): def test_do_get_proxy_on_other_host(self) -> None: handler = self._get_handler("/", {"Host": "other_host"}) with ( + self.mock_call( + self.registry.get_from_host, + ["other_host"], + ), self.patch("http.server.SimpleHTTPRequestHandler.do_GET"), self.seal_mocks(), ): @@ -271,7 +316,7 @@ class TestRequestHandler(BaseHandlerTestCase): ): handler.do_PUT() - def test_do_put_no_content(self) -> None: + def test_do_put_extract_no_content(self) -> None: handler = self._get_handler("/path", {"X-Token": "secret"}) with ( self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 @@ -287,7 +332,7 @@ class TestRequestHandler(BaseHandlerTestCase): ): handler.do_PUT() - def test_do_put_content_too_large(self) -> None: + def test_do_put_extract_content_too_large(self) -> None: handler = self._get_handler( "/path", {"X-Token": "secret", "Content-Length": "999999999"} ) @@ -307,7 +352,7 @@ class TestRequestHandler(BaseHandlerTestCase): ): handler.do_PUT() - def test_do_put_tar_error(self) -> None: + def test_do_put_extract_tar_error(self) -> None: handler = self._get_handler( "/path", {"X-Token": "secret", "Content-Length": "1"} ) @@ -328,7 +373,7 @@ class TestRequestHandler(BaseHandlerTestCase): handler.do_PUT() self.data_dir.extract_tar_bytes.assert_called_once() - def test_do_put_extract_error(self) -> None: + def test_do_put_extract_other_error(self) -> None: handler = self._get_handler( "/path", {"X-Token": "secret", "Content-Length": "1"} ) @@ -347,7 +392,7 @@ class TestRequestHandler(BaseHandlerTestCase): handler.do_PUT() self.data_dir.extract_tar_bytes.assert_called_once() - def test_do_put_ok(self) -> None: + def test_do_put_extract_ok(self) -> None: handler = self._get_handler( "/path", {"X-Token": "secret", "Content-Length": "1"} ) @@ -361,7 +406,7 @@ class TestRequestHandler(BaseHandlerTestCase): ), self.mock_call_unchecked(self.data_dir.extract_tar_bytes), self.mock_call(self.registry.add, ["path"]), - self.mock_call(self.token_manager.set_token, ["secret", "path"]), + self.mock_call(self.token_manager.set_token, ["path", "secret"]), self.expects_status_only( handler, http.HTTPStatus.CREATED, "Resource /path/ updated" ), @@ -369,7 +414,7 @@ class TestRequestHandler(BaseHandlerTestCase): ): handler.do_PUT() - def test_do_put_ok_with_host_fail_init(self) -> None: + def test_do_put_extract_with_host_fail_init(self) -> None: handler = self._get_handler( "/path", {"X-Token": "secret", "Content-Length": "1", "X-Host": "example.com"}, @@ -385,7 +430,7 @@ class TestRequestHandler(BaseHandlerTestCase): self.mock_call(self.registry.get_from_host, ["example.com"], Page("path")), self.mock_call_unchecked(self.data_dir.extract_tar_bytes), self.mock_call(self.registry.add, ["path"]), - self.mock_call(self.token_manager.set_token, ["secret", "path"]), + self.mock_call(self.token_manager.set_token, ["path", "secret"]), self.mock_call(self.cert_manager.create_or_update, ["example.com"], False), # noqa: FBT003 self.expects_status_only( handler, http.HTTPStatus.CREATED, "Resource /path/ updated" @@ -394,7 +439,7 @@ class TestRequestHandler(BaseHandlerTestCase): ): handler.do_PUT() - def test_do_put_ok_with_host(self) -> None: + def test_do_put_extract_with_host(self) -> None: handler = self._get_handler( "/path", {"X-Token": "secret", "Content-Length": "1", "X-Host": "example.com"}, @@ -410,7 +455,88 @@ class TestRequestHandler(BaseHandlerTestCase): self.mock_call(self.registry.get_from_host, ["example.com"], Page("path")), self.mock_call_unchecked(self.data_dir.extract_tar_bytes), self.mock_call(self.registry.add, ["path"]), - self.mock_call(self.token_manager.set_token, ["secret", "path"]), + self.mock_call(self.token_manager.set_token, ["path", "secret"]), + self.mock_call(self.cert_manager.create_or_update, ["example.com"], True), # noqa: FBT003 + self.mock_call(self.registry.set_host, ["path", "example.com"]), + self.expects_status_only( + handler, http.HTTPStatus.CREATED, "Resource /path/ updated" + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_redirect_with_content(self) -> None: + handler = self._get_handler( + "/path", + { + "X-Token": "secret", + "X-Redirect": "https://example.com", + "Content-Length": "1", + }, + ) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.expects_error( + handler, + http.HTTPStatus.BAD_REQUEST, + "No content must be sent with X-Redirect", + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_redirect_ok(self) -> None: + handler = self._get_handler( + "/path", + { + "X-Token": "secret", + "X-Redirect": "https://example.com", + }, + ) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.mock_call(self.data_dir.empty, ["path"]), + self.mock_call(self.registry.add, ["path"]), + self.mock_call(self.token_manager.set_token, ["path", "secret"]), + self.mock_call(self.registry.set_redirect, ["path", "https://example.com"]), + self.expects_status_only( + handler, http.HTTPStatus.CREATED, "Resource /path/ updated" + ), + self.seal_mocks(), + ): + handler.do_PUT() + + def test_do_put_redirect_with_host(self) -> None: + handler = self._get_handler( + "/path", + { + "X-Token": "secret", + "X-Redirect": "https://example.com", + "X-Host": "example.com", + }, + ) + with ( + self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003 + self.mock_call( + self.token_manager.is_valid_for_path, + ["secret", "path"], + True, # noqa: FBT003 + ), + self.mock_call(self.registry.get_from_host, ["example.com"], Page("path")), + self.mock_call(self.data_dir.empty, ["path"]), + self.mock_call(self.registry.add, ["path"]), + self.mock_call(self.token_manager.set_token, ["path", "secret"]), + self.mock_call(self.registry.set_redirect, ["path", "https://example.com"]), self.mock_call(self.cert_manager.create_or_update, ["example.com"], True), # noqa: FBT003 self.mock_call(self.registry.set_host, ["path", "example.com"]), self.expects_status_only( @@ -577,6 +703,7 @@ class TestRequestHandler(BaseHandlerTestCase): def test_translate_path_dotfile(self) -> None: handler = self._get_handler() with ( + self.mock_call(self.registry.get_from_path, ["path"], Page("path")), self.patch("http.server.SimpleHTTPRequestHandler.translate_path", count=0), self.seal_mocks(), ): @@ -603,6 +730,7 @@ class TestRequestHandler(BaseHandlerTestCase): def test_translate_path_default_host(self) -> None: handler = self._get_handler() with ( + self.mock_call(self.registry.get_from_path, ["path"], Page("path")), self.patch_call( "http.server.SimpleHTTPRequestHandler.translate_path", ["/path/index.html"], @@ -614,6 +742,18 @@ class TestRequestHandler(BaseHandlerTestCase): None, ) + def test_translate_path_default_host_not_found(self) -> None: + handler = self._get_handler() + with ( + self.mock_call(self.registry.get_from_path, ["path"]), + self.patch("http.server.SimpleHTTPRequestHandler.translate_path", count=0), + self.seal_mocks(), + ): + self.assertEqual( + handler.translate_path("/path/index.html"), + "", + ) + class TestUpgradeHandler(BaseHandlerTestCase): def _get_handler( diff --git a/tests/test_page.py b/tests/test_page.py index 0668bfa..14b381c 100644 --- a/tests/test_page.py +++ b/tests/test_page.py @@ -15,3 +15,9 @@ class TestPage(BaseTestCase): str(Page("test_1", with_index=True, host="example.com")), "/test_1/ [example.com]", ) + + def test_repr_with_redirect(self) -> None: + self.assertEqual( + str(Page("test_1", redirect="https://example.com")), + "/test_1/ (redirect: https://example.com)", + ) diff --git a/tests/test_registry.py b/tests/test_registry.py index 9bb29e0..96d63de 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -29,14 +29,18 @@ class TestRegistry(BaseTestCase): [ ["test_1", Registry.HOST_FILE], ["test_1", Registry.TOKEN_FILE], + ["test_1", Registry.REDIRECT_FILE], ["test_2", Registry.HOST_FILE], ["test_2", Registry.TOKEN_FILE], + ["test_2", Registry.REDIRECT_FILE], ], [ "test_1_host", "test_1_token", None, + None, "test_2_token", + "test_2_redirect", ], ), self.seal_mocks(), @@ -51,6 +55,7 @@ class TestRegistry(BaseTestCase): True, # noqa: FBT003 "test_1_host", "test_1_token", + None, ), ) self.assertEqual( @@ -60,6 +65,7 @@ class TestRegistry(BaseTestCase): False, # noqa: FBT003 None, "test_2_token", + "test_2_redirect", ), ) @@ -108,6 +114,23 @@ class TestRegistry(BaseTestCase): self.registry.set_token_hash("test_1", "new_value") self.assertEqual(self.registry.pages["test_1"].token_hash, "new_value") + def test_set_redirect(self) -> None: + self.registry.pages["test_1"] = Page( + "test_1", + redirect="https://example.com", + ) + with ( + self.mock_call( + self.data_dir.set_file, + ["test_1", Registry.REDIRECT_FILE, "https://new-example.com"], + ), + self.seal_mocks(), + ): + self.registry.set_redirect("test_1", "https://new-example.com") + self.assertEqual( + self.registry.pages["test_1"].redirect, "https://new-example.com" + ) + def test_remove(self) -> None: self.registry.pages["test_1"] = Page( "test_1", diff --git a/tests/test_token_manager.py b/tests/test_token_manager.py index df96cbb..67c49e1 100644 --- a/tests/test_token_manager.py +++ b/tests/test_token_manager.py @@ -117,7 +117,7 @@ class TestTokenManager(BaseTestCase): ), self.seal_mocks(), ): - self.token_manager.set_token("secret", "test_1") + self.token_manager.set_token("test_1", "secret") @unittest.mock.patch("secrets.token_hex") def test_new_token(self, mock_token_hex: unittest.mock.Mock) -> None: