Skip to content

Commit

Permalink
Update buttons logic (add ButtonsInfo)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlberLC committed Jun 19, 2022
1 parent 94eb13e commit 6cd4dad
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 119 deletions.
43 changes: 24 additions & 19 deletions multibot/bots/discord_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,17 @@
import io
import pathlib
import random
from typing import Iterable
from typing import Any, Iterable

import discord
import flanautils
from discord.ext.commands import Bot
from discord.ui import Button, View
from flanautils import Media, MediaType, NotFoundError, OrderedSet, return_if_first_empty

from multibot import constants
from multibot.bots.multi_bot import MultiBot, parse_arguments
from multibot.exceptions import LimitError, SendError, UserDisconnectedError
from multibot.models import Chat, Message, Mute, Platform, Role, User
from multibot.models import Button, ButtonsInfo, Chat, Message, Mute, Platform, Role, User


# --------------------------------------------------------------------------------------------------- #
Expand Down Expand Up @@ -106,7 +105,7 @@ async def _get_button_pressed_text(self, event: constants.DISCORD_EVENT) -> str
return button.label

@return_if_first_empty(exclude_self_types='DiscordBot', globals_=globals())
async def _get_button_pressed_user(self, event: constants.DISCORD_EVENT) -> User | None:
async def _get_button_presser_user(self, event: constants.DISCORD_EVENT) -> User | None:
try:
return self._create_user_from_discord_user(event.user)
except AttributeError:
Expand Down Expand Up @@ -419,42 +418,44 @@ async def send(
self,
text='',
media: Media = None,
buttons: list[str | list[str]] | None = None,
buttons: list[str | tuple[str, bool] | Button | list[str | tuple[str, bool] | Button]] | None = None,
chat: int | str | User | Chat | Message | None = None,
message: Message = None,
*,
buttons_key: Any = None,
reply_to: int | str | Message = None,
silent: bool = False,
send_as_file: bool = None,
edit=False
) -> Message | None:
def create_view() -> View | None:
if not buttons:
return
text = self._parse_html_to_discord_markdown(text)
file = await self._prepare_media_to_send(media)
view = None

view_ = View(timeout=None)
if buttons:
view = discord.ui.View(timeout=None)
for i, row in enumerate(buttons):
for button_text in row:
discord_button = Button(label=button_text, row=i)
for button in row:
discord_button = discord.ui.Button(label=button.text, row=i)
discord_button.callback = self._on_button_press_raw
view_.add_item(discord_button)

return view_

text = self._parse_html_to_discord_markdown(text)
file = await self._prepare_media_to_send(media)
view.add_item(discord_button)

if edit:
kwargs = {}
if file:
kwargs['attachments'] = [file]
if buttons is not None:
kwargs['view'] = create_view()
kwargs['view'] = view
message.buttons_info.buttons = buttons
if buttons_key is not None:
message.buttons_info.key = buttons_key

message.original_object = await message.original_object.edit(content=text, **kwargs)
if content := getattr(media, 'content', None):
message.contents = [content]
message.update_last_edit()
message.save()

return message

match reply_to:
Expand All @@ -466,7 +467,7 @@ def create_view() -> View | None:
reply_to = message_to_reply.original_object

try:
bot_message = await self._get_message(await chat.original_object.send(text, file=file, view=create_view(), reference=reply_to))
bot_message = await self._get_message(await chat.original_object.send(text, file=file, view=view, reference=reply_to))
except discord.errors.HTTPException as e:
if 'too large' in str(e).lower():
if random.randint(0, 10):
Expand All @@ -476,9 +477,13 @@ def create_view() -> View | None:
await self._manage_exceptions(SendError(error_message), chat)
return
raise e

bot_message.buttons_info = ButtonsInfo(buttons=buttons, key=buttons_key)
if content := getattr(media, 'content', None):
bot_message.contents = [content]

bot_message.save()

return bot_message

def start(self):
Expand Down
126 changes: 75 additions & 51 deletions multibot/bots/multi_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
import telethon.events
import telethon.events.common
from flanautils import AmbiguityError, Media, NotFoundError, OrderedSet, RatioMatch, return_if_first_empty, shift_args_if_called

from multibot import constants
from multibot.exceptions import LimitError, SendError
from multibot.models import Ban, BotAction, Chat, Message, Mute, Platform, RegisteredCallback, Role, User
from multibot.models import Ban, BotAction, Button, ButtonsGroup, Chat, Message, Mute, Platform, RegisteredButtonCallback, RegisteredCallback, Role, User


