Files
stapler/stapler/server.py
T

166 lines
5.3 KiB
Python

import contextlib
import http.server
import logging
import ssl
import threading
import time
import typing
from . import (
STAPLER_ASCII,
)
from .cert_manager import CertManager
from .data_dir import DataDir
from .handlers import RequestHandler, UpgradeHandler
from .params import Parameters
from .registry import Registry
from .token_manager import TokenManager
if typing.TYPE_CHECKING:
from .params import Parameters
class StaplerServer:
__slots__ = [
"cert_manager",
"data_dir",
"default_host",
"https",
"logger",
"params",
"registry",
"server",
"token_manager",
]
def __init__(self, params: Parameters) -> None:
self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self.params: Parameters = params
self.registry: Registry = Registry(params)
self.cert_manager: CertManager = CertManager(params)
self.token_manager: TokenManager = TokenManager(params, self.registry)
self.data_dir: DataDir = DataDir(params.data_dir)
self.default_host: str = params.host.split(":", maxsplit=2)[0]
self.server: http.server.ThreadingHTTPServer | None = None
def __get_all_hosts(self) -> list[str]:
return [self.default_host, *self.registry.get_hosts()]
def __startup(self) -> None:
self.logger.info("Starting up...")
self.registry.load_pages()
if self.params.with_certificates:
self.cert_manager.init(self.__get_all_hosts())
self.data_dir.init()
self.token_manager.init()
def __request_handler(
self, *args: typing.Any
) -> http.server.BaseHTTPRequestHandler:
return RequestHandler(
*args,
params=self.params,
registry=self.registry,
token_manager=self.token_manager,
)
def __create_base_server(self) -> http.server.ThreadingHTTPServer:
if self.params.https:
server = http.server.ThreadingHTTPServer(
(
self.params.bind,
self.params.https_port,
),
self.__request_handler,
)
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
server.socket = context.wrap_socket(server.socket, server_side=True)
context.sni_callback = self.cert_manager.sni_callback
else:
server = http.server.ThreadingHTTPServer(
(
self.params.bind,
self.params.http_port,
),
self.__request_handler,
)
self.logger.info(
"Server listening on %s:%d...",
server.server_address[0],
server.server_port,
)
return server
def __upgrade_handler(
self, *args: typing.Any
) -> http.server.BaseHTTPRequestHandler:
return UpgradeHandler(
*args,
params=self.params,
registry=self.registry,
token_manager=self.token_manager,
)
def __start_upgrade_server(self) -> http.server.ThreadingHTTPServer:
server = http.server.ThreadingHTTPServer(
(
self.params.bind,
self.params.http_port,
),
self.__upgrade_handler,
)
self.logger.info(
"Upgrade server listening on %s:%d...",
server.server_address[0],
server.server_port,
)
threading.Thread(target=server.serve_forever, daemon=True).start()
return server
def __token_manager_background(self) -> None:
with contextlib.suppress(KeyboardInterrupt):
while True:
if self.token_manager.detect_file_change():
self.token_manager.init()
time.sleep(1)
def __start_background_tasks(self) -> None:
threading.Thread(target=self.__token_manager_background, daemon=True).start()
def run(self) -> int:
for line in STAPLER_ASCII.split("\n"):
self.logger.debug(line.ljust(36))
self.__startup()
self.server = self.__create_base_server()
upgrade_server = self.__start_upgrade_server() if self.params.https else None
self.logger.info(
"Server up and ready on %s://%s",
"https" if self.params.https else "http",
self.params.host,
)
self.__start_background_tasks()
with contextlib.suppress(KeyboardInterrupt):
self.server.serve_forever()
self.logger.info("Shutting down...")
if upgrade_server is not None:
upgrade_server.shutdown()
return 0
def renew(self) -> int:
self.logger.info("Starting up...")
if not self.params.with_certificates:
self.logger.warning("Cannot renew without certificates")
return 1
self.registry.load_pages()
self.cert_manager.init(self.__get_all_hosts())
for host in self.__get_all_hosts():
self.cert_manager.create_or_update(host)
return 0
def token(self) -> int:
self.logger.info("Starting up...")
self.registry.load_pages()
self.token_manager.init()
self.token_manager.new_token()
return 0