diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 7729d54..da57f06 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -66,7 +66,7 @@ docker-run docker run - [x] X-Redirect - [x] X-Proxy - [ ] detect root certificate change and update server -- [ ] detect tokens change and update token_manager +- [x] detect tokens change and update token_manager - [ ] allow args before/after command - [x] proper doc diff --git a/src/server.py b/src/server.py index e532f60..277cb1e 100644 --- a/src/server.py +++ b/src/server.py @@ -112,9 +112,17 @@ class StaplerServer: server.server_address[0], server.server_port, ) - threading.Thread(target=server.serve_forever).start() + threading.Thread(target=server.serve_forever, daemon=True).start() return server + def __token_manager_background(self) -> None: + with contextlib.suppress(KeyboardInterrupt): + while True: + self.token_manager.detect_file_change() + + def __start_background_tasks(self) -> None: + threading.Thread(target=self.__token_manager_background, daemon=True).start() + def run(self) -> int: self.logger.info("Version %s", project.get_version()) for line in STAPLER_ASCII.split("\n"): @@ -127,6 +135,7 @@ class StaplerServer: "https" if https else "http", self.params.host, ) + self.__start_background_tasks() with contextlib.suppress(KeyboardInterrupt): base_server.serve_forever() self.logger.info("Shutting down...") diff --git a/src/token_manager.py b/src/token_manager.py index eb40a5e..a4c8854 100644 --- a/src/token_manager.py +++ b/src/token_manager.py @@ -13,6 +13,7 @@ if typing.TYPE_CHECKING: class TokenManager: __slots__ = [ + "last_file_change", "logger", "registry", "token_hashes", @@ -28,6 +29,7 @@ class TokenManager: self.tokens_file: pathlib.Path = pathlib.Path(params.data_dir) / self.FILE self.registry: Registry = registry self.token_hashes: list[str] = [] + self.last_file_change: int | float = 0 def init(self) -> None: self.logger.debug("Initializing...") @@ -38,7 +40,7 @@ class TokenManager: self.token_hashes = self.__load_hashes() if not self.tokens_file.exists(): self.__save_hashes() - self.tokens_file.chmod(0o600) + self.last_file_change = self.tokens_file.stat().st_mtime def is_valid(self, token: str) -> bool: return self.__hash_token(token) in self.token_hashes @@ -58,6 +60,14 @@ class TokenManager: self.logger.warning("NEW TOKEN: %s", new_token) self.logger.warning("Please copy this secret value before it disappears") + def detect_file_change(self) -> None: + if ( + not self.tokens_file.exists() + or self.tokens_file.stat().st_mtime != self.last_file_change + ): + self.logger.debug("Detected change: %s", self.tokens_file) + self.init() + def __hash_token(self, token: str) -> str: return hashlib.sha512( (self.token_salt + token).encode(), usedforsecurity=True diff --git a/tests/test_server.py b/tests/test_server.py index 08dae3d..423bc64 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -53,6 +53,7 @@ class TestStaplerServer(BaseTestCase): def test_run_http(self) -> None: self.server.params = Parameters(https=False, with_certificates=False) + self.token_manager.detect_file_change.side_effect = KeyboardInterrupt with ( self.mock_call(self.registry.load_pages), self.mock_call(self.data_dir.init), @@ -62,8 +63,10 @@ class TestStaplerServer(BaseTestCase): self.seal_mocks(), ): self.assertEqual(self.server.run(), 0) + self.token_manager.detect_file_change.assert_called_once() def test_run_https_fail(self) -> None: + self.token_manager.detect_file_change.side_effect = KeyboardInterrupt with ( self.mock_call(self.registry.load_pages), self.mock_call(self.registry.get_hosts, [], []), @@ -76,8 +79,10 @@ class TestStaplerServer(BaseTestCase): self.seal_mocks(), ): self.assertEqual(self.server.run(), 0) + self.token_manager.detect_file_change.assert_called_once() def test_run_https(self) -> None: + self.token_manager.detect_file_change.side_effect = KeyboardInterrupt with ( self.mock_call(self.registry.load_pages), self.mock_call(self.registry.get_hosts, [], []), @@ -96,3 +101,4 @@ class TestStaplerServer(BaseTestCase): self.seal_mocks(self.context_mock), ): self.assertEqual(self.server.run(), 0) + self.token_manager.detect_file_change.assert_called_once() diff --git a/tests/test_token_manager.py b/tests/test_token_manager.py index 67c49e1..5e4831d 100644 --- a/tests/test_token_manager.py +++ b/tests/test_token_manager.py @@ -119,6 +119,21 @@ class TestTokenManager(BaseTestCase): ): self.token_manager.set_token("test_1", "secret") + def test_detect_file_change(self) -> None: + self.seal_mocks() + self.token_manager.detect_file_change() + self.assert_file_content(self.tmp_tokens_file, self.SALT_HASH) + self.assertEqual(self.tmp_tokens_file.stat().st_mode, 0o100600) + self.assertListEqual(self.token_manager.token_hashes, []) + + def test_detect_file_change_nothing(self) -> None: + with self.tmp_tokens_file.open(mode="w") as file: + file.write(self.SALT_HASH + "\n" + self.SECRET_HASH) + self.token_manager.last_file_change = self.tmp_tokens_file.stat().st_mtime + self.seal_mocks() + self.token_manager.detect_file_change() + self.assertListEqual(self.token_manager.token_hashes, []) + @unittest.mock.patch("secrets.token_hex") def test_new_token(self, mock_token_hex: unittest.mock.Mock) -> None: mock_token_hex.return_value = "secret"