From 7f02abca1ac47e0afd3ffe7ccb9375028bdc2e87 Mon Sep 17 00:00:00 2001 From: klemek Date: Mon, 20 Apr 2026 19:55:35 +0200 Subject: [PATCH] feat: cert_manager detect file change --- DEVELOPMENT.md | 2 +- src/cert_manager.py | 10 ++++++++++ src/server.py | 37 +++++++++++++++++++++++++++++++------ src/token_manager.py | 10 ++++------ tests/test_cert_manager.py | 11 +++++++++++ tests/test_server.py | 4 ++++ tests/test_token_manager.py | 12 ++++-------- 7 files changed, 65 insertions(+), 21 deletions(-) diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index da57f06..c8b031e 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -65,7 +65,7 @@ docker-run docker run - [x] github actions - [x] X-Redirect - [x] X-Proxy -- [ ] detect root certificate change and update server +- [x] detect root certificate change and update server - [x] detect tokens change and update token_manager - [ ] allow args before/after command - [x] proper doc diff --git a/src/cert_manager.py b/src/cert_manager.py index d3eb795..2835b02 100644 --- a/src/cert_manager.py +++ b/src/cert_manager.py @@ -17,6 +17,7 @@ class CertManager: __slots__ = [ "certbot_conf", "certbot_www", + "last_file_change", "logger", "self_signed_path", "with_certbot", @@ -32,6 +33,7 @@ class CertManager: self.certbot_www: pathlib.Path = pathlib.Path(params.certbot_www) self.self_signed_path: pathlib.Path = pathlib.Path(params.self_signed_path) self.with_certbot: bool = params.with_certbot + self.last_file_change: int | float = 0 def init(self, hosts: list[str]) -> None: self.logger.debug("Initializing...") @@ -187,6 +189,7 @@ class CertManager: return None cert_file = self.get_cert(default_host) key_file = self.get_key(default_host) + self.last_file_change = cert_file.stat().st_mtime context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) context.load_cert_chain( cert_file, @@ -195,6 +198,13 @@ class CertManager: context.sni_callback = self.__sni_callback return context + def detect_default_cert_change(self, default_host: str) -> bool: + cert_file = self.get_cert(default_host) + if cert_file.exists() and cert_file.stat().st_mtime != self.last_file_change: + self.logger.debug("Detected change: %s", cert_file) + return True + return False + def __sni_callback( self, socket: ssl.SSLObject, host: str | None, _: ssl.SSLContext, / ) -> None | int: diff --git a/src/server.py b/src/server.py index 277cb1e..8bc1eef 100644 --- a/src/server.py +++ b/src/server.py @@ -2,6 +2,7 @@ import contextlib import http.server import logging import threading +import time import typing from . import ( @@ -24,9 +25,11 @@ class StaplerServer: "cert_manager", "data_dir", "default_host", + "https", "logger", "params", "registry", + "server", "token_manager", ] @@ -38,6 +41,8 @@ class StaplerServer: self.token_manager: TokenManager = TokenManager(params, self.registry) self.data_dir: DataDir = DataDir(params.data_dir) self.default_host: str = params.host.split(":", maxsplit=2)[0] + self.server: http.server.ThreadingHTTPServer | None = None + self.https = params.https def __get_all_hosts(self) -> list[str]: return [self.default_host, *self.registry.get_hosts()] @@ -115,29 +120,49 @@ class StaplerServer: threading.Thread(target=server.serve_forever, daemon=True).start() return server - def __token_manager_background(self) -> None: + def __token_manager_background(self) -> None: # pragma: no cover with contextlib.suppress(KeyboardInterrupt): while True: - self.token_manager.detect_file_change() + if self.token_manager.detect_file_change(): + self.token_manager.init() + time.sleep(1) + + def __cert_manager_background(self) -> None: # pragma: no cover + with contextlib.suppress(KeyboardInterrupt): + while True: + if ( + self.server is not None + and self.cert_manager.detect_default_cert_change(self.default_host) + and ( + context := self.cert_manager.get_https_context( + self.default_host + ) + ) + is not None + ): + self.server.socket = context.wrap_socket(self.server.socket) + time.sleep(1) def __start_background_tasks(self) -> None: threading.Thread(target=self.__token_manager_background, daemon=True).start() + if self.https: + threading.Thread(target=self.__cert_manager_background, daemon=True).start() def run(self) -> int: self.logger.info("Version %s", project.get_version()) for line in STAPLER_ASCII.split("\n"): self.logger.debug(line.ljust(36)) self.__startup() - base_server, https = self.__create_base_server() - upgrade_server = self.__start_upgrade_server() if https else None + self.server, self.https = self.__create_base_server() + upgrade_server = self.__start_upgrade_server() if self.https else None self.logger.info( "Server up and ready on %s://%s", - "https" if https else "http", + "https" if self.https else "http", self.params.host, ) self.__start_background_tasks() with contextlib.suppress(KeyboardInterrupt): - base_server.serve_forever() + self.server.serve_forever() self.logger.info("Shutting down...") if upgrade_server is not None: upgrade_server.shutdown() diff --git a/src/token_manager.py b/src/token_manager.py index a4c8854..a844032 100644 --- a/src/token_manager.py +++ b/src/token_manager.py @@ -60,13 +60,11 @@ class TokenManager: self.logger.warning("NEW TOKEN: %s", new_token) self.logger.warning("Please copy this secret value before it disappears") - def detect_file_change(self) -> None: - if ( - not self.tokens_file.exists() - or self.tokens_file.stat().st_mtime != self.last_file_change - ): + def detect_file_change(self) -> bool: + if self.tokens_file.stat().st_mtime != self.last_file_change: self.logger.debug("Detected change: %s", self.tokens_file) - self.init() + return True + return False def __hash_token(self, token: str) -> str: return hashlib.sha512( diff --git a/tests/test_cert_manager.py b/tests/test_cert_manager.py index ab0e7d7..fe1704d 100644 --- a/tests/test_cert_manager.py +++ b/tests/test_cert_manager.py @@ -249,6 +249,17 @@ class TestRegistry(BaseTestCase): self.socket_mock, "new_host", self.context_mock ) + def test_detect_default_cert_change(self) -> None: + self._make_self_signed("localhost") + assert self.cert_manager.detect_default_cert_change("localhost") + + def test_detect_default_cert_change_nothing(self) -> None: + self._make_self_signed("localhost") + self.cert_manager.last_file_change = ( + (self.self_signed_path / "localhost" / CertManager.CRT_FILE).stat().st_mtime + ) + assert not self.cert_manager.detect_default_cert_change("localhost") + def _make_self_signed(self, host: str) -> None: (self.self_signed_path / host).mkdir(parents=True, exist_ok=True) (self.self_signed_path / host / CertManager.CRT_FILE).touch() diff --git a/tests/test_server.py b/tests/test_server.py index 423bc64..a3137b4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -83,6 +83,7 @@ class TestStaplerServer(BaseTestCase): def test_run_https(self) -> None: self.token_manager.detect_file_change.side_effect = KeyboardInterrupt + self.cert_manager.detect_default_cert_change.side_effect = KeyboardInterrupt with ( self.mock_call(self.registry.load_pages), self.mock_call(self.registry.get_hosts, [], []), @@ -102,3 +103,6 @@ class TestStaplerServer(BaseTestCase): ): self.assertEqual(self.server.run(), 0) self.token_manager.detect_file_change.assert_called_once() + self.cert_manager.detect_default_cert_change.assert_called_once_with( + "localhost" + ) diff --git a/tests/test_token_manager.py b/tests/test_token_manager.py index 5e4831d..6f1de65 100644 --- a/tests/test_token_manager.py +++ b/tests/test_token_manager.py @@ -120,19 +120,15 @@ class TestTokenManager(BaseTestCase): self.token_manager.set_token("test_1", "secret") def test_detect_file_change(self) -> None: + self.tmp_tokens_file.touch() self.seal_mocks() - self.token_manager.detect_file_change() - self.assert_file_content(self.tmp_tokens_file, self.SALT_HASH) - self.assertEqual(self.tmp_tokens_file.stat().st_mode, 0o100600) - self.assertListEqual(self.token_manager.token_hashes, []) + assert self.token_manager.detect_file_change() def test_detect_file_change_nothing(self) -> None: - with self.tmp_tokens_file.open(mode="w") as file: - file.write(self.SALT_HASH + "\n" + self.SECRET_HASH) + self.tmp_tokens_file.touch() self.token_manager.last_file_change = self.tmp_tokens_file.stat().st_mtime self.seal_mocks() - self.token_manager.detect_file_change() - self.assertListEqual(self.token_manager.token_hashes, []) + assert not self.token_manager.detect_file_change() @unittest.mock.patch("secrets.token_hex") def test_new_token(self, mock_token_hex: unittest.mock.Mock) -> None: