diff --git a/bot/run.py b/bot/run.py index 03994e3..53ee099 100644 --- a/bot/run.py +++ b/bot/run.py @@ -4,11 +4,13 @@ from aiogram.types import Message from aiogram.utils.keyboard import InlineKeyboardBuilder from func.functions import * + # Other import asyncio import traceback import io import base64 + bot = Bot(token=token) dp = Dispatcher() builder = InlineKeyboardBuilder() @@ -33,15 +35,19 @@ CHAT_TYPE_GROUP = "group" CHAT_TYPE_SUPERGROUP = "supergroup" + def is_mentioned_in_group_or_supergroup(message): - return (message.chat.type in [CHAT_TYPE_GROUP, CHAT_TYPE_SUPERGROUP] - and message.text.startswith(mention)) + return message.chat.type in [CHAT_TYPE_GROUP, CHAT_TYPE_SUPERGROUP] and ( + (message.text is not None and message.text.startswith(mention)) + or (message.caption is not None and message.caption.startswith(mention)) + ) + async def get_bot_info(): global mention if mention is None: get = await bot.get_me() - mention = (f"@{get.username}") + mention = f"@{get.username}" return mention @@ -102,7 +108,9 @@ async def modelmanager_callback_handler(query: types.CallbackQuery): if model["details"]["families"]: modelicon = {"llama": "šŸ¦™", "clip": "šŸ“·"} try: - modelfamilies = "".join([modelicon[family] for family in model['details']['families']]) + modelfamilies = "".join( + [modelicon[family] for family in model["details"]["families"]] + ) except KeyError as e: # Use a default value when the key is not found modelfamilies = f"āœØ" @@ -113,7 +121,8 @@ async def modelmanager_callback_handler(query: types.CallbackQuery): ) ) await query.message.edit_text( - f"{len(models)} models available.\nšŸ¦™ = Regular\nšŸ¦™šŸ“· = Multimodal", reply_markup=modelmanager_builder.as_markup() + f"{len(models)} models available.\nšŸ¦™ = Regular\nšŸ¦™šŸ“· = Multimodal", + reply_markup=modelmanager_builder.as_markup(), ) @@ -147,30 +156,32 @@ async def handle_message(message: types.Message): await ollama_request(message) if is_mentioned_in_group_or_supergroup(message): # Remove the mention from the message - text_without_mention = message.text.replace(mention, "").strip() + if message.text is not None: + text_without_mention = message.text.replace(mention, "").strip() + prompt = text_without_mention + else: + text_without_mention = message.caption.replace(mention, "").strip() + prompt = text_without_mention + # Pass the modified text and bot instance to ollama_request - await ollama_request(types.Message( - message_id=message.message_id, - from_user=message.from_user, - date=message.date, - chat=message.chat, - text=text_without_mention - )) + await ollama_request(message, prompt) ... -async def ollama_request(message: types.Message): + + +async def ollama_request(message: types.Message, prompt: str = None): try: await bot.send_chat_action(message.chat.id, "typing") - prompt = message.text or message.caption - image_base64 = '' - if message.content_type == 'photo': + image_base64 = "" + if message.content_type == "photo": image_buffer = io.BytesIO() - await bot.download( - message.photo[-1], - destination=image_buffer - ) - image_base64 = base64.b64encode(image_buffer.getvalue()).decode('utf-8') + await bot.download(message.photo[-1], destination=image_buffer) + image_base64 = base64.b64encode(image_buffer.getvalue()).decode("utf-8") + + if prompt is None: + prompt = message.text or message.caption + full_response = "" sent_message = None last_sent_text = None @@ -180,12 +191,22 @@ async def ollama_request(message: types.Message): if ACTIVE_CHATS.get(message.from_user.id) is None: ACTIVE_CHATS[message.from_user.id] = { "model": modelname, - "messages": [{"role": "user", "content": prompt, "images": ([image_base64] if image_base64 else [])}], + "messages": [ + { + "role": "user", + "content": prompt, + "images": ([image_base64] if image_base64 else []), + } + ], "stream": True, } else: ACTIVE_CHATS[message.from_user.id]["messages"].append( - {"role": "user", "content": prompt, "images": ([image_base64] if image_base64 else [])} + { + "role": "user", + "content": prompt, + "images": ([image_base64] if image_base64 else []), + } ) logging.info( f"[Request]: Processing '{prompt}' for {message.from_user.first_name} {message.from_user.last_name}" @@ -206,8 +227,11 @@ async def ollama_request(message: types.Message): if "." in chunk or "\n" in chunk or "!" in chunk or "?" in chunk: if sent_message: if last_sent_text != full_response_stripped: - await bot.edit_message_text(chat_id=message.chat.id, message_id=sent_message.message_id, - text=full_response_stripped) + await bot.edit_message_text( + chat_id=message.chat.id, + message_id=sent_message.message_id, + text=full_response_stripped, + ) last_sent_text = full_response_stripped else: sent_message = await bot.send_message( @@ -218,16 +242,17 @@ async def ollama_request(message: types.Message): last_sent_text = full_response_stripped if response_data.get("done"): - if ( - full_response_stripped - and last_sent_text != full_response_stripped - ): + if full_response_stripped and last_sent_text != full_response_stripped: if sent_message: - await bot.edit_message_text(chat_id=message.chat.id, message_id=sent_message.message_id, - text=full_response_stripped) + await bot.edit_message_text( + chat_id=message.chat.id, + message_id=sent_message.message_id, + text=full_response_stripped, + ) else: - sent_message = await bot.send_message(chat_id=message.chat.id, - text=full_response_stripped) + sent_message = await bot.send_message( + chat_id=message.chat.id, text=full_response_stripped + ) await bot.edit_message_text( chat_id=message.chat.id, message_id=sent_message.message_id,