Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support restful & go sdk for Int8Vector #39278

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions client/column/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@
vector = append(vector, v)
}
return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil

case schemapb.DataType_SparseFloatVector:
sparseVectors := fd.GetVectors().GetSparseFloatVector()
if sparseVectors == nil {
Expand All @@ -303,6 +304,29 @@
vectors = append(vectors, vector)
}
return NewColumnSparseVectors(fd.GetFieldName(), vectors), nil

case schemapb.DataType_Int8Vector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schemapb.VectorField_Int8Vector)
if !ok {
return nil, errFieldDataTypeNotMatch
}

Check warning on line 313 in client/column/columns.go

View check run for this annotation

Codecov / codecov/patch

client/column/columns.go#L312-L313

Added lines #L312 - L313 were not covered by tests
data := x.Int8Vector
dim := int(vectors.GetDim())
if end < 0 {
end = len(data) / dim
}
vector := make([][]int8, 0, end-begin) // shall not have remanunt
// TODO caiyd: has better way to convert []byte to []int8 ?
for i := begin; i < end; i++ {
v := make([]int8, dim)
for j := 0; j < dim; j++ {
v[j] = int8(data[i*dim+j])
}
vector = append(vector, v)
}
return NewColumnInt8Vector(fd.GetFieldName(), dim, vector), nil

default:
return nil, fmt.Errorf("unsupported data type %s", fd.GetType())
}
Expand Down
14 changes: 12 additions & 2 deletions client/column/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ func values2FieldData[T any](values []T, fieldType entity.FieldType, dim int) *s
entity.FieldTypeFloat16Vector,
entity.FieldTypeBFloat16Vector,
entity.FieldTypeBinaryVector,
entity.FieldTypeSparseVector:
entity.FieldTypeSparseVector,
entity.FieldTypeInt8Vector:
fd.Field = &schemapb.FieldData_Vectors{
Vectors: values2Vectors(values, fieldType, int64(dim)),
}
Expand Down Expand Up @@ -265,8 +266,17 @@ func values2Vectors[T any](values []T, fieldType entity.FieldType, dim int64) *s
Contents: data,
},
}
case entity.FieldTypeInt8Vector:
var vectors []entity.Int8Vector
vectors, ok = any(values).([]entity.Int8Vector)
data := make([]byte, 0, int64(len(vectors))*dim)
for _, vector := range vectors {
data = append(data, vector.Serialize()...)
}
vectorField.Data = &schemapb.VectorField_Int8Vector{
Int8Vector: data,
}
}

if !ok {
panic(fmt.Sprintf("unexpected values type(%T) of fieldType %v", values, fieldType))
}
Expand Down
33 changes: 33 additions & 0 deletions client/column/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,36 @@
vectorBase: c.vectorBase.slice(start, end),
}
}

/* int8 vector */

type ColumnInt8Vector struct {
*vectorBase[entity.Int8Vector]
}

func NewColumnInt8Vector(fieldName string, dim int, data [][]int8) *ColumnInt8Vector {
vectors := lo.Map(data, func(row []int8, _ int) entity.Int8Vector { return entity.Int8Vector(row) })
return &ColumnInt8Vector{
vectorBase: newVectorBase(fieldName, dim, vectors, entity.FieldTypeInt8Vector),
}
}

// AppendValue appends vector value into values.
// override default type constrains, add `[]int8` conversion
func (c *ColumnInt8Vector) AppendValue(i interface{}) error {
switch vector := i.(type) {
case entity.Int8Vector:
c.values = append(c.values, vector)

Check warning on line 235 in client/column/vector.go

View check run for this annotation

Codecov / codecov/patch

client/column/vector.go#L234-L235

Added lines #L234 - L235 were not covered by tests
case []int8:
c.values = append(c.values, vector)
default:
return errors.Newf("unexpected append value type %T, field type %v", vector, c.fieldType)

Check warning on line 239 in client/column/vector.go

View check run for this annotation

Codecov / codecov/patch

client/column/vector.go#L238-L239

Added lines #L238 - L239 were not covered by tests
}
return nil
}

func (c *ColumnInt8Vector) Slice(start, end int) Column {
return &ColumnInt8Vector{
vectorBase: c.vectorBase.slice(start, end),
}
}
55 changes: 55 additions & 0 deletions client/column/vector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/stretchr/testify/suite"

"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

type VectorSuite struct {
Expand Down Expand Up @@ -187,6 +188,38 @@ func (s *VectorSuite) TestBasic() {
}
}
})

