refactor: use slots and strongly typed properties

This commit is contained in:
2026-04-20 10:48:58 +02:00
parent fc7d3cb0e8
commit 5fb10ffb9d
10 changed files with 161 additions and 91 deletions
+15 -7
View File
@@ -6,7 +6,7 @@ import subprocess
import typing import typing
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from . import params from .params import Parameters
class CertManagerError(Exception): class CertManagerError(Exception):
@@ -14,16 +14,24 @@ class CertManagerError(Exception):
class CertManager: class CertManager:
__slots__ = [
"certbot_conf",
"certbot_www",
"logger",
"self_signed_path",
"with_certbot",
]
SELF_SIGNED_DAYS = 30 SELF_SIGNED_DAYS = 30
CRT_FILE = "fullchain.pem" CRT_FILE = "fullchain.pem"
KEY_FILE = "privkey.pem" KEY_FILE = "privkey.pem"
def __init__(self, params: params.Parameters) -> None: def __init__(self, params: Parameters) -> None:
self.logger = logging.getLogger(self.__class__.__name__) self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self.certbot_conf = pathlib.Path(params.certbot_conf) self.certbot_conf: pathlib.Path = pathlib.Path(params.certbot_conf)
self.certbot_www = pathlib.Path(params.certbot_www) self.certbot_www: pathlib.Path = pathlib.Path(params.certbot_www)
self.self_signed_path = pathlib.Path(params.self_signed_path) self.self_signed_path: pathlib.Path = pathlib.Path(params.self_signed_path)
self.with_certbot = params.with_certbot self.with_certbot: bool = params.with_certbot
def init(self, hosts: list[str]) -> None: def init(self, hosts: list[str]) -> None:
self.logger.debug("Initializing...") self.logger.debug("Initializing...")
+7 -2
View File
@@ -10,12 +10,17 @@ if typing.TYPE_CHECKING:
class DataDir: class DataDir:
__slots__ = [
"logger",
"root_path",
]
PATH_REGEX = re.compile(r"^[\w-]+$") PATH_REGEX = re.compile(r"^[\w-]+$")
NEEDED_FILES: typing.ClassVar[list[str]] = ["favicon.ico"] NEEDED_FILES: typing.ClassVar[list[str]] = ["favicon.ico"]
def __init__(self, root_path: str) -> None: def __init__(self, root_path: str) -> None:
self.logger = logging.getLogger(self.__class__.__name__) self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self.root_path = pathlib.Path(root_path) self.root_path: pathlib.Path = pathlib.Path(root_path)
def init(self) -> None: def init(self) -> None:
self.logger.debug("Initializing...") self.logger.debug("Initializing...")
+75 -45
View File
@@ -9,10 +9,14 @@ import re
import tarfile import tarfile
import typing import typing
from . import STAPLER_ASCII, data_dir, logs, project from . import STAPLER_ASCII, logs, project
from .data_dir import DataDir
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from . import cert_manager, params, registry, token_manager from .cert_manager import CertManager
from .params import Parameters
from .registry import Registry
from .token_manager import TokenManager
class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler):
@@ -20,12 +24,14 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler):
def __init__( def __init__(
self, self,
*args: typing.Any, *args: typing.Any,
params: params.Parameters, params: Parameters,
**kwargs: dict[str, typing.Any], **kwargs: dict[str, typing.Any],
) -> None: ) -> None:
self.logger = logging.getLogger(self.__class__.__name__) self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self.default_host = params.host.split(":", maxsplit=2)[0] self.default_host: str = params.host.split(":", maxsplit=2)[0]
self.out_size = 0 self.out_size: int = 0
self._host: str | None = None
self._in_size: int | None = None
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@typing.override @typing.override
@@ -74,7 +80,7 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler):
code = color + str(code.value) + logs.TermColor.RESET code = color + str(code.value) + logs.TermColor.RESET
if size == "" and self.out_size > 0: if size == "" and self.out_size > 0:
size = str(self.out_size) size = str(self.out_size)
args = (code, self.address_string(), self._get_host(), self.requestline) args = (code, self.address_string(), self.host, self.requestline)
fmt = "%s - %s - %s - %s" fmt = "%s - %s - %s - %s"
if size != "": if size != "":
args = (*args, size) args = (*args, size)
@@ -117,10 +123,22 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
self.close_connection = True self.close_connection = True
@property
def host(self) -> str:
if self._host is None:
self._host = self._get_host()
return self._host
def _get_host(self) -> str: def _get_host(self) -> str:
host = self._get_header("Host", self.default_host) host = self._get_header("Host", self.default_host)
return host.split(":", maxsplit=2)[0] return host.split(":", maxsplit=2)[0]
@property
def in_size(self) -> int:
if self._in_size is None:
self._in_size = self._get_length()
return self._in_size
def _get_length(self) -> int: def _get_length(self) -> int:
return int(self._get_header("Content-Length", "0")) return int(self._get_header("Content-Length", "0"))
@@ -137,10 +155,10 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler):
) )
def _pre_log_request(self) -> None: # pragma: no cover def _pre_log_request(self) -> None: # pragma: no cover
args = ("...", self.address_string(), self._get_host(), self.requestline) args = ("...", self.address_string(), self.host, self.requestline)
fmt = "%s - %s - %s - %s" fmt = "%s - %s - %s - %s"
if (size := self._get_length()) > 0: if self.in_size > 0:
args = (*args, size) args = (*args, self.in_size)
fmt += " - %s" fmt += " - %s"
self.logger.debug(fmt, *args) self.logger.debug(fmt, *args)
@@ -163,21 +181,39 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
def __init__( def __init__(
self, self,
*args: typing.Any, *args: typing.Any,
params: params.Parameters, params: Parameters,
registry: registry.Registry, registry: Registry,
cert_manager: cert_manager.CertManager, cert_manager: CertManager,
token_manager: token_manager.TokenManager, token_manager: TokenManager,
**kwargs: dict[str, typing.Any], **kwargs: dict[str, typing.Any],
) -> None: ) -> None:
self.logger = logging.getLogger(self.__class__.__name__) self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self.token_manager = token_manager self.token_manager: TokenManager = token_manager
self.data_dir = data_dir.DataDir(params.data_dir) self.data_dir: DataDir = DataDir(params.data_dir)
self.max_size_bytes = params.max_size_bytes self.max_size_bytes: int = params.max_size_bytes
self.registry = registry self.registry: Registry = registry
self.cert_manager = cert_manager self.cert_manager: CertManager = cert_manager
self.certbot_www = os.path.realpath(params.certbot_www) self.certbot_www: str = os.path.realpath(params.certbot_www)
self._token: str | None = None
self._target_host: str | None = None
super().__init__(*args, directory=params.data_dir, **kwargs, params=params) # ty:ignore[unknown-argument] super().__init__(*args, directory=params.data_dir, **kwargs, params=params) # ty:ignore[unknown-argument]
@property
def token(self) -> str:
if self._token is None:
self._token = self._get_header(self.TOKEN_HEADER)
return self._token
@property
def target_host(self) -> str:
if self._target_host is None:
self._target_host = self._get_header(self.HOST_HEADER).lower()
return self._target_host
@property
def has_target_host(self) -> bool:
return len(self.target_host) > 0
@typing.override @typing.override
def do_HEAD(self) -> None: def do_HEAD(self) -> None:
self._pre_log_request() self._pre_log_request()
@@ -186,7 +222,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
@typing.override @typing.override
def do_GET(self) -> None: def do_GET(self) -> None:
self._pre_log_request() self._pre_log_request()
if self.path == "/" and self._get_host() == self.default_host: if self.path == "/" and self.host == self.default_host:
return self.send_basic_body(self.server_signature()) return self.send_basic_body(self.server_signature())
super().do_GET() super().do_GET()
return None return None
@@ -195,30 +231,25 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
self._pre_log_request() self._pre_log_request()
if (sub_path := self.__check_update_request()) is None: if (sub_path := self.__check_update_request()) is None:
return None return None
host: str | None = ( if self.has_target_host and not self.__valid_host(self.target_host):
self._get_header(self.HOST_HEADER).lower()
if self._has_header(self.HOST_HEADER)
else None
)
if host is not None and not self.__valid_host(host):
return self.send_error( return self.send_error(
http.HTTPStatus.BAD_REQUEST, "Invalid requested host" http.HTTPStatus.BAD_REQUEST, "Invalid requested host"
) )
if ( if (
host is not None self.has_target_host
and (page := self.registry.get_from_host(host)) is not None and (page := self.registry.get_from_host(self.target_host)) is not None
and page.path != sub_path and page.path != sub_path
): ):
return self.send_error(http.HTTPStatus.FORBIDDEN, "Host already taken") return self.send_error(http.HTTPStatus.FORBIDDEN, "Host already taken")
if (content_length := self._get_length()) == 0: if self.in_size == 0:
return self.send_error(http.HTTPStatus.LENGTH_REQUIRED, "No body found") return self.send_error(http.HTTPStatus.LENGTH_REQUIRED, "No body found")
if content_length > self.max_size_bytes: if self.in_size > self.max_size_bytes:
return self.send_error( return self.send_error(
http.HTTPStatus.CONTENT_TOO_LARGE, http.HTTPStatus.CONTENT_TOO_LARGE,
"Archive too large", "Archive too large",
) )
try: try:
file_bytes = io.BytesIO(self.rfile.read(content_length)) file_bytes = io.BytesIO(self.rfile.read(self.in_size))
self.data_dir.extract_tar_bytes(sub_path, file_bytes) self.data_dir.extract_tar_bytes(sub_path, file_bytes)
except tarfile.TarError: except tarfile.TarError:
return self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid tar archive") return self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid tar archive")
@@ -229,9 +260,11 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
f"Resource /{sub_path}/ updated", f"Resource /{sub_path}/ updated",
) )
self.registry.add(sub_path) self.registry.add(sub_path)
self.token_manager.set_token(self._get_header(self.TOKEN_HEADER), sub_path) self.token_manager.set_token(self.token, sub_path)
if host is not None and self.cert_manager.create_or_update(host): if self.has_target_host and self.cert_manager.create_or_update(
self.registry.set_host(sub_path, host) self.target_host
):
self.registry.set_host(sub_path, self.target_host)
return None return None
def do_DELETE(self) -> None: def do_DELETE(self) -> None:
@@ -261,14 +294,12 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
def translate_path(self, path: str) -> str: def translate_path(self, path: str) -> str:
if path.startswith(self.CERTBOT_CHALLENGE_PATH): if path.startswith(self.CERTBOT_CHALLENGE_PATH):
return self.certbot_www + path return self.certbot_www + path
host = self._get_host()
if ( if (
host != self.default_host self.host != self.default_host
and (page := self.registry.get_from_host(host := self._get_host())) and (page := self.registry.get_from_host(self.host)) is not None
is not None
): ):
path = f"/{page.path}" + path path = f"/{page.path}" + path
elif host != self.default_host: elif self.host != self.default_host:
return "" return ""
elif ( elif (
path not in self.AUTHORIZED_PATHS path not in self.AUTHORIZED_PATHS
@@ -283,14 +314,13 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
if not self._has_header(self.TOKEN_HEADER): if not self._has_header(self.TOKEN_HEADER):
self.send_error(http.HTTPStatus.BAD_REQUEST, "No X-Token header in request") self.send_error(http.HTTPStatus.BAD_REQUEST, "No X-Token header in request")
return None return None
token = self._get_header(self.TOKEN_HEADER) if not self.token_manager.is_valid(self.token):
if not self.token_manager.is_valid(token):
self.send_error(http.HTTPStatus.UNAUTHORIZED, "Invalid token") self.send_error(http.HTTPStatus.UNAUTHORIZED, "Invalid token")
return None return None
if (sub_path := self.__get_subpath(self.path, self.UPDATE_PATH_REGEX)) is None: if (sub_path := self.__get_subpath(self.path, self.UPDATE_PATH_REGEX)) is None:
self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid path") self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid path")
return None return None
if not self.token_manager.is_valid_for_path(token, sub_path): if not self.token_manager.is_valid_for_path(self.token, sub_path):
self.send_error(http.HTTPStatus.FORBIDDEN, "Path forbidden for this token") self.send_error(http.HTTPStatus.FORBIDDEN, "Path forbidden for this token")
return None return None
return sub_path return sub_path
@@ -314,7 +344,7 @@ class UpgradeHandler(BaseHandler):
self._pre_log_request() self._pre_log_request()
self.send_status_only( self.send_status_only(
http.HTTPStatus.MOVED_PERMANENTLY, http.HTTPStatus.MOVED_PERMANENTLY,
headers={"Location": f"https://{self._get_host()}{self.path.lower()}"}, headers={"Location": f"https://{self.host}{self.path.lower()}"},
) )
def do_GET(self) -> None: def do_GET(self) -> None:
+3 -3
View File
@@ -3,7 +3,7 @@ import logging
import typing import typing
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from . import params from .params import Parameters
class TermColor(enum.StrEnum): class TermColor(enum.StrEnum):
@@ -54,7 +54,7 @@ class ColoredLoggingFormatter(logging.Formatter):
@typing.override @typing.override
def __init__(self, trace: bool) -> None: def __init__(self, trace: bool) -> None:
self.trace = trace self.trace: bool = trace
super().__init__() super().__init__()
@typing.override @typing.override
@@ -75,7 +75,7 @@ class ColoredLoggingFormatter(logging.Formatter):
return formatter.format(record) return formatter.format(record)
def setup_logs(params: params.Parameters) -> None: def setup_logs(params: Parameters) -> None:
stream_handler = logging.StreamHandler() stream_handler = logging.StreamHandler()
stream_handler.setFormatter(ColoredLoggingFormatter(trace=params.debug)) stream_handler.setFormatter(ColoredLoggingFormatter(trace=params.debug))
log_level = logging.INFO log_level = logging.INFO
+1 -1
View File
@@ -1,7 +1,7 @@
import dataclasses import dataclasses
@dataclasses.dataclass @dataclasses.dataclass(slots=True)
class Page: class Page:
path: str path: str
with_index: bool = False with_index: bool = False
+1 -1
View File
@@ -8,7 +8,7 @@ from . import project
__EPILOG = "(Each option can be supplied with equivalent environment variable.)" __EPILOG = "(Each option can be supplied with equivalent environment variable.)"
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True, slots=True)
class Parameters: class Parameters:
debug: bool = False debug: bool = False
data_dir: str = "./data" data_dir: str = "./data"
+16 -9
View File
@@ -1,20 +1,27 @@
import logging import logging
import typing import typing
from . import data_dir, page from .data_dir import DataDir
from .page import Page
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from . import params from .params import Parameters
class Registry: class Registry:
__slots__ = [
"data_dir",
"logger",
"pages",
]
HOST_FILE = ".host" HOST_FILE = ".host"
TOKEN_FILE = ".token" # noqa: S105 TOKEN_FILE = ".token" # noqa: S105
def __init__(self, params: params.Parameters) -> None: def __init__(self, params: Parameters) -> None:
self.logger = logging.getLogger(self.__class__.__name__) self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self.pages: dict[str, page.Page] = {} self.pages: dict[str, Page] = {}
self.data_dir = data_dir.DataDir(params.data_dir) self.data_dir = DataDir(params.data_dir)
def load_pages(self) -> None: def load_pages(self) -> None:
self.pages = {} self.pages = {}
@@ -25,7 +32,7 @@ class Registry:
return [p.host for p in self.pages.values() if p.host is not None] return [p.host for p in self.pages.values() if p.host is not None]
def add(self, path: str) -> None: def add(self, path: str) -> None:
self.pages[path] = page.Page( self.pages[path] = Page(
path, path,
self.data_dir.has_index(path), self.data_dir.has_index(path),
self.data_dir.get_file(path, self.HOST_FILE), self.data_dir.get_file(path, self.HOST_FILE),
@@ -51,12 +58,12 @@ class Registry:
del self.pages[path] del self.pages[path]
self.logger.info("Removed %s", page) self.logger.info("Removed %s", page)
def get_from_path(self, path: str) -> page.Page | None: def get_from_path(self, path: str) -> Page | None:
if path in self.pages: if path in self.pages:
return self.pages[path] return self.pages[path]
return None return None
def get_from_host(self, host: str) -> page.Page | None: def get_from_host(self, host: str) -> Page | None:
for p in self.pages.values(): for p in self.pages.values():
if p.host == host: if p.host == host:
return p return p
+27 -16
View File
@@ -6,27 +6,38 @@ import typing
from . import ( from . import (
STAPLER_ASCII, STAPLER_ASCII,
cert_manager,
data_dir,
handlers,
project, project,
registry,
token_manager,
) )
from .cert_manager import CertManager
from .data_dir import DataDir
from .handlers import RequestHandler, UpgradeHandler
from .params import Parameters
from .registry import Registry
from .token_manager import TokenManager
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from . import params from .params import Parameters
class StaplerServer: class StaplerServer:
def __init__(self, params: params.Parameters) -> None: __slots__ = [
self.logger = logging.getLogger(self.__class__.__name__) "cert_manager",
self.params = params "data_dir",
self.registry = registry.Registry(params) "default_host",
self.cert_manager = cert_manager.CertManager(params) "logger",
self.token_manager = token_manager.TokenManager(params, self.registry) "params",
self.data_dir = data_dir.DataDir(params.data_dir) "registry",
self.default_host = params.host.split(":", maxsplit=2)[0] "token_manager",
]
def __init__(self, params: Parameters) -> None:
self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self.params: Parameters = params
self.registry: Registry = Registry(params)
self.cert_manager: CertManager = CertManager(params)
self.token_manager: TokenManager = TokenManager(params, self.registry)
self.data_dir: DataDir = DataDir(params.data_dir)
self.default_host: str = params.host.split(":", maxsplit=2)[0]
def __get_all_hosts(self) -> list[str]: def __get_all_hosts(self) -> list[str]:
return [self.default_host, *self.registry.get_hosts()] return [self.default_host, *self.registry.get_hosts()]
@@ -42,7 +53,7 @@ class StaplerServer:
def __request_handler( # pragma: no cover def __request_handler( # pragma: no cover
self, *args: typing.Any self, *args: typing.Any
) -> http.server.BaseHTTPRequestHandler: ) -> http.server.BaseHTTPRequestHandler:
return handlers.RequestHandler( return RequestHandler(
*args, *args,
params=self.params, params=self.params,
registry=self.registry, registry=self.registry,
@@ -83,7 +94,7 @@ class StaplerServer:
def __upgrade_handler( # pragma: no cover def __upgrade_handler( # pragma: no cover
self, *args: typing.Any self, *args: typing.Any
) -> http.server.BaseHTTPRequestHandler: ) -> http.server.BaseHTTPRequestHandler:
return handlers.UpgradeHandler( return UpgradeHandler(
*args, *args,
params=self.params, params=self.params,
) )
+15 -6
View File
@@ -7,17 +7,26 @@ import typing
from . import project from . import project
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from . import params, registry from .params import Parameters
from .registry import Registry
class TokenManager: class TokenManager:
__slots__ = [
"logger",
"registry",
"token_hashes",
"token_salt",
"tokens_file",
]
FILE = ".tokens" FILE = ".tokens"
def __init__(self, params: params.Parameters, registry: registry.Registry) -> None: def __init__(self, params: Parameters, registry: Registry) -> None:
self.logger = logging.getLogger(self.__class__.__name__) self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self.token_salt = params.token_salt self.token_salt: str = params.token_salt
self.tokens_file = pathlib.Path(params.data_dir) / self.FILE self.tokens_file: pathlib.Path = pathlib.Path(params.data_dir) / self.FILE
self.registry = registry self.registry: Registry = registry
self.token_hashes: list[str] = [] self.token_hashes: list[str] = []
def init(self) -> None: def init(self) -> None:
+1 -1
View File
@@ -11,7 +11,7 @@ class BaseTestCase(unittest.TestCase):
def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self.mocks: list[unittest.mock.Mock] = [] self.mocks: list[unittest.mock.Mock] = []
self.tmp_dir: tempfile.TemporaryDirectory | None = None self.tmp_dir: tempfile.TemporaryDirectory | None = None
self.tmp_path = pathlib.Path() self.tmp_path: pathlib.Path = pathlib.Path()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@typing.override @typing.override