diff --git a/Makefile b/Makefile index c563163..eb96694 100644 --- a/Makefile +++ b/Makefile @@ -62,7 +62,7 @@ docker-build: ## docker build .PHONY: docker-run docker-run: docker-build ## docker run - @$(DOCKER) run -it -p $(PORT):8080 -v ./data:/data $(DOCKER_TAG) --token $(TOKEN) --host localhost:$(PORT) --debug + @$(DOCKER) run -it -p $(PORT):8080 -v ./data:/data $(DOCKER_TAG) --token $(TOKEN) --debug # ACTIONS diff --git a/README.md b/README.md index 1d33700..b651aba 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ curl -X DELETE \ - [x] certbot/self-signed create/renew in specific dir - [x] better logger - [ ] renew command -- [ ] https mode w/ multiple hosts +- [x] https mode w/ multiple hosts - [ ] restart command (on new/deleted host) - [ ] proper doc - [ ] log visits (and store accross sessions) diff --git a/src/cert.py b/src/cert.py index 1a54f8f..5c0fe6b 100644 --- a/src/cert.py +++ b/src/cert.py @@ -1,6 +1,7 @@ import logging import pathlib import shutil +import ssl import subprocess import typing @@ -49,19 +50,21 @@ class CertManager: return True return created or self.__create_self_signed(host) - def get_pem(self, host: str) -> pathlib.Path | None: + def get_cert(self, host: str) -> pathlib.Path: if self.__exists_certbot(host): return self.__certbot_file(host, self.CRT_FILE) if self.__exists_self_signed(host): return self.__self_signed_file(host, self.CRT_FILE) - return None + msg = "Cannot get cert file for %s" + raise CertManagerError(msg, host) - def get_key(self, host: str) -> pathlib.Path | None: + def get_key(self, host: str) -> pathlib.Path: if self.__exists_certbot(host): return self.__certbot_file(host, self.KEY_FILE) if self.__exists_self_signed(host): return self.__self_signed_file(host, self.KEY_FILE) - return None + msg = "Cannot get key file for %s" + raise CertManagerError(msg, host) def __self_signed_file(self, host: str, file: str) -> pathlib.Path: return self.self_signed_path / host / file @@ -83,9 +86,6 @@ class CertManager: cert_path = self.self_signed_path / host if not cert_path.exists(): cert_path.mkdir(parents=True) - cert_host: str = host - if ":" in host: - cert_host = host.split(":", maxsplit=2)[0] try: # openssl req -new -newkey rsa:2048 -days 30 -nodes -x509 -keyout server.key -out server.crt subprocess.run( @@ -104,7 +104,7 @@ class CertManager: "-out", cert_path / "fullchain.pem", "-subj", - f"/C=/ST=/L=/O=/OU=/CN={cert_host}", + f"/C=/ST=/L=/O=/OU=/CN={host}", ], check=True, ) @@ -134,9 +134,6 @@ class CertManager: return binary_path def __create_certbot(self, host: str) -> bool: - cert_host: str = host - if ":" in host: - cert_host = host.split(":", maxsplit=2)[0] try: # certonly -v --webroot --webroot-path=/var/www/certbot --agree-tos --no-eff-email -n --force-renewal --expand subprocess.run( @@ -151,7 +148,7 @@ class CertManager: "--cert-name", host, "--domain", - cert_host, + host, ], check=True, ) @@ -160,3 +157,35 @@ class CertManager: self.logger.exception("Could not create certbot certificate for %s", host) return False return self.__exists_certbot(host) + + def get_https_context(self, default_host: str) -> ssl.SSLContext | None: + if not self.exists(default_host): + self.logger.warning("Cannot create HTTPS context for %s", default_host) + return None + cert_file = self.get_cert(default_host) + key_file = self.get_key(default_host) + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context.load_cert_chain( + cert_file, + key_file, + ) + context.sni_callback = self.__sni_callback + return context + + def __sni_callback( + self, socket: ssl.SSLObject, host: str, _: ssl.SSLContext, / + ) -> None | int: + if host is None: + return + if not self.exists(host) and not self.create_or_update(host): + msg = "Could not get certificate for %s" + raise CertManagerError(msg, host) + new_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + cert_file = self.get_cert(host) + key_file = self.get_key(host) + new_context.load_cert_chain( + cert_file, + key_file, + ) + socket.context = new_context + return None diff --git a/src/handler.py b/src/handler.py index 3278d79..87265d1 100644 --- a/src/handler.py +++ b/src/handler.py @@ -20,7 +20,6 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler): CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge" PATH_REGEX = re.compile(r"^\/([\w-]+)\/") HOST_PART_REGEX = re.compile(r"^([a-zA-Z0-9]|[a-zA-Z0-9]*[a-zA-Z0-9][a-zA-Z0-9])$") - HOST_PORT_REGEX = re.compile(r"^\d{2,5}$") @typing.override def __init__( @@ -31,7 +30,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler): **kwargs: dict[str, typing.Any], ) -> None: self.logger = logging.getLogger(self.__class__.__name__) - self.default_host = params.host + self.default_host = params.host.split(":", maxsplit=2)[0] self.token = params.token self.data_dir = data_dir.DataDir(params.data_dir) self.max_size_bytes = params.max_size_bytes @@ -114,12 +113,11 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler): return self.certbot_www + path.removeprefix(self.CERTBOT_CHALLENGE_PATH) if (page := self.registry.get_from_host(self.__get_host())) is not None: path = f"/{page.path}" + path - path = super().translate_path(path) - if self.__get_subpath() is None: # not a valid path + if self.__get_subpath(path) is None: # not a valid path return "" if pathlib.Path(path).name.startswith("."): # hidden files return "" - return path + return super().translate_path(path) @typing.override def send_error( @@ -186,25 +184,25 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler): if self.headers["X-Token"] != self.token: self.send_error(http.HTTPStatus.UNAUTHORIZED, "Invalid token") return None - if (sub_path := self.__get_subpath_full()) is None: + if (sub_path := self.__get_subpath_full(self.path)) is None: self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid path") return None return sub_path - def __get_subpath(self) -> str | None: - if (match := self.PATH_REGEX.match(self.path)) is not None: + def __get_subpath(self, path: str) -> str | None: + if (match := self.PATH_REGEX.match(path)) is not None: return match.group(1) return None - def __get_subpath_full(self) -> str | None: - if (match := self.PATH_REGEX.fullmatch(self.path)) is not None: + def __get_subpath_full(self, path: str) -> str | None: + if (match := self.PATH_REGEX.fullmatch(path)) is not None: return match.group(1) return None def __get_host(self) -> str: if self.headers["Host"] is None: return self.default_host - return self.headers["Host"] + return self.headers["Host"].split(":", maxsplit=2)[0] def __get_length(self) -> int: if not self.headers["Content-Length"]: @@ -212,11 +210,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler): return int(self.headers["Content-Length"]) def __valid_host(self, host: str) -> bool: - hostname, port = host.split(":", maxsplit=2) - for part in hostname.split("."): - if not self.HOST_PART_REGEX.fullmatch(part): - return False - return not (port and len(port) and not self.HOST_PORT_REGEX.fullmatch(port)) + return all(self.HOST_PART_REGEX.fullmatch(part) for part in host.split(".")) def __server_index(self) -> None: self.__send_basic_body(self.server_version + "\n") diff --git a/src/page.py b/src/page.py index 2c14364..386f79e 100644 --- a/src/page.py +++ b/src/page.py @@ -13,7 +13,7 @@ class Page: def __repr__(self) -> str: out = self.get_url_path() if self.host is not None: - out += f" [http://{self.host}/]" + out += f" [{self.host}]" if not self.with_index: out += " (no index)" return out diff --git a/src/params.py b/src/params.py index f7881ed..c179615 100644 --- a/src/params.py +++ b/src/params.py @@ -5,7 +5,7 @@ import os from . import project -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Parameters: port: int host: str @@ -18,6 +18,7 @@ class Parameters: self_signed_path: str with_certbot: bool with_certificates: bool + https: bool debug: bool @classmethod @@ -160,8 +161,8 @@ def parse_parameters() -> Parameters: parser.add_argument( "--certbot", action=argparse.BooleanOptionalAction, - help="Use Certbot (default: true)", - default=True, + help="Use Certbot (default: false)", + default=False, dest="with_certbot", ) parser.add_argument( @@ -171,5 +172,13 @@ def parse_parameters() -> Parameters: default=True, dest="with_certificates", ) + parser.add_argument( + "--https", + action=argparse.BooleanOptionalAction, + help="Use https (implies --certificates) (default: true)", + default=True, + ) args = parser.parse_args() + if args.https: + args.with_certificates = True return Parameters.from_namespace(args) diff --git a/src/registry.py b/src/registry.py index f3b4cec..47cd1c8 100644 --- a/src/registry.py +++ b/src/registry.py @@ -12,7 +12,6 @@ class Registry: self.logger = logging.getLogger(self.__class__.__name__) self.pages: dict[str, page.Page] = {} self.data_dir = data_dir.DataDir(params.data_dir) - self.prefix = f"http://{params.host}" def load_pages(self) -> None: self.pages = {} @@ -28,7 +27,7 @@ class Registry: self.data_dir.has_index(path), self.data_dir.get_host(path), ) - self.logger.info("Updated %s%s", self.prefix, str(self.pages[path])) + self.logger.info("Updated %s", self.pages[path]) def set_host(self, path: str, host: str) -> None: self.data_dir.set_host(path, host) @@ -37,7 +36,7 @@ class Registry: def remove(self, path: str) -> None: page = self.pages[path] del self.pages[path] - self.logger.info("Removed %s%s", self.prefix, str(page)) + self.logger.info("Removed %s", page) def get_from_host(self, host: str) -> page.Page | None: for p in self.pages.values(): diff --git a/src/server.py b/src/server.py index b11431f..c31e178 100644 --- a/src/server.py +++ b/src/server.py @@ -15,10 +15,7 @@ class StaplerServer: self.params = params self.registry = registry.Registry(params) self.cert_manager = cert.CertManager(params) - self.server = http.server.ThreadingHTTPServer( - (params.bind, params.port), - self.request_handler, - ) + self.default_host = params.host.split(":", maxsplit=2)[0] def request_handler(self, *args: typing.Any) -> http.server.BaseHTTPRequestHandler: return handler.RequestHandler(*args, params=self.params, registry=self.registry) @@ -27,19 +24,34 @@ class StaplerServer: self.logger.info("Starting up...") self.registry.load_pages() if self.params.with_certificates: - self.cert_manager.init([self.params.host, *self.registry.get_hosts()]) + self.cert_manager.init([self.default_host, *self.registry.get_hosts()]) + + def __create_https_context(self, server: http.server.HTTPServer) -> bool: + https = False + if ( + context := self.cert_manager.get_https_context(self.default_host) + ) is not None: + https = True + server.socket = context.wrap_socket(server.socket, server_side=True) + return https def start(self) -> None: self.logger.info("Version %s", project.get_version()) self.__startup() + server = http.server.ThreadingHTTPServer( + (self.params.bind, self.params.port), + self.request_handler, + ) + https = self.params.https and self.__create_https_context(server) self.logger.info( "Listening on %s:%d...", - self.server.server_address[0], - self.server.server_port, + server.server_address[0], + server.server_port, ) self.logger.info( - "Server up and ready on http://%s", + "Server up and ready on %s://%s", + "https" if https else "http", self.params.host, ) with contextlib.suppress(KeyboardInterrupt): - self.server.serve_forever() + server.serve_forever()