Skip to content

Commit

Permalink
Add upsert test cases
Browse files Browse the repository at this point in the history
Signed-off-by: ThreadDao <[email protected]>
  • Loading branch information
ThreadDao committed Nov 7, 2023
1 parent 18efb15 commit 043c4d7
Show file tree
Hide file tree
Showing 6 changed files with 640 additions and 47 deletions.
9 changes: 9 additions & 0 deletions test/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ func GenLongString(n int) string {
return builder.String()
}

func ColumnIndexFunc(data []entity.Column, fieldName string) int {
for index, column := range data {
if column.Name() == fieldName {
return index
}
}
return -1
}

// --- common utils ---

// --- gen fields ---
Expand Down
3 changes: 1 addition & 2 deletions test/testcases/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ func TestInsertAutoIdPkData(t *testing.T) {
// insert
pkColumn, floatColumn, vecColumn := common.GenDefaultColumnData(0, common.DefaultNb, common.DefaultDim)
_, errInsert := mc.Insert(ctx, collName, "", pkColumn, floatColumn, vecColumn)
//TODO change to check error code
common.CheckErr(t, errInsert, false, "invalid parameter") //, "can not assign primary field data when auto id enabled")
common.CheckErr(t, errInsert, false, "the length of passed fields is equal to needed: expected=2, actual=3: invalid parameter")

// flush and check row count
errFlush := mc.Flush(ctx, collName, false)
Expand Down
114 changes: 74 additions & 40 deletions test/testcases/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,27 +242,14 @@ func createVarcharCollectionWithDataIndex(ctx context.Context, t *testing.T, mc
return collName, ids
}

type CollectionFieldsType string

const (
Int64FloatVec CollectionFieldsType = "PkInt64FloatVec" // int64 + float + floatVec
Int64BinaryVec CollectionFieldsType = "Int64BinaryVec" // int64 + float + binaryVec
VarcharBinaryVec CollectionFieldsType = "PkVarcharBinaryVec" // varchar + binaryVec
Int64FloatVecJSON CollectionFieldsType = "PkInt64FloatVecJson" // int64 + float + floatVec + json
AllFields CollectionFieldsType = "AllFields" // all scalar fields + floatVec
CustomerFields CollectionFieldsType = "CustomerFields" // customer fields
)

type CollectionParams struct {
CollectionFieldsType CollectionFieldsType // collection fields type
AutoID bool // autoId
EnableDynamicField bool // enable dynamic field
ShardsNum int32
Fields []*entity.Field
Dim int64
MaxLength int64
}

func createCollection(ctx context.Context, t *testing.T, mc *base.MilvusClient, cp CollectionParams, opts ...client.CreateCollectionOption) string {
collName := common.GenRandomString(4)
var fields []*entity.Field
Expand All @@ -282,8 +269,6 @@ func createCollection(ctx context.Context, t *testing.T, mc *base.MilvusClient,
fields = append(fields, jsonField)
case AllFields:
fields = common.GenAllFields()
case CustomerFields:
fields = cp.Fields
}

// schema
Expand All @@ -300,19 +285,7 @@ func createCollection(ctx context.Context, t *testing.T, mc *base.MilvusClient,
return collName
}

type DataParams struct {
CollectionName string // insert data into which collection
PartitionName string
CollectionFieldsType CollectionFieldsType // collection fields type
start int // start
nb int // insert how many data
dim int64
EnableDynamicField bool // whether insert dynamic field data
WithRows bool
Data []entity.Column
Rows []interface{}
}

// insert nb data
func insertData(ctx context.Context, t *testing.T, mc *base.MilvusClient, dp DataParams) (entity.Column, error) {
// todo autoid
// prepare data
Expand Down Expand Up @@ -360,12 +333,6 @@ func insertData(ctx context.Context, t *testing.T, mc *base.MilvusClient, dp Dat
rows = common.GenAllFieldsRows(dp.start, dp.nb, dp.dim, dp.EnableDynamicField)
}
data = common.GenAllFieldsData(dp.start, dp.nb, dp.dim)
case CustomerFields:
if dp.WithRows {
rows = dp.Rows
} else {
data = dp.Data
}
}

if dp.EnableDynamicField && !dp.WithRows {
Expand Down Expand Up @@ -446,12 +413,6 @@ func createCollectionAllFields(ctx context.Context, t *testing.T, mc *base.Milvu
return collName, ids
}

type HelpPartitionColumns struct {
PartitionName string
IdsColumn entity.Column
VectorColumn entity.Column
}

func createInsertTwoPartitions(ctx context.Context, t *testing.T, mc *base.MilvusClient, collName string, nb int) (partitionName string, defaultPartition HelpPartitionColumns, newPartition HelpPartitionColumns) {
// create new partition
partitionName = "new"
Expand Down Expand Up @@ -486,6 +447,79 @@ func createInsertTwoPartitions(ctx context.Context, t *testing.T, mc *base.Milvu
return partitionName, defaultPartition, newPartition
}

// prepare collection, maybe data index and load
func prepareCollection(ctx context.Context, t *testing.T, mc *base.MilvusClient, collParam CollectionParams, opts ...PrepareCollectionOption) string {
// default insert nb entities with 0 start
defaultDp := DataParams{DoInsert: true, CollectionName: "", PartitionName: "", CollectionFieldsType: collParam.CollectionFieldsType,
start: 0, nb: common.DefaultNb, dim: collParam.Dim, EnableDynamicField: collParam.EnableDynamicField, WithRows: false}

// default do flush
defaultFp := FlushParams{DoFlush: true, PartitionNames: []string{}, async: false}

// default build index
idx, err := entity.NewIndexHNSW(entity.L2, 8, 96)
common.CheckErr(t, err, true)
defaultIndexParams := IndexParams{BuildIndex: true, Index: idx, FieldName: common.DefaultFloatVecFieldName, async: false}

// default load collection
defaultLp := LoadParams{DoLoad: true, async: false}
opt := &ClientParamsOption{
DataParams: defaultDp,
FlushParams: defaultFp,
IndexParams: defaultIndexParams,
LoadParams: defaultLp,
}
for _, o := range opts {
o(opt)
}
// create collection
collName := createCollection(ctx, t, mc, collParam, opt.CreateOpts)

// insert
if opt.DataParams.DoInsert {
if opt.DataParams.EnableDynamicField != collParam.EnableDynamicField {
t.Fatalf("The EnableDynamicField of CollectionParams and DataParams should be equal.")
}
opt.DataParams.CollectionName = collName
opt.DataParams.CollectionFieldsType = collParam.CollectionFieldsType
insertData(ctx, t, mc, opt.DataParams)
}

// flush
if opt.FlushParams.DoFlush {
err := mc.Flush(ctx, collName, opt.FlushParams.async)
common.CheckErr(t, err, true)
}

// index
if opt.IndexParams.BuildIndex {
var err error
if opt.IndexOpts == nil {
err = mc.CreateIndex(ctx, collName, opt.IndexParams.FieldName, opt.IndexParams.Index, opt.IndexParams.async)
} else {
err = mc.CreateIndex(ctx, collName, opt.IndexParams.FieldName, opt.IndexParams.Index, opt.IndexParams.async, opt.IndexOpts)
}
common.CheckErr(t, err, true)
}

// load
if opt.LoadParams.DoLoad {
var err error
if len(opt.LoadParams.PartitionNames) > 0 {
err = mc.LoadPartitions(ctx, collName, opt.LoadParams.PartitionNames, opt.LoadParams.async)
common.CheckErr(t, err, true)
} else {
if opt.LoadOpts != nil {
err = mc.LoadCollection(ctx, collName, opt.LoadParams.async, opt.LoadOpts)
} else {
err = mc.LoadCollection(ctx, collName, opt.LoadParams.async)
}
common.CheckErr(t, err, true)
}
}
return collName
}

func TestMain(m *testing.M) {
flag.Parse()
log.Printf("parse addr=%s", *addr)
Expand Down
116 changes: 116 additions & 0 deletions test/testcases/option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package testcases

import (
"github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

type HelpPartitionColumns struct {
PartitionName string
IdsColumn entity.Column
VectorColumn entity.Column
}

type CollectionFieldsType string

type CollectionParams struct {
CollectionFieldsType CollectionFieldsType // collection fields type
AutoID bool // autoId
EnableDynamicField bool // enable dynamic field
ShardsNum int32
Dim int64
MaxLength int64
}

type DataParams struct {
CollectionName string // insert data into which collection
PartitionName string
CollectionFieldsType CollectionFieldsType // collection fields type
start int // start
nb int // insert how many data
dim int64
EnableDynamicField bool // whether insert dynamic field data
WithRows bool
DoInsert bool
}

func (d DataParams) IsEmpty() bool {
return d.CollectionName == "" || d.nb == 0
}

type FlushParams struct {
DoFlush bool
PartitionNames []string
async bool
}

type IndexParams struct {
BuildIndex bool
Index entity.Index
FieldName string
async bool
}

func (i IndexParams) IsEmpty() bool {
return i.Index == nil || i.FieldName == ""
}

type LoadParams struct {
DoLoad bool
PartitionNames []string
async bool
}

type ClientParamsOption struct {
DataParams DataParams
FlushParams FlushParams
IndexParams IndexParams
LoadParams LoadParams
CreateOpts client.CreateCollectionOption
IndexOpts client.IndexOption
LoadOpts client.LoadCollectionOption
}

type PrepareCollectionOption func(opt *ClientParamsOption)

func WithDataParams(dp DataParams) PrepareCollectionOption {
return func(opt *ClientParamsOption) {
opt.DataParams = dp
}
}

func WithFlushParams(fp FlushParams) PrepareCollectionOption {
return func(opt *ClientParamsOption) {
opt.FlushParams = fp
}
}

func WithIndexParams(ip IndexParams) PrepareCollectionOption {
return func(opt *ClientParamsOption) {
opt.IndexParams = ip
}
}

func WithLoadParams(lp LoadParams) PrepareCollectionOption {
return func(opt *ClientParamsOption) {
opt.LoadParams = lp
}
}

func WithCreateOption(createOpts client.CreateCollectionOption) PrepareCollectionOption {
return func(opt *ClientParamsOption) {
opt.CreateOpts = createOpts
}
}

func WithIndexOption(indexOpts client.IndexOption) PrepareCollectionOption {
return func(opt *ClientParamsOption) {
opt.IndexOpts = indexOpts
}
}

func WithLoadOption(loadOpts client.LoadCollectionOption) PrepareCollectionOption {
return func(opt *ClientParamsOption) {
opt.LoadOpts = loadOpts
}
}
10 changes: 5 additions & 5 deletions test/testcases/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ func TestQueryOutputInvalidOutputFieldCount(t *testing.T) {
}

// test query count* after insert -> delete -> upsert -> compact
func TestQueryCountAfterDml(t *testing.T) {
func TestQueryCountAfterDml(t *testing.T) {
ctx := createContext(t, time.Second*common.DefaultTimeout)
// connect
mc := createMilvusClient(ctx, t)
Expand Down Expand Up @@ -872,7 +872,7 @@ func TestQueryCountAfterDml(t *testing.T) {
start: common.DefaultNb, nb: insertNb, dim: common.DefaultDim, EnableDynamicField: true}
insertData(ctx, t, mc, dpInsert)
countAfterInsert, _ := mc.Query(ctx, collName, []string{common.DefaultPartition}, "", []string{common.QueryCountFieldName})
require.Equal(t, int64(common.DefaultNb + insertNb), countAfterInsert.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])
require.Equal(t, int64(common.DefaultNb+insertNb), countAfterInsert.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])

// delete 1000 entities -> count*
mc.Delete(ctx, collName, common.DefaultPartition, fmt.Sprintf("%s < 1000 ", common.DefaultIntFieldName))
Expand All @@ -885,20 +885,20 @@ func TestQueryCountAfterDml(t *testing.T) {
jsonColumn := common.GenDefaultJSONData(common.DefaultJSONFieldName, 0, upsertNb)
mc.Upsert(ctx, collName, "", intColumn, floatColumn, vecColumn, jsonColumn)
countAfterUpsert, _ := mc.Query(ctx, collName, []string{common.DefaultPartition}, "", []string{common.QueryCountFieldName})
require.Equal(t, int64(common.DefaultNb + upsertNb), countAfterUpsert.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])
require.Equal(t, int64(common.DefaultNb+upsertNb), countAfterUpsert.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])

// upsert existed 100 entities -> count*
intColumn, floatColumn, vecColumn = common.GenDefaultColumnData(common.DefaultNb, upsertNb, common.DefaultDim)
jsonColumn = common.GenDefaultJSONData(common.DefaultJSONFieldName, common.DefaultNb, upsertNb)
mc.Upsert(ctx, collName, "", intColumn, floatColumn, vecColumn, jsonColumn)
countAfterUpsert2, _ := mc.Query(ctx, collName, []string{common.DefaultPartition}, "", []string{common.QueryCountFieldName})
require.Equal(t, int64(common.DefaultNb + upsertNb), countAfterUpsert2.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])
require.Equal(t, int64(common.DefaultNb+upsertNb), countAfterUpsert2.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])

// compact -> count(*)
_, err := mc.Compact(ctx, collName, time.Second*60)
common.CheckErr(t, err, true)
countAfterCompact, _ := mc.Query(ctx, collName, []string{common.DefaultPartition}, "", []string{common.QueryCountFieldName})
require.Equal(t, int64(common.DefaultNb + upsertNb), countAfterCompact.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])
require.Equal(t, int64(common.DefaultNb+upsertNb), countAfterCompact.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])
}

// TODO offset and limit
Expand Down
Loading

0 comments on commit 043c4d7

Please sign in to comment.