diff --git a/stapler/handlers.py b/stapler/handlers.py index f14e6db..87a849c 100644 --- a/stapler/handlers.py +++ b/stapler/handlers.py @@ -538,7 +538,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): return self.registry.get_from_host(self.host) -class UpgradeHandler(BaseHandler): +class UpgradeHandler(RequestHandler): server_version = "StaplerUpgradeServer/" + PKG_VERSION def do_HEAD(self) -> None: @@ -546,4 +546,7 @@ class UpgradeHandler(BaseHandler): self.send_redirect(f"https://{self.host}{self.path}") def do_GET(self) -> None: - self.do_HEAD() + if self.path.startswith(self.CERTBOT_CHALLENGE_PATH): + super().do_GET() + else: + self.do_HEAD() diff --git a/stapler/server.py b/stapler/server.py index 08b8d47..4f27546 100644 --- a/stapler/server.py +++ b/stapler/server.py @@ -97,6 +97,8 @@ class StaplerServer: return UpgradeHandler( *args, params=self.params, + registry=self.registry, + token_manager=self.token_manager, ) def __start_upgrade_server(self) -> http.server.ThreadingHTTPServer: diff --git a/tests/test_handlers.py b/tests/test_handlers.py index e74d6db..db9daf8 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1168,6 +1168,11 @@ class TestRequestHandler(BaseHandlerTestCase): class TestUpgradeHandler(BaseHandlerTestCase): + @typing.override + def setUp(self) -> None: + self.data_dir = self.new_mock() + super().setUp() + def _get_handler( self, path: str = "/", @@ -1183,6 +1188,8 @@ class TestUpgradeHandler(BaseHandlerTestCase): "127.0.0.1", unittest.mock.MagicMock(), params=Parameters(), + registry=self.new_mock(), + token_manager=self.new_mock(), ) handler.address_string = lambda: "127.0.0.1" # ty:ignore[invalid-assignment] handler.requestline = f"{method} {path}" @@ -1193,22 +1200,43 @@ class TestUpgradeHandler(BaseHandlerTestCase): handler.rfile = rfile if rfile is not None else io.BytesIO() handler.wfile = io.BytesIO() handler.logger = unittest.mock.Mock(logging.Logger) + handler.data_dir = self.data_dir return handler def test_do_get(self) -> None: handler = self._get_handler("/file") - with self.expects_status_only( - handler, - http.HTTPStatus.MOVED_PERMANENTLY, - headers={"Location": "https://localhost/file"}, + with ( + self.expects_status_only( + handler, + http.HTTPStatus.MOVED_PERMANENTLY, + headers={"Location": "https://localhost/file"}, + ), + self.patch( + "http.server.SimpleHTTPRequestHandler.do_GET", + count=0, + ), + self.seal_mocks(), + ): + handler.do_GET() + + def test_do_get_certbot(self) -> None: + handler = self._get_handler("/.well-known/acme-challenge/abcde") + with ( + self.patch( + "http.server.SimpleHTTPRequestHandler.do_GET", + ), + self.seal_mocks(), ): handler.do_GET() def test_do_head(self) -> None: handler = self._get_handler("/file") - with self.expects_status_only( - handler, - http.HTTPStatus.MOVED_PERMANENTLY, - headers={"Location": "https://localhost/file"}, + with ( + self.expects_status_only( + handler, + http.HTTPStatus.MOVED_PERMANENTLY, + headers={"Location": "https://localhost/file"}, + ), + self.seal_mocks(), ): handler.do_HEAD()