From 7d9d96f85e56a85fb8fdfb95665befdf231400de Mon Sep 17 00:00:00 2001 From: shaoting-huang Date: Thu, 27 Jun 2024 20:22:46 +0800 Subject: [PATCH] support primary key types Signed-off-by: shaoting-huang --- internal/storage/payload_writer.go | 23 +++++++ internal/storage/serde_events.go | 88 ++++++++++++++++----------- internal/storage/serde_events_test.go | 71 ++++++++++++++++++++- 3 files changed, 145 insertions(+), 37 deletions(-) diff --git a/internal/storage/payload_writer.go b/internal/storage/payload_writer.go index 8b8b00100564b..9926aff68e3c9 100644 --- a/internal/storage/payload_writer.go +++ b/internal/storage/payload_writer.go @@ -759,3 +759,26 @@ func milvusDataTypeToArrowType(dataType schemapb.DataType, dim int) arrow.DataTy panic("unsupported data type") } } + +func arrowTypeToMilvusDataType(arrowType arrow.DataType) schemapb.DataType { + switch arrowType.ID() { + case arrow.BOOL: + return schemapb.DataType_Bool + case arrow.INT8: + return schemapb.DataType_Int8 + case arrow.INT16: + return schemapb.DataType_Int16 + case arrow.INT32: + return schemapb.DataType_Int32 + case arrow.INT64: + return schemapb.DataType_Int64 + case arrow.FLOAT32: + return schemapb.DataType_Float + case arrow.FLOAT64: + return schemapb.DataType_Double + case arrow.STRING: + return schemapb.DataType_VarChar + default: + panic("unsupported arrow type") + } +} diff --git a/internal/storage/serde_events.go b/internal/storage/serde_events.go index 186cccfc2962d..e97f530c23875 100644 --- a/internal/storage/serde_events.go +++ b/internal/storage/serde_events.go @@ -28,7 +28,6 @@ import ( "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/memory" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/merr" @@ -611,7 +610,7 @@ func (crr *simpleArrowRecordReader) Next() error { } record := crr.rr.Record() for i := range record.Schema().Fields() { - crr.r.schema[FieldID(i)] = schemapb.DataType_Int64 + crr.r.schema[FieldID(i)] = arrowTypeToMilvusDataType(record.Column(i).DataType()) crr.r.field2Col[FieldID(i)] = i } crr.r.r = record @@ -645,23 +644,12 @@ func newSimpleArrowRecordReader(blobs []*Blob) (*simpleArrowRecordReader, error) }, nil } -func NewMultiFieldDeltalogStreamWriter(collectionID, partitionID, segmentID UniqueID) *MultiFieldDeltalogStreamWriter { +func NewMultiFieldDeltalogStreamWriter(collectionID, partitionID, segmentID UniqueID, schema []*schemapb.FieldSchema) *MultiFieldDeltalogStreamWriter { return &MultiFieldDeltalogStreamWriter{ collectionID: collectionID, partitionID: partitionID, segmentID: segmentID, - fieldSchemas: []*schemapb.FieldSchema{ - { - FieldID: common.RowIDField, - Name: "pk", - DataType: schemapb.DataType_Int64, - }, - { - FieldID: common.TimeStampField, - Name: "ts", - DataType: schemapb.DataType_Int64, - }, - }, + fieldSchemas: schema, } } @@ -759,23 +747,10 @@ func NewDeltalogMultiFieldWriter(partitionID, segmentID UniqueID, eventWriter *M return nil, err } return NewSerializeRecordWriter[*DeleteLog](rw, func(v []*DeleteLog) (Record, uint64, error) { - builders := [2]*array.Int64Builder{ - array.NewInt64Builder(memory.DefaultAllocator), - array.NewInt64Builder(memory.DefaultAllocator), - } - var memorySize uint64 - for _, vv := range v { - builders[0].Append(vv.Pk.GetValue().(int64)) - memorySize += uint64(vv.Pk.GetValue().(int64)) - - builders[1].Append(int64(vv.Ts)) - memorySize += uint64(vv.Ts) - } - arrs := []arrow.Array{builders[0].NewArray(), builders[1].NewArray()} fields := []arrow.Field{ { Name: "pk", - Type: arrow.PrimitiveTypes.Int64, + Type: serdeMap[schemapb.DataType(v[0].PkType)].arrowType(0), Nullable: false, }, { @@ -784,15 +759,47 @@ func NewDeltalogMultiFieldWriter(partitionID, segmentID UniqueID, eventWriter *M Nullable: false, }, } + arrowSchema := arrow.NewSchema(fields, nil) + builder := array.NewRecordBuilder(memory.DefaultAllocator, arrowSchema) + defer builder.Release() + + var memorySize uint64 + pkType := schemapb.DataType(v[0].PkType) + switch pkType { + case schemapb.DataType_Int64: + pb := builder.Field(0).(*array.Int64Builder) + for _, vv := range v { + pk := vv.Pk.GetValue().(int64) + pb.Append(pk) + memorySize += uint64(pk) + } + case schemapb.DataType_VarChar: + pb := builder.Field(0).(*array.StringBuilder) + for _, vv := range v { + pk := vv.Pk.GetValue().(string) + pb.Append(pk) + memorySize += uint64(binary.Size(pk)) + } + default: + return nil, 0, fmt.Errorf("unexpected pk type %v", v[0].PkType) + } + + for _, vv := range v { + builder.Field(1).(*array.Int64Builder).Append(int64(vv.Ts)) + memorySize += uint64(vv.Ts) + } + + arr := []arrow.Array{builder.Field(0).NewArray(), builder.Field(1).NewArray()} + field2Col := map[FieldID]int{ common.RowIDField: 0, common.TimeStampField: 1, } schema := map[FieldID]schemapb.DataType{ - common.RowIDField: schemapb.DataType_Int64, + common.RowIDField: pkType, common.TimeStampField: schemapb.DataType_Int64, } - return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrs, int64(len(v))), schema, field2Col), memorySize, nil + return newSimpleArrowRecord(array.NewRecord(arrowSchema, arr, int64(len(v))), schema, field2Col), memorySize, nil }, batchSize), nil } @@ -807,15 +814,26 @@ func NewDeltalogMultiFieldReader(blobs []*Blob) (*DeserializeReader[*DeleteLog], v[i] = &DeleteLog{} } } - for fid := range r.Schema() { + for fid, dt := range r.Schema() { a := r.Column(fid) switch fid { case common.RowIDField: - int64Array := a.(*array.Int64) - for i := 0; i < int64Array.Len(); i++ { - v[i].Pk = NewInt64PrimaryKey(int64Array.Value(i)) + switch dt { + case schemapb.DataType_Int64: + arr := a.(*array.Int64) + for i := 0; i < a.Len(); i++ { + v[i].Pk = &Int64PrimaryKey{Value: arr.Value(i)} + v[i].PkType = int64(dt) + } + case schemapb.DataType_VarChar: + arr := a.(*array.String) + for i := 0; i < a.Len(); i++ { + v[i].Pk = &VarCharPrimaryKey{Value: arr.Value(i)} + v[i].PkType = int64(dt) + } } + case common.TimeStampField: int64Array := a.(*array.Int64) for i := 0; i < int64Array.Len(); i++ { diff --git a/internal/storage/serde_events_test.go b/internal/storage/serde_events_test.go index cc9713e383d9a..6131e3fc40be0 100644 --- a/internal/storage/serde_events_test.go +++ b/internal/storage/serde_events_test.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "io" + "strconv" "testing" "github.com/apache/arrow/go/v12/arrow" @@ -353,8 +354,19 @@ func TestDeltalogV2(t *testing.T) { assert.Equal(t, io.EOF, err) }) - t.Run("test serialize deserialize", func(t *testing.T) { - eventWriter := NewMultiFieldDeltalogStreamWriter(0, 0, 0) + t.Run("test int64 pk", func(t *testing.T) { + eventWriter := NewMultiFieldDeltalogStreamWriter(0, 0, 0, []*schemapb.FieldSchema{ + { + FieldID: common.RowIDField, + Name: "pk", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: common.TimeStampField, + Name: "ts", + DataType: schemapb.DataType_Int64, + }, + }) writer, err := NewDeltalogMultiFieldWriter(0, 0, eventWriter, 7) assert.NoError(t, err) @@ -395,4 +407,59 @@ func TestDeltalogV2(t *testing.T) { assertTestDeltalogData(t, i, value) } }) + + t.Run("test varchar pk", func(t *testing.T) { + eventWriter := NewMultiFieldDeltalogStreamWriter(0, 0, 0, []*schemapb.FieldSchema{ + { + FieldID: common.RowIDField, + Name: "pk", + DataType: schemapb.DataType_VarChar, + }, + { + FieldID: common.TimeStampField, + Name: "ts", + DataType: schemapb.DataType_Int64, + }, + }) + writer, err := NewDeltalogMultiFieldWriter(0, 0, eventWriter, 7) + assert.NoError(t, err) + + size := 10 + pks := make([]string, size) + tss := make([]uint64, size) + for i := 0; i < size; i++ { + pks[i] = strconv.Itoa(i) + tss[i] = uint64(i + 1) + } + data := make([]*DeleteLog, size) + for i := range pks { + data[i] = NewDeleteLog(NewVarCharPrimaryKey(pks[i]), tss[i]) + } + + // Serialize the data + for i := 0; i < size; i++ { + err := writer.Write(data[i]) + assert.NoError(t, err) + } + err = writer.Close() + assert.NoError(t, err) + + blob, err := eventWriter.Finalize() + assert.NoError(t, err) + assert.NotNil(t, blob) + blobs := []*Blob{blob} + + // Deserialize the data + reader, err := NewDeltalogDeserializeReader(blobs) + assert.NoError(t, err) + defer reader.Close() + for i := 0; i < size; i++ { + err = reader.Next() + assert.NoError(t, err) + + value := reader.Value() + assert.Equal(t, &VarCharPrimaryKey{strconv.Itoa(i)}, value.Pk) + assert.Equal(t, uint64(i+1), value.Ts) + } + }) }