s.Run("int8_vector", func() {
name := fmt.Sprintf("field_%d", rand.Intn(1000))
n := 3
dim := rand.Intn(10) + 2
data := make([][]int8, 0, n)
for i := 0; i < n; i++ {
row := lo.RepeatBy(dim, func(i int) int8 {
return int8(rand.Intn(256) - 128)
})
data = append(data, row)
}
column := NewColumnInt8Vector(name, dim, data)
s.Equal(entity.FieldTypeInt8Vector, column.Type())
s.Equal(name, column.Name())
s.Equal(lo.Map(data, func(row []int8, _ int) entity.Int8Vector { return entity.Int8Vector(row) }), column.Data())
s.Equal(dim, column.Dim())

fd := column.FieldData()
s.Equal(name, fd.GetFieldName())
s.Equal(typeutil.Int8ArrayToBytes(lo.Flatten(data)), fd.GetVectors().GetInt8Vector())

result, err := FieldDataColumn(fd, 0, -1)
s.NoError(err)
parsed, ok := result.(*ColumnInt8Vector)
if s.True(ok) {
s.Equal(entity.FieldTypeInt8Vector, parsed.Type())
s.Equal(name, parsed.Name())
s.Equal(lo.Map(data, func(row []int8, _ int) entity.Int8Vector { return entity.Int8Vector(row) }), parsed.Data())
s.Equal(dim, parsed.Dim())
}
})
}

func (s *VectorSuite) TestSlice() {
Expand Down Expand Up @@ -277,6 +310,28 @@ func (s *VectorSuite) TestSlice() {
s.Equal(lo.Map(data[:l], func(row []byte, _ int) entity.BFloat16Vector { return entity.BFloat16Vector(row) }), slicedColumn.Data())
}
})

s.Run("int8_vector", func() {
name := fmt.Sprintf("field_%d", rand.Intn(1000))
n := 100
dim := rand.Intn(10) + 2
data := make([][]int8, 0, n)
for i := 0; i < n; i++ {
row := lo.RepeatBy(dim, func(i int) int8 {
return int8(rand.Intn(256) - 128)
})
data = append(data, row)
}
column := NewColumnInt8Vector(name, dim, data)

l := rand.Intn(n)
sliced := column.Slice(0, l)
slicedColumn, ok := sliced.(*ColumnInt8Vector)
if s.True(ok) {
s.Equal(dim, slicedColumn.Dim())
s.Equal(lo.Map(data[:l], func(row []int8, _ int) entity.Int8Vector { return entity.Int8Vector(row) }), slicedColumn.Data())
}
})
}

func TestVectors(t *testing.T) {
Expand Down
4 changes: 4 additions & 0 deletions client/entity/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
return "Float16Vector"
case FieldTypeBFloat16Vector:
return "BFloat16Vector"
case FieldTypeInt8Vector:
return "Int8Vector"

Check warning on line 66 in client/entity/field.go

View check run for this annotation

Codecov / codecov/patch

client/entity/field.go#L65-L66

Added lines #L65 - L66 were not covered by tests
default:
return "undefined"
}
Expand Down Expand Up @@ -100,6 +102,8 @@
return "[]byte"
case FieldTypeBFloat16Vector:
return "[]byte"
case FieldTypeInt8Vector:
return "[]int8"

Check warning on line 106 in client/entity/field.go

View check run for this annotation

Codecov / codecov/patch

client/entity/field.go#L105-L106

Added lines #L105 - L106 were not covered by tests
default:
return "undefined"
}
Expand Down
22 changes: 20 additions & 2 deletions client/entity/vectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
return typeutil.Float32ArrayToBFloat16Bytes(fv)
}

// FloatVector float32 vector wrapper.
// Float16Vector float16 vector wrapper.
type Float16Vector []byte

// Dim returns vector dimension.
Expand All @@ -77,7 +77,7 @@
return typeutil.Float16BytesToFloat32Vector(fv)
}

// FloatVector float32 vector wrapper.
// BFloat16Vector bfloat16 vector wrapper.
type BFloat16Vector []byte

// Dim returns vector dimension.
Expand Down Expand Up @@ -131,3 +131,21 @@
func (t Text) Serialize() []byte {
return []byte(t)
}

// Int8Vector []int8 vector wrapper
type Int8Vector []int8

// Dim return vector dimension
func (iv Int8Vector) Dim() int {
return len(iv)
}

// Serialize just return bytes
func (iv Int8Vector) Serialize() []byte {
return typeutil.Int8ArrayToBytes(iv)
}

