feat: X-Proxy

This commit is contained in:
2026-04-20 14:41:27 +02:00
committed by klemek
parent fb70638330
commit 33cfd350a5
8 changed files with 609 additions and 85 deletions
+1 -1
View File
@@ -116,7 +116,7 @@ curl -X DELETE \
- [x] unit tests - [x] unit tests
- [x] github actions - [x] github actions
- [x] X-Redirect - [x] X-Redirect
- [ ] X-Proxy - [x] X-Proxy
- [ ] proper doc - [ ] proper doc
### Makefile targets ### Makefile targets
+73 -20
View File
@@ -47,7 +47,7 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler):
) -> None: ) -> None:
shortmsg, longmsg = self.responses[code] shortmsg, longmsg = self.responses[code]
if message is None: if message is None:
message = shortmsg # pragma: no cover message = shortmsg
if explain is None: if explain is None:
explain = longmsg explain = longmsg
if "text/" in self._get_header("Accept"): if "text/" in self._get_header("Accept"):
@@ -104,7 +104,7 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler):
encoded: bytes = body.encode() encoded: bytes = body.encode()
self.out_size = len(encoded) self.out_size = len(encoded)
self.send_response(code, message) self.send_response(code, message)
self.send_header("Content-type", f"{content_type}; charset=UTF-8") self.send_header("Content-Type", f"{content_type}; charset=UTF-8")
self.send_header("Content-Length", str(len(encoded))) self.send_header("Content-Length", str(len(encoded)))
for header, value in headers.items(): for header, value in headers.items():
self.send_header(header, value) # pragma: no cover self.send_header(header, value) # pragma: no cover
@@ -134,14 +134,14 @@ class BaseHandler(abc.ABC, http.server.BaseHTTPRequestHandler):
) )
def send_proxy(self, url: str) -> None: def send_proxy(self, url: str) -> None:
body: bytes | None = None
if self.in_size > 0:
body = self.rfile.read(self.in_size)
headers = dict(self.headers) headers = dict(self.headers)
headers["Host"] = urllib.parse.urlparse(url).netloc headers["Host"] = urllib.parse.urlparse(url).netloc
headers["X-Forwarded-For"] = self.client_address[0] headers["X-Forwarded-For"] = self.client_address[0]
headers["X-Real-IP"] = self.client_address[0] headers["X-Real-IP"] = self.client_address[0]
try: try:
body: bytes | None = None
if self.in_size > 0:
body = self.rfile.read(self.in_size)
response: requests.Response = requests.request( response: requests.Response = requests.request(
self.command, url, data=body, headers=headers, timeout=240 self.command, url, data=body, headers=headers, timeout=240
) )
@@ -220,6 +220,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
TOKEN_HEADER = "X-Token" # noqa: S105 TOKEN_HEADER = "X-Token" # noqa: S105
HOST_HEADER = "X-Host" HOST_HEADER = "X-Host"
REDIRECT_HEADER = "X-Redirect" REDIRECT_HEADER = "X-Redirect"
PROXY_HEADER = "X-Proxy"
@typing.override @typing.override
def __init__( def __init__(
@@ -241,6 +242,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
self.__token: str | None = None self.__token: str | None = None
self.__target_host: str | None = None self.__target_host: str | None = None
self.__target_redirect: str | None = None self.__target_redirect: str | None = None
self.__target_proxy: str | None = None
super().__init__(*args, directory=params.data_dir, **kwargs, params=params) # ty:ignore[unknown-argument] super().__init__(*args, directory=params.data_dir, **kwargs, params=params) # ty:ignore[unknown-argument]
@property @property
@@ -249,6 +251,10 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
self.__token = self._get_header(self.TOKEN_HEADER) self.__token = self._get_header(self.TOKEN_HEADER)
return self.__token return self.__token
@property
def has_token(self) -> bool:
return len(self.token) > 0
@property @property
def target_host(self) -> str: def target_host(self) -> str:
if self.__target_host is None: if self.__target_host is None:
@@ -269,28 +275,35 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
def has_target_redirect(self) -> bool: def has_target_redirect(self) -> bool:
return len(self.target_redirect) > 0 return len(self.target_redirect) > 0
@property
def target_proxy(self) -> str:
if self.__target_proxy is None:
self.__target_proxy = self._get_header(self.PROXY_HEADER).lower()
return self.__target_proxy
@property
def has_target_proxy(self) -> bool:
return len(self.target_proxy) > 0
@typing.override @typing.override
def do_HEAD(self) -> None: def do_HEAD(self) -> None:
self._pre_log_request() self._pre_log_request()
if ( if not self._proxy_or_redirect():
page := self.__get_page(self.path) super().do_HEAD()
) is not None and page.redirect is not None:
return self.send_redirect(page.redirect)
return super().do_HEAD()
@typing.override @typing.override
def do_GET(self) -> None: def do_GET(self) -> None:
self._pre_log_request() self._pre_log_request()
if self._proxy_or_redirect():
return None
if self.path == "/" and self.host == self.default_host: if self.path == "/" and self.host == self.default_host:
return self.send_basic_body(self.server_signature()) return self.send_basic_body(self.server_signature())
if (
page := self.__get_page(self.path)
) is not None and page.redirect is not None:
return self.send_redirect(page.redirect)
return super().do_GET() return super().do_GET()
def do_PUT(self) -> None: def do_PUT(self) -> None:
self._pre_log_request() self._pre_log_request()
if self._proxy_or_redirect():
return None
if (path := self.__check_update_request()) is None: if (path := self.__check_update_request()) is None:
return None return None
if self.has_target_host and not self.__valid_host(self.target_host): if self.has_target_host and not self.__valid_host(self.target_host):
@@ -303,8 +316,15 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
and page.path != path and page.path != path
): ):
return self.send_error(http.HTTPStatus.FORBIDDEN, "Host already taken") return self.send_error(http.HTTPStatus.FORBIDDEN, "Host already taken")
if self.has_target_proxy and self.has_target_redirect:
return self.send_error(
http.HTTPStatus.BAD_REQUEST,
f"Cannot use {self.PROXY_HEADER} with {self.REDIRECT_HEADER}",
)
if self.has_target_redirect: if self.has_target_redirect:
self._update_redirect(path) self._update_redirect(path)
elif self.has_target_proxy:
self._update_proxy(path)
else: else:
self._update_extract(path) self._update_extract(path)
if self.has_target_host and self.cert_manager.create_or_update( if self.has_target_host and self.cert_manager.create_or_update(
@@ -321,21 +341,26 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
def do_DELETE(self) -> None: def do_DELETE(self) -> None:
self._pre_log_request() self._pre_log_request()
if self._proxy_or_redirect():
return None
if (path := self.__check_update_request()) is None: if (path := self.__check_update_request()) is None:
return None return None
return self._update_remove(path) return self._update_remove(path)
def do_CONNECT(self) -> None: def do_CONNECT(self) -> None:
self._pre_log_request() self._pre_log_request()
self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) if not self._proxy_or_redirect():
self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED)
def do_OPTIONS(self) -> None: def do_OPTIONS(self) -> None:
self._pre_log_request() self._pre_log_request()
self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) if not self._proxy_or_redirect():
self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED)
def do_TRACE(self) -> None: def do_TRACE(self) -> None:
self._pre_log_request() self._pre_log_request()
self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED) if not self._proxy_or_redirect():
self.send_error(http.HTTPStatus.METHOD_NOT_ALLOWED)
def _update_extract(self, path: str) -> None: def _update_extract(self, path: str) -> None:
if self.in_size == 0: if self.in_size == 0:
@@ -366,10 +391,22 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
http.HTTPStatus.BAD_REQUEST, http.HTTPStatus.BAD_REQUEST,
f"No content must be sent with {self.REDIRECT_HEADER}", f"No content must be sent with {self.REDIRECT_HEADER}",
) )
self.data_dir.empty(path)
self.registry.add(path)
self.token_manager.set_token(path, self.token)
self.registry.set_redirect(path, self.target_redirect) self.registry.set_redirect(path, self.target_redirect)
self.token_manager.set_token(path, self.token)
self.send_status_only(
http.HTTPStatus.CREATED,
f"Resource /{path}/ updated",
)
return None
def _update_proxy(self, path: str) -> None:
if self.in_size > 0:
return self.send_error(
http.HTTPStatus.BAD_REQUEST,
f"No content must be sent with {self.PROXY_HEADER}",
)
self.registry.set_proxy(path, self.target_proxy)
self.token_manager.set_token(path, self.token)
self.send_status_only( self.send_status_only(
http.HTTPStatus.CREATED, http.HTTPStatus.CREATED,
f"Resource /{path}/ updated", f"Resource /{path}/ updated",
@@ -391,6 +428,22 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
self.registry.remove(path) self.registry.remove(path)
return None return None
def _proxy_or_redirect(self) -> bool:
if self.has_token:
return False
if (page := self.__get_page(self.path)) is None:
return False
if page.redirect is not None:
self.send_redirect(page.redirect)
return True
if page.proxy is not None:
if self.host == self.default_host:
self.send_proxy(page.proxy + self.path.removeprefix(f"/{page.path}"))
else:
self.send_proxy(page.proxy + self.path)
return True
return False
@typing.override @typing.override
def list_directory(self, *_: typing.Any, **__: typing.Any) -> None: def list_directory(self, *_: typing.Any, **__: typing.Any) -> None:
"""Disable default directory listing.""" """Disable default directory listing."""
+3
View File
@@ -8,6 +8,7 @@ class Page:
host: str | None = None host: str | None = None
token_hash: str | None = None token_hash: str | None = None
redirect: str | None = None redirect: str | None = None
proxy: str | None = None
def __repr__(self) -> str: def __repr__(self) -> str:
out = f"/{self.path}/" out = f"/{self.path}/"
@@ -15,6 +16,8 @@ class Page:
out += f" [{self.host}]" out += f" [{self.host}]"
if self.redirect is not None: if self.redirect is not None:
out += f" (redirect: {self.redirect})" out += f" (redirect: {self.redirect})"
elif self.proxy is not None:
out += f" (proxy: {self.proxy})"
elif not self.with_index: elif not self.with_index:
out += " (no index)" out += " (no index)"
return out return out
+16 -2
View File
@@ -18,6 +18,7 @@ class Registry:
HOST_FILE = ".host" HOST_FILE = ".host"
TOKEN_FILE = ".token" # noqa: S105 TOKEN_FILE = ".token" # noqa: S105
REDIRECT_FILE = ".redirect" REDIRECT_FILE = ".redirect"
PROXY_FILE = ".proxy"
def __init__(self, params: Parameters) -> None: def __init__(self, params: Parameters) -> None:
self.logger: logging.Logger = logging.getLogger(self.__class__.__name__) self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
@@ -39,6 +40,7 @@ class Registry:
self.data_dir.get_file(path, self.HOST_FILE), self.data_dir.get_file(path, self.HOST_FILE),
self.data_dir.get_file(path, self.TOKEN_FILE), self.data_dir.get_file(path, self.TOKEN_FILE),
self.data_dir.get_file(path, self.REDIRECT_FILE), self.data_dir.get_file(path, self.REDIRECT_FILE),
self.data_dir.get_file(path, self.PROXY_FILE),
) )
self.logger.info("Updated %s", self.pages[path]) self.logger.info("Updated %s", self.pages[path])
@@ -52,14 +54,26 @@ class Registry:
if path in self.pages and self.pages[path].token_hash != token_hash: if path in self.pages and self.pages[path].token_hash != token_hash:
self.data_dir.set_file(path, self.TOKEN_FILE, token_hash, 0o600) self.data_dir.set_file(path, self.TOKEN_FILE, token_hash, 0o600)
self.pages[path].token_hash = token_hash self.pages[path].token_hash = token_hash
self.logger.debug("Updated %s", self.pages[path]) self.logger.debug("Updated %s (token)", self.pages[path])
def set_redirect(self, path: str, redirect: str) -> None: def set_redirect(self, path: str, redirect: str) -> None:
if path in self.pages and self.pages[path].redirect != redirect: if path not in self.pages or self.pages[path].redirect != redirect:
self.data_dir.empty(path)
self.data_dir.set_file(path, self.REDIRECT_FILE, redirect) self.data_dir.set_file(path, self.REDIRECT_FILE, redirect)
if path not in self.pages:
self.pages[path] = Page(path)
self.pages[path].redirect = redirect self.pages[path].redirect = redirect
self.logger.debug("Updated %s", self.pages[path]) self.logger.debug("Updated %s", self.pages[path])
def set_proxy(self, path: str, proxy: str) -> None:
if path not in self.pages or self.pages[path].proxy != proxy:
self.data_dir.empty(path)
self.data_dir.set_file(path, self.PROXY_FILE, proxy)
if path not in self.pages:
self.pages[path] = Page(path)
self.pages[path].proxy = proxy
self.logger.debug("Updated %s", self.pages[path])
def remove(self, path: str) -> None: def remove(self, path: str) -> None:
if path in self.pages: if path in self.pages:
page = self.pages[path] page = self.pages[path]
+30 -10
View File
@@ -5,6 +5,8 @@ import typing
import unittest import unittest
import unittest.mock import unittest.mock
__import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 # ty:ignore[unresolved-attribute] # noqa: SLF001
class BaseTestCase(unittest.TestCase): class BaseTestCase(unittest.TestCase):
@typing.override @typing.override
@@ -47,16 +49,19 @@ class BaseTestCase(unittest.TestCase):
target: str, target: str,
args: list[typing.Iterable[typing.Any]] | None = None, args: list[typing.Iterable[typing.Any]] | None = None,
return_values: list[typing.Any] | None = None, return_values: list[typing.Any] | None = None,
kwargs: list[dict[str, typing.Any]] | None = None,
) -> typing.Iterator[unittest.mock.Mock]: ) -> typing.Iterator[unittest.mock.Mock]:
if args is None: if args is None:
args = [[]] args = [[]]
if return_values is None: if return_values is None:
return_values = [None] * len(args) return_values = [None] * len(args)
if kwargs is None:
kwargs = [{}] * len(args)
with unittest.mock.patch( with unittest.mock.patch(
target, side_effect=return_values, create=True target, side_effect=return_values, create=True
) as mock: ) as mock:
yield mock yield mock
self.__check_calls(mock, args) self.__check_calls(mock, args, kwargs)
@contextlib.contextmanager @contextlib.contextmanager
def patch_call( def patch_call(
@@ -64,10 +69,13 @@ class BaseTestCase(unittest.TestCase):
target: str, target: str,
args: typing.Iterable[typing.Any] | None = None, args: typing.Iterable[typing.Any] | None = None,
return_value: typing.Any = None, return_value: typing.Any = None,
kwargs: dict[str, typing.Any] | None = None,
) -> typing.Iterator[unittest.mock.Mock]: ) -> typing.Iterator[unittest.mock.Mock]:
if args is None: if args is None:
args = [] args = []
with self.patch_calls(target, [args], [return_value]) as mock: if kwargs is None:
kwargs = {}
with self.patch_calls(target, [args], [return_value], [kwargs]) as mock:
yield mock yield mock
@contextlib.contextmanager @contextlib.contextmanager
@@ -84,15 +92,18 @@ class BaseTestCase(unittest.TestCase):
mock: unittest.mock.Mock, mock: unittest.mock.Mock,
args: list[typing.Iterable[typing.Any]] | None = None, args: list[typing.Iterable[typing.Any]] | None = None,
return_values: list[typing.Any] | None = None, return_values: list[typing.Any] | None = None,
kwargs: list[dict[str, typing.Any]] | None = None,
) -> typing.Iterator[None]: ) -> typing.Iterator[None]:
if args is None: if args is None:
args = [[]] args = [[]]
if return_values is None: if return_values is None:
return_values = [None] * len(args) return_values = [None] * len(args)
if kwargs is None:
kwargs = [{}] * len(args)
mock.side_effect = return_values mock.side_effect = return_values
mock.reset_mock() mock.reset_mock()
yield yield
self.__check_calls(mock, args) self.__check_calls(mock, args, kwargs)
@contextlib.contextmanager @contextlib.contextmanager
def mock_call( def mock_call(
@@ -100,10 +111,13 @@ class BaseTestCase(unittest.TestCase):
mock: unittest.mock.Mock, mock: unittest.mock.Mock,
args: typing.Iterable[typing.Any] | None = None, args: typing.Iterable[typing.Any] | None = None,
return_value: typing.Any = None, return_value: typing.Any = None,
kwargs: dict[str, typing.Any] | None = None,
) -> typing.Iterator[None]: ) -> typing.Iterator[None]:
if args is None: if args is None:
args = [] args = []
with self.mock_calls(mock, [args], [return_value]): if kwargs is None:
kwargs = {}
with self.mock_calls(mock, [args], [return_value], [kwargs]):
yield yield
@contextlib.contextmanager @contextlib.contextmanager
@@ -140,17 +154,23 @@ class BaseTestCase(unittest.TestCase):
self, self,
mock: unittest.mock.Mock, mock: unittest.mock.Mock,
args: list[typing.Iterable[typing.Any]], args: list[typing.Iterable[typing.Any]],
kwargs: list[dict[str, typing.Any]],
) -> None: ) -> None:
total_rows = max(len(args), len(mock.method_calls), len(kwargs))
missing_calls = max(0, total_rows - len(mock.mock_calls))
missing_args = max(0, total_rows - len(args))
missing_kwargs = max(0, total_rows - len(kwargs))
for i, values in enumerate( for i, values in enumerate(
zip( zip(
mock.mock_calls mock.mock_calls + [None] * missing_calls,
+ [None] args + [[]] * missing_args,
* (max(len(args), len(mock.method_calls)) - len(mock.mock_calls)), kwargs + [{}] * missing_kwargs,
args + [[]] * (max(len(args), len(mock.method_calls)) - len(args)),
strict=False, strict=False,
) )
): ):
real_call, expected_args = values real_call, expected_args, expected_kwargs = values
self.assertEqual( self.assertEqual(
real_call, unittest.mock.call(*expected_args), f"{i + 1}: {mock}" real_call,
unittest.mock.call(*expected_args, **expected_kwargs),
f"{i + 1}: {mock}",
) )
+380 -52
View File
@@ -9,6 +9,8 @@ import tarfile
import typing import typing
import unittest.mock import unittest.mock
import requests
from src.handlers import BaseHandler, RequestHandler, UpgradeHandler from src.handlers import BaseHandler, RequestHandler, UpgradeHandler
from src.page import Page from src.page import Page
from src.params import Parameters from src.params import Parameters
@@ -22,6 +24,7 @@ class BaseHandlerTestCase(BaseTestCase, abc.ABC):
self, self,
path: str = "/", path: str = "/",
headers: dict[str, str | None] | None = None, headers: dict[str, str | None] | None = None,
method: str = "GET",
rfile: io.BufferedIOBase | None = None, rfile: io.BufferedIOBase | None = None,
) -> BaseHandler: ) -> BaseHandler:
pass pass
@@ -70,7 +73,7 @@ class BaseHandlerTestCase(BaseTestCase, abc.ABC):
send_header_mock.assert_has_calls( send_header_mock.assert_has_calls(
[ [
unittest.mock.call("Content-Length", str(len(body.encode()))), unittest.mock.call("Content-Length", str(len(body.encode()))),
unittest.mock.call("Content-type", f"{content_type}; charset=UTF-8"), unittest.mock.call("Content-Type", f"{content_type}; charset=UTF-8"),
] ]
+ [unittest.mock.call(header, value) for header, value in headers.items()], + [unittest.mock.call(header, value) for header, value in headers.items()],
any_order=True, any_order=True,
@@ -129,6 +132,7 @@ class TestRequestHandler(BaseHandlerTestCase):
self, self,
path: str = "/", path: str = "/",
headers: dict[str, str | None] | None = None, headers: dict[str, str | None] | None = None,
method: str = "GET",
rfile: io.BufferedIOBase | None = None, rfile: io.BufferedIOBase | None = None,
) -> RequestHandler: ) -> RequestHandler:
if headers is None: if headers is None:
@@ -146,9 +150,11 @@ class TestRequestHandler(BaseHandlerTestCase):
token_manager=self.token_manager, token_manager=self.token_manager,
) )
handler.address_string = lambda: "127.0.0.1" # ty:ignore[invalid-assignment] handler.address_string = lambda: "127.0.0.1" # ty:ignore[invalid-assignment]
handler.requestline = "GET /" handler.requestline = f"{method} {path}"
handler.path = path handler.path = path
handler.command = method
handler.request_version = "HTTP/0.9" handler.request_version = "HTTP/0.9"
handler.client_address = ("127.0.0.1", 12345)
handler.headers = collections.defaultdict(lambda: None, headers) # ty:ignore[invalid-assignment] handler.headers = collections.defaultdict(lambda: None, headers) # ty:ignore[invalid-assignment]
handler.rfile = rfile if rfile is not None else io.BytesIO() handler.rfile = rfile if rfile is not None else io.BytesIO()
handler.wfile = io.BytesIO() handler.wfile = io.BytesIO()
@@ -156,25 +162,7 @@ class TestRequestHandler(BaseHandlerTestCase):
handler.data_dir = self.data_dir handler.data_dir = self.data_dir
return handler return handler
def test_do_head_redirect(self) -> None: def test_do_head_forward(self) -> None:
handler = self._get_handler("/path")
with (
self.mock_call(
self.registry.get_from_path,
["path"],
Page("path", redirect="https://example.com"),
),
self.expects_status_only(
handler,
http.HTTPStatus.MOVED_PERMANENTLY,
headers={"Location": "https://example.com"},
),
self.patch("http.server.SimpleHTTPRequestHandler.do_HEAD", count=0),
self.seal_mocks(),
):
handler.do_HEAD()
def test_do_head_proxy(self) -> None:
handler = self._get_handler() handler = self._get_handler()
with ( with (
self.patch("http.server.SimpleHTTPRequestHandler.do_HEAD"), self.patch("http.server.SimpleHTTPRequestHandler.do_HEAD"),
@@ -191,37 +179,16 @@ class TestRequestHandler(BaseHandlerTestCase):
): ):
handler.do_GET() handler.do_GET()
def test_do_get_redirect(self) -> None: def test_do_get_forward_on_other_path(self) -> None:
handler = self._get_handler("/path")
with (
self.mock_call(
self.registry.get_from_path,
["path"],
Page("path", redirect="https://example.com"),
),
self.expects_status_only(
handler,
http.HTTPStatus.MOVED_PERMANENTLY,
headers={"Location": "https://example.com"},
),
self.patch("http.server.SimpleHTTPRequestHandler.do_GET", count=0),
self.seal_mocks(),
):
handler.do_GET()
def test_do_get_proxy_on_other_path(self) -> None:
handler = self._get_handler("/file") handler = self._get_handler("/file")
with ( with (
self.mock_call( self.mock_call(self.registry.get_from_path, ["file"], Page("file")),
self.registry.get_from_path,
["file"],
),
self.patch("http.server.SimpleHTTPRequestHandler.do_GET"), self.patch("http.server.SimpleHTTPRequestHandler.do_GET"),
self.seal_mocks(), self.seal_mocks(),
): ):
handler.do_GET() handler.do_GET()
def test_do_get_proxy_on_other_host(self) -> None: def test_do_get_forward_on_other_host(self) -> None:
handler = self._get_handler("/", {"Host": "other_host"}) handler = self._get_handler("/", {"Host": "other_host"})
with ( with (
self.mock_call( self.mock_call(
@@ -236,6 +203,7 @@ class TestRequestHandler(BaseHandlerTestCase):
def test_do_put_no_token(self) -> None: def test_do_put_no_token(self) -> None:
handler = self._get_handler("/path") handler = self._get_handler("/path")
with ( with (
self.mock_call(self.registry.get_from_path, ["path"]),
self.expects_error( self.expects_error(
handler, http.HTTPStatus.BAD_REQUEST, "No X-Token header in request" handler, http.HTTPStatus.BAD_REQUEST, "No X-Token header in request"
), ),
@@ -246,6 +214,7 @@ class TestRequestHandler(BaseHandlerTestCase):
def test_do_post_is_do_put(self) -> None: def test_do_post_is_do_put(self) -> None:
handler = self._get_handler("/path") handler = self._get_handler("/path")
with ( with (
self.mock_call(self.registry.get_from_path, ["path"]),
self.expects_error( self.expects_error(
handler, http.HTTPStatus.BAD_REQUEST, "No X-Token header in request" handler, http.HTTPStatus.BAD_REQUEST, "No X-Token header in request"
), ),
@@ -256,6 +225,7 @@ class TestRequestHandler(BaseHandlerTestCase):
def test_do_patch_is_do_put(self) -> None: def test_do_patch_is_do_put(self) -> None:
handler = self._get_handler("/path") handler = self._get_handler("/path")
with ( with (
self.mock_call(self.registry.get_from_path, ["path"]),
self.expects_error( self.expects_error(
handler, http.HTTPStatus.BAD_REQUEST, "No X-Token header in request" handler, http.HTTPStatus.BAD_REQUEST, "No X-Token header in request"
), ),
@@ -525,10 +495,8 @@ class TestRequestHandler(BaseHandlerTestCase):
["secret", "path"], ["secret", "path"],
True, # noqa: FBT003 True, # noqa: FBT003
), ),
self.mock_call(self.data_dir.empty, ["path"]),
self.mock_call(self.registry.add, ["path"]),
self.mock_call(self.token_manager.set_token, ["path", "secret"]),
self.mock_call(self.registry.set_redirect, ["path", "https://example.com"]), self.mock_call(self.registry.set_redirect, ["path", "https://example.com"]),
self.mock_call(self.token_manager.set_token, ["path", "secret"]),
self.expects_status_only( self.expects_status_only(
handler, http.HTTPStatus.CREATED, "Resource /path/ updated" handler, http.HTTPStatus.CREATED, "Resource /path/ updated"
), ),
@@ -553,10 +521,8 @@ class TestRequestHandler(BaseHandlerTestCase):
True, # noqa: FBT003 True, # noqa: FBT003
), ),
self.mock_call(self.registry.get_from_host, ["example.com"], Page("path")), self.mock_call(self.registry.get_from_host, ["example.com"], Page("path")),
self.mock_call(self.data_dir.empty, ["path"]),
self.mock_call(self.registry.add, ["path"]),
self.mock_call(self.token_manager.set_token, ["path", "secret"]),
self.mock_call(self.registry.set_redirect, ["path", "https://example.com"]), self.mock_call(self.registry.set_redirect, ["path", "https://example.com"]),
self.mock_call(self.token_manager.set_token, ["path", "secret"]),
self.mock_call(self.cert_manager.create_or_update, ["example.com"], True), # noqa: FBT003 self.mock_call(self.cert_manager.create_or_update, ["example.com"], True), # noqa: FBT003
self.mock_call(self.registry.set_host, ["path", "example.com"]), self.mock_call(self.registry.set_host, ["path", "example.com"]),
self.expects_status_only( self.expects_status_only(
@@ -566,9 +532,114 @@ class TestRequestHandler(BaseHandlerTestCase):
): ):
handler.do_PUT() handler.do_PUT()
def test_do_put_proxy_with_content(self) -> None:
handler = self._get_handler(
"/path",
{
"X-Token": "secret",
"X-Proxy": "https://example.com",
"Content-Length": "1",
},
)
with (
self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003
self.mock_call(
self.token_manager.is_valid_for_path,
["secret", "path"],
True, # noqa: FBT003
),
self.expects_error(
handler,
http.HTTPStatus.BAD_REQUEST,
"No content must be sent with X-Proxy",
),
self.seal_mocks(),
):
handler.do_PUT()
def test_do_put_proxy_ok(self) -> None:
handler = self._get_handler(
"/path",
{
"X-Token": "secret",
"X-Proxy": "https://example.com",
},
)
with (
self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003
self.mock_call(
self.token_manager.is_valid_for_path,
["secret", "path"],
True, # noqa: FBT003
),
self.mock_call(self.registry.set_proxy, ["path", "https://example.com"]),
self.mock_call(self.token_manager.set_token, ["path", "secret"]),
self.expects_status_only(
handler, http.HTTPStatus.CREATED, "Resource /path/ updated"
),
self.seal_mocks(),
):
handler.do_PUT()
def test_do_put_proxy_with_host(self) -> None:
handler = self._get_handler(
"/path",
{
"X-Token": "secret",
"X-Proxy": "https://example.com",
"X-Host": "example.com",
},
)
with (
self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003
self.mock_call(
self.token_manager.is_valid_for_path,
["secret", "path"],
True, # noqa: FBT003
),
self.mock_call(self.registry.get_from_host, ["example.com"], Page("path")),
self.mock_call(self.registry.set_proxy, ["path", "https://example.com"]),
self.mock_call(self.token_manager.set_token, ["path", "secret"]),
self.mock_call(self.cert_manager.create_or_update, ["example.com"], True), # noqa: FBT003
self.mock_call(self.registry.set_host, ["path", "example.com"]),
self.expects_status_only(
handler, http.HTTPStatus.CREATED, "Resource /path/ updated"
),
self.seal_mocks(),
):
handler.do_PUT()
def test_do_put_proxy_and_redirect(self) -> None:
handler = self._get_handler(
"/path",
{
"X-Token": "secret",
"X-Proxy": "https://example.com",
"X-Redirect": "https://example.com",
"X-Host": "example.com",
},
)
with (
self.mock_call(self.token_manager.is_valid, ["secret"], True), # noqa: FBT003
self.mock_call(
self.token_manager.is_valid_for_path,
["secret", "path"],
True, # noqa: FBT003
),
self.mock_call(self.registry.get_from_host, ["example.com"], Page("path")),
self.expects_status_only(
handler,
http.HTTPStatus.BAD_REQUEST,
"Cannot use X-Proxy with X-Redirect",
),
self.seal_mocks(),
):
handler.do_PUT()
def test_do_delete_no_token(self) -> None: def test_do_delete_no_token(self) -> None:
handler = self._get_handler("/path") handler = self._get_handler("/path")
with ( with (
self.mock_call(self.registry.get_from_path, ["path"]),
self.expects_error( self.expects_error(
handler, http.HTTPStatus.BAD_REQUEST, "No X-Token header in request" handler, http.HTTPStatus.BAD_REQUEST, "No X-Token header in request"
), ),
@@ -662,6 +733,261 @@ class TestRequestHandler(BaseHandlerTestCase):
): ):
handler.do_DELETE() handler.do_DELETE()
def test_do_post_proxy_no_body(self) -> None:
handler = self._get_handler("/path", method="POST")
response = requests.Response()
response.status_code = 200
response.reason = "OK"
response.raw = io.BytesIO()
with (
self.mock_call(
self.registry.get_from_path,
["path"],
Page("path", proxy="https://example.com"),
),
self.patch_call(
"requests.request",
[
"POST",
"https://example.com",
],
response,
{
"data": None,
"headers": {
"Host": "example.com",
"X-Forwarded-For": "127.0.0.1",
"X-Real-IP": "127.0.0.1",
},
"timeout": 240,
},
),
self.expects_status_only(handler, 200, "OK"),
self.seal_mocks(),
):
handler.do_POST()
def test_do_post_proxy_with_request_body(self) -> None:
handler = self._get_handler(
"/path",
method="POST",
headers={"Content-Length": "5"},
rfile=io.BytesIO(b"hello"),
)
response = requests.Response()
response.status_code = 200
response.reason = "OK"
response.raw = io.BytesIO()
with (
self.mock_call(
self.registry.get_from_path,
["path"],
Page("path", proxy="https://example.com"),
),
self.patch_call(
"requests.request",
[
"POST",
"https://example.com",
],
response,
{
"data": b"hello",
"headers": {
"Host": "example.com",
"X-Forwarded-For": "127.0.0.1",
"X-Real-IP": "127.0.0.1",
"Content-Length": "5",
},
"timeout": 240,
},
),
self.expects_status_only(handler, 200, "OK"),
self.seal_mocks(),
):
handler.do_POST()
def test_do_post_proxy_with_response_body(self) -> None:
handler = self._get_handler(
"/path",
method="POST",
)
response = requests.Response()
response.status_code = 200
response.reason = "OK"
response.headers["Content-Type"] = "text/plain; charset=UTF-8"
response.raw = io.BytesIO(b"hello")
with (
self.mock_call(
self.registry.get_from_path,
["path"],
Page("path", proxy="https://example.com"),
),
self.patch_call(
"requests.request",
[
"POST",
"https://example.com",
],
response,
{
"data": None,
"headers": {
"Host": "example.com",
"X-Forwarded-For": "127.0.0.1",
"X-Real-IP": "127.0.0.1",
},
"timeout": 240,
},
),
self.expects_basic_body(handler, "hello", message="OK"),
self.seal_mocks(),
):
handler.do_POST()
def test_do_post_proxy_fail(self) -> None:
handler = self._get_handler("/path", method="POST")
with (
self.mock_call(
self.registry.get_from_path,
["path"],
Page("path", proxy="https://example.com"),
),
self.patch_call(
"requests.request",
[
"POST",
"https://example.com",
],
None,
{
"data": None,
"headers": {
"Host": "example.com",
"X-Forwarded-For": "127.0.0.1",
"X-Real-IP": "127.0.0.1",
},
"timeout": 240,
},
) as request_mock,
self.expects_status_only(
handler,
http.HTTPStatus.BAD_GATEWAY,
"Could not reach https://example.com",
),
self.seal_mocks(),
):
request_mock.side_effect = Exception
handler.do_POST()
def test_do_post_proxy_sub_path(self) -> None:
handler = self._get_handler("/path/index.html", method="POST")
response = requests.Response()
response.status_code = 200
response.reason = "OK"
response.raw = io.BytesIO()
with (
self.mock_call(
self.registry.get_from_path,
["path"],
Page("path", proxy="https://example.com"),
),
self.patch_call(
"requests.request",
[
"POST",
"https://example.com/index.html",
],
response,
{
"data": None,
"headers": {
"Host": "example.com",
"X-Forwarded-For": "127.0.0.1",
"X-Real-IP": "127.0.0.1",
},
"timeout": 240,
},
),
self.expects_status_only(handler, 200, "OK"),
self.seal_mocks(),
):
handler.do_POST()
def test_do_post_proxy_sub_path_for_host(self) -> None:
handler = self._get_handler(
"/path/index.html", method="POST", headers={"Host": "host"}
)
response = requests.Response()
response.status_code = 200
response.reason = "OK"
response.raw = io.BytesIO()
with (
self.mock_call(
self.registry.get_from_host,
["host"],
Page("path", proxy="https://example.com"),
),
self.patch_call(
"requests.request",
[
"POST",
"https://example.com/path/index.html",
],
response,
{
"data": None,
"headers": {
"Host": "example.com",
"X-Forwarded-For": "127.0.0.1",
"X-Real-IP": "127.0.0.1",
},
"timeout": 240,
},
),
self.expects_status_only(handler, 200, "OK"),
self.seal_mocks(),
):
handler.do_POST()
def test_do_method_not_supported(self) -> None:
handler = self._get_handler("/path")
for method in ["CONNECT", "TRACE", "OPTIONS"]:
with (
self.subTest(None, method=method),
self.mock_call(
self.registry.get_from_path,
["path"],
),
self.expects_status_only(
handler, http.HTTPStatus.METHOD_NOT_ALLOWED, "Method Not Allowed"
),
self.seal_mocks(),
):
getattr(handler, f"do_{method}")()
def test_do_redirect(self) -> None:
handler = self._get_handler("/path")
for method in [method.value for method in http.HTTPMethod]:
with (
self.subTest(None, method=method),
self.mock_call(
self.registry.get_from_path,
["path"],
Page("path", redirect="https://example.com"),
),
self.expects_status_only(
handler,
http.HTTPStatus.MOVED_PERMANENTLY,
headers={"Location": "https://example.com"},
),
self.patch(
f"http.server.SimpleHTTPRequestHandler.do_{method}", count=0
),
self.seal_mocks(),
):
getattr(handler, f"do_{method}")()
def test_list_directory(self) -> None: def test_list_directory(self) -> None:
handler = self._get_handler("/path/", {"Accept": "text/html"}) handler = self._get_handler("/path/", {"Accept": "text/html"})
with ( with (
@@ -780,6 +1106,7 @@ class TestUpgradeHandler(BaseHandlerTestCase):
self, self,
path: str = "/", path: str = "/",
headers: dict[str, str | None] | None = None, headers: dict[str, str | None] | None = None,
method: str = "GET",
rfile: io.BufferedIOBase | None = None, rfile: io.BufferedIOBase | None = None,
) -> UpgradeHandler: ) -> UpgradeHandler:
if headers is None: if headers is None:
@@ -792,8 +1119,9 @@ class TestUpgradeHandler(BaseHandlerTestCase):
params=Parameters(), params=Parameters(),
) )
handler.address_string = lambda: "127.0.0.1" # ty:ignore[invalid-assignment] handler.address_string = lambda: "127.0.0.1" # ty:ignore[invalid-assignment]
handler.requestline = "GET /" handler.requestline = f"{method} {path}"
handler.path = path handler.path = path
handler.command = method
handler.request_version = "HTTP/0.9" handler.request_version = "HTTP/0.9"
handler.headers = collections.defaultdict(lambda: None, headers) # ty:ignore[invalid-assignment] handler.headers = collections.defaultdict(lambda: None, headers) # ty:ignore[invalid-assignment]
handler.rfile = rfile if rfile is not None else io.BytesIO() handler.rfile = rfile if rfile is not None else io.BytesIO()
+6
View File
@@ -21,3 +21,9 @@ class TestPage(BaseTestCase):
str(Page("test_1", redirect="https://example.com")), str(Page("test_1", redirect="https://example.com")),
"/test_1/ (redirect: https://example.com)", "/test_1/ (redirect: https://example.com)",
) )
def test_repr_with_proxy(self) -> None:
self.assertEqual(
str(Page("test_1", proxy="https://example.com")),
"/test_1/ (proxy: https://example.com)",
)
+100
View File
@@ -30,17 +30,21 @@ class TestRegistry(BaseTestCase):
["test_1", Registry.HOST_FILE], ["test_1", Registry.HOST_FILE],
["test_1", Registry.TOKEN_FILE], ["test_1", Registry.TOKEN_FILE],
["test_1", Registry.REDIRECT_FILE], ["test_1", Registry.REDIRECT_FILE],
["test_1", Registry.PROXY_FILE],
["test_2", Registry.HOST_FILE], ["test_2", Registry.HOST_FILE],
["test_2", Registry.TOKEN_FILE], ["test_2", Registry.TOKEN_FILE],
["test_2", Registry.REDIRECT_FILE], ["test_2", Registry.REDIRECT_FILE],
["test_2", Registry.PROXY_FILE],
], ],
[ [
"test_1_host", "test_1_host",
"test_1_token", "test_1_token",
None, None,
None, None,
None,
"test_2_token", "test_2_token",
"test_2_redirect", "test_2_redirect",
None,
], ],
), ),
self.seal_mocks(), self.seal_mocks(),
@@ -114,12 +118,33 @@ class TestRegistry(BaseTestCase):
self.registry.set_token_hash("test_1", "new_value") self.registry.set_token_hash("test_1", "new_value")
self.assertEqual(self.registry.pages["test_1"].token_hash, "new_value") self.assertEqual(self.registry.pages["test_1"].token_hash, "new_value")
def test_set_token_hash_no_change(self) -> None:
self.registry.pages["test_1"] = Page(
"test_1",
token_hash="secret", # noqa: S106
)
with (
self.seal_mocks(),
):
self.registry.set_token_hash("test_1", "secret")
self.assertEqual(self.registry.pages["test_1"].token_hash, "secret")
def test_set_token_hash_not_found(self) -> None:
with (
self.seal_mocks(),
):
self.registry.set_token_hash("test_1", "secret")
def test_set_redirect(self) -> None: def test_set_redirect(self) -> None:
self.registry.pages["test_1"] = Page( self.registry.pages["test_1"] = Page(
"test_1", "test_1",
redirect="https://example.com", redirect="https://example.com",
) )
with ( with (
self.mock_call(
self.data_dir.empty,
["test_1"],
),
self.mock_call( self.mock_call(
self.data_dir.set_file, self.data_dir.set_file,
["test_1", Registry.REDIRECT_FILE, "https://new-example.com"], ["test_1", Registry.REDIRECT_FILE, "https://new-example.com"],
@@ -131,6 +156,81 @@ class TestRegistry(BaseTestCase):
self.registry.pages["test_1"].redirect, "https://new-example.com" self.registry.pages["test_1"].redirect, "https://new-example.com"
) )
def test_set_redirect_no_change(self) -> None:
self.registry.pages["test_1"] = Page(
"test_1",
redirect="https://example.com",
)
with (
self.seal_mocks(),
):
self.registry.set_redirect("test_1", "https://example.com")
self.assertEqual(self.registry.pages["test_1"].redirect, "https://example.com")
def test_set_redirect_not_found(self) -> None:
with (
self.mock_call(
self.data_dir.empty,
["test_1"],
),
self.mock_call(
self.data_dir.set_file,
["test_1", Registry.REDIRECT_FILE, "https://new-example.com"],
),
self.seal_mocks(),
):
self.registry.set_redirect("test_1", "https://new-example.com")
self.assertIn("test_1", self.registry.pages)
self.assertEqual(
self.registry.pages["test_1"].redirect, "https://new-example.com"
)
def test_set_proxy(self) -> None:
self.registry.pages["test_1"] = Page(
"test_1",
proxy="https://example.com",
)
with (
self.mock_call(
self.data_dir.empty,
["test_1"],
),
self.mock_call(
self.data_dir.set_file,
["test_1", Registry.PROXY_FILE, "https://new-example.com"],
),
self.seal_mocks(),
):
self.registry.set_proxy("test_1", "https://new-example.com")
self.assertEqual(self.registry.pages["test_1"].proxy, "https://new-example.com")
def test_set_proxy_no_change(self) -> None:
self.registry.pages["test_1"] = Page(
"test_1",
proxy="https://example.com",
)
with (
self.seal_mocks(),
):
self.registry.set_proxy("test_1", "https://example.com")
self.assertEqual(self.registry.pages["test_1"].proxy, "https://example.com")
def test_set_proxy_not_found(self) -> None:
with (
self.mock_call(
self.data_dir.empty,
["test_1"],
),
self.mock_call(
self.data_dir.set_file,
["test_1", Registry.PROXY_FILE, "https://new-example.com"],
),
self.seal_mocks(),
):
self.registry.set_proxy("test_1", "https://new-example.com")
self.assertIn("test_1", self.registry.pages)
self.assertEqual(self.registry.pages["test_1"].proxy, "https://new-example.com")
def test_remove(self) -> None: def test_remove(self) -> None:
self.registry.pages["test_1"] = Page( self.registry.pages["test_1"] = Page(
"test_1", "test_1",