diff --git a/bson/bson_binary_vector_test.go b/bson/bson_binary_vector_test.go new file mode 100644 index 0000000000..8303bc8db6 --- /dev/null +++ b/bson/bson_binary_vector_test.go @@ -0,0 +1,209 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "math" + "os" + "path" + "testing" + + "go.mongodb.org/mongo-driver/v2/internal/require" +) + +const bsonBinaryVectorDir = "../testdata/bson-binary-vector/" + +type bsonBinaryVectorTests struct { + Description string `json:"description"` + TestKey string `json:"test_key"` + Tests []bsonBinaryVectorTestCase `json:"tests"` +} + +type bsonBinaryVectorTestCase struct { + Description string `json:"description"` + Valid bool `json:"valid"` + Vector []interface{} `json:"vector"` + DtypeHex string `json:"dtype_hex"` + DtypeAlias string `json:"dtype_alias"` + Padding int `json:"padding"` + CanonicalBson string `json:"canonical_bson"` +} + +func Test_BsonBinaryVector(t *testing.T) { + t.Parallel() + + jsonFiles, err := findJSONFilesInDir(bsonBinaryVectorDir) + require.NoErrorf(t, err, "error finding JSON files in %s: %v", bsonBinaryVectorDir, err) + + for _, file := range jsonFiles { + filepath := path.Join(bsonBinaryVectorDir, file) + content, err := os.ReadFile(filepath) + require.NoErrorf(t, err, "reading test file %s", filepath) + + var tests bsonBinaryVectorTests + require.NoErrorf(t, json.Unmarshal(content, &tests), "parsing test file %s", filepath) + + t.Run(tests.Description, func(t *testing.T) { + t.Parallel() + + for _, test := range tests.Tests { + test := test + t.Run(test.Description, func(t *testing.T) { + t.Parallel() + + runBsonBinaryVectorTest(t, tests.TestKey, test) + }) + } + }) + } + + t.Run("Insufficient vector data FLOAT32", func(t *testing.T) { + t.Parallel() + + val := Binary{Subtype: TypeBinaryVector} + + for _, tc := range [][]byte{ + {Float32Vector, 0, 42}, + {Float32Vector, 0, 42, 42}, + {Float32Vector, 0, 42, 42, 42}, + + {Float32Vector, 0, 42, 42, 42, 42, 42}, + {Float32Vector, 0, 42, 42, 42, 42, 42, 42}, + {Float32Vector, 0, 42, 42, 42, 42, 42, 42, 42}, + } { + t.Run(fmt.Sprintf("marshaling %d bytes", len(tc)-2), func(t *testing.T) { + val.Data = tc + b, err := Marshal(D{{"vector", val}}) + require.NoError(t, err, "marshaling test BSON") + var got struct { + Vector Vector[float32] + } + err = Unmarshal(b, &got) + require.ErrorContains(t, err, ErrInsufficientData.Error()) + }) + } + }) + + t.Run("Padding specified with no vector data PACKED_BIT", func(t *testing.T) { + t.Parallel() + + t.Run("Marshaling", func(t *testing.T) { + val := BitVector{Padding: 1} + _, err := Marshal(val) + require.EqualError(t, err, ErrNonZeroPadding.Error()) + }) + t.Run("Unmarshaling", func(t *testing.T) { + val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 1}}}} + b, err := Marshal(val) + require.NoError(t, err, "marshaling test BSON") + var got struct { + Vector Vector[float32] + } + err = Unmarshal(b, &got) + require.ErrorContains(t, err, ErrNonZeroPadding.Error()) + }) + }) + + t.Run("Exceeding maximum padding PACKED_BIT", func(t *testing.T) { + t.Parallel() + + t.Run("Marshaling", func(t *testing.T) { + val := BitVector{Padding: 8} + _, err := Marshal(val) + require.EqualError(t, err, ErrPaddingTooLarge.Error()) + }) + t.Run("Unmarshaling", func(t *testing.T) { + val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 8}}}} + b, err := Marshal(val) + require.NoError(t, err, "marshaling test BSON") + var got struct { + Vector Vector[float32] + } + err = Unmarshal(b, &got) + require.ErrorContains(t, err, ErrPaddingTooLarge.Error()) + }) + }) +} + +func convertSlice[T int8 | float32 | byte](s []interface{}) []T { + v := make([]T, len(s)) + for i, e := range s { + f := math.NaN() + switch v := e.(type) { + case float64: + f = v + case string: + if v == "inf" { + f = math.Inf(0) + } else if v == "-inf" { + f = math.Inf(-1) + } + } + v[i] = T(f) + } + return v +} + +func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVectorTestCase) { + if !test.Valid { + t.Skipf("skip invalid case %s", test.Description) + } + + var testVector interface{} + switch alias := test.DtypeHex; alias { + case "0x03": + testVector = map[string]Vector[int8]{ + testKey: {convertSlice[int8](test.Vector)}, + } + case "0x27": + testVector = map[string]Vector[float32]{ + testKey: {convertSlice[float32](test.Vector)}, + } + case "0x10": + testVector = map[string]BitVector{ + testKey: { + Padding: uint8(test.Padding), + Data: convertSlice[byte](test.Vector), + }, + } + default: + t.Fatalf("unsupported vector type: %s", alias) + } + + testBSON, err := hex.DecodeString(test.CanonicalBson) + require.NoError(t, err, "decoding canonical BSON") + + t.Run("Unmarshaling", func(t *testing.T) { + t.Parallel() + + var got interface{} + switch alias := test.DtypeHex; alias { + case "0x03": + got = make(map[string]Vector[int8]) + case "0x27": + got = make(map[string]Vector[float32]) + case "0x10": + got = make(map[string]BitVector) + default: + t.Fatalf("unsupported type: %s", alias) + } + err := Unmarshal(testBSON, got) + require.NoError(t, err) + require.Equal(t, testVector, got) + }) + + t.Run("Marshaling", func(t *testing.T) { + t.Parallel() + + got, err := Marshal(testVector) + require.NoError(t, err) + require.Equal(t, testBSON, got) + }) +} diff --git a/bson/bson_corpus_spec_test.go b/bson/bson_corpus_spec_test.go index a0d5a5aa38..043aa2f019 100644 --- a/bson/bson_corpus_spec_test.go +++ b/bson/bson_corpus_spec_test.go @@ -217,7 +217,7 @@ func normalizeRelaxedDouble(t *testing.T, key string, rEJ string) string { func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D { var doc D err := Unmarshal(b, &doc) - expectNoError(t, err, fmt.Sprintf("%s: decoding %s BSON", testDesc, bType)) + require.NoErrorf(t, err, "%s: decoding %s BSON", testDesc, bType) return doc } @@ -225,7 +225,7 @@ func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D { // canonical BSON (cB) func nativeToBSON(t *testing.T, cB []byte, doc D, testDesc, bType, docSrcDesc string) { actual, err := Marshal(doc) - expectNoError(t, err, fmt.Sprintf("%s: encoding %s BSON", testDesc, bType)) + require.NoErrorf(t, err, "%s: encoding %s BSON", testDesc, bType) if diff := cmp.Diff(cB, actual); diff != "" { t.Errorf("%s: 'native_to_bson(%s) = cB' failed (-want, +got):\n-%v\n+%v\n", @@ -261,7 +261,7 @@ func jsonToBytes(ej, ejType, testDesc string) ([]byte, error) { // nativeToJSON encodes the native Document (doc) into an extended JSON string func nativeToJSON(t *testing.T, ej string, doc D, testDesc, ejType, ejShortName, docSrcDesc string) { actualEJ, err := MarshalExtJSON(doc, ejType != "relaxed", true) - expectNoError(t, err, fmt.Sprintf("%s: encoding %s extended JSON", testDesc, ejType)) + require.NoErrorf(t, err, "%s: encoding %s extended JSON", testDesc, ejType) if diff := cmp.Diff(ej, string(actualEJ)); diff != "" { t.Errorf("%s: 'native_to_%s_extended_json(%s) = %s' failed (-want, +got):\n%s\n", @@ -288,7 +288,7 @@ func runTest(t *testing.T, file string) { t.Run(v.Description, func(t *testing.T) { // get canonical BSON cB, err := hex.DecodeString(v.CanonicalBson) - expectNoError(t, err, fmt.Sprintf("%s: reading canonical BSON", v.Description)) + require.NoErrorf(t, err, "%s: reading canonical BSON", v.Description) // get canonical extended JSON var compactEJ bytes.Buffer @@ -341,7 +341,7 @@ func runTest(t *testing.T, file string) { /*** degenerate BSON round-trip tests (if exists) ***/ if v.DegenerateBSON != nil { dB, err := hex.DecodeString(*v.DegenerateBSON) - expectNoError(t, err, fmt.Sprintf("%s: reading degenerate BSON", v.Description)) + require.NoErrorf(t, err, "%s: reading degenerate BSON", v.Description) doc = bsonToNative(t, dB, "degenerate", v.Description) @@ -377,7 +377,7 @@ func runTest(t *testing.T, file string) { for _, d := range test.DecodeErrors { t.Run(d.Description, func(t *testing.T) { b, err := hex.DecodeString(d.Bson) - expectNoError(t, err, d.Description) + require.NoError(t, err, d.Description) var doc D err = Unmarshal(b, &doc) @@ -392,12 +392,12 @@ func runTest(t *testing.T, file string) { invalidDBPtr := ok && !utf8.ValidString(dbPtr.DB) if invalidString || invalidDBPtr { - expectNoError(t, err, d.Description) + require.NoError(t, err, d.Description) return } } - expectError(t, err, fmt.Sprintf("%s: expected decode error", d.Description)) + require.Errorf(t, err, "%s: expected decode error", d.Description) }) } }) @@ -418,7 +418,7 @@ func runTest(t *testing.T, file string) { if strings.Contains(p.Description, "Null") { _, err = Marshal(doc) } - expectError(t, err, fmt.Sprintf("%s: expected parse error", p.Description)) + require.Errorf(t, err, "%s: expected parse error", p.Description) default: t.Errorf("Update test to check for parse errors for type %s", test.BsonType) t.Fail() @@ -431,31 +431,13 @@ func runTest(t *testing.T, file string) { func Test_BsonCorpus(t *testing.T) { jsonFiles, err := findJSONFilesInDir(dataDir) - if err != nil { - t.Fatalf("error finding JSON files in %s: %v", dataDir, err) - } + require.NoErrorf(t, err, "error finding JSON files in %s: %v", dataDir, err) for _, file := range jsonFiles { runTest(t, file) } } -func expectNoError(t *testing.T, err error, desc string) { - if err != nil { - t.Helper() - t.Errorf("%s: Unepexted error: %v", desc, err) - t.FailNow() - } -} - -func expectError(t *testing.T, err error, desc string) { - if err == nil { - t.Helper() - t.Errorf("%s: Expected error", desc) - t.FailNow() - } -} - func TestRelaxedUUIDValidation(t *testing.T) { testCases := []struct { description string diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index 2f195329ca..fa47613dee 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -42,6 +42,9 @@ func registerDefaultDecoders(reg *Registry) { reg.RegisterTypeDecoder(tD, ValueDecoderFunc(dDecodeValue)) reg.RegisterTypeDecoder(tBinary, decodeAdapter{binaryDecodeValue, binaryDecodeType}) + reg.RegisterTypeDecoder(tInt8Vector, decodeAdapter{vectorDecodeValue, vectorDecodeType}) + reg.RegisterTypeDecoder(tFloat32Vector, decodeAdapter{vectorDecodeValue, vectorDecodeType}) + reg.RegisterTypeDecoder(tBitVector, decodeAdapter{vectorDecodeValue, vectorDecodeType}) reg.RegisterTypeDecoder(tUndefined, decodeAdapter{undefinedDecodeValue, undefinedDecodeType}) reg.RegisterTypeDecoder(tDateTime, decodeAdapter{dateTimeDecodeValue, dateTimeDecodeType}) reg.RegisterTypeDecoder(tNull, decodeAdapter{nullDecodeValue, nullDecodeType}) @@ -501,14 +504,8 @@ func symbolDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) er return nil } -func binaryDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - if t != tBinary { - return emptyValue, ValueDecoderError{ - Name: "BinaryDecodeValue", - Types: []reflect.Type{tBinary}, - Received: reflect.Zero(t), - } - } +func binaryDecode(vr ValueReader) (Binary, error) { + var b Binary var data []byte var subtype byte @@ -521,13 +518,31 @@ func binaryDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect. case TypeUndefined: err = vr.ReadUndefined() default: - return emptyValue, fmt.Errorf("cannot decode %v into a Binary", vrType) + return b, fmt.Errorf("cannot decode %v into a Binary", vrType) } if err != nil { - return emptyValue, err + return b, err + } + b.Subtype = subtype + b.Data = data + + return b, nil +} + +func binaryDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tBinary { + return emptyValue, ValueDecoderError{ + Name: "BinaryDecodeValue", + Types: []reflect.Type{tBinary}, + Received: reflect.Zero(t), + } } - return reflect.ValueOf(Binary{Subtype: subtype, Data: data}), nil + b, err := binaryDecode(vr) + if err != nil { + return emptyValue, err + } + return reflect.ValueOf(b), nil } // binaryDecodeValue is the ValueDecoderFunc for Binary. @@ -545,6 +560,48 @@ func binaryDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) erro return nil } +func vectorDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tInt8Vector && t != tFloat32Vector && t != tBitVector { + return emptyValue, ValueDecoderError{ + Name: "VectorDecodeValue", + Types: []reflect.Type{tInt8Vector, tFloat32Vector, tBitVector}, + Received: reflect.Zero(t), + } + } + + b, err := binaryDecode(vr) + if err != nil { + return emptyValue, err + } + + v, err := NewVectorFromBinary(b) + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(v), nil +} + +// vectorDecodeValue is the ValueDecoderFunc for Vector. +func vectorDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { + t := val.Type() + if !val.CanSet() || (t != tInt8Vector && t != tFloat32Vector && t != tBitVector) { + return ValueDecoderError{ + Name: "VectorDecodeValue", + Types: []reflect.Type{tInt8Vector, tFloat32Vector, tBitVector}, + Received: val, + } + } + + elem, err := vectorDecodeType(dctx, vr, t) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + func undefinedDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tUndefined { return emptyValue, ValueDecoderError{ diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index 9835738be3..2141c3c91d 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -70,6 +70,9 @@ func registerDefaultEncoders(reg *Registry) { reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue)) reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue)) reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue)) + reg.RegisterTypeEncoder(tInt8Vector, ValueEncoderFunc(vectorEncodeValue)) + reg.RegisterTypeEncoder(tFloat32Vector, ValueEncoderFunc(vectorEncodeValue)) + reg.RegisterTypeEncoder(tBitVector, ValueEncoderFunc(vectorEncodeValue)) reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue)) reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue)) reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue)) @@ -364,6 +367,32 @@ func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } +// vectorEncodeValue is the ValueEncoderFunc for Vector. +func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { + t := val.Type() + if !val.IsValid() || (t != tInt8Vector && t != tFloat32Vector && t != tBitVector) { + return ValueEncoderError{Name: "VectorEncodeValue", + Types: []reflect.Type{tInt8Vector, tFloat32Vector, tBitVector}, + Received: val, + } + } + var b Binary + var err error + switch v := val.Interface().(type) { + case Vector[int8]: + b, err = NewBinaryFromVector(v) + case Vector[float32]: + b, err = NewBinaryFromVector(v) + case BitVector: + b, err = NewBinaryFromVector(v) + } + if err != nil { + return err + } + + return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) +} + // undefinedEncodeValue is the ValueEncoderFunc for Undefined. func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tUndefined { diff --git a/bson/extjson_parser_test.go b/bson/extjson_parser_test.go index 51a7de7aef..ae6bbf7bb3 100644 --- a/bson/extjson_parser_test.go +++ b/bson/extjson_parser_test.go @@ -45,6 +45,22 @@ type readKeyValueTestCase struct { valEFs []expectedErrorFunc } +func expectNoError(t *testing.T, err error, desc string) { + if err != nil { + t.Helper() + t.Errorf("%s: Unepexted error: %v", desc, err) + t.FailNow() + } +} + +func expectError(t *testing.T, err error, desc string) { + if err == nil { + t.Helper() + t.Errorf("%s: Expected error", desc) + t.FailNow() + } +} + func expectSpecificError(expected error) expectedErrorFunc { return func(t *testing.T, err error, desc string) { if !errors.Is(err, expected) { diff --git a/bson/json_scanner_test.go b/bson/json_scanner_test.go index 58f6e64594..b46d5aac9c 100644 --- a/bson/json_scanner_test.go +++ b/bson/json_scanner_test.go @@ -12,6 +12,7 @@ import ( "testing/iotest" "github.com/google/go-cmp/cmp" + "go.mongodb.org/mongo-driver/v2/internal/require" ) func jttDiff(t *testing.T, expected, actual jsonTokenType, desc string) { @@ -289,7 +290,7 @@ func TestJsonScannerValidInputs(t *testing.T) { for _, token := range tc.tokens { c, err := js.nextToken() - expectNoError(t, err, tc.desc) + require.NoError(t, err, tc.desc) jttDiff(t, token.t, c.t, tc.desc) jtvDiff(t, token.v, c.v, tc.desc) } @@ -303,7 +304,7 @@ func TestJsonScannerValidInputs(t *testing.T) { for _, token := range tc.tokens { c, err := js.nextToken() - expectNoError(t, err, tc.desc) + require.NoError(t, err, tc.desc) jttDiff(t, token.t, c.t, tc.desc) jtvDiff(t, token.v, c.v, tc.desc) } @@ -354,7 +355,7 @@ func TestJsonScannerInvalidInputs(t *testing.T) { c, err := js.nextToken() expectNilToken(t, c, tc.desc) - expectError(t, err, tc.desc) + require.Error(t, err, tc.desc) }) } } diff --git a/bson/types.go b/bson/types.go index dedc95a596..2550098cca 100644 --- a/bson/types.go +++ b/bson/types.go @@ -72,6 +72,7 @@ const ( TypeBinaryEncrypted byte = 0x06 TypeBinaryColumn byte = 0x07 TypeBinarySensitive byte = 0x08 + TypeBinaryVector byte = 0x09 TypeBinaryUserDefined byte = 0x80 ) @@ -106,6 +107,9 @@ var tJavaScript = reflect.TypeOf(JavaScript("")) var tSymbol = reflect.TypeOf(Symbol("")) var tTimestamp = reflect.TypeOf(Timestamp{}) var tDecimal = reflect.TypeOf(Decimal128{}) +var tInt8Vector = reflect.TypeOf(Vector[int8]{}) +var tFloat32Vector = reflect.TypeOf(Vector[float32]{}) +var tBitVector = reflect.TypeOf(BitVector{}) var tMinKey = reflect.TypeOf(MinKey{}) var tMaxKey = reflect.TypeOf(MaxKey{}) var tD = reflect.TypeOf(D{}) diff --git a/bson/vector.go b/bson/vector.go new file mode 100644 index 0000000000..d72c77793d --- /dev/null +++ b/bson/vector.go @@ -0,0 +1,166 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "encoding/binary" + "errors" + "fmt" + "math" +) + +// These constants indicate vector data types. +const ( + Int8Vector = 0x03 + Float32Vector = 0x27 + PackedBitVector = 0x10 +) + +var ( + ErrNotVector = errors.New("not a vector") + ErrInsufficientData = errors.New("insufficient data") + ErrNonZeroPadding = errors.New("padding must be 0") + ErrPaddingTooLarge = errors.New("padding larger than 7") +) + +type Vector[T int8 | float32] struct { + Data []T +} + +type BitVector struct { + Padding uint8 + Data []byte +} + +func newInt8Vector(b []byte) (Vector[int8], error) { + var v Vector[int8] + if len(b) == 0 { + return v, ErrInsufficientData + } + if padding := b[0]; padding > 0 { + return v, ErrNonZeroPadding + } + s := make([]int8, 0, len(b)-1) + for i := 1; i < len(b); i++ { + s = append(s, int8(b[i])) + } + v.Data = s + return v, nil +} + +func newFloat32Vector(b []byte) (Vector[float32], error) { + var v Vector[float32] + if len(b) == 0 { + return v, ErrInsufficientData + } + if padding := b[0]; padding > 0 { + return v, ErrNonZeroPadding + } + l := (len(b) - 1) / 4 + if l*4 != len(b)-1 { + return v, ErrInsufficientData + } + s := make([]float32, 0, l) + for i := 1; i < len(b); i += 4 { + s = append(s, math.Float32frombits(binary.LittleEndian.Uint32(b[i:i+4]))) + } + v.Data = s + return v, nil +} + +func newBitVector(b []byte) (BitVector, error) { + var v BitVector + if len(b) == 0 { + return v, ErrInsufficientData + } + padding := b[0] + if padding > 7 { + return v, ErrPaddingTooLarge + } + if padding > 0 && len(b) == 1 { + return v, ErrNonZeroPadding + } + v.Padding = padding + v.Data = b[1:] + return v, nil +} + +func NewVectorFromBinary(b Binary) (interface{}, error) { + if b.Subtype != TypeBinaryVector { + return nil, ErrNotVector + } + if len(b.Data) < 2 { + return nil, ErrInsufficientData + } + switch t := b.Data[0]; t { + case Int8Vector: + return newInt8Vector(b.Data[1:]) + case Float32Vector: + return newFloat32Vector(b.Data[1:]) + case PackedBitVector: + return newBitVector(b.Data[1:]) + default: + return nil, fmt.Errorf("invalid Vector data type: %d", t) + } +} + +func binaryFromFloat32Vector(v Vector[float32]) (Binary, error) { + data := make([]byte, 2, len(v.Data)*4+2) + copy(data, []byte{Float32Vector, 0}) + var a [4]byte + for _, e := range v.Data { + binary.LittleEndian.PutUint32(a[:], math.Float32bits(e)) + data = append(data, a[:]...) + } + + return Binary{ + Subtype: TypeBinaryVector, + Data: data, + }, nil +} + +func binaryFromInt8Vector(v Vector[int8]) (Binary, error) { + data := make([]byte, 2, len(v.Data)+2) + copy(data, []byte{Int8Vector, 0}) + for _, e := range v.Data { + data = append(data, byte(e)) + } + + return Binary{ + Subtype: TypeBinaryVector, + Data: data, + }, nil +} + +func binaryFromBitVector(v BitVector) (Binary, error) { + var b Binary + if v.Padding > 7 { + return b, ErrPaddingTooLarge + } + if v.Padding > 0 && len(v.Data) == 0 { + return b, ErrNonZeroPadding + } + data := []byte{PackedBitVector, v.Padding} + data = append(data, v.Data...) + return Binary{ + Subtype: TypeBinaryVector, + Data: data, + }, nil +} + +func NewBinaryFromVector[T BitVector | Vector[int8] | Vector[float32]](v T) (Binary, error) { + switch a := any(v); a.(type) { + case Vector[int8]: + return binaryFromInt8Vector(a.(Vector[int8])) + case Vector[float32]: + return binaryFromFloat32Vector(a.(Vector[float32])) + case BitVector: + return binaryFromBitVector(a.(BitVector)) + default: + return Binary{}, fmt.Errorf("unsupported type %T", v) + } +} diff --git a/testdata/bson-binary-vector/float32.json b/testdata/bson-binary-vector/float32.json new file mode 100644 index 0000000000..d423f9e2bd --- /dev/null +++ b/testdata/bson-binary-vector/float32.json @@ -0,0 +1,50 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype FLOAT32", + "test_key": "vector", + "tests": [ + { + "description": "Simple Vector FLOAT32", + "valid": true, + "vector": [127.0, 7.0], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1C00000005766563746F72000A0000000927000000FE420000E04000" + }, + { + "description": "Vector with decimals and negative value FLOAT32", + "valid": true, + "vector": [127.7, -7.7], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1C00000005766563746F72000A0000000927006666FF426666F6C000" + }, + { + "description": "Empty Vector FLOAT32", + "valid": true, + "vector": [], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009270000" + }, + { + "description": "Infinity Vector FLOAT32", + "valid": true, + "vector": ["-inf", 0.0, "inf"], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "2000000005766563746F72000E000000092700000080FF000000000000807F00" + }, + { + "description": "FLOAT32 with padding", + "valid": false, + "vector": [127.0, 7.0], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 3 + } + ] +} diff --git a/testdata/bson-binary-vector/int8.json b/testdata/bson-binary-vector/int8.json new file mode 100644 index 0000000000..d849819992 --- /dev/null +++ b/testdata/bson-binary-vector/int8.json @@ -0,0 +1,56 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype INT8", + "test_key": "vector", + "tests": [ + { + "description": "Simple Vector INT8", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0, + "canonical_bson": "1600000005766563746F7200040000000903007F0700" + }, + { + "description": "Empty Vector INT8", + "valid": true, + "vector": [], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009030000" + }, + { + "description": "Overflow Vector INT8", + "valid": false, + "vector": [128], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + }, + { + "description": "Underflow Vector INT8", + "valid": false, + "vector": [-129], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + }, + { + "description": "INT8 with padding", + "valid": false, + "vector": [127, 7], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 3 + }, + { + "description": "INT8 with float inputs", + "valid": false, + "vector": [127.77, 7.77], + "dtype_hex": "0x03", + "dtype_alias": "INT8", + "padding": 0 + } + ] +} diff --git a/testdata/bson-binary-vector/packed_bit.json b/testdata/bson-binary-vector/packed_bit.json new file mode 100644 index 0000000000..0d5dae52b4 --- /dev/null +++ b/testdata/bson-binary-vector/packed_bit.json @@ -0,0 +1,97 @@ +{ + "description": "Tests of Binary subtype 9, Vectors, with dtype PACKED_BIT", + "test_key": "vector", + "tests": [ + { + "description": "Padding specified with no vector data PACKED_BIT", + "valid": false, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 1 + }, + { + "description": "Simple Vector PACKED_BIT", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0, + "canonical_bson": "1600000005766563746F7200040000000910007F0700" + }, + { + "description": "Empty Vector PACKED_BIT", + "valid": true, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0, + "canonical_bson": "1400000005766563746F72000200000009100000" + }, + { + "description": "PACKED_BIT with padding", + "valid": true, + "vector": [127, 7], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 3, + "canonical_bson": "1600000005766563746F7200040000000910037F0700" + }, + { + "description": "Overflow Vector PACKED_BIT", + "valid": false, + "vector": [256], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Underflow Vector PACKED_BIT", + "valid": false, + "vector": [-1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Vector with float values PACKED_BIT", + "valid": false, + "vector": [127.5], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Padding specified with no vector data PACKED_BIT", + "valid": false, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 1 + }, + { + "description": "Exceeding maximum padding PACKED_BIT", + "valid": false, + "vector": [1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 8 + }, + { + "description": "Negative padding PACKED_BIT", + "valid": false, + "vector": [1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": -1 + }, + { + "description": "Vector with float values PACKED_BIT", + "valid": false, + "vector": [127.5], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + } + ] +} diff --git a/testdata/bson-corpus/binary.json b/testdata/bson-corpus/binary.json index 20aaef743b..0e0056f3a2 100644 --- a/testdata/bson-corpus/binary.json +++ b/testdata/bson-corpus/binary.json @@ -74,6 +74,36 @@ "description": "$type query operator (conflicts with legacy $binary form with $type field)", "canonical_bson": "180000000378001000000010247479706500020000000000", "canonical_extjson": "{\"x\" : { \"$type\" : {\"$numberInt\": \"2\"}}}" + }, + { + "description": "subtype 0x09 Vector FLOAT32", + "canonical_bson": "170000000578000A0000000927000000FE420000E04000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"JwAAAP5CAADgQA==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector INT8", + "canonical_bson": "11000000057800040000000903007F0700", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"AwB/Bw==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector PACKED_BIT", + "canonical_bson": "11000000057800040000000910007F0700", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"EAB/Bw==\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) FLOAT32", + "canonical_bson": "0F0000000578000200000009270000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"JwA=\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) INT8", + "canonical_bson": "0F0000000578000200000009030000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"AwA=\", \"subType\": \"09\"}}}" + }, + { + "description": "subtype 0x09 Vector (Zero-length) PACKED_BIT", + "canonical_bson": "0F0000000578000200000009100000", + "canonical_extjson": "{\"x\": {\"$binary\": {\"base64\": \"EAA=\", \"subType\": \"09\"}}}" } ], "decodeErrors": [