diff --git a/app/handlers/base.py b/app/handlers/base.py index 1819879d..6d89699a 100644 --- a/app/handlers/base.py +++ b/app/handlers/base.py @@ -4,6 +4,7 @@ from aiogram.utils.markdown import hbold, hpre from app.infrastructure.database.models import Chat +from app.infrastructure.database.repo.chat import ChatRepo from app.utils.log import Logger logger = Logger(__name__) @@ -83,9 +84,9 @@ async def cancel_state(message: types.Message, state: FSMContext): @router.message(F.message.content_types == types.ContentType.MIGRATE_TO_CHAT_ID) -async def chat_migrate(message: types.Message, chat: Chat): +async def chat_migrate(message: types.Message, chat: Chat, chat_repo: ChatRepo): old_id = message.chat.id new_id = message.migrate_to_chat_id chat.chat_id = new_id - await chat.save() + await chat_repo.update(chat) logger.info(f"Migrate chat from {old_id} to {new_id}") diff --git a/app/handlers/karma.py b/app/handlers/karma.py index d01e7234..12e37f12 100644 --- a/app/handlers/karma.py +++ b/app/handlers/karma.py @@ -5,6 +5,7 @@ from aiogram.utils.text_decorations import html_decoration as hd from app.infrastructure.database.models import Chat, User +from app.infrastructure.database.repo.chat import ChatRepo from app.models.config import Config from app.services.karma import get_me_chat_info, get_me_info from app.services.karma import get_top as get_karma_top @@ -16,10 +17,10 @@ @router.message(Command("top", prefix="!"), F.chat.type == "private") -async def get_top_from_private(message: types.Message, user: User): +async def get_top_from_private(message: types.Message, user: User, chat_repo: ChatRepo): parts = message.text.split(maxsplit=1) if len(parts) > 1: - chat = await Chat.get(chat_id=int(parts[1])) + chat = await chat_repo.get_by_id(chat_id=int(parts[1])) else: return await message.reply( "Эту команду можно использовать только в группах " diff --git a/app/infrastructure/database/repo/__init__.py b/app/infrastructure/database/repo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/infrastructure/database/repo/chat.py b/app/infrastructure/database/repo/chat.py new file mode 100644 index 00000000..049db8a5 --- /dev/null +++ b/app/infrastructure/database/repo/chat.py @@ -0,0 +1,100 @@ +from collections import namedtuple + +from tortoise import BaseDBAsyncClient +from tortoise.exceptions import DoesNotExist + +from app.infrastructure.database.models import UserKarma +from app.infrastructure.database.models.chat import Chat +from app.infrastructure.database.models.user import User +from app.models.db.db import karma_filters +from app.utils.exceptions import NotHaveNeighbours + +TopResultEntry = namedtuple("TopResult", ("user", "karma")) +Neighbours = namedtuple("Neighbours", ("prev_id", "next_id")) + + +class ChatRepo: + def __init__(self, session: BaseDBAsyncClient | None = None): + self.session = session + + async def get_by_id(self, chat_id: int) -> Chat: + return await Chat.get(chat_id=chat_id, using_db=self.session) + + async def update(self, chat: Chat): + await chat.save(using_db=self.session) + + async def create_from_tg_chat(self, chat) -> Chat: + chat = await Chat.create( + chat_id=chat.id, + type_=chat.type, + title=chat.title, + username=chat.username, + using_db=self.session, + ) + return chat + + async def get_or_create_from_tg_chat(self, chat) -> Chat: + try: + chat = await Chat.get(chat_id=chat.id) + except DoesNotExist: + chat = await self.create_from_tg_chat(chat=chat) + return chat + + async def get_top_karma_list( + self, chat: Chat, limit: int = 15 + ) -> list[TopResultEntry]: + await chat.fetch_related("user_karma", using_db=self.session) + users_karmas = ( + await chat.user_karma.order_by(*karma_filters) + .limit(limit) + .prefetch_related("user") + .all() + ) + rez = [] + for user_karma in users_karmas: + user = user_karma.user + karma = user_karma.karma_round + rez.append(TopResultEntry(user, karma)) + + return rez + + async def get_neighbours( + self, user: User, chat: Chat + ) -> tuple[UserKarma, UserKarma, UserKarma]: + prev_id, next_id = await self.get_neighbours_id(chat.chat_id, user.id) + uk = ( + await chat.user_karma.filter(user_id__in=(prev_id, next_id)) + .prefetch_related("user") + .order_by(*karma_filters) + .all() + ) + + user_uk = ( + await chat.user_karma.filter(user=user).prefetch_related("user").first() + ) + return uk[0], user_uk, uk[1] + + async def get_neighbours_id(self, chat_id, user_id) -> Neighbours: + neighbours = await self.session.execute_query( + query=""" + SELECT prev_user_id, next_user_id + FROM ( + SELECT + user_id, + LAG(user_id) OVER(ORDER BY karma) prev_user_id, + LEAD(user_id) OVER(ORDER BY karma) next_user_id + FROM user_karma + WHERE chat_id = ? + ) + WHERE user_id = ?""", + values=[chat_id, user_id], + ) + try: + rez = dict(neighbours[1][0]) + except IndexError: + raise NotHaveNeighbours + try: + rez = int(rez["prev_user_id"]), int(rez["next_user_id"]) + except TypeError: + raise NotHaveNeighbours + return rez diff --git a/app/middlewares/db_middleware.py b/app/middlewares/db_middleware.py index 1c6fdadd..ba74e19e 100644 --- a/app/middlewares/db_middleware.py +++ b/app/middlewares/db_middleware.py @@ -5,8 +5,11 @@ from aiogram import BaseMiddleware, types from aiogram.dispatcher.event.bases import CancelHandler from aiogram.types import TelegramObject +from tortoise import BaseDBAsyncClient +from tortoise.transactions import in_transaction -from app.infrastructure.database.models import Chat, User +from app.infrastructure.database.models import User +from app.infrastructure.database.repo.chat import ChatRepo from app.services.settings import get_chat_settings from app.utils.lock_factory import LockFactory from app.utils.log import Logger @@ -29,18 +32,27 @@ async def __call__( user: types.User = data.get("event_from_user", None) if isinstance(event, types.Message) and event.sender_chat: raise CancelHandler - await self.setup_chat(data, user, chat) - return await handler(event, data) + + async with in_transaction() as session: + await self.setup_chat(session, data, user, chat) + return await handler(event, data) async def setup_chat( - self, data: dict, user: types.User, chat: Optional[types.Chat] = None + self, + session: BaseDBAsyncClient, + data: dict, + user: types.User, + chat: Optional[types.Chat] = None, ): try: + chat_repo = ChatRepo(session) + async with self.lock_factory.get_lock(user.id): user = await User.get_or_create_from_tg_user(user) + if chat and chat.type != "private": async with self.lock_factory.get_lock(chat.id): - chat = await Chat.get_or_create_from_tg_chat(chat) + chat = await chat_repo.get_or_create_from_tg_chat(chat) data["chat_settings"] = await get_chat_settings(chat=chat) except Exception as e: @@ -48,3 +60,4 @@ async def setup_chat( raise e data["user"] = user data["chat"] = chat + data["chat_repo"] = chat_repo