Skip to content

Commit

Permalink
support primary key types
Browse files Browse the repository at this point in the history
Signed-off-by: shaoting-huang <[email protected]>
  • Loading branch information
shaoting-huang committed Jun 27, 2024
1 parent 6bb784b commit 7d9d96f
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 37 deletions.
23 changes: 23 additions & 0 deletions internal/storage/payload_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -759,3 +759,26 @@ func milvusDataTypeToArrowType(dataType schemapb.DataType, dim int) arrow.DataTy
panic("unsupported data type")
}
}

func arrowTypeToMilvusDataType(arrowType arrow.DataType) schemapb.DataType {
switch arrowType.ID() {
case arrow.BOOL:
return schemapb.DataType_Bool
case arrow.INT8:
return schemapb.DataType_Int8
case arrow.INT16:
return schemapb.DataType_Int16
case arrow.INT32:
return schemapb.DataType_Int32
case arrow.INT64:
return schemapb.DataType_Int64
case arrow.FLOAT32:
return schemapb.DataType_Float
case arrow.FLOAT64:
return schemapb.DataType_Double
case arrow.STRING:
return schemapb.DataType_VarChar
default:
panic("unsupported arrow type")
}
}
88 changes: 53 additions & 35 deletions internal/storage/serde_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/arrow/go/v12/arrow/memory"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/merr"
Expand Down Expand Up @@ -611,7 +610,7 @@ func (crr *simpleArrowRecordReader) Next() error {
}
record := crr.rr.Record()
for i := range record.Schema().Fields() {
crr.r.schema[FieldID(i)] = schemapb.DataType_Int64
crr.r.schema[FieldID(i)] = arrowTypeToMilvusDataType(record.Column(i).DataType())
crr.r.field2Col[FieldID(i)] = i
}
crr.r.r = record
Expand Down Expand Up @@ -645,23 +644,12 @@ func newSimpleArrowRecordReader(blobs []*Blob) (*simpleArrowRecordReader, error)
}, nil
}

func NewMultiFieldDeltalogStreamWriter(collectionID, partitionID, segmentID UniqueID) *MultiFieldDeltalogStreamWriter {
func NewMultiFieldDeltalogStreamWriter(collectionID, partitionID, segmentID UniqueID, schema []*schemapb.FieldSchema) *MultiFieldDeltalogStreamWriter {
return &MultiFieldDeltalogStreamWriter{
collectionID: collectionID,
partitionID: partitionID,
segmentID: segmentID,
fieldSchemas: []*schemapb.FieldSchema{
{
FieldID: common.RowIDField,
Name: "pk",
DataType: schemapb.DataType_Int64,
},
{
FieldID: common.TimeStampField,
Name: "ts",
DataType: schemapb.DataType_Int64,
},
},
fieldSchemas: schema,
}
}

Expand Down Expand Up @@ -759,23 +747,10 @@ func NewDeltalogMultiFieldWriter(partitionID, segmentID UniqueID, eventWriter *M
return nil, err
}
return NewSerializeRecordWriter[*DeleteLog](rw, func(v []*DeleteLog) (Record, uint64, error) {
builders := [2]*array.Int64Builder{
array.NewInt64Builder(memory.DefaultAllocator),
array.NewInt64Builder(memory.DefaultAllocator),
}
var memorySize uint64
for _, vv := range v {
builders[0].Append(vv.Pk.GetValue().(int64))
memorySize += uint64(vv.Pk.GetValue().(int64))

builders[1].Append(int64(vv.Ts))
memorySize += uint64(vv.Ts)
}
arrs := []arrow.Array{builders[0].NewArray(), builders[1].NewArray()}
fields := []arrow.Field{
{
Name: "pk",
Type: arrow.PrimitiveTypes.Int64,
Type: serdeMap[schemapb.DataType(v[0].PkType)].arrowType(0),
Nullable: false,
},
{
Expand All @@ -784,15 +759,47 @@ func NewDeltalogMultiFieldWriter(partitionID, segmentID UniqueID, eventWriter *M
Nullable: false,
},
}
arrowSchema := arrow.NewSchema(fields, nil)
builder := array.NewRecordBuilder(memory.DefaultAllocator, arrowSchema)
defer builder.Release()

var memorySize uint64
pkType := schemapb.DataType(v[0].PkType)
switch pkType {
case schemapb.DataType_Int64:
pb := builder.Field(0).(*array.Int64Builder)
for _, vv := range v {
pk := vv.Pk.GetValue().(int64)
pb.Append(pk)
memorySize += uint64(pk)
}
case schemapb.DataType_VarChar:
pb := builder.Field(0).(*array.StringBuilder)
for _, vv := range v {
pk := vv.Pk.GetValue().(string)
pb.Append(pk)
memorySize += uint64(binary.Size(pk))
}
default:
return nil, 0, fmt.Errorf("unexpected pk type %v", v[0].PkType)
}

for _, vv := range v {
builder.Field(1).(*array.Int64Builder).Append(int64(vv.Ts))
memorySize += uint64(vv.Ts)
}

arr := []arrow.Array{builder.Field(0).NewArray(), builder.Field(1).NewArray()}

field2Col := map[FieldID]int{
common.RowIDField: 0,
common.TimeStampField: 1,
}
schema := map[FieldID]schemapb.DataType{
common.RowIDField: schemapb.DataType_Int64,
common.RowIDField: pkType,
common.TimeStampField: schemapb.DataType_Int64,
}
return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrs, int64(len(v))), schema, field2Col), memorySize, nil
return newSimpleArrowRecord(array.NewRecord(arrowSchema, arr, int64(len(v))), schema, field2Col), memorySize, nil
}, batchSize), nil
}

