From e3aefb92d1578600de5a3fee1b9e72eff1faf788 Mon Sep 17 00:00:00 2001 From: Alexey Palazhchenko Date: Wed, 25 Dec 2024 16:24:16 +0400 Subject: [PATCH] Add JSON marshalling and unmarshalling (#71) Co-authored-by: krishna sindhur --- wirebson/array.go | 50 ++++- wirebson/binary.go | 6 +- wirebson/bson_test.go | 430 +++++++++++++++++++++++++++++------------- wirebson/document.go | 50 ++++- 4 files changed, 399 insertions(+), 137 deletions(-) diff --git a/wirebson/array.go b/wirebson/array.go index 583178f..78c999b 100644 --- a/wirebson/array.go +++ b/wirebson/array.go @@ -17,11 +17,14 @@ package wirebson import ( "bytes" "encoding/binary" + "encoding/json" "iter" "log/slog" "sort" "strconv" + "go.mongodb.org/mongo-driver/v2/bson" + "github.com/FerretDB/wire/internal/util/lazyerrors" "github.com/FerretDB/wire/internal/util/must" ) @@ -177,6 +180,23 @@ func (arr *Array) Encode() (RawArray, error) { return buf.Bytes(), nil } +// MarshalJSON implements [json.Marshaler]. +func (arr *Array) MarshalJSON() ([]byte, error) { + must.NotBeZero(arr) + + a, err := toDriver(arr) + if err != nil { + return nil, lazyerrors.Error(err) + } + + b, err := bson.MarshalExtJSON(a, true, false) + if err != nil { + return nil, lazyerrors.Error(err) + } + + return b, nil +} + // Decode returns itself to implement [AnyArray]. // // Receiver must not be nil. @@ -185,6 +205,30 @@ func (arr *Array) Decode() (*Array, error) { return arr, nil } +// UnmarshalJSON implements [json.Unmarshaler]. +func (arr *Array) UnmarshalJSON(b []byte) error { + must.NotBeZero(arr) + + var a bson.A + if err := bson.UnmarshalExtJSON(b, true, &a); err != nil { + return lazyerrors.Error(err) + } + + v, err := fromDriver(a) + if err != nil { + return lazyerrors.Error(err) + } + + switch v := v.(type) { + case *Array: + must.NotBeZero(v) + *arr = *v + return nil + default: + return lazyerrors.Errorf("expected *Array, got %T", v) + } +} + // LogValue implements [slog.LogValuer]. func (arr *Array) LogValue() slog.Value { return slogValue(arr, 1) @@ -202,6 +246,8 @@ func (arr *Array) LogMessageIndent() string { // check interfaces var ( - _ AnyArray = (*Array)(nil) - _ slog.LogValuer = (*Array)(nil) + _ AnyArray = (*Array)(nil) + _ slog.LogValuer = (*Array)(nil) + _ json.Marshaler = (*Array)(nil) + _ json.Unmarshaler = (*Array)(nil) ) diff --git a/wirebson/binary.go b/wirebson/binary.go index c962ef9..03d2bcf 100644 --- a/wirebson/binary.go +++ b/wirebson/binary.go @@ -89,10 +89,8 @@ func decodeBinary(b []byte) (Binary, error) { res.Subtype = BinarySubtype(b[4]) - if i > 0 { - res.B = make([]byte, i) - copy(res.B, b[5:5+i]) - } + res.B = make([]byte, i) + copy(res.B, b[5:5+i]) return res, nil } diff --git a/wirebson/bson_test.go b/wirebson/bson_test.go index e9874ce..078cb59 100644 --- a/wirebson/bson_test.go +++ b/wirebson/bson_test.go @@ -16,6 +16,7 @@ package wirebson import ( "encoding/hex" + "encoding/json" "math" "strings" "testing" @@ -35,6 +36,7 @@ type normalTestCase struct { raw RawDocument doc *Document mi string + j string } // decodeTestCase represents a single test case for unsuccessful decoding. @@ -279,7 +281,7 @@ var normalTestCases = []normalTestCase{ "string", MustArray("foo", ""), "binary", MustArray( Binary{Subtype: BinaryUser, B: []byte{0x42}}, - Binary{}, + Binary{Subtype: BinaryGeneric, B: []byte{}}, ), "objectID", MustArray(ObjectID{0x42}, ObjectID{}), "bool", MustArray(true, false), @@ -361,6 +363,137 @@ var normalTestCases = []normalTestCase{ Decimal128(0,0), ], }`, + j: ` + { + "document": [ + { + "": "foo", + "bar": "baz", + "": "qux" + }, + {} + ], + "array": [ + [ + "foo" + ], + [] + ], + "float64": [ + { + "$numberDouble": "42.13" + }, + { + "$numberDouble": "0.0" + }, + { + "$numberDouble": "-0.0" + }, + { + "$numberDouble": "Infinity" + }, + { + "$numberDouble": "-Infinity" + } + ], + "string": [ + "foo", + "" + ], + "binary": [ + { + "$binary": { + "base64": "Qg==", + "subType": "80" + } + }, + { + "$binary": { + "base64": "", + "subType": "00" + } + } + ], + "objectID": [ + { + "$oid": "420000000000000000000000" + }, + { + "$oid": "000000000000000000000000" + } + ], + "bool": [ + true, + false + ], + "datetime": [ + { + "$date": { + "$numberLong": "1627378542123" + } + }, + { + "$date": { + "$numberLong": "-62135596800000" + } + } + ], + "null": [ + null + ], + "regex": [ + { + "$regularExpression": { + "pattern": "p", + "options": "o" + } + }, + { + "$regularExpression": { + "pattern": "", + "options": "" + } + } + ], + "int32": [ + { + "$numberInt": "42" + }, + { + "$numberInt": "0" + } + ], + "timestamp": [ + { + "$timestamp": { + "t": 0, + "i": 42 + } + }, + { + "$timestamp": { + "t": 0, + "i": 0 + } + } + ], + "int64": [ + { + "$numberLong": "42" + }, + { + "$numberLong": "0" + } + ], + "decimal128": [ + { + "$numberDecimal": "2.39807672958224171050E-6156" + }, + { + "$numberDecimal": "0E-6176" + } + ] + }`, }, { name: "nested", @@ -636,6 +769,10 @@ var normalTestCases = []normalTestCase{ { "foo": [], }`, + j: ` + { + "foo": [] + }`, }, { name: "duplicateKeys", @@ -654,6 +791,11 @@ var normalTestCases = []normalTestCase{ "": false, "": true, }`, + j: ` + { + "": false, + "": true + }`, }, } @@ -802,11 +944,19 @@ func TestNormal(t *testing.T) { assert.Equal(t, tc.raw, raw, "actual:\n"+hex.Dump(raw)) }) - t.Run("ToDriverFromDrive", func(t *testing.T) { - d, err := toDriver(tc.doc) + t.Run("MarshalUnmarshal", func(t *testing.T) { + // We should set all tc.j and remove this Skip. + // TODO https://github.com/FerretDB/wire/issues/49 + if tc.j == "" { + t.Skip("https://github.com/FerretDB/wire/issues/49") + } + + b, err := json.MarshalIndent(tc.doc, "", " ") require.NoError(t, err) + assert.Equal(t, testutil.Unindent(tc.j), string(b)) - doc, err := fromDriver(d) + var doc *Document + err = json.Unmarshal([]byte(tc.j), &doc) require.NoError(t, err) assert.Equal(t, tc.doc, doc) }) @@ -865,203 +1015,223 @@ var drain any func BenchmarkDocumentDecode(b *testing.B) { for _, tc := range normalTestCases { - b.Run(tc.name, func(b *testing.B) { - b.ReportAllocs() + if tc.raw != nil { + b.Run(tc.name, func(b *testing.B) { + b.ReportAllocs() - var err error - for range b.N { - drain, err = tc.raw.Decode() - } + var err error + for range b.N { + drain, err = tc.raw.Decode() + } - b.StopTimer() + b.StopTimer() - require.NoError(b, err) - require.NotNil(b, drain) - }) + require.NoError(b, err) + require.NotNil(b, drain) + }) + } } } func BenchmarkDocumentDecodeDeep(b *testing.B) { for _, tc := range normalTestCases { - b.Run(tc.name, func(b *testing.B) { - b.ReportAllocs() + if tc.raw != nil { + b.Run(tc.name, func(b *testing.B) { + b.ReportAllocs() - var err error - for range b.N { - drain, err = tc.raw.DecodeDeep() - } + var err error + for range b.N { + drain, err = tc.raw.DecodeDeep() + } - b.StopTimer() + b.StopTimer() - require.NoError(b, err) - require.NotNil(b, drain) - }) + require.NoError(b, err) + require.NotNil(b, drain) + }) + } } } func BenchmarkDocumentEncode(b *testing.B) { for _, tc := range normalTestCases { - b.Run(tc.name, func(b *testing.B) { - doc, err := tc.raw.Decode() - require.NoError(b, err) + if tc.raw != nil { + b.Run(tc.name, func(b *testing.B) { + doc, err := tc.raw.Decode() + require.NoError(b, err) - b.ReportAllocs() - b.ResetTimer() + b.ReportAllocs() + b.ResetTimer() - for range b.N { - drain, err = doc.Encode() - } + for range b.N { + drain, err = doc.Encode() + } - b.StopTimer() + b.StopTimer() - require.NoError(b, err) - assert.NotNil(b, drain) - }) + require.NoError(b, err) + assert.NotNil(b, drain) + }) + } } } func BenchmarkDocumentEncodeDeep(b *testing.B) { for _, tc := range normalTestCases { - b.Run(tc.name, func(b *testing.B) { - doc, err := tc.raw.DecodeDeep() - require.NoError(b, err) + if tc.raw != nil { + b.Run(tc.name, func(b *testing.B) { + doc, err := tc.raw.DecodeDeep() + require.NoError(b, err) - b.ReportAllocs() - b.ResetTimer() + b.ReportAllocs() + b.ResetTimer() - for range b.N { - drain, err = doc.Encode() - } + for range b.N { + drain, err = doc.Encode() + } - b.StopTimer() + b.StopTimer() - require.NoError(b, err) - assert.NotNil(b, drain) - }) + require.NoError(b, err) + assert.NotNil(b, drain) + }) + } } } func BenchmarkDocumentLogValue(b *testing.B) { for _, tc := range normalTestCases { - b.Run(tc.name, func(b *testing.B) { - doc, err := tc.raw.Decode() - require.NoError(b, err) + if tc.raw != nil { + b.Run(tc.name, func(b *testing.B) { + doc, err := tc.raw.Decode() + require.NoError(b, err) - b.ReportAllocs() - b.ResetTimer() + b.ReportAllocs() + b.ResetTimer() - for range b.N { - drain = doc.LogValue().Resolve().String() - } + for range b.N { + drain = doc.LogValue().Resolve().String() + } - b.StopTimer() + b.StopTimer() - assert.NotEmpty(b, drain) - assert.NotContains(b, drain, "panicked") - assert.NotContains(b, drain, "called too many times") - }) + assert.NotEmpty(b, drain) + assert.NotContains(b, drain, "panicked") + assert.NotContains(b, drain, "called too many times") + }) + } } } func BenchmarkDocumentLogValueDeep(b *testing.B) { for _, tc := range normalTestCases { - b.Run(tc.name, func(b *testing.B) { - doc, err := tc.raw.DecodeDeep() - require.NoError(b, err) + if tc.raw != nil { + b.Run(tc.name, func(b *testing.B) { + doc, err := tc.raw.DecodeDeep() + require.NoError(b, err) - b.ReportAllocs() - b.ResetTimer() + b.ReportAllocs() + b.ResetTimer() - for range b.N { - drain = doc.LogValue().Resolve().String() - } + for range b.N { + drain = doc.LogValue().Resolve().String() + } - b.StopTimer() + b.StopTimer() - assert.NotEmpty(b, drain) - assert.NotContains(b, drain, "panicked") - assert.NotContains(b, drain, "called too many times") - }) + assert.NotEmpty(b, drain) + assert.NotContains(b, drain, "panicked") + assert.NotContains(b, drain, "called too many times") + }) + } } } func BenchmarkDocumentLogMessage(b *testing.B) { for _, tc := range normalTestCases { - b.Run(tc.name, func(b *testing.B) { - doc, err := tc.raw.Decode() - require.NoError(b, err) + if tc.raw != nil { + b.Run(tc.name, func(b *testing.B) { + doc, err := tc.raw.Decode() + require.NoError(b, err) - b.ReportAllocs() - b.ResetTimer() + b.ReportAllocs() + b.ResetTimer() - for range b.N { - drain = doc.LogMessage() - } + for range b.N { + drain = doc.LogMessage() + } - b.StopTimer() + b.StopTimer() - assert.NotEmpty(b, drain) - }) + assert.NotEmpty(b, drain) + }) + } } } func BenchmarkDocumentLogMessageDeep(b *testing.B) { for _, tc := range normalTestCases { - b.Run(tc.name, func(b *testing.B) { - doc, err := tc.raw.DecodeDeep() - require.NoError(b, err) + if tc.raw != nil { + b.Run(tc.name, func(b *testing.B) { + doc, err := tc.raw.DecodeDeep() + require.NoError(b, err) - b.ReportAllocs() - b.ResetTimer() + b.ReportAllocs() + b.ResetTimer() - for range b.N { - drain = doc.LogMessage() - } + for range b.N { + drain = doc.LogMessage() + } - b.StopTimer() + b.StopTimer() - assert.NotEmpty(b, drain) - }) + assert.NotEmpty(b, drain) + }) + } } } func BenchmarkDocumentLogMessageIndent(b *testing.B) { for _, tc := range normalTestCases { - b.Run(tc.name, func(b *testing.B) { - doc, err := tc.raw.Decode() - require.NoError(b, err) + if tc.raw != nil { + b.Run(tc.name, func(b *testing.B) { + doc, err := tc.raw.Decode() + require.NoError(b, err) - b.ReportAllocs() - b.ResetTimer() + b.ReportAllocs() + b.ResetTimer() - for range b.N { - drain = doc.LogMessageIndent() - } + for range b.N { + drain = doc.LogMessageIndent() + } - b.StopTimer() + b.StopTimer() - assert.NotEmpty(b, drain) - }) + assert.NotEmpty(b, drain) + }) + } } } func BenchmarkDocumentLogMessageIndentDeep(b *testing.B) { for _, tc := range normalTestCases { - b.Run(tc.name, func(b *testing.B) { - doc, err := tc.raw.DecodeDeep() - require.NoError(b, err) + if tc.raw != nil { + b.Run(tc.name, func(b *testing.B) { + doc, err := tc.raw.DecodeDeep() + require.NoError(b, err) - b.ReportAllocs() - b.ResetTimer() + b.ReportAllocs() + b.ResetTimer() - for range b.N { - drain = doc.LogMessageIndent() - } + for range b.N { + drain = doc.LogMessageIndent() + } - b.StopTimer() + b.StopTimer() - assert.NotEmpty(b, drain) - }) + assert.NotEmpty(b, drain) + }) + } } } @@ -1121,24 +1291,27 @@ func testRawDocument(t *testing.T, rawDoc RawDocument) { assert.Equal(t, rawDoc, raw) }) - t.Run("ToDriverFromDriver", func(t *testing.T) { + t.Run("MarshalUnmarshal", func(t *testing.T) { doc, err := rawDoc.DecodeDeep() if err != nil { return } - d, err := toDriver(doc) - require.NoError(t, err) + b, err := json.Marshal(doc) + d, _ := toDriver(doc) + require.NoError(t, err, "%s\n%#v", doc.LogMessage(), d) - doc2, err := fromDriver(d) + var doc2 *Document + err = json.Unmarshal(b, &doc2) require.NoError(t, err) - assert.Equal(t, doc, doc2) }) } func FuzzDocument(f *testing.F) { for _, tc := range normalTestCases { - f.Add([]byte(tc.raw)) + if tc.raw != nil { + f.Add([]byte(tc.raw)) + } } for _, tc := range decodeTestCases { @@ -1148,11 +1321,10 @@ func FuzzDocument(f *testing.F) { f.Fuzz(func(t *testing.T, b []byte) { t.Parallel() - testRawDocument(t, RawDocument(b)) + rawDoc := RawDocument(b) - l, err := FindRaw(b) - if err == nil { - testRawDocument(t, RawDocument(b[:l])) - } + t.Run("TestRawDocument", func(t *testing.T) { + testRawDocument(t, rawDoc) + }) }) } diff --git a/wirebson/document.go b/wirebson/document.go index 53b15a5..3180d43 100644 --- a/wirebson/document.go +++ b/wirebson/document.go @@ -17,10 +17,13 @@ package wirebson import ( "bytes" "encoding/binary" + "encoding/json" "iter" "log/slog" "slices" + "go.mongodb.org/mongo-driver/v2/bson" + "github.com/FerretDB/wire/internal/util/lazyerrors" "github.com/FerretDB/wire/internal/util/must" ) @@ -253,6 +256,23 @@ func (doc *Document) Encode() (RawDocument, error) { return buf.Bytes(), nil } +// MarshalJSON implements [json.Marshaler]. +func (doc *Document) MarshalJSON() ([]byte, error) { + must.NotBeZero(doc) + + d, err := toDriver(doc) + if err != nil { + return nil, lazyerrors.Error(err) + } + + b, err := bson.MarshalExtJSON(d, true, false) + if err != nil { + return nil, lazyerrors.Error(err) + } + + return b, nil +} + // Decode returns itself to implement [AnyDocument]. // // Receiver must not be nil. @@ -261,6 +281,30 @@ func (doc *Document) Decode() (*Document, error) { return doc, nil } +// UnmarshalJSON implements [json.Unmarshaler]. +func (doc *Document) UnmarshalJSON(b []byte) error { + must.NotBeZero(doc) + + var d bson.D + if err := bson.UnmarshalExtJSON(b, true, &d); err != nil { + return lazyerrors.Error(err) + } + + v, err := fromDriver(d) + if err != nil { + return lazyerrors.Error(err) + } + + switch v := v.(type) { + case *Document: + must.NotBeZero(v) + *doc = *v + return nil + default: + return lazyerrors.Errorf("expected *Document, got %T", v) + } +} + // LogValue implements [slog.LogValuer]. func (doc *Document) LogValue() slog.Value { return slogValue(doc, 1) @@ -278,6 +322,8 @@ func (doc *Document) LogMessageIndent() string { // check interfaces var ( - _ AnyDocument = (*Document)(nil) - _ slog.LogValuer = (*Document)(nil) + _ AnyDocument = (*Document)(nil) + _ slog.LogValuer = (*Document)(nil) + _ json.Marshaler = (*Document)(nil) + _ json.Unmarshaler = (*Document)(nil) )