From 1ab91658c2a3b5e03d9fb363a244e903479a848c Mon Sep 17 00:00:00 2001 From: Zomatree Date: Mon, 8 Aug 2022 19:23:25 +0100 Subject: [PATCH] handle server join event --- revolt/member.py | 6 +++- revolt/message.py | 17 +++++------ revolt/state.py | 17 ++++++----- revolt/types/gateway.py | 8 +++++- revolt/user.py | 12 ++++++-- revolt/websocket.py | 64 ++++++++++++++++++++++++++++++++++++++++- 6 files changed, 103 insertions(+), 21 deletions(-) diff --git a/revolt/member.py b/revolt/member.py index fd0dcb6..8ebc734 100755 --- a/revolt/member.py +++ b/revolt/member.py @@ -35,10 +35,13 @@ class Member(User): def __init__(self, data: MemberPayload, server: Server, state: State): user = state.get_user(data["_id"]["user"]) + + # due to not having a user payload and only a user object we have to manually add all the attributes instead of calling User.__init__ + flattern_user(self, user) + user._members.append(self) self._state = state - self.nickname = data.get("nickname") if avatar := data.get("avatar"): self.guild_avatar = Asset(avatar, state) @@ -49,6 +52,7 @@ def __init__(self, data: MemberPayload, server: Server, state: State): self.roles = sorted(roles, key=lambda role: role.rank, reverse=True) self.server = server + self.nickname = data.get("nickname") @property def avatar(self) -> Optional[Asset]: diff --git a/revolt/message.py b/revolt/message.py index b14fe81..9c24f43 100755 --- a/revolt/message.py +++ b/revolt/message.py @@ -13,7 +13,7 @@ from .types import Masquerade as MasqueradePayload from .types import Message as MessagePayload from .types import MessageReplyPayload - + from .server import Server __all__ = ( "Message", @@ -36,8 +36,6 @@ class Message: The embeds of the message channel: :class:`Messageable` The channel the message was sent in - server: :class:`Server` - The server the message was sent in author: Union[:class:`Member`, :class:`User`] The author of the message, will be :class:`User` in DMs edited_at: Optional[:class:`datetime.datetime`] @@ -49,7 +47,7 @@ class Message: reply_ids: list[:class:`str`] The message's ids this message has replies to """ - __slots__ = ("state", "id", "content", "attachments", "embeds", "channel", "server", "author", "edited_at", "mentions", "replies", "reply_ids") + __slots__ = ("state", "id", "content", "attachments", "embeds", "channel", "author", "edited_at", "mentions", "replies", "reply_ids") def __init__(self, data: MessagePayload, state: State): self.state = state @@ -63,10 +61,8 @@ def __init__(self, data: MessagePayload, state: State): assert isinstance(channel, Messageable) self.channel = channel - self.server = self.channel and self.channel.server - - if self.server: - author = state.get_member(self.server.id, data["author"]) + if server_id := self.channel.server_id: + author = state.get_member(server_id, data["author"]) else: author = state.get_user(data["author"]) @@ -135,6 +131,11 @@ def reply(self, *args, mention: bool = False, **kwargs): """ return self.channel.send(*args, **kwargs, replies=[MessageReply(self, mention)]) + @property + def server(self) -> Server: + """:class:`Server` The server this voice channel belongs too""" + return self.channel.server + class MessageReply(NamedTuple): """A namedtuple which represents a reply to a message. diff --git a/revolt/state.py b/revolt/state.py index 187a055..5961624 100755 --- a/revolt/state.py +++ b/revolt/state.py @@ -93,12 +93,15 @@ def get_message(self, message_id: str) -> Message: raise LookupError - async def fetch_all_server_members(self): - for server_id in self.servers: - data = await self.http.fetch_members(server_id) + async def fetch_server_members(self, server_id: str): + data = await self.http.fetch_members(server_id) + + for user in data["users"]: + self.add_user(user) - for user in data["users"]: - self.add_user(user) + for member in data["members"]: + self.add_member(server_id, member) - for member in data["members"]: - self.add_member(server_id, member) + async def fetch_all_server_members(self): + for server_id in self.servers: + await self.fetch_server_members(server_id) diff --git a/revolt/types/gateway.py b/revolt/types/gateway.py index 5461d3e..7a2580d 100755 --- a/revolt/types/gateway.py +++ b/revolt/types/gateway.py @@ -38,7 +38,8 @@ "ServerRoleUpdateEventPayload", "ServerRoleDeleteEventPayload", "UserUpdateEventPayload", - "UserRelationshipEventPayload" + "UserRelationshipEventPayload", + "ServerCreateEventPayload" ) class BasePayload(TypedDict): @@ -130,6 +131,11 @@ class ServerUpdateEventPayload(BasePayload): class ServerDeleteEventPayload(BasePayload): id: str +class ServerCreateEventPayload(BasePayload): + id: str + server: Server + channels: list[Channel] + class ServerMemberUpdateEventPayloadData(TypedDict, total=False): nickname: str avatar: File diff --git a/revolt/user.py b/revolt/user.py index b74528c..c0ea7f2 100755 --- a/revolt/user.py +++ b/revolt/user.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Literal, NamedTuple, Optional, Union from .asset import Asset, PartialAsset from .channel import DMChannel @@ -13,7 +13,7 @@ from .types import File from .types import Status as StatusPayload from .types import User as UserPayload - + from .member import Member __all__ = ("User", "Status", "Relation", "UserProfile") @@ -59,10 +59,11 @@ class User(Messageable): The dm channel between the client and the user, this will only be set if the client has dm'ed the user or :meth:`User.open_dm` was run """ __flattern_attributes__ = ("id", "bot", "owner_id", "badges", "online", "flags", "relations", "relationship", "status", "masquerade_avatar", "masquerade_name", "original_name", "original_avatar", "profile", "dm_channel") - __slots__ = (*__flattern_attributes__, "state") + __slots__ = (*__flattern_attributes__, "state", "_members") def __init__(self, data: UserPayload, state: State): self.state = state + self._members: list[Member] = [] # we store all member versions of this user to avoid having to check every guild when needing to update. self.id = data["_id"] self.original_name = data["username"] self.dm_channel = None @@ -153,6 +154,11 @@ def _update(self, *, status: Optional[StatusPayload] = None, profile_content: Op if online: self.online = online + # update user infomation for all members + + for member in self._members: + User._update(member, status=status, profile_content=profile_content, profile_background=profile_background, avatar=avatar, online=online) + async def default_avatar(self) -> bytes: """Returns the default avatar for this user diff --git a/revolt/websocket.py b/revolt/websocket.py index 7e889f8..330057e 100755 --- a/revolt/websocket.py +++ b/revolt/websocket.py @@ -15,6 +15,7 @@ from .types import (MessageDeleteEventPayload, MessageUpdateEventPayload, ServerDeleteEventPayload, ServerMemberJoinEventPayload, ServerMemberLeaveEventPayload, + ServerCreateEventPayload, ServerMemberUpdateEventPayload, ServerRoleDeleteEventPayload, ServerRoleUpdateEventPayload, ServerUpdateEventPayload, UserRelationshipEventPayload, @@ -46,7 +47,7 @@ logger = logging.getLogger("revolt") class WebsocketHandler: - __slots__ = ("session", "token", "ws_url", "dispatch", "state", "websocket", "loop", "user", "ready") + __slots__ = ("session", "token", "ws_url", "dispatch", "state", "websocket", "loop", "user", "ready", "server_events") def __init__(self, session: aiohttp.ClientSession, token: str, ws_url: str, dispatch: Callable[..., None], state: State): self.session = session @@ -58,6 +59,11 @@ def __init__(self, session: aiohttp.ClientSession, token: str, ws_url: str, disp self.loop = asyncio.get_running_loop() self.user = None self.ready = asyncio.Event() + self.server_events: dict[str, asyncio.Event] = {} + + async def _wait_for_server_ready(self, server_id: str): + if event := self.server_events.get(server_id): + await event.wait() async def send_payload(self, payload: BasePayload): if use_msgpack: @@ -120,6 +126,10 @@ async def handle_ready(self, payload: ReadyEventPayload): async def handle_message(self, payload: MessageEventPayload): message = self.state.add_message(cast(MessagePayload, payload)) + + if server := message.server: + await self._wait_for_server_ready(server.id) + self.dispatch("message", message) async def handle_messageupdate(self, payload: MessageUpdateEventPayload): @@ -140,6 +150,9 @@ async def handle_messageupdate(self, payload: MessageUpdateEventPayload): message._update(**kwargs) + if server := message.server: + await self._wait_for_server_ready(server.id) + self.dispatch("message_update", message) async def handle_messagedelete(self, payload: MessageDeleteEventPayload): @@ -151,11 +164,18 @@ async def handle_messagedelete(self, payload: MessageDeleteEventPayload): return self.state.messages.remove(message) + + if server := message.server: + await self._wait_for_server_ready(server.id) + self.dispatch("message_delete", message) async def handle_channelcreate(self, payload: ChannelCreateEventPayload): channel = self.state.add_channel(payload) + if server := channel.server: + await self._wait_for_server_ready(server.id) + self.dispatch("channel_create", channel) async def handle_channelupdate(self, payload: ChannelUpdateEventPayload): @@ -174,23 +194,35 @@ async def handle_channelupdate(self, payload: ChannelUpdateEventPayload): if isinstance(channel, (TextChannel, VoiceChannel, GroupDMChannel)): channel.description = None + if server := channel.server: + await self._wait_for_server_ready(server.id) + self.dispatch("channel_update", old_channel, channel) async def handle_channeldelete(self, payload: ChannelDeleteEventPayload): channel = self.state.channels.pop(payload["id"]) + if server := channel.server: + await self._wait_for_server_ready(server.id) + self.dispatch("channel_delete", channel) async def handle_channelstarttyping(self, payload: ChannelStartTypingEventPayload): channel = self.state.get_channel(payload["id"]) user = self.state.get_user(payload["user"]) + if server := channel.server: + await self._wait_for_server_ready(server.id) + self.dispatch("typing_start", channel, user) async def handle_channelstoptyping(self, payload: ChannelDeleteTypingEventPayload): channel = self.state.get_channel(payload["id"]) user = self.state.get_user(payload["user"]) + if server := channel.server: + await self._wait_for_server_ready(server.id) + self.dispatch("typing_stop", channel, user) async def handle_serverupdate(self, payload: ServerUpdateEventPayload): @@ -210,6 +242,8 @@ async def handle_serverupdate(self, payload: ServerUpdateEventPayload): elif clear == "Description": server.description = None + await self._wait_for_server_ready(server.id) + self.dispatch("server_update", old_server, server) async def handle_serverdelete(self, payload: ServerDeleteEventPayload): @@ -218,9 +252,26 @@ async def handle_serverdelete(self, payload: ServerDeleteEventPayload): for channel in server.channels: del self.state.channels[channel.id] + await self._wait_for_server_ready(server.id) + self.dispatch("server_delete", server) + async def handle_servercreate(self, payload: ServerCreateEventPayload): + for channel in payload["channels"]: + self.state.add_channel(channel) + + server = self.state.add_server(payload["server"]) + + # lock all server events until we fetch all the members, otherwise the cache will be incomplete + self.server_events[server.id] = asyncio.Event() + await self.state.fetch_server_members(server.id) + self.server_events.pop(server.id).set() + + self.dispatch("server_create", server) + async def handle_servermemberupdate(self, payload: ServerMemberUpdateEventPayload): + await self._wait_for_server_ready(payload["id"]["server"]) + member = self.state.get_member(payload["id"]["server"], payload["id"]["user"]) old_member = copy(member) @@ -239,9 +290,16 @@ async def handle_servermemberjoin(self, payload: ServerMemberJoinEventPayload): self.dispatch("member_join", member) async def handle_memberleave(self, payload: ServerMemberLeaveEventPayload): + await self._wait_for_server_ready(payload["id"]) + server = self.state.get_server(payload["id"]) member = server._members.pop(payload["user"]) + # remove the member from the user + + user = self.state.get_user(payload["user"]) + user._members.remove(member) + self.dispatch("member_leave", member) async def handle_serveroleupdate(self, payload: ServerRoleUpdateEventPayload): @@ -255,12 +313,16 @@ async def handle_serveroleupdate(self, payload: ServerRoleUpdateEventPayload): role._update(**payload["data"]) + await self._wait_for_server_ready(server.id) + self.dispatch("role_update", old_role, role) async def handle_serverroledelete(self, payload: ServerRoleDeleteEventPayload): server = self.state.get_server(payload["id"]) role = server._roles.pop(payload["role_id"]) + await self._wait_for_server_ready(server.id) + self.dispatch("role_delete", role) async def handle_userupdate(self, payload: UserUpdateEventPayload):