nsfw filters

This commit is contained in:
Klemek
2021-05-18 18:13:37 +02:00
parent a01414dce7
commit b2858cca95
4 changed files with 42 additions and 4 deletions
+11 -3
View File
@@ -47,11 +47,17 @@ class ChannelLogs:
def is_format(self): def is_format(self):
return self.format == FORMAT return self.format == FORMAT
def preload(self, channel: discord.TextChannel):
self.name = channel.name
self.channel = channel
@property
def nsfw(self):
self.channel.nsfw
async def load( async def load(
self, channel: discord.TextChannel, start_date: datetime, stop_date: datetime self, channel: discord.TextChannel, start_date: datetime, stop_date: datetime
) -> Tuple[int, int]: ) -> Tuple[int, int]:
self.name = channel.name
self.channel = channel
is_empty = self.last_message_id is None is_empty = self.last_message_id is None
try: try:
if is_empty: if is_empty:
@@ -125,6 +131,8 @@ class ChannelLogs:
yield len(self.messages), True yield len(self.messages), True
def dict(self) -> dict: def dict(self) -> dict:
channel = serialize(self, not_serialized=["channel", "guild", "start_date"]) channel = serialize(
self, not_serialized=["channel", "guild", "start_date"]
)
channel["messages"] = [message.dict() for message in self.messages] channel["messages"] = [message.dict() for message in self.messages]
return channel return channel
+1
View File
@@ -231,6 +231,7 @@ class GuildLogs:
if channel.id not in self.channels or fresh: if channel.id not in self.channels or fresh:
loading_new += 1 loading_new += 1
self.channels[channel.id] = ChannelLogs(channel, self) self.channels[channel.id] = ChannelLogs(channel, self)
self.channels[channel.id].preload(channel)
workers += [ workers += [
Worker(self.channels[channel.id], channel, start_date, stop_date) Worker(self.channels[channel.id], channel, start_date, stop_date)
] ]
+22 -1
View File
@@ -15,6 +15,7 @@ from utils import (
RELATIVE_REGEX, RELATIVE_REGEX,
parse_time, parse_time,
command_cache, command_cache,
FilterLevel,
) )
from logs import ( from logs import (
GuildLogs, GuildLogs,
@@ -27,7 +28,7 @@ from logs import (
class Scanner(ABC): class Scanner(ABC):
VALID_ARGS = ["me", "here", "fast", "fresh", "mobile", "mention"] VALID_ARGS = ["me", "here", "fast", "fresh", "mobile", "mention", "nsfw", "nsfw:allow", "nsfw:only"]
def __init__( def __init__(
self, self,
@@ -139,6 +140,26 @@ class Scanner(ABC):
self.mention_users = "mention" in args or "mobile" in args self.mention_users = "mention" in args or "mobile" in args
# nsfw filter
if "nsfw" in args or "nsfw:allow" in args:
self.nsfw = FilterLevel.ALLOW
elif "nsfw:only" in args:
self.nsfw = FilterLevel.ONLY
else:
self.nsfw = FilterLevel.NONE
# fix nsfw filter if channel specified
if not self.full and any(channel.nsfw for channel in self.channels):
self.nsfw = FilterLevel.ALLOW
elif all(channel.nsfw for channel in self.channels):
self.nsfw = FilterLevel.ONLY
# filter nsfw channels
if self.nsfw == FilterLevel.NONE:
self.channels = list(filter(lambda channel:not channel.nsfw, self.channels))
elif self.nsfw == FilterLevel.ONLY:
self.channels = list(filter(lambda channel:channel.nsfw, self.channels))
if not await self.init(message, *args): if not await self.init(message, *args):
return return
+8
View File
@@ -1,3 +1,4 @@
from enum import Enum
from typing import Callable, List, Dict, Union, Optional, Any from typing import Callable, List, Dict, Union, Optional, Any
import os import os
import logging import logging
@@ -17,6 +18,7 @@ COMMON_HELP_ARGS = [
"<date2> - filter before <date2>", "<date2> - filter before <date2>",
"fast - only read cache", "fast - only read cache",
"fresh - does not read cache (long)", "fresh - does not read cache (long)",
"nsfw:allow/only - allow messages from nsfw channels",
"mobile/mention - mentions users (fix @invalid-user bug)", "mobile/mention - mentions users (fix @invalid-user bug)",
] ]
@@ -49,6 +51,12 @@ def deltas(t0: datetime):
return (datetime.now() - t0).total_seconds() return (datetime.now() - t0).total_seconds()
class FilterLevel(Enum):
NONE = 0
ALLOW = 1
ONLY = 2
# DISCORD API # DISCORD API