Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Refactor pymilvus #799

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added pymilvus/v2/__init__.py
Empty file.
15 changes: 15 additions & 0 deletions pymilvus/v2/grpc_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .milvus_server import IServer, GrpcServer


class GRPCHandler:
def __init__(self, server):
if not isinstance(server, IServer):
raise TypeError("Except an IServer")
self._server = server

def create_collection(self, collection_name, fields, shards_num=2):
return self._server.create_collection(collection_name, fields, shards_num)


server_instance = GrpcServer()
GRPCHandler(server_instance)
162 changes: 162 additions & 0 deletions pymilvus/v2/milvus_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from abc import ABCMeta, abstractmethod

import grpc
from grpc._cython import cygrpc

from ..grpc_gen import common_pb2 as common_types
from ..grpc_gen import milvus_pb2 as milvus_types
from ..grpc_gen import milvus_pb2_grpc
from ..grpc_gen import schema_pb2 as schema_types


class IServer(metaclass=ABCMeta):
"""
The interface of milvus server.

Methods
-------
create_collection(collection_name, fields, shards_num) -> common_types.Status
Create a collection in Milvus

drop_collection(collection_name) -> common_types.Status
Drop a collection in Milvus

has_collection(collection_name) -> milvus_types.BoolResponse
Check if a collection exists in Milvus

describe_collection(collection_name) -> milvus_types.DescribeCollectionResponse
Get the schema of a collection in Milvus

list_collections() -> milvus_types.ShowCollectionsResponse
List all collections in Milvus

create_partition(collection_name, partition_name) -> common_types.Status
Create a partition in specified collection of Milvus

drop_partition(collection_name, partition_name) -> common_types.Status
Drop a partition in specified collection of Milvus

has_partition(collection_name, partition_name) -> milvus_types.BoolResponse
Check if a partition exists in specified collection of Milvus

list_partitions(collection_name) -> milvus_types.ShowPartitionsResponse
List all partitions in specified collection of Milvus
"""

def __init__(self):
pass

@abstractmethod
def create_collection(self, collection_name, fields, shards_num) -> common_types.Status:
pass

@abstractmethod
def drop_collection(self, collection_name) -> common_types.Status:
pass

@abstractmethod
def has_collection(self, collection_name) -> milvus_types.BoolResponse:
pass

@abstractmethod
def describe_collection(self, collection_name) -> milvus_types.DescribeCollectionResponse:
pass

@abstractmethod
def list_collections(self) -> milvus_types.ShowCollectionsResponse:
pass

@abstractmethod
def create_partition(self, collection_name, partition_name) -> common_types.Status:
pass

@abstractmethod
def drop_partition(self, collection_name, partition_name) -> common_types.Status:
pass

@abstractmethod
def has_partition(self, collection_name, partition_name) -> milvus_types.BoolResponse:
pass

@abstractmethod
def list_partitions(self, collection_name) -> milvus_types.ShowPartitionsResponse:
pass


class GrpcServer(IServer):
"""
Methods in this class cannot be covered by unit tests(unit tests should not depends on the milvus server), so that
keep them as simple as possible.
"""

def __init__(self, host="localhost", port="19530"):
super().__init__()
self._channel = grpc.insecure_channel(
f"{host}:{port}",
options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
(cygrpc.ChannelArgKey.max_receive_message_length, -1),
('grpc.enable_retries', 1),
('grpc.keepalive_time_ms', 55000)]
)
self._stub = milvus_pb2_grpc.MilvusServiceStub(self._channel)

def create_collection(self, collection_name, fields, shards_num) -> common_types.Status:
assert isinstance(fields, dict)
assert "fields" in fields
assert sum(1 for field in fields["fields"] if "is_primary" in field) == 1
assert sum(1 for field in fields["fields"] if "auto_id" in field) <= 1

schema = schema_types.CollectionSchema(name=collection_name)
for field in fields["fields"]:
field_schema = schema_types.FieldSchema()
assert "name" in field
field_schema.name = field["name"]
assert "type" in field
field_schema.data_type = field["type"]

field_schema.is_primary_key = field.get("is_primary", False)
field_schema.autoID = field.get('auto_id', False)

if "params" in field:
assert isinstance(field["params"], dict)
assert "dim" in field["params"]
kv_pair = common_types.KeyValuePair(key="dim", value=str(int(field["params"]["dim"])))
field_schema.type_params.append(kv_pair)

schema.fields.append(field_schema)

request = milvus_types.CreateCollectionRequest(collection_name=collection_name,
schema=bytes(schema.SerializeToString()), shards_num=shards_num)
return self._stub.CreateCollection(request)

def drop_collection(self, collection_name) -> common_types.Status:
request = milvus_types.DropCollectionRequest(collection_name=collection_name)
return self._stub.DropCollection(request)

def has_collection(self, collection_name) -> milvus_types.BoolResponse:
request = milvus_types.HasCollectionRequest(collection_name=collection_name)
return self._stub.HasCollection(request)

def describe_collection(self, collection_name) -> milvus_types.DescribeCollectionResponse:
request = milvus_types.DescribeCollectionRequest(collection_name=collection_name)
return self._stub.DescribeCollection(request)

def list_collections(self) -> milvus_types.ShowCollectionsResponse:
request = milvus_types.ShowCollectionsRequest()
return self._stub.ShowCollections(request)

def create_partition(self, collection_name, partition_name) -> common_types.Status:
request = milvus_types.CreatePartitionRequest(collection_name=collection_name, partition_name=partition_name)
return self._stub.CreatePartition(request)

def drop_partition(self, collection_name, partition_name) -> common_types.Status:
request = milvus_types.DropPartitionRequest(collection_name=collection_name, partition_name=partition_name)
return self._stub.DropPartition(request)

