From 98a1a330c5efdb60e8ab86bf3960a8b1bcf95ac2 Mon Sep 17 00:00:00 2001 From: klemek Date: Fri, 17 Apr 2026 23:51:52 +0200 Subject: [PATCH] tests(params): add tests for params --- main.py | 2 +- src/params.py | 98 +++++++++++++++++++------------------------- tests/test_params.py | 37 +++++++++++++++++ 3 files changed, 81 insertions(+), 56 deletions(-) create mode 100644 tests/test_params.py diff --git a/main.py b/main.py index 88ac0d4..c2cb8cc 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,7 @@ from src.server import StaplerServer def main() -> None: - params = parse_parameters() + params = parse_parameters(sys.argv[1:]) setup_logs(params) server = StaplerServer(params) method = getattr(server, params.command) diff --git a/src/params.py b/src/params.py index 8c06d91..a9c6688 100644 --- a/src/params.py +++ b/src/params.py @@ -10,21 +10,21 @@ __EPILOG = "(Each option can be supplied with equivalent environment variable.)" @dataclasses.dataclass(frozen=True) class Parameters: - http_port: int - https_port: int - host: str - data_dir: str - bind: str - token_salt: str - max_size_bytes: int - certbot_conf: str - certbot_www: str - self_signed_path: str - with_certbot: bool - with_certificates: bool - https: bool - command: typing.Literal["run", "renew"] - debug: bool + debug: bool = False + data_dir: str = "./data" + with_certificates: bool = True + self_signed_path: str = "./data/.certificates" + with_certbot: bool = True + certbot_conf: str = "/etc/letsencrypt" + certbot_www: str = "./data/.certbot" + host: str = "localhost" + http_port: int = 80 + https_port: int = 443 + https: bool = True + token_salt: str = "" + max_size_bytes: int = 2_000_000 + bind: str = "0.0.0.0" + command: typing.Literal["run", "renew", "token"] = "run" @classmethod def from_namespace(cls, args: argparse.Namespace) -> Parameters: @@ -75,112 +75,100 @@ def __add_arg_int( ) -def __add_arg_str_required( - parser: argparse.ArgumentParser, - *flags: str, - env_var: str, - help_txt: str, -) -> None: - parser.add_argument( - *flags, - metavar=env_var, - required=os.getenv(env_var) is None, - default=os.getenv(env_var), - help=help_txt, - ) - - -def parse_parameters() -> Parameters: +def parse_parameters(args: typing.Sequence[str]) -> Parameters: + default_values = Parameters() parser = argparse.ArgumentParser( project.get_name(), description=project.get_description(), epilog=__EPILOG, suggest_on_error=True, ) - parser.add_argument("--debug", action=argparse.BooleanOptionalAction) + parser.add_argument( + "--debug", action=argparse.BooleanOptionalAction, default=default_values.debug + ) __add_arg_str( parser, "-d", "--data-dir", env_var="DATA_DIR", - default="./data", + default=default_values.data_dir, help_txt="directory where pages are/will be stored", ) parser.add_argument( "--certificates", action=argparse.BooleanOptionalAction, help="Handle certificates (default: true)", - default=True, + default=default_values.with_certificates, dest="with_certificates", ) - parser.add_argument( - "--certbot", - action=argparse.BooleanOptionalAction, - help="Use Certbot (default: true)", - default=True, - dest="with_certbot", - ) __add_arg_str( parser, "--self-signed-path", env_var="SELF_SIGNED_PATH", - default="./data/.certificates", + default=default_values.self_signed_path, help_txt="Self-signed certificates dir", ) + parser.add_argument( + "--certbot", + action=argparse.BooleanOptionalAction, + help="Use Certbot (default: true)", + default=default_values.with_certbot, + dest="with_certbot", + ) __add_arg_str( parser, "--certbot-conf", env_var="CERTBOT_CONF", - default="/etc/letsencrypt", + default=default_values.certbot_conf, help_txt="Certbot config dir", ) __add_arg_str( parser, "--certbot-www", env_var="CERTBOT_WWW", - default="./data/.certbot", + default=default_values.certbot_www, help_txt="Certbot www dir", ) __add_arg_str( parser, "--host", env_var="HOST", - default="localhost", + default=default_values.host, help_txt="server default host", ) __add_arg_int( parser, "--http-port", env_var="HTTP_PORT", - default=80, + default=default_values.http_port, help_txt="server http port", ) __add_arg_int( parser, "--https-port", env_var="HTTPS_PORT", - default=443, + default=default_values.https_port, help_txt="server https port", ) parser.add_argument( "--https", action=argparse.BooleanOptionalAction, help="Use https (implies --certificates) (default: true)", - default=True, + default=default_values.https, ) __add_arg_str( parser, "-t", "--token-salt", env_var="TOKEN_SALT", - default="", + default=default_values.token_salt, help_txt="salt for tokens generation", ) __add_arg_int( parser, "--max-size-bytes", env_var="MAX_SIZE", - default=2_000_000, + default=default_values.max_size_bytes, help_txt="max size of accepted archives (in bytes)", ) __add_arg_str( @@ -188,14 +176,14 @@ def parse_parameters() -> Parameters: "-b", "--bind", env_var="BIND", - default="0.0.0.0", + default=default_values.bind, help_txt="server bind address", ) subparsers = parser.add_subparsers(dest="command", required=True, metavar="COMMAND") subparsers.add_parser("run", help="Run Stapler server") subparsers.add_parser("renew", help="Renew certificates") subparsers.add_parser("token", help="Generate a new token") - args = parser.parse_args() - if args.https: - args.with_certificates = True - return Parameters.from_namespace(args) + parsed_args = parser.parse_args(args) + if parsed_args.https: + parsed_args.with_certificates = True + return Parameters.from_namespace(parsed_args) diff --git a/tests/test_params.py b/tests/test_params.py new file mode 100644 index 0000000..553bb48 --- /dev/null +++ b/tests/test_params.py @@ -0,0 +1,37 @@ +from src.params import Parameters, parse_parameters + +from . import BaseTestCase + + +class TestParams(BaseTestCase): + ENV_COUNT = 10 + + def test_parse_parameters(self) -> None: + with self.patch("os.getenv", return_value=None, count=self.ENV_COUNT): + params = parse_parameters(["run"]) + self.assertEqual(params, Parameters()) + + def test_parse_parameters_with_implied_certificates(self) -> None: + with self.patch("os.getenv", return_value=None, count=self.ENV_COUNT): + params = parse_parameters(["--no-certificates", "--https", "run"]) + assert params.with_certificates + + def test_parse_parameters_without_implied_certificates(self) -> None: + with self.patch("os.getenv", return_value=None, count=self.ENV_COUNT): + params = parse_parameters(["--no-certificates", "--no-https", "run"]) + assert not params.with_certificates + + def test_parse_parameters_with_env_var(self) -> None: + with self.patch("os.getenv", return_value="127.0.0.1", count=self.ENV_COUNT): + params = parse_parameters(["run"]) + self.assertEqual(params.bind, "127.0.0.1") + + def test_parse_parameters_with_env_var_int(self) -> None: + with self.patch("os.getenv", return_value="127", count=self.ENV_COUNT): + params = parse_parameters(["run"]) + self.assertEqual(params.http_port, 127) + + def test_parse_parameters_with_invalid_env_var_int(self) -> None: + with self.patch("os.getenv", return_value="aaa", count=self.ENV_COUNT): + params = parse_parameters(["run"]) + self.assertEqual(params.http_port, 80)