diff --git a/pyproject.toml b/pyproject.toml index e4e2470..476ee2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ module-name = "stapler" [dependency-groups] dev = [ "coverage>=7.13.5", + "parameterized>=0.9.0", "pytest>=9.0.3", "ruff>=0.15.10", "ty>=0.0.29", diff --git a/stapler/cert_manager.py b/stapler/cert_manager.py index c9adb72..158c1cd 100644 --- a/stapler/cert_manager.py +++ b/stapler/cert_manager.py @@ -5,6 +5,8 @@ import ssl import subprocess import typing +from stapler.strings import valid_host + if typing.TYPE_CHECKING: from .params import Parameters @@ -57,7 +59,7 @@ class CertManager: def create_or_update(self, host: str) -> bool: created = self.init_cert(host) - if self.with_certbot and self.__create_certbot(host): + if self.with_certbot and valid_host(host) and self.__create_certbot(host): return True return created or self.__create_self_signed(host) diff --git a/stapler/handlers.py b/stapler/handlers.py index bca736e..426dc4a 100644 --- a/stapler/handlers.py +++ b/stapler/handlers.py @@ -16,6 +16,7 @@ import requests from . import PKG_VERSION, STAPLER_ASCII, logs from .data_dir import DataDir +from .strings import sanitize_string, valid_host if typing.TYPE_CHECKING: from .page import Page @@ -25,7 +26,6 @@ if typing.TYPE_CHECKING: class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): - SANITIZE_REGEX = re.compile(r"[^\x20-\x7F]+") timeout = 10 REQUEST_COUNT = 0 @@ -82,7 +82,7 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): @typing.override def address_string(self) -> str: # pragma: no cover - return self.SANITIZE_REGEX.sub("?", super().address_string()) + return sanitize_string(super().address_string()) @typing.override def log_message(self, format: str, *args: typing.Any) -> None: # pragma: no cover @@ -100,7 +100,7 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): self.address_string(), self.host, format(self.__class__.REQUEST_COUNT, "07_d"), - self.SANITIZE_REGEX.sub("?", self.requestline), + sanitize_string(self.requestline), ) fmt = "← %s - %s - %s - %s - %s" if self.in_size > 0: @@ -128,7 +128,7 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): self.address_string(), self.host, format(self.__class__.REQUEST_COUNT, "07_d"), - self.SANITIZE_REGEX.sub("?", self.requestline), + sanitize_string(self.requestline), ) fmt = "→ %s - %s - %s - %s - %s" if size != "": @@ -579,7 +579,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): f"Cannot use {self.HOST_ONLY_HEADER} with {self.HOST_HEADER}", ) return None - if self.has_target_host and not self.__valid_host(self.target_host): + if self.has_target_host and not valid_host(self.target_host): self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid requested host") return None if self.has_target_proxy and self.has_target_redirect: @@ -602,12 +602,6 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): return match.group(1) return None - def __valid_host(self, host: str) -> bool: - return ( - all(self.HOST_PART_REGEX.fullmatch(part) for part in host.split(".")) - and len(host) < 256 - ) - def __get_page(self, src_path: str) -> Page | None: if self.host == self.default_host: if ( diff --git a/stapler/strings.py b/stapler/strings.py new file mode 100644 index 0000000..462db7c --- /dev/null +++ b/stapler/strings.py @@ -0,0 +1,19 @@ +import re + +__HOST_PART_REGEX = re.compile(r"^([a-z0-9]|[a-z0-9][a-z0-9-]{,61}[a-z0-9])$") +__SANITIZE_REGEX = re.compile(r"[^\x20-\x7F]") + + +def valid_host(host: str) -> bool: + parts = host.split(".") + return ( + len(parts) > 1 + and len(parts[-1]) > 1 + and all(__HOST_PART_REGEX.fullmatch(part) for part in parts) + and not all(part.isnumeric() for part in parts) + and len(host) < 256 + ) + + +def sanitize_string(raw: str) -> str: + return __SANITIZE_REGEX.sub("?", raw) diff --git a/tests/test_cert_manager.py b/tests/test_cert_manager.py index dc9f874..3d22490 100644 --- a/tests/test_cert_manager.py +++ b/tests/test_cert_manager.py @@ -44,32 +44,32 @@ class TestRegistry(BaseTestCase): self.patch("shutil.which", count=0), self.patch("subprocess.check_output", count=0), ): - self._make_self_signed("localhost") - self.cert_manager.init(["localhost"]) + self._make_self_signed("example.com") + self.cert_manager.init(["example.com"]) def test_exists_self_signed(self) -> None: - self._make_self_signed("localhost") - assert self.cert_manager.exists("localhost") + self._make_self_signed("example.com") + assert self.cert_manager.exists("example.com") def test_exists_certbot(self) -> None: - self._make_certbot("localhost") - assert self.cert_manager.exists("localhost") + self._make_certbot("example.com") + assert self.cert_manager.exists("example.com") def test_exists_fail(self) -> None: - assert not self.cert_manager.exists("localhost") + assert not self.cert_manager.exists("example.com") def test_exists_fail_without_certbot(self) -> None: self.cert_manager.with_certbot = False - self._make_certbot("localhost") - assert not self.cert_manager.exists("localhost") + self._make_certbot("example.com") + assert not self.cert_manager.exists("example.com") def test_init_cert_existing(self) -> None: with ( self.patch("shutil.which", count=0), self.patch("subprocess.check_output", count=0), ): - self._make_self_signed("localhost") - assert not self.cert_manager.init_cert("localhost") + self._make_self_signed("example.com") + assert not self.cert_manager.init_cert("example.com") def test_init_cert_fail(self) -> None: with ( @@ -77,7 +77,7 @@ class TestRegistry(BaseTestCase): self.patch("subprocess.check_output") as process_mock, ): process_mock.side_effect = subprocess.CalledProcessError(1, "", output=b"") - assert not self.cert_manager.init_cert("localhost") + assert not self.cert_manager.init_cert("example.com") def test_init_cert_new(self) -> None: with ( @@ -85,135 +85,137 @@ class TestRegistry(BaseTestCase): self.patch("subprocess.check_output") as process_mock, ): process_mock.side_effect = lambda *_, **__: self._make_self_signed( - "localhost" + "example.com" ) - assert self.cert_manager.init_cert("localhost") + assert self.cert_manager.init_cert("example.com") def test_create_or_update_existing_no_certbot(self) -> None: - self._make_self_signed("localhost") + self._make_self_signed("example.com") self.cert_manager.with_certbot = False with ( self.patch("shutil.which", return_value=""), self.patch("subprocess.check_output") as process_mock, ): process_mock.side_effect = lambda *_, **__: self._make_self_signed( - "localhost" + "example.com" ) - assert self.cert_manager.create_or_update("localhost") + assert self.cert_manager.create_or_update("example.com") def test_create_or_update_existing_certbot(self) -> None: - self._make_certbot("localhost") + self._make_certbot("example.com") with ( self.patch("shutil.which", return_value=""), self.patch("subprocess.check_output") as process_mock, ): - process_mock.side_effect = lambda *_, **__: self._make_certbot("localhost") - assert self.cert_manager.create_or_update("localhost") + process_mock.side_effect = lambda *_, **__: self._make_certbot( + "example.com" + ) + assert self.cert_manager.create_or_update("example.com") def test_create_or_update_existing_fail_both(self) -> None: - self._make_certbot("localhost") + self._make_certbot("example.com") with ( self.patch("shutil.which", return_value="", count=2), self.patch("subprocess.check_output", count=2) as process_mock, ): process_mock.side_effect = subprocess.CalledProcessError(1, "", output=b"") - assert not self.cert_manager.create_or_update("localhost") + assert not self.cert_manager.create_or_update("example.com") def test_create_or_update_existing_fail_both_binary(self) -> None: - self._make_certbot("localhost") + self._make_certbot("example.com") with ( self.patch("shutil.which", count=2), self.patch("subprocess.check_output", count=0), ): - assert not self.cert_manager.create_or_update("localhost") + assert not self.cert_manager.create_or_update("example.com") def test_get_cert_certbot(self) -> None: - self._make_certbot("localhost") + self._make_certbot("example.com") self.assertEqual( - self.cert_manager.get_cert("localhost"), - self.certbot_conf / "live" / "localhost" / CertManager.CRT_FILE, + self.cert_manager.get_cert("example.com"), + self.certbot_conf / "live" / "example.com" / CertManager.CRT_FILE, ) def test_get_cert_self_signed(self) -> None: - self._make_self_signed("localhost") + self._make_self_signed("example.com") self.assertEqual( - self.cert_manager.get_cert("localhost"), - self.self_signed_path / "localhost" / CertManager.CRT_FILE, + self.cert_manager.get_cert("example.com"), + self.self_signed_path / "example.com" / CertManager.CRT_FILE, ) def test_get_cert_fail(self) -> None: self.assertRaises( CertManagerError, - lambda: self.cert_manager.get_cert("localhost"), + lambda: self.cert_manager.get_cert("example.com"), ) def test_get_key_certbot(self) -> None: - self._make_certbot("localhost") + self._make_certbot("example.com") self.assertEqual( - self.cert_manager.get_key("localhost"), - self.certbot_conf / "live" / "localhost" / CertManager.KEY_FILE, + self.cert_manager.get_key("example.com"), + self.certbot_conf / "live" / "example.com" / CertManager.KEY_FILE, ) def test_get_key_self_signed(self) -> None: - self._make_self_signed("localhost") + self._make_self_signed("example.com") self.assertEqual( - self.cert_manager.get_key("localhost"), - self.self_signed_path / "localhost" / CertManager.KEY_FILE, + self.cert_manager.get_key("example.com"), + self.self_signed_path / "example.com" / CertManager.KEY_FILE, ) def test_get_key_fail(self) -> None: self.assertRaises( CertManagerError, - lambda: self.cert_manager.get_key("localhost"), + lambda: self.cert_manager.get_key("example.com"), ) def test_sni_callback_no_host(self) -> None: - self._make_self_signed("localhost") + self._make_self_signed("example.com") with ( self.patch("ssl.create_default_context", count=0), ): self.cert_manager.sni_callback(self.socket_mock, None, self.context_mock) def test_sni_callback_fail(self) -> None: - self._make_self_signed("localhost") + self._make_self_signed("example.com") with ( self.patch("shutil.which", count=3), self.patch("ssl.create_default_context", count=0), ): self.cert_manager.sni_callback( - self.socket_mock, "new_host", self.context_mock + self.socket_mock, "example.fr", self.context_mock ) def test_sni_callback_create_context(self) -> None: - self._make_self_signed("localhost") + self._make_self_signed("example.com") with ( self.patch("ssl.create_default_context", return_value=self.context_mock), self.mock_call( self.context_mock.load_cert_chain, [ - self.self_signed_path / "localhost" / CertManager.CRT_FILE, - self.self_signed_path / "localhost" / CertManager.KEY_FILE, + self.self_signed_path / "example.com" / CertManager.CRT_FILE, + self.self_signed_path / "example.com" / CertManager.KEY_FILE, ], ), self.patch("shutil.which", count=0), ): self.cert_manager.sni_callback( - self.socket_mock, "localhost", self.context_mock + self.socket_mock, "example.com", self.context_mock ) def test_sni_callback_create_context_fail(self) -> None: - self._make_self_signed("localhost") + self._make_self_signed("example.com") with ( self.patch("ssl.create_default_context", return_value=self.context_mock), self.patch("shutil.which", count=0), ): self.context_mock.load_cert_chain.side_effect = Exception self.cert_manager.sni_callback( - self.socket_mock, "localhost", self.context_mock + self.socket_mock, "example.com", self.context_mock ) self.context_mock.load_cert_chain.assert_called_once_with( - self.self_signed_path / "localhost" / CertManager.CRT_FILE, - self.self_signed_path / "localhost" / CertManager.KEY_FILE, + self.self_signed_path / "example.com" / CertManager.CRT_FILE, + self.self_signed_path / "example.com" / CertManager.KEY_FILE, ) def _make_self_signed(self, host: str) -> None: diff --git a/tests/test_strings.py b/tests/test_strings.py new file mode 100644 index 0000000..00cdd19 --- /dev/null +++ b/tests/test_strings.py @@ -0,0 +1,22 @@ +import parameterized + +from stapler.strings import sanitize_string, valid_host + +from . import BaseTestCase + + +class TestStrings(BaseTestCase): + def test_sanitize(self) -> None: + self.assertEqual("??A??", sanitize_string("\n\tA\x00\x99")) + + @parameterized.parameterized.expand( + [("example.com"), ("test-test.com"), ("subdomain.example.com")] + ) + def test_valid_host(self, host: str) -> None: + self.assertTrue(valid_host(host), host) + + @parameterized.parameterized.expand( + [("example.c"), ("localhost"), ("127.0.0.1"), ("test..com"), ("www-.test.com")] + ) + def test_invalid_host(self, host: str) -> None: + self.assertFalse(valid_host(host), host) diff --git a/uv.lock b/uv.lock index b76ac4f..833f36f 100644 --- a/uv.lock +++ b/uv.lock @@ -127,6 +127,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/b2/87e62e8c3e2f4b32e5fe99e0b86d576da1312593b39f47d8ceef365e95ed/packaging-26.2-py3-none-any.whl", hash = "sha256:5fc45236b9446107ff2415ce77c807cee2862cb6fac22b8a73826d0693b0980e", size = 100195, upload-time = "2026-04-24T20:15:22.081Z" }, ] +[[package]] +name = "parameterized" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ea/49/00c0c0cc24ff4266025a53e41336b79adaa5a4ebfad214f433d623f9865e/parameterized-0.9.0.tar.gz", hash = "sha256:7fc905272cefa4f364c1a3429cbbe9c0f98b793988efb5bf90aac80f08db09b1", size = 24351, upload-time = "2023-03-27T02:01:11.592Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/2f/804f58f0b856ab3bf21617cccf5b39206e6c4c94c2cd227bde125ea6105f/parameterized-0.9.0-py2.py3-none-any.whl", hash = "sha256:4e0758e3d41bea3bbd05ec14fc2c24736723f243b28d702081aef438c9372b1b", size = 20475, upload-time = "2023-03-27T02:01:09.31Z" }, +] + [[package]] name = "pluggy" version = "1.6.0" @@ -212,6 +221,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "coverage" }, + { name = "parameterized" }, { name = "pytest" }, { name = "ruff" }, { name = "ty" }, @@ -223,6 +233,7 @@ requires-dist = [{ name = "requests", specifier = ">=2.33.1" }] [package.metadata.requires-dev] dev = [ { name = "coverage", specifier = ">=7.13.5" }, + { name = "parameterized", specifier = ">=0.9.0" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "ruff", specifier = ">=0.15.10" }, { name = "ty", specifier = ">=0.0.29" },