From c186668208556a9c4283ecd36b3232f6e112437a Mon Sep 17 00:00:00 2001 From: klemek Date: Wed, 3 Jun 2026 19:51:08 +0200 Subject: [PATCH] fix: check loopback before cert creation --- stapler/cert_manager.py | 18 +++++++++++++++++- tests/test_cert_manager.py | 18 +++++++++++++++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/stapler/cert_manager.py b/stapler/cert_manager.py index 42109d4..e188147 100644 --- a/stapler/cert_manager.py +++ b/stapler/cert_manager.py @@ -5,6 +5,8 @@ import ssl import subprocess import typing +import requests + from stapler.strings import valid_host if typing.TYPE_CHECKING: @@ -49,6 +51,18 @@ class CertManager: def exists(self, host: str) -> bool: return self.__exists_certbot(host) or self.__exists_self_signed(host) + def valid_host(self, host: str) -> bool: + try: + response = requests.head( + url=f"http://{host}/.well-known/stapler", + allow_redirects=True, + timeout=5, + stream=False, + ) + return type(response.status_code) is int and response.status_code < 400 + except Exception: + return False + def init_cert(self, host: str) -> bool: if not self.exists(host): return self.__create_self_signed(host) @@ -196,7 +210,9 @@ class CertManager: if host is None or not valid_host(host): return None self.logger.debug("servername callback: %s", host) - if not self.exists(host) and not self.create_or_update(host): + if not self.exists(host) and ( + not self.valid_host(host) or not self.create_or_update(host) + ): return None cert_file = self.get_cert(host) key_file = self.get_key(host) diff --git a/tests/test_cert_manager.py b/tests/test_cert_manager.py index 7debbfe..be381c4 100644 --- a/tests/test_cert_manager.py +++ b/tests/test_cert_manager.py @@ -4,6 +4,8 @@ import subprocess import typing import unittest.mock +import requests + from stapler.cert_manager import CertManager, CertManagerError from stapler.params import Parameters @@ -170,10 +172,24 @@ class TestRegistry(BaseTestCase): self.socket_mock, None, self.context_mock ) - def test_servername_callback_fail(self) -> None: + def test_servername_callback_fail_no_valid_host(self) -> None: self._make_self_signed("example.com") + with ( + self.patch("requests.head") as request_mock, + self.patch("ssl.create_default_context", count=0), + ): + request_mock.side_effect = Exception() + self.cert_manager.servername_callback( + self.socket_mock, "example.fr", self.context_mock + ) + + def test_servername_callback_fail_no_binaries(self) -> None: + self._make_self_signed("example.com") + response = requests.Response() + response.status_code = 200 with ( self.patch("shutil.which", count=3), + self.patch("requests.head", response), self.patch("ssl.create_default_context", count=0), ): self.cert_manager.servername_callback(