Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Feat/session management #86

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llmstudio/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from dotenv import load_dotenv
import socket

from dotenv import load_dotenv

load_dotenv(os.path.join(os.getcwd(), ".env"))


Expand Down
44 changes: 7 additions & 37 deletions llmstudio/tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
import uvicorn
from fastapi import Depends, FastAPI, HTTPException
from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import extract, func
from sqlalchemy.orm import Session

from llmstudio.config import TRACKING_HOST, TRACKING_PORT
from llmstudio.engine.providers import *
from llmstudio.tracking import crud, models, schemas
from llmstudio.tracking.database import SessionLocal, engine

models.Base.metadata.create_all(bind=engine)
from llmstudio.tracking.logs.endpoints import LogsRoutes
from llmstudio.tracking.session.endpoints import SessionsRoutes

TRACKING_HEALTH_ENDPOINT = "/health"
TRACKING_TITLE = "LLMstudio Tracking API"
TRACKING_DESCRIPTION = "The tracking API for LLM interactions"
TRACKING_VERSION = "0.0.1"
TRACKING_BASE_ENDPOINT = "/api/tracking"


## Tracking
def create_tracking_app() -> FastAPI:
app = FastAPI(
Expand All @@ -34,41 +29,16 @@ def create_tracking_app() -> FastAPI:
allow_headers=["*"],
)

def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

@app.get(TRACKING_HEALTH_ENDPOINT)
def health_check():
"""Health check endpoint to ensure the API is running."""
return {"status": "healthy", "message": "Tracking is up and running"}

@app.post(
f"{TRACKING_BASE_ENDPOINT}/logs",
response_model=schemas.LogDefault,
)
def add_log(log: schemas.LogDefaultCreate, db: Session = Depends(get_db)):
return crud.add_log(db=db, log=log)
tracking_router = APIRouter(prefix=TRACKING_BASE_ENDPOINT)
LogsRoutes(tracking_router)
SessionsRoutes(tracking_router)

@app.get(f"{TRACKING_BASE_ENDPOINT}/logs", response_model=list[schemas.LogDefault])
def read_logs(skip: int = 0, limit: int = 1000, db: Session = Depends(get_db)):
logs = crud.get_logs(db, skip=skip, limit=limit)
return logs

@app.get(
f"{TRACKING_BASE_ENDPOINT}/logs_by_session",
response_model=list[schemas.LogDefault],
)
def read_logs_by_session(
session_id: str, skip: int = 0, limit: int = 1000, db: Session = Depends(get_db)
):
logs = crud.get_logs_by_session(
db, session_id=session_id, skip=skip, limit=limit
)
return logs
app.include_router(tracking_router)

@app.on_event("startup")
async def startup_event():
Expand Down
8 changes: 8 additions & 0 deletions llmstudio/tracking/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,11 @@ def create_tracking_engine(uri: str):
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()


def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
Empty file.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy.orm import Session

from llmstudio.tracking import models, schemas
from llmstudio.tracking.logs import models, schemas


def get_project_by_name(db: Session, name: str):
Expand All @@ -22,6 +22,7 @@ def add_log(db: Session, log: schemas.LogDefaultCreate):
db.add(db_log)
db.commit()
db.refresh(db_log)

return db_log


Expand Down
57 changes: 57 additions & 0 deletions llmstudio/tracking/logs/endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import List

from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session

from llmstudio.tracking.database import engine, get_db
from llmstudio.tracking.logs import crud, models, schemas

models.Base.metadata.create_all(bind=engine)


class LogsRoutes:
def __init__(self, router: APIRouter):
self.router = router

# Define routes
self.define_routes()

def define_routes(self):
# Add log
self.router.post(
"/logs",
response_model=schemas.LogDefault,
)(self.add_log)

# Read logs
self.router.get("/logs", response_model=List[schemas.LogDefault])(
self.read_logs
)

# Read logs by session
self.router.get("/logs_by_session", response_model=List[schemas.LogDefault])(
self.read_logs_by_session
)

async def add_log(
self, log: schemas.LogDefaultCreate, db: Session = Depends(get_db)
):
return crud.add_log(db=db, log=log)

