Skip to content

Commit

Permalink
fix: restrict input/search type for vector fields (#2025)
Browse files Browse the repository at this point in the history
See also: #2018, #2004, #2016

---------

Signed-off-by: yangxuan <[email protected]>
  • Loading branch information
XuanYang-cn authored Apr 10, 2024
1 parent 2112ee2 commit ea2c9c6
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 151 deletions.
53 changes: 40 additions & 13 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,9 @@ def entity_to_array_arr(entity: List[Any], field_info: Any):
return convert_to_array_arr(entity.get("values", []), field_info)


def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info: Any):
def pack_field_value_to_field_data(
field_value: Any, field_data: schema_types.FieldData, field_info: Any
):
field_type = field_data.type
if field_type == DataType.BOOL:
field_data.scalars.bool_data.data.append(field_value)
Expand All @@ -304,26 +306,51 @@ def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info
elif field_type == DataType.DOUBLE:
field_data.scalars.double_data.data.append(field_value)
elif field_type == DataType.FLOAT_VECTOR:
field_data.vectors.dim = len(field_value)
field_data.vectors.float_vector.data.extend(field_value)
f_value = field_value
if isinstance(field_value, np.ndarray):
if field_value.dtype not in ("float32", "float64"):
raise ParamError(
message="invalid input for float32 vector, expect np.ndarray with dtype=float32"
)
f_value = field_value.view(np.float32).tolist()

field_data.vectors.dim = len(f_value)
field_data.vectors.float_vector.data.extend(f_value)

elif field_type == DataType.BINARY_VECTOR:
field_data.vectors.dim = len(field_value) * 8
field_data.vectors.binary_vector += bytes(field_value)

elif field_type == DataType.FLOAT16_VECTOR:
v_bytes = (
bytes(field_value)
if not isinstance(field_value, np.ndarray)
else field_value.view(np.uint8).tobytes()
)
if isinstance(field_value, bytes):
v_bytes = field_value
elif isinstance(field_value, np.ndarray):
if field_value.dtype != "float16":
raise ParamError(
message="invalid input for float16 vector, expect np.ndarray with dtype=float16"
)
v_bytes = field_value.view(np.uint8).tobytes()
else:
raise ParamError(
message="invalid input type for float16 vector, expect np.ndarray with dtype=float16"
)

field_data.vectors.dim = len(v_bytes) // 2
field_data.vectors.float16_vector += v_bytes

elif field_type == DataType.BFLOAT16_VECTOR:
v_bytes = (
bytes(field_value)
if not isinstance(field_value, np.ndarray)
else field_value.view(np.uint8).tobytes()
)
if isinstance(field_value, bytes):
v_bytes = field_value
elif isinstance(field_value, np.ndarray):
if field_value.dtype != "bfloat16":
raise ParamError(
message="invalid input for bfloat16 vector, expect np.ndarray with dtype=bfloat16"
)
v_bytes = field_value.view(np.uint8).tobytes()
else:
raise ParamError(
message="invalid input type for bfloat16 vector, expect np.ndarray with dtype=bfloat16"
)

field_data.vectors.dim = len(v_bytes) // 2
field_data.vectors.bfloat16_vector += v_bytes
Expand Down
55 changes: 29 additions & 26 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,46 +475,49 @@ def _get_info(self, collection_name: str, timeout: Optional[float] = None, **kwa

return fields_info, enable_dynamic

def _prepare_row_insert_request(
@retry_on_rpc_failure()
def insert_rows(
self,
collection_name: str,
entity_rows: List,
entities: Union[Dict, List[Dict]],
partition_name: Optional[str] = None,
schema: Optional[dict] = None,
timeout: Optional[float] = None,
**kwargs,
):
if not isinstance(entity_rows, list):
raise ParamError(message="None rows, please provide valid row data.")

fields_info, enable_dynamic = self._get_info(collection_name, timeout, **kwargs)
return Prepare.row_insert_param(
collection_name,
entity_rows,
partition_name,
fields_info,
enable_dynamic=enable_dynamic,
request = self._prepare_row_insert_request(
collection_name, entities, partition_name, timeout, **kwargs
)
resp = self._stub.Insert(request=request, timeout=timeout)
check_status(resp.status)
ts_utils.update_collection_ts(collection_name, resp.timestamp)
return MutationResult(resp)

@retry_on_rpc_failure()
def insert_rows(
def _prepare_row_insert_request(
self,
collection_name: str,
entities: List,
entity_rows: Union[List[Dict], Dict],
partition_name: Optional[str] = None,
schema: Optional[dict] = None,
timeout: Optional[float] = None,
**kwargs,
):
if isinstance(entities, dict):
entities = [entities]
request = self._prepare_row_insert_request(
collection_name, entities, partition_name, timeout, **kwargs
if isinstance(entity_rows, dict):
entity_rows = [entity_rows]

if not isinstance(schema, dict):
schema = self.describe_collection(collection_name, timeout=timeout)

fields_info = schema.get("fields")
enable_dynamic = schema.get("enable_dynamic_field", False)

return Prepare.row_insert_param(
collection_name,
entity_rows,
partition_name,
fields_info,
enable_dynamic=enable_dynamic,
)
rf = self._stub.Insert.future(request, timeout=timeout)
response = rf.result()
check_status(response.status)
m = MutationResult(response)
ts_utils.update_collection_ts(collection_name, m.timestamp)
return m

def _prepare_batch_insert_request(
self,
Expand Down Expand Up @@ -1376,7 +1379,7 @@ def _wait_for_flushed(
end = time.time()
if timeout is not None and end - start > timeout:
raise MilvusException(
message=f"wait for flush timeout, collection: {collection_name}"
message=f"wait for flush timeout, collection: {collection_name}, flusht_ts: {flush_ts}"
)

if not flush_ret:
Expand Down
38 changes: 22 additions & 16 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,10 @@ def _parse_row_request(
field["name"]: field for field in fields_info if not field.get("auto_id", False)
}

meta_field = (
schema_types.FieldData(is_dynamic=True, type=DataType.JSON) if enable_dynamic else None
)
if meta_field is not None:
field_info_map[meta_field.field_name] = meta_field
fields_data[meta_field.field_name] = meta_field
if enable_dynamic:
d_field = schema_types.FieldData(is_dynamic=True, type=DataType.JSON)
fields_data[d_field.field_name] = d_field
field_info_map[d_field.field_name] = d_field

try:
for entity in entities:
Expand All @@ -390,7 +388,7 @@ def _parse_row_request(

if enable_dynamic:
json_value = entity_helper.convert_to_json(json_dict)
meta_field.scalars.json_data.data.append(json_value)
d_field.scalars.json_data.data.append(json_value)

except (TypeError, ValueError) as e:
raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e
Expand All @@ -400,7 +398,7 @@ def _parse_row_request(
)

if enable_dynamic:
request.fields_data.append(meta_field)
request.fields_data.append(d_field)

_, _, auto_id_loc = traverse_rows_info(fields_info, entities)
if auto_id_loc is not None:
Expand All @@ -418,16 +416,18 @@ def row_insert_param(
collection_name: str,
entities: List,
partition_name: str,
fields_info: Any,
fields_info: Dict,
enable_dynamic: bool = False,
):
if not fields_info:
raise ParamError(message="Missing collection meta to validate entities")

# insert_request.hash_keys won't be filled in client.
tag = partition_name if isinstance(partition_name, str) else ""
p_name = partition_name if isinstance(partition_name, str) else ""
request = milvus_types.InsertRequest(
collection_name=collection_name, partition_name=tag, num_rows=len(entities)
collection_name=collection_name,
partition_name=p_name,
num_rows=len(entities),
)

return cls._parse_row_request(request, fields_info, enable_dynamic, entities)
Expand All @@ -445,9 +445,11 @@ def row_upsert_param(
raise ParamError(message="Missing collection meta to validate entities")

# upsert_request.hash_keys won't be filled in client.
tag = partition_name if isinstance(partition_name, str) else ""
p_name = partition_name if isinstance(partition_name, str) else ""
request = milvus_types.UpsertRequest(
collection_name=collection_name, partition_name=tag, num_rows=len(entities)
collection_name=collection_name,
partition_name=p_name,
num_rows=len(entities),
)

return cls._parse_row_request(request, fields_info, enable_dynamic, entities)
Expand All @@ -469,7 +471,7 @@ def _pre_batch_check(
if not fields_info:
raise ParamError(message="Missing collection meta to validate entities")

location, primary_key_loc, auto_id_loc = traverse_info(fields_info, entities)
location, primary_key_loc, auto_id_loc = traverse_info(fields_info)

# though impossible from sdk
if primary_key_loc is None:
Expand Down Expand Up @@ -583,16 +585,20 @@ def _prepare_placeholder_str(cls, data: Any):

elif isinstance(data[0], np.ndarray):
dtype = data[0].dtype
pl_values = (array.tobytes() for array in data)

if dtype == "bfloat16":
pl_type = PlaceholderType.BFLOAT16_VECTOR
pl_values = (array.tobytes() for array in data)
elif dtype == "float16":
pl_type = PlaceholderType.FLOAT16_VECTOR
elif dtype == "float32":
pl_values = (array.tobytes() for array in data)
elif dtype in ("float32", "float64"):
pl_type = PlaceholderType.FloatVector
pl_values = (blob.vector_float_to_bytes(entity) for entity in data)

elif dtype == "byte":
pl_type = PlaceholderType.BinaryVector
pl_values = data

else:
err_msg = f"unsupported data type: {dtype}"
Expand Down
55 changes: 2 additions & 53 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def traverse_rows_info(fields_info: Any, entities: List):
return location, primary_key_loc, auto_id_loc


def traverse_info(fields_info: Any, entities: List):
def traverse_info(fields_info: Any):
location, primary_key_loc, auto_id_loc = {}, None, None
for i, field in enumerate(fields_info):
if field.get("is_primary", False):
Expand All @@ -259,58 +259,7 @@ def traverse_info(fields_info: Any, entities: List):
if field.get("auto_id", False):
auto_id_loc = i
continue

match_flag = False
field_name = field["name"]
field_type = field["type"]

for entity in entities:
entity_name, entity_type = entity["name"], entity["type"]

if field_name == entity_name:
if field_type != entity_type:
raise ParamError(
message=f"Collection field type is {field_type}"
f", but entities field type is {entity_type}"
)

entity_dim, field_dim = 0, 0
if entity_type in [
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.FLOAT16_VECTOR,
]:
field_dim = field["params"]["dim"]
entity_dim = len(entity["values"][0])

if entity_type in [DataType.FLOAT_VECTOR] and entity_dim != field_dim:
raise ParamError(
message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim}"
)

if entity_type in [DataType.BINARY_VECTOR] and entity_dim * 8 != field_dim:
raise ParamError(
message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim * 8}"
)

if (
entity_type in [DataType.BFLOAT16_VECTOR, DataType.FLOAT16_VECTOR]
and int(entity_dim // 2) != field_dim
):
raise ParamError(
message=f"Collection field dim is {field_dim}"
f", but entities field dim is {int(entity_dim // 2)}"
)

location[field["name"]] = i
match_flag = True
break

if not match_flag:
raise ParamError(message=f"Field {field['name']} don't match in entities")
location[field["name"]] = i

return location, primary_key_loc, auto_id_loc

Expand Down
Loading

0 comments on commit ea2c9c6

Please sign in to comment.