Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CommandTree.get_hash #10082

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 62 additions & 9 deletions discord/app_commands/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from __future__ import annotations
import logging
import inspect
import sys
from hashlib import blake2b

from typing import (
Any,
Expand Down Expand Up @@ -62,7 +64,7 @@
from .translator import Translator, locale_str
from ..errors import ClientException, HTTPException
from ..enums import AppCommandType, InteractionType
from ..utils import MISSING, _get_as_snowflake, _is_submodule, _shorten
from ..utils import MISSING, _get_as_snowflake, _is_submodule, _shorten, _to_json
from .._types import ClientT


Expand All @@ -82,6 +84,25 @@
_log = logging.getLogger(__name__)


if sys.version_info < (3, 9):
_blake_kwargs = {}
else:
# This prevents blake from raising an error
# with some configurations in python 3.9
_blake_kwargs = {'usedforsecurity': False}


def _hash_payload(payload: list[Dict[str, Any]]) -> bytes:
tree_hash = blake2b(digest_size=32, person=b"tree", last_node=True, **_blake_kwargs)
command_hashes = [
blake2b(_to_json(c).encode(), person=b"command", last_node=False, **_blake_kwargs).digest() for c in payload
]
for h in sorted(command_hashes):
tree_hash.update(h)

return b"v1:" + tree_hash.digest()


def _retrieve_guild_ids(
command: Any, guild: Optional[Snowflake] = MISSING, guilds: Sequence[Snowflake] = MISSING
) -> Optional[Set[int]]:
Expand Down Expand Up @@ -1076,6 +1097,17 @@ async def set_translator(self, translator: Optional[Translator]) -> None:
await translator.load()
self._state._translator = translator

async def _get_payload(self, *, guild: Optional[Snowflake] = None) -> List[Dict[str, Any]]:
commands = self._get_all_commands(guild=guild)
mikeshardmind marked this conversation as resolved.
Show resolved Hide resolved

translator = self.translator
if translator:
payload = [await command.get_translated_payload(self, translator) for command in commands]
else:
payload = [command.to_dict(self) for command in commands]

return payload

async def sync(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]:
"""|coro|

Expand Down Expand Up @@ -1116,13 +1148,7 @@ async def sync(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]:
if self.client.application_id is None:
raise MissingApplicationID

commands = self._get_all_commands(guild=guild)

translator = self.translator
if translator:
payload = [await command.get_translated_payload(self, translator) for command in commands]
else:
payload = [command.to_dict(self) for command in commands]
payload = await self._get_payload(guild=guild)

try:
if guild is None:
Expand All @@ -1131,7 +1157,7 @@ async def sync(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]:
data = await self._http.bulk_upsert_guild_commands(self.client.application_id, guild.id, payload=payload)
except HTTPException as e:
if e.status == 400 and e.code == 50035:
raise CommandSyncFailure(e, commands) from None
raise CommandSyncFailure(e, self._get_all_commands(guild=guild)) from None
raise

return [AppCommand(data=d, state=self._state) for d in data]
Expand Down Expand Up @@ -1315,3 +1341,30 @@ async def _call(self, interaction: Interaction[ClientT]) -> None:
else:
if not interaction.command_failed:
self.client.dispatch('app_command_completion', interaction, command)

async def get_hash(self, *, guild: Optional[Snowflake] = None) -> bytes:
"""|coro|

Returns a hash for tree's state, either for the global tree
if no guild is provided, or the guild specific portions of
the tree if a guild is provided.

This can be used to avoid uneccessary syncing.

.. versionadded:: 2.5

Parameters
-----------
guild: Optional[:class:`~discord.abc.Snowflake`]
The guild to get the tree hash for. If ``None`` then then
the hash for all global commands instead.

This corresponds to the guild parameter used when syncing.

Returns
-------
bytes
A hash of the tree's state as it would be synced.
"""
payload = await self._get_payload(guild=guild)
return _hash_payload(payload)
Loading