From e99b9af38e873f4c65d9a74e0e1c007e90a58cd7 Mon Sep 17 00:00:00 2001 From: Mega-JC <65417594+Mega-JC@users.noreply.github.com> Date: Fri, 26 Jul 2024 21:49:38 +0200 Subject: [PATCH] Generalize message validation system for showcase messages in 'showcase' extension --- pcbot/exts/showcase/__init__.py | 61 ++ pcbot/exts/showcase/cogs.py | 767 ++++++++++++++++++++++++ pcbot/exts/showcase/utils/__init__.py | 20 + pcbot/exts/showcase/utils/rules.py | 271 +++++++++ pcbot/exts/showcase/utils/utils.py | 76 +++ pcbot/exts/showcase/utils/validators.py | 234 ++++++++ 6 files changed, 1429 insertions(+) create mode 100644 pcbot/exts/showcase/__init__.py create mode 100644 pcbot/exts/showcase/cogs.py create mode 100644 pcbot/exts/showcase/utils/__init__.py create mode 100644 pcbot/exts/showcase/utils/rules.py create mode 100644 pcbot/exts/showcase/utils/utils.py create mode 100644 pcbot/exts/showcase/utils/validators.py diff --git a/pcbot/exts/showcase/__init__.py b/pcbot/exts/showcase/__init__.py new file mode 100644 index 0000000..12cf62b --- /dev/null +++ b/pcbot/exts/showcase/__init__.py @@ -0,0 +1,61 @@ +from typing import Collection +import discord +import snakecore + +from .utils import ShowcaseChannelConfig + +BotT = snakecore.commands.Bot | snakecore.commands.AutoShardedBot + + +@snakecore.commands.decorators.with_config_kwargs +async def setup( + bot: BotT, + showcase_channels_config: Collection[ShowcaseChannelConfig], + theme_color: int | discord.Color = 0, +): + # validate showcase channels config + for i, showcase_channel_config in enumerate(showcase_channels_config): + if "channel_id" not in showcase_channel_config: + raise ValueError("Showcase channel config must have a 'channel_id' key") + elif ( + "default_auto_archive_duration" in showcase_channel_config + and not isinstance( + showcase_channel_config["default_auto_archive_duration"], int + ) + ): + raise ValueError( + "Showcase channel config 'default_auto_archive_duration' must be an integer" + ) + elif ( + "default_thread_slowmode_delay" in showcase_channel_config + and not isinstance( + showcase_channel_config["default_thread_slowmode_delay"], int + ) + ): + raise ValueError( + "Showcase channel config 'default_thread_slowmode_delay' must be an integer" + ) + elif "showcase_message_rules" not in showcase_channel_config: + raise ValueError( + "Showcase channel config must have a 'showcase_message_rules' key" + ) + + from .utils import dispatch_rule_specifier_dict_validator, BadRuleSpecifier + + specifier_dict_validator = dispatch_rule_specifier_dict_validator( + showcase_channel_config["showcase_message_rules"] + ) + + # validate 'showcase_message_rules' value + try: + specifier_dict_validator( + showcase_channel_config["showcase_message_rules"] # type: ignore + ) + except BadRuleSpecifier as e: + raise ValueError( + f"Error while parsing config.{i}.showcase_message_rules field: {e}" + ) from e + + from .cogs import Showcasing + + await bot.add_cog(Showcasing(bot, showcase_channels_config, theme_color)) diff --git a/pcbot/exts/showcase/cogs.py b/pcbot/exts/showcase/cogs.py new file mode 100644 index 0000000..42fa93c --- /dev/null +++ b/pcbot/exts/showcase/cogs.py @@ -0,0 +1,767 @@ +"""This file is a part of the source code for PygameCommunityBot. +This project has been licensed under the MIT license. +Copyright (c) 2022-present pygame-community. +""" + +import abc +import asyncio +from collections.abc import Collection +import datetime +import enum +import itertools +import re +import time +from typing import Any, Callable, Literal, NotRequired, Protocol, TypedDict + +import discord +from discord.ext import commands +import snakecore +from snakecore.commands import flagconverter_kwargs +from snakecore.commands import UnicodeEmoji +from snakecore.commands.converters import DateTime + +from .utils import ShowcaseChannelConfig, validate_message + +from ...base import BaseExtensionCog + +BotT = snakecore.commands.Bot | snakecore.commands.AutoShardedBot + + +class Showcasing(BaseExtensionCog, name="showcasing"): + """A cog for managing showcase forum/threaded channels.""" + + def __init__( + self, + bot: BotT, + showcase_channels_config: Collection[ShowcaseChannelConfig], + theme_color: int | discord.Color = 0, + ) -> None: + super().__init__(bot, theme_color=theme_color) + self.showcase_channels_config: dict[int, ShowcaseChannelConfig] = { + showcase_channel_config["channel_id"]: showcase_channel_config + for showcase_channel_config in showcase_channels_config + } + self.entry_message_deletion_dict: dict[int, tuple[asyncio.Task[None], int]] = {} + + @commands.guild_only() + @commands.max_concurrency(1, per=commands.BucketType.guild, wait=True) + @commands.group( + invoke_without_command=True, + ) + async def showcase(self, ctx: commands.Context[BotT]): + pass + + @showcase.command( + name="rank", + usage=" [tags: Text...] [include_tags: Text...] " + "[exclude_tags: Text...] [before: Thread/DateTime] " + "[after: Thread/DateTime] [rank_emoji: Emoji]", + extras=dict(response_deletion_with_reaction=True), + ) + @flagconverter_kwargs() + async def showcase_rank( + self, + ctx: commands.Context[BotT], + forum: discord.ForumChannel, + *, + amount: commands.Range[int, 0], + include_tags: tuple[str, ...] = commands.flag(aliases=["tags"], default=()), + exclude_tags: tuple[str, ...] = (), + before: discord.Thread | DateTime | None = None, + after: discord.Thread | DateTime | None = None, + rank_emoji: UnicodeEmoji | discord.PartialEmoji | str | None = None, + ): + """Rank the specified showcase forum channel's posts by the number of reactions they have. + + __**Parameters:**__ + + **``** + > The forum channel to rank. + + **``** + > The amount of posts to rank. + + **`[include_tags: Text...]`** + **`[tags: Text...]`** + > A flag for specifying the inclusionary tags to filter posts by. + > Cannot be used with `exclude_tags`. + + **`[exclude_tags: Text...]`** + > A flag for specifying the exlcusary tags to filter posts by. + > Cannot be used with `include_tags`. + + **`[before: Thread/DateTime]`** + > A flag for specifying the thread to start the ranking from. + + **`[after: Thread/DateTime]`** + > A flag for specifying the thread to end the ranking at. + + **`[rank_emoji: Emoji]`** + > A flag for specifying the reaction emoji to use for ranking. In omitted, + > all used reaction emojis will be counted and summed up to calculate the rank. + """ + + assert ( + ctx.guild + and ctx.bot.user + and (bot_member := ctx.guild.get_member(ctx.bot.user.id)) + and isinstance( + ctx.channel, + (discord.TextChannel, discord.VoiceChannel, discord.Thread), + ) + and isinstance(ctx.author, discord.Member) + ) + + channel = forum + + if isinstance(rank_emoji, str): + rank_emoji = discord.PartialEmoji(name=rank_emoji) + + if include_tags and exclude_tags: + raise commands.CommandInvokeError( + commands.CommandError( + "You cannot specify both `include_tags:` and `exclude_tags:` at the same time." + ) + ) + + tags = [tag.name.lower() for tag in channel.available_tags] + + if include_tags: + include_tags = tuple(tag.lower() for tag in include_tags) + tags = [tag for tag in tags if tag in include_tags] + + if exclude_tags: + exclude_tags = tuple(tag.lower() for tag in exclude_tags) + tags = [tag for tag in tags if tag not in exclude_tags] + + if not snakecore.utils.have_permissions_in_channels( + ctx.author, + channel, + "view_channel", + ): + raise commands.CommandInvokeError( + commands.CommandError( + "You do not have enough permissions to run this command on the " + f"specified destination (<#{channel.id}>." + ) + ) + + if isinstance(before, discord.Thread) and before.parent_id != channel.id: + raise commands.CommandInvokeError( + commands.CommandError( + "`before` has to be an ID of a thread from the specified channel", + ) + ) + + if isinstance(after, discord.Thread) and after.parent_id != channel.id: + raise commands.CommandInvokeError( + commands.CommandError( + "`after` has to be an ID of a thread from the specified channel", + ) + ) + + before_ts = ( + before.replace(tzinfo=datetime.timezone.utc) + if isinstance(before, datetime.datetime) + else ( + discord.utils.snowflake_time(before.id) + if isinstance(before, discord.Thread) + else None + ) + ) + + after_ts = ( + after.replace(tzinfo=datetime.timezone.utc) + if isinstance(after, datetime.datetime) + else ( + discord.utils.snowflake_time(after.id) + if isinstance(after, discord.Thread) + else None + ) + ) + + async def count_unique_thread_reactions( + thread: discord.Thread, starter_message: discord.Message + ): + if rank_emoji: + return sum( + reaction.count + for reaction in starter_message.reactions + if snakecore.utils.is_emoji_equal(rank_emoji, reaction.emoji) + ) + + user_ids_by_reaction: dict[tuple[str, int], list[int]] = {} + + for i, reaction in enumerate(starter_message.reactions): + user_ids_by_reaction[str(reaction.emoji), i] = [ + user.id async for user in reaction.users() + ] + + return len( + set(itertools.chain.from_iterable(user_ids_by_reaction.values())) + ) + + async def thread_triple(thread: discord.Thread): + try: + starter_message = thread.starter_message or await thread.fetch_message( + thread.id + ) + except discord.NotFound: + return None + + return ( + thread, + starter_message, + await count_unique_thread_reactions(thread, starter_message), + ) + + max_archived_threads = max( + amount - len(channel.threads), 0 + ) # subtract active threads + + thread_triples = sorted( # sort triples by reaction count + ( + sorted_thread_triples := [ + # retrieve threads as + # (thread, message, reaction_count) tuples within time range + # in descending order + triple + for thread in itertools.chain( + sorted(channel.threads, key=lambda t: t.id, reverse=True), + ( + [ + thread + async for thread in channel.archived_threads( + limit=max_archived_threads, + ) + ] + if max_archived_threads + else (()) + ), + ) + if ( + before_ts is None + or discord.utils.snowflake_time(thread.id) < before_ts + ) + and ( + after_ts is None + or discord.utils.snowflake_time(thread.id) > after_ts + ) + and (triple := (await thread_triple(thread))) + and any(tag.name.lower() in tags for tag in triple[0].applied_tags) + ][:amount] + ), + key=lambda tup: tup[2], + reverse=True, + ) + + if not thread_triples: + raise commands.CommandInvokeError( + commands.CommandError("No threads found in the specified channel.") + ) + + embed_dict = { + "title": f"Showcase Rankings for {channel.mention} Posts by Emoji\n" + f"({len(thread_triples)} selected, from " + " " + f"to , based on unique reactions)", + "color": self.theme_color.value, + "fields": [], + } + + for i, triple in enumerate(thread_triples): + thread, starter_message, thread_reactions_count = triple + if thread_reactions_count: + embed_dict["fields"].append( + dict( + name=( + f"{i + 1}. " + + ( + f"{rank_emoji}: {thread_reactions_count}" + if rank_emoji + else f"{thread_reactions_count}: Unique | " + + ", ".join( + f"{reaction.emoji}: {reaction.count}" + for reaction in starter_message.reactions + ) + ) + ), + value=f"{thread.jump_url}", + inline=False, + ) + ) + + # divide embed dict into lists of multiple embed dicts if necessary + response_embed_dict_lists = [ + snakecore.utils.embeds.split_embed_dict(embed_dict) + ] + + # group those lists based on the total character count of the embeds + for i in range(len(response_embed_dict_lists)): + response_embed_dicts_list = response_embed_dict_lists[i] + total_char_count = 0 + for j in range(len(response_embed_dicts_list)): + response_embed_dict = response_embed_dicts_list[j] + if ( + total_char_count + + snakecore.utils.embeds.check_embed_dict_char_count( + response_embed_dict + ) + ) > snakecore.utils.embeds.EMBED_TOTAL_CHAR_LIMIT: + response_embed_dict_lists.insert( + # slice up the response embed dict list to fit the character + # limit per message + i + 1, + response_embed_dicts_list[j : j + 1], + ) + response_embed_dict_lists[i] = response_embed_dicts_list[:j] + else: + total_char_count += ( + snakecore.utils.embeds.check_embed_dict_char_count( + response_embed_dict + ) + ) + + for response_embed_dicts_list in response_embed_dict_lists: + await ctx.send( + embeds=[ + discord.Embed.from_dict(embed_dict) + for embed_dict in response_embed_dicts_list + ] + ) + + @staticmethod + async def delete_bad_message_with_thread( + message: discord.Message, delay: float = 0.0 + ): + """A function to pardon a bad message and its post/thread (if present) with a grace period. If this coroutine is not cancelled during the + grace period specified in `delay` in seconds, it will delete `thread`, if possible. + """ + try: + await asyncio.sleep(delay) # allow cancelling during delay + except asyncio.CancelledError: + return + + else: + try: + if isinstance(message.channel, discord.Thread): + await message.channel.delete() + + await message.delete() + except discord.NotFound: + # don't error here if thread and/or message were already deleted + pass + + def showcase_message_validity_check( + self, + message: discord.Message, + ) -> tuple[bool, str | None]: + """Checks if a showcase message has the right format. + + Returns + ------- + tuple[bool, str | None]: + A tuple containing a boolean indicating whether the message is valid or not, and a string describing the reason why it is invalid if it is not valid. + """ + return validate_message( + message, + self.showcase_channels_config[message.channel.id]["showcase_message_rules"], + ) + + @commands.Cog.listener() + async def on_thread_create(self, thread: discord.Thread): + if not ( + isinstance(thread.parent, discord.ForumChannel) + and thread.parent_id in self.showcase_channels_config + ): + return + + try: + message = thread.starter_message or await thread.fetch_message(thread.id) + except discord.NotFound: + return + + is_valid, reason = self.showcase_message_validity_check(message) + + if not is_valid: + deletion_datetime = datetime.datetime.now( + datetime.timezone.utc + ) + datetime.timedelta(minutes=5) + warn_msg = await message.reply( + "### Invalid showcase message\n\n" + f"{reason}\n\n" + " If no changes are made, your message (and its thread/post) will be " + f"deleted {snakecore.utils.create_markdown_timestamp(deletion_datetime, 'R')}." + ) + self.entry_message_deletion_dict[message.id] = ( + asyncio.create_task( + self.delete_bad_message_with_thread(message, delay=300) + ), + warn_msg.id, + ) + + async def prompt_author_for_feedback_thread(self, message: discord.Message): + assert ( + message.guild + and isinstance(message.channel, discord.TextChannel) + and self.bot.user + and (bot_member := message.guild.get_member(self.bot.user.id)) + ) + bot_perms = message.channel.permissions_for(bot_member) + + if not bot_perms.create_public_threads: + return + + deletion_datetime = datetime.datetime.now( + datetime.timezone.utc + ) + datetime.timedelta(minutes=1) + + alert_msg = await message.reply( + content=f"Need a feedback thread?\n\n-# This message will be deleted " + + snakecore.utils.create_markdown_timestamp(deletion_datetime, "R") + + ".", + ) + + await alert_msg.add_reaction("✅") + await alert_msg.add_reaction("❌") + + try: + event = await self.bot.wait_for( + "raw_reaction_add", + check=lambda event: event.message_id == alert_msg.id + and ( + event.user_id == message.author.id + or ( + event.member + and (not event.member.bot) + and ( + ( + perms := message.channel.permissions_for(event.member) + ).administrator + or perms.manage_messages + ) + ) + ) + and ( + snakecore.utils.is_emoji_equal(event.emoji, "✅") + or snakecore.utils.is_emoji_equal(event.emoji, "❌") + ), + timeout=60, + ) + except asyncio.TimeoutError: + try: + await alert_msg.delete() + except discord.NotFound: + pass + else: + if snakecore.utils.is_emoji_equal(event.emoji, "✅"): + try: + await message.create_thread( + name=( + f"Feedback for " + + f"@{message.author.name} | {str(message.author.id)[-6:]}" + )[:100], + auto_archive_duration=( + self.showcase_channels_config[message.channel.id].get( + "default_auto_archive_duration", 60 + ) + if bot_perms.manage_threads + else discord.utils.MISSING + ), # type: ignore + slowmode_delay=( + self.showcase_channels_config[message.channel.id].get( + "default_thread_slowmode_delay", + ) + if bot_perms.manage_threads + else None + ), # type: ignore + reason=f"A '#{message.channel.name}' message " + "author requested a feedback thread.", + ) + except discord.HTTPException: + pass + + try: + await alert_msg.delete() + except discord.NotFound: + pass + + @commands.Cog.listener() + async def on_message(self, message: discord.Message): + if not ( + (not message.author.bot) + and ( + isinstance(message.channel, discord.TextChannel) + and message.channel.id + in self.showcase_channels_config # is message in a showcase text channel + ) + ): + return + + is_valid, reason = self.showcase_message_validity_check(message) + + if is_valid: + await self.prompt_author_for_feedback_thread(message) + else: + deletion_datetime = datetime.datetime.now( + datetime.timezone.utc + ) + datetime.timedelta(minutes=5) + warn_msg = await message.reply( + "### Invalid showcase message\n\n" + f"{reason}\n\n" + " If no changes are made, your message (and its thread/post) will be " + f"deleted {snakecore.utils.create_markdown_timestamp(deletion_datetime, 'R')}." + ) + self.entry_message_deletion_dict[message.id] = ( + asyncio.create_task( + self.delete_bad_message_with_thread(message, delay=300) + ), + warn_msg.id, + ) + + @commands.Cog.listener() + async def on_message_edit(self, old: discord.Message, new: discord.Message): + if not ( + (not new.author.bot) + and ( + new.channel.id + in self.showcase_channels_config # is message in a showcase text channel + or ( + isinstance(new.channel, discord.Thread) + and new.channel.parent_id in self.showcase_channels_config + and new.id == new.channel.id + ) # is starter message of a post in a showcase forum + ) + and ( + new.content != old.content + or new.embeds != old.embeds + or new.attachments != old.attachments + ) + ): + return + + is_valid, reason = self.showcase_message_validity_check(new) + + if not is_valid: + if new.id in self.entry_message_deletion_dict: + deletion_data_tuple = self.entry_message_deletion_dict[new.id] + deletion_task = deletion_data_tuple[0] + if deletion_task.done(): + del self.entry_message_deletion_dict[new.id] + else: + try: + deletion_task.cancel() # try to cancel deletion after noticing edit by sender + + # fetch warning message from inside a post or refrencing the target message in a text showcase channel + warn_msg = await new.channel.fetch_message( + deletion_data_tuple[1] + ) + deletion_datetime = datetime.datetime.now( + datetime.timezone.utc + ) + datetime.timedelta(minutes=5) + await warn_msg.edit( + content=( + "### Invalid showcase message\n\n" + "Your edited showcase message is invalid.\n\n" + f"{reason}\n\n" + " If no changes are made, your post will be " + f"deleted " + + snakecore.utils.create_markdown_timestamp( + deletion_datetime, "R" + ) + + "." + ) + ) + self.entry_message_deletion_dict[new.id] = ( + asyncio.create_task( + self.delete_bad_message_with_thread(new, delay=300) + ), + warn_msg.id, + ) + except ( + discord.NotFound + ): # cancelling didn't work, warning and post were already deleted + if new.id in self.entry_message_deletion_dict: + del self.entry_message_deletion_dict[new.id] + + else: # an edit led to an invalid post from a valid one + deletion_datetime = datetime.datetime.now( + datetime.timezone.utc + ) + datetime.timedelta(minutes=5) + warn_msg = await new.reply( + "Your post must contain an attachment or text and safe links " + "to be valid.\n\n" + "- Attachment-only entries must be in reference to a previous " + "post of yours.\n" + "- Text-only posts must contain at least 32 characters " + "(including their title and including links, but not links " + "alone).\n\nIf no changes are made, your post will be" + f" deleted " + + snakecore.utils.create_markdown_timestamp(deletion_datetime, "R") + + "." + ) + + self.entry_message_deletion_dict[new.id] = ( + asyncio.create_task( + self.delete_bad_message_with_thread(new, delay=300) + ), + warn_msg.id, + ) + + elif ( + is_valid + ) and new.id in self.entry_message_deletion_dict: # an invalid entry was corrected + deletion_data_tuple = self.entry_message_deletion_dict[new.id] + deletion_task = deletion_data_tuple[0] + if not deletion_task.done(): # too late to do anything + try: + deletion_task.cancel() # try to cancel deletion after noticing valid edit by sender + await discord.PartialMessage( + channel=new.channel, id=deletion_data_tuple[1] + ).delete() + except ( + discord.NotFound + ): # cancelling didn't work, warning was already deleted + pass + + if new.id in self.entry_message_deletion_dict: + del self.entry_message_deletion_dict[new.id] + + if isinstance(new.channel, discord.TextChannel): + try: + # check if a feedback thread was previously created for this message + _ = new.channel.get_thread( + new.id + ) or await new.channel.guild.fetch_channel(new.id) + except discord.NotFound: + pass + else: + return + + await self.prompt_author_for_feedback_thread(new) + + @commands.Cog.listener() + async def on_message_delete(self, message: discord.Message): + if not ( + (not message.author.bot) + and ( + message.channel.id + in self.showcase_channels_config # is message in a showcase text channel + or ( + isinstance(message.channel, discord.Thread) + and message.channel.parent_id in self.showcase_channels_config + and message.id == message.channel.id + ) # is starter message of a post in a showcase forum + ) + ): + return + + if ( + message.id in self.entry_message_deletion_dict + ): # for case where user deletes their bad entry by themselves + deletion_data_tuple = self.entry_message_deletion_dict[message.id] + deletion_task = deletion_data_tuple[0] + if not deletion_task.done(): + deletion_task.cancel() + try: + await discord.PartialMessage( + channel=message.channel, id=deletion_data_tuple[1] + ).delete() + except discord.NotFound: + # warning message and post were already deleted + pass + + del self.entry_message_deletion_dict[message.id] + + alert_destination = message.channel + + if isinstance(message.channel, discord.TextChannel): + try: + alert_destination = message.channel.get_thread( + message.id + ) or await message.channel.guild.fetch_channel(message.id) + except discord.NotFound: + return + + if not isinstance(alert_destination, discord.Thread): + return + + alert_msg = await alert_destination.send( + embed=discord.Embed.from_dict( + dict( + title="Post/Thread scheduled for deletion", + description=( + "This post/thread is scheduled for deletion:\n\n" + "The OP has deleted their starter message." + + "\n\nIt will be deleted " + f"****." + ), + color=0x551111, + footer=dict(text="React with ❌ to cancel the deletion."), + ) + ) + ) + + await alert_msg.add_reaction("❌") + + try: + await self.bot.wait_for( + "raw_reaction_add", + check=lambda event: event.message_id == alert_msg.id + and ( + event.user_id == message.author.id + or ( + event.member + and (not event.member.bot) + and ( + ( + perms := message.channel.permissions_for(event.member) + ).administrator + or perms.manage_messages + ) + ) + ) + and snakecore.utils.is_emoji_equal(event.emoji, "❌"), + timeout=300, + ) + except asyncio.TimeoutError: + try: + await alert_destination.delete() + except discord.NotFound: + pass + else: + try: + await alert_msg.delete() + except discord.NotFound: + pass + + @commands.Cog.listener() + async def on_raw_thread_delete(self, payload: discord.RawThreadDeleteEvent): + if ( + payload.parent_id not in self.showcase_channels_config + or payload.thread_id not in self.entry_message_deletion_dict + ): + return + + deletion_data_tuple = self.entry_message_deletion_dict[payload.thread_id] + deletion_task = deletion_data_tuple[0] + if not deletion_task.done(): + deletion_task.cancel() + + del self.entry_message_deletion_dict[payload.thread_id] diff --git a/pcbot/exts/showcase/utils/__init__.py b/pcbot/exts/showcase/utils/__init__.py new file mode 100644 index 0000000..dcb8b11 --- /dev/null +++ b/pcbot/exts/showcase/utils/__init__.py @@ -0,0 +1,20 @@ +from abc import ABC +import re +from typing import Any, Callable, Collection, Literal, NotRequired, TypedDict +import discord +import snakecore + +from .utils import * +from .validators import * + +BotT = snakecore.commands.Bot | snakecore.commands.AutoShardedBot + + +class ShowcaseChannelConfig(TypedDict): + """A typed dict for specifying showcase channel configurations.""" + + channel_id: int + default_auto_archive_duration: NotRequired[int] + default_thread_slowmode_delay: NotRequired[int] + showcase_message_rules: RuleSpecifier | RuleSpecifierPair | RuleSpecifierList + "A rule specifier dict for validating messages posted to the showcase channel" diff --git a/pcbot/exts/showcase/utils/rules.py b/pcbot/exts/showcase/utils/rules.py new file mode 100644 index 0000000..a91c8c6 --- /dev/null +++ b/pcbot/exts/showcase/utils/rules.py @@ -0,0 +1,271 @@ +# Base class for common message validation logic +import re +from typing import Literal +import discord +from .utils import MISSING, URL_PATTERN, DiscordMessageRule, EnforceType, is_vcs_url + + +class ContentRule(DiscordMessageRule, name="content"): + """A rule for validating if a Discord message contains only, any (the default), or no content.""" + + @staticmethod + def validate( + enforce_type: EnforceType, + message: discord.Message, + arg: Literal["any", "only", "none"] = "any", + ) -> tuple[Literal[False], str] | tuple[Literal[True], None]: + """Validate a message for the presence of content according to the specified arguments.""" + + has_content = bool(message.content) + only_content = not has_content and not (message.attachments or message.embeds) + + if enforce_type == "always" and arg == "only" and not only_content: + return (False, "Message must always contain only text content") + + if enforce_type == "always" and arg == "any" and not has_content: + return (False, "Message must always contain text content") + + if enforce_type == "always" and arg == "none" and has_content: + return (False, "Message must always contain no text content") + + if enforce_type == "never" and arg == "only" and only_content: + return (False, "Message must never contain only text content") + + if enforce_type == "never" and arg == "any" and has_content: + return (False, "Message must never contain text content") + + if enforce_type == "never" and arg == "none" and not has_content: + return (False, "Message must never contain no text content") + + return (True, None) + + @staticmethod + def validate_arg(arg: Literal["any", "only", "none"]) -> str | None: + if arg not in (MISSING, "any", "only", "none"): + return "Argument must be one of 'any', 'only', or 'none'" + + +class ContentLengthRule(DiscordMessageRule, name="content-length"): + """A rule for validating if a Discord message contains text content within the specified length range.""" + + @staticmethod + def validate( + enforce_type: EnforceType, + message: discord.Message, + arg: tuple[int, int], + ): + """Validate a message for the presence of text content within the specified length range.""" + + if not isinstance(arg, tuple) or len(arg) != 2: + raise ValueError("Argument must be a tuple of two integers") + + min_length, max_length = arg + + min_length = min_length or 0 + max_length = max_length or 4096 + + if min_length > max_length: + raise ValueError( + "Minimum length must be less than or equal to maximum length" + ) + + content_length = len(message.content) + + if enforce_type == "always" and not ( + min_length <= content_length <= max_length + ): + return ( + False, + f"Message must always contain text content within {min_length}-{max_length} characters", + ) + + if enforce_type == "never" and (min_length <= content_length <= max_length): + return ( + False, + f"Message must never contain text content within {min_length}-{max_length} characters", + ) + + return (True, None) + + @staticmethod + def validate_arg(arg: tuple[int | None, int | None]) -> str | None: + if (not isinstance(arg, (list, tuple))) or ( + isinstance(arg, (list, tuple)) and len(arg) != 2 + ): + return "Argument must be a list/tuple of two integers" + + if arg[0] is not None and arg[1] is not None: + if arg[0] > arg[1]: + return "Minimum length must be less than or equal to maximum length" + elif arg[0] is not None: + if arg[0] < 0: + return "Minimum length must be greater than or equal to 0" + elif arg[1] is not None: + if arg[1] < 0: + return "Maximum length must be greater than or equal to 0" + + +class URLsRule(DiscordMessageRule, name="urls"): + """A rule for validating if a Discord message contains only, at least one or no URLs.""" + + @staticmethod + def validate( + enforce_type: EnforceType, + message: discord.Message, + arg: Literal["any", "only", "none"], + ) -> tuple[Literal[False], str] | tuple[Literal[True], None]: + """Validate a message for the presence of URLs according to the specified arguments.""" + + search_obj = tuple(re.finditer(URL_PATTERN, message.content)) + links = tuple(match.group() for match in search_obj if match) + any_urls = bool(links) + only_urls = any_urls and sum(len(link) for link in links) == len( + re.sub(r"\s", "", message.content) + ) + no_urls = not any_urls + + if enforce_type == "always" and arg == "only" and not only_urls: + return (False, "Message must always contain only URLs") + + if enforce_type == "always" and arg == "any" and not any_urls: + return (False, "Message must always contain at least one URL") + + if enforce_type == "always" and arg == "none" and not no_urls: + return (False, "Message must always contain no URLs") + + if enforce_type == "never" and arg == "only" and only_urls: + return (False, "Message must never contain only URLs") + + if enforce_type == "never" and arg == "any" and any_urls: + return (False, "Message must never contain at least one URL") + + if enforce_type == "never" and arg == "none" and no_urls: + return (False, "Message must never contain no URLs") + + return (True, None) + + +# Rule for validating VCS URLs +class VCSURLsRule(DiscordMessageRule, name="vcs-urls"): + """A rule for validating if a Discord message contains only, at least one (the default), or no valid VCS URLs.""" + + @staticmethod + def validate( + enforce_type: EnforceType, + message: discord.Message, + arg: Literal["any", "all", "none"] = "any", + ) -> tuple[Literal[False], str] | tuple[Literal[True], None]: + """Validate a message for the presence of VCS URLs according to the specified arguments.""" + + search_obj = tuple(re.finditer(URL_PATTERN, message.content or "")) + links = tuple(match.group() for match in search_obj if match) + any_vcs_urls = links and any(is_vcs_url(link) for link in links) + no_vcs_urls = not any_vcs_urls + all_vcs_urls = not any(not is_vcs_url(link) for link in links) + + if enforce_type == "always" and arg == "all" and not all_vcs_urls: + return (False, "Message must always contain only valid VCS URLs") + + if enforce_type == "always" and arg == "any" and not any_vcs_urls: + return (False, "Message must always contain at least one valid VCS URL") + + if enforce_type == "always" and arg == "none" and not no_vcs_urls: + return (False, "Message must always contain no valid VCS URLs") + + if enforce_type == "never" and arg == "all" and all_vcs_urls: + return (False, "Message must never contain only valid VCS URLs") + + if enforce_type == "never" and arg == "any" and any_vcs_urls: + return (False, "Message must never contain at least one valid VCS URL") + + if enforce_type == "never" and arg == "none" and no_vcs_urls: + return (False, "Message must never contain no valid VCS URLs") + + return (True, None) + + @staticmethod + def validate_arg(arg: Literal["any", "all", "none"]) -> str | None: + if arg not in (MISSING, "any", "all", "none"): + return "Argument must be one of 'any', 'all', or 'none'" + + +class AttachmentsRule(DiscordMessageRule, name="attachments"): + """A rule for validating if a Discord message contains only, at least one or no attachments.""" + + @staticmethod + def validate( + enforce_type: EnforceType, + message: discord.Message, + arg: Literal["any", "only", "none"], + ): + """Validate a message for the presence of attachments according to the specified arguments.""" + + any_attachments = bool(message.attachments) + only_attachments = any_attachments and not (message.content or message.embeds) + no_attachments = not any_attachments + + if enforce_type == "always" and arg == "only" and not only_attachments: + return (False, "Message must always contain only attachments") + + if enforce_type == "always" and arg == "any" and not any_attachments: + return (False, "Message must always contain at least one attachment") + + if enforce_type == "always" and arg == "none" and not no_attachments: + return (False, "Message must always contain no attachments") + + if enforce_type == "never" and arg == "only" and only_attachments: + return (False, "Message must never contain only attachments") + + if enforce_type == "never" and arg == "any" and any_attachments: + return (False, "Message must never contain at least one attachment") + + if enforce_type == "never" and arg == "none" and no_attachments: + return (False, "Message must never contain no attachments") + + return (True, None) + + +class EmbedsRule(DiscordMessageRule, name="embeds"): + """A rule for validating if a Discord message contains only, at least one or no embeds.""" + + @staticmethod + def validate( + enforce_type: EnforceType, + message: discord.Message, + arg: Literal["any", "only", "none"], + ) -> tuple[Literal[False], str] | tuple[Literal[True], None]: + """Validate a message for the presence of embeds according to the specified arguments.""" + + any_embeds = bool(message.embeds) + only_embeds = any_embeds and not (message.content or message.attachments) + no_embeds = not any_embeds + + if enforce_type == "always" and arg == "only" and not only_embeds: + return (False, "Message must always contain only embeds") + + if enforce_type == "always" and arg == "any" and not any_embeds: + return (False, "Message must always contain at least one embed") + + if enforce_type == "always" and arg == "none" and not no_embeds: + return (False, "Message must always contain no embeds") + + if enforce_type == "never" and arg == "only" and only_embeds: + return (False, "Message must never contain only embeds") + + if enforce_type == "never" and arg == "any" and any_embeds: + return (False, "Message must never contain at least one embed") + + if enforce_type == "never" and arg == "none" and no_embeds: + return (False, "Message must never contain no embeds") + + return (True, None) + + +RULE_MAPPING: dict[str, type[DiscordMessageRule]] = { + "content": ContentRule, + "content-length": ContentLengthRule, + "urls": URLsRule, + "vcs-urls": VCSURLsRule, + "attachments": AttachmentsRule, + "embeds": EmbedsRule, +} diff --git a/pcbot/exts/showcase/utils/utils.py b/pcbot/exts/showcase/utils/utils.py new file mode 100644 index 0000000..5127bb3 --- /dev/null +++ b/pcbot/exts/showcase/utils/utils.py @@ -0,0 +1,76 @@ +# ABC for rules +from abc import ABC, abstractmethod +import re +from typing import Any, Literal, NotRequired, TypedDict +import discord + +EnforceType = Literal["always", "never"] + +MISSING: Any = object() + + +URL_PATTERN = re.compile( + r"(?P\w+):\/\/(?:(?P[\w_.-]+(?::[\w_.-]+)?)@)?(?P(?:(?P[\w_-]+(?:\.[\w_-]+)*)\.)?(?P(?P[\w_-]+)\.(?P\w+))|(?P[\w_-]+))(?:\:(?P\d+))?(?P\/[\w.,@?^=%&:\/~+-]*)?(?:\?(?P[\w.,@?^=%&:\/~+-]*))?(?:#(?P[\w@?^=%&\/~+#-]*))?" +) + + +def is_vcs_url(url: str) -> bool: + """Check if a URL points to a known VCS SaaS (e.g. GitHub, GitLab, Bitbucket).""" + return bool( + (match_ := (re.match(URL_PATTERN, url))) + and match_.group("scheme") in ("https", "http") + and match_.group("domain") in ("github.com", "gitlab.com", "bitbucket.org") + ) + + +class RuleSpecifier(TypedDict): + name: str + enforce_type: EnforceType + arg: NotRequired[Any] + description: NotRequired[str] + + +class RuleSpecifierPair(TypedDict): + mode: Literal["and", "or"] + clause1: "RuleSpecifier | RuleSpecifierPair | RuleSpecifierList" + clause2: "RuleSpecifier | RuleSpecifierPair | RuleSpecifierList" + description: NotRequired[str] + + +class RuleSpecifierList(TypedDict): + mode: Literal["any", "all"] + clauses: list["RuleSpecifier | RuleSpecifierPair | RuleSpecifierList"] + description: NotRequired[str] + + +class BadRuleSpecifier(Exception): + """Exception raised when a rule specifier is invalid.""" + + pass + + +class DiscordMessageRule(ABC): + name: str + + def __init_subclass__(cls, name: str) -> None: + cls.name = name + + @staticmethod + @abstractmethod + def validate( + enforce_type: EnforceType, message: discord.Message, arg: Any = None + ) -> tuple[Literal[False], str] | tuple[Literal[True], None]: + ... + + @staticmethod + def validate_arg(arg: Any) -> str | None: + ... + + +class AsyncDiscordMessageRule(DiscordMessageRule, name="AsyncDiscordMessageRule"): + @staticmethod + @abstractmethod + async def validate( + enforce_type: EnforceType, message: discord.Message, arg: Any = None + ) -> tuple[Literal[False], str] | tuple[Literal[True], None]: + ... diff --git a/pcbot/exts/showcase/utils/validators.py b/pcbot/exts/showcase/utils/validators.py new file mode 100644 index 0000000..330c614 --- /dev/null +++ b/pcbot/exts/showcase/utils/validators.py @@ -0,0 +1,234 @@ +from typing import Callable, overload + +import discord + +from .rules import RULE_MAPPING +from .utils import ( + MISSING, + BadRuleSpecifier, + RuleSpecifier, + RuleSpecifierList, + RuleSpecifierPair, +) + + +def dispatch_rule_specifier_dict_validator( + specifier: RuleSpecifier | RuleSpecifierPair | RuleSpecifierList, +) -> ( + Callable[[RuleSpecifier], None] + | Callable[[RuleSpecifierPair], None] + | Callable[[RuleSpecifierList], None] + | None +): + """Dispatch the appropriate validator to use to validate the structure of a rule specifier.""" + + if "mode" in specifier: + if specifier["mode"] in ("and", "or"): + return validate_rule_specifier_dict_pair + elif specifier["mode"] in ("any", "all"): + return validate_rule_specifier_dict_list + else: + return validate_rule_specifier_dict_single + + return None + + +def validate_rule_specifier_dict_single( + specifier: RuleSpecifier, + depth_viz: str = "RuleSpecifier", +) -> None: + """Validate a single rule specifier's structure.""" + + if specifier["name"] not in RULE_MAPPING: + raise BadRuleSpecifier( + f"{depth_viz}.name: Unknown rule '{specifier['name']}'" + ) # type + elif "enforce_type" not in specifier or specifier["enforce_type"].lower() not in ( + "always", + "never", + ): + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifier 'enforce_type' field must be set to 'always' or 'never'" + ) + + error_string = RULE_MAPPING[specifier["name"]].validate_arg( + specifier.get("arg", MISSING) + ) + + if error_string is not None: + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifier 'arg' field validation failed: {error_string}" + ) + + +def validate_rule_specifier_dict_pair( + specifier: RuleSpecifierPair, + depth_viz: str = "RuleSpecifierPair", +) -> None: + """Validate a rule specifier pair's structure.""" + + if "mode" not in specifier or specifier["mode"] not in ("and", "or"): + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifierPair 'mode' field must be 'and' or 'or'" + ) + + if "clause1" not in specifier or "clause2" not in specifier: + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifierPair must have 'clause1' " + "and 'clause2' fields pointing to RuleSpecifier or RuleSpecifierPair or RuleSpecifierList dicts" + ) + + dict_validator1 = dispatch_rule_specifier_dict_validator(specifier["clause1"]) + dict_validator2 = dispatch_rule_specifier_dict_validator(specifier["clause2"]) + + if dict_validator1 is None: + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifierPair 'clause1' field " + "must be a RuleSpecifier or RuleSpecifierPair or RuleSpecifierList dict" + ) + + if dict_validator2 is None: + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifierPair 'clause2' field " + "must be a RuleSpecifier or RuleSpecifierPair or RuleSpecifierList dict" + ) + + dict_validator1(specifier["clause1"], depth_viz=f"{depth_viz}.clause1") # type: ignore + dict_validator2(specifier["clause2"], depth_viz=f"{depth_viz}.clause2") # type: ignore + + +def validate_rule_specifier_dict_list( + specifier: RuleSpecifierList, + depth_viz: str = "RuleSpecifierList", +) -> None: + """Validate a rule specifier list's structure.""" + + if "mode" not in specifier or specifier["mode"] not in ("any", "all"): + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifierList 'mode' field must be 'any' or 'all'" + ) + + if "clauses" not in specifier or not specifier["clauses"]: + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifierList must have 'clauses' " + "field pointing to a list of RuleSpecifier or RuleSpecifierPair or RuleSpecifierList dicts" + ) + + for i, clause in enumerate(specifier["clauses"]): + dict_validator = dispatch_rule_specifier_dict_validator(clause) + if dict_validator is None: + raise BadRuleSpecifier( + f"{depth_viz}.clauses.{i} field " + "must be a RuleSpecifier or RuleSpecifierPair or RuleSpecifierList dict" + ) + + dict_validator(clause, depth_viz=f"{depth_viz}.clauses.{i}") # type: ignore + + +def dispatch_rule_specifier_message_validator( + specifier: RuleSpecifier | RuleSpecifierPair | RuleSpecifierList, +): + """Dispatch the appropriate validator to use to enforce a rule specifier on a Discord message.""" + + if "mode" in specifier: + if specifier["mode"] in ("and", "or"): + return rule_specifier_pair_validate_message + elif specifier["mode"] in ("any", "all"): + return rule_specifier_list_validate_message + return rule_specifier_single_validate_message + + +def rule_specifier_single_validate_message( + specifier: RuleSpecifier, + message: discord.Message, + depth_viz: str = "", +) -> tuple[bool, str | None]: + """Validate a message according to a single rule specifier.""" + + rule = RULE_MAPPING[specifier["name"]] + + if "arg" in specifier: + result = rule.validate(specifier["enforce_type"], message, specifier["arg"]) + else: + result = rule.validate(specifier["enforce_type"], message) + + if "description" in specifier: + # insert description of rule specifier if present + return (result[0], specifier["description"] if not result[0] else None) + + return result + + +def rule_specifier_pair_validate_message( + specifier: RuleSpecifierPair, + message: discord.Message, +) -> tuple[bool, str | None]: + """Validate a message according to a rule specifier pair.""" + + success = True + failure_description = specifier.get("description") + + print(specifier) + + validator1 = dispatch_rule_specifier_message_validator(specifier["clause1"]) + validator2 = dispatch_rule_specifier_message_validator(specifier["clause2"]) + + result1 = validator1(specifier["clause1"], message) # type: ignore + result2 = None + + success = result1[0] + if (specifier["mode"] == "and" and success) or ( + specifier["mode"] == "or" and not success + ): + result2 = validator2(specifier["clause2"], message) # type: ignore + success = bool(result2[0]) + + if not result1[0] and failure_description is None: + failure_description = result1[1] + elif result2 and not result2[0] and failure_description is None: + failure_description = result2[1] + + return (success, failure_description if not success else None) + + +def rule_specifier_list_validate_message( + specifier: RuleSpecifierList, + message: discord.Message, +) -> tuple[bool, str | None]: + """Validate a message according to a rule specifier list.""" + + success = True + failure_description = specifier.get("description") + + if specifier["mode"] == "all": + for i, clause in enumerate(specifier["clauses"]): + validator = dispatch_rule_specifier_message_validator(clause) + result = validator(clause, message) # type: ignore + if not result[0]: + success = False + if failure_description is None: + failure_description = result[1] + break + + elif specifier["mode"] == "any": + for i, clause in enumerate(specifier["clauses"]): + validator = dispatch_rule_specifier_message_validator(clause) + result = validator(clause, message) # type: ignore + success = success or result[0] + + if not success and failure_description is None: + failure_description = result[1] + + return (success, failure_description if not success else None) + + +def validate_message( + message: discord.Message, + specifier: RuleSpecifier | RuleSpecifierPair | RuleSpecifierList, +) -> tuple[bool, str | None]: + """Validate a message according to a rule specifier.""" + + validator = dispatch_rule_specifier_message_validator(specifier) + result = validator(specifier, message) # type: ignore + + return result