diff --git a/revolt/ext/commands/__init__.py b/revolt/ext/commands/__init__.py index b6529fd..93ea2af 100755 --- a/revolt/ext/commands/__init__.py +++ b/revolt/ext/commands/__init__.py @@ -4,6 +4,7 @@ from .command import * from .context import * from .converters import * +from .cooldown import * from .errors import * from .group import * from .help import * diff --git a/revolt/ext/commands/command.py b/revolt/ext/commands/command.py index e9cfbbf..fd65508 100755 --- a/revolt/ext/commands/command.py +++ b/revolt/ext/commands/command.py @@ -9,8 +9,9 @@ from revolt.utils import maybe_coroutine -from .errors import InvalidLiteralArgument, UnionConverterError +from .errors import CommandOnCooldown, InvalidLiteralArgument, UnionConverterError from .utils import ClientT_Co_D, evaluate_parameters, ClientT_Co +from .cooldown import BucketType, CooldownMapping if TYPE_CHECKING: from .checks import Check @@ -43,28 +44,44 @@ class Command(Generic[ClientT_Co_D]): The cog the command is apart of. usage: Optional[:class:`str`] The usage string for the command - checks: list[Callable] + checks: Optional[list[Callable]] The list of checks the command has + cooldown: Optional[:class:`Cooldown`] + The cooldown for the command to restrict how often the command can be used description: Optional[:class:`str`] The commands description if it has one hidden: :class:`bool` Whether or not the command should be hidden from the help command """ - __slots__ = ("callback", "name", "aliases", "signature", "checks", "parent", "_error_handler", "cog", "description", "usage", "parameters", "hidden") - - def __init__(self, callback: Callable[..., Coroutine[Any, Any, Any]], name: str, aliases: list[str], usage: Optional[str] = None): + __slots__ = ("callback", "name", "aliases", "signature", "checks", "parent", "_error_handler", "cog", "description", "usage", "parameters", "hidden", "cooldown", "cooldown_bucket") + + def __init__( + self, + callback: Callable[..., Coroutine[Any, Any, Any]], + name: str, + *, + aliases: list[str] | None = None, + usage: Optional[str] = None, + checks: list[Check[ClientT_Co_D]] | None = None, + cooldown: Optional[CooldownMapping] | None = None, + bucket: Optional[BucketType | Callable[[Context[ClientT_Co_D]], Coroutine[Any, Any, str]]] = None, + description: str | None = None, + hidden: bool = False, + ): self.callback: Callable[..., Coroutine[Any, Any, Any]] = callback self.name: str = name - self.aliases: list[str] = aliases + self.aliases: list[str] = aliases or [] self.usage: str | None = usage self.signature: inspect.Signature = inspect.signature(self.callback) self.parameters: list[inspect.Parameter] = evaluate_parameters(self.signature.parameters.values(), getattr(callback, "__globals__", {})) - self.checks: list[Check[ClientT_Co_D]] = getattr(callback, "_checks", []) + self.checks: list[Check[ClientT_Co_D]] = checks or getattr(callback, "_checks", []) + self.cooldown = cooldown or getattr(callback, "_cooldown", None) + self.cooldown_bucket: BucketType | Callable[[Context[ClientT_Co_D]], Coroutine[Any, Any, str]] = bucket or getattr(callback, "_bucket", BucketType.default) self.parent: Optional[Group[ClientT_Co_D]] = None self.cog: Optional[Cog[ClientT_Co_D]] = None self._error_handler: Callable[[Any, Context[ClientT_Co_D], Exception], Coroutine[Any, Any, Any]] = type(self)._default_error_handler - self.description: str | None = callback.__doc__ - self.hidden: bool = False + self.description: str | None = description or callback.__doc__ + self.hidden: bool = hidden async def invoke(self, context: Context[ClientT_Co_D], *args: Any, **kwargs: Any) -> Any: """Runs the command and calls the error handler if the command errors. @@ -181,6 +198,18 @@ async def parse_arguments(self, context: Context[ClientT_Co_D]) -> None: context.args.append(arg) + async def run_cooldown(self, context: Context[ClientT_Co_D]): + if mapping := self.cooldown: + if isinstance(self.cooldown_bucket, BucketType): + key = self.cooldown_bucket.resolve(context) + else: + key = await self.cooldown_bucket(context) + + cooldown = mapping.get_bucket(key) + + if retry_after := cooldown.update_cooldown(): + raise CommandOnCooldown(retry_after) + def __repr__(self) -> str: return f"<{self.__class__.__name__} name=\"{self.name}\">" @@ -239,6 +268,8 @@ def command( The aliases of the command, defaults to no aliases cls: type[:class:`Command`] The class used for creating the command, this defaults to :class:`Command` but can be used to use a custom command subclass + usage: Optional[:class:`str`] + The signature for how the command should be called Returns -------- @@ -246,6 +277,6 @@ def command( A function that takes the command callback and returns a :class:`Command` """ def inner(func: Callable[..., Coroutine[Any, Any, Any]]) -> Command[ClientT_Co]: - return cls(func, name or func.__name__, aliases or [], usage) + return cls(func, name or func.__name__, aliases=aliases or [], usage=usage) return inner diff --git a/revolt/ext/commands/context.py b/revolt/ext/commands/context.py index 5060084..6b7ad67 100755 --- a/revolt/ext/commands/context.py +++ b/revolt/ext/commands/context.py @@ -97,6 +97,7 @@ async def invoke(self) -> Any: self.view.undo() + await command.run_cooldown(self) await command.parse_arguments(self) return await command.invoke(self, *self.args, **self.kwargs) diff --git a/revolt/ext/commands/cooldown.py b/revolt/ext/commands/cooldown.py new file mode 100644 index 0000000..3ae6b64 --- /dev/null +++ b/revolt/ext/commands/cooldown.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Any, Callable, Coroutine, TypeVar, cast + +from .errors import ServerOnly + +if TYPE_CHECKING: + from enum import Enum + + from .context import Context + from .utils import ClientT_Co_D, ClientT_Co +else: + from aenum import Enum + +__all__ = ("Cooldown", "CooldownMapping", "BucketType", "cooldown") + +T = TypeVar("T") + +class Cooldown: + """Represent a single cooldown for a single key + + Parameters + ----------- + rate: :class:`int` + How many times it can be used + per: :class:`int` + How long the window is before the ratelimit resets + """ + + def __init__(self, rate: int, per: int): + self.rate: int = rate + self.per: int = per + self.window: float = 0.0 + self.tokens: int = rate + self.last: float = 0.0 + + def get_tokens(self, current: float | None) -> int: + current = current or time.time() + + if current > (self.window + self.per): + return self.rate + else: + return self.tokens + + def update_cooldown(self) -> float | None: + current = time.time() + + self.last = current + + self.tokens = self.get_tokens(current) + + if self.tokens == 0: + return self.per - (current - self.window) + + self.tokens -= 1 + + if self.tokens == 0: + self.window = current + + return None + +class CooldownMapping: + """Holds all cooldowns for every key""" + def __init__(self, rate: int, per: int): + self.rate = rate + self.per = per + self.cache: dict[str, Cooldown] = {} + + def verify_cache(self): + current = time.time() + self.cache = {k: v for k, v in self.cache.items() if current < (v.last + v.per)} + + def get_bucket(self, key: str) -> Cooldown: + self.verify_cache() + + if not (rl := self.cache.get(key)): + self.cache[key] = rl = Cooldown(self.rate, self.per) + + return rl + +class BucketType(Enum): + default = 0 + user = 1 + server = 2 + channel = 3 + member = 4 + + def resolve(self, context: Context[ClientT_Co_D]) -> str: + if self == BucketType.default: + return f"{context.author.id}{context.channel.id}" + + elif self == BucketType.user: + return context.author.id + + elif self == BucketType.server: + if id := context.server_id: + return id + + raise ServerOnly + + elif self == BucketType.channel: + return context.channel.id + + else: # BucketType.member + if server_id := context.server_id: + return f"{context.author.id}{server_id}" + + raise ServerOnly + +def cooldown(rate: int, per: int, *, bucket: BucketType | Callable[[Context[ClientT_Co]], Coroutine[Any, Any, str]] = BucketType.default) -> Callable[[T], T]: + """Adds a cooldown to a command + + Parameters + ----------- + rate: :class:`int` + How many times it can be used + per: :class:`int` + How long the window is before the ratelimit resets + bucket: Optional[Union[:class:`BucketType`, Callable[[Context], str]]] + Controls how the key is generated for the cooldowns + + Examples + -------- + .. code-block:: python + @commands.command() + @commands.cooldown(1, 5) + async def ping(self, ctx: Context): + await ctx.send("Pong") + """ + def inner(func: T) -> T: + from .command import Command + + if isinstance(func, Command): + command = cast(Command[ClientT_Co], func) # cant verify generic at runtime so must cast + command.cooldown = CooldownMapping(rate, per) + command.cooldown_bucket = bucket + else: + func._cooldown = CooldownMapping(rate, per) # type: ignore + func._bucket = bucket # type: ignore + + return func + + return inner \ No newline at end of file diff --git a/revolt/ext/commands/errors.py b/revolt/ext/commands/errors.py index bd07e8b..80b8e9d 100755 --- a/revolt/ext/commands/errors.py +++ b/revolt/ext/commands/errors.py @@ -16,6 +16,7 @@ "UserConverterError", "MemberConverterError", "MissingSetup", + "CommandOnCooldown" ) class CommandError(RevoltError): @@ -97,3 +98,17 @@ def __init__(self, argument: str): class MissingSetup(CommandError): """Raised when an extension is missing the `setup` function""" + +class CommandOnCooldown(CommandError): + """Raised when a command is on cooldown + + Attributes + ----------- + retry_after: :class:`float` + How long the user must wait until the cooldown resets + """ + + __slots__ = ("retry_after",) + + def __init__(self, retry_after: float): + self.retry_after = retry_after \ No newline at end of file