diff --git a/stapler/cert_manager.py b/stapler/cert_manager.py index 647bd31..2b8525b 100644 --- a/stapler/cert_manager.py +++ b/stapler/cert_manager.py @@ -200,6 +200,7 @@ class CertManager: cert_file, key_file, ) + new_context.set_alpn_protocols(["http/1.1"]) socket.context = new_context except Exception: self.logger.exception("Could not create HTTPS context for %s", host) diff --git a/stapler/handlers.py b/stapler/handlers.py index d7bb5cc..f802ab6 100644 --- a/stapler/handlers.py +++ b/stapler/handlers.py @@ -27,6 +27,7 @@ if typing.TYPE_CHECKING: class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): timeout = 10 + protocol_version = "HTTP/1.1" REQUEST_COUNT = 0 @typing.override @@ -111,6 +112,8 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): @typing.override def log_request(self, code: str = "?", size: str = "-") -> None: # ty:ignore[invalid-method-override] # pragma: no cover if isinstance(code, http.HTTPStatus): + code = code.value + if isinstance(code, int): color = logs.TermColor.RED if 100 <= code < 200: color = logs.TermColor.CYAN @@ -120,7 +123,7 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): color = logs.TermColor.BLUE elif 400 <= code < 500: color = logs.TermColor.YELLOW - code = color + str(code.value) + logs.TermColor.RESET + code = color + str(code) + logs.TermColor.RESET if size == "" and self.out_size > 0: size = str(self.out_size) args = ( @@ -193,6 +196,7 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): headers=headers, allow_redirects=False, timeout=480, + stream=False, ) except Exception as e: self.send_error( @@ -385,6 +389,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): self._pre_log_request() if not self._proxy_or_redirect(): super().do_HEAD() + self.close_connection = True @typing.override def do_GET(self) -> None: @@ -394,7 +399,9 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): return None if self.path == "/" and self.host == self.default_host: return self.send_basic_body(self.server_signature()) - return super().do_GET() + super().do_GET() + self.close_connection = True + return None def do_PUT(self) -> None: with self.handle_errors(): @@ -620,16 +627,19 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): class UpgradeHandler(RequestHandler): + protocol_version = "HTTP/1.0" server_version = "StaplerUpgradeServer/" + PKG_VERSION def do_HEAD(self) -> None: with self.handle_errors(): self._pre_log_request() self.send_redirect(f"https://{self.host}{self.path}") + self.close_connection = True def do_GET(self) -> None: with self.handle_errors(): if self.path.startswith(self.CERTBOT_CHALLENGE_PATH): super().do_GET() + self.close_connection = True else: self.do_HEAD() diff --git a/stapler/server.py b/stapler/server.py index 7400913..da5baf2 100644 --- a/stapler/server.py +++ b/stapler/server.py @@ -29,7 +29,6 @@ class StaplerServer: "logger", "params", "registry", - "server", "token_manager", ] @@ -41,7 +40,6 @@ class StaplerServer: self.token_manager: TokenManager = TokenManager(params, self.registry) self.data_dir: DataDir = DataDir(params.data_dir) self.default_host: str = params.host.split(":", maxsplit=2)[0] - self.server: http.server.ThreadingHTTPServer | None = None def __get_all_hosts(self) -> list[str]: return [self.default_host, *self.registry.get_hosts()] @@ -131,7 +129,7 @@ class StaplerServer: for line in STAPLER_ASCII.split("\n"): self.logger.debug(line.ljust(36)) self.__startup() - self.server = self.__create_base_server() + base_server = self.__create_base_server() upgrade_server = self.__start_upgrade_server() if self.params.https else None self.logger.info( "Server up and ready on %s://%s", @@ -140,7 +138,7 @@ class StaplerServer: ) self.__start_background_tasks() with contextlib.suppress(KeyboardInterrupt): - self.server.serve_forever() + base_server.serve_forever() self.logger.info("Shutting down...") if upgrade_server is not None: upgrade_server.shutdown() diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 038c0d8..7d7b960 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -859,6 +859,7 @@ class TestRequestHandler(BaseHandlerTestCase): }, "allow_redirects": False, "timeout": 480, + "stream": False, }, ), self.expects_status_only(handler, 200, "OK"), @@ -903,6 +904,7 @@ class TestRequestHandler(BaseHandlerTestCase): }, "allow_redirects": False, "timeout": 480, + "stream": False, }, ), self.expects_status_only(handler, 200, "OK"), @@ -945,6 +947,7 @@ class TestRequestHandler(BaseHandlerTestCase): }, "allow_redirects": False, "timeout": 480, + "stream": False, }, ), self.expects_basic_body(handler, "hello", message="OK"), @@ -979,6 +982,7 @@ class TestRequestHandler(BaseHandlerTestCase): }, "allow_redirects": False, "timeout": 480, + "stream": False, }, ) as request_mock, self.expects_status_only( @@ -1022,6 +1026,7 @@ class TestRequestHandler(BaseHandlerTestCase): }, "allow_redirects": False, "timeout": 480, + "stream": False, }, ), self.expects_status_only(handler, 200, "OK"), @@ -1062,6 +1067,7 @@ class TestRequestHandler(BaseHandlerTestCase): }, "allow_redirects": False, "timeout": 480, + "stream": False, }, ), self.expects_status_only(handler, 200, "OK"),