Skip to content

Commit

Permalink
feat: Command cooldowns
Browse files Browse the repository at this point in the history
  • Loading branch information
Zomatree committed Feb 15, 2024
1 parent 5a22d60 commit 41be164
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 10 deletions.
1 change: 1 addition & 0 deletions revolt/ext/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .command import *
from .context import *
from .converters import *
from .cooldown import *
from .errors import *
from .group import *
from .help import *
51 changes: 41 additions & 10 deletions revolt/ext/commands/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from revolt.utils import maybe_coroutine

from .errors import InvalidLiteralArgument, UnionConverterError
from .errors import CommandOnCooldown, InvalidLiteralArgument, UnionConverterError
from .utils import ClientT_Co_D, evaluate_parameters, ClientT_Co
from .cooldown import BucketType, CooldownMapping

if TYPE_CHECKING:
from .checks import Check
Expand Down Expand Up @@ -43,28 +44,44 @@ class Command(Generic[ClientT_Co_D]):
The cog the command is apart of.
usage: Optional[:class:`str`]
The usage string for the command
checks: list[Callable]
checks: Optional[list[Callable]]
The list of checks the command has
cooldown: Optional[:class:`Cooldown`]
The cooldown for the command to restrict how often the command can be used
description: Optional[:class:`str`]
The commands description if it has one
hidden: :class:`bool`
Whether or not the command should be hidden from the help command
"""
__slots__ = ("callback", "name", "aliases", "signature", "checks", "parent", "_error_handler", "cog", "description", "usage", "parameters", "hidden")

def __init__(self, callback: Callable[..., Coroutine[Any, Any, Any]], name: str, aliases: list[str], usage: Optional[str] = None):
__slots__ = ("callback", "name", "aliases", "signature", "checks", "parent", "_error_handler", "cog", "description", "usage", "parameters", "hidden", "cooldown", "cooldown_bucket")

def __init__(
self,
callback: Callable[..., Coroutine[Any, Any, Any]],
name: str,
*,
aliases: list[str] | None = None,
usage: Optional[str] = None,
checks: list[Check[ClientT_Co_D]] | None = None,
cooldown: Optional[CooldownMapping] | None = None,
bucket: Optional[BucketType | Callable[[Context[ClientT_Co_D]], Coroutine[Any, Any, str]]] = None,
description: str | None = None,
hidden: bool = False,
):
self.callback: Callable[..., Coroutine[Any, Any, Any]] = callback
self.name: str = name
self.aliases: list[str] = aliases
self.aliases: list[str] = aliases or []
self.usage: str | None = usage
self.signature: inspect.Signature = inspect.signature(self.callback)
self.parameters: list[inspect.Parameter] = evaluate_parameters(self.signature.parameters.values(), getattr(callback, "__globals__", {}))
self.checks: list[Check[ClientT_Co_D]] = getattr(callback, "_checks", [])
self.checks: list[Check[ClientT_Co_D]] = checks or getattr(callback, "_checks", [])
self.cooldown = cooldown or getattr(callback, "_cooldown", None)
self.cooldown_bucket: BucketType | Callable[[Context[ClientT_Co_D]], Coroutine[Any, Any, str]] = bucket or getattr(callback, "_bucket", BucketType.default)
self.parent: Optional[Group[ClientT_Co_D]] = None
self.cog: Optional[Cog[ClientT_Co_D]] = None
self._error_handler: Callable[[Any, Context[ClientT_Co_D], Exception], Coroutine[Any, Any, Any]] = type(self)._default_error_handler
self.description: str | None = callback.__doc__
self.hidden: bool = False
self.description: str | None = description or callback.__doc__
self.hidden: bool = hidden

async def invoke(self, context: Context[ClientT_Co_D], *args: Any, **kwargs: Any) -> Any:
"""Runs the command and calls the error handler if the command errors.
Expand Down Expand Up @@ -181,6 +198,18 @@ async def parse_arguments(self, context: Context[ClientT_Co_D]) -> None:

context.args.append(arg)

async def run_cooldown(self, context: Context[ClientT_Co_D]):
if mapping := self.cooldown:
if isinstance(self.cooldown_bucket, BucketType):
key = self.cooldown_bucket.resolve(context)
else:
key = await self.cooldown_bucket(context)

cooldown = mapping.get_bucket(key)

if retry_after := cooldown.update_cooldown():
raise CommandOnCooldown(retry_after)

def __repr__(self) -> str:
return f"<{self.__class__.__name__} name=\"{self.name}\">"

