Skip to content

Commit

Permalink
[ENH] Round robin grpc connections amongst N query nodes (#3454)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Retries requests in a round robin fashion (preserving existing tracing
and configuration of the retry interceptor)
- Retries requests across `query_replication_factor` nodes in round
robin fashion.
 - New functionality
   - None

<img width="1728" alt="Screenshot 2025-01-14 at 3 18 01 PM"
src="https://github.com/user-attachments/assets/0cd177a0-4b7c-4bab-ba94-22b585001567"
/>



## Test plan
*How are these changes tested?*
Testing retries, oh my. Wouldn't that be nice?
We could, but it requires more investment than its worth. I manually
verified it for now, and cut a task.
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB authored Jan 15, 2025
1 parent 6ae644d commit 4442978
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 48 deletions.
1 change: 1 addition & 0 deletions chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def empty_str_to_none(cls, v: str) -> Optional[str]:
"chromadb.segment.impl.manager.local.LocalSegmentManager"
)
chroma_executor_impl: str = "chromadb.execution.executor.local.LocalExecutor"
chroma_query_replication_factor: int = 2

chroma_logservice_host = "localhost"
chroma_logservice_port = 50052
Expand Down
102 changes: 82 additions & 20 deletions chromadb/execution/executor/distributed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import threading
from typing import Dict, Optional
import random
from typing import Callable, Dict, List, Optional, TypeVar
import grpc
from overrides import overrides
from chromadb.api.types import GetResult, Metadata, QueryResult
Expand All @@ -11,6 +12,14 @@
from chromadb.proto.query_executor_pb2_grpc import QueryExecutorStub
from chromadb.segment.impl.manager.distributed import DistributedSegmentManager
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
from tenacity import (
RetryCallState,
Retrying,
stop_after_attempt,
wait_exponential_jitter,
retry_if_exception,
)
from opentelemetry.trace import Span


def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]:
Expand All @@ -34,37 +43,86 @@ def _uri(metadata: Optional[Metadata]) -> Optional[str]:
return None


# Type variables for input and output types of the round-robin retry function
I = TypeVar("I") # noqa: E741
O = TypeVar("O") # noqa: E741


class DistributedExecutor(Executor):
_mtx: threading.Lock
_grpc_stub_pool: Dict[str, QueryExecutorStub]
_manager: DistributedSegmentManager
_request_timeout_seconds: int
_query_replication_factor: int

def __init__(self, system: System):
super().__init__(system)
self._mtx = threading.Lock()
self._grpc_stub_pool = dict()
self._grpc_stub_pool = {}
self._manager = self.require(DistributedSegmentManager)
self._request_timeout_seconds = system.settings.require(
"chroma_query_request_timeout_seconds"
)
self._query_replication_factor = system.settings.require(
"chroma_query_replication_factor"
)

def _round_robin_retry(self, funcs: List[Callable[[I], O]], args: I) -> O:
"""
Retry a list of functions in a round-robin fashion until one of them succeeds.
funcs: List of functions to retry
args: Arguments to pass to each function
"""
attempt_count = 0
sleep_span: Optional[Span] = None

def before_sleep(_: RetryCallState) -> None:
# HACK(hammadb) 1/14/2024 - this is a hack to avoid the fact that tracer is not yet available and there are boot order issues
# This should really use our component system to get the tracer. Since our grpc utils use this pattern
# we are copying it here. This should be removed once we have a better way to get the tracer
from chromadb.telemetry.opentelemetry import tracer

nonlocal sleep_span
if tracer is not None:
sleep_span = tracer.start_span("Waiting to retry RPC")

for attempt in Retrying(
stop=stop_after_attempt(5),
wait=wait_exponential_jitter(0.1, jitter=0.1),
reraise=True,
retry=retry_if_exception(
lambda x: isinstance(x, grpc.RpcError)
and x.code() in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN]
),
before_sleep=before_sleep,
):
if sleep_span is not None:
sleep_span.end()
sleep_span = None