// entity.FieldType returns coresponding field type.
func (iv Int8Vector) FieldType() FieldType {
return FieldTypeInt8Vector

Check warning on line 150 in client/entity/vectors.go

View check run for this annotation

Codecov / codecov/patch

client/entity/vectors.go#L149-L150

Added lines #L149 - L150 were not covered by tests
}
11 changes: 11 additions & 0 deletions client/entity/vectors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,15 @@ func TestVectors(t *testing.T) {
assert.Equal(t, dim*8, bv.Dim())
assert.ElementsMatch(t, raw, bv.Serialize())
})

t.Run("test int8 vector", func(t *testing.T) {
raw := make([]int8, dim)
for i := 0; i < dim; i++ {
raw[i] = int8(rand.Intn(256) - 128)
}

iv := Int8Vector(raw)
assert.Equal(t, dim, iv.Dim())
assert.Equal(t, dim, len(iv.Serialize()))
})
}
6 changes: 3 additions & 3 deletions client/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@ require (
github.com/blang/semver/v4 v4.0.0
github.com/cockroachdb/errors v1.9.1
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250102080446-c3ba3d26a90f
github.com/milvus-io/milvus/pkg v0.0.2-0.20250115044500-f5234c3c11a3
github.com/quasilyte/go-ruleguard/dsl v0.3.22
github.com/samber/lo v1.27.0
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.17.1
go.uber.org/atomic v1.10.0
golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2
google.golang.org/grpc v1.65.0
google.golang.org/protobuf v1.34.2
)
Expand Down Expand Up @@ -99,6 +98,7 @@ require (
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.28.0 // indirect
Expand Down
8 changes: 4 additions & 4 deletions client/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,10 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b h1:iPPhnFx+s7FF53UeWj7A4EYhPRMFPL6mHqyQw7qRjeQ=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84 h1:EAFxmxUVp5yYFDCrX1MQoSxkTO+ycy8NXEqEDEB3cRM=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84/go.mod h1:RATa0GS4jhkPpsYOvQ/QvcNz8rd+TlRPDiSyXQnMMxs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250102080446-c3ba3d26a90f h1:So6RKU5wqP/8EaKogicJP8gZ2SrzzS/JprusBaE3RKc=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250102080446-c3ba3d26a90f/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus/pkg v0.0.2-0.20250115044500-f5234c3c11a3 h1:WF9BkNk1XjLtwMbaB/cniRBMMNLnqG6e+AUbK8DciHQ=
github.com/milvus-io/milvus/pkg v0.0.2-0.20250115044500-f5234c3c11a3/go.mod h1:nxnHkDFB3jh27nTQJBaC4azAQO8chT03DkmoiZ5086s=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
Expand Down
10 changes: 9 additions & 1 deletion client/milvusclient/write_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ func (opt *columnBasedDataOption) processInsertColumns(colSchema *entity.Schema,
return nil, 0, fmt.Errorf("param column %s has type %v but collection field definition is %v", col.Name(), col.Type(), field.DataType)
}
if field.DataType == entity.FieldTypeFloatVector || field.DataType == entity.FieldTypeBinaryVector ||
field.DataType == entity.FieldTypeFloat16Vector || field.DataType == entity.FieldTypeBFloat16Vector {
field.DataType == entity.FieldTypeFloat16Vector || field.DataType == entity.FieldTypeBFloat16Vector ||
field.DataType == entity.FieldTypeInt8Vector {
dim := 0
switch column := col.(type) {
case *column.ColumnFloatVector:
Expand All @@ -108,6 +109,8 @@ func (opt *columnBasedDataOption) processInsertColumns(colSchema *entity.Schema,
dim = column.Dim()
case *column.ColumnBFloat16Vector:
dim = column.Dim()
case *column.ColumnInt8Vector:
dim = column.Dim()
}
if fmt.Sprintf("%d", dim) != field.TypeParams[entity.TypeParamDim] {
return nil, 0, fmt.Errorf("params column %s vector dim %d not match collection definition, which has dim of %s", field.Name, dim, field.TypeParams[entity.TypeParamDim])
Expand Down Expand Up @@ -234,6 +237,11 @@ func (opt *columnBasedDataOption) WithBinaryVectorColumn(colName string, dim int
return opt.WithColumns(column)
}

func (opt *columnBasedDataOption) WithInt8VectorColumn(colName string, dim int, data [][]int8) *columnBasedDataOption {
column := column.NewColumnInt8Vector(colName, dim, data)
return opt.WithColumns(column)
}

func (opt *columnBasedDataOption) WithPartition(partitionName string) *columnBasedDataOption {
opt.partitionName = partitionName
return opt
Expand Down
Loading
Loading