diff --git a/chromadb/proto/convert.py b/chromadb/proto/convert.py index 2c095b2b291..0f1c28081c1 100644 --- a/chromadb/proto/convert.py +++ b/chromadb/proto/convert.py @@ -4,7 +4,7 @@ import numpy as np from numpy.typing import NDArray -import chromadb.proto.chroma_pb2 as chroma_pb +import chromadb.proto.chroma_pb2 as query_pb import chromadb.proto.query_executor_pb2 as query_pb from chromadb.api.configuration import CollectionConfigurationInternal from chromadb.api.types import Embedding, Where, WhereDocument @@ -22,14 +22,12 @@ Metadata, Operation, OperationRecord, - RequestVersionContext, ScalarEncoding, Segment, SegmentScope, SeqId, UpdateMetadata, Vector, - VectorQueryResult, ) @@ -46,60 +44,60 @@ class KNNProjectionRecord(TypedDict): # TODO: Unit tests for this file, handling optional states etc -def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> chroma_pb.Vector: +def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> query_pb.Vector: if encoding == ScalarEncoding.FLOAT32: as_bytes = np.array(vector, dtype=np.float32).tobytes() - proto_encoding = chroma_pb.ScalarEncoding.FLOAT32 + proto_encoding = query_pb.ScalarEncoding.FLOAT32 elif encoding == ScalarEncoding.INT32: as_bytes = np.array(vector, dtype=np.int32).tobytes() - proto_encoding = chroma_pb.ScalarEncoding.INT32 + proto_encoding = query_pb.ScalarEncoding.INT32 else: raise ValueError( f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \ or {ScalarEncoding.INT32}" ) - return chroma_pb.Vector(dimension=vector.size, vector=as_bytes, encoding=proto_encoding) + return query_pb.Vector(dimension=vector.size, vector=as_bytes, encoding=proto_encoding) -def from_proto_vector(vector: chroma_pb.Vector) -> Tuple[Embedding, ScalarEncoding]: +def from_proto_vector(vector: query_pb.Vector) -> Tuple[Embedding, ScalarEncoding]: encoding = vector.encoding as_array: Union[NDArray[np.int32], NDArray[np.float32]] - if encoding == chroma_pb.ScalarEncoding.FLOAT32: + if encoding == query_pb.ScalarEncoding.FLOAT32: as_array = np.frombuffer(vector.vector, dtype=np.float32) out_encoding = ScalarEncoding.FLOAT32 - elif encoding == chroma_pb.ScalarEncoding.INT32: + elif encoding == query_pb.ScalarEncoding.INT32: as_array = np.frombuffer(vector.vector, dtype=np.int32) out_encoding = ScalarEncoding.INT32 else: raise ValueError( f"Unknown encoding {encoding}, expected one of \ - {chroma_pb.ScalarEncoding.FLOAT32} or {chroma_pb.ScalarEncoding.INT32}" + {query_pb.ScalarEncoding.FLOAT32} or {query_pb.ScalarEncoding.INT32}" ) return (as_array, out_encoding) -def from_proto_operation(operation: chroma_pb.Operation) -> Operation: - if operation == chroma_pb.Operation.ADD: +def from_proto_operation(operation: query_pb.Operation) -> Operation: + if operation == query_pb.Operation.ADD: return Operation.ADD - elif operation == chroma_pb.Operation.UPDATE: + elif operation == query_pb.Operation.UPDATE: return Operation.UPDATE - elif operation == chroma_pb.Operation.UPSERT: + elif operation == query_pb.Operation.UPSERT: return Operation.UPSERT - elif operation == chroma_pb.Operation.DELETE: + elif operation == query_pb.Operation.DELETE: return Operation.DELETE else: # TODO: full error raise RuntimeError(f"Unknown operation {operation}") -def from_proto_metadata(metadata: chroma_pb.UpdateMetadata) -> Optional[Metadata]: +def from_proto_metadata(metadata: query_pb.UpdateMetadata) -> Optional[Metadata]: return cast(Optional[Metadata], _from_proto_metadata_handle_none(metadata, False)) def from_proto_update_metadata( - metadata: chroma_pb.UpdateMetadata, + metadata: query_pb.UpdateMetadata, ) -> Optional[UpdateMetadata]: return cast( Optional[UpdateMetadata], _from_proto_metadata_handle_none(metadata, True) @@ -107,7 +105,7 @@ def from_proto_update_metadata( def _from_proto_metadata_handle_none( - metadata: chroma_pb.UpdateMetadata, is_update: bool + metadata: query_pb.UpdateMetadata, is_update: bool ) -> Optional[Union[UpdateMetadata, Metadata]]: if not metadata.metadata: return None @@ -128,14 +126,14 @@ def _from_proto_metadata_handle_none( return out_metadata -def to_proto_update_metadata(metadata: UpdateMetadata) -> chroma_pb.UpdateMetadata: - return chroma_pb.UpdateMetadata( +def to_proto_update_metadata(metadata: UpdateMetadata) -> query_pb.UpdateMetadata: + return query_pb.UpdateMetadata( metadata={k: to_proto_metadata_update_value(v) for k, v in metadata.items()} ) def from_proto_submit( - operation_record: chroma_pb.OperationRecord, seq_id: SeqId + operation_record: query_pb.OperationRecord, seq_id: SeqId ) -> LogRecord: embedding, encoding = from_proto_vector(operation_record.vector) record = LogRecord( @@ -151,7 +149,7 @@ def from_proto_submit( return record -def from_proto_segment(segment: chroma_pb.Segment) -> Segment: +def from_proto_segment(segment: query_pb.Segment) -> Segment: return Segment( id=UUID(hex=segment.id), type=segment.type, @@ -164,8 +162,8 @@ def from_proto_segment(segment: chroma_pb.Segment) -> Segment: ) -def to_proto_segment(segment: Segment) -> chroma_pb.Segment: - return chroma_pb.Segment( +def to_proto_segment(segment: Segment) -> query_pb.Segment: + return query_pb.Segment( id=segment["id"].hex, type=segment["type"], scope=to_proto_segment_scope(segment["scope"]), @@ -173,49 +171,49 @@ def to_proto_segment(segment: Segment) -> chroma_pb.Segment: metadata=None if segment["metadata"] is None else to_proto_update_metadata(segment["metadata"]), - file_paths={name: chroma_pb.FilePaths(paths=paths) for name, paths in segment["file_paths"].items()} + file_paths={name: query_pb.FilePaths(paths=paths) for name, paths in segment["file_paths"].items()} ) -def from_proto_segment_scope(segment_scope: chroma_pb.SegmentScope) -> SegmentScope: - if segment_scope == chroma_pb.SegmentScope.VECTOR: +def from_proto_segment_scope(segment_scope: query_pb.SegmentScope) -> SegmentScope: + if segment_scope == query_pb.SegmentScope.VECTOR: return SegmentScope.VECTOR - elif segment_scope == chroma_pb.SegmentScope.METADATA: + elif segment_scope == query_pb.SegmentScope.METADATA: return SegmentScope.METADATA - elif segment_scope == chroma_pb.SegmentScope.RECORD: + elif segment_scope == query_pb.SegmentScope.RECORD: return SegmentScope.RECORD else: raise RuntimeError(f"Unknown segment scope {segment_scope}") -def to_proto_segment_scope(segment_scope: SegmentScope) -> chroma_pb.SegmentScope: +def to_proto_segment_scope(segment_scope: SegmentScope) -> query_pb.SegmentScope: if segment_scope == SegmentScope.VECTOR: - return chroma_pb.SegmentScope.VECTOR + return query_pb.SegmentScope.VECTOR elif segment_scope == SegmentScope.METADATA: - return chroma_pb.SegmentScope.METADATA + return query_pb.SegmentScope.METADATA elif segment_scope == SegmentScope.RECORD: - return chroma_pb.SegmentScope.RECORD + return query_pb.SegmentScope.RECORD else: raise RuntimeError(f"Unknown segment scope {segment_scope}") def to_proto_metadata_update_value( value: Union[str, int, float, bool, None] -) -> chroma_pb.UpdateMetadataValue: +) -> query_pb.UpdateMetadataValue: # Be careful with the order here. Since bools are a subtype of int in python, # isinstance(value, bool) and isinstance(value, int) both return true # for a value of bool type. if isinstance(value, bool): - return chroma_pb.UpdateMetadataValue(bool_value=value) + return query_pb.UpdateMetadataValue(bool_value=value) elif isinstance(value, str): - return chroma_pb.UpdateMetadataValue(string_value=value) + return query_pb.UpdateMetadataValue(string_value=value) elif isinstance(value, int): - return chroma_pb.UpdateMetadataValue(int_value=value) + return query_pb.UpdateMetadataValue(int_value=value) elif isinstance(value, float): - return chroma_pb.UpdateMetadataValue(float_value=value) + return query_pb.UpdateMetadataValue(float_value=value) # None is used to delete the metadata key. elif value is None: - return chroma_pb.UpdateMetadataValue() + return query_pb.UpdateMetadataValue() else: raise ValueError( f"Unknown metadata value type {type(value)}, expected one of str, int, \ @@ -223,7 +221,7 @@ def to_proto_metadata_update_value( ) -def from_proto_collection(collection: chroma_pb.Collection) -> Collection: +def from_proto_collection(collection: query_pb.Collection) -> Collection: return Collection( id=UUID(hex=collection.id), name=collection.name, @@ -243,8 +241,8 @@ def from_proto_collection(collection: chroma_pb.Collection) -> Collection: ) -def to_proto_collection(collection: Collection) -> chroma_pb.Collection: - return chroma_pb.Collection( +def to_proto_collection(collection: Collection) -> query_pb.Collection: + return query_pb.Collection( id=collection["id"].hex, name=collection["name"], configuration_json_str=collection.get_configuration().to_json_str(), @@ -259,15 +257,15 @@ def to_proto_collection(collection: Collection) -> chroma_pb.Collection: ) -def to_proto_operation(operation: Operation) -> chroma_pb.Operation: +def to_proto_operation(operation: Operation) -> query_pb.Operation: if operation == Operation.ADD: - return chroma_pb.Operation.ADD + return query_pb.Operation.ADD elif operation == Operation.UPDATE: - return chroma_pb.Operation.UPDATE + return query_pb.Operation.UPDATE elif operation == Operation.UPSERT: - return chroma_pb.Operation.UPSERT + return query_pb.Operation.UPSERT elif operation == Operation.DELETE: - return chroma_pb.Operation.DELETE + return query_pb.Operation.DELETE else: raise ValueError( f"Unknown operation {operation}, expected one of {Operation.ADD}, \ @@ -277,7 +275,7 @@ def to_proto_operation(operation: Operation) -> chroma_pb.Operation: def to_proto_submit( submit_record: OperationRecord, -) -> chroma_pb.OperationRecord: +) -> query_pb.OperationRecord: vector = None if submit_record["embedding"] is not None and submit_record["encoding"] is not None: vector = to_proto_vector(submit_record["embedding"], submit_record["encoding"]) @@ -286,7 +284,7 @@ def to_proto_submit( if submit_record["metadata"] is not None: metadata = to_proto_update_metadata(submit_record["metadata"]) - return chroma_pb.OperationRecord( + return query_pb.OperationRecord( id=submit_record["id"], vector=vector, metadata=metadata, @@ -295,7 +293,7 @@ def to_proto_submit( def to_proto_where(where: Where) -> query_pb.Where: - response = chroma_pb.Where() + response = query_pb.Where() if len(where) != 1: raise ValueError(f"Expected where to have exactly one operator, got {where}") @@ -308,13 +306,13 @@ def to_proto_where(where: Where) -> query_pb.Where: raise ValueError( f"Expected where value for $and or $or to be a list of where expressions, got {value}" ) - children: chroma_pb.WhereChildren = chroma_pb.WhereChildren( + children: query_pb.WhereChildren = query_pb.WhereChildren( children=[to_proto_where(w) for w in value] ) if key == "$and": - children.operator = chroma_pb.BooleanOperator.AND + children.operator = query_pb.BooleanOperator.AND else: - children.operator = chroma_pb.BooleanOperator.OR + children.operator = query_pb.BooleanOperator.OR response.children.CopyFrom(children) return response @@ -322,30 +320,30 @@ def to_proto_where(where: Where) -> query_pb.Where: # At this point we know we're at a direct comparison. It can either # be of the form {"key": "value"} or {"key": {"$operator": "value"}}. - dc = chroma_pb.DirectComparison() + dc = query_pb.DirectComparison() dc.key = key if not isinstance(value, dict): # {'key': 'value'} case if type(value) is str: - ssc = chroma_pb.SingleStringComparison() + ssc = query_pb.SingleStringComparison() ssc.value = value - ssc.comparator = chroma_pb.GenericComparator.EQ + ssc.comparator = query_pb.GenericComparator.EQ dc.single_string_operand.CopyFrom(ssc) elif type(value) is bool: - sbc = chroma_pb.SingleBoolComparison() + sbc = query_pb.SingleBoolComparison() sbc.value = value - sbc.comparator = chroma_pb.GenericComparator.EQ + sbc.comparator = query_pb.GenericComparator.EQ dc.single_bool_operand.CopyFrom(sbc) elif type(value) is int: - sic = chroma_pb.SingleIntComparison() + sic = query_pb.SingleIntComparison() sic.value = value - sic.generic_comparator = chroma_pb.GenericComparator.EQ + sic.generic_comparator = query_pb.GenericComparator.EQ dc.single_int_operand.CopyFrom(sic) elif type(value) is float: - sdc = chroma_pb.SingleDoubleComparison() + sdc = query_pb.SingleDoubleComparison() sdc.value = value - sdc.generic_comparator = chroma_pb.GenericComparator.EQ + sdc.generic_comparator = query_pb.GenericComparator.EQ dc.single_double_operand.CopyFrom(sdc) else: raise ValueError( @@ -367,29 +365,29 @@ def to_proto_where(where: Where) -> query_pb.Where: ) list_operator = None if operator == "$in": - list_operator = chroma_pb.ListOperator.IN + list_operator = query_pb.ListOperator.IN else: - list_operator = chroma_pb.ListOperator.NIN + list_operator = query_pb.ListOperator.NIN if type(operand[0]) is str: - slo = chroma_pb.StringListComparison() + slo = query_pb.StringListComparison() for x in operand: slo.values.extend([x]) # type: ignore slo.list_operator = list_operator dc.string_list_operand.CopyFrom(slo) elif type(operand[0]) is bool: - blo = chroma_pb.BoolListComparison() + blo = query_pb.BoolListComparison() for x in operand: blo.values.extend([x]) # type: ignore blo.list_operator = list_operator dc.bool_list_operand.CopyFrom(blo) elif type(operand[0]) is int: - ilo = chroma_pb.IntListComparison() + ilo = query_pb.IntListComparison() for x in operand: ilo.values.extend([x]) # type: ignore ilo.list_operator = list_operator dc.int_list_operand.CopyFrom(ilo) elif type(operand[0]) is float: - dlo = chroma_pb.DoubleListComparison() + dlo = query_pb.DoubleListComparison() for x in operand: dlo.values.extend([x]) # type: ignore dlo.list_operator = list_operator @@ -401,64 +399,64 @@ def to_proto_where(where: Where) -> query_pb.Where: elif operator in ["$eq", "$ne", "$gt", "$lt", "$gte", "$lte"]: # Direct comparison to a single value. if type(operand) is str: - ssc = chroma_pb.SingleStringComparison() + ssc = query_pb.SingleStringComparison() ssc.value = operand if operator == "$eq": - ssc.comparator = chroma_pb.GenericComparator.EQ + ssc.comparator = query_pb.GenericComparator.EQ elif operator == "$ne": - ssc.comparator = chroma_pb.GenericComparator.NE + ssc.comparator = query_pb.GenericComparator.NE else: raise ValueError( f"Expected where operator to be $eq or $ne, got {operator}" ) dc.single_string_operand.CopyFrom(ssc) elif type(operand) is bool: - sbc = chroma_pb.SingleBoolComparison() + sbc = query_pb.SingleBoolComparison() sbc.value = operand if operator == "$eq": - sbc.comparator = chroma_pb.GenericComparator.EQ + sbc.comparator = query_pb.GenericComparator.EQ elif operator == "$ne": - sbc.comparator = chroma_pb.GenericComparator.NE + sbc.comparator = query_pb.GenericComparator.NE else: raise ValueError( f"Expected where operator to be $eq or $ne, got {operator}" ) dc.single_bool_operand.CopyFrom(sbc) elif type(operand) is int: - sic = chroma_pb.SingleIntComparison() + sic = query_pb.SingleIntComparison() sic.value = operand if operator == "$eq": - sic.generic_comparator = chroma_pb.GenericComparator.EQ + sic.generic_comparator = query_pb.GenericComparator.EQ elif operator == "$ne": - sic.generic_comparator = chroma_pb.GenericComparator.NE + sic.generic_comparator = query_pb.GenericComparator.NE elif operator == "$gt": - sic.number_comparator = chroma_pb.NumberComparator.GT + sic.number_comparator = query_pb.NumberComparator.GT elif operator == "$lt": - sic.number_comparator = chroma_pb.NumberComparator.LT + sic.number_comparator = query_pb.NumberComparator.LT elif operator == "$gte": - sic.number_comparator = chroma_pb.NumberComparator.GTE + sic.number_comparator = query_pb.NumberComparator.GTE elif operator == "$lte": - sic.number_comparator = chroma_pb.NumberComparator.LTE + sic.number_comparator = query_pb.NumberComparator.LTE else: raise ValueError( f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}" ) dc.single_int_operand.CopyFrom(sic) elif type(operand) is float: - sfc = chroma_pb.SingleDoubleComparison() + sfc = query_pb.SingleDoubleComparison() sfc.value = operand if operator == "$eq": - sfc.generic_comparator = chroma_pb.GenericComparator.EQ + sfc.generic_comparator = query_pb.GenericComparator.EQ elif operator == "$ne": - sfc.generic_comparator = chroma_pb.GenericComparator.NE + sfc.generic_comparator = query_pb.GenericComparator.NE elif operator == "$gt": - sfc.number_comparator = chroma_pb.NumberComparator.GT + sfc.number_comparator = query_pb.NumberComparator.GT elif operator == "$lt": - sfc.number_comparator = chroma_pb.NumberComparator.LT + sfc.number_comparator = query_pb.NumberComparator.LT elif operator == "$gte": - sfc.number_comparator = chroma_pb.NumberComparator.GTE + sfc.number_comparator = query_pb.NumberComparator.GTE elif operator == "$lte": - sfc.number_comparator = chroma_pb.NumberComparator.LTE + sfc.number_comparator = query_pb.NumberComparator.LTE else: raise ValueError( f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}" @@ -478,7 +476,7 @@ def to_proto_where(where: Where) -> query_pb.Where: def to_proto_where_document(where_document: WhereDocument) -> query_pb.WhereDocument: - response = chroma_pb.WhereDocument() + response = query_pb.WhereDocument() if len(where_document) != 1: raise ValueError( f"Expected where_document to have exactly one operator, got {where_document}" @@ -491,13 +489,13 @@ def to_proto_where_document(where_document: WhereDocument) -> query_pb.WhereDocu raise ValueError( f"Expected where_document value for $and or $or to be a list of where_document expressions, got {operand}" ) - children: chroma_pb.WhereDocumentChildren = chroma_pb.WhereDocumentChildren( + children: query_pb.WhereDocumentChildren = query_pb.WhereDocumentChildren( children=[to_proto_where_document(w) for w in operand] ) if operator == "$and": - children.operator = chroma_pb.BooleanOperator.AND + children.operator = query_pb.BooleanOperator.AND else: - children.operator = chroma_pb.BooleanOperator.OR + children.operator = query_pb.BooleanOperator.OR response.children.CopyFrom(children) else: @@ -507,12 +505,12 @@ def to_proto_where_document(where_document: WhereDocument) -> query_pb.WhereDocu raise ValueError( f"Expected where_document operand to be a string, got {operand}" ) - dwd = chroma_pb.DirectWhereDocument() + dwd = query_pb.DirectWhereDocument() dwd.document = operand if operator == "$contains": - dwd.operator = chroma_pb.WhereDocumentOperator.CONTAINS + dwd.operator = query_pb.WhereDocumentOperator.CONTAINS elif operator == "$not_contains": - dwd.operator = chroma_pb.WhereDocumentOperator.NOT_CONTAINS + dwd.operator = query_pb.WhereDocumentOperator.NOT_CONTAINS else: raise ValueError( f"Expected where_document operator to be one of $contains, $not_contains, got {operator}" @@ -533,7 +531,7 @@ def to_proto_scan(scan: Scan) -> query_pb.ScanOperator: def to_proto_filter(filter: Filter) -> query_pb.FilterOperator: return query_pb.FilterOperator( - ids=chroma_pb.UserIds(ids=filter.user_ids) if filter.user_ids is not None else None, + ids=query_pb.UserIds(ids=filter.user_ids) if filter.user_ids is not None else None, where=to_proto_where(filter.where) if filter.where else None, where_document=to_proto_where_document(filter.where_document) if filter.where_document diff --git a/chromadb/test/segment/distributed/test_protobuf_translation.py b/chromadb/test/segment/distributed/test_protobuf_translation.py index d29fcad0365..1992637a2ba 100644 --- a/chromadb/test/segment/distributed/test_protobuf_translation.py +++ b/chromadb/test/segment/distributed/test_protobuf_translation.py @@ -102,7 +102,7 @@ def test_where_document_to_proto_not_contains() -> None: proto = convert.to_proto_where_document(where_document) assert proto.HasField("direct") assert proto.direct.document == "test" - assert proto.direct.operator == pb.WhereDocumentOperator.NOT_CONTAINS + assert proto.direct.operator == query_pb.WhereDocumentOperator.NOT_CONTAINS def test_where_document_to_proto_contains_to_proto() -> None: @@ -110,7 +110,7 @@ def test_where_document_to_proto_contains_to_proto() -> None: proto = convert.to_proto_where_document(where_document) assert proto.HasField("direct") assert proto.direct.document == "test" - assert proto.direct.operator == pb.WhereDocumentOperator.CONTAINS + assert proto.direct.operator == query_pb.WhereDocumentOperator.CONTAINS def test_where_document_to_proto_and() -> None: @@ -123,7 +123,7 @@ def test_where_document_to_proto_and() -> None: proto = convert.to_proto_where_document(where_document) assert proto.HasField("children") children_pb = proto.children - assert children_pb.operator == pb.BooleanOperator.AND + assert children_pb.operator == query_pb.BooleanOperator.AND assert len(children_pb.children) == 2 children = children_pb.children @@ -131,8 +131,8 @@ def test_where_document_to_proto_and() -> None: assert child.HasField("direct") assert child.direct.document == "test" # Protobuf retains the order of repeated fields so this is safe. - assert children[0].direct.operator == pb.WhereDocumentOperator.CONTAINS - assert children[1].direct.operator == pb.WhereDocumentOperator.NOT_CONTAINS + assert children[0].direct.operator == query_pb.WhereDocumentOperator.CONTAINS + assert children[1].direct.operator == query_pb.WhereDocumentOperator.NOT_CONTAINS def test_where_document_to_proto_or() -> None: @@ -145,7 +145,7 @@ def test_where_document_to_proto_or() -> None: proto = convert.to_proto_where_document(where_document) assert proto.HasField("children") children_pb = proto.children - assert children_pb.operator == pb.BooleanOperator.OR + assert children_pb.operator == query_pb.BooleanOperator.OR assert len(children_pb.children) == 2 children = children_pb.children @@ -153,8 +153,8 @@ def test_where_document_to_proto_or() -> None: assert child.HasField("direct") assert child.direct.document == "test" # Protobuf retains the order of repeated fields so this is safe. - assert children[0].direct.operator == pb.WhereDocumentOperator.CONTAINS - assert children[1].direct.operator == pb.WhereDocumentOperator.NOT_CONTAINS + assert children[0].direct.operator == query_pb.WhereDocumentOperator.CONTAINS + assert children[1].direct.operator == query_pb.WhereDocumentOperator.NOT_CONTAINS def test_where_document_to_proto_nested_boolean_operators() -> None: @@ -177,7 +177,7 @@ def test_where_document_to_proto_nested_boolean_operators() -> None: proto = convert.to_proto_where_document(where_document) assert proto.HasField("children") children_pb = proto.children - assert children_pb.operator == pb.BooleanOperator.AND + assert children_pb.operator == query_pb.BooleanOperator.AND assert len(children_pb.children) == 2 children = children_pb.children @@ -190,9 +190,9 @@ def test_where_document_to_proto_nested_boolean_operators() -> None: assert nested_child.HasField("direct") assert nested_child.direct.document == "test" # Protobuf retains the order of repeated fields so this is safe. - assert nested_children[0].direct.operator == pb.WhereDocumentOperator.CONTAINS + assert nested_children[0].direct.operator == query_pb.WhereDocumentOperator.CONTAINS assert ( - nested_children[1].direct.operator == pb.WhereDocumentOperator.NOT_CONTAINS + nested_children[1].direct.operator == query_pb.WhereDocumentOperator.NOT_CONTAINS ) @@ -242,7 +242,7 @@ def test_where_to_proto_and() -> None: proto = convert.to_proto_where(where) assert proto.HasField("children") children_pb = proto.children - assert children_pb.operator == pb.BooleanOperator.AND + assert children_pb.operator == query_pb.BooleanOperator.AND children = children_pb.children assert len(children) == 2 @@ -266,7 +266,7 @@ def test_where_to_proto_or() -> None: proto = convert.to_proto_where(where) assert proto.HasField("children") children_pb = proto.children - assert children_pb.operator == pb.BooleanOperator.OR + assert children_pb.operator == query_pb.BooleanOperator.OR children = children_pb.children assert len(children) == 2 @@ -300,7 +300,7 @@ def test_where_to_proto_nested_boolean_operators() -> None: proto = convert.to_proto_where(where) assert proto.HasField("children") children_pb = proto.children - assert children_pb.operator == pb.BooleanOperator.AND + assert children_pb.operator == query_pb.BooleanOperator.AND assert len(children_pb.children) == 2 children = children_pb.children @@ -331,7 +331,7 @@ def test_where_to_proto_float_operator() -> None: proto = convert.to_proto_where(where) assert proto.HasField("children") children_pb = proto.children - assert children_pb.operator == pb.BooleanOperator.AND + assert children_pb.operator == query_pb.BooleanOperator.AND assert len(children_pb.children) == 2 children = children_pb.children