diff --git a/README.md b/README.md index 4eeb270..c38bd7b 100644 --- a/README.md +++ b/README.md @@ -182,6 +182,20 @@ async def message(client: discord.client, message: discord.Message): bot.register_watcher(message) # any message was sent (except this bot messages) ``` +### Registering events + +Register a [discord API event](https://discordpy.readthedocs.io/en/latest/api.html#discord-api-events) + +The function must be exactly named after the event + +```python +async def on_ready() -> bool: + print("on_ready") + return False # if True is returned, prevent miniscord handling of the event + +bot.register_event(on_ready) +``` + ### Game status On starting, the bot will cycle through 2-3 "game status" (under its name as "playing xxx") : @@ -329,6 +343,7 @@ bot.start() # this bot respond to "|help", "|info" and "|hello" ## Versions +* v0.0.3 : custom events handling * v0.0.2 : new answer capability * v0.0.1 : initial version diff --git a/miniscord/_bot.py b/miniscord/_bot.py index 70a4770..abcb50e 100644 --- a/miniscord/_bot.py +++ b/miniscord/_bot.py @@ -58,6 +58,7 @@ class Bot(object): self.__watcher = None self.__commands = [] self.__fallback = None + self.__events = {} self.games = [f"v{version}", lambda: f"{len(self.client.guilds)} guilds"] if self.alias is not None: self.games += [f"{self.alias}help"] @@ -67,10 +68,10 @@ class Bot(object): self.__register_commands() def __register_events(self): - self.on_ready = self.client.event(self.on_ready) - self.on_message = self.client.event(self.on_message) - self.on_guild_join = self.client.event(self.on_guild_join) - self.on_guild_remove = self.client.event(self.on_guild_remove) + self.client.event(self.on_ready) + self.client.event(self.on_message) + self.client.event(self.on_guild_join) + self.client.event(self.on_guild_remove) def __register_commands(self): # register default commands @@ -145,7 +146,15 @@ class Bot(object): mention_author=self.answer_mention, ) - async def on_ready(self): + async def __handle_event(self, event_name: str, args: list) -> bool: + if event_name in self.__events: + return not await self.__events[event_name](*args) + return False + + async def on_ready(self, *args): + if await self.__handle_event("on_ready", args): + return + # Change status logging.info( f"{self.client.user} (v{self.version}) has connected to {len(self.client.guilds)} Discord guilds" @@ -162,7 +171,10 @@ class Bot(object): ) await asyncio.sleep(self.game_change_delay) - async def on_message(self, message: discord.Message): + async def on_message(self, message: discord.Message, *args): + if await self.__handle_event("on_message", [message, *args]): + return + if message.author == self.client.user: return # Ignore self messages @@ -227,16 +239,29 @@ class Bot(object): if not command_found and self.__fallback is not None: await self.__fallback(self.client, message, *command_args) - async def on_guild_join(self, guild: discord.guild): + async def on_guild_join(self, guild: discord.guild, *args): + if await self.__handle_event("on_guild_join", [guild, *args]): + return + if self.guild_logs_file is not None: with open(self.guild_logs_file, encoding="utf-8", mode="a") as f: f.write(f"{datetime.now():%Y-%m-%d %H:%M} +{guild.id}: {guild.name}\n") - async def on_guild_remove(self, guild: discord.guild): + async def on_guild_remove(self, guild: discord.guild, *args): + if await self.__handle_event("on_guild_remove", [guild, *args]): + return + if self.guild_logs_file is not None: with open(self.guild_logs_file, encoding="utf-8", mode="a") as f: f.write(f"{datetime.now():%Y-%m-%d %H:%M} -{guild.id}: {guild.name}\n") + def register_event(self, event_callback: Callable): + event_name = event_callback.__name__ + if event_name in dir(self): + self.__events[event_name] = event_callback + else: + self.client.event(event_callback) + def register_command( self, regex: str, compute: CommandFunction, help_short: str, help_long: str ): diff --git a/setup.py b/setup.py index ef44774..891a0aa 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ with open("README.md", "r") as fh: setuptools.setup( name="miniscord-Klemek", - version="0.0.2", + version="0.0.3", author="Klemek", description="A minimalist discord bot API", long_description=long_description, diff --git a/tests/unit/miniscord/test_bot.py b/tests/unit/miniscord/test_bot.py index 0bc8e9e..3a4478f 100644 --- a/tests/unit/miniscord/test_bot.py +++ b/tests/unit/miniscord/test_bot.py @@ -45,7 +45,9 @@ class TestInfo(AsyncTestCase): f"app_name vversion\n" f"* Started at {t0:%Y-%m-%d %H:%M}\n" f"* Connected to 3 guilds\n" - f"```" + f"```", + reference=message, + mention_author=False, ) @@ -60,7 +62,9 @@ class TestHelp(AsyncTestCase): f"List of available commands:\n" f"* info: show description\n" f"* help: show this help\n" - f"```" + f"```", + reference=message, + mention_author=False, ) @patch_discord @@ -73,7 +77,9 @@ class TestHelp(AsyncTestCase): f"List of available commands:\n" f"* ¡info: show description\n" f"* ¡help: show this help\n" - f"```" + f"```", + reference=message, + mention_author=False, ) @patch_discord @@ -90,7 +96,9 @@ class TestHelp(AsyncTestCase): f"* test1: desc1\n" f"* info: show description\n" f"* help: show this help\n" - f"```" + f"```", + reference=message, + mention_author=False, ) @patch_discord @@ -99,7 +107,9 @@ class TestHelp(AsyncTestCase): bot.register_command("test1", None, None, "long desc") message = AsyncMock() self._await(bot.help(None, message, "help", "test1")) - message.channel.send.assert_awaited_once_with("long desc") + message.channel.send.assert_awaited_once_with( + "long desc", reference=message, mention_author=False + ) @patch_discord def test_long_regex(self): @@ -108,7 +118,15 @@ class TestHelp(AsyncTestCase): bot.register_command("t.*", None, None, "desc2") message = AsyncMock() self._await(bot.help(None, message, "help", "test")) - message.channel.send.assert_awaited_once_with("desc2") + message.channel.send.assert_awaited_once_with( + "desc2", reference=message, mention_author=False + ) + + +class TestRegisterEvent(TestCase): + @skip + def test_todo(self): + self.fail("not implemented") class TestRegisterCommand(TestCase): @@ -162,7 +180,7 @@ class TestOnMessage(AsyncTestCase): def test_mention_no_command(self): bot = Bot("app_name", "version") bot.enforce_write_permission = False - bot.client.user.id = '12345' + bot.client.user.id = "12345" simple_callback = AsyncMock() bot.register_command("test", simple_callback, "short", "long") watcher_callback = AsyncMock() @@ -174,13 +192,15 @@ class TestOnMessage(AsyncTestCase): self._await(bot.on_message(message)) simple_callback.assert_not_awaited() watcher_callback.assert_awaited_once_with(bot.client, message) - fallback_callback.assert_awaited_once_with(bot.client, message, "testt", "arg0", "arg1") + fallback_callback.assert_awaited_once_with( + bot.client, message, "testt", "arg0", "arg1" + ) @patch_discord def test_mention_no_command_empty(self): bot = Bot("app_name", "version") bot.enforce_write_permission = False - bot.client.user.id = '12345' + bot.client.user.id = "12345" simple_callback = AsyncMock() bot.register_command("test", simple_callback, "short", "long") watcher_callback = AsyncMock() @@ -198,7 +218,7 @@ class TestOnMessage(AsyncTestCase): def test_mention_command_simple(self): bot = Bot("app_name", "version") bot.enforce_write_permission = False - bot.client.user.id = '12345' + bot.client.user.id = "12345" simple_callback = AsyncMock() bot.register_command("test", simple_callback, "short", "long") watcher_callback = AsyncMock() @@ -208,7 +228,9 @@ class TestOnMessage(AsyncTestCase): message = AsyncMock() message.content = "<@12345> test arg0 arg1" self._await(bot.on_message(message)) - simple_callback.assert_awaited_once_with(bot.client, message, "test", "arg0", "arg1") + simple_callback.assert_awaited_once_with( + bot.client, message, "test", "arg0", "arg1" + ) watcher_callback.assert_awaited_once_with(bot.client, message) fallback_callback.assert_not_awaited() @@ -216,7 +238,7 @@ class TestOnMessage(AsyncTestCase): def test_mention_command_regex(self): bot = Bot("app_name", "version") bot.enforce_write_permission = False - bot.client.user.id = '12345' + bot.client.user.id = "12345" regex_callback = AsyncMock() bot.register_command("^t[eo]a?st$", regex_callback, "short", "long") watcher_callback = AsyncMock() @@ -232,7 +254,7 @@ class TestOnMessage(AsyncTestCase): @patch_discord def test_mention_alias_no_command(self): - bot = Bot("app_name", "version", alias='|') + bot = Bot("app_name", "version", alias="|") bot.enforce_write_permission = False simple_callback = AsyncMock() bot.register_command("test", simple_callback, "short", "long") @@ -249,7 +271,7 @@ class TestOnMessage(AsyncTestCase): @patch_discord def test_mention_alias_command_simple(self): - bot = Bot("app_name", "version", alias='|') + bot = Bot("app_name", "version", alias="|") bot.enforce_write_permission = False simple_callback = AsyncMock() bot.register_command("test", simple_callback, "short", "long") @@ -267,7 +289,7 @@ class TestOnMessage(AsyncTestCase): @patch_discord def test_mention_no_permission(self): bot = Bot("app_name", "version") - bot.client.user.id = '12345' + bot.client.user.id = "12345" simple_callback = AsyncMock() bot.register_command("test", simple_callback, "short", "long") watcher_callback = AsyncMock() @@ -276,19 +298,19 @@ class TestOnMessage(AsyncTestCase): bot.register_fallback(fallback_callback) message = AsyncMock() message.content = "<@12345> test hey" - message.channel.__repr__ = lambda *a:'test_channel' - message.guild.__repr__ = lambda *a:'test_guild' + message.channel.__repr__ = lambda *a: "test_channel" + message.guild.__repr__ = lambda *a: "test_guild" permissions = AsyncMock() permissions.send_messages = False - message.channel.permissions_for = lambda u:permissions + message.channel.permissions_for = lambda u: permissions self._await(bot.on_message(message)) simple_callback.assert_not_awaited() watcher_callback.assert_awaited_once_with(bot.client, message) fallback_callback.assert_not_awaited() message.author.create_dm.assert_awaited_once() message.author.dm_channel.send.assert_awaited_once_with( - f"Hi, this bot doesn\'t have the permission to send a message to" - f" #test_channel in server 'test_guild'" + f"Hi, this bot doesn't have the permission to send a message to" + f" #test_channel in server 'test_guild'" ) @skip @@ -311,6 +333,15 @@ class TestOnMessage(AsyncTestCase): def test_lower_command_names(self): self.fail("not implemented") + @skip + def test_fire_registered_event(self): + self.fail("not implemented") + + @skip + def test_fire_registered_event_cancel(self): + self.fail("not implemented") + + class TestOnReady(AsyncTestCase): LOG_PATH = "guilds.log" @@ -337,8 +368,7 @@ class TestOnReady(AsyncTestCase): except Exception as error: self.assertEqual(ex, error) client_mock.change_presence.assert_called_with( - activity="activity", - status=discord.Status.online + activity="activity", status=discord.Status.online ) @patch_discord_arg @@ -359,8 +389,10 @@ class TestOnReady(AsyncTestCase): except: pass with open(self.LOG_PATH, encoding="utf-8", mode="r") as f: - self.assertEqual(f"{d:%Y-%m-%d %H:%M} +id1: name1\n" - f"{d:%Y-%m-%d %H:%M} +id2: name2\n", f.read()) + self.assertEqual( + f"{d:%Y-%m-%d %H:%M} +id1: name1\n" f"{d:%Y-%m-%d %H:%M} +id2: name2\n", + f.read(), + ) @patch_discord_arg def test_log_exists(self, client_mock): @@ -381,6 +413,14 @@ class TestOnReady(AsyncTestCase): with open(self.LOG_PATH, encoding="utf-8", mode="r") as f: self.assertEqual(f"test", f.read()) + @skip + def test_fire_registered_event(self): + self.fail("not implemented") + + @skip + def test_fire_registered_event_cancel(self): + self.fail("not implemented") + class TestOnGuildJoin(AsyncTestCase): LOG_PATH = "guilds.log" @@ -418,8 +458,18 @@ class TestOnGuildJoin(AsyncTestCase): self._await(bot.on_guild_join(guild2)) self.assertTrue(path.exists(self.LOG_PATH)) with open(self.LOG_PATH, encoding="utf-8", mode="r") as f: - self.assertEqual(f"{d:%Y-%m-%d %H:%M} +id1: name1\n" - f"{d:%Y-%m-%d %H:%M} +id2: name2\n", f.read()) + self.assertEqual( + f"{d:%Y-%m-%d %H:%M} +id1: name1\n" f"{d:%Y-%m-%d %H:%M} +id2: name2\n", + f.read(), + ) + + @skip + def test_fire_registered_event(self): + self.fail("not implemented") + + @skip + def test_fire_registered_event_cancel(self): + self.fail("not implemented") class TestOnGuildRemove(AsyncTestCase): @@ -458,5 +508,15 @@ class TestOnGuildRemove(AsyncTestCase): self._await(bot.on_guild_remove(guild2)) self.assertTrue(path.exists(self.LOG_PATH)) with open(self.LOG_PATH, encoding="utf-8", mode="r") as f: - self.assertEqual(f"{d:%Y-%m-%d %H:%M} -id1: name1\n" - f"{d:%Y-%m-%d %H:%M} -id2: name2\n", f.read()) + self.assertEqual( + f"{d:%Y-%m-%d %H:%M} -id1: name1\n" f"{d:%Y-%m-%d %H:%M} -id2: name2\n", + f.read(), + ) + + @skip + def test_fire_registered_event(self): + self.fail("not implemented") + + @skip + def test_fire_registered_event_cancel(self): + self.fail("not implemented")