Skip to content

Commit

Permalink
Add missing endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Nov 20, 2023
1 parent 9391065 commit 815b312
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 16 deletions.
48 changes: 38 additions & 10 deletions backend/app/api/assistants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Annotated, List, Optional
from uuid import uuid4

from fastapi import APIRouter, Cookie, Path, Query
from fastapi import APIRouter, HTTPException, Path, Query
from pydantic import BaseModel, Field

import app.storage as storage
Expand All @@ -14,6 +15,17 @@
]


class AssistantPayload(BaseModel):
"""Payload for creating an assistant."""

name: str = Field(..., description="The name of the assistant.")
config: dict = Field(..., description="The assistant config.")
public: bool = Field(default=False, description="Whether the assistant is public.")


AssistantID = Annotated[str, Path(description="The ID of the assistant.")]


@router.get("/")
def list_assistants(opengpts_user_id: OpengptsUserId) -> List[AssistantWithoutUserId]:
"""List all assistants for the current user."""
Expand All @@ -32,20 +44,36 @@ def list_public_assistants(
)


class AssistantPayload(BaseModel):
"""Payload for creating an assistant."""

name: str = Field(..., description="The name of the assistant.")
config: dict = Field(..., description="The assistant config.")
public: bool = Field(default=False, description="Whether the assistant is public.")
@router.get("/{aid}")
def get_asistant(
opengpts_user_id: OpengptsUserId,
aid: AssistantID,
) -> Assistant:
"""Get an assistant by ID."""
assistant = storage.get_assistant(opengpts_user_id, aid)
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
return assistant


AssistantID = Annotated[str, Path(description="The ID of the assistant.")]
@router.post("")
def create_assistant(
opengpts_user_id: OpengptsUserId,
payload: AssistantPayload,
) -> Assistant:
"""Create an assistant."""
return storage.put_assistant(
opengpts_user_id,
str(uuid4()),
name=payload.name,
config=payload.config,
public=payload.public,
)


@router.put("/{aid}")
def put_assistant(
opengpts_user_id: Annotated[str, Cookie()],
def upsert_assistant(
opengpts_user_id: OpengptsUserId,
aid: AssistantID,
payload: AssistantPayload,
) -> Assistant:
Expand Down
37 changes: 31 additions & 6 deletions backend/app/api/threads.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Annotated, List
from uuid import uuid4

from fastapi import APIRouter, Path
from fastapi import APIRouter, HTTPException, Path
from pydantic import BaseModel, Field

import app.storage as storage
Expand All @@ -20,24 +21,48 @@ class ThreadPutRequest(BaseModel):


@router.get("/")
def list_threads_endpoint(
opengpts_user_id: OpengptsUserId
) -> List[ThreadWithoutUserId]:
def list_threads(opengpts_user_id: OpengptsUserId) -> List[ThreadWithoutUserId]:
"""List all threads for the current user."""
return storage.list_threads(opengpts_user_id)


@router.get("/{tid}/messages")
def get_thread_messages_endpoint(
def get_thread_messages(
opengpts_user_id: OpengptsUserId,
tid: ThreadID,
):
"""Get all messages for a thread."""
return storage.get_thread_messages(opengpts_user_id, tid)


@router.get("/{tid}")
def get_thread(
opengpts_user_id: OpengptsUserId,
tid: ThreadID,
) -> Thread:
"""Get a thread by ID."""
thread = storage.get_thread(opengpts_user_id, tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return thread


@router.post("")
def create_thread(
opengpts_user_id: OpengptsUserId,
thread_put_request: ThreadPutRequest,
) -> Thread:
"""Create a thread."""
return storage.put_thread(
opengpts_user_id,
str(uuid4()),
assistant_id=thread_put_request.assistant_id,
name=thread_put_request.name,
)


@router.put("/{tid}")
def put_thread_endpoint(
def upsert_thread(
opengpts_user_id: OpengptsUserId,
tid: ThreadID,
thread_put_request: ThreadPutRequest,
Expand Down
7 changes: 7 additions & 0 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ def list_threads(user_id: str) -> List[ThreadWithoutUserId]:
return [load(thread_hash_keys, values) for values in threads]


def get_thread(user_id: str, thread_id: str) -> Thread | None:
"""Get a thread by ID."""
client = _get_redis_client()
values = client.hmget(thread_key(user_id, thread_id), *thread_hash_keys)
return load(thread_hash_keys, values) if any(values) else None


def get_thread_messages(user_id: str, thread_id: str):
"""Get all messages for a thread."""
client = RedisCheckpoint()
Expand Down

0 comments on commit 815b312

Please sign in to comment.