import contextlib import pathlib import tempfile import typing import unittest import unittest.mock __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 # ty:ignore[unresolved-attribute] # noqa: SLF001 class BaseTestCase(unittest.TestCase): @typing.override def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: self.mocks: list[unittest.mock.Mock] = [] self.tmp_dir: tempfile.TemporaryDirectory | None = None self.tmp_path: pathlib.Path = pathlib.Path() super().__init__(*args, **kwargs) @typing.override def tearDown(self) -> None: if self.tmp_dir is not None: self.tmp_dir.cleanup() self.tmp_dir = None super().tearDown() def get_tmp_dir(self) -> str: self.tmp_dir = tempfile.TemporaryDirectory(delete=False) self.tmp_path = pathlib.Path(self.tmp_dir.name) return self.tmp_dir.name def new_mock(self, spec: type | None = None) -> unittest.mock.Mock: mock = unittest.mock.Mock(spec) self.mocks += [mock] return mock @contextlib.contextmanager def patch( self, target: str, return_value: typing.Any = None, count: int = 1 ) -> typing.Iterator[unittest.mock.Mock]: with unittest.mock.patch( target, return_value=return_value, create=True ) as mock: yield mock self.assertEqual(mock.call_count, count, mock) @contextlib.contextmanager def patch_calls( self, target: str, args: list[typing.Iterable[typing.Any]] | None = None, return_values: list[typing.Any] | None = None, kwargs: list[dict[str, typing.Any]] | None = None, ) -> typing.Iterator[unittest.mock.Mock]: if args is None: args = [[]] if return_values is None: return_values = [None] * len(args) if kwargs is None: kwargs = [{}] * len(args) with unittest.mock.patch( target, side_effect=return_values, create=True ) as mock: yield mock self.__check_calls(mock, args, kwargs) @contextlib.contextmanager def patch_call( self, target: str, args: typing.Iterable[typing.Any] | None = None, return_value: typing.Any = None, kwargs: dict[str, typing.Any] | None = None, ) -> typing.Iterator[unittest.mock.Mock]: if args is None: args = [] if kwargs is None: kwargs = {} with self.patch_calls(target, [args], [return_value], [kwargs]) as mock: yield mock @contextlib.contextmanager def seal_mocks(self, *extra_mocks: unittest.mock.Mock) -> typing.Iterator[None]: for mock in self.mocks: unittest.mock.seal(mock) for mock in extra_mocks: unittest.mock.seal(mock) yield @contextlib.contextmanager def mock_calls( self, mock: unittest.mock.Mock, args: list[typing.Iterable[typing.Any]] | None = None, return_values: list[typing.Any] | None = None, kwargs: list[dict[str, typing.Any]] | None = None, ) -> typing.Iterator[None]: if args is None: args = [[]] if return_values is None: return_values = [None] * len(args) if kwargs is None: kwargs = [{}] * len(args) mock.side_effect = return_values mock.reset_mock() yield self.__check_calls(mock, args, kwargs) @contextlib.contextmanager def mock_call( self, mock: unittest.mock.Mock, args: typing.Iterable[typing.Any] | None = None, return_value: typing.Any = None, kwargs: dict[str, typing.Any] | None = None, ) -> typing.Iterator[None]: if args is None: args = [] if kwargs is None: kwargs = {} with self.mock_calls(mock, [args], [return_value], [kwargs]): yield @contextlib.contextmanager def mock_calls_unchecked( self, mock: unittest.mock.Mock, count: int = 1, return_values: list[typing.Any] | None = None, ) -> typing.Iterator[None]: if return_values is None: return_values = [None] * count mock.side_effect = return_values mock.reset_mock() yield self.assertEqual(mock.call_count, count, mock) @contextlib.contextmanager def mock_call_unchecked( self, mock: unittest.mock.Mock, return_value: typing.Any = None, ) -> typing.Iterator[None]: with self.mock_calls_unchecked(mock, 1, [return_value]): yield def assert_file_content(self, file: pathlib.Path, *expected_content: str) -> None: assert file.parent.is_dir(), file assert file.exists(), file assert file.is_file(), file with file.open() as file_content: self.assertEqual(file_content.read(), "\n".join(expected_content)) def __check_calls( self, mock: unittest.mock.Mock, args: list[typing.Iterable[typing.Any]], kwargs: list[dict[str, typing.Any]], ) -> None: total_rows = max(len(args), len(mock.method_calls), len(kwargs)) missing_calls = max(0, total_rows - len(mock.mock_calls)) missing_args = max(0, total_rows - len(args)) missing_kwargs = max(0, total_rows - len(kwargs)) for i, values in enumerate( zip( mock.mock_calls + [None] * missing_calls, args + [[]] * missing_args, kwargs + [{}] * missing_kwargs, strict=False, ) ): real_call, expected_args, expected_kwargs = values self.assertEqual( real_call, unittest.mock.call(*expected_args, **expected_kwargs), f"{i + 1}: {mock}", )