Skip to content

Commit

Permalink
[BUG]: fix bad OpenAPI generation (#3445)
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb authored Jan 9, 2025
1 parent 2661b76 commit d50a942
Showing 1 changed file with 60 additions and 32 deletions.
92 changes: 60 additions & 32 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CapacityLimiter,
)
from fastapi import FastAPI as _FastAPI, Response, Request
from fastapi.openapi.utils import get_openapi
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse
from fastapi.routing import APIRoute
Expand Down Expand Up @@ -61,6 +62,7 @@
)
from starlette.datastructures import Headers
import logging
import importlib.metadata

from chromadb.telemetry.product.events import ServerStartEvent
from chromadb.utils.fastapi import fastapi_json_response, string_to_uuid as _uuid
Expand Down Expand Up @@ -142,18 +144,6 @@ def validate_model(model: Type[D], data: Any) -> D: # type: ignore
return model.parse_obj(data) # pydantic 1.x


def get_openapi_extras_for_model(request_model: Type[D]) -> Dict[str, Any]:
openapi_extra = {
"requestBody": {
"content": {
"application/json": {"schema": request_model.model_json_schema()}
},
"required": True,
}
}
return openapi_extra


class ChromaAPIRouter(fastapi.APIRouter): # type: ignore
# A simple subclass of fastapi's APIRouter which treats URLs with a
# trailing "/" the same as URLs without. Docs will only contain URLs
Expand Down Expand Up @@ -189,6 +179,10 @@ def __init__(self, settings: Settings):
self._app = fastapi.FastAPI(debug=True, default_response_class=ORJSONResponse)
self._system = System(settings)
self._api: ServerAPI = self._system.instance(ServerAPI)

self._extra_openapi_schemas: Dict[str, Any] = {}
self._app.openapi = self.generate_openapi

self._opentelemetry_client = self._api.require(OpenTelemetryClient)
self._capacity_limiter = CapacityLimiter(
settings.chroma_server_thread_pool_size
Expand Down Expand Up @@ -232,6 +226,37 @@ def __init__(self, settings: Settings):
telemetry_client = self._system.instance(ProductTelemetryClient)
telemetry_client.capture(ServerStartEvent())

def generate_openapi(self) -> Dict[str, Any]:
"""Used instead of the default openapi() generation handler to include manually-populated schemas."""
schema: Dict[str, Any] = get_openapi(
title="Chroma",
routes=self._app.routes,
version=importlib.metadata.version("chromadb"),
)

for key, value in self._extra_openapi_schemas.items():
schema["components"]["schemas"][key] = value

return schema

def get_openapi_extras_for_body_model(
self, request_model: Type[D]
) -> Dict[str, Any]:
schema = request_model.model_json_schema(
ref_template="#/components/schemas/{model}"
)
if "$defs" in schema:
for key, value in schema["$defs"].items():
self._extra_openapi_schemas[key] = value

openapi_extra = {
"requestBody": {
"content": {"application/json": {"schema": schema}},
"required": True,
}
}
return openapi_extra

def setup_v2_routes(self) -> None:
self.router.add_api_route("/api/v2", self.root, methods=["GET"])
self.router.add_api_route("/api/v2/reset", self.reset, methods=["POST"])
Expand All @@ -253,7 +278,7 @@ def setup_v2_routes(self) -> None:
self.create_database,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(CreateDatabase),
openapi_extra=self.get_openapi_extras_for_body_model(CreateDatabase),
)

self.router.add_api_route(
Expand All @@ -268,7 +293,7 @@ def setup_v2_routes(self) -> None:
self.create_tenant,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(CreateTenant),
openapi_extra=self.get_openapi_extras_for_body_model(CreateTenant),
)

self.router.add_api_route(
Expand All @@ -295,7 +320,7 @@ def setup_v2_routes(self) -> None:
self.create_collection,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(CreateCollection),
openapi_extra=self.get_openapi_extras_for_body_model(CreateCollection),
)

self.router.add_api_route(
Expand All @@ -304,35 +329,35 @@ def setup_v2_routes(self) -> None:
methods=["POST"],
status_code=status.HTTP_201_CREATED,
response_model=None,
openapi_extra=get_openapi_extras_for_model(AddEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/update",
self.update,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(UpdateEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(UpdateEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/upsert",
self.upsert,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(AddEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/get",
self.get,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(GetEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(GetEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/delete",
self.delete,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(DeleteEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(DeleteEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/count",
Expand All @@ -345,7 +370,9 @@ def setup_v2_routes(self) -> None:
self.get_nearest_neighbors,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(request_model=QueryEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(
request_model=QueryEmbedding
),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_name}",
Expand All @@ -358,7 +385,7 @@ def setup_v2_routes(self) -> None:
self.update_collection,
methods=["PUT"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(UpdateCollection),
openapi_extra=self.get_openapi_extras_for_body_model(UpdateCollection),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_name}",
Expand Down Expand Up @@ -1138,7 +1165,7 @@ def setup_v1_routes(self) -> None:
self.create_database_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(CreateDatabase),
openapi_extra=self.get_openapi_extras_for_body_model(CreateDatabase),
)

self.router.add_api_route(
Expand All @@ -1153,7 +1180,7 @@ def setup_v1_routes(self) -> None:
self.create_tenant_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(CreateTenant),
openapi_extra=self.get_openapi_extras_for_body_model(CreateTenant),
)

self.router.add_api_route(
Expand All @@ -1180,7 +1207,7 @@ def setup_v1_routes(self) -> None:
self.create_collection_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(CreateCollection),
openapi_extra=self.get_openapi_extras_for_body_model(CreateCollection),
)

self.router.add_api_route(
Expand All @@ -1189,35 +1216,35 @@ def setup_v1_routes(self) -> None:
methods=["POST"],
status_code=status.HTTP_201_CREATED,
response_model=None,
openapi_extra=get_openapi_extras_for_model(AddEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/update",
self.update_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(UpdateEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(UpdateEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/upsert",
self.upsert_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(AddEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/get",
self.get_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(GetEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(GetEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/delete",
self.delete_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(DeleteEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(DeleteEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/count",
Expand All @@ -1230,7 +1257,7 @@ def setup_v1_routes(self) -> None:
self.get_nearest_neighbors_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(QueryEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(QueryEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_name}",
Expand All @@ -1243,7 +1270,7 @@ def setup_v1_routes(self) -> None:
self.update_collection_v1,
methods=["PUT"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(UpdateCollection),
openapi_extra=self.get_openapi_extras_for_body_model(UpdateCollection),
)
self.router.add_api_route(
"/api/v1/collections/{collection_name}",
Expand Down Expand Up @@ -1598,6 +1625,7 @@ async def inner():
),
)
return api_collection_model

return await inner()

@trace_method("FastAPI.update_collection_v1", OpenTelemetryGranularity.OPERATION)
Expand Down

0 comments on commit d50a942

Please sign in to comment.