From acba5dd6aa8c5e71c378e28e1122654af1e49c20 Mon Sep 17 00:00:00 2001 From: Furior Date: Mon, 10 Feb 2025 22:59:34 +0700 Subject: [PATCH] overload the wl route --- app/routes/v1/whitelist.py | 211 ++++++++++--------------------------- app/schemas/whitelist.py | 17 ++- 2 files changed, 69 insertions(+), 159 deletions(-) diff --git a/app/routes/v1/whitelist.py b/app/routes/v1/whitelist.py index 9e56e79..d22f012 100644 --- a/app/routes/v1/whitelist.py +++ b/app/routes/v1/whitelist.py @@ -1,6 +1,5 @@ import datetime import logging -from typing import Callable from fastapi import APIRouter, Depends, HTTPException, Request, status from sqlalchemy import func, update @@ -10,10 +9,7 @@ from app.database.models import Player, Whitelist, WhitelistBan from app.deps import BEARER_DEP_RESPONSES, SessionDep, verify_bearer from app.schemas.generic import PaginatedResponse, paginate_selection -from app.schemas.whitelist import (NewWhitelistBanBase, NewWhitelistBanCkey, - NewWhitelistBanDiscord, NewWhitelistBanInternal, NewWhitelistBase, - NewWhitelistCkey, NewWhitelistDiscord, - NewWhitelistInternal, WhitelistPatch) +from app.schemas.whitelist import (NEW_WHITELIST_BAN_TYPES, NEW_WHITELIST_TYPES, WhitelistPatch, resolve_whitelist_type) logger = logging.getLogger(__name__) @@ -29,48 +25,6 @@ def select_only_active_whitelists(selection: SelectOfScalar[Whitelist]): ) -WHITELIST_TYPES_T = NewWhitelistCkey | NewWhitelistDiscord | NewWhitelistInternal - - -async def create_whitelist_helper( - session: SessionDep, - new_wl: NewWhitelistBase, - player_resolver: Callable[[WHITELIST_TYPES_T], bool], - admin_resolver: Callable[[WHITELIST_TYPES_T], bool], - ignore_bans: bool = False -) -> Whitelist: - """Core logic for creating whitelist entries""" - player: Player = session.exec( - select(Player).where(player_resolver(new_wl))).first() - admin: Player = session.exec( - select(Player).where(admin_resolver(new_wl))).first() - - if not player or not admin: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, - detail="Player or admin not found") - - if not ignore_bans and session.exec( - select_only_active_whitelist_bans( - select(WhitelistBan) - .where(WhitelistBan.player_id == player.id) - .where(WhitelistBan.wl_type == new_wl.wl_type) - ) - ).first(): - raise HTTPException(status_code=status.HTTP_409_CONFLICT, - detail="Player is banned from this type of whitelist.") - - wl = Whitelist( - **{**new_wl.model_dump(), "player_id": player.id, "admin_id": admin.id}, - expiration_time=new_wl.get_expiration_time(), - ) - - session.add(wl) - session.commit() - session.refresh(wl) - logger.info("Whitelist created: %s", wl.model_dump_json()) - return wl - - def filter_whitelists(selection: SelectOfScalar[Whitelist], ckey: str | None = None, discord_id: str | None = None, @@ -178,42 +132,33 @@ async def get_whitelisted_ckeys(session: SessionDep, status_code=status.HTTP_201_CREATED, responses=WHITELIST_POST_RESPONSES, dependencies=[Depends(verify_bearer)]) -async def create_whitelist(session: SessionDep, new_wl: NewWhitelistInternal, ignore_bans: bool = False) -> Whitelist: - return await create_whitelist_helper( - session, - new_wl, - lambda d: Player.id == d.player_id, - lambda d: Player.id == d.admin_id, - ignore_bans - ) +async def create_whitelist(session: SessionDep, new_wl: NEW_WHITELIST_TYPES, ignore_bans: bool = False) -> Whitelist: + player_resolver, admin_resolver = resolve_whitelist_type(new_wl) + player = session.exec(select(Player).where(player_resolver(new_wl))).first() + admin = session.exec(select(Player).where(admin_resolver(new_wl))).first() -@whitelist_router.post("/ckey", - status_code=status.HTTP_201_CREATED, - responses=WHITELIST_POST_RESPONSES, - dependencies=[Depends(verify_bearer)]) -async def create_whitelist_by_ckey(session: SessionDep, new_wl: NewWhitelistCkey, ignore_bans: bool = False) -> Whitelist: - return await create_whitelist_helper( - session, - new_wl, - lambda d: Player.ckey == d.player_ckey, - lambda d: Player.ckey == d.admin_ckey, - ignore_bans - ) + if player is None or admin is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, + detail="Player or admin not found") + if not ignore_bans: + selection = select(WhitelistBan).where( + WhitelistBan.player_id == player.id).where( + WhitelistBan.wl_type == new_wl.wl_type) + selection = select_only_active_whitelist_bans(selection) + + if session.exec(selection).first() is not None: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, + detail="Player is banned from this type of whitelist.") + + wl = Whitelist(**new_wl.model_dump(), player_id=player.id, admin_id=admin.id) + session.add(wl) + session.commit() + session.refresh(wl) + return wl + -@whitelist_router.post("/discord", - status_code=status.HTTP_201_CREATED, - responses=WHITELIST_POST_RESPONSES, - dependencies=[Depends(verify_bearer)]) -async def create_whitelist_by_discord(session: SessionDep, new_wl: NewWhitelistDiscord, ignore_bans: bool = False) -> Whitelist: - return await create_whitelist_helper( - session, - new_wl, - lambda d: Player.discord_id == d.player_discord_id, - lambda d: Player.discord_id == d.admin_discord_id, - ignore_bans - ) # endregion # region Patch @@ -254,48 +199,6 @@ def select_only_active_whitelist_bans(selection: SelectOfScalar[WhitelistBan]): ) -BAN_POST_RESPONSES = { - **BEARER_DEP_RESPONSES, - status.HTTP_201_CREATED: {"description": "Ban created"}, - status.HTTP_404_NOT_FOUND: {"description": "Player or admin not found"}, -} - - -def create_ban_helper(session: SessionDep, - new_ban: NewWhitelistBanBase, - player_resolver: Callable[[WHITELIST_TYPES_T], bool], - admin_resolver: Callable[[WHITELIST_TYPES_T], bool], - invalidate_wls: bool = True - ) -> WhitelistBan: - player: Player = session.exec( - select(Player).where(player_resolver(new_ban))).first() - admin: Player = session.exec( - select(Player).where(admin_resolver(new_ban))).first() - - if not player or not admin: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, - detail="Player or admin not found") - - ban = WhitelistBan(**new_ban.model_dump(), - player_id=player.id, admin_id=admin.id) - session.add(ban) - - if invalidate_wls: - session.exec( - update(Whitelist) - .where(Whitelist.player_id == player.id) - .where(Whitelist.wl_type == new_ban.wl_type) - .where(Whitelist.valid) - .where(Whitelist.expiration_time > datetime.datetime.now()) - .values(valid=False) - ) - - session.commit() - session.refresh(ban) - logger.info("Whitelist ban created: %s", ban.model_dump_json()) - return ban - - # region Get @@ -345,53 +248,45 @@ async def get_whitelist_bans(session: SessionDep, # endregion # region Post +BAN_POST_RESPONSES = { + **BEARER_DEP_RESPONSES, + status.HTTP_201_CREATED: {"description": "Ban created"}, + status.HTTP_404_NOT_FOUND: {"description": "Player or admin not found"}, +} @whitelist_ban_router.post("", status_code=status.HTTP_201_CREATED, responses=BAN_POST_RESPONSES, dependencies=[Depends(verify_bearer)]) async def create_whitelist_ban(session: SessionDep, - new_ban: NewWhitelistBanInternal, + new_ban: NEW_WHITELIST_BAN_TYPES, invalidate_wls: bool = True) -> WhitelistBan: - return create_ban_helper( - session, - new_ban, - lambda d: Player.id == d.player_id, - lambda d: Player.id == d.admin_id, - invalidate_wls - ) - + player_resolver, admin_resolver = resolve_whitelist_type(new_ban) + player = session.exec(select(Player).where(player_resolver(new_ban))).first() + admin = session.exec(select(Player).where(admin_resolver(new_ban))).first() -@whitelist_ban_router.post("/ckey", - status_code=status.HTTP_201_CREATED, - responses=BAN_POST_RESPONSES, - dependencies=[Depends(verify_bearer)]) -async def create_whitelist_ban_by_ckey(session: SessionDep, - new_ban: NewWhitelistBanCkey, - invalidate_wls: bool = True) -> WhitelistBan: - return create_ban_helper( - session, - new_ban, - lambda d: Player.ckey == d.player_ckey, - lambda d: Player.ckey == d.admin_ckey, - invalidate_wls - ) + if player is None or admin is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, + detail="Player or admin not found") + if invalidate_wls: + session.exec( + update(Whitelist).where( + Whitelist.player_id == player.id + ).where( + Whitelist.wl_type == new_ban.wl_type + ).where( + Whitelist.expiration_time < datetime.datetime.now() + ).values( + valid=False + ) + ) -@whitelist_ban_router.post("/discord", - status_code=status.HTTP_201_CREATED, - responses=BAN_POST_RESPONSES, - dependencies=[Depends(verify_bearer)]) -async def create_whitelist_ban_by_discord(session: SessionDep, - new_ban: NewWhitelistBanDiscord, - invalidate_wls: bool = True) -> WhitelistBan: - return create_ban_helper( - session, - new_ban, - lambda d: Player.discord_id == d.player_discord_id, - lambda d: Player.discord_id == d.admin_discord_id, - invalidate_wls - ) + ban = WhitelistBan(**new_ban.model_dump(), player_id=player.id, admin_id=admin.id) + session.add(ban) + session.commit() + session.refresh(ban) + return ban # endregion # region Patch diff --git a/app/schemas/whitelist.py b/app/schemas/whitelist.py index 777135b..41e07fa 100644 --- a/app/schemas/whitelist.py +++ b/app/schemas/whitelist.py @@ -1,6 +1,8 @@ import datetime +from typing import TYPE_CHECKING, Callable from pydantic import BaseModel - +from sqlmodel.sql.expression import SelectOfScalar +from app.database.models import Player class NewWhitelistBase(BaseModel): wl_type: str @@ -41,6 +43,19 @@ class NewWhitelistDiscord(NewWhitelistBase): class NewWhitelistBanDiscord(NewWhitelistDiscord, NewWhitelistBanBase): pass +NEW_WHITELIST_TYPES = NewWhitelistInternal | NewWhitelistDiscord | NewWhitelistCkey +NEW_WHITELIST_BAN_TYPES = NewWhitelistBanInternal | NewWhitelistBanDiscord | NewWhitelistBanCkey + +def resolve_whitelist_type(new_wl: NEW_WHITELIST_TYPES) -> tuple[Callable[[NEW_WHITELIST_TYPES], SelectOfScalar], Callable[[NEW_WHITELIST_TYPES], SelectOfScalar]]: + match new_wl: + case NewWhitelistInternal(): + return (lambda new_wl: Player.id == new_wl.player_id, lambda new_wl: Player.id == new_wl.admin_id) + case NewWhitelistDiscord(): + return (lambda new_wl: Player.discord_id == new_wl.player_discord_id, lambda new_wl: Player.discord_id == new_wl.admin_discord_id) + case NewWhitelistCkey(): + return (lambda new_wl: Player.ckey == new_wl.player_ckey, lambda new_wl: Player.ckey == new_wl.admin_ckey) + case _: + raise TypeError("Someone added a new whitelist type without a case in resolve_whitelist_type") class WhitelistPatch(BaseModel): valid: bool | None = None \ No newline at end of file