From 851ffa91e37db2c1764e1696ac11a2cb3ca20167 Mon Sep 17 00:00:00 2001 From: Roger Peppe Date: Mon, 17 Feb 2020 15:34:09 +0000 Subject: [PATCH] avro: more efficient SingleEncoder This will make only a single call to the registry for a given type. Also add a `CheckMarshalType` method that can be used to check that a type is OK with a `SingleEncoder` and also populate the cache in advance. --- singleencoder.go | 34 +++++++++++++--- singleencoder_test.go | 93 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 120 insertions(+), 7 deletions(-) diff --git a/singleencoder.go b/singleencoder.go index cc0a538..ba916c7 100644 --- a/singleencoder.go +++ b/singleencoder.go @@ -3,6 +3,7 @@ package avro import ( "context" "reflect" + "sync" ) // EncodingRegistry is used by SingleEncoder to find @@ -21,6 +22,8 @@ type EncodingRegistry interface { type SingleEncoder struct { registry EncodingRegistry names *Names + // ids holds a map from Go type (reflect.Type) to schema ID (int64) + ids sync.Map } // NewSingleEncoder returns a SingleEncoder instance that encodes single @@ -39,16 +42,20 @@ func NewSingleEncoder(r EncodingRegistry, names *Names) *SingleEncoder { } } +// CheckMarshalType checks that the given type can be marshaled with the encoder. +// It also caches any type information obtained from the EncodingRegistry from the +// type, so future calls to Marshal with that type won't call it. +func (enc *SingleEncoder) CheckMarshalType(ctx context.Context, x interface{}) error { + _, err := enc.idForType(ctx, reflect.TypeOf(x)) + return err +} + // Marshal returns x marshaled as using the Avro binary encoding, // along with an identifier that records the type that it was encoded // with. func (enc *SingleEncoder) Marshal(ctx context.Context, x interface{}) ([]byte, error) { xv := reflect.ValueOf(x) - avroType, err := avroTypeOf(enc.names, xv.Type()) - if err != nil { - return nil, err - } - id, err := enc.registry.IDForSchema(ctx, avroType.String()) + id, err := enc.idForType(ctx, xv.Type()) if err != nil { return nil, err } @@ -57,3 +64,20 @@ func (enc *SingleEncoder) Marshal(ctx context.Context, x interface{}) ([]byte, e data, _, err := marshalAppend(enc.names, buf, xv) return data, err } + +func (enc *SingleEncoder) idForType(ctx context.Context, t reflect.Type) (int64, error) { + id, ok := enc.ids.Load(t) + if ok { + return id.(int64), nil + } + avroType, err := avroTypeOf(enc.names, t) + if err != nil { + return 0, err + } + id1, err := enc.registry.IDForSchema(ctx, avroType.String()) + if err != nil { + return 0, err + } + enc.ids.LoadOrStore(t, id1) + return id1, nil +} diff --git a/singleencoder_test.go b/singleencoder_test.go index 0477635..d018c73 100644 --- a/singleencoder_test.go +++ b/singleencoder_test.go @@ -2,6 +2,7 @@ package avro_test import ( "context" + "sync" "testing" qt "github.com/frankban/quicktest" @@ -11,8 +12,7 @@ import ( func TestSingleEncoder(t *testing.T) { c := qt.New(t) - avroType, err := avro.TypeOf(TestRecord{}) - c.Assert(err, qt.Equals, nil) + avroType := mustTypeOf(TestRecord{}) registry := memRegistry{ 1: avroType.String(), } @@ -28,3 +28,92 @@ func TestSingleEncoder(t *testing.T) { c.Assert(err, qt.Equals, nil) c.Assert(x, qt.DeepEquals, TestRecord{A: 20, B: 34}) } + +func TestSingleEncoderCheckMarshalTypeBadType(t *testing.T) { + c := qt.New(t) + enc := avro.NewSingleEncoder(memRegistry{}, nil) + err := enc.CheckMarshalType(context.Background(), struct{ C chan int }{}) + c.Assert(err, qt.ErrorMatches, `cannot use unnamed type struct .*`) +} + +func TestSingleEncoderCheckMarshalTypeNotFound(t *testing.T) { + c := qt.New(t) + enc := avro.NewSingleEncoder(memRegistry{}, nil) + err := enc.CheckMarshalType(context.Background(), TestRecord{}) + c.Assert(err, qt.ErrorMatches, `schema not found`) +} + +func TestSingleEncoderCachesTypes(t *testing.T) { + c := qt.New(t) + registry := &statsRegistry{ + memRegistry: memRegistry{ + 1: mustTypeOf(TestRecord{}).String(), + }, + } + enc := avro.NewSingleEncoder(registry, nil) + data, err := enc.Marshal(context.Background(), TestRecord{A: 20, B: 34}) + c.Assert(err, qt.Equals, nil) + c.Assert(data, qt.DeepEquals, []byte{1, 40, 68}) + + // Check that when we marshal it again that we don't get another + // call to the registry. + data, err = enc.Marshal(context.Background(), TestRecord{A: 22, B: 35}) + c.Assert(err, qt.Equals, nil) + c.Assert(data, qt.DeepEquals, []byte{1, 44, 70}) + c.Assert(registry.idForSchemaCount, qt.Equals, 1) +} + +func TestSingleEncoderRace(t *testing.T) { + // Note: this test is designed to be run with the + // race detector enabled. + + c := qt.New(t) + + type T1 struct { + A int + } + type T2 struct { + B int + } + registry := memRegistry{ + 1: mustTypeOf(T1{}).String(), + 2: mustTypeOf(T2{}).String(), + } + enc := avro.NewSingleEncoder(registry, nil) + var wg sync.WaitGroup + marshal := func(x interface{}) { + defer wg.Done() + _, err := enc.Marshal(context.Background(), x) + c.Check(err, qt.Equals, nil) + } + wg.Add(3) + go marshal(T1{10}) + go marshal(T1{11}) + go marshal(T2{12}) + wg.Wait() +} + +// statsRegistry wraps a memRegistry instance and counts calls to some calls. +type statsRegistry struct { + idForSchemaCount int + schemaForIDCount int + memRegistry +} + +func (r *statsRegistry) IDForSchema(ctx context.Context, schema string) (int64, error) { + r.idForSchemaCount++ + return r.memRegistry.IDForSchema(ctx, schema) +} + +func (r *statsRegistry) SchemaForID(ctx context.Context, id int64) (string, error) { + r.schemaForIDCount++ + return r.memRegistry.SchemaForID(ctx, id) +} + +func mustTypeOf(x interface{}) *avro.Type { + t, err := avro.TypeOf(x) + if err != nil { + panic(err) + } + return t +}