Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: restrict input/search type for vector fields #2025

Merged
merged 2 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,9 @@ def entity_to_array_arr(entity: List[Any], field_info: Any):
return convert_to_array_arr(entity.get("values", []), field_info)


def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info: Any):
def pack_field_value_to_field_data(
field_value: Any, field_data: schema_types.FieldData, field_info: Any
):
field_type = field_data.type
if field_type == DataType.BOOL:
field_data.scalars.bool_data.data.append(field_value)
Expand All @@ -304,26 +306,51 @@ def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info
elif field_type == DataType.DOUBLE:
field_data.scalars.double_data.data.append(field_value)
elif field_type == DataType.FLOAT_VECTOR:
field_data.vectors.dim = len(field_value)
field_data.vectors.float_vector.data.extend(field_value)
f_value = field_value
if isinstance(field_value, np.ndarray):
if field_value.dtype not in ("float32", "float64"):
raise ParamError(
message="invalid input for float32 vector, expect np.ndarray with dtype=float32"
)
f_value = field_value.view(np.float32).tolist()

field_data.vectors.dim = len(f_value)
field_data.vectors.float_vector.data.extend(f_value)

elif field_type == DataType.BINARY_VECTOR:
field_data.vectors.dim = len(field_value) * 8
field_data.vectors.binary_vector += bytes(field_value)

elif field_type == DataType.FLOAT16_VECTOR:
v_bytes = (
bytes(field_value)
if not isinstance(field_value, np.ndarray)
else field_value.view(np.uint8).tobytes()
)
if isinstance(field_value, bytes):
v_bytes = field_value
elif isinstance(field_value, np.ndarray):
if field_value.dtype != "float16":
raise ParamError(
message="invalid input for float16 vector, expect np.ndarray with dtype=float16"
)
v_bytes = field_value.view(np.uint8).tobytes()
else:
raise ParamError(
message="invalid input type for float16 vector, expect np.ndarray with dtype=float16"
)

field_data.vectors.dim = len(v_bytes) // 2
field_data.vectors.float16_vector += v_bytes

elif field_type == DataType.BFLOAT16_VECTOR:
v_bytes = (
bytes(field_value)
if not isinstance(field_value, np.ndarray)
else field_value.view(np.uint8).tobytes()
)
if isinstance(field_value, bytes):
v_bytes = field_value
elif isinstance(field_value, np.ndarray):
if field_value.dtype != "bfloat16":
raise ParamError(
message="invalid input for bfloat16 vector, expect np.ndarray with dtype=bfloat16"
)
v_bytes = field_value.view(np.uint8).tobytes()
else:
raise ParamError(
message="invalid input type for bfloat16 vector, expect np.ndarray with dtype=bfloat16"
)

