diff --git a/stapler/cert_manager.py b/stapler/cert_manager.py index cd21372..c9adb72 100644 --- a/stapler/cert_manager.py +++ b/stapler/cert_manager.py @@ -122,7 +122,9 @@ class CertManager: ) self.logger.info("Created self-signed certificate for %s", host) except CertManagerError: - self.logger.exception("Could not create certbot certificate for %s\n%s") + self.logger.exception( + "Could not create self-signed certificate for %s", host + ) return False except subprocess.CalledProcessError as e: self.logger.exception( @@ -172,7 +174,7 @@ class CertManager: ) self.logger.info("Created certbot certificate for %s", host) except CertManagerError: - self.logger.exception("Could not create certbot certificate for %s\n%s") + self.logger.exception("Could not create certbot certificate for %s", host) return False except subprocess.CalledProcessError as e: self.logger.exception( diff --git a/stapler/handlers.py b/stapler/handlers.py index 7c78f47..1add5f1 100644 --- a/stapler/handlers.py +++ b/stapler/handlers.py @@ -1,4 +1,5 @@ import abc +import contextlib import http import http.cookiejar import http.server @@ -223,6 +224,14 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): def server_signature(self) -> str: return self.server_version + "\n\n" + STAPLER_ASCII + "\n" + @contextlib.contextmanager + def handle_errors(self) -> typing.Iterator[None]: + try: + yield + except Exception as e: + self.send_error(http.HTTPStatus.INTERNAL_SERVER_ERROR, str(e)) + self.logger.exception("Internal Server Error") + class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): protocol_version = "HTTP/1.1" @@ -335,42 +344,45 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): @typing.override def do_HEAD(self) -> None: - self._pre_log_request() - if not self._proxy_or_redirect(): - super().do_HEAD() + with self.handle_errors(): + self._pre_log_request() + if not self._proxy_or_redirect(): + super().do_HEAD() @typing.override def do_GET(self) -> None: - self._pre_log_request() - if self._proxy_or_redirect(): - return None - if self.path == "/" and self.host == self.default_host: - return self.send_basic_body(self.server_signature()) - return super().do_GET() + with self.handle_errors(): + self._pre_log_request() + if self._proxy_or_redirect(): + return None + if self.path == "/" and self.host == self.default_host: + return self.send_basic_body(self.server_signature()) + return super().do_GET() def do_PUT(self) -> None: - self._pre_log_request() - if self._proxy_or_redirect(): - return - if (path := self.__check_put_request()) is None: - return - if self.has_target_redirect: - if not self._update_redirect(path): + with self.handle_errors(): + self._pre_log_request() + if self._proxy_or_redirect(): return - elif self.has_target_proxy: - if not self._update_proxy(path): + if (path := self.__check_put_request()) is None: return - elif not self._update_extract(path): - return - if self.has_request_host: - self.registry.set_host(path, self.target_host) - if self.has_request_host_only: - self.registry.set_host_only(path, self.target_host) - self.send_status( - http.HTTPStatus.CREATED, - "Resource updated", - str(self.registry.get_from_path(path)), - ) + if self.has_target_redirect: + if not self._update_redirect(path): + return + elif self.has_target_proxy: + if not self._update_proxy(path): + return + elif not self._update_extract(path): + return + if self.has_request_host: + self.registry.set_host(path, self.target_host) + if self.has_request_host_only: + self.registry.set_host_only(path, self.target_host) + self.send_status( + http.HTTPStatus.CREATED, + "Resource updated", + str(self.registry.get_from_path(path)), + ) def do_POST(self) -> None: self.do_PUT() # be gentle on them @@ -379,32 +391,36 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): self.do_PUT() # be gentle on them def do_DELETE(self) -> None: - self._pre_log_request() - if self._proxy_or_redirect(): + with self.handle_errors(): + self._pre_log_request() + if self._proxy_or_redirect(): + return + if (path := self.__check_update_request()) is None: + return + if self._update_remove(path): + self.send_status( + http.HTTPStatus.OK, + f"Resource /{path}/ removed", + ) return - if (path := self.__check_update_request()) is None: - return - if self._update_remove(path): - self.send_status( - http.HTTPStatus.OK, - f"Resource /{path}/ removed", - ) - return def do_CONNECT(self) -> None: - self._pre_log_request() - if not self._proxy_or_redirect(): - self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) + with self.handle_errors(): + self._pre_log_request() + if not self._proxy_or_redirect(): + self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) def do_OPTIONS(self) -> None: - self._pre_log_request() - if not self._proxy_or_redirect(): - self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) + with self.handle_errors(): + self._pre_log_request() + if not self._proxy_or_redirect(): + self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) def do_TRACE(self) -> None: - self._pre_log_request() - if not self._proxy_or_redirect(): - self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) + with self.handle_errors(): + self._pre_log_request() + if not self._proxy_or_redirect(): + self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) def _update_extract(self, path: str) -> bool: if self.in_size == 0: @@ -422,9 +438,6 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): except tarfile.TarError: self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid tar archive") return False - except Exception as e: - self.send_error(http.HTTPStatus.INTERNAL_SERVER_ERROR, str(e)) - return False self.registry.add(path) self.token_manager.set_token(path, self.token) if self.has_target_spa: @@ -457,11 +470,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): if not self.data_dir.exists(path): self.send_error(http.HTTPStatus.NOT_FOUND, "Not found") return False - try: - self.data_dir.remove(path) - except Exception as e: - self.send_error(http.HTTPStatus.INTERNAL_SERVER_ERROR, str(e)) - return False + self.data_dir.remove(path) self.registry.remove(path) return True @@ -578,11 +587,13 @@ class UpgradeHandler(RequestHandler): server_version = "StaplerUpgradeServer/" + PKG_VERSION def do_HEAD(self) -> None: - self._pre_log_request() - self.send_redirect(f"https://{self.host}{self.path}") + with self.handle_errors(): + self._pre_log_request() + self.send_redirect(f"https://{self.host}{self.path}") def do_GET(self) -> None: - if self.path.startswith(self.CERTBOT_CHALLENGE_PATH): - super().do_GET() - else: - self.do_HEAD() + with self.handle_errors(): + if self.path.startswith(self.CERTBOT_CHALLENGE_PATH): + super().do_GET() + else: + self.do_HEAD()