Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Feat/session management #86

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .github/workflows/upload-pypi-dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Upload Python package to PyPI

on:
push:
branches:
- feat/session_management
paths:
- "llmstudio/**"

jobs:
deploy:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.x"

- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -

- name: Configure Poetry
run: |
poetry config pypi-token.pypi ${{ secrets.PYPI_API_TOKEN }}

- name: Build and publish to PyPI
run: |
poetry version $(poetry version -s).dev1
poetry build
poetry publish
3 changes: 2 additions & 1 deletion llmstudio/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from dotenv import load_dotenv
import socket

from dotenv import load_dotenv

load_dotenv(os.path.join(os.getcwd(), ".env"))


Expand Down
44 changes: 7 additions & 37 deletions llmstudio/tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
import uvicorn
from fastapi import Depends, FastAPI, HTTPException
from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import extract, func
from sqlalchemy.orm import Session

from llmstudio.config import TRACKING_HOST, TRACKING_PORT
from llmstudio.engine.providers import *
from llmstudio.tracking import crud, models, schemas
from llmstudio.tracking.database import SessionLocal, engine

models.Base.metadata.create_all(bind=engine)
from llmstudio.tracking.logs.endpoints import LogsRoutes
from llmstudio.tracking.session.endpoints import SessionsRoutes

TRACKING_HEALTH_ENDPOINT = "/health"
TRACKING_TITLE = "LLMstudio Tracking API"
TRACKING_DESCRIPTION = "The tracking API for LLM interactions"
TRACKING_VERSION = "0.0.1"
TRACKING_BASE_ENDPOINT = "/api/tracking"


## Tracking
def create_tracking_app() -> FastAPI:
app = FastAPI(
Expand All @@ -34,41 +29,16 @@ def create_tracking_app() -> FastAPI:
allow_headers=["*"],
)

def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

@app.get(TRACKING_HEALTH_ENDPOINT)
def health_check():
"""Health check endpoint to ensure the API is running."""
return {"status": "healthy", "message": "Tracking is up and running"}

@app.post(
f"{TRACKING_BASE_ENDPOINT}/logs",
response_model=schemas.LogDefault,
)
def add_log(log: schemas.LogDefaultCreate, db: Session = Depends(get_db)):
return crud.add_log(db=db, log=log)
tracking_router = APIRouter(prefix=TRACKING_BASE_ENDPOINT)
LogsRoutes(tracking_router)
SessionsRoutes(tracking_router)

@app.get(f"{TRACKING_BASE_ENDPOINT}/logs", response_model=list[schemas.LogDefault])
def read_logs(skip: int = 0, limit: int = 1000, db: Session = Depends(get_db)):
logs = crud.get_logs(db, skip=skip, limit=limit)
return logs

@app.get(
f"{TRACKING_BASE_ENDPOINT}/logs_by_session",
response_model=list[schemas.LogDefault],
)
def read_logs_by_session(
session_id: str, skip: int = 0, limit: int = 1000, db: Session = Depends(get_db)
):
logs = crud.get_logs_by_session(
db, session_id=session_id, skip=skip, limit=limit
)
return logs
app.include_router(tracking_router)

@app.on_event("startup")
async def startup_event():
Expand Down
8 changes: 8 additions & 0 deletions llmstudio/tracking/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,11 @@ def create_tracking_engine(uri: str):
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()


def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
Empty file.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy.orm import Session

from llmstudio.tracking import models, schemas
from llmstudio.tracking.logs import models, schemas


def get_project_by_name(db: Session, name: str):
Expand All @@ -22,6 +22,7 @@ def add_log(db: Session, log: schemas.LogDefaultCreate):
db.add(db_log)
db.commit()
db.refresh(db_log)

return db_log


Expand Down
57 changes: 57 additions & 0 deletions llmstudio/tracking/logs/endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import List

from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session

from llmstudio.tracking.database import engine, get_db
from llmstudio.tracking.logs import crud, models, schemas

models.Base.metadata.create_all(bind=engine)


class LogsRoutes:
def __init__(self, router: APIRouter):
self.router = router

# Define routes
self.define_routes()

def define_routes(self):
# Add log
self.router.post(
"/logs",
response_model=schemas.LogDefault,
)(self.add_log)

# Read logs
self.router.get("/logs", response_model=List[schemas.LogDefault])(
self.read_logs
)

# Read logs by session
self.router.get("/logs_by_session", response_model=List[schemas.LogDefault])(
self.read_logs_by_session
)

async def add_log(
self, log: schemas.LogDefaultCreate, db: Session = Depends(get_db)
):
return crud.add_log(db=db, log=log)

async def read_logs(
self, skip: int = 0, limit: int = 1000, db: Session = Depends(get_db)
):
logs = crud.get_logs(db, skip=skip, limit=limit)
return logs