# ---------------------------------------------------------- #
Expand Down Expand Up @@ -145,10 +144,36 @@ async def wrapper(self: MultiBot, message: Message, *_args, **_kwargs):
def parse_arguments(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs) -> Any:
def parse_buttons(buttons_) -> list[list[Button]] | None:
match buttons_:
case [str(), *_] as buttons_:
buttons_ = [[Button(button_text, False) for button_text in buttons_]]
case [(str(), bool()), *_] as buttons_:
buttons_ = [[Button(button_text, is_checked) for button_text, is_checked in buttons_]]
case [Button(), *_] as buttons_:
buttons_ = [list(buttons_)]
case [[str(), *_], *_] as buttons_:
buttons_ = list(buttons_)
for i, buttons_row in enumerate(buttons_):
buttons_[i] = [Button(button_text, False) for button_text in buttons_row]
case [[(str(), bool()), *_], *_] as buttons_:
buttons_ = list(buttons_)
for i, buttons_row in enumerate(buttons_):
buttons_[i] = [Button(button_text, is_checked) for button_text, is_checked in buttons_row]
case [[Button(), *_], *_] as buttons_:
buttons_ = list(buttons_)
for i, buttons_row in enumerate(buttons_):
buttons_[i] = list(buttons_row)
case [] as buttons_:
pass
case _:
return
return buttons_

self: MultiBot | None = None
text = ''
media: Media | None = None
buttons: list[str | list[str]] | None = None
buttons: list[str | tuple[str, bool] | Button | list[str | tuple[str, bool] | Button]] | None = None
chat: int | str | User | Chat | Message | None = None
message: Message | None = None

Expand All @@ -162,20 +187,18 @@ async def wrapper(*args, **kwargs) -> Any:
text = str(number)
case Media() as media:
pass
case [str(), *_] as buttons:
buttons = [list(buttons)]
case [[str(), *_], *_] as buttons:
buttons = list(buttons)
for i, buttons_row in enumerate(buttons):
buttons[i] = list(buttons_row)
case User() as user:
chat = user
case Chat() as chat:
pass
case Message() as message:
pass
case _:
buttons = parse_buttons(arg)

chat = await self.get_chat(kwargs.get('chat', chat))
if 'buttons' in kwargs:
buttons = parse_buttons(kwargs['buttons'])
reply_to = kwargs.get('reply_to', None)
edit = kwargs.get('edit', None)

Expand All @@ -194,9 +217,10 @@ async def wrapper(*args, **kwargs) -> Any:
if not message and isinstance(reply_to, Message):
message = reply_to

for arg_name in ('self', 'text', 'media', 'buttons', 'message'):
for arg_name in ('self', 'text', 'media', 'message'):
if arg_name not in kwargs:
kwargs[arg_name] = locals()[arg_name]
kwargs['buttons'] = buttons
kwargs['chat'] = chat

return await func(**kwargs)
Expand Down Expand Up @@ -239,7 +263,7 @@ def __init__(self, bot_token: str, bot_client: T):
self.token: str = bot_token
self.client: T = bot_client
self._registered_callbacks: list[RegisteredCallback] = []
self._registered_button_callbacks: list[RegisteredCallback] = []
self._registered_button_callbacks: list[RegisteredButtonCallback] = []

self._add_handlers()

Expand Down Expand Up @@ -268,7 +292,7 @@ def _add_handlers(self):

self.register(self._on_users, constants.KEYWORDS['user'])

self.register_button(self._on_users_button_press, always=True)
self.register_button(self._on_users_button_press, ButtonsGroup.USERS)

async def _ban(self, user: int | str | User, group_: int | str | Chat | Message, message: Message = None):
pass
Expand Down Expand Up @@ -306,7 +330,7 @@ async def _get_button_pressed_text(self, event: constants.MESSAGE_EVENT) -> str
pass

@return_if_first_empty(exclude_self_types='MultiBot', globals_=globals())
async def _get_button_pressed_user(self, event: constants.MESSAGE_EVENT) -> User | None:
async def _get_button_presser_user(self, event: constants.MESSAGE_EVENT) -> User | None:
pass

@abstractmethod
Expand All @@ -328,15 +352,18 @@ async def _get_message(self, event: constants.MESSAGE_EVENT) -> Message:
id=await self._get_message_id(original_message),
author=await self._get_author(original_message),
text=await self._get_text(original_message),
button_pressed_text=await self._get_button_pressed_text(event),
button_pressed_user=await self._get_button_pressed_user(event),
mentions=await self._get_mentions(original_message),
chat=await self._get_chat(original_message),
replied_message=await self._get_replied_message(original_message),
is_inline=isinstance(event, telethon.events.InlineQuery.Event) if isinstance(event, constants.TELEGRAM_EVENT | constants.TELEGRAM_MESSAGE) else None,
original_object=original_message,
original_event=event
)
message.resolve()
message.pull_from_database()
if message.buttons_info:
message.buttons_info.pressed_text = await self._get_button_pressed_text(event)
message.buttons_info.presser_user = await self._get_button_presser_user(event)
message.save(pull_overwrite_fields=('_id', 'config'))

