diff --git a/multibot/bots/discord_bot.py b/multibot/bots/discord_bot.py index 4378efc..62edb60 100644 --- a/multibot/bots/discord_bot.py +++ b/multibot/bots/discord_bot.py @@ -6,18 +6,17 @@ import io import pathlib import random -from typing import Iterable +from typing import Any, Iterable import discord import flanautils from discord.ext.commands import Bot -from discord.ui import Button, View from flanautils import Media, MediaType, NotFoundError, OrderedSet, return_if_first_empty from multibot import constants from multibot.bots.multi_bot import MultiBot, parse_arguments from multibot.exceptions import LimitError, SendError, UserDisconnectedError -from multibot.models import Chat, Message, Mute, Platform, Role, User +from multibot.models import Button, ButtonsInfo, Chat, Message, Mute, Platform, Role, User # --------------------------------------------------------------------------------------------------- # @@ -106,7 +105,7 @@ async def _get_button_pressed_text(self, event: constants.DISCORD_EVENT) -> str return button.label @return_if_first_empty(exclude_self_types='DiscordBot', globals_=globals()) - async def _get_button_pressed_user(self, event: constants.DISCORD_EVENT) -> User | None: + async def _get_button_presser_user(self, event: constants.DISCORD_EVENT) -> User | None: try: return self._create_user_from_discord_user(event.user) except AttributeError: @@ -419,42 +418,44 @@ async def send( self, text='', media: Media = None, - buttons: list[str | list[str]] | None = None, + buttons: list[str | tuple[str, bool] | Button | list[str | tuple[str, bool] | Button]] | None = None, chat: int | str | User | Chat | Message | None = None, message: Message = None, *, + buttons_key: Any = None, reply_to: int | str | Message = None, silent: bool = False, send_as_file: bool = None, edit=False ) -> Message | None: - def create_view() -> View | None: - if not buttons: - return + text = self._parse_html_to_discord_markdown(text) + file = await self._prepare_media_to_send(media) + view = None - view_ = View(timeout=None) + if buttons: + view = discord.ui.View(timeout=None) for i, row in enumerate(buttons): - for button_text in row: - discord_button = Button(label=button_text, row=i) + for button in row: + discord_button = discord.ui.Button(label=button.text, row=i) discord_button.callback = self._on_button_press_raw - view_.add_item(discord_button) - - return view_ - - text = self._parse_html_to_discord_markdown(text) - file = await self._prepare_media_to_send(media) + view.add_item(discord_button) if edit: kwargs = {} if file: kwargs['attachments'] = [file] if buttons is not None: - kwargs['view'] = create_view() + kwargs['view'] = view + message.buttons_info.buttons = buttons + if buttons_key is not None: + message.buttons_info.key = buttons_key + message.original_object = await message.original_object.edit(content=text, **kwargs) if content := getattr(media, 'content', None): message.contents = [content] message.update_last_edit() message.save() + return message match reply_to: @@ -466,7 +467,7 @@ def create_view() -> View | None: reply_to = message_to_reply.original_object try: - bot_message = await self._get_message(await chat.original_object.send(text, file=file, view=create_view(), reference=reply_to)) + bot_message = await self._get_message(await chat.original_object.send(text, file=file, view=view, reference=reply_to)) except discord.errors.HTTPException as e: if 'too large' in str(e).lower(): if random.randint(0, 10): @@ -476,9 +477,13 @@ def create_view() -> View | None: await self._manage_exceptions(SendError(error_message), chat) return raise e + + bot_message.buttons_info = ButtonsInfo(buttons=buttons, key=buttons_key) if content := getattr(media, 'content', None): bot_message.contents = [content] + bot_message.save() + return bot_message def start(self): diff --git a/multibot/bots/multi_bot.py b/multibot/bots/multi_bot.py index 2e34095..1d53a38 100644 --- a/multibot/bots/multi_bot.py +++ b/multibot/bots/multi_bot.py @@ -25,10 +25,9 @@ import telethon.events import telethon.events.common from flanautils import AmbiguityError, Media, NotFoundError, OrderedSet, RatioMatch, return_if_first_empty, shift_args_if_called - from multibot import constants from multibot.exceptions import LimitError, SendError -from multibot.models import Ban, BotAction, Chat, Message, Mute, Platform, RegisteredCallback, Role, User +from multibot.models import Ban, BotAction, Button, ButtonsGroup, Chat, Message, Mute, Platform, RegisteredButtonCallback, RegisteredCallback, Role, User # ---------------------------------------------------------- # @@ -145,10 +144,36 @@ async def wrapper(self: MultiBot, message: Message, *_args, **_kwargs): def parse_arguments(func: Callable) -> Callable: @functools.wraps(func) async def wrapper(*args, **kwargs) -> Any: + def parse_buttons(buttons_) -> list[list[Button]] | None: + match buttons_: + case [str(), *_] as buttons_: + buttons_ = [[Button(button_text, False) for button_text in buttons_]] + case [(str(), bool()), *_] as buttons_: + buttons_ = [[Button(button_text, is_checked) for button_text, is_checked in buttons_]] + case [Button(), *_] as buttons_: + buttons_ = [list(buttons_)] + case [[str(), *_], *_] as buttons_: + buttons_ = list(buttons_) + for i, buttons_row in enumerate(buttons_): + buttons_[i] = [Button(button_text, False) for button_text in buttons_row] + case [[(str(), bool()), *_], *_] as buttons_: + buttons_ = list(buttons_) + for i, buttons_row in enumerate(buttons_): + buttons_[i] = [Button(button_text, is_checked) for button_text, is_checked in buttons_row] + case [[Button(), *_], *_] as buttons_: + buttons_ = list(buttons_) + for i, buttons_row in enumerate(buttons_): + buttons_[i] = list(buttons_row) + case [] as buttons_: + pass + case _: + return + return buttons_ + self: MultiBot | None = None text = '' media: Media | None = None - buttons: list[str | list[str]] | None = None + buttons: list[str | tuple[str, bool] | Button | list[str | tuple[str, bool] | Button]] | None = None chat: int | str | User | Chat | Message | None = None message: Message | None = None @@ -162,20 +187,18 @@ async def wrapper(*args, **kwargs) -> Any: text = str(number) case Media() as media: pass - case [str(), *_] as buttons: - buttons = [list(buttons)] - case [[str(), *_], *_] as buttons: - buttons = list(buttons) - for i, buttons_row in enumerate(buttons): - buttons[i] = list(buttons_row) case User() as user: chat = user case Chat() as chat: pass case Message() as message: pass + case _: + buttons = parse_buttons(arg) chat = await self.get_chat(kwargs.get('chat', chat)) + if 'buttons' in kwargs: + buttons = parse_buttons(kwargs['buttons']) reply_to = kwargs.get('reply_to', None) edit = kwargs.get('edit', None) @@ -194,9 +217,10 @@ async def wrapper(*args, **kwargs) -> Any: if not message and isinstance(reply_to, Message): message = reply_to - for arg_name in ('self', 'text', 'media', 'buttons', 'message'): + for arg_name in ('self', 'text', 'media', 'message'): if arg_name not in kwargs: kwargs[arg_name] = locals()[arg_name] + kwargs['buttons'] = buttons kwargs['chat'] = chat return await func(**kwargs) @@ -239,7 +263,7 @@ def __init__(self, bot_token: str, bot_client: T): self.token: str = bot_token self.client: T = bot_client self._registered_callbacks: list[RegisteredCallback] = [] - self._registered_button_callbacks: list[RegisteredCallback] = [] + self._registered_button_callbacks: list[RegisteredButtonCallback] = [] self._add_handlers() @@ -268,7 +292,7 @@ def _add_handlers(self): self.register(self._on_users, constants.KEYWORDS['user']) - self.register_button(self._on_users_button_press, always=True) + self.register_button(self._on_users_button_press, ButtonsGroup.USERS) async def _ban(self, user: int | str | User, group_: int | str | Chat | Message, message: Message = None): pass @@ -306,7 +330,7 @@ async def _get_button_pressed_text(self, event: constants.MESSAGE_EVENT) -> str pass @return_if_first_empty(exclude_self_types='MultiBot', globals_=globals()) - async def _get_button_pressed_user(self, event: constants.MESSAGE_EVENT) -> User | None: + async def _get_button_presser_user(self, event: constants.MESSAGE_EVENT) -> User | None: pass @abstractmethod @@ -328,8 +352,6 @@ async def _get_message(self, event: constants.MESSAGE_EVENT) -> Message: id=await self._get_message_id(original_message), author=await self._get_author(original_message), text=await self._get_text(original_message), - button_pressed_text=await self._get_button_pressed_text(event), - button_pressed_user=await self._get_button_pressed_user(event), mentions=await self._get_mentions(original_message), chat=await self._get_chat(original_message), replied_message=await self._get_replied_message(original_message), @@ -337,6 +359,11 @@ async def _get_message(self, event: constants.MESSAGE_EVENT) -> Message: original_object=original_message, original_event=event ) + message.resolve() + message.pull_from_database() + if message.buttons_info: + message.buttons_info.pressed_text = await self._get_button_pressed_text(event) + message.buttons_info.presser_user = await self._get_button_presser_user(event) message.save(pull_overwrite_fields=('_id', 'config')) return message @@ -463,12 +490,11 @@ async def _on_ban(self, message: Message): @find_message async def _on_button_press_raw(self, message: Message): - try: - registered_callbacks = self._parse_callbacks(message.button_pressed_text, self._registered_button_callbacks) - except AmbiguityError as e: - await self._manage_exceptions(e, message) - else: - for registered_callback in registered_callbacks: + if getattr(message.buttons_info, 'key', None) is None: + return + + for registered_callback in self._registered_button_callbacks: + if registered_callback.key == message.buttons_info.key: await registered_callback(message) @inline(False) @@ -539,45 +565,30 @@ async def _on_users(self, message: Message): f"{joined_user_names}\n\n" f"Filtrar usuarios por roles:", flanautils.chunks([f'❌ {role_name}' for role_name in role_names], 5), - message + message, + buttons_key=ButtonsGroup.USERS ) async def _on_users_button_press(self, message: Message): + await self._accept_button_event(message) + try: - button_role_name = message.button_pressed_text.split(maxsplit=1)[1] + button_role_name = message.buttons_info.pressed_text.split(maxsplit=1)[1] except IndexError: return - if not (role_names := [role.name for role in await self.get_roles(message.chat)]) or button_role_name not in role_names: - return - await self._accept_button_event(message) - - new_buttons = [] - selected_role_names = [] - for row in message.original_object.components: - new_row = [] - for button in row.children: - emoji, role_name = button.label.split(maxsplit=1) - if role_name == button_role_name: - if emoji == '✔': - new_emoji = '❌' - else: - new_emoji = '✔' - selected_role_names.append(role_name) - new_row.append(f'{new_emoji} {role_name}') - else: - if emoji == '✔': - selected_role_names.append(role_name) - new_row.append(button.label) - new_buttons.append(new_row) + pressed_button = message.buttons_info[message.buttons_info.pressed_text] + pressed_button.is_checked = not pressed_button.is_checked + pressed_button.text = f"{'✔' if pressed_button.is_checked else '❌'} {button_role_name}" - await self.edit(new_buttons, message) + selected_role_names = [checked_button.text.split(maxsplit=1)[1] for checked_button in message.buttons_info.checked_buttons()] user_names = [f'<@{user.id}>' for user in await self.find_users_by_roles(selected_role_names, message)] joined_user_names = ', '.join(user_names) await self.edit( f"{len(user_names)} usuario{'' if len(user_names) == 1 else 's'}:\n" f"{joined_user_names}\n\n" f"Filtrar usuarios por roles:", + message.buttons_info.buttons, message ) @@ -731,24 +742,37 @@ def decorator(func): return decorator(func_) if func_ else decorator @overload - def register_button(self, func_: Callable = None, keywords=(), min_ratio=constants.PARSE_BUTTON_CALLBACKS_MIN_RATIO_DEFAULT, always=False, default=False): + def register_button(self, func_: Callable = None, key: Any = None): pass @overload - def register_button(self, keywords=(), min_ratio=constants.PARSE_BUTTON_CALLBACKS_MIN_RATIO_DEFAULT, always=False, default=False): + def register_button(self, key: Any = None): pass @shift_args_if_called(exclude_self_types='MultiBot', globals_=globals()) - def register_button(self, func_: Callable = None, keywords: str | Iterable[str | Iterable[str]] = (), min_ratio=constants.PARSE_BUTTON_CALLBACKS_MIN_RATIO_DEFAULT, always=False, default=False): + def register_button(self, func_: Callable = None, key: Any = None): def decorator(func): - self._registered_button_callbacks.append(RegisteredCallback(func, keywords, min_ratio, always, default)) + self._registered_button_callbacks.append(RegisteredButtonCallback(func, key)) return func return decorator(func_) if func_ else decorator @abstractmethod @parse_arguments - async def send(self, text='', media: Media = None, buttons: list[str | list[str]] | None = None, chat: int | str | User | Chat | Message | None = None, message: Message = None, *, reply_to: int | str | Message = None, silent: bool = False, send_as_file: bool = None, edit=False) -> Message | None: + async def send( + self, + text='', + media: Media = None, + buttons: list[str | tuple[str, bool] | list[str | tuple[str, bool]]] | None = None, + chat: int | str | User | Chat | Message | None = None, + message: Message = None, + *, + buttons_key: Any = None, + reply_to: int | str | Message = None, + silent: bool = False, + send_as_file: bool = None, + edit=False + ) -> Message | None: pass @parse_arguments diff --git a/multibot/bots/telegram_bot.py b/multibot/bots/telegram_bot.py index 97b7918..13f0354 100644 --- a/multibot/bots/telegram_bot.py +++ b/multibot/bots/telegram_bot.py @@ -20,7 +20,7 @@ from multibot import constants from multibot.bots.multi_bot import MultiBot, find_message, inline, parse_arguments from multibot.exceptions import LimitError -from multibot.models import Chat, Message, Platform, User +from multibot.models import Button, ButtonsInfo, Chat, Message, Platform, User # ---------------------------------------------------------- # @@ -83,12 +83,14 @@ def _add_handlers(self): self.client.add_event_handler(self._on_new_message_raw, telethon.events.NewMessage) @return_if_first_empty(exclude_self_types='TelegramBot', globals_=globals()) - async def _create_bot_message_from_telegram_bot_message(self, original_message: constants.TELEGRAM_MESSAGE, chat: Chat, contents: Any = None) -> Message | None: + async def _create_bot_message_from_telegram_bot_message(self, original_message: constants.TELEGRAM_MESSAGE, chat: Chat, buttons: list[list[Button]], buttons_key: Any = None, contents: Any = None) -> Message | None: original_message._sender = await self.client.get_entity(self.id) original_message._chat = chat.original_object bot_message = await self._get_message(original_message) + bot_message.buttons_info = ButtonsInfo(buttons=buttons, key=buttons_key) bot_message.contents = contents or [] bot_message.save() + return bot_message @return_if_first_empty(exclude_self_types='TelegramBot', globals_=globals()) @@ -139,7 +141,7 @@ async def _get_button_pressed_text(self, event: constants.TELEGRAM_EVENT) -> str pass @return_if_first_empty(exclude_self_types='TelegramBot', globals_=globals()) - async def _get_button_pressed_user(self, event: constants.TELEGRAM_EVENT) -> User | None: + async def _get_button_presser_user(self, event: constants.TELEGRAM_EVENT) -> User | None: return await self._create_user_from_telegram_user(event.sender) @return_if_first_empty(exclude_self_types='TelegramBot', globals_=globals()) @@ -341,30 +343,27 @@ async def send( self, text='', media: Media = None, - buttons: list[str | list[str]] | None = None, + buttons: list[str | tuple[str, bool] | Button | list[str | tuple[str, bool] | Button]] | None = None, chat: int | str | User | Chat | Message | None = None, message: Message = None, *, + buttons_key: Any = None, reply_to: int | str | Message = None, silent: bool = False, send_as_file: bool = None, edit=False, ) -> Message | None: - def create_buttons() -> list[list[str]] | None: - if not buttons: - return + file = await self._prepare_media_to_send(media) + telegram_buttons = None + if buttons: telegram_buttons = [] for row in buttons: telegram_buttons_row = [] - for button_text in row: - telegram_buttons_row.append(telethon.Button.inline(button_text)) + for button in row: + telegram_buttons_row.append(telethon.Button.inline(button.text)) telegram_buttons.append(telegram_buttons_row) - return telegram_buttons - - file = await self._prepare_media_to_send(media) - kwargs = { 'file': file, 'parse_mode': 'html' @@ -384,38 +383,44 @@ def create_buttons() -> list[list[str]] | None: message.contents.append(message.original_event.builder.document(file, title=media.type_.name.title(), type=media.type_.name.lower())) elif edit: if buttons is not None: - kwargs['buttons'] = create_buttons() + kwargs['buttons'] = telegram_buttons + message.buttons_info.buttons = buttons + if buttons_key is not None: + message.buttons_info.key = buttons_key + try: - edited_message = await message.original_object.edit(text, **kwargs) + message.original_object = await message.original_object.edit(text, **kwargs) except ( telethon.errors.rpcerrorlist.PeerIdInvalidError, telethon.errors.rpcerrorlist.UserIsBlockedError, telethon.errors.rpcerrorlist.MessageNotModifiedError ): - pass + return else: - message.original_object = edited_message - if content := getattr(media, 'content', None): - message.contents = [content] - message.update_last_edit() - message.save() - return message - - match reply_to: - case str(): - reply_to = int(reply_to) - case Message() as message_to_reply: - reply_to = message_to_reply.original_object + if content := getattr(media, 'content', None): + message.contents = [content] + message.update_last_edit() + message.save() + + return message + + match reply_to: + case str(): + reply_to = int(reply_to) + case Message() as message_to_reply: + reply_to = message_to_reply.original_object try: - original_message = await self.client.send_message(chat.original_object, text, buttons=create_buttons(), reply_to=reply_to, silent=silent, **kwargs) + original_message = await self.client.send_message(chat.original_object, text, buttons=telegram_buttons, reply_to=reply_to, silent=silent, **kwargs) except (telethon.errors.rpcerrorlist.PeerIdInvalidError, telethon.errors.rpcerrorlist.UserIsBlockedError): return + if content := getattr(media, 'content', None): contents = [content] else: contents = [] - return await self._create_bot_message_from_telegram_bot_message(original_message, chat, contents=contents) + + return await self._create_bot_message_from_telegram_bot_message(original_message, chat, buttons, buttons_key, contents) @inline async def send_inline_results(self, message: Message): diff --git a/multibot/bots/twitch_bot.py b/multibot/bots/twitch_bot.py index acdd718..9b40930 100644 --- a/multibot/bots/twitch_bot.py +++ b/multibot/bots/twitch_bot.py @@ -6,7 +6,7 @@ import datetime import re from collections import defaultdict -from typing import Iterable, Iterator +from typing import Any, Iterable, Iterator import flanautils import pymongo @@ -243,6 +243,7 @@ async def send( chat: int | str | User | Chat | Message | None = None, message: Message = None, *, + buttons_key: Any = None, reply_to: str | Message = None, silent: bool = False, send_as_file: bool = None, diff --git a/multibot/models/__init__.py b/multibot/models/__init__.py index fe667e2..82692bb 100644 --- a/multibot/models/__init__.py +++ b/multibot/models/__init__.py @@ -1,4 +1,5 @@ from multibot.models.bot_action import * +from multibot.models.buttons import * from multibot.models.chat import * from multibot.models.database import * from multibot.models.enums import * diff --git a/multibot/models/buttons.py b/multibot/models/buttons.py new file mode 100644 index 0000000..7f624ea --- /dev/null +++ b/multibot/models/buttons.py @@ -0,0 +1,43 @@ +__all__ = ['Button', 'ButtonsInfo'] + +from dataclasses import dataclass, field +from typing import Any + +from flanautils import FlanaBase + +from multibot.models.user import User + + +@dataclass(eq=False) +class Button(FlanaBase): + text: str = None + is_checked: bool = False + + def _dict_repr(self) -> Any: + return bytes(self) + + +@dataclass(eq=False) +class ButtonsInfo(FlanaBase): + pressed_text: str = None + presser_user: User = None + buttons: list[list[Button]] = field(default_factory=lambda: [[]]) + key: Any = None + + def __getitem__(self, item) -> Button: + if not isinstance(item, str): + raise TypeError('index has to be a string') + + for row in self.buttons: + for button in row: + if button.text == item: + return button + + def _dict_repr(self) -> Any: + return bytes(self) + + def checked_buttons(self) -> list[Button]: + return [button for row in self.buttons for button in row if button.is_checked] + + def find_button(self, text: str) -> Button: + return self[text] diff --git a/multibot/models/enums.py b/multibot/models/enums.py index 12351f2..c3459fe 100644 --- a/multibot/models/enums.py +++ b/multibot/models/enums.py @@ -1,4 +1,4 @@ -__all__ = ['Action', 'Platform'] +__all__ = ['Action', 'ButtonsGroup', 'Platform'] from enum import auto @@ -10,6 +10,13 @@ class Action(FlanaEnum): MESSAGE_DELETED = auto() +class ButtonsGroup(FlanaEnum): + CONFIG = auto() + POLL = auto() + USERS = auto() + WEATHER = auto() + + class Platform(FlanaEnum): DISCORD = auto() TELEGRAM = auto() diff --git a/multibot/models/message.py b/multibot/models/message.py index 11f8515..1ab6053 100644 --- a/multibot/models/message.py +++ b/multibot/models/message.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field from multibot import constants +from multibot.models.buttons import ButtonsInfo from multibot.models.chat import Chat from multibot.models.database import db from multibot.models.enums import Platform @@ -23,9 +24,8 @@ class Message(EventComponent): id: int | str = None author: User = None text: str = None - button_pressed_text: str = None - button_pressed_user: User = None mentions: list[User] = field(default_factory=list) + buttons_info: ButtonsInfo = None chat: Chat = None replied_message: Message = None date: datetime.datetime = field(default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)) diff --git a/multibot/models/registered_callback.py b/multibot/models/registered_callback.py index 895baec..efe3c2d 100644 --- a/multibot/models/registered_callback.py +++ b/multibot/models/registered_callback.py @@ -1,4 +1,4 @@ -__all__ = ['RegisteredCallback'] +__all__ = ['RegisteredCallbackBase', 'RegisteredCallback', 'RegisteredButtonCallback'] from dataclasses import dataclass from typing import Callable, Iterable @@ -10,8 +10,24 @@ @dataclass -class RegisteredCallback(FlanaBase): +class RegisteredCallbackBase(FlanaBase): callback: Callable + + def __call__(self, *args, **kwargs): + return self.callback(*args, **kwargs) + + def __eq__(self, other): + if isinstance(other, RegisteredCallback): + return self.callback == other.callback + else: + return self.callback == other + + def __hash__(self): + return hash(self.callback) + + +@dataclass(eq=False) +class RegisteredCallback(RegisteredCallbackBase): keywords: str | Iterable[str | Iterable[str]] min_ratio: float always: bool @@ -39,17 +55,7 @@ def __init__( self.always = always self.default = default - def __post_init__(self): - self.keywords = tuple(self.keywords) - - def __call__(self, *args, **kwargs): - return self.callback(*args, **kwargs) - - def __eq__(self, other): - if isinstance(other, RegisteredCallback): - return self.callback == other.callback - else: - return self.callback == other - def __hash__(self): - return hash(self.callback) +@dataclass(eq=False) +class RegisteredButtonCallback(RegisteredCallbackBase): + key: any diff --git a/multibot/models/user.py b/multibot/models/user.py index ab62d08..46c8a0e 100644 --- a/multibot/models/user.py +++ b/multibot/models/user.py @@ -19,3 +19,6 @@ class User(EventComponent): is_admin: bool = None is_bot: bool = None original_object: constants.ORIGINAL_USER = None + + def __getstate__(self): + return self._mongo_repr()