diff --git a/src/data_types/history.py b/src/data_types/history.py index c2a7229..5354bde 100644 --- a/src/data_types/history.py +++ b/src/data_types/history.py @@ -7,6 +7,7 @@ from utils import mention, from_now, str_datetime, message_link, SPLIT_TOKEN MAX_RANDOM_TRIES = 10 + class History: def __init__(self): self.messages = [] @@ -40,7 +41,7 @@ class History: message = random.choice(self.messages) real_message = await message.fetch() tries += 1 - + if real_message is None: return ["There was no messages matching your filters"] image = "" diff --git a/src/logs/channel_logs.py b/src/logs/channel_logs.py index 65d45c2..07bb14e 100644 --- a/src/logs/channel_logs.py +++ b/src/logs/channel_logs.py @@ -131,8 +131,6 @@ class ChannelLogs: yield len(self.messages), True def dict(self) -> dict: - channel = serialize( - self, not_serialized=["channel", "guild", "start_date"] - ) + channel = serialize(self, not_serialized=["channel", "guild", "start_date"]) channel["messages"] = [message.dict() for message in self.messages] return channel diff --git a/src/logs/message_log.py b/src/logs/message_log.py index 7ef51e6..bd69b4b 100644 --- a/src/logs/message_log.py +++ b/src/logs/message_log.py @@ -76,14 +76,13 @@ class MessageLog: self.reactions[str(reaction.emoji)] = [] async for user in reaction.users(): self.reactions[str(reaction.emoji)] += [user.id] - + async def fetch(self) -> Optional[discord.Message]: try: return await self.channel.channel.fetch_message(self.id) except (discord.NotFound, discord.Forbidden, discord.HTTPException): return None - def dict(self) -> dict: return serialize( self, not_serialized=["channel"], dates=["created_at", "edited_at"] diff --git a/src/scanners/find_scanner.py b/src/scanners/find_scanner.py index 8abf0e0..54ece5c 100644 --- a/src/scanners/find_scanner.py +++ b/src/scanners/find_scanner.py @@ -48,7 +48,10 @@ class FindScanner(Scanner): reference=message, ) return False - self.queries = [(query, query.strip("`") if re.match(r"^`.*`$", query) else None) for query in self.other_args] + self.queries = [ + (query, query.strip("`") if re.match(r"^`.*`$", query) else None) + for query in self.other_args + ] return True def compute_message(self, channel: ChannelLogs, message: MessageLog): @@ -125,5 +128,7 @@ class FindScanner(Scanner): if count > 0: matches[message.author].update_use(count, message.created_at) else: - matches[query[0]].update_use(count, message.created_at, message.author) + matches[query[0]].update_use( + count, message.created_at, message.author + ) return impacted diff --git a/src/scanners/history_scanner.py b/src/scanners/history_scanner.py index 6493137..52f2507 100644 --- a/src/scanners/history_scanner.py +++ b/src/scanners/history_scanner.py @@ -25,7 +25,13 @@ class HistoryScanner(Scanner, ABC): self.all_messages = "all" in args or "everyone" in args self.images_only = "image" in args if not self.images_only: - self.queries = [(query.lower(), query.strip("`") if re.match(r"^`.*`$", query) else None) for query in self.other_args] + self.queries = [ + ( + query.lower(), + query.strip("`") if re.match(r"^`.*`$", query) else None, + ) + for query in self.other_args + ] else: self.queries = [] return True diff --git a/src/scanners/scanner.py b/src/scanners/scanner.py index 5adc7b8..dbfb831 100644 --- a/src/scanners/scanner.py +++ b/src/scanners/scanner.py @@ -30,7 +30,17 @@ from logs import ( class Scanner(ABC): - VALID_ARGS = ["me", "here", "fast", "fresh", "mobile", "mention", "nsfw", "nsfw:allow", "nsfw:only"] + VALID_ARGS = [ + "me", + "here", + "fast", + "fresh", + "mobile", + "mention", + "nsfw", + "nsfw:allow", + "nsfw:only", + ] def __init__( self, @@ -149,18 +159,22 @@ class Scanner(ABC): self.nsfw = FilterLevel.ONLY else: self.nsfw = FilterLevel.NONE - + # fix nsfw filter if channel specified if not self.full and any(channel.nsfw for channel in self.channels): self.nsfw = FilterLevel.ALLOW elif all(channel.nsfw for channel in self.channels): self.nsfw = FilterLevel.ONLY - + # filter nsfw channels if self.nsfw == FilterLevel.NONE: - self.channels = list(filter(lambda channel:not channel.nsfw, self.channels)) + self.channels = list( + filter(lambda channel: not channel.nsfw, self.channels) + ) elif self.nsfw == FilterLevel.ONLY: - self.channels = list(filter(lambda channel:channel.nsfw, self.channels)) + self.channels = list( + filter(lambda channel: channel.nsfw, self.channels) + ) if not await self.init(message, *args): return @@ -244,15 +258,15 @@ class Scanner(ABC): # Display results t0 = datetime.now() intro = get_intro( - self.intro_context, - self.full, - self.channels, - self.members, - self.msg_count, - self.chan_count, - self.start_date, - self.stop_date, - ) + self.intro_context, + self.full, + self.channels, + self.members, + self.msg_count, + self.chan_count, + self.start_date, + self.stop_date, + ) if inspect.iscoroutinefunction(self.get_results): results = await self.get_results(intro) else: diff --git a/src/utils/utils.py b/src/utils/utils.py index 74c0f0c..2fd7413 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -99,6 +99,7 @@ class FakeMessage: def __init__(self, id: int): self.id = id + # FILE