diff --git a/stapler/handlers.py b/stapler/handlers.py index 2590f87..0d958fe 100644 --- a/stapler/handlers.py +++ b/stapler/handlers.py @@ -1,5 +1,6 @@ import abc import contextlib +import datetime import http import http.cookiejar import http.server @@ -300,6 +301,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): REDIRECT_HEADER = "X-Redirect" PROXY_HEADER = "X-Proxy" SPA_HEADER = "X-SPA" + RATE_LIMIT = datetime.timedelta(seconds=1) @typing.override def __init__( @@ -323,6 +325,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): self.__target_redirect: str | None = None self.__target_proxy: str | None = None self.__target_spa: str | None = None + self.rate_limits: dict[str, datetime.datetime] = {} try: super().__init__(*args, directory=params.data_dir, **kwargs, params=params) # ty:ignore[unknown-argument] except (BrokenPipeError, ConnectionResetError) as e: @@ -608,7 +611,20 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): path = f"/{page.path}/{page.spa}" return super().translate_path(path) + def __check_rate_limit(self) -> bool: + now = datetime.datetime.now(tz=datetime.UTC) + address = self.address_string() + last_atempt = self.rate_limits.get(address, None) + self.rate_limits[address] = now + return last_atempt is None or now - last_atempt > self.RATE_LIMIT + + def __clear_rate_limit(self) -> None: + del self.rate_limits[self.address_string()] + def __check_update_request(self) -> str | None: + if not self.__check_rate_limit(): + self.send_error(http.HTTPStatus.TOO_MANY_REQUESTS, "Rate limit exceeded") + return None if not self._has_header(self.TOKEN_HEADER): self.send_error( http.HTTPStatus.BAD_REQUEST, f"No {self.TOKEN_HEADER} header in request" @@ -617,6 +633,7 @@ class RequestHandler(http.server.SimpleHTTPRequestHandler, BaseHandler): if not self.token_manager.is_valid(self.token): self.send_error(http.HTTPStatus.UNAUTHORIZED, "Invalid token") return None + self.__clear_rate_limit() if (sub_path := self.__get_path(self.path, self.UPDATE_PATH_REGEX)) is None: self.send_error(http.HTTPStatus.BAD_REQUEST, "Invalid path") return None diff --git a/tests/test_handlers.py b/tests/test_handlers.py index fab7233..6f8fdc5 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1,6 +1,7 @@ import abc import collections import contextlib +import datetime import http import http.server import io @@ -275,6 +276,7 @@ class TestRequestHandler(BaseHandlerTestCase): self.seal_mocks(), ): handler.do_PUT() + assert "127.0.0.1" in handler.rate_limits def test_do_post_is_do_put(self) -> None: handler = self._get_handler("/path") @@ -306,6 +308,18 @@ class TestRequestHandler(BaseHandlerTestCase): self.seal_mocks(), ): handler.do_PUT() + assert "127.0.0.1" in handler.rate_limits + + def test_do_put_rate_limit(self) -> None: + handler = self._get_handler("/path", {"X-Token": "secret"}) + handler.rate_limits["127.0.0.1"] = datetime.datetime.now(tz=datetime.UTC) + with ( + self.expects_error( + handler, http.HTTPStatus.TOO_MANY_REQUESTS, "Rate limit exceeded" + ), + self.seal_mocks(), + ): + handler.do_PUT() def test_do_put_invalid_path(self) -> None: handler = self._get_handler("/pa.th", {"X-Token": "secret"}) @@ -792,6 +806,7 @@ class TestRequestHandler(BaseHandlerTestCase): self.seal_mocks(), ): handler.do_DELETE() + assert "127.0.0.1" in handler.rate_limits def test_do_delete_invalid_token(self) -> None: handler = self._get_handler("/path", {"X-Token": "secret"}) @@ -801,6 +816,7 @@ class TestRequestHandler(BaseHandlerTestCase): self.seal_mocks(), ): handler.do_DELETE() + assert "127.0.0.1" in handler.rate_limits def test_do_delete_invalid_path(self) -> None: handler = self._get_handler("/pa.th", {"X-Token": "secret"})