Skip to content

Commit

Permalink
Refactor 'anti-crosspost' cog in 'anti_crosspost' extension
Browse files Browse the repository at this point in the history
  • Loading branch information
Mega-JC committed Jul 26, 2024
1 parent a8b7321 commit 3b800df
Showing 1 changed file with 143 additions and 104 deletions.
247 changes: 143 additions & 104 deletions pcbot/exts/anti_crosspost.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,50 @@
Copyright (c) 2022-present pygame-community.
"""

import asyncio
from collections import OrderedDict

import asyncio
from collections import OrderedDict
from collections.abc import Collection
from typing import TypedDict
import discord
from discord.ext import commands
import snakecore
from typing import TypedDict, Collection
from collections import OrderedDict

from ..base import BaseExtensionCog

# Define the type for the bot, supporting both Bot and AutoShardedBot from snakecore
BotT = snakecore.commands.Bot | snakecore.commands.AutoShardedBot


def crosspost_cmp(message: discord.Message, other: discord.Message) -> bool:
"""
Compare two messages to determine if they are crossposts or duplicates.
Args:
message (discord.Message): The first message to compare.
other (discord.Message): The second message to compare.
Returns:
bool: True if the messages are similar enough to be considered
duplicates, otherwise False.
"""
if not message.content or not other.content:
return False

hamming_score = sum(x != y for x, y in zip(message.content, other.content)) / max(
len(message.content), len(other.content)
)

return hamming_score < 0.20 or any(
att1.url == att2.url or att1.size == att2.size # check for approximate matches
att1.filename == att2.filename and att1.size == att2.size
for att1, att2 in zip(message.attachments, other.attachments)
)


class UserCrosspostCache(TypedDict):
"""
A TypedDict for caching user messages and alert message IDs.
"""

message_groups: list[list[discord.Message]]
alert_message_ids: set[int]
message_to_alert: dict[int, int] # Mapping from message ID to alert message ID


class AntiCrosspostCog(BaseExtensionCog, name="anti-crosspost"):
Expand All @@ -42,9 +56,21 @@ def __init__(
channel_ids: Collection[int],
message_length_threshold: int,
max_tracked_users: int,
max_tracked_message_groups_per_user,
max_tracked_message_groups_per_user: int,
theme_color: int | discord.Color = 0,
) -> None:
"""
Initialize the AntiCrosspostCog.
Args:
bot (BotT): The bot instance.
channel_ids (Collection[int]): Collection of channel IDs to monitor.
message_length_threshold (int): Minimum length of a message to be considered.
max_tracked_users (int): Maximum number of users to track.
max_tracked_message_groups_per_user (int): Maximum number of message
groups to track per user.
theme_color (int | discord.Color): Theme color for the bot's responses.
"""
super().__init__(bot, theme_color)
self.channel_ids = set(channel_ids)
self.crossposting_cache: OrderedDict[int, UserCrosspostCache] = OrderedDict()
Expand All @@ -55,142 +81,144 @@ def __init__(

@commands.Cog.listener()
async def on_message(self, message: discord.Message):
if message.author.bot or not (
(
message.channel.id in self.channel_ids
or (
isinstance(message.channel, (discord.abc.GuildChannel))
and message.channel.category_id in self.channel_ids
)
or (
isinstance(message.channel, discord.Thread)
and (
message.channel.parent_id in self.channel_ids
or message.channel.parent
and message.channel.parent.category_id in self.channel_ids
)
)
"""
Event listener for new messages.
Args:
message (discord.Message): The message object.
"""
if (
message.author.bot
or not self._is_valid_channel(message.channel) # type: ignore
or message.type != discord.MessageType.default
or (
len(message.content) < self.message_length_threshold
and not message.attachments
)
and message.type == discord.MessageType.default
):
return

if len(message.content) < self.message_length_threshold:
return

# Initialize cache for new users
if message.author.id not in self.crossposting_cache:
self.crossposting_cache[message.author.id] = UserCrosspostCache(
message_groups=[[message]],
alert_message_ids=set(),
message_to_alert={},
)
return
else:
# Remove oldest user if the limit is exceeded
if len(self.crossposting_cache) > self.max_tracked_users:
self.crossposting_cache.popitem(last=False)

user_cache = self.crossposting_cache[message.author.id]

# Remove old message groups if limit is exceeded and the oldest group is too small
if (
len(self.crossposting_cache[message.author.id]["message_groups"])
> self.max_tracked_message_groups_per_user
and len(self.crossposting_cache[message.author.id]["message_groups"][0]) < 2
len(user_cache["message_groups"]) > self.max_tracked_message_groups_per_user
and len(user_cache["message_groups"][0]) < 2
):
self.crossposting_cache[message.author.id]["message_groups"].pop(0)
user_cache["message_groups"].pop(0)

for i, messages in enumerate(
self.crossposting_cache[message.author.id]["message_groups"]
):
break_outer = False
for j in range(len(messages)):
if message.channel.id != messages[j].channel.id and crosspost_cmp(
message, messages[j]
# Check for crossposts or duplicates in existing message groups
for messages in user_cache["message_groups"]:
for existing_message in messages:
if message.channel.id != existing_message.channel.id and crosspost_cmp(
message, existing_message
):
messages.append(message)

self.crossposting_cache[message.author.id]["alert_message_ids"].add(
(
await message.reply(
"This message is a recent crosspost/duplicate of the following messages: "
+ ", ".join([m.jump_url for m in messages])
+ ".\n\nPlease delete all duplicate messages."
)
).id
)
break_outer = True
# Send an alert message and add its ID to the alert list
try:
alert_message = await message.reply(
"This message is a recent crosspost/duplicate of the following messages: "
+ ", ".join([m.jump_url for m in messages])
+ ".\n\nPlease delete all duplicate messages."
)
user_cache["message_to_alert"][message.id] = alert_message.id
except discord.HTTPException:
# Silently handle errors
pass
break

if break_outer:
break
else:
continue
break
else:
self.crossposting_cache[message.author.id]["message_groups"].append(
[message]
)
# Add new message group if no crosspost is found
user_cache["message_groups"].append([message])
if (
len(self.crossposting_cache[message.author.id]["message_groups"])
len(user_cache["message_groups"])
> self.max_tracked_message_groups_per_user
):
self.crossposting_cache[message.author.id]["message_groups"].pop(0)
user_cache["message_groups"].pop(0)

@commands.Cog.listener()
async def on_message_delete(self, message: discord.Message):
if not (
message.guild
and (
message.channel.id in self.channel_ids
or isinstance(message.channel, discord.abc.GuildChannel)
and message.channel.category_id in self.channel_ids
or isinstance(message.channel, discord.Thread)
and message.channel.parent_id in self.channel_ids
)
):
"""
Event listener for deleted messages.
Args:
message (discord.Message): The message object.
"""
if not self._is_valid_channel(message.channel): # type: ignore
return

if message.author.id not in self.crossposting_cache:
return

user_cache = self.crossposting_cache[message.author.id]
stale_alert_message_ids: list[int] = []
for messages in self.crossposting_cache[message.author.id]["message_groups"]:

for messages in user_cache["message_groups"]:
for j in reversed(range(len(messages))):
if message.id == messages[j].id:
del messages[j] # remove the message from the crosspost group
for alert_message_id in tuple(
self.crossposting_cache[message.author.id]["alert_message_ids"]
):

try:
alert_message = discord.utils.find(
lambda m: m.id == alert_message_id,
self.bot.cached_messages,
)
if not alert_message:
alert_message = await message.channel.fetch_message(
alert_message_id
)
except discord.NotFound:
continue

if (
alert_message.reference
and alert_message.reference.message_id == message.id
):
self.crossposting_cache[message.author.id][
"alert_message_ids"
].remove(alert_message_id)
# mark the alert message as stale if it references the deleted message
stale_alert_message_ids.append(alert_message_id)
del messages[j]
if message.id in user_cache["message_to_alert"]:
stale_alert_message_ids.append(
user_cache["message_to_alert"].pop(message.id)
)
break

if len(messages) == 1:
# mark all alert messages for this crosspost group as stale
# as there is only one message left
stale_alert_message_ids.extend(
self.crossposting_cache[message.author.id]["alert_message_ids"]
)
self.crossposting_cache[message.author.id]["alert_message_ids"].clear()
# Mark all alert messages for this crosspost group as stale
stale_alert_message_ids.extend(user_cache["message_to_alert"].values())
user_cache["message_to_alert"].clear()

# Delete stale alert messages
for alert_message_id in stale_alert_message_ids:
try:
await discord.PartialMessage(
channel=message.channel, id=alert_message_id
).delete()
except (discord.NotFound, discord.Forbidden):
# Silently handle errors
pass

def _is_valid_channel(self, channel: discord.abc.GuildChannel) -> bool:
"""
Check if a channel is valid based on the configured channel IDs.
Args:
channel (discord.abc.GuildChannel): The channel to check.
Returns:
bool: True if the channel is valid, otherwise False.
"""
if isinstance(channel, discord.abc.GuildChannel):
# Check if the channel ID or category ID is in the monitored channel IDs
if (
channel.id in self.channel_ids
or channel.category_id in self.channel_ids
):
return True

# If the channel is a thread, check if the parent or the parent's category ID is in the monitored channel IDs
if isinstance(channel, discord.Thread):
if channel.parent_id in self.channel_ids:
return True
if channel.parent and channel.parent.category_id in self.channel_ids:
return True

return False


@snakecore.commands.decorators.with_config_kwargs
async def setup(
Expand All @@ -201,13 +229,24 @@ async def setup(
message_length_threshold: int = 64,
theme_color: int | discord.Color = 0,
):
"""
Setup function to add the AntiCrosspostCog to the bot.
Args:
bot (BotT): The bot instance.
channel_ids (Collection[int]): Collection of channel IDs to monitor.
max_tracked_users (int): Maximum number of users to track.
max_tracked_message_groups_per_user (int): Maximum number of message groups to track per user.
message_length_threshold (int): Minimum length of a message to be considered.
theme_color (int | discord.Color): Theme color for the bot's responses.
"""
await bot.add_cog(
AntiCrosspostCog(
bot,
channel_ids,
message_length_threshold,
max_tracked_users,
max_tracked_message_groups_per_user,
message_length_threshold,
theme_color,
)
)

0 comments on commit 3b800df

Please sign in to comment.