feat: multiple hosts for same http server
This commit is contained in:
@@ -62,7 +62,7 @@ docker-build: ## docker build
|
|||||||
|
|
||||||
.PHONY: docker-run
|
.PHONY: docker-run
|
||||||
docker-run: docker-build ## docker run
|
docker-run: docker-build ## docker run
|
||||||
@$(DOCKER) run -it -p $(PORT):8080 -v ./data:/data $(DOCKER_TAG) --token $(TOKEN) --host localhost:$(PORT) --debug
|
@$(DOCKER) run -it -p $(PORT):8080 -v ./data:/data $(DOCKER_TAG) --token $(TOKEN) --debug
|
||||||
|
|
||||||
# ACTIONS
|
# ACTIONS
|
||||||
|
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ curl -X DELETE \
|
|||||||
- [x] certbot/self-signed create/renew in specific dir
|
- [x] certbot/self-signed create/renew in specific dir
|
||||||
- [x] better logger
|
- [x] better logger
|
||||||
- [ ] renew command
|
- [ ] renew command
|
||||||
- [ ] https mode w/ multiple hosts
|
- [x] https mode w/ multiple hosts
|
||||||
- [ ] restart command (on new/deleted host)
|
- [ ] restart command (on new/deleted host)
|
||||||
- [ ] proper doc
|
- [ ] proper doc
|
||||||
- [ ] log visits (and store accross sessions)
|
- [ ] log visits (and store accross sessions)
|
||||||
|
|||||||
+41
-12
@@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import shutil
|
import shutil
|
||||||
|
import ssl
|
||||||
import subprocess
|
import subprocess
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
@@ -49,19 +50,21 @@ class CertManager:
|
|||||||
return True
|
return True
|
||||||
return created or self.__create_self_signed(host)
|
return created or self.__create_self_signed(host)
|
||||||
|
|
||||||
def get_pem(self, host: str) -> pathlib.Path | None:
|
def get_cert(self, host: str) -> pathlib.Path:
|
||||||
if self.__exists_certbot(host):
|
if self.__exists_certbot(host):
|
||||||
return self.__certbot_file(host, self.CRT_FILE)
|
return self.__certbot_file(host, self.CRT_FILE)
|
||||||
if self.__exists_self_signed(host):
|
if self.__exists_self_signed(host):
|
||||||
return self.__self_signed_file(host, self.CRT_FILE)
|
return self.__self_signed_file(host, self.CRT_FILE)
|
||||||
return None
|
msg = "Cannot get cert file for %s"
|
||||||
|
raise CertManagerError(msg, host)
|
||||||
|
|
||||||
def get_key(self, host: str) -> pathlib.Path | None:
|
def get_key(self, host: str) -> pathlib.Path:
|
||||||
if self.__exists_certbot(host):
|
if self.__exists_certbot(host):
|
||||||
return self.__certbot_file(host, self.KEY_FILE)
|
return self.__certbot_file(host, self.KEY_FILE)
|
||||||
if self.__exists_self_signed(host):
|
if self.__exists_self_signed(host):
|
||||||
return self.__self_signed_file(host, self.KEY_FILE)
|
return self.__self_signed_file(host, self.KEY_FILE)
|
||||||
return None
|
msg = "Cannot get key file for %s"
|
||||||
|
raise CertManagerError(msg, host)
|
||||||
|
|
||||||
def __self_signed_file(self, host: str, file: str) -> pathlib.Path:
|
def __self_signed_file(self, host: str, file: str) -> pathlib.Path:
|
||||||
return self.self_signed_path / host / file
|
return self.self_signed_path / host / file
|
||||||
@@ -83,9 +86,6 @@ class CertManager:
|
|||||||
cert_path = self.self_signed_path / host
|
cert_path = self.self_signed_path / host
|
||||||
if not cert_path.exists():
|
if not cert_path.exists():
|
||||||
cert_path.mkdir(parents=True)
|
cert_path.mkdir(parents=True)
|
||||||
cert_host: str = host
|
|
||||||
if ":" in host:
|
|
||||||
cert_host = host.split(":", maxsplit=2)[0]
|
|
||||||
try:
|
try:
|
||||||
# openssl req -new -newkey rsa:2048 -days 30 -nodes -x509 -keyout server.key -out server.crt
|
# openssl req -new -newkey rsa:2048 -days 30 -nodes -x509 -keyout server.key -out server.crt
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
@@ -104,7 +104,7 @@ class CertManager:
|
|||||||
"-out",
|
"-out",
|
||||||
cert_path / "fullchain.pem",
|
cert_path / "fullchain.pem",
|
||||||
"-subj",
|
"-subj",
|
||||||
f"/C=/ST=/L=/O=/OU=/CN={cert_host}",
|
f"/C=/ST=/L=/O=/OU=/CN={host}",
|
||||||
],
|
],
|
||||||
check=True,
|
check=True,
|
||||||
)
|
)
|
||||||
@@ -134,9 +134,6 @@ class CertManager:
|
|||||||
return binary_path
|
return binary_path
|
||||||
|
|
||||||
def __create_certbot(self, host: str) -> bool:
|
def __create_certbot(self, host: str) -> bool:
|
||||||
cert_host: str = host
|
|
||||||
if ":" in host:
|
|
||||||
cert_host = host.split(":", maxsplit=2)[0]
|
|
||||||
try:
|
try:
|
||||||
# certonly -v --webroot --webroot-path=/var/www/certbot --agree-tos --no-eff-email -n --force-renewal --expand
|
# certonly -v --webroot --webroot-path=/var/www/certbot --agree-tos --no-eff-email -n --force-renewal --expand
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
@@ -151,7 +148,7 @@ class CertManager:
|
|||||||
"--cert-name",
|
"--cert-name",
|
||||||
host,
|
host,
|
||||||
"--domain",
|
"--domain",
|
||||||
cert_host,
|
host,
|
||||||
],
|
],
|
||||||
check=True,
|
check=True,
|
||||||
)
|
)
|
||||||
@@ -160,3 +157,35 @@ class CertManager:
|
|||||||
self.logger.exception("Could not create certbot certificate for %s", host)
|
self.logger.exception("Could not create certbot certificate for %s", host)
|
||||||
return False
|
return False
|
||||||
return self.__exists_certbot(host)
|
return self.__exists_certbot(host)
|
||||||
|
|
||||||
|
def get_https_context(self, default_host: str) -> ssl.SSLContext | None:
|
||||||
|
if not self.exists(default_host):
|
||||||
|
self.logger.warning("Cannot create HTTPS context for %s", default_host)
|
||||||
|
return None
|
||||||
|
cert_file = self.get_cert(default_host)
|
||||||
|
key_file = self.get_key(default_host)
|
||||||
|
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||||
|
context.load_cert_chain(
|
||||||
|
cert_file,
|
||||||
|
key_file,
|
||||||
|
)
|
||||||
|
context.sni_callback = self.__sni_callback
|
||||||
|
return context
|
||||||
|
|
||||||
|
def __sni_callback(
|
||||||
|
self, socket: ssl.SSLObject, host: str, _: ssl.SSLContext, /
|
||||||
|
) -> None | int:
|
||||||
|
if host is None:
|
||||||
|
return
|
||||||
|
if not self.exists(host) and not self.create_or_update(host):
|
||||||
|
msg = "Could not get certificate for %s"
|
||||||
|
raise CertManagerError(msg, host)
|
||||||
|
new_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||||
|
cert_file = self.get_cert(host)
|
||||||
|
key_file = self.get_key(host)
|
||||||
|
new_context.load_cert_chain(
|
||||||
|
cert_file,
|
||||||
|
key_file,
|
||||||
|
)
|
||||||
|
socket.context = new_context
|
||||||
|
return None
|
||||||
|
|||||||
+10
-16
@@ -20,7 +20,6 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||||||
CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge"
|
CERTBOT_CHALLENGE_PATH = "/.well-known/acme-challenge"
|
||||||
PATH_REGEX = re.compile(r"^\/([\w-]+)\/")
|
PATH_REGEX = re.compile(r"^\/([\w-]+)\/")
|
||||||
HOST_PART_REGEX = re.compile(r"^([a-zA-Z0-9]|[a-zA-Z0-9]*[a-zA-Z0-9][a-zA-Z0-9])$")
|
HOST_PART_REGEX = re.compile(r"^([a-zA-Z0-9]|[a-zA-Z0-9]*[a-zA-Z0-9][a-zA-Z0-9])$")
|
||||||
HOST_PORT_REGEX = re.compile(r"^\d{2,5}$")
|
|
||||||
|
|
||||||
@typing.override
|
@typing.override
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -31,7 +30,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||||||
**kwargs: dict[str, typing.Any],
|
**kwargs: dict[str, typing.Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.logger = logging.getLogger(self.__class__.__name__)
|
self.logger = logging.getLogger(self.__class__.__name__)
|
||||||
self.default_host = params.host
|
self.default_host = params.host.split(":", maxsplit=2)[0]
|
||||||
self.token = params.token
|
self.token = params.token
|
||||||
self.data_dir = data_dir.DataDir(params.data_dir)
|
self.data_dir = data_dir.DataDir(params.data_dir)
|
||||||
self.max_size_bytes = params.max_size_bytes
|
self.max_size_bytes = params.max_size_bytes
|
||||||
@@ -114,12 +113,11 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||||||
return self.certbot_www + path.removeprefix(self.CERTBOT_CHALLENGE_PATH)
|
return self.certbot_www + path.removeprefix(self.CERTBOT_CHALLENGE_PATH)
|
||||||
if (page := self.registry.get_from_host(self.__get_host())) is not None:
|
if (page := self.registry.get_from_host(self.__get_host())) is not None:
|
||||||
path = f"/{page.path}" + path
|
path = f"/{page.path}" + path
|
||||||
path = super().translate_path(path)
|
if self.__get_subpath(path) is None: # not a valid path
|
||||||
if self.__get_subpath() is None: # not a valid path
|
|
||||||
return ""
|
return ""
|
||||||
if pathlib.Path(path).name.startswith("."): # hidden files
|
if pathlib.Path(path).name.startswith("."): # hidden files
|
||||||
return ""
|
return ""
|
||||||
return path
|
return super().translate_path(path)
|
||||||
|
|
||||||
@typing.override
|
@typing.override
|
||||||
def send_error(
|
def send_error(
|
||||||
@@ -186,25 +184,25 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||||||
if self.headers["X-Token"] != self.token:
|
if self.headers["X-Token"] != self.token:
|
||||||
self.send_error(http.HTTPStatus.UNAUTHORIZED, "Invalid token")
|
self.send_error(http.HTTPStatus.UNAUTHORIZED, "Invalid token")
|
||||||
return None
|
return None
|
||||||
if (sub_path := self.__get_subpath_full()) is None:
|
if (sub_path := self.__get_subpath_full(self.path)) is None:
|
||||||
self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid path")
|
self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid path")
|
||||||
return None
|
return None
|
||||||
return sub_path
|
return sub_path
|
||||||
|
|
||||||
def __get_subpath(self) -> str | None:
|
def __get_subpath(self, path: str) -> str | None:
|
||||||
if (match := self.PATH_REGEX.match(self.path)) is not None:
|
if (match := self.PATH_REGEX.match(path)) is not None:
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __get_subpath_full(self) -> str | None:
|
def __get_subpath_full(self, path: str) -> str | None:
|
||||||
if (match := self.PATH_REGEX.fullmatch(self.path)) is not None:
|
if (match := self.PATH_REGEX.fullmatch(path)) is not None:
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __get_host(self) -> str:
|
def __get_host(self) -> str:
|
||||||
if self.headers["Host"] is None:
|
if self.headers["Host"] is None:
|
||||||
return self.default_host
|
return self.default_host
|
||||||
return self.headers["Host"]
|
return self.headers["Host"].split(":", maxsplit=2)[0]
|
||||||
|
|
||||||
def __get_length(self) -> int:
|
def __get_length(self) -> int:
|
||||||
if not self.headers["Content-Length"]:
|
if not self.headers["Content-Length"]:
|
||||||
@@ -212,11 +210,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||||||
return int(self.headers["Content-Length"])
|
return int(self.headers["Content-Length"])
|
||||||
|
|
||||||
def __valid_host(self, host: str) -> bool:
|
def __valid_host(self, host: str) -> bool:
|
||||||
hostname, port = host.split(":", maxsplit=2)
|
return all(self.HOST_PART_REGEX.fullmatch(part) for part in host.split("."))
|
||||||
for part in hostname.split("."):
|
|
||||||
if not self.HOST_PART_REGEX.fullmatch(part):
|
|
||||||
return False
|
|
||||||
return not (port and len(port) and not self.HOST_PORT_REGEX.fullmatch(port))
|
|
||||||
|
|
||||||
def __server_index(self) -> None:
|
def __server_index(self) -> None:
|
||||||
self.__send_basic_body(self.server_version + "\n")
|
self.__send_basic_body(self.server_version + "\n")
|
||||||
|
|||||||
+1
-1
@@ -13,7 +13,7 @@ class Page:
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
out = self.get_url_path()
|
out = self.get_url_path()
|
||||||
if self.host is not None:
|
if self.host is not None:
|
||||||
out += f" [http://{self.host}/]"
|
out += f" [{self.host}]"
|
||||||
if not self.with_index:
|
if not self.with_index:
|
||||||
out += " (no index)"
|
out += " (no index)"
|
||||||
return out
|
return out
|
||||||
|
|||||||
+12
-3
@@ -5,7 +5,7 @@ import os
|
|||||||
from . import project
|
from . import project
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Parameters:
|
class Parameters:
|
||||||
port: int
|
port: int
|
||||||
host: str
|
host: str
|
||||||
@@ -18,6 +18,7 @@ class Parameters:
|
|||||||
self_signed_path: str
|
self_signed_path: str
|
||||||
with_certbot: bool
|
with_certbot: bool
|
||||||
with_certificates: bool
|
with_certificates: bool
|
||||||
|
https: bool
|
||||||
debug: bool
|
debug: bool
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -160,8 +161,8 @@ def parse_parameters() -> Parameters:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--certbot",
|
"--certbot",
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
help="Use Certbot (default: true)",
|
help="Use Certbot (default: false)",
|
||||||
default=True,
|
default=False,
|
||||||
dest="with_certbot",
|
dest="with_certbot",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -171,5 +172,13 @@ def parse_parameters() -> Parameters:
|
|||||||
default=True,
|
default=True,
|
||||||
dest="with_certificates",
|
dest="with_certificates",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--https",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
help="Use https (implies --certificates) (default: true)",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
if args.https:
|
||||||
|
args.with_certificates = True
|
||||||
return Parameters.from_namespace(args)
|
return Parameters.from_namespace(args)
|
||||||
|
|||||||
+2
-3
@@ -12,7 +12,6 @@ class Registry:
|
|||||||
self.logger = logging.getLogger(self.__class__.__name__)
|
self.logger = logging.getLogger(self.__class__.__name__)
|
||||||
self.pages: dict[str, page.Page] = {}
|
self.pages: dict[str, page.Page] = {}
|
||||||
self.data_dir = data_dir.DataDir(params.data_dir)
|
self.data_dir = data_dir.DataDir(params.data_dir)
|
||||||
self.prefix = f"http://{params.host}"
|
|
||||||
|
|
||||||
def load_pages(self) -> None:
|
def load_pages(self) -> None:
|
||||||
self.pages = {}
|
self.pages = {}
|
||||||
@@ -28,7 +27,7 @@ class Registry:
|
|||||||
self.data_dir.has_index(path),
|
self.data_dir.has_index(path),
|
||||||
self.data_dir.get_host(path),
|
self.data_dir.get_host(path),
|
||||||
)
|
)
|
||||||
self.logger.info("Updated %s%s", self.prefix, str(self.pages[path]))
|
self.logger.info("Updated %s", self.pages[path])
|
||||||
|
|
||||||
def set_host(self, path: str, host: str) -> None:
|
def set_host(self, path: str, host: str) -> None:
|
||||||
self.data_dir.set_host(path, host)
|
self.data_dir.set_host(path, host)
|
||||||
@@ -37,7 +36,7 @@ class Registry:
|
|||||||
def remove(self, path: str) -> None:
|
def remove(self, path: str) -> None:
|
||||||
page = self.pages[path]
|
page = self.pages[path]
|
||||||
del self.pages[path]
|
del self.pages[path]
|
||||||
self.logger.info("Removed %s%s", self.prefix, str(page))
|
self.logger.info("Removed %s", page)
|
||||||
|
|
||||||
def get_from_host(self, host: str) -> page.Page | None:
|
def get_from_host(self, host: str) -> page.Page | None:
|
||||||
for p in self.pages.values():
|
for p in self.pages.values():
|
||||||
|
|||||||
+21
-9
@@ -15,10 +15,7 @@ class StaplerServer:
|
|||||||
self.params = params
|
self.params = params
|
||||||
self.registry = registry.Registry(params)
|
self.registry = registry.Registry(params)
|
||||||
self.cert_manager = cert.CertManager(params)
|
self.cert_manager = cert.CertManager(params)
|
||||||
self.server = http.server.ThreadingHTTPServer(
|
self.default_host = params.host.split(":", maxsplit=2)[0]
|
||||||
(params.bind, params.port),
|
|
||||||
self.request_handler,
|
|
||||||
)
|
|
||||||
|
|
||||||
def request_handler(self, *args: typing.Any) -> http.server.BaseHTTPRequestHandler:
|
def request_handler(self, *args: typing.Any) -> http.server.BaseHTTPRequestHandler:
|
||||||
return handler.RequestHandler(*args, params=self.params, registry=self.registry)
|
return handler.RequestHandler(*args, params=self.params, registry=self.registry)
|
||||||
@@ -27,19 +24,34 @@ class StaplerServer:
|
|||||||
self.logger.info("Starting up...")
|
self.logger.info("Starting up...")
|
||||||
self.registry.load_pages()
|
self.registry.load_pages()
|
||||||
if self.params.with_certificates:
|
if self.params.with_certificates:
|
||||||
self.cert_manager.init([self.params.host, *self.registry.get_hosts()])
|
self.cert_manager.init([self.default_host, *self.registry.get_hosts()])
|
||||||
|
|
||||||
|
def __create_https_context(self, server: http.server.HTTPServer) -> bool:
|
||||||
|
https = False
|
||||||
|
if (
|
||||||
|
context := self.cert_manager.get_https_context(self.default_host)
|
||||||
|
) is not None:
|
||||||
|
https = True
|
||||||
|
server.socket = context.wrap_socket(server.socket, server_side=True)
|
||||||
|
return https
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
self.logger.info("Version %s", project.get_version())
|
self.logger.info("Version %s", project.get_version())
|
||||||
self.__startup()
|
self.__startup()
|
||||||
|
server = http.server.ThreadingHTTPServer(
|
||||||
|
(self.params.bind, self.params.port),
|
||||||
|
self.request_handler,
|
||||||
|
)
|
||||||
|
https = self.params.https and self.__create_https_context(server)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"Listening on %s:%d...",
|
"Listening on %s:%d...",
|
||||||
self.server.server_address[0],
|
server.server_address[0],
|
||||||
self.server.server_port,
|
server.server_port,
|
||||||
)
|
)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"Server up and ready on http://%s",
|
"Server up and ready on %s://%s",
|
||||||
|
"https" if https else "http",
|
||||||
self.params.host,
|
self.params.host,
|
||||||
)
|
)
|
||||||
with contextlib.suppress(KeyboardInterrupt):
|
with contextlib.suppress(KeyboardInterrupt):
|
||||||
self.server.serve_forever()
|
server.serve_forever()
|
||||||
|
|||||||
Reference in New Issue
Block a user