From e83b1c1fec40537d282b407f9023a1bd5dc38408 Mon Sep 17 00:00:00 2001 From: "zhenshan.cao" Date: Fri, 12 Aug 2022 17:28:45 +0800 Subject: [PATCH] Remove redundant checks and add MetaCache (#301) Signed-off-by: zhenshan.cao Signed-off-by: zhenshan.cao --- client/client.go | 7 +- client/client_grpc_collection.go | 16 +- client/client_grpc_data.go | 249 +++++++++-------------------- client/client_grpc_data_test.go | 58 +------ client/client_grpc_options.go | 11 +- client/client_grpc_options_test.go | 24 +-- client/client_grpc_row.go | 2 +- client/meta_cache.go | 79 +++++++++ client/meta_cache_test.go | 55 +++++++ client/timestamp_map.go | 43 ----- client/timestamp_map_test.go | 32 ---- examples/query/query.go | 2 +- 12 files changed, 252 insertions(+), 326 deletions(-) create mode 100644 client/meta_cache.go create mode 100644 client/meta_cache_test.go delete mode 100644 client/timestamp_map.go delete mode 100644 client/timestamp_map_test.go diff --git a/client/client.go b/client/client.go index 13d56152..7b697100 100644 --- a/client/client.go +++ b/client/client.go @@ -17,16 +17,11 @@ import ( "time" "github.com/milvus-io/milvus-sdk-go/v2/entity" + "google.golang.org/grpc" ) // Client is the interface used to communicate with Milvus -// The common usage is like follow -// c, err := client.NewGrpcClient(context.Background, "address-to-milvus") // or other creation func maybe added later -// if err != nil { -// //handle err -// } -// // start doing things with client instance, note that there is no need to call Connect since NewXXXClient will do that for you type Client interface { // Close close the remaining connection resources Close() error diff --git a/client/client_grpc_collection.go b/client/client_grpc_collection.go index edd4e8ef..f9805489 100644 --- a/client/client_grpc_collection.go +++ b/client/client_grpc_collection.go @@ -19,9 +19,10 @@ import ( "github.com/golang/protobuf/proto" "github.com/milvus-io/milvus-sdk-go/v2/entity" + "google.golang.org/grpc" + common "github.com/milvus-io/milvus-sdk-go/v2/internal/proto/common" server "github.com/milvus-io/milvus-sdk-go/v2/internal/proto/server" - "google.golang.org/grpc" ) // grpcClient, uses default grpc service definition to connect with Milvus2.0 @@ -231,6 +232,13 @@ func (c *grpcClient) DescribeCollection(ctx context.Context, collName string) (* ConsistencyLevel: entity.ConsistencyLevel(resp.ConsistencyLevel), } collection.Name = collection.Schema.CollectionName + colInfo := collInfo{ + ID: collection.ID, + Name: collection.Name, + Schema: collection.Schema, + ConsistencyLevel: collection.ConsistencyLevel, + } + MetaCache.setCollectionInfo(resp.CollectionName, &colInfo) return collection, nil } @@ -250,7 +258,11 @@ func (c *grpcClient) DropCollection(ctx context.Context, collName string) error if err != nil { return err } - return handleRespStatus(resp) + err = handleRespStatus(resp) + if err == nil { + MetaCache.setCollectionInfo(collName, nil) + } + return err } // HasCollection check whether collection name exists diff --git a/client/client_grpc_data.go b/client/client_grpc_data.go index f44cb701..a78b3078 100644 --- a/client/client_grpc_data.go +++ b/client/client_grpc_data.go @@ -17,10 +17,8 @@ import ( "errors" "fmt" "log" - "math" "strconv" "strings" - "sync" "time" "github.com/golang/protobuf/proto" @@ -117,7 +115,7 @@ func (c *grpcClient) Insert(ctx context.Context, collName string, partitionName if err := handleRespStatus(resp.GetStatus()); err != nil { return nil, err } - tsm.set(coll.ID, resp.Timestamp) + MetaCache.setSessionTs(collName, resp.Timestamp) // 3. parse id column return entity.IDColumns(resp.GetIDs(), 0, -1) } @@ -222,7 +220,7 @@ func (c *grpcClient) DeleteByPks(ctx context.Context, collName string, partition if err != nil { return err } - tsm.set(coll.ID, resp.Timestamp) + MetaCache.setSessionTs(collName, resp.Timestamp) return nil } @@ -232,124 +230,54 @@ func (c *grpcClient) Search(ctx context.Context, collName string, partitions []s if c.service == nil { return []SearchResult{}, ErrClientNotReady } - // 1. check all input params - if err := c.checkCollectionExists(ctx, collName); err != nil { - return nil, err + _, ok := MetaCache.getCollectionInfo(collName) + if !ok { + c.DescribeCollection(ctx, collName) } - for _, partition := range partitions { - err := c.checkPartitionExists(ctx, collName, partition) - if err != nil { - return nil, err - } - } - // TODO maybe add expr analysis? - coll, err := c.DescribeCollection(ctx, collName) + option, err := MakeSearchQueryOption(collName, opts...) if err != nil { return nil, err } - if coll.Schema.CollectionName == "" { - coll.Schema.CollectionName = collName - } - mNameField := make(map[string]*entity.Field) - for _, field := range coll.Schema.Fields { - mNameField[field.Name] = field - } - for _, outField := range outputFields { - _, has := mNameField[outField] - if !has { - return nil, fmt.Errorf("field %s does not exist in collection %s", outField, collName) - } - } - vfDef, has := mNameField[vectorField] - if !has { - return nil, fmt.Errorf("vector field %s does not exist in collection %s", vectorField, collName) - } - dimStr := vfDef.TypeParams[entity.TypeParamDim] - for _, vector := range vectors { - if vector.FieldType() != vfDef.DataType { - return nil, fmt.Errorf("vector %s shall be type of %s, but input found type:%s", vectorField, - vfDef.DataType.String(), vector.FieldType().String()) - } - - if fmt.Sprintf("%d", vector.Dim()) != dimStr { - return nil, fmt.Errorf("vector %s has dim of %s while found search vector with dim %d", vectorField, - dimStr, vector.Dim()) - } - } - - switch vfDef.DataType { - case entity.FieldTypeFloatVector: - if metricType != entity.IP && metricType != entity.L2 { - return nil, fmt.Errorf("Float vector does not support metric type %s", metricType) - } - case entity.FieldTypeBinaryVector: - if metricType == entity.IP || metricType == entity.L2 { - return nil, fmt.Errorf("Binary vector does not support metric type %s", metricType) - } + // 2. Request milvus service + req, err := prepareSearchRequest(collName, partitions, expr, outputFields, vectors, vectorField, metricType, topK, sp, option) + if err != nil { + return nil, err } - option, err := MakeSearchQueryOption(coll, opts...) + sr := make([]SearchResult, 0, len(vectors)) + resp, err := c.service.Search(ctx, req) if err != nil { return nil, err } - - // 2. Request milvus service - reqs := splitSearchRequest(coll.Schema, partitions, expr, outputFields, vectors, vfDef.DataType, vectorField, metricType, topK, sp, option) - if len(reqs) == 0 { - return nil, errors.New("empty request generated") + if err := handleRespStatus(resp.GetStatus()); err != nil { + return nil, err } - wg := &sync.WaitGroup{} - wg.Add(len(reqs)) - var batchErr error - sr := make([]SearchResult, 0, len(vectors)) - mut := sync.Mutex{} - for _, req := range reqs { - go func(req *server.SearchRequest) { - defer wg.Done() - resp, err := c.service.Search(ctx, req) + // 3. parse result into result + results := resp.GetResults() + offset := 0 + fieldDataList := results.GetFieldsData() + for i := 0; i < int(results.GetNumQueries()); i++ { + rc := int(results.GetTopks()[i]) // result entry count for current query + entry := SearchResult{ + ResultCount: rc, + Scores: results.GetScores()[offset : offset+rc], + } + entry.IDs, entry.Err = entity.IDColumns(results.GetIds(), offset, offset+rc) + if entry.Err != nil { + offset += rc + continue + } + entry.Fields = make([]entity.Column, 0, len(fieldDataList)) + for _, fieldData := range fieldDataList { + column, err := entity.FieldDataColumn(fieldData, offset, offset+rc) if err != nil { - batchErr = err - return - } - if err := handleRespStatus(resp.GetStatus()); err != nil { - batchErr = err - return - } - // 3. parse result into result - results := resp.GetResults() - offset := 0 - fieldDataList := results.GetFieldsData() - for i := 0; i < int(results.GetNumQueries()); i++ { - - rc := int(results.GetTopks()[i]) // result entry count for current query - entry := SearchResult{ - ResultCount: rc, - Scores: results.GetScores()[offset : offset+rc], - } - entry.IDs, entry.Err = entity.IDColumns(results.GetIds(), offset, offset+rc) - if entry.Err != nil { - offset += rc - continue - } - entry.Fields = make([]entity.Column, 0, len(fieldDataList)) - for _, fieldData := range fieldDataList { - column, err := entity.FieldDataColumn(fieldData, offset, offset+rc) - if err != nil { - entry.Err = err - continue - } - entry.Fields = append(entry.Fields, column) - } - mut.Lock() - sr = append(sr, entry) - mut.Unlock() - offset += rc + entry.Err = err + continue } - }(req) - } - wg.Wait() - if batchErr != nil { - return []SearchResult{}, batchErr + entry.Fields = append(entry.Fields, column) + } + sr = append(sr, entry) + offset += rc } return sr, nil } @@ -359,23 +287,6 @@ func (c *grpcClient) QueryByPks(ctx context.Context, collectionName string, part if c.service == nil { return nil, ErrClientNotReady } - - // check collection exists and get collection schema - // check collection name - if err := c.checkCollectionExists(ctx, collectionName); err != nil { - return nil, err - } - coll, err := c.DescribeCollection(ctx, collectionName) - if err != nil { - return nil, err - } - // check partition exists - for _, partitionName := range partitionNames { - err := c.checkPartitionExists(ctx, collectionName, partitionName) - if err != nil { - return nil, err - } - } // check primary keys if ids.Len() == 0 { return nil, errors.New("ids len must not be zero") @@ -384,12 +295,6 @@ func (c *grpcClient) QueryByPks(ctx context.Context, collectionName string, part return nil, errors.New("only int64 and varchar column can be primary key for now") } - pkf := getPKField(coll.Schema) - // pkf shall not be nil since is returned from milvus - if pkf.Name != ids.Name() { - return nil, errors.New("only query by primary key is supported now") - } - var expr string switch ids.Type() { case entity.FieldTypeInt64: @@ -402,7 +307,12 @@ func (c *grpcClient) QueryByPks(ctx context.Context, collectionName string, part expr = fmt.Sprintf("%s in %s", ids.Name(), strings.Join(strings.Fields(fmt.Sprint(data)), ",")) } - option, err := MakeSearchQueryOption(coll, opts...) + _, ok := MetaCache.getCollectionInfo(collectionName) + if !ok { + c.DescribeCollection(ctx, collectionName) + } + + option, err := MakeSearchQueryOption(collectionName, opts...) if err != nil { return nil, err } @@ -456,11 +366,14 @@ func getPKField(schema *entity.Schema) *entity.Field { return nil } -func splitSearchRequest(sch *entity.Schema, partitions []string, - expr string, outputFields []string, vectors []entity.Vector, fieldType entity.FieldType, vectorField string, - metricType entity.MetricType, topK int, sp entity.SearchParam, opt *SearchQueryOption) []*server.SearchRequest { +func prepareSearchRequest(collName string, partitions []string, + expr string, outputFields []string, vectors []entity.Vector, vectorField string, + metricType entity.MetricType, topK int, sp entity.SearchParam, opt *SearchQueryOption) (*server.SearchRequest, error) { params := sp.Params() - bs, _ := json.Marshal(params) + bs, err := json.Marshal(params) + if err != nil { + return nil, err + } searchParams := entity.MapKvPairs(map[string]string{ "anns_field": vectorField, "topk": fmt.Sprintf("%d", topK), @@ -468,32 +381,19 @@ func splitSearchRequest(sch *entity.Schema, partitions []string, "metric_type": string(metricType), "round_decimal": "-1", }) - - ers := estRowSize(sch, outputFields) - maxBatch := int(math.Ceil(float64(5*1024*1024) / float64(ers*int64(topK)))) - var result []*server.SearchRequest - for i := 0; i*maxBatch < len(vectors); i++ { - start := i * maxBatch - end := (i + 1) * maxBatch - if end > len(vectors) { - end = len(vectors) - } - batchVectors := vectors[start:end] - req := &server.SearchRequest{ - DbName: "", - CollectionName: sch.CollectionName, - PartitionNames: partitions, - Dsl: expr, - PlaceholderGroup: vector2PlaceholderGroupBytes(batchVectors, fieldType), - DslType: common.DslType_BoolExprV1, - OutputFields: outputFields, - SearchParams: searchParams, - GuaranteeTimestamp: opt.GuaranteeTimestamp, - TravelTimestamp: opt.TravelTimestamp, - } - result = append(result, req) + req := &server.SearchRequest{ + DbName: "", + CollectionName: collName, + PartitionNames: partitions, + Dsl: expr, + PlaceholderGroup: vector2PlaceholderGroupBytes(vectors), + DslType: common.DslType_BoolExprV1, + OutputFields: outputFields, + SearchParams: searchParams, + GuaranteeTimestamp: opt.GuaranteeTimestamp, + TravelTimestamp: opt.TravelTimestamp, } - return result + return req, nil } // GetPersistentSegmentInfo get persistent segment info @@ -690,10 +590,10 @@ func columnToVectorsArray(collName string, partitions []string, column entity.Co return result } -func vector2PlaceholderGroupBytes(vectors []entity.Vector, fieldType entity.FieldType) []byte { +func vector2PlaceholderGroupBytes(vectors []entity.Vector) []byte { phg := &common.PlaceholderGroup{ Placeholders: []*common.PlaceholderValue{ - vector2Placeholder(vectors, fieldType), + vector2Placeholder(vectors), }, } @@ -701,19 +601,22 @@ func vector2PlaceholderGroupBytes(vectors []entity.Vector, fieldType entity.Fiel return bs } -func vector2Placeholder(vectors []entity.Vector, fieldType entity.FieldType) *common.PlaceholderValue { +func vector2Placeholder(vectors []entity.Vector) *common.PlaceholderValue { var placeHolderType common.PlaceholderType - switch fieldType { - case entity.FieldTypeFloatVector: - placeHolderType = common.PlaceholderType_FloatVector - case entity.FieldTypeBinaryVector: - placeHolderType = common.PlaceholderType_BinaryVector - } ph := &common.PlaceholderValue{ Tag: "$0", - Type: placeHolderType, Values: make([][]byte, 0, len(vectors)), } + if len(vectors) == 0 { + return ph + } + switch vectors[0].(type) { + case entity.FloatVector: + placeHolderType = common.PlaceholderType_FloatVector + case entity.BinaryVector: + placeHolderType = common.PlaceholderType_BinaryVector + } + ph.Type = placeHolderType for _, vector := range vectors { ph.Values = append(ph.Values, vector.Serialize()) } diff --git a/client/client_grpc_data_test.go b/client/client_grpc_data_test.go index 7b8df165..5d67943f 100644 --- a/client/client_grpc_data_test.go +++ b/client/client_grpc_data_test.go @@ -267,56 +267,8 @@ func TestGrpcSearch(t *testing.T) { sp, err := entity.NewIndexFlatSearchParam(10) assert.Nil(t, err) - // collection name - mock.delInjection(mHasCollection) - r, err := c.Search(ctx, testCollectionName, []string{}, "", []string{}, []entity.Vector{}, "vector", - entity.L2, 5, sp) - assert.Nil(t, r) - assert.NotNil(t, err) - - // partition - mock.setInjection(mHasCollection, hasCollectionDefault) - r, err = c.Search(ctx, testCollectionName, []string{"_non_exist"}, "", []string{}, []entity.Vector{}, "vector", - entity.L2, 5, sp) - assert.Nil(t, r) - assert.NotNil(t, err) - - // output field - mock.setInjection(mDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema())) - r, err = c.Search(ctx, testCollectionName, []string{}, "", []string{"extra"}, []entity.Vector{}, "vector", - entity.L2, 5, sp) - assert.Nil(t, r) - assert.NotNil(t, err) - - // vector field - mock.setInjection(mDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema())) - r, err = c.Search(ctx, testCollectionName, []string{}, "", []string{"int64"}, []entity.Vector{}, "no_vector", - entity.L2, 5, sp) - assert.Nil(t, r) - assert.NotNil(t, err) - - // vector dim - badVectors := generateFloatVector(1, testVectorDim*2) - r, err = c.Search(ctx, testCollectionName, []string{}, "", []string{"int64"}, []entity.Vector{entity.FloatVector(badVectors[0])}, "vector", - entity.L2, 5, sp) - assert.Nil(t, r) - assert.NotNil(t, err) - - // wrong vector type - binaryVector := generateBinaryVector(1, testVectorDim) - r, err = c.Search(ctx, testCollectionName, []string{}, "", []string{"int64"}, []entity.Vector{entity.BinaryVector(binaryVector[0])}, "vector", - entity.L2, 5, sp) - assert.Nil(t, r) - assert.Error(t, err) - - // metric type - r, err = c.Search(ctx, testCollectionName, []string{}, "", []string{"int64"}, []entity.Vector{entity.FloatVector(vectors[0])}, "vector", - entity.HAMMING, 5, sp) - assert.Nil(t, r) - assert.NotNil(t, err) - // specify guarantee timestamp in strong consistency level - r, err = c.Search(ctx, testCollectionName, []string{}, "", []string{"int64"}, []entity.Vector{entity.FloatVector(vectors[0])}, "vector", + r, err := c.Search(ctx, testCollectionName, []string{}, "", []string{"int64"}, []entity.Vector{entity.FloatVector(vectors[0])}, "vector", entity.HAMMING, 5, sp, WithSearchQueryConsistencyLevel(entity.ClStrong), WithGuaranteeTimestamp(1)) assert.Nil(t, r) assert.NotNil(t, err) @@ -565,10 +517,6 @@ func TestGrpcQueryByPks(t *testing.T) { // string pk field _, err = c.QueryByPks(ctx, testCollectionName, []string{}, entity.NewColumnString("pk", []string{"1"}), []string{}) assert.Error(t, err) - - // pk name not match - _, err = c.QueryByPks(ctx, testCollectionName, []string{}, entity.NewColumnInt64("non_pk", []int64{1}), []string{}) - assert.Error(t, err) }) t.Run("Query service error", func(t *testing.T) { @@ -1073,7 +1021,7 @@ func TestVector2PlaceHolder(t *testing.T) { vectors = append(vectors, entity.FloatVector(row)) } - phv := vector2Placeholder(vectors, entity.FieldTypeFloatVector) + phv := vector2Placeholder(vectors) assert.Equal(t, "$0", phv.Tag) assert.Equal(t, common.PlaceholderType_FloatVector, phv.Type) require.Equal(t, len(vectors), len(phv.Values)) @@ -1089,7 +1037,7 @@ func TestVector2PlaceHolder(t *testing.T) { vectors = append(vectors, entity.BinaryVector(row)) } - phv := vector2Placeholder(vectors, entity.FieldTypeBinaryVector) + phv := vector2Placeholder(vectors) assert.Equal(t, "$0", phv.Tag) assert.Equal(t, common.PlaceholderType_BinaryVector, phv.Type) require.Equal(t, len(vectors), len(phv.Values)) diff --git a/client/client_grpc_options.go b/client/client_grpc_options.go index 6c86812a..9813cdf8 100644 --- a/client/client_grpc_options.go +++ b/client/client_grpc_options.go @@ -69,14 +69,17 @@ func WithTravelTimestamp(tt uint64) SearchQueryOptionFunc { } } -func MakeSearchQueryOption(c *entity.Collection, opts ...SearchQueryOptionFunc) (*SearchQueryOption, error) { +func MakeSearchQueryOption(collName string, opts ...SearchQueryOptionFunc) (*SearchQueryOption, error) { opt := &SearchQueryOption{ - ConsistencyLevel: c.ConsistencyLevel, // default + ConsistencyLevel: entity.ClBounded, // default + } + info, ok := MetaCache.getCollectionInfo(collName) + if ok { + opt.ConsistencyLevel = info.ConsistencyLevel } for _, o := range opts { o(opt) } - // sanity-check if opt.ConsistencyLevel != entity.ClCustomized && opt.GuaranteeTimestamp != 0 { return nil, errors.New("user can only specify guarantee timestamp under customized consistency level") @@ -86,7 +89,7 @@ func MakeSearchQueryOption(c *entity.Collection, opts ...SearchQueryOptionFunc) case entity.ClStrong: opt.GuaranteeTimestamp = StrongTimestamp case entity.ClSession: - ts, ok := tsm.get(c.ID) + ts, ok := MetaCache.getSessionTs(collName) if !ok { ts = EventuallyTimestamp } diff --git a/client/client_grpc_options_test.go b/client/client_grpc_options_test.go index 3be080e0..aeb4b32c 100644 --- a/client/client_grpc_options_test.go +++ b/client/client_grpc_options_test.go @@ -46,12 +46,18 @@ func TestLoadCollectionWithReplicaNumber(t *testing.T) { func TestMakeSearchQueryOption(t *testing.T) { c := &entity.Collection{ - ID: 999, + Name: "999", ConsistencyLevel: entity.ClStrong, } + cInfo := collInfo{ + Name: c.Name, + ConsistencyLevel: c.ConsistencyLevel, + } + MetaCache.setCollectionInfo(c.Name, &cInfo) + t.Run("strong consistency", func(t *testing.T) { - opt, err := MakeSearchQueryOption(c) + opt, err := MakeSearchQueryOption(c.Name) assert.Nil(t, err) assert.NotNil(t, opt) expected := &SearchQueryOption{ @@ -62,7 +68,7 @@ func TestMakeSearchQueryOption(t *testing.T) { }) t.Run("session consistency", func(t *testing.T) { - opt, err := MakeSearchQueryOption(c, WithSearchQueryConsistencyLevel(entity.ClSession)) + opt, err := MakeSearchQueryOption(c.Name, WithSearchQueryConsistencyLevel(entity.ClSession)) assert.Nil(t, err) assert.NotNil(t, opt) expected := &SearchQueryOption{ @@ -71,8 +77,8 @@ func TestMakeSearchQueryOption(t *testing.T) { } assert.Equal(t, expected, opt) - tsm.set(c.ID, 99) - opt, err = MakeSearchQueryOption(c, WithSearchQueryConsistencyLevel(entity.ClSession)) + MetaCache.setSessionTs(c.Name, 99) + opt, err = MakeSearchQueryOption(c.Name, WithSearchQueryConsistencyLevel(entity.ClSession)) assert.Nil(t, err) assert.NotNil(t, opt) expected = &SearchQueryOption{ @@ -83,7 +89,7 @@ func TestMakeSearchQueryOption(t *testing.T) { }) t.Run("bounded consistency", func(t *testing.T) { - opt, err := MakeSearchQueryOption(c, WithSearchQueryConsistencyLevel(entity.ClBounded)) + opt, err := MakeSearchQueryOption(c.Name, WithSearchQueryConsistencyLevel(entity.ClBounded)) assert.Nil(t, err) assert.NotNil(t, opt) expected := &SearchQueryOption{ @@ -94,7 +100,7 @@ func TestMakeSearchQueryOption(t *testing.T) { }) t.Run("eventually consistency", func(t *testing.T) { - opt, err := MakeSearchQueryOption(c, WithSearchQueryConsistencyLevel(entity.ClEventually)) + opt, err := MakeSearchQueryOption(c.Name, WithSearchQueryConsistencyLevel(entity.ClEventually)) assert.Nil(t, err) assert.NotNil(t, opt) expected := &SearchQueryOption{ @@ -105,7 +111,7 @@ func TestMakeSearchQueryOption(t *testing.T) { }) t.Run("customized consistency", func(t *testing.T) { - opt, err := MakeSearchQueryOption(c, WithSearchQueryConsistencyLevel(entity.ClCustomized), WithGuaranteeTimestamp(100)) + opt, err := MakeSearchQueryOption(c.Name, WithSearchQueryConsistencyLevel(entity.ClCustomized), WithGuaranteeTimestamp(100)) assert.Nil(t, err) assert.NotNil(t, opt) expected := &SearchQueryOption{ @@ -116,7 +122,7 @@ func TestMakeSearchQueryOption(t *testing.T) { }) t.Run("guarantee timestamp sanity check", func(t *testing.T) { - _, err := MakeSearchQueryOption(c, WithSearchQueryConsistencyLevel(entity.ClStrong), WithGuaranteeTimestamp(100)) + _, err := MakeSearchQueryOption(c.Name, WithSearchQueryConsistencyLevel(entity.ClStrong), WithGuaranteeTimestamp(100)) assert.Error(t, err) }) } diff --git a/client/client_grpc_row.go b/client/client_grpc_row.go index a3f82594..33bc29c5 100644 --- a/client/client_grpc_row.go +++ b/client/client_grpc_row.go @@ -99,7 +99,7 @@ func (c *grpcClient) InsertByRows(ctx context.Context, collName string, partitio if err := handleRespStatus(resp.GetStatus()); err != nil { return nil, err } - tsm.set(coll.ID, resp.Timestamp) + MetaCache.setSessionTs(collName, resp.Timestamp) // 3. parse id column return entity.IDColumns(resp.GetIDs(), 0, -1) } diff --git a/client/meta_cache.go b/client/meta_cache.go new file mode 100644 index 00000000..530cba29 --- /dev/null +++ b/client/meta_cache.go @@ -0,0 +1,79 @@ +package client + +import ( + "sync" + + "github.com/milvus-io/milvus-sdk-go/v2/entity" +) + +// Magical timestamps for communicating with server +const ( + StrongTimestamp uint64 = 0 + EventuallyTimestamp uint64 = 1 + BoundedTimestamp uint64 = 2 +) + +type collInfo struct { + ID int64 // collection id + Name string // collection name + Schema *entity.Schema // collection schema, with fields schema and primary key definition + ConsistencyLevel entity.ConsistencyLevel +} + +var MetaCache = metaCache{ + sessionTsMap: make(map[string]uint64), + collInfoMap: make(map[string]collInfo), +} + +// timestampMap collects the last-write-timestamp of every collection, which is required by session consistency level. +type metaCache struct { + sessionMu sync.RWMutex + colMu sync.RWMutex + sessionTsMap map[string]uint64 // collectionName -> last-write-timestamp + collInfoMap map[string]collInfo +} + +func (m *metaCache) getSessionTs(cName string) (uint64, bool) { + m.sessionMu.RLock() + defer m.sessionMu.RUnlock() + ts, ok := m.sessionTsMap[cName] + return ts, ok +} + +func (m *metaCache) setSessionTs(cName string, ts uint64) { + m.sessionMu.Lock() + defer m.sessionMu.Unlock() + m.sessionTsMap[cName] = max(m.sessionTsMap[cName], ts) // increase monotonically +} + +func (m *metaCache) setCollectionInfo(cName string, c *collInfo) { + m.colMu.Lock() + defer m.colMu.Unlock() + if c == nil { + delete(m.collInfoMap, cName) + } else { + m.collInfoMap[cName] = *c + } +} + +func (m *metaCache) getCollectionInfo(cName string) (*collInfo, bool) { + m.colMu.RLock() + defer m.colMu.RUnlock() + col, ok := m.collInfoMap[cName] + if !ok { + return nil, false + } + return &collInfo{ + ID: col.ID, + Name: col.Name, + Schema: col.Schema, + ConsistencyLevel: col.ConsistencyLevel, + }, true +} + +func max(x, y uint64) uint64 { + if x > y { + return x + } + return y +} diff --git a/client/meta_cache_test.go b/client/meta_cache_test.go new file mode 100644 index 00000000..54d114bd --- /dev/null +++ b/client/meta_cache_test.go @@ -0,0 +1,55 @@ +package client + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMetaCache(t *testing.T) { + meta := &metaCache{ + sessionTsMap: make(map[string]uint64), + collInfoMap: make(map[string]collInfo), + } + + t.Run("session-ts-get", func(t *testing.T) { + ts, ok := meta.getSessionTs("") + assert.False(t, ok) + assert.Equal(t, uint64(0), ts) + }) + + t.Run("session-ts-set-then-get", func(t *testing.T) { + meta.setSessionTs("0", 1) + ts, ok := meta.getSessionTs("0") + assert.True(t, ok) + assert.Equal(t, uint64(1), ts) + }) + + t.Run("session-ts-monotonic-set", func(t *testing.T) { + meta.setSessionTs("0", 2) + meta.setSessionTs("0", 1) + ts, ok := meta.getSessionTs("0") + assert.True(t, ok) + assert.Equal(t, uint64(2), ts) + }) + + t.Run("info-get", func(t *testing.T) { + info, ok := meta.getCollectionInfo("") + assert.False(t, ok) + assert.Nil(t, info) + }) + + t.Run("info-set-get", func(t *testing.T) { + info1 := &collInfo{ + Name: "aaa", + } + meta.setCollectionInfo(info1.Name, info1) + info2, ok := meta.getCollectionInfo(info1.Name) + assert.Equal(t, info1, info2) + assert.True(t, ok) + meta.setCollectionInfo(info1.Name, nil) + info2, ok = meta.getCollectionInfo(info1.Name) + assert.Nil(t, info2) + assert.False(t, ok) + }) +} diff --git a/client/timestamp_map.go b/client/timestamp_map.go deleted file mode 100644 index a26b11da..00000000 --- a/client/timestamp_map.go +++ /dev/null @@ -1,43 +0,0 @@ -package client - -import ( - "sync" -) - -// Magical timestamps for communicating with server -const ( - StrongTimestamp uint64 = 0 - EventuallyTimestamp uint64 = 1 - BoundedTimestamp uint64 = 2 -) - -// global timestampMap -var tsm = timestampMap{ - m: make(map[int64]uint64), -} - -// timestampMap collects the last-write-timestamp of every collection, which is required by session consistency level. -type timestampMap struct { - mu sync.RWMutex - m map[int64]uint64 // collectionID -> last-write-timestamp -} - -func (tsm *timestampMap) get(cid int64) (uint64, bool) { - tsm.mu.RLock() - defer tsm.mu.RUnlock() - ts, ok := tsm.m[cid] - return ts, ok -} - -func (tsm *timestampMap) set(cid int64, ts uint64) { - tsm.mu.Lock() - defer tsm.mu.Unlock() - tsm.m[cid] = max(tsm.m[cid], ts) // increase monotonically -} - -func max(x, y uint64) uint64 { - if x > y { - return x - } - return y -} diff --git a/client/timestamp_map_test.go b/client/timestamp_map_test.go deleted file mode 100644 index a30891ef..00000000 --- a/client/timestamp_map_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package client - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestTimestampMap(t *testing.T) { - tsm := ×tampMap{m: map[int64]uint64{}} - - t.Run("get", func(t *testing.T) { - ts, ok := tsm.get(0) - assert.False(t, ok) - assert.Equal(t, uint64(0), ts) - }) - - t.Run("set-then-get", func(t *testing.T) { - tsm.set(0, 1) - ts, ok := tsm.get(0) - assert.True(t, ok) - assert.Equal(t, uint64(1), ts) - }) - - t.Run("monotonic-set", func(t *testing.T) { - tsm.set(0, 2) - tsm.set(0, 1) - ts, ok := tsm.get(0) - assert.True(t, ok) - assert.Equal(t, uint64(2), ts) - }) -} diff --git a/examples/query/query.go b/examples/query/query.go index d5f5c348..564e6d48 100644 --- a/examples/query/query.go +++ b/examples/query/query.go @@ -144,7 +144,7 @@ func main() { begin = time.Now() sp, _ = entity.NewIndexFlatSearchParam(10) sRet, err = c.Search(ctx, collectionName, nil, "", []string{randomCol}, vec2search, - embeddingCol, entity.L2, topK, sp, client.WithSearchQueryConsistencyLevel(entity.ClStrong) + embeddingCol, entity.L2, topK, sp, client.WithSearchQueryConsistencyLevel(entity.ClStrong)) end = time.Now() if err != nil { log.Fatalf("failed to search collection, err: %v", err)