diff --git a/miniscord/_bot.py b/miniscord/_bot.py index 347887b..c6ed49c 100644 --- a/miniscord/_bot.py +++ b/miniscord/_bot.py @@ -66,18 +66,22 @@ class Bot(object): def __register_commands(self): # register default commands tmp_alias = '' if self.alias is None else self.alias - self.register_command("(help|h)", self.help, "help: show this help", - f"```\n" - f"* {tmp_alias}help\n" - f"\tShows the list of commands.\n" - f"* {tmp_alias}help [command]\n" - f"\tShows help about a specific command.\n" - f"```") - self.register_command("(info|about)", self.info, "info: show description", - f"```\n" - f"* {tmp_alias}info:\n" - f"\tShows this bot's status.\n" - f"```") + self.register_command( + "(help|h)", self.help, "help: show this help", + f"```\n" + f"* {tmp_alias}help\n" + f"\tShows the list of commands.\n" + f"* {tmp_alias}help [command]\n" + f"\tShows help about a specific command.\n" + f"```" + ) + self.register_command( + "(info|about)", self.info, "info: show description", + f"```\n" + f"* {tmp_alias}info:\n" + f"\tShows this bot's status.\n" + f"```" + ) def __generate_game(self) -> str: game = random.choice(self.games) @@ -86,12 +90,14 @@ class Bot(object): else: return game - async def info(self, _client: discord.client, message: discord.Message): - await message.channel.send(f"```\n" - f"{self.app_name} v{self.version}\n" - f"* Started at {self.__t0:%Y-%m-%d %H:%M}\n" - f"* Connected to {len(self.client.guilds)} guilds\n" - f"```") + async def info(self, _client: discord.client, message: discord.Message, *args: str): + await message.channel.send( + f"```\n" + f"{self.app_name} v{self.version}\n" + f"* Started at {self.__t0:%Y-%m-%d %H:%M}\n" + f"* Connected to {len(self.client.guilds)} guilds\n" + f"```" + ) async def help(self, _client: discord.client, message: discord.Message, *args: str): if len(args) <= 1: @@ -100,7 +106,8 @@ class Bot(object): f"```\n" f"List of available commands:\n" + "".join([f"* {tmp_alias}{command.help_short}\n" for command in self.__commands]) + - f"```") + f"```" + ) else: for command in self.__commands: if re.match(command.regex, args[1].lower() if self.lower_command_names else args[1]): @@ -159,7 +166,8 @@ class Bot(object): await message.author.create_dm() await message.author.dm_channel.send( f"Hi, this bot doesn\'t have the permission to send a message to" - f" #{message.channel} in server '{message.guild}'") + f" #{message.channel} in server '{message.guild}'" + ) return await command.compute(self.client, message, *command_args) break @@ -176,6 +184,10 @@ class Bot(object): pass def register_command(self, regex: str, compute: CommandFunction, help_short: str, help_long: str): + if not regex.startswith("^"): + regex = "^" + regex + if not regex.endswith("$"): + regex = regex + "$" self.__commands.insert(0, Command(regex, compute, help_short, help_long)) def start(self): @@ -200,8 +212,10 @@ class Bot(object): self.__last_error = repr(e) filename = f"error_{t:%Y-%m-%d_%H-%M-%S}.txt" with open(filename, 'w') as f: - f.write(f"{self.app_name} v{self.version} started at {self.__t0:%Y-%m-%d %H:%M}\r\n" - f"Exception raised at {t:%Y-%m-%d %H:%M}\r\n" - f"\r\n" - f"{traceback.format_exc()}") + f.write( + f"{self.app_name} v{self.version} started at {self.__t0:%Y-%m-%d %H:%M}\r\n" + f"Exception raised at {t:%Y-%m-%d %H:%M}\r\n" + f"\r\n" + f"{traceback.format_exc()}" + ) time.sleep(self.error_restart_delay) diff --git a/tests/unit/miniscord/test_bot.py b/tests/unit/miniscord/test_bot.py new file mode 100644 index 0000000..f2eb339 --- /dev/null +++ b/tests/unit/miniscord/test_bot.py @@ -0,0 +1,139 @@ +from unittest import TestCase, skip +from unittest.mock import Mock, MagicMock, AsyncMock +from tests.utils import AsyncTestCase + +import discord +from datetime import datetime +from miniscord._bot import Bot + + +class TestInit(TestCase): + def test_normal(self): + discord.Client = Mock() + bot = Bot("app_name", "version") + self.assertEqual("app_name", bot.app_name) + self.assertEqual("version", bot.version) + self.assertIsNone(bot.alias) + discord.Client.assert_called_once() + self.assertEqual(2, len(bot._Bot__commands)) + self.assertEqual(2, len(bot.games)) + + def test_alias(self): + discord.Client = Mock() + bot = Bot("app_name", "version", alias="alias") + self.assertEqual("app_name", bot.app_name) + self.assertEqual("version", bot.version) + self.assertEqual("alias", bot.alias) + discord.Client.assert_called_once() + self.assertEqual(2, len(bot._Bot__commands)) + self.assertEqual(3, len(bot.games)) + + +class TestInfo(AsyncTestCase): + def test(self): + discord.Client = Mock() + bot = Bot("app_name", "version") + message = AsyncMock() + t0 = datetime.now() + bot._Bot__t0 = t0 + bot.client.guilds = [None, None, None] + self._await(bot.info(None, message, "info")) + message.channel.send.assert_awaited_once_with( + f"```\n" + f"app_name vversion\n" + f"* Started at {t0:%Y-%m-%d %H:%M}\n" + f"* Connected to 3 guilds\n" + f"```" + ) + + +class TestHelp(AsyncTestCase): + def test_list_minimal(self): + discord.Client = Mock() + bot = Bot("app_name", "version") + message = AsyncMock() + self._await(bot.help(None, message, "help")) + message.channel.send.assert_awaited_once_with( + f"```\n" + f"List of available commands:\n" + f"* info: show description\n" + f"* help: show this help\n" + f"```" + ) + + def test_list_alias(self): + discord.Client = Mock() + bot = Bot("app_name", "version", alias="¡") + message = AsyncMock() + self._await(bot.help(None, message, "help")) + message.channel.send.assert_awaited_once_with( + f"```\n" + f"List of available commands:\n" + f"* ¡info: show description\n" + f"* ¡help: show this help\n" + f"```" + ) + + def test_list_functions(self): + discord.Client = Mock() + bot = Bot("app_name", "version") + bot.register_command("", None, "test1: desc1", None) + bot.register_command("", None, "test2: desc2", None) + message = AsyncMock() + self._await(bot.help(None, message, "help")) + message.channel.send.assert_awaited_once_with( + f"```\n" + f"List of available commands:\n" + f"* test2: desc2\n" + f"* test1: desc1\n" + f"* info: show description\n" + f"* help: show this help\n" + f"```" + ) + + def test_long(self): + discord.Client = Mock() + bot = Bot("app_name", "version") + bot.register_command("test1", None, None, "long desc") + message = AsyncMock() + self._await(bot.help(None, message, "help", "test1")) + message.channel.send.assert_awaited_once_with("long desc") + + def test_long_regex(self): + discord.Client = Mock() + bot = Bot("app_name", "version") + bot.register_command("test", None, None, "desc1") + bot.register_command("t.*", None, None, "desc2") + message = AsyncMock() + self._await(bot.help(None, message, "help", "test")) + message.channel.send.assert_awaited_once_with("desc2") + + +class TestRegisterCommand(TestCase): + @skip + def test_todo(self): + self.fail("not implemented") + + +class TestOnMessage(AsyncTestCase): + @skip + def test_todo(self): + self.fail("not implemented") + + +class TestOnReady(AsyncTestCase): + @skip + def test_todo(self): + self.fail("not implemented") + + +class TestOnGuildJoin(AsyncTestCase): + @skip + def test_todo(self): + self.fail("not implemented") + + +class TestOnGuildRemove(AsyncTestCase): + @skip + def test_todo(self): + self.fail("not implemented") diff --git a/tests/unit/miniscord/test_discord_utils.py b/tests/unit/miniscord/test_discord_utils.py index 74223df..9a46e55 100644 --- a/tests/unit/miniscord/test_discord_utils.py +++ b/tests/unit/miniscord/test_discord_utils.py @@ -1,47 +1,41 @@ from unittest import TestCase -import asyncio from unittest.mock import Mock, AsyncMock +from tests.utils import AsyncTestCase + import discord from miniscord._discord_utils import delete_message, message_id -class TestDeleteMessage(TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() - +class TestDeleteMessage(AsyncTestCase): def test_success(self): - mock = AsyncMock() - self.assertTrue(self.loop.run_until_complete(delete_message(mock))) - mock.delete.assert_awaited_once() + message = AsyncMock() + self.assertTrue(self._await(delete_message(message))) + message.delete.assert_awaited_once() def test_forbidden(self): - mock = AsyncMock() - mock.delete.side_effect = discord.Forbidden(Mock(), "") - self.assertFalse(self.loop.run_until_complete(delete_message(mock))) - mock.delete.assert_awaited_once() + message = AsyncMock() + message.delete.side_effect = discord.Forbidden(Mock(), "") + self.assertFalse(self._await(delete_message(message))) + message.delete.assert_awaited_once() def test_not_found(self): - mock = AsyncMock() - mock.delete.side_effect = discord.NotFound(Mock(), "") - self.assertFalse(self.loop.run_until_complete(delete_message(mock))) - mock.delete.assert_awaited_once() + message = AsyncMock() + message.delete.side_effect = discord.NotFound(Mock(), "") + self.assertFalse(self._await(delete_message(message))) + message.delete.assert_awaited_once() class TestMessageId(TestCase): def test_direct(self): - mock = Mock() - mock.channel.type = discord.ChannelType.private - mock.author.id = "TEST" - self.assertEqual("TEST", message_id(mock)) + message = Mock() + message.channel.type = discord.ChannelType.private + message.author.id = "TEST" + self.assertEqual("TEST", message_id(message)) def test_not_direct(self): - mock = Mock() - mock.channel.type = discord.ChannelType.text - mock.guild.id = "TEST1" - mock.channel.id = "TEST2" - mock.author.id = "TEST3" - self.assertEqual("TEST1/TEST2/TEST3", message_id(mock)) + message = Mock() + message.channel.type = discord.ChannelType.text + message.guild.id = "TEST1" + message.channel.id = "TEST2" + message.author.id = "TEST3" + self.assertEqual("TEST1/TEST2/TEST3", message_id(message)) diff --git a/tests/unit/miniscord/test_utils.py b/tests/unit/miniscord/test_utils.py index 75674c5..257a079 100644 --- a/tests/unit/miniscord/test_utils.py +++ b/tests/unit/miniscord/test_utils.py @@ -1,4 +1,5 @@ from unittest import TestCase + from miniscord._utils import sanitize_input, parse_arguments diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..299555d --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,14 @@ +from unittest import TestCase +import asyncio + + +class AsyncTestCase(TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def _await(self, fn): + return self.loop.run_until_complete(fn)