From b2858cca95a2a4d9322e0280601a8e75262ab37e Mon Sep 17 00:00:00 2001 From: Klemek Date: Tue, 18 May 2021 18:13:37 +0200 Subject: [PATCH] nsfw filters --- src/logs/channel_logs.py | 14 +++++++++++--- src/logs/guild_logs.py | 1 + src/scanners/scanner.py | 23 ++++++++++++++++++++++- src/utils/utils.py | 8 ++++++++ 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/src/logs/channel_logs.py b/src/logs/channel_logs.py index 2c4afe8..65d45c2 100644 --- a/src/logs/channel_logs.py +++ b/src/logs/channel_logs.py @@ -47,11 +47,17 @@ class ChannelLogs: def is_format(self): return self.format == FORMAT + def preload(self, channel: discord.TextChannel): + self.name = channel.name + self.channel = channel + + @property + def nsfw(self): + self.channel.nsfw + async def load( self, channel: discord.TextChannel, start_date: datetime, stop_date: datetime ) -> Tuple[int, int]: - self.name = channel.name - self.channel = channel is_empty = self.last_message_id is None try: if is_empty: @@ -125,6 +131,8 @@ class ChannelLogs: yield len(self.messages), True def dict(self) -> dict: - channel = serialize(self, not_serialized=["channel", "guild", "start_date"]) + channel = serialize( + self, not_serialized=["channel", "guild", "start_date"] + ) channel["messages"] = [message.dict() for message in self.messages] return channel diff --git a/src/logs/guild_logs.py b/src/logs/guild_logs.py index 7600077..8e84b23 100644 --- a/src/logs/guild_logs.py +++ b/src/logs/guild_logs.py @@ -231,6 +231,7 @@ class GuildLogs: if channel.id not in self.channels or fresh: loading_new += 1 self.channels[channel.id] = ChannelLogs(channel, self) + self.channels[channel.id].preload(channel) workers += [ Worker(self.channels[channel.id], channel, start_date, stop_date) ] diff --git a/src/scanners/scanner.py b/src/scanners/scanner.py index e5b41b4..4c35238 100644 --- a/src/scanners/scanner.py +++ b/src/scanners/scanner.py @@ -15,6 +15,7 @@ from utils import ( RELATIVE_REGEX, parse_time, command_cache, + FilterLevel, ) from logs import ( GuildLogs, @@ -27,7 +28,7 @@ from logs import ( class Scanner(ABC): - VALID_ARGS = ["me", "here", "fast", "fresh", "mobile", "mention"] + VALID_ARGS = ["me", "here", "fast", "fresh", "mobile", "mention", "nsfw", "nsfw:allow", "nsfw:only"] def __init__( self, @@ -139,6 +140,26 @@ class Scanner(ABC): self.mention_users = "mention" in args or "mobile" in args + # nsfw filter + if "nsfw" in args or "nsfw:allow" in args: + self.nsfw = FilterLevel.ALLOW + elif "nsfw:only" in args: + self.nsfw = FilterLevel.ONLY + else: + self.nsfw = FilterLevel.NONE + + # fix nsfw filter if channel specified + if not self.full and any(channel.nsfw for channel in self.channels): + self.nsfw = FilterLevel.ALLOW + elif all(channel.nsfw for channel in self.channels): + self.nsfw = FilterLevel.ONLY + + # filter nsfw channels + if self.nsfw == FilterLevel.NONE: + self.channels = list(filter(lambda channel:not channel.nsfw, self.channels)) + elif self.nsfw == FilterLevel.ONLY: + self.channels = list(filter(lambda channel:channel.nsfw, self.channels)) + if not await self.init(message, *args): return diff --git a/src/utils/utils.py b/src/utils/utils.py index 19e158a..bc79fe2 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Callable, List, Dict, Union, Optional, Any import os import logging @@ -17,6 +18,7 @@ COMMON_HELP_ARGS = [ " - filter before ", "fast - only read cache", "fresh - does not read cache (long)", + "nsfw:allow/only - allow messages from nsfw channels", "mobile/mention - mentions users (fix @invalid-user bug)", ] @@ -49,6 +51,12 @@ def deltas(t0: datetime): return (datetime.now() - t0).total_seconds() +class FilterLevel(Enum): + NONE = 0 + ALLOW = 1 + ONLY = 2 + + # DISCORD API