Skip to content

Commit

Permalink
fix Unmarshal buffer corruption
Browse files Browse the repository at this point in the history
Also speed up SingleDecoder a lot (we'd forgotten to cache
the compile results) and unmarshaling in general (the debugf
calls were really slowing things down).

Here are the new benchmark results after these changes have been
made:

```
BenchmarkMarshal-4                  	10628875	       551 ns/op	     240 B/op	       4 allocs/op
BenchmarkSingleDecoderUnmarshal-4   	 3596359	      1619 ns/op	     272 B/op	      14 allocs/op
```
  • Loading branch information
rogpeppe committed Jan 27, 2020
1 parent a1153d7 commit 7b64504
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 20 deletions.
38 changes: 30 additions & 8 deletions analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ func compileDecoder(names *Names, t reflect.Type, writerType *Type) (*decodeProg
if err != nil {
return nil, fmt.Errorf("cannot determine schema for %s: %v", t, err)
}
if debugging {
debugf("compiling:\nwriter type: %s\nreader type: %s\n", writerType, readerType)
}
prog, err := compiler.Compile(writerType.avroType, readerType.avroType)
if err != nil {
return nil, fmt.Errorf("cannot create decoder: %v", err)
Expand All @@ -91,7 +94,9 @@ func analyzeProgramTypes(prog *vm.Program, t reflect.Type) (*decodeProgram, erro
enter: make([]func(reflect.Value) (reflect.Value, bool), len(prog.Instructions)),
makeDefault: make([]func() reflect.Value, len(prog.Instructions)),
}
debugf("analyze %d instructions\n%s {", len(prog.Instructions), prog)
if debugging {
debugf("analyze %d instructions; type %s\n%s {", len(prog.Instructions), t, prog)
}
defer debugf("}")
info, err := newAzTypeInfo(t)
if err != nil {
Expand Down Expand Up @@ -126,9 +131,13 @@ func analyzeProgramTypes(prog *vm.Program, t reflect.Type) (*decodeProgram, erro
}

func (a *analyzer) eval(stack []int, path []pathElem) (retErr error) {
debugf("eval %v; path %s{", stack, pathStr(path))
if debugging {
debugf("analyzer.eval %v; path %s{", stack, pathStr(path))
}
defer func() {
debugf("} -> %v", retErr)
if debugging {
debugf("} -> %v", retErr)
}
}()
for {
pc := stack[len(stack)-1]
Expand All @@ -140,7 +149,9 @@ func (a *analyzer) eval(stack []int, path []pathElem) (retErr error) {
// of the current path.
a.pcInfo[pc].path = append(a.pcInfo[pc].path, path...)
} else {
debugf("already evaluated instruction %d", pc)
if debugging {
debugf("already evaluated instruction %d", pc)
}
// We've already visited this instruction which
// means we can stop analysing here.
// Make sure that the path is consistent though,
Expand All @@ -150,12 +161,17 @@ func (a *analyzer) eval(stack []int, path []pathElem) (retErr error) {
}
return nil
}
debugf("exec %d: %v", pc, a.prog.Instructions[pc])
if debugging {
debugf("exec %d: %v", pc, a.prog.Instructions[pc])
}

elem := path[len(path)-1]
switch inst := a.prog.Instructions[pc]; inst.Op {
case vm.Set:
if elem.info.isUnion {
if debugging {
debugf("patching Set to Nop")
}
// Set on a union type is just to set the type of the union,
// which is implicit with the next Enter, so we want to just
// ignore the instruction, so replace it with a jump to the next instruction,
Expand All @@ -179,7 +195,9 @@ func (a *analyzer) eval(stack []int, path []pathElem) (retErr error) {
return fmt.Errorf("union index out of bounds; pc %d; type %s", pc, elem.ftype)
}
info := elem.info.entries[index]
debugf("enter %d -> %v, %d entries", index, info.ftype, len(info.entries))
if debugging {
debugf("enter %d -> %v, %d entries", index, info.ftype, len(info.entries))
}
if info.ftype == nil {
// Special case for the nil value. Return
// a zero value that will never be used.
Expand Down Expand Up @@ -270,7 +288,9 @@ func (a *analyzer) eval(stack []int, path []pathElem) (retErr error) {
}
stack = stack[:len(stack)-1]
case vm.CondJump:
debugf("split {")
if debugging {
debugf("split {")
}
// Execute one path of the condition with a forked
// version of the state before carrying on with the
// current execution flow.
Expand All @@ -282,7 +302,9 @@ func (a *analyzer) eval(stack []int, path []pathElem) (retErr error) {
if err := a.eval(stack1, path1); err != nil {
return err
}
debugf("}")
if debugging {
debugf("}")
}
case vm.Jump:
stack[len(stack)-1] = inst.Operand - 1
case vm.EvalGreater,
Expand Down
74 changes: 74 additions & 0 deletions bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package avro_test

import (
"context"
"testing"

qt "github.com/frankban/quicktest"
"github.com/heetch/avro"
)

func BenchmarkMarshal(b *testing.B) {
type R struct {
A *string
B *string
C []int
}
type T struct {
R R
}
x := T{
R: R{
A: newString("hello"),
B: newString("goodbye"),
C: []int{1, 3, 1 << 20},
},
}
for i := 0; i < b.N; i++ {
_, _, err := avro.Marshal(x)
if err != nil {
b.Fatal(err)
}
}
}

func BenchmarkSingleDecoderUnmarshal(b *testing.B) {
c := qt.New(b)
type R struct {
A *string
B *string
C []int
}
type T struct {
R R
}
at, err := avro.TypeOf(T{})
c.Assert(err, qt.Equals, nil)
r := memRegistry{
1: at.String(),
}
enc := avro.NewSingleEncoder(r, nil)
ctx := context.Background()
data, err := enc.Marshal(ctx, T{
R: R{
A: newString("hello"),
B: newString("goodbye"),
C: []int{1, 3, 1 << 20},
},
})
c.Assert(err, qt.Equals, nil)

dec := avro.NewSingleDecoder(r, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
var x T
_, err := dec.Unmarshal(ctx, data, &x)
if err != nil {
b.Fatal(err)
}
}
}

func newString(s string) *string {
return &s
}
31 changes: 22 additions & 9 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ type decodeError struct {
// unmarshal unmarshals Avro binary data from r and writes it to target
// following the given program.
func unmarshal(r io.Reader, buf []byte, prog *decodeProgram, target reflect.Value) (_ *Type, err error) {
if debugging {
debugf("unmarshal %x into %s", buf, target.Type())
}
defer func() {
switch panicErr := recover().(type) {
case *decodeError:
Expand All @@ -92,15 +95,19 @@ func unmarshal(r io.Reader, buf []byte, prog *decodeProgram, target reflect.Valu
}

func (d *decoder) eval(target reflect.Value) {
if target.IsValid() {
debugf("eval %s", target.Type())
} else {
debugf("eval nil")
if debugging {
if target.IsValid() {
debugf("eval %s", target.Type())
} else {
debugf("eval nil")
}
defer debugf("}")
}
defer debugf("}")
var frame stackFrame
for ; d.pc < len(d.program.Instructions); d.pc++ {
debugf("x %d: %v", d.pc, d.program.Instructions[d.pc])
if debugging {
debugf("x %d: %v", d.pc, d.program.Instructions[d.pc])
}
switch inst := d.program.Instructions[d.pc]; inst.Op {
case vm.Read:
switch inst.Operand {
Expand All @@ -126,7 +133,9 @@ func (d *decoder) eval(target reflect.Value) {
frame.Bytes = d.readFixed(inst.Operand - 11)
}
case vm.Set:
debugf("%v on %s", inst, target.Type())
if debugging {
debugf("%v on %s", inst, target.Type())
}
switch inst.Operand {
case vm.Null:
case vm.Boolean:
Expand Down Expand Up @@ -170,14 +179,18 @@ func (d *decoder) eval(target reflect.Value) {
target.Field(inst.Operand).Set(v)
case vm.Enter:
val, isRef := d.program.enter[d.pc](target)
debugf("enter %d -> %#v (isRef %v) {", inst.Operand, val, isRef)
if debugging {
debugf("enter %d -> %#v (isRef %v) {", inst.Operand, val, isRef)
}
d.pc++
d.eval(val)
if !isRef {
target.Set(val)
}
case vm.Exit:
debugf("}")
if debugging {
debugf("}")
}
return
case vm.AppendArray:
target.Set(reflect.Append(target, reflect.Zero(target.Type().Elem())))
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.2.2 h1:xfmOhhoH5fGPgbEAlhLpJH9p0z/0Qizio9osmvn9IUY=
github.com/frankban/quicktest v1.2.2/go.mod h1:Qh/WofXFeiAFII1aEBu529AtJo6Zg2VHscnEsbBnJ20=
github.com/frankban/quicktest v1.7.2 h1:2QxQoC1TS09S7fhCPsrvqYdvP1H5M1P1ih5ABm3BTYk=
github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.2.1-0.20190312032427-6f77996f0c42 h1:q3pnF5JFBNRz8sRD+IRj7Y6DMyYGTNqnZ9axTbSfoNI=
github.com/google/go-cmp v0.2.1-0.20190312032427-6f77996f0c42/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
Expand Down
26 changes: 26 additions & 0 deletions gotype_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,32 @@ func TestProtobufGeneratedType(t *testing.T) {
}`))
}

func TestUnmarshalDoesNotCorruptData(t *testing.T) {
c := qt.New(t)
type R struct {
A *string
B *string
}
type T struct {
R R
}
x := T{
R: R{
A: newString("hello"),
B: newString("goodbye"),
},
}
data, at, err := avro.Marshal(x)
c.Assert(err, qt.Equals, nil)
origData := data
var x1 T
_, err = avro.Unmarshal(data, &x1, at)
c.Assert(err, qt.Equals, nil)
_, err = avro.Unmarshal(data, &x1, at)
c.Assert(err, qt.Equals, nil)
c.Assert(data, qt.DeepEquals, []byte(origData))
}

type OOBPanicEnum int

var enumValues = []string{"a", "b"}
Expand Down
7 changes: 7 additions & 0 deletions reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ func (d *decoder) fill(n int) int {
if len(d.buf)-d.scan >= n {
return n
}
if d.readErr != nil {
// If there's an error, there's no point in doing
// anything more. This is also crucial to avoid
// corrupting the buffer when it has been provided by a
// caller.
return len(d.buf) - d.scan
}
// Slide any remaining bytes to the
// start of the buffer.
total := copy(d.buf, d.buf[d.scan:])
Expand Down
4 changes: 4 additions & 0 deletions singledecoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ func (c *SingleDecoder) getProgram(ctx context.Context, vt reflect.Type, wID int
c.mu.RUnlock()
return prog, nil
}
if debugging {
debugf("no hit found for program %T schemaID %v", vt, wID)
}
wType := c.writerTypes[wID]
c.mu.RUnlock()

Expand Down Expand Up @@ -142,5 +145,6 @@ func (c *SingleDecoder) getProgram(ctx context.Context, vt reflect.Type, wID int
}
return nil, err
}
c.programs[decoderSchemaPair{vt, wID}] = prog
return prog, nil
}
12 changes: 9 additions & 3 deletions typeinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ type azTypeInfo struct {
}

func newAzTypeInfo(t reflect.Type) (azTypeInfo, error) {
debugf("azTypeInfo(%v)", t)
if debugging {
debugf("azTypeInfo(%v)", t)
}
switch t.Kind() {
case reflect.Struct:
info := azTypeInfo{
Expand Down Expand Up @@ -84,12 +86,16 @@ func newAzTypeInfo(t reflect.Type) (azTypeInfo, error) {
entry := newAzTypeInfoFromField(f, required, makeDefault, unionInfo)
info.entries = append(info.entries, entry)
}
debugf("-> record, %d entries", len(info.entries))
if debugging {
debugf("-> record, %d entries", len(info.entries))
}
return info, nil
default:
// TODO check for top-level union types too.
// See https://github.com/heetch/avro/issues/13
debugf("-> unknown")
if debugging {
debugf("-> unknown")
}
return azTypeInfo{
ftype: t,
}, nil
Expand Down

0 comments on commit 7b64504

Please sign in to comment.