diff --git a/bson/bson_binary_vector_spec_test.go b/bson/bson_binary_vector_spec_test.go index 02d41a5617..62b88229c7 100644 --- a/bson/bson_binary_vector_spec_test.go +++ b/bson/bson_binary_vector_spec_test.go @@ -70,23 +70,23 @@ func Test_BsonBinaryVector(t *testing.T) { val := Binary{Subtype: TypeBinaryVector} for _, tc := range [][]byte{ - {Float32Vector, 0, 42}, - {Float32Vector, 0, 42, 42}, - {Float32Vector, 0, 42, 42, 42}, + {byte(Float32Vector), 0, 42}, + {byte(Float32Vector), 0, 42, 42}, + {byte(Float32Vector), 0, 42, 42, 42}, - {Float32Vector, 0, 42, 42, 42, 42, 42}, - {Float32Vector, 0, 42, 42, 42, 42, 42, 42}, - {Float32Vector, 0, 42, 42, 42, 42, 42, 42, 42}, + {byte(Float32Vector), 0, 42, 42, 42, 42, 42}, + {byte(Float32Vector), 0, 42, 42, 42, 42, 42, 42}, + {byte(Float32Vector), 0, 42, 42, 42, 42, 42, 42, 42}, } { t.Run(fmt.Sprintf("marshaling %d bytes", len(tc)-2), func(t *testing.T) { val.Data = tc b, err := Marshal(D{{"vector", val}}) require.NoError(t, err, "marshaling test BSON") var got struct { - Vector Vector[float32] + Vector Vector } err = Unmarshal(b, &got) - require.ErrorContains(t, err, ErrInsufficientVectorData.Error()) + require.ErrorContains(t, err, errInsufficientVectorData.Error()) }) } }) @@ -95,19 +95,18 @@ func Test_BsonBinaryVector(t *testing.T) { t.Parallel() t.Run("Marshaling", func(t *testing.T) { - val := BitVector{Padding: 1} - _, err := Marshal(val) - require.EqualError(t, err, ErrNonZeroVectorPadding.Error()) + _, err := NewPackedBitsVector(nil, 1) + require.EqualError(t, err, errNonZeroVectorPadding.Error()) }) t.Run("Unmarshaling", func(t *testing.T) { - val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 1}}}} + val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{byte(PackedBitVector), 1}}}} b, err := Marshal(val) require.NoError(t, err, "marshaling test BSON") var got struct { - Vector Vector[float32] + Vector Vector } err = Unmarshal(b, &got) - require.ErrorContains(t, err, ErrNonZeroVectorPadding.Error()) + require.ErrorContains(t, err, errNonZeroVectorPadding.Error()) }) }) @@ -115,19 +114,18 @@ func Test_BsonBinaryVector(t *testing.T) { t.Parallel() t.Run("Marshaling", func(t *testing.T) { - val := BitVector{Padding: 8} - _, err := Marshal(val) - require.EqualError(t, err, ErrVectorPaddingTooLarge.Error()) + _, err := NewPackedBitsVector(nil, 8) + require.EqualError(t, err, errVectorPaddingTooLarge.Error()) }) t.Run("Unmarshaling", func(t *testing.T) { - val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 8}}}} + val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{byte(PackedBitVector), 8}}}} b, err := Marshal(val) require.NoError(t, err, "marshaling test BSON") var got struct { - Vector Vector[float32] + Vector Vector } err = Unmarshal(b, &got) - require.ErrorContains(t, err, ErrVectorPaddingTooLarge.Error()) + require.ErrorContains(t, err, errVectorPaddingTooLarge.Error()) }) }) } @@ -156,22 +154,23 @@ func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVector t.Skipf("skip invalid case %s", test.Description) } - var testVector interface{} + testVector := make(map[string]Vector) switch alias := test.DtypeHex; alias { case "0x03": - testVector = map[string]Vector[int8]{ - testKey: {convertSlice[int8](test.Vector)}, + testVector[testKey] = Vector{ + dType: Int8Vector, + int8Data: convertSlice[int8](test.Vector), } case "0x27": - testVector = map[string]Vector[float32]{ - testKey: {convertSlice[float32](test.Vector)}, + testVector[testKey] = Vector{ + dType: Float32Vector, + float32Data: convertSlice[float32](test.Vector), } case "0x10": - testVector = map[string]BitVector{ - testKey: { - Padding: uint8(test.Padding), - Data: convertSlice[byte](test.Vector), - }, + testVector[testKey] = Vector{ + dType: PackedBitVector, + bitData: convertSlice[byte](test.Vector), + bitPadding: uint8(test.Padding), } default: t.Fatalf("unsupported vector type: %s", alias) @@ -183,18 +182,8 @@ func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVector t.Run("Unmarshaling", func(t *testing.T) { t.Parallel() - var got interface{} - switch alias := test.DtypeHex; alias { - case "0x03": - got = make(map[string]Vector[int8]) - case "0x27": - got = make(map[string]Vector[float32]) - case "0x10": - got = make(map[string]BitVector) - default: - t.Fatalf("unsupported type: %s", alias) - } - err := Unmarshal(testBSON, got) + var got map[string]Vector + err := Unmarshal(testBSON, &got) require.NoError(t, err) require.Equal(t, testVector, got) }) diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index fa47613dee..dfff145219 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -42,9 +42,7 @@ func registerDefaultDecoders(reg *Registry) { reg.RegisterTypeDecoder(tD, ValueDecoderFunc(dDecodeValue)) reg.RegisterTypeDecoder(tBinary, decodeAdapter{binaryDecodeValue, binaryDecodeType}) - reg.RegisterTypeDecoder(tInt8Vector, decodeAdapter{vectorDecodeValue, vectorDecodeType}) - reg.RegisterTypeDecoder(tFloat32Vector, decodeAdapter{vectorDecodeValue, vectorDecodeType}) - reg.RegisterTypeDecoder(tBitVector, decodeAdapter{vectorDecodeValue, vectorDecodeType}) + reg.RegisterTypeDecoder(tVector, decodeAdapter{vectorDecodeValue, vectorDecodeType}) reg.RegisterTypeDecoder(tUndefined, decodeAdapter{undefinedDecodeValue, undefinedDecodeType}) reg.RegisterTypeDecoder(tDateTime, decodeAdapter{dateTimeDecodeValue, dateTimeDecodeType}) reg.RegisterTypeDecoder(tNull, decodeAdapter{nullDecodeValue, nullDecodeType}) @@ -561,10 +559,10 @@ func binaryDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) erro } func vectorDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - if t != tInt8Vector && t != tFloat32Vector && t != tBitVector { + if t != tVector { return emptyValue, ValueDecoderError{ Name: "VectorDecodeValue", - Types: []reflect.Type{tInt8Vector, tFloat32Vector, tBitVector}, + Types: []reflect.Type{tVector}, Received: reflect.Zero(t), } } @@ -585,10 +583,10 @@ func vectorDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect. // vectorDecodeValue is the ValueDecoderFunc for Vector. func vectorDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { t := val.Type() - if !val.CanSet() || (t != tInt8Vector && t != tFloat32Vector && t != tBitVector) { + if !val.CanSet() || t != tVector { return ValueDecoderError{ Name: "VectorDecodeValue", - Types: []reflect.Type{tInt8Vector, tFloat32Vector, tBitVector}, + Types: []reflect.Type{tVector}, Received: val, } } diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index 2141c3c91d..bd5a20f2f9 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -70,9 +70,7 @@ func registerDefaultEncoders(reg *Registry) { reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue)) reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue)) reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue)) - reg.RegisterTypeEncoder(tInt8Vector, ValueEncoderFunc(vectorEncodeValue)) - reg.RegisterTypeEncoder(tFloat32Vector, ValueEncoderFunc(vectorEncodeValue)) - reg.RegisterTypeEncoder(tBitVector, ValueEncoderFunc(vectorEncodeValue)) + reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue)) reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue)) reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue)) reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue)) @@ -370,26 +368,14 @@ func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error // vectorEncodeValue is the ValueEncoderFunc for Vector. func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { t := val.Type() - if !val.IsValid() || (t != tInt8Vector && t != tFloat32Vector && t != tBitVector) { + if !val.IsValid() || t != tVector { return ValueEncoderError{Name: "VectorEncodeValue", - Types: []reflect.Type{tInt8Vector, tFloat32Vector, tBitVector}, + Types: []reflect.Type{tVector}, Received: val, } } - var b Binary - var err error - switch v := val.Interface().(type) { - case Vector[int8]: - b, err = NewBinaryFromVector(v) - case Vector[float32]: - b, err = NewBinaryFromVector(v) - case BitVector: - b, err = NewBinaryFromVector(v) - } - if err != nil { - return err - } - + v := val.Interface().(Vector) + b := v.Binary() return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } diff --git a/bson/types.go b/bson/types.go index 2550098cca..c2883aa4ef 100644 --- a/bson/types.go +++ b/bson/types.go @@ -107,9 +107,7 @@ var tJavaScript = reflect.TypeOf(JavaScript("")) var tSymbol = reflect.TypeOf(Symbol("")) var tTimestamp = reflect.TypeOf(Timestamp{}) var tDecimal = reflect.TypeOf(Decimal128{}) -var tInt8Vector = reflect.TypeOf(Vector[int8]{}) -var tFloat32Vector = reflect.TypeOf(Vector[float32]{}) -var tBitVector = reflect.TypeOf(BitVector{}) +var tVector = reflect.TypeOf(Vector{}) var tMinKey = reflect.TypeOf(MinKey{}) var tMaxKey = reflect.TypeOf(MaxKey{}) var tD = reflect.TypeOf(D{}) diff --git a/bson/vector.go b/bson/vector.go index fa71716984..95ae5d0b03 100644 --- a/bson/vector.go +++ b/bson/vector.go @@ -13,96 +13,208 @@ import ( "math" ) +// VectorDType represents the Vector data type. +type VectorDType byte + // These constants are vector data types. const ( - Int8Vector = 0x03 - Float32Vector = 0x27 - PackedBitVector = 0x10 + Int8Vector VectorDType = 0x03 + Float32Vector VectorDType = 0x27 + PackedBitVector VectorDType = 0x10 ) +// Stringer of VectorDType +func (vt VectorDType) String() string { + switch vt { + case Int8Vector: + return "int8" + case Float32Vector: + return "float32" + case PackedBitVector: + return "packed bit" + default: + return "invalid" + } +} + // These are vector conversion errors. var ( - ErrNotVector = errors.New("not a vector") - ErrInsufficientVectorData = errors.New("insufficient data") - ErrNonZeroVectorPadding = errors.New("padding must be 0") - ErrVectorPaddingTooLarge = errors.New("padding larger than 7") + errInsufficientVectorData = errors.New("insufficient data") + errNonZeroVectorPadding = errors.New("padding must be 0") + errVectorPaddingTooLarge = errors.New("padding larger than 7") ) -// Vector represents a densely packed array of numbers. -type Vector[T int8 | float32] struct { - Data []T +type vectorTypeError struct { + Method string + Type VectorDType } -// BitVector represents a binary quantized (PACKED_BIT) vector of 0s and 1s -// (bits). The Padding prescribes the number of bits to ignore in the final byte -// of the Data. It should be 0 for an empty Data and always less than 8. -type BitVector struct { - Padding uint8 - Data []byte +// Error implements the error interface. +func (vte vectorTypeError) Error() string { + return "Call of " + vte.Method + " on " + vte.Type.String() + " vector" } -func newInt8Vector(b []byte) (Vector[int8], error) { - var v Vector[int8] - if len(b) == 0 { - return v, ErrInsufficientVectorData +// Vector represents a densely packed array of numbers / bits. +type Vector struct { + dType VectorDType + int8Data []int8 + float32Data []float32 + bitData []byte + bitPadding uint8 +} + +// Type returns the vector type. +func (v Vector) Type() VectorDType { + return v.dType +} + +// Int8 returns the int8 slice hold by the vector. +// It panics if v is not an int8 vector. +func (v Vector) Int8() []int8 { + d, ok := v.Int8OK() + if !ok { + panic(vectorTypeError{"bson.Vector.Int8", v.dType}) } - if padding := b[0]; padding > 0 { - return v, ErrNonZeroVectorPadding + return d +} + +// Int8OK is the same as Int8, but returns a boolean instead of panicking. +func (v Vector) Int8OK() ([]int8, bool) { + if v.dType != Int8Vector { + return nil, false } - s := make([]int8, 0, len(b)-1) - for i := 1; i < len(b); i++ { - s = append(s, int8(b[i])) + return v.int8Data, true +} + +// Float32 returns the float32 slice hold by the vector. +// It panics if v is not a float32 vector. +func (v Vector) Float32() []float32 { + d, ok := v.Float32OK() + if !ok { + panic(vectorTypeError{"bson.Vector.Float32", v.dType}) } - v.Data = s - return v, nil + return d } -func newFloat32Vector(b []byte) (Vector[float32], error) { - var v Vector[float32] - if len(b) == 0 { - return v, ErrInsufficientVectorData +// Float32OK is the same as Float32, but returns a boolean instead of panicking. +func (v Vector) Float32OK() ([]float32, bool) { + if v.dType != Float32Vector { + return nil, false } - if padding := b[0]; padding > 0 { - return v, ErrNonZeroVectorPadding + return v.float32Data, true +} + +// PackedBit returns the byte slice representing the binary quantized (packed bit) vector and the byte padding, which +// is the number of bits in the final byte that are to be ignored. +// It panics if v is not a packed bit vector. +func (v Vector) PackedBit() ([]byte, uint8) { + d, p, ok := v.PackedBitOK() + if !ok { + panic(vectorTypeError{"bson.Vector.PackedBit", v.dType}) } - l := (len(b) - 1) / 4 - if l*4 != len(b)-1 { - return v, ErrInsufficientVectorData + return d, p +} + +// PackedBitOK is the same as PackedBit, but returns a boolean instead of panicking. +func (v Vector) PackedBitOK() ([]byte, uint8, bool) { + if v.dType != PackedBitVector { + return nil, 0, false } - s := make([]float32, 0, l) - for i := 1; i < len(b); i += 4 { - s = append(s, math.Float32frombits(binary.LittleEndian.Uint32(b[i:i+4]))) + return v.bitData, v.bitPadding, true +} + +// Binary returns the BSON Binary of the Vector. +func (v Vector) Binary() Binary { + switch v.Type() { + case Int8Vector: + return binaryFromInt8Vector(v.Int8()) + case Float32Vector: + return binaryFromFloat32Vector(v.Float32()) + case PackedBitVector: + return binaryFromBitVector(v.PackedBit()) + default: + panic("invalid Vector type") } - v.Data = s - return v, nil } -func newBitVector(b []byte) (BitVector, error) { - var v BitVector - if len(b) == 0 { - return v, ErrInsufficientVectorData +func binaryFromInt8Vector(v []int8) Binary { + data := make([]byte, 2, len(v)+2) + copy(data, []byte{byte(Int8Vector), 0}) + for _, e := range v { + data = append(data, byte(e)) + } + + return Binary{ + Subtype: TypeBinaryVector, + Data: data, + } +} + +func binaryFromFloat32Vector(v []float32) Binary { + data := make([]byte, 2, len(v)*4+2) + copy(data, []byte{byte(Float32Vector), 0}) + var a [4]byte + for _, e := range v { + binary.LittleEndian.PutUint32(a[:], math.Float32bits(e)) + data = append(data, a[:]...) } - padding := b[0] + + return Binary{ + Subtype: TypeBinaryVector, + Data: data, + } +} + +func binaryFromBitVector(bits []byte, padding uint8) Binary { + data := []byte{byte(PackedBitVector), padding} + data = append(data, bits...) + return Binary{ + Subtype: TypeBinaryVector, + Data: data, + } +} + +// NewVector constructs a Vector from a slice of int8 or float32. +func NewVector[T int8 | float32](data []T) Vector { + var v Vector + switch a := any(data).(type) { + case []int8: + v.dType = Int8Vector + v.int8Data = append(v.int8Data, a...) + case []float32: + v.dType = Float32Vector + v.float32Data = append(v.float32Data, a...) + default: + panic(fmt.Errorf("unsupported type %T", data)) + } + return v +} + +// NewPackedBitVector constructs a Vector from a byte slice and a value of byte padding. +func NewPackedBitVector(bits []byte, padding uint8) (Vector, error) { + var v Vector if padding > 7 { - return v, ErrVectorPaddingTooLarge + return v, errVectorPaddingTooLarge } - if padding > 0 && len(b) == 1 { - return v, ErrNonZeroVectorPadding + if padding > 0 && len(bits) == 0 { + return v, errNonZeroVectorPadding } - v.Padding = padding - v.Data = b[1:] + v.dType = PackedBitVector + v.bitData = append(v.bitData, bits...) + v.bitPadding = padding return v, nil } // NewVectorFromBinary unpacks a BSON Binary into a Vector. -func NewVectorFromBinary(b Binary) (interface{}, error) { +func NewVectorFromBinary(b Binary) (Vector, error) { + var v Vector if b.Subtype != TypeBinaryVector { - return nil, ErrNotVector + return v, errors.New("not a vector") } if len(b.Data) < 2 { - return nil, ErrInsufficientVectorData + return v, errInsufficientVectorData } - switch t := b.Data[0]; t { + switch t := b.Data[0]; VectorDType(t) { case Int8Vector: return newInt8Vector(b.Data[1:]) case Float32Vector: @@ -110,64 +222,47 @@ func NewVectorFromBinary(b Binary) (interface{}, error) { case PackedBitVector: return newBitVector(b.Data[1:]) default: - return nil, fmt.Errorf("invalid Vector data type: %d", t) + return v, fmt.Errorf("invalid Vector data type: %d", t) } } -func binaryFromFloat32Vector(v Vector[float32]) (Binary, error) { - data := make([]byte, 2, len(v.Data)*4+2) - copy(data, []byte{Float32Vector, 0}) - var a [4]byte - for _, e := range v.Data { - binary.LittleEndian.PutUint32(a[:], math.Float32bits(e)) - data = append(data, a[:]...) +func newInt8Vector(b []byte) (Vector, error) { + var v Vector + if len(b) == 0 { + return v, errInsufficientVectorData } - - return Binary{ - Subtype: TypeBinaryVector, - Data: data, - }, nil -} - -func binaryFromInt8Vector(v Vector[int8]) (Binary, error) { - data := make([]byte, 2, len(v.Data)+2) - copy(data, []byte{Int8Vector, 0}) - for _, e := range v.Data { - data = append(data, byte(e)) + if padding := b[0]; padding > 0 { + return v, errNonZeroVectorPadding } - - return Binary{ - Subtype: TypeBinaryVector, - Data: data, - }, nil + s := make([]int8, 0, len(b)-1) + for i := 1; i < len(b); i++ { + s = append(s, int8(b[i])) + } + return NewVector(s), nil } -func binaryFromBitVector(v BitVector) (Binary, error) { - var b Binary - if v.Padding > 7 { - return b, ErrVectorPaddingTooLarge +func newFloat32Vector(b []byte) (Vector, error) { + var v Vector + if len(b) == 0 { + return v, errInsufficientVectorData } - if v.Padding > 0 && len(v.Data) == 0 { - return b, ErrNonZeroVectorPadding + if padding := b[0]; padding > 0 { + return v, errNonZeroVectorPadding } - data := []byte{PackedBitVector, v.Padding} - data = append(data, v.Data...) - return Binary{ - Subtype: TypeBinaryVector, - Data: data, - }, nil -} - -// NewBinaryFromVector converts a Vector into a BSON Binary. -func NewBinaryFromVector[T BitVector | Vector[int8] | Vector[float32]](v T) (Binary, error) { - switch a := any(v).(type) { - case Vector[int8]: - return binaryFromInt8Vector(a) - case Vector[float32]: - return binaryFromFloat32Vector(a) - case BitVector: - return binaryFromBitVector(a) - default: - return Binary{}, fmt.Errorf("unsupported type %T", v) + l := (len(b) - 1) / 4 + if l*4 != len(b)-1 { + return v, errInsufficientVectorData + } + s := make([]float32, 0, l) + for i := 1; i < len(b); i += 4 { + s = append(s, math.Float32frombits(binary.LittleEndian.Uint32(b[i:i+4]))) + } + return NewVector(s), nil +} + +func newBitVector(b []byte) (Vector, error) { + if len(b) == 0 { + return Vector{}, errInsufficientVectorData } + return NewPackedBitVector(b[1:], b[0]) }