Expand Down Expand Up @@ -239,13 +268,15 @@ def command(
The aliases of the command, defaults to no aliases
cls: type[:class:`Command`]
The class used for creating the command, this defaults to :class:`Command` but can be used to use a custom command subclass
usage: Optional[:class:`str`]
The signature for how the command should be called
Returns
--------
Callable[Callable[..., Coroutine], :class:`Command`]
A function that takes the command callback and returns a :class:`Command`
"""
def inner(func: Callable[..., Coroutine[Any, Any, Any]]) -> Command[ClientT_Co]:
return cls(func, name or func.__name__, aliases or [], usage)
return cls(func, name or func.__name__, aliases=aliases or [], usage=usage)

return inner
1 change: 1 addition & 0 deletions revolt/ext/commands/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ async def invoke(self) -> Any:

self.view.undo()

await command.run_cooldown(self)
await command.parse_arguments(self)
return await command.invoke(self, *self.args, **self.kwargs)

Expand Down
144 changes: 144 additions & 0 deletions revolt/ext/commands/cooldown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from __future__ import annotations

import time
from typing import TYPE_CHECKING, Any, Callable, Coroutine, TypeVar, cast

from .errors import ServerOnly

if TYPE_CHECKING:
from enum import Enum

from .context import Context
from .utils import ClientT_Co_D, ClientT_Co
else:
from aenum import Enum

__all__ = ("Cooldown", "CooldownMapping", "BucketType", "cooldown")

T = TypeVar("T")

class Cooldown:
"""Represent a single cooldown for a single key
Parameters
-----------
rate: :class:`int`
How many times it can be used
per: :class:`int`
How long the window is before the ratelimit resets
"""

def __init__(self, rate: int, per: int):
self.rate: int = rate
self.per: int = per
self.window: float = 0.0
self.tokens: int = rate
self.last: float = 0.0

def get_tokens(self, current: float | None) -> int:
current = current or time.time()

if current > (self.window + self.per):
return self.rate
else:
return self.tokens

def update_cooldown(self) -> float | None:
current = time.time()

self.last = current

self.tokens = self.get_tokens(current)

if self.tokens == 0:
return self.per - (current - self.window)

self.tokens -= 1

if self.tokens == 0:
self.window = current

return None

class CooldownMapping:
"""Holds all cooldowns for every key"""
def __init__(self, rate: int, per: int):
self.rate = rate
self.per = per
self.cache: dict[str, Cooldown] = {}

def verify_cache(self):
current = time.time()
self.cache = {k: v for k, v in self.cache.items() if current < (v.last + v.per)}

def get_bucket(self, key: str) -> Cooldown:
self.verify_cache()

if not (rl := self.cache.get(key)):
self.cache[key] = rl = Cooldown(self.rate, self.per)

return rl

class BucketType(Enum):
default = 0
user = 1
server = 2
channel = 3
member = 4

def resolve(self, context: Context[ClientT_Co_D]) -> str:
if self == BucketType.default:
return f"{context.author.id}{context.channel.id}"

elif self == BucketType.user:
return context.author.id

elif self == BucketType.server:
if id := context.server_id:
return id

raise ServerOnly

elif self == BucketType.channel:
return context.channel.id

else: # BucketType.member
if server_id := context.server_id:
return f"{context.author.id}{server_id}"

raise ServerOnly

def cooldown(rate: int, per: int, *, bucket: BucketType | Callable[[Context[ClientT_Co]], Coroutine[Any, Any, str]] = BucketType.default) -> Callable[[T], T]:
"""Adds a cooldown to a command
Parameters
-----------
rate: :class:`int`
How many times it can be used
per: :class:`int`
How long the window is before the ratelimit resets
bucket: Optional[Union[:class:`BucketType`, Callable[[Context], str]]]
Controls how the key is generated for the cooldowns
Examples
--------
.. code-block:: python
@commands.command()
@commands.cooldown(1, 5)
async def ping(self, ctx: Context):
await ctx.send("Pong")
"""
def inner(func: T) -> T:
from .command import Command

if isinstance(func, Command):
command = cast(Command[ClientT_Co], func) # cant verify generic at runtime so must cast
command.cooldown = CooldownMapping(rate, per)
command.cooldown_bucket = bucket
else:
func._cooldown = CooldownMapping(rate, per) # type: ignore
func._bucket = bucket # type: ignore

return func

return inner
15 changes: 15 additions & 0 deletions revolt/ext/commands/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"UserConverterError",
"MemberConverterError",
"MissingSetup",
"CommandOnCooldown"
)

class CommandError(RevoltError):
Expand Down Expand Up @@ -97,3 +98,17 @@ def __init__(self, argument: str):

class MissingSetup(CommandError):
"""Raised when an extension is missing the `setup` function"""

class CommandOnCooldown(CommandError):
"""Raised when a command is on cooldown
Attributes
-----------
retry_after: :class:`float`
How long the user must wait until the cooldown resets
"""

__slots__ = ("retry_after",)

def __init__(self, retry_after: float):
self.retry_after = retry_after

0 comments on commit 41be164

Please sign in to comment.