diff --git a/pymilvus/client/entity_helper.py b/pymilvus/client/entity_helper.py index 1fa78c196..47a224670 100644 --- a/pymilvus/client/entity_helper.py +++ b/pymilvus/client/entity_helper.py @@ -291,7 +291,9 @@ def entity_to_array_arr(entity: List[Any], field_info: Any): return convert_to_array_arr(entity.get("values", []), field_info) -def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info: Any): +def pack_field_value_to_field_data( + field_value: Any, field_data: schema_types.FieldData, field_info: Any +): field_type = field_data.type if field_type == DataType.BOOL: field_data.scalars.bool_data.data.append(field_value) @@ -304,26 +306,51 @@ def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info elif field_type == DataType.DOUBLE: field_data.scalars.double_data.data.append(field_value) elif field_type == DataType.FLOAT_VECTOR: - field_data.vectors.dim = len(field_value) - field_data.vectors.float_vector.data.extend(field_value) + f_value = field_value + if isinstance(field_value, np.ndarray): + if field_value.dtype not in ("float32", "float64"): + raise ParamError( + message="invalid input for float32 vector, expect np.ndarray with dtype=float32" + ) + f_value = field_value.view(np.float32).tolist() + + field_data.vectors.dim = len(f_value) + field_data.vectors.float_vector.data.extend(f_value) + elif field_type == DataType.BINARY_VECTOR: field_data.vectors.dim = len(field_value) * 8 field_data.vectors.binary_vector += bytes(field_value) + elif field_type == DataType.FLOAT16_VECTOR: - v_bytes = ( - bytes(field_value) - if not isinstance(field_value, np.ndarray) - else field_value.view(np.uint8).tobytes() - ) + if isinstance(field_value, bytes): + v_bytes = field_value + elif isinstance(field_value, np.ndarray): + if field_value.dtype != "float16": + raise ParamError( + message="invalid input for float16 vector, expect np.ndarray with dtype=float16" + ) + v_bytes = field_value.view(np.uint8).tobytes() + else: + raise ParamError( + message="invalid input type for float16 vector, expect np.ndarray with dtype=float16" + ) field_data.vectors.dim = len(v_bytes) // 2 field_data.vectors.float16_vector += v_bytes + elif field_type == DataType.BFLOAT16_VECTOR: - v_bytes = ( - bytes(field_value) - if not isinstance(field_value, np.ndarray) - else field_value.view(np.uint8).tobytes() - ) + if isinstance(field_value, bytes): + v_bytes = field_value + elif isinstance(field_value, np.ndarray): + if field_value.dtype != "bfloat16": + raise ParamError( + message="invalid input for bfloat16 vector, expect np.ndarray with dtype=bfloat16" + ) + v_bytes = field_value.view(np.uint8).tobytes() + else: + raise ParamError( + message="invalid input type for bfloat16 vector, expect np.ndarray with dtype=bfloat16" + ) field_data.vectors.dim = len(v_bytes) // 2 field_data.vectors.bfloat16_vector += v_bytes diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index af36b1449..460529e1c 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -475,46 +475,49 @@ def _get_info(self, collection_name: str, timeout: Optional[float] = None, **kwa return fields_info, enable_dynamic - def _prepare_row_insert_request( + @retry_on_rpc_failure() + def insert_rows( self, collection_name: str, - entity_rows: List, + entities: Union[Dict, List[Dict]], partition_name: Optional[str] = None, + schema: Optional[dict] = None, timeout: Optional[float] = None, **kwargs, ): - if not isinstance(entity_rows, list): - raise ParamError(message="None rows, please provide valid row data.") - - fields_info, enable_dynamic = self._get_info(collection_name, timeout, **kwargs) - return Prepare.row_insert_param( - collection_name, - entity_rows, - partition_name, - fields_info, - enable_dynamic=enable_dynamic, + request = self._prepare_row_insert_request( + collection_name, entities, partition_name, timeout, **kwargs ) + resp = self._stub.Insert(request=request, timeout=timeout) + check_status(resp.status) + ts_utils.update_collection_ts(collection_name, resp.timestamp) + return MutationResult(resp) - @retry_on_rpc_failure() - def insert_rows( + def _prepare_row_insert_request( self, collection_name: str, - entities: List, + entity_rows: Union[List[Dict], Dict], partition_name: Optional[str] = None, + schema: Optional[dict] = None, timeout: Optional[float] = None, **kwargs, ): - if isinstance(entities, dict): - entities = [entities] - request = self._prepare_row_insert_request( - collection_name, entities, partition_name, timeout, **kwargs + if isinstance(entity_rows, dict): + entity_rows = [entity_rows] + + if not isinstance(schema, dict): + schema = self.describe_collection(collection_name, timeout=timeout) + + fields_info = schema.get("fields") + enable_dynamic = schema.get("enable_dynamic_field", False) + + return Prepare.row_insert_param( + collection_name, + entity_rows, + partition_name, + fields_info, + enable_dynamic=enable_dynamic, ) - rf = self._stub.Insert.future(request, timeout=timeout) - response = rf.result() - check_status(response.status) - m = MutationResult(response) - ts_utils.update_collection_ts(collection_name, m.timestamp) - return m def _prepare_batch_insert_request( self, @@ -1376,7 +1379,7 @@ def _wait_for_flushed( end = time.time() if timeout is not None and end - start > timeout: raise MilvusException( - message=f"wait for flush timeout, collection: {collection_name}" + message=f"wait for flush timeout, collection: {collection_name}, flusht_ts: {flush_ts}" ) if not flush_ret: diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 3f9ed2f9a..5ec58718a 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -364,12 +364,10 @@ def _parse_row_request( field["name"]: field for field in fields_info if not field.get("auto_id", False) } - meta_field = ( - schema_types.FieldData(is_dynamic=True, type=DataType.JSON) if enable_dynamic else None - ) - if meta_field is not None: - field_info_map[meta_field.field_name] = meta_field - fields_data[meta_field.field_name] = meta_field + if enable_dynamic: + d_field = schema_types.FieldData(is_dynamic=True, type=DataType.JSON) + fields_data[d_field.field_name] = d_field + field_info_map[d_field.field_name] = d_field try: for entity in entities: @@ -390,7 +388,7 @@ def _parse_row_request( if enable_dynamic: json_value = entity_helper.convert_to_json(json_dict) - meta_field.scalars.json_data.data.append(json_value) + d_field.scalars.json_data.data.append(json_value) except (TypeError, ValueError) as e: raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e @@ -400,7 +398,7 @@ def _parse_row_request( ) if enable_dynamic: - request.fields_data.append(meta_field) + request.fields_data.append(d_field) _, _, auto_id_loc = traverse_rows_info(fields_info, entities) if auto_id_loc is not None: @@ -418,16 +416,18 @@ def row_insert_param( collection_name: str, entities: List, partition_name: str, - fields_info: Any, + fields_info: Dict, enable_dynamic: bool = False, ): if not fields_info: raise ParamError(message="Missing collection meta to validate entities") # insert_request.hash_keys won't be filled in client. - tag = partition_name if isinstance(partition_name, str) else "" + p_name = partition_name if isinstance(partition_name, str) else "" request = milvus_types.InsertRequest( - collection_name=collection_name, partition_name=tag, num_rows=len(entities) + collection_name=collection_name, + partition_name=p_name, + num_rows=len(entities), ) return cls._parse_row_request(request, fields_info, enable_dynamic, entities) @@ -445,9 +445,11 @@ def row_upsert_param( raise ParamError(message="Missing collection meta to validate entities") # upsert_request.hash_keys won't be filled in client. - tag = partition_name if isinstance(partition_name, str) else "" + p_name = partition_name if isinstance(partition_name, str) else "" request = milvus_types.UpsertRequest( - collection_name=collection_name, partition_name=tag, num_rows=len(entities) + collection_name=collection_name, + partition_name=p_name, + num_rows=len(entities), ) return cls._parse_row_request(request, fields_info, enable_dynamic, entities) @@ -469,7 +471,7 @@ def _pre_batch_check( if not fields_info: raise ParamError(message="Missing collection meta to validate entities") - location, primary_key_loc, auto_id_loc = traverse_info(fields_info, entities) + location, primary_key_loc, auto_id_loc = traverse_info(fields_info) # though impossible from sdk if primary_key_loc is None: @@ -583,16 +585,20 @@ def _prepare_placeholder_str(cls, data: Any): elif isinstance(data[0], np.ndarray): dtype = data[0].dtype - pl_values = (array.tobytes() for array in data) if dtype == "bfloat16": pl_type = PlaceholderType.BFLOAT16_VECTOR + pl_values = (array.tobytes() for array in data) elif dtype == "float16": pl_type = PlaceholderType.FLOAT16_VECTOR - elif dtype == "float32": + pl_values = (array.tobytes() for array in data) + elif dtype in ("float32", "float64"): pl_type = PlaceholderType.FloatVector + pl_values = (blob.vector_float_to_bytes(entity) for entity in data) + elif dtype == "byte": pl_type = PlaceholderType.BinaryVector + pl_values = data else: err_msg = f"unsupported data type: {dtype}" diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index 4d407e1ca..7ecc35670 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -250,7 +250,7 @@ def traverse_rows_info(fields_info: Any, entities: List): return location, primary_key_loc, auto_id_loc -def traverse_info(fields_info: Any, entities: List): +def traverse_info(fields_info: Any): location, primary_key_loc, auto_id_loc = {}, None, None for i, field in enumerate(fields_info): if field.get("is_primary", False): @@ -259,58 +259,7 @@ def traverse_info(fields_info: Any, entities: List): if field.get("auto_id", False): auto_id_loc = i continue - - match_flag = False - field_name = field["name"] - field_type = field["type"] - - for entity in entities: - entity_name, entity_type = entity["name"], entity["type"] - - if field_name == entity_name: - if field_type != entity_type: - raise ParamError( - message=f"Collection field type is {field_type}" - f", but entities field type is {entity_type}" - ) - - entity_dim, field_dim = 0, 0 - if entity_type in [ - DataType.FLOAT_VECTOR, - DataType.BINARY_VECTOR, - DataType.BFLOAT16_VECTOR, - DataType.FLOAT16_VECTOR, - ]: - field_dim = field["params"]["dim"] - entity_dim = len(entity["values"][0]) - - if entity_type in [DataType.FLOAT_VECTOR] and entity_dim != field_dim: - raise ParamError( - message=f"Collection field dim is {field_dim}" - f", but entities field dim is {entity_dim}" - ) - - if entity_type in [DataType.BINARY_VECTOR] and entity_dim * 8 != field_dim: - raise ParamError( - message=f"Collection field dim is {field_dim}" - f", but entities field dim is {entity_dim * 8}" - ) - - if ( - entity_type in [DataType.BFLOAT16_VECTOR, DataType.FLOAT16_VECTOR] - and int(entity_dim // 2) != field_dim - ): - raise ParamError( - message=f"Collection field dim is {field_dim}" - f", but entities field dim is {int(entity_dim // 2)}" - ) - - location[field["name"]] = i - match_flag = True - break - - if not match_flag: - raise ParamError(message=f"Field {field['name']} don't match in entities") + location[field["name"]] = i return location, primary_key_loc, auto_id_loc diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index 2e2459721..27fdfee71 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -34,6 +34,7 @@ IndexNotExistException, PartitionAlreadyExistException, SchemaNotReadyException, + UpsertAutoIDTrueException, ) from pymilvus.grpc_gen import schema_pb2 from pymilvus.settings import Config @@ -50,10 +51,11 @@ CollectionSchema, FieldSchema, check_insert_schema, - check_is_row_based, check_schema, check_upsert_schema, construct_fields_from_dataframe, + is_row_based, + is_valid_insert_data, ) from .types import DataType from .utility import _get_connection @@ -241,7 +243,7 @@ def aliases(self, **kwargs) -> list: @property def description(self) -> str: """str: a text description of the collection.""" - return self._schema.description + return self.schema.description @property def name(self) -> str: @@ -289,7 +291,7 @@ def num_entities(self, **kwargs) -> int: @property def primary_field(self) -> FieldSchema: """FieldSchema: the primary field of the collection.""" - return self._schema.primary_field + return self.schema.primary_field def flush(self, timeout: Optional[float] = None, **kwargs): """Seal all segments in the collection. Inserts after flushing will be written into @@ -490,25 +492,28 @@ def insert( >>> res.insert_count 10 """ - if data is None: - return MutationResult(data) - row_based = check_is_row_based(data) + if not is_valid_insert_data(data): + raise DataTypeNotSupportException( + message="The type of data should be List, pd.DataFrame or Dict" + ) + conn = self._get_connection() - if not row_based: - check_insert_schema(self._schema, data) - entities = Prepare.prepare_insert_data(data, self._schema) - return conn.batch_insert( - self._name, - entities, - partition_name, + if is_row_based(data): + return conn.insert_rows( + collection_name=self._name, + entities=data, + partition_name=partition_name, timeout=timeout, schema=self._schema_dict, **kwargs, ) - return conn.insert_rows( - collection_name=self._name, - entities=data, - partition_name=partition_name, + + check_insert_schema(self.schema, data) + entities = Prepare.prepare_insert_data(data, self.schema) + return conn.batch_insert( + self._name, + entities, + partition_name, timeout=timeout, schema=self._schema_dict, **kwargs, @@ -614,27 +619,17 @@ def upsert( >>> res.upsert_count 10 """ - if data is None: - return MutationResult(data) - row_based = check_is_row_based(data) - conn = self._get_connection() - if not row_based: - check_upsert_schema(self._schema, data) - entities = Prepare.prepare_upsert_data(data, self._schema) + if self.schema.auto_id: + raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue) - res = conn.upsert( - self._name, - entities, - partition_name, - timeout=timeout, - schema=self._schema_dict, - **kwargs, + if not is_valid_insert_data(data): + raise DataTypeNotSupportException( + message="The type of data should be List, pd.DataFrame or Dict" ) - if kwargs.get("_async", False): - return MutationFuture(res) - else: + conn = self._get_connection() + if is_row_based(data): res = conn.upsert_rows( self._name, data, @@ -643,8 +638,20 @@ def upsert( schema=self._schema_dict, **kwargs, ) + return MutationResult(res) - return MutationResult(res) + check_upsert_schema(self.schema, data) + entities = Prepare.prepare_upsert_data(data, self.schema) + res = conn.upsert( + self._name, + entities, + partition_name, + timeout=timeout, + schema=self._schema_dict, + **kwargs, + ) + + return MutationFuture(res) if kwargs.get("_async", False) else MutationResult(res) def search( self, diff --git a/pymilvus/orm/prepare.py b/pymilvus/orm/prepare.py index 4466fa78d..5e2902a20 100644 --- a/pymilvus/orm/prepare.py +++ b/pymilvus/orm/prepare.py @@ -22,6 +22,7 @@ DataNotMatchException, DataTypeNotSupportException, ExceptionsMessage, + ParamError, UpsertAutoIDTrueException, ) @@ -65,6 +66,14 @@ def prepare_insert_data( if field.is_primary and field.auto_id: tmp_fields.pop(i) + vec_dtype_checker = { + DataType.FLOAT_VECTOR: lambda ndarr: ndarr.dtype in ("float32", "float64"), + DataType.FLOAT16_VECTOR: lambda ndarr: ndarr.dtype in ("float16",), + DataType.BFLOAT16_VECTOR: lambda ndarr: ndarr.dtype in ("bfloat16",), + } + + wrong_field_type = "Wrong type for vector field: {}, expect={}, got={}" + wrong_ndarr_type = "Wrong type for np.ndarray for vector field: {}, expect={}, got={}" for i, field in enumerate(tmp_fields): try: f_data = data[i] @@ -72,14 +81,70 @@ def prepare_insert_data( except IndexError: entities.append({"name": field.name, "type": field.dtype, "values": []}) - if isinstance(f_data, np.ndarray): - d = f_data.tolist() - - elif isinstance(f_data[0], np.ndarray) and field.dtype in ( - DataType.FLOAT16_VECTOR, - DataType.BFLOAT16_VECTOR, - ): - d = [bytes(ndarr.view(np.uint8).tolist()) for ndarr in f_data] + d = [] + if field.dtype == DataType.FLOAT_VECTOR: + is_valid_ndarray = vec_dtype_checker[field.dtype] + if isinstance(f_data, np.ndarray): + if not is_valid_ndarray(f_data): + raise ParamError( + message=wrong_ndarr_type.format( + field.name, "np.float32/np.float64", f_data.dtype + ) + ) + d = f_data.view(np.float32).tolist() + + elif isinstance(f_data[0], np.ndarray): + for ndarr in f_data: + if not is_valid_ndarray(ndarr): + raise ParamError( + message=wrong_ndarr_type.format( + field.name, "np.float32/np.float64", ndarr.dtype + ) + ) + d.append(ndarr.tolist()) + + else: + d = f_data if f_data is not None else [] + + elif field.dtype == DataType.FLOAT16_VECTOR: + is_valid_ndarray = vec_dtype_checker[field.dtype] + if isinstance(f_data[0], np.ndarray): + for ndarr in f_data: + if not is_valid_ndarray(ndarr): + raise ParamError( + message=wrong_ndarr_type.format( + field.name, "np.float16", ndarr.dtype + ) + ) + d.append(ndarr.view(np.uint8).tobytes()) + else: + raise ParamError( + message=wrong_field_type.format( + field.name, + "List", + f"List{type(f_data[0])})", + ) + ) + + elif field.dtype == DataType.BFLOAT16_VECTOR: + is_valid_ndarray = vec_dtype_checker[field.dtype] + if isinstance(f_data[0], np.ndarray): + for ndarr in f_data: + if not is_valid_ndarray(ndarr): + raise ParamError( + message=wrong_ndarr_type.format( + field.name, "np.bfloat16", ndarr.dtype + ) + ) + d.append(ndarr.view(np.uint8).tobytes()) + else: + raise ParamError( + message=wrong_field_type.format( + field.name, + "List", + f"List{type(f_data[0])})", + ) + ) else: d = f_data if f_data is not None else [] diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index 44eea8455..059a0f37a 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -441,6 +441,18 @@ def dtype(self) -> DataType: return self._dtype +def is_valid_insert_data(data: Union[pd.DataFrame, list, dict]) -> bool: + """DataFrame, list, dict are valid insert data""" + return isinstance(data, (pd.DataFrame, list, dict)) + + +def is_row_based(data: Union[Dict, List[Dict]]) -> bool: + """Dict or List[Dict] are row-based""" + return isinstance(data, dict) or ( + isinstance(data, list) and len(data) > 0 and isinstance(data[0], Dict) + ) + + def check_is_row_based(data: Union[List[List], List[Dict], Dict, pd.DataFrame]) -> bool: if not isinstance(data, (pd.DataFrame, list, dict)): raise DataTypeNotSupportException( diff --git a/pyproject.toml b/pyproject.toml index 697ab28f7..c237b5dff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,6 +110,7 @@ lint.ignore = [ "ARG005", # [ruff] ARG005 Unused lambda argument: `disable` [E] "TRY400", "PLR0912", # TODO + "PLR0915", # To many statements TODO "C901", # TODO "PYI041", # TODO ]