From 1d2df5f0e3d02e7ee7967531220da57552384b3d Mon Sep 17 00:00:00 2001 From: Eran Duchan Date: Sun, 16 Feb 2020 17:50:04 +0200 Subject: [PATCH] Support v3io stream consumer groups (#48) --- go.mod | 4 +- go.sum | 10 +- pkg/common/backoff.go | 99 +++++ pkg/common/helper.go | 142 +++++++ pkg/dataplane/container.go | 6 + pkg/dataplane/http/container.go | 14 + pkg/dataplane/http/context.go | 72 +++- pkg/dataplane/http/headers.go | 23 +- pkg/dataplane/item.go | 31 +- pkg/dataplane/streamconsumergroup/claim.go | 204 +++++++++ pkg/dataplane/streamconsumergroup/config.go | 54 +++ .../sequencenumberhandler.go | 237 +++++++++++ pkg/dataplane/streamconsumergroup/session.go | 92 ++++ pkg/dataplane/streamconsumergroup/state.go | 34 ++ .../streamconsumergroup/statehandler.go | 398 ++++++++++++++++++ .../streamconsumergroup/statehandler_test.go | 82 ++++ .../streamconsumergroup.go | 142 +++++++ pkg/dataplane/streamconsumergroup/types.go | 61 +++ .../test/streamconsumergroup_test.go | 368 ++++++++++++++++ pkg/dataplane/test/sync_test.go | 51 +-- pkg/dataplane/test/test.go | 61 +++ pkg/dataplane/types.go | 48 ++- pkg/errors/errors.go | 3 +- 23 files changed, 2140 insertions(+), 96 deletions(-) mode change 100755 => 100644 go.mod mode change 100755 => 100644 go.sum create mode 100644 pkg/common/backoff.go create mode 100644 pkg/common/helper.go create mode 100644 pkg/dataplane/streamconsumergroup/claim.go create mode 100644 pkg/dataplane/streamconsumergroup/config.go create mode 100644 pkg/dataplane/streamconsumergroup/sequencenumberhandler.go create mode 100644 pkg/dataplane/streamconsumergroup/session.go create mode 100644 pkg/dataplane/streamconsumergroup/state.go create mode 100644 pkg/dataplane/streamconsumergroup/statehandler.go create mode 100644 pkg/dataplane/streamconsumergroup/statehandler_test.go create mode 100644 pkg/dataplane/streamconsumergroup/streamconsumergroup.go create mode 100644 pkg/dataplane/streamconsumergroup/types.go create mode 100644 pkg/dataplane/test/streamconsumergroup_test.go diff --git a/go.mod b/go.mod old mode 100755 new mode 100644 index 386152d..8cb35ca --- a/go.mod +++ b/go.mod @@ -4,15 +4,17 @@ go 1.12 require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/kylelemons/godebug v1.1.0 // indirect github.com/mattn/go-colorable v0.1.1 // indirect github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect github.com/nuclio/errors v0.0.1 github.com/nuclio/logger v0.0.0-20190303161055-fc1e4b16d127 github.com/nuclio/zap v0.0.2 github.com/pavius/zap v1.4.2-0.20180228181622-8d52692529b8 // indirect + github.com/philhofer/fwd v1.0.0 // indirect github.com/pkg/errors v0.8.1 // indirect + github.com/rs/xid v1.1.0 github.com/stretchr/testify v1.3.0 + github.com/tinylib/msgp v1.1.1 // indirect github.com/valyala/fasthttp v1.2.0 go.uber.org/atomic v1.3.2 // indirect go.uber.org/multierr v1.1.0 // indirect diff --git a/go.sum b/go.sum old mode 100755 new mode 100644 index c2515a8..740bbad --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/capnproto/go-capnproto2 v2.17.0+incompatible h1:vPbYlc2CBNdjzOMzHfwo7TbFNRBDaRKitlWiRs1riTw= -github.com/capnproto/go-capnproto2 v2.17.0+incompatible/go.mod h1:T3/pxeK0qevFRlAASYZe90Ozs+JmlQTNY+VLc6+lJHw= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -8,8 +6,6 @@ github.com/klauspost/compress v1.4.0 h1:8nsMz3tWa9SWWPL60G1V6CUsf4lLjWLTNEtibhe8 github.com/klauspost/compress v1.4.0/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/cpuid v0.0.0-20180405133222-e7e905edc00e h1:+lIPJOWl+jSiJOc70QXJ07+2eg2Jy2EC7Mi11BWujeM= github.com/klauspost/cpuid v0.0.0-20180405133222-e7e905edc00e/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= -github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/mattn/go-colorable v0.1.1 h1:G1f5SKeVxmagw/IyvzvtZE4Gybcc4Tr1tf7I8z0XgOg= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-isatty v0.0.5 h1:tHXDdz1cpzGaovsTB+TVB8q90WEokoVmfMqoVcrLUgw= @@ -24,13 +20,19 @@ github.com/nuclio/zap v0.0.2 h1:rY5PkMOl8CTkqRqIPuxziBiKK6Mq/8oEurfgRnNtqf0= github.com/nuclio/zap v0.0.2/go.mod h1:SUxPsgePvlyjx6c5MtGdB50pf0IQThtlyLwISLboeuc= github.com/pavius/zap v1.4.2-0.20180228181622-8d52692529b8 h1:WqLgmr/wj9TO5Sc6oYPQRAJBxuHE0NTeuVeFnT+FZVo= github.com/pavius/zap v1.4.2-0.20180228181622-8d52692529b8/go.mod h1:6FWOCx06uh50GClv8S2cfk3asqTJs3qq3ZNRtLZE77I= +github.com/philhofer/fwd v1.0.0 h1:UbZqGr5Y38ApvM/V/jEljVxwocdweyH+vmYvRPBnbqQ= +github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.1.0 h1:9Z322kTPrDR5GpxTH+1yl7As6tEHIH9aGsRccl20ELk= +github.com/rs/xid v1.1.0/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/tinylib/msgp v1.1.1 h1:TnCZ3FIuKeaIy+F45+Cnp+caqdXGy4z74HvwXN+570Y= +github.com/tinylib/msgp v1.1.1/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.2.0 h1:dzZJf2IuMiclVjdw0kkT+f9u4YdrapbNyGAN47E/qnk= diff --git a/pkg/common/backoff.go b/pkg/common/backoff.go new file mode 100644 index 0000000..9bbc316 --- /dev/null +++ b/pkg/common/backoff.go @@ -0,0 +1,99 @@ +package common + +import ( + "math" + "math/rand" + "sync/atomic" + "time" +) + +// Backoff is a time.Duration counter, starting at Min. After every call to +// the Duration method the current timing is multiplied by Factor, but it +// never exceeds Max. +// +// Backoff is not generally concurrent-safe, but the ForAttempt method can +// be used concurrently. +type Backoff struct { + attempt uint64 + // Factor is the multiplying factor for each increment step + Factor float64 + // Jitter eases contention by randomizing backoff steps + Jitter bool + // Min and Max are the minimum and maximum values of the counter + Min, Max time.Duration +} + +// Duration returns the duration for the current attempt before incrementing +// the attempt counter. See ForAttempt. +func (b *Backoff) Duration() time.Duration { + d := b.ForAttempt(float64(atomic.AddUint64(&b.attempt, 1) - 1)) + return d +} + +const maxInt64 = float64(math.MaxInt64 - 512) + +// ForAttempt returns the duration for a specific attempt. This is useful if +// you have a large number of independent Backoffs, but don't want use +// unnecessary memory storing the Backoff parameters per Backoff. The first +// attempt should be 0. +// +// ForAttempt is concurrent-safe. +func (b *Backoff) ForAttempt(attempt float64) time.Duration { + // Zero-values are nonsensical, so we use + // them to apply defaults + min := b.Min + if min <= 0 { + min = 100 * time.Millisecond + } + max := b.Max + if max <= 0 { + max = 10 * time.Second + } + if min >= max { + // short-circuit + return max + } + factor := b.Factor + if factor <= 0 { + factor = 2 + } + //calculate this duration + minf := float64(min) + durf := minf * math.Pow(factor, attempt) + if b.Jitter { + durf = rand.Float64()*(durf-minf) + minf + } + //ensure float64 wont overflow int64 + if durf > maxInt64 { + return max + } + dur := time.Duration(durf) + //keep within bounds + if dur < min { + return min + } + if dur > max { + return max + } + return dur +} + +// Reset restarts the current attempt counter at zero. +func (b *Backoff) Reset() { + atomic.StoreUint64(&b.attempt, 0) +} + +// Attempt returns the current attempt counter value. +func (b *Backoff) Attempt() float64 { + return float64(atomic.LoadUint64(&b.attempt)) +} + +// Copy returns a backoff with equals constraints as the original +func (b *Backoff) Copy() *Backoff { + return &Backoff{ + Factor: b.Factor, + Jitter: b.Jitter, + Min: b.Min, + Max: b.Max, + } +} diff --git a/pkg/common/helper.go b/pkg/common/helper.go new file mode 100644 index 0000000..126f630 --- /dev/null +++ b/pkg/common/helper.go @@ -0,0 +1,142 @@ +package common + +import ( + "context" + "reflect" + "runtime" + "time" + + "github.com/nuclio/errors" + "github.com/nuclio/logger" +) + +func getFunctionName(fn interface{}) string { + return runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() +} + +// give either retryInterval or backoff +func RetryFunc(ctx context.Context, + loggerInstance logger.Logger, + attempts int, + retryInterval *time.Duration, + backoff *Backoff, + fn func(int) (bool, error)) error { + + var err error + var retry bool + + for attempt := 1; attempt <= attempts; attempt++ { + retry, err = fn(attempt) + + // if there's no need to retry - we're done + if !retry { + return err + } + + // are we out of time? + if ctx.Err() != nil { + + loggerInstance.WarnWithCtx(ctx, + "Context error detected during retries", + "ctxErr", ctx.Err(), + "previousErr", err, + "function", getFunctionName(fn), + "attempt", attempt) + + // return the error if one was provided + if err != nil { + return err + } + + return ctx.Err() + } + + if backoff != nil { + time.Sleep(backoff.Duration()) + } else { + if retryInterval == nil { + return errors.New("Either retry interval or backoff must be given") + } + time.Sleep(*retryInterval) + } + } + + // attempts exhausted and we're unsuccessful + // Return the original error for later checking + loggerInstance.WarnWithCtx(ctx, + "Failed final attempt to invoke function", + "function", getFunctionName(fn), + "err", err, + "attempts", attempts) + + // this shouldn't happen + if err == nil { + loggerInstance.ErrorWithCtx(ctx, + "Failed final attempt to invoke function, but error is nil. This shouldn't happen", + "function", getFunctionName(fn), + "err", err, + "attempts", attempts) + return errors.New("Failed final attempt to invoke function without proper error supplied") + } + return err +} + +func MakeRange(min, max int) []int { + a := make([]int, max-min+1) + for i := range a { + a[i] = min + i + } + return a +} + +func IntSliceContainsInt(slice []int, number int) bool { + for _, intInSlice := range slice { + if intInSlice == number { + return true + } + } + + return false +} + +func IntSlicesEqual(slice1 []int, slice2 []int) bool { + if len(slice1) != len(slice2) { + return false + } + + for intIndex := 0; intIndex < len(slice1); intIndex++ { + if slice1[intIndex] != slice2[intIndex] { + return false + } + } + + return true +} + +func Uint64SlicesEqual(slice1 []uint64, slice2 []uint64) bool { + if len(slice1) != len(slice2) { + return false + } + + for intIndex := 0; intIndex < len(slice1); intIndex++ { + if slice1[intIndex] != slice2[intIndex] { + return false + } + } + + return true +} + +func StringSlicesEqual(slice1 []string, slice2 []string) bool { + if len(slice1) != len(slice2) { + return false + } + + for stringIndex := 0; stringIndex < len(slice1); stringIndex++ { + if slice1[stringIndex] != slice2[stringIndex] { + return false + } + } + + return true +} diff --git a/pkg/dataplane/container.go b/pkg/dataplane/container.go index 05157ed..c2dbff3 100644 --- a/pkg/dataplane/container.go +++ b/pkg/dataplane/container.go @@ -101,6 +101,12 @@ type Container interface { // CreateStreamSync CreateStreamSync(*CreateStreamInput) error + // DescribeStream + DescribeStream(*DescribeStreamInput, interface{}, chan *Response) (*Request, error) + + // DescribeStreamSync + DescribeStreamSync(*DescribeStreamInput) (*Response, error) + // DeleteStream DeleteStream(*DeleteStreamInput, interface{}, chan *Response) (*Request, error) diff --git a/pkg/dataplane/http/container.go b/pkg/dataplane/http/container.go index 941029b..cd5d26d 100644 --- a/pkg/dataplane/http/container.go +++ b/pkg/dataplane/http/container.go @@ -178,6 +178,20 @@ func (c *container) CreateStreamSync(createStreamInput *v3io.CreateStreamInput) return c.session.context.CreateStreamSync(createStreamInput) } +// DescribeStream +func (c *container) DescribeStream(describeStreamInput *v3io.DescribeStreamInput, + context interface{}, + responseChan chan *v3io.Response) (*v3io.Request, error) { + c.populateInputFields(&describeStreamInput.DataPlaneInput) + return c.session.context.DescribeStream(describeStreamInput, context, responseChan) +} + +// DescribeStreamSync +func (c *container) DescribeStreamSync(describeStreamInput *v3io.DescribeStreamInput) (*v3io.Response, error) { + c.populateInputFields(&describeStreamInput.DataPlaneInput) + return c.session.context.DescribeStreamSync(describeStreamInput) +} + // DeleteStream func (c *container) DeleteStream(deleteStreamInput *v3io.DeleteStreamInput, context interface{}, responseChan chan *v3io.Response) (*v3io.Request, error) { c.populateInputFields(&deleteStreamInput.DataPlaneInput) diff --git a/pkg/dataplane/http/context.go b/pkg/dataplane/http/context.go index d465427..38ad1c4 100755 --- a/pkg/dataplane/http/context.go +++ b/pkg/dataplane/http/context.go @@ -204,8 +204,6 @@ func (c *context) GetItemSync(getItemInput *v3io.GetItemInput) (*v3io.Response, Item map[string]map[string]interface{} }{} - c.logger.DebugWithCtx(getItemInput.Ctx, "Body", "body", string(response.Body())) - // unmarshal the body err = json.Unmarshal(response.Body(), &item) if err != nil { @@ -280,7 +278,7 @@ func (c *context) GetItemsSync(getItemsInput *v3io.GetItemsInput) (*v3io.Respons } headers := getItemsHeadersCapnp - if getItemsInput.RequestJsonResponse { + if getItemsInput.RequestJSONResponse { headers = getItemsHeaders } @@ -518,6 +516,40 @@ func (c *context) CreateStreamSync(createStreamInput *v3io.CreateStreamInput) er return err } +// DescribeStream +func (c *context) DescribeStream(describeStreamInput *v3io.DescribeStreamInput, + context interface{}, + responseChan chan *v3io.Response) (*v3io.Request, error) { + return c.sendRequestToWorker(describeStreamInput, context, responseChan) +} + +// DescribeStreamSync +func (c *context) DescribeStreamSync(describeStreamInput *v3io.DescribeStreamInput) (*v3io.Response, error) { + response, err := c.sendRequest(&describeStreamInput.DataPlaneInput, + http.MethodPut, + describeStreamInput.Path, + "", + describeStreamHeaders, + nil, + false) + if err != nil { + return nil, err + } + + describeStreamOutput := v3io.DescribeStreamOutput{} + + // unmarshal the body into an ad hoc structure + err = json.Unmarshal(response.Body(), &describeStreamOutput) + if err != nil { + return nil, err + } + + // set the output in the response + response.Output = &describeStreamOutput + + return response, nil +} + // DeleteStream func (c *context) DeleteStream(deleteStreamInput *v3io.DeleteStreamInput, context interface{}, @@ -575,7 +607,7 @@ func (c *context) SeekShardSync(seekShardInput *v3io.SeekShardInput) (*v3io.Resp if seekShardInput.Type == v3io.SeekShardInputTypeSequence { buffer.WriteString(`, "StartingSequenceNumber": `) - buffer.WriteString(strconv.Itoa(seekShardInput.StartingSequenceNumber)) + buffer.WriteString(strconv.FormatUint(seekShardInput.StartingSequenceNumber, 10)) } else if seekShardInput.Type == v3io.SeekShardInputTypeTime { buffer.WriteString(`, "TimestampSec": `) buffer.WriteString(strconv.Itoa(seekShardInput.Timestamp)) @@ -866,11 +898,11 @@ func (c *context) sendRequest(dataPlaneInput *v3io.DataPlaneInput, request.Header.Add(headerName, headerValue) } - c.logger.DebugWithCtx(dataPlaneInput.Ctx, - "Tx", - "uri", uriStr, - "method", method, - "body-length", len(body)) + //c.logger.DebugWithCtx(dataPlaneInput.Ctx, + // "Tx", + // "uri", uriStr, + // "method", method, + // "body-length", len(body)) if dataPlaneInput.Timeout <= 0 { err = c.httpClient.Do(request, response.HTTPResponse) @@ -885,14 +917,14 @@ func (c *context) sendRequest(dataPlaneInput *v3io.DataPlaneInput, statusCode = response.HTTPResponse.StatusCode() { - contentLength := response.HTTPResponse.Header.ContentLength() - if contentLength < 0 { - contentLength = 0 - } - c.logger.DebugWithCtx(dataPlaneInput.Ctx, - "Rx", - "statusCode", statusCode, - "Content-Length", contentLength) + //contentLength := response.HTTPResponse.Header.ContentLength() + //if contentLength < 0 { + // contentLength = 0 + //} + //c.logger.DebugWithCtx(dataPlaneInput.Ctx, + // "Rx", + // "statusCode", statusCode, + // "Content-Length", contentLength) } // did we get a 2xx response? @@ -938,7 +970,7 @@ func (c *context) buildRequestURI(urlString string, containerName string, query if strings.HasSuffix(pathStr, "/") { uri.Path += "/" // retain trailing slash } - uri.RawQuery = strings.ReplaceAll(query, " ", "%20") + uri.RawQuery = strings.Replace(query, " ", "%20", -1) return uri, nil } @@ -959,6 +991,8 @@ func (c *context) encodeTypedAttributes(attributes map[string]interface{}) (map[ return nil, fmt.Errorf("unexpected attribute type for %s: %T", attributeName, reflect.TypeOf(attributeValue)) case int: typedAttributes[attributeName]["N"] = strconv.Itoa(value) + case uint64: + typedAttributes[attributeName]["N"] = strconv.FormatUint(value, 10) case int64: typedAttributes[attributeName]["N"] = strconv.FormatInt(value, 10) // this is a tmp bypass to the fact Go maps Json numbers to float64 @@ -1116,6 +1150,8 @@ func (c *context) workerEntry(workerIndex int) { err = c.UpdateItemSync(typedInput) case *v3io.CreateStreamInput: err = c.CreateStreamSync(typedInput) + case *v3io.DescribeStreamInput: + response, err = c.DescribeStreamSync(typedInput) case *v3io.DeleteStreamInput: err = c.DeleteStreamSync(typedInput) case *v3io.GetRecordsInput: diff --git a/pkg/dataplane/http/headers.go b/pkg/dataplane/http/headers.go index da64a1e..3f976da 100755 --- a/pkg/dataplane/http/headers.go +++ b/pkg/dataplane/http/headers.go @@ -2,14 +2,15 @@ package v3iohttp // function names const ( - putItemFunctionName = "PutItem" - updateItemFunctionName = "UpdateItem" - getItemFunctionName = "GetItem" - getItemsFunctionName = "GetItems" - createStreamFunctionName = "CreateStream" - putRecordsFunctionName = "PutRecords" - getRecordsFunctionName = "GetRecords" - seekShardsFunctionName = "SeekShard" + putItemFunctionName = "PutItem" + updateItemFunctionName = "UpdateItem" + getItemFunctionName = "GetItem" + getItemsFunctionName = "GetItems" + createStreamFunctionName = "CreateStream" + describeStreamFunctionName = "DescribeStream" + putRecordsFunctionName = "PutRecords" + getRecordsFunctionName = "GetRecords" + seekShardsFunctionName = "SeekShard" ) // headers for put item @@ -49,6 +50,12 @@ var createStreamHeaders = map[string]string{ "X-v3io-function": createStreamFunctionName, } +// headers for get records +var describeStreamHeaders = map[string]string{ + "Content-Type": "application/json", + "X-v3io-function": describeStreamFunctionName, +} + // headers for put records var putRecordsHeaders = map[string]string{ "Content-Type": "application/json", diff --git a/pkg/dataplane/item.go b/pkg/dataplane/item.go index 5fb6743..5bce099 100644 --- a/pkg/dataplane/item.go +++ b/pkg/dataplane/item.go @@ -29,7 +29,12 @@ func (i Item) GetField(name string) interface{} { } func (i Item) GetFieldInt(name string) (int, error) { - switch typedField := i[name].(type) { + fieldValue, fieldFound := i[name] + if !fieldFound { + return 0, v3ioerrors.ErrNotFound + } + + switch typedField := fieldValue.(type) { case int: return typedField, nil case float64: @@ -42,7 +47,12 @@ func (i Item) GetFieldInt(name string) (int, error) { } func (i Item) GetFieldString(name string) (string, error) { - switch typedField := i[name].(type) { + fieldValue, fieldFound := i[name] + if !fieldFound { + return "", v3ioerrors.ErrNotFound + } + + switch typedField := fieldValue.(type) { case int: return strconv.Itoa(typedField), nil case float64: @@ -53,3 +63,20 @@ func (i Item) GetFieldString(name string) (string, error) { return "", v3ioerrors.ErrInvalidTypeConversion } } + +func (i Item) GetFieldUint64(name string) (uint64, error) { + fieldValue, fieldFound := i[name] + if !fieldFound { + return 0, v3ioerrors.ErrNotFound + } + + switch typedField := fieldValue.(type) { + // TODO: properly handle uint64 + case int: + return uint64(typedField), nil + case uint64: + return typedField, nil + default: + return 0, v3ioerrors.ErrInvalidTypeConversion + } +} diff --git a/pkg/dataplane/streamconsumergroup/claim.go b/pkg/dataplane/streamconsumergroup/claim.go new file mode 100644 index 0000000..1557440 --- /dev/null +++ b/pkg/dataplane/streamconsumergroup/claim.go @@ -0,0 +1,204 @@ +package streamconsumergroup + +import ( + "fmt" + "path" + "strconv" + "time" + + "github.com/v3io/v3io-go/pkg/dataplane" + v3ioerrors "github.com/v3io/v3io-go/pkg/errors" + + "github.com/nuclio/errors" + "github.com/nuclio/logger" +) + +type claim struct { + logger logger.Logger + streamConsumerGroup *streamConsumerGroup + shardID int + recordBatchChan chan *RecordBatch + stopRecordBatchFetchChan chan struct{} + currentShardLocation string +} + +func newClaim(streamConsumerGroup *streamConsumerGroup, shardID int) (*claim, error) { + return &claim{ + logger: streamConsumerGroup.logger.GetChild(fmt.Sprintf("claim-%d", shardID)), + streamConsumerGroup: streamConsumerGroup, + shardID: shardID, + recordBatchChan: make(chan *RecordBatch, streamConsumerGroup.config.Claim.RecordBatchChanSize), + stopRecordBatchFetchChan: make(chan struct{}, 1), + }, nil +} + +func (c *claim) start() error { + c.logger.DebugWith("Starting claim") + + go func() { + err := c.fetchRecordBatches(c.stopRecordBatchFetchChan, + c.streamConsumerGroup.config.Claim.RecordBatchFetch.Interval) + + if err != nil { + c.logger.WarnWith("Failed to fetch record batches", "err", errors.GetErrorStackString(err, 10)) + } + }() + + go func() { + + // tell the consumer group handler to consume the claim + c.logger.DebugWith("Calling ConsumeClaim on handler") + if err := c.streamConsumerGroup.handler.ConsumeClaim(c.streamConsumerGroup.session, c); err != nil { + c.logger.WarnWith("ConsumeClaim returned with error", "err", errors.GetErrorStackString(err, 10)) + } + + if err := c.stop(); err != nil { + c.logger.WarnWith("Failed to stop claim after consumption", "err", errors.GetErrorStackString(err, 10)) + } + }() + + return nil +} + +func (c *claim) stop() error { + c.logger.DebugWith("Stopping claim") + + // don't block + select { + case c.stopRecordBatchFetchChan <- struct{}{}: + default: + } + + return nil +} + +func (c *claim) GetStreamPath() string { + return c.streamConsumerGroup.streamPath +} + +func (c *claim) GetShardID() int { + return c.shardID +} + +func (c *claim) GetCurrentLocation() string { + return c.currentShardLocation +} + +func (c *claim) GetRecordBatchChan() <-chan *RecordBatch { + return c.recordBatchChan +} + +func (c *claim) fetchRecordBatches(stopChannel chan struct{}, fetchInterval time.Duration) error { + var err error + + // read initial location. use config if error. might need to wait until shard actually exists + c.currentShardLocation, err = c.getCurrentShardLocation(c.shardID) + if err != nil { + if err == v3ioerrors.ErrStopped { + return nil + } + + return errors.Wrap(err, "Failed to get shard location") + } + + for { + select { + case <-time.After(fetchInterval): + c.currentShardLocation, err = c.fetchRecordBatch(c.currentShardLocation) + if err != nil { + c.logger.WarnWith("Failed fetching record batch", "err", errors.GetErrorStackString(err, 10)) + continue + } + + case <-stopChannel: + close(c.recordBatchChan) + c.logger.Debug("Stopping fetch") + return nil + } + } +} + +func (c *claim) fetchRecordBatch(location string) (string, error) { + getRecordsInput := v3io.GetRecordsInput{ + Path: path.Join(c.streamConsumerGroup.streamPath, strconv.Itoa(c.shardID)), + Location: location, + Limit: c.streamConsumerGroup.config.Claim.RecordBatchFetch.NumRecordsInBatch, + } + + response, err := c.streamConsumerGroup.container.GetRecordsSync(&getRecordsInput) + if err != nil { + return "", errors.Wrapf(err, "Failed fetching record batch: %s", location) + } + + defer response.Release() + + getRecordsOutput := response.Output.(*v3io.GetRecordsOutput) + + if len(getRecordsOutput.Records) == 0 { + return getRecordsOutput.NextLocation, nil + } + + records := make([]v3io.StreamRecord, len(getRecordsOutput.Records)) + + for receivedRecordIndex, receivedRecord := range getRecordsOutput.Records { + record := v3io.StreamRecord{ + ShardID: &c.shardID, + Data: receivedRecord.Data, + ClientInfo: receivedRecord.ClientInfo, + PartitionKey: receivedRecord.PartitionKey, + SequenceNumber: receivedRecord.SequenceNumber, + } + + records[receivedRecordIndex] = record + } + + recordBatch := RecordBatch{ + Location: location, + Records: records, + NextLocation: getRecordsOutput.NextLocation, + ShardID: c.shardID, + } + + // write into chunks channel, blocking if there's no space + c.recordBatchChan <- &recordBatch + + return getRecordsOutput.NextLocation, nil +} + +func (c *claim) getCurrentShardLocation(shardID int) (string, error) { + + // get the location from persistency + currentShardLocation, err := c.streamConsumerGroup.sequenceNumberHandler.getShardLocationFromPersistency(shardID) + if err != nil && errors.RootCause(err) != errShardNotFound { + return "", errors.Wrap(err, "Failed to get shard location") + } + + // if shard wasn't found, try again periodically + if errors.RootCause(err) == errShardNotFound { + for { + select { + + // TODO: from configuration + case <-time.After(c.streamConsumerGroup.config.SequenceNumber.ShardWaitInterval): + + // get the location from persistency + currentShardLocation, err = c.streamConsumerGroup.sequenceNumberHandler.getShardLocationFromPersistency(shardID) + if err != nil { + if errors.RootCause(err) == errShardNotFound { + + // shard doesn't exist yet, try again + continue + } + + return "", errors.Wrap(err, "Failed to get shard location") + } + + return currentShardLocation, nil + case <-c.stopRecordBatchFetchChan: + return "", v3ioerrors.ErrStopped + } + } + } + + return currentShardLocation, nil +} diff --git a/pkg/dataplane/streamconsumergroup/config.go b/pkg/dataplane/streamconsumergroup/config.go new file mode 100644 index 0000000..333e392 --- /dev/null +++ b/pkg/dataplane/streamconsumergroup/config.go @@ -0,0 +1,54 @@ +package streamconsumergroup + +import ( + "time" + + "github.com/v3io/v3io-go/pkg/common" + "github.com/v3io/v3io-go/pkg/dataplane" +) + +type Config struct { + Session struct { + Timeout time.Duration `json:"timeout,omitempty"` + HeartbeatInterval time.Duration + } `json:"session,omitempty"` + State struct { + ModifyRetry struct { + Attempts int `json:"attempts,omitempty"` + Backoff common.Backoff `json:"backoff,omitempty"` + } `json:"modifyRetry,omitempty"` + } `json:"state,omitempty"` + SequenceNumber struct { + CommitInterval time.Duration `json:"commitInterval,omitempty"` + ShardWaitInterval time.Duration `json:"shardWaitInterval,omitempty"` + } + Claim struct { + RecordBatchChanSize int `json:"recordBatchChanSize,omitempty"` + RecordBatchFetch struct { + Interval time.Duration `json:"interval,omitempty"` + NumRecordsInBatch int `json:"numRecordsInBatch,omitempty"` + InitialLocation v3io.SeekShardInputType `json:"initialLocation,omitempty"` + } `json:"recordBatchFetch,omitempty"` + } `json:"claim,omitempty"` +} + +// NewConfig returns a new configuration instance with sane defaults. +func NewConfig() *Config { + c := &Config{} + c.Session.Timeout = 10 * time.Second + c.Session.HeartbeatInterval = 3 * time.Second + c.State.ModifyRetry.Attempts = 100 + c.State.ModifyRetry.Backoff = common.Backoff{ + Min: 50 * time.Millisecond, + Max: 1 * time.Second, + Factor: 4, + } + c.SequenceNumber.CommitInterval = 10 * time.Second + c.SequenceNumber.ShardWaitInterval = 1 * time.Second + c.Claim.RecordBatchChanSize = 100 + c.Claim.RecordBatchFetch.Interval = 250 * time.Millisecond + c.Claim.RecordBatchFetch.NumRecordsInBatch = 10 + c.Claim.RecordBatchFetch.InitialLocation = v3io.SeekShardInputTypeEarliest + + return c +} diff --git a/pkg/dataplane/streamconsumergroup/sequencenumberhandler.go b/pkg/dataplane/streamconsumergroup/sequencenumberhandler.go new file mode 100644 index 0000000..5728635 --- /dev/null +++ b/pkg/dataplane/streamconsumergroup/sequencenumberhandler.go @@ -0,0 +1,237 @@ +package streamconsumergroup + +import ( + "fmt" + "net/http" + "sync" + "time" + + "github.com/v3io/v3io-go/pkg/common" + "github.com/v3io/v3io-go/pkg/dataplane" + "github.com/v3io/v3io-go/pkg/errors" + + "github.com/nuclio/errors" + "github.com/nuclio/logger" +) + +var errShardNotFound = errors.New("Shard not found") +var errShardSequenceNumberAttributeNotFound = errors.New("Shard sequenceNumber attribute") + +type sequenceNumberHandler struct { + logger logger.Logger + streamConsumerGroup *streamConsumerGroup + markedShardSequenceNumbers []uint64 + markedShardSequenceNumbersLock sync.RWMutex + stopMarkedShardSequenceNumberCommitterChan chan struct{} + lastCommittedShardSequenceNumbers []uint64 +} + +func newSequenceNumberHandler(streamConsumerGroup *streamConsumerGroup) (*sequenceNumberHandler, error) { + + return &sequenceNumberHandler{ + logger: streamConsumerGroup.logger.GetChild("sequenceNumberHandler"), + streamConsumerGroup: streamConsumerGroup, + markedShardSequenceNumbers: make([]uint64, streamConsumerGroup.totalNumShards), + stopMarkedShardSequenceNumberCommitterChan: make(chan struct{}, 1), + }, nil +} + +func (snh *sequenceNumberHandler) start() error { + snh.logger.DebugWith("Starting sequenceNumber handler") + + // stopped on stop() + go snh.markedShardSequenceNumbersCommitter(snh.streamConsumerGroup.config.SequenceNumber.CommitInterval, + snh.stopMarkedShardSequenceNumberCommitterChan) + + return nil +} + +func (snh *sequenceNumberHandler) stop() error { + snh.logger.DebugWith("Stopping sequenceNumber handler") + + select { + case snh.stopMarkedShardSequenceNumberCommitterChan <- struct{}{}: + default: + } + + return nil +} + +func (snh *sequenceNumberHandler) markShardSequenceNumber(shardID int, sequenceNumber uint64) error { + + // lock semantics are reverse - it's OK to write in parallel since each write goes + // to a different cell in the array, but once a read is happening we need to stop the world + snh.markedShardSequenceNumbersLock.RLock() + snh.markedShardSequenceNumbers[shardID] = sequenceNumber + snh.markedShardSequenceNumbersLock.RUnlock() + + return nil +} + +func (snh *sequenceNumberHandler) getShardLocationFromPersistency(shardID int) (string, error) { + snh.logger.DebugWith("Getting shard sequenceNumber from persistency", "shardID", shardID) + + shardPath, err := snh.streamConsumerGroup.getShardPath(shardID) + if err != nil { + return "", errors.Wrapf(err, "Failed getting shard path: %v", shardID) + } + + seekShardInput := v3io.SeekShardInput{ + Path: shardPath, + } + + // get the shard sequenceNumber from the item + shardSequenceNumber, err := snh.getShardSequenceNumberFromItemAttributes(shardPath) + if err != nil { + + // if the error is that the attribute wasn't found, but the shard was found - seek the shard + // according to the configuration + if err != errShardSequenceNumberAttributeNotFound { + return "", errors.Wrap(err, "Failed to get shard sequenceNumber from item attributes") + } + + seekShardInput.Type = snh.streamConsumerGroup.config.Claim.RecordBatchFetch.InitialLocation + } else { + + // use sequence number + seekShardInput.Type = v3io.SeekShardInputTypeSequence + seekShardInput.StartingSequenceNumber = shardSequenceNumber + 1 + } + + return snh.getShardLocationWithSeek(shardPath, &seekShardInput) +} + +// returns the sequenceNumber, an error re: the shard itself and an error re: the attribute in the shard +func (snh *sequenceNumberHandler) getShardSequenceNumberFromItemAttributes(shardPath string) (uint64, error) { + response, err := snh.streamConsumerGroup.container.GetItemSync(&v3io.GetItemInput{ + Path: shardPath, + AttributeNames: []string{snh.getShardCommittedSequenceNumberAttributeName()}, + }) + + if err != nil { + errWithStatusCode, errHasStatusCode := err.(v3ioerrors.ErrorWithStatusCode) + if !errHasStatusCode { + return 0, errors.Wrap(err, "Got error without status code") + } + + if errWithStatusCode.StatusCode() != http.StatusNotFound { + return 0, errors.Wrap(err, "Failed getting shard item") + } + + // TODO: remove after errors.Is support added + snh.logger.DebugWith("Could not find shard, probably doesn't exist yet", "path", shardPath) + + return 0, errShardNotFound + } + + defer response.Release() + + getItemOutput := response.Output.(*v3io.GetItemOutput) + + // return the attribute name + sequenceNumber, err := getItemOutput.Item.GetFieldUint64(snh.getShardCommittedSequenceNumberAttributeName()) + if err != nil && err == v3ioerrors.ErrNotFound { + return 0, errShardSequenceNumberAttributeNotFound + } + + // return the sequenceNumber we found + return sequenceNumber, nil +} + +func (snh *sequenceNumberHandler) getShardLocationWithSeek(shardPath string, seekShardInput *v3io.SeekShardInput) (string, error) { + + snh.logger.DebugWith("Seeking shard", "shardPath", shardPath, "seekShardInput", seekShardInput) + + response, err := snh.streamConsumerGroup.container.SeekShardSync(seekShardInput) + if err != nil { + return "", errors.Wrap(err, "Failed to seek shard") + } + defer response.Release() + + location := response.Output.(*v3io.SeekShardOutput).Location + + snh.logger.DebugWith("Seek shard succeeded", "shardPath", shardPath, "location", location) + + return location, nil +} + +func (snh *sequenceNumberHandler) getShardCommittedSequenceNumberAttributeName() string { + return fmt.Sprintf("__%s_committed_sequence_number", snh.streamConsumerGroup.name) +} + +func (snh *sequenceNumberHandler) setShardSequenceNumberInPersistency(shardID int, sequenceNumber uint64) error { + snh.logger.DebugWith("Setting shard sequenceNumber in persistency", "shardID", shardID, "sequenceNumber", sequenceNumber) + shardPath, err := snh.streamConsumerGroup.getShardPath(shardID) + if err != nil { + return errors.Wrapf(err, "Failed getting shard path: %v", shardID) + } + + return snh.streamConsumerGroup.container.UpdateItemSync(&v3io.UpdateItemInput{ + Path: shardPath, + Attributes: map[string]interface{}{ + snh.getShardCommittedSequenceNumberAttributeName(): sequenceNumber, + }, + }) +} + +func (snh *sequenceNumberHandler) markedShardSequenceNumbersCommitter(interval time.Duration, stopChan chan struct{}) { + for { + select { + case <-time.After(interval): + if err := snh.commitMarkedShardSequenceNumbers(); err != nil { + snh.logger.WarnWith("Failed committing marked shard sequenceNumbers", "err", errors.GetErrorStackString(err, 10)) + continue + } + case <-stopChan: + snh.logger.Debug("Stopped committing marked shard sequenceNumbers") + + // do the last commit + if err := snh.commitMarkedShardSequenceNumbers(); err != nil { + snh.logger.WarnWith("Failed committing marked shard sequenceNumbers on stop", "err", errors.GetErrorStackString(err, 10)) + } + return + } + } +} + +func (snh *sequenceNumberHandler) commitMarkedShardSequenceNumbers() error { + var markedShardSequenceNumbersCopy []uint64 + + // create a copy of the marked shard sequenceNumbers + snh.markedShardSequenceNumbersLock.Lock() + markedShardSequenceNumbersCopy = append(markedShardSequenceNumbersCopy, snh.markedShardSequenceNumbers...) + snh.markedShardSequenceNumbersLock.Unlock() + + // if there was no chance since last, do nothing + if common.Uint64SlicesEqual(snh.lastCommittedShardSequenceNumbers, markedShardSequenceNumbersCopy) { + return nil + } + + snh.logger.DebugWith("Committing marked shard sequenceNumbers", "markedShardSequenceNumbersCopy", markedShardSequenceNumbersCopy) + + var failedShardIDs []int + for shardID, sequenceNumber := range markedShardSequenceNumbersCopy { + + // the sequenceNumber array holds a sequenceNumber for all partitions, indexed by their id to allow for + // faster writes (using a rw lock) only the relevant shards ever get populated + if sequenceNumber == 0 { + continue + } + + if err := snh.setShardSequenceNumberInPersistency(shardID, sequenceNumber); err != nil { + snh.logger.WarnWith("Failed committing shard sequenceNumber", "shardID", shardID, + "sequenceNumber", sequenceNumber, + "err", errors.GetErrorStackString(err, 10)) + + failedShardIDs = append(failedShardIDs, shardID) + } + } + + if len(failedShardIDs) > 0 { + return errors.Errorf("Failed committing marked shard sequenceNumbers in shards: %v", failedShardIDs) + } + + snh.lastCommittedShardSequenceNumbers = markedShardSequenceNumbersCopy + + return nil +} diff --git a/pkg/dataplane/streamconsumergroup/session.go b/pkg/dataplane/streamconsumergroup/session.go new file mode 100644 index 0000000..25b0e07 --- /dev/null +++ b/pkg/dataplane/streamconsumergroup/session.go @@ -0,0 +1,92 @@ +package streamconsumergroup + +import ( + v3io "github.com/v3io/v3io-go/pkg/dataplane" + + "github.com/nuclio/errors" + "github.com/nuclio/logger" +) + +type session struct { + logger logger.Logger + streamConsumerGroup *streamConsumerGroup + state *SessionState + claims []Claim +} + +func newSession(streamConsumerGroup *streamConsumerGroup, + sessionState *SessionState) (Session, error) { + + return &session{ + logger: streamConsumerGroup.logger.GetChild("session"), + streamConsumerGroup: streamConsumerGroup, + state: sessionState, + }, nil +} + +func (s *session) start() error { + s.logger.DebugWith("Starting session") + + // for each shard we need handle, create a StreamConsumerGroupClaim object and start it + for _, shardID := range s.state.Shards { + claim, err := newClaim(s.streamConsumerGroup, shardID) + if err != nil { + return errors.Wrapf(err, "Failed creating stream consumer group claim for shard: %d", shardID) + } + + // add to claims + s.claims = append(s.claims, claim) + } + + // tell the consumer group handler to set up + s.logger.DebugWith("Triggering given handler Setup") + if err := s.streamConsumerGroup.handler.Setup(s); err != nil { + return errors.Wrap(err, "Failed to set up session") + } + + s.logger.DebugWith("Starting claim consumption") + for _, claim := range s.claims { + if err := claim.start(); err != nil { + return errors.Wrap(err, "Failed starting stream consumer group claim") + } + } + + return nil +} + +func (s *session) stop() error { + s.logger.DebugWith("Stopping session, triggering given handler cleanup") + + // tell the consumer group handler to set up + if err := s.streamConsumerGroup.handler.Cleanup(s); err != nil { + return errors.Wrap(err, "Failed to cleanup") + } + + s.logger.DebugWith("Stopping claims") + + for _, claim := range s.claims { + err := claim.stop() + if err != nil { + return errors.Wrap(err, "Failed starting stream consumer group claim") + } + } + + return nil +} + +func (s *session) GetClaims() []Claim { + return s.claims +} + +func (s *session) GetMemberID() string { + return s.streamConsumerGroup.memberID +} + +func (s *session) MarkRecord(record *v3io.StreamRecord) error { + err := s.streamConsumerGroup.sequenceNumberHandler.markShardSequenceNumber(*record.ShardID, record.SequenceNumber) + if err != nil { + return errors.Wrap(err, "Failed marking record") + } + + return nil +} diff --git a/pkg/dataplane/streamconsumergroup/state.go b/pkg/dataplane/streamconsumergroup/state.go new file mode 100644 index 0000000..a765812 --- /dev/null +++ b/pkg/dataplane/streamconsumergroup/state.go @@ -0,0 +1,34 @@ +package streamconsumergroup + +type State struct { + SchemasVersion string `json:"schema_version"` + SessionStates []*SessionState `json:"session_states"` +} + +func newState() (*State, error) { + return &State{ + SchemasVersion: "0.0.1", + SessionStates: []*SessionState{}, + }, nil +} + +func (s *State) deepCopy() *State { + stateCopy := State{} + stateCopy.SchemasVersion = s.SchemasVersion + for _, stateSession := range s.SessionStates { + stateSessionCopy := stateSession + stateCopy.SessionStates = append(stateCopy.SessionStates, stateSessionCopy) + } + + return &stateCopy +} + +func (s *State) findSessionStateByMemberID(memberID string) *SessionState { + for _, sessionState := range s.SessionStates { + if sessionState.MemberID == memberID { + return sessionState + } + } + + return nil +} diff --git a/pkg/dataplane/streamconsumergroup/statehandler.go b/pkg/dataplane/streamconsumergroup/statehandler.go new file mode 100644 index 0000000..ca93fd7 --- /dev/null +++ b/pkg/dataplane/streamconsumergroup/statehandler.go @@ -0,0 +1,398 @@ +package streamconsumergroup + +import ( + "context" + "encoding/json" + "fmt" + "math" + "path" + "time" + + "github.com/v3io/v3io-go/pkg/common" + "github.com/v3io/v3io-go/pkg/dataplane" + "github.com/v3io/v3io-go/pkg/errors" + + "github.com/nuclio/errors" + "github.com/nuclio/logger" +) + +const stateContentsAttributeKey string = "state" + +var errNoFreeShardGroups = errors.New("No free shard groups") + +type stateHandler struct { + logger logger.Logger + streamConsumerGroup *streamConsumerGroup + stopChan chan struct{} + getStateChan chan chan *State +} + +func newStateHandler(streamConsumerGroup *streamConsumerGroup) (*stateHandler, error) { + return &stateHandler{ + logger: streamConsumerGroup.logger.GetChild("stateHandler"), + streamConsumerGroup: streamConsumerGroup, + stopChan: make(chan struct{}, 1), + getStateChan: make(chan chan *State), + }, nil +} + +func (sh *stateHandler) start() error { + + // stops on stop() + go sh.refreshStatePeriodically() + + return nil +} + +func (sh *stateHandler) stop() error { + + select { + case sh.stopChan <- struct{}{}: + default: + } + + return nil +} + +func (sh *stateHandler) getOrCreateSessionState(memberID string) (*SessionState, error) { + + // create a channel on which we'll request the state + stateResponseChan := make(chan *State, 1) + + // send the channel to the refreshing goroutine. it'll post the state to this channel + sh.getStateChan <- stateResponseChan + + // wait on it + state := <-stateResponseChan + + // get the member's session state + return sh.getSessionState(state, memberID) +} + +func (sh *stateHandler) getSessionState(state *State, memberID string) (*SessionState, error) { + for _, sessionState := range state.SessionStates { + if sessionState.MemberID == memberID { + return sessionState, nil + } + } + + return nil, errors.Errorf("Member state not found: %s", memberID) +} + +func (sh *stateHandler) refreshStatePeriodically() { + var err error + + // guaranteed to only be REPLACED by a new instance - not edited. as such, once this is initialized + // it points to a read only state object + var lastState *State + + for { + select { + + // if we're asked to get state, get it + case stateResponseChan := <-sh.getStateChan: + if lastState != nil { + stateResponseChan <- lastState + } else { + lastState, err = sh.refreshState() + if err != nil { + sh.logger.WarnWith("Failed getting state", "err", errors.GetErrorStackString(err, 10)) + } + + // lastState may be nil + stateResponseChan <- lastState + } + + // periodically get the state + case <-time.After(sh.streamConsumerGroup.config.Session.HeartbeatInterval): + lastState, err = sh.refreshState() + if err != nil { + sh.logger.WarnWith("Failed refreshing state", "err", errors.GetErrorStackString(err, 10)) + continue + } + + // if we're told to stop, exit the loop + case <-sh.stopChan: + sh.logger.Debug("Stopping") + return + } + } +} + +func (sh *stateHandler) refreshState() (*State, error) { + return sh.modifyState(func(state *State) (*State, error) { + + // remove stale sessions from state + if err := sh.removeStaleSessionStates(state); err != nil { + return nil, errors.Wrap(err, "Failed to remove stale sessions") + } + + // find our session by member ID + sessionState := state.findSessionStateByMemberID(sh.streamConsumerGroup.memberID) + + // session already exists - just set the last heartbeat + if sessionState != nil { + sessionState.LastHeartbeat = time.Now() + + // we're done + return state, nil + } + + // session doesn't exist - create it + if err := sh.createSessionState(state); err != nil { + return nil, errors.Wrap(err, "Failed to create session state") + } + + return state, nil + }) +} + +func (sh *stateHandler) createSessionState(state *State) error { + if state.SessionStates == nil { + state.SessionStates = []*SessionState{} + } + + // assign shards + shards, err := sh.assignShards(sh.streamConsumerGroup.maxReplicas, sh.streamConsumerGroup.totalNumShards, state) + if err != nil { + return errors.Wrap(err, "Failed resolving shards for session") + } + + sh.logger.DebugWith("Assigned shards", + "shards", shards, + "state", state) + + state.SessionStates = append(state.SessionStates, &SessionState{ + MemberID: sh.streamConsumerGroup.memberID, + LastHeartbeat: time.Now(), + Shards: shards, + }) + + return nil +} + +func (sh *stateHandler) assignShards(maxReplicas int, numShards int, state *State) ([]int, error) { + + // per replica index, holds which shards it should handle + replicaShardGroups, err := sh.getReplicaShardGroups(maxReplicas, numShards) + if err != nil { + return nil, errors.Wrap(err, "Failed to get replica shard group") + } + + // empty shard groups are not unique - therefore simply check whether the number of + // empty shard groups allocated to sessions is equal to the number of empty shard groups + // required. if not, allocate an empty shard group + if sh.getAssignEmptyShardGroup(replicaShardGroups, state) { + return []int{}, nil + } + + // simply look for the first non-assigned replica shard group which isn't empty + for _, replicaShardGroup := range replicaShardGroups { + + // we already checked if we need to allocate an empty shard group + if len(replicaShardGroup) == 0 { + continue + } + + foundReplicaShardGroup := false + + for _, sessionState := range state.SessionStates { + if common.IntSlicesEqual(replicaShardGroup, sessionState.Shards) { + foundReplicaShardGroup = true + break + } + } + + if !foundReplicaShardGroup { + return replicaShardGroup, nil + } + } + + return nil, errNoFreeShardGroups +} + +func (sh *stateHandler) getReplicaShardGroups(maxReplicas int, numShards int) ([][]int, error) { + var replicaShardGroups [][]int + shards := common.MakeRange(0, numShards) + + step := float64(numShards) / float64(maxReplicas) + + for replicaIndex := 0; replicaIndex < maxReplicas; replicaIndex++ { + replicaIndexFloat := float64(replicaIndex) + startShard := int(math.Floor(replicaIndexFloat*step + 0.5)) + endShard := int(math.Floor((replicaIndexFloat+1)*step + 0.5)) + + replicaShardGroups = append(replicaShardGroups, shards[startShard:endShard]) + } + + return replicaShardGroups, nil +} + +func (sh *stateHandler) getAssignEmptyShardGroup(replicaShardGroups [][]int, state *State) bool { + numEmptyShardGroupRequired := 0 + for _, replicaShardGroup := range replicaShardGroups { + if len(replicaShardGroup) == 0 { + numEmptyShardGroupRequired++ + } + } + + numEmptyShardGroupAssigned := 0 + for _, sessionState := range state.SessionStates { + if len(sessionState.Shards) == 0 { + numEmptyShardGroupAssigned++ + } + } + + return numEmptyShardGroupRequired != numEmptyShardGroupAssigned + +} + +func (sh *stateHandler) modifyState(modifier stateModifier) (*State, error) { + var modifiedState *State + + backoff := sh.streamConsumerGroup.config.State.ModifyRetry.Backoff + attempts := sh.streamConsumerGroup.config.State.ModifyRetry.Attempts + + err := common.RetryFunc(context.TODO(), sh.logger, attempts, nil, &backoff, func(int) (bool, error) { + state, mtime, err := sh.getStateFromPersistency() + if err != nil && err != v3ioerrors.ErrNotFound { + return true, errors.Wrap(err, "Failed getting current state from persistency") + } + + if state == nil { + state, err = newState() + if err != nil { + return true, errors.Wrap(err, "Failed to create state") + } + } + + // for logging + previousState := state.deepCopy() + + modifiedState, err = modifier(state) + if err != nil { + return true, errors.Wrap(err, "Failed modifying state") + } + + sh.logger.DebugWith("Modified state, saving", + "previousState", previousState, + "modifiedState", modifiedState) + + err = sh.setStateInPersistency(modifiedState, mtime) + if err != nil { + return true, errors.Wrap(err, "Failed setting state in persistency state") + } + + return false, nil + }) + + if err != nil { + return nil, errors.Wrap(err, "Failed modifying state, attempts exhausted") + } + + return modifiedState, nil +} + +func (sh *stateHandler) getStateFilePath() (string, error) { + return path.Join(sh.streamConsumerGroup.streamPath, fmt.Sprintf("%s-state.json", sh.streamConsumerGroup.name)), nil +} + +func (sh *stateHandler) setStateInPersistency(state *State, mtime *int) error { + stateFilePath, err := sh.getStateFilePath() + if err != nil { + return errors.Wrap(err, "Failed getting state file path") + } + + stateContents, err := json.Marshal(state) + if err != nil { + return errors.Wrap(err, "Failed marshaling state file contents") + } + + var condition string + if mtime != nil { + condition = fmt.Sprintf("__mtime_nsecs == %v", *mtime) + } + + err = sh.streamConsumerGroup.container.UpdateItemSync(&v3io.UpdateItemInput{ + Path: stateFilePath, + Condition: condition, + Attributes: map[string]interface{}{ + stateContentsAttributeKey: string(stateContents), + }, + }) + if err != nil { + return errors.Wrap(err, "Failed setting state in persistency") + } + + return nil +} + +func (sh *stateHandler) getStateFromPersistency() (*State, *int, error) { + stateFilePath, err := sh.getStateFilePath() + if err != nil { + return nil, nil, errors.Wrap(err, "Failed getting state file path") + } + + response, err := sh.streamConsumerGroup.container.GetItemSync(&v3io.GetItemInput{ + Path: stateFilePath, + AttributeNames: []string{"__mtime_nsecs", stateContentsAttributeKey}, + }) + + if err != nil { + errWithStatusCode, errHasStatusCode := err.(v3ioerrors.ErrorWithStatusCode) + if !errHasStatusCode { + return nil, nil, errors.Wrap(err, "Got error without status code") + } + + if errWithStatusCode.StatusCode() != 404 { + return nil, nil, errors.Wrap(err, "Failed getting state item") + } + + return nil, nil, v3ioerrors.ErrNotFound + } + + defer response.Release() + + getItemOutput := response.Output.(*v3io.GetItemOutput) + + stateContents, err := getItemOutput.Item.GetFieldString(stateContentsAttributeKey) + if err != nil { + return nil, nil, errors.Wrap(err, "Failed getting state attribute") + } + + var state State + + err = json.Unmarshal([]byte(stateContents), &state) + if err != nil { + return nil, nil, errors.Wrapf(err, "Failed unmarshalling state contents: %s", stateContents) + } + + stateMtime, err := getItemOutput.Item.GetFieldInt("__mtime_nsecs") + if err != nil { + return nil, nil, errors.New("Failed getting mtime attribute") + } + + return &state, &stateMtime, nil +} + +func (sh *stateHandler) removeStaleSessionStates(state *State) error { + + // clear out the sessions since we only want the valid sessions + var activeSessionStates []*SessionState + + for _, sessionState := range state.SessionStates { + + // check if the last heartbeat happened prior to the session timeout + if time.Since(sessionState.LastHeartbeat) < sh.streamConsumerGroup.config.Session.Timeout { + activeSessionStates = append(activeSessionStates, sessionState) + } else { + sh.logger.DebugWith("Removing stale member", + "memberID", sessionState.MemberID, + "lastHeartbeat", time.Since(sessionState.LastHeartbeat)) + } + } + + state.SessionStates = activeSessionStates + + return nil +} diff --git a/pkg/dataplane/streamconsumergroup/statehandler_test.go b/pkg/dataplane/streamconsumergroup/statehandler_test.go new file mode 100644 index 0000000..65538b9 --- /dev/null +++ b/pkg/dataplane/streamconsumergroup/statehandler_test.go @@ -0,0 +1,82 @@ +package streamconsumergroup + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type stateHandlerSuite struct { + suite.Suite + stateHandler *stateHandler +} + +func (suite *stateHandlerSuite) TestAssignShards() { + + for _, testCase := range []struct { + name string + maxReplicas int + numShards int + existingShardGroups [][]int + expectedShardGroup []int + }{ + { + name: "even, more shards than replicas", + maxReplicas: 4, + numShards: 8, + existingShardGroups: [][]int{{0, 1}, {4, 5}}, + expectedShardGroup: []int{2, 3}, + }, + { + name: "odd, more shards than replicas", + maxReplicas: 3, + numShards: 8, + existingShardGroups: [][]int{{0, 1, 2}}, + expectedShardGroup: []int{3, 4}, + }, + { + name: "equal number of shards and replicas", + maxReplicas: 4, + numShards: 4, + existingShardGroups: [][]int{{0}, {1}, {3}}, + expectedShardGroup: []int{2}, + }, + { + name: "more replicas than shards, no empty groups assigned", + maxReplicas: 4, + numShards: 2, + existingShardGroups: [][]int{{0}, {1}}, + expectedShardGroup: []int{}, + }, + { + name: "more replicas than shards, all empty groups assigned", + maxReplicas: 4, + numShards: 2, + existingShardGroups: [][]int{{}, {}}, + expectedShardGroup: []int{0}, + }, + { + name: "more replicas than shards, some empty groups assigned", + maxReplicas: 4, + numShards: 2, + existingShardGroups: [][]int{{}, {0}}, + expectedShardGroup: []int{}, + }, + } { + // make state from shard groups + state := State{} + for _, existingShardGroup := range testCase.existingShardGroups { + state.SessionStates = append(state.SessionStates, &SessionState{ + Shards: existingShardGroup, + }) + } + + assignedShardGroup, err := suite.stateHandler.assignShards(testCase.maxReplicas, testCase.numShards, &state) + suite.Require().NoError(err) + suite.Require().Equal(testCase.expectedShardGroup, assignedShardGroup, testCase.name) + } +} + +func TestBinaryTestSuite(t *testing.T) { + suite.Run(t, new(stateHandlerSuite)) +} diff --git a/pkg/dataplane/streamconsumergroup/streamconsumergroup.go b/pkg/dataplane/streamconsumergroup/streamconsumergroup.go new file mode 100644 index 0000000..1595cbe --- /dev/null +++ b/pkg/dataplane/streamconsumergroup/streamconsumergroup.go @@ -0,0 +1,142 @@ +package streamconsumergroup + +import ( + "fmt" + "path" + "strconv" + + "github.com/v3io/v3io-go/pkg/dataplane" + + "github.com/nuclio/errors" + "github.com/nuclio/logger" + "github.com/rs/xid" +) + +type streamConsumerGroup struct { + name string + memberID string + logger logger.Logger + config *Config + streamPath string + maxReplicas int + stateHandler *stateHandler + sequenceNumberHandler *sequenceNumberHandler + container v3io.Container + handler Handler + session Session + totalNumShards int +} + +func NewStreamConsumerGroup(parentLogger logger.Logger, + name string, + memberName string, + config *Config, + streamPath string, + maxReplicas int, + container v3io.Container) (StreamConsumerGroup, error) { + var err error + + // add uniqueness + memberID := fmt.Sprintf("%s-%s", memberName, xid.New().String()) + + if config == nil { + config = NewConfig() + } + + newStreamConsumerGroup := streamConsumerGroup{ + name: name, + memberID: memberID, + logger: parentLogger.GetChild(fmt.Sprintf("%s-%s", name, memberID)), + config: config, + streamPath: streamPath, + maxReplicas: maxReplicas, + container: container, + } + + // get the total number of shards for this stream + newStreamConsumerGroup.totalNumShards, err = newStreamConsumerGroup.getTotalNumberOfShards() + if err != nil { + return nil, errors.Wrap(err, "Failed to get total number of shards") + } + + // create & start a state handler for the stream + newStreamConsumerGroup.stateHandler, err = newStateHandler(&newStreamConsumerGroup) + if err != nil { + return nil, errors.Wrap(err, "Failed creating stream consumer group state handler") + } + + err = newStreamConsumerGroup.stateHandler.start() + if err != nil { + return nil, errors.Wrap(err, "Failed starting stream consumer group state handler") + } + + // create & start an location handler for the stream + newStreamConsumerGroup.sequenceNumberHandler, err = newSequenceNumberHandler(&newStreamConsumerGroup) + if err != nil { + return nil, errors.Wrap(err, "Failed creating stream consumer group location handler") + } + + err = newStreamConsumerGroup.sequenceNumberHandler.start() + if err != nil { + return nil, errors.Wrap(err, "Failed starting stream consumer group state handler") + } + + return &newStreamConsumerGroup, nil +} + +func (scg *streamConsumerGroup) Consume(handler Handler) error { + scg.logger.DebugWith("Starting consumption of consumer group") + + scg.handler = handler + + // get the state (holding our shards) + sessionState, err := scg.stateHandler.getOrCreateSessionState(scg.memberID) + if err != nil { + return errors.Wrap(err, "Failed getting stream consumer group member state") + } + + // create a session object from our state + scg.session, err = newSession(scg, sessionState) + if err != nil { + return errors.Wrap(err, "Failed creating stream consumer group session") + } + + // start it + return scg.session.start() +} + +func (scg *streamConsumerGroup) Close() error { + scg.logger.DebugWith("Closing consumer group") + + if err := scg.stateHandler.stop(); err != nil { + return errors.Wrapf(err, "Failed stopping state handler") + } + if err := scg.sequenceNumberHandler.stop(); err != nil { + return errors.Wrapf(err, "Failed stopping location handler") + } + + if scg.session != nil { + if err := scg.session.stop(); err != nil { + return errors.Wrap(err, "Failed stopping member session") + } + } + + return nil +} + +func (scg *streamConsumerGroup) getShardPath(shardID int) (string, error) { + return path.Join(scg.streamPath, strconv.Itoa(shardID)), nil +} + +func (scg *streamConsumerGroup) getTotalNumberOfShards() (int, error) { + response, err := scg.container.DescribeStreamSync(&v3io.DescribeStreamInput{ + Path: scg.streamPath, + }) + if err != nil { + return 0, errors.Wrapf(err, "Failed describing stream: %s", scg.streamPath) + } + + defer response.Release() + + return response.Output.(*v3io.DescribeStreamOutput).ShardCount, nil +} diff --git a/pkg/dataplane/streamconsumergroup/types.go b/pkg/dataplane/streamconsumergroup/types.go new file mode 100644 index 0000000..d522d59 --- /dev/null +++ b/pkg/dataplane/streamconsumergroup/types.go @@ -0,0 +1,61 @@ +package streamconsumergroup + +import ( + "time" + + v3io "github.com/v3io/v3io-go/pkg/dataplane" +) + +type stateModifier func(*State) (*State, error) + +type SessionState struct { + MemberID string `json:"member_id"` + LastHeartbeat time.Time `json:"last_heartbeat_time"` + Shards []int `json:"shards"` +} + +type Handler interface { + + // Setup is run at the beginning of a new session, before ConsumeClaim. + Setup(Session) error + + // Cleanup is run at the end of a session, once all ConsumeClaim goroutines have exited + // but before the locations are committed for the very last time. + Cleanup(Session) error + + // ConsumeClaim must start a consumer loop of ConsumerGroupClaim's Messages(). + // Once the Messages() channel is closed, the Handler must finish its processing + // loop and exit. + ConsumeClaim(Session, Claim) error +} + +type RecordBatch struct { + Records []v3io.StreamRecord + Location string + NextLocation string + ShardID int +} + +type StreamConsumerGroup interface { + Consume(Handler) error + Close() error +} + +type Session interface { + GetClaims() []Claim + GetMemberID() string + MarkRecord(*v3io.StreamRecord) error + + start() error + stop() error +} + +type Claim interface { + GetStreamPath() string + GetShardID() int + GetCurrentLocation() string + GetRecordBatchChan() <-chan *RecordBatch + + start() error + stop() error +} diff --git a/pkg/dataplane/test/streamconsumergroup_test.go b/pkg/dataplane/test/streamconsumergroup_test.go new file mode 100644 index 0000000..3362cf7 --- /dev/null +++ b/pkg/dataplane/test/streamconsumergroup_test.go @@ -0,0 +1,368 @@ +package test + +import ( + "encoding/json" + "fmt" + "sync/atomic" + "testing" + "time" + + v3io "github.com/v3io/v3io-go/pkg/dataplane" + "github.com/v3io/v3io-go/pkg/dataplane/streamconsumergroup" + + "github.com/nuclio/logger" + "github.com/stretchr/testify/suite" +) + +type recordData struct { + ShardID int `json:"shard_id"` + Index int `json:"index"` +} + +type streamConsumerGroupTestSuite struct { + StreamTestSuite + streamPath string +} + +func (suite *streamConsumerGroupTestSuite) SetupSuite() { + suite.StreamTestSuite.SetupSuite() + suite.createContainer() + suite.streamPath = fmt.Sprintf("%s/test-stream-0/", suite.testPath) +} + +func (suite *streamConsumerGroupTestSuite) TestLocationHandling() { + consumerGroupName := "cg0" + numShards := 8 + + suite.createStream(suite.streamPath, numShards) + + memberGroup := newMemberGroup(suite, + consumerGroupName, + 2, + numShards, + 2, + []int{0, 0, 0, 0, 0, 0, 0, 0}, + []int{5, 10, 10, 10, 15, 10, 10, 20}) + + // wait a bit for things to happen - the members should all connect, get their partitions and start consuming + // but not actually consume anything + time.Sleep(3 * time.Second) + + // must have exactly 2 shards each, must all be consuming, must all have not processed any messages + memberGroup.verifyClaimShards(numShards, []int{4}) + memberGroup.verifyNumActiveClaimConsumptions(numShards) + memberGroup.verifyNumRecordsConsumed([]int{0, 0, 0, 0, 0, 0, 0, 0}) + + suite.writeRecords([]int{30, 30, 30, 30, 30, 30, 30, 30}) + + // wait a bit for things to happen - the members should read data from the shards up to the amount they were + // told to read, verifying that each message is in order and the expected + time.Sleep(15 * time.Second) + + memberGroup.verifyClaimShards(numShards, []int{4}) + memberGroup.verifyNumActiveClaimConsumptions(0) + memberGroup.verifyNumRecordsConsumed([]int{5, 10, 10, 10, 15, 10, 10, 20}) + + // stop the group + memberGroup.stop() + time.Sleep(3 * time.Second) + + memberGroup = newMemberGroup(suite, + consumerGroupName, + 4, + numShards, + 4, + []int{5, 10, 10, 10, 15, 10, 10, 20}, + []int{50, 50, 50, 50, 50, 50, 50, 50}) + + // wait a bit for things to happen + time.Sleep(30 * time.Second) + + memberGroup.verifyClaimShards(numShards, []int{2}) + memberGroup.verifyNumActiveClaimConsumptions(8) + memberGroup.verifyNumRecordsConsumed([]int{25, 20, 20, 20, 15, 20, 20, 10}) + + memberGroup.stop() + time.Sleep(3 * time.Second) +} + +func (suite *streamConsumerGroupTestSuite) createStream(streamPath string, numShards int) { + createStreamInput := v3io.CreateStreamInput{ + Path: streamPath, + ShardCount: numShards, + RetentionPeriodHours: 1, + } + + err := suite.container.CreateStreamSync(&createStreamInput) + suite.Require().NoError(err, "Failed to create stream") +} + +func (suite *streamConsumerGroupTestSuite) writeRecords(numRecordsPerShard []int) { + var records []*v3io.StreamRecord + + suite.logger.DebugWith("Writing records", "numRecordsPerShard", numRecordsPerShard) + + for shardID, numRecordsPerShard := range numRecordsPerShard { + + // we're taking address + shardIDCopy := shardID + + for recordIndex := 0; recordIndex < numRecordsPerShard; recordIndex++ { + recordDataInstance := recordData{ + ShardID: shardIDCopy, + Index: recordIndex, + } + + marshalledRecordDataInstance, err := json.Marshal(&recordDataInstance) + suite.Require().NoError(err) + + records = append(records, &v3io.StreamRecord{ + ShardID: &shardIDCopy, + Data: marshalledRecordDataInstance, + }) + } + } + + putRecordsInput := v3io.PutRecordsInput{ + Path: suite.streamPath, + Records: records, + } + + response, err := suite.container.PutRecordsSync(&putRecordsInput) + suite.Require().NoError(err, "Failed to put records") + + putRecordsResponse := response.Output.(*v3io.PutRecordsOutput) + suite.Require().Equal(0, putRecordsResponse.FailedRecordCount) + + suite.logger.DebugWith("Done writing records", "numRecordsPerShard", numRecordsPerShard) +} + +// +// Orchestrates a group of members +// + +type memberGroup struct { + suite *streamConsumerGroupTestSuite + members []*member + numberOfRecordsConsumed []int +} + +func newMemberGroup(suite *streamConsumerGroupTestSuite, + consumerGroupName string, + maxNumMembers int, + numShards int, + numMembers int, + expectedInitialRecordIndex []int, + numberOfRecordToConsume []int) *memberGroup { + newMemberGroup := memberGroup{ + suite: suite, + numberOfRecordsConsumed: make([]int, numShards), + } + + memberChan := make(chan *member, numMembers) + + for memberIdx := 0; memberIdx < numMembers; memberIdx++ { + go func() { + memberInstance := newMember(suite, + consumerGroupName, + maxNumMembers, + numShards, + memberIdx, + newMemberGroup.numberOfRecordsConsumed) + + // start + memberInstance.start(expectedInitialRecordIndex, numberOfRecordToConsume) + + // shove to member chan + memberChan <- memberInstance + }() + } + + for memberInstance := range memberChan { + newMemberGroup.members = append(newMemberGroup.members, memberInstance) + if len(newMemberGroup.members) >= numMembers { + break + } + } + + return &newMemberGroup +} + +func (mg *memberGroup) verifyClaimShards(expectedTotalNumShards int, expectedNumShardsPerMember []int) { + totalNumShards := 0 + + for _, member := range mg.members { + numMemberShards := len(member.claims) + + mg.suite.Require().Contains(expectedNumShardsPerMember, + numMemberShards, + "Member %s doesn't have the required amount of shards. Has %d, expected %v", + member.id, + numMemberShards, + expectedNumShardsPerMember) + + totalNumShards += numMemberShards + } + + mg.suite.Require().Equal(expectedTotalNumShards, totalNumShards) +} + +func (mg *memberGroup) verifyNumActiveClaimConsumptions(expectedNumActiveClaimConsumptions int) { + totalNumActiveClaimConsumptions := 0 + + for _, member := range mg.members { + totalNumActiveClaimConsumptions += int(member.numActiveClaimConsumptions) + } + + mg.suite.Require().Equal(expectedNumActiveClaimConsumptions, totalNumActiveClaimConsumptions) +} + +func (mg *memberGroup) verifyNumRecordsConsumed(expectedNumRecordsConsumed []int) { + mg.suite.Require().Equal(expectedNumRecordsConsumed, mg.numberOfRecordsConsumed) +} + +func (mg *memberGroup) stop() { + for _, member := range mg.members { + member.stop() + } + + mg.suite.logger.Info("Member group stopped") +} + +// +// Simulates a member +// + +type member struct { + suite *streamConsumerGroupTestSuite + logger logger.Logger + id string + expectedStartRecordIndex []int + numberOfRecordToConsume []int + numberOfRecordsConsumed []int + streamConsumerGroup streamconsumergroup.StreamConsumerGroup + claims []streamconsumergroup.Claim + numActiveClaimConsumptions int64 +} + +func newMember(suite *streamConsumerGroupTestSuite, + consumerGroupName string, + maxNumMembers int, + numShards int, + index int, + numberOfRecordsConsumed []int) *member { + id := fmt.Sprintf("m%d", index) + + streamConsumerGroupConfig := streamconsumergroup.NewConfig() + streamConsumerGroupConfig.Claim.RecordBatchFetch.NumRecordsInBatch = 10 + streamConsumerGroupConfig.Claim.RecordBatchFetch.Interval = 50 * time.Millisecond + + streamConsumerGroup, err := streamconsumergroup.NewStreamConsumerGroup( + suite.logger, + consumerGroupName, + id, + streamConsumerGroupConfig, + suite.streamPath, + maxNumMembers, + suite.container) + suite.Require().NoError(err, "Failed creating stream consumer group") + + return &member{ + suite: suite, + logger: suite.logger.GetChild(id), + id: id, + streamConsumerGroup: streamConsumerGroup, + expectedStartRecordIndex: make([]int, numShards), + numberOfRecordToConsume: make([]int, numShards), + numberOfRecordsConsumed: numberOfRecordsConsumed, + } +} + +func (m *member) Setup(session streamconsumergroup.Session) error { + m.claims = session.GetClaims() + + shardIDs := m.getShardIDs() + m.logger.DebugWith("Setup called", "shardIDs", shardIDs) + + return nil +} + +func (m *member) Cleanup(session streamconsumergroup.Session) error { + m.logger.DebugWith("Cleanup called") + return nil +} + +func (m *member) ConsumeClaim(session streamconsumergroup.Session, claim streamconsumergroup.Claim) error { + numActiveClaimConsumptions := atomic.AddInt64(&m.numActiveClaimConsumptions, 1) + m.logger.DebugWith("Consume Claims called", "numActiveClaimConsumptions", numActiveClaimConsumptions) + + expectedRecordIndex := m.expectedStartRecordIndex[claim.GetShardID()] + + // reduce at the end + defer func() { + numActiveClaimConsumptions := atomic.AddInt64(&m.numActiveClaimConsumptions, -1) + + m.logger.DebugWith("Consume Claims done", + "numRecordsConsumed", m.numberOfRecordsConsumed, + "numActiveClaimConsumptions", numActiveClaimConsumptions) + }() + + // start reading + for recordBatch := range claim.GetRecordBatchChan() { + + // iterate over records + for _, record := range recordBatch.Records { + recordDataInstance := recordData{} + + // read the data into message + err := json.Unmarshal(record.Data, &recordDataInstance) + m.suite.Require().NoError(err) + + // make sure we're reading the proper shard + m.suite.Require().Equal(recordDataInstance.ShardID, claim.GetShardID()) + + // check we got the expected message index + m.suite.Require().Equal(expectedRecordIndex, recordDataInstance.Index) + + expectedRecordIndex++ + m.numberOfRecordsConsumed[claim.GetShardID()]++ + + err = session.MarkRecord(&record) + m.suite.Require().NoError(err) + + if m.numberOfRecordsConsumed[claim.GetShardID()] >= m.numberOfRecordToConsume[claim.GetShardID()] { + return nil + } + } + } + + return nil +} + +func (m *member) start(expectedStartRecordIndex []int, numberOfRecordToConsume []int) { + m.expectedStartRecordIndex = expectedStartRecordIndex + m.numberOfRecordToConsume = numberOfRecordToConsume + + // start consuming + err := m.streamConsumerGroup.Consume(m) + m.suite.Require().NoError(err) +} + +func (m *member) stop() { + err := m.streamConsumerGroup.Close() + m.suite.Require().NoError(err) +} + +func (m *member) getShardIDs() []int { + var shardIDs []int + + for _, claim := range m.claims { + shardIDs = append(shardIDs, claim.GetShardID()) + } + + return shardIDs +} + +func TestStreamConsumerGroupTestSuite(t *testing.T) { + suite.Run(t, new(streamConsumerGroupTestSuite)) +} diff --git a/pkg/dataplane/test/sync_test.go b/pkg/dataplane/test/sync_test.go index dc55f6d..5489155 100644 --- a/pkg/dataplane/test/sync_test.go +++ b/pkg/dataplane/test/sync_test.go @@ -632,28 +632,22 @@ func (suite *syncContainerKVTestSuite) SetupSuite() { type syncStreamTestSuite struct { syncTestSuite - testPath string + StreamTestSuite StreamTestSuite } func (suite *syncStreamTestSuite) SetupTest() { - suite.testPath = "/stream-test" - err := suite.deleteAllStreamsInPath(suite.testPath) - // get the underlying root error - if err != nil { - errWithStatusCode, errHasStatusCode := err.(v3ioerrors.ErrorWithStatusCode) - suite.Require().True(errHasStatusCode) - // File not found is OK - suite.Require().Equal(404, errWithStatusCode.StatusCode(), "Failed to setup test suite") + suite.StreamTestSuite = StreamTestSuite{ + testSuite: suite.syncTestSuite.testSuite, } + suite.StreamTestSuite.SetupTest() } func (suite *syncStreamTestSuite) TearDownTest() { - err := suite.deleteAllStreamsInPath(suite.testPath) - suite.Require().NoError(err, "Failed to tear down test suite") + suite.StreamTestSuite.TearDownTest() } func (suite *syncStreamTestSuite) TestStream() { - streamPath := fmt.Sprintf("%s/mystream/", suite.testPath) + streamPath := fmt.Sprintf("%s/mystream/", suite.StreamTestSuite.testPath) // // Create the stream @@ -772,39 +766,6 @@ func (suite *syncStreamTestSuite) TestStream() { suite.Require().NoError(err, "Failed to delete stream") } -func (suite *syncStreamTestSuite) deleteAllStreamsInPath(path string) error { - - getContainerContentsInput := v3io.GetContainerContentsInput{ - Path: path, - } - - suite.populateDataPlaneInput(&getContainerContentsInput.DataPlaneInput) - - // get all streams in the test path - response, err := suite.container.GetContainerContentsSync(&getContainerContentsInput) - - if err != nil { - return err - } - response.Release() - - // iterate over streams (prefixes) and delete them - for _, commonPrefix := range response.Output.(*v3io.GetContainerContentsOutput).CommonPrefixes { - deleteStreamInput := v3io.DeleteStreamInput{ - Path: "/" + commonPrefix.Prefix, - } - - suite.populateDataPlaneInput(&deleteStreamInput.DataPlaneInput) - - err := suite.container.DeleteStreamSync(&deleteStreamInput) - if err != nil { - return err - } - } - - return nil -} - type syncContextStreamTestSuite struct { syncStreamTestSuite } diff --git a/pkg/dataplane/test/test.go b/pkg/dataplane/test/test.go index 8d77116..08a1d09 100644 --- a/pkg/dataplane/test/test.go +++ b/pkg/dataplane/test/test.go @@ -5,6 +5,7 @@ import ( "github.com/v3io/v3io-go/pkg/dataplane" "github.com/v3io/v3io-go/pkg/dataplane/http" + "github.com/v3io/v3io-go/pkg/errors" "github.com/nuclio/logger" "github.com/nuclio/zap" @@ -72,3 +73,63 @@ func (suite *testSuite) createContainer() { }) suite.Require().NoError(err) } + +type StreamTestSuite struct { // nolint: deadcode + testSuite + testPath string +} + +func (suite *StreamTestSuite) SetupSuite() { + suite.testSuite.SetupSuite() + suite.testPath = "/stream-test" +} + +func (suite *StreamTestSuite) SetupTest() { + err := suite.deleteAllStreamsInPath(suite.testPath) + + // get the underlying root error + if err != nil { + errWithStatusCode, errHasStatusCode := err.(v3ioerrors.ErrorWithStatusCode) + suite.Require().True(errHasStatusCode) + + // File not found is OK + suite.Require().Equal(404, errWithStatusCode.StatusCode(), "Failed to setup test suite") + } +} + +func (suite *StreamTestSuite) TearDownTest() { + err := suite.deleteAllStreamsInPath(suite.testPath) + suite.Require().NoError(err, "Failed to tear down test suite") +} + +func (suite *StreamTestSuite) deleteAllStreamsInPath(path string) error { + getContainerContentsInput := v3io.GetContainerContentsInput{ + Path: path, + } + + suite.populateDataPlaneInput(&getContainerContentsInput.DataPlaneInput) + + // get all streams in the test path + response, err := suite.container.GetContainerContentsSync(&getContainerContentsInput) + + if err != nil { + return err + } + response.Release() + + // iterate over streams (prefixes) and delete them + for _, commonPrefix := range response.Output.(*v3io.GetContainerContentsOutput).CommonPrefixes { + deleteStreamInput := v3io.DeleteStreamInput{ + Path: "/" + commonPrefix.Prefix, + } + + suite.populateDataPlaneInput(&deleteStreamInput.DataPlaneInput) + + err := suite.container.DeleteStreamSync(&deleteStreamInput) + if err != nil { + return err + } + } + + return nil +} diff --git a/pkg/dataplane/types.go b/pkg/dataplane/types.go index 37b33e0..053980d 100644 --- a/pkg/dataplane/types.go +++ b/pkg/dataplane/types.go @@ -125,20 +125,22 @@ func mode(v3ioFileMode FileMode) (os.FileMode, error) { // For example Scan API returns file mode as decimal number (base 10) while ListDir as Octal (base 8) var sFileMode = string(v3ioFileMode) if strings.HasPrefix(sFileMode, "0") { + // Convert octal representation of V3IO into decimal representation of Go - if mode, err := strconv.ParseUint(sFileMode, 8, 32); err != nil { - return os.FileMode(S_IFMT), err - } else { - golangFileMode := ((mode & S_IFMT) << 17) | (mode & IP_OFFMASK) - return os.FileMode(golangFileMode), nil - } - } else { - mode, err := strconv.ParseUint(sFileMode, 10, 32) + mode, err := strconv.ParseUint(sFileMode, 8, 32) if err != nil { return os.FileMode(S_IFMT), err } - return os.FileMode(mode), nil + + golangFileMode := ((mode & S_IFMT) << 17) | (mode & IP_OFFMASK) + return os.FileMode(golangFileMode), nil } + + mode, err := strconv.ParseUint(sFileMode, 10, 32) + if err != nil { + return os.FileMode(S_IFMT), err + } + return os.FileMode(mode), nil } type GetContainerContentsOutput struct { @@ -254,7 +256,7 @@ type GetItemsInput struct { TotalSegments int SortKeyRangeStart string SortKeyRangeEnd string - RequestJsonResponse bool + RequestJSONResponse bool `json:"RequestJsonResponse"` } type GetItemsOutput struct { @@ -269,10 +271,11 @@ type GetItemsOutput struct { // type StreamRecord struct { - ShardID *int - Data []byte - ClientInfo []byte - PartitionKey string + ShardID *int + Data []byte + ClientInfo []byte + PartitionKey string + SequenceNumber uint64 } type SeekShardInputType int @@ -291,6 +294,17 @@ type CreateStreamInput struct { RetentionPeriodHours int } +type DescribeStreamInput struct { + DataPlaneInput + Path string +} + +type DescribeStreamOutput struct { + DataPlaneOutput + ShardCount int + RetentionPeriodHours int +} + type DeleteStreamInput struct { DataPlaneInput Path string @@ -303,7 +317,7 @@ type PutRecordsInput struct { } type PutRecordResult struct { - SequenceNumber int + SequenceNumber uint64 ShardID int `json:"ShardId"` ErrorCode int ErrorMessage string @@ -319,7 +333,7 @@ type SeekShardInput struct { DataPlaneInput Path string Type SeekShardInputType - StartingSequenceNumber int + StartingSequenceNumber uint64 Timestamp int } @@ -338,7 +352,7 @@ type GetRecordsInput struct { type GetRecordsResult struct { ArrivalTimeSec int ArrivalTimeNSec int - SequenceNumber int + SequenceNumber uint64 ClientInfo []byte PartitionKey string Data []byte diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index 0f12001..103f7f3 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -5,7 +5,8 @@ import ( ) var ErrInvalidTypeConversion = errors.New("Invalid type conversion") - +var ErrNotFound = errors.New("Not found") +var ErrStopped = errors.New("Stopped") var ErrTimeout = errors.New("Timed out") type ErrorWithStatusCode struct {