177 lines
5.8 KiB
Python
177 lines
5.8 KiB
Python
import contextlib
|
|
import pathlib
|
|
import tempfile
|
|
import typing
|
|
import unittest
|
|
import unittest.mock
|
|
|
|
__import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 # ty:ignore[unresolved-attribute] # noqa: SLF001
|
|
|
|
|
|
class BaseTestCase(unittest.TestCase):
|
|
@typing.override
|
|
def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
|
|
self.mocks: list[unittest.mock.Mock] = []
|
|
self.tmp_dir: tempfile.TemporaryDirectory | None = None
|
|
self.tmp_path: pathlib.Path = pathlib.Path()
|
|
super().__init__(*args, **kwargs)
|
|
|
|
@typing.override
|
|
def tearDown(self) -> None:
|
|
if self.tmp_dir is not None:
|
|
self.tmp_dir.cleanup()
|
|
self.tmp_dir = None
|
|
super().tearDown()
|
|
|
|
def get_tmp_dir(self) -> str:
|
|
self.tmp_dir = tempfile.TemporaryDirectory(delete=False)
|
|
self.tmp_path = pathlib.Path(self.tmp_dir.name)
|
|
return self.tmp_dir.name
|
|
|
|
def new_mock(self, spec: type | None = None) -> unittest.mock.Mock:
|
|
mock = unittest.mock.Mock(spec)
|
|
self.mocks += [mock]
|
|
return mock
|
|
|
|
@contextlib.contextmanager
|
|
def patch(
|
|
self, target: str, return_value: typing.Any = None, count: int = 1
|
|
) -> typing.Iterator[unittest.mock.Mock]:
|
|
with unittest.mock.patch(
|
|
target, return_value=return_value, create=True
|
|
) as mock:
|
|
yield mock
|
|
self.assertEqual(mock.call_count, count, mock)
|
|
|
|
@contextlib.contextmanager
|
|
def patch_calls(
|
|
self,
|
|
target: str,
|
|
args: list[typing.Iterable[typing.Any]] | None = None,
|
|
return_values: list[typing.Any] | None = None,
|
|
kwargs: list[dict[str, typing.Any]] | None = None,
|
|
) -> typing.Iterator[unittest.mock.Mock]:
|
|
if args is None:
|
|
args = [[]]
|
|
if return_values is None:
|
|
return_values = [None] * len(args)
|
|
if kwargs is None:
|
|
kwargs = [{}] * len(args)
|
|
with unittest.mock.patch(
|
|
target, side_effect=return_values, create=True
|
|
) as mock:
|
|
yield mock
|
|
self.__check_calls(mock, args, kwargs)
|
|
|
|
@contextlib.contextmanager
|
|
def patch_call(
|
|
self,
|
|
target: str,
|
|
args: typing.Iterable[typing.Any] | None = None,
|
|
return_value: typing.Any = None,
|
|
kwargs: dict[str, typing.Any] | None = None,
|
|
) -> typing.Iterator[unittest.mock.Mock]:
|
|
if args is None:
|
|
args = []
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
with self.patch_calls(target, [args], [return_value], [kwargs]) as mock:
|
|
yield mock
|
|
|
|
@contextlib.contextmanager
|
|
def seal_mocks(self, *extra_mocks: unittest.mock.Mock) -> typing.Iterator[None]:
|
|
for mock in self.mocks:
|
|
unittest.mock.seal(mock)
|
|
for mock in extra_mocks:
|
|
unittest.mock.seal(mock)
|
|
yield
|
|
|
|
@contextlib.contextmanager
|
|
def mock_calls(
|
|
self,
|
|
mock: unittest.mock.Mock,
|
|
args: list[typing.Iterable[typing.Any]] | None = None,
|
|
return_values: list[typing.Any] | None = None,
|
|
kwargs: list[dict[str, typing.Any]] | None = None,
|
|
) -> typing.Iterator[None]:
|
|
if args is None:
|
|
args = [[]]
|
|
if return_values is None:
|
|
return_values = [None] * len(args)
|
|
if kwargs is None:
|
|
kwargs = [{}] * len(args)
|
|
mock.side_effect = return_values
|
|
mock.reset_mock()
|
|
yield
|
|
self.__check_calls(mock, args, kwargs)
|
|
|
|
@contextlib.contextmanager
|
|
def mock_call(
|
|
self,
|
|
mock: unittest.mock.Mock,
|
|
args: typing.Iterable[typing.Any] | None = None,
|
|
return_value: typing.Any = None,
|
|
kwargs: dict[str, typing.Any] | None = None,
|
|
) -> typing.Iterator[None]:
|
|
if args is None:
|
|
args = []
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
with self.mock_calls(mock, [args], [return_value], [kwargs]):
|
|
yield
|
|
|
|
@contextlib.contextmanager
|
|
def mock_calls_unchecked(
|
|
self,
|
|
mock: unittest.mock.Mock,
|
|
count: int = 1,
|
|
return_values: list[typing.Any] | None = None,
|
|
) -> typing.Iterator[None]:
|
|
if return_values is None:
|
|
return_values = [None] * count
|
|
mock.side_effect = return_values
|
|
mock.reset_mock()
|
|
yield
|
|
self.assertEqual(mock.call_count, count, mock)
|
|
|
|
@contextlib.contextmanager
|
|
def mock_call_unchecked(
|
|
self,
|
|
mock: unittest.mock.Mock,
|
|
return_value: typing.Any = None,
|
|
) -> typing.Iterator[None]:
|
|
with self.mock_calls_unchecked(mock, 1, [return_value]):
|
|
yield
|
|
|
|
def assert_file_content(self, file: pathlib.Path, *expected_content: str) -> None:
|
|
assert file.parent.is_dir(), file
|
|
assert file.exists(), file
|
|
assert file.is_file(), file
|
|
with file.open() as file_content:
|
|
self.assertEqual(file_content.read(), "\n".join(expected_content))
|
|
|
|
def __check_calls(
|
|
self,
|
|
mock: unittest.mock.Mock,
|
|
args: list[typing.Iterable[typing.Any]],
|
|
kwargs: list[dict[str, typing.Any]],
|
|
) -> None:
|
|
total_rows = max(len(args), len(mock.method_calls), len(kwargs))
|
|
missing_calls = max(0, total_rows - len(mock.mock_calls))
|
|
missing_args = max(0, total_rows - len(args))
|
|
missing_kwargs = max(0, total_rows - len(kwargs))
|
|
for i, values in enumerate(
|
|
zip(
|
|
mock.mock_calls + [None] * missing_calls,
|
|
args + [[]] * missing_args,
|
|
kwargs + [{}] * missing_kwargs,
|
|
strict=False,
|
|
)
|
|
):
|
|
real_call, expected_args, expected_kwargs = values
|
|
self.assertEqual(
|
|
real_call,
|
|
unittest.mock.call(*expected_args, **expected_kwargs),
|
|
f"{i + 1}: {mock}",
|
|
)
|