166 lines
5.3 KiB
Python
166 lines
5.3 KiB
Python
import contextlib
|
|
import http.server
|
|
import logging
|
|
import ssl
|
|
import threading
|
|
import time
|
|
import typing
|
|
|
|
from . import (
|
|
STAPLER_ASCII,
|
|
)
|
|
from .cert_manager import CertManager
|
|
from .data_dir import DataDir
|
|
from .handlers import RequestHandler, UpgradeHandler
|
|
from .params import Parameters
|
|
from .registry import Registry
|
|
from .token_manager import TokenManager
|
|
|
|
if typing.TYPE_CHECKING:
|
|
from .params import Parameters
|
|
|
|
|
|
class StaplerServer:
|
|
__slots__ = [
|
|
"cert_manager",
|
|
"data_dir",
|
|
"default_host",
|
|
"https",
|
|
"logger",
|
|
"params",
|
|
"registry",
|
|
"server",
|
|
"token_manager",
|
|
]
|
|
|
|
def __init__(self, params: Parameters) -> None:
|
|
self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
|
|
self.params: Parameters = params
|
|
self.registry: Registry = Registry(params)
|
|
self.cert_manager: CertManager = CertManager(params)
|
|
self.token_manager: TokenManager = TokenManager(params, self.registry)
|
|
self.data_dir: DataDir = DataDir(params.data_dir)
|
|
self.default_host: str = params.host.split(":", maxsplit=2)[0]
|
|
self.server: http.server.ThreadingHTTPServer | None = None
|
|
|
|
def __get_all_hosts(self) -> list[str]:
|
|
return [self.default_host, *self.registry.get_hosts()]
|
|
|
|
def __startup(self) -> None:
|
|
self.logger.info("Starting up...")
|
|
self.registry.load_pages()
|
|
if self.params.with_certificates:
|
|
self.cert_manager.init(self.__get_all_hosts())
|
|
self.data_dir.init()
|
|
self.token_manager.init()
|
|
|
|
def __request_handler(
|
|
self, *args: typing.Any
|
|
) -> http.server.BaseHTTPRequestHandler:
|
|
return RequestHandler(
|
|
*args,
|
|
params=self.params,
|
|
registry=self.registry,
|
|
token_manager=self.token_manager,
|
|
)
|
|
|
|
def __create_base_server(self) -> http.server.ThreadingHTTPServer:
|
|
if self.params.https:
|
|
server = http.server.ThreadingHTTPServer(
|
|
(
|
|
self.params.bind,
|
|
self.params.https_port,
|
|
),
|
|
self.__request_handler,
|
|
)
|
|
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|
server.socket = context.wrap_socket(server.socket, server_side=True)
|
|
context.sni_callback = self.cert_manager.sni_callback
|
|
else:
|
|
server = http.server.ThreadingHTTPServer(
|
|
(
|
|
self.params.bind,
|
|
self.params.http_port,
|
|
),
|
|
self.__request_handler,
|
|
)
|
|
self.logger.info(
|
|
"Server listening on %s:%d...",
|
|
server.server_address[0],
|
|
server.server_port,
|
|
)
|
|
return server
|
|
|
|
def __upgrade_handler(
|
|
self, *args: typing.Any
|
|
) -> http.server.BaseHTTPRequestHandler:
|
|
return UpgradeHandler(
|
|
*args,
|
|
params=self.params,
|
|
registry=self.registry,
|
|
token_manager=self.token_manager,
|
|
)
|
|
|
|
def __start_upgrade_server(self) -> http.server.ThreadingHTTPServer:
|
|
server = http.server.ThreadingHTTPServer(
|
|
(
|
|
self.params.bind,
|
|
self.params.http_port,
|
|
),
|
|
self.__upgrade_handler,
|
|
)
|
|
self.logger.info(
|
|
"Upgrade server listening on %s:%d...",
|
|
server.server_address[0],
|
|
server.server_port,
|
|
)
|
|
threading.Thread(target=server.serve_forever, daemon=True).start()
|
|
return server
|
|
|
|
def __token_manager_background(self) -> None:
|
|
with contextlib.suppress(KeyboardInterrupt):
|
|
while True:
|
|
if self.token_manager.detect_file_change():
|
|
self.token_manager.init()
|
|
time.sleep(1)
|
|
|
|
def __start_background_tasks(self) -> None:
|
|
threading.Thread(target=self.__token_manager_background, daemon=True).start()
|
|
|
|
def run(self) -> int:
|
|
for line in STAPLER_ASCII.split("\n"):
|
|
self.logger.debug(line.ljust(36))
|
|
self.__startup()
|
|
self.server = self.__create_base_server()
|
|
upgrade_server = self.__start_upgrade_server() if self.params.https else None
|
|
self.logger.info(
|
|
"Server up and ready on %s://%s",
|
|
"https" if self.params.https else "http",
|
|
self.params.host,
|
|
)
|
|
self.__start_background_tasks()
|
|
with contextlib.suppress(KeyboardInterrupt):
|
|
self.server.serve_forever()
|
|
self.logger.info("Shutting down...")
|
|
if upgrade_server is not None:
|
|
upgrade_server.shutdown()
|
|
return 0
|
|
|
|
def renew(self) -> int:
|
|
self.logger.info("Starting up...")
|
|
if not self.params.with_certificates:
|
|
self.logger.warning("Cannot renew without certificates")
|
|
return 1
|
|
self.registry.load_pages()
|
|
self.cert_manager.init(self.__get_all_hosts())
|
|
for host in self.__get_all_hosts():
|
|
self.cert_manager.create_or_update(host)
|
|
return 0
|
|
|
|
def token(self) -> int:
|
|
self.logger.info("Starting up...")
|
|
self.registry.load_pages()
|
|
self.token_manager.init()
|
|
self.token_manager.new_token()
|
|
return 0
|