From 41be16477da586953717fff9f28ed6196ad12510 Mon Sep 17 00:00:00 2001
From: Zomatree <me@zomatree.live>
Date: Thu, 15 Feb 2024 20:25:37 +0000
Subject: [PATCH] feat: Command cooldowns

---
 revolt/ext/commands/__init__.py |   1 +
 revolt/ext/commands/command.py  |  51 ++++++++---
 revolt/ext/commands/context.py  |   1 +
 revolt/ext/commands/cooldown.py | 144 ++++++++++++++++++++++++++++++++
 revolt/ext/commands/errors.py   |  15 ++++
 5 files changed, 202 insertions(+), 10 deletions(-)
 create mode 100644 revolt/ext/commands/cooldown.py

diff --git a/revolt/ext/commands/__init__.py b/revolt/ext/commands/__init__.py
index b6529fd..93ea2af 100755
--- a/revolt/ext/commands/__init__.py
+++ b/revolt/ext/commands/__init__.py
@@ -4,6 +4,7 @@
 from .command import *
 from .context import *
 from .converters import *
+from .cooldown import *
 from .errors import *
 from .group import *
 from .help import *
diff --git a/revolt/ext/commands/command.py b/revolt/ext/commands/command.py
index e9cfbbf..fd65508 100755
--- a/revolt/ext/commands/command.py
+++ b/revolt/ext/commands/command.py
@@ -9,8 +9,9 @@
 
 from revolt.utils import maybe_coroutine
 
-from .errors import InvalidLiteralArgument, UnionConverterError
+from .errors import CommandOnCooldown, InvalidLiteralArgument, UnionConverterError
 from .utils import ClientT_Co_D, evaluate_parameters, ClientT_Co
+from .cooldown import BucketType, CooldownMapping
 
 if TYPE_CHECKING:
     from .checks import Check
@@ -43,28 +44,44 @@ class Command(Generic[ClientT_Co_D]):
         The cog the command is apart of.
     usage: Optional[:class:`str`]
         The usage string for the command
-    checks: list[Callable]
+    checks: Optional[list[Callable]]
         The list of checks the command has
+    cooldown: Optional[:class:`Cooldown`]
+        The cooldown for the command to restrict how often the command can be used
     description: Optional[:class:`str`]
         The commands description if it has one
     hidden: :class:`bool`
         Whether or not the command should be hidden from the help command
     """
-    __slots__ = ("callback", "name", "aliases", "signature", "checks", "parent", "_error_handler", "cog", "description", "usage", "parameters", "hidden")
-
-    def __init__(self, callback: Callable[..., Coroutine[Any, Any, Any]], name: str, aliases: list[str], usage: Optional[str] = None):
+    __slots__ = ("callback", "name", "aliases", "signature", "checks", "parent", "_error_handler", "cog", "description", "usage", "parameters", "hidden", "cooldown", "cooldown_bucket")
+
+    def __init__(
+            self,
+            callback: Callable[..., Coroutine[Any, Any, Any]],
+            name: str,
+            *,
+            aliases: list[str] | None = None,
+            usage: Optional[str] = None,
+            checks: list[Check[ClientT_Co_D]] | None = None,
+            cooldown: Optional[CooldownMapping] | None = None,
+            bucket: Optional[BucketType | Callable[[Context[ClientT_Co_D]], Coroutine[Any, Any, str]]] = None,
+            description: str | None = None,
+            hidden: bool = False,
+        ):
         self.callback: Callable[..., Coroutine[Any, Any, Any]] = callback
         self.name: str = name
-        self.aliases: list[str] = aliases
+        self.aliases: list[str] = aliases or []
         self.usage: str | None = usage
         self.signature: inspect.Signature = inspect.signature(self.callback)
         self.parameters: list[inspect.Parameter] = evaluate_parameters(self.signature.parameters.values(), getattr(callback, "__globals__", {}))
-        self.checks: list[Check[ClientT_Co_D]] = getattr(callback, "_checks", [])
+        self.checks: list[Check[ClientT_Co_D]] = checks or getattr(callback, "_checks", [])
+        self.cooldown = cooldown or getattr(callback, "_cooldown", None)
+        self.cooldown_bucket: BucketType | Callable[[Context[ClientT_Co_D]], Coroutine[Any, Any, str]] = bucket or getattr(callback, "_bucket", BucketType.default)
         self.parent: Optional[Group[ClientT_Co_D]] = None
         self.cog: Optional[Cog[ClientT_Co_D]] = None
         self._error_handler: Callable[[Any, Context[ClientT_Co_D], Exception], Coroutine[Any, Any, Any]] = type(self)._default_error_handler
-        self.description: str | None = callback.__doc__
-        self.hidden: bool = False
+        self.description: str | None = description or callback.__doc__
+        self.hidden: bool = hidden
 
     async def invoke(self, context: Context[ClientT_Co_D], *args: Any, **kwargs: Any) -> Any:
         """Runs the command and calls the error handler if the command errors.
