From f2d1fc0bf8a987096253a48fd6cbc66f5124eb89 Mon Sep 17 00:00:00 2001
From: Mega-JC <65417594+Mega-JC@users.noreply.github.com>
Date: Sat, 27 Jul 2024 13:07:22 +0200
Subject: [PATCH] Refactor attachment checking in 'crosspost_cmp()' in
 'anti_crosspost' extension

---
 pcbot/exts/anti_crosspost.py | 75 +++++++++++++++++++++++++++++-------
 1 file changed, 61 insertions(+), 14 deletions(-)

diff --git a/pcbot/exts/anti_crosspost.py b/pcbot/exts/anti_crosspost.py
index 89444a6..2ab083b 100644
--- a/pcbot/exts/anti_crosspost.py
+++ b/pcbot/exts/anti_crosspost.py
@@ -15,7 +15,16 @@
 BotT = snakecore.commands.Bot | snakecore.commands.AutoShardedBot
 
 
-def crosspost_cmp(message: discord.Message, other: discord.Message) -> bool:
+fetched_attachments: dict[int, bytes] = {}
+
+
+async def fetch_attachment(attachment: discord.Attachment, cache: bool = True) -> bytes:
+    if cache and attachment.id in fetched_attachments:
+        return fetched_attachments[attachment.id]
+    return await attachment.read()
+
+
+async def crosspost_cmp(message: discord.Message, other: discord.Message) -> bool:
     """
     Compare two messages to determine if they are crossposts or duplicates.
 
@@ -28,24 +37,62 @@ def crosspost_cmp(message: discord.Message, other: discord.Message) -> bool:
         duplicates, otherwise False.
     """
 
-    similarity_score = 0.0
-    matching_attachments = False
-    if message.content and other.content:
+    similarity_score = None
+    matching_attachments = None
+
+    have_content = message.content and other.content
+    have_attachments = message.attachments and other.attachments
+
+    if have_content:
         hamming_score = sum(
             x != y for x, y in zip(message.content, other.content)
         ) / max(len(message.content), len(other.content))
         similarity_score = min(max(0, 1 - hamming_score), 1)
-
-    elif message.attachments and other.attachments:
-        matching_attachments = all(
-            att1.filename == att2.filename and att1.size == att2.size
-            for att1, att2 in zip(
-                sorted(message.attachments, key=lambda x: (x.filename, x.size)),
-                sorted(other.attachments, key=lambda x: (x.filename, x.size)),
+    else:
+        similarity_score = 0
+
+    if have_attachments:
+        # Check if the attachments are the same:
+        # - Sort the attachments by filename and size
+        # - Compare the sorted lists of attachments
+        # - if filename and size are the same,
+        # additionally check if the content is the same
+        # (only if under 8mb)
+
+        try:
+            matching_attachments = all(
+                [
+                    att1.filename == att2.filename
+                    and att1.size == att2.size
+                    and att1.size < 2**20 * 8
+                    and att2.size < 2**20 * 8
+                    and att1.content_type == att2.content_type
+                    and (await fetch_attachment(att1) == await fetch_attachment(att2))
+                    for att1, att2 in zip(
+                        sorted(message.attachments, key=lambda x: (x.filename, x.size)),
+                        sorted(other.attachments, key=lambda x: (x.filename, x.size)),
+                    )
+                ]
             )
-        )
+        except discord.HTTPException:
+            matching_attachments = False
+    else:
+        matching_attachments = False
 
-    return similarity_score > 0.80 or matching_attachments
+    if not have_content and (message.content or other.content):
+        return False
+    elif not have_attachments and (
+        message.attachments or other.attachments
+    ):
+        return False
+    elif have_content and have_attachments:
+        return similarity_score > 0.80 and matching_attachments
+    elif have_content:
+        return similarity_score > 0.80
+    elif have_attachments:
+        return matching_attachments
+    
+    return False
 
 
 class UserCrosspostCache(TypedDict):
@@ -132,7 +179,7 @@ async def on_message(self, message: discord.Message):
             for messages in user_cache["message_groups"]:
                 for existing_message in messages:
                     if (
-                        crosspost_cmp(message, existing_message)
+                        await crosspost_cmp(message, existing_message)
                         and message.created_at.timestamp()
                         - existing_message.created_at.timestamp()
                         <= self.crosspost_timedelta_threshold