nsfw filters
This commit is contained in:
@@ -47,11 +47,17 @@ class ChannelLogs:
|
||||
def is_format(self):
|
||||
return self.format == FORMAT
|
||||
|
||||
def preload(self, channel: discord.TextChannel):
|
||||
self.name = channel.name
|
||||
self.channel = channel
|
||||
|
||||
@property
|
||||
def nsfw(self):
|
||||
self.channel.nsfw
|
||||
|
||||
async def load(
|
||||
self, channel: discord.TextChannel, start_date: datetime, stop_date: datetime
|
||||
) -> Tuple[int, int]:
|
||||
self.name = channel.name
|
||||
self.channel = channel
|
||||
is_empty = self.last_message_id is None
|
||||
try:
|
||||
if is_empty:
|
||||
@@ -125,6 +131,8 @@ class ChannelLogs:
|
||||
yield len(self.messages), True
|
||||
|
||||
def dict(self) -> dict:
|
||||
channel = serialize(self, not_serialized=["channel", "guild", "start_date"])
|
||||
channel = serialize(
|
||||
self, not_serialized=["channel", "guild", "start_date"]
|
||||
)
|
||||
channel["messages"] = [message.dict() for message in self.messages]
|
||||
return channel
|
||||
|
||||
@@ -231,6 +231,7 @@ class GuildLogs:
|
||||
if channel.id not in self.channels or fresh:
|
||||
loading_new += 1
|
||||
self.channels[channel.id] = ChannelLogs(channel, self)
|
||||
self.channels[channel.id].preload(channel)
|
||||
workers += [
|
||||
Worker(self.channels[channel.id], channel, start_date, stop_date)
|
||||
]
|
||||
|
||||
+22
-1
@@ -15,6 +15,7 @@ from utils import (
|
||||
RELATIVE_REGEX,
|
||||
parse_time,
|
||||
command_cache,
|
||||
FilterLevel,
|
||||
)
|
||||
from logs import (
|
||||
GuildLogs,
|
||||
@@ -27,7 +28,7 @@ from logs import (
|
||||
|
||||
|
||||
class Scanner(ABC):
|
||||
VALID_ARGS = ["me", "here", "fast", "fresh", "mobile", "mention"]
|
||||
VALID_ARGS = ["me", "here", "fast", "fresh", "mobile", "mention", "nsfw", "nsfw:allow", "nsfw:only"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -139,6 +140,26 @@ class Scanner(ABC):
|
||||
|
||||
self.mention_users = "mention" in args or "mobile" in args
|
||||
|
||||
# nsfw filter
|
||||
if "nsfw" in args or "nsfw:allow" in args:
|
||||
self.nsfw = FilterLevel.ALLOW
|
||||
elif "nsfw:only" in args:
|
||||
self.nsfw = FilterLevel.ONLY
|
||||
else:
|
||||
self.nsfw = FilterLevel.NONE
|
||||
|
||||
# fix nsfw filter if channel specified
|
||||
if not self.full and any(channel.nsfw for channel in self.channels):
|
||||
self.nsfw = FilterLevel.ALLOW
|
||||
elif all(channel.nsfw for channel in self.channels):
|
||||
self.nsfw = FilterLevel.ONLY
|
||||
|
||||
# filter nsfw channels
|
||||
if self.nsfw == FilterLevel.NONE:
|
||||
self.channels = list(filter(lambda channel:not channel.nsfw, self.channels))
|
||||
elif self.nsfw == FilterLevel.ONLY:
|
||||
self.channels = list(filter(lambda channel:channel.nsfw, self.channels))
|
||||
|
||||
if not await self.init(message, *args):
|
||||
return
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Dict, Union, Optional, Any
|
||||
import os
|
||||
import logging
|
||||
@@ -17,6 +18,7 @@ COMMON_HELP_ARGS = [
|
||||
"<date2> - filter before <date2>",
|
||||
"fast - only read cache",
|
||||
"fresh - does not read cache (long)",
|
||||
"nsfw:allow/only - allow messages from nsfw channels",
|
||||
"mobile/mention - mentions users (fix @invalid-user bug)",
|
||||
]
|
||||
|
||||
@@ -49,6 +51,12 @@ def deltas(t0: datetime):
|
||||
return (datetime.now() - t0).total_seconds()
|
||||
|
||||
|
||||
class FilterLevel(Enum):
|
||||
NONE = 0
|
||||
ALLOW = 1
|
||||
ONLY = 2
|
||||
|
||||
|
||||
# DISCORD API
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user