Skip to content

Commit

Permalink
handle server join event
Browse files Browse the repository at this point in the history
  • Loading branch information
Zomatree committed Aug 8, 2022
1 parent 64ccd58 commit 1ab9165
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 21 deletions.
6 changes: 5 additions & 1 deletion revolt/member.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ class Member(User):

def __init__(self, data: MemberPayload, server: Server, state: State):
user = state.get_user(data["_id"]["user"])

# due to not having a user payload and only a user object we have to manually add all the attributes instead of calling User.__init__

flattern_user(self, user)
user._members.append(self)

self._state = state
self.nickname = data.get("nickname")

if avatar := data.get("avatar"):
self.guild_avatar = Asset(avatar, state)
Expand All @@ -49,6 +52,7 @@ def __init__(self, data: MemberPayload, server: Server, state: State):
self.roles = sorted(roles, key=lambda role: role.rank, reverse=True)

self.server = server
self.nickname = data.get("nickname")

@property
def avatar(self) -> Optional[Asset]:
Expand Down
17 changes: 9 additions & 8 deletions revolt/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .types import Masquerade as MasqueradePayload
from .types import Message as MessagePayload
from .types import MessageReplyPayload

from .server import Server

__all__ = (
"Message",
Expand All @@ -36,8 +36,6 @@ class Message:
The embeds of the message
channel: :class:`Messageable`
The channel the message was sent in
server: :class:`Server`
The server the message was sent in
author: Union[:class:`Member`, :class:`User`]
The author of the message, will be :class:`User` in DMs
edited_at: Optional[:class:`datetime.datetime`]
Expand All @@ -49,7 +47,7 @@ class Message:
reply_ids: list[:class:`str`]
The message's ids this message has replies to
"""
__slots__ = ("state", "id", "content", "attachments", "embeds", "channel", "server", "author", "edited_at", "mentions", "replies", "reply_ids")
__slots__ = ("state", "id", "content", "attachments", "embeds", "channel", "author", "edited_at", "mentions", "replies", "reply_ids")

def __init__(self, data: MessagePayload, state: State):
self.state = state
Expand All @@ -63,10 +61,8 @@ def __init__(self, data: MessagePayload, state: State):
assert isinstance(channel, Messageable)
self.channel = channel

self.server = self.channel and self.channel.server

if self.server:
author = state.get_member(self.server.id, data["author"])
if server_id := self.channel.server_id:
author = state.get_member(server_id, data["author"])
else:
author = state.get_user(data["author"])

Expand Down Expand Up @@ -135,6 +131,11 @@ def reply(self, *args, mention: bool = False, **kwargs):
"""
return self.channel.send(*args, **kwargs, replies=[MessageReply(self, mention)])

@property
def server(self) -> Server:
""":class:`Server` The server this voice channel belongs too"""
return self.channel.server

class MessageReply(NamedTuple):
"""A namedtuple which represents a reply to a message.
Expand Down
17 changes: 10 additions & 7 deletions revolt/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,15 @@ def get_message(self, message_id: str) -> Message:

raise LookupError

async def fetch_all_server_members(self):
for server_id in self.servers:
data = await self.http.fetch_members(server_id)
async def fetch_server_members(self, server_id: str):
data = await self.http.fetch_members(server_id)

for user in data["users"]:
self.add_user(user)

for user in data["users"]:
self.add_user(user)
for member in data["members"]:
self.add_member(server_id, member)

for member in data["members"]:
self.add_member(server_id, member)
async def fetch_all_server_members(self):
for server_id in self.servers:
await self.fetch_server_members(server_id)
8 changes: 7 additions & 1 deletion revolt/types/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
"ServerRoleUpdateEventPayload",
"ServerRoleDeleteEventPayload",
"UserUpdateEventPayload",
"UserRelationshipEventPayload"
"UserRelationshipEventPayload",
"ServerCreateEventPayload"
)

class BasePayload(TypedDict):
Expand Down Expand Up @@ -130,6 +131,11 @@ class ServerUpdateEventPayload(BasePayload):
class ServerDeleteEventPayload(BasePayload):
id: str

class ServerCreateEventPayload(BasePayload):
id: str
server: Server
channels: list[Channel]

class ServerMemberUpdateEventPayloadData(TypedDict, total=False):
nickname: str
avatar: File
Expand Down
12 changes: 9 additions & 3 deletions revolt/user.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, NamedTuple, Optional, Union
from typing import TYPE_CHECKING, Literal, NamedTuple, Optional, Union

from .asset import Asset, PartialAsset
from .channel import DMChannel
Expand All @@ -13,7 +13,7 @@
from .types import File
from .types import Status as StatusPayload
from .types import User as UserPayload

from .member import Member

__all__ = ("User", "Status", "Relation", "UserProfile")

Expand Down Expand Up @@ -59,10 +59,11 @@ class User(Messageable):
The dm channel between the client and the user, this will only be set if the client has dm'ed the user or :meth:`User.open_dm` was run
"""
__flattern_attributes__ = ("id", "bot", "owner_id", "badges", "online", "flags", "relations", "relationship", "status", "masquerade_avatar", "masquerade_name", "original_name", "original_avatar", "profile", "dm_channel")
__slots__ = (*__flattern_attributes__, "state")
__slots__ = (*__flattern_attributes__, "state", "_members")

def __init__(self, data: UserPayload, state: State):
self.state = state
self._members: list[Member] = [] # we store all member versions of this user to avoid having to check every guild when needing to update.
self.id = data["_id"]
self.original_name = data["username"]
self.dm_channel = None
Expand Down Expand Up @@ -153,6 +154,11 @@ def _update(self, *, status: Optional[StatusPayload] = None, profile_content: Op
if online:
self.online = online

# update user infomation for all members

for member in self._members:
User._update(member, status=status, profile_content=profile_content, profile_background=profile_background, avatar=avatar, online=online)

async def default_avatar(self) -> bytes:
"""Returns the default avatar for this user
Expand Down
64 changes: 63 additions & 1 deletion revolt/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .types import (MessageDeleteEventPayload, MessageUpdateEventPayload,
ServerDeleteEventPayload, ServerMemberJoinEventPayload,
ServerMemberLeaveEventPayload,
ServerCreateEventPayload,
ServerMemberUpdateEventPayload,
ServerRoleDeleteEventPayload, ServerRoleUpdateEventPayload,
ServerUpdateEventPayload, UserRelationshipEventPayload,
Expand Down Expand Up @@ -46,7 +47,7 @@
logger = logging.getLogger("revolt")

class WebsocketHandler:
__slots__ = ("session", "token", "ws_url", "dispatch", "state", "websocket", "loop", "user", "ready")
__slots__ = ("session", "token", "ws_url", "dispatch", "state", "websocket", "loop", "user", "ready", "server_events")

def __init__(self, session: aiohttp.ClientSession, token: str, ws_url: str, dispatch: Callable[..., None], state: State):
self.session = session
Expand All @@ -58,6 +59,11 @@ def __init__(self, session: aiohttp.ClientSession, token: str, ws_url: str, disp
self.loop = asyncio.get_running_loop()
self.user = None
self.ready = asyncio.Event()
self.server_events: dict[str, asyncio.Event] = {}

async def _wait_for_server_ready(self, server_id: str):
if event := self.server_events.get(server_id):
await event.wait()

async def send_payload(self, payload: BasePayload):
if use_msgpack:
Expand Down Expand Up @@ -120,6 +126,10 @@ async def handle_ready(self, payload: ReadyEventPayload):

async def handle_message(self, payload: MessageEventPayload):
message = self.state.add_message(cast(MessagePayload, payload))

if server := message.server:
await self._wait_for_server_ready(server.id)

self.dispatch("message", message)

async def handle_messageupdate(self, payload: MessageUpdateEventPayload):
Expand All @@ -140,6 +150,9 @@ async def handle_messageupdate(self, payload: MessageUpdateEventPayload):

message._update(**kwargs)

if server := message.server:
await self._wait_for_server_ready(server.id)

self.dispatch("message_update", message)

async def handle_messagedelete(self, payload: MessageDeleteEventPayload):
Expand All @@ -151,11 +164,18 @@ async def handle_messagedelete(self, payload: MessageDeleteEventPayload):
return

self.state.messages.remove(message)

if server := message.server:
await self._wait_for_server_ready(server.id)

self.dispatch("message_delete", message)

async def handle_channelcreate(self, payload: ChannelCreateEventPayload):
channel = self.state.add_channel(payload)

if server := channel.server:
await self._wait_for_server_ready(server.id)

self.dispatch("channel_create", channel)

async def handle_channelupdate(self, payload: ChannelUpdateEventPayload):
Expand All @@ -174,23 +194,35 @@ async def handle_channelupdate(self, payload: ChannelUpdateEventPayload):
if isinstance(channel, (TextChannel, VoiceChannel, GroupDMChannel)):
channel.description = None

if server := channel.server:
await self._wait_for_server_ready(server.id)

self.dispatch("channel_update", old_channel, channel)

async def handle_channeldelete(self, payload: ChannelDeleteEventPayload):
channel = self.state.channels.pop(payload["id"])

if server := channel.server:
await self._wait_for_server_ready(server.id)

self.dispatch("channel_delete", channel)

async def handle_channelstarttyping(self, payload: ChannelStartTypingEventPayload):
channel = self.state.get_channel(payload["id"])
user = self.state.get_user(payload["user"])

if server := channel.server:
await self._wait_for_server_ready(server.id)

self.dispatch("typing_start", channel, user)

async def handle_channelstoptyping(self, payload: ChannelDeleteTypingEventPayload):
channel = self.state.get_channel(payload["id"])
user = self.state.get_user(payload["user"])

if server := channel.server:
await self._wait_for_server_ready(server.id)

self.dispatch("typing_stop", channel, user)

async def handle_serverupdate(self, payload: ServerUpdateEventPayload):
Expand All @@ -210,6 +242,8 @@ async def handle_serverupdate(self, payload: ServerUpdateEventPayload):
elif clear == "Description":
server.description = None

await self._wait_for_server_ready(server.id)

self.dispatch("server_update", old_server, server)

async def handle_serverdelete(self, payload: ServerDeleteEventPayload):
Expand All @@ -218,9 +252,26 @@ async def handle_serverdelete(self, payload: ServerDeleteEventPayload):
for channel in server.channels:
del self.state.channels[channel.id]

await self._wait_for_server_ready(server.id)

self.dispatch("server_delete", server)

async def handle_servercreate(self, payload: ServerCreateEventPayload):
for channel in payload["channels"]:
self.state.add_channel(channel)

server = self.state.add_server(payload["server"])

# lock all server events until we fetch all the members, otherwise the cache will be incomplete
self.server_events[server.id] = asyncio.Event()
await self.state.fetch_server_members(server.id)
self.server_events.pop(server.id).set()

self.dispatch("server_create", server)

async def handle_servermemberupdate(self, payload: ServerMemberUpdateEventPayload):
await self._wait_for_server_ready(payload["id"]["server"])

member = self.state.get_member(payload["id"]["server"], payload["id"]["user"])
old_member = copy(member)

Expand All @@ -239,9 +290,16 @@ async def handle_servermemberjoin(self, payload: ServerMemberJoinEventPayload):
self.dispatch("member_join", member)

async def handle_memberleave(self, payload: ServerMemberLeaveEventPayload):
await self._wait_for_server_ready(payload["id"])

server = self.state.get_server(payload["id"])
member = server._members.pop(payload["user"])

# remove the member from the user

user = self.state.get_user(payload["user"])
user._members.remove(member)

self.dispatch("member_leave", member)

async def handle_serveroleupdate(self, payload: ServerRoleUpdateEventPayload):
Expand All @@ -255,12 +313,16 @@ async def handle_serveroleupdate(self, payload: ServerRoleUpdateEventPayload):

role._update(**payload["data"])

await self._wait_for_server_ready(server.id)

self.dispatch("role_update", old_role, role)

async def handle_serverroledelete(self, payload: ServerRoleDeleteEventPayload):
server = self.state.get_server(payload["id"])
role = server._roles.pop(payload["role_id"])

await self._wait_for_server_ready(server.id)

self.dispatch("role_delete", role)

async def handle_userupdate(self, payload: UserUpdateEventPayload):
Expand Down

0 comments on commit 1ab9165

Please sign in to comment.