Skip to content

Commit

Permalink
Add 'exclude_alert_channel_ids' int list field support to disable sen…
Browse files Browse the repository at this point in the history
…ding crossposting alerts to those channels
  • Loading branch information
Mega-JC committed Aug 24, 2024
1 parent 945ba18 commit 6320a0f
Showing 1 changed file with 138 additions and 74 deletions.
212 changes: 138 additions & 74 deletions pcbot/exts/anti_crosspost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
import discord
from discord.ext import commands
import snakecore
from typing import TypedDict, Collection
from typing import TypedDict, Collection, cast
from collections import OrderedDict
import logging

from ..base import BaseExtensionCog

# Define the type for the bot, supporting both Bot and AutoShardedBot from snakecore
BotT = snakecore.commands.Bot | snakecore.commands.AutoShardedBot
MessageableGuildChannel = (
discord.TextChannel | discord.VoiceChannel | discord.StageChannel | discord.Thread
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,13 +58,17 @@ async def crosspost_cmp(message: discord.Message, other: discord.Message) -> boo
"""
Compare two messages to determine if they are crossposts or duplicates.
Args:
message (discord.Message): The first message to compare.
other (discord.Message): The second message to compare.
Returns:
bool: True if the messages are similar enough to be considered
duplicates, otherwise False.
Parameters
----------
message : discord.Message
The first message to compare.
other : discord.Message
The second message to compare.
Returns
-------
bool
True if the messages are similar enough to be considered duplicates, otherwise False.
"""

similarity_score = None
Expand Down Expand Up @@ -123,14 +130,15 @@ class UserCrosspostCache(TypedDict):
"""

message_groups: list[list[discord.Message]]
message_to_alert: dict[int, int] # Mapping from message ID to alert message ID
message_to_alert_map: dict[int, int] # Mapping from message ID to alert message ID


class AntiCrosspostCog(BaseExtensionCog, name="anti-crosspost"):
def __init__(
self,
bot: BotT,
channel_ids: Collection[int],
exclude_alert_channel_ids: Collection[int] | None,
crosspost_timedelta_threshold: int,
same_channel_message_length_threshold: int,
cross_channel_message_length_threshold: int,
Expand All @@ -141,21 +149,30 @@ def __init__(
"""
Initialize the AntiCrosspostCog.
Args:
bot (BotT): The bot instance.
channel_ids (Collection[int]): Collection of channel IDs to monitor.
crosspost_timedelta_threshold (int): Minimum time difference between messages to not be considered crossposts.
same_channel_message_length_threshold (int): Minimum length of a text-only message to be considered
if the messages are in the same channel.
cross_channel_message_length_threshold (int): Minimum length of a text-only message to be considered
if the messages are in different channels.
max_tracked_users (int): Maximum number of users to track.
max_tracked_message_groups_per_user (int): Maximum number of message
groups to track per user.
theme_color (int | discord.Color): Theme color for the bot's responses.
Parameters
----------
bot : BotT
The bot instance.
channel_ids : Collection[int]
Collection of channel IDs to watch.
exclude_alert_channel_ids : Collection[int] or None
Collection of channel IDs to exclude from alerting.
crosspost_timedelta_threshold : int
Minimum time difference between messages to not be considered crossposts.
same_channel_message_length_threshold : int
Minimum length of a text-only message to be considered if the messages are in the same channel.
cross_channel_message_length_threshold : int
Minimum length of a text-only message to be considered if the messages are in different channels.
max_tracked_users : int
Maximum number of users to track.
max_tracked_message_groups_per_user : int
Maximum number of message groups to track per user.
theme_color : int or discord.Color, optional
Theme color for the bot's responses, by default 0.
"""
super().__init__(bot, theme_color)
self.channel_ids = set(channel_ids)
self.exclude_alert_channel_ids = set(exclude_alert_channel_ids or ())
self.crossposting_cache: OrderedDict[int, UserCrosspostCache] = OrderedDict()

self.crosspost_timedelta_threshold = crosspost_timedelta_threshold
Expand All @@ -170,15 +187,9 @@ def __init__(

@commands.Cog.listener()
async def on_message(self, message: discord.Message):
"""
Event listener for new messages.
Args:
message (discord.Message): The message object.
"""
if (
message.author.bot
or not await self._is_watched_channel(message.channel) # type: ignore
or not await self._check_channel(message.channel, self.channel_ids) # type: ignore
or message.type != discord.MessageType.default
or (
message.content
Expand All @@ -205,14 +216,15 @@ async def on_message(self, message: discord.Message):

user_cache = self.crossposting_cache[user_id]
if not any(len(group) > 1 for group in user_cache["message_groups"]):
# Remove user from cache if they dont have any crossposts
self.crossposting_cache.pop(user_id)
logger.debug(f"Removed user {user_id} from cache to enforce size limit")

# Initialize cache for new users
if message.author.id not in self.crossposting_cache:
self.crossposting_cache[message.author.id] = UserCrosspostCache(
message_groups=[[message]],
message_to_alert={},
message_to_alert_map={},
)
logger.debug(f"Initialized cache for new user {message.author.name}")
else:
Expand Down Expand Up @@ -248,14 +260,40 @@ async def on_message(self, message: discord.Message):
logger.debug(
f"Found crosspost for user {message.author.name}, message URL {message.jump_url}!!!!!!!!!!"
)
alert_channel = cast(MessageableGuildChannel, message.channel)
if self.exclude_alert_channel_ids and not await self._check_channel(
alert_channel, deny=self.exclude_alert_channel_ids
):
# Attempt to find the next best channel to alert in
print( [ msg.content for msg in messages[:-1] ])
for message in reversed(messages[:-1]):
alert_channel = cast(
MessageableGuildChannel, message.channel
)
if await self._check_channel(
alert_channel, deny=self.exclude_alert_channel_ids
):
break
else:
logger.debug(
f"No allowed alerting channel for user {message.author.name} found"
)
break # Don't issue an alert if not possible

if message.id in user_cache["message_to_alert_map"]:
logger.debug(
f"Message {message.id} is already being alerted for user {message.author.name}"
)
break # Don't issue an alert if already alerted

try:
alert_message = await message.reply(
alert_message = await alert_channel.send(
"This message is a recent crosspost/duplicate among the following messages: "
+ ", ".join([m.jump_url for m in messages])
+ ".\n\nPlease delete all duplicate messages."
+ ".\n\nPlease delete all duplicate messages.",
reference=message,
)
user_cache["message_to_alert"][
user_cache["message_to_alert_map"][
message.id
] = alert_message.id
logger.debug(
Expand Down Expand Up @@ -290,13 +328,7 @@ async def on_message(self, message: discord.Message):

@commands.Cog.listener()
async def on_message_delete(self, message: discord.Message):
"""
Event listener for deleted messages.
Args:
message (discord.Message): The message object.
"""
if not await self._is_watched_channel(message.channel): # type: ignore
if not await self._check_channel(message.channel, self.channel_ids): # type: ignore
return

if message.author.id not in self.crossposting_cache:
Expand All @@ -309,9 +341,9 @@ async def on_message_delete(self, message: discord.Message):
for j in range(len(messages) - 1, -1, -1):
if message.id == messages[j].id:
del messages[j]
if message.id in user_cache["message_to_alert"]:
if message.id in user_cache["message_to_alert_map"]:
stale_alert_message_ids.append(
user_cache["message_to_alert"].pop(message.id)
user_cache["message_to_alert_map"].pop(message.id)
)
logger.debug(
f"Removed message {message.jump_url} from user {message.author.name}'s cache due to deletion"
Expand All @@ -320,9 +352,12 @@ async def on_message_delete(self, message: discord.Message):

# Mark last alert message for this crosspost group as stale if the group
# has only one message
if len(messages) == 1 and messages[0].id in user_cache["message_to_alert"]:
if (
len(messages) == 1
and messages[0].id in user_cache["message_to_alert_map"]
):
stale_alert_message_ids.append(
user_cache["message_to_alert"].pop(messages[0].id)
user_cache["message_to_alert_map"].pop(messages[0].id)
)

# Delete stale alert messages
Expand All @@ -337,29 +372,48 @@ async def on_message_delete(self, message: discord.Message):
f"Failed to delete alert message ID {alert_message_id}: {e}"
)

async def _is_watched_channel(self, channel: discord.abc.GuildChannel) -> bool:
@staticmethod
async def _check_channel(
channel: discord.abc.GuildChannel | discord.Thread,
allow: Collection[int] = (),
deny: Collection[int] = (),
) -> bool:
"""
Check if a guild channel or thread is allowed or denied for something based on the provided allow and deny lists.
Parameters
----------
channel : discord.abc.GuildChannel | discord.Thread
The channel to check.
allow : Collection[int], optional
Collection of channel IDs to allow, by default ()
deny : Collection[int], optional
Collection of channel IDs to deny, by default ()
Returns
-------
bool: True if the channel is allowed, False if it is denied, and None if neither is allowed.
"""
Check if a channel is watched for crossposts based on the configured channel IDs.

Args:
channel (discord.abc.GuildChannel): The channel to check.
if not (allow or deny):
raise ValueError("Either 'allow' or 'deny' must be provided")

result = False

Returns:
bool: True if the channel is watched, otherwise False.
"""
if isinstance(channel, discord.abc.GuildChannel):
# Check if the channel ID or category ID is in the monitored channel IDs
if (
channel.id in self.channel_ids
or channel.category_id in self.channel_ids
):
return True

# If the channel is a thread, check if the parent or the parent's category ID is in the monitored channel IDs
if isinstance(channel, discord.Thread):
if channel.parent_id in self.channel_ids:
return True
result = (
bool(allow) and (channel.id in allow or channel.category_id in allow)
) or not (
bool(deny) and (channel.id in deny or channel.category_id in deny)
)

# If the channel is a thread, check if the parent or the parent's category ID is in the monitored channel IDs
elif isinstance(channel, discord.Thread):
if not (
result := (bool(allow) and channel.parent_id in allow)
or not (bool(deny) and channel.parent_id in deny)
):
try:
parent = (
channel.parent
Expand All @@ -369,16 +423,18 @@ async def _is_watched_channel(self, channel: discord.abc.GuildChannel) -> bool:
except discord.NotFound:
pass
else:
if parent and parent.category_id in self.channel_ids:
return True
result = (bool(allow) and parent.category_id in allow) or not (
bool(deny) and parent.category_id in deny
)

return False
return result


@snakecore.commands.decorators.with_config_kwargs
async def setup(
bot: BotT,
channel_ids: Collection[int],
exclude_alert_channel_ids: Collection[int] | None = None,
max_tracked_users: int = 10,
max_tracked_message_groups_per_user: int = 10,
crosspost_timedelta_threshold: int = 86400,
Expand All @@ -389,22 +445,30 @@ async def setup(
"""
Setup function to add the AntiCrosspostCog to the bot.
Args:
bot (BotT): The bot instance.
channel_ids (Collection[int]): Collection of channel IDs to monitor.
max_tracked_users (int): Maximum number of users to track.
max_tracked_message_groups_per_user (int): Maximum number of message groups to track per user.
crosspost_timedelta_threshold (int): Minimum time difference between messages to not be considered crossposts.
same_channel_message_length_threshold (int): Minimum length of a text-only message to be considered
if the messages are in the same channel.
cross_channel_message_length_threshold (int): Minimum length of a text-only message to be considered
if the messages are in different channels.
theme_color (int | discord.Color): Theme color for the bot's responses.
Parameters
----------
bot : BotT
The bot instance.
channel_ids : Collection[int]
Collection of channel IDs to watch.
exclude_alert_channel_ids : Collection[int] or None, optional
Collection of channel IDs to exclude from alerting, by default None
max_tracked_users : int, optional
Maximum number of users to track, by default 10
max_tracked_message_groups_per_user : int, optional
Maximum number of message groups to track per user, by default 10
crosspost_timedelta_threshold : int, optional
Minimum time difference between messages to not be considered crossposts, by default 86400
same_channel_message_length_threshold : int, optional
Minimum length of a text-only message to be considered if the messages are in the same channel, by default 64
cross_channel_message_length_threshold : int, optional
Minimum length of a text-only message to be considered if the messages are in different channels, by default 16
"""
await bot.add_cog(
AntiCrosspostCog(
bot,
channel_ids,
exclude_alert_channel_ids,
crosspost_timedelta_threshold,
same_channel_message_length_threshold,
cross_channel_message_length_threshold,
Expand Down

0 comments on commit 6320a0f

Please sign in to comment.