Skip to content

Commit

Permalink
[ENH] Add num_records_last_compaction to sysdb (#3463)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - New functionality
   - Adds `num_records_last_compaction` to distributed sysdb

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
drewkim authored and codetheweb committed Jan 23, 2025
1 parent d7b9a49 commit 3440b66
Show file tree
Hide file tree
Showing 85 changed files with 1,884 additions and 350 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 17 additions & 1 deletion chromadb/db/impl/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
DeleteSegmentRequest,
GetCollectionsRequest,
GetCollectionsResponse,
GetCollectionSizeRequest,
GetCollectionSizeResponse,
GetCollectionWithSegmentsRequest,
GetCollectionWithSegmentsResponse,
GetDatabaseRequest,
Expand Down Expand Up @@ -419,7 +421,21 @@ def get_collections(
)
raise InternalError()

@trace_method("SysDB.get_collection_with_segments", OpenTelemetryGranularity.OPERATION)
@overrides
def get_collection_size(self, id: UUID) -> int:
try:
request = GetCollectionSizeRequest(id=id.hex)
response: GetCollectionSizeResponse = self._sys_db_stub.GetCollectionSize(
request
)
return response.total_records_post_compaction
except grpc.RpcError as e:
logger.error(f"Failed to get collection {id} size due to error: {e}")
raise InternalError()

@trace_method(
"SysDB.get_collection_with_segments", OpenTelemetryGranularity.OPERATION
)
@overrides
def get_collection_with_segments(
self, collection_id: UUID
Expand Down
8 changes: 8 additions & 0 deletions chromadb/db/impl/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
DeleteSegmentResponse,
GetCollectionsRequest,
GetCollectionsResponse,
GetCollectionSizeRequest,
GetCollectionSizeResponse,
GetCollectionWithSegmentsRequest,
GetCollectionWithSegmentsResponse,
GetDatabaseRequest,
Expand Down Expand Up @@ -371,6 +373,12 @@ def GetCollections(
to_proto_collection(collection) for collection in found_collections
]
)

@overrides(check_signature=False)
def GetCollectionSize(self, request: GetCollectionSizeRequest, context: grpc.ServicerContext) -> GetCollectionSizeResponse:
return GetCollectionSizeResponse(
total_records_post_compaction = 0,
)

@overrides(check_signature=False)
def GetCollectionWithSegments(
Expand Down
4 changes: 4 additions & 0 deletions chromadb/db/mixins/sysdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,3 +961,7 @@ def _insert_config_from_legacy_params(
with self.tx() as cur:
cur.execute(sql, params)
return configuration

@override
def get_collection_size(self, id: UUID) -> int:
raise NotImplementedError
5 changes: 5 additions & 0 deletions chromadb/db/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,8 @@ def update_collection(
keys with None values will be removed and keys not present in the UpdateMetadata
dict will be left unchanged."""
pass

@abstractmethod
def get_collection_size(self, id: UUID) -> int:
"""Returns the number of records in a collection."""
pass
2 changes: 2 additions & 0 deletions chromadb/quota/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import abstractmethod
from enum import Enum
from typing import Dict, Any, Optional
from uuid import UUID

from chromadb.api.types import (
Embeddings,
Expand Down Expand Up @@ -62,6 +63,7 @@ def enforce(
where_document: Optional[WhereDocument] = None,
n_results: Optional[int] = None,
query_embeddings: Optional[Embeddings] = None,
collection_id: Optional[UUID] = None,
) -> None:
"""
Enforces a quota.
Expand Down
2 changes: 2 additions & 0 deletions chromadb/quota/simple_quota_enforcer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from overrides import override
from typing import Any, Callable, TypeVar, Dict, Optional
from uuid import UUID

from chromadb.api.types import (
Embeddings,
Expand Down Expand Up @@ -48,5 +49,6 @@ def enforce(
where_document: Optional[WhereDocument] = None,
n_results: Optional[int] = None,
query_embeddings: Optional[Embeddings] = None,
collection_id: Optional[UUID] = None,
) -> None:
pass
25 changes: 17 additions & 8 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import (
Any,
Awaitable,
Callable,
cast,
Dict,
Expand All @@ -26,6 +25,7 @@

from chromadb.api.configuration import CollectionConfigurationInternal
from pydantic import BaseModel
from chromadb import __version__ as chromadb_version
from chromadb.api.types import (
Embedding,
GetResult,
Expand Down Expand Up @@ -65,7 +65,6 @@
)
from starlette.datastructures import Headers
import logging
import importlib.metadata

from chromadb.telemetry.product.events import ServerStartEvent
from chromadb.utils.fastapi import fastapi_json_response, string_to_uuid as _uuid
Expand All @@ -90,6 +89,7 @@ def rate_limit(func):
async def wrapper(*args: Any, **kwargs: Any) -> Any:
self = args[0]
return await self._async_rate_limit_enforcer.rate_limit(func)(*args, **kwargs)

return wrapper


Expand Down Expand Up @@ -243,7 +243,7 @@ def generate_openapi(self) -> Dict[str, Any]:
schema: Dict[str, Any] = get_openapi(
title="Chroma",
routes=self._app.routes,
version=importlib.metadata.version("chromadb"),
version=chromadb_version,
)

for key, value in self._extra_openapi_schemas.items():
Expand Down Expand Up @@ -470,7 +470,9 @@ async def auth_request(
database: Optional[str],
collection: Optional[str],
) -> None:
return await to_thread.run_sync(self.sync_auth_request, *(headers, action, tenant, database, collection))
return await to_thread.run_sync(
self.sync_auth_request, *(headers, action, tenant, database, collection)
)

@trace_method(
"FastAPI.sync_auth_request",
Expand Down Expand Up @@ -631,7 +633,6 @@ def process_create_tenant(request: Request, raw_body: bytes) -> None:
None,
)


return self._api.create_tenant(tenant.name)

await to_thread.run_sync(
Expand Down Expand Up @@ -1186,7 +1187,9 @@ async def get_nearest_neighbors(
collection_id: str,
request: Request,
) -> QueryResult:
@trace_method("internal.get_nearest_neighbors", OpenTelemetryGranularity.OPERATION)
@trace_method(
"internal.get_nearest_neighbors", OpenTelemetryGranularity.OPERATION
)
def process_query(request: Request, raw_body: bytes) -> QueryResult:
query = validate_model(QueryEmbedding, orjson.loads(raw_body))

Expand Down Expand Up @@ -1406,7 +1409,14 @@ async def auth_and_get_tenant_and_database_for_request(
(can be overwritten separately)
- The user has access to a single tenant and/or single database.
"""
return await to_thread.run_sync(self.auth_and_get_tenant_and_database_for_request, headers, action, tenant, database, collection)
return await to_thread.run_sync(
self.auth_and_get_tenant_and_database_for_request,
headers,
action,
tenant,
database,
collection,
)

def sync_auth_and_get_tenant_and_database_for_request(
self,
Expand All @@ -1416,7 +1426,6 @@ def sync_auth_and_get_tenant_and_database_for_request(
database: Optional[str],
collection: Optional[str],
) -> Tuple[Optional[str], Optional[str]]:

if not self.authn_provider:
add_attributes_to_current_span(
{
Expand Down
10 changes: 10 additions & 0 deletions chromadb/test/db/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,16 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None:
assert by_collection_result == []


def test_get_collection_size(sysdb: SysDB) -> None:
if not isinstance(sysdb, GrpcSysDB):
pytest.skip("Skipping because this functionality is only supported by GrpcSysDB")

sysdb.reset_state()

for collection in sample_collections:
collection_size = sysdb.get_collection_size(collection.id)
assert collection_size == 0

def test_update_collections(sysdb: SysDB) -> None:
coll = Collection(
name=sample_collections[0].name,
Expand Down
77 changes: 47 additions & 30 deletions chromadb/test/ef/test_ollama_ef.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,50 @@
import os

import pytest
import httpx
from httpx import HTTPError, ConnectError

from chromadb.utils.embedding_functions import OllamaEmbeddingFunction


def test_ollama() -> None:
"""
To set up the Ollama server, follow instructions at: https://github.com/ollama/ollama?tab=readme-ov-file
Export the OLLAMA_SERVER_URL and OLLAMA_MODEL environment variables.
"""
if (
os.environ.get("OLLAMA_SERVER_URL") is None
or os.environ.get("OLLAMA_MODEL") is None
):
pytest.skip(
"OLLAMA_SERVER_URL or OLLAMA_MODEL environment variable not set. Skipping test."
)
try:
response = httpx.get(os.environ.get("OLLAMA_SERVER_URL", ""))
# If the response was successful, no Exception will be raised
response.raise_for_status()
except (HTTPError, ConnectError):
pytest.skip("Ollama server not running. Skipping test.")
ef = OllamaEmbeddingFunction(
model_name=os.environ.get("OLLAMA_MODEL") or "nomic-embed-text",
url=f"{os.environ.get('OLLAMA_SERVER_URL')}/embeddings",
)

from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)


def test_ollama_default_model() -> None:
pytest.importorskip("ollama", reason="ollama not installed")
ef = OllamaEmbeddingFunction()
embeddings = ef(["Here is an article about llamas...", "this is another article"])
assert embeddings is not None
assert len(embeddings) == 2
assert all(len(e) == 384 for e in embeddings)


def test_ollama_unknown_model() -> None:
pytest.importorskip("ollama", reason="ollama not installed")
model_name = "unknown-model"
ef = OllamaEmbeddingFunction(model_name=model_name)
with pytest.raises(Exception) as e:
ef(["Here is an article about llamas...", "this is another article"])
assert f'model "{model_name}" not found' in str(e.value)


def test_ollama_backward_compat() -> None:
pytest.importorskip("ollama", reason="ollama not installed")
ef = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings")
embeddings = ef(["Here is an article about llamas...", "this is another article"])
assert embeddings is not None


def test_wrong_url() -> None:
pytest.importorskip("ollama", reason="ollama not installed")
ef = OllamaEmbeddingFunction(url="http://localhost:11434/this_is_wrong")
with pytest.raises(Exception) as e:
ef(["Here is an article about llamas...", "this is another article"])
assert "404" in str(e.value)


def test_ollama_ask_user_to_install() -> None:
try:
from ollama import Client # noqa: F401
except ImportError:
pass
else:
pytest.skip("ollama python package is installed")
with pytest.raises(ValueError) as e:
OllamaEmbeddingFunction()
assert "The ollama python package is not installed" in str(e.value)
57 changes: 31 additions & 26 deletions chromadb/utils/embedding_functions/ollama_embedding_function.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,47 @@
import logging
from typing import Union, cast

import httpx
from typing import Union, cast, Optional
from urllib.parse import urlparse

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings

logger = logging.getLogger(__name__)

DEFAULT_MODEL_NAME = "chroma/all-minilm-l6-v2-f32"


class OllamaEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings).
"""

def __init__(self, url: str, model_name: str) -> None:
def __init__(
self,
url: Optional[str] = "http://localhost:11434",
model_name: Optional[str] = DEFAULT_MODEL_NAME,
*,
timeout: Optional[int] = 60,
) -> None:
"""
Initialize the Ollama Embedding Function.
Args:
url (str): The URL of the Ollama Server.
model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text" (see https://ollama.com/library for available models).
url (str): The Base URL of the Ollama Server (default: "http://localhost:11434").
model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text" (see defaults to "chroma/all-minilm-l6-v2-f32", for available models see https://ollama.com/library).
"""
self._api_url = f"{url}"
self._model_name = model_name
self._session = httpx.Client()

try:
from ollama import Client
except ImportError:
raise ValueError(
"The ollama python package is not installed. Please install it with `pip install ollama`"
)
# adding this for backwards compatibility with the old version of the EF
self._base_url = url
if self._base_url.endswith("/api/embeddings"):
parsed_url = urlparse(url)
self._base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
self._client = Client(host=self._base_url, timeout=timeout)
self._model_name = model_name or DEFAULT_MODEL_NAME

def __call__(self, input: Union[Documents, str]) -> Embeddings:
"""
Expand All @@ -36,23 +54,10 @@ def __call__(self, input: Union[Documents, str]) -> Embeddings:
Embeddings: The embeddings for the texts.
Example:
>>> ollama_ef = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="nomic-embed-text")
>>> ollama_ef = OllamaEmbeddingFunction()
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = ollama_ef(texts)
"""
# Call Ollama Server API for each document
texts = input if isinstance(input, list) else [input]
embeddings = [
self._session.post(
self._api_url, json={"model": self._model_name, "prompt": text}
).json()
for text in texts
]
return cast(
Embeddings,
[
embedding["embedding"]
for embedding in embeddings
if "embedding" in embedding
],
)
# Call Ollama client
response = self._client.embed(model=self._model_name, input=input)
return cast(Embeddings, response["embeddings"])
Loading

0 comments on commit 3440b66

Please sign in to comment.