Skip to content

Commit

Permalink
Update Vector interface
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Jan 28, 2025
1 parent 73018a5 commit 65ef156
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 176 deletions.
73 changes: 31 additions & 42 deletions bson/bson_binary_vector_spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,23 @@ func Test_BsonBinaryVector(t *testing.T) {
val := Binary{Subtype: TypeBinaryVector}

for _, tc := range [][]byte{
{Float32Vector, 0, 42},
{Float32Vector, 0, 42, 42},
{Float32Vector, 0, 42, 42, 42},
{byte(Float32Vector), 0, 42},
{byte(Float32Vector), 0, 42, 42},
{byte(Float32Vector), 0, 42, 42, 42},

{Float32Vector, 0, 42, 42, 42, 42, 42},
{Float32Vector, 0, 42, 42, 42, 42, 42, 42},
{Float32Vector, 0, 42, 42, 42, 42, 42, 42, 42},
{byte(Float32Vector), 0, 42, 42, 42, 42, 42},
{byte(Float32Vector), 0, 42, 42, 42, 42, 42, 42},
{byte(Float32Vector), 0, 42, 42, 42, 42, 42, 42, 42},
} {
t.Run(fmt.Sprintf("marshaling %d bytes", len(tc)-2), func(t *testing.T) {
val.Data = tc
b, err := Marshal(D{{"vector", val}})
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector[float32]
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, ErrInsufficientVectorData.Error())
require.ErrorContains(t, err, errInsufficientVectorData.Error())
})
}
})
Expand All @@ -95,39 +95,37 @@ func Test_BsonBinaryVector(t *testing.T) {
t.Parallel()

t.Run("Marshaling", func(t *testing.T) {
val := BitVector{Padding: 1}
_, err := Marshal(val)
require.EqualError(t, err, ErrNonZeroVectorPadding.Error())
_, err := NewPackedBitsVector(nil, 1)

Check failure on line 98 in bson/bson_binary_vector_spec_test.go

View workflow job for this annotation

GitHub Actions / pre_commit

undefined: NewPackedBitsVector

Check failure on line 98 in bson/bson_binary_vector_spec_test.go

View workflow job for this annotation

GitHub Actions / pre_commit

undefined: NewPackedBitsVector
require.EqualError(t, err, errNonZeroVectorPadding.Error())
})
t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 1}}}}
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{byte(PackedBitVector), 1}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector[float32]
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, ErrNonZeroVectorPadding.Error())
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
})
})

t.Run("Exceeding maximum padding PACKED_BIT", func(t *testing.T) {
t.Parallel()

t.Run("Marshaling", func(t *testing.T) {
val := BitVector{Padding: 8}
_, err := Marshal(val)
require.EqualError(t, err, ErrVectorPaddingTooLarge.Error())
_, err := NewPackedBitsVector(nil, 8)

Check failure on line 117 in bson/bson_binary_vector_spec_test.go

View workflow job for this annotation

GitHub Actions / pre_commit

undefined: NewPackedBitsVector (typecheck)

Check failure on line 117 in bson/bson_binary_vector_spec_test.go

View workflow job for this annotation

GitHub Actions / pre_commit

undefined: NewPackedBitsVector) (typecheck)
require.EqualError(t, err, errVectorPaddingTooLarge.Error())
})
t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 8}}}}
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{byte(PackedBitVector), 8}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector[float32]
Vector Vector
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, ErrVectorPaddingTooLarge.Error())
require.ErrorContains(t, err, errVectorPaddingTooLarge.Error())
})
})
}
Expand Down Expand Up @@ -156,22 +154,23 @@ func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVector
t.Skipf("skip invalid case %s", test.Description)
}

