diff --git a/src/cert_manager.py b/src/cert_manager.py index 7128c77..d3eb795 100644 --- a/src/cert_manager.py +++ b/src/cert_manager.py @@ -6,7 +6,7 @@ import subprocess import typing if typing.TYPE_CHECKING: - from . import params + from .params import Parameters class CertManagerError(Exception): @@ -14,16 +14,24 @@ class CertManagerError(Exception): class CertManager: + __slots__ = [ + "certbot_conf", + "certbot_www", + "logger", + "self_signed_path", + "with_certbot", + ] + SELF_SIGNED_DAYS = 30 CRT_FILE = "fullchain.pem" KEY_FILE = "privkey.pem" - def __init__(self, params: params.Parameters) -> None: - self.logger = logging.getLogger(self.__class__.__name__) - self.certbot_conf = pathlib.Path(params.certbot_conf) - self.certbot_www = pathlib.Path(params.certbot_www) - self.self_signed_path = pathlib.Path(params.self_signed_path) - self.with_certbot = params.with_certbot + def __init__(self, params: Parameters) -> None: + self.logger: logging.Logger = logging.getLogger(self.__class__.__name__) + self.certbot_conf: pathlib.Path = pathlib.Path(params.certbot_conf) + self.certbot_www: pathlib.Path = pathlib.Path(params.certbot_www) + self.self_signed_path: pathlib.Path = pathlib.Path(params.self_signed_path) + self.with_certbot: bool = params.with_certbot def init(self, hosts: list[str]) -> None: self.logger.debug("Initializing...") diff --git a/src/data_dir.py b/src/data_dir.py index f29253b..37a28d0 100644 --- a/src/data_dir.py +++ b/src/data_dir.py @@ -10,12 +10,17 @@ if typing.TYPE_CHECKING: class DataDir: + __slots__ = [ + "logger", + "root_path", + ] + PATH_REGEX = re.compile(r"^[\w-]+$") NEEDED_FILES: typing.ClassVar[list[str]] = ["favicon.ico"] def __init__(self, root_path: str) -> None: - self.logger = logging.getLogger(self.__class__.__name__) - self.root_path = pathlib.Path(root_path) + self.logger: logging.Logger = logging.getLogger(self.__class__.__name__) + self.root_path: pathlib.Path = pathlib.Path(root_path) def init(self) -> None: self.logger.debug("Initializing...") diff --git a/src/handlers.py b/src/handlers.py index 4fcb9e2..f989942 100644 --- a/src/handlers.py +++ b/src/handlers.py @@ -9,10 +9,14 @@ import re import tarfile import typing -from . import STAPLER_ASCII, data_dir, logs, project +from . import STAPLER_ASCII, logs, project +from .data_dir import DataDir if typing.TYPE_CHECKING: - from . import cert_manager, params, registry, token_manager + from .cert_manager import CertManager + from .params import Parameters + from .registry import Registry + from .token_manager import TokenManager class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): @@ -20,12 +24,14 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): def __init__( self, *args: typing.Any, - params: params.Parameters, + params: Parameters, **kwargs: dict[str, typing.Any], ) -> None: - self.logger = logging.getLogger(self.__class__.__name__) - self.default_host = params.host.split(":", maxsplit=2)[0] - self.out_size = 0 + self.logger: logging.Logger = logging.getLogger(self.__class__.__name__) + self.default_host: str = params.host.split(":", maxsplit=2)[0] + self.out_size: int = 0 + self._host: str | None = None + self._in_size: int | None = None super().__init__(*args, **kwargs) @typing.override @@ -74,7 +80,7 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): code = color + str(code.value) + logs.TermColor.RESET if size == "" and self.out_size > 0: size = str(self.out_size) - args = (code, self.address_string(), self._get_host(), self.requestline) + args = (code, self.address_string(), self.host, self.requestline) fmt = "→ %s - %s - %s - %s" if size != "": args = (*args, size) @@ -117,10 +123,22 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): self.end_headers() self.close_connection = True + @property + def host(self) -> str: + if self._host is None: + self._host = self._get_host() + return self._host + def _get_host(self) -> str: host = self._get_header("Host", self.default_host) return host.split(":", maxsplit=2)[0] + @property + def in_size(self) -> int: + if self._in_size is None: + self._in_size = self._get_length() + return self._in_size + def _get_length(self) -> int: return int(self._get_header("Content-Length", "0")) @@ -137,10 +155,10 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): ) def _pre_log_request(self) -> None: # pragma: no cover - args = ("...", self.address_string(), self._get_host(), self.requestline) + args = ("...", self.address_string(), self.host, self.requestline) fmt = "← %s - %s - %s - %s" - if (size := self._get_length()) > 0: - args = (*args, size) + if self.in_size > 0: + args = (*args, self.in_size) fmt += " - %s" self.logger.debug(fmt, *args) @@ -163,21 +181,39 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): def __init__( self, *args: typing.Any, - params: params.Parameters, - registry: registry.Registry, - cert_manager: cert_manager.CertManager, - token_manager: token_manager.TokenManager, + params: Parameters, + registry: Registry, + cert_manager: CertManager, + token_manager: TokenManager, **kwargs: dict[str, typing.Any], ) -> None: - self.logger = logging.getLogger(self.__class__.__name__) - self.token_manager = token_manager - self.data_dir = data_dir.DataDir(params.data_dir) - self.max_size_bytes = params.max_size_bytes - self.registry = registry - self.cert_manager = cert_manager - self.certbot_www = os.path.realpath(params.certbot_www) + self.logger: logging.Logger = logging.getLogger(self.__class__.__name__) + self.token_manager: TokenManager = token_manager + self.data_dir: DataDir = DataDir(params.data_dir) + self.max_size_bytes: int = params.max_size_bytes + self.registry: Registry = registry + self.cert_manager: CertManager = cert_manager + self.certbot_www: str = os.path.realpath(params.certbot_www) + self._token: str | None = None + self._target_host: str | None = None super().__init__(*args, directory=params.data_dir, **kwargs, params=params) # ty:ignore[unknown-argument] + @property + def token(self) -> str: + if self._token is None: + self._token = self._get_header(self.TOKEN_HEADER) + return self._token + + @property + def target_host(self) -> str: + if self._target_host is None: + self._target_host = self._get_header(self.HOST_HEADER).lower() + return self._target_host + + @property + def has_target_host(self) -> bool: + return len(self.target_host) > 0 + @typing.override def do_HEAD(self) -> None: self._pre_log_request() @@ -186,7 +222,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): @typing.override def do_GET(self) -> None: self._pre_log_request() - if self.path == "/" and self._get_host() == self.default_host: + if self.path == "/" and self.host == self.default_host: return self.send_basic_body(self.server_signature()) super().do_GET() return None @@ -195,30 +231,25 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): self._pre_log_request() if (sub_path := self.__check_update_request()) is None: return None - host: str | None = ( - self._get_header(self.HOST_HEADER).lower() - if self._has_header(self.HOST_HEADER) - else None - ) - if host is not None and not self.__valid_host(host): + if self.has_target_host and not self.__valid_host(self.target_host): return self.send_error( http.HTTPStatus.BAD_REQUEST, "Invalid requested host" ) if ( - host is not None - and (page := self.registry.get_from_host(host)) is not None + self.has_target_host + and (page := self.registry.get_from_host(self.target_host)) is not None and page.path != sub_path ): return self.send_error(http.HTTPStatus.FORBIDDEN, "Host already taken") - if (content_length := self._get_length()) == 0: + if self.in_size == 0: return self.send_error(http.HTTPStatus.LENGTH_REQUIRED, "No body found") - if content_length > self.max_size_bytes: + if self.in_size > self.max_size_bytes: return self.send_error( http.HTTPStatus.CONTENT_TOO_LARGE, "Archive too large", ) try: - file_bytes = io.BytesIO(self.rfile.read(content_length)) + file_bytes = io.BytesIO(self.rfile.read(self.in_size)) self.data_dir.extract_tar_bytes(sub_path, file_bytes) except tarfile.TarError: return self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid tar archive") @@ -229,9 +260,11 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): f"Resource /{sub_path}/ updated", ) self.registry.add(sub_path) - self.token_manager.set_token(self._get_header(self.TOKEN_HEADER), sub_path) - if host is not None and self.cert_manager.create_or_update(host): - self.registry.set_host(sub_path, host) + self.token_manager.set_token(self.token, sub_path) + if self.has_target_host and self.cert_manager.create_or_update( + self.target_host + ): + self.registry.set_host(sub_path, self.target_host) return None def do_DELETE(self) -> None: @@ -261,14 +294,12 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): def translate_path(self, path: str) -> str: if path.startswith(self.CERTBOT_CHALLENGE_PATH): return self.certbot_www + path - host = self._get_host() if ( - host != self.default_host - and (page := self.registry.get_from_host(host := self._get_host())) - is not None + self.host != self.default_host + and (page := self.registry.get_from_host(self.host)) is not None ): path = f"/{page.path}" + path - elif host != self.default_host: + elif self.host != self.default_host: return "" elif ( path not in self.AUTHORIZED_PATHS @@ -283,14 +314,13 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): if not self._has_header(self.TOKEN_HEADER): self.send_error(http.HTTPStatus.BAD_REQUEST, "No X-Token header in request") return None - token = self._get_header(self.TOKEN_HEADER) - if not self.token_manager.is_valid(token): + if not self.token_manager.is_valid(self.token): self.send_error(http.HTTPStatus.UNAUTHORIZED, "Invalid token") return None if (sub_path := self.__get_subpath(self.path, self.UPDATE_PATH_REGEX)) is None: self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid path") return None - if not self.token_manager.is_valid_for_path(token, sub_path): + if not self.token_manager.is_valid_for_path(self.token, sub_path): self.send_error(http.HTTPStatus.FORBIDDEN, "Path forbidden for this token") return None return sub_path @@ -314,7 +344,7 @@ class UpgradeHandler(BaseHandler): self._pre_log_request() self.send_status_only( http.HTTPStatus.MOVED_PERMANENTLY, - headers={"Location": f"https://{self._get_host()}{self.path.lower()}"}, + headers={"Location": f"https://{self.host}{self.path.lower()}"}, ) def do_GET(self) -> None: diff --git a/src/logs.py b/src/logs.py index 32a54c5..2ce841c 100644 --- a/src/logs.py +++ b/src/logs.py @@ -3,7 +3,7 @@ import logging import typing if typing.TYPE_CHECKING: - from . import params + from .params import Parameters class TermColor(enum.StrEnum): @@ -54,7 +54,7 @@ class ColoredLoggingFormatter(logging.Formatter): @typing.override def __init__(self, trace: bool) -> None: - self.trace = trace + self.trace: bool = trace super().__init__() @typing.override @@ -75,7 +75,7 @@ class ColoredLoggingFormatter(logging.Formatter): return formatter.format(record) -def setup_logs(params: params.Parameters) -> None: +def setup_logs(params: Parameters) -> None: stream_handler = logging.StreamHandler() stream_handler.setFormatter(ColoredLoggingFormatter(trace=params.debug)) log_level = logging.INFO diff --git a/src/page.py b/src/page.py index 7bc397e..d648229 100644 --- a/src/page.py +++ b/src/page.py @@ -1,7 +1,7 @@ import dataclasses -@dataclasses.dataclass +@dataclasses.dataclass(slots=True) class Page: path: str with_index: bool = False diff --git a/src/params.py b/src/params.py index a9c6688..531a50d 100644 --- a/src/params.py +++ b/src/params.py @@ -8,7 +8,7 @@ from . import project __EPILOG = "(Each option can be supplied with equivalent environment variable.)" -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, slots=True) class Parameters: debug: bool = False data_dir: str = "./data" diff --git a/src/registry.py b/src/registry.py index b49a8c0..2ae04b9 100644 --- a/src/registry.py +++ b/src/registry.py @@ -1,20 +1,27 @@ import logging import typing -from . import data_dir, page +from .data_dir import DataDir +from .page import Page if typing.TYPE_CHECKING: - from . import params + from .params import Parameters class Registry: + __slots__ = [ + "data_dir", + "logger", + "pages", + ] + HOST_FILE = ".host" TOKEN_FILE = ".token" # noqa: S105 - def __init__(self, params: params.Parameters) -> None: - self.logger = logging.getLogger(self.__class__.__name__) - self.pages: dict[str, page.Page] = {} - self.data_dir = data_dir.DataDir(params.data_dir) + def __init__(self, params: Parameters) -> None: + self.logger: logging.Logger = logging.getLogger(self.__class__.__name__) + self.pages: dict[str, Page] = {} + self.data_dir = DataDir(params.data_dir) def load_pages(self) -> None: self.pages = {} @@ -25,7 +32,7 @@ class Registry: return [p.host for p in self.pages.values() if p.host is not None] def add(self, path: str) -> None: - self.pages[path] = page.Page( + self.pages[path] = Page( path, self.data_dir.has_index(path), self.data_dir.get_file(path, self.HOST_FILE), @@ -51,12 +58,12 @@ class Registry: del self.pages[path] self.logger.info("Removed %s", page) - def get_from_path(self, path: str) -> page.Page | None: + def get_from_path(self, path: str) -> Page | None: if path in self.pages: return self.pages[path] return None - def get_from_host(self, host: str) -> page.Page | None: + def get_from_host(self, host: str) -> Page | None: for p in self.pages.values(): if p.host == host: return p diff --git a/src/server.py b/src/server.py index eecb27e..e532f60 100644 --- a/src/server.py +++ b/src/server.py @@ -6,27 +6,38 @@ import typing from . import ( STAPLER_ASCII, - cert_manager, - data_dir, - handlers, project, - registry, - token_manager, ) +from .cert_manager import CertManager +from .data_dir import DataDir +from .handlers import RequestHandler, UpgradeHandler +from .params import Parameters +from .registry import Registry +from .token_manager import TokenManager if typing.TYPE_CHECKING: - from . import params + from .params import Parameters class StaplerServer: - def __init__(self, params: params.Parameters) -> None: - self.logger = logging.getLogger(self.__class__.__name__) - self.params = params - self.registry = registry.Registry(params) - self.cert_manager = cert_manager.CertManager(params) - self.token_manager = token_manager.TokenManager(params, self.registry) - self.data_dir = data_dir.DataDir(params.data_dir) - self.default_host = params.host.split(":", maxsplit=2)[0] + __slots__ = [ + "cert_manager", + "data_dir", + "default_host", + "logger", + "params", + "registry", + "token_manager", + ] + + def __init__(self, params: Parameters) -> None: + self.logger: logging.Logger = logging.getLogger(self.__class__.__name__) + self.params: Parameters = params + self.registry: Registry = Registry(params) + self.cert_manager: CertManager = CertManager(params) + self.token_manager: TokenManager = TokenManager(params, self.registry) + self.data_dir: DataDir = DataDir(params.data_dir) + self.default_host: str = params.host.split(":", maxsplit=2)[0] def __get_all_hosts(self) -> list[str]: return [self.default_host, *self.registry.get_hosts()] @@ -42,7 +53,7 @@ class StaplerServer: def __request_handler( # pragma: no cover self, *args: typing.Any ) -> http.server.BaseHTTPRequestHandler: - return handlers.RequestHandler( + return RequestHandler( *args, params=self.params, registry=self.registry, @@ -83,7 +94,7 @@ class StaplerServer: def __upgrade_handler( # pragma: no cover self, *args: typing.Any ) -> http.server.BaseHTTPRequestHandler: - return handlers.UpgradeHandler( + return UpgradeHandler( *args, params=self.params, ) diff --git a/src/token_manager.py b/src/token_manager.py index e388b5c..9c52671 100644 --- a/src/token_manager.py +++ b/src/token_manager.py @@ -7,17 +7,26 @@ import typing from . import project if typing.TYPE_CHECKING: - from . import params, registry + from .params import Parameters + from .registry import Registry class TokenManager: + __slots__ = [ + "logger", + "registry", + "token_hashes", + "token_salt", + "tokens_file", + ] + FILE = ".tokens" - def __init__(self, params: params.Parameters, registry: registry.Registry) -> None: - self.logger = logging.getLogger(self.__class__.__name__) - self.token_salt = params.token_salt - self.tokens_file = pathlib.Path(params.data_dir) / self.FILE - self.registry = registry + def __init__(self, params: Parameters, registry: Registry) -> None: + self.logger: logging.Logger = logging.getLogger(self.__class__.__name__) + self.token_salt: str = params.token_salt + self.tokens_file: pathlib.Path = pathlib.Path(params.data_dir) / self.FILE + self.registry: Registry = registry self.token_hashes: list[str] = [] def init(self) -> None: diff --git a/tests/__init__.py b/tests/__init__.py index ed71692..6aaa51f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -11,7 +11,7 @@ class BaseTestCase(unittest.TestCase): def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: self.mocks: list[unittest.mock.Mock] = [] self.tmp_dir: tempfile.TemporaryDirectory | None = None - self.tmp_path = pathlib.Path() + self.tmp_path: pathlib.Path = pathlib.Path() super().__init__(*args, **kwargs) @typing.override