diff --git a/Cargo.lock b/Cargo.lock index 6569f3008d6..763425d53c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1370,6 +1370,7 @@ dependencies = [ "tracing-bunyan-formatter", "tracing-opentelemetry", "tracing-subscriber", + "uuid", ] [[package]] @@ -6864,6 +6865,7 @@ dependencies = [ "chroma-sysdb", "chroma-system", "chroma-types", + "clap", "criterion", "fastrace", "fastrace-opentelemetry", diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py index ab011ea685a..805ed7e54bc 100644 --- a/chromadb/db/impl/grpc/client.py +++ b/chromadb/db/impl/grpc/client.py @@ -22,6 +22,8 @@ DeleteSegmentRequest, GetCollectionsRequest, GetCollectionsResponse, + GetCollectionSizeRequest, + GetCollectionSizeResponse, GetCollectionWithSegmentsRequest, GetCollectionWithSegmentsResponse, GetDatabaseRequest, @@ -419,7 +421,21 @@ def get_collections( ) raise InternalError() - @trace_method("SysDB.get_collection_with_segments", OpenTelemetryGranularity.OPERATION) + @overrides + def get_collection_size(self, id: UUID) -> int: + try: + request = GetCollectionSizeRequest(id=id.hex) + response: GetCollectionSizeResponse = self._sys_db_stub.GetCollectionSize( + request + ) + return response.total_records_post_compaction + except grpc.RpcError as e: + logger.error(f"Failed to get collection {id} size due to error: {e}") + raise InternalError() + + @trace_method( + "SysDB.get_collection_with_segments", OpenTelemetryGranularity.OPERATION + ) @overrides def get_collection_with_segments( self, collection_id: UUID diff --git a/chromadb/db/impl/grpc/server.py b/chromadb/db/impl/grpc/server.py index b6948ee9615..e7759194974 100644 --- a/chromadb/db/impl/grpc/server.py +++ b/chromadb/db/impl/grpc/server.py @@ -28,6 +28,8 @@ DeleteSegmentResponse, GetCollectionsRequest, GetCollectionsResponse, + GetCollectionSizeRequest, + GetCollectionSizeResponse, GetCollectionWithSegmentsRequest, GetCollectionWithSegmentsResponse, GetDatabaseRequest, @@ -371,6 +373,12 @@ def GetCollections( to_proto_collection(collection) for collection in found_collections ] ) + + @overrides(check_signature=False) + def GetCollectionSize(self, request: GetCollectionSizeRequest, context: grpc.ServicerContext) -> GetCollectionSizeResponse: + return GetCollectionSizeResponse( + total_records_post_compaction = 0, + ) @overrides(check_signature=False) def GetCollectionWithSegments( diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index c31b7388016..dc3fcfdd8de 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -961,3 +961,7 @@ def _insert_config_from_legacy_params( with self.tx() as cur: cur.execute(sql, params) return configuration + + @override + def get_collection_size(self, id: UUID) -> int: + raise NotImplementedError \ No newline at end of file diff --git a/chromadb/db/system.py b/chromadb/db/system.py index 40bc4ab468e..8cc51b67727 100644 --- a/chromadb/db/system.py +++ b/chromadb/db/system.py @@ -165,3 +165,8 @@ def update_collection( keys with None values will be removed and keys not present in the UpdateMetadata dict will be left unchanged.""" pass + + @abstractmethod + def get_collection_size(self, id: UUID) -> int: + """Returns the number of records in a collection.""" + pass \ No newline at end of file diff --git a/chromadb/quota/__init__.py b/chromadb/quota/__init__.py index e61b945fc26..0314a7aaa5b 100644 --- a/chromadb/quota/__init__.py +++ b/chromadb/quota/__init__.py @@ -1,6 +1,7 @@ from abc import abstractmethod from enum import Enum from typing import Dict, Any, Optional +from uuid import UUID from chromadb.api.types import ( Embeddings, @@ -62,6 +63,7 @@ def enforce( where_document: Optional[WhereDocument] = None, n_results: Optional[int] = None, query_embeddings: Optional[Embeddings] = None, + collection_id: Optional[UUID] = None, ) -> None: """ Enforces a quota. diff --git a/chromadb/quota/simple_quota_enforcer/__init__.py b/chromadb/quota/simple_quota_enforcer/__init__.py index f3a9a4035bb..4ad166a6c0b 100644 --- a/chromadb/quota/simple_quota_enforcer/__init__.py +++ b/chromadb/quota/simple_quota_enforcer/__init__.py @@ -1,5 +1,6 @@ from overrides import override from typing import Any, Callable, TypeVar, Dict, Optional +from uuid import UUID from chromadb.api.types import ( Embeddings, @@ -48,5 +49,6 @@ def enforce( where_document: Optional[WhereDocument] = None, n_results: Optional[int] = None, query_embeddings: Optional[Embeddings] = None, + collection_id: Optional[UUID] = None, ) -> None: pass diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index b875bf5e5fe..4d0e05ca716 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -1,6 +1,5 @@ from typing import ( Any, - Awaitable, Callable, cast, Dict, @@ -26,6 +25,7 @@ from chromadb.api.configuration import CollectionConfigurationInternal from pydantic import BaseModel +from chromadb import __version__ as chromadb_version from chromadb.api.types import ( Embedding, GetResult, @@ -65,7 +65,6 @@ ) from starlette.datastructures import Headers import logging -import importlib.metadata from chromadb.telemetry.product.events import ServerStartEvent from chromadb.utils.fastapi import fastapi_json_response, string_to_uuid as _uuid @@ -90,6 +89,7 @@ def rate_limit(func): async def wrapper(*args: Any, **kwargs: Any) -> Any: self = args[0] return await self._async_rate_limit_enforcer.rate_limit(func)(*args, **kwargs) + return wrapper @@ -243,7 +243,7 @@ def generate_openapi(self) -> Dict[str, Any]: schema: Dict[str, Any] = get_openapi( title="Chroma", routes=self._app.routes, - version=importlib.metadata.version("chromadb"), + version=chromadb_version, ) for key, value in self._extra_openapi_schemas.items(): @@ -470,7 +470,9 @@ async def auth_request( database: Optional[str], collection: Optional[str], ) -> None: - return await to_thread.run_sync(self.sync_auth_request, *(headers, action, tenant, database, collection)) + return await to_thread.run_sync( + self.sync_auth_request, *(headers, action, tenant, database, collection) + ) @trace_method( "FastAPI.sync_auth_request", @@ -631,7 +633,6 @@ def process_create_tenant(request: Request, raw_body: bytes) -> None: None, ) - return self._api.create_tenant(tenant.name) await to_thread.run_sync( @@ -1186,7 +1187,9 @@ async def get_nearest_neighbors( collection_id: str, request: Request, ) -> QueryResult: - @trace_method("internal.get_nearest_neighbors", OpenTelemetryGranularity.OPERATION) + @trace_method( + "internal.get_nearest_neighbors", OpenTelemetryGranularity.OPERATION + ) def process_query(request: Request, raw_body: bytes) -> QueryResult: query = validate_model(QueryEmbedding, orjson.loads(raw_body)) @@ -1406,7 +1409,14 @@ async def auth_and_get_tenant_and_database_for_request( (can be overwritten separately) - The user has access to a single tenant and/or single database. """ - return await to_thread.run_sync(self.auth_and_get_tenant_and_database_for_request, headers, action, tenant, database, collection) + return await to_thread.run_sync( + self.auth_and_get_tenant_and_database_for_request, + headers, + action, + tenant, + database, + collection, + ) def sync_auth_and_get_tenant_and_database_for_request( self, @@ -1416,7 +1426,6 @@ def sync_auth_and_get_tenant_and_database_for_request( database: Optional[str], collection: Optional[str], ) -> Tuple[Optional[str], Optional[str]]: - if not self.authn_provider: add_attributes_to_current_span( { diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index 20384e739c1..b92e9b144c4 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -273,6 +273,16 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None: assert by_collection_result == [] +def test_get_collection_size(sysdb: SysDB) -> None: + if not isinstance(sysdb, GrpcSysDB): + pytest.skip("Skipping because this functionality is only supported by GrpcSysDB") + + sysdb.reset_state() + + for collection in sample_collections: + collection_size = sysdb.get_collection_size(collection.id) + assert collection_size == 0 + def test_update_collections(sysdb: SysDB) -> None: coll = Collection( name=sample_collections[0].name, diff --git a/chromadb/test/ef/test_ollama_ef.py b/chromadb/test/ef/test_ollama_ef.py index 413bf091bf7..f9c22ef9f38 100644 --- a/chromadb/test/ef/test_ollama_ef.py +++ b/chromadb/test/ef/test_ollama_ef.py @@ -1,33 +1,50 @@ -import os - import pytest -import httpx -from httpx import HTTPError, ConnectError - -from chromadb.utils.embedding_functions import OllamaEmbeddingFunction - - -def test_ollama() -> None: - """ - To set up the Ollama server, follow instructions at: https://github.com/ollama/ollama?tab=readme-ov-file - Export the OLLAMA_SERVER_URL and OLLAMA_MODEL environment variables. - """ - if ( - os.environ.get("OLLAMA_SERVER_URL") is None - or os.environ.get("OLLAMA_MODEL") is None - ): - pytest.skip( - "OLLAMA_SERVER_URL or OLLAMA_MODEL environment variable not set. Skipping test." - ) - try: - response = httpx.get(os.environ.get("OLLAMA_SERVER_URL", "")) - # If the response was successful, no Exception will be raised - response.raise_for_status() - except (HTTPError, ConnectError): - pytest.skip("Ollama server not running. Skipping test.") - ef = OllamaEmbeddingFunction( - model_name=os.environ.get("OLLAMA_MODEL") or "nomic-embed-text", - url=f"{os.environ.get('OLLAMA_SERVER_URL')}/embeddings", - ) + +from chromadb.utils.embedding_functions.ollama_embedding_function import ( + OllamaEmbeddingFunction, +) + + +def test_ollama_default_model() -> None: + pytest.importorskip("ollama", reason="ollama not installed") + ef = OllamaEmbeddingFunction() embeddings = ef(["Here is an article about llamas...", "this is another article"]) + assert embeddings is not None assert len(embeddings) == 2 + assert all(len(e) == 384 for e in embeddings) + + +def test_ollama_unknown_model() -> None: + pytest.importorskip("ollama", reason="ollama not installed") + model_name = "unknown-model" + ef = OllamaEmbeddingFunction(model_name=model_name) + with pytest.raises(Exception) as e: + ef(["Here is an article about llamas...", "this is another article"]) + assert f'model "{model_name}" not found' in str(e.value) + + +def test_ollama_backward_compat() -> None: + pytest.importorskip("ollama", reason="ollama not installed") + ef = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings") + embeddings = ef(["Here is an article about llamas...", "this is another article"]) + assert embeddings is not None + + +def test_wrong_url() -> None: + pytest.importorskip("ollama", reason="ollama not installed") + ef = OllamaEmbeddingFunction(url="http://localhost:11434/this_is_wrong") + with pytest.raises(Exception) as e: + ef(["Here is an article about llamas...", "this is another article"]) + assert "404" in str(e.value) + + +def test_ollama_ask_user_to_install() -> None: + try: + from ollama import Client # noqa: F401 + except ImportError: + pass + else: + pytest.skip("ollama python package is installed") + with pytest.raises(ValueError) as e: + OllamaEmbeddingFunction() + assert "The ollama python package is not installed" in str(e.value) diff --git a/chromadb/utils/embedding_functions/ollama_embedding_function.py b/chromadb/utils/embedding_functions/ollama_embedding_function.py index a6293e36075..def39f140a2 100644 --- a/chromadb/utils/embedding_functions/ollama_embedding_function.py +++ b/chromadb/utils/embedding_functions/ollama_embedding_function.py @@ -1,29 +1,47 @@ import logging -from typing import Union, cast - -import httpx +from typing import Union, cast, Optional +from urllib.parse import urlparse from chromadb.api.types import Documents, EmbeddingFunction, Embeddings logger = logging.getLogger(__name__) +DEFAULT_MODEL_NAME = "chroma/all-minilm-l6-v2-f32" + class OllamaEmbeddingFunction(EmbeddingFunction[Documents]): """ This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings). """ - def __init__(self, url: str, model_name: str) -> None: + def __init__( + self, + url: Optional[str] = "http://localhost:11434", + model_name: Optional[str] = DEFAULT_MODEL_NAME, + *, + timeout: Optional[int] = 60, + ) -> None: """ Initialize the Ollama Embedding Function. Args: - url (str): The URL of the Ollama Server. - model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text" (see https://ollama.com/library for available models). + url (str): The Base URL of the Ollama Server (default: "http://localhost:11434"). + model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text" (see defaults to "chroma/all-minilm-l6-v2-f32", for available models see https://ollama.com/library). """ - self._api_url = f"{url}" - self._model_name = model_name - self._session = httpx.Client() + + try: + from ollama import Client + except ImportError: + raise ValueError( + "The ollama python package is not installed. Please install it with `pip install ollama`" + ) + # adding this for backwards compatibility with the old version of the EF + self._base_url = url + if self._base_url.endswith("/api/embeddings"): + parsed_url = urlparse(url) + self._base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + self._client = Client(host=self._base_url, timeout=timeout) + self._model_name = model_name or DEFAULT_MODEL_NAME def __call__(self, input: Union[Documents, str]) -> Embeddings: """ @@ -36,23 +54,10 @@ def __call__(self, input: Union[Documents, str]) -> Embeddings: Embeddings: The embeddings for the texts. Example: - >>> ollama_ef = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="nomic-embed-text") + >>> ollama_ef = OllamaEmbeddingFunction() >>> texts = ["Hello, world!", "How are you?"] >>> embeddings = ollama_ef(texts) """ - # Call Ollama Server API for each document - texts = input if isinstance(input, list) else [input] - embeddings = [ - self._session.post( - self._api_url, json={"model": self._model_name, "prompt": text} - ).json() - for text in texts - ] - return cast( - Embeddings, - [ - embedding["embedding"] - for embedding in embeddings - if "embedding" in embedding - ], - ) + # Call Ollama client + response = self._client.embed(model=self._model_name, input=input) + return cast(Embeddings, response["embeddings"]) diff --git a/clients/js/src/AdminClient.ts b/clients/js/src/AdminClient.ts index d66e635d97f..7f42fd74f0e 100644 --- a/clients/js/src/AdminClient.ts +++ b/clients/js/src/AdminClient.ts @@ -16,8 +16,9 @@ interface Tenant { name: string; } -// interface for tenant interface Database { + id: string; + tenant: string; name: string; } @@ -203,7 +204,7 @@ export class AdminClient { }: { name: string; tenantName: string; - }): Promise { + }): Promise<{ name: string }> { await this.api.createDatabase(tenantName, { name }, this.api.options); return { name }; @@ -234,13 +235,13 @@ export class AdminClient { name: string; tenantName: string; }): Promise { - const getDatabase = (await this.api.getDatabase( + const result = (await this.api.getDatabase( name, tenantName, this.api.options, )) as Database; - return { name: getDatabase.name } as Database; + return result; } /** @@ -282,13 +283,11 @@ export class AdminClient { offset?: number; tenantName: string; }): Promise { - const listDatabases = (await this.api.listDatabases( + return (await this.api.listDatabases( tenantName, limit, offset, this.api.options, )) as Database[]; - - return listDatabases.map((db) => ({ name: db.name })); } } diff --git a/clients/js/src/ChromaClient.ts b/clients/js/src/ChromaClient.ts index ac6c9bc170d..d6037f8fc6a 100644 --- a/clients/js/src/ChromaClient.ts +++ b/clients/js/src/ChromaClient.ts @@ -6,6 +6,7 @@ import { DefaultEmbeddingFunction } from "./embeddings/DefaultEmbeddingFunction" import { Configuration, ApiApi as DefaultApi } from "./generated"; import type { ChromaClientParams, + CollectionMetadata, CollectionParams, ConfigOptions, CreateCollectionParams, @@ -288,7 +289,7 @@ export class ChromaClient { } /** - * Lists all collections. + * Get all collection names. * * @returns {Promise} A promise that resolves to a list of collection names. * @param {PositiveInteger} [params.limit] - Optional limit on the number of items to get. @@ -314,7 +315,42 @@ export class ChromaClient { offset, this.api.options, )) as Collection[]; - return collections.map((collection: Collection) => collection.name); + return collections.map((collection) => collection.name); + } + + /** + * List collection names, IDs, and metadata. + * + * @param {PositiveInteger} [params.limit] - Optional limit on the number of items to get. + * @param {PositiveInteger} [params.offset] - Optional offset on the items to get. + * @throws {Error} If there is an issue listing the collections. + * @returns {Promise<{ name: string, id: string, metadata?: CollectionMetadata }[]>} A promise that resolves to a list of collection names, IDs, and metadata. + * + * @example + * ```typescript + * const collections = await client.listCollectionsAndMetadata({ + * limit: 10, + * offset: 0, + * }); + */ + async listCollectionsAndMetadata({ + limit, + offset, + }: ListCollectionsParams = {}): Promise< + { + name: string; + id: string; + metadata?: CollectionMetadata; + }[] + > { + await this.init(); + return (await this.api.listCollections( + this.tenant, + this.database, + limit, + offset, + this.api.options, + )) as CollectionParams[]; } /** diff --git a/clients/js/src/index.ts b/clients/js/src/index.ts index fe9e64e2f18..06535907e29 100644 --- a/clients/js/src/index.ts +++ b/clients/js/src/index.ts @@ -43,3 +43,5 @@ export type { DeleteParams, CollectionParams, } from "./types"; + +export * from "./Errors"; diff --git a/clients/js/test/collection.client.test.ts b/clients/js/test/collection.client.test.ts index 22cb63edcec..2e6350669ba 100644 --- a/clients/js/test/collection.client.test.ts +++ b/clients/js/test/collection.client.test.ts @@ -1,9 +1,4 @@ -import { - expect, - test, - beforeEach, - describe, -} from "@jest/globals"; +import { expect, test, beforeEach, describe } from "@jest/globals"; import { DefaultEmbeddingFunction } from "../src"; import { ChromaClient } from "../src"; @@ -27,6 +22,16 @@ describe("collection operations", () => { expect(collections).toHaveLength(1); }); + test("it should list collections with metadata", async () => { + await client.createCollection({ name: "test", metadata: { test: "test" } }); + const collections = await client.listCollectionsAndMetadata(); + expect(collections).toHaveLength(1); + const [collection] = collections; + expect(collection).toHaveProperty("metadata"); + expect(collection.metadata).toHaveProperty("test"); + expect(collection.metadata).toEqual({ test: "test" }); + }); + test("it should create a collection", async () => { const collection = await client.createCollection({ name: "test" }); expect(collection).toBeDefined(); @@ -38,7 +43,7 @@ describe("collection operations", () => { const [returnedCollection] = collections; - expect(returnedCollection).toEqual("test") + expect(returnedCollection).toEqual("test"); expect([{ name: "test2", metadata: null }]).not.toEqual( expect.arrayContaining(collections), diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 0a7d1f300c8..f1cad2556e0 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - 'numpy >= 1.22.5, < 2.0.0', + 'numpy >= 1.22.5', 'opentelemetry-api>=1.2.0', 'opentelemetry-exporter-otlp-proto-grpc>=1.2.0', 'opentelemetry-sdk>=1.2.0', diff --git a/docs/docs.trychroma.com/markdoc/content/docs/overview/getting-started.md b/docs/docs.trychroma.com/markdoc/content/docs/overview/getting-started.md index 1885977df38..2b7d61f460c 100644 --- a/docs/docs.trychroma.com/markdoc/content/docs/overview/getting-started.md +++ b/docs/docs.trychroma.com/markdoc/content/docs/overview/getting-started.md @@ -9,7 +9,7 @@ # Getting Started -Chroma is an AI-native open-source vector database. It comes with everything you need to get started built in, and runs on your machine. A [hosted version](https://airtable.com/shrOAiDUtS2ILy5vZ) is coming soon! +Chroma is an AI-native open-source vector database. It comes with everything you need to get started built in, and runs on your machine. A [hosted version](https://trychroma.com/signup) is now available for early access! ### 1. Install @@ -29,7 +29,7 @@ pip install chromadb {% Tab label="yarn" %} ```terminal -yarn install chromadb chromadb-default-embed +yarn add chromadb chromadb-default-embed ``` {% /Tab %} @@ -41,7 +41,7 @@ npm install --save chromadb chromadb-default-embed {% Tab label="pnpm" %} ```terminal -pnpm install chromadb chromadb-default-embed +pnpm add chromadb chromadb-default-embed ``` {% /Tab %} diff --git a/docs/docs.trychroma.com/markdoc/content/integrations/embedding-models/ollama.md b/docs/docs.trychroma.com/markdoc/content/integrations/embedding-models/ollama.md index 1a00397f28f..160a309911a 100644 --- a/docs/docs.trychroma.com/markdoc/content/integrations/embedding-models/ollama.md +++ b/docs/docs.trychroma.com/markdoc/content/integrations/embedding-models/ollama.md @@ -15,10 +15,12 @@ a [model](https://github.com/ollama/ollama?tab=readme-ov-file#model-library) of {% Tab label="python" %} ```python -import chromadb.utils.embedding_functions as embedding_functions +from chromadb.utils.embedding_functions.ollama_embedding_function import ( + OllamaEmbeddingFunction, +) -ollama_ef = embedding_functions.OllamaEmbeddingFunction( - url="http://localhost:11434/api/embeddings", +ollama_ef = OllamaEmbeddingFunction( + url="http://localhost:11434", model_name="llama2", ) diff --git a/go/cmd/coordinator/cmd.go b/go/cmd/coordinator/cmd.go index ce490f92a9a..964d8ba283d 100644 --- a/go/cmd/coordinator/cmd.go +++ b/go/cmd/coordinator/cmd.go @@ -37,6 +37,7 @@ func init() { Cmd.Flags().StringVar(&conf.DBConfig.Username, "username", "chroma", "MetaTable username") Cmd.Flags().StringVar(&conf.DBConfig.Password, "password", "chroma", "MetaTable password") Cmd.Flags().StringVar(&conf.DBConfig.Address, "db-address", "postgres", "MetaTable db address") + Cmd.Flags().StringVar(&conf.DBConfig.ReadAddress, "read-db-address", "postgres", "MetaTable db read only address") Cmd.Flags().IntVar(&conf.DBConfig.Port, "db-port", 5432, "MetaTable db port") Cmd.Flags().StringVar(&conf.DBConfig.DBName, "db-name", "sysdb", "MetaTable db name") Cmd.Flags().IntVar(&conf.DBConfig.MaxIdleConns, "max-idle-conns", 10, "MetaTable max idle connections") diff --git a/go/pkg/sysdb/coordinator/coordinator.go b/go/pkg/sysdb/coordinator/coordinator.go index df638da63ab..97099af56ad 100644 --- a/go/pkg/sysdb/coordinator/coordinator.go +++ b/go/pkg/sysdb/coordinator/coordinator.go @@ -11,7 +11,6 @@ import ( "github.com/chroma-core/chroma/go/pkg/types" "github.com/pingcap/log" "go.uber.org/zap" - "gorm.io/gorm" ) // DeleteMode represents whether to perform a soft or hard delete @@ -33,7 +32,7 @@ type Coordinator struct { deleteMode DeleteMode } -func NewCoordinator(ctx context.Context, db *gorm.DB, deleteMode DeleteMode) (*Coordinator, error) { +func NewCoordinator(ctx context.Context, deleteMode DeleteMode) (*Coordinator, error) { s := &Coordinator{ ctx: ctx, deleteMode: deleteMode, @@ -115,6 +114,10 @@ func (s *Coordinator) GetCollections(ctx context.Context, collectionID types.Uni return s.catalog.GetCollections(ctx, collectionID, collectionName, tenantID, databaseName, limit, offset) } +func (s *Coordinator) GetCollectionSize(ctx context.Context, collectionID types.UniqueID) (uint64, error) { + return s.catalog.GetCollectionSize(ctx, collectionID) +} + func (s *Coordinator) GetCollectionWithSegments(ctx context.Context, collectionID types.UniqueID) (*model.Collection, []*model.Segment, error) { return s.catalog.GetCollectionWithSegments(ctx, collectionID) } @@ -223,3 +226,7 @@ func (s *Coordinator) GetTenantsLastCompactionTime(ctx context.Context, tenantID func (s *Coordinator) FlushCollectionCompaction(ctx context.Context, flushCollectionCompaction *model.FlushCollectionCompaction) (*model.FlushCollectionInfo, error) { return s.catalog.FlushCollectionCompaction(ctx, flushCollectionCompaction) } + +func (s *Coordinator) ListCollectionsToGc(ctx context.Context) ([]*model.CollectionToGc, error) { + return s.catalog.ListCollectionsToGc(ctx) +} diff --git a/go/pkg/sysdb/coordinator/coordinator_test.go b/go/pkg/sysdb/coordinator/coordinator_test.go index 16d4bdb42a7..1bb12d6c821 100644 --- a/go/pkg/sysdb/coordinator/coordinator_test.go +++ b/go/pkg/sysdb/coordinator/coordinator_test.go @@ -25,6 +25,7 @@ import ( type APIsTestSuite struct { suite.Suite db *gorm.DB + read_db *gorm.DB collectionId1 types.UniqueID collectionId2 types.UniqueID records [][]byte @@ -37,7 +38,7 @@ type APIsTestSuite struct { func (suite *APIsTestSuite) SetupSuite() { log.Info("setup suite") - suite.db = dbcore.ConfigDatabaseForTesting() + suite.db, suite.read_db = dbcore.ConfigDatabaseForTesting() } func (suite *APIsTestSuite) SetupTest() { @@ -53,7 +54,7 @@ func (suite *APIsTestSuite) SetupTest() { collection.Name = "collection_" + suite.T().Name() + strconv.Itoa(index) } ctx := context.Background() - c, err := NewCoordinator(ctx, suite.db, SoftDelete) + c, err := NewCoordinator(ctx, SoftDelete) if err != nil { suite.T().Fatalf("error creating coordinator: %v", err) } @@ -82,9 +83,9 @@ func (suite *APIsTestSuite) TearDownTest() { // TODO: This is not complete yet. We need to add more tests for the other APIs. // We will deprecate the example based tests once we have enough tests here. func testCollection(t *rapid.T) { - db := dbcore.ConfigDatabaseForTesting() + dbcore.ConfigDatabaseForTesting() ctx := context.Background() - c, err := NewCoordinator(ctx, db, HardDelete) + c, err := NewCoordinator(ctx, HardDelete) if err != nil { t.Fatalf("error creating coordinator: %v", err) } @@ -135,9 +136,9 @@ func testCollection(t *rapid.T) { } func testSegment(t *rapid.T) { - db := dbcore.ConfigDatabaseForTesting() + dbcore.ConfigDatabaseForTesting() ctx := context.Background() - c, err := NewCoordinator(ctx, db, HardDelete) + c, err := NewCoordinator(ctx, HardDelete) if err != nil { t.Fatalf("error creating coordinator: %v", err) } @@ -493,6 +494,16 @@ func (suite *APIsTestSuite) TestCreateGetDeleteCollections() { suite.Empty(segments) } +func (suite *APIsTestSuite) TestCollectionSize() { + ctx := context.Background() + + for _, collection := range suite.sampleCollections { + result, err := suite.coordinator.GetCollectionSize(ctx, collection.ID) + suite.NoError(err) + suite.Equal(uint64(0), result) + } +} + func (suite *APIsTestSuite) TestUpdateCollections() { ctx := context.Background() coll := &model.Collection{ diff --git a/go/pkg/sysdb/coordinator/model/collection.go b/go/pkg/sysdb/coordinator/model/collection.go index 20d37b1d0b4..99feef52aa5 100644 --- a/go/pkg/sysdb/coordinator/model/collection.go +++ b/go/pkg/sysdb/coordinator/model/collection.go @@ -5,17 +5,24 @@ import ( ) type Collection struct { - ID types.UniqueID - Name string - ConfigurationJsonStr string - Dimension *int32 - Metadata *CollectionMetadata[CollectionMetadataValueType] - TenantID string - DatabaseName string - Ts types.Timestamp - LogPosition int64 - Version int32 - UpdatedAt types.Timestamp + ID types.UniqueID + Name string + ConfigurationJsonStr string + Dimension *int32 + Metadata *CollectionMetadata[CollectionMetadataValueType] + TenantID string + DatabaseName string + Ts types.Timestamp + LogPosition int64 + Version int32 + UpdatedAt types.Timestamp + TotalRecordsPostCompaction uint64 +} + +type CollectionToGc struct { + ID types.UniqueID + Name string + VersionFilePath string } type CreateCollection struct { @@ -49,11 +56,12 @@ type UpdateCollection struct { } type FlushCollectionCompaction struct { - ID types.UniqueID - TenantID string - LogPosition int64 - CurrentCollectionVersion int32 - FlushSegmentCompactions []*FlushSegmentCompaction + ID types.UniqueID + TenantID string + LogPosition int64 + CurrentCollectionVersion int32 + FlushSegmentCompactions []*FlushSegmentCompaction + TotalRecordsPostCompaction uint64 } type FlushCollectionInfo struct { diff --git a/go/pkg/sysdb/coordinator/model_db_convert.go b/go/pkg/sysdb/coordinator/model_db_convert.go index 778c8056245..a4a7728c0ee 100644 --- a/go/pkg/sysdb/coordinator/model_db_convert.go +++ b/go/pkg/sysdb/coordinator/model_db_convert.go @@ -15,15 +15,16 @@ func convertCollectionToModel(collectionAndMetadataList []*dbmodel.CollectionAnd collections := make([]*model.Collection, 0, len(collectionAndMetadataList)) for _, collectionAndMetadata := range collectionAndMetadataList { collection := &model.Collection{ - ID: types.MustParse(collectionAndMetadata.Collection.ID), - Name: *collectionAndMetadata.Collection.Name, - ConfigurationJsonStr: *collectionAndMetadata.Collection.ConfigurationJsonStr, - Dimension: collectionAndMetadata.Collection.Dimension, - TenantID: collectionAndMetadata.TenantID, - DatabaseName: collectionAndMetadata.DatabaseName, - Ts: collectionAndMetadata.Collection.Ts, - LogPosition: collectionAndMetadata.Collection.LogPosition, - Version: collectionAndMetadata.Collection.Version, + ID: types.MustParse(collectionAndMetadata.Collection.ID), + Name: *collectionAndMetadata.Collection.Name, + ConfigurationJsonStr: *collectionAndMetadata.Collection.ConfigurationJsonStr, + Dimension: collectionAndMetadata.Collection.Dimension, + TenantID: collectionAndMetadata.TenantID, + DatabaseName: collectionAndMetadata.DatabaseName, + Ts: collectionAndMetadata.Collection.Ts, + LogPosition: collectionAndMetadata.Collection.LogPosition, + Version: collectionAndMetadata.Collection.Version, + TotalRecordsPostCompaction: collectionAndMetadata.Collection.TotalRecordsPostCompaction, } collection.Metadata = convertCollectionMetadataToModel(collectionAndMetadata.CollectionMetadata) collections = append(collections, collection) @@ -32,6 +33,23 @@ func convertCollectionToModel(collectionAndMetadataList []*dbmodel.CollectionAnd return collections } +func convertCollectionToGcToModel(collectionToGc []*dbmodel.CollectionToGc) []*model.CollectionToGc { + if collectionToGc == nil { + return nil + } + collections := make([]*model.CollectionToGc, 0, len(collectionToGc)) + // TODO(Sanket): Set version file path. + for _, collectionInfo := range collectionToGc { + collection := model.CollectionToGc{ + ID: types.MustParse(collectionInfo.ID), + Name: collectionInfo.Name, + VersionFilePath: "", + } + collections = append(collections, &collection) + } + return collections +} + func convertCollectionMetadataToModel(collectionMetadataList []*dbmodel.CollectionMetadata) *model.CollectionMetadata[model.CollectionMetadataValueType] { metadata := model.NewCollectionMetadata[model.CollectionMetadataValueType]() if collectionMetadataList == nil { diff --git a/go/pkg/sysdb/coordinator/model_db_convert_test.go b/go/pkg/sysdb/coordinator/model_db_convert_test.go index 5f252ee776c..92cae9114a4 100644 --- a/go/pkg/sysdb/coordinator/model_db_convert_test.go +++ b/go/pkg/sysdb/coordinator/model_db_convert_test.go @@ -135,12 +135,14 @@ func TestConvertCollectionToModel(t *testing.T) { collectionName := "collection_name" colllectionConfigurationJsonStr := "{\"a\": \"param\", \"b\": \"param2\", \"3\": true}" collectionDimension := int32(3) + collectionTotalRecordsPostCompaction := uint64(100) collectionAndMetadata := &dbmodel.CollectionAndMetadata{ Collection: &dbmodel.Collection{ - ID: collectionID.String(), - Name: &collectionName, - ConfigurationJsonStr: &colllectionConfigurationJsonStr, - Dimension: &collectionDimension, + ID: collectionID.String(), + Name: &collectionName, + ConfigurationJsonStr: &colllectionConfigurationJsonStr, + Dimension: &collectionDimension, + TotalRecordsPostCompaction: collectionTotalRecordsPostCompaction, }, CollectionMetadata: []*dbmodel.CollectionMetadata{}, } @@ -151,5 +153,6 @@ func TestConvertCollectionToModel(t *testing.T) { assert.Equal(t, collectionName, modelCollections[0].Name) assert.Equal(t, colllectionConfigurationJsonStr, modelCollections[0].ConfigurationJsonStr) assert.Equal(t, collectionDimension, *modelCollections[0].Dimension) + assert.Equal(t, collectionTotalRecordsPostCompaction, modelCollections[0].TotalRecordsPostCompaction) assert.Nil(t, modelCollections[0].Metadata) } diff --git a/go/pkg/sysdb/coordinator/table_catalog.go b/go/pkg/sysdb/coordinator/table_catalog.go index afd8796f86c..05811ee1d38 100644 --- a/go/pkg/sysdb/coordinator/table_catalog.go +++ b/go/pkg/sysdb/coordinator/table_catalog.go @@ -367,6 +367,36 @@ func (tc *Catalog) GetCollections(ctx context.Context, collectionID types.Unique return collections, nil } +func (tc *Catalog) GetCollectionSize(ctx context.Context, collectionID types.UniqueID) (uint64, error) { + tracer := otel.Tracer + if tracer != nil { + _, span := tracer.Start(ctx, "Catalog.GetCollectionSize") + defer span.End() + } + + total_records_post_compaction, err := tc.metaDomain.CollectionDb(ctx).GetCollectionSize(collectionID.String()) + if err != nil { + return 0, err + } + return total_records_post_compaction, nil +} + +func (tc *Catalog) ListCollectionsToGc(ctx context.Context) ([]*model.CollectionToGc, error) { + tracer := otel.Tracer + if tracer != nil { + _, span := tracer.Start(ctx, "Catalog.ListCollectionsToGc") + defer span.End() + } + + collectionsToGc, err := tc.metaDomain.CollectionDb(ctx).ListCollectionsToGc() + + if err != nil { + return nil, err + } + collections := convertCollectionToGcToModel(collectionsToGc) + return collections, nil +} + func (tc *Catalog) GetCollectionWithSegments(ctx context.Context, collectionID types.UniqueID) (*model.Collection, []*model.Segment, error) { tracer := otel.Tracer if tracer != nil { @@ -864,7 +894,7 @@ func (tc *Catalog) FlushCollectionCompaction(ctx context.Context, flushCollectio } // update collection log position and version - collectionVersion, err := tc.metaDomain.CollectionDb(txCtx).UpdateLogPositionAndVersion(flushCollectionCompaction.ID.String(), flushCollectionCompaction.LogPosition, flushCollectionCompaction.CurrentCollectionVersion) + collectionVersion, err := tc.metaDomain.CollectionDb(txCtx).UpdateLogPositionVersionAndTotalRecords(flushCollectionCompaction.ID.String(), flushCollectionCompaction.LogPosition, flushCollectionCompaction.CurrentCollectionVersion, flushCollectionCompaction.TotalRecordsPostCompaction) if err != nil { return err } diff --git a/go/pkg/sysdb/coordinator/table_catalog_test.go b/go/pkg/sysdb/coordinator/table_catalog_test.go index 5be35fa0847..4680ccca18f 100644 --- a/go/pkg/sysdb/coordinator/table_catalog_test.go +++ b/go/pkg/sysdb/coordinator/table_catalog_test.go @@ -137,3 +137,17 @@ func TestCatalog_GetCollections(t *testing.T) { // assert that the mock methods were called as expected mockMetaDomain.AssertExpectations(t) } + +func TestCatalog_GetCollectionSize(t *testing.T) { + mockMetaDomain := &mocks.IMetaDomain{} + catalog := NewTableCatalog(nil, mockMetaDomain) + collectionID := types.MustParse("00000000-0000-0000-0000-000000000001") + mockMetaDomain.On("CollectionDb", context.Background()).Return(&mocks.ICollectionDb{}) + var total_records_post_compaction uint64 = uint64(100) + mockMetaDomain.CollectionDb(context.Background()).(*mocks.ICollectionDb).On("GetCollectionSize", *types.FromUniqueID(collectionID)).Return(total_records_post_compaction, nil) + collection_size, err := catalog.GetCollectionSize(context.Background(), collectionID) + + assert.NoError(t, err) + assert.Equal(t, total_records_post_compaction, collection_size) + mockMetaDomain.AssertExpectations(t) +} diff --git a/go/pkg/sysdb/grpc/cleaup_test.go b/go/pkg/sysdb/grpc/cleaup_test.go index d356a042be6..314088d8f8a 100644 --- a/go/pkg/sysdb/grpc/cleaup_test.go +++ b/go/pkg/sysdb/grpc/cleaup_test.go @@ -21,6 +21,7 @@ import ( type CleanupTestSuite struct { suite.Suite db *gorm.DB + read_db *gorm.DB s *Server tenantName string databaseName string @@ -29,14 +30,14 @@ type CleanupTestSuite struct { func (suite *CleanupTestSuite) SetupSuite() { log.Info("setup suite") - suite.db = dbcore.ConfigDatabaseForTesting() + suite.db, suite.read_db = dbcore.ConfigDatabaseForTesting() s, err := NewWithGrpcProvider(Config{ SystemCatalogProvider: "database", SoftDeleteEnabled: true, SoftDeleteCleanupInterval: 1 * time.Second, SoftDeleteMaxAge: 0, SoftDeleteCleanupBatchSize: 10, - Testing: true}, grpcutils.Default, suite.db) + Testing: true}, grpcutils.Default) if err != nil { suite.T().Fatalf("error creating server: %v", err) } diff --git a/go/pkg/sysdb/grpc/collection_service.go b/go/pkg/sysdb/grpc/collection_service.go index 7d7abd69647..f0f9297106e 100644 --- a/go/pkg/sysdb/grpc/collection_service.go +++ b/go/pkg/sysdb/grpc/collection_service.go @@ -148,6 +148,26 @@ func (s *Server) GetCollections(ctx context.Context, req *coordinatorpb.GetColle return res, nil } +func (s *Server) GetCollectionSize(ctx context.Context, req *coordinatorpb.GetCollectionSizeRequest) (*coordinatorpb.GetCollectionSizeResponse, error) { + collectionID := req.Id + + res := &coordinatorpb.GetCollectionSizeResponse{} + + parsedCollectionID, err := types.ToUniqueID(&collectionID) + if err != nil { + log.Error("GetCollectionSize failed. collection id format error", zap.Error(err), zap.Stringp("collection_id", &collectionID)) + return res, grpcutils.BuildInternalGrpcError(err.Error()) + } + + total_records_post_compaction, err := s.coordinator.GetCollectionSize(ctx, parsedCollectionID) + if err != nil { + log.Error("GetCollectionSize failed. ", zap.Error(err), zap.Stringp("collection_id", &collectionID)) + return res, grpcutils.BuildInternalGrpcError(err.Error()) + } + res.TotalRecordsPostCompaction = total_records_post_compaction + return res, nil +} + func (s *Server) CheckCollections(ctx context.Context, req *coordinatorpb.CheckCollectionsRequest) (*coordinatorpb.CheckCollectionsResponse, error) { res := &coordinatorpb.CheckCollectionsResponse{} res.Deleted = make([]bool, len(req.CollectionIds)) @@ -325,11 +345,12 @@ func (s *Server) FlushCollectionCompaction(ctx context.Context, req *coordinator }) } FlushCollectionCompaction := &model.FlushCollectionCompaction{ - ID: collectionID, - TenantID: req.TenantId, - LogPosition: req.LogPosition, - CurrentCollectionVersion: req.CollectionVersion, - FlushSegmentCompactions: segmentCompactionInfo, + ID: collectionID, + TenantID: req.TenantId, + LogPosition: req.LogPosition, + CurrentCollectionVersion: req.CollectionVersion, + FlushSegmentCompactions: segmentCompactionInfo, + TotalRecordsPostCompaction: req.TotalRecordsPostCompaction, } flushCollectionInfo, err := s.coordinator.FlushCollectionCompaction(ctx, FlushCollectionCompaction) if err != nil { @@ -343,3 +364,21 @@ func (s *Server) FlushCollectionCompaction(ctx context.Context, req *coordinator } return res, nil } + +func (s *Server) ListCollectionsToGc(ctx context.Context, req *coordinatorpb.ListCollectionsToGcRequest) (*coordinatorpb.ListCollectionsToGcResponse, error) { + // Dumb implementation that just returns ALL the collections for now. + collectionsToGc, err := s.coordinator.ListCollectionsToGc(ctx) + if err != nil { + log.Error("ListCollectionsToGc failed", zap.Error(err)) + return nil, grpcutils.BuildInternalGrpcError(err.Error()) + } + res := &coordinatorpb.ListCollectionsToGcResponse{} + for _, collectionToGc := range collectionsToGc { + res.Collections = append(res.Collections, &coordinatorpb.CollectionToGcInfo{ + Id: collectionToGc.ID.String(), + Name: collectionToGc.Name, + VersionFilePath: collectionToGc.VersionFilePath, + }) + } + return res, nil +} diff --git a/go/pkg/sysdb/grpc/collection_service_test.go b/go/pkg/sysdb/grpc/collection_service_test.go index a003e59621b..23222c5a307 100644 --- a/go/pkg/sysdb/grpc/collection_service_test.go +++ b/go/pkg/sysdb/grpc/collection_service_test.go @@ -29,6 +29,7 @@ type CollectionServiceTestSuite struct { suite.Suite catalog *coordinator.Catalog db *gorm.DB + read_db *gorm.DB s *Server tenantName string databaseName string @@ -37,10 +38,10 @@ type CollectionServiceTestSuite struct { func (suite *CollectionServiceTestSuite) SetupSuite() { log.Info("setup suite") - suite.db = dbcore.ConfigDatabaseForTesting() + suite.db, suite.read_db = dbcore.ConfigDatabaseForTesting() s, err := NewWithGrpcProvider(Config{ SystemCatalogProvider: "database", - Testing: true}, grpcutils.Default, suite.db) + Testing: true}, grpcutils.Default) if err != nil { suite.T().Fatalf("error creating server: %v", err) } @@ -70,10 +71,10 @@ func (suite *CollectionServiceTestSuite) TearDownSuite() { // Collection created should have the right ID // Collection created should have the right timestamp func testCollection(t *rapid.T) { - db := dbcore.ConfigDatabaseForTesting() + dbcore.ConfigDatabaseForTesting() s, err := NewWithGrpcProvider(Config{ SystemCatalogProvider: "memory", - Testing: true}, grpcutils.Default, db) + Testing: true}, grpcutils.Default) if err != nil { t.Fatalf("error creating server: %v", err) } @@ -476,6 +477,22 @@ func (suite *CollectionServiceTestSuite) TestServer_FlushCollectionCompaction() suite.NoError(err) } +func (suite *CollectionServiceTestSuite) TestGetCollectionSize() { + collectionName := "collection_service_test_get_collection_size" + collectionID, err := dao.CreateTestCollection(suite.db, collectionName, 128, suite.databaseId) + suite.NoError(err) + + req := coordinatorpb.GetCollectionSizeRequest{ + Id: collectionID, + } + res, err := suite.s.GetCollectionSize(context.Background(), &req) + suite.NoError(err) + suite.Equal(uint64(100), res.TotalRecordsPostCompaction) + + err = dao.CleanUpTestCollection(suite.db, collectionID) + suite.NoError(err) +} + func TestCollectionServiceTestSuite(t *testing.T) { testSuite := new(CollectionServiceTestSuite) suite.Run(t, testSuite) diff --git a/go/pkg/sysdb/grpc/proto_model_convert.go b/go/pkg/sysdb/grpc/proto_model_convert.go index fe1c4e6c6ef..f3785bb5539 100644 --- a/go/pkg/sysdb/grpc/proto_model_convert.go +++ b/go/pkg/sysdb/grpc/proto_model_convert.go @@ -40,14 +40,15 @@ func convertCollectionToProto(collection *model.Collection) *coordinatorpb.Colle } collectionpb := &coordinatorpb.Collection{ - Id: collection.ID.String(), - Name: collection.Name, - ConfigurationJsonStr: collection.ConfigurationJsonStr, - Dimension: collection.Dimension, - Tenant: collection.TenantID, - Database: collection.DatabaseName, - LogPosition: collection.LogPosition, - Version: collection.Version, + Id: collection.ID.String(), + Name: collection.Name, + ConfigurationJsonStr: collection.ConfigurationJsonStr, + Dimension: collection.Dimension, + Tenant: collection.TenantID, + Database: collection.DatabaseName, + LogPosition: collection.LogPosition, + Version: collection.Version, + TotalRecordsPostCompaction: collection.TotalRecordsPostCompaction, } if collection.Metadata == nil { return collectionpb diff --git a/go/pkg/sysdb/grpc/proto_model_convert_test.go b/go/pkg/sysdb/grpc/proto_model_convert_test.go index 32282332ae9..18e97e56ba2 100644 --- a/go/pkg/sysdb/grpc/proto_model_convert_test.go +++ b/go/pkg/sysdb/grpc/proto_model_convert_test.go @@ -51,6 +51,7 @@ func TestConvertCollectionToProto(t *testing.T) { // Test case 2: collection is not nil dimention := int32(10) + num_records := uint64(0) collection := &model.Collection{ ID: types.NewUniqueID(), Name: "test_collection", @@ -62,6 +63,7 @@ func TestConvertCollectionToProto(t *testing.T) { "key3": &model.CollectionMetadataValueFloat64Type{Value: 3.14}, }, }, + TotalRecordsPostCompaction: num_records, } collectionpb = convertCollectionToProto(collection) assert.NotNil(t, collectionpb) @@ -72,6 +74,7 @@ func TestConvertCollectionToProto(t *testing.T) { assert.Equal(t, "value1", collectionpb.Metadata.Metadata["key1"].GetStringValue()) assert.Equal(t, int64(123), collectionpb.Metadata.Metadata["key2"].GetIntValue()) assert.Equal(t, 3.14, collectionpb.Metadata.Metadata["key3"].GetFloatValue()) + assert.Equal(t, uint64(0), collectionpb.TotalRecordsPostCompaction) } func TestConvertCollectionMetadataToProto(t *testing.T) { diff --git a/go/pkg/sysdb/grpc/server.go b/go/pkg/sysdb/grpc/server.go index 4cf8346426b..1c2410c51b4 100644 --- a/go/pkg/sysdb/grpc/server.go +++ b/go/pkg/sysdb/grpc/server.go @@ -16,7 +16,6 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/health" - "gorm.io/gorm" ) type Config struct { @@ -71,20 +70,20 @@ type Server struct { func New(config Config) (*Server, error) { if config.SystemCatalogProvider == "memory" { - return NewWithGrpcProvider(config, grpcutils.Default, nil) + return NewWithGrpcProvider(config, grpcutils.Default) } else if config.SystemCatalogProvider == "database" { dBConfig := config.DBConfig - db, err := dbcore.ConnectPostgres(dBConfig) + err := dbcore.ConnectDB(dBConfig) if err != nil { return nil, err } - return NewWithGrpcProvider(config, grpcutils.Default, db) + return NewWithGrpcProvider(config, grpcutils.Default) } else { return nil, errors.New("invalid system catalog provider, only memory and database are supported") } } -func NewWithGrpcProvider(config Config, provider grpcutils.GrpcProvider, db *gorm.DB) (*Server, error) { +func NewWithGrpcProvider(config Config, provider grpcutils.GrpcProvider) (*Server, error) { ctx := context.Background() s := &Server{ healthServer: health.NewServer(), @@ -97,7 +96,7 @@ func NewWithGrpcProvider(config Config, provider grpcutils.GrpcProvider, db *gor deleteMode = coordinator.HardDelete } - coordinator, err := coordinator.NewCoordinator(ctx, db, deleteMode) + coordinator, err := coordinator.NewCoordinator(ctx, deleteMode) if err != nil { return nil, err } diff --git a/go/pkg/sysdb/grpc/tenant_database_service_test.go b/go/pkg/sysdb/grpc/tenant_database_service_test.go index 4d5d8c549ce..d966eb0b95f 100644 --- a/go/pkg/sysdb/grpc/tenant_database_service_test.go +++ b/go/pkg/sysdb/grpc/tenant_database_service_test.go @@ -30,10 +30,10 @@ type TenantDatabaseServiceTestSuite struct { func (suite *TenantDatabaseServiceTestSuite) SetupSuite() { log.Info("setup suite") - suite.db = dbcore.ConfigDatabaseForTesting() + suite.db, _ = dbcore.ConfigDatabaseForTesting() s, err := NewWithGrpcProvider(Config{ SystemCatalogProvider: "database", - Testing: true}, grpcutils.Default, suite.db) + Testing: true}, grpcutils.Default) if err != nil { suite.T().Fatalf("error creating server: %v", err) } diff --git a/go/pkg/sysdb/metastore/db/dao/collection.go b/go/pkg/sysdb/metastore/db/dao/collection.go index 6030cc8c6e4..7c21a6c8ba0 100644 --- a/go/pkg/sysdb/metastore/db/dao/collection.go +++ b/go/pkg/sysdb/metastore/db/dao/collection.go @@ -16,7 +16,8 @@ import ( ) type collectionDb struct { - db *gorm.DB + db *gorm.DB + read_db *gorm.DB } var _ dbmodel.ICollectionDb = &collectionDb{} @@ -50,10 +51,22 @@ func (s *collectionDb) GetCollections(id *string, name *string, tenantID string, return s.getCollections(id, name, tenantID, databaseName, limit, offset, false) } +func (s *collectionDb) ListCollectionsToGc() ([]*dbmodel.CollectionToGc, error) { + // TODO(Sanket): Read version file path. + var collections []*dbmodel.CollectionToGc + // Use the read replica for this so as to not overwhelm the writer. + // Skip collections that have not been compacted even once. + err := s.read_db.Table("collections").Select("id, name, version").Find(&collections).Where("version > 0").Error + if err != nil { + return nil, err + } + return collections, nil +} + func (s *collectionDb) getCollections(id *string, name *string, tenantID string, databaseName string, limit *int32, offset *int32, is_deleted bool) (collectionWithMetdata []*dbmodel.CollectionAndMetadata, err error) { var collections []*dbmodel.Collection query := s.db.Table("collections"). - Select("collections.id, collections.log_position, collections.version, collections.name, collections.configuration_json_str, collections.dimension, collections.database_id, collections.is_deleted, databases.name, databases.tenant_id"). + Select("collections.id, collections.log_position, collections.version, collections.name, collections.configuration_json_str, collections.dimension, collections.database_id, collections.is_deleted, collections.total_records_post_compaction, databases.name, databases.tenant_id"). Joins("INNER JOIN databases ON collections.database_id = databases.id"). Order("collections.created_at ASC") @@ -96,22 +109,24 @@ func (s *collectionDb) getCollections(id *string, name *string, tenantID string, collectionCreatedAt sql.NullTime databaseName string databaseTenantID string + totalRecordsPostCompaction uint64 ) - err := rows.Scan(&collectionID, &logPosition, &version, &collectionName, &collectionConfigurationJsonStr, &collectionDimension, &collectionDatabaseID, &collectionIsDeleted, &databaseName, &databaseTenantID) + err := rows.Scan(&collectionID, &logPosition, &version, &collectionName, &collectionConfigurationJsonStr, &collectionDimension, &collectionDatabaseID, &collectionIsDeleted, &totalRecordsPostCompaction, &databaseName, &databaseTenantID) if err != nil { log.Error("scan collection failed", zap.Error(err)) return nil, err } collection := &dbmodel.Collection{ - ID: collectionID, - Name: &collectionName, - ConfigurationJsonStr: &collectionConfigurationJsonStr, - DatabaseID: collectionDatabaseID, - LogPosition: logPosition, - Version: version, - IsDeleted: collectionIsDeleted, + ID: collectionID, + Name: &collectionName, + ConfigurationJsonStr: &collectionConfigurationJsonStr, + DatabaseID: collectionDatabaseID, + LogPosition: logPosition, + Version: version, + IsDeleted: collectionIsDeleted, + TotalRecordsPostCompaction: totalRecordsPostCompaction, } if collectionDimension.Valid { collection.Dimension = &collectionDimension.Int32 @@ -140,6 +155,29 @@ func (s *collectionDb) getCollections(id *string, name *string, tenantID string, return } +func (s *collectionDb) GetCollectionSize(id string) (uint64, error) { + query := s.read_db.Table("collections"). + Select("collections.total_records_post_compaction"). + Where("collections.id = ?", id) + + rows, err := query.Rows() + if err != nil { + return 0, err + } + + var totalRecordsPostCompaction uint64 + + for rows.Next() { + err := rows.Scan(&totalRecordsPostCompaction) + if err != nil { + log.Error("scan collection failed", zap.Error(err)) + return 0, err + } + } + rows.Close() + return totalRecordsPostCompaction, nil +} + func (s *collectionDb) GetSoftDeletedCollections(collectionID *string, tenantID string, databaseName string, limit int32) ([]*dbmodel.CollectionAndMetadata, error) { return s.getCollections(collectionID, nil, tenantID, databaseName, &limit, nil, true) } @@ -209,8 +247,8 @@ func (s *collectionDb) Update(in *dbmodel.Collection) error { return nil } -func (s *collectionDb) UpdateLogPositionAndVersion(collectionID string, logPosition int64, currentCollectionVersion int32) (int32, error) { - log.Info("update log position and version", zap.String("collectionID", collectionID), zap.Int64("logPosition", logPosition), zap.Int32("currentCollectionVersion", currentCollectionVersion)) +func (s *collectionDb) UpdateLogPositionVersionAndTotalRecords(collectionID string, logPosition int64, currentCollectionVersion int32, totalRecordsPostCompaction uint64) (int32, error) { + log.Info("update log position, version, and total records post compaction", zap.String("collectionID", collectionID), zap.Int64("logPosition", logPosition), zap.Int32("currentCollectionVersion", currentCollectionVersion), zap.Uint64("totalRecords", totalRecordsPostCompaction)) var collection dbmodel.Collection // We use select for update to ensure no lost update happens even for isolation level read committed or below // https://patrick.engineering/posts/postgres-internals/ @@ -230,7 +268,7 @@ func (s *collectionDb) UpdateLogPositionAndVersion(collectionID string, logPosit } version := currentCollectionVersion + 1 - err = s.db.Model(&dbmodel.Collection{}).Where("id = ?", collectionID).Updates(map[string]interface{}{"log_position": logPosition, "version": version}).Error + err = s.db.Model(&dbmodel.Collection{}).Where("id = ?", collectionID).Updates(map[string]interface{}{"log_position": logPosition, "version": version, "total_records_post_compaction": totalRecordsPostCompaction}).Error if err != nil { return 0, err } diff --git a/go/pkg/sysdb/metastore/db/dao/collection_test.go b/go/pkg/sysdb/metastore/db/dao/collection_test.go index 8f8b37e8c9e..5823c4a38b7 100644 --- a/go/pkg/sysdb/metastore/db/dao/collection_test.go +++ b/go/pkg/sysdb/metastore/db/dao/collection_test.go @@ -16,6 +16,7 @@ import ( type CollectionDbTestSuite struct { suite.Suite db *gorm.DB + read_db *gorm.DB collectionDb *collectionDb tenantName string databaseName string @@ -24,9 +25,10 @@ type CollectionDbTestSuite struct { func (suite *CollectionDbTestSuite) SetupSuite() { log.Info("setup suite") - suite.db = dbcore.ConfigDatabaseForTesting() + suite.db, suite.read_db = dbcore.ConfigDatabaseForTesting() suite.collectionDb = &collectionDb{ - db: suite.db, + db: suite.db, + read_db: suite.read_db, } suite.tenantName = "test_collection_tenant" suite.databaseName = "test_collection_database" @@ -75,6 +77,7 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollections() { suite.Len(collections[0].CollectionMetadata, 1) suite.Equal(metadata.Key, collections[0].CollectionMetadata[0].Key) suite.Equal(metadata.StrValue, collections[0].CollectionMetadata[0].StrValue) + suite.Equal(uint64(100), collections[0].Collection.TotalRecordsPostCompaction) // Test when filtering by ID collections, err = suite.collectionDb.GetCollections(nil, nil, suite.tenantName, suite.databaseName, nil, nil) @@ -120,7 +123,7 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollections() { suite.NoError(err) } -func (suite *CollectionDbTestSuite) TestCollectionDb_UpdateLogPositionAndVersion() { +func (suite *CollectionDbTestSuite) TestCollectionDb_UpdateLogPositionVersionAndTotalRecords() { collectionName := "test_collection_get_collections" collectionID, _ := CreateTestCollection(suite.db, collectionName, 128, suite.databaseId) // verify default values @@ -131,22 +134,23 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_UpdateLogPositionAndVersion suite.Equal(int32(0), collections[0].Collection.Version) // update log position and version - version, err := suite.collectionDb.UpdateLogPositionAndVersion(collectionID, int64(10), 0) + version, err := suite.collectionDb.UpdateLogPositionVersionAndTotalRecords(collectionID, int64(10), 0, uint64(100)) suite.NoError(err) suite.Equal(int32(1), version) collections, _ = suite.collectionDb.GetCollections(&collectionID, nil, "", "", nil, nil) suite.Len(collections, 1) suite.Equal(int64(10), collections[0].Collection.LogPosition) suite.Equal(int32(1), collections[0].Collection.Version) + suite.Equal(uint64(100), collections[0].Collection.TotalRecordsPostCompaction) // invalid log position - _, err = suite.collectionDb.UpdateLogPositionAndVersion(collectionID, int64(5), 0) + _, err = suite.collectionDb.UpdateLogPositionVersionAndTotalRecords(collectionID, int64(5), 0, uint64(100)) suite.Error(err, "collection log position Stale") // invalid version - _, err = suite.collectionDb.UpdateLogPositionAndVersion(collectionID, int64(20), 0) + _, err = suite.collectionDb.UpdateLogPositionVersionAndTotalRecords(collectionID, int64(20), 0, uint64(100)) suite.Error(err, "collection version invalid") - _, err = suite.collectionDb.UpdateLogPositionAndVersion(collectionID, int64(20), 3) + _, err = suite.collectionDb.UpdateLogPositionVersionAndTotalRecords(collectionID, int64(20), 3, uint64(100)) suite.Error(err, "collection version invalid") //clean up @@ -206,6 +210,19 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_SoftDelete() { suite.NoError(err) } +func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollectionSize() { + collectionName := "test_collection_get_collection_size" + collectionID, err := CreateTestCollection(suite.db, collectionName, 128, suite.databaseId) + suite.NoError(err) + + total_records_post_compaction, err := suite.collectionDb.GetCollectionSize(collectionID) + suite.NoError(err) + suite.Equal(uint64(100), total_records_post_compaction) + + err = CleanUpTestCollection(suite.db, collectionID) + suite.NoError(err) +} + func TestCollectionDbTestSuiteSuite(t *testing.T) { testSuite := new(CollectionDbTestSuite) suite.Run(t, testSuite) diff --git a/go/pkg/sysdb/metastore/db/dao/common.go b/go/pkg/sysdb/metastore/db/dao/common.go index 2c5fadb77ab..2962f01209d 100644 --- a/go/pkg/sysdb/metastore/db/dao/common.go +++ b/go/pkg/sysdb/metastore/db/dao/common.go @@ -22,7 +22,7 @@ func (*MetaDomain) TenantDb(ctx context.Context) dbmodel.ITenantDb { } func (*MetaDomain) CollectionDb(ctx context.Context) dbmodel.ICollectionDb { - return &collectionDb{dbcore.GetDB(ctx)} + return &collectionDb{dbcore.GetDB(ctx), dbcore.GetReadDB(ctx)} } func (*MetaDomain) CollectionMetadataDb(ctx context.Context) dbmodel.ICollectionMetadataDb { diff --git a/go/pkg/sysdb/metastore/db/dao/segment_test.go b/go/pkg/sysdb/metastore/db/dao/segment_test.go index 3f85ee4d78a..e19d52e80fd 100644 --- a/go/pkg/sysdb/metastore/db/dao/segment_test.go +++ b/go/pkg/sysdb/metastore/db/dao/segment_test.go @@ -23,7 +23,7 @@ type SegmentDbTestSuite struct { func (suite *SegmentDbTestSuite) SetupSuite() { log.Info("setup suite") - suite.db = dbcore.ConfigDatabaseForTesting() + suite.db, _ = dbcore.ConfigDatabaseForTesting() suite.segmentDb = &segmentDb{ db: suite.db, } diff --git a/go/pkg/sysdb/metastore/db/dao/tenant_test.go b/go/pkg/sysdb/metastore/db/dao/tenant_test.go index ac97e4f685b..21e1b58b84b 100644 --- a/go/pkg/sysdb/metastore/db/dao/tenant_test.go +++ b/go/pkg/sysdb/metastore/db/dao/tenant_test.go @@ -21,7 +21,7 @@ type TenantDbTestSuite struct { func (suite *TenantDbTestSuite) SetupSuite() { log.Info("setup suite") - suite.db = dbcore.ConfigDatabaseForTesting() + suite.db, _ = dbcore.ConfigDatabaseForTesting() suite.Db = &tenantDb{ db: suite.db, } diff --git a/go/pkg/sysdb/metastore/db/dao/test_utils.go b/go/pkg/sysdb/metastore/db/dao/test_utils.go index 74db6cd8678..87c5500f13c 100644 --- a/go/pkg/sysdb/metastore/db/dao/test_utils.go +++ b/go/pkg/sysdb/metastore/db/dao/test_utils.go @@ -118,12 +118,13 @@ func CreateTestCollection(db *gorm.DB, collectionName string, dimension int32, d defaultConfigurationJsonStr := "{\"a\": \"param\", \"b\": \"param2\", \"3\": true}" err := collectionDb.Insert(&dbmodel.Collection{ - ID: collectionId, - Name: &collectionName, - ConfigurationJsonStr: &defaultConfigurationJsonStr, - Dimension: &dimension, - DatabaseID: databaseID, - CreatedAt: time.Now(), + ID: collectionId, + Name: &collectionName, + ConfigurationJsonStr: &defaultConfigurationJsonStr, + Dimension: &dimension, + DatabaseID: databaseID, + CreatedAt: time.Now(), + TotalRecordsPostCompaction: uint64(100), }) if err != nil { return "", err diff --git a/go/pkg/sysdb/metastore/db/dbcore/core.go b/go/pkg/sysdb/metastore/db/dbcore/core.go index 4ad893f873c..1d1c9b9e54b 100644 --- a/go/pkg/sysdb/metastore/db/dbcore/core.go +++ b/go/pkg/sysdb/metastore/db/dbcore/core.go @@ -24,13 +24,15 @@ import ( ) var ( - globalDB *gorm.DB + globalDB *gorm.DB + globalReadDB *gorm.DB ) type DBConfig struct { Username string Password string Address string + ReadAddress string Port int DBName string MaxIdleConns int @@ -38,10 +40,26 @@ type DBConfig struct { SslMode string } -func ConnectPostgres(cfg DBConfig) (*gorm.DB, error) { - log.Info("ConnectPostgres", zap.String("host", cfg.Address), zap.String("database", cfg.DBName), zap.Int("port", cfg.Port)) +func ConnectDB(cfg DBConfig) error { + db, err := ConnectPostgres(cfg.Address, cfg.Username, cfg.Password, cfg.Port, cfg.DBName, cfg.SslMode, cfg.MaxIdleConns, cfg.MaxOpenConns) + if err != nil { + return err + } + read_db, err := ConnectPostgres(cfg.ReadAddress, cfg.Username, cfg.Password, cfg.Port, cfg.DBName, cfg.SslMode, cfg.MaxIdleConns, cfg.MaxOpenConns) + if err != nil { + return err + } + + globalDB = db + globalReadDB = read_db + + return nil +} + +func ConnectPostgres(address string, username string, password string, port int, dbName string, sslMode string, maxIdleConns int, maxOpenConns int) (*gorm.DB, error) { + log.Info("ConnectPostgres", zap.String("host", address), zap.String("database", dbName), zap.Int("port", port)) dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=%s", - cfg.Address, cfg.Username, cfg.Password, cfg.DBName, cfg.Port, cfg.SslMode) + address, username, password, dbName, port, sslMode) ormLogger := logger.Default ormLogger.LogMode(logger.Info) @@ -51,8 +69,8 @@ func ConnectPostgres(cfg DBConfig) (*gorm.DB, error) { }) if err != nil { log.Error("fail to connect db", - zap.String("host", cfg.Address), - zap.String("database", cfg.DBName), + zap.String("host", address), + zap.String("database", dbName), zap.Error(err)) return nil, err } @@ -65,29 +83,22 @@ func ConnectPostgres(cfg DBConfig) (*gorm.DB, error) { idb, err := db.DB() if err != nil { log.Error("fail to create db instance", - zap.String("host", cfg.Address), - zap.String("database", cfg.DBName), + zap.String("host", address), + zap.String("database", dbName), zap.Error(err)) return nil, err } - idb.SetMaxIdleConns(cfg.MaxIdleConns) - idb.SetMaxOpenConns(cfg.MaxOpenConns) - - globalDB = db + idb.SetMaxIdleConns(maxIdleConns) + idb.SetMaxOpenConns(maxOpenConns) log.Info("Postgres connected success", - zap.String("host", cfg.Address), - zap.String("database", cfg.DBName), + zap.String("host", address), + zap.String("database", dbName), zap.Error(err)) return db, nil } -// SetGlobalDB Only for test -func SetGlobalDB(db *gorm.DB) { - globalDB = db -} - type ctxTransactionKey struct{} func CtxWithTransaction(ctx context.Context, tx *gorm.DB) context.Context { @@ -128,6 +139,22 @@ func GetDB(ctx context.Context) *gorm.DB { return globalDB.WithContext(ctx) } +func GetReadDB(ctx context.Context) *gorm.DB { + iface := ctx.Value(ctxTransactionKey{}) + + if iface != nil { + tx, ok := iface.(*gorm.DB) + if !ok { + log.Error("unexpected context value type", zap.Any("type", reflect.TypeOf(tx))) + return nil + } + + return tx + } + + return globalReadDB.WithContext(ctx) +} + func CreateDefaultTenantAndDatabase(db *gorm.DB) string { defaultTenant := &dbmodel.Tenant{ ID: common.DefaultTenant, @@ -225,15 +252,19 @@ func GetDBConfigForTesting() DBConfig { MaxIdleConns: 10, MaxOpenConns: 100, SslMode: "disable", + ReadAddress: "localhost", } } -func ConfigDatabaseForTesting() *gorm.DB { - db, err := ConnectPostgres(GetDBConfigForTesting()) +func ConfigDatabaseForTesting() (*gorm.DB, *gorm.DB) { + cfg := GetDBConfigForTesting() + db, err := ConnectPostgres(cfg.Address, cfg.Username, cfg.Password, cfg.Port, cfg.DBName, cfg.SslMode, cfg.MaxIdleConns, cfg.MaxOpenConns) if err != nil { panic("failed to connect database") } - SetGlobalDB(db) + globalDB = db + // For testing, we set the read_db to be the same as the db + globalReadDB = db CreateTestTables(db) - return db + return globalDB, globalReadDB } diff --git a/go/pkg/sysdb/metastore/db/dbmodel/collection.go b/go/pkg/sysdb/metastore/db/dbmodel/collection.go index 761ed66acbd..b680b19687d 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/collection.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/collection.go @@ -7,17 +7,24 @@ import ( ) type Collection struct { - ID string `gorm:"id;primaryKey"` - Name *string `gorm:"name;not null;index:idx_name,unique;"` - ConfigurationJsonStr *string `gorm:"configuration_json_str"` - Dimension *int32 `gorm:"dimension"` - DatabaseID string `gorm:"database_id;not null;index:idx_name,unique;"` - Ts types.Timestamp `gorm:"ts;type:bigint;default:0"` - IsDeleted bool `gorm:"is_deleted;type:bool;default:false"` - CreatedAt time.Time `gorm:"created_at;type:timestamp;not null;default:current_timestamp"` - UpdatedAt time.Time `gorm:"updated_at;type:timestamp;not null;default:current_timestamp"` - LogPosition int64 `gorm:"log_position;default:0"` - Version int32 `gorm:"version;default:0"` + ID string `gorm:"id;primaryKey"` + Name *string `gorm:"name;not null;index:idx_name,unique;"` + ConfigurationJsonStr *string `gorm:"configuration_json_str"` + Dimension *int32 `gorm:"dimension"` + DatabaseID string `gorm:"database_id;not null;index:idx_name,unique;"` + Ts types.Timestamp `gorm:"ts;type:bigint;default:0"` + IsDeleted bool `gorm:"is_deleted;type:bool;default:false"` + CreatedAt time.Time `gorm:"created_at;type:timestamp;not null;default:current_timestamp"` + UpdatedAt time.Time `gorm:"updated_at;type:timestamp;not null;default:current_timestamp"` + LogPosition int64 `gorm:"log_position;default:0"` + Version int32 `gorm:"version;default:0"` + TotalRecordsPostCompaction uint64 `gorm:"total_records_post_compaction;default:0"` +} + +type CollectionToGc struct { + ID string `gorm:"id;primaryKey"` + Name string `gorm:"name;not null;index:idx_name,unique;"` + Version int32 `gorm:"version;default:0"` } func (v Collection) TableName() string { @@ -39,6 +46,8 @@ type ICollectionDb interface { Insert(in *Collection) error Update(in *Collection) error DeleteAll() error - UpdateLogPositionAndVersion(collectionID string, logPosition int64, currentCollectionVersion int32) (int32, error) + UpdateLogPositionVersionAndTotalRecords(collectionID string, logPosition int64, currentCollectionVersion int32, totalRecordsPostCompaction uint64) (int32, error) GetCollectionEntry(collectionID *string, databaseName *string) (*Collection, error) + GetCollectionSize(collectionID string) (uint64, error) + ListCollectionsToGc() ([]*CollectionToGc, error) } diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionDb.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionDb.go index 21a3596f8ba..bf9c3e76895 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionDb.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionDb.go @@ -88,6 +88,34 @@ func (_m *ICollectionDb) GetCollectionEntry(collectionID *string, databaseName * return r0, r1 } +// GetCollectionSize provides a mock function with given fields: collectionID +func (_m *ICollectionDb) GetCollectionSize(collectionID string) (uint64, error) { + ret := _m.Called(collectionID) + + if len(ret) == 0 { + panic("no return value specified for GetCollectionSize") + } + + var r0 uint64 + var r1 error + if rf, ok := ret.Get(0).(func(string) (uint64, error)); ok { + return rf(collectionID) + } + if rf, ok := ret.Get(0).(func(string) uint64); ok { + r0 = rf(collectionID) + } else { + r0 = ret.Get(0).(uint64) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(collectionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetCollections provides a mock function with given fields: collectionID, collectionName, tenantID, databaseName, limit, offset func (_m *ICollectionDb) GetCollections(collectionID *string, collectionName *string, tenantID string, databaseName string, limit *int32, offset *int32) ([]*dbmodel.CollectionAndMetadata, error) { ret := _m.Called(collectionID, collectionName, tenantID, databaseName, limit, offset) @@ -166,6 +194,36 @@ func (_m *ICollectionDb) Insert(in *dbmodel.Collection) error { return r0 } +// ListCollectionsToGc provides a mock function with given fields: +func (_m *ICollectionDb) ListCollectionsToGc() ([]*dbmodel.CollectionToGc, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ListCollectionsToGc") + } + + var r0 []*dbmodel.CollectionToGc + var r1 error + if rf, ok := ret.Get(0).(func() ([]*dbmodel.CollectionToGc, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() []*dbmodel.CollectionToGc); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*dbmodel.CollectionToGc) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Update provides a mock function with given fields: in func (_m *ICollectionDb) Update(in *dbmodel.Collection) error { ret := _m.Called(in) @@ -184,27 +242,27 @@ func (_m *ICollectionDb) Update(in *dbmodel.Collection) error { return r0 } -// UpdateLogPositionAndVersion provides a mock function with given fields: collectionID, logPosition, currentCollectionVersion -func (_m *ICollectionDb) UpdateLogPositionAndVersion(collectionID string, logPosition int64, currentCollectionVersion int32) (int32, error) { - ret := _m.Called(collectionID, logPosition, currentCollectionVersion) +// UpdateLogPositionVersionAndTotalRecords provides a mock function with given fields: collectionID, logPosition, currentCollectionVersion, totalRecordsPostCompaction +func (_m *ICollectionDb) UpdateLogPositionVersionAndTotalRecords(collectionID string, logPosition int64, currentCollectionVersion int32, totalRecordsPostCompaction uint64) (int32, error) { + ret := _m.Called(collectionID, logPosition, currentCollectionVersion, totalRecordsPostCompaction) if len(ret) == 0 { - panic("no return value specified for UpdateLogPositionAndVersion") + panic("no return value specified for UpdateLogPositionVersionAndTotalRecords") } var r0 int32 var r1 error - if rf, ok := ret.Get(0).(func(string, int64, int32) (int32, error)); ok { - return rf(collectionID, logPosition, currentCollectionVersion) + if rf, ok := ret.Get(0).(func(string, int64, int32, uint64) (int32, error)); ok { + return rf(collectionID, logPosition, currentCollectionVersion, totalRecordsPostCompaction) } - if rf, ok := ret.Get(0).(func(string, int64, int32) int32); ok { - r0 = rf(collectionID, logPosition, currentCollectionVersion) + if rf, ok := ret.Get(0).(func(string, int64, int32, uint64) int32); ok { + r0 = rf(collectionID, logPosition, currentCollectionVersion, totalRecordsPostCompaction) } else { r0 = ret.Get(0).(int32) } - if rf, ok := ret.Get(1).(func(string, int64, int32) error); ok { - r1 = rf(collectionID, logPosition, currentCollectionVersion) + if rf, ok := ret.Get(1).(func(string, int64, int32, uint64) error); ok { + r1 = rf(collectionID, logPosition, currentCollectionVersion, totalRecordsPostCompaction) } else { r1 = ret.Error(1) } diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/IDatabaseDb.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/IDatabaseDb.go index acb6685bfee..90d4d8dc782 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/mocks/IDatabaseDb.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/IDatabaseDb.go @@ -12,6 +12,24 @@ type IDatabaseDb struct { mock.Mock } +// Delete provides a mock function with given fields: databaseID +func (_m *IDatabaseDb) Delete(databaseID string) error { + ret := _m.Called(databaseID) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(databaseID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // DeleteAll provides a mock function with given fields: func (_m *IDatabaseDb) DeleteAll() error { ret := _m.Called() @@ -108,6 +126,36 @@ func (_m *IDatabaseDb) Insert(in *dbmodel.Database) error { return r0 } +// ListDatabases provides a mock function with given fields: limit, offset, tenantID +func (_m *IDatabaseDb) ListDatabases(limit *int32, offset *int32, tenantID string) ([]*dbmodel.Database, error) { + ret := _m.Called(limit, offset, tenantID) + + if len(ret) == 0 { + panic("no return value specified for ListDatabases") + } + + var r0 []*dbmodel.Database + var r1 error + if rf, ok := ret.Get(0).(func(*int32, *int32, string) ([]*dbmodel.Database, error)); ok { + return rf(limit, offset, tenantID) + } + if rf, ok := ret.Get(0).(func(*int32, *int32, string) []*dbmodel.Database); ok { + r0 = rf(limit, offset, tenantID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*dbmodel.Database) + } + } + + if rf, ok := ret.Get(1).(func(*int32, *int32, string) error); ok { + r1 = rf(limit, offset, tenantID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // NewIDatabaseDb creates a new instance of IDatabaseDb. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewIDatabaseDb(t interface { diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentDb.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentDb.go index 08087c86c37..040bf1de60c 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentDb.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentDb.go @@ -82,6 +82,36 @@ func (_m *ISegmentDb) GetSegments(id types.UniqueID, segmentType *string, scope return r0, r1 } +// GetSegmentsByCollectionID provides a mock function with given fields: collectionID +func (_m *ISegmentDb) GetSegmentsByCollectionID(collectionID string) ([]*dbmodel.Segment, error) { + ret := _m.Called(collectionID) + + if len(ret) == 0 { + panic("no return value specified for GetSegmentsByCollectionID") + } + + var r0 []*dbmodel.Segment + var r1 error + if rf, ok := ret.Get(0).(func(string) ([]*dbmodel.Segment, error)); ok { + return rf(collectionID) + } + if rf, ok := ret.Get(0).(func(string) []*dbmodel.Segment); ok { + r0 = rf(collectionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*dbmodel.Segment) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(collectionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Insert provides a mock function with given fields: _a0 func (_m *ISegmentDb) Insert(_a0 *dbmodel.Segment) error { ret := _m.Called(_a0) diff --git a/go/pkg/sysdb/metastore/db/migrations/20250109224431.sql b/go/pkg/sysdb/metastore/db/migrations/20250109224431.sql new file mode 100644 index 00000000000..35cab301001 --- /dev/null +++ b/go/pkg/sysdb/metastore/db/migrations/20250109224431.sql @@ -0,0 +1 @@ +ALTER TABLE "public"."collections" ADD COLUMN "total_records_post_compaction" bigint NULL DEFAULT 0; diff --git a/go/pkg/sysdb/metastore/db/migrations/atlas.sum b/go/pkg/sysdb/metastore/db/migrations/atlas.sum index dbc1a7b701d..91cd0abf3c3 100644 --- a/go/pkg/sysdb/metastore/db/migrations/atlas.sum +++ b/go/pkg/sysdb/metastore/db/migrations/atlas.sum @@ -1,4 +1,4 @@ -h1:zBpjJtt9tYhfBrZbaMXWdPN4CUXU4EEIPOy/5gaO0Vg= +h1:tHeZoWb7PdWh9oOHH9g4j69+xbuHhzdNXwzWKQltxEk= 20240313233558.sql h1:Gv0TiSYsqGoOZ2T2IWvX4BOasauxool8PrBOIjmmIdg= 20240321194713.sql h1:kVkNpqSFhrXGVGFFvL7JdK3Bw31twFcEhI6A0oCFCkg= 20240327075032.sql h1:nlr2J74XRU8erzHnKJgMr/tKqJxw9+R6RiiEBuvuzgo= @@ -8,3 +8,4 @@ h1:zBpjJtt9tYhfBrZbaMXWdPN4CUXU4EEIPOy/5gaO0Vg= 20240621171854.sql h1:kc8ZFK8A4/GLHu0gQ04RdIt3O38wfcD6ouXX4nr7lTk= 20241003212820.sql h1:zHloxrMr7EMcqV008a3aqQdU5fHjWY3m66CIoThexbo= 20241016181945.sql h1:O8UmR8rvD1LyKIld5OO9c0j+xSXW51MHL//gYUTQ2jo= +20250109224431.sql h1:RjJ2Q3jAWj48T2vmEo7X9rI9cKFC6zIcBUTq4RaE14A= diff --git a/idl/chromadb/proto/chroma.proto b/idl/chromadb/proto/chroma.proto index 1d1f153f2ba..541b036e56d 100644 --- a/idl/chromadb/proto/chroma.proto +++ b/idl/chromadb/proto/chroma.proto @@ -53,6 +53,7 @@ message Collection { string database = 7; int64 log_position = 8; int32 version = 9; + uint64 total_records_post_compaction = 10; } message Database { diff --git a/idl/chromadb/proto/compactor.proto b/idl/chromadb/proto/compactor.proto new file mode 100644 index 00000000000..1541e4144f4 --- /dev/null +++ b/idl/chromadb/proto/compactor.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package chroma; + +message CollectionIds { + repeated string ids = 1; +} + +message CompactionRequest { + CollectionIds ids = 1; +} + +message CompactionResponse { + // Empty +} + +service Compactor { + rpc Compact(CompactionRequest) returns (CompactionResponse) {} +} diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index 30c6613dff7..f3d54cfc6ba 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -225,6 +225,7 @@ message FlushCollectionCompactionRequest { int64 log_position = 3; int32 collection_version = 4; repeated FlushSegmentCompactionInfo segment_compaction_info = 5; + uint64 total_records_post_compaction = 6; } message FlushCollectionCompactionResponse { @@ -360,6 +361,26 @@ message RestoreCollectionResponse { int64 new_collection_version = 1; } +message GetCollectionSizeRequest { + string id = 1; +} + +message GetCollectionSizeResponse { + uint64 total_records_post_compaction = 1; +} + +message ListCollectionsToGcRequest {} + +message CollectionToGcInfo { + string id = 1; + string name = 2; + string version_file_path = 3; +} + +message ListCollectionsToGcResponse { + repeated CollectionToGcInfo collections = 1; +} + service SysDB { rpc CreateDatabase(CreateDatabaseRequest) returns (CreateDatabaseResponse) {} rpc GetDatabase(GetDatabaseRequest) returns (GetDatabaseResponse) {} @@ -383,4 +404,6 @@ service SysDB { rpc FlushCollectionCompaction(FlushCollectionCompactionRequest) returns (FlushCollectionCompactionResponse) {} rpc RestoreCollection(RestoreCollectionRequest) returns (RestoreCollectionResponse) {} rpc ListCollectionVersions(ListCollectionVersionsRequest) returns (ListCollectionVersionsResponse) {} + rpc GetCollectionSize(GetCollectionSizeRequest) returns (GetCollectionSizeResponse) {} + rpc ListCollectionsToGc(ListCollectionsToGcRequest) returns (ListCollectionsToGcResponse) {} } diff --git a/k8s/test/otel-collector.yaml b/k8s/test/otel-collector.yaml index d70341eaea9..8c1ceb7df3c 100644 --- a/k8s/test/otel-collector.yaml +++ b/k8s/test/otel-collector.yaml @@ -15,6 +15,10 @@ data: processors: batch: + # When using the tracing crate in Rust, we sometimes set the otel.name attribute for dynamic spans. Jaeger does not automatically override the span name with this attribute, so we do it manually here. + span/override-name: + name: + from_attributes: [name] exporters: prometheus: @@ -29,6 +33,7 @@ data: pipelines: traces: receivers: [otlp] + processors: [batch, span/override-name] exporters: [otlp/jaeger] metrics: receivers: [otlp] diff --git a/main.py b/main.py deleted file mode 100644 index 723399cb78a..00000000000 --- a/main.py +++ /dev/null @@ -1,5 +0,0 @@ -import chromadb - -if __name__ == '__main__': - client = chromadb.Client() - col = client.create_collection(name="test", metadata={"hnsw:search_ef": 100, "hnsw:construction_ef": 1000}) diff --git a/rust/blockstore/src/arrow/blockfile.rs b/rust/blockstore/src/arrow/blockfile.rs index 6be8b5a2fe3..3c831d9f92d 100644 --- a/rust/blockstore/src/arrow/blockfile.rs +++ b/rust/blockstore/src/arrow/blockfile.rs @@ -133,12 +133,23 @@ impl ArrowUnorderedBlockfileWriter { Box::new(ArrowBlockfileError::MigrationError(e)) as Box })?; + let count = self + .root + .sparse_index + .data + .lock() + .counts + .values() + .map(|&x| x as u64) + .sum::(); + let flusher = ArrowBlockfileFlusher::new( self.block_manager, self.root_manager, blocks, self.root, self.id, + count, ); Ok(flusher) @@ -705,7 +716,7 @@ mod tests { use uuid::Uuid; #[tokio::test] - async fn test_count() { + async fn test_reader_count() { let tmp_dir = tempfile::tempdir().unwrap(); let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); let block_cache = new_cache_for_test(); @@ -744,6 +755,79 @@ mod tests { } } + #[tokio::test] + async fn test_writer_count() { + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let blockfile_provider = ArrowBlockfileProvider::new( + storage, + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + + // Test no keys + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + + let flusher = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(0_u64, flusher.count()); + flusher.flush::<&str, Vec>().await.unwrap(); + + // Test 2 keys + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + + let prefix_1 = "key"; + let key1 = "zzzz"; + let value1 = vec![1, 2, 3]; + writer.set(prefix_1, key1, value1.clone()).await.unwrap(); + + let prefix_2 = "key"; + let key2 = "aaaa"; + let value2 = vec![4, 5, 6]; + writer.set(prefix_2, key2, value2).await.unwrap(); + + let flusher1 = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(2_u64, flusher1.count()); + + // Test add keys after commit, before flush + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + + let prefix_3 = "key"; + let key3 = "yyyy"; + let value3 = vec![7, 8, 9]; + writer.set(prefix_3, key3, value3.clone()).await.unwrap(); + + let prefix_4 = "key"; + let key4 = "bbbb"; + let value4 = vec![10, 11, 12]; + writer.set(prefix_4, key4, value4).await.unwrap(); + + let flusher2 = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(2_u64, flusher2.count()); + + flusher1.flush::<&str, Vec>().await.unwrap(); + flusher2.flush::<&str, Vec>().await.unwrap(); + + // Test count after flush + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + let flusher = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(0_u64, flusher.count()); + } + fn test_prefix(num_keys: u32, prefix_for_query: u32) { Runtime::new().unwrap().block_on(async { let tmp_dir = tempfile::tempdir().unwrap(); diff --git a/rust/blockstore/src/arrow/flusher.rs b/rust/blockstore/src/arrow/flusher.rs index 401d3ae4f7d..02726eb9f5c 100644 --- a/rust/blockstore/src/arrow/flusher.rs +++ b/rust/blockstore/src/arrow/flusher.rs @@ -14,6 +14,7 @@ pub struct ArrowBlockfileFlusher { blocks: Vec, root: RootWriter, id: Uuid, + count: u64, } impl ArrowBlockfileFlusher { @@ -23,6 +24,7 @@ impl ArrowBlockfileFlusher { blocks: Vec, root: RootWriter, id: Uuid, + count: u64, ) -> Self { Self { block_manager, @@ -30,6 +32,7 @@ impl ArrowBlockfileFlusher { blocks, root, id, + count, } } @@ -68,4 +71,8 @@ impl ArrowBlockfileFlusher { pub(crate) fn id(&self) -> Uuid { self.id } + + pub(crate) fn count(&self) -> u64 { + self.count + } } diff --git a/rust/blockstore/src/arrow/ordered_blockfile_writer.rs b/rust/blockstore/src/arrow/ordered_blockfile_writer.rs index 5fa021eefc7..6806f379b99 100644 --- a/rust/blockstore/src/arrow/ordered_blockfile_writer.rs +++ b/rust/blockstore/src/arrow/ordered_blockfile_writer.rs @@ -194,12 +194,23 @@ impl ArrowOrderedBlockfileWriter { Box::new(ArrowBlockfileError::MigrationError(e)) as Box })?; + let count = self + .root + .sparse_index + .data + .lock() + .counts + .values() + .map(|&x| x as u64) + .sum::(); + let flusher = ArrowBlockfileFlusher::new( self.block_manager, self.root_manager, blocks, self.root, self.id, + count, ); Ok(flusher) @@ -366,7 +377,7 @@ mod tests { use uuid::Uuid; #[tokio::test] - async fn test_count() { + async fn test_reader_count() { let tmp_dir = tempfile::tempdir().unwrap(); let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); let block_cache = new_cache_for_test(); @@ -405,6 +416,79 @@ mod tests { } } + #[tokio::test] + async fn test_writer_count() { + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let blockfile_provider = ArrowBlockfileProvider::new( + storage, + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + + // Test no keys + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + + let flusher = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(0_u64, flusher.count()); + flusher.flush::<&str, Vec>().await.unwrap(); + + // Test 2 keys + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + + let prefix_1 = "key"; + let key1 = "zzzz"; + let value1 = vec![1, 2, 3]; + writer.set(prefix_1, key1, value1.clone()).await.unwrap(); + + let prefix_2 = "key"; + let key2 = "aaaa"; + let value2 = vec![4, 5, 6]; + writer.set(prefix_2, key2, value2).await.unwrap(); + + let flusher1 = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(2_u64, flusher1.count()); + + // Test add keys after commit, before flush + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + + let prefix_3 = "key"; + let key3 = "yyyy"; + let value3 = vec![7, 8, 9]; + writer.set(prefix_3, key3, value3.clone()).await.unwrap(); + + let prefix_4 = "key"; + let key4 = "bbbb"; + let value4 = vec![10, 11, 12]; + writer.set(prefix_4, key4, value4).await.unwrap(); + + let flusher2 = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(2_u64, flusher2.count()); + + flusher1.flush::<&str, Vec>().await.unwrap(); + flusher2.flush::<&str, Vec>().await.unwrap(); + + // Test count after flush + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + let flusher = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(0_u64, flusher.count()); + } + #[tokio::test] async fn test_blockfile() { let tmp_dir = tempfile::tempdir().unwrap(); diff --git a/rust/blockstore/src/arrow/provider.rs b/rust/blockstore/src/arrow/provider.rs index b4ea130e341..0a27d0d0e72 100644 --- a/rust/blockstore/src/arrow/provider.rs +++ b/rust/blockstore/src/arrow/provider.rs @@ -18,11 +18,29 @@ use chroma_cache::{CacheError, PersistentCache}; use chroma_config::Configurable; use chroma_error::{ChromaError, ErrorCodes}; use chroma_storage::Storage; +use futures::{stream::FuturesUnordered, StreamExt}; use std::sync::Arc; use thiserror::Error; use tracing::{Instrument, Span}; use uuid::Uuid; +#[derive(Error, Debug)] +pub enum ArrowBlockfileProviderPrefetchError { + #[error("Error reading root for blockfile: {0}")] + RootManager(#[from] Box), + #[error("Error fetching block: {0}")] + BlockManager(#[from] GetError), +} + +impl ChromaError for ArrowBlockfileProviderPrefetchError { + fn code(&self) -> ErrorCodes { + match self { + ArrowBlockfileProviderPrefetchError::RootManager(e) => e.code(), + ArrowBlockfileProviderPrefetchError::BlockManager(e) => e.code(), + } + } +} + /// A BlockFileProvider that creates ArrowBlockfiles (Arrow-backed blockfiles used for production). /// For now, it keeps a simple local cache of blockfiles. #[derive(Clone)] @@ -62,6 +80,29 @@ impl ArrowBlockfileProvider { } } + pub async fn prefetch(&self, id: &Uuid) -> Result { + let block_ids = self + .root_manager + .get_all_block_ids(id) + .await + .map_err(|e| ArrowBlockfileProviderPrefetchError::RootManager(Box::new(e)))?; + + let mut futures = FuturesUnordered::new(); + for block_id in block_ids.iter() { + // Don't prefetch if already cached. + if !self.block_manager.cached(block_id).await { + futures.push(self.block_manager.get(block_id)); + } + } + let count = futures.len(); + + while let Some(result) = futures.next().await { + result?; + } + + Ok(count) + } + pub async fn write< 'new, K: Key + Into + ArrowWriteableKey + 'new, @@ -427,10 +468,7 @@ impl RootManager { Some(index) => Ok(Some(index)), None => { tracing::info!("Cache miss - fetching root from storage"); - // TODO(hammadb): For legacy and temporary development purposes, we are reading the file - // from a fixed location. The path is sparse_index/ for legacy reasons. - // This will be replaced with a full prefix-based storage shortly - let key = format!("sparse_index/{}", id); + let key = Self::get_storage_key(id); tracing::debug!("Reading root from storage with key: {}", key); match self.storage.get(&key).await { Ok(bytes) => match RootReader::from_bytes::(&bytes, *id) { @@ -452,6 +490,19 @@ impl RootManager { } } + pub async fn get_all_block_ids(&self, id: &Uuid) -> Result, RootManagerError> { + let key = Self::get_storage_key(id); + tracing::debug!("Reading root from storage with key: {}", key); + match self.storage.get(&key).await { + Ok(bytes) => RootReader::get_all_block_ids_from_bytes(&bytes, *id) + .map_err(RootManagerError::FromBytesError), + Err(e) => { + tracing::error!("Error reading root from storage: {}", e); + Err(RootManagerError::StorageGetError(e)) + } + } + } + pub async fn flush<'read, K: ArrowWriteableKey + 'read>( &self, root: &RootWriter, @@ -492,6 +543,13 @@ impl RootManager { None => Err(RootManagerError::NotFound), } } + + fn get_storage_key(id: &Uuid) -> String { + // TODO(hammadb): For legacy and temporary development purposes, we are reading the file + // from a fixed location. The path is sparse_index/ for legacy reasons. + // This will be replaced with a full prefix-based storage shortly + format!("sparse_index/{}", id) + } } #[cfg(test)] diff --git a/rust/blockstore/src/arrow/root.rs b/rust/blockstore/src/arrow/root.rs index 34f5caef368..567306f5f40 100644 --- a/rust/blockstore/src/arrow/root.rs +++ b/rust/blockstore/src/arrow/root.rs @@ -269,6 +269,32 @@ impl ChromaError for FromBytesError { } impl RootReader { + pub(super) fn get_all_block_ids_from_bytes( + bytes: &[u8], + id: Uuid, + ) -> Result, FromBytesError> { + let mut cursor = std::io::Cursor::new(bytes); + let arrow_reader = arrow::ipc::reader::FileReader::try_new(&mut cursor, None); + + let record_batch = match arrow_reader { + Ok(mut reader) => match reader.next() { + Some(Ok(batch)) => batch, + Some(Err(e)) => return Err(FromBytesError::ArrowError(e)), + None => { + return Err(FromBytesError::NoDataError); + } + }, + Err(e) => return Err(FromBytesError::ArrowError(e)), + }; + + let (version, read_id) = Self::version_and_id_from_record_batch(&record_batch, id)?; + if read_id != id { + return Err(FromBytesError::IdMismatch); + } + + Self::block_ids_from_record_batch(&record_batch, version) + } + pub(super) fn from_bytes<'data, K: ArrowReadableKey<'data>>( bytes: &[u8], id: Uuid, @@ -287,20 +313,7 @@ impl RootReader { Err(e) => return Err(FromBytesError::ArrowError(e)), }; - let metadata = &record_batch.schema_ref().metadata; - let (version, read_id) = match (metadata.get("version"), metadata.get("id")) { - (Some(version), Some(read_id)) => ( - Version::try_from(version.as_str())?, - Uuid::parse_str(read_id)?, - ), - (Some(_), None) => return Err(FromBytesError::MissingMetadata("id".to_string())), - (None, Some(_)) => { - return Err(FromBytesError::MissingMetadata("version".to_string())); - } - // We default to the current version in the absence of metadata for these fields for - // backwards compatibility - (None, None) => (Version::V1, id), - }; + let (version, read_id) = Self::version_and_id_from_record_batch(&record_batch, id)?; if read_id != id { return Err(FromBytesError::IdMismatch); @@ -316,31 +329,7 @@ impl RootReader { // The sparse index copies the data so it can live as long as it needs to independently let record_batch: &'data RecordBatch = unsafe { std::mem::transmute(&record_batch) }; let key_arr = record_batch.column(1); - let mut ids: Vec = Vec::new(); - // Versions after V1 store uuid as bytes - if version == Version::V1 { - let id_array = record_batch - .column(2) - .as_any() - .downcast_ref::() - .expect("ID array to be a StringArray"); - - for i in 0..id_array.len() { - let id = Uuid::parse_str(id_array.value(i)).expect("ID to be a valid UUID"); - ids.push(id); - } - } else { - let id_arr = record_batch - .column(2) - .as_any() - .downcast_ref::() - .expect("ID array to be a BinaryArray"); - for i in 0..id_arr.len() { - let id = Uuid::from_slice(id_arr.value(i)).expect("ID to be a valid UUID"); - ids.push(id); - } - } // Version 1.1 is the first version to have a count column let mut counts = None; if version >= Version::V1_1 { @@ -352,6 +341,8 @@ impl RootReader { counts = Some(count_arr); } + let ids = Self::block_ids_from_record_batch(record_batch, version)?; + let mut forward = BTreeMap::new(); for (i, block_id) in ids.iter().enumerate() { let prefix = prefix_arr.value(i); @@ -394,6 +385,57 @@ impl RootReader { id: new_id, } } + + fn version_and_id_from_record_batch( + record_batch: &RecordBatch, + default_id: Uuid, + ) -> Result<(Version, Uuid), FromBytesError> { + let metadata = &record_batch.schema_ref().metadata; + match (metadata.get("version"), metadata.get("id")) { + (Some(version), Some(read_id)) => Ok(( + Version::try_from(version.as_str())?, + Uuid::parse_str(read_id)?, + )), + (Some(_), None) => Err(FromBytesError::MissingMetadata("id".to_string())), + (None, Some(_)) => Err(FromBytesError::MissingMetadata("version".to_string())), + // We default to the current version in the absence of metadata for these fields for + // backwards compatibility + (None, None) => Ok((Version::V1, default_id)), + } + } + + fn block_ids_from_record_batch( + record_batch: &RecordBatch, + version: Version, + ) -> Result, FromBytesError> { + let mut ids: Vec = Vec::new(); + // Versions after V1 store uuid as bytes + if version == Version::V1 { + let id_array = record_batch + .column(2) + .as_any() + .downcast_ref::() + .expect("ID array to be a StringArray"); + + for i in 0..id_array.len() { + let id = Uuid::parse_str(id_array.value(i)).expect("ID to be a valid UUID"); + ids.push(id); + } + } else { + let id_arr = record_batch + .column(2) + .as_any() + .downcast_ref::() + .expect("ID array to be a BinaryArray"); + + for i in 0..id_arr.len() { + let id = Uuid::from_slice(id_arr.value(i)).expect("ID to be a valid UUID"); + ids.push(id); + } + } + + Ok(ids) + } } #[cfg(test)] diff --git a/rust/blockstore/src/lib.rs b/rust/blockstore/src/lib.rs index c8cd1bee2fe..4e5e56a113a 100644 --- a/rust/blockstore/src/lib.rs +++ b/rust/blockstore/src/lib.rs @@ -10,10 +10,10 @@ use chroma_storage::test_storage; use provider::BlockfileProvider; pub use types::*; -pub fn test_arrow_blockfile_provider(size: usize) -> BlockfileProvider { +pub fn test_arrow_blockfile_provider(max_block_size_bytes: usize) -> BlockfileProvider { BlockfileProvider::new_arrow( test_storage(), - size, + max_block_size_bytes, new_cache_for_test(), new_cache_for_test(), ) diff --git a/rust/blockstore/src/provider.rs b/rust/blockstore/src/provider.rs index bd41a285b41..e904c96a935 100644 --- a/rust/blockstore/src/provider.rs +++ b/rust/blockstore/src/provider.rs @@ -100,6 +100,15 @@ impl BlockfileProvider { }; Ok(()) } + + pub async fn prefetch(&self, id: &uuid::Uuid) -> Result> { + match self { + BlockfileProvider::HashMapBlockfileProvider(_) => unimplemented!(), + BlockfileProvider::ArrowBlockfileProvider(provider) => { + provider.prefetch(id).await.map_err(|e| Box::new(e) as _) + } + } + } } // =================== Configurable =================== diff --git a/rust/blockstore/src/types/flusher.rs b/rust/blockstore/src/types/flusher.rs index 12d9ee76788..8364c382d2d 100644 --- a/rust/blockstore/src/types/flusher.rs +++ b/rust/blockstore/src/types/flusher.rs @@ -31,4 +31,11 @@ impl BlockfileFlusher { BlockfileFlusher::ArrowBlockfileFlusher(flusher) => flusher.id(), } } + + pub fn count(&self) -> u64 { + match self { + BlockfileFlusher::MemoryBlockfileFlusher(_) => unimplemented!(), // no op + BlockfileFlusher::ArrowBlockfileFlusher(flusher) => flusher.count(), + } + } } diff --git a/rust/garbage_collector/src/garbage_collector_component.rs b/rust/garbage_collector/src/garbage_collector_component.rs index a525615162b..7d285289c3e 100644 --- a/rust/garbage_collector/src/garbage_collector_component.rs +++ b/rust/garbage_collector/src/garbage_collector_component.rs @@ -83,8 +83,13 @@ impl Handler for GarbageCollector { _message: GarbageCollectMessage, _ctx: &ComponentContext, ) -> Self::Result { - // TODO(Sanket): Implement the garbage collection logic. - todo!() + // Get all collections to gc and create gc orchestrator for each. + let _ = self + .sysdb_client + .get_collections_to_gc() + .await + .expect("Failed to get collections to gc"); + // TODO(Sanket): Implement the logic to create gc orchestrator for each collection. } } diff --git a/rust/load/src/data_sets.rs b/rust/load/src/data_sets.rs index 9fbf1364c95..6eee6089791 100644 --- a/rust/load/src/data_sets.rs +++ b/rust/load/src/data_sets.rs @@ -1,3 +1,4 @@ +use std::sync::atomic::AtomicUsize; use std::sync::Arc; use chromadb::collection::{GetOptions, QueryOptions}; @@ -265,6 +266,71 @@ const TINY_STORIES_DATA_SETS: &[TinyStoriesDataSet] = &[ TinyStoriesDataSet::new("stories9", PARAPHRASE_ALBERT_SMALL_V2, 50_000), ]; +//////////////////////////////////////////// RoundRobin //////////////////////////////////////////// + +/// A data set that round-robins between other data sets. +#[derive(Debug)] +pub struct RoundRobinDataSet { + name: String, + description: String, + data_sets: Vec>, + index: AtomicUsize, +} + +#[async_trait::async_trait] +impl DataSet for RoundRobinDataSet { + fn name(&self) -> String { + format!("round-robin-{}", self.name) + } + + fn description(&self) -> String { + format!("round robin between other data sets; {}", self.description) + } + + fn json(&self) -> serde_json::Value { + serde_json::json!("round-robin") + } + + async fn get( + &self, + client: &ChromaClient, + gq: GetQuery, + guac: &mut Guacamole, + ) -> Result<(), Box> { + let index = self + .index + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) + % self.data_sets.len(); + self.data_sets[index].get(client, gq, guac).await + } + + async fn query( + &self, + client: &ChromaClient, + qq: QueryQuery, + guac: &mut Guacamole, + ) -> Result<(), Box> { + let index = self + .index + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) + % self.data_sets.len(); + self.data_sets[index].query(client, qq, guac).await + } + + async fn upsert( + &self, + client: &ChromaClient, + uq: UpsertQuery, + guac: &mut Guacamole, + ) -> Result<(), Box> { + let index = self + .index + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) + % self.data_sets.len(); + self.data_sets[index].upsert(client, uq, guac).await + } +} + /////////////////////////////////////////// All Data Sets ////////////////////////////////////////// /// Get all data sets. @@ -273,6 +339,15 @@ pub fn all_data_sets() -> Vec> { for data_set in TINY_STORIES_DATA_SETS { data_sets.push(Arc::new(data_set.clone()) as _); } + data_sets.push(Arc::new(RoundRobinDataSet { + name: "tiny-stories".to_string(), + description: "tiny stories data sets".to_string(), + data_sets: TINY_STORIES_DATA_SETS + .iter() + .map(|ds| Arc::new(ds.clone()) as _) + .collect(), + index: AtomicUsize::new(0), + }) as _); for num_clusters in [10_000, 100_000] { for (seed_idx, seed_clusters) in [ 0xab1cd5b6a5173d40usize, diff --git a/rust/sysdb/Cargo.toml b/rust/sysdb/Cargo.toml index 15ce7001091..522fe932b36 100644 --- a/rust/sysdb/Cargo.toml +++ b/rust/sysdb/Cargo.toml @@ -21,6 +21,7 @@ tracing = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true } tonic = { workspace = true } +uuid = { workspace = true } parking_lot = { workspace = true } chroma-config = { workspace = true } diff --git a/rust/sysdb/src/sysdb.rs b/rust/sysdb/src/sysdb.rs index e72574cbd80..0b58bcecc2f 100644 --- a/rust/sysdb/src/sysdb.rs +++ b/rust/sysdb/src/sysdb.rs @@ -18,6 +18,7 @@ use tonic::service::interceptor; use tonic::transport::Endpoint; use tonic::Request; use tonic::Status; +use uuid::{Error, Uuid}; #[derive(Debug, Clone)] pub enum SysDb { @@ -46,6 +47,15 @@ impl SysDb { } } + pub async fn get_collections_to_gc( + &mut self, + ) -> Result, GetCollectionsToGcError> { + match self { + SysDb::Grpc(grpc) => grpc.get_collections_to_gc().await, + SysDb::Test(_) => todo!(), + } + } + pub async fn get_segments( &mut self, id: Option, @@ -76,6 +86,7 @@ impl SysDb { log_position: i64, collection_version: i32, segment_flush_info: Arc<[SegmentFlushInfo]>, + total_records_post_compaction: u64, ) -> Result { match self { SysDb::Grpc(grpc) => { @@ -85,6 +96,7 @@ impl SysDb { log_position, collection_version, segment_flush_info, + total_records_post_compaction, ) .await } @@ -95,6 +107,7 @@ impl SysDb { log_position, collection_version, segment_flush_info, + total_records_post_compaction, ) .await } @@ -167,6 +180,47 @@ impl Configurable for GrpcSysDb { } } +#[allow(dead_code)] +pub struct CollectionToGcInfo { + id: CollectionUuid, + name: String, + version_file_path: String, +} + +#[derive(Debug, Error)] +pub enum GetCollectionsToGcError { + #[error("Failed to parse uuid")] + ParsingError(#[from] Error), + #[error("Grpc request failed")] + RequestFailed(#[from] tonic::Status), +} + +impl ChromaError for GetCollectionsToGcError { + fn code(&self) -> ErrorCodes { + match self { + GetCollectionsToGcError::ParsingError(_) => ErrorCodes::Internal, + GetCollectionsToGcError::RequestFailed(_) => ErrorCodes::Internal, + } + } +} + +impl TryFrom for CollectionToGcInfo { + type Error = GetCollectionsToGcError; + + fn try_from(value: chroma_proto::CollectionToGcInfo) -> Result { + let collection_uuid = match Uuid::try_parse(&value.id) { + Ok(uuid) => uuid, + Err(e) => return Err(GetCollectionsToGcError::ParsingError(e)), + }; + let collection_id = CollectionUuid(collection_uuid); + Ok(CollectionToGcInfo { + id: collection_id, + name: value.name, + version_file_path: value.version_file_path, + }) + } +} + impl GrpcSysDb { async fn get_collections( &mut self, @@ -207,6 +261,25 @@ impl GrpcSysDb { } } + pub async fn get_collections_to_gc( + &mut self, + ) -> Result, GetCollectionsToGcError> { + let res = self + .client + .list_collections_to_gc(chroma_proto::ListCollectionsToGcRequest {}) + .await; + + match res { + Ok(collections) => collections + .into_inner() + .collections + .into_iter() + .map(|collection| collection.try_into()) + .collect::, GetCollectionsToGcError>>(), + Err(e) => Err(GetCollectionsToGcError::RequestFailed(e)), + } + } + async fn get_segments( &mut self, id: Option, @@ -273,6 +346,7 @@ impl GrpcSysDb { log_position: i64, collection_version: i32, segment_flush_info: Arc<[SegmentFlushInfo]>, + total_records_post_compaction: u64, ) -> Result { let segment_compaction_info = segment_flush_info @@ -296,6 +370,7 @@ impl GrpcSysDb { log_position, collection_version, segment_compaction_info, + total_records_post_compaction, }; let res = self.client.flush_collection_compaction(req).await; diff --git a/rust/sysdb/src/test_sysdb.rs b/rust/sysdb/src/test_sysdb.rs index 0af96b4a4ec..ec601148c91 100644 --- a/rust/sysdb/src/test_sysdb.rs +++ b/rust/sysdb/src/test_sysdb.rs @@ -171,6 +171,7 @@ impl TestSysDb { log_position: i64, collection_version: i32, segment_flush_info: Arc<[SegmentFlushInfo]>, + total_records_post_compaction: u64, ) -> Result { let mut inner = self.inner.lock(); let collection = inner.collections.get(&collection_id); @@ -182,6 +183,7 @@ impl TestSysDb { collection.log_position = log_position; let new_collection_version = collection_version + 1; collection.version = new_collection_version; + collection.total_records_post_compaction = total_records_post_compaction; inner .collections .insert(collection.collection_id, collection); diff --git a/rust/types/build.rs b/rust/types/build.rs index 4be0a2f31d7..7439f24741d 100644 --- a/rust/types/build.rs +++ b/rust/types/build.rs @@ -2,6 +2,7 @@ fn main() -> Result<(), Box> { // Compile the protobuf files in the chromadb proto directory. let mut proto_paths = vec![ "../../idl/chromadb/proto/chroma.proto", + "../../idl/chromadb/proto/compactor.proto", "../../idl/chromadb/proto/coordinator.proto", "../../idl/chromadb/proto/logservice.proto", "../../idl/chromadb/proto/query_executor.proto", diff --git a/rust/types/src/collection.rs b/rust/types/src/collection.rs index 9c7bd8914f2..db09a1641ba 100644 --- a/rust/types/src/collection.rs +++ b/rust/types/src/collection.rs @@ -41,6 +41,7 @@ pub struct Collection { pub database: String, pub log_position: i64, pub version: i32, + pub total_records_post_compaction: u64, } #[derive(Error, Debug)] @@ -85,6 +86,7 @@ impl TryFrom for Collection { database: proto_collection.database, log_position: proto_collection.log_position, version: proto_collection.version, + total_records_post_compaction: proto_collection.total_records_post_compaction, }) } } @@ -143,6 +145,7 @@ mod test { database: "qux".to_string(), log_position: 0, version: 0, + total_records_post_compaction: 0, }; let converted_collection: Collection = proto_collection.try_into().unwrap(); assert_eq!( @@ -154,5 +157,6 @@ mod test { assert_eq!(converted_collection.dimension, None); assert_eq!(converted_collection.tenant, "baz".to_string()); assert_eq!(converted_collection.database, "qux".to_string()); + assert_eq!(converted_collection.total_records_post_compaction, 0); } } diff --git a/rust/types/src/segment.rs b/rust/types/src/segment.rs index 8858365554a..7ecec3be879 100644 --- a/rust/types/src/segment.rs +++ b/rust/types/src/segment.rs @@ -36,7 +36,7 @@ impl std::fmt::Display for SegmentUuid { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum SegmentType { HnswDistributed, BlockfileMetadata, diff --git a/rust/worker/Cargo.toml b/rust/worker/Cargo.toml index 3afa0fa53ae..7cf00d56348 100644 --- a/rust/worker/Cargo.toml +++ b/rust/worker/Cargo.toml @@ -4,13 +4,17 @@ version = "0.1.0" edition = "2021" [[bin]] -name = "query_service" -path = "src/bin/query_service.rs" +name = "compaction_client" +path = "src/bin/compaction_client.rs" [[bin]] name = "compaction_service" path = "src/bin/compaction_service.rs" +[[bin]] +name = "query_service" +path = "src/bin/query_service.rs" + [dependencies] rand = "0.8.5" murmur3 = "0.5.2" @@ -47,6 +51,7 @@ prost-types = { workspace = true } num_cpus = { workspace = true } flatbuffers = { workspace = true } tantivy = { workspace = true } +clap = { workspace = true } chroma-blockstore = { workspace = true } chroma-error = { workspace = true } diff --git a/rust/worker/src/bin/compaction_client.rs b/rust/worker/src/bin/compaction_client.rs new file mode 100644 index 00000000000..48f99c329c5 --- /dev/null +++ b/rust/worker/src/bin/compaction_client.rs @@ -0,0 +1,4 @@ +#[tokio::main] +async fn main() { + worker::compaction_client_entrypoint().await; +} diff --git a/rust/worker/src/compactor/compaction_client.rs b/rust/worker/src/compactor/compaction_client.rs new file mode 100644 index 00000000000..2e4e3b52956 --- /dev/null +++ b/rust/worker/src/compactor/compaction_client.rs @@ -0,0 +1,63 @@ +use chroma_types::chroma_proto::{ + compactor_client::CompactorClient, CollectionIds, CompactionRequest, +}; +use clap::{Parser, Subcommand}; +use thiserror::Error; +use tonic::transport::Channel; +use uuid::Uuid; + +/// Error for compaction client +#[derive(Debug, Error)] +pub enum CompactionClientError { + #[error("Compactor failed: {0}")] + Compactor(String), + #[error("Unable to connect to compactor: {0}")] + Connection(#[from] tonic::transport::Error), +} + +/// Tool to control compaction service +#[derive(Debug, Parser)] +#[command(version, about, long_about = None)] +pub struct CompactionClient { + /// Url of the target compactor + #[arg(short, long)] + url: String, + /// Subcommand for compaction + #[command(subcommand)] + command: CompactionCommand, +} + +#[derive(Debug, Subcommand)] +pub enum CompactionCommand { + /// Trigger a one-off compaction + Compact { + /// Specify Uuids of the collections to compact. If unspecified, no compaction will occur unless --all flag is specified + #[arg(short, long)] + id: Vec, + }, +} + +impl CompactionClient { + async fn grpc_client(&self) -> Result, CompactionClientError> { + Ok(CompactorClient::connect(self.url.clone()).await?) + } + + pub async fn run(&self) -> Result<(), CompactionClientError> { + match &self.command { + CompactionCommand::Compact { id } => { + let mut client = self.grpc_client().await?; + let response = client + .compact(CompactionRequest { + ids: Some(CollectionIds { + ids: id.iter().map(ToString::to_string).collect(), + }), + }) + .await; + if let Err(status) = response { + return Err(CompactionClientError::Compactor(status.to_string())); + } + } + }; + Ok(()) + } +} diff --git a/rust/worker/src/compactor/compaction_manager.rs b/rust/worker/src/compactor/compaction_manager.rs index 840de1bbeb1..df4861844fd 100644 --- a/rust/worker/src/compactor/compaction_manager.rs +++ b/rust/worker/src/compactor/compaction_manager.rs @@ -1,7 +1,8 @@ use super::scheduler::Scheduler; use super::scheduler_policy::LasCompactionTimeSchedulerPolicy; +use super::OneOffCompactionMessage; use crate::compactor::types::CompactionJob; -use crate::compactor::types::ScheduleMessage; +use crate::compactor::types::ScheduledCompactionMessage; use crate::config::CompactionServiceConfig; use crate::execution::orchestration::CompactOrchestrator; use crate::execution::orchestration::CompactionResponse; @@ -107,7 +108,7 @@ impl CompactionManager { let dispatcher = match self.dispatcher { Some(ref dispatcher) => dispatcher.clone(), None => { - println!("No dispatcher found"); + tracing::error!("No dispatcher found"); return Err(Box::new(CompactionError::FailedToCompact)); } }; @@ -139,43 +140,42 @@ impl CompactionManager { } } None => { - println!("No system found"); + tracing::error!("No system found"); return Err(Box::new(CompactionError::FailedToCompact)); } }; } - // TODO: make the return type more informative #[instrument(name = "CompactionManager::compact_batch")] - pub(crate) async fn compact_batch( - &mut self, - compacted: &mut Vec, - ) -> (u32, u32) { + pub(crate) async fn compact_batch(&mut self) -> Vec { self.scheduler.schedule().await; - let mut jobs = FuturesUnordered::new(); - for job in self.scheduler.get_jobs() { - let instrumented_span = span!(parent: None, tracing::Level::INFO, "Compacting job", collection_id = ?job.collection_id); - instrumented_span.follows_from(Span::current()); - jobs.push(self.compact(job).instrument(instrumented_span)); - } - println!("Compacting {} jobs", jobs.len()); - tracing::info!("Compacting {} jobs", jobs.len()); - let mut num_completed_jobs = 0; - let mut num_failed_jobs = 0; - while let Some(job) = jobs.next().await { - match job { - Ok(result) => { - println!("Compaction completed: {:?}", result); - compacted.push(result.compaction_job.collection_id); - num_completed_jobs += 1; - } - Err(e) => { - println!("Compaction failed: {:?}", e); - num_failed_jobs += 1; + let job_futures = self + .scheduler + .get_jobs() + .map(|job| { + let instrumented_span = span!(parent: None, tracing::Level::INFO, "Compacting job", collection_id = ?job.collection_id); + instrumented_span.follows_from(Span::current()); + self.compact(job).instrument(instrumented_span) + }) + .collect::>(); + + tracing::info!("Running {} compaction jobs", job_futures.len()); + + job_futures + .filter_map(|result| async move { + match result { + Ok(response) => { + tracing::info!("Compaction completed: {response:?}"); + Some(response.compaction_job.collection_id) + } + Err(err) => { + tracing::error!("Compaction failed {err}"); + None + } } - } - } - (num_completed_jobs, num_failed_jobs) + }) + .collect() + .await } pub(crate) fn set_dispatcher(&mut self, dispatcher: ComponentHandle) { @@ -285,11 +285,13 @@ impl Component for CompactionManager { } async fn start(&mut self, ctx: &ComponentContext) -> () { - println!("Starting CompactionManager"); - ctx.scheduler - .schedule(ScheduleMessage {}, self.compaction_interval, ctx, || { - Some(span!(parent: None, tracing::Level::INFO, "Scheduled compaction")) - }); + tracing::info!("Starting CompactionManager"); + ctx.scheduler.schedule( + ScheduledCompactionMessage {}, + self.compaction_interval, + ctx, + || Some(span!(parent: None, tracing::Level::INFO, "Scheduled compaction")), + ); } } @@ -301,25 +303,42 @@ impl Debug for CompactionManager { // ============== Handlers ============== #[async_trait] -impl Handler for CompactionManager { +impl Handler for CompactionManager { type Result = (); async fn handle( &mut self, - _message: ScheduleMessage, + _message: ScheduledCompactionMessage, ctx: &ComponentContext, ) { - println!("CompactionManager: Performing compaction"); - let mut ids = Vec::new(); - self.compact_batch(&mut ids).await; - + tracing::info!("CompactionManager: Performing scheduled compaction"); + let ids = self.compact_batch().await; self.hnsw_index_provider.purge_by_id(&ids).await; // Compaction is done, schedule the next compaction - ctx.scheduler - .schedule(ScheduleMessage {}, self.compaction_interval, ctx, || { - Some(span!(parent: None, tracing::Level::INFO, "Scheduled compaction")) - }); + ctx.scheduler.schedule( + ScheduledCompactionMessage {}, + self.compaction_interval, + ctx, + || Some(span!(parent: None, tracing::Level::INFO, "Scheduled compaction")), + ); + } +} + +#[async_trait] +impl Handler for CompactionManager { + type Result = (); + async fn handle( + &mut self, + message: OneOffCompactionMessage, + _ctx: &ComponentContext, + ) { + self.scheduler + .add_oneoff_collections(message.collection_ids); + tracing::info!( + "One-off collections queued: {:?}", + self.scheduler.get_oneoff_collections() + ); } } @@ -416,6 +435,7 @@ mod tests { database: "database_1".to_string(), log_position: -1, version: 0, + total_records_post_compaction: 0, }; let tenant_2 = "tenant_2".to_string(); @@ -428,6 +448,7 @@ mod tests { database: "database_2".to_string(), log_position: -1, version: 0, + total_records_post_compaction: 0, }; match *sysdb { SysDb::Test(ref mut sysdb) => { @@ -566,10 +587,7 @@ mod tests { let dispatcher_handle = system.start_component(dispatcher); manager.set_dispatcher(dispatcher_handle); manager.set_system(system); - let mut compacted = vec![]; - let (num_completed, number_failed) = manager.compact_batch(&mut compacted).await; - assert_eq!(num_completed, 2); - assert_eq!(number_failed, 0); + let compacted = manager.compact_batch().await; assert!( (compacted == vec![collection_uuid_1, collection_uuid_2]) || (compacted == vec![collection_uuid_2, collection_uuid_1]) diff --git a/rust/worker/src/compactor/compaction_server.rs b/rust/worker/src/compactor/compaction_server.rs new file mode 100644 index 00000000000..08c0d94781d --- /dev/null +++ b/rust/worker/src/compactor/compaction_server.rs @@ -0,0 +1,60 @@ +use async_trait::async_trait; +use chroma_system::ComponentHandle; +use chroma_types::chroma_proto::{ + compactor_server::{Compactor, CompactorServer}, + CompactionRequest, CompactionResponse, +}; +use tokio::signal::unix::{signal, SignalKind}; +use tonic::{transport::Server, Request, Response, Status}; +use tracing::trace_span; + +use crate::compactor::OneOffCompactionMessage; + +use super::CompactionManager; + +pub struct CompactionServer { + pub manager: ComponentHandle, + pub port: u16, +} + +impl CompactionServer { + pub async fn run(self) -> Result<(), Box> { + let addr = format!("[::]:{}", self.port).parse().unwrap(); + tracing::info!("Compaction server listening at {addr}"); + let server = Server::builder().add_service(CompactorServer::new(self)); + server + .serve_with_shutdown(addr, async { + match signal(SignalKind::terminate()) { + Ok(mut sigterm) => { + sigterm.recv().await; + tracing::info!("Received SIGTERM, shutting down") + } + Err(err) => { + tracing::error!("Failed to create SIGTERM handler: {err}") + } + } + }) + .await?; + Ok(()) + } +} + +#[async_trait] +impl Compactor for CompactionServer { + async fn compact( + &self, + request: Request, + ) -> Result, Status> { + let compaction_span = trace_span!("CompactionRequest", request = ?request); + self.manager + .receiver() + .send( + OneOffCompactionMessage::try_from(request.into_inner()) + .map_err(|e| Status::invalid_argument(e.to_string()))?, + Some(compaction_span), + ) + .await + .map_err(|e| Status::internal(e.to_string()))?; + Ok(Response::new(CompactionResponse {})) + } +} diff --git a/rust/worker/src/compactor/mod.rs b/rust/worker/src/compactor/mod.rs index 08e07c63a14..1c07272a7bb 100644 --- a/rust/worker/src/compactor/mod.rs +++ b/rust/worker/src/compactor/mod.rs @@ -6,3 +6,6 @@ mod types; pub(crate) use compaction_manager::*; pub(crate) use types::*; + +pub mod compaction_client; +pub mod compaction_server; diff --git a/rust/worker/src/compactor/scheduler.rs b/rust/worker/src/compactor/scheduler.rs index db99f677ac1..149e71c6c08 100644 --- a/rust/worker/src/compactor/scheduler.rs +++ b/rust/worker/src/compactor/scheduler.rs @@ -26,6 +26,7 @@ pub(crate) struct Scheduler { min_compaction_size: usize, memberlist: Option, assignment_policy: Box, + oneoff_collections: HashSet, disabled_collections: HashSet, } @@ -56,10 +57,19 @@ impl Scheduler { max_concurrent_jobs, memberlist: None, assignment_policy, + oneoff_collections: HashSet::new(), disabled_collections, } } + pub(crate) fn add_oneoff_collections(&mut self, ids: Vec) { + self.oneoff_collections.extend(ids); + } + + pub(crate) fn get_oneoff_collections(&self) -> Vec { + self.oneoff_collections.iter().cloned().collect() + } + async fn get_collections_with_new_data(&mut self) -> Vec { let collections = self .log @@ -154,7 +164,7 @@ impl Scheduler { } } } - self.filter_collections(collection_records) + collection_records } fn filter_collections(&mut self, collections: Vec) -> Vec { @@ -182,11 +192,35 @@ impl Scheduler { } pub(crate) async fn schedule_internal(&mut self, collection_records: Vec) { - let jobs = self - .policy - .determine(collection_records, self.max_concurrent_jobs as i32); self.job_queue.clear(); - self.job_queue.extend(jobs); + let mut scheduled_collections = Vec::new(); + for record in collection_records { + if self.oneoff_collections.contains(&record.collection_id) { + tracing::info!( + "Creating one-off compaction job for collection: {}", + record.collection_version + ); + self.job_queue.push(CompactionJob { + collection_id: record.collection_id, + tenant_id: record.tenant_id, + offset: record.offset, + collection_version: record.collection_version, + }); + self.oneoff_collections.remove(&record.collection_id); + if self.job_queue.len() == self.max_concurrent_jobs { + return; + } + } else { + scheduled_collections.push(record); + } + } + + let filtered_collections = self.filter_collections(scheduled_collections); + self.job_queue.extend( + self.policy + .determine(filtered_collections, self.max_concurrent_jobs as i32), + ); + self.job_queue.truncate(self.max_concurrent_jobs); } pub(crate) fn recompute_disabled_collections(&mut self) { @@ -315,6 +349,7 @@ mod tests { database: "database_1".to_string(), log_position: 0, version: 0, + total_records_post_compaction: 0, }; let tenant_2 = "tenant_2".to_string(); @@ -327,6 +362,7 @@ mod tests { database: "database_2".to_string(), log_position: 0, version: 0, + total_records_post_compaction: 0, }; match *sysdb { SysDb::Test(ref mut sysdb) => { @@ -539,6 +575,7 @@ mod tests { database: "database_1".to_string(), log_position: 0, version: 0, + total_records_post_compaction: 0, }; match *sysdb { diff --git a/rust/worker/src/compactor/types.rs b/rust/worker/src/compactor/types.rs index bcdc0f63e7c..ccdec24e2a9 100644 --- a/rust/worker/src/compactor/types.rs +++ b/rust/worker/src/compactor/types.rs @@ -9,4 +9,9 @@ pub(crate) struct CompactionJob { } #[derive(Clone, Debug)] -pub(crate) struct ScheduleMessage {} +pub struct ScheduledCompactionMessage {} + +#[derive(Clone, Debug)] +pub struct OneOffCompactionMessage { + pub collection_ids: Vec, +} diff --git a/rust/worker/src/execution/operators/mod.rs b/rust/worker/src/execution/operators/mod.rs index a469a82464b..c5b3e8d61ad 100644 --- a/rust/worker/src/execution/operators/mod.rs +++ b/rust/worker/src/execution/operators/mod.rs @@ -4,6 +4,7 @@ pub(super) mod count_records; pub mod flush_segment_writer; pub mod materialize_logs; pub(super) mod partition; +pub mod prefetch_segment; pub(super) mod register; pub mod spann_bf_pl; pub(super) mod spann_centers_search; diff --git a/rust/worker/src/execution/operators/prefetch_segment.rs b/rust/worker/src/execution/operators/prefetch_segment.rs new file mode 100644 index 00000000000..2ea54f55fb0 --- /dev/null +++ b/rust/worker/src/execution/operators/prefetch_segment.rs @@ -0,0 +1,219 @@ +use chroma_blockstore::provider::BlockfileProvider; +use chroma_error::ChromaError; +use chroma_system::{Operator, OperatorType}; +use chroma_types::{Segment, SegmentType}; +use futures::{stream::FuturesUnordered, StreamExt}; +use thiserror::Error; +use tonic::async_trait; +use uuid::Uuid; + +#[derive(Debug, Default)] +pub struct PrefetchSegmentOperator {} + +impl PrefetchSegmentOperator { + pub fn new() -> Self { + Self::default() + } +} + +#[derive(Debug)] +pub struct PrefetchSegmentInput { + segment: Segment, + blockfile_provider: BlockfileProvider, +} + +impl PrefetchSegmentInput { + pub fn new(segment: Segment, blockfile_provider: BlockfileProvider) -> Self { + Self { + segment, + blockfile_provider, + } + } +} + +#[derive(Debug)] +pub struct PrefetchSegmentOutput { + #[allow(dead_code)] + num_blocks_fetched: usize, +} + +#[derive(Debug, Error)] +pub enum PrefetchSegmentError { + #[error("Could not parse blockfile ID string: {0}")] + ParseBlockfileId(#[from] uuid::Error), + #[error("Error prefetching: {0}")] + Prefetch(#[from] Box), + #[error("Unsupported segment type: {:?}", .0)] + UnsupportedSegmentType(SegmentType), +} + +impl ChromaError for PrefetchSegmentError { + fn code(&self) -> chroma_error::ErrorCodes { + match self { + PrefetchSegmentError::ParseBlockfileId(_) => chroma_error::ErrorCodes::InvalidArgument, + PrefetchSegmentError::Prefetch(err) => err.code(), + PrefetchSegmentError::UnsupportedSegmentType(_) => { + chroma_error::ErrorCodes::InvalidArgument + } + } + } +} + +#[async_trait] +impl Operator for PrefetchSegmentOperator { + type Error = PrefetchSegmentError; + + async fn run( + &self, + input: &PrefetchSegmentInput, + ) -> Result { + if input.segment.r#type != SegmentType::BlockfileMetadata + && input.segment.r#type != SegmentType::BlockfileRecord + { + return Err(PrefetchSegmentError::UnsupportedSegmentType( + input.segment.r#type, + )); + } + + tracing::info!( + "Prefetching segment: {:?} ({:?})", + input.segment.r#type, + input.segment.id, + ); + + let mut futures = input + .segment + .file_path + .values() + .flatten() + .map(|blockfile_id| async move { + let blockfile_id = Uuid::parse_str(blockfile_id)?; + let count = input.blockfile_provider.prefetch(&blockfile_id).await?; + Ok::<_, PrefetchSegmentError>(count) + }) + .collect::>(); + + let mut total_blocks_fetched = 0; + while let Some(result) = futures.next().await { + total_blocks_fetched += result?; + } + + Ok(PrefetchSegmentOutput { + num_blocks_fetched: total_blocks_fetched, + }) + } + + fn get_type(&self) -> OperatorType { + OperatorType::IO + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::segment::materialize_logs; + use crate::segment::record_segment::{RecordSegmentReader, RecordSegmentWriter}; + use chroma_cache::new_cache_for_test; + use chroma_storage::test_storage; + use chroma_types::{Chunk, CollectionUuid, LogRecord, Operation, OperationRecord, SegmentUuid}; + use std::collections::HashMap; + use std::str::FromStr; + + #[tokio::test] + async fn test_loads_blocks_into_cache() { + let cache = new_cache_for_test(); + let blockfile_provider = + BlockfileProvider::new_arrow(test_storage(), 1000, cache, new_cache_for_test()); + + let mut record_segment = chroma_types::Segment { + id: SegmentUuid::from_str("00000000-0000-0000-0000-000000000000").expect("parse error"), + r#type: chroma_types::SegmentType::BlockfileRecord, + scope: chroma_types::SegmentScope::RECORD, + collection: CollectionUuid::from_str("00000000-0000-0000-0000-000000000000") + .expect("parse error"), + metadata: None, + file_path: HashMap::new(), + }; + { + let segment_writer = + RecordSegmentWriter::from_segment(&record_segment, &blockfile_provider) + .await + .expect("Error creating segment writer"); + let data = vec![ + LogRecord { + log_offset: 1, + record: OperationRecord { + id: "embedding_id_1".to_string(), + embedding: Some(vec![1.0, 2.0, 3.0]), + encoding: None, + metadata: None, + document: None, + operation: Operation::Add, + }, + }, + LogRecord { + log_offset: 2, + record: OperationRecord { + id: "embedding_id_2".to_string(), + embedding: Some(vec![4.0, 5.0, 6.0]), + encoding: None, + metadata: None, + document: None, + operation: Operation::Add, + }, + }, + LogRecord { + log_offset: 3, + record: OperationRecord { + id: "embedding_id_1".to_string(), + embedding: None, + encoding: None, + metadata: None, + document: None, + operation: Operation::Delete, + }, + }, + ]; + let data: Chunk = Chunk::new(data.into()); + let record_segment_reader: Option = None; + + let mat_records = materialize_logs(&record_segment_reader, data, None) + .await + .expect("Log materialization failed"); + segment_writer + .apply_materialized_log_chunk(&record_segment_reader, &mat_records) + .await + .expect("Apply materialized log failed"); + let flusher = segment_writer + .commit() + .await + .expect("Commit for segment writer failed"); + record_segment.file_path = flusher.flush().await.expect("Flush segment writer failed"); + } + + // Since our cache is write-through, this should have no effect + let input = PrefetchSegmentInput::new(record_segment.clone(), blockfile_provider.clone()); + let operator = PrefetchSegmentOperator::new(); + + let result = operator + .run(&input) + .await + .expect("Prefetch operator run failed"); + + assert_eq!(result.num_blocks_fetched, 0); + + // Clear the cache... + blockfile_provider.clear().await.unwrap(); + + // ...and now blocks should be fetched + let input = PrefetchSegmentInput::new(record_segment, blockfile_provider); + let operator = PrefetchSegmentOperator::new(); + + let result = operator + .run(&input) + .await + .expect("Prefetch operator run failed"); + + assert!(result.num_blocks_fetched > 0); + } +} diff --git a/rust/worker/src/execution/operators/register.rs b/rust/worker/src/execution/operators/register.rs index 330c2975b7b..60eb01e3e15 100644 --- a/rust/worker/src/execution/operators/register.rs +++ b/rust/worker/src/execution/operators/register.rs @@ -34,6 +34,7 @@ impl RegisterOperator { /// collection version in sysdb is not the same as the current collection version, the flush /// operation will fail. /// * `segment_flush_info` - The segment flush info. +/// * `total_records_post_compaction` - The total number of records in the collection post compaction. /// * `sysdb` - The sysdb client. /// * `log` - The log client. pub struct RegisterInput { @@ -42,11 +43,13 @@ pub struct RegisterInput { log_position: i64, collection_version: i32, segment_flush_info: Arc<[SegmentFlushInfo]>, + total_records_post_compaction: u64, sysdb: Box, log: Box, } impl RegisterInput { + #[allow(clippy::too_many_arguments)] /// Create a new flush sysdb input. pub fn new( tenant: String, @@ -54,6 +57,7 @@ impl RegisterInput { log_position: i64, collection_version: i32, segment_flush_info: Arc<[SegmentFlushInfo]>, + total_records_post_compaction: u64, sysdb: Box, log: Box, ) -> Self { @@ -63,6 +67,7 @@ impl RegisterInput { log_position, collection_version, segment_flush_info, + total_records_post_compaction, sysdb, log, } @@ -112,6 +117,7 @@ impl Operator for RegisterOperator { input.log_position, input.collection_version, input.segment_flush_info.clone(), + input.total_records_post_compaction, ) .await; @@ -153,6 +159,7 @@ mod tests { let collection_uuid_1 = CollectionUuid::from_str("00000000-0000-0000-0000-000000000001").unwrap(); let tenant_1 = "tenant_1".to_string(); + let total_records_post_compaction: u64 = 5; let collection_1 = Collection { collection_id: collection_uuid_1, name: "collection_1".to_string(), @@ -162,6 +169,7 @@ mod tests { database: "database_1".to_string(), log_position: 0, version: collection_version, + total_records_post_compaction, }; let collection_uuid_2 = @@ -176,6 +184,7 @@ mod tests { database: "database_2".to_string(), log_position: 0, version: collection_version, + total_records_post_compaction, }; match *sysdb { @@ -242,6 +251,7 @@ mod tests { log_position, collection_version, segment_flush_info.into(), + total_records_post_compaction, sysdb.clone(), log.clone(), ); @@ -268,6 +278,10 @@ mod tests { assert_eq!(collection.len(), 1); let collection = collection[0].clone(); assert_eq!(collection.log_position, log_position); + assert_eq!( + collection.total_records_post_compaction, + total_records_post_compaction + ); let collection_1_segments = sysdb .get_segments(None, None, None, collection_uuid_1) diff --git a/rust/worker/src/execution/orchestration/compact.rs b/rust/worker/src/execution/orchestration/compact.rs index b2659463e6b..6c7985335f4 100644 --- a/rust/worker/src/execution/orchestration/compact.rs +++ b/rust/worker/src/execution/orchestration/compact.rs @@ -21,6 +21,10 @@ use crate::execution::operators::partition::PartitionError; use crate::execution::operators::partition::PartitionInput; use crate::execution::operators::partition::PartitionOperator; use crate::execution::operators::partition::PartitionOutput; +use crate::execution::operators::prefetch_segment::PrefetchSegmentError; +use crate::execution::operators::prefetch_segment::PrefetchSegmentInput; +use crate::execution::operators::prefetch_segment::PrefetchSegmentOperator; +use crate::execution::operators::prefetch_segment::PrefetchSegmentOutput; use crate::execution::operators::register::RegisterError; use crate::execution::operators::register::RegisterInput; use crate::execution::operators::register::RegisterOperator; @@ -131,6 +135,8 @@ pub struct CompactOrchestrator { flush_results: Vec, // We track a parent span for each segment type so we can group all the spans for a given segment type (makes the resulting trace much easier to read) segment_spans: HashMap, + // Total number of records in the collection after the compaction + total_records_last_compaction: u64, } #[derive(Error, Debug)] @@ -171,6 +177,8 @@ pub enum CompactionError { MaterializeLogs(#[from] MaterializeLogOperatorError), #[error("Apply logs to segment writer error: {0}")] ApplyLogToSegmentWriter(#[from] ApplyLogToSegmentWriterOperatorError), + #[error("Prefetch segment error: {0}")] + PrefetchSegment(#[from] PrefetchSegmentError), #[error("Commit segment writer error: {0}")] CommitSegmentWriter(#[from] CommitSegmentWriterOperatorError), #[error("Flush segment writer error: {0}")] @@ -253,6 +261,7 @@ impl CompactOrchestrator { writers: OnceCell::new(), flush_results: Vec::new(), segment_spans: HashMap::new(), + total_records_last_compaction: 0, } } @@ -268,6 +277,21 @@ impl CompactOrchestrator { let input = PartitionInput::new(records, self.max_partition_size); let task = wrap(operator, input, ctx.receiver()); self.send(task, ctx).await; + + let segments = self.get_all_segments().await.unwrap(); + for segment in segments { + if segment.r#type == SegmentType::BlockfileMetadata + || segment.r#type == SegmentType::BlockfileRecord + { + let prefetch_task = wrap( + Box::new(PrefetchSegmentOperator::new()), + PrefetchSegmentInput::new(segment, self.blockfile_provider.clone()), + ctx.receiver(), + ); + + self.send(prefetch_task, ctx).await; + } + } } async fn materialize_log( @@ -471,6 +495,7 @@ impl CompactOrchestrator { log_position, self.compaction_job.collection_version, self.flush_results.clone().into(), + self.total_records_last_compaction, self.sysdb.clone(), self.log.clone(), ); @@ -721,6 +746,22 @@ impl Handler> for CompactOrchestrator } } +#[async_trait] +impl Handler> for CompactOrchestrator { + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + match self.ok_or_terminate(message.into_inner(), ctx) { + Some(_) => (), + None => return, + } + } +} + #[async_trait] impl Handler> for CompactOrchestrator { type Result = (); @@ -840,7 +881,13 @@ impl Handler return, }; - self.dispatch_segment_flush(message.flusher, ctx.receiver(), ctx) + let flusher = message.flusher; + // If the flusher recieved is a record segment flusher, get the number of keys for the blockfile and set it on the orchestrator + if let ChromaSegmentFlusher::RecordSegment(ref record_segment_flusher) = flusher { + self.total_records_last_compaction = record_segment_flusher.count(); + } + + self.dispatch_segment_flush(flusher, ctx.receiver(), ctx) .await; } } diff --git a/rust/worker/src/lib.rs b/rust/worker/src/lib.rs index effe97a812c..6fd187e3f96 100644 --- a/rust/worker/src/lib.rs +++ b/rust/worker/src/lib.rs @@ -6,6 +6,9 @@ mod tracing; mod utils; use chroma_config::Configurable; +use clap::Parser; +use compactor::compaction_client::CompactionClient; +use compactor::compaction_server::CompactionServer; use memberlist::MemberlistProvider; use tokio::select; @@ -131,6 +134,15 @@ pub async fn compaction_service_entrypoint() { let mut memberlist_handle = system.start_component(memberlist); + let compaction_server = CompactionServer { + manager: compaction_manager_handle.clone(), + port: config.my_port, + }; + + let server_join_handle = tokio::spawn(async move { + let _ = compaction_server.run().await; + }); + let mut sigterm = match signal(SignalKind::terminate()) { Ok(sigterm) => sigterm, Err(e) => { @@ -151,7 +163,15 @@ pub async fn compaction_service_entrypoint() { let _ = compaction_manager_handle.join().await; system.stop().await; system.join().await; + let _ = server_join_handle.await; }, }; println!("Server stopped"); } + +pub async fn compaction_client_entrypoint() { + let client = CompactionClient::parse(); + if let Err(e) = client.run().await { + eprintln!("{e}"); + } +} diff --git a/rust/worker/src/segment/metadata_segment.rs b/rust/worker/src/segment/metadata_segment.rs index 72eaff4444c..c77a4152f13 100644 --- a/rust/worker/src/segment/metadata_segment.rs +++ b/rust/worker/src/segment/metadata_segment.rs @@ -1387,6 +1387,8 @@ mod test { .commit() .await .expect("Commit for segment writer failed"); + let count = record_flusher.count(); + assert_eq!(count, 2_u64); let metadata_flusher = metadata_writer .commit() .await diff --git a/rust/worker/src/segment/record_segment.rs b/rust/worker/src/segment/record_segment.rs index 4d28487e47c..83fe6a3fb15 100644 --- a/rust/worker/src/segment/record_segment.rs +++ b/rust/worker/src/segment/record_segment.rs @@ -632,6 +632,10 @@ impl RecordSegmentFlusher { Ok(flushed_files) } + + pub(crate) fn count(&self) -> u64 { + self.id_to_user_id_flusher.count() + } } #[derive(Clone)] diff --git a/rust/worker/src/segment/test.rs b/rust/worker/src/segment/test.rs index c104a9ddd32..ac608d6f046 100644 --- a/rust/worker/src/segment/test.rs +++ b/rust/worker/src/segment/test.rs @@ -36,6 +36,7 @@ impl TestSegment { database: String::new(), log_position: 0, version: 0, + total_records_post_compaction: 0, }; Self { blockfile_provider: test_arrow_blockfile_provider(2 << 22), diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index 828b575cf73..187c331e078 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -451,6 +451,7 @@ mod tests { database: "test-database".to_string(), log_position: 0, version: 0, + total_records_post_compaction: 0, }), knn: Some(chroma_proto::Segment { id: Uuid::new_v4().to_string(), @@ -553,6 +554,7 @@ mod tests { database: "test-database".to_string(), log_position: 0, version: 0, + total_records_post_compaction: 0, }); let request = chroma_proto::GetPlan { scan: Some(scan_operator.clone()), diff --git a/rust/worker/src/utils/convert.rs b/rust/worker/src/utils/convert.rs index cb14475b806..2a2f111ebbf 100644 --- a/rust/worker/src/utils/convert.rs +++ b/rust/worker/src/utils/convert.rs @@ -1,14 +1,19 @@ +use std::str::FromStr; + use chroma_types::{ chroma_proto::{self, GetResult, KnnBatchResult, KnnResult}, - ConversionError, ScalarEncoding, Where, + CollectionUuid, ConversionError, ScalarEncoding, Where, }; -use crate::execution::operators::{ - filter::FilterOperator, - knn::KnnOperator, - knn_projection::{KnnProjectionOperator, KnnProjectionOutput, KnnProjectionRecord}, - limit::LimitOperator, - projection::{ProjectionOperator, ProjectionOutput, ProjectionRecord}, +use crate::{ + compactor::OneOffCompactionMessage, + execution::operators::{ + filter::FilterOperator, + knn::KnnOperator, + knn_projection::{KnnProjectionOperator, KnnProjectionOutput, KnnProjectionRecord}, + limit::LimitOperator, + projection::{ProjectionOperator, ProjectionOutput, ProjectionRecord}, + }, }; impl TryFrom for FilterOperator { @@ -110,7 +115,7 @@ impl TryFrom for chroma_proto::KnnProjectionRecord { type Error = ConversionError; fn try_from(value: KnnProjectionRecord) -> Result { - Ok(chroma_proto::KnnProjectionRecord { + Ok(Self { record: Some(value.record.try_into()?), distance: value.distance, }) @@ -121,7 +126,7 @@ impl TryFrom for KnnResult { type Error = ConversionError; fn try_from(value: KnnProjectionOutput) -> Result { - Ok(KnnResult { + Ok(Self { records: value .records .into_iter() @@ -154,3 +159,20 @@ pub fn to_proto_knn_batch_result( .collect::>()?, }) } + +impl TryFrom for OneOffCompactionMessage { + type Error = ConversionError; + + fn try_from(value: chroma_proto::CompactionRequest) -> Result { + Ok(Self { + collection_ids: value + .ids + .ok_or(ConversionError::DecodeError)? + .ids + .into_iter() + .map(|id| CollectionUuid::from_str(&id)) + .collect::>() + .map_err(|_| ConversionError::DecodeError)?, + }) + } +}