diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml new file mode 100644 index 0000000..a50177e --- /dev/null +++ b/.github/workflows/python.yml @@ -0,0 +1,52 @@ +name: Test + +on: ["push", "pull_request"] + +jobs: + syntax: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8, 3.9] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 black + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Code style with black + run: | + black --check + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8, 3.9] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r test-requirements.txt + pip install -r requirements.txt + - name: Test + run: pytest --cov=miniscord tests/ + - name: Upload coverage data to coveralls.io + run: coveralls + env: + COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/README.md b/README.md index 4eeb270..dd5f426 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +![Python Version >= 3.7](https://img.shields.io/badge/python-%3E=3.7%20-blue) [![Scc Count Badge](https://sloc.xyz/github/klemek/miniscord?category=code)](https://github.com/boyter/scc/#badges-beta) [![Total alerts](https://img.shields.io/lgtm/alerts/g/Klemek/miniscord.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/Klemek/miniscord/alerts/) [![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/Klemek/miniscord.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/Klemek/miniscord/context:python) @@ -182,6 +183,20 @@ async def message(client: discord.client, message: discord.Message): bot.register_watcher(message) # any message was sent (except this bot messages) ``` +### Registering events + +Register a [discord API event](https://discordpy.readthedocs.io/en/latest/api.html#discord-api-events) + +The function must be exactly named after the event + +```python +async def on_ready() -> bool: + print("on_ready") + return False # if True is returned, prevent miniscord handling of the event + +bot.register_event(on_ready) +``` + ### Game status On starting, the bot will cycle through 2-3 "game status" (under its name as "playing xxx") : @@ -329,6 +344,7 @@ bot.start() # this bot respond to "|help", "|info" and "|hello" ## Versions +* v0.0.3 : custom events handling * v0.0.2 : new answer capability * v0.0.1 : initial version diff --git a/miniscord/_bot.py b/miniscord/_bot.py index 70a4770..a867d26 100644 --- a/miniscord/_bot.py +++ b/miniscord/_bot.py @@ -58,6 +58,7 @@ class Bot(object): self.__watcher = None self.__commands = [] self.__fallback = None + self.__events = {} self.games = [f"v{version}", lambda: f"{len(self.client.guilds)} guilds"] if self.alias is not None: self.games += [f"{self.alias}help"] @@ -67,10 +68,10 @@ class Bot(object): self.__register_commands() def __register_events(self): - self.on_ready = self.client.event(self.on_ready) - self.on_message = self.client.event(self.on_message) - self.on_guild_join = self.client.event(self.on_guild_join) - self.on_guild_remove = self.client.event(self.on_guild_remove) + self.client.event(self.on_ready) + self.client.event(self.on_message) + self.client.event(self.on_guild_join) + self.client.event(self.on_guild_remove) def __register_commands(self): # register default commands @@ -145,7 +146,15 @@ class Bot(object): mention_author=self.answer_mention, ) - async def on_ready(self): + async def __handle_event(self, event_name: str, args: list) -> bool: + if event_name in self.__events: + return not await self.__events[event_name](*args) + return False + + async def on_ready(self, *args): + if await self.__handle_event("on_ready", args): + return + # Change status logging.info( f"{self.client.user} (v{self.version}) has connected to {len(self.client.guilds)} Discord guilds" @@ -162,7 +171,10 @@ class Bot(object): ) await asyncio.sleep(self.game_change_delay) - async def on_message(self, message: discord.Message): + async def on_message(self, message: discord.Message, *args): + if await self.__handle_event("on_message", [message, *args]): + return + if message.author == self.client.user: return # Ignore self messages @@ -181,7 +193,7 @@ class Bot(object): message.content = re.sub(r"<@!?[^>]+>", "", message.content) elif is_mention: message.content = re.sub( - f"^<@!?{self.client.user.id}>", "", message.content + f"<@!?{self.client.user.id}>", "", message.content, count=1 ) command_args = parse_arguments(message.content) @@ -200,12 +212,13 @@ class Bot(object): command_found = False + if self.lower_command_names: + command_args[0] = command_args[0].lower() + for command in self.__commands: if re.match( command.regex, - command_args[0].lower() - if self.lower_command_names - else command_args[0], + command_args[0], ): if self.log_calls: debug(message, str(command_args)) @@ -227,16 +240,29 @@ class Bot(object): if not command_found and self.__fallback is not None: await self.__fallback(self.client, message, *command_args) - async def on_guild_join(self, guild: discord.guild): + async def on_guild_join(self, guild: discord.guild, *args): + if await self.__handle_event("on_guild_join", [guild, *args]): + return + if self.guild_logs_file is not None: with open(self.guild_logs_file, encoding="utf-8", mode="a") as f: f.write(f"{datetime.now():%Y-%m-%d %H:%M} +{guild.id}: {guild.name}\n") - async def on_guild_remove(self, guild: discord.guild): + async def on_guild_remove(self, guild: discord.guild, *args): + if await self.__handle_event("on_guild_remove", [guild, *args]): + return + if self.guild_logs_file is not None: with open(self.guild_logs_file, encoding="utf-8", mode="a") as f: f.write(f"{datetime.now():%Y-%m-%d %H:%M} -{guild.id}: {guild.name}\n") + def register_event(self, event_callback: Callable): + event_name = event_callback.__name__ + if event_name in dir(self): + self.__events[event_name] = event_callback + else: + self.client.event(event_callback) + def register_command( self, regex: str, compute: CommandFunction, help_short: str, help_long: str ): diff --git a/setup.py b/setup.py index ef44774..891a0aa 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ with open("README.md", "r") as fh: setuptools.setup( name="miniscord-Klemek", - version="0.0.2", + version="0.0.3", author="Klemek", description="A minimalist discord bot API", long_description=long_description, diff --git a/test-requirements.txt b/test-requirements.txt new file mode 100644 index 0000000..007731a --- /dev/null +++ b/test-requirements.txt @@ -0,0 +1,3 @@ +pytest~=6.2.3 +pytest-cov +coveralls \ No newline at end of file diff --git a/tests/unit/miniscord/test_bot.py b/tests/unit/miniscord/test_bot.py index 0bc8e9e..833a4fc 100644 --- a/tests/unit/miniscord/test_bot.py +++ b/tests/unit/miniscord/test_bot.py @@ -45,7 +45,9 @@ class TestInfo(AsyncTestCase): f"app_name vversion\n" f"* Started at {t0:%Y-%m-%d %H:%M}\n" f"* Connected to 3 guilds\n" - f"```" + f"```", + reference=message, + mention_author=False, ) @@ -60,7 +62,9 @@ class TestHelp(AsyncTestCase): f"List of available commands:\n" f"* info: show description\n" f"* help: show this help\n" - f"```" + f"```", + reference=message, + mention_author=False, ) @patch_discord @@ -73,7 +77,9 @@ class TestHelp(AsyncTestCase): f"List of available commands:\n" f"* ¡info: show description\n" f"* ¡help: show this help\n" - f"```" + f"```", + reference=message, + mention_author=False, ) @patch_discord @@ -90,7 +96,9 @@ class TestHelp(AsyncTestCase): f"* test1: desc1\n" f"* info: show description\n" f"* help: show this help\n" - f"```" + f"```", + reference=message, + mention_author=False, ) @patch_discord @@ -99,7 +107,9 @@ class TestHelp(AsyncTestCase): bot.register_command("test1", None, None, "long desc") message = AsyncMock() self._await(bot.help(None, message, "help", "test1")) - message.channel.send.assert_awaited_once_with("long desc") + message.channel.send.assert_awaited_once_with( + "long desc", reference=message, mention_author=False + ) @patch_discord def test_long_regex(self): @@ -108,7 +118,42 @@ class TestHelp(AsyncTestCase): bot.register_command("t.*", None, None, "desc2") message = AsyncMock() self._await(bot.help(None, message, "help", "test")) - message.channel.send.assert_awaited_once_with("desc2") + message.channel.send.assert_awaited_once_with( + "desc2", reference=message, mention_author=False + ) + + @patch_discord + def test_not_found(self): + bot = Bot("app_name", "version") + message = AsyncMock() + self._await(bot.help(None, message, "help", "notfound")) + message.channel.send.assert_awaited_once_with( + f"Command `notfound` not found", + reference=message, + mention_author=False, + ) + + +class TestRegisterEvent(TestCase): + @patch_discord + def test_register_event_normal(self): + async def on_connect(): + pass + + bot = Bot("app_name", "version") + bot.client.event = MagicMock() + bot.register_event(on_connect) + bot.client.event.assert_called_once_with(on_connect) + + @patch_discord + def test_register_event_existing(self): + async def on_ready(): + pass + + bot = Bot("app_name", "version") + bot.client.event = MagicMock() + bot.register_event(on_ready) + bot.client.event.assert_not_called() class TestRegisterCommand(TestCase): @@ -162,7 +207,7 @@ class TestOnMessage(AsyncTestCase): def test_mention_no_command(self): bot = Bot("app_name", "version") bot.enforce_write_permission = False - bot.client.user.id = '12345' + bot.client.user.id = "12345" simple_callback = AsyncMock() bot.register_command("test", simple_callback, "short", "long") watcher_callback = AsyncMock() @@ -174,13 +219,15 @@ class TestOnMessage(AsyncTestCase): self._await(bot.on_message(message)) simple_callback.assert_not_awaited() watcher_callback.assert_awaited_once_with(bot.client, message) - fallback_callback.assert_awaited_once_with(bot.client, message, "testt", "arg0", "arg1") + fallback_callback.assert_awaited_once_with( + bot.client, message, "testt", "arg0", "arg1" + ) @patch_discord def test_mention_no_command_empty(self): bot = Bot("app_name", "version") bot.enforce_write_permission = False - bot.client.user.id = '12345' + bot.client.user.id = "12345" simple_callback = AsyncMock() bot.register_command("test", simple_callback, "short", "long") watcher_callback = AsyncMock() @@ -198,7 +245,7 @@ class TestOnMessage(AsyncTestCase): def test_mention_command_simple(self): bot = Bot("app_name", "version") bot.enforce_write_permission = False - bot.client.user.id = '12345' + bot.client.user.id = "12345" simple_callback = AsyncMock() bot.register_command("test", simple_callback, "short", "long") watcher_callback = AsyncMock() @@ -208,7 +255,9 @@ class TestOnMessage(AsyncTestCase): message = AsyncMock() message.content = "<@12345> test arg0 arg1" self._await(bot.on_message(message)) - simple_callback.assert_awaited_once_with(bot.client, message, "test", "arg0", "arg1") + simple_callback.assert_awaited_once_with( + bot.client, message, "test", "arg0", "arg1" + ) watcher_callback.assert_awaited_once_with(bot.client, message) fallback_callback.assert_not_awaited() @@ -216,7 +265,7 @@ class TestOnMessage(AsyncTestCase): def test_mention_command_regex(self): bot = Bot("app_name", "version") bot.enforce_write_permission = False - bot.client.user.id = '12345' + bot.client.user.id = "12345" regex_callback = AsyncMock() bot.register_command("^t[eo]a?st$", regex_callback, "short", "long") watcher_callback = AsyncMock() @@ -232,7 +281,7 @@ class TestOnMessage(AsyncTestCase): @patch_discord def test_mention_alias_no_command(self): - bot = Bot("app_name", "version", alias='|') + bot = Bot("app_name", "version", alias="|") bot.enforce_write_permission = False simple_callback = AsyncMock() bot.register_command("test", simple_callback, "short", "long") @@ -249,7 +298,7 @@ class TestOnMessage(AsyncTestCase): @patch_discord def test_mention_alias_command_simple(self): - bot = Bot("app_name", "version", alias='|') + bot = Bot("app_name", "version", alias="|") bot.enforce_write_permission = False simple_callback = AsyncMock() bot.register_command("test", simple_callback, "short", "long") @@ -267,7 +316,7 @@ class TestOnMessage(AsyncTestCase): @patch_discord def test_mention_no_permission(self): bot = Bot("app_name", "version") - bot.client.user.id = '12345' + bot.client.user.id = "12345" simple_callback = AsyncMock() bot.register_command("test", simple_callback, "short", "long") watcher_callback = AsyncMock() @@ -276,40 +325,217 @@ class TestOnMessage(AsyncTestCase): bot.register_fallback(fallback_callback) message = AsyncMock() message.content = "<@12345> test hey" - message.channel.__repr__ = lambda *a:'test_channel' - message.guild.__repr__ = lambda *a:'test_guild' + message.channel.__repr__ = lambda *a: "test_channel" + message.guild.__repr__ = lambda *a: "test_guild" permissions = AsyncMock() permissions.send_messages = False - message.channel.permissions_for = lambda u:permissions + message.channel.permissions_for = lambda u: permissions self._await(bot.on_message(message)) simple_callback.assert_not_awaited() watcher_callback.assert_awaited_once_with(bot.client, message) fallback_callback.assert_not_awaited() message.author.create_dm.assert_awaited_once() message.author.dm_channel.send.assert_awaited_once_with( - f"Hi, this bot doesn\'t have the permission to send a message to" - f" #test_channel in server 'test_guild'" + f"Hi, this bot doesn't have the permission to send a message to" + f" #test_channel in server 'test_guild'" ) - @skip + @patch_discord def test_mention_self(self): - self.fail("not implemented") + bot = Bot("app_name", "version") + bot.enforce_write_permission = False + bot.client.user.id = "12345" + simple_callback = AsyncMock() + bot.register_command("test", simple_callback, "short", "long") + watcher_callback = AsyncMock() + bot.register_watcher(watcher_callback) + fallback_callback = AsyncMock() + bot.register_fallback(fallback_callback) + message = AsyncMock() + message.content = "<@12345> test arg0 arg1" + message.author = bot.client.user + self._await(bot.on_message(message)) + simple_callback.assert_not_awaited() + watcher_callback.assert_not_awaited() + fallback_callback.assert_not_awaited() - @skip + @patch_discord def test_mention_direct(self): - self.fail("not implemented") + bot = Bot("app_name", "version") + bot.enforce_write_permission = False + bot.client.user.id = "12345" + simple_callback = AsyncMock() + bot.register_command("test", simple_callback, "short", "long") + watcher_callback = AsyncMock() + bot.register_watcher(watcher_callback) + fallback_callback = AsyncMock() + bot.register_fallback(fallback_callback) + message = AsyncMock() + message.content = "<@12345> test arg0 arg1" + message.channel.type == discord.ChannelType.private + self._await(bot.on_message(message)) + simple_callback.assert_awaited_once_with( + bot.client, message, "test", "arg0", "arg1" + ) + watcher_callback.assert_awaited_once_with(bot.client, message) + fallback_callback.assert_not_awaited() - @skip + @patch_discord def test_any_mention(self): - self.fail("not implemented") + bot = Bot("app_name", "version") + bot.enforce_write_permission = False + bot.any_mention = True + bot.client.user.id = "12345" + simple_callback = AsyncMock() + bot.register_command("test", simple_callback, "short", "long") + watcher_callback = AsyncMock() + bot.register_watcher(watcher_callback) + fallback_callback = AsyncMock() + bot.register_fallback(fallback_callback) + message = AsyncMock() + message.content = "test <@12345> arg0 arg1" + message.channel.type == discord.ChannelType.private + message.mentions = [bot.client.user] + self._await(bot.on_message(message)) + simple_callback.assert_awaited_once_with( + bot.client, message, "test", "arg0", "arg1" + ) + watcher_callback.assert_awaited_once_with(bot.client, message) + fallback_callback.assert_not_awaited() - @skip + @patch_discord + def test_any_mention_off(self): + bot = Bot("app_name", "version") + bot.enforce_write_permission = False + bot.any_mention = False + bot.client.user.id = "12345" + simple_callback = AsyncMock() + bot.register_command("test", simple_callback, "short", "long") + watcher_callback = AsyncMock() + bot.register_watcher(watcher_callback) + fallback_callback = AsyncMock() + bot.register_fallback(fallback_callback) + message = AsyncMock() + message.content = "test <@12345> arg0 arg1" + message.channel.type == discord.ChannelType.private + message.mentions = [bot.client.user] + self._await(bot.on_message(message)) + simple_callback.assert_not_awaited() + watcher_callback.assert_awaited_once_with(bot.client, message) + fallback_callback.assert_not_awaited() + + @patch_discord def test_remove_mentions(self): - self.fail("not implemented") + bot = Bot("app_name", "version") + bot.enforce_write_permission = False + bot.remove_mentions = True + bot.client.user.id = "12345" + simple_callback = AsyncMock() + bot.register_command("test", simple_callback, "short", "long") + watcher_callback = AsyncMock() + bot.register_watcher(watcher_callback) + fallback_callback = AsyncMock() + bot.register_fallback(fallback_callback) + message = AsyncMock() + message.content = "<@12345> test <@12345> arg1" + self._await(bot.on_message(message)) + simple_callback.assert_awaited_once_with(bot.client, message, "test", "arg1") + watcher_callback.assert_awaited_once_with(bot.client, message) + fallback_callback.assert_not_awaited() - @skip + @patch_discord + def test_remove_mentions_off(self): + bot = Bot("app_name", "version") + bot.enforce_write_permission = False + bot.remove_mentions = False + bot.client.user.id = "12345" + simple_callback = AsyncMock() + bot.register_command("test", simple_callback, "short", "long") + watcher_callback = AsyncMock() + bot.register_watcher(watcher_callback) + fallback_callback = AsyncMock() + bot.register_fallback(fallback_callback) + message = AsyncMock() + message.content = "<@12345> test <@12345> arg1" + self._await(bot.on_message(message)) + simple_callback.assert_awaited_once_with( + bot.client, message, "test", "<@12345>", "arg1" + ) + watcher_callback.assert_awaited_once_with(bot.client, message) + fallback_callback.assert_not_awaited() + + @patch_discord def test_lower_command_names(self): - self.fail("not implemented") + bot = Bot("app_name", "version") + bot.enforce_write_permission = False + bot.lower_command_names = True + bot.client.user.id = "12345" + simple_callback = AsyncMock() + bot.register_command("test", simple_callback, "short", "long") + watcher_callback = AsyncMock() + bot.register_watcher(watcher_callback) + fallback_callback = AsyncMock() + bot.register_fallback(fallback_callback) + message = AsyncMock() + message.content = "<@12345> Test arg0 arg1" + self._await(bot.on_message(message)) + simple_callback.assert_awaited_once_with( + bot.client, message, "test", "arg0", "arg1" + ) + watcher_callback.assert_awaited_once_with(bot.client, message) + fallback_callback.assert_not_awaited() + + @patch_discord + def test_lower_command_names_off(self): + bot = Bot("app_name", "version") + bot.enforce_write_permission = False + bot.lower_command_names = False + bot.client.user.id = "12345" + simple_callback = AsyncMock() + bot.register_command("test", simple_callback, "short", "long") + watcher_callback = AsyncMock() + bot.register_watcher(watcher_callback) + fallback_callback = AsyncMock() + bot.register_fallback(fallback_callback) + message = AsyncMock() + message.content = "<@12345> Test arg0 arg1" + self._await(bot.on_message(message)) + simple_callback.assert_not_awaited() + watcher_callback.assert_awaited_once_with(bot.client, message) + fallback_callback.assert_awaited_once_with( + bot.client, message, "Test", "arg0", "arg1" + ) + + @patch_discord + def test_fire_registered_event(self): + bot = Bot("app_name", "version") + on_message = AsyncMock() + on_message.__name__ = "on_message" + on_message.return_value = True + bot.register_event(on_message) + watcher_callback = AsyncMock() + bot.register_watcher(watcher_callback) + message = AsyncMock() + message.content = "hello there" + self._await(bot.on_message(message)) + on_message.assert_awaited_once_with(message) + watcher_callback.assert_awaited_once_with(bot.client, message) + + @patch_discord + def test_fire_registered_event_cancel(self): + bot = Bot("app_name", "version") + on_message = AsyncMock() + on_message.__name__ = "on_message" + on_message.return_value = False + bot.register_event(on_message) + watcher_callback = AsyncMock() + bot.register_watcher(watcher_callback) + message = AsyncMock() + message.content = "hello there" + self._await(bot.on_message(message)) + on_message.assert_awaited_once_with(message) + watcher_callback.assert_not_awaited() + class TestOnReady(AsyncTestCase): LOG_PATH = "guilds.log" @@ -337,8 +563,7 @@ class TestOnReady(AsyncTestCase): except Exception as error: self.assertEqual(ex, error) client_mock.change_presence.assert_called_with( - activity="activity", - status=discord.Status.online + activity="activity", status=discord.Status.online ) @patch_discord_arg @@ -359,8 +584,10 @@ class TestOnReady(AsyncTestCase): except: pass with open(self.LOG_PATH, encoding="utf-8", mode="r") as f: - self.assertEqual(f"{d:%Y-%m-%d %H:%M} +id1: name1\n" - f"{d:%Y-%m-%d %H:%M} +id2: name2\n", f.read()) + self.assertEqual( + f"{d:%Y-%m-%d %H:%M} +id1: name1\n" f"{d:%Y-%m-%d %H:%M} +id2: name2\n", + f.read(), + ) @patch_discord_arg def test_log_exists(self, client_mock): @@ -381,6 +608,46 @@ class TestOnReady(AsyncTestCase): with open(self.LOG_PATH, encoding="utf-8", mode="r") as f: self.assertEqual(f"test", f.read()) + @patch_discord_arg + def test_fire_registered_event(self, client_mock): + bot = Bot("app_name", "version") + bot.guild_logs_file = None + on_ready = AsyncMock() + on_ready.__name__ = "on_ready" + on_ready.return_value = True + bot.register_event(on_ready) + ex = Exception("test") + client_mock.change_presence.side_effect = ex + try: + with patch("discord.Game") as game_mock: + game_mock.return_value = "activity" + self._await(bot.on_ready()) + except Exception as error: + self.assertEqual(ex, error) + on_ready.assert_awaited_once() + client_mock.change_presence.assert_called_with( + activity="activity", status=discord.Status.online + ) + + @patch_discord_arg + def test_fire_registered_event_cancel(self, client_mock): + bot = Bot("app_name", "version") + bot.guild_logs_file = None + on_ready = AsyncMock() + on_ready.__name__ = "on_ready" + on_ready.return_value = False + bot.register_event(on_ready) + ex = Exception("test") + client_mock.change_presence.side_effect = ex + try: + with patch("discord.Game") as game_mock: + game_mock.return_value = "activity" + self._await(bot.on_ready()) + except Exception as error: + self.assertEqual(ex, error) + on_ready.assert_awaited_once() + client_mock.change_presence.assert_not_called() + class TestOnGuildJoin(AsyncTestCase): LOG_PATH = "guilds.log" @@ -418,8 +685,36 @@ class TestOnGuildJoin(AsyncTestCase): self._await(bot.on_guild_join(guild2)) self.assertTrue(path.exists(self.LOG_PATH)) with open(self.LOG_PATH, encoding="utf-8", mode="r") as f: - self.assertEqual(f"{d:%Y-%m-%d %H:%M} +id1: name1\n" - f"{d:%Y-%m-%d %H:%M} +id2: name2\n", f.read()) + self.assertEqual( + f"{d:%Y-%m-%d %H:%M} +id1: name1\n" f"{d:%Y-%m-%d %H:%M} +id2: name2\n", + f.read(), + ) + + @patch_discord + def test_fire_registered_event(self): + bot = Bot("app_name", "version") + bot.guild_logs_file = self.LOG_PATH + on_guild_join = AsyncMock() + on_guild_join.__name__ = "on_guild_join" + on_guild_join.return_value = True + bot.register_event(on_guild_join) + guild = AsyncMock() + self._await(bot.on_guild_join(guild)) + on_guild_join.assert_awaited_once_with(guild) + self.assertTrue(path.exists(self.LOG_PATH)) + + @patch_discord + def test_fire_registered_event_cancel(self): + bot = Bot("app_name", "version") + bot.guild_logs_file = self.LOG_PATH + on_guild_join = AsyncMock() + on_guild_join.__name__ = "on_guild_join" + on_guild_join.return_value = False + bot.register_event(on_guild_join) + guild = AsyncMock() + self._await(bot.on_guild_join(guild)) + on_guild_join.assert_awaited_once_with(guild) + self.assertFalse(path.exists(self.LOG_PATH)) class TestOnGuildRemove(AsyncTestCase): @@ -458,5 +753,33 @@ class TestOnGuildRemove(AsyncTestCase): self._await(bot.on_guild_remove(guild2)) self.assertTrue(path.exists(self.LOG_PATH)) with open(self.LOG_PATH, encoding="utf-8", mode="r") as f: - self.assertEqual(f"{d:%Y-%m-%d %H:%M} -id1: name1\n" - f"{d:%Y-%m-%d %H:%M} -id2: name2\n", f.read()) + self.assertEqual( + f"{d:%Y-%m-%d %H:%M} -id1: name1\n" f"{d:%Y-%m-%d %H:%M} -id2: name2\n", + f.read(), + ) + + @patch_discord + def test_fire_registered_event(self): + bot = Bot("app_name", "version") + bot.guild_logs_file = self.LOG_PATH + on_guild_remove = AsyncMock() + on_guild_remove.__name__ = "on_guild_remove" + on_guild_remove.return_value = True + bot.register_event(on_guild_remove) + guild = AsyncMock() + self._await(bot.on_guild_remove(guild)) + on_guild_remove.assert_awaited_once_with(guild) + self.assertTrue(path.exists(self.LOG_PATH)) + + @patch_discord + def test_fire_registered_event_cancel(self): + bot = Bot("app_name", "version") + bot.guild_logs_file = self.LOG_PATH + on_guild_remove = AsyncMock() + on_guild_remove.__name__ = "on_guild_remove" + on_guild_remove.return_value = False + bot.register_event(on_guild_remove) + guild = AsyncMock() + self._await(bot.on_guild_remove(guild)) + on_guild_remove.assert_awaited_once_with(guild) + self.assertFalse(path.exists(self.LOG_PATH))