with attempt:
return funcs[attempt_count % len(funcs)](args)
attempt_count += 1

# NOTE(hammadb) because Retrying() will always either return or raise an exception, this line should never be reached
raise Exception("Unreachable code error - should never reach here")

@overrides
def count(self, plan: CountPlan) -> int:
executor = self._grpc_executor_stub(plan.scan)
try:
count_result = executor.Count(convert.to_proto_count_plan(plan))
except grpc.RpcError as rpc_error:
raise rpc_error
endpoints = self._get_grpc_endpoints(plan.scan)
count_funcs = [self._get_stub(endpoint).Count for endpoint in endpoints]
count_result = self._round_robin_retry(
count_funcs, convert.to_proto_count_plan(plan)
)
return convert.from_proto_count_result(count_result)

@overrides
def get(self, plan: GetPlan) -> GetResult:
executor = self._grpc_executor_stub(plan.scan)
try:
get_result = executor.Get(convert.to_proto_get_plan(plan))
except grpc.RpcError as rpc_error:
raise rpc_error
endpoints = self._get_grpc_endpoints(plan.scan)
get_funcs = [self._get_stub(endpoint).Get for endpoint in endpoints]
get_result = self._round_robin_retry(get_funcs, convert.to_proto_get_plan(plan))
records = convert.from_proto_get_result(get_result)

ids = [record["id"] for record in records]
Expand Down Expand Up @@ -102,11 +160,9 @@ def get(self, plan: GetPlan) -> GetResult:

@overrides
def knn(self, plan: KNNPlan) -> QueryResult:
executor = self._grpc_executor_stub(plan.scan)
try:
knn_result = executor.KNN(convert.to_proto_knn_plan(plan))
except grpc.RpcError as rpc_error:
raise rpc_error
endpoints = self._get_grpc_endpoints(plan.scan)
knn_funcs = [self._get_stub(endpoint).KNN for endpoint in endpoints]
knn_result = self._round_robin_retry(knn_funcs, convert.to_proto_knn_plan(plan))
results = convert.from_proto_knn_batch_result(knn_result)

ids = [[record["record"]["id"] for record in records] for records in results]
Expand Down Expand Up @@ -160,10 +216,17 @@ def knn(self, plan: KNNPlan) -> QueryResult:
included=plan.projection.included,
)

def _grpc_executor_stub(self, scan: Scan) -> QueryExecutorStub:
def _get_grpc_endpoints(self, scan: Scan) -> List[str]:
# Since grpc endpoint is endpoint is determined by collection uuid,
# the endpoint should be the same for all segments of the same collection
grpc_url = self._manager.get_endpoint(scan.record)
grpc_urls = self._manager.get_endpoints(
scan.record, self._query_replication_factor
)
# Shuffle the grpc urls to distribute the load evenly
random.shuffle(grpc_urls)
return grpc_urls

def _get_stub(self, grpc_url: str) -> QueryExecutorStub:
with self._mtx:
if grpc_url not in self._grpc_stub_pool:
channel = grpc.insecure_channel(
Expand All @@ -172,5 +235,4 @@ def _grpc_executor_stub(self, scan: Scan) -> QueryExecutorStub:
interceptors = [OtelInterceptor()]
channel = grpc.intercept_channel(channel, *interceptors)
self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel)

return self._grpc_stub_pool[grpc_url]
8 changes: 5 additions & 3 deletions chromadb/segment/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

