diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 8190cfc7f..ff6fae9df 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -586,6 +586,7 @@ def delete( partition_name, expression, consistency_level=kwargs.get("consistency_level", 0), + param_name=kwargs.get("param_name", None), ) future = self._stub.Delete.future(req, timeout=timeout) @@ -1002,7 +1003,7 @@ def describe_index( info_dict = {kv.key: kv.value for kv in response.index_descriptions[0].params} info_dict["field_name"] = response.index_descriptions[0].field_name info_dict["index_name"] = response.index_descriptions[0].index_name - if info_dict.get("params", None): + if info_dict.get("params"): info_dict["params"] = json.loads(info_dict["params"]) return info_dict diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 7ee4d1946..e413da1c5 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -370,6 +370,9 @@ def _parse_row_request( try: for entity in entities: + if not isinstance(entity, Dict): + msg = f"expected Dict, got '{type(entity).__name__}'" + raise TypeError(msg) for k, v in entity.items(): if k not in fields_data and not enable_dynamic: raise DataNotMatchException(message=ExceptionsMessage.InsertUnexpectedField) @@ -545,6 +548,7 @@ def delete_request( partition_name: str, expr: str, consistency_level: Optional[Union[int, str]], + **kwargs, ): def check_str(instr: str, prefix: str): if instr is None: @@ -557,7 +561,8 @@ def check_str(instr: str, prefix: str): check_str(collection_name, "collection_name") if partition_name is not None and partition_name != "": check_str(partition_name, "partition_name") - check_str(expr, "expr") + param_name = kwargs.get("param_name", "expr") + check_str(expr, param_name) return milvus_types.DeleteRequest( collection_name=collection_name, @@ -626,7 +631,7 @@ def search_requests_with_expr( if group_by_field is not None: search_params[GROUP_BY_FIELD] = group_by_field - if param.get("metric_type", None) is not None: + if param.get("metric_type") is not None: search_params["metric_type"] = param["metric_type"] if anns_field: diff --git a/pymilvus/client/ts_utils.py b/pymilvus/client/ts_utils.py index 46d72e365..c260aa4a8 100644 --- a/pymilvus/client/ts_utils.py +++ b/pymilvus/client/ts_utils.py @@ -75,7 +75,7 @@ def get_bounded_ts(): def construct_guarantee_ts(collection_name: str, kwargs: Dict): - consistency_level = kwargs.get("consistency_level", None) + consistency_level = kwargs.get("consistency_level") use_default = consistency_level is None if use_default: # in case of the default consistency is Customized or Session, diff --git a/pymilvus/milvus_client/check.py b/pymilvus/milvus_client/check.py new file mode 100644 index 000000000..e3ff3f0ef --- /dev/null +++ b/pymilvus/milvus_client/check.py @@ -0,0 +1,11 @@ +from typing import Any + + +def check_param_type(param_name: str, param: Any, expected_type: Any, ignore_none: bool = True): + if ignore_none and param is None: + return + if not isinstance(param, expected_type): + msg = f"wrong type of arugment '{param_name}', " + msg += f"expected '{expected_type.__name__}', " + msg += f"got '{type(param).__name__}'" + raise TypeError(msg) diff --git a/pymilvus/milvus_client/index.py b/pymilvus/milvus_client/index.py index 1eb2c5a5e..8642937e7 100644 --- a/pymilvus/milvus_client/index.py +++ b/pymilvus/milvus_client/index.py @@ -19,6 +19,7 @@ def index_type(self): def __iter__(self): yield "field_name", self._field_name + yield "index_type", self._index_type yield "index_name", self._index_name yield from self._kwargs.items() diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py index d0f316837..2bac8ed60 100644 --- a/pymilvus/milvus_client/milvus_client.py +++ b/pymilvus/milvus_client/milvus_client.py @@ -17,6 +17,7 @@ from pymilvus.orm.connections import connections from pymilvus.orm.types import DataType +from .check import check_param_type from .index import IndexParams logger = logging.getLogger(__name__) @@ -51,14 +52,7 @@ def __init__( to None. """ # Optionial TQDM import - try: - import tqdm - - self.tqdm = tqdm.tqdm - except ImportError: - logger.debug("tqdm not found") - self.tqdm = lambda x, disable: x - + check_param_type("timeout", timeout, float) self._using = self._create_connection( uri, user, password, db_name, token, timeout=timeout, **kwargs ) @@ -78,6 +72,7 @@ def create_collection( index_params: Optional[IndexParams] = None, **kwargs, ): + check_param_type("timeout", timeout, float) if schema is None: return self._fast_create_collection( collection_name, @@ -107,6 +102,9 @@ def _fast_create_collection( timeout: Optional[float] = None, **kwargs, ): + if dimension is None: + msg = "missing requried argument: 'dimension'" + raise TypeError(msg) if "enable_dynamic_field" not in kwargs: kwargs["enable_dynamic_field"] = True @@ -141,9 +139,10 @@ def _fast_create_collection( index_params = self.prepare_index_params() index_type = "" index_name = "" - index_params.add_index( - vector_field_name, index_type, index_name, metric_type=metric_type, params={} - ) + params = { + "metric_type": metric_type, + } + index_params.add_index(vector_field_name, index_type, index_name, params=params) self.create_index(collection_name, index_params, timeout=timeout) self.load_collection(collection_name, timeout=timeout) @@ -154,6 +153,7 @@ def create_index( timeout: Optional[float] = None, **kwargs, ): + check_param_type("timeout", timeout, float) for index_param in index_params: self._create_index(collection_name, index_param, timeout=timeout, **kwargs) @@ -166,10 +166,18 @@ def _create_index( ): conn = self._get_connection() try: + params = index_param.get("params", {}) + _index_type = index_param.get("index_type") + if _index_type: + params["index_type"] = _index_type + _metric_type = index_param.get("metric_type") + if _metric_type: + params["metric_type"] = _metric_type + conn.create_index( collection_name, index_param["field_name"], - index_param.get("params", {}), + params, index_name=index_param.get("index_name", ""), timeout=timeout, **kwargs, @@ -205,10 +213,17 @@ def insert( Returns: Dict: Number of rows that were inserted. """ + check_param_type("timeout", timeout, float) # If no data provided, we cannot input anything if isinstance(data, Dict): data = [data] + msg = "wrong type of argument 'data'," + msg += f"expected 'Dict' or list of 'Dict', got '{type(data).__name__}'" + + if not isinstance(data, List): + raise TypeError(msg) + if len(data) == 0: return {"insert_count": 0} @@ -248,10 +263,17 @@ def upsert( Returns: List[Union[str, int]]: A list of primary keys that were inserted. """ + check_param_type("timeout", timeout, float) # If no data provided, we cannot input anything if isinstance(data, Dict): data = [data] + msg = "wrong type of argument 'data'," + msg += f"expected 'Dict' or list of 'Dict', got '{type(data).__name__}'" + + if not isinstance(data, List): + raise TypeError(msg) + if len(data) == 0: return {"upsert_count": 0} @@ -300,7 +322,7 @@ def search( List[List[dict]]: A nested list of dicts containing the result data. Embeddings are not included in the result data. """ - + check_param_type("timeout", timeout, float) conn = self._get_connection() try: res = conn.search( @@ -354,6 +376,7 @@ def query( Returns: List[dict]: A list of result dicts, vectors are not included. """ + check_param_type("timeout", timeout, float) if filter and not isinstance(filter, str): raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter)) @@ -419,6 +442,7 @@ def get( Returns: List[dict]: A list of result dicts with keys {pk_field, vector_field} """ + check_param_type("timeout", timeout, float) if not isinstance(ids, list): ids = [ids] @@ -487,16 +511,16 @@ def delete( Returns: Dict: Number of rows that were deleted. """ + check_param_type("timeout", timeout, float) pks = kwargs.get("pks", []) if isinstance(pks, (int, str)): pks = [pks] - if ids: + if ids is not None: if isinstance(ids, (int, str)): pks.append(ids) elif isinstance(ids, list): pks.extend(ids) - expr = "" conn = self._get_connection() if pks: @@ -519,7 +543,14 @@ def delete( ret_pks = [] try: - res = conn.delete(collection_name, expr, partition_name, timeout=timeout, **kwargs) + res = conn.delete( + collection_name, + expr, + partition_name, + timeout=timeout, + param_name="filter or ids", + **kwargs, + ) if res.primary_keys: ret_pks.extend(res.primary_keys) except Exception as ex: @@ -532,6 +563,7 @@ def delete( return {"delete_count": res.delete_count} def get_collection_stats(self, collection_name: str, timeout: Optional[float] = None) -> Dict: + check_param_type("timeout", timeout, float) conn = self._get_connection() stats = conn.get_collection_stats(collection_name, timeout=timeout) result = {stat.key: stat.value for stat in stats} @@ -539,10 +571,12 @@ def get_collection_stats(self, collection_name: str, timeout: Optional[float] = return result def describe_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs): + check_param_type("timeout", timeout, float) conn = self._get_connection() return conn.describe_collection(collection_name, timeout=timeout, **kwargs) def has_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs): + check_param_type("timeout", timeout, float) conn = self._get_connection() return conn.has_collection(collection_name, timeout=timeout, **kwargs) @@ -551,6 +585,7 @@ def list_collections(self, **kwargs): return conn.list_collections(**kwargs) def drop_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs): + check_param_type("timeout", timeout, float) """Delete the collection stored in this object""" conn = self._get_connection() conn.drop_collection(collection_name, timeout=timeout, **kwargs) @@ -563,6 +598,7 @@ def rename_collection( timeout: Optional[float] = None, **kwargs, ): + check_param_type("timeout", timeout, float) conn = self._get_connection() conn.rename_collections(old_name, new_name, target_db, timeout=timeout, **kwargs) @@ -668,6 +704,7 @@ def _pack_pks_expr(self, schema_dict: Dict, pks: List) -> str: def load_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs): """Loads the collection.""" + check_param_type("timeout", timeout, float) conn = self._get_connection() try: conn.load_collection(collection_name, timeout=timeout, **kwargs) @@ -679,6 +716,7 @@ def load_collection(self, collection_name: str, timeout: Optional[float] = None, raise ex from ex def release_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs): + check_param_type("timeout", timeout, float) conn = self._get_connection() try: conn.release_collection(collection_name, timeout=timeout, **kwargs) @@ -696,12 +734,13 @@ def get_load_state( timeout: Optional[float] = None, **kwargs, ) -> Dict: + check_param_type("timeout", timeout, float) conn = self._get_connection() partition_names = None if partition_name: partition_names = [partition_name] try: - state = conn.get_load_state(collection_name, partition_names, tiemout=timeout, **kwargs) + state = conn.get_load_state(collection_name, partition_names, timeout=timeout, **kwargs) except Exception as ex: raise ex from ex @@ -713,6 +752,7 @@ def get_load_state( return ret def refresh_load(self, collection_name: str, timeout: Optional[float] = None, **kwargs): + check_param_type("timeout", timeout, float) kwargs.pop("_refresh", None) conn = self._get_connection() conn.load_collection(collection_name, timeout=timeout, _refresh=True, **kwargs) @@ -748,6 +788,7 @@ def drop_index( timeout: Optional[float] = None, **kwargs, ): + check_param_type("timeout", timeout, float) conn = self._get_connection() conn.drop_index(collection_name, "", index_name, timeout=timeout, **kwargs) @@ -758,6 +799,7 @@ def describe_index( timeout: Optional[float] = None, **kwargs, ) -> Dict: + check_param_type("timeout", timeout, float) conn = self._get_connection() return conn.describe_index(collection_name, index_name, timeout=timeout, **kwargs) @@ -768,6 +810,7 @@ def create_partition( timeout: Optional[float] = None, **kwargs, ): + check_param_type("timeout", timeout, float) conn = self._get_connection() conn.create_partition(collection_name, partition_name, timeout=timeout, **kwargs) @@ -778,6 +821,7 @@ def drop_partition( timeout: Optional[float] = None, **kwargs, ): + check_param_type("timeout", timeout, float) conn = self._get_connection() conn.drop_partition(collection_name, partition_name, timeout=timeout, **kwargs) @@ -788,12 +832,14 @@ def has_partition( timeout: Optional[float] = None, **kwargs, ) -> bool: + check_param_type("timeout", timeout, float) conn = self._get_connection() return conn.has_partition(collection_name, partition_name, timeout=timeout, **kwargs) def list_partitions( self, collection_name: str, timeout: Optional[float] = None, **kwargs ) -> List[str]: + check_param_type("timeout", timeout, float) conn = self._get_connection() return conn.list_partitions(collection_name, timeout=timeout, **kwargs) @@ -804,6 +850,7 @@ def load_partitions( timeout: Optional[float] = None, **kwargs, ): + check_param_type("timeout", timeout, float) if isinstance(partition_names, str): partition_names = [partition_names] @@ -817,6 +864,7 @@ def release_partitions( timeout: Optional[float] = None, **kwargs, ): + check_param_type("timeout", timeout, float) if isinstance(partition_names, str): partition_names = [partition_names] conn = self._get_connection() @@ -825,14 +873,21 @@ def release_partitions( def get_partition_stats( self, collection_name: str, partition_name: str, timeout: Optional[float] = None, **kwargs ) -> Dict: + check_param_type("timeout", timeout, float) conn = self._get_connection() - return conn.get_partition_stats(collection_name, partition_name, timeout=timeout, **kwargs) + if not isinstance(partition_name, str): + msg = f"wrong type of argument 'partition_name', str expected, got '{type(partition_name).__name__}'" + raise TypeError(msg) + ret = conn.get_partition_stats(collection_name, partition_name, timeout=timeout, **kwargs) + return {stat.key: stat.value for stat in ret} def create_user(self, user_name: str, password: str, timeout: Optional[float] = None, **kwargs): + check_param_type("timeout", timeout, float) conn = self._get_connection() return conn.create_user(user_name, password, timeout=timeout, **kwargs) def drop_user(self, user_name: str, timeout: Optional[float] = None, **kwargs): + check_param_type("timeout", timeout, float) conn = self._get_connection() return conn.delete_user(user_name, timeout=timeout, **kwargs) @@ -845,6 +900,7 @@ def update_password( timeout: Optional[float] = None, **kwargs, ): + check_param_type("timeout", timeout, float) conn = self._get_connection() conn.update_password(user_name, old_password, new_password, timeout=timeout, **kwargs) if reset_connection: @@ -852,10 +908,12 @@ def update_password( conn._setup_grpc_channel() def list_users(self, timeout: Optional[float] = None, **kwargs): + check_param_type("timeout", timeout, float) conn = self._get_connection() return conn.list_usernames(timeout=timeout, **kwargs) def describe_user(self, user_name: str, timeout: Optional[float] = None, **kwargs): + check_param_type("timeout", timeout, float) conn = self._get_connection() try: res = conn.select_one_user(user_name, True, timeout=timeout, **kwargs) @@ -870,26 +928,31 @@ def describe_user(self, user_name: str, timeout: Optional[float] = None, **kwarg return {} def grant_role(self, user_name: str, role_name: str, timeout: Optional[float] = None, **kwargs): + check_param_type("timeout", timeout, float) conn = self._get_connection() conn.add_user_to_role(user_name, role_name, timeout=timeout, **kwargs) def revoke_role( self, user_name: str, role_name: str, timeout: Optional[float] = None, **kwargs ): + check_param_type("timeout", timeout, float) conn = self._get_connection() conn.remove_user_from_role(user_name, role_name, timeout=timeout, **kwargs) def create_role(self, role_name: str, timeout: Optional[float] = None, **kwargs): + check_param_type("timeout", timeout, float) conn = self._get_connection() conn.create_role(role_name, timeout=timeout, **kwargs) def drop_role(self, role_name: str, timeout: Optional[float] = None, **kwargs): + check_param_type("timeout", timeout, float) conn = self._get_connection() conn.drop_role(role_name, timeout=timeout, **kwargs) def describe_role( self, role_name: str, timeout: Optional[float] = None, **kwargs ) -> List[Dict]: + check_param_type("timeout", timeout, float) conn = self._get_connection() db_name = kwargs.pop("db_name", "") try: @@ -899,6 +962,7 @@ def describe_role( return [dict(i) for i in res.groups] def list_roles(self, timeout: Optional[float] = None, **kwargs): + check_param_type("timeout", timeout, float) conn = self._get_connection() try: res = conn.select_all_role(False, timeout=timeout, **kwargs) @@ -918,6 +982,7 @@ def grant_privilege( timeout: Optional[float] = None, **kwargs, ): + check_param_type("timeout", timeout, float) conn = self._get_connection() conn.grant_privilege( role_name, object_type, object_name, privilege, db_name, timeout=timeout, **kwargs @@ -933,6 +998,7 @@ def revoke_privilege( timeout: Optional[float] = None, **kwargs, ): + check_param_type("timeout", timeout, float) conn = self._get_connection() conn.revoke_privilege( role_name, object_type, object_name, privilege, db_name, timeout=timeout, **kwargs @@ -941,24 +1007,27 @@ def revoke_privilege( def create_alias( self, collection_name: str, alias: str, timeout: Optional[float] = None, **kwargs ): + check_param_type("timeout", timeout, float) conn = self._get_connection() - conn.create_alias(collection_name, alias, tiemout=timeout, **kwargs) + conn.create_alias(collection_name, alias, timeout=timeout, **kwargs) def drop_alias(self, alias: str, timeout: Optional[float] = None, **kwargs): + check_param_type("timeout", timeout, float) conn = self._get_connection() - conn.drop_alias(alias, tiemout=timeout, **kwargs) + conn.drop_alias(alias, timeout=timeout, **kwargs) def alter_alias( self, collection_name: str, alias: str, timeout: Optional[float] = None, **kwargs ): + check_param_type("timeout", timeout, float) conn = self._get_connection() - conn.alter_alias(collection_name, alias, tiemout=timeout, **kwargs) + conn.alter_alias(collection_name, alias, timeout=timeout, **kwargs) def describe_alias(self, alias: str, timeout: Optional[float] = None, **kwargs) -> Dict: - pass + check_param_type("timeout", timeout, float) def list_aliases(self, timeout: Optional[float] = None, **kwargs) -> List[str]: - pass + check_param_type("timeout", timeout, float) def using_database(self, db_name: str, **kwargs): conn = self._get_connection() diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index d152484b2..85d306b32 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -1260,7 +1260,7 @@ def indexes(self, **kwargs) -> List[Index]: for index in tmp_index: if index is not None: info_dict = {kv.key: kv.value for kv in index.params} - if info_dict.get("params", None): + if info_dict.get("params"): info_dict["params"] = json.loads(info_dict["params"]) index_info = Index( diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index 50db692b4..5cad24dfe 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -176,7 +176,7 @@ def __check_reached_limit(self, ret: List): def __setup__pk_prop(self): fields = self._schema[FIELDS] for field in fields: - if IS_PRIMARY in field and field[IS_PRIMARY]: + if field.get(IS_PRIMARY): if field["type"] == DataType.VARCHAR: self._pk_str = True else: @@ -383,7 +383,7 @@ def __check_for_special_index_param(self): def __setup__pk_prop(self): fields = self._schema[FIELDS] for field in fields: - if IS_PRIMARY in field and field[IS_PRIMARY]: + if field.get(IS_PRIMARY): if field["type"] == DataType.VARCHAR: self._pk_str = True else: diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index 142340e5b..c5674a31d 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -210,7 +210,9 @@ def auto_id(self): >>> schema.auto_id False """ - return self.primary_field.auto_id + if self.primary_field: + return self.primary_field.auto_id + return self._kwargs.get("auto_id", False) @auto_id.setter def auto_id(self, value: bool): @@ -309,11 +311,11 @@ def construct_from_dict(cls, raw: Dict): kwargs = {} kwargs.update(raw.get("params", {})) kwargs["is_primary"] = raw.get("is_primary", False) - if raw.get("auto_id", None) is not None: - kwargs["auto_id"] = raw.get("auto_id", None) + if raw.get("auto_id") is not None: + kwargs["auto_id"] = raw.get("auto_id") kwargs["is_partition_key"] = raw.get("is_partition_key", False) kwargs["is_dynamic"] = raw.get("is_dynamic", False) - kwargs["element_type"] = raw.get("element_type", None) + kwargs["element_type"] = raw.get("element_type") return FieldSchema(raw["name"], raw["type"], raw.get("description", ""), **kwargs) def to_dict(self):