@@ -181,6 +198,18 @@ async def parse_arguments(self, context: Context[ClientT_Co_D]) -> None:
 
                 context.args.append(arg)
 
+    async def run_cooldown(self, context: Context[ClientT_Co_D]):
+        if mapping := self.cooldown:
+            if isinstance(self.cooldown_bucket, BucketType):
+                key = self.cooldown_bucket.resolve(context)
+            else:
+                key = await self.cooldown_bucket(context)
+
+            cooldown = mapping.get_bucket(key)
+
+            if retry_after := cooldown.update_cooldown():
+                raise CommandOnCooldown(retry_after)
+
     def __repr__(self) -> str:
         return f"<{self.__class__.__name__} name=\"{self.name}\">"
 
@@ -239,6 +268,8 @@ def command(
         The aliases of the command, defaults to no aliases
     cls: type[:class:`Command`]
         The class used for creating the command, this defaults to :class:`Command` but can be used to use a custom command subclass
+    usage: Optional[:class:`str`]
+        The signature for how the command should be called
 
     Returns
     --------
@@ -246,6 +277,6 @@ def command(
         A function that takes the command callback and returns a :class:`Command`
     """
     def inner(func: Callable[..., Coroutine[Any, Any, Any]]) -> Command[ClientT_Co]:
-        return cls(func, name or func.__name__, aliases or [], usage)
+        return cls(func, name or func.__name__, aliases=aliases or [], usage=usage)
 
     return inner
diff --git a/revolt/ext/commands/context.py b/revolt/ext/commands/context.py
index 5060084..6b7ad67 100755
--- a/revolt/ext/commands/context.py
+++ b/revolt/ext/commands/context.py
@@ -97,6 +97,7 @@ async def invoke(self) -> Any:
 
                     self.view.undo()
 
+            await command.run_cooldown(self)
             await command.parse_arguments(self)
             return await command.invoke(self, *self.args, **self.kwargs)
 
diff --git a/revolt/ext/commands/cooldown.py b/revolt/ext/commands/cooldown.py
new file mode 100644
index 0000000..3ae6b64
--- /dev/null
+++ b/revolt/ext/commands/cooldown.py
@@ -0,0 +1,144 @@
+from __future__ import annotations
+
+import time
+from typing import TYPE_CHECKING, Any, Callable, Coroutine, TypeVar, cast
+
+from .errors import ServerOnly
+
+if TYPE_CHECKING:
+    from enum import Enum
+
+    from .context import Context
+    from .utils import ClientT_Co_D, ClientT_Co
+else:
+    from aenum import Enum
+
+__all__ = ("Cooldown", "CooldownMapping", "BucketType", "cooldown")
+
+T = TypeVar("T")
+
+class Cooldown:
+    """Represent a single cooldown for a single key
+
+    Parameters
+    -----------
+    rate: :class:`int`
+        How many times it can be used
+    per: :class:`int`
+        How long the window is before the ratelimit resets
+    """
+
+    def __init__(self, rate: int, per: int):
+        self.rate: int = rate
+        self.per: int = per
+        self.window: float = 0.0
+        self.tokens: int = rate
+        self.last: float = 0.0
+
+    def get_tokens(self, current: float | None) -> int:
+        current = current or time.time()
+
+        if current > (self.window + self.per):
+            return self.rate
+        else:
+            return self.tokens
+
+    def update_cooldown(self) -> float | None:
+        current = time.time()
+
+        self.last = current
+
+        self.tokens = self.get_tokens(current)
+
+        if self.tokens == 0:
+            return self.per - (current - self.window)
+
+        self.tokens -= 1
+
+        if self.tokens == 0:
+            self.window = current
+
+        return None
+
+class CooldownMapping:
+    """Holds all cooldowns for every key"""
+    def __init__(self, rate: int, per: int):
+        self.rate = rate
+        self.per = per
+        self.cache: dict[str, Cooldown] = {}
+
+    def verify_cache(self):
+        current = time.time()
+        self.cache = {k: v for k, v in self.cache.items() if current < (v.last + v.per)}
+
+    def get_bucket(self, key: str) -> Cooldown:
+        self.verify_cache()
+
+        if not (rl := self.cache.get(key)):
+            self.cache[key] = rl = Cooldown(self.rate, self.per)
+
+        return rl
+
+class BucketType(Enum):
+    default = 0
+    user = 1
+    server = 2
+    channel = 3
+    member = 4
+
+    def resolve(self, context: Context[ClientT_Co_D]) -> str:
+        if self == BucketType.default:
+            return f"{context.author.id}{context.channel.id}"
+
+        elif self == BucketType.user:
+            return context.author.id
+
+        elif self == BucketType.server:
+            if id := context.server_id:
+                return id
+
+            raise ServerOnly
+
+        elif self == BucketType.channel:
+            return context.channel.id
+
+        else:  # BucketType.member
+            if server_id := context.server_id:
+                return f"{context.author.id}{server_id}"
+
+            raise ServerOnly
+
+def cooldown(rate: int, per: int, *, bucket: BucketType | Callable[[Context[ClientT_Co]], Coroutine[Any, Any, str]] = BucketType.default) -> Callable[[T], T]:
+    """Adds a cooldown to a command
+
+    Parameters
+    -----------
+    rate: :class:`int`
+        How many times it can be used
+    per: :class:`int`
+        How long the window is before the ratelimit resets
+    bucket: Optional[Union[:class:`BucketType`, Callable[[Context], str]]]
+        Controls how the key is generated for the cooldowns
+
+    Examples
+    --------
+    .. code-block:: python
+        @commands.command()
+        @commands.cooldown(1, 5)
+        async def ping(self, ctx: Context):
+            await ctx.send("Pong")
+    """
+    def inner(func: T) -> T:
+        from .command import Command
+
+        if isinstance(func, Command):
+            command = cast(Command[ClientT_Co], func)  # cant verify generic at runtime so must cast
+            command.cooldown = CooldownMapping(rate, per)
+            command.cooldown_bucket = bucket
+        else:
+            func._cooldown = CooldownMapping(rate, per)  # type: ignore
+            func._bucket = bucket  # type: ignore
+
+        return func
+
+    return inner
\ No newline at end of file
diff --git a/revolt/ext/commands/errors.py b/revolt/ext/commands/errors.py
index bd07e8b..80b8e9d 100755
--- a/revolt/ext/commands/errors.py
+++ b/revolt/ext/commands/errors.py
@@ -16,6 +16,7 @@
     "UserConverterError",
     "MemberConverterError",
     "MissingSetup",
+    "CommandOnCooldown"
 )
 
 class CommandError(RevoltError):
@@ -97,3 +98,17 @@ def __init__(self, argument: str):
 
 class MissingSetup(CommandError):
     """Raised when an extension is missing the `setup` function"""
+
+class CommandOnCooldown(CommandError):
+    """Raised when a command is on cooldown
+
+    Attributes
+    -----------
+    retry_after: :class:`float`
+        How long the user must wait until the cooldown resets
+    """
+
+    __slots__ = ("retry_after",)
+
+    def __init__(self, retry_after: float):
+        self.retry_after = retry_after
\ No newline at end of file