var testVector interface{}
testVector := make(map[string]Vector)
switch alias := test.DtypeHex; alias {
case "0x03":
testVector = map[string]Vector[int8]{
testKey: {convertSlice[int8](test.Vector)},
testVector[testKey] = Vector{
dType: Int8Vector,
int8Data: convertSlice[int8](test.Vector),
}
case "0x27":
testVector = map[string]Vector[float32]{
testKey: {convertSlice[float32](test.Vector)},
testVector[testKey] = Vector{
dType: Float32Vector,
float32Data: convertSlice[float32](test.Vector),
}
case "0x10":
testVector = map[string]BitVector{
testKey: {
Padding: uint8(test.Padding),
Data: convertSlice[byte](test.Vector),
},
testVector[testKey] = Vector{
dType: PackedBitVector,
bitData: convertSlice[byte](test.Vector),
bitPadding: uint8(test.Padding),
}
default:
t.Fatalf("unsupported vector type: %s", alias)
Expand All @@ -183,18 +182,8 @@ func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVector
t.Run("Unmarshaling", func(t *testing.T) {
t.Parallel()

var got interface{}
switch alias := test.DtypeHex; alias {
case "0x03":
got = make(map[string]Vector[int8])
case "0x27":
got = make(map[string]Vector[float32])
case "0x10":
got = make(map[string]BitVector)
default:
t.Fatalf("unsupported type: %s", alias)
}
err := Unmarshal(testBSON, got)
var got map[string]Vector
err := Unmarshal(testBSON, &got)
require.NoError(t, err)
require.Equal(t, testVector, got)
})
Expand Down
12 changes: 5 additions & 7 deletions bson/default_value_decoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ func registerDefaultDecoders(reg *Registry) {

reg.RegisterTypeDecoder(tD, ValueDecoderFunc(dDecodeValue))
reg.RegisterTypeDecoder(tBinary, decodeAdapter{binaryDecodeValue, binaryDecodeType})
reg.RegisterTypeDecoder(tInt8Vector, decodeAdapter{vectorDecodeValue, vectorDecodeType})
reg.RegisterTypeDecoder(tFloat32Vector, decodeAdapter{vectorDecodeValue, vectorDecodeType})
reg.RegisterTypeDecoder(tBitVector, decodeAdapter{vectorDecodeValue, vectorDecodeType})
reg.RegisterTypeDecoder(tVector, decodeAdapter{vectorDecodeValue, vectorDecodeType})
reg.RegisterTypeDecoder(tUndefined, decodeAdapter{undefinedDecodeValue, undefinedDecodeType})
reg.RegisterTypeDecoder(tDateTime, decodeAdapter{dateTimeDecodeValue, dateTimeDecodeType})
reg.RegisterTypeDecoder(tNull, decodeAdapter{nullDecodeValue, nullDecodeType})
Expand Down Expand Up @@ -561,10 +559,10 @@ func binaryDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) erro
}

func vectorDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tInt8Vector && t != tFloat32Vector && t != tBitVector {
if t != tVector {
return emptyValue, ValueDecoderError{
Name: "VectorDecodeValue",
Types: []reflect.Type{tInt8Vector, tFloat32Vector, tBitVector},
Types: []reflect.Type{tVector},
Received: reflect.Zero(t),
}
}
Expand All @@ -585,10 +583,10 @@ func vectorDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.
// vectorDecodeValue is the ValueDecoderFunc for Vector.
func vectorDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error {
t := val.Type()
if !val.CanSet() || (t != tInt8Vector && t != tFloat32Vector && t != tBitVector) {
if !val.CanSet() || t != tVector {
return ValueDecoderError{
Name: "VectorDecodeValue",
Types: []reflect.Type{tInt8Vector, tFloat32Vector, tBitVector},
Types: []reflect.Type{tVector},
Received: val,
}
}
Expand Down
24 changes: 5 additions & 19 deletions bson/default_value_encoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ func registerDefaultEncoders(reg *Registry) {
reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue))
reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue))
reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue))
reg.RegisterTypeEncoder(tInt8Vector, ValueEncoderFunc(vectorEncodeValue))
reg.RegisterTypeEncoder(tFloat32Vector, ValueEncoderFunc(vectorEncodeValue))
reg.RegisterTypeEncoder(tBitVector, ValueEncoderFunc(vectorEncodeValue))
reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue))
reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue))
reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue))
reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue))
Expand Down Expand Up @@ -370,26 +368,14 @@ func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error
// vectorEncodeValue is the ValueEncoderFunc for Vector.
func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
t := val.Type()
if !val.IsValid() || (t != tInt8Vector && t != tFloat32Vector && t != tBitVector) {
if !val.IsValid() || t != tVector {
return ValueEncoderError{Name: "VectorEncodeValue",
Types: []reflect.Type{tInt8Vector, tFloat32Vector, tBitVector},
Types: []reflect.Type{tVector},
Received: val,
}
}
var b Binary
var err error
switch v := val.Interface().(type) {
case Vector[int8]:
b, err = NewBinaryFromVector(v)
case Vector[float32]:
b, err = NewBinaryFromVector(v)
case BitVector:
b, err = NewBinaryFromVector(v)
}
if err != nil {
return err
}

v := val.Interface().(Vector)
b := v.Binary()
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
}

Expand Down
4 changes: 1 addition & 3 deletions bson/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ var tJavaScript = reflect.TypeOf(JavaScript(""))
var tSymbol = reflect.TypeOf(Symbol(""))
var tTimestamp = reflect.TypeOf(Timestamp{})
var tDecimal = reflect.TypeOf(Decimal128{})
var tInt8Vector = reflect.TypeOf(Vector[int8]{})
var tFloat32Vector = reflect.TypeOf(Vector[float32]{})
var tBitVector = reflect.TypeOf(BitVector{})
var tVector = reflect.TypeOf(Vector{})
var tMinKey = reflect.TypeOf(MinKey{})
var tMaxKey = reflect.TypeOf(MaxKey{})
var tD = reflect.TypeOf(D{})
Expand Down
Loading

0 comments on commit 65ef156

Please sign in to comment.