Files
stapler/src/handlers.py
T
2026-04-20 23:34:39 +02:00

516 lines
18 KiB
Python

import abc
import http
import http.cookiejar
import http.server
import io
import logging
import os
import pathlib
import re
import tarfile
import typing
import urllib.parse
import requests
from . import STAPLER_ASCII, logs, project
from .data_dir import DataDir
if typing.TYPE_CHECKING:
from .cert_manager import CertManager
from .page import Page
from .params import Parameters
from .registry import Registry
from .token_manager import TokenManager
class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler):
@typing.override
def __init__(
self,
*args: typing.Any,
params: Parameters,
**kwargs: dict[str, typing.Any],
) -> None:
self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self.default_host: str = params.host.split(":", maxsplit=2)[0]
self.out_size: int = 0
self.__host: str | None = None
self.__in_size: int | None = None
self.https: bool = params.https
super().__init__(*args, **kwargs)
@typing.override
def send_error(
self,
code: int,
message: str | None = None,
explain: str | None = None,
) -> None:
shortmsg, longmsg = self.responses[code]
if message is None:
message = shortmsg
if explain is None:
explain = longmsg
if "text/" in self._get_header("Accept"):
self.send_basic_body(
f"{code} {message}\n{explain}\n\n{self.server_signature()}",
code=code,
message=message,
)
else:
self.send_status_only(code, message)
@typing.override
def log_message(self, format: str, *args: typing.Any) -> None: # pragma: no cover
fmt = "%s - " + format
self.logger.info(fmt, self.address_string(), *args)
@typing.override
def log_error(self, format: str, *args: typing.Any) -> None: # pragma: no cover
fmt = "%s - " + format
self.logger.error(fmt, self.address_string(), *args)
@typing.override
def log_request(self, code: str = "?", size: str = "-") -> None: # ty:ignore[invalid-method-override] # pragma: no cover
if isinstance(code, http.HTTPStatus):
color = logs.TermColor.RED
if 100 <= code < 200:
color = logs.TermColor.CYAN
if 200 <= code < 300:
color = logs.TermColor.GREEN
elif 300 <= code < 400:
color = logs.TermColor.BLUE
elif 400 <= code < 500:
color = logs.TermColor.YELLOW
code = color + str(code.value) + logs.TermColor.RESET
if size == "" and self.out_size > 0:
size = str(self.out_size)
args = (code, self.address_string(), self.host, self.requestline)
fmt = "%s - %s - %s - %s"
if size != "":
args = (*args, size)
fmt += " - %s"
self.logger.info(fmt, *args)
def send_basic_body(
self,
body: str,
content_type: str = "text/plain",
code: int = http.HTTPStatus.OK,
message: str | None = None,
) -> None:
encoded: bytes = body.encode()
self.out_size = len(encoded)
self.send_response(code, message)
self.send_header("Content-Type", f"{content_type}; charset=UTF-8")
self.send_header("Content-Length", str(len(encoded)))
self.end_headers()
self.wfile.write(encoded)
self.close_connection = True
def send_status_only(
self,
code: int,
message: str | None = None,
headers: dict[str, str] | None = None,
) -> None:
if headers is None:
headers = {}
self.send_response(code, message)
self.send_header("Content-Length", "0")
for header, value in headers.items():
self.send_header(header, value)
self.end_headers()
self.close_connection = True
def send_redirect(self, location: str) -> None:
self.send_status_only(
http.HTTPStatus.MOVED_PERMANENTLY,
headers={"Location": location},
)
def send_proxy(self, url: str) -> None:
headers = dict(self.headers)
headers["Host"] = (target_host := urllib.parse.urlparse(url).netloc)
headers["X-Real-IP"] = self.client_address[0]
headers["X-Forwarded-Host"] = self.host
headers["X-Forwarded-For"] = self.client_address[0]
headers["X-Forwarded-Proto"] = "https" if self.https else "http"
try:
body: bytes | None = None
if self.in_size > 0:
body = self.rfile.read(self.in_size)
response: requests.Response = requests.request(
self.command,
url,
data=body,
headers=headers,
allow_redirects=False,
timeout=480,
)
except Exception as e:
self.send_error(
http.HTTPStatus.BAD_GATEWAY, f"Could not reach {url}", explain=str(e)
)
return
self.send_response(response.status_code, response.reason)
for header, value in response.headers.items():
if header.lower() not in [
"content-length",
"content-encoding",
"transfer-encoding",
"server",
"date",
]:
self.send_header(header, value.replace(target_host, self.host))
self.send_header("Content-Length", str(out_size := len(response.content)))
self.end_headers()
if out_size > 0:
self.wfile.write(response.content)
self.close_connection = True
@property
def host(self) -> str:
if self.__host is None:
self.__host = self._get_host()
return self.__host
def _get_host(self) -> str:
host = self._get_header("Host", self.default_host)
return host.split(":", maxsplit=2)[0]
@property
def in_size(self) -> int:
if self.__in_size is None:
self.__in_size = self._get_length()
return self.__in_size
def _get_length(self) -> int:
return int(self._get_header("Content-Length", "0"))
def _get_header(self, key: str, default_value: str = "") -> str:
if self._has_header(key):
return self.headers[key]
return default_value
def _has_header(self, key: str) -> bool:
return (
hasattr(self, "headers")
and key in self.headers
and len(self.headers[key]) > 0
)
def _pre_log_request(self) -> None: # pragma: no cover
args = ("...", self.address_string(), self.host, self.requestline)
fmt = "%s - %s - %s - %s"
if self.in_size > 0:
args = (*args, self.in_size)
fmt += " - %s"
self.logger.debug(fmt, *args)
def server_signature(self) -> str:
return self.server_version + "\n\n" + STAPLER_ASCII + "\n"
class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
protocol_version = "HTTP/1.1"
server_version = "StaplerServer/" + project.get_version()
CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge"
UPDATE_PATH_REGEX = re.compile(r"^\/([\w-]+)\/?$")
GET_PATH_REGEX = re.compile(r"^\/([\w-]+)($|\/)")
HOST_PART_REGEX = re.compile(r"^([a-z0-9]|[a-z0-9][a-z0-9-]{,61}[a-z0-9])$")
AUTHORIZED_PATHS: typing.ClassVar[list[str]] = ["/favicon.ico"]
TOKEN_HEADER = "X-Token" # noqa: S105
HOST_HEADER = "X-Host"
REDIRECT_HEADER = "X-Redirect"
PROXY_HEADER = "X-Proxy"
@typing.override
def __init__(
self,
*args: typing.Any,
params: Parameters,
registry: Registry,
cert_manager: CertManager,
token_manager: TokenManager,
**kwargs: dict[str, typing.Any],
) -> None:
self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self.token_manager: TokenManager = token_manager
self.data_dir: DataDir = DataDir(params.data_dir)
self.max_size_bytes: int = params.max_size_bytes
self.registry: Registry = registry
self.cert_manager: CertManager = cert_manager
self.certbot_www: str = os.path.realpath(params.certbot_www)
self.__token: str | None = None
self.__target_host: str | None = None
self.__target_redirect: str | None = None
self.__target_proxy: str | None = None
super().__init__(*args, directory=params.data_dir, **kwargs, params=params) # ty:ignore[unknown-argument]
@property
def token(self) -> str:
if self.__token is None:
self.__token = self._get_header(self.TOKEN_HEADER)
return self.__token
@property
def has_token(self) -> bool:
return len(self.token) > 0
@property
def target_host(self) -> str:
if self.__target_host is None:
self.__target_host = self._get_header(self.HOST_HEADER).lower()
return self.__target_host
@property
def has_target_host(self) -> bool:
return len(self.target_host) > 0
@property
def target_redirect(self) -> str:
if self.__target_redirect is None:
self.__target_redirect = self._get_header(self.REDIRECT_HEADER).lower()
return self.__target_redirect
@property
def has_target_redirect(self) -> bool:
return len(self.target_redirect) > 0
@property
def target_proxy(self) -> str:
if self.__target_proxy is None:
self.__target_proxy = self._get_header(self.PROXY_HEADER).lower()
return self.__target_proxy
@property
def has_target_proxy(self) -> bool:
return len(self.target_proxy) > 0
@typing.override
def do_HEAD(self) -> None:
self._pre_log_request()
if not self._proxy_or_redirect():
super().do_HEAD()
@typing.override
def do_GET(self) -> None:
self._pre_log_request()
if self._proxy_or_redirect():
return None
if self.path == "/" and self.host == self.default_host:
return self.send_basic_body(self.server_signature())
return super().do_GET()
def do_PUT(self) -> None:
self._pre_log_request()
if self._proxy_or_redirect():
return None
if (path := self.__check_update_request()) is None:
return None
if self.has_target_host and not self.__valid_host(self.target_host):
return self.send_error(
http.HTTPStatus.BAD_REQUEST, "Invalid requested host"
)
if (
self.has_target_host
and (page := self.registry.get_from_host(self.target_host)) is not None
and page.path != path
):
return self.send_error(http.HTTPStatus.FORBIDDEN, "Host already taken")
if self.has_target_proxy and self.has_target_redirect:
return self.send_error(
http.HTTPStatus.BAD_REQUEST,
f"Cannot use {self.PROXY_HEADER} with {self.REDIRECT_HEADER}",
)
if self.has_target_redirect:
self._update_redirect(path)
elif self.has_target_proxy:
self._update_proxy(path)
else:
self._update_extract(path)
if self.has_target_host and self.cert_manager.create_or_update(
self.target_host
):
self.registry.set_host(path, self.target_host)
return None
def do_POST(self) -> None:
self.do_PUT() # be gentle on them
def do_PATCH(self) -> None:
self.do_PUT() # be gentle on them
def do_DELETE(self) -> None:
self._pre_log_request()
if self._proxy_or_redirect():
return None
if (path := self.__check_update_request()) is None:
return None
return self._update_remove(path)
def do_CONNECT(self) -> None:
self._pre_log_request()
if not self._proxy_or_redirect():
self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED)
def do_OPTIONS(self) -> None:
self._pre_log_request()
if not self._proxy_or_redirect():
self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED)
def do_TRACE(self) -> None:
self._pre_log_request()
if not self._proxy_or_redirect():
self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED)
def _update_extract(self, path: str) -> None:
if self.in_size == 0:
return self.send_error(http.HTTPStatus.LENGTH_REQUIRED, "No body found")
if self.in_size > self.max_size_bytes:
return self.send_error(
http.HTTPStatus.CONTENT_TOO_LARGE,
"Archive too large",
)
try:
file_bytes = io.BytesIO(self.rfile.read(self.in_size))
self.data_dir.extract_tar_bytes(path, file_bytes)
except tarfile.TarError:
return self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid tar archive")
except Exception as e:
return self.send_error(http.HTTPStatus.INTERNAL_SERVER_ERROR, str(e))
self.registry.add(path)
self.token_manager.set_token(path, self.token)
self.send_status_only(
http.HTTPStatus.CREATED,
f"Resource /{path}/ updated",
)
return None
def _update_redirect(self, path: str) -> None:
if self.in_size > 0:
return self.send_error(
http.HTTPStatus.BAD_REQUEST,
f"No content must be sent with {self.REDIRECT_HEADER}",
)
self.registry.set_redirect(path, self.target_redirect)
self.token_manager.set_token(path, self.token)
self.send_status_only(
http.HTTPStatus.CREATED,
f"Resource /{path}/ updated",
)
return None
def _update_proxy(self, path: str) -> None:
if self.in_size > 0:
return self.send_error(
http.HTTPStatus.BAD_REQUEST,
f"No content must be sent with {self.PROXY_HEADER}",
)
self.registry.set_proxy(path, self.target_proxy)
self.token_manager.set_token(path, self.token)
self.send_status_only(
http.HTTPStatus.CREATED,
f"Resource /{path}/ updated",
)
return None
def _update_remove(self, path: str) -> None:
if not self.data_dir.exists(path):
self.send_error(http.HTTPStatus.NOT_FOUND, "Not found")
return None
try:
self.data_dir.remove(path)
except Exception as e:
return self.send_error(http.HTTPStatus.INTERNAL_SERVER_ERROR, str(e))
self.send_status_only(
http.HTTPStatus.NO_CONTENT,
f"Resource /{path}/ removed",
)
self.registry.remove(path)
return None
def _proxy_or_redirect(self) -> bool:
if self.has_token or self.path.startswith(self.CERTBOT_CHALLENGE_PATH):
return False
if (page := self.__get_page(self.path)) is None:
return False
if page.redirect is not None:
self.send_redirect(page.redirect)
return True
if page.proxy is not None:
if self.host == self.default_host:
self.send_proxy(page.proxy + self.path.removeprefix(f"/{page.path}"))
else:
self.send_proxy(page.proxy + self.path)
return True
return False
@typing.override
def list_directory(self, *_: typing.Any, **__: typing.Any) -> None:
"""Disable default directory listing."""
self.send_error(http.HTTPStatus.NOT_FOUND, "File not found")
@typing.override
def translate_path(self, path: str) -> str:
if path.startswith(self.CERTBOT_CHALLENGE_PATH):
return self.certbot_www + path
page = self.__get_page(path)
if page is None:
if path in self.AUTHORIZED_PATHS:
return super().translate_path(path)
return ""
if self.host != self.default_host:
path = f"/{page.path}" + path
if pathlib.Path(path).name.startswith("."): # hidden files
return ""
return super().translate_path(path)
def __check_update_request(self) -> str | None:
if not self._has_header(self.TOKEN_HEADER):
self.send_error(
http.HTTPStatus.BAD_REQUEST, f"No {self.TOKEN_HEADER} header in request"
)
return None
if not self.token_manager.is_valid(self.token):
self.send_error(http.HTTPStatus.UNAUTHORIZED, "Invalid token")
return None
if (sub_path := self.__get_path(self.path, self.UPDATE_PATH_REGEX)) is None:
self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid path")
return None
if not self.token_manager.is_valid_for_path(self.token, sub_path):
self.send_error(http.HTTPStatus.FORBIDDEN, "Path forbidden for this token")
return None
return sub_path
def __get_path(self, path: str, regex: re.Pattern) -> str | None:
if (match := regex.match(path.lower())) is not None:
return match.group(1)
return None
def __valid_host(self, host: str) -> bool:
return (
all(self.HOST_PART_REGEX.fullmatch(part) for part in host.split("."))
and len(host) < 256
)
def __get_page(self, src_path: str) -> Page | None:
if self.host == self.default_host:
if path := self.__get_path(src_path, self.GET_PATH_REGEX):
return self.registry.get_from_path(path)
return None
return self.registry.get_from_host(self.host)
class UpgradeHandler(BaseHandler):
server_version = "StaplerUpgradeServer/" + project.get_version()
def do_HEAD(self) -> None:
self._pre_log_request()
self.send_redirect(f"https://{self.host}{self.path}")
def do_GET(self) -> None:
self.do_HEAD()