async def read_logs_by_session(
self,
session_id: str,
skip: int = 0,
limit: int = 1000,
db: Session = Depends(get_db),
):
logs = crud.get_logs_by_session(
db, session_id=session_id, skip=skip, limit=limit
)
return logs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from sqlalchemy import JSON, Boolean, Column, DateTime, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from sqlalchemy import JSON, Column, DateTime, Integer, String
from sqlalchemy.sql import func

from llmstudio.tracking.database import Base
Expand Down
File renamed without changes.
Empty file.
48 changes: 48 additions & 0 deletions llmstudio/tracking/session/crud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from sqlalchemy.orm import Session

from llmstudio.tracking.session import models, schemas


def get_project_by_name(db: Session, name: str):
return db.query(models.Project).filter(models.Project.name == name).first()


def get_session_by_session_id(
db: Session, session_id: str, skip: int = 0, limit: int = 100
):
return (
db.query(models.SessionDefault)
.filter(models.SessionDefault.session_id == session_id)
.offset(skip)
.limit(limit)
.all()
)


def get_session_by_message_id(db: Session, message_id: int):
return (
db.query(models.SessionDefault)
.filter(models.SessionDefault.message_id == message_id)
.first()
)


def add_session(db: Session, session: schemas.SessionDefaultCreate):
db_session = models.SessionDefault(**session.dict())

db.add(db_session)
db.commit()
db.refresh(db_session)
return db_session


def update_session(db: Session, message_id: int, extras: dict):
existing_session = get_session_by_message_id(db, message_id)
existing_session.extras = extras
db.commit()
db.refresh(existing_session)
return existing_session


def upsert_session(db: Session, session: schemas.SessionDefaultCreate):
return add_session(db, session)
46 changes: 46 additions & 0 deletions llmstudio/tracking/session/endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import List

from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session

from llmstudio.tracking.database import engine, get_db
from llmstudio.tracking.session import crud, models, schemas

models.Base.metadata.create_all(bind=engine)


class SessionsRoutes:
def __init__(self, router: APIRouter):
self.router = router
self.define_routes()

def define_routes(self):
# Add session
self.router.post(
"/session",
response_model=schemas.SessionDefault,
)(self.add_session)

# Read session
self.router.get(
"/session/{session_id}", response_model=List[schemas.SessionDefault]
)(self.get_session)

self.router.patch(
"/session/{message_id}", response_model=schemas.SessionDefault
)(self.update_session)

async def add_session(
self, session: schemas.SessionDefaultCreate, db: Session = Depends(get_db)
):
return crud.upsert_session(db=db, session=session)

async def update_session(
self, message_id: int, extras: dict, db: Session = Depends(get_db)
):
sessions = crud.update_session(db, message_id=message_id, extras=extras)
return sessions

async def get_session(self, session_id: str, db: Session = Depends(get_db)):
sessions = crud.get_session_by_session_id(db, session_id=session_id)
return sessions
16 changes: 16 additions & 0 deletions llmstudio/tracking/session/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from sqlalchemy import JSON, Column, DateTime, Integer, String
from sqlalchemy.sql import func

from llmstudio.tracking.database import Base


class SessionDefault(Base):
__tablename__ = "sessions"
message_id = Column(Integer, primary_key=True)
session_id = Column(String, index=True)
chat_history = Column(JSON)
extras = Column(JSON)
updated_at = Column(
DateTime(timezone=True), onupdate=func.now(), server_default=func.now()
)
created_at = Column(DateTime(timezone=True), server_default=func.now())
20 changes: 20 additions & 0 deletions llmstudio/tracking/session/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from datetime import datetime
from typing import Any, Dict, List

from pydantic import BaseModel


class SessionDefaultBase(BaseModel):
session_id: str
chat_history: List[Dict[str, Any]] = None
extras: Dict[str, Any] = None


class SessionDefault(SessionDefaultBase):
message_id: int
created_at: datetime
updated_at: datetime


class SessionDefaultCreate(SessionDefaultBase):
pass
25 changes: 25 additions & 0 deletions llmstudio/tracking/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,30 @@ def log(self, data: dict):
)
return req

def update_session(self, data: dict):
req = self._session.post(
f"http://{TRACKING_HOST}:{TRACKING_PORT}/api/tracking/session",
headers={"accept": "application/json", "Content-Type": "application/json"},
data=json.dumps(data),
timeout=100,
)
return req

def get_session(self, session_id: str):
req = self._session.post(
f"http://{TRACKING_HOST}:{TRACKING_PORT}/api/tracking/session/{session_id}",
headers={"accept": "application/json", "Content-Type": "application/json"},
timeout=100,
)
return req

def add_extras(self, message_id: int):
req = self._session.patch(
f"http://{TRACKING_HOST}:{TRACKING_PORT}/api/tracking/session/{message_id}",
headers={"accept": "application/json", "Content-Type": "application/json"},
timeout=100,
)
return req


tracker = Tracker()
Loading