diff --git a/.env.example b/.env.example index 49fb701..0c56d7d 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,2 @@ HOST=example.com -TOKEN=secret -PORT=443 \ No newline at end of file +TOKEN=secret \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 7acc5ff..86045e0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,8 +4,9 @@ WORKDIR /app VOLUME [ "/data", "/etc/letsencrypt" ] -ENV PORT=8080 -ENV HOST=localhost:8080 +ENV HTTP_PORT=80 +ENV HTTPS_PORT=443 +ENV HOST=localhost ENV DATA_DIR=/data ENV MAX_SIZE=2000000 ENV BIND=0.0.0.0 diff --git a/Makefile b/Makefile index 54d2204..2f7082e 100644 --- a/Makefile +++ b/Makefile @@ -72,7 +72,7 @@ docker-build: ## docker build .PHONY: docker-run docker-run: docker-build ## docker run - @$(DOCKER) run -it -p $(PORT):8080 -v ./data:/data $(DOCKER_TAG) --debug --no-certbot --token $(TOKEN) run + @$(DOCKER) run -it -p $(PORT):80 -v ./data:/data $(DOCKER_TAG) --debug --no-certbot --no-https --token $(TOKEN) --host localhost:$(PORT) run # ACTIONS diff --git a/README.md b/README.md index 9b0c762..64df6e7 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ ```txt usage: stapler [-h] [--debug | --no-debug] [-d DATA_DIR] [--certificates | --no-certificates] [--certbot | --no-certbot] [--self-signed-path SELF_SIGNED_PATH] [--certbot-conf CERTBOT_CONF] - [--certbot-www CERTBOT_WWW] [--host HOST] [-p PORT] [--https | --no-https] [-t TOKEN] [--max-size-bytes MAX_SIZE] [-b BIND] + [--certbot-www CERTBOT_WWW] [--host HOST] [--http-port HTTP_PORT] [--https-port HTTPS_PORT] [--https | --no-https] [-t TOKEN] [--max-size-bytes MAX_SIZE] [-b BIND] COMMAND ... Static pages as simple as a gzip file @@ -29,15 +29,18 @@ options: Certbot config dir (default: /etc/letsencrypt) --certbot-www CERTBOT_WWW Certbot www dir (default: ./data/.certbot) - --host HOST server default host (default: localhost:8080) - -p, --port PORT server port (default: 8080) + --host HOST server default host (default: localhost) + --http-port HTTP_PORT + server http port (default: 80) + --https-port HTTPS_PORT + server https port (default: 443) --https, --no-https Use https (implies --certificates) (default: true) -t, --token TOKEN secret token for update requests (default: ) --max-size-bytes MAX_SIZE max size of accepted archives (in bytes) (default: 2000000) -b, --bind BIND server bind address (default: 0.0.0.0) -(Each option can be supplied with equivalent environment variable.) +(Each option can be supplied with equivalent environment variable.) ``` ## Endpoints @@ -100,7 +103,7 @@ curl -X DELETE \ - [x] better error page - [x] add favicon.ico + special path - [x] [http.server security](https://docs.python.org/3/library/http.server.html#http-server-security) -- [ ] launch separate upgrade 80->443 server when https +- [x] launch separate upgrade 80->443 server when https - [ ] token management with "generate" command and bind path to specific token - [x] docker compose example + .env - [ ] proper doc diff --git a/docker-compose.example.yml b/docker-compose.example.yml index 089e5b4..0167f15 100644 --- a/docker-compose.example.yml +++ b/docker-compose.example.yml @@ -5,7 +5,8 @@ services: build: . restart: unless-stopped ports: - - "${PORT}:8080" + - "80:80" + - "443:443" volumes: - "./data:/data" - "./letsencrypt:/etc/letsencrypt" diff --git a/src/handler.py b/src/handlers.py similarity index 78% rename from src/handler.py rename to src/handlers.py index 28ec81e..8dc76b6 100644 --- a/src/handler.py +++ b/src/handlers.py @@ -1,3 +1,4 @@ +import abc import http import http.server import io @@ -14,116 +15,21 @@ if typing.TYPE_CHECKING: from . import params, registry -class RequestHandler(http.server.SimpleHTTPRequestHandler): +class _BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler): protocol_version = "HTTP/2.0" server_version = "StaplerServer/" + project.get_version() - CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge" - PATH_REGEX = re.compile(r"^\/([\w-]+)\/") - HOST_PART_REGEX = re.compile(r"^([a-zA-Z0-9]|[a-zA-Z0-9]*[a-zA-Z0-9][a-zA-Z0-9])$") - AUTHORIZED_PATHS: typing.ClassVar[list[str]] = ["/favicon.ico"] @typing.override def __init__( self, *args: typing.Any, params: params.Parameters, - registry: registry.Registry, - cert_manager: cert.CertManager, **kwargs: dict[str, typing.Any], ) -> None: self.logger = logging.getLogger(self.__class__.__name__) self.default_host = params.host.split(":", maxsplit=2)[0] - self.token = params.token - self.data_dir = data_dir.DataDir(params.data_dir) - self.max_size_bytes = params.max_size_bytes - self.registry = registry - self.cert_manager = cert_manager - self.certbot_www = os.path.realpath(params.certbot_www) self.out_size = 0 - super().__init__(*args, directory=params.data_dir, **kwargs) - - @typing.override - def do_HEAD(self) -> None: - self.__pre_log_request() - super().do_HEAD() - - @typing.override - def do_GET(self) -> None: - self.__pre_log_request() - if self.path == "/" and self.__get_host() == self.default_host: - return self.__server_index() - super().do_GET() - return None - - def do_PUT(self) -> None: - self.__pre_log_request() - if (sub_path := self.__check_update_request()) is None: - return None - host: str | None = self.headers["X-Host"] - if host is not None and not self.__valid_host(host): - return self.send_error( - http.HTTPStatus.BAD_REQUEST, "Invalid requested host" - ) - if (content_length := self.__get_length()) == 0: - return self.send_error(http.HTTPStatus.LENGTH_REQUIRED, "No body found") - if content_length > self.max_size_bytes: - return self.send_error( - http.HTTPStatus.CONTENT_TOO_LARGE, - "Archive too large", - ) - try: - file_bytes = io.BytesIO(self.rfile.read(content_length)) - self.data_dir.extract_tar_bytes(sub_path, file_bytes) - except tarfile.TarError: - return self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid tar archive") - except Exception as e: - return self.send_error(http.HTTPStatus.INTERNAL_SERVER_ERROR, str(e)) - self.__send_status_only( - http.HTTPStatus.CREATED, - f"Resource /{sub_path}/ updated", - ) - self.registry.add(sub_path) - if host is not None and self.cert_manager.create_or_update(host): - self.registry.set_host(sub_path, host) - self.registry.add(sub_path) - return None - - def do_DELETE(self) -> None: - self.__pre_log_request() - if (sub_path := self.__check_update_request()) is None: - return None - if not self.data_dir.exists(sub_path): - self.send_error(http.HTTPStatus.NOT_FOUND, "Not found") - return None - try: - self.data_dir.remove(sub_path) - except Exception as e: - return self.send_error(http.HTTPStatus.INTERNAL_SERVER_ERROR, str(e)) - self.__send_status_only( - http.HTTPStatus.NO_CONTENT, - f"Resource /{sub_path}/ removed", - ) - self.registry.remove(sub_path) - return None - - @typing.override - def list_directory(self, *_: typing.Any, **__: typing.Any) -> None: - """Disable default directory listing.""" - self.send_error(http.HTTPStatus.NOT_FOUND, "File not found") - - @typing.override - def translate_path(self, path: str) -> str: - if path.startswith(self.CERTBOT_CHALLENGE_PATH): - return self.certbot_www + path - if (page := self.registry.get_from_host(self.__get_host())) is not None: - path = f"/{page.path}" + path - elif ( - path not in self.AUTHORIZED_PATHS and self.__get_subpath(path) is None - ): # not a valid path - return "" - if pathlib.Path(path).name.startswith("."): # hidden files - return "" - return super().translate_path(path) + super().__init__(*args, **kwargs) @typing.override def send_error( @@ -138,13 +44,13 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler): if explain is None: explain = longmsg if "Accept" not in self.headers["Accept"] or "text/" in self.headers["Accept"]: - self.__send_basic_body( + self.send_basic_body( f"{code} {message}\n{explain}\n{self.server_version}\n", code=code, message=message, ) else: - self.__send_status_only(code, message) + self.send_status_only(code, message) @typing.override def log_message(self, format: str, *args: typing.Any) -> None: @@ -171,21 +77,173 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler): code = color + str(code.value) + logs.TermColor.RESET if size == "" and self.out_size > 0: size = str(self.out_size) - args = (code, self.address_string(), self.__get_host(), self.requestline) + args = (code, self.address_string(), self._get_host(), self.requestline) fmt = "→ %s - %s - %s - %s" if size != "": args = (*args, size) fmt += " - %s" self.logger.info(fmt, *args) - def __pre_log_request(self) -> None: - args = ("...", self.address_string(), self.__get_host(), self.requestline) + def send_basic_body( + self, + body: str, + content_type: str = "text/plain", + code: int = http.HTTPStatus.OK, + message: str | None = None, + headers: dict[str, str] | None = None, + ) -> None: + encoded: bytes = body.encode() + self.out_size = len(encoded) + self.send_response(code, message) + self.send_header("Content-type", f"{content_type}; charset=UTF-8") + self.send_header("Content-Length", str(len(encoded))) + if headers is not None: + for header, value in headers.items(): + self.send_header(header, value) + self.end_headers() + self.wfile.write(encoded) + + def send_status_only( + self, + code: int, + message: str | None = None, + headers: dict[str, str] | None = None, + ) -> None: + self.send_response(code, message) + self.send_header("Content-Length", "0") + if headers is not None: + for header, value in headers.items(): + self.send_header(header, value) + self.end_headers() + + def _get_host(self) -> str: + if self.headers["Host"] is None: + return self.default_host + return self.headers["Host"].split(":", maxsplit=2)[0] + + def _get_length(self) -> int: + if not self.headers["Content-Length"]: + return 0 + return int(self.headers["Content-Length"]) + + def _pre_log_request(self) -> None: + args = ("...", self.address_string(), self._get_host(), self.requestline) fmt = "← %s - %s - %s - %s" - if (size := self.__get_length()) > 0: + if (size := self._get_length()) > 0: args = (*args, size) fmt += " - %s" self.logger.debug(fmt, *args) + +class RequestHandler(http.server.SimpleHTTPRequestHandler, _BaseHandler): + protocol_version = "HTTP/2.0" + server_version = "StaplerServer/" + project.get_version() + CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge" + PATH_REGEX = re.compile(r"^\/([\w-]+)\/") + HOST_PART_REGEX = re.compile(r"^([a-zA-Z0-9]|[a-zA-Z0-9]*[a-zA-Z0-9][a-zA-Z0-9])$") + AUTHORIZED_PATHS: typing.ClassVar[list[str]] = ["/favicon.ico"] + + @typing.override + def __init__( + self, + *args: typing.Any, + params: params.Parameters, + registry: registry.Registry, + cert_manager: cert.CertManager, + **kwargs: dict[str, typing.Any], + ) -> None: + self.logger = logging.getLogger(self.__class__.__name__) + self.token = params.token + self.data_dir = data_dir.DataDir(params.data_dir) + self.max_size_bytes = params.max_size_bytes + self.registry = registry + self.cert_manager = cert_manager + self.certbot_www = os.path.realpath(params.certbot_www) + super().__init__(*args, directory=params.data_dir, **kwargs, params=params) # ty:ignore[unknown-argument] + + @typing.override + def do_HEAD(self) -> None: + self._pre_log_request() + super().do_HEAD() + + @typing.override + def do_GET(self) -> None: + self._pre_log_request() + if self.path == "/" and self._get_host() == self.default_host: + return self.__server_index() + super().do_GET() + return None + + def do_PUT(self) -> None: + self._pre_log_request() + if (sub_path := self.__check_update_request()) is None: + return None + host: str | None = self.headers["X-Host"] + if host is not None and not self.__valid_host(host): + return self.send_error( + http.HTTPStatus.BAD_REQUEST, "Invalid requested host" + ) + if (content_length := self._get_length()) == 0: + return self.send_error(http.HTTPStatus.LENGTH_REQUIRED, "No body found") + if content_length > self.max_size_bytes: + return self.send_error( + http.HTTPStatus.CONTENT_TOO_LARGE, + "Archive too large", + ) + try: + file_bytes = io.BytesIO(self.rfile.read(content_length)) + self.data_dir.extract_tar_bytes(sub_path, file_bytes) + except tarfile.TarError: + return self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid tar archive") + except Exception as e: + return self.send_error(http.HTTPStatus.INTERNAL_SERVER_ERROR, str(e)) + self.send_status_only( + http.HTTPStatus.CREATED, + f"Resource /{sub_path}/ updated", + ) + self.registry.add(sub_path) + if host is not None and self.cert_manager.create_or_update(host): + self.registry.set_host(sub_path, host) + self.registry.add(sub_path) + return None + + def do_DELETE(self) -> None: + self._pre_log_request() + if (sub_path := self.__check_update_request()) is None: + return None + if not self.data_dir.exists(sub_path): + self.send_error(http.HTTPStatus.NOT_FOUND, "Not found") + return None + try: + self.data_dir.remove(sub_path) + except Exception as e: + return self.send_error(http.HTTPStatus.INTERNAL_SERVER_ERROR, str(e)) + self.send_status_only( + http.HTTPStatus.NO_CONTENT, + f"Resource /{sub_path}/ removed", + ) + self.registry.remove(sub_path) + return None + + @typing.override + def list_directory(self, *_: typing.Any, **__: typing.Any) -> None: + """Disable default directory listing.""" + self.send_error(http.HTTPStatus.NOT_FOUND, "File not found") + + @typing.override + def translate_path(self, path: str) -> str: + if path.startswith(self.CERTBOT_CHALLENGE_PATH): + return self.certbot_www + path + if (page := self.registry.get_from_host(self._get_host())) is not None: + path = f"/{page.path}" + path + elif ( + path not in self.AUTHORIZED_PATHS and self.__get_subpath(path) is None + ): # not a valid path + return "" + if pathlib.Path(path).name.startswith("."): # hidden files + return "" + return super().translate_path(path) + def __check_update_request(self) -> str | None: if len(self.token) and self.headers["X-Token"] != self.token: self.send_error(http.HTTPStatus.UNAUTHORIZED, "Invalid token") @@ -205,38 +263,20 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler): return match.group(1) return None - def __get_host(self) -> str: - if self.headers["Host"] is None: - return self.default_host - return self.headers["Host"].split(":", maxsplit=2)[0] - - def __get_length(self) -> int: - if not self.headers["Content-Length"]: - return 0 - return int(self.headers["Content-Length"]) - def __valid_host(self, host: str) -> bool: return all(self.HOST_PART_REGEX.fullmatch(part) for part in host.split(".")) def __server_index(self) -> None: - self.__send_basic_body(self.server_version + "\n") + self.send_basic_body(self.server_version + "\n") - def __send_basic_body( - self, - body: str, - content_type: str = "text/plain", - code: int = http.HTTPStatus.OK, - message: str | None = None, - ) -> None: - encoded: bytes = body.encode() - self.out_size = len(encoded) - self.send_response(code, message) - self.send_header("Content-type", f"{content_type}; charset=UTF-8") - self.send_header("Content-Length", str(len(encoded))) - self.end_headers() - self.wfile.write(encoded) - def __send_status_only(self, code: int, message: str | None = None) -> None: - self.send_response(code, message) - self.send_header("Content-Length", "0") - self.end_headers() +class UpgradeHandler(_BaseHandler): + def do_HEAD(self) -> None: + self._pre_log_request() + self.send_status_only( + http.HTTPStatus.MOVED_PERMANENTLY, + headers={"Location": f"https://{self._get_host()}{self.path}"}, + ) + + def do_GET(self) -> None: + self.do_HEAD() diff --git a/src/params.py b/src/params.py index 496b000..2d85247 100644 --- a/src/params.py +++ b/src/params.py @@ -10,7 +10,8 @@ __EPILOG = "(Each option can be supplied with equivalent environment variable.)" @dataclasses.dataclass(frozen=True) class Parameters: - port: int + http_port: int + https_port: int host: str data_dir: str bind: str @@ -144,16 +145,22 @@ def parse_parameters() -> Parameters: parser, "--host", env_var="HOST", - default="localhost:8080", + default="localhost", help_txt="server default host", ) __add_arg_int( parser, - "-p", - "--port", - env_var="PORT", - default=8080, - help_txt="server port", + "--http-port", + env_var="HTTP_PORT", + default=80, + help_txt="server http port", + ) + __add_arg_int( + parser, + "--https-port", + env_var="HTTPS_PORT", + default=443, + help_txt="server https port", ) parser.add_argument( "--https", diff --git a/src/server.py b/src/server.py index 99de631..66334b7 100644 --- a/src/server.py +++ b/src/server.py @@ -1,9 +1,10 @@ import contextlib import http.server import logging +import threading import typing -from . import cert, data_dir, handler, project, registry +from . import cert, data_dir, handlers, project, registry if typing.TYPE_CHECKING: from . import params @@ -18,14 +19,6 @@ class StaplerServer: self.data_dir = data_dir.DataDir(params.data_dir) self.default_host = params.host.split(":", maxsplit=2)[0] - def request_handler(self, *args: typing.Any) -> http.server.BaseHTTPRequestHandler: - return handler.RequestHandler( - *args, - params=self.params, - registry=self.registry, - cert_manager=self.cert_manager, - ) - def __get_all_hosts(self) -> list[str]: return [self.default_host, *self.registry.get_hosts()] @@ -47,26 +40,85 @@ class StaplerServer: server.socket = context.wrap_socket(server.socket, server_side=True) return https - def run(self) -> int: - self.logger.info("Version %s", project.get_version()) - self.__startup() - server = http.server.ThreadingHTTPServer( - (self.params.bind, self.params.port), - self.request_handler, + def __request_handler( + self, *args: typing.Any + ) -> http.server.BaseHTTPRequestHandler: + return handlers.RequestHandler( + *args, + params=self.params, + registry=self.registry, + cert_manager=self.cert_manager, ) - https = self.params.https and self.__create_https_context(server) + + def __create_base_server(self) -> tuple[http.server.ThreadingHTTPServer, bool]: + context = ( + self.cert_manager.get_https_context(self.default_host) + if self.params.https + else None + ) + if context is not None: + server = http.server.ThreadingHTTPServer( + ( + self.params.bind, + self.params.https_port, + ), + self.__request_handler, + ) + server.socket = context.wrap_socket(server.socket, server_side=True) + else: + server = http.server.ThreadingHTTPServer( + ( + self.params.bind, + self.params.http_port, + ), + self.__request_handler, + ) self.logger.info( - "Listening on %s:%d...", + "Server listening on %s:%d...", server.server_address[0], server.server_port, ) + return server, context is not None + + def __upgrade_handler( + self, *args: typing.Any + ) -> http.server.BaseHTTPRequestHandler: + return handlers.UpgradeHandler( + *args, + params=self.params, + ) + + def __start_upgrade_server(self) -> http.server.ThreadingHTTPServer: + server = http.server.ThreadingHTTPServer( + ( + self.params.bind, + self.params.http_port, + ), + self.__upgrade_handler, + ) + self.logger.info( + "Upgrade server listening on %s:%d...", + server.server_address[0], + server.server_port, + ) + threading.Thread(target=server.serve_forever).start() + return server + + def run(self) -> int: + self.logger.info("Version %s", project.get_version()) + self.__startup() + base_server, https = self.__create_base_server() + upgrade_server = self.__start_upgrade_server() if https else None self.logger.info( "Server up and ready on %s://%s", "https" if https else "http", self.params.host, ) with contextlib.suppress(KeyboardInterrupt): - server.serve_forever() + base_server.serve_forever() + self.logger.info("Shutting down...") + if upgrade_server is not None: + upgrade_server.shutdown() return 0 def renew(self) -> int: