Skip to content

Commit

Permalink
Add an utility to get server type (milvus-io#1381)
Browse files Browse the repository at this point in the history
Signed-off-by: longjiquan <[email protected]>
  • Loading branch information
longjiquan authored Apr 21, 2023
1 parent 6b5cf24 commit 05f3384
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 4 deletions.
6 changes: 5 additions & 1 deletion pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@

from .utils import (
check_invalid_binary_vector,
len_of
len_of,
get_server_type,
)

from ..settings import DefaultConfig as config
Expand Down Expand Up @@ -192,6 +193,9 @@ def server_address(self):
""" Server network address """
return self._address

def get_server_type(self):
return get_server_type(self.server_address)

def reset_password(self, user, old_password, new_password, timeout=None):
"""
reset password and then setup the grpc channel.
Expand Down
3 changes: 3 additions & 0 deletions pymilvus/client/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def name(self):
def handler(self):
return self._handler

def get_server_type(self):
return self._handler.get_server_type()

def reset_password(self, user, old_password, new_password):
self._handler.reset_password(user, old_password, new_password)

Expand Down
26 changes: 23 additions & 3 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import datetime

from urllib.parse import urlparse

from .types import DataType
from .constants import LOGICAL_BITS, LOGICAL_BITS_MASK
from ..exceptions import ParamError, MilvusException
Expand Down Expand Up @@ -167,7 +169,7 @@ def traverse_info(fields_info, entities):
if field_name == entity_name:
if field_type != entity_type:
raise ParamError(message=f"Collection field type is {field_type}"
f", but entities field type is {entity_type}")
f", but entities field type is {entity_type}")

entity_dim, field_dim = 0, 0
if entity_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
Expand All @@ -176,11 +178,11 @@ def traverse_info(fields_info, entities):

if entity_type in [DataType.FLOAT_VECTOR, ] and entity_dim != field_dim:
raise ParamError(message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim}")
f", but entities field dim is {entity_dim}")

if entity_type in [DataType.BINARY_VECTOR, ] and entity_dim * 8 != field_dim:
raise ParamError(message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim * 8}")
f", but entities field dim is {entity_dim * 8}")

location[field["name"]] = i
match_flag = True
Expand All @@ -191,3 +193,21 @@ def traverse_info(fields_info, entities):
message=f"Field {field['name']} don't match in entities")

return location, primary_key_loc, auto_id_loc


def get_protocol_and_domain(host):
o = urlparse(host)
return o.scheme, o.hostname


def get_server_type(host):
protocol, hostname = get_protocol_and_domain(host)
if protocol != "https":
return "milvus"
splits = hostname.split('.')
len_of_splits = len(splits)
if len_of_splits >= 2 and \
splits[len_of_splits - 2].lower() == "zillizcloud" and \
splits[len_of_splits - 1].lower() == "com":
return "zilliz"
return "milvus"
14 changes: 14 additions & 0 deletions pymilvus/orm/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,7 @@ def transfer_replica(source_group, target_group, collection_name, num_replicas,
"""
return _get_connection(using).transfer_replica(source_group, target_group, collection_name, num_replicas, timeout)


def flush_all(using="default", timeout=None, **kwargs):
""" Flush all collections. All insertions, deletions, and upserts before `flush_all` will be synced.
Expand Down Expand Up @@ -1064,3 +1065,16 @@ def flush_all(using="default", timeout=None, **kwargs):
>>> future.done() # flush_all finished
"""
return _get_connection(using).flush_all(timeout=timeout, **kwargs)


def get_server_type(using="default"):
""" Get the server type. Now, it will return "zilliz" if the connection related to an instance on the zilliz cloud,
otherwise "milvus" will be returned.
:param using: Alias to the connection. Default connection is used if this is not specified.
:type: str
:return: The server type.
:rtype: str
"""
return _get_connection(using).get_server_type()
21 changes: 21 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pymilvus.client import utils


class TestUtils:
def test_get_server_type(self):
url1 = 'in01-0390f61a8675594.aws-us-west-2.vectordb.zillizcloud.com'
assert utils.get_server_type(url1) == "milvus"

url2 = 'https://in01-0390f61a8675594.aws-us-west-2.vectordb.zillizcloud.com'
assert utils.get_server_type(url2) == "zilliz"

url3 = 'http://in01-0390f61a8675594.aws-us-west-2.vectordb.zillizcloud.com'
assert utils.get_server_type(url3) == "milvus"

url4 = 'https://something.notzillizcloud.com'
assert utils.get_server_type(url4) == "milvus"

url5 = 'https://something.zillizcloud.not.com'
assert utils.get_server_type(url5) == "milvus"


0 comments on commit 05f3384

Please sign in to comment.