Files
stapler/stapler/handlers.py
T
2026-04-27 15:15:34 +02:00

589 lines
20 KiB
Python

import abc
import http
import http.cookiejar
import http.server
import io
import logging
import os
import pathlib
import re
import tarfile
import typing
import urllib.parse
import requests
from . import PKG_VERSION, STAPLER_ASCII, logs
from .data_dir import DataDir
if typing.TYPE_CHECKING:
from .page import Page
from .params import Parameters
from .registry import Registry
from .token_manager import TokenManager
class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler):
@typing.override
def __init__(
self,
*args: typing.Any,
params: Parameters,
**kwargs: dict[str, typing.Any],
) -> None:
self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self.default_host: str = params.host.split(":", maxsplit=2)[0]
self.out_size: int = 0
self.__host: str | None = None
self.__in_size: int | None = None
self.https: bool = params.https
super().__init__(*args, **kwargs)
@typing.override
def send_error(
self,
code: int,
message: str | None = None,
explain: str | None = None,
) -> None:
self.send_status(code, message, explain)
def send_status(
self,
code: int,
message: str | None = None,
explain: str | None = None,
) -> None:
shortmsg, longmsg = self.responses[code]
if message is None:
message = shortmsg
if explain is None:
explain = longmsg
if (
not self._has_header("Accept")
or self._get_header("Accept").startswith("*/")
or self._get_header("Accept").startswith("text/")
):
self.send_basic_body(
f"{code} {message}\n{explain}\n\n{self.server_signature()}",
code=code,
message=message,
)
else:
self.send_status_only(code, message)
@typing.override
def log_message(self, format: str, *args: typing.Any) -> None: # pragma: no cover
fmt = "%s - " + format
self.logger.info(fmt, self.address_string(), *args)
@typing.override
def log_error(self, format: str, *args: typing.Any) -> None: # pragma: no cover
fmt = "%s - " + format
self.logger.error(fmt, self.address_string(), *args)
@typing.override
def log_request(self, code: str = "?", size: str = "-") -> None: # ty:ignore[invalid-method-override] # pragma: no cover
if isinstance(code, http.HTTPStatus):
color = logs.TermColor.RED
if 100 <= code < 200:
color = logs.TermColor.CYAN
if 200 <= code < 300:
color = logs.TermColor.GREEN
elif 300 <= code < 400:
color = logs.TermColor.BLUE
elif 400 <= code < 500:
color = logs.TermColor.YELLOW
code = color + str(code.value) + logs.TermColor.RESET
if size == "" and self.out_size > 0:
size = str(self.out_size)
args = (code, self.address_string(), self.host, self.requestline)
fmt = "%s - %s - %s - %s"
if size != "":
args = (*args, size)
fmt += " - %s"
self.logger.info(fmt, *args)
def send_basic_body(
self,
body: str,
content_type: str = "text/plain",
code: int = http.HTTPStatus.OK,
message: str | None = None,
) -> None:
encoded: bytes = body.encode()
self.out_size = len(encoded)
self.send_response(code, message)
self.send_header("Content-Type", f"{content_type}; charset=UTF-8")
self.send_header("Content-Length", str(len(encoded)))
self.end_headers()
self.wfile.write(encoded)
self.close_connection = True
def send_status_only(
self,
code: int,
message: str | None = None,
headers: dict[str, str] | None = None,
) -> None:
if headers is None:
headers = {}
self.send_response(code, message)
self.send_header("Content-Length", "0")
for header, value in headers.items():
self.send_header(header, value)
self.end_headers()
self.close_connection = True
def send_redirect(self, location: str) -> None:
self.send_status_only(
http.HTTPStatus.MOVED_PERMANENTLY,
headers={"Location": location},
)
def send_proxy(self, url: str) -> None:
headers = dict(self.headers)
headers["Host"] = (target_host := urllib.parse.urlparse(url).netloc)
headers["X-Real-IP"] = self.client_address[0]
headers["X-Forwarded-Host"] = self.host
headers["X-Forwarded-For"] = self.client_address[0]
headers["X-Forwarded-Proto"] = "https" if self.https else "http"
try:
body: bytes | None = None
if self.in_size > 0:
body = self.rfile.read(self.in_size)
response: requests.Response = requests.request(
self.command,
url,
data=body,
headers=headers,
allow_redirects=False,
timeout=480,
)
except Exception as e:
self.send_error(
http.HTTPStatus.BAD_GATEWAY, f"Could not reach {url}", explain=str(e)
)
return
self.send_response(response.status_code, response.reason)
for header, value in response.headers.items():
if header.lower() not in [
"content-length",
"content-encoding",
"transfer-encoding",
"server",
"date",
]:
self.send_header(header, value.replace(target_host, self.host))
self.send_header("Content-Length", str(out_size := len(response.content)))
self.end_headers()
if out_size > 0:
self.wfile.write(response.content)
self.close_connection = True
@property
def host(self) -> str:
if self.__host is None:
self.__host = self._get_host()
return self.__host
def _get_host(self) -> str:
host = self._get_header("Host", self.default_host)
return host.split(":", maxsplit=2)[0]
@property
def in_size(self) -> int:
if self.__in_size is None:
self.__in_size = self._get_length()
return self.__in_size
def _get_length(self) -> int:
return int(self._get_header("Content-Length", "0"))
def _get_header(self, key: str, default_value: str = "") -> str:
if self._has_header(key):
return self.headers[key]
return default_value
def _has_header(self, key: str) -> bool:
return (
hasattr(self, "headers")
and key in self.headers
and len(self.headers[key]) > 0
)
def _pre_log_request(self) -> None: # pragma: no cover
args = ("...", self.address_string(), self.host, self.requestline)
fmt = "%s - %s - %s - %s"
if self.in_size > 0:
args = (*args, self.in_size)
fmt += " - %s"
self.logger.debug(fmt, *args)
def server_signature(self) -> str:
return self.server_version + "\n\n" + STAPLER_ASCII + "\n"
class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
protocol_version = "HTTP/1.1"
server_version = "StaplerServer/" + PKG_VERSION
CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge"
UPDATE_PATH_REGEX = re.compile(r"^\/([\w-]+)\/?$")
GET_PATH_REGEX = re.compile(r"^\/([\w-]+)($|\/)")
HOST_PART_REGEX = re.compile(r"^([a-z0-9]|[a-z0-9][a-z0-9-]{,61}[a-z0-9])$")
AUTHORIZED_PATHS: typing.ClassVar[list[str]] = ["/favicon.ico"]
TOKEN_HEADER = "X-Token" # noqa: S105
HOST_HEADER = "X-Host"
HOST_ONLY_HEADER = "X-Host-Only"
REDIRECT_HEADER = "X-Redirect"
PROXY_HEADER = "X-Proxy"
SPA_HEADER = "X-SPA"
@typing.override
def __init__(
self,
*args: typing.Any,
params: Parameters,
registry: Registry,
token_manager: TokenManager,
**kwargs: dict[str, typing.Any],
) -> None:
self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self.token_manager: TokenManager = token_manager
self.data_dir: DataDir = DataDir(params.data_dir)
self.root_path: pathlib.Path = pathlib.Path(params.data_dir)
self.max_size_bytes: int = params.max_size_bytes
self.registry: Registry = registry
self.certbot_www: str = os.path.realpath(params.certbot_www)
self.__token: str | None = None
self.__target_host: str | None = None
self.__target_host_only: str | None = None
self.__target_redirect: str | None = None
self.__target_proxy: str | None = None
self.__target_spa: str | None = None
super().__init__(*args, directory=params.data_dir, **kwargs, params=params) # ty:ignore[unknown-argument]
@property
def token(self) -> str:
if self.__token is None:
self.__token = self._get_header(self.TOKEN_HEADER)
return self.__token
@property
def has_token(self) -> bool:
return len(self.token) > 0
@property
def request_host(self) -> str:
if self.__target_host is None:
self.__target_host = self._get_header(self.HOST_HEADER).lower()
return self.__target_host
@property
def has_request_host(self) -> bool:
return len(self.request_host) > 0
@property
def request_host_only(self) -> str:
if self.__target_host_only is None:
self.__target_host_only = self._get_header(self.HOST_ONLY_HEADER).lower()
return self.__target_host_only
@property
def has_request_host_only(self) -> bool:
return len(self.request_host_only) > 0
@property
def target_host(self) -> str:
if self.has_request_host:
return self.request_host
return self.request_host_only
@property
def has_target_host(self) -> bool:
return self.has_request_host or self.has_request_host_only
@property
def target_redirect(self) -> str:
if self.__target_redirect is None:
self.__target_redirect = self._get_header(self.REDIRECT_HEADER).lower()
return self.__target_redirect
@property
def has_target_redirect(self) -> bool:
return len(self.target_redirect) > 0
@property
def target_proxy(self) -> str:
if self.__target_proxy is None:
self.__target_proxy = self._get_header(self.PROXY_HEADER).lower()
return self.__target_proxy
@property
def has_target_proxy(self) -> bool:
return len(self.target_proxy) > 0
@property
def target_spa(self) -> str:
if self.__target_spa is None:
self.__target_spa = self._get_header(self.SPA_HEADER).lower()
return self.__target_spa
@property
def has_target_spa(self) -> bool:
return len(self.target_spa) > 0
@typing.override
def do_HEAD(self) -> None:
self._pre_log_request()
if not self._proxy_or_redirect():
super().do_HEAD()
@typing.override
def do_GET(self) -> None:
self._pre_log_request()
if self._proxy_or_redirect():
return None
if self.path == "/" and self.host == self.default_host:
return self.send_basic_body(self.server_signature())
return super().do_GET()
def do_PUT(self) -> None:
self._pre_log_request()
if self._proxy_or_redirect():
return
if (path := self.__check_put_request()) is None:
return
if self.has_target_redirect:
if not self._update_redirect(path):
return
elif self.has_target_proxy:
if not self._update_proxy(path):
return
elif not self._update_extract(path):
return
if self.has_request_host:
self.registry.set_host(path, self.target_host)
if self.has_request_host_only:
self.registry.set_host_only(path, self.target_host)
self.send_status(
http.HTTPStatus.CREATED,
"Resource updated",
str(self.registry.get_from_path(path)),
)
def do_POST(self) -> None:
self.do_PUT() # be gentle on them
def do_PATCH(self) -> None:
self.do_PUT() # be gentle on them
def do_DELETE(self) -> None:
self._pre_log_request()
if self._proxy_or_redirect():
return
if (path := self.__check_update_request()) is None:
return
if self._update_remove(path):
self.send_status(
http.HTTPStatus.OK,
f"Resource /{path}/ removed",
)
return
def do_CONNECT(self) -> None:
self._pre_log_request()
if not self._proxy_or_redirect():
self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED)
def do_OPTIONS(self) -> None:
self._pre_log_request()
if not self._proxy_or_redirect():
self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED)
def do_TRACE(self) -> None:
self._pre_log_request()
if not self._proxy_or_redirect():
self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED)
def _update_extract(self, path: str) -> bool:
if self.in_size == 0:
self.send_error(http.HTTPStatus.LENGTH_REQUIRED, "No body found")
return False
if self.in_size > self.max_size_bytes:
self.send_error(
http.HTTPStatus.CONTENT_TOO_LARGE,
"Archive too large",
)
return False
try:
file_bytes = io.BytesIO(self.rfile.read(self.in_size))
self.data_dir.extract_tar_bytes(path, file_bytes)
except tarfile.TarError:
self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid tar archive")
return False
except Exception as e:
self.send_error(http.HTTPStatus.INTERNAL_SERVER_ERROR, str(e))
return False
self.registry.add(path)
self.token_manager.set_token(path, self.token)
if self.has_target_spa:
self.registry.set_spa(path, self.target_spa)
return True
def _update_redirect(self, path: str) -> bool:
if self.in_size > 0:
self.send_error(
http.HTTPStatus.BAD_REQUEST,
f"No content must be sent with {self.REDIRECT_HEADER}",
)
return False
self.registry.set_redirect(path, self.target_redirect)
self.token_manager.set_token(path, self.token)
return True
def _update_proxy(self, path: str) -> bool:
if self.in_size > 0:
self.send_error(
http.HTTPStatus.BAD_REQUEST,
f"No content must be sent with {self.PROXY_HEADER}",
)
return False
self.registry.set_proxy(path, self.target_proxy)
self.token_manager.set_token(path, self.token)
return True
def _update_remove(self, path: str) -> bool:
if not self.data_dir.exists(path):
self.send_error(http.HTTPStatus.NOT_FOUND, "Not found")
return False
try:
self.data_dir.remove(path)
except Exception as e:
self.send_error(http.HTTPStatus.INTERNAL_SERVER_ERROR, str(e))
return False
self.registry.remove(path)
return True
def _proxy_or_redirect(self) -> bool:
if self.has_token or self.path.startswith(self.CERTBOT_CHALLENGE_PATH):
return False
if (page := self.__get_page(self.path)) is None:
return False
if page.redirect is not None:
self.send_redirect(page.redirect)
return True
if page.proxy is not None:
if self.host == self.default_host:
self.send_proxy(page.proxy + self.path.removeprefix(f"/{page.path}"))
else:
self.send_proxy(page.proxy + self.path)
return True
return False
@typing.override
def list_directory(self, *_: typing.Any, **__: typing.Any) -> None:
"""Disable default directory listing."""
self.send_error(http.HTTPStatus.NOT_FOUND, "File not found")
@typing.override
def translate_path(self, path: str) -> str:
if path.startswith(self.CERTBOT_CHALLENGE_PATH):
return self.certbot_www + path
page = self.__get_page(path)
if page is None:
if path in self.AUTHORIZED_PATHS:
return super().translate_path(path)
return ""
if self.host != self.default_host:
path = f"/{page.path}" + path
if pathlib.Path(path).name.startswith("."): # hidden files
return ""
if (
page.spa is not None
and not (self.root_path / pathlib.Path(path[1:])).is_file()
and not (self.root_path / pathlib.Path(path[1:]) / "index.html").is_file()
):
path = f"/{page.path}/{page.spa}"
return super().translate_path(path)
def __check_update_request(self) -> str | None:
if not self._has_header(self.TOKEN_HEADER):
self.send_error(
http.HTTPStatus.BAD_REQUEST, f"No {self.TOKEN_HEADER} header in request"
)
return None
if not self.token_manager.is_valid(self.token):
self.send_error(http.HTTPStatus.UNAUTHORIZED, "Invalid token")
return None
if (sub_path := self.__get_path(self.path, self.UPDATE_PATH_REGEX)) is None:
self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid path")
return None
if not self.token_manager.is_valid_for_path(self.token, sub_path):
self.send_error(http.HTTPStatus.FORBIDDEN, "Path forbidden for this token")
return None
return sub_path
def __check_put_request(self) -> str | None:
if (path := self.__check_update_request()) is None:
return None
if self.has_request_host and self.has_request_host_only:
self.send_error(
http.HTTPStatus.BAD_REQUEST,
f"Cannot use {self.HOST_ONLY_HEADER} with {self.HOST_HEADER}",
)
return None
if self.has_target_host and not self.__valid_host(self.target_host):
self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid requested host")
return None
if self.has_target_proxy and self.has_target_redirect:
self.send_error(
http.HTTPStatus.BAD_REQUEST,
f"Cannot use {self.PROXY_HEADER} with {self.REDIRECT_HEADER}",
)
return None
if (
self.has_target_host
and (page := self.registry.get_from_host(self.target_host)) is not None
and page.path != path
):
self.send_error(http.HTTPStatus.FORBIDDEN, "Host already taken")
return None
return path
def __get_path(self, path: str, regex: re.Pattern) -> str | None:
if (match := regex.match(path.lower())) is not None:
return match.group(1)
return None
def __valid_host(self, host: str) -> bool:
return (
all(self.HOST_PART_REGEX.fullmatch(part) for part in host.split("."))
and len(host) < 256
)
def __get_page(self, src_path: str) -> Page | None:
if self.host == self.default_host:
if (
(path := self.__get_path(src_path, self.GET_PATH_REGEX))
and (page := self.registry.get_from_path(path)) is not None
and not page.host_only
):
return page
return None
return self.registry.get_from_host(self.host)
class UpgradeHandler(RequestHandler):
server_version = "StaplerUpgradeServer/" + PKG_VERSION
def do_HEAD(self) -> None:
self._pre_log_request()
self.send_redirect(f"https://{self.host}{self.path}")
def do_GET(self) -> None:
if self.path.startswith(self.CERTBOT_CHALLENGE_PATH):
super().do_GET()
else:
self.do_HEAD()