From cc76d75bf5a9b082ce5e871fb442c354b149852c Mon Sep 17 00:00:00 2001 From: klemek Date: Sat, 18 Apr 2026 17:36:20 +0200 Subject: [PATCH] tests(server): add server tests --- src/server.py | 13 +----- tests/test_server.py | 102 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 11 deletions(-) create mode 100644 tests/test_server.py diff --git a/src/server.py b/src/server.py index 41b34b0..eecb27e 100644 --- a/src/server.py +++ b/src/server.py @@ -39,16 +39,7 @@ class StaplerServer: self.data_dir.init() self.token_manager.init() - def __create_https_context(self, server: http.server.HTTPServer) -> bool: - https = False - if ( - context := self.cert_manager.get_https_context(self.default_host) - ) is not None: - https = True - server.socket = context.wrap_socket(server.socket, server_side=True) - return https - - def __request_handler( + def __request_handler( # pragma: no cover self, *args: typing.Any ) -> http.server.BaseHTTPRequestHandler: return handlers.RequestHandler( @@ -89,7 +80,7 @@ class StaplerServer: ) return server, context is not None - def __upgrade_handler( + def __upgrade_handler( # pragma: no cover self, *args: typing.Any ) -> http.server.BaseHTTPRequestHandler: return handlers.UpgradeHandler( diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..9455f5f --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,102 @@ +import logging +import ssl +import typing +import unittest +import unittest.mock + +from src.cert_manager import CertManager +from src.data_dir import DataDir +from src.params import Parameters +from src.registry import Registry +from src.server import StaplerServer +from src.token_manager import TokenManager + +from . import BaseTestCase + + +class TestStaplerServer(BaseTestCase): + @typing.override + def setUp(self) -> None: + self.server = StaplerServer(Parameters()) + self.server.logger = unittest.mock.Mock(logging.Logger) + self.registry = self.server.registry = self.mock(Registry) + self.cert_manager = self.server.cert_manager = self.mock(CertManager) + self.token_manager = self.server.token_manager = self.mock(TokenManager) + self.data_dir = self.server.data_dir = self.mock(DataDir) + self.server_mock = unittest.mock.MagicMock() + self.context_mock = unittest.mock.Mock(ssl.SSLContext) + super().setUp() + + def test_renew(self) -> None: + with ( + self.mock_call(self.registry.load_pages), + self.mock_calls( + self.registry.get_hosts, [[], []], [["host_1"], ["host_1"]] + ), + self.mock_call(self.cert_manager.init, [["localhost", "host_1"]]), + self.mock_calls( + self.cert_manager.create_or_update, [["localhost"], ["host_1"]] + ), + self.seal_mocks(), + ): + self.assertEqual(self.server.renew(), 0) + + def test_renew_without_certificates(self) -> None: + self.server.params = Parameters(with_certificates=False) + self.seal_mocks() + self.assertEqual(self.server.renew(), 1) + + def test_token(self) -> None: + with ( + self.mock_call(self.registry.load_pages), + self.mock_call(self.token_manager.init), + self.mock_call(self.token_manager.new_token), + self.seal_mocks(), + ): + self.assertEqual(self.server.token(), 0) + + def test_run_http(self) -> None: + self.server.params = Parameters(https=False, with_certificates=False) + with ( + self.mock_call(self.registry.load_pages), + self.mock_call(self.data_dir.init), + self.mock_call(self.token_manager.init), + self.patch("http.server.ThreadingHTTPServer", self.server_mock), + self.mock_call(self.server_mock.serve_forever), + self.seal_mocks(), + ): + self.assertEqual(self.server.run(), 0) + + def test_run_https_fail(self) -> None: + with ( + self.mock_call(self.registry.load_pages), + self.mock_call(self.registry.get_hosts, [], []), + self.mock_call(self.cert_manager.init, [["localhost"]]), + self.mock_call(self.data_dir.init), + self.mock_call(self.token_manager.init), + self.mock_call(self.cert_manager.get_https_context, ["localhost"]), + self.patch("http.server.ThreadingHTTPServer", self.server_mock), + self.mock_call(self.server_mock.serve_forever), + self.seal_mocks(), + ): + self.assertEqual(self.server.run(), 0) + + def test_run_https(self) -> None: + with ( + self.mock_call(self.registry.load_pages), + self.mock_call(self.registry.get_hosts, [], []), + self.mock_call(self.cert_manager.init, [["localhost"]]), + self.mock_call(self.data_dir.init), + self.mock_call(self.token_manager.init), + self.mock_call( + self.cert_manager.get_https_context, + ["localhost"], + self.context_mock, + ), + self.patch("http.server.ThreadingHTTPServer", self.server_mock, 2), + self.mock_call_unchecked(self.context_mock.wrap_socket), + self.mock_calls_unchecked(self.server_mock.serve_forever, 2), + self.mock_call(self.server_mock.shutdown), + self.seal_mocks(self.context_mock), + ): + self.assertEqual(self.server.run(), 0)