async def read_logs(
self, skip: int = 0, limit: int = 1000, db: Session = Depends(get_db)
):
logs = crud.get_logs(db, skip=skip, limit=limit)
return logs

async def read_logs_by_session(
self,
session_id: str,
skip: int = 0,
limit: int = 1000,
db: Session = Depends(get_db),
):
logs = crud.get_logs_by_session(
db, session_id=session_id, skip=skip, limit=limit
)
return logs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from sqlalchemy import JSON, Boolean, Column, DateTime, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from sqlalchemy import JSON, Column, DateTime, Integer, String
from sqlalchemy.sql import func

from llmstudio.tracking.database import Base
Expand Down
File renamed without changes.
Empty file.
41 changes: 41 additions & 0 deletions llmstudio/tracking/session/crud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from sqlalchemy.orm import Session

from llmstudio.tracking.session import models, schemas


def get_project_by_name(db: Session, name: str):
return db.query(models.Project).filter(models.Project.name == name).first()


def get_session_by_id(db: Session, session_id: str):
return (
db.query(models.SessionDefault)
.filter(models.SessionDefault.session_id == session_id)
.first()
)


def add_session(db: Session, session: schemas.SessionDefaultCreate):
db_session = models.SessionDefault(**session.dict())

db.add(db_session)
db.commit()
db.refresh(db_session)
return db_session


def update_session(db: Session, session: schemas.SessionDefaultCreate):
existing_session = get_session_by_id(db, session.session_id)
for key, value in session.dict().items():
setattr(existing_session, key, value)

db.commit()
db.refresh(existing_session)
return existing_session


def upsert_session(db: Session, session: schemas.SessionDefaultCreate):
try:
return update_session(db, session)
except:
return add_session(db, session)
34 changes: 34 additions & 0 deletions llmstudio/tracking/session/endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session

from llmstudio.tracking.database import engine, get_db
from llmstudio.tracking.session import crud, models, schemas

models.Base.metadata.create_all(bind=engine)


class SessionsRoutes:
def __init__(self, router: APIRouter):
self.router = router
self.define_routes()

def define_routes(self):
# Add session
self.router.post(
"/session",
response_model=schemas.SessionDefault,
)(self.add_session)

# Read session
self.router.get("/session/{session_id}", response_model=schemas.SessionDefault)(
self.get_session
)

async def add_session(
self, session: schemas.SessionDefaultCreate, db: Session = Depends(get_db)
):
return crud.upsert_session(db=db, session=session)

async def get_session(self, session_id: str, db: Session = Depends(get_db)):
logs = crud.get_session_by_id(db, session_id=session_id)
return logs
14 changes: 14 additions & 0 deletions llmstudio/tracking/session/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from sqlalchemy import JSON, Column, DateTime, String
from sqlalchemy.sql import func

from llmstudio.tracking.database import Base


class SessionDefault(Base):
__tablename__ = "sessions"
session_id = Column(String, primary_key=True)
chat_history = Column(JSON)
updated_at = Column(
DateTime(timezone=True), onupdate=func.now(), server_default=func.now()
)
created_at = Column(DateTime(timezone=True), server_default=func.now())
18 changes: 18 additions & 0 deletions llmstudio/tracking/session/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from datetime import datetime
from typing import Any, Dict, List

from pydantic import BaseModel


class SessionDefaultBase(BaseModel):
session_id: str
chat_history: List[Dict[str, Any]] = None


class SessionDefault(SessionDefaultBase):
created_at: datetime
updated_at: datetime


class SessionDefaultCreate(SessionDefaultBase):
pass
9 changes: 9 additions & 0 deletions llmstudio/tracking/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,14 @@ def log(self, data: dict):
)
return req

def session(self, data: dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case, we might need more methods:

  1. to add something to the session
  2. to get the session_history

req = self._session.post(
f"http://{TRACKING_HOST}:{TRACKING_PORT}/api/tracking/session",
headers={"accept": "application/json", "Content-Type": "application/json"},
data=json.dumps(data),
timeout=100,
)
return req


tracker = Tracker()
Loading