Expand All @@ -807,15 +814,26 @@ func NewDeltalogMultiFieldReader(blobs []*Blob) (*DeserializeReader[*DeleteLog],
v[i] = &DeleteLog{}
}
}
for fid := range r.Schema() {
for fid, dt := range r.Schema() {
a := r.Column(fid)

switch fid {
case common.RowIDField:
int64Array := a.(*array.Int64)
for i := 0; i < int64Array.Len(); i++ {
v[i].Pk = NewInt64PrimaryKey(int64Array.Value(i))
switch dt {
case schemapb.DataType_Int64:
arr := a.(*array.Int64)
for i := 0; i < a.Len(); i++ {
v[i].Pk = &Int64PrimaryKey{Value: arr.Value(i)}
v[i].PkType = int64(dt)
}
case schemapb.DataType_VarChar:
arr := a.(*array.String)
for i := 0; i < a.Len(); i++ {
v[i].Pk = &VarCharPrimaryKey{Value: arr.Value(i)}
v[i].PkType = int64(dt)
}
}

case common.TimeStampField:
int64Array := a.(*array.Int64)
for i := 0; i < int64Array.Len(); i++ {
Expand Down
71 changes: 69 additions & 2 deletions internal/storage/serde_events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bytes"
"context"
"io"
"strconv"
"testing"

"github.com/apache/arrow/go/v12/arrow"
Expand Down Expand Up @@ -353,8 +354,19 @@ func TestDeltalogV2(t *testing.T) {
assert.Equal(t, io.EOF, err)
})

t.Run("test serialize deserialize", func(t *testing.T) {
eventWriter := NewMultiFieldDeltalogStreamWriter(0, 0, 0)
t.Run("test int64 pk", func(t *testing.T) {
eventWriter := NewMultiFieldDeltalogStreamWriter(0, 0, 0, []*schemapb.FieldSchema{
{
FieldID: common.RowIDField,
Name: "pk",
DataType: schemapb.DataType_Int64,
},
{
FieldID: common.TimeStampField,
Name: "ts",
DataType: schemapb.DataType_Int64,
},
})
writer, err := NewDeltalogMultiFieldWriter(0, 0, eventWriter, 7)
assert.NoError(t, err)

Expand Down Expand Up @@ -395,4 +407,59 @@ func TestDeltalogV2(t *testing.T) {
assertTestDeltalogData(t, i, value)
}
})

t.Run("test varchar pk", func(t *testing.T) {
eventWriter := NewMultiFieldDeltalogStreamWriter(0, 0, 0, []*schemapb.FieldSchema{
{
FieldID: common.RowIDField,
Name: "pk",
DataType: schemapb.DataType_VarChar,
},
{
FieldID: common.TimeStampField,
Name: "ts",
DataType: schemapb.DataType_Int64,
},
})
writer, err := NewDeltalogMultiFieldWriter(0, 0, eventWriter, 7)
assert.NoError(t, err)

size := 10
pks := make([]string, size)
tss := make([]uint64, size)
for i := 0; i < size; i++ {
pks[i] = strconv.Itoa(i)
tss[i] = uint64(i + 1)
}
data := make([]*DeleteLog, size)
for i := range pks {
data[i] = NewDeleteLog(NewVarCharPrimaryKey(pks[i]), tss[i])
}

// Serialize the data
for i := 0; i < size; i++ {
err := writer.Write(data[i])
assert.NoError(t, err)
}
err = writer.Close()
assert.NoError(t, err)

blob, err := eventWriter.Finalize()
assert.NoError(t, err)
assert.NotNil(t, blob)
blobs := []*Blob{blob}

// Deserialize the data
reader, err := NewDeltalogDeserializeReader(blobs)
assert.NoError(t, err)
defer reader.Close()
for i := 0; i < size; i++ {
err = reader.Next()
assert.NoError(t, err)

value := reader.Value()
assert.Equal(t, &VarCharPrimaryKey{strconv.Itoa(i)}, value.Pk)
assert.Equal(t, uint64(i+1), value.Ts)
}
})
}

0 comments on commit 7d9d96f

Please sign in to comment.