tests(params): add tests for params

This commit is contained in:
2026-04-17 23:51:52 +02:00
parent 66f7f605d1
commit 98a1a330c5
3 changed files with 81 additions and 56 deletions
+1 -1
View File
@@ -6,7 +6,7 @@ from src.server import StaplerServer
def main() -> None: def main() -> None:
params = parse_parameters() params = parse_parameters(sys.argv[1:])
setup_logs(params) setup_logs(params)
server = StaplerServer(params) server = StaplerServer(params)
method = getattr(server, params.command) method = getattr(server, params.command)
+43 -55
View File
@@ -10,21 +10,21 @@ __EPILOG = "(Each option can be supplied with equivalent environment variable.)"
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Parameters: class Parameters:
http_port: int debug: bool = False
https_port: int data_dir: str = "./data"
host: str with_certificates: bool = True
data_dir: str self_signed_path: str = "./data/.certificates"
bind: str with_certbot: bool = True
token_salt: str certbot_conf: str = "/etc/letsencrypt"
max_size_bytes: int certbot_www: str = "./data/.certbot"
certbot_conf: str host: str = "localhost"
certbot_www: str http_port: int = 80
self_signed_path: str https_port: int = 443
with_certbot: bool https: bool = True
with_certificates: bool token_salt: str = ""
https: bool max_size_bytes: int = 2_000_000
command: typing.Literal["run", "renew"] bind: str = "0.0.0.0"
debug: bool command: typing.Literal["run", "renew", "token"] = "run"
@classmethod @classmethod
def from_namespace(cls, args: argparse.Namespace) -> Parameters: def from_namespace(cls, args: argparse.Namespace) -> Parameters:
@@ -75,112 +75,100 @@ def __add_arg_int(
) )
def __add_arg_str_required( def parse_parameters(args: typing.Sequence[str]) -> Parameters:
parser: argparse.ArgumentParser, default_values = Parameters()
*flags: str,
env_var: str,
help_txt: str,
) -> None:
parser.add_argument(
*flags,
metavar=env_var,
required=os.getenv(env_var) is None,
default=os.getenv(env_var),
help=help_txt,
)
def parse_parameters() -> Parameters:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
project.get_name(), project.get_name(),
description=project.get_description(), description=project.get_description(),
epilog=__EPILOG, epilog=__EPILOG,
suggest_on_error=True, suggest_on_error=True,
) )
parser.add_argument("--debug", action=argparse.BooleanOptionalAction) parser.add_argument(
"--debug", action=argparse.BooleanOptionalAction, default=default_values.debug
)
__add_arg_str( __add_arg_str(
parser, parser,
"-d", "-d",
"--data-dir", "--data-dir",
env_var="DATA_DIR", env_var="DATA_DIR",
default="./data", default=default_values.data_dir,
help_txt="directory where pages are/will be stored", help_txt="directory where pages are/will be stored",
) )
parser.add_argument( parser.add_argument(
"--certificates", "--certificates",
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
help="Handle certificates (default: true)", help="Handle certificates (default: true)",
default=True, default=default_values.with_certificates,
dest="with_certificates", dest="with_certificates",
) )
parser.add_argument(
"--certbot",
action=argparse.BooleanOptionalAction,
help="Use Certbot (default: true)",
default=True,
dest="with_certbot",
)
__add_arg_str( __add_arg_str(
parser, parser,
"--self-signed-path", "--self-signed-path",
env_var="SELF_SIGNED_PATH", env_var="SELF_SIGNED_PATH",
default="./data/.certificates", default=default_values.self_signed_path,
help_txt="Self-signed certificates dir", help_txt="Self-signed certificates dir",
) )
parser.add_argument(
"--certbot",
action=argparse.BooleanOptionalAction,
help="Use Certbot (default: true)",
default=default_values.with_certbot,
dest="with_certbot",
)
__add_arg_str( __add_arg_str(
parser, parser,
"--certbot-conf", "--certbot-conf",
env_var="CERTBOT_CONF", env_var="CERTBOT_CONF",
default="/etc/letsencrypt", default=default_values.certbot_conf,
help_txt="Certbot config dir", help_txt="Certbot config dir",
) )
__add_arg_str( __add_arg_str(
parser, parser,
"--certbot-www", "--certbot-www",
env_var="CERTBOT_WWW", env_var="CERTBOT_WWW",
default="./data/.certbot", default=default_values.certbot_www,
help_txt="Certbot www dir", help_txt="Certbot www dir",
) )
__add_arg_str( __add_arg_str(
parser, parser,
"--host", "--host",
env_var="HOST", env_var="HOST",
default="localhost", default=default_values.host,
help_txt="server default host", help_txt="server default host",
) )
__add_arg_int( __add_arg_int(
parser, parser,
"--http-port", "--http-port",
env_var="HTTP_PORT", env_var="HTTP_PORT",
default=80, default=default_values.http_port,
help_txt="server http port", help_txt="server http port",
) )
__add_arg_int( __add_arg_int(
parser, parser,
"--https-port", "--https-port",
env_var="HTTPS_PORT", env_var="HTTPS_PORT",
default=443, default=default_values.https_port,
help_txt="server https port", help_txt="server https port",
) )
parser.add_argument( parser.add_argument(
"--https", "--https",
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
help="Use https (implies --certificates) (default: true)", help="Use https (implies --certificates) (default: true)",
default=True, default=default_values.https,
) )
__add_arg_str( __add_arg_str(
parser, parser,
"-t", "-t",
"--token-salt", "--token-salt",
env_var="TOKEN_SALT", env_var="TOKEN_SALT",
default="", default=default_values.token_salt,
help_txt="salt for tokens generation", help_txt="salt for tokens generation",
) )
__add_arg_int( __add_arg_int(
parser, parser,
"--max-size-bytes", "--max-size-bytes",
env_var="MAX_SIZE", env_var="MAX_SIZE",
default=2_000_000, default=default_values.max_size_bytes,
help_txt="max size of accepted archives (in bytes)", help_txt="max size of accepted archives (in bytes)",
) )
__add_arg_str( __add_arg_str(
@@ -188,14 +176,14 @@ def parse_parameters() -> Parameters:
"-b", "-b",
"--bind", "--bind",
env_var="BIND", env_var="BIND",
default="0.0.0.0", default=default_values.bind,
help_txt="server bind address", help_txt="server bind address",
) )
subparsers = parser.add_subparsers(dest="command", required=True, metavar="COMMAND") subparsers = parser.add_subparsers(dest="command", required=True, metavar="COMMAND")
subparsers.add_parser("run", help="Run Stapler server") subparsers.add_parser("run", help="Run Stapler server")
subparsers.add_parser("renew", help="Renew certificates") subparsers.add_parser("renew", help="Renew certificates")
subparsers.add_parser("token", help="Generate a new token") subparsers.add_parser("token", help="Generate a new token")
args = parser.parse_args() parsed_args = parser.parse_args(args)
if args.https: if parsed_args.https:
args.with_certificates = True parsed_args.with_certificates = True
return Parameters.from_namespace(args) return Parameters.from_namespace(parsed_args)
+37
View File
@@ -0,0 +1,37 @@
from src.params import Parameters, parse_parameters
from . import BaseTestCase
class TestParams(BaseTestCase):
ENV_COUNT = 10
def test_parse_parameters(self) -> None:
with self.patch("os.getenv", return_value=None, count=self.ENV_COUNT):
params = parse_parameters(["run"])
self.assertEqual(params, Parameters())
def test_parse_parameters_with_implied_certificates(self) -> None:
with self.patch("os.getenv", return_value=None, count=self.ENV_COUNT):
params = parse_parameters(["--no-certificates", "--https", "run"])
assert params.with_certificates
def test_parse_parameters_without_implied_certificates(self) -> None:
with self.patch("os.getenv", return_value=None, count=self.ENV_COUNT):
params = parse_parameters(["--no-certificates", "--no-https", "run"])
assert not params.with_certificates
def test_parse_parameters_with_env_var(self) -> None:
with self.patch("os.getenv", return_value="127.0.0.1", count=self.ENV_COUNT):
params = parse_parameters(["run"])
self.assertEqual(params.bind, "127.0.0.1")
def test_parse_parameters_with_env_var_int(self) -> None:
with self.patch("os.getenv", return_value="127", count=self.ENV_COUNT):
params = parse_parameters(["run"])
self.assertEqual(params.http_port, 127)
def test_parse_parameters_with_invalid_env_var_int(self) -> None:
with self.patch("os.getenv", return_value="aaa", count=self.ENV_COUNT):
params = parse_parameters(["run"])
self.assertEqual(params.http_port, 80)