From 112c2ff691425b1418b48950d8d70402588b2efd Mon Sep 17 00:00:00 2001
From: Furior <furiorg@gmail.com>
Date: Thu, 16 Jan 2025 21:48:37 +0700
Subject: [PATCH] wl improvements

---
 app/routes/whitelist.py           | 76 +++++++++++++++++++++++++++++--
 app/schemas/whitelist.py          | 26 ++++++++---
 app/tests/conftest.py             |  1 -
 app/tests/test_whitelist_route.py | 10 ++--
 requirements.txt                  |  6 +--
 5 files changed, 101 insertions(+), 18 deletions(-)

diff --git a/app/routes/whitelist.py b/app/routes/whitelist.py
index e2e7386..3b58d3a 100644
--- a/app/routes/whitelist.py
+++ b/app/routes/whitelist.py
@@ -1,12 +1,13 @@
 import datetime
 import logging
 
-from fastapi import APIRouter, Depends, HTTPException, status
+from fastapi import APIRouter, Depends, status
 from sqlmodel import select
 
-from app.database.models import Player, Whitelist, WhitelistBan
+from app.database.models import Player, Whitelist
 from app.deps import SessionDep, verify_bearer
-from app.schemas.whitelist import NewWhitelistBanCkey, NewWhitelistCkey
+from app.routes.player import get_player_by_ckey, get_player_by_discord
+from app.schemas.whitelist import NewWhitelistCkey, NewWhitelistDiscord
 
 logger = logging.getLogger("main-logger")
 
@@ -24,3 +25,72 @@ async def get_whitelists(session: SessionDep, active_only: bool = True) -> list[
             Whitelist.expiration_time > datetime.datetime.now()
         )
     return session.exec(selection).all()
+
+
+@router.post("/", dependencies=[Depends(verify_bearer)])
+async def create_whitelist(session: SessionDep, new_whitelist: Whitelist) -> Whitelist:
+    session.add(new_whitelist)
+    session.commit()
+    session.refresh(new_whitelist)
+    logger.info("Created whitelist: %s", new_whitelist)
+    return new_whitelist
+
+
+@router.get("/{wl_type}/ckey/{ckey}", status_code=status.HTTP_200_OK)
+async def get_whitelists_by_ckey(session: SessionDep, wl_type: str, ckey: str, active_only: bool = True) -> list[Whitelist]:
+    selection = select(Whitelist
+                       ).join(Player, Player.id == Whitelist.player_id
+                              ).where(Player.ckey == ckey
+                                      ).where(Whitelist.wl_type == wl_type)
+    if active_only:
+        selection = selection.where(
+            Whitelist.valid).where(
+            Whitelist.expiration_time > datetime.datetime.now()
+        )
+    return session.exec(selection).all()
+
+
+@router.post("/{wl_type}/ckey/{ckey}", dependencies=[Depends(verify_bearer)])
+async def create_whitelist_by_ckey(session: SessionDep, wl_type: str, ckey: str, new_whitelist: NewWhitelistCkey) -> Whitelist:
+    player = await get_player_by_ckey(session, ckey)
+    admin = await get_player_by_ckey(session, new_whitelist.admin_ckey)
+
+    wl = Whitelist(
+        player_id=player.id,
+        admin_id=admin.id,
+        wl_type=wl_type,
+        expiration_time=datetime.datetime.now(
+        ) + datetime.timedelta(days=new_whitelist.duration_days),
+        valid=new_whitelist.valid
+    )
+    return await create_whitelist(session, wl)
+
+
+@router.get("/{wl_type}/discord/{discord_id}", status_code=status.HTTP_200_OK)
+async def get_whitelists_by_discord(session: SessionDep, wl_type: str, discord_id: str, active_only: bool = True) -> list[Whitelist]:
+    selection = select(Whitelist).join(Player, Player.id == Whitelist.player_id).where(
+        Player.discord_id == discord_id).where(Whitelist.wl_type == wl_type)
+    if active_only:
+        selection = selection.where(
+            Whitelist.valid).where(
+            Whitelist.expiration_time > datetime.datetime.now()
+        )
+    return session.exec(selection).all()
+
+
+@router.post("/{wl_type}/discord/{discord_id}", dependencies=[Depends(verify_bearer)])
+async def create_whitelist_by_discord(session: SessionDep, new_whitelist: NewWhitelistDiscord) -> Whitelist:
+    player = await get_player_by_discord(session, new_whitelist.discord_id)
+    admin = await get_player_by_ckey(session, new_whitelist.admin_ckey)
+
+    wl = Whitelist(
+        player_id=player.id,
+        admin_id=admin.id,
+        wl_type=new_whitelist.wl_type,
+        expiration_time=datetime.datetime.now(
+        ) + datetime.timedelta(days=new_whitelist.duration_days),
+        valid=new_whitelist.valid
+    )
+    return await create_whitelist(session, wl)
+
+router.include_router(whitelist_ban_router)
diff --git a/app/schemas/whitelist.py b/app/schemas/whitelist.py
index d61f648..bbe147d 100644
--- a/app/schemas/whitelist.py
+++ b/app/schemas/whitelist.py
@@ -1,16 +1,28 @@
 from pydantic import BaseModel
 
 
-class NewWhitelistCkey(BaseModel):
-    player_ckey: str
-    admin_ckey: str
+class NewWhitelistBase(BaseModel):
     wl_type: str
     duration_days: int
 
 
-class NewWhitelistBanCkey(BaseModel):
+class NewWhitelistBanBase(NewWhitelistBase):
+    reason: str | None = None
+
+
+class NewWhitelistCkey(NewWhitelistBase):
     player_ckey: str
     admin_ckey: str
-    wl_type: str
-    duration_days: int
-    reason: str
+
+
+class NewWhitelistBanCkey(NewWhitelistCkey, NewWhitelistBanBase):
+    pass
+
+
+class NewWhitelistDiscord(NewWhitelistBase):
+    player_discord_id: str
+    admin_discord_id: str
+
+
+class NewWhitelistBanDiscord(NewWhitelistCkey, NewWhitelistBanBase):
+    pass
diff --git a/app/tests/conftest.py b/app/tests/conftest.py
index 485fb16..0bd4c5b 100644
--- a/app/tests/conftest.py
+++ b/app/tests/conftest.py
@@ -25,7 +25,6 @@ def client(app: FastAPI):
 
 @pytest.fixture(scope="function")
 def db_session():
-    
 
     # Create an in-memory SQLite database engine
     # sqlite_engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False})
diff --git a/app/tests/test_whitelist_route.py b/app/tests/test_whitelist_route.py
index 5e190c5..0f7da9f 100644
--- a/app/tests/test_whitelist_route.py
+++ b/app/tests/test_whitelist_route.py
@@ -1,7 +1,6 @@
 import datetime
-from sqlmodel import select
 
-from app.database.models import Player, Whitelist, WhitelistBan
+from app.database.models import Whitelist
 
 
 def test_get_whitelists_general_empty(client):
@@ -18,11 +17,14 @@ def test_get_whitelists_all(client, whitelist_factory):
 
     assert wls == [Whitelist.model_validate(wl) for wl in response.json()]
 
+
 def test_get_whitelists_active(client, whitelist_factory):
     wls = [whitelist_factory() for _ in range(5)]
-    active_wls = [wl for wl in wls if wl.expiration_time > datetime.datetime.now() and wl.valid]
+    active_wls = [wl for wl in wls if wl.expiration_time >
+                  datetime.datetime.now() and wl.valid]
     response = client.get("/whitelist?active_only=true")
     assert response.status_code == 200
     assert len(response.json()) == len(active_wls)
 
-    assert active_wls == [Whitelist.model_validate(wl) for wl in response.json()]
\ No newline at end of file
+    assert active_wls == [Whitelist.model_validate(
+        wl) for wl in response.json()]
diff --git a/requirements.txt b/requirements.txt
index 2cc4f01..8f5e2b9 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,7 @@
-fastapi ~ = 0.112
+fastapi ~= 0.112
 fastapi[standard]
-pydantic ~ = 2.8.2
-sqlalchemy ~ = 2.0.33
+pydantic ~= 2.8.2
+sqlalchemy ~= 2.0.33
 sqlmodel
 psycopg2
 aiohttp