feat: rate limit

This commit is contained in:
2026-06-03 19:32:18 +02:00
parent ada69f773e
commit 132cbe9b20
2 changed files with 33 additions and 0 deletions
+17
View File
@@ -1,5 +1,6 @@
import abc
import contextlib
import datetime
import http
import http.cookiejar
import http.server
@@ -300,6 +301,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
REDIRECT_HEADER = "X-Redirect"
PROXY_HEADER = "X-Proxy"
SPA_HEADER = "X-SPA"
RATE_LIMIT = datetime.timedelta(seconds=1)
@typing.override
def __init__(
@@ -323,6 +325,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
self.__target_redirect: str | None = None
self.__target_proxy: str | None = None
self.__target_spa: str | None = None
self.rate_limits: dict[str, datetime.datetime] = {}
try:
super().__init__(*args, directory=params.data_dir, **kwargs, params=params) # ty:ignore[unknown-argument]
except (BrokenPipeError, ConnectionResetError) as e:
@@ -608,7 +611,20 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
path = f"/{page.path}/{page.spa}"
return super().translate_path(path)
def __check_rate_limit(self) -> bool:
now = datetime.datetime.now(tz=datetime.UTC)
address = self.address_string()
last_atempt = self.rate_limits.get(address, None)
self.rate_limits[address] = now
return last_atempt is None or now - last_atempt > self.RATE_LIMIT
def __clear_rate_limit(self) -> None:
del self.rate_limits[self.address_string()]
def __check_update_request(self) -> str | None:
if not self.__check_rate_limit():
self.send_error(http.HTTPStatus.TOO_MANY_REQUESTS, "Rate limit exceeded")
return None
if not self._has_header(self.TOKEN_HEADER):
self.send_error(
http.HTTPStatus.BAD_REQUEST, f"No {self.TOKEN_HEADER} header in request"
@@ -617,6 +633,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler):
if not self.token_manager.is_valid(self.token):
self.send_error(http.HTTPStatus.UNAUTHORIZED, "Invalid token")
return None
self.__clear_rate_limit()
if (sub_path := self.__get_path(self.path, self.UPDATE_PATH_REGEX)) is None:
self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid path")
return None
+16
View File
@@ -1,6 +1,7 @@
import abc
import collections
import contextlib
import datetime
import http
import http.server
import io
@@ -275,6 +276,7 @@ class TestRequestHandler(BaseHandlerTestCase):
self.seal_mocks(),
):
handler.do_PUT()
assert "127.0.0.1" in handler.rate_limits
def test_do_post_is_do_put(self) -> None:
handler = self._get_handler("/path")
@@ -306,6 +308,18 @@ class TestRequestHandler(BaseHandlerTestCase):
self.seal_mocks(),
):
handler.do_PUT()
assert "127.0.0.1" in handler.rate_limits
def test_do_put_rate_limit(self) -> None:
handler = self._get_handler("/path", {"X-Token": "secret"})
handler.rate_limits["127.0.0.1"] = datetime.datetime.now(tz=datetime.UTC)
with (
self.expects_error(
handler, http.HTTPStatus.TOO_MANY_REQUESTS, "Rate limit exceeded"
),
self.seal_mocks(),
):
handler.do_PUT()
def test_do_put_invalid_path(self) -> None:
handler = self._get_handler("/pa.th", {"X-Token": "secret"})
@@ -792,6 +806,7 @@ class TestRequestHandler(BaseHandlerTestCase):
self.seal_mocks(),
):
handler.do_DELETE()
assert "127.0.0.1" in handler.rate_limits
def test_do_delete_invalid_token(self) -> None:
handler = self._get_handler("/path", {"X-Token": "secret"})
@@ -801,6 +816,7 @@ class TestRequestHandler(BaseHandlerTestCase):
self.seal_mocks(),
):
handler.do_DELETE()
assert "127.0.0.1" in handler.rate_limits
def test_do_delete_invalid_path(self) -> None:
handler = self._get_handler("/pa.th", {"X-Token": "secret"})