return message
Expand Down Expand Up @@ -463,12 +490,11 @@ async def _on_ban(self, message: Message):

@find_message
async def _on_button_press_raw(self, message: Message):
try:
registered_callbacks = self._parse_callbacks(message.button_pressed_text, self._registered_button_callbacks)
except AmbiguityError as e:
await self._manage_exceptions(e, message)
else:
for registered_callback in registered_callbacks:
if getattr(message.buttons_info, 'key', None) is None:
return

for registered_callback in self._registered_button_callbacks:
if registered_callback.key == message.buttons_info.key:
await registered_callback(message)

@inline(False)
Expand Down Expand Up @@ -539,45 +565,30 @@ async def _on_users(self, message: Message):
f"{joined_user_names}\n\n"
f"<b>Filtrar usuarios por roles:<b>",
flanautils.chunks([f'❌ {role_name}' for role_name in role_names], 5),
message
message,
buttons_key=ButtonsGroup.USERS
)

async def _on_users_button_press(self, message: Message):
await self._accept_button_event(message)

try:
button_role_name = message.button_pressed_text.split(maxsplit=1)[1]
button_role_name = message.buttons_info.pressed_text.split(maxsplit=1)[1]
except IndexError:
return
if not (role_names := [role.name for role in await self.get_roles(message.chat)]) or button_role_name not in role_names:
return

await self._accept_button_event(message)

new_buttons = []
selected_role_names = []
for row in message.original_object.components:
new_row = []
for button in row.children:
emoji, role_name = button.label.split(maxsplit=1)
if role_name == button_role_name:
if emoji == '✔':
new_emoji = '❌'
else:
new_emoji = '✔'
selected_role_names.append(role_name)
new_row.append(f'{new_emoji} {role_name}')
else:
if emoji == '✔':
selected_role_names.append(role_name)
new_row.append(button.label)
new_buttons.append(new_row)
pressed_button = message.buttons_info[message.buttons_info.pressed_text]
pressed_button.is_checked = not pressed_button.is_checked
pressed_button.text = f"{'✔' if pressed_button.is_checked else '❌'} {button_role_name}"

await self.edit(new_buttons, message)
selected_role_names = [checked_button.text.split(maxsplit=1)[1] for checked_button in message.buttons_info.checked_buttons()]
user_names = [f'<@{user.id}>' for user in await self.find_users_by_roles(selected_role_names, message)]
joined_user_names = ', '.join(user_names)
await self.edit(
f"<b>{len(user_names)} usuario{'' if len(user_names) == 1 else 's'}:<b>\n"
f"{joined_user_names}\n\n"
f"<b>Filtrar usuarios por roles:<b>",
message.buttons_info.buttons,
message
)

Expand Down Expand Up @@ -731,24 +742,37 @@ def decorator(func):
return decorator(func_) if func_ else decorator

@overload
def register_button(self, func_: Callable = None, keywords=(), min_ratio=constants.PARSE_BUTTON_CALLBACKS_MIN_RATIO_DEFAULT, always=False, default=False):
def register_button(self, func_: Callable = None, key: Any = None):
pass

@overload
def register_button(self, keywords=(), min_ratio=constants.PARSE_BUTTON_CALLBACKS_MIN_RATIO_DEFAULT, always=False, default=False):
def register_button(self, key: Any = None):
pass

@shift_args_if_called(exclude_self_types='MultiBot', globals_=globals())
def register_button(self, func_: Callable = None, keywords: str | Iterable[str | Iterable[str]] = (), min_ratio=constants.PARSE_BUTTON_CALLBACKS_MIN_RATIO_DEFAULT, always=False, default=False):
def register_button(self, func_: Callable = None, key: Any = None):
def decorator(func):
self._registered_button_callbacks.append(RegisteredCallback(func, keywords, min_ratio, always, default))
self._registered_button_callbacks.append(RegisteredButtonCallback(func, key))
return func

return decorator(func_) if func_ else decorator

@abstractmethod
@parse_arguments
async def send(self, text='', media: Media = None, buttons: list[str | list[str]] | None = None, chat: int | str | User | Chat | Message | None = None, message: Message = None, *, reply_to: int | str | Message = None, silent: bool = False, send_as_file: bool = None, edit=False) -> Message | None:
async def send(
self,
text='',
media: Media = None,
buttons: list[str | tuple[str, bool] | list[str | tuple[str, bool]]] | None = None,
chat: int | str | User | Chat | Message | None = None,
message: Message = None,
*,
buttons_key: Any = None,
reply_to: int | str | Message = None,
silent: bool = False,
send_as_file: bool = None,
edit=False
) -> Message | None:
pass

@parse_arguments
Expand Down
Loading

0 comments on commit 6cd4dad

Please sign in to comment.