Skip to content

Commit

Permalink
feat: add test cases for query iterator (#760)
Browse files Browse the repository at this point in the history
Signed-off-by: ThreadDao <[email protected]>
  • Loading branch information
ThreadDao authored Jun 11, 2024
1 parent 3a0f12b commit b091d81
Show file tree
Hide file tree
Showing 10 changed files with 735 additions and 76 deletions.
2 changes: 2 additions & 0 deletions entity/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,8 @@ func (t FieldType) String() string {
return "[]byte"
case FieldTypeBFloat16Vector:
return "[]byte"
case FieldTypeSparseVector:
return "[]SparseEmbedding"
default:
return "undefined"
}
Expand Down
9 changes: 9 additions & 0 deletions test/base/milvus_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,15 @@ func (mc *MilvusClient) Get(ctx context.Context, collName string, ids entity.Col
return queryResults, err
}

// QueryIterator QueryIterator from collection
func (mc *MilvusClient) QueryIterator(ctx context.Context, opt *client.QueryIteratorOption) (*client.QueryIterator, error) {
funcName := "QueryIterator"
preRequest(funcName, ctx, opt)
itr, err := mc.mClient.QueryIterator(ctx, opt)
postResponse(funcName, err, itr)
return itr, err
}

// -- row based apis --

// CreateCollectionByRow Create Collection By Row
Expand Down
107 changes: 106 additions & 1 deletion test/common/response_check.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package common

import (
"context"
"fmt"
"io"
"log"
"strings"
"testing"
Expand Down Expand Up @@ -160,14 +162,48 @@ func EqualColumn(t *testing.T, columnA entity.Column, columnB entity.Column) {
require.ElementsMatch(t, columnA.(*entity.ColumnFloatVector).Data(), columnB.(*entity.ColumnFloatVector).Data())
case entity.FieldTypeBinaryVector:
require.ElementsMatch(t, columnA.(*entity.ColumnBinaryVector).Data(), columnB.(*entity.ColumnBinaryVector).Data())
case entity.FieldTypeFloat16Vector:
require.ElementsMatch(t, columnA.(*entity.ColumnFloat16Vector).Data(), columnB.(*entity.ColumnFloat16Vector).Data())
case entity.FieldTypeBFloat16Vector:
require.ElementsMatch(t, columnA.(*entity.ColumnBFloat16Vector).Data(), columnB.(*entity.ColumnBFloat16Vector).Data())
case entity.FieldTypeSparseVector:
require.ElementsMatch(t, columnA.(*entity.ColumnSparseFloatVector).Data(), columnB.(*entity.ColumnSparseFloatVector).Data())
case entity.FieldTypeArray:
log.Println("TODO support column element type")
EqualArrayColumn(t, columnA, columnB)
default:
log.Printf("The column type not in: [%v, %v, %v, %v, %v, %v, %v, %v, %v, %v, %v, %v]",
entity.FieldTypeBool, entity.FieldTypeInt8, entity.FieldTypeInt16, entity.FieldTypeInt32,
entity.FieldTypeInt64, entity.FieldTypeFloat, entity.FieldTypeDouble, entity.FieldTypeString,
entity.FieldTypeVarChar, entity.FieldTypeArray, entity.FieldTypeFloatVector, entity.FieldTypeBinaryVector)
}
}

// EqualColumn assert field data is equal of two columns
func EqualArrayColumn(t *testing.T, columnA entity.Column, columnB entity.Column) {
require.Equal(t, columnA.Name(), columnB.Name())
require.IsType(t, columnA.Type(), entity.FieldTypeArray)
require.IsType(t, columnB.Type(), entity.FieldTypeArray)
switch columnA.(type) {
case *entity.ColumnBoolArray:
require.ElementsMatch(t, columnA.(*entity.ColumnBoolArray).Data(), columnB.(*entity.ColumnBoolArray).Data())
case *entity.ColumnInt8Array:
require.ElementsMatch(t, columnA.(*entity.ColumnInt8Array).Data(), columnB.(*entity.ColumnInt8Array).Data())
case *entity.ColumnInt16Array:
require.ElementsMatch(t, columnA.(*entity.ColumnInt16Array).Data(), columnB.(*entity.ColumnInt16Array).Data())
case *entity.ColumnInt32Array:
require.ElementsMatch(t, columnA.(*entity.ColumnInt32Array).Data(), columnB.(*entity.ColumnInt32Array).Data())
case *entity.ColumnInt64Array:
require.ElementsMatch(t, columnA.(*entity.ColumnInt64Array).Data(), columnB.(*entity.ColumnInt64Array).Data())
case *entity.ColumnFloatArray:
require.ElementsMatch(t, columnA.(*entity.ColumnFloatArray).Data(), columnB.(*entity.ColumnFloatArray).Data())
case *entity.ColumnDoubleArray:
require.ElementsMatch(t, columnA.(*entity.ColumnDoubleArray).Data(), columnB.(*entity.ColumnDoubleArray).Data())
case *entity.ColumnVarCharArray:
require.ElementsMatch(t, columnA.(*entity.ColumnVarCharArray).Data(), columnB.(*entity.ColumnVarCharArray).Data())
default:
log.Printf("Now support array type: [%v, %v, %v, %v, %v, %v, %v, %v]",
entity.FieldTypeBool, entity.FieldTypeInt8, entity.FieldTypeInt16, entity.FieldTypeInt32,
entity.FieldTypeInt64, entity.FieldTypeFloat, entity.FieldTypeDouble, entity.FieldTypeVarChar)
}
}

Expand Down Expand Up @@ -203,6 +239,75 @@ func CheckSearchResult(t *testing.T, actualSearchResults []client.SearchResult,

}

func EqualIntSlice(a []int, b []int) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

type CheckIteratorOption func(opt *checkIteratorOpt)

type checkIteratorOpt struct {
expBatchSize []int
expOutputFields []string
}

func WithExpBatchSize(expBatchSize []int) CheckIteratorOption {
return func(opt *checkIteratorOpt) {
opt.expBatchSize = expBatchSize
}
}

func WithExpOutputFields(expOutputFields []string) CheckIteratorOption {
return func(opt *checkIteratorOpt) {
opt.expOutputFields = expOutputFields
}
}

// check queryIterator: result limit, each batch size, output fields
func CheckQueryIteratorResult(ctx context.Context, t *testing.T, itr *client.QueryIterator, expLimit int, opts ...CheckIteratorOption) {
opt := &checkIteratorOpt{}
for _, o := range opts {
o(opt)
}
actualLimit := 0
var actualBatchSize []int
for {
rs, err := itr.Next(ctx)
if err != nil {
if err == io.EOF {
break
}
log.Fatalf("QueryIterator next gets error: %v", err)
}
//log.Printf("QueryIterator result len: %d", rs.Len())
//log.Printf("QueryIterator result data: %d", rs.GetColumn("int64"))

if opt.expBatchSize != nil {
actualBatchSize = append(actualBatchSize, rs.Len())
}
var actualOutputFields []string
if opt.expOutputFields != nil {
for _, column := range rs {
actualOutputFields = append(actualOutputFields, column.Name())
}
require.ElementsMatch(t, opt.expOutputFields, actualOutputFields)
}
actualLimit = actualLimit + rs.Len()
}
require.Equal(t, expLimit, actualLimit)
if opt.expBatchSize != nil {
log.Printf("QueryIterator result len: %v", actualBatchSize)
require.True(t, EqualIntSlice(opt.expBatchSize, actualBatchSize))
}
}

// CheckPersistentSegments check persistent segments
func CheckPersistentSegments(t *testing.T, actualSegments []*entity.Segment, expNb int64) {
actualNb := int64(0)
Expand Down
24 changes: 23 additions & 1 deletion test/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ const (
DefaultPartitionNum = 16 // default num_partitions
MaxTopK = 16384
MaxVectorFieldNum = 4
DefaultBatchSize = 1000
)

var IndexStateValue = map[string]int32{
Expand Down Expand Up @@ -842,7 +843,7 @@ func GenDefaultJSONRows(start int, nb int, dim int64, enableDynamicField bool) [
}

for i := start; i < start+nb; i++ {
// jsonStruct row and dynamic row
//jsonStruct row and dynamic row
var jsonStruct JSONStruct
if i%2 == 0 {
jsonStruct = JSONStruct{
Expand Down Expand Up @@ -1386,6 +1387,8 @@ type InvalidExprStruct struct {
var InvalidExpressions = []InvalidExprStruct{
{Expr: "id in [0]", ErrNil: true, ErrMsg: "fieldName(id) not found"}, // not exist field but no error
{Expr: "int64 in not [0]", ErrNil: false, ErrMsg: "cannot parse expression"}, // wrong term expr keyword
{Expr: "int64 > 10 AND int64 < 100", ErrNil: false, ErrMsg: "cannot parse expression"}, // AND isn't supported
{Expr: "int64 < 10 OR int64 > 100", ErrNil: false, ErrMsg: "cannot parse expression"}, // OR isn't supported
{Expr: "int64 < floatVec", ErrNil: false, ErrMsg: "not supported"}, // unsupported compare field
{Expr: "floatVec in [0]", ErrNil: false, ErrMsg: "cannot be casted to FloatVector"}, // value and field type mismatch
{Expr: fmt.Sprintf("%s == 1", DefaultJSONFieldName), ErrNil: true, ErrMsg: ""}, // hist empty
Expand All @@ -1406,4 +1409,23 @@ var InvalidExpressions = []InvalidExprStruct{
{Expr: fmt.Sprintf(fmt.Sprintf("%s[-1] > 1", DefaultJSONFieldName)), ErrNil: false, ErrMsg: "invalid expression"}, // json[-1] >
}

func GenBatchSizes(limit int, batch int) []int {
if batch == 0 {
log.Fatal("Batch should be larger than 0")
}
if limit == 0 {
return []int{}
}
_loop := limit / batch
_last := limit % batch
batchSizes := make([]int, 0, _loop+1)
for i := 0; i < _loop; i++ {
batchSizes = append(batchSizes, batch)
}
if _last > 0 {
batchSizes = append(batchSizes, _last)
}
return batchSizes
}

// --- search utils ---
2 changes: 1 addition & 1 deletion test/testcases/configure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestCompactAfterDelete(t *testing.T) {
common.CheckErr(t, err, true)

// delete half ids
deleteIds := entity.NewColumnInt64(common.DefaultIntFieldName, ids.(*entity.ColumnInt64).Data()[:common.DefaultNb/2])
deleteIds := ids.Slice(0, common.DefaultNb/2)
errDelete := mc.DeleteByPks(ctx, collName, "", deleteIds)
common.CheckErr(t, errDelete, true)

Expand Down
12 changes: 6 additions & 6 deletions test/testcases/delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestDelete(t *testing.T) {
common.CheckErr(t, errLoad, true)

// delete
deleteIds := entity.NewColumnInt64(common.DefaultIntFieldName, ids.(*entity.ColumnInt64).Data()[:10])
deleteIds := ids.Slice(0, 10)
errDelete := mc.DeleteByPks(ctx, collName, common.DefaultPartition, deleteIds)
common.CheckErr(t, errDelete, true)

Expand All @@ -48,7 +48,7 @@ func TestDeleteStringPks(t *testing.T) {
collName, ids := createVarcharCollectionWithDataIndex(ctx, t, mc, true, client.WithConsistencyLevel(entity.ClStrong))

// delete
deleteIds := entity.NewColumnVarChar(common.DefaultVarcharFieldName, ids.(*entity.ColumnVarChar).Data()[:10])
deleteIds := ids.Slice(0, 10)
errDelete := mc.DeleteByPks(ctx, collName, common.DefaultPartition, deleteIds)
common.CheckErr(t, errDelete, true)

Expand Down Expand Up @@ -103,7 +103,7 @@ func TestDeleteNotExistPartition(t *testing.T) {
common.CheckErr(t, errLoad, true)

// delete
deleteIds := entity.NewColumnInt64(common.DefaultIntFieldName, ids.(*entity.ColumnInt64).Data()[:10])
deleteIds := ids.Slice(0, 10)
errDelete := mc.DeleteByPks(ctx, collName, "p1", deleteIds)
common.CheckErr(t, errDelete, false, fmt.Sprintf("partition p1 of collection %s does not exist", collName))
}
Expand All @@ -125,7 +125,7 @@ func TestDeleteEmptyPartitionNames(t *testing.T) {
mc.Flush(ctx, collName, false)

// delete
deleteIds := entity.NewColumnInt64(common.DefaultIntFieldName, intColumn.(*entity.ColumnInt64).Data()[:10])
deleteIds := intColumn.Slice(0, 10)
errDelete := mc.DeleteByPks(ctx, collName, emptyPartitionName, deleteIds)
common.CheckErr(t, errDelete, true)

Expand Down Expand Up @@ -160,7 +160,7 @@ func TestDeleteEmptyPartition(t *testing.T) {
common.CheckErr(t, errLoad, true)

// delete from empty partition p1
deleteIds := entity.NewColumnInt64(common.DefaultIntFieldName, ids.(*entity.ColumnInt64).Data()[:10])
deleteIds := ids.Slice(0, 10)
errDelete := mc.DeleteByPks(ctx, collName, "p1", deleteIds)
common.CheckErr(t, errDelete, true)

Expand All @@ -186,7 +186,7 @@ func TestDeletePartitionIdsNotMatch(t *testing.T) {
partitionName, vecColumnDefault, _ := createInsertTwoPartitions(ctx, t, mc, collName, common.DefaultNb)

// delete [0:10) from new partition -> delete nothing
deleteIds := entity.NewColumnInt64(common.DefaultIntFieldName, vecColumnDefault.IdsColumn.(*entity.ColumnInt64).Data()[:10])
deleteIds := vecColumnDefault.IdsColumn.Slice(0, 10)
errDelete := mc.DeleteByPks(ctx, collName, partitionName, deleteIds)
common.CheckErr(t, errDelete, true)

Expand Down
12 changes: 6 additions & 6 deletions test/testcases/highlevel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ func TestNewCollection(t *testing.T) {
queryResult, err := mc.Get(
ctx,
collName,
entity.NewColumnInt64(DefaultPkFieldName, pkColumn.(*entity.ColumnInt64).Data()[:10]),
pkColumn.Slice(0, 10),
)
common.CheckErr(t, err, true)
common.CheckOutputFields(t, queryResult, []string{DefaultPkFieldName, DefaultVectorFieldName})
common.CheckQueryResult(t, queryResult, []entity.Column{
entity.NewColumnInt64(DefaultPkFieldName, pkColumn.(*entity.ColumnInt64).Data()[:10]),
entity.NewColumnFloatVector(DefaultVectorFieldName, int(common.DefaultDim), vecColumn.(*entity.ColumnFloatVector).Data()[:10]),
pkColumn.Slice(0, 10),
vecColumn.Slice(0, 10),
})

// search
Expand Down Expand Up @@ -142,13 +142,13 @@ func TestNewCollectionCustomize(t *testing.T) {
queryResult, err := mc.Get(
ctx,
collName,
entity.NewColumnVarChar(pkFieldName, pkColumn.(*entity.ColumnVarChar).Data()[:10]),
pkColumn.Slice(0, 10),
)
common.CheckErr(t, err, true)
common.CheckOutputFields(t, queryResult, []string{pkFieldName, vectorFieldName})
common.CheckQueryResult(t, queryResult, []entity.Column{
entity.NewColumnVarChar(pkFieldName, pkColumn.(*entity.ColumnVarChar).Data()[:10]),
entity.NewColumnFloatVector(vectorFieldName, int(common.DefaultDim), vecColumn.(*entity.ColumnFloatVector).Data()[:10]),
pkColumn.Slice(0, 10),
vecColumn.Slice(0, 10),
})

// search
Expand Down
2 changes: 1 addition & 1 deletion test/testcases/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ const (
Int64FloatVecJSON CollectionFieldsType = "PkInt64FloatVecJson" // int64 + float + floatVec + json
Int64FloatVecArray CollectionFieldsType = "Int64FloatVecArray" // int64 + float + floatVec + all array
Int64VarcharSparseVec CollectionFieldsType = "Int64VarcharSparseVec" // int64 + varchar + float32Vec + sparseVec
AllVectors CollectionFieldsType = "AllVectors" // int64 + fp32Vec + fp16Vec + binaryVec
AllVectors CollectionFieldsType = "AllVectors" // int64 + fp32Vec + fp16Vec + bf16Vec + binaryVec
AllFields CollectionFieldsType = "AllFields" // all scalar fields + floatVec
)

Expand Down
Loading

0 comments on commit b091d81

Please sign in to comment.