diff --git a/pcbot/exts/anti_crosspost.py b/pcbot/exts/anti_crosspost.py index fafb16a..7161058 100644 --- a/pcbot/exts/anti_crosspost.py +++ b/pcbot/exts/anti_crosspost.py @@ -3,36 +3,50 @@ Copyright (c) 2022-present pygame-community. """ -import asyncio -from collections import OrderedDict - -import asyncio -from collections import OrderedDict -from collections.abc import Collection -from typing import TypedDict import discord from discord.ext import commands import snakecore +from typing import TypedDict, Collection +from collections import OrderedDict from ..base import BaseExtensionCog +# Define the type for the bot, supporting both Bot and AutoShardedBot from snakecore BotT = snakecore.commands.Bot | snakecore.commands.AutoShardedBot def crosspost_cmp(message: discord.Message, other: discord.Message) -> bool: + """ + Compare two messages to determine if they are crossposts or duplicates. + + Args: + message (discord.Message): The first message to compare. + other (discord.Message): The second message to compare. + + Returns: + bool: True if the messages are similar enough to be considered + duplicates, otherwise False. + """ + if not message.content or not other.content: + return False + hamming_score = sum(x != y for x, y in zip(message.content, other.content)) / max( len(message.content), len(other.content) ) return hamming_score < 0.20 or any( - att1.url == att2.url or att1.size == att2.size # check for approximate matches + att1.filename == att2.filename and att1.size == att2.size for att1, att2 in zip(message.attachments, other.attachments) ) class UserCrosspostCache(TypedDict): + """ + A TypedDict for caching user messages and alert message IDs. + """ + message_groups: list[list[discord.Message]] - alert_message_ids: set[int] + message_to_alert: dict[int, int] # Mapping from message ID to alert message ID class AntiCrosspostCog(BaseExtensionCog, name="anti-crosspost"): @@ -42,9 +56,21 @@ def __init__( channel_ids: Collection[int], message_length_threshold: int, max_tracked_users: int, - max_tracked_message_groups_per_user, + max_tracked_message_groups_per_user: int, theme_color: int | discord.Color = 0, ) -> None: + """ + Initialize the AntiCrosspostCog. + + Args: + bot (BotT): The bot instance. + channel_ids (Collection[int]): Collection of channel IDs to monitor. + message_length_threshold (int): Minimum length of a message to be considered. + max_tracked_users (int): Maximum number of users to track. + max_tracked_message_groups_per_user (int): Maximum number of message + groups to track per user. + theme_color (int | discord.Color): Theme color for the bot's responses. + """ super().__init__(bot, theme_color) self.channel_ids = set(channel_ids) self.crossposting_cache: OrderedDict[int, UserCrosspostCache] = OrderedDict() @@ -55,142 +81,144 @@ def __init__( @commands.Cog.listener() async def on_message(self, message: discord.Message): - if message.author.bot or not ( - ( - message.channel.id in self.channel_ids - or ( - isinstance(message.channel, (discord.abc.GuildChannel)) - and message.channel.category_id in self.channel_ids - ) - or ( - isinstance(message.channel, discord.Thread) - and ( - message.channel.parent_id in self.channel_ids - or message.channel.parent - and message.channel.parent.category_id in self.channel_ids - ) - ) + """ + Event listener for new messages. + + Args: + message (discord.Message): The message object. + """ + if ( + message.author.bot + or not self._is_valid_channel(message.channel) # type: ignore + or message.type != discord.MessageType.default + or ( + len(message.content) < self.message_length_threshold + and not message.attachments ) - and message.type == discord.MessageType.default ): return - if len(message.content) < self.message_length_threshold: - return - + # Initialize cache for new users if message.author.id not in self.crossposting_cache: self.crossposting_cache[message.author.id] = UserCrosspostCache( message_groups=[[message]], - alert_message_ids=set(), + message_to_alert={}, ) - return + else: + # Remove oldest user if the limit is exceeded + if len(self.crossposting_cache) > self.max_tracked_users: + self.crossposting_cache.popitem(last=False) + user_cache = self.crossposting_cache[message.author.id] + + # Remove old message groups if limit is exceeded and the oldest group is too small if ( - len(self.crossposting_cache[message.author.id]["message_groups"]) - > self.max_tracked_message_groups_per_user - and len(self.crossposting_cache[message.author.id]["message_groups"][0]) < 2 + len(user_cache["message_groups"]) > self.max_tracked_message_groups_per_user + and len(user_cache["message_groups"][0]) < 2 ): - self.crossposting_cache[message.author.id]["message_groups"].pop(0) + user_cache["message_groups"].pop(0) - for i, messages in enumerate( - self.crossposting_cache[message.author.id]["message_groups"] - ): - break_outer = False - for j in range(len(messages)): - if message.channel.id != messages[j].channel.id and crosspost_cmp( - message, messages[j] + # Check for crossposts or duplicates in existing message groups + for messages in user_cache["message_groups"]: + for existing_message in messages: + if message.channel.id != existing_message.channel.id and crosspost_cmp( + message, existing_message ): messages.append(message) - self.crossposting_cache[message.author.id]["alert_message_ids"].add( - ( - await message.reply( - "This message is a recent crosspost/duplicate of the following messages: " - + ", ".join([m.jump_url for m in messages]) - + ".\n\nPlease delete all duplicate messages." - ) - ).id - ) - break_outer = True + # Send an alert message and add its ID to the alert list + try: + alert_message = await message.reply( + "This message is a recent crosspost/duplicate of the following messages: " + + ", ".join([m.jump_url for m in messages]) + + ".\n\nPlease delete all duplicate messages." + ) + user_cache["message_to_alert"][message.id] = alert_message.id + except discord.HTTPException: + # Silently handle errors + pass break - - if break_outer: - break + else: + continue + break else: - self.crossposting_cache[message.author.id]["message_groups"].append( - [message] - ) + # Add new message group if no crosspost is found + user_cache["message_groups"].append([message]) if ( - len(self.crossposting_cache[message.author.id]["message_groups"]) + len(user_cache["message_groups"]) > self.max_tracked_message_groups_per_user ): - self.crossposting_cache[message.author.id]["message_groups"].pop(0) + user_cache["message_groups"].pop(0) @commands.Cog.listener() async def on_message_delete(self, message: discord.Message): - if not ( - message.guild - and ( - message.channel.id in self.channel_ids - or isinstance(message.channel, discord.abc.GuildChannel) - and message.channel.category_id in self.channel_ids - or isinstance(message.channel, discord.Thread) - and message.channel.parent_id in self.channel_ids - ) - ): + """ + Event listener for deleted messages. + + Args: + message (discord.Message): The message object. + """ + if not self._is_valid_channel(message.channel): # type: ignore return if message.author.id not in self.crossposting_cache: return + user_cache = self.crossposting_cache[message.author.id] stale_alert_message_ids: list[int] = [] - for messages in self.crossposting_cache[message.author.id]["message_groups"]: + + for messages in user_cache["message_groups"]: for j in reversed(range(len(messages))): if message.id == messages[j].id: - del messages[j] # remove the message from the crosspost group - for alert_message_id in tuple( - self.crossposting_cache[message.author.id]["alert_message_ids"] - ): - - try: - alert_message = discord.utils.find( - lambda m: m.id == alert_message_id, - self.bot.cached_messages, - ) - if not alert_message: - alert_message = await message.channel.fetch_message( - alert_message_id - ) - except discord.NotFound: - continue - - if ( - alert_message.reference - and alert_message.reference.message_id == message.id - ): - self.crossposting_cache[message.author.id][ - "alert_message_ids" - ].remove(alert_message_id) - # mark the alert message as stale if it references the deleted message - stale_alert_message_ids.append(alert_message_id) + del messages[j] + if message.id in user_cache["message_to_alert"]: + stale_alert_message_ids.append( + user_cache["message_to_alert"].pop(message.id) + ) break if len(messages) == 1: - # mark all alert messages for this crosspost group as stale - # as there is only one message left - stale_alert_message_ids.extend( - self.crossposting_cache[message.author.id]["alert_message_ids"] - ) - self.crossposting_cache[message.author.id]["alert_message_ids"].clear() + # Mark all alert messages for this crosspost group as stale + stale_alert_message_ids.extend(user_cache["message_to_alert"].values()) + user_cache["message_to_alert"].clear() + # Delete stale alert messages for alert_message_id in stale_alert_message_ids: try: await discord.PartialMessage( channel=message.channel, id=alert_message_id ).delete() except (discord.NotFound, discord.Forbidden): + # Silently handle errors pass + def _is_valid_channel(self, channel: discord.abc.GuildChannel) -> bool: + """ + Check if a channel is valid based on the configured channel IDs. + + Args: + channel (discord.abc.GuildChannel): The channel to check. + + Returns: + bool: True if the channel is valid, otherwise False. + """ + if isinstance(channel, discord.abc.GuildChannel): + # Check if the channel ID or category ID is in the monitored channel IDs + if ( + channel.id in self.channel_ids + or channel.category_id in self.channel_ids + ): + return True + + # If the channel is a thread, check if the parent or the parent's category ID is in the monitored channel IDs + if isinstance(channel, discord.Thread): + if channel.parent_id in self.channel_ids: + return True + if channel.parent and channel.parent.category_id in self.channel_ids: + return True + + return False + @snakecore.commands.decorators.with_config_kwargs async def setup( @@ -201,13 +229,24 @@ async def setup( message_length_threshold: int = 64, theme_color: int | discord.Color = 0, ): + """ + Setup function to add the AntiCrosspostCog to the bot. + + Args: + bot (BotT): The bot instance. + channel_ids (Collection[int]): Collection of channel IDs to monitor. + max_tracked_users (int): Maximum number of users to track. + max_tracked_message_groups_per_user (int): Maximum number of message groups to track per user. + message_length_threshold (int): Minimum length of a message to be considered. + theme_color (int | discord.Color): Theme color for the bot's responses. + """ await bot.add_cog( AntiCrosspostCog( bot, channel_ids, + message_length_threshold, max_tracked_users, max_tracked_message_groups_per_user, - message_length_threshold, theme_color, ) )