diff --git a/src/logs/channel_logs.py b/src/logs/channel_logs.py index 39d854b..9e83b2b 100644 --- a/src/logs/channel_logs.py +++ b/src/logs/channel_logs.py @@ -1,6 +1,7 @@ from typing import Union, Tuple, Any import discord from discord import message +from datetime import datetime from . import MessageLog from utils import FakeMessage @@ -18,6 +19,7 @@ class ChannelLogs: self.id = channel.id self.name = channel.name self.last_message_id = None + self.first_message_id = None self.format = FORMAT self.messages = [] self.start_date = None @@ -32,6 +34,12 @@ class ChannelLogs: if channel["last_message_id"] is not None else None ) + self.first_message_id = ( + int(channel["first_message_id"]) + if "first_message_id" in channel + and channel["first_message_id"] is not None + else None + ) self.messages = [ MessageLog(message, self) for message in channel["messages"] ] @@ -42,48 +50,74 @@ class ChannelLogs: def is_format(self): return self.format == FORMAT - async def load(self, channel: discord.TextChannel) -> Tuple[int, int]: + async def load( + self, channel: discord.TextChannel, start_date: datetime, stop_date: datetime + ) -> Tuple[int, int]: self.name = channel.name self.channel = channel + is_empty = self.last_message_id is None try: - if self.last_message_id is not None: # append + if is_empty: + sanity_check = len(await channel.history(limit=1).flatten()) + if sanity_check != 1: + yield len(self.messages), True + return + # load backward + if is_empty or ( + start_date is not None + and self.start_date > start_date + and self.first_message_id is not None + ): + first_message_id = self.first_message_id + first_message_date = None + tmp_message_id = 0 + done = 0 + while ( + done >= CHUNK_SIZE + or first_message_id is None + or (first_message_date is None or first_message_date >= start_date) + and start_date is not None + ) and tmp_message_id != first_message_id: + tmp_message_id = first_message_id + done = 0 + async for message in channel.history( + limit=CHUNK_SIZE, + before=FakeMessage(first_message_id) + if first_message_id is not None + else None, + oldest_first=False, + ): + done += 1 + first_message_id = message.id + first_message_date = message.created_at + m = MessageLog(message, self) + await m.load(message) + self.messages += [m] + yield len(self.messages), False + if done >= CHUNK_SIZE and first_message_date < start_date: + # date was limiting here, store first message id + self.first_message_id = first_message_id + self.last_message_id = channel.last_message_id + # load forward + if not is_empty: tmp_message_id = None + last_message_date = self.messages[0].created_at while ( self.last_message_id != channel.last_message_id - and self.last_message_id != tmp_message_id - ): + or (stop_date is not None and last_message_date <= stop_date) + ) and self.last_message_id != tmp_message_id: tmp_message_id = self.last_message_id async for message in channel.history( limit=CHUNK_SIZE, after=FakeMessage(self.last_message_id), oldest_first=True, ): + last_message_date = message.created_at self.last_message_id = message.id m = MessageLog(message, self) await m.load(message) self.messages.insert(0, m) yield len(self.messages), False - else: # first load - last_message_id = None - done = 0 - sanity_check = len(await channel.history(limit=1).flatten()) - if sanity_check == 1: - while done >= CHUNK_SIZE or last_message_id is None: - done = 0 - async for message in channel.history( - limit=CHUNK_SIZE, - before=FakeMessage(last_message_id) - if last_message_id is not None - else None, - oldest_first=False, - ): - done += 1 - last_message_id = message.id - m = MessageLog(message, self) - await m.load(message) - self.messages += [m] - yield len(self.messages), False - self.last_message_id = channel.last_message_id except discord.errors.HTTPException: yield -1, True return # When an exception occurs (like Forbidden) diff --git a/src/logs/guild_logs.py b/src/logs/guild_logs.py index 9aa0d1a..6ba20cf 100644 --- a/src/logs/guild_logs.py +++ b/src/logs/guild_logs.py @@ -32,7 +32,13 @@ MAX_MODIFICATION_TIME = 365 * 24 * 60 * 60 class Worker: - def __init__(self, channel_log: ChannelLogs, channel: discord.TextChannel): + def __init__( + self, + channel_log: ChannelLogs, + channel: discord.TextChannel, + start_date: datetime, + stop_date: datetime, + ): self.channel_log = channel_log self.channel = channel self.start_msg = len(channel_log.messages) @@ -41,12 +47,16 @@ class Worker: self.done = False self.cancelled = False self.loop = asyncio.get_event_loop() + self.start_date = start_date + self.stop_date = stop_date def start(self): asyncio.run_coroutine_threadsafe(self.process(), self.loop) async def process(self): - async for count, done in self.channel_log.load(self.channel): + async for count, done in self.channel_log.load( + self.channel, self.start_date, self.stop_date + ): if count > 0: self.queried_msg = count - self.start_msg self.total_msg = count @@ -98,7 +108,9 @@ class GuildLogs: async def load( self, progress: discord.Message, - target_channels: List[discord.TextChannel] = [], + target_channels: List[discord.TextChannel], + start_date: datetime, + stop_date: datetime, *, fast: bool, fresh: bool, @@ -173,6 +185,8 @@ class GuildLogs: if ( not fast and not fresh + and start_date is None + and stop_date is None and last_time is not None and (time.time() - last_time) < MIN_MODIFICATION_TIME ): @@ -214,7 +228,9 @@ class GuildLogs: if channel.id not in self.channels or fresh: loading_new += 1 self.channels[channel.id] = ChannelLogs(channel, self) - workers += [Worker(self.channels[channel.id], channel)] + workers += [ + Worker(self.channels[channel.id], channel, start_date, stop_date) + ] warning_msg = "(this might take a while)" if len(target_channels) > 5 and loading_new > 5: warning_msg = "(most channels are new, this will take a long while)" @@ -255,7 +271,7 @@ class GuildLogs: f"Reading new history...\n{total_msg:,} messages in {total_chan:,}/{max_chan:,} channels ({round(queried_msg/deltas(t0)):,}m/s)\n{warning_msg}{remaining_msg}", ) logging.info( - f"log {self.guild.id} > queried in {delta(t0):,}ms -> {queried_msg / deltas(t0):,.3f} m/s" + f"log {self.guild.id} > queried {queried_msg} in {delta(t0):,}ms -> {queried_msg / deltas(t0):,.3f} m/s" ) # write logs real_total_msg = sum( diff --git a/src/scanners/scanner.py b/src/scanners/scanner.py index f30cb11..62ba168 100644 --- a/src/scanners/scanner.py +++ b/src/scanners/scanner.py @@ -90,10 +90,10 @@ class Scanner(ABC): ) return - self.start_datetime = None if len(dates) < 1 else min(dates) - self.stop_datetime = datetime.now() if len(dates) < 2 else max(dates) + self.start_date = None if len(dates) < 1 else min(dates) + self.stop_date = None if len(dates) < 2 else max(dates) - if self.start_datetime is not None and self.start_datetime > datetime.now(): + if self.start_date is not None and self.start_date > datetime.now(): await message.channel.send( f"Start date is after today", reference=message ) @@ -130,20 +130,13 @@ class Scanner(ABC): allowed_mentions=discord.AllowedMentions.none(), ) total_msg, total_chan = await logs.load( - progress, self.channels, fast="fast" in args, fresh="fresh" in args + progress, + self.channels, + self.start_date, + self.stop_date, + fast="fast" in args, + fresh="fresh" in args, ) - if self.start_datetime is not None: - self.start_datetime = max( - self.start_datetime, - min( - [ - logs.channels[channel.id].start_date - for channel in self.channels - if channel.id in logs.channels - and logs.channels[channel.id].start_date is not None - ] - ), - ) if total_msg == CANCELLED: await message.channel.send( "Operation cancelled by user", @@ -157,6 +150,21 @@ class Scanner(ABC): elif total_msg == NO_FILE: await message.channel.send(gdpr.TEXT) else: + if self.start_date is not None: + self.start_date = max( + self.start_date, + min( + [ + logs.channels[channel.id].start_date + for channel in self.channels + if channel.id in logs.channels + and logs.channels[channel.id].start_date is not None + ] + ), + ) + if self.stop_date is None: + self.stop_date = datetime.utcnow() + self.msg_count = 0 self.total_msg = 0 self.chan_count = 0 @@ -169,12 +177,12 @@ class Scanner(ABC): self.compute_message(channel_logs, message_log) for message_log in channel_logs.messages if ( - self.start_datetime is None - or message_log.created_at >= self.start_datetime + self.start_date is None + or message_log.created_at >= self.start_date ) and ( - self.stop_datetime is None - or message_log.created_at <= self.stop_datetime + self.stop_date is None + or message_log.created_at <= self.stop_date ) ] ) @@ -199,8 +207,8 @@ class Scanner(ABC): self.members, self.msg_count, self.chan_count, - self.start_datetime, - self.stop_datetime, + self.start_date, + self.stop_date, ) ) logging.info(f"scan {guild.id} > results in {delta(t0):,}ms") diff --git a/src/utils/utils.py b/src/utils/utils.py index 729d633..a439ffd 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -180,9 +180,7 @@ def parse_iso_datetime(str_date: str) -> datetime: return dateutil.parser.parse(str_date) -RELATIVE_REGEX = ( - r"(yesterday|today|\d*h(ours?)?|\d*d(ays?)?|\d*w(eeks?)?|\d*m(onths?)?|\d*y(ears?))" -) +RELATIVE_REGEX = r"(yesterday|today|\d*h(ours?)?|\d*d(ays?)?|\d*w(eeks?)?|\d*m(onths?)?|\d*y(ears?)?)" def parse_relative_time(src: str) -> datetime: