Skip to content

Commit

Permalink
feat: sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
brunoalho99 committed Apr 7, 2024
1 parent b16a4a2 commit 930963b
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 29 deletions.
3 changes: 2 additions & 1 deletion llmstudio/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from dotenv import load_dotenv
import socket

from dotenv import load_dotenv

load_dotenv(os.path.join(os.getcwd(), ".env"))


Expand Down
7 changes: 3 additions & 4 deletions llmstudio/tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uvicorn
from fastapi import FastAPI
from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi import APIRouter

from llmstudio.config import TRACKING_HOST, TRACKING_PORT
from llmstudio.engine.providers import *
from llmstudio.tracking.logs.endpoints import LogsRoutes
Expand Down Expand Up @@ -29,7 +29,6 @@ def create_tracking_app() -> FastAPI:
allow_headers=["*"],
)


@app.get(TRACKING_HEALTH_ENDPOINT)
def health_check():
"""Health check endpoint to ensure the API is running."""
Expand All @@ -38,7 +37,7 @@ def health_check():
tracking_router = APIRouter(prefix=TRACKING_BASE_ENDPOINT)
LogsRoutes(tracking_router)
SessionsRoutes(tracking_router)

app.include_router(tracking_router)

@app.on_event("startup")
Expand Down
3 changes: 2 additions & 1 deletion llmstudio/tracking/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ def create_tracking_engine(uri: str):

Base = declarative_base()


def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
db.close()
5 changes: 5 additions & 0 deletions llmstudio/tracking/logs/crud.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from sqlalchemy.orm import Session

from llmstudio.tracking.logs import models, schemas


def get_project_by_name(db: Session, name: str):
return db.query(models.Project).filter(models.Project.name == name).first()


def get_logs_by_session(db: Session, session_id: str, skip: int = 0, limit: int = 100):
return (
db.query(models.LogDefault)
Expand All @@ -13,11 +16,13 @@ def get_logs_by_session(db: Session, session_id: str, skip: int = 0, limit: int
.all()
)


def add_log(db: Session, log: schemas.LogDefaultCreate):
db_log = models.LogDefault(**log.dict())
db.add(db_log)
db.commit()
db.refresh(db_log)

return db_log


Expand Down
44 changes: 29 additions & 15 deletions llmstudio/tracking/logs/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import List

from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List
from llmstudio.tracking.database import get_db, engine
from llmstudio.tracking.logs import crud, schemas,models

from llmstudio.tracking.database import engine, get_db
from llmstudio.tracking.logs import crud, models, schemas

models.Base.metadata.create_all(bind=engine)


class LogsRoutes:
def __init__(self, router: APIRouter):
self.router = router
Expand All @@ -20,24 +24,34 @@ def define_routes(self):
)(self.add_log)

# Read logs
self.router.get(
"/logs",
response_model=List[schemas.LogDefault]
)(self.read_logs)
self.router.get("/logs", response_model=List[schemas.LogDefault])(
self.read_logs
)

# Read logs by session
self.router.get(
"/logs_by_session",
response_model=List[schemas.LogDefault]
)(self.read_logs_by_session)
self.router.get("/logs_by_session", response_model=List[schemas.LogDefault])(
self.read_logs_by_session
)

async def add_log(self, log: schemas.LogDefaultCreate, db: Session = Depends(get_db)):
async def add_log(
self, log: schemas.LogDefaultCreate, db: Session = Depends(get_db)
):
return crud.add_log(db=db, log=log)

async def read_logs(self, skip: int = 0, limit: int = 1000, db: Session = Depends(get_db)):
async def read_logs(
self, skip: int = 0, limit: int = 1000, db: Session = Depends(get_db)
):
logs = crud.get_logs(db, skip=skip, limit=limit)
return logs

async def read_logs_by_session(self, session_id: str, skip: int = 0, limit: int = 1000, db: Session = Depends(get_db)):
logs = crud.get_logs_by_session(db, session_id=session_id, skip=skip, limit=limit)
async def read_logs_by_session(
self,
session_id: str,
skip: int = 0,
limit: int = 1000,
db: Session = Depends(get_db),
):
logs = crud.get_logs_by_session(
db, session_id=session_id, skip=skip, limit=limit
)
return logs
1 change: 1 addition & 0 deletions llmstudio/tracking/logs/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from sqlalchemy import JSON, Column, DateTime, Integer, String
from sqlalchemy.sql import func

from llmstudio.tracking.database import Base


Expand Down
1 change: 1 addition & 0 deletions llmstudio/tracking/logs/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime
from typing import Any, Dict, List, Optional

from pydantic import BaseModel


Expand Down
Empty file.
41 changes: 41 additions & 0 deletions llmstudio/tracking/session/crud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from sqlalchemy.orm import Session

from llmstudio.tracking.session import models, schemas


def get_project_by_name(db: Session, name: str):
return db.query(models.Project).filter(models.Project.name == name).first()


def get_session_by_id(db: Session, session_id: str):
return (
db.query(models.SessionDefault)
.filter(models.SessionDefault.session_id == session_id)
.first()
)


def add_session(db: Session, session: schemas.SessionDefaultCreate):
db_session = models.SessionDefault(**session.dict())

db.add(db_session)
db.commit()
db.refresh(db_session)
return db_session


def update_session(db: Session, session: schemas.SessionDefaultCreate):
existing_session = get_session_by_id(db, session.session_id)
for key, value in session.dict().items():
setattr(existing_session, key, value)

db.commit()
db.refresh(existing_session)
return existing_session


def upsert_session(db: Session, session: schemas.SessionDefaultCreate):
try:
return update_session(db, session)
except:
return add_session(db, session)
34 changes: 34 additions & 0 deletions llmstudio/tracking/session/endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session

from llmstudio.tracking.database import engine, get_db
from llmstudio.tracking.session import crud, models, schemas

models.Base.metadata.create_all(bind=engine)


class SessionsRoutes:
def __init__(self, router: APIRouter):
self.router = router
self.define_routes()

def define_routes(self):
# Add session
self.router.post(
"/session",
response_model=schemas.SessionDefault,
)(self.add_session)

# Read session
self.router.get("/session/{session_id}", response_model=schemas.SessionDefault)(
self.get_session
)

async def add_session(
self, session: schemas.SessionDefaultCreate, db: Session = Depends(get_db)
):
return crud.upsert_session(db=db, session=session)

async def get_session(self, session_id: str, db: Session = Depends(get_db)):
logs = crud.get_session_by_id(db, session_id=session_id)
return logs
14 changes: 14 additions & 0 deletions llmstudio/tracking/session/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from sqlalchemy import JSON, Column, DateTime, String
from sqlalchemy.sql import func

from llmstudio.tracking.database import Base


class SessionDefault(Base):
__tablename__ = "sessions"
session_id = Column(String, primary_key=True)
chat_history = Column(JSON)
updated_at = Column(
DateTime(timezone=True), onupdate=func.now(), server_default=func.now()
)
created_at = Column(DateTime(timezone=True), server_default=func.now())
18 changes: 18 additions & 0 deletions llmstudio/tracking/session/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from datetime import datetime
from typing import Any, Dict, List

from pydantic import BaseModel


class SessionDefaultBase(BaseModel):
session_id: str
chat_history: List[Dict[str, Any]] = None


class SessionDefault(SessionDefaultBase):
created_at: datetime
updated_at: datetime


class SessionDefaultCreate(SessionDefaultBase):
pass
19 changes: 11 additions & 8 deletions llmstudio/tracking/tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json

import requests

from llmstudio.config import TRACKING_HOST, TRACKING_PORT


Expand All @@ -17,12 +19,13 @@ def log(self, data: dict):
return req

def session(self, data: dict):
req = self._session.post(
f"http://{TRACKING_HOST}:{TRACKING_PORT}/api/tracking/session",
headers={"accept": "application/json", "Content-Type": "application/json"},
data=json.dumps(data),
timeout=100,
)
return req

req = self._session.post(
f"http://{TRACKING_HOST}:{TRACKING_PORT}/api/tracking/session",
headers={"accept": "application/json", "Content-Type": "application/json"},
data=json.dumps(data),
timeout=100,
)
return req


tracker = Tracker()

0 comments on commit 930963b

Please sign in to comment.