More tests

This commit is contained in:
klemek
2020-09-06 12:33:49 +02:00
parent 8934acb004
commit c14993653e
2 changed files with 67 additions and 15 deletions
+52 -15
View File
@@ -1,6 +1,6 @@
from unittest import TestCase, skip from unittest import TestCase, skip
from unittest.mock import Mock, MagicMock, AsyncMock from unittest.mock import AsyncMock
from tests.utils import AsyncTestCase from tests.utils import AsyncTestCase, patch_discord
import discord import discord
from datetime import datetime from datetime import datetime
@@ -8,8 +8,8 @@ from miniscord._bot import Bot
class TestInit(TestCase): class TestInit(TestCase):
@patch_discord
def test_normal(self): def test_normal(self):
discord.Client = Mock()
bot = Bot("app_name", "version") bot = Bot("app_name", "version")
self.assertEqual("app_name", bot.app_name) self.assertEqual("app_name", bot.app_name)
self.assertEqual("version", bot.version) self.assertEqual("version", bot.version)
@@ -18,8 +18,8 @@ class TestInit(TestCase):
self.assertEqual(2, len(bot._Bot__commands)) self.assertEqual(2, len(bot._Bot__commands))
self.assertEqual(2, len(bot.games)) self.assertEqual(2, len(bot.games))
@patch_discord
def test_alias(self): def test_alias(self):
discord.Client = Mock()
bot = Bot("app_name", "version", alias="alias") bot = Bot("app_name", "version", alias="alias")
self.assertEqual("app_name", bot.app_name) self.assertEqual("app_name", bot.app_name)
self.assertEqual("version", bot.version) self.assertEqual("version", bot.version)
@@ -30,8 +30,8 @@ class TestInit(TestCase):
class TestInfo(AsyncTestCase): class TestInfo(AsyncTestCase):
@patch_discord
def test(self): def test(self):
discord.Client = Mock()
bot = Bot("app_name", "version") bot = Bot("app_name", "version")
message = AsyncMock() message = AsyncMock()
t0 = datetime.now() t0 = datetime.now()
@@ -48,8 +48,8 @@ class TestInfo(AsyncTestCase):
class TestHelp(AsyncTestCase): class TestHelp(AsyncTestCase):
@patch_discord
def test_list_minimal(self): def test_list_minimal(self):
discord.Client = Mock()
bot = Bot("app_name", "version") bot = Bot("app_name", "version")
message = AsyncMock() message = AsyncMock()
self._await(bot.help(None, message, "help")) self._await(bot.help(None, message, "help"))
@@ -61,8 +61,8 @@ class TestHelp(AsyncTestCase):
f"```" f"```"
) )
@patch_discord
def test_list_alias(self): def test_list_alias(self):
discord.Client = Mock()
bot = Bot("app_name", "version", alias="¡") bot = Bot("app_name", "version", alias="¡")
message = AsyncMock() message = AsyncMock()
self._await(bot.help(None, message, "help")) self._await(bot.help(None, message, "help"))
@@ -74,8 +74,8 @@ class TestHelp(AsyncTestCase):
f"```" f"```"
) )
@patch_discord
def test_list_functions(self): def test_list_functions(self):
discord.Client = Mock()
bot = Bot("app_name", "version") bot = Bot("app_name", "version")
bot.register_command("", None, "test1: desc1", None) bot.register_command("", None, "test1: desc1", None)
bot.register_command("", None, "test2: desc2", None) bot.register_command("", None, "test2: desc2", None)
@@ -91,16 +91,16 @@ class TestHelp(AsyncTestCase):
f"```" f"```"
) )
@patch_discord
def test_long(self): def test_long(self):
discord.Client = Mock()
bot = Bot("app_name", "version") bot = Bot("app_name", "version")
bot.register_command("test1", None, None, "long desc") bot.register_command("test1", None, None, "long desc")
message = AsyncMock() message = AsyncMock()
self._await(bot.help(None, message, "help", "test1")) self._await(bot.help(None, message, "help", "test1"))
message.channel.send.assert_awaited_once_with("long desc") message.channel.send.assert_awaited_once_with("long desc")
@patch_discord
def test_long_regex(self): def test_long_regex(self):
discord.Client = Mock()
bot = Bot("app_name", "version") bot = Bot("app_name", "version")
bot.register_command("test", None, None, "desc1") bot.register_command("test", None, None, "desc1")
bot.register_command("t.*", None, None, "desc2") bot.register_command("t.*", None, None, "desc2")
@@ -110,9 +110,28 @@ class TestHelp(AsyncTestCase):
class TestRegisterCommand(TestCase): class TestRegisterCommand(TestCase):
@skip @patch_discord
def test_todo(self): def test_normal(self):
self.fail("not implemented") bot = Bot("app_name", "version")
self.assertEqual(2, len(bot._Bot__commands))
def fn():
pass
bot.register_command("^t[eo]a?st$", fn, "short", "long")
self.assertEqual(3, len(bot._Bot__commands))
cmd = bot._Bot__commands[0]
self.assertEqual("^t[eo]a?st$", cmd.regex)
self.assertEqual(fn, cmd.compute)
self.assertEqual("short", cmd.help_short)
self.assertEqual("long", cmd.help_long)
@patch_discord
def test_add_regex(self):
bot = Bot("app_name", "version")
bot.register_command("test", None, None, None)
cmd = bot._Bot__commands[0]
self.assertEqual("^test$", cmd.regex)
class TestOnMessage(AsyncTestCase): class TestOnMessage(AsyncTestCase):
@@ -128,12 +147,30 @@ class TestOnReady(AsyncTestCase):
class TestOnGuildJoin(AsyncTestCase): class TestOnGuildJoin(AsyncTestCase):
@patch_discord
def test_no_log(self):
bot = Bot("app_name", "version")
bot.guild_logs_file = None
guild = AsyncMock()
self._await(bot.on_guild_join(guild))
# nothing
# TODO test normal file path
@skip @skip
def test_todo(self): def test_log(self):
self.fail("not implemented") self.fail("not implemented")
class TestOnGuildRemove(AsyncTestCase): class TestOnGuildRemove(AsyncTestCase):
@patch_discord
def test_no_log(self):
bot = Bot("app_name", "version")
bot.guild_logs_file = None
guild = AsyncMock()
self._await(bot.on_guild_remove(guild))
# nothing
# TODO test normal file path
@skip @skip
def test_todo(self): def test_log(self):
self.fail("not implemented") self.fail("not implemented")
+15
View File
@@ -1,4 +1,5 @@
from unittest import TestCase from unittest import TestCase
from unittest.mock import MagicMock, patch
import asyncio import asyncio
@@ -12,3 +13,17 @@ class AsyncTestCase(TestCase):
def _await(self, fn): def _await(self, fn):
return self.loop.run_until_complete(fn) return self.loop.run_until_complete(fn)
def pass_through(arg):
return arg
def patch_discord(test):
def wrapper(*args):
m = MagicMock()
m.event = pass_through
with patch("discord.Client", return_value=m):
test(*args)
return wrapper