From d3d98bd9b26c8407816571bf837be1688deeed71 Mon Sep 17 00:00:00 2001 From: klemek Date: Mon, 11 May 2026 17:26:21 +0200 Subject: [PATCH] feat: handle HEAD requests --- stapler/handlers.py | 13 +++++++++---- tests/test_handlers.py | 17 +++++++++++++++-- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/stapler/handlers.py b/stapler/handlers.py index f802ab6..870015c 100644 --- a/stapler/handlers.py +++ b/stapler/handlers.py @@ -153,7 +153,8 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): self.send_header("Content-Length", str(len(encoded))) self.send_header("Connection", "close") self.end_headers() - self.wfile.write(encoded) + if self.command != http.HTTPMethod.HEAD: + self.wfile.write(encoded) self.close_connection = True def send_status_only( @@ -222,7 +223,7 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): self.send_header("Content-Length", str(out_size := len(response.content))) self.send_header("Connection", "close") self.end_headers() - if out_size > 0: + if out_size > 0 and self.command != http.HTTPMethod.HEAD: self.wfile.write(response.content) self.close_connection = True @@ -387,9 +388,13 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): def do_HEAD(self) -> None: with self.handle_errors(): self._pre_log_request() - if not self._proxy_or_redirect(): - super().do_HEAD() + if self._proxy_or_redirect(): + return None + if self.path == "/" and self.host == self.default_host: + return self.send_basic_body(self.server_signature()) + super().do_HEAD() self.close_connection = True + return None @typing.override def do_GET(self) -> None: diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 7d7b960..d25dcd2 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -36,6 +36,7 @@ class BaseHandlerTestCase(BaseTestCase, abc.ABC): code: int, message: str | None = None, headers: dict[str, str] | None = None, + content_length: int = 0, ) -> typing.Iterator[None]: if headers is None: headers = {} @@ -46,7 +47,7 @@ class BaseHandlerTestCase(BaseTestCase, abc.ABC): send_response_mock.assert_called_once_with(code, message) send_header_mock.assert_has_calls( [ - unittest.mock.call("Content-Length", "0"), + unittest.mock.call("Content-Length", str(content_length)), ] + [unittest.mock.call(header, value) for header, value in headers.items()], any_order=True, @@ -192,9 +193,21 @@ class TestRequestHandler(BaseHandlerTestCase): token_manager=self.token_manager, ) - def test_do_head_forward(self) -> None: + def test_do_head_index(self) -> None: handler = self._get_handler() with ( + self.expects_status_only( + handler, 200, content_length=len(handler.server_signature()) + ), + self.patch("http.server.SimpleHTTPRequestHandler.do_HEAD", count=0), + self.seal_mocks(), + ): + handler.do_HEAD() + + def test_do_head_forward(self) -> None: + handler = self._get_handler("/file") + with ( + self.mock_call(self.registry.get_from_path, ["file"], Page("file")), self.patch("http.server.SimpleHTTPRequestHandler.do_HEAD"), self.seal_mocks(), ):