class SegmentDirectory(Component):
"""A segment directory is a data interface that manages the location of segments. Concretely, this
means that for clustered chroma, it provides the grpc endpoint for a segment."""
means that for distributed chroma, it provides the grpc endpoint for a segment."""

@abstractmethod
def get_segment_endpoint(self, segment: Segment) -> str:
"""Return the segment residence for a given segment ID"""
def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]:
"""Return the segment residences for a given segment ID. Will return at most n residences.
Should only return less than n residences if there are less than n residences available.
"""

@abstractmethod
def register_updated_segment_callback(
Expand Down
52 changes: 30 additions & 22 deletions chromadb/segment/impl/distributed/segment_directory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import threading
import time
from typing import Any, Callable, Dict, Optional, cast

from typing import Any, Callable, Dict, List, Optional, cast
from kubernetes import client, config, watch
from kubernetes.client.rest import ApiException
from overrides import EnforceOverrides, override
Expand Down Expand Up @@ -254,54 +253,63 @@ def stop(self) -> None:
return super().stop()

@override
def get_segment_endpoint(self, segment: Segment) -> str:
def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]:
if self._curr_memberlist is None or len(self._curr_memberlist) == 0:
raise ValueError("Memberlist is not initialized")

# assign() will throw an error if n is greater than the number of members
# clamp n to the number of members to align with the contract of this method
# which is to return at most n endpoints
n = min(n, len(self._curr_memberlist))

# Check if all members in the memberlist have a node set,
# if so, route using the node

# NOTE(@hammadb) 1/8/2024: This is to handle the migration between routing
# using the member id and routing using the node name
# We want to route using the node name over the member id
# because the node may have a disk cache that we want a
# stable identifier for over deploys.
can_use_node_routing = all(
[m.node != "" and len(m.node) != 0 for m in self._curr_memberlist]
can_use_node_routing = (
all([m.node != "" and len(m.node) != 0 for m in self._curr_memberlist])
and self._routing_mode == RoutingMode.NODE
)
if can_use_node_routing and self._routing_mode == RoutingMode.NODE:
if can_use_node_routing:
# If we are using node routing and the segments
assignment = assign(
assignments = assign(
segment["collection"].hex,
[m.node for m in self._curr_memberlist],
murmur3hasher,
1,
)[0]
n,
)
else:
# Query to the same collection should end up on the same endpoint
assignment = assign(
assignments = assign(
segment["collection"].hex,
[m.id for m in self._curr_memberlist],
murmur3hasher,
1,
)[0]

service_name = self.extract_service_name(assignment)
# If the memberlist has an ip, use it, otherwise use the member id with the headless service
# this is for backwards compatibility with the old memberlist which only had ids
n,
)
assignments_set = set(assignments)
out_endpoints = []
for member in self._curr_memberlist:
is_chosen_with_node_routing = (
can_use_node_routing and member.node == assignment
can_use_node_routing and member.node in assignments_set
)
is_chosen_with_id_routing = (
not can_use_node_routing and member.id == assignment
not can_use_node_routing and member.id in assignments_set
)
if is_chosen_with_node_routing or is_chosen_with_id_routing:
# If the memberlist has an ip, use it, otherwise use the member id with the headless service
# this is for backwards compatibility with the old memberlist which only had ids
if member.ip is not None and member.ip != "":
endpoint = f"{member.ip}:50051"
return endpoint

endpoint = f"{assignment}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051" # TODO: make port configurable
return endpoint
out_endpoints.append(endpoint)
else:
service_name = self.extract_service_name(member.id)
endpoint = f"{member.id}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051"
out_endpoints.append(endpoint)
return out_endpoints

@override
def register_updated_segment_callback(
Expand Down
6 changes: 3 additions & 3 deletions chromadb/segment/impl/manager/distributed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from threading import Lock
from typing import Dict, Sequence
from typing import Dict, List, Sequence
from uuid import UUID, uuid4

from overrides import override
Expand Down Expand Up @@ -87,8 +87,8 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
"DistributedSegmentManager.get_endpoint",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
def get_endpoint(self, segment: Segment) -> str:
return self._segment_directory.get_segment_endpoint(segment)
def get_endpoints(self, segment: Segment, n: int) -> List[str]:
return self._segment_directory.get_segment_endpoints(segment, n)

@trace_method(
"DistributedSegmentManager.hint_use_collection",
Expand Down

0 comments on commit 4442978

Please sign in to comment.