From da227ff9a17038fd8b054af42271a17fb4dd9d47 Mon Sep 17 00:00:00 2001 From: aoiasd <45024769+aoiasd@users.noreply.github.com> Date: Thu, 12 Sep 2024 10:43:06 +0800 Subject: [PATCH] feat: Support create collection with functions (#35973) relate: https://github.com/milvus-io/milvus/issues/35853 Support create collection with functions. Prepare for support bm25 function. --------- Signed-off-by: aoiasd --- go.mod | 2 +- go.sum | 4 +- internal/metastore/kv/rootcoord/kv_catalog.go | 50 +++++- .../metastore/kv/rootcoord/kv_catalog_test.go | 65 +++++++- .../kv/rootcoord/rootcoord_constant.go | 1 + internal/metastore/model/collection.go | 2 + internal/metastore/model/collection_test.go | 18 +- internal/metastore/model/field.go | 121 +++++++------- internal/metastore/model/function.go | 120 +++++++++++++ internal/metastore/model/function_test.go | 81 +++++++++ internal/proxy/meta_cache.go | 3 + internal/proxy/meta_cache_test.go | 71 +++++--- internal/proxy/task.go | 38 +++-- internal/proxy/task_index.go | 71 +++++++- internal/proxy/task_test.go | 36 +++- internal/proxy/util.go | 157 +++++++++++++++++- internal/rootcoord/broker.go | 1 + internal/rootcoord/create_collection_task.go | 38 ++++- internal/rootcoord/field_id.go | 3 + internal/rootcoord/root_coord.go | 1 + pkg/common/common.go | 2 + pkg/go.mod | 4 +- pkg/go.sum | 4 +- pkg/util/merr/utils.go | 1 - pkg/util/typeutil/schema.go | 11 ++ 25 files changed, 765 insertions(+), 140 deletions(-) create mode 100644 internal/metastore/model/function.go create mode 100644 internal/metastore/model/function_test.go diff --git a/go.mod b/go.mod index 1ac7ef8d6208c..f5554fbd84d8a 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/klauspost/compress v1.17.7 github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d - github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240822040249-4bbc8f623cbb + github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240909041258-8f8ca67816cd github.com/minio/minio-go/v7 v7.0.61 github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81 github.com/prometheus/client_golang v1.14.0 diff --git a/go.sum b/go.sum index e5c034510d10c..1d8628c50315e 100644 --- a/go.sum +++ b/go.sum @@ -602,8 +602,8 @@ github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119 h1:9VXijWu github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119/go.mod h1:DvXTE/K/RtHehxU8/GtDs4vFtfw64jJ3PaCnFri8CRg= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240822040249-4bbc8f623cbb h1:S3QIkNv9N1Vd1UKtdaQ4yVDPFAwFiPSAjN07axzbR70= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240822040249-4bbc8f623cbb/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240909041258-8f8ca67816cd h1:x0b0+foTe23sKcVFseR1DE8+BB08EH6ViiRHaz8PEik= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240909041258-8f8ca67816cd/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= github.com/milvus-io/pulsar-client-go v0.6.10 h1:eqpJjU+/QX0iIhEo3nhOqMNXL+TyInAs1IAHZCrCM/A= github.com/milvus-io/pulsar-client-go v0.6.10/go.mod h1:lQqCkgwDF8YFYjKA+zOheTk1tev2B+bKj5j7+nm8M1w= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= diff --git a/internal/metastore/kv/rootcoord/kv_catalog.go b/internal/metastore/kv/rootcoord/kv_catalog.go index 8e8e762b1783d..7af3c30485081 100644 --- a/internal/metastore/kv/rootcoord/kv_catalog.go +++ b/internal/metastore/kv/rootcoord/kv_catalog.go @@ -60,6 +60,14 @@ func BuildFieldKey(collectionID typeutil.UniqueID, fieldID int64) string { return fmt.Sprintf("%s/%d", BuildFieldPrefix(collectionID), fieldID) } +func BuildFunctionPrefix(collectionID typeutil.UniqueID) string { + return fmt.Sprintf("%s/%d", FunctionMetaPrefix, collectionID) +} + +func BuildFunctionKey(collectionID typeutil.UniqueID, functionID int64) string { + return fmt.Sprintf("%s/%d", BuildFunctionPrefix(collectionID), functionID) +} + func BuildAliasKey210(alias string) string { return fmt.Sprintf("%s/%s", CollectionAliasMetaPrefix210, alias) } @@ -166,7 +174,7 @@ func (kc *Catalog) CreateCollection(ctx context.Context, coll *model.Collection, kvs := map[string]string{} - // save partition info to newly path. + // save partition info to new path. for _, partition := range coll.Partitions { k := BuildPartitionKey(coll.CollectionID, partition.PartitionID) partitionInfo := model.MarshalPartitionModel(partition) @@ -178,8 +186,7 @@ func (kc *Catalog) CreateCollection(ctx context.Context, coll *model.Collection, } // no default aliases will be created. - - // save fields info to newly path. + // save fields info to new path. for _, field := range coll.Fields { k := BuildFieldKey(coll.CollectionID, field.FieldID) fieldInfo := model.MarshalFieldModel(field) @@ -190,6 +197,17 @@ func (kc *Catalog) CreateCollection(ctx context.Context, coll *model.Collection, kvs[k] = string(v) } + // save functions info to new path. + for _, function := range coll.Functions { + k := BuildFunctionKey(coll.CollectionID, function.ID) + functionInfo := model.MarshalFunctionModel(function) + v, err := proto.Marshal(functionInfo) + if err != nil { + return err + } + kvs[k] = string(v) + } + // Though batchSave is not atomic enough, we can promise the atomicity outside. // Recovering from failure, if we found collection is creating, we should remove all these related meta. // since SnapshotKV may save both snapshot key and the original key if the original key is newest @@ -358,6 +376,24 @@ func (kc *Catalog) listFieldsAfter210(ctx context.Context, collectionID typeutil return fields, nil } +func (kc *Catalog) listFunctions(collectionID typeutil.UniqueID, ts typeutil.Timestamp) ([]*model.Function, error) { + prefix := BuildFunctionPrefix(collectionID) + _, values, err := kc.Snapshot.LoadWithPrefix(prefix, ts) + if err != nil { + return nil, err + } + functions := make([]*model.Function, 0, len(values)) + for _, v := range values { + functionSchema := &schemapb.FunctionSchema{} + err := proto.Unmarshal([]byte(v), functionSchema) + if err != nil { + return nil, err + } + functions = append(functions, model.UnmarshalFunctionModel(functionSchema)) + } + return functions, nil +} + func (kc *Catalog) appendPartitionAndFieldsInfo(ctx context.Context, collMeta *pb.CollectionInfo, ts typeutil.Timestamp, ) (*model.Collection, error) { @@ -379,6 +415,11 @@ func (kc *Catalog) appendPartitionAndFieldsInfo(ctx context.Context, collMeta *p } collection.Fields = fields + functions, err := kc.listFunctions(collection.CollectionID, ts) + if err != nil { + return nil, err + } + collection.Functions = functions return collection, nil } @@ -441,6 +482,9 @@ func (kc *Catalog) DropCollection(ctx context.Context, collectionInfo *model.Col for _, field := range collectionInfo.Fields { delMetakeysSnap = append(delMetakeysSnap, BuildFieldKey(collectionInfo.CollectionID, field.FieldID)) } + for _, function := range collectionInfo.Functions { + delMetakeysSnap = append(delMetakeysSnap, BuildFunctionKey(collectionInfo.CollectionID, function.ID)) + } // delMetakeysSnap = append(delMetakeysSnap, buildPartitionPrefix(collectionInfo.CollectionID)) // delMetakeysSnap = append(delMetakeysSnap, buildFieldPrefix(collectionInfo.CollectionID)) diff --git a/internal/metastore/kv/rootcoord/kv_catalog_test.go b/internal/metastore/kv/rootcoord/kv_catalog_test.go index b58010516aa14..44ee262970588 100644 --- a/internal/metastore/kv/rootcoord/kv_catalog_test.go +++ b/internal/metastore/kv/rootcoord/kv_catalog_test.go @@ -207,8 +207,17 @@ func TestCatalog_ListCollections(t *testing.T) { return strings.HasPrefix(prefix, FieldMetaPrefix) }), ts). Return([]string{"key"}, []string{string(fm)}, nil) - kc := Catalog{Snapshot: kv} + functionMeta := &schemapb.FunctionSchema{} + fcm, err := proto.Marshal(functionMeta) + assert.NoError(t, err) + kv.On("LoadWithPrefix", mock.MatchedBy( + func(prefix string) bool { + return strings.HasPrefix(prefix, FunctionMetaPrefix) + }), ts). + Return([]string{"key"}, []string{string(fcm)}, nil) + + kc := Catalog{Snapshot: kv} ret, err := kc.ListCollections(ctx, testDb, ts) assert.NoError(t, err) assert.NotNil(t, ret) @@ -248,6 +257,16 @@ func TestCatalog_ListCollections(t *testing.T) { return strings.HasPrefix(prefix, FieldMetaPrefix) }), ts). Return([]string{"key"}, []string{string(fm)}, nil) + + functionMeta := &schemapb.FunctionSchema{} + fcm, err := proto.Marshal(functionMeta) + assert.NoError(t, err) + kv.On("LoadWithPrefix", mock.MatchedBy( + func(prefix string) bool { + return strings.HasPrefix(prefix, FunctionMetaPrefix) + }), ts). + Return([]string{"key"}, []string{string(fcm)}, nil) + kv.On("MultiSaveAndRemove", mock.Anything, mock.Anything, ts).Return(nil) kc := Catalog{Snapshot: kv} @@ -1215,6 +1234,22 @@ func TestCatalog_CreateCollection(t *testing.T) { err := kc.CreateCollection(ctx, coll, 100) assert.NoError(t, err) }) + + t.Run("create collection with function", func(t *testing.T) { + mockSnapshot := newMockSnapshot(t, withMockSave(nil), withMockMultiSave(nil)) + kc := &Catalog{Snapshot: mockSnapshot} + ctx := context.Background() + coll := &model.Collection{ + Partitions: []*model.Partition{ + {PartitionName: "test"}, + }, + Fields: []*model.Field{{Name: "text", DataType: schemapb.DataType_VarChar}, {Name: "sparse", DataType: schemapb.DataType_SparseFloatVector}}, + Functions: []*model.Function{{Name: "test", Type: schemapb.FunctionType_BM25, InputFieldNames: []string{"text"}, OutputFieldNames: []string{"sparse"}}}, + State: pb.CollectionState_CollectionCreating, + } + err := kc.CreateCollection(ctx, coll, 100) + assert.NoError(t, err) + }) } func TestCatalog_DropCollection(t *testing.T) { @@ -1281,6 +1316,22 @@ func TestCatalog_DropCollection(t *testing.T) { err := kc.DropCollection(ctx, coll, 100) assert.NoError(t, err) }) + + t.Run("drop collection with function", func(t *testing.T) { + mockSnapshot := newMockSnapshot(t, withMockMultiSaveAndRemove(nil)) + kc := &Catalog{Snapshot: mockSnapshot} + ctx := context.Background() + coll := &model.Collection{ + Partitions: []*model.Partition{ + {PartitionName: "test"}, + }, + Fields: []*model.Field{{Name: "text", DataType: schemapb.DataType_VarChar}, {Name: "sparse", DataType: schemapb.DataType_SparseFloatVector}}, + Functions: []*model.Function{{Name: "test", Type: schemapb.FunctionType_BM25, InputFieldNames: []string{"text"}, OutputFieldNames: []string{"sparse"}}}, + State: pb.CollectionState_CollectionDropping, + } + err := kc.DropCollection(ctx, coll, 100) + assert.NoError(t, err) + }) } func getUserInfoMetaString(username string) string { @@ -2779,3 +2830,15 @@ func TestCatalog_AlterDatabase(t *testing.T) { err = c.AlterDatabase(ctx, newDB, typeutil.ZeroTimestamp) assert.ErrorIs(t, err, mockErr) } + +func TestCatalog_listFunctionError(t *testing.T) { + mockSnapshot := newMockSnapshot(t) + kc := &Catalog{Snapshot: mockSnapshot} + mockSnapshot.EXPECT().LoadWithPrefix(mock.Anything, mock.Anything).Return(nil, nil, fmt.Errorf("mock error")) + _, err := kc.listFunctions(1, 1) + assert.Error(t, err) + + mockSnapshot.EXPECT().LoadWithPrefix(mock.Anything, mock.Anything).Return([]string{"test-key"}, []string{"invalid bytes"}, nil) + _, err = kc.listFunctions(1, 1) + assert.Error(t, err) +} diff --git a/internal/metastore/kv/rootcoord/rootcoord_constant.go b/internal/metastore/kv/rootcoord/rootcoord_constant.go index b250216048a37..c48e6f619529d 100644 --- a/internal/metastore/kv/rootcoord/rootcoord_constant.go +++ b/internal/metastore/kv/rootcoord/rootcoord_constant.go @@ -20,6 +20,7 @@ const ( PartitionMetaPrefix = ComponentPrefix + "/partitions" AliasMetaPrefix = ComponentPrefix + "/aliases" FieldMetaPrefix = ComponentPrefix + "/fields" + FunctionMetaPrefix = ComponentPrefix + "/functions" // CollectionAliasMetaPrefix210 prefix for collection alias meta CollectionAliasMetaPrefix210 = ComponentPrefix + "/collection-alias" diff --git a/internal/metastore/model/collection.go b/internal/metastore/model/collection.go index 66acf68cf248c..13cbac1d3d686 100644 --- a/internal/metastore/model/collection.go +++ b/internal/metastore/model/collection.go @@ -18,6 +18,7 @@ type Collection struct { Description string AutoID bool Fields []*Field + Functions []*Function VirtualChannelNames []string PhysicalChannelNames []string ShardsNum int32 @@ -54,6 +55,7 @@ func (c *Collection) Clone() *Collection { Properties: common.CloneKeyValuePairs(c.Properties), State: c.State, EnableDynamicField: c.EnableDynamicField, + Functions: CloneFunctions(c.Functions), } } diff --git a/internal/metastore/model/collection_test.go b/internal/metastore/model/collection_test.go index 7ddde61e9f495..0dfdc59c42be5 100644 --- a/internal/metastore/model/collection_test.go +++ b/internal/metastore/model/collection_test.go @@ -12,14 +12,16 @@ import ( ) var ( - colID int64 = 1 - colName = "c" - fieldID int64 = 101 - fieldName = "field110" - partID int64 = 20 - partName = "testPart" - tenantID = "tenant-1" - typeParams = []*commonpb.KeyValuePair{ + colID int64 = 1 + colName = "c" + fieldID int64 = 101 + fieldName = "field110" + partID int64 = 20 + partName = "testPart" + tenantID = "tenant-1" + functionID int64 = 1 + functionName = "test-bm25" + typeParams = []*commonpb.KeyValuePair{ { Key: "field110-k1", Value: "field110-v1", diff --git a/internal/metastore/model/field.go b/internal/metastore/model/field.go index 4693d2ba39d6b..49766675c7064 100644 --- a/internal/metastore/model/field.go +++ b/internal/metastore/model/field.go @@ -7,21 +7,22 @@ import ( ) type Field struct { - FieldID int64 - Name string - IsPrimaryKey bool - Description string - DataType schemapb.DataType - TypeParams []*commonpb.KeyValuePair - IndexParams []*commonpb.KeyValuePair - AutoID bool - State schemapb.FieldState - IsDynamic bool - IsPartitionKey bool // partition key mode, multi logic partitions share a physical partition - IsClusteringKey bool - DefaultValue *schemapb.ValueField - ElementType schemapb.DataType - Nullable bool + FieldID int64 + Name string + IsPrimaryKey bool + Description string + DataType schemapb.DataType + TypeParams []*commonpb.KeyValuePair + IndexParams []*commonpb.KeyValuePair + AutoID bool + State schemapb.FieldState + IsDynamic bool + IsPartitionKey bool // partition key mode, multi logic partitions share a physical partition + IsClusteringKey bool + IsFunctionOutput bool + DefaultValue *schemapb.ValueField + ElementType schemapb.DataType + Nullable bool } func (f *Field) Available() bool { @@ -30,21 +31,22 @@ func (f *Field) Available() bool { func (f *Field) Clone() *Field { return &Field{ - FieldID: f.FieldID, - Name: f.Name, - IsPrimaryKey: f.IsPrimaryKey, - Description: f.Description, - DataType: f.DataType, - TypeParams: common.CloneKeyValuePairs(f.TypeParams), - IndexParams: common.CloneKeyValuePairs(f.IndexParams), - AutoID: f.AutoID, - State: f.State, - IsDynamic: f.IsDynamic, - IsPartitionKey: f.IsPartitionKey, - IsClusteringKey: f.IsClusteringKey, - DefaultValue: f.DefaultValue, - ElementType: f.ElementType, - Nullable: f.Nullable, + FieldID: f.FieldID, + Name: f.Name, + IsPrimaryKey: f.IsPrimaryKey, + Description: f.Description, + DataType: f.DataType, + TypeParams: common.CloneKeyValuePairs(f.TypeParams), + IndexParams: common.CloneKeyValuePairs(f.IndexParams), + AutoID: f.AutoID, + State: f.State, + IsDynamic: f.IsDynamic, + IsPartitionKey: f.IsPartitionKey, + IsClusteringKey: f.IsClusteringKey, + IsFunctionOutput: f.IsFunctionOutput, + DefaultValue: f.DefaultValue, + ElementType: f.ElementType, + Nullable: f.Nullable, } } @@ -75,6 +77,7 @@ func (f *Field) Equal(other Field) bool { f.IsClusteringKey == other.IsClusteringKey && f.DefaultValue == other.DefaultValue && f.ElementType == other.ElementType && + f.IsFunctionOutput == other.IsFunctionOutput && f.Nullable == other.Nullable } @@ -97,20 +100,21 @@ func MarshalFieldModel(field *Field) *schemapb.FieldSchema { } return &schemapb.FieldSchema{ - FieldID: field.FieldID, - Name: field.Name, - IsPrimaryKey: field.IsPrimaryKey, - Description: field.Description, - DataType: field.DataType, - TypeParams: field.TypeParams, - IndexParams: field.IndexParams, - AutoID: field.AutoID, - IsDynamic: field.IsDynamic, - IsPartitionKey: field.IsPartitionKey, - IsClusteringKey: field.IsClusteringKey, - DefaultValue: field.DefaultValue, - ElementType: field.ElementType, - Nullable: field.Nullable, + FieldID: field.FieldID, + Name: field.Name, + IsPrimaryKey: field.IsPrimaryKey, + Description: field.Description, + DataType: field.DataType, + TypeParams: field.TypeParams, + IndexParams: field.IndexParams, + AutoID: field.AutoID, + IsDynamic: field.IsDynamic, + IsPartitionKey: field.IsPartitionKey, + IsClusteringKey: field.IsClusteringKey, + IsFunctionOutput: field.IsFunctionOutput, + DefaultValue: field.DefaultValue, + ElementType: field.ElementType, + Nullable: field.Nullable, } } @@ -132,20 +136,21 @@ func UnmarshalFieldModel(fieldSchema *schemapb.FieldSchema) *Field { } return &Field{ - FieldID: fieldSchema.FieldID, - Name: fieldSchema.Name, - IsPrimaryKey: fieldSchema.IsPrimaryKey, - Description: fieldSchema.Description, - DataType: fieldSchema.DataType, - TypeParams: fieldSchema.TypeParams, - IndexParams: fieldSchema.IndexParams, - AutoID: fieldSchema.AutoID, - IsDynamic: fieldSchema.IsDynamic, - IsPartitionKey: fieldSchema.IsPartitionKey, - IsClusteringKey: fieldSchema.IsClusteringKey, - DefaultValue: fieldSchema.DefaultValue, - ElementType: fieldSchema.ElementType, - Nullable: fieldSchema.Nullable, + FieldID: fieldSchema.FieldID, + Name: fieldSchema.Name, + IsPrimaryKey: fieldSchema.IsPrimaryKey, + Description: fieldSchema.Description, + DataType: fieldSchema.DataType, + TypeParams: fieldSchema.TypeParams, + IndexParams: fieldSchema.IndexParams, + AutoID: fieldSchema.AutoID, + IsDynamic: fieldSchema.IsDynamic, + IsPartitionKey: fieldSchema.IsPartitionKey, + IsClusteringKey: fieldSchema.IsClusteringKey, + IsFunctionOutput: fieldSchema.IsFunctionOutput, + DefaultValue: fieldSchema.DefaultValue, + ElementType: fieldSchema.ElementType, + Nullable: fieldSchema.Nullable, } } diff --git a/internal/metastore/model/function.go b/internal/metastore/model/function.go new file mode 100644 index 0000000000000..a817dee4ee10a --- /dev/null +++ b/internal/metastore/model/function.go @@ -0,0 +1,120 @@ +package model + +import ( + "slices" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +type Function struct { + Name string + ID int64 + Description string + + Type schemapb.FunctionType + + InputFieldIDs []int64 + InputFieldNames []string + + OutputFieldIDs []int64 + OutputFieldNames []string + + Params []*commonpb.KeyValuePair +} + +func (f *Function) Clone() *Function { + return &Function{ + Name: f.Name, + ID: f.ID, + Description: f.Description, + Type: f.Type, + + InputFieldIDs: f.InputFieldIDs, + InputFieldNames: f.InputFieldNames, + + OutputFieldIDs: f.OutputFieldIDs, + OutputFieldNames: f.OutputFieldNames, + Params: f.Params, + } +} + +func (f *Function) Equal(other Function) bool { + return f.Name == other.Name && + f.Type == other.Type && + f.Description == other.Description && + slices.Equal(f.InputFieldNames, other.InputFieldNames) && + slices.Equal(f.InputFieldIDs, other.InputFieldIDs) && + slices.Equal(f.OutputFieldNames, other.OutputFieldNames) && + slices.Equal(f.OutputFieldIDs, other.OutputFieldIDs) && + slices.Equal(f.Params, other.Params) +} + +func CloneFunctions(functions []*Function) []*Function { + clone := make([]*Function, len(functions)) + for i, function := range functions { + clone[i] = function.Clone() + } + return functions +} + +func MarshalFunctionModel(function *Function) *schemapb.FunctionSchema { + if function == nil { + return nil + } + + return &schemapb.FunctionSchema{ + Name: function.Name, + Id: function.ID, + Description: function.Description, + Type: function.Type, + InputFieldIds: function.InputFieldIDs, + InputFieldNames: function.InputFieldNames, + OutputFieldIds: function.OutputFieldIDs, + OutputFieldNames: function.OutputFieldNames, + Params: function.Params, + } +} + +func UnmarshalFunctionModel(schema *schemapb.FunctionSchema) *Function { + if schema == nil { + return nil + } + return &Function{ + Name: schema.GetName(), + ID: schema.GetId(), + Description: schema.GetDescription(), + Type: schema.GetType(), + + InputFieldIDs: schema.GetInputFieldIds(), + InputFieldNames: schema.GetInputFieldNames(), + + OutputFieldIDs: schema.GetOutputFieldIds(), + OutputFieldNames: schema.GetOutputFieldNames(), + Params: schema.GetParams(), + } +} + +func MarshalFunctionModels(functions []*Function) []*schemapb.FunctionSchema { + if functions == nil { + return nil + } + + functionSchemas := make([]*schemapb.FunctionSchema, len(functions)) + for idx, function := range functions { + functionSchemas[idx] = MarshalFunctionModel(function) + } + return functionSchemas +} + +func UnmarshalFunctionModels(functions []*schemapb.FunctionSchema) []*Function { + if functions == nil { + return nil + } + + functionSchemas := make([]*Function, len(functions)) + for idx, function := range functions { + functionSchemas[idx] = UnmarshalFunctionModel(function) + } + return functionSchemas +} diff --git a/internal/metastore/model/function_test.go b/internal/metastore/model/function_test.go new file mode 100644 index 0000000000000..86dedbdcce7f7 --- /dev/null +++ b/internal/metastore/model/function_test.go @@ -0,0 +1,81 @@ +package model + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +var ( + functionSchemaPb = &schemapb.FunctionSchema{ + Id: functionID, + Name: functionName, + Type: schemapb.FunctionType_BM25, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{103}, + OutputFieldNames: []string{"sparse"}, + } + + functionModel = &Function{ + ID: functionID, + Name: functionName, + Type: schemapb.FunctionType_BM25, + InputFieldIDs: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIDs: []int64{103}, + OutputFieldNames: []string{"sparse"}, + } +) + +func TestMarshalFunctionModel(t *testing.T) { + ret := MarshalFunctionModel(functionModel) + assert.Equal(t, functionSchemaPb, ret) + assert.Nil(t, MarshalFunctionModel(nil)) +} + +func TestMarshalFunctionModels(t *testing.T) { + ret := MarshalFunctionModels([]*Function{functionModel}) + assert.Equal(t, []*schemapb.FunctionSchema{functionSchemaPb}, ret) + assert.Nil(t, MarshalFunctionModels(nil)) +} + +func TestUnmarshalFunctionModel(t *testing.T) { + ret := UnmarshalFunctionModel(functionSchemaPb) + assert.Equal(t, functionModel, ret) + assert.Nil(t, UnmarshalFunctionModel(nil)) +} + +func TestUnmarshalFunctionModels(t *testing.T) { + ret := UnmarshalFunctionModels([]*schemapb.FunctionSchema{functionSchemaPb}) + assert.Equal(t, []*Function{functionModel}, ret) + assert.Nil(t, UnmarshalFunctionModels(nil)) +} + +func TestFunctionEqual(t *testing.T) { + EqualFunction := Function{ + ID: functionID, + Name: functionName, + Type: schemapb.FunctionType_BM25, + InputFieldIDs: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIDs: []int64{103}, + OutputFieldNames: []string{"sparse"}, + } + + NoEqualFunction := Function{ + ID: functionID, + Name: functionName, + Type: schemapb.FunctionType_BM25, + InputFieldIDs: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIDs: []int64{102}, + OutputFieldNames: []string{"sparse"}, + } + + assert.True(t, functionModel.Equal(EqualFunction)) + assert.True(t, functionModel.Equal(*functionModel.Clone())) + assert.False(t, functionModel.Equal(NoEqualFunction)) +} diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 160ee053b89df..fc83d5b180b56 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -719,6 +719,7 @@ func (m *MetaCache) describeCollection(ctx context.Context, database, collection Description: coll.Schema.Description, AutoID: coll.Schema.AutoID, Fields: make([]*schemapb.FieldSchema, 0), + Functions: make([]*schemapb.FunctionSchema, 0), EnableDynamicField: coll.Schema.EnableDynamicField, }, CollectionID: coll.CollectionID, @@ -735,6 +736,8 @@ func (m *MetaCache) describeCollection(ctx context.Context, database, collection resp.Schema.Fields = append(resp.Schema.Fields, field) } } + + resp.Schema.Functions = append(resp.Schema.Functions, coll.Schema.Functions...) return resp, nil } diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index a1611aacc12e2..3f97a1dca6bbc 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -213,9 +213,10 @@ func TestMetaCache_GetCollection(t *testing.T) { assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.NoError(t, err) assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ - AutoID: true, - Fields: []*schemapb.FieldSchema{}, - Name: "collection1", + AutoID: true, + Fields: []*schemapb.FieldSchema{}, + Functions: []*schemapb.FunctionSchema{}, + Name: "collection1", }) id, err = globalMetaCache.GetCollectionID(ctx, dbName, "collection2") assert.Equal(t, rootCoord.GetAccessCount(), 2) @@ -225,9 +226,10 @@ func TestMetaCache_GetCollection(t *testing.T) { assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.NoError(t, err) assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ - AutoID: true, - Fields: []*schemapb.FieldSchema{}, - Name: "collection2", + AutoID: true, + Fields: []*schemapb.FieldSchema{}, + Functions: []*schemapb.FunctionSchema{}, + Name: "collection2", }) // test to get from cache, this should trigger root request @@ -239,9 +241,10 @@ func TestMetaCache_GetCollection(t *testing.T) { assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.NoError(t, err) assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ - AutoID: true, - Fields: []*schemapb.FieldSchema{}, - Name: "collection1", + AutoID: true, + Fields: []*schemapb.FieldSchema{}, + Functions: []*schemapb.FunctionSchema{}, + Name: "collection1", }) } @@ -298,9 +301,10 @@ func TestMetaCache_GetCollectionName(t *testing.T) { assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.NoError(t, err) assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ - AutoID: true, - Fields: []*schemapb.FieldSchema{}, - Name: "collection1", + AutoID: true, + Fields: []*schemapb.FieldSchema{}, + Functions: []*schemapb.FunctionSchema{}, + Name: "collection1", }) collection, err = globalMetaCache.GetCollectionName(ctx, GetCurDBNameFromContextOrDefault(ctx), 1) assert.Equal(t, rootCoord.GetAccessCount(), 1) @@ -310,9 +314,10 @@ func TestMetaCache_GetCollectionName(t *testing.T) { assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.NoError(t, err) assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ - AutoID: true, - Fields: []*schemapb.FieldSchema{}, - Name: "collection2", + AutoID: true, + Fields: []*schemapb.FieldSchema{}, + Functions: []*schemapb.FunctionSchema{}, + Name: "collection2", }) // test to get from cache, this should trigger root request @@ -324,9 +329,10 @@ func TestMetaCache_GetCollectionName(t *testing.T) { assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.NoError(t, err) assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ - AutoID: true, - Fields: []*schemapb.FieldSchema{}, - Name: "collection1", + AutoID: true, + Fields: []*schemapb.FieldSchema{}, + Functions: []*schemapb.FunctionSchema{}, + Name: "collection1", }) } @@ -349,18 +355,20 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) { schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1") assert.NoError(t, err) assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ - AutoID: true, - Fields: []*schemapb.FieldSchema{}, - Name: "collection1", + AutoID: true, + Fields: []*schemapb.FieldSchema{}, + Functions: []*schemapb.FunctionSchema{}, + Name: "collection1", }) rootCoord.Error = true // should be cached with no error assert.NoError(t, err) assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ - AutoID: true, - Fields: []*schemapb.FieldSchema{}, - Name: "collection1", + AutoID: true, + Fields: []*schemapb.FieldSchema{}, + Functions: []*schemapb.FunctionSchema{}, + Name: "collection1", }) } @@ -422,9 +430,10 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) { schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1") assert.NoError(t, err) assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{ - AutoID: true, - Fields: []*schemapb.FieldSchema{}, - Name: "collection1", + AutoID: true, + Fields: []*schemapb.FieldSchema{}, + Functions: []*schemapb.FunctionSchema{}, + Name: "collection1", }) time.Sleep(10 * time.Millisecond) } @@ -1071,6 +1080,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) { vectorField, dynamicField, }, + Functions: []*schemapb.FunctionSchema{}, }, loadFields: nil, skipDynamicField: false, @@ -1091,6 +1101,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) { dynamicField, clusteringKeyField, }, + Functions: []*schemapb.FunctionSchema{}, }, loadFields: nil, skipDynamicField: false, @@ -1111,6 +1122,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) { dynamicField, clusteringKeyField, }, + Functions: []*schemapb.FunctionSchema{}, }, loadFields: []string{"pk", "part_key", "vector", "clustering_key"}, skipDynamicField: false, @@ -1130,6 +1142,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) { vectorField, dynamicField, }, + Functions: []*schemapb.FunctionSchema{}, }, loadFields: []string{"pk", "part_key", "vector"}, skipDynamicField: true, @@ -1149,6 +1162,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) { vectorField, dynamicField, }, + Functions: []*schemapb.FunctionSchema{}, }, loadFields: []string{"part_key", "vector"}, skipDynamicField: true, @@ -1167,6 +1181,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) { vectorField, dynamicField, }, + Functions: []*schemapb.FunctionSchema{}, }, loadFields: []string{"pk", "vector"}, skipDynamicField: true, @@ -1185,6 +1200,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) { vectorField, dynamicField, }, + Functions: []*schemapb.FunctionSchema{}, }, loadFields: []string{"pk", "part_key"}, skipDynamicField: true, @@ -1203,6 +1219,7 @@ func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) { vectorField, clusteringKeyField, }, + Functions: []*schemapb.FunctionSchema{}, }, loadFields: []string{"pk", "part_key", "vector"}, expectErr: true, diff --git a/internal/proxy/task.go b/internal/proxy/task.go index e91181c328e75..ff5b86bcdbbcb 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -301,6 +301,10 @@ func (t *createCollectionTask) PreExecute(ctx context.Context) error { } t.schema.AutoID = false + if err := validateFunction(t.schema); err != nil { + return err + } + if t.ShardsNum > Params.ProxyCfg.MaxShardNum.GetAsInt32() { return fmt.Errorf("maximum shards's number should be limited to %d", Params.ProxyCfg.MaxShardNum.GetAsInt()) } @@ -632,6 +636,7 @@ func (t *describeCollectionTask) Execute(ctx context.Context) error { Description: "", AutoID: false, Fields: make([]*schemapb.FieldSchema, 0), + Functions: make([]*schemapb.FunctionSchema, 0), }, CollectionID: 0, VirtualChannelNames: nil, @@ -681,23 +686,28 @@ func (t *describeCollectionTask) Execute(ctx context.Context) error { } if field.FieldID >= common.StartOfUserFieldID { t.result.Schema.Fields = append(t.result.Schema.Fields, &schemapb.FieldSchema{ - FieldID: field.FieldID, - Name: field.Name, - IsPrimaryKey: field.IsPrimaryKey, - AutoID: field.AutoID, - Description: field.Description, - DataType: field.DataType, - TypeParams: field.TypeParams, - IndexParams: field.IndexParams, - IsDynamic: field.IsDynamic, - IsPartitionKey: field.IsPartitionKey, - IsClusteringKey: field.IsClusteringKey, - DefaultValue: field.DefaultValue, - ElementType: field.ElementType, - Nullable: field.Nullable, + FieldID: field.FieldID, + Name: field.Name, + IsPrimaryKey: field.IsPrimaryKey, + AutoID: field.AutoID, + Description: field.Description, + DataType: field.DataType, + TypeParams: field.TypeParams, + IndexParams: field.IndexParams, + IsDynamic: field.IsDynamic, + IsPartitionKey: field.IsPartitionKey, + IsClusteringKey: field.IsClusteringKey, + DefaultValue: field.DefaultValue, + ElementType: field.ElementType, + Nullable: field.Nullable, + IsFunctionOutput: field.IsFunctionOutput, }) } } + + for _, function := range result.Schema.Functions { + t.result.Schema.Functions = append(t.result.Schema.Functions, proto.Clone(function).(*schemapb.FunctionSchema)) + } return nil } diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index 6e1281c250261..a9cee3a377470 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -71,6 +71,7 @@ type createIndexTask struct { newExtraParams []*commonpb.KeyValuePair collectionID UniqueID + functionSchema *schemapb.FunctionSchema fieldSchema *schemapb.FieldSchema userAutoIndexMetricTypeSpecified bool } @@ -129,6 +130,48 @@ func wrapUserIndexParams(metricType string) []*commonpb.KeyValuePair { } } +func (cit *createIndexTask) parseFunctionParamsToIndex(indexParamsMap map[string]string) error { + if !cit.fieldSchema.GetIsFunctionOutput() { + return nil + } + + switch cit.functionSchema.GetType() { + case schemapb.FunctionType_BM25: + for _, kv := range cit.functionSchema.GetParams() { + switch kv.GetKey() { + case "bm25_k1": + if _, ok := indexParamsMap["bm25_k1"]; !ok { + indexParamsMap["bm25_k1"] = kv.GetValue() + } + case "bm25_b": + if _, ok := indexParamsMap["bm25_b"]; !ok { + indexParamsMap["bm25_b"] = kv.GetValue() + } + case "bm25_avgdl": + if _, ok := indexParamsMap["bm25_avgdl"]; !ok { + indexParamsMap["bm25_avgdl"] = kv.GetValue() + } + } + } + // set default avgdl + if _, ok := indexParamsMap["bm25_k1"]; !ok { + indexParamsMap["bm25_k1"] = "1.2" + } + + if _, ok := indexParamsMap["bm25_b"]; !ok { + indexParamsMap["bm25_b"] = "0.75" + } + + if _, ok := indexParamsMap["bm25_avgdl"]; !ok { + indexParamsMap["bm25_avgdl"] = "100" + } + default: + return fmt.Errorf("parse unknown type function params to index") + } + + return nil +} + func (cit *createIndexTask) parseIndexParams() error { cit.newExtraParams = cit.req.GetExtraParams() @@ -149,6 +192,11 @@ func (cit *createIndexTask) parseIndexParams() error { } } + // fill index param for bm25 function + if err := cit.parseFunctionParamsToIndex(indexParamsMap); err != nil { + return err + } + if err := ValidateAutoIndexMmapConfig(isVecIndex, indexParamsMap); err != nil { return err } @@ -353,18 +401,29 @@ func (cit *createIndexTask) parseIndexParams() error { return nil } -func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.FieldSchema, error) { +func (cit *createIndexTask) getIndexedFieldAndFunction(ctx context.Context) error { schema, err := globalMetaCache.GetCollectionSchema(ctx, cit.req.GetDbName(), cit.req.GetCollectionName()) if err != nil { log.Error("failed to get collection schema", zap.Error(err)) - return nil, fmt.Errorf("failed to get collection schema: %s", err) + return fmt.Errorf("failed to get collection schema: %s", err) } + field, err := schema.schemaHelper.GetFieldFromName(cit.req.GetFieldName()) if err != nil { log.Error("create index on non-exist field", zap.Error(err)) - return nil, fmt.Errorf("cannot create index on non-exist field: %s", cit.req.GetFieldName()) + return fmt.Errorf("cannot create index on non-exist field: %s", cit.req.GetFieldName()) } - return field, nil + + if field.IsFunctionOutput { + function, err := schema.schemaHelper.GetFunctionByOutputField(field) + if err != nil { + log.Error("create index failed, cannot find function of function output field", zap.Error(err)) + return fmt.Errorf("create index failed, cannot find function of function output field: %s", cit.req.GetFieldName()) + } + cit.functionSchema = function + } + cit.fieldSchema = field + return nil } func fillDimension(field *schemapb.FieldSchema, indexParams map[string]string) error { @@ -452,11 +511,11 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error { return err } - field, err := cit.getIndexedField(ctx) + err = cit.getIndexedFieldAndFunction(ctx) if err != nil { return err } - cit.fieldSchema = field + // check index param, not accurate, only some static rules err = cit.parseIndexParams() if err != nil { diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 6240feae49eb1..cb1de2b0351c3 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -2170,7 +2170,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { }) } -func Test_createIndexTask_getIndexedField(t *testing.T) { +func Test_createIndexTask_getIndexedFieldAndFunction(t *testing.T) { collectionName := "test" fieldName := "test" @@ -2224,9 +2224,9 @@ func Test_createIndexTask_getIndexedField(t *testing.T) { }), nil) globalMetaCache = cache - field, err := cit.getIndexedField(context.Background()) + err := cit.getIndexedFieldAndFunction(context.Background()) assert.NoError(t, err) - assert.Equal(t, fieldName, field.GetName()) + assert.Equal(t, fieldName, cit.fieldSchema.GetName()) }) t.Run("schema not found", func(t *testing.T) { @@ -2237,7 +2237,7 @@ func Test_createIndexTask_getIndexedField(t *testing.T) { mock.AnythingOfType("string"), ).Return(nil, errors.New("mock")) globalMetaCache = cache - _, err := cit.getIndexedField(context.Background()) + err := cit.getIndexedFieldAndFunction(context.Background()) assert.Error(t, err) }) @@ -2256,7 +2256,7 @@ func Test_createIndexTask_getIndexedField(t *testing.T) { }, }), nil) globalMetaCache = cache - _, err := cit.getIndexedField(context.Background()) + err := cit.getIndexedFieldAndFunction(context.Background()) assert.Error(t, err) }) } @@ -3128,6 +3128,10 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) { }, }, } + sparseVecField := &schemapb.FieldSchema{ + Name: "sparse", + DataType: schemapb.DataType_SparseFloatVector, + } partitionKeyField := &schemapb.FieldSchema{ Name: "partition_key", DataType: schemapb.DataType_Int64, @@ -3236,6 +3240,28 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) { task.Schema = marshaledSchema err = task.PreExecute(ctx) assert.NoError(t, err) + + // test schema with function + // invalid function + schema.Functions = []*schemapb.FunctionSchema{ + {Name: "test", Type: schemapb.FunctionType_BM25, InputFieldNames: []string{"invalid name"}}, + } + marshaledSchema, err = proto.Marshal(schema) + assert.NoError(t, err) + task.Schema = marshaledSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + // normal case + schema.Fields = append(schema.Fields, sparseVecField) + schema.Functions = []*schemapb.FunctionSchema{ + {Name: "test", Type: schemapb.FunctionType_BM25, InputFieldNames: []string{varCharField.Name}, OutputFieldNames: []string{sparseVecField.Name}}, + } + marshaledSchema, err = proto.Marshal(schema) + assert.NoError(t, err) + task.Schema = marshaledSchema + err = task.PreExecute(ctx) + assert.NoError(t, err) }) t.Run("Execute", func(t *testing.T) { diff --git a/internal/proxy/util.go b/internal/proxy/util.go index fe90239b25eac..56be896626fc5 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -25,6 +25,7 @@ import ( "time" "github.com/cockroachdb/errors" + "github.com/samber/lo" "go.uber.org/zap" "golang.org/x/crypto/bcrypt" "google.golang.org/grpc/metadata" @@ -609,6 +610,143 @@ func validateSchema(coll *schemapb.CollectionSchema) error { return nil } +func validateFunction(coll *schemapb.CollectionSchema) error { + nameMap := lo.SliceToMap(coll.GetFields(), func(field *schemapb.FieldSchema) (string, *schemapb.FieldSchema) { + return field.GetName(), field + }) + usedOutputField := typeutil.NewSet[string]() + usedFunctionName := typeutil.NewSet[string]() + // validate function + for _, function := range coll.GetFunctions() { + if usedFunctionName.Contain(function.GetName()) { + return fmt.Errorf("duplicate function name %s", function.GetName()) + } + + usedFunctionName.Insert(function.GetName()) + inputFields := []*schemapb.FieldSchema{} + for _, name := range function.GetInputFieldNames() { + inputField, ok := nameMap[name] + if !ok { + return fmt.Errorf("function input field not found %s", function.InputFieldNames) + } + inputFields = append(inputFields, inputField) + } + + err := checkFunctionInputField(function, inputFields) + if err != nil { + return err + } + + outputFields := make([]*schemapb.FieldSchema, len(function.GetOutputFieldNames())) + for i, name := range function.GetOutputFieldNames() { + outputField, ok := nameMap[name] + if !ok { + return fmt.Errorf("function output field not found %s", function.InputFieldNames) + } + outputField.IsFunctionOutput = true + outputFields[i] = outputField + if usedOutputField.Contain(name) { + return fmt.Errorf("duplicate function output %s", name) + } + usedOutputField.Insert(name) + } + + if err := checkFunctionOutputField(function, outputFields); err != nil { + return err + } + + if err := checkFunctionParams(function); err != nil { + return err + } + } + return nil +} + +func checkFunctionOutputField(function *schemapb.FunctionSchema, fields []*schemapb.FieldSchema) error { + switch function.GetType() { + case schemapb.FunctionType_BM25: + if len(fields) != 1 { + return fmt.Errorf("bm25 only need 1 output field, but now %d", len(fields)) + } + + if !typeutil.IsSparseFloatVectorType(fields[0].GetDataType()) { + return fmt.Errorf("bm25 only need sparse embedding output field, but now %s", fields[0].DataType.String()) + } + + if fields[0].GetIsPrimaryKey() { + return fmt.Errorf("bm25 output field can't be primary key") + } + + if fields[0].GetIsPartitionKey() || fields[0].GetIsClusteringKey() { + return fmt.Errorf("bm25 output field can't be partition key or cluster key field") + } + default: + return fmt.Errorf("check output field for unknown function type") + } + return nil +} + +func checkFunctionInputField(function *schemapb.FunctionSchema, fields []*schemapb.FieldSchema) error { + switch function.GetType() { + case schemapb.FunctionType_BM25: + if len(fields) != 1 || fields[0].DataType != schemapb.DataType_VarChar { + return fmt.Errorf("only one VARCHAR input field is allowed for a BM25 Function, got %d field with type %s", + len(fields), fields[0].DataType.String()) + } + + default: + return fmt.Errorf("check input field with unknown function type") + } + return nil +} + +func checkFunctionParams(function *schemapb.FunctionSchema) error { + switch function.GetType() { + case schemapb.FunctionType_BM25: + for _, kv := range function.GetParams() { + switch kv.GetKey() { + case "bm25_k1": + k1, err := strconv.ParseFloat(kv.GetValue(), 64) + if err != nil { + return fmt.Errorf("failed to parse bm25_k1 value, %w", err) + } + + if k1 < 0 || k1 > 3 { + return fmt.Errorf("bm25_k1 must in [0,3] but now %f", k1) + } + + case "bm25_b": + b, err := strconv.ParseFloat(kv.GetValue(), 64) + if err != nil { + return fmt.Errorf("failed to parse bm25_b value, %w", err) + } + + if b < 0 || b > 1 { + return fmt.Errorf("bm25_b must in [0,1] but now %f", b) + } + + case "bm25_avgdl": + avgdl, err := strconv.ParseFloat(kv.GetValue(), 64) + if err != nil { + return fmt.Errorf("failed to parse bm25_avgdl value, %w", err) + } + + if avgdl <= 0 { + return fmt.Errorf("bm25_avgdl must large than zero but now %f", avgdl) + } + + case "analyzer_params": + // TODO ADD tokenizer check + default: + return fmt.Errorf("invalid function params, key: %s, value:%s", kv.GetKey(), kv.GetValue()) + } + } + default: + return fmt.Errorf("check function params with unknown function type") + } + return nil +} + // validateMultipleVectorFields check if schema has multiple vector fields. func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error { vecExist := false @@ -754,13 +892,19 @@ func autoGenDynamicFieldData(data [][]byte) *schemapb.FieldData { // fillFieldIDBySchema set fieldID to fieldData according FieldSchemas func fillFieldIDBySchema(columns []*schemapb.FieldData, schema *schemapb.CollectionSchema) error { - if len(columns) != len(schema.GetFields()) { - return fmt.Errorf("len(columns) mismatch the len(fields), len(columns): %d, len(fields): %d", - len(columns), len(schema.GetFields())) - } fieldName2Schema := make(map[string]*schemapb.FieldSchema) + + expectColumnNum := 0 for _, field := range schema.GetFields() { fieldName2Schema[field.Name] = field + if !field.GetIsFunctionOutput() { + expectColumnNum++ + } + } + + if len(columns) != expectColumnNum { + return fmt.Errorf("len(columns) mismatch the expectColumnNum, expectColumnNum: %d, len(columns): %d", + expectColumnNum, len(columns)) } for _, fieldData := range columns { @@ -1211,15 +1355,16 @@ func checkFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgst if fieldSchema.GetDefaultValue() != nil && fieldSchema.IsPrimaryKey { return merr.WrapErrParameterInvalidMsg("primary key can't be with default value") } - if fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert { + if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || fieldSchema.GetIsFunctionOutput() { // when inInsert, no need to pass when pk is autoid and SkipAutoIDCheck is false autoGenFieldNum++ } if _, ok := dataNameSet[fieldSchema.GetName()]; !ok { - if fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert { + if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || fieldSchema.GetIsFunctionOutput() { // autoGenField continue } + if fieldSchema.GetDefaultValue() == nil && !fieldSchema.GetNullable() { log.Warn("no corresponding fieldData pass in", zap.String("fieldSchema", fieldSchema.GetName())) return merr.WrapErrParameterInvalidMsg("fieldSchema(%s) has no corresponding fieldData pass in", fieldSchema.GetName()) diff --git a/internal/rootcoord/broker.go b/internal/rootcoord/broker.go index edd3bc0525faf..c4811b3550a9a 100644 --- a/internal/rootcoord/broker.go +++ b/internal/rootcoord/broker.go @@ -248,6 +248,7 @@ func (b *ServerBroker) BroadcastAlteredCollection(ctx context.Context, req *milv Description: colMeta.Description, AutoID: colMeta.AutoID, Fields: model.MarshalFieldModels(colMeta.Fields), + Functions: model.MarshalFunctionModels(colMeta.Functions), }, PartitionIDs: partitionIDs, StartPositions: colMeta.StartPositions, diff --git a/internal/rootcoord/create_collection_task.go b/internal/rootcoord/create_collection_task.go index e625ba0315d1c..e61ab868a1863 100644 --- a/internal/rootcoord/create_collection_task.go +++ b/internal/rootcoord/create_collection_task.go @@ -259,10 +259,34 @@ func (t *createCollectionTask) validateSchema(schema *schemapb.CollectionSchema) return validateFieldDataType(schema) } -func (t *createCollectionTask) assignFieldID(schema *schemapb.CollectionSchema) { - for idx := range schema.GetFields() { - schema.Fields[idx].FieldID = int64(idx + StartOfUserFieldID) +func (t *createCollectionTask) assignFieldAndFunctionID(schema *schemapb.CollectionSchema) error { + name2id := map[string]int64{} + for idx, field := range schema.GetFields() { + field.FieldID = int64(idx + StartOfUserFieldID) + name2id[field.GetName()] = field.GetFieldID() + } + + for fidx, function := range schema.GetFunctions() { + function.InputFieldIds = make([]int64, len(function.InputFieldNames)) + function.Id = int64(fidx) + StartOfUserFunctionID + for idx, name := range function.InputFieldNames { + fieldId, ok := name2id[name] + if !ok { + return fmt.Errorf("input field %s of function %s not found", name, function.GetName()) + } + function.InputFieldIds[idx] = fieldId + } + + function.OutputFieldIds = make([]int64, len(function.OutputFieldNames)) + for idx, name := range function.OutputFieldNames { + fieldId, ok := name2id[name] + if !ok { + return fmt.Errorf("output field %s of function %s not found", name, function.GetName()) + } + function.OutputFieldIds[idx] = fieldId + } } + return nil } func (t *createCollectionTask) appendDynamicField(schema *schemapb.CollectionSchema) { @@ -303,7 +327,11 @@ func (t *createCollectionTask) prepareSchema() error { return err } t.appendDynamicField(&schema) - t.assignFieldID(&schema) + + if err := t.assignFieldAndFunctionID(&schema); err != nil { + return err + } + t.appendSysFields(&schema) t.schema = &schema return nil @@ -540,6 +568,7 @@ func (t *createCollectionTask) Execute(ctx context.Context) error { Description: t.schema.Description, AutoID: t.schema.AutoID, Fields: model.UnmarshalFieldModels(t.schema.Fields), + Functions: model.UnmarshalFunctionModels(t.schema.Functions), VirtualChannelNames: vchanNames, PhysicalChannelNames: chanNames, ShardsNum: t.Req.ShardsNum, @@ -609,6 +638,7 @@ func (t *createCollectionTask) Execute(ctx context.Context) error { Description: collInfo.Description, AutoID: collInfo.AutoID, Fields: model.MarshalFieldModels(collInfo.Fields), + Functions: model.MarshalFunctionModels(collInfo.Functions), }, }, }, &nullStep{}) diff --git a/internal/rootcoord/field_id.go b/internal/rootcoord/field_id.go index 99ae1b2160d5b..af17552970ec8 100644 --- a/internal/rootcoord/field_id.go +++ b/internal/rootcoord/field_id.go @@ -29,6 +29,9 @@ const ( // StartOfUserFieldID id of user defined field begin from here StartOfUserFieldID = common.StartOfUserFieldID + // StartOfUserFunctionID id of user defined function begin from here + StartOfUserFunctionID = common.StartOfUserFunctionID + // RowIDField id of row ID field RowIDField = common.RowIDField diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index 7484b195a81a3..9ea9b830e283c 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -1124,6 +1124,7 @@ func convertModelToDesc(collInfo *model.Collection, aliases []string, dbName str Description: collInfo.Description, AutoID: collInfo.AutoID, Fields: model.MarshalFieldModels(collInfo.Fields), + Functions: model.MarshalFunctionModels(collInfo.Functions), EnableDynamicField: collInfo.EnableDynamicField, } resp.CollectionID = collInfo.CollectionID diff --git a/pkg/common/common.go b/pkg/common/common.go index 7cd02083285e7..94f361da4a316 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -42,6 +42,8 @@ const ( // StartOfUserFieldID represents the starting ID of the user-defined field StartOfUserFieldID = 100 + // StartOfUserFunctionID represents the starting ID of the user-defined function + StartOfUserFunctionID = 100 // RowIDField is the ID of the RowID field reserved by the system RowIDField = 0 diff --git a/pkg/go.mod b/pkg/go.mod index fba4d00bf1e1f..1f805db350bde 100644 --- a/pkg/go.mod +++ b/pkg/go.mod @@ -11,10 +11,11 @@ require ( github.com/confluentinc/confluent-kafka-go v1.9.1 github.com/containerd/cgroups/v3 v3.0.3 github.com/expr-lang/expr v1.15.7 + github.com/golang/protobuf v1.5.4 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/json-iterator/go v1.1.12 github.com/klauspost/compress v1.17.7 - github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240822040249-4bbc8f623cbb + github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240909041258-8f8ca67816cd github.com/nats-io/nats-server/v2 v2.10.12 github.com/nats-io/nats.go v1.34.1 github.com/panjf2000/ants/v2 v2.7.2 @@ -93,7 +94,6 @@ require ( github.com/godbus/dbus/v5 v5.0.4 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect - github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/uuid v1.6.0 // indirect diff --git a/pkg/go.sum b/pkg/go.sum index b81ae8d933c9f..c21c42eb4f067 100644 --- a/pkg/go.sum +++ b/pkg/go.sum @@ -494,8 +494,8 @@ github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119 h1:9VXijWu github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119/go.mod h1:DvXTE/K/RtHehxU8/GtDs4vFtfw64jJ3PaCnFri8CRg= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240822040249-4bbc8f623cbb h1:S3QIkNv9N1Vd1UKtdaQ4yVDPFAwFiPSAjN07axzbR70= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240822040249-4bbc8f623cbb/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240909041258-8f8ca67816cd h1:x0b0+foTe23sKcVFseR1DE8+BB08EH6ViiRHaz8PEik= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240909041258-8f8ca67816cd/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= github.com/milvus-io/pulsar-client-go v0.6.10 h1:eqpJjU+/QX0iIhEo3nhOqMNXL+TyInAs1IAHZCrCM/A= github.com/milvus-io/pulsar-client-go v0.6.10/go.mod h1:lQqCkgwDF8YFYjKA+zOheTk1tev2B+bKj5j7+nm8M1w= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index e940651dd5d8f..2ad30564e570b 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -323,7 +323,6 @@ func WrapErrAsInputErrorWhen(err error, targets ...milvusError) error { if target.errCode == merr.errCode { log.Info("mark error as input error", zap.Error(err)) WithErrorType(InputError)(&merr) - log.Info("test--", zap.String("type", merr.errType.String())) return merr } } diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index 9f040c463fa3f..89d6690595e91 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -429,6 +429,17 @@ func (helper *SchemaHelper) GetVectorDimFromID(fieldID int64) (int, error) { return 0, fmt.Errorf("fieldID(%d) not has dim", fieldID) } +func (helper *SchemaHelper) GetFunctionByOutputField(field *schemapb.FieldSchema) (*schemapb.FunctionSchema, error) { + for _, function := range helper.schema.GetFunctions() { + for _, id := range function.GetOutputFieldIds() { + if field.GetFieldID() == id { + return function, nil + } + } + } + return nil, fmt.Errorf("function not exist") +} + func IsBinaryVectorType(dataType schemapb.DataType) bool { return dataType == schemapb.DataType_BinaryVector }