def has_partition(self, collection_name, partition_name) -> milvus_types.BoolResponse:
request = milvus_types.HasPartitionRequest(collection_name=collection_name, partition_name=partition_name)
return self._stub.HasPartition(request)

def list_partitions(self, collection_name) -> milvus_types.ShowPartitionsResponse:
request = milvus_types.ShowPartitionsRequest(collection_name=collection_name)
return self._stub.ShowPartitions(request)
206 changes: 206 additions & 0 deletions pymilvus/v2/test_milvus_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import os
import random

import pytest

from .milvus_server import GrpcServer
from .types import DataType


@pytest.fixture
def server_instance():
# just for debug
host = os.getenv("host")
host = host if host else "127.0.0.1"
port = os.getenv("port")
port = port if port else "19530"
return GrpcServer(host=host, port=port)


@pytest.fixture
def collection_name():
# just for develop
return f"collection_{random.randint(100000000, 999999999)}"


@pytest.fixture
def partition_name():
# just for develop
return f"partition_{random.randint(100000000, 999999999)}"


class TestCreateCollection:
def test_create_collection(self, server_instance, collection_name):
response = server_instance.create_collection(collection_name, {"fields": [
{
"name": "my_id",
"type": DataType.INT64,
"auto_id": True,
"is_primary": True,
},
{
"name": "my_vector",
"type": DataType.FLOAT_VECTOR,
"params": {"dim": 64},
}
], "description": "this is a description"}, 2)
assert response.error_code == 0


class TestDropCollection:
def test_drop_collection(self, server_instance, collection_name):
server_instance.create_collection(collection_name, {"fields": [
{
"name": "my_id",
"type": DataType.INT64,
"auto_id": True,
"is_primary": True,
},
{
"name": "my_vector",
"type": DataType.FLOAT_VECTOR,
"params": {"dim": 64},
}
], "description": "this is a description"}, 2)
response = server_instance.drop_collection(collection_name)
assert response.error_code == 0


class TestHasCollection:
def test_has_collection(self, server_instance, collection_name):
server_instance.create_collection(collection_name, {"fields": [
{
"name": "my_id",
"type": DataType.INT64,
"auto_id": True,
"is_primary": True,
},
{
"name": "my_vector",
"type": DataType.FLOAT_VECTOR,
"params": {"dim": 64},
}
], "description": "this is a description"}, 2)
response = server_instance.has_collection(collection_name)
assert response.status.error_code == 0
assert response.value is True


class TestDescribeCollection:
def test_describe_collection(self, server_instance, collection_name):
server_instance.create_collection(collection_name, {"fields": [
{
"name": "my_id",
"type": DataType.INT64,
"auto_id": True,
"is_primary": True,
},
{
"name": "my_vector",
"type": DataType.FLOAT_VECTOR,
"params": {"dim": 64},
}
], "description": "this is a description"}, 2)
response = server_instance.describe_collection(collection_name)
assert response.status.error_code == 0


class TestListCollections:
def test_list_collections(self, server_instance, collection_name):
server_instance.create_collection(collection_name, {"fields": [
{
"name": "my_id",
"type": DataType.INT64,
"auto_id": True,
"is_primary": True,
},
{
"name": "my_vector",
"type": DataType.FLOAT_VECTOR,
"params": {"dim": 64},
}
], "description": "this is a description"}, 2)
response = server_instance.list_collections()
assert response.status.error_code == 0
assert collection_name in list(response.collection_names)


class TestCreatePartition:
def test_create_partition(self, server_instance, collection_name, partition_name):
server_instance.create_collection(collection_name, {"fields": [
{
"name": "my_id",
"type": DataType.INT64,
"auto_id": True,
"is_primary": True,
},
{
"name": "my_vector",
"type": DataType.FLOAT_VECTOR,
"params": {"dim": 64},
}
], "description": "this is a description"}, 2)
response = server_instance.create_partition(collection_name, partition_name)
assert response.error_code == 0


class TestDropPartition:
def test_drop_partition(self, server_instance, collection_name, partition_name):
server_instance.create_collection(collection_name, {"fields": [
{
"name": "my_id",
"type": DataType.INT64,
"auto_id": True,
"is_primary": True,
},
{
"name": "my_vector",
"type": DataType.FLOAT_VECTOR,
"params": {"dim": 64},
}
], "description": "this is a description"}, 2)
server_instance.create_partition(collection_name, partition_name)
response = server_instance.drop_partition(collection_name, partition_name)
assert response.error_code == 0


class TestHasPartition:
def test_has_partition(self, server_instance, collection_name, partition_name):
server_instance.create_collection(collection_name, {"fields": [
{
"name": "my_id",
"type": DataType.INT64,
"auto_id": True,
"is_primary": True,
},
{
"name": "my_vector",
"type": DataType.FLOAT_VECTOR,
"params": {"dim": 64},
}
], "description": "this is a description"}, 2)
server_instance.create_partition(collection_name, partition_name)
response = server_instance.has_partition(collection_name, partition_name)
assert response.status.error_code == 0
assert response.value is True


class TestListPartitions:
def test_list_partitions(self, server_instance, collection_name, partition_name):
server_instance.create_collection(collection_name, {"fields": [
{
"name": "my_id",
"type": DataType.INT64,
"auto_id": True,
"is_primary": True,
},
{
"name": "my_vector",
"type": DataType.FLOAT_VECTOR,
"params": {"dim": 64},
}
], "description": "this is a description"}, 2)
server_instance.create_partition(collection_name, partition_name)
response = server_instance.list_partitions(collection_name)
assert response.status.error_code == 0
assert partition_name in list(response.partition_names)
Loading