field_data.vectors.dim = len(v_bytes) // 2
field_data.vectors.bfloat16_vector += v_bytes
Expand Down
55 changes: 29 additions & 26 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,46 +475,49 @@ def _get_info(self, collection_name: str, timeout: Optional[float] = None, **kwa

return fields_info, enable_dynamic

def _prepare_row_insert_request(
@retry_on_rpc_failure()
def insert_rows(
self,
collection_name: str,
entity_rows: List,
entities: Union[Dict, List[Dict]],
partition_name: Optional[str] = None,
schema: Optional[dict] = None,
timeout: Optional[float] = None,
**kwargs,
):
if not isinstance(entity_rows, list):
raise ParamError(message="None rows, please provide valid row data.")

fields_info, enable_dynamic = self._get_info(collection_name, timeout, **kwargs)
return Prepare.row_insert_param(
collection_name,
entity_rows,
partition_name,
fields_info,
enable_dynamic=enable_dynamic,
request = self._prepare_row_insert_request(
collection_name, entities, partition_name, timeout, **kwargs
)
resp = self._stub.Insert(request=request, timeout=timeout)
check_status(resp.status)
ts_utils.update_collection_ts(collection_name, resp.timestamp)
return MutationResult(resp)

@retry_on_rpc_failure()
def insert_rows(
def _prepare_row_insert_request(
self,
collection_name: str,
entities: List,
entity_rows: Union[List[Dict], Dict],
partition_name: Optional[str] = None,
schema: Optional[dict] = None,
timeout: Optional[float] = None,
**kwargs,
):
if isinstance(entities, dict):
entities = [entities]
request = self._prepare_row_insert_request(
collection_name, entities, partition_name, timeout, **kwargs
if isinstance(entity_rows, dict):
entity_rows = [entity_rows]

if not isinstance(schema, dict):
schema = self.describe_collection(collection_name, timeout=timeout)

fields_info = schema.get("fields")
enable_dynamic = schema.get("enable_dynamic_field", False)

return Prepare.row_insert_param(
collection_name,
entity_rows,
partition_name,
fields_info,
enable_dynamic=enable_dynamic,
)
rf = self._stub.Insert.future(request, timeout=timeout)
response = rf.result()
check_status(response.status)
m = MutationResult(response)
ts_utils.update_collection_ts(collection_name, m.timestamp)
return m

def _prepare_batch_insert_request(
self,
Expand Down Expand Up @@ -1376,7 +1379,7 @@ def _wait_for_flushed(
end = time.time()
if timeout is not None and end - start > timeout:
raise MilvusException(
message=f"wait for flush timeout, collection: {collection_name}"
message=f"wait for flush timeout, collection: {collection_name}, flusht_ts: {flush_ts}"
)

if not flush_ret:
Expand Down
38 changes: 22 additions & 16 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,10 @@ def _parse_row_request(
field["name"]: field for field in fields_info if not field.get("auto_id", False)
}

meta_field = (
schema_types.FieldData(is_dynamic=True, type=DataType.JSON) if enable_dynamic else None
)
if meta_field is not None:
field_info_map[meta_field.field_name] = meta_field
fields_data[meta_field.field_name] = meta_field
if enable_dynamic:
d_field = schema_types.FieldData(is_dynamic=True, type=DataType.JSON)
fields_data[d_field.field_name] = d_field
field_info_map[d_field.field_name] = d_field

try:
for entity in entities:
Expand All @@ -390,7 +388,7 @@ def _parse_row_request(

if enable_dynamic:
json_value = entity_helper.convert_to_json(json_dict)
meta_field.scalars.json_data.data.append(json_value)
d_field.scalars.json_data.data.append(json_value)

except (TypeError, ValueError) as e:
raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e
Expand All @@ -400,7 +398,7 @@ def _parse_row_request(
)

if enable_dynamic:
request.fields_data.append(meta_field)
request.fields_data.append(d_field)

_, _, auto_id_loc = traverse_rows_info(fields_info, entities)
if auto_id_loc is not None:
Expand All @@ -418,16 +416,18 @@ def row_insert_param(
collection_name: str,
entities: List,
partition_name: str,
fields_info: Any,
fields_info: Dict,
enable_dynamic: bool = False,
):
if not fields_info:
raise ParamError(message="Missing collection meta to validate entities")

# insert_request.hash_keys won't be filled in client.
tag = partition_name if isinstance(partition_name, str) else ""
p_name = partition_name if isinstance(partition_name, str) else ""
request = milvus_types.InsertRequest(
collection_name=collection_name, partition_name=tag, num_rows=len(entities)
collection_name=collection_name,
partition_name=p_name,
num_rows=len(entities),
)

return cls._parse_row_request(request, fields_info, enable_dynamic, entities)
Expand All @@ -445,9 +445,11 @@ def row_upsert_param(
raise ParamError(message="Missing collection meta to validate entities")

# upsert_request.hash_keys won't be filled in client.
tag = partition_name if isinstance(partition_name, str) else ""
p_name = partition_name if isinstance(partition_name, str) else ""
request = milvus_types.UpsertRequest(
collection_name=collection_name, partition_name=tag, num_rows=len(entities)
collection_name=collection_name,
partition_name=p_name,
num_rows=len(entities),
)

return cls._parse_row_request(request, fields_info, enable_dynamic, entities)
Expand All @@ -469,7 +471,7 @@ def _pre_batch_check(
if not fields_info:
raise ParamError(message="Missing collection meta to validate entities")

location, primary_key_loc, auto_id_loc = traverse_info(fields_info, entities)
location, primary_key_loc, auto_id_loc = traverse_info(fields_info)

# though impossible from sdk
if primary_key_loc is None:
Expand Down Expand Up @@ -583,16 +585,20 @@ def _prepare_placeholder_str(cls, data: Any):

elif isinstance(data[0], np.ndarray):
dtype = data[0].dtype
pl_values = (array.tobytes() for array in data)

if dtype == "bfloat16":
pl_type = PlaceholderType.BFLOAT16_VECTOR
pl_values = (array.tobytes() for array in data)
elif dtype == "float16":
pl_type = PlaceholderType.FLOAT16_VECTOR
elif dtype == "float32":
pl_values = (array.tobytes() for array in data)
elif dtype in ("float32", "float64"):
pl_type = PlaceholderType.FloatVector
pl_values = (blob.vector_float_to_bytes(entity) for entity in data)

elif dtype == "byte":
pl_type = PlaceholderType.BinaryVector
pl_values = data

else:
err_msg = f"unsupported data type: {dtype}"
Expand Down
55 changes: 2 additions & 53 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def traverse_rows_info(fields_info: Any, entities: List):
return location, primary_key_loc, auto_id_loc


def traverse_info(fields_info: Any, entities: List):
def traverse_info(fields_info: Any):
location, primary_key_loc, auto_id_loc = {}, None, None
for i, field in enumerate(fields_info):
if field.get("is_primary", False):
Expand All @@ -259,58 +259,7 @@ def traverse_info(fields_info: Any, entities: List):
if field.get("auto_id", False):
auto_id_loc = i
continue

match_flag = False
field_name = field["name"]
field_type = field["type"]

for entity in entities:
entity_name, entity_type = entity["name"], entity["type"]

if field_name == entity_name:
if field_type != entity_type:
raise ParamError(
message=f"Collection field type is {field_type}"
f", but entities field type is {entity_type}"
)

entity_dim, field_dim = 0, 0
if entity_type in [
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.FLOAT16_VECTOR,
]:
field_dim = field["params"]["dim"]
entity_dim = len(entity["values"][0])

if entity_type in [DataType.FLOAT_VECTOR] and entity_dim != field_dim:
raise ParamError(
message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim}"
)

if entity_type in [DataType.BINARY_VECTOR] and entity_dim * 8 != field_dim:
raise ParamError(
message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim * 8}"
)

if (
entity_type in [DataType.BFLOAT16_VECTOR, DataType.FLOAT16_VECTOR]
and int(entity_dim // 2) != field_dim
):
raise ParamError(
message=f"Collection field dim is {field_dim}"
f", but entities field dim is {int(entity_dim // 2)}"
)

location[field["name"]] = i
match_flag = True
break

if not match_flag:
raise ParamError(message=f"Field {field['name']} don't match in entities")
location[field["name"]] = i

return location, primary_key_loc, auto_id_loc

Expand Down
Loading