diff --git a/.env.example b/.env.example new file mode 100644 index 000000000..8e286b7d1 --- /dev/null +++ b/.env.example @@ -0,0 +1,9 @@ +# Please copy this file and rename as .env, pymilvus will read .env file if provided + +MILVUS_URI= +# MILVUS_URI=https://username:password@in01-random123.xxx.com:19530 + +# Milvus connections configs +MILVUS_CONN_ALIAS=default +MILVUS_CONN_TIMEOUT=10 + diff --git a/pymilvus/__init__.py b/pymilvus/__init__.py index 72e94648a..545665e3b 100644 --- a/pymilvus/__init__.py +++ b/pymilvus/__init__.py @@ -31,7 +31,16 @@ ) from .client import __version__ -from .settings import DEBUG_LOG_LEVEL, INFO_LOG_LEVEL, WARN_LOG_LEVEL, ERROR_LOG_LEVEL +from .settings import ( + DEBUG_LOG_LEVEL, + INFO_LOG_LEVEL, + WARN_LOG_LEVEL, + ERROR_LOG_LEVEL, +) +# Compatiable +from .settings import Config as DefaultConfig + +from .client.constants import DEFAULT_RESOURCE_GROUP from .orm.collection import Collection from .orm.connections import connections, Connections @@ -58,7 +67,6 @@ ) from .orm import utility -from .orm.default_config import DefaultConfig, ENV_CONNECTION_CONF, DEFAULT_RESOURCE_GROUP from .orm.search import SearchResult, Hits, Hit from .orm.schema import FieldSchema, CollectionSchema diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index 38fce7a5e..a6d383dac 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -1,7 +1,7 @@ import abc import numpy as np -from .configs import DefaultConfigs +from ..settings import Config from .types import DataType from .constants import DEFAULT_CONSISTENCY_LEVEL from ..grpc_gen import schema_pb2 @@ -88,7 +88,7 @@ def __pack(self, raw): else: self.params[type_param.key] = type_param.value # maybe we'd better not to check these fields in ORM. - if type_param.key in ["dim", DefaultConfigs.MaxVarCharLengthKey]: + if type_param.key in ["dim", Config.MaxVarCharLengthKey]: self.params[type_param.key] = int(type_param.value) index_dict = {} diff --git a/pymilvus/client/configs.py b/pymilvus/client/configs.py deleted file mode 100644 index 1c97fcfb3..000000000 --- a/pymilvus/client/configs.py +++ /dev/null @@ -1,7 +0,0 @@ -# TODO(dragondriver): add more default configs to here -class DefaultConfigs: - WaitTimeDurationWhenLoad = 0.5 # in seconds - MaxVarCharLengthKey = "max_length" - MaxVarCharLength = 65535 - EncodeProtocol = 'utf-8' - IndexName = "" diff --git a/pymilvus/client/constants.py b/pymilvus/client/constants.py index d9454ab89..a0b05a2e8 100644 --- a/pymilvus/client/constants.py +++ b/pymilvus/client/constants.py @@ -7,3 +7,4 @@ EVENTUALLY_TS = 1 BOUNDED_TS = 2 DEFAULT_CONSISTENCY_LEVEL = ConsistencyLevel.Bounded +DEFAULT_RESOURCE_GROUP = "__default_resource_group" diff --git a/pymilvus/client/entity_helper.py b/pymilvus/client/entity_helper.py index a36e1728f..ff0a63e70 100644 --- a/pymilvus/client/entity_helper.py +++ b/pymilvus/client/entity_helper.py @@ -1,7 +1,7 @@ from ..grpc_gen import schema_pb2 as schema_types from .types import DataType from ..exceptions import ParamError -from .configs import DefaultConfigs +from ..settings import Config def entity_type_to_dtype(entity_type): @@ -14,8 +14,8 @@ def entity_type_to_dtype(entity_type): def get_max_len_of_var_char(field_info) -> int: - k = DefaultConfigs.MaxVarCharLengthKey - v = DefaultConfigs.MaxVarCharLength + k = Config.MaxVarCharLengthKey + v = Config.MaxVarCharLength return field_info.get("params", {}).get(k, v) @@ -30,9 +30,9 @@ def check_str_arr(str_arr, max_len): def entity_to_str_arr(entity, field_info, check=True): arr = [] - if DefaultConfigs.EncodeProtocol.lower() != 'utf-8'.lower(): + if Config.EncodeProtocol.lower() != 'utf-8'.lower(): for s in entity.get("values"): - arr.append(s.encode(DefaultConfigs.EncodeProtocol)) + arr.append(s.encode(Config.EncodeProtocol)) else: arr = entity.get("values") max_len = int(get_max_len_of_var_char(field_info)) diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 5400c6530..e17520b8c 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -42,8 +42,7 @@ get_server_type, ) -from ..settings import DefaultConfig as config -from .configs import DefaultConfigs +from ..settings import Config from . import ts_utils from . import interceptor @@ -68,7 +67,7 @@ class GrpcHandler: - def __init__(self, uri=config.GRPC_URI, host="", port="", channel=None, **kwargs): + def __init__(self, uri=Config.GRPC_URI, host="", port="", channel=None, **kwargs): self._stub = None self._channel = channel @@ -560,7 +559,7 @@ def alter_alias(self, collection_name, alias, timeout=None, **kwargs): @retry_on_rpc_failure() def create_index(self, collection_name, field_name, params, timeout=None, **kwargs): # for historical reason, index_name contained in kwargs. - index_name = kwargs.pop("index_name", DefaultConfigs.IndexName) + index_name = kwargs.pop("index_name", Config.IndexName) copy_kwargs = copy.deepcopy(kwargs) collection_desc = self.describe_collection(collection_name, timeout=timeout, **copy_kwargs) @@ -732,7 +731,7 @@ def can_loop(t) -> bool: progress = self.get_loading_progress(collection_name, timeout=timeout) if progress >= 100: return - time.sleep(DefaultConfigs.WaitTimeDurationWhenLoad) + time.sleep(Config.WaitTimeDurationWhenLoad) raise MilvusException(message=f"wait for loading collection timeout, collection: {collection_name}") @retry_on_rpc_failure() @@ -788,7 +787,7 @@ def can_loop(t) -> bool: progress = self.get_loading_progress(collection_name, partition_names, timeout=timeout) if progress >= 100: return - time.sleep(DefaultConfigs.WaitTimeDurationWhenLoad) + time.sleep(Config.WaitTimeDurationWhenLoad) raise MilvusException(message=f"wait for loading partition timeout, collection: {collection_name}, partitions: {partition_names}") @retry_on_rpc_failure() diff --git a/pymilvus/client/stub.py b/pymilvus/client/stub.py index 17022c59e..94cf3908b 100644 --- a/pymilvus/client/stub.py +++ b/pymilvus/client/stub.py @@ -3,7 +3,7 @@ from .grpc_handler import GrpcHandler from ..exceptions import MilvusException, ParamError from .types import CompactionState, CompactionPlans, Replica, BulkInsertState, ResourceGroupInfo -from ..settings import DefaultConfig as config +from ..settings import Config from ..decorators import deprecated from .check import is_legal_host, is_legal_port @@ -11,14 +11,14 @@ class Milvus: @deprecated - def __init__(self, host=None, port=config.GRPC_PORT, uri=config.GRPC_URI, channel=None, **kwargs): + def __init__(self, host=None, port=Config.GRPC_PORT, uri=Config.GRPC_URI, channel=None, **kwargs): self.address = self.__get_address(host, port, uri) self._handler = GrpcHandler(address=self.address, channel=channel, **kwargs) if kwargs.get("pre_ping", False) is True: self._handler._wait_for_channel_ready() - def __get_address(self, host=None, port=config.GRPC_PORT, uri=config.GRPC_URI): + def __get_address(self, host=None, port=Config.GRPC_PORT, uri=Config.GRPC_URI): if host is None and uri is None: raise ParamError(message='Host and uri cannot both be None') diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index ed159b6ab..12d688df7 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -40,10 +40,9 @@ ) from .future import SearchFuture, MutationFuture from .utility import _get_connection -from .default_config import DefaultConfig +from ..settings import Config from ..client.types import CompactionState, CompactionPlans, Replica, get_consistency_level, cmp_consistency_level from ..client.constants import DEFAULT_CONSISTENCY_LEVEL -from ..client.configs import DefaultConfigs @@ -167,7 +166,7 @@ def construct_from_dataframe(cls, name, dataframe, **kwargs): else: raise SchemaNotReadyException(message=ExceptionsMessage.AutoIDWithData) - using = kwargs.get("using", DefaultConfig.DEFAULT_USING) + using = kwargs.get("using", Config.MILVUS_CONN_ALIAS) conn = _get_connection(using) if conn.has_collection(name, **kwargs): resp = conn.describe_collection(name, **kwargs) @@ -186,7 +185,7 @@ def construct_from_dataframe(cls, name, dataframe, **kwargs): field.is_primary = True field.auto_id = False if field.dtype == DataType.VARCHAR: - field.params[DefaultConfigs.MaxVarCharLengthKey] = int(DefaultConfigs.MaxVarCharLength) + field.params[Config.MaxVarCharLengthKey] = int(Config.MaxVarCharLength) schema = CollectionSchema(fields=fields_schema) check_schema(schema) @@ -951,7 +950,7 @@ def index(self, **kwargs) -> Index: """ copy_kwargs = copy.deepcopy(kwargs) - index_name = copy_kwargs.pop("index_name", DefaultConfigs.IndexName) + index_name = copy_kwargs.pop("index_name", Config.IndexName) conn = self._get_connection() tmp_index = conn.describe_index(self._name, index_name, **copy_kwargs) if tmp_index is not None: @@ -1027,7 +1026,7 @@ def has_index(self, timeout=None, **kwargs) -> bool: """ conn = self._get_connection() copy_kwargs = copy.deepcopy(kwargs) - index_name = copy_kwargs.pop("index_name", DefaultConfigs.IndexName) + index_name = copy_kwargs.pop("index_name", Config.IndexName) if conn.describe_index(self._name, index_name, timeout=timeout, **copy_kwargs) is None: return False return True @@ -1062,7 +1061,7 @@ def drop_index(self, timeout=None, **kwargs): False """ copy_kwargs = copy.deepcopy(kwargs) - index_name = copy_kwargs.pop("index_name", DefaultConfigs.IndexName) + index_name = copy_kwargs.pop("index_name", Config.IndexName) conn = self._get_connection() tmp_index = conn.describe_index(self._name, index_name, timeout=timeout, **copy_kwargs) if tmp_index is not None: diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index 2186fccc7..bccf0f9d3 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -10,9 +10,7 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. -import os import copy -import re import threading from urllib import parse from typing import Tuple @@ -20,7 +18,7 @@ from ..client.check import is_legal_host, is_legal_port, is_legal_address from ..client.grpc_handler import GrpcHandler -from .default_config import DefaultConfig, ENV_CONNECTION_CONF +from ..settings import Config from ..exceptions import ExceptionsMessage, ConnectionConfigException, ConnectionNotExistException @@ -61,49 +59,72 @@ class Connections(metaclass=SingleInstanceMetaClass): def __init__(self): """ Constructs a default milvus alias config - default config will be read from env: MILVUS_DEFAULT_CONNECTION, - or "localhost:19530" + default config will be read from env: MILVUS_URI and MILVUS_CONN_ALIAS + with default value: default="localhost:19530" - """ - self._alias = {} - self._connected_alias = {} - - self.add_connection(default=self._read_default_config_from_os_env()) + Read default connection config from environment variable: MILVUS_URI. + Format is: + [scheme://][@]host[:] - def _read_default_config_from_os_env(self): - """ Read default connection config from environment variable: MILVUS_DEFAULT_CONNECTION. - Format is: - [@]host[:] + scheme is one of: http, https, or - protocol is one of: http, https, tcp, or - Examples:: + Examples: localhost localhost:19530 test_user@localhost:19530 + http://test_userlocalhost:19530 + https://test_user:password@localhost:19530 + """ + self._alias = {} + self._connected_alias = {} + self._env_uri = None - # no need to adjust http://xxx, https://xxx, tcp://xxxx, - # because protocol is ignored - # @see __generate_address + if Config.MILVUS_URI != "": + address, parsed_uri = self.__parse_address_from_uri(Config.MILVUS_URI) + self._env_uri = (address, parsed_uri) - conf = os.getenv(ENV_CONNECTION_CONF, "").strip() - if not conf: - conf = DefaultConfig.DEFAULT_HOST + default_conn_config = { + "user": parsed_uri.username if parsed_uri.username is not None else "", + "address": address, + } + else: + default_conn_config = { + "user": "", + "address": f"{Config.DEFAULT_HOST}:{Config.DEFAULT_PORT}", + } - rex = re.compile(r"^(?:([^\s/\\:]+)@)?([^\s/\\:]+)(?::(\d{1,5}))?$") - matched = rex.search(conf) + self.add_connection(**{Config.MILVUS_CONN_ALIAS: default_conn_config}) - if not matched: - raise ConnectionConfigException(message=ExceptionsMessage.EnvConfigErr % (ENV_CONNECTION_CONF, conf)) + def __verify_host_port(self, host, port): + if not is_legal_host(host): + raise ConnectionConfigException(message=ExceptionsMessage.HostType) + if not is_legal_port(port): + raise ConnectionConfigException(message=ExceptionsMessage.PortType) + if not 0 <= int(port) < 65535: + raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)") + + + def __parse_address_from_uri(self, uri: str) -> (str, parse.ParseResult): + illegal_uri_msg = "Illegal uri: [{}], expected form 'https://user:pwd@example.com:12345'" + try: + parsed_uri = parse.urlparse(uri) + except (Exception) as e: + raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None + + if len(parsed_uri.netloc) == 0: + raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}") from None - user, host, port = matched.groups() - user = user or "" - port = port or DefaultConfig.DEFAULT_PORT + host = parsed_uri.hostname if parsed_uri.hostname is not None else Config.DEFAULT_HOST + port = parsed_uri.port if parsed_uri.port is not None else Config.DEFAULT_PORT + addr = f"{host}:{port}" - return { - "user": user, - "address": f"{host}:{port}" - } + self.__verify_host_port(host, port) + + if not is_legal_address(addr): + raise ConnectionConfigException(message=illegal_uri_msg.format(uri)) + + return addr, parsed_uri def add_connection(self, **kwargs): """ Configures a milvus connection. @@ -136,7 +157,7 @@ def add_connection(self, **kwargs): ) """ for alias, config in kwargs.items(): - addr = self.__get_full_address( + addr, _ = self.__get_full_address( config.get("address", ""), config.get("uri", ""), config.get("host", ""), @@ -153,42 +174,21 @@ def add_connection(self, **kwargs): self._alias[alias] = alias_config - def __get_full_address(self, address: str = "", uri: str = "", host: str = "", port: str = "") -> str: + def __get_full_address(self, address: str = "", uri: str = "", host: str = "", port: str = "") -> (str, parse.ParseResult): if address != "": if not is_legal_address(address): raise ConnectionConfigException(message=f"Illegal address: {address}, should be in form 'localhost:19530'") - else: - address = self.__generate_address(uri, host, port) + return address, None - return address - - def __generate_address(self, uri: str, host: str, port: str) -> str: - illegal_uri_msg = "Illegal uri: [{}], should be in form 'http://example.com' or 'tcp://6.6.6.6:12345'" if uri != "": - try: - parsed_uri = parse.urlparse(uri) - except (Exception) as e: - raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None - - if len(parsed_uri.netloc) == 0: - raise ConnectionConfigException(message=illegal_uri_msg.format(uri)) + address, parsed = self.__parse_address_from_uri(uri) + return address, parsed - addr = parsed_uri.netloc if ":" in parsed_uri.netloc else f"{parsed_uri.netloc}:{DefaultConfig.DEFAULT_PORT}" - if not is_legal_address(addr): - raise ConnectionConfigException(message=illegal_uri_msg.format(uri)) - return addr + host = host if host != "" else Config.DEFAULT_HOST + port = port if port != "" else Config.DEFAULT_PORT + self.__verify_host_port(host, port) - host = host if host != "" else DefaultConfig.DEFAULT_HOST - port = port if port != "" else DefaultConfig.DEFAULT_PORT - - if not is_legal_host(host): - raise ConnectionConfigException(message=ExceptionsMessage.HostType) - if not is_legal_port(port): - raise ConnectionConfigException(message=ExceptionsMessage.PortType) - if not 0 <= int(port) < 65535: - raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)") - - return f"{host}:{port}" + return f"{host}:{port}", None def disconnect(self, alias: str): """ Disconnects connection from the registry. @@ -214,7 +214,7 @@ def remove_connection(self, alias: str): self.disconnect(alias) self._alias.pop(alias, None) - def connect(self, alias=DefaultConfig.DEFAULT_USING, user="", password="", **kwargs): + def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", **kwargs): """ Constructs a milvus connection and register it under given alias. @@ -234,14 +234,14 @@ def connect(self, alias=DefaultConfig.DEFAULT_USING, user="", password="", **kwa * *port* (``str/int``) -- Optional. The port of Milvus instance. Default at 19530, PyMilvus will fill in the default port if only host is provided. + * *secure* (``bool``) -- + Optional. Default is false. If set to true, tls will be enabled. * *user* (``str``) -- Optional. Use which user to connect to Milvus instance. If user and password are provided, we will add related header in every RPC call. * *password* (``str``) -- Optional and required when user is provided. The password corresponding to the user. - * *secure* (``bool``) -- - Optional. Default is false. If set to true, tls will be enabled. * *client_key_path* (``str``) -- Optional. If use tls two-way authentication, need to write the client.key path. * *client_pem_path* (``str``) -- @@ -262,14 +262,11 @@ def connect(self, alias=DefaultConfig.DEFAULT_USING, user="", password="", **kwa >>> from pymilvus import connections >>> connections.connect("test", host="localhost", port="19530") """ - if not isinstance(alias, str): - raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) - def connect_milvus(**kwargs): gh = GrpcHandler(**kwargs) t = kwargs.get("timeout") - timeout = t if isinstance(t, (int, float)) else DefaultConfig.DEFAULT_CONNECT_TIMEOUT + timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT gh._wait_for_channel_ready(timeout=timeout) kwargs.pop('password') @@ -285,6 +282,9 @@ def with_config(config: Tuple) -> bool: return False + if not isinstance(alias, str): + raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) + config = ( kwargs.pop("address", ""), kwargs.pop("uri", ""), @@ -292,23 +292,50 @@ def with_config(config: Tuple) -> bool: kwargs.pop("port", "") ) + # 1st Priority: connection from params if with_config(config): - in_addr = self.__get_full_address(*config) + in_addr, parsed_uri = self.__get_full_address(*config) kwargs["address"] = in_addr if self.has_connection(alias): if self._alias[alias].get("address") != in_addr: raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) + # uri might take extra info + if parsed_uri is not None: + user = parsed_uri.username if parsed_uri.username is not None else user + password = parsed_uri.password if parsed_uri.password is not None else password + # Set secure=True if uri provided user and password + if len(user) > 0 and len(password) > 0: + kwargs["secure"] = True + connect_milvus(**kwargs, user=user, password=password) + return - else: - if alias not in self._alias: - raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias) + # 2nd Priority, connection configs from env + if self._env_uri is not None: + addr, parsed_uri = self._env_uri + kwargs["address"] = addr + + user = parsed_uri.username if parsed_uri.username is not None else "" + password = parsed_uri.password if parsed_uri.password is not None else "" + # Set secure=True if uri provided user and password + if len(user) > 0 and len(password) > 0: + kwargs["secure"] = True + + connect_milvus(**kwargs, user=user, password=password) + return + # 3rd Priority, connect to cached configs with provided user and password + if alias in self._alias: connect_alias = dict(self._alias[alias].items()) connect_alias["user"] = user connect_milvus(**connect_alias, password=password, **kwargs) + return + + # No params, env, and cached configs for the alias + raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias) + def list_connections(self) -> list: """ List names of all connections. @@ -369,7 +396,7 @@ def has_connection(self, alias: str) -> bool: raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) return alias in self._connected_alias - def _fetch_handler(self, alias=DefaultConfig.DEFAULT_USING) -> GrpcHandler: + def _fetch_handler(self, alias=Config.MILVUS_CONN_ALIAS) -> GrpcHandler: """ Retrieves a GrpcHandler by alias. """ if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) diff --git a/pymilvus/orm/default_config.py b/pymilvus/orm/default_config.py deleted file mode 100644 index 3eaafb8fd..000000000 --- a/pymilvus/orm/default_config.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (C) 2019-2021 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except -# in compliance with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License -# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing permissions and limitations under -# the License. - - -class DefaultConfig: - DEFAULT_USING = "default" - DEFAULT_HOST = "localhost" - DEFAULT_PORT = "19530" - DEFAULT_CONNECT_TIMEOUT = 10 - - -ENV_CONNECTION_CONF = "MILVUS_DEFAULT_CONNECTION" - -DEFAULT_RESOURCE_GROUP = "__default_resource_group" diff --git a/pymilvus/orm/index.py b/pymilvus/orm/index.py index 2329242ca..2bbd39b62 100644 --- a/pymilvus/orm/index.py +++ b/pymilvus/orm/index.py @@ -13,7 +13,7 @@ import copy from ..exceptions import CollectionNotExistException, ExceptionsMessage -from ..client.configs import DefaultConfigs +from ..settings import Config class Index: @@ -64,7 +64,7 @@ def __init__(self, collection, field_name, index_params, **kwargs): self._collection = collection self._field_name = field_name self._index_params = index_params - index_name = kwargs.get("index_name", DefaultConfigs.IndexName) + index_name = kwargs.get("index_name", Config.IndexName) self._index_name = index_name self._kwargs = kwargs if self._kwargs.pop("construct_only", False): @@ -155,6 +155,6 @@ def drop(self, timeout=None, **kwargs): """ copy_kwargs = copy.deepcopy(kwargs) - index_name = copy_kwargs.pop("index_name", DefaultConfigs.IndexName) + index_name = copy_kwargs.pop("index_name", Config.IndexName) conn = self._get_connection() conn.drop_index(self._collection.name, self.field_name, index_name, timeout=timeout, **copy_kwargs) diff --git a/pymilvus/settings.py b/pymilvus/settings.py index b1a3b46cd..c08ad038b 100644 --- a/pymilvus/settings.py +++ b/pymilvus/settings.py @@ -1,16 +1,33 @@ import logging.config +import environs +env = environs.Env() +try: + env.read_env(".env") +except Exception: + pass -class DefaultConfig: +class Config: + # legacy env MILVUS_DEFAULT_CONNECTION, not recommended + LEGACY_URI = env.str("MILVUS_DEFAULT_CONNECTION", "") + MILVUS_URI = env.str("MILVUS_URI", LEGACY_URI) + + MILVUS_CONN_ALIAS = env.str("MILVUS_CONN_ALIAS", "default") + MILVUS_CONN_TIMEOUT = env.float("MILVUS_CONN_TIMEOUT", 10) + + # TODO tidy the following configs GRPC_PORT = "19530" GRPC_ADDRESS = "127.0.0.1:19530" GRPC_URI = f"tcp://{GRPC_ADDRESS}" - HTTP_PORT = "19121" - HTTP_ADDRESS = "127.0.0.1:19121" - HTTP_URI = f"http://{HTTP_ADDRESS}" + DEFAULT_HOST = "localhost" + DEFAULT_PORT = "19530" - CALC_DIST_METRIC = "L2" + WaitTimeDurationWhenLoad = 0.5 # in seconds + MaxVarCharLengthKey = "max_length" + MaxVarCharLength = 65535 + EncodeProtocol = 'utf-8' + IndexName = "" # logging diff --git a/requirements.txt b/requirements.txt index ac3f11773..dd32fd799 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ build==0.4.0 certifi==2022.12.7 chardet==4.0.0 +environs==9.5.0 grpcio==1.53.0 grpcio-testing==1.53.0 grpcio-tools==1.53.0 protobuf>=3.17.1 idna==2.10 -mmh3>=2.0 packaging==20.9 pep517==0.10.0 pyparsing==2.4.7 diff --git a/setup.py b/setup.py index a5a3ee9e8..e42e961a3 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ install_requires=[ "grpcio>=1.49.1,<=1.53.0", "protobuf>=3.20.0", - "mmh3>=2.0", + "environs<=9.5.0", "ujson>=2.0.0", "pandas>=1.2.4", ], diff --git a/tests/test_connections.py b/tests/test_connections.py index 800a6f53c..95531f861 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -6,7 +6,7 @@ from unittest import mock from pymilvus import connections -from pymilvus import DefaultConfig, MilvusException, ENV_CONNECTION_CONF +from pymilvus import DefaultConfig, MilvusException from pymilvus.exceptions import ErrorCode LOGGER = logging.getLogger(__name__) @@ -25,7 +25,7 @@ class TestConnect: - connect with the providing configs if valid, and replace the old ones. Connect to a new alias will: - - connect with the provieding configs if valid, store the new alias with these configs + - connect with the providing configs if valid, store the new alias with these configs """ @pytest.fixture(scope="function", params=[ @@ -72,22 +72,38 @@ def test_connect_with_default_config(self): with mock.patch(f"{mock_prefix}.close", return_value=None): connections.disconnect(alias) - def test_connect_with_default_config_from_environment(self): - test_list = [ - ["", {"address": "localhost:19530", "user": ""}], - ["localhost", {"address": "localhost:19530", "user": ""}], - ["localhost:19530", {"address": "localhost:19530", "user": ""}], - ["abc@localhost", {"address": "localhost:19530", "user": "abc"}], - ["milvus_host", {"address": "milvus_host:19530", "user": ""}], - ["milvus_host:12012", {"address": "milvus_host:12012", "user": ""}], - ["abc@milvus_host:12012", {"address": "milvus_host:12012", "user": "abc"}], - ["abc@milvus_host", {"address": "milvus_host:19530", "user": "abc"}], - ] + @pytest.fixture(scope="function", params=[ + ("", {"address": "localhost:19530", "user": ""}), + ("localhost", {"address": "localhost:19530", "user": ""}), + ("localhost:19530", {"address": "localhost:19530", "user": ""}), + ("abc@localhost", {"address": "localhost:19530", "user": "abc"}), + ("milvus_host", {"address": "milvus_host:19530", "user": ""}), + ("milvus_host:12012", {"address": "milvus_host:12012", "user": ""}), + ("abc@milvus_host:12012", {"address": "milvus_host:12012", "user": "abc"}), + ("abc@milvus_host", {"address": "milvus_host:19530", "user": "abc"}), + ]) + def test_connect_with_default_config_from_environment(self, env_result): + os.environ[DefaultConfig.MILVUS_URI] = env_result[0] + assert env_result[1] == connections._read_default_config_from_os_env() + + with mock.patch(f"{mock_prefix}.__init__", return_value=None): + with mock.patch(f"{mock_prefix}._wait_for_channel_ready", return_value=None): + # use env + connections.connect() - for env_str, assert_config in test_list: - os.environ[ENV_CONNECTION_CONF] = env_str + assert env_result[1] == connections.get_connection_addr(DefaultConfig.MILVUS_CONN_ALIAS) - assert assert_config == connections._read_default_config_from_os_env() + with mock.patch(f"{mock_prefix}.__init__", return_value=None): + with mock.patch(f"{mock_prefix}._wait_for_channel_ready", return_value=None): + # use param + connections.connect(DefaultConfig.MILVUS_CONN_ALIAS, host="test_host", port="19999") + + curr_addr = connections.get_connection_addr(DefaultConfig.MILVUS_CONN_ALIAS) + assert env_result[1] != curr_addr + assert {"address":"test_host:19999", "user": ""} == curr_addr + + with mock.patch(f"{mock_prefix}.close", return_value=None): + connections.remove_connection(DefaultConfig.MILVUS_CONN_ALIAS) def test_connect_new_alias_with_configs(self): alias = "exist" @@ -296,14 +312,13 @@ def test_add_connection_uri(self, valid_uri): host, port = addr["address"].split(':') assert host in valid_uri['uri'] or host in DefaultConfig.DEFAULT_HOST assert port in valid_uri['uri'] or port in DefaultConfig.DEFAULT_PORT + print(addr) with mock.patch(f"{mock_prefix}.close", return_value=None): connections.remove_connection(alias) @pytest.mark.parametrize("invalid_uri", [ - {"uri": "http://:19530"}, - {"uri": "localhost:19530"}, - {"uri": ":80"}, + {"uri": "http://"}, {"uri": None}, {"uri": -1}, ]) diff --git a/tests/test_prepare.py b/tests/test_prepare.py index 125b76d18..ef5729491 100644 --- a/tests/test_prepare.py +++ b/tests/test_prepare.py @@ -2,7 +2,7 @@ from pymilvus.client.prepare import Prepare from pymilvus import DataType, MilvusException, CollectionSchema, FieldSchema -from pymilvus.client.configs import DefaultConfigs +from pymilvus import DefaultConfig class TestPrepare: @@ -84,7 +84,7 @@ def test_param_error_get_schema(self, invalid_fields): {"name": "test_varchar", "type": DataType.VARCHAR, "is_primary": True, "params": {"dim": "invalid"}}, ]}, {"fields": [ - {"name": "test_floatvector", "type": DataType.FLOAT_VECTOR, "params": {"dim": 128, DefaultConfigs.MaxVarCharLengthKey: DefaultConfigs.MaxVarCharLength + 1}}, + {"name": "test_floatvector", "type": DataType.FLOAT_VECTOR, "params": {"dim": 128, DefaultConfig.MaxVarCharLengthKey: DefaultConfig.MaxVarCharLength + 1}}, ]} ]) def test_valid_type_params_get_collection_schema(self, valid_fields):