diff --git a/tests/unit/miniscord/test_bot.py b/tests/unit/miniscord/test_bot.py index f2eb339..399f0cb 100644 --- a/tests/unit/miniscord/test_bot.py +++ b/tests/unit/miniscord/test_bot.py @@ -1,6 +1,6 @@ from unittest import TestCase, skip -from unittest.mock import Mock, MagicMock, AsyncMock -from tests.utils import AsyncTestCase +from unittest.mock import AsyncMock +from tests.utils import AsyncTestCase, patch_discord import discord from datetime import datetime @@ -8,8 +8,8 @@ from miniscord._bot import Bot class TestInit(TestCase): + @patch_discord def test_normal(self): - discord.Client = Mock() bot = Bot("app_name", "version") self.assertEqual("app_name", bot.app_name) self.assertEqual("version", bot.version) @@ -18,8 +18,8 @@ class TestInit(TestCase): self.assertEqual(2, len(bot._Bot__commands)) self.assertEqual(2, len(bot.games)) + @patch_discord def test_alias(self): - discord.Client = Mock() bot = Bot("app_name", "version", alias="alias") self.assertEqual("app_name", bot.app_name) self.assertEqual("version", bot.version) @@ -30,8 +30,8 @@ class TestInit(TestCase): class TestInfo(AsyncTestCase): + @patch_discord def test(self): - discord.Client = Mock() bot = Bot("app_name", "version") message = AsyncMock() t0 = datetime.now() @@ -48,8 +48,8 @@ class TestInfo(AsyncTestCase): class TestHelp(AsyncTestCase): + @patch_discord def test_list_minimal(self): - discord.Client = Mock() bot = Bot("app_name", "version") message = AsyncMock() self._await(bot.help(None, message, "help")) @@ -61,8 +61,8 @@ class TestHelp(AsyncTestCase): f"```" ) + @patch_discord def test_list_alias(self): - discord.Client = Mock() bot = Bot("app_name", "version", alias="ยก") message = AsyncMock() self._await(bot.help(None, message, "help")) @@ -74,8 +74,8 @@ class TestHelp(AsyncTestCase): f"```" ) + @patch_discord def test_list_functions(self): - discord.Client = Mock() bot = Bot("app_name", "version") bot.register_command("", None, "test1: desc1", None) bot.register_command("", None, "test2: desc2", None) @@ -91,16 +91,16 @@ class TestHelp(AsyncTestCase): f"```" ) + @patch_discord def test_long(self): - discord.Client = Mock() bot = Bot("app_name", "version") bot.register_command("test1", None, None, "long desc") message = AsyncMock() self._await(bot.help(None, message, "help", "test1")) message.channel.send.assert_awaited_once_with("long desc") + @patch_discord def test_long_regex(self): - discord.Client = Mock() bot = Bot("app_name", "version") bot.register_command("test", None, None, "desc1") bot.register_command("t.*", None, None, "desc2") @@ -110,9 +110,28 @@ class TestHelp(AsyncTestCase): class TestRegisterCommand(TestCase): - @skip - def test_todo(self): - self.fail("not implemented") + @patch_discord + def test_normal(self): + bot = Bot("app_name", "version") + self.assertEqual(2, len(bot._Bot__commands)) + + def fn(): + pass + + bot.register_command("^t[eo]a?st$", fn, "short", "long") + self.assertEqual(3, len(bot._Bot__commands)) + cmd = bot._Bot__commands[0] + self.assertEqual("^t[eo]a?st$", cmd.regex) + self.assertEqual(fn, cmd.compute) + self.assertEqual("short", cmd.help_short) + self.assertEqual("long", cmd.help_long) + + @patch_discord + def test_add_regex(self): + bot = Bot("app_name", "version") + bot.register_command("test", None, None, None) + cmd = bot._Bot__commands[0] + self.assertEqual("^test$", cmd.regex) class TestOnMessage(AsyncTestCase): @@ -128,12 +147,30 @@ class TestOnReady(AsyncTestCase): class TestOnGuildJoin(AsyncTestCase): + @patch_discord + def test_no_log(self): + bot = Bot("app_name", "version") + bot.guild_logs_file = None + guild = AsyncMock() + self._await(bot.on_guild_join(guild)) + # nothing + # TODO test normal file path + @skip - def test_todo(self): + def test_log(self): self.fail("not implemented") class TestOnGuildRemove(AsyncTestCase): + @patch_discord + def test_no_log(self): + bot = Bot("app_name", "version") + bot.guild_logs_file = None + guild = AsyncMock() + self._await(bot.on_guild_remove(guild)) + # nothing + # TODO test normal file path + @skip - def test_todo(self): + def test_log(self): self.fail("not implemented") diff --git a/tests/utils.py b/tests/utils.py index 299555d..99f9e57 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,5 @@ from unittest import TestCase +from unittest.mock import MagicMock, patch import asyncio @@ -12,3 +13,17 @@ class AsyncTestCase(TestCase): def _await(self, fn): return self.loop.run_until_complete(fn) + + +def pass_through(arg): + return arg + + +def patch_discord(test): + def wrapper(*args): + m = MagicMock() + m.event = pass_through + with patch("discord.Client", return_value=m): + test(*args) + + return wrapper