From 0b8daa44ccaf73d63902bf3f553221136bc97236 Mon Sep 17 00:00:00 2001 From: Gleb Brozhe Date: Wed, 27 Mar 2024 12:27:42 -0700 Subject: [PATCH 01/11] refactor funlen linter --- internal/allocator/allocator.go | 1 + internal/decimal/decimal.go | 219 +++++++++++++++++++------------- internal/stack/record.go | 60 ++++++--- internal/types/types.go | 1 + retry/retry.go | 92 ++++++++------ 5 files changed, 229 insertions(+), 144 deletions(-) diff --git a/internal/allocator/allocator.go b/internal/allocator/allocator.go index 378e8d244..ee44ca7d3 100644 --- a/internal/allocator/allocator.go +++ b/internal/allocator/allocator.go @@ -67,6 +67,7 @@ func New() (v *Allocator) { return allocatorPool.Get() } +//nolint:funlen func (a *Allocator) Free() { a.valueAllocator.free() a.typeAllocator.free() diff --git a/internal/decimal/decimal.go b/internal/decimal/decimal.go index a4753992a..d0e3bf861 100644 --- a/internal/decimal/decimal.go +++ b/internal/decimal/decimal.go @@ -95,94 +95,113 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { return v, nil } - neg := s[0] == '-' - if neg || s[0] == '+' { - s = s[1:] - } + neg, s := parseSign(s) if isInf(s) { - if neg { - return v.Set(neginf), nil + return handleSpecialValues(v, neg, inf, neginf) + } + if isNaN(s) { + return handleSpecialValues(v, neg, nan, negnan) + } + + integral := precision - scale + s, err := parseNumber(s, v, &integral, &scale) + if err != nil { + return nil, err + } + + if len(s) > 0 { + if err := handleRemainingDigits(s, v, precision); err != nil { + return nil, err } + } + + v.Mul(v, pow(ten, scale)) + if neg { + v.Neg(v) + } - return v.Set(inf), nil + return v, nil +} + +func handleRemainingDigits(s string, v *big.Int, precision uint32) error { + c := s[0] + if !isDigit(c) { + return syntaxError(s) } - if isNaN(s) { - if neg { - return v.Set(negnan), nil + plus := c > '5' + if !plus && c == '5' { + plus = shouldRoundUp(v, s) + } + if plus { + v.Add(v, one) + if v.Cmp(pow(ten, precision)) >= 0 { + v.Set(inf) } + } + + return nil +} - return v.Set(nan), nil +func shouldRoundUp(v *big.Int, s string) bool { + var x big.Int + plus := x.And(v, one).Cmp(zero) != 0 // Last digit is not a zero. + for !plus && len(s) > 1 { + s = s[1:] + c := s[0] + if !isDigit(c) { + break + } + plus = c != '0' } - integral := precision - scale + return plus +} + +func parseSign(s string) (neg bool, remaining string) { + neg = s[0] == '-' + if neg || s[0] == '+' { + s = s[1:] + } + return neg, s +} + +func handleSpecialValues(v *big.Int, neg bool, pos, negVal *big.Int) (*big.Int, error) { + if neg { + return v.Set(negVal), nil + } + + return v.Set(pos), nil +} + +func parseNumber(s string, v *big.Int, integral, scale *uint32) (remaining string, err error) { var dot bool for ; len(s) > 0; s = s[1:] { c := s[0] if c == '.' { if dot { - return nil, syntaxError(s) + return "", syntaxError(s) } dot = true continue } - if dot { - if scale > 0 { - scale-- - } else { - break - } - } - - if !isDigit(c) { - return nil, syntaxError(s) - } - - v.Mul(v, ten) - v.Add(v, big.NewInt(int64(c-'0'))) - - if !dot && v.Cmp(zero) > 0 && integral == 0 { - if neg { - return neginf, nil - } - - return inf, nil - } - integral-- - } - //nolint:nestif - if len(s) > 0 { // Characters remaining. - c := s[0] - if !isDigit(c) { - return nil, syntaxError(s) - } - plus := c > '5' - if !plus && c == '5' { - var x big.Int - plus = x.And(v, one).Cmp(zero) != 0 // Last digit is not a zero. - for !plus && len(s) > 1 { - s = s[1:] - c := s[0] - if !isDigit(c) { - return nil, syntaxError(s) - } - plus = c != '0' + if dot && *scale > 0 { + (*scale)-- + } else if !dot { + if !isDigit(c) { + return "", syntaxError(s) } - } - if plus { - v.Add(v, one) - if v.Cmp(pow(ten, precision)) >= 0 { - v.Set(inf) + v.Mul(v, ten) + v.Add(v, big.NewInt(int64(c-'0'))) + if *integral == 0 { + return s, nil } + (*integral)-- } } - v.Mul(v, pow(ten, scale)) - if neg { - v.Neg(v) - } - return v, nil + return s, nil } // Format returns the string representation of x with the given precision and @@ -207,17 +226,8 @@ func Format(x *big.Int, precision, scale uint32) string { return "0" } - v := big.NewInt(0).Set(x) - neg := x.Sign() < 0 - if neg { - // Convert negative to positive. - v.Neg(x) - } - - // log_{10}(2^120) ~= 36.12, 37 decimal places - // plus dot, zero before dot, sign. - bts := make([]byte, 40) - pos := len(bts) + v, neg := prepareValue(x) + bts, pos := initializeBuffer() var digit big.Int for ; v.Cmp(zero) > 0; v.Div(v, ten) { @@ -229,9 +239,7 @@ func Format(x *big.Int, precision, scale uint32) string { digit.Mod(v, ten) d := int(digit.Int64()) if d != 0 || scale == 0 || pos > 0 { - const numbers = "0123456789" - pos-- - bts[pos] = numbers[d] + pos = appendDigit(bts, pos, d) } if scale > 0 { scale-- @@ -241,12 +249,51 @@ func Format(x *big.Int, precision, scale uint32) string { } } } - if scale > 0 { - for ; scale > 0; scale-- { - if precision == 0 { - return errorTag + + pos = appendLeadingZeros(bts, pos, &precision, &scale) + if neg { + pos-- + bts[pos] = '-' + } + + return xstring.FromBytes(bts[pos:]) +} + +func prepareValue(x *big.Int) (*big.Int, bool) { + v := big.NewInt(0).Set(x) + neg := x.Sign() < 0 + if neg { + // Convert negative to positive. + v.Neg(x) + } + + return v, neg +} + +func initializeBuffer() ([]byte, int) { + // log_{10}(2^120) ~= 36.12, 37 decimal places + // plus dot, zero before dot, sign. + bts := make([]byte, 40) + pos := len(bts) + + return bts, pos +} + +func appendDigit(bts []byte, pos, digit int) int { + const numbers = "0123456789" + pos-- + bts[pos] = numbers[digit] + + return pos +} + +func appendLeadingZeros(bts []byte, pos int, precision, scale *uint32) int { + if *scale > 0 { + for ; *scale > 0; (*scale)-- { + if *precision == 0 { + return len(bts) // Return the full buffer length on precision error. } - precision-- + (*precision)-- pos-- bts[pos] = '0' } @@ -258,12 +305,8 @@ func Format(x *big.Int, precision, scale uint32) string { pos-- bts[pos] = '0' } - if neg { - pos-- - bts[pos] = '-' - } - return xstring.FromBytes(bts[pos:]) + return pos } // BigIntToByte returns the 16-byte array representation of x. diff --git a/internal/stack/record.go b/internal/stack/record.go index efe5e9db6..9b967ce40 100644 --- a/internal/stack/record.go +++ b/internal/stack/record.go @@ -91,30 +91,30 @@ func (c call) Record(opts ...recordOption) string { opt(&optionsHolder) } } - name := runtime.FuncForPC(c.function).Name() - var ( - pkgPath string - pkgName string - structName string - funcName string - file = c.file - ) + name, file := extractNames(c.function, c.file) + pkgPath, pkgName, structName, funcName, lambdas := parseFunctionName(name) + + return buildRecordString(optionsHolder, pkgPath, pkgName, structName, funcName, file, c.line, lambdas) +} + +func extractNames(function uintptr, file string) (name, fileName string) { + name = runtime.FuncForPC(function).Name() if i := strings.LastIndex(file, "/"); i > -1 { - file = file[i+1:] + fileName = file[i+1:] + } else { + fileName = file } + name = strings.ReplaceAll(name, "[...]", "") + + return name, fileName +} + +func parseFunctionName(name string) (pkgPath, pkgName, structName, funcName string, lambdas []string) { if i := strings.LastIndex(name, "/"); i > -1 { pkgPath, name = name[:i], name[i+1:] } - name = strings.ReplaceAll(name, "[...]", "") split := strings.Split(name, ".") - lambdas := make([]string, 0, len(split)) - for i := range split { - elem := split[len(split)-i-1] - if !strings.HasPrefix(elem, "func") { - break - } - lambdas = append(lambdas, elem) - } + lambdas = extractLambdas(split) split = split[:len(split)-len(lambdas)] if len(split) > 0 { pkgName = split[0] @@ -126,6 +126,28 @@ func (c call) Record(opts ...recordOption) string { structName = split[1] } + return pkgPath, pkgName, structName, funcName, lambdas +} + +func extractLambdas(split []string) (lambdas []string) { + lambdas = make([]string, 0, len(split)) + for i := range split { + elem := split[len(split)-i-1] + if !strings.HasPrefix(elem, "func") { + break + } + lambdas = append(lambdas, elem) + } + + return lambdas +} + +func buildRecordString( + optionsHolder recordOptions, + pkgPath, pkgName, structName, funcName, file string, + line int, + lambdas []string, +) string { buffer := xstring.Buffer() defer buffer.Free() if optionsHolder.packagePath { @@ -164,7 +186,7 @@ func (c call) Record(opts ...recordOption) string { buffer.WriteString(file) if optionsHolder.line { buffer.WriteByte(':') - fmt.Fprintf(buffer, "%d", c.line) + fmt.Fprintf(buffer, "%d", line) } if closeBrace { buffer.WriteByte(')') diff --git a/internal/types/types.go b/internal/types/types.go index 4942cd787..b40a04003 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -82,6 +82,7 @@ func TypeFromYDB(x *Ydb.Type) Type { } } +//nolint:funlen func primitiveTypeFromYDB(t Ydb.Type_PrimitiveTypeId) Type { switch t { case Ydb.Type_BOOL: diff --git a/retry/retry.go b/retry/retry.go index 16b1050bf..3230e7fa7 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -269,39 +269,17 @@ func Retry(ctx context.Context, op retryOperation, opts ...Option) (finalErr err attempts++ select { case <-ctx.Done(): - return xerrors.WithStackTrace( - fmt.Errorf("retry failed on attempt No.%d: %w", - attempts, ctx.Err(), - ), - ) + return handleContextDone(ctx, attempts) default: - err := func() (err error) { - if options.panicCallback != nil { - defer func() { - if e := recover(); e != nil { - options.panicCallback(e) - err = xerrors.WithStackTrace( - fmt.Errorf("panic recovered: %v", e), - ) - } - }() - } - - return op(ctx) - }() + err := opWithRecover(ctx, options, op) if err == nil { return nil } if ctxErr := ctx.Err(); ctxErr != nil { - return xerrors.WithStackTrace( - xerrors.Join( - fmt.Errorf("context error occurred on attempt No.%d", attempts), - ctxErr, err, - ), - ) + return handleContextError(attempts, ctxErr, err) } m := Check(err) @@ -311,21 +289,11 @@ func Retry(ctx context.Context, op retryOperation, opts ...Option) (finalErr err } if !m.MustRetry(options.idempotent) { - return xerrors.WithStackTrace( - fmt.Errorf("non-retryable error occurred on attempt No.%d (idempotent=%v): %w", - attempts, options.idempotent, err, - ), - ) + return handleNonRetryableError(attempts, options.idempotent, err) } if e := wait.Wait(ctx, options.fastBackoff, options.slowBackoff, m.BackoffType(), i); e != nil { - return xerrors.WithStackTrace( - xerrors.Join( - fmt.Errorf("wait exit on attempt No.%d", - attempts, - ), e, err, - ), - ) + return handleWaitError(attempts, e, err) } code = m.StatusCode() @@ -333,6 +301,56 @@ func Retry(ctx context.Context, op retryOperation, opts ...Option) (finalErr err } } +func opWithRecover(ctx context.Context, options *retryOptions, op retryOperation) (err error) { + if options.panicCallback != nil { + defer func() { + if e := recover(); e != nil { + options.panicCallback(e) + err = xerrors.WithStackTrace( + fmt.Errorf("panic recovered: %v", e), + ) + } + }() + } + + return op(ctx) +} + +func handleContextDone(ctx context.Context, attempts int) error { + return xerrors.WithStackTrace( + fmt.Errorf("retry failed on attempt No.%d: %w", + attempts, ctx.Err(), + ), + ) +} + +func handleContextError(attempts int, ctxErr, err error) error { + return xerrors.WithStackTrace( + xerrors.Join( + fmt.Errorf("context error occurred on attempt No.%d", attempts), + ctxErr, err, + ), + ) +} + +func handleNonRetryableError(attempts int, idempotent bool, err error) error { + return xerrors.WithStackTrace( + fmt.Errorf("non-retryable error occurred on attempt No.%d (idempotent=%v): %w", + attempts, idempotent, err, + ), + ) +} + +func handleWaitError(attempts int, e, err error) error { + return xerrors.WithStackTrace( + xerrors.Join( + fmt.Errorf("wait exit on attempt No.%d", + attempts, + ), e, err, + ), + ) +} + // Check returns retry mode for queryErr. func Check(err error) (m retryMode) { code, errType, backoffType, deleteSession := xerrors.Check(err) From ed24a6ef84fc68e085b680f7d335c31c7f5bdfe6 Mon Sep 17 00:00:00 2001 From: Gleb Brozhe Date: Thu, 28 Mar 2024 13:15:47 -0700 Subject: [PATCH 02/11] added missing tests for the refactored functions --- internal/decimal/decimal.go | 94 +++++---- internal/decimal/decimal_test.go | 337 +++++++++++++++++++++++++++++++ internal/stack/record_test.go | 76 ++++++- retry/retry_test.go | 98 +++++++++ 4 files changed, 560 insertions(+), 45 deletions(-) diff --git a/internal/decimal/decimal.go b/internal/decimal/decimal.go index d0e3bf861..dd3911bb3 100644 --- a/internal/decimal/decimal.go +++ b/internal/decimal/decimal.go @@ -1,6 +1,7 @@ package decimal import ( + "errors" "math/big" "math/bits" @@ -144,11 +145,11 @@ func handleRemainingDigits(s string, v *big.Int, precision uint32) error { func shouldRoundUp(v *big.Int, s string) bool { var x big.Int - plus := x.And(v, one).Cmp(zero) != 0 // Last digit is not a zero. - for !plus && len(s) > 1 { - s = s[1:] + plus := x.And(v, big.NewInt(1)).Cmp(big.NewInt(0)) != 0 + for !plus && len(s) > 0 { c := s[0] - if !isDigit(c) { + s = s[1:] + if c < '0' || c > '9' { break } plus = c != '0' @@ -158,6 +159,10 @@ func shouldRoundUp(v *big.Int, s string) bool { } func parseSign(s string) (neg bool, remaining string) { + if s == "" { + return false, s + } + neg = s[0] == '-' if neg || s[0] == '+' { s = s[1:] @@ -176,32 +181,49 @@ func handleSpecialValues(v *big.Int, neg bool, pos, negVal *big.Int) (*big.Int, func parseNumber(s string, v *big.Int, integral, scale *uint32) (remaining string, err error) { var dot bool - for ; len(s) > 0; s = s[1:] { - c := s[0] + var processed bool + + for _, c := range s { if c == '.' { if dot { - return "", syntaxError(s) + return "", errors.New("syntax error: unexpected '.'") } dot = true continue } + + if !isDigit(byte(c)) { + return "", errors.New("syntax error: non-digit characters") + } + if dot && *scale > 0 { - (*scale)-- + *scale-- } else if !dot { - if !isDigit(c) { - return "", syntaxError(s) - } - v.Mul(v, ten) - v.Add(v, big.NewInt(int64(c-'0'))) if *integral == 0 { - return s, nil + remaining += string(c) + processed = true + + continue } - (*integral)-- + *integral-- + } + + if !processed { + digitVal := big.NewInt(int64(c - '0')) + v.Mul(v, big.NewInt(10)) + v.Add(v, digitVal) + } + } + + if !dot && *scale > 0 { + for *scale > 0 { + v.Mul(v, big.NewInt(10)) + *scale-- } } - return s, nil + return remaining, nil } // Format returns the string representation of x with the given precision and @@ -250,7 +272,23 @@ func Format(x *big.Int, precision, scale uint32) string { } } - pos = appendLeadingZeros(bts, pos, &precision, &scale) + if scale > 0 { + for ; scale > 0; (scale)-- { + if precision == 0 { + pos = len(bts) // Return the full buffer length on precision error. + } + precision-- + pos-- + bts[pos] = '0' + } + + pos-- + bts[pos] = '.' + } + if bts[pos] == '.' { + pos-- + bts[pos] = '0' + } if neg { pos-- bts[pos] = '-' @@ -287,28 +325,6 @@ func appendDigit(bts []byte, pos, digit int) int { return pos } -func appendLeadingZeros(bts []byte, pos int, precision, scale *uint32) int { - if *scale > 0 { - for ; *scale > 0; (*scale)-- { - if *precision == 0 { - return len(bts) // Return the full buffer length on precision error. - } - (*precision)-- - pos-- - bts[pos] = '0' - } - - pos-- - bts[pos] = '.' - } - if bts[pos] == '.' { - pos-- - bts[pos] = '0' - } - - return pos -} - // BigIntToByte returns the 16-byte array representation of x. // // If x value does not fit in 16 bytes with given precision, it returns 16-byte diff --git a/internal/decimal/decimal_test.go b/internal/decimal/decimal_test.go index fd7391da1..531da9952 100644 --- a/internal/decimal/decimal_test.go +++ b/internal/decimal/decimal_test.go @@ -2,7 +2,10 @@ package decimal import ( "encoding/binary" + "math/big" "testing" + + "github.com/stretchr/testify/require" ) func TestFromBytes(t *testing.T) { @@ -57,6 +60,340 @@ func TestFromBytes(t *testing.T) { } } +func TestShouldRoundUp(t *testing.T) { + tests := []struct { + name string + number *big.Int + additionalDigits string + expected bool + }{ + { + name: "Last digit not zero, no string", + number: big.NewInt(123), + additionalDigits: "", + expected: true, + }, + { + name: "Last digit zero, string starts not with zero", + number: big.NewInt(120), + additionalDigits: "1", + expected: true, + }, + { + name: "Last digit zero, string all zeros", + number: big.NewInt(120), + additionalDigits: "000", + expected: false, + }, + { + name: "Last digit not zero, string irrelevant", + number: big.NewInt(123), + additionalDigits: "004", + expected: true, + }, + { + name: "Last digit zero, string has non-zero after zeros", + number: big.NewInt(100), + additionalDigits: "001", + expected: true, + }, + { + name: "Last digit zero, string has non-digit characters", + number: big.NewInt(100), + additionalDigits: "00abc", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shouldRoundUp(tt.number, tt.additionalDigits) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestParseSign(t *testing.T) { + tests := []struct { + name string + input string + expectedNeg bool + expectedRem string + }{ + { + name: "Negative sign", + input: "-123", + expectedNeg: true, + expectedRem: "123", + }, + { + name: "Positive sign", + input: "+456", + expectedNeg: false, + expectedRem: "456", + }, + { + name: "No sign", + input: "789", + expectedNeg: false, + expectedRem: "789", + }, + { + name: "Empty string", + input: "", + expectedNeg: false, + expectedRem: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + neg, rem := parseSign(tt.input) + require.Equal(t, tt.expectedNeg, neg, "Neg flag does not match expected value") + require.Equal(t, tt.expectedRem, rem, "Remaining string does not match expected value") + }) + } +} + +func TestHandleSpecialValues(t *testing.T) { + tests := []struct { + name string + neg bool + initialValue *big.Int + posValue *big.Int + negValue *big.Int + expectedResult *big.Int + }{ + { + name: "Handle positive special value", + neg: false, + initialValue: big.NewInt(0), + posValue: big.NewInt(123), + negValue: big.NewInt(-123), + expectedResult: big.NewInt(123), + }, + { + name: "Handle negative special value", + neg: true, + initialValue: big.NewInt(0), + posValue: big.NewInt(123), + negValue: big.NewInt(-123), + expectedResult: big.NewInt(-123), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := handleSpecialValues(tt.initialValue, tt.neg, tt.posValue, tt.negValue) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, tt.expectedResult, result, "The result should match the expected value") + }) + } +} + +func TestParseNumber(t *testing.T) { + tests := []struct { + name string + s string + initialValue *big.Int + initialIntegral uint32 + initialScale uint32 + expectedValue *big.Int + expectedRemain string + expectError bool + }{ + { + name: "Parse integer", + s: "123", + initialValue: big.NewInt(0), + initialIntegral: 3, + initialScale: 0, + expectedValue: big.NewInt(123), + expectedRemain: "", + expectError: false, + }, + { + name: "Parse floating point", + s: "123.45", + initialValue: big.NewInt(0), + initialIntegral: 3, + initialScale: 2, + expectedValue: big.NewInt(12345), + expectedRemain: "", + expectError: false, + }, + { + name: "Non-digit character", + s: "123x45", + initialValue: big.NewInt(0), + initialIntegral: 3, + initialScale: 2, + expectedValue: nil, + expectedRemain: "", + expectError: true, + }, + { + name: "Multiple dots", + s: "123.45.67", + initialValue: big.NewInt(0), + initialIntegral: 3, + initialScale: 2, + expectedValue: nil, + expectedRemain: "", + expectError: true, + }, + { + name: "Early termination by integral", + s: "12345", + initialValue: big.NewInt(0), + initialIntegral: 3, + initialScale: 0, + expectedValue: big.NewInt(123), + expectedRemain: "45", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + integral := tt.initialIntegral + scale := tt.initialScale + remain, err := parseNumber(tt.s, tt.initialValue, &integral, &scale) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedValue, tt.initialValue) + require.Equal(t, tt.expectedRemain, remain) + } + }) + } +} + +func TestHandleRemainingDigits(t *testing.T) { + tests := []struct { + value *big.Int + expected *big.Int + name string + inputString string + precision uint32 + expectErr bool + }{ + { + name: "No rounding needed", + inputString: "4", + value: big.NewInt(1), + precision: 3, + expected: big.NewInt(1), + expectErr: false, + }, + { + name: "Rounding up needed", + inputString: "6", + value: big.NewInt(1), + precision: 3, + expected: big.NewInt(2), + expectErr: false, + }, + { + name: "Exactly halfway - assume round up for test", + inputString: "50", + value: big.NewInt(1), + precision: 3, + expected: big.NewInt(2), + expectErr: false, + }, + { + name: "Invalid character", + inputString: "a", + value: big.NewInt(1), + precision: 3, + expected: nil, + expectErr: true, + }, + { + name: "Exceeds precision limit - set to inf", + inputString: "9", // Triggers rounding that should exceed precision + value: pow(ten, 2), // Set v close to precision limit + precision: 2, // Precision limit + expected: inf, // Expected to be set to inf + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := handleRemainingDigits(tt.inputString, tt.value, tt.precision) + + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + if tt.expected.Cmp(inf) == 0 { + require.Equal(t, 0, tt.expected.Cmp(tt.value), "v should be set to inf") + } else { + require.Equal(t, tt.expected.String(), tt.value.String(), "Expected and actual values should match") + } + } + }) + } +} + +func TestPrepareValue(t *testing.T) { + tests := []struct { + name string + input *big.Int + expectedValue *big.Int + expectedNeg bool + }{ + { + name: "Positive value", + input: big.NewInt(123), + expectedValue: big.NewInt(123), + expectedNeg: false, + }, + { + name: "Negative value", + input: big.NewInt(-123), + expectedValue: big.NewInt(123), + expectedNeg: true, + }, + { + name: "Zero value", + input: big.NewInt(0), + expectedValue: big.NewInt(0), + expectedNeg: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + value, neg := prepareValue(tt.input) + require.Equal(t, tt.expectedValue, value) + require.Equal(t, tt.expectedNeg, neg) + }) + } +} + +func TestInitializeBuffer(t *testing.T) { + bts, pos := initializeBuffer() + require.Len(t, bts, 40) + require.Equal(t, 40, pos) +} + +func TestAppendDigit(t *testing.T) { + bts, _ := initializeBuffer() + pos := len(bts) + pos = appendDigit(bts, pos, 5) + + expectedByte := byte('5') // '5' as a byte, equivalent to 0x35 in hexadecimal + actualByte := bts[pos] // This should now correctly point to the appended '5' + + require.Equal(t, expectedByte, actualByte, "The appended byte should match '5'") +} + func uint128(hi, lo uint64) []byte { p := make([]byte, 16) binary.BigEndian.PutUint64(p[:8], hi) diff --git a/internal/stack/record_test.go b/internal/stack/record_test.go index d0cdec245..cd2f2d4f1 100644 --- a/internal/stack/record_test.go +++ b/internal/stack/record_test.go @@ -1,6 +1,9 @@ package stack import ( + "reflect" + "runtime" + "strings" "testing" "github.com/stretchr/testify/require" @@ -30,13 +33,13 @@ func TestRecord(t *testing.T) { }{ { act: Record(0), - exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestRecord(record_test.go:32)", + exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestRecord(record_test.go:35)", }, { act: func() string { return Record(1) }(), - exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestRecord(record_test.go:38)", + exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestRecord(record_test.go:41)", }, { act: func() string { @@ -44,7 +47,7 @@ func TestRecord(t *testing.T) { return Record(2) }() }(), - exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestRecord(record_test.go:46)", + exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestRecord(record_test.go:49)", }, { act: testStruct{depth: 0, opts: []recordOption{ @@ -164,7 +167,7 @@ func TestRecord(t *testing.T) { // FileName(false), // Line(false), }}.TestFunc(), - exp: "record_test.go:16", + exp: "record_test.go:19", }, { act: testStruct{depth: 0, opts: []recordOption{ @@ -236,7 +239,7 @@ func TestRecord(t *testing.T) { // FileName(false), // Line(false), }}.TestFunc(), - exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.testStruct.TestFunc.func1(record_test.go:16)", + exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.testStruct.TestFunc.func1(record_test.go:19)", }, { act: (&testStruct{depth: 0, opts: []recordOption{ @@ -248,7 +251,7 @@ func TestRecord(t *testing.T) { // FileName(false), // Line(false), }}).TestPointerFunc(), - exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.(*testStruct).TestPointerFunc.func1(record_test.go:22)", + exp: "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.(*testStruct).TestPointerFunc.func1(record_test.go:25)", }, } { t.Run("", func(t *testing.T) { @@ -257,6 +260,67 @@ func TestRecord(t *testing.T) { } } +func TestExtractNames(t *testing.T) { + testFunc := func() {} + funcPtr := reflect.ValueOf(testFunc).Pointer() + + funcNameExpected := runtime.FuncForPC(funcPtr).Name() + + _, file, _, ok := runtime.Caller(0) + require.True(t, ok, "runtime.Caller should return true indicating success") + + fileParts := strings.Split(file, "/") + fileNameExpected := fileParts[len(fileParts)-1] + + name, fileName := extractNames(funcPtr, file) + + require.Equal(t, funcNameExpected, name, "Function name should match expected value") + require.Equal(t, fileNameExpected, fileName, "File name should match expected value") +} + +func TestParseFunctionName(t *testing.T) { + name := "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestParseFunctionName.func1" + pkgPath, pkgName, structName, funcName, lambdas := parseFunctionName(name) + + require.Equal(t, "github.com/ydb-platform/ydb-go-sdk/v3/internal", pkgPath) + require.Equal(t, "stack", pkgName) + require.Empty(t, structName, "Struct name should be empty for standalone functions") + require.Equal(t, "TestParseFunctionName", funcName) + require.Contains(t, lambdas, "func1", "Lambdas should include 'func1'") +} + +func TestExtractLambdas(t *testing.T) { + split := []string{"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack", "TestExtractLambdas", "func1", "func2"} + lambdas := extractLambdas(split) + + require.Len(t, lambdas, 2, "There should be two lambda functions extracted") + require.Contains(t, lambdas, "func1") + require.Contains(t, lambdas, "func2") +} + +func TestBuildRecordString(t *testing.T) { + optionsHolder := recordOptions{ + packagePath: true, + packageName: false, + structName: true, + functionName: true, + fileName: true, + line: true, + lambdas: true, + } + pkgPath := "github.com/ydb-platform/ydb-go-sdk/v3/internal" + pkgName := "" + structName := "testStruct" + funcName := "TestFunc" + file := "record_test.go" + line := 319 + lambdas := []string{"func1"} + + result := buildRecordString(optionsHolder, pkgPath, pkgName, structName, funcName, file, line, lambdas) + expected := "github.com/ydb-platform/ydb-go-sdk/v3/internal.testStruct.TestFunc.func1(record_test.go:319)" + require.Equal(t, expected, result) +} + func BenchmarkCall(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { diff --git a/retry/retry_test.go b/retry/retry_test.go index 770f2ea65..87f5e8d0c 100644 --- a/retry/retry_test.go +++ b/retry/retry_test.go @@ -2,6 +2,7 @@ package retry import ( "context" + "errors" "fmt" "testing" "time" @@ -187,3 +188,100 @@ func TestRetryTransportCancelled(t *testing.T) { }) } } + +type MockPanicCallback struct { + called bool + received interface{} +} + +func (m *MockPanicCallback) Call(e interface{}) { + m.called = true + m.received = e +} + +func TestOpWithRecover_NoPanic(t *testing.T) { + ctx := context.Background() + options := &retryOptions{ + panicCallback: nil, + } + op := func(ctx context.Context) error { + return nil + } + + err := opWithRecover(ctx, options, op) + + require.NoError(t, err) +} + +func TestOpWithRecover_WithPanic(t *testing.T) { + ctx := context.Background() + mockCallback := new(MockPanicCallback) + options := &retryOptions{ + panicCallback: mockCallback.Call, + } + op := func(ctx context.Context) error { + panic("test panic") + } + + err := opWithRecover(ctx, options, op) + + require.Error(t, err) + require.Contains(t, err.Error(), "panic recovered: test panic") + require.True(t, mockCallback.called) + require.Equal(t, "test panic", mockCallback.received) +} + +func TestHandleContextDone(t *testing.T) { + attempts := 5 + ctx, cancel := context.WithCancel(context.Background()) + cancel() // immediately cancel to simulate a done context + + err := handleContextDone(ctx, attempts) + require.Error(t, err) + + expectedMsg := fmt.Sprintf("retry failed on attempt No.%d: %s", attempts, context.Canceled.Error()) + require.Contains(t, err.Error(), expectedMsg) +} + +func TestHandleContextError(t *testing.T) { + attempts := 3 + ctxErr := context.DeadlineExceeded + opErr := errors.New("operation failed") + + err := handleContextError(attempts, ctxErr, opErr) + require.Error(t, err) + + expectedMsg := fmt.Sprintf("context error occurred on attempt No.%d", attempts) + require.Contains(t, err.Error(), expectedMsg) + require.Contains(t, err.Error(), ctxErr.Error()) + require.Contains(t, err.Error(), opErr.Error()) +} + +func TestHandleNonRetryableError(t *testing.T) { + attempts := 2 + idempotent := false + opErr := errors.New("non-retryable error") + + err := handleNonRetryableError(attempts, idempotent, opErr) + require.Error(t, err) + + expectedMsg := fmt.Sprintf( + "non-retryable error occurred on attempt No.%d (idempotent=%v): %s", + attempts, idempotent, opErr.Error(), + ) + require.Contains(t, err.Error(), expectedMsg) +} + +func TestHandleWaitError(t *testing.T) { + attempts := 4 + waitErr := errors.New("wait error") + opErr := errors.New("operation during wait error") + + err := handleWaitError(attempts, waitErr, opErr) + require.Error(t, err) + + expectedMsg := fmt.Sprintf("wait exit on attempt No.%d", attempts) + require.Contains(t, err.Error(), expectedMsg) + require.Contains(t, err.Error(), waitErr.Error()) + require.Contains(t, err.Error(), opErr.Error()) +} From bacf99b14f70132eea7aa08343c2a2f240e5b4a7 Mon Sep 17 00:00:00 2001 From: Gleb Brozhe Date: Fri, 29 Mar 2024 10:20:26 -0700 Subject: [PATCH 03/11] refactor after PR review --- internal/decimal/decimal.go | 147 +++++++++++++++++-------------- internal/decimal/decimal_test.go | 54 +----------- internal/stack/record.go | 51 ++++++----- internal/stack/record_test.go | 29 +++--- retry/retry.go | 58 ++++-------- retry/retry_test.go | 56 ------------ 6 files changed, 149 insertions(+), 246 deletions(-) diff --git a/internal/decimal/decimal.go b/internal/decimal/decimal.go index dd3911bb3..dfe64f18d 100644 --- a/internal/decimal/decimal.go +++ b/internal/decimal/decimal.go @@ -4,8 +4,7 @@ import ( "errors" "math/big" "math/bits" - - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring" + "strings" ) const wordSize = bits.UintSize / 8 @@ -96,16 +95,14 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { return v, nil } - neg, s := parseSign(s) - if isInf(s) { - return handleSpecialValues(v, neg, inf, neginf) - } - if isNaN(s) { - return handleSpecialValues(v, neg, nan, negnan) + if SetSpecialValue(v, s) { + return v, nil } + neg, s := parseSign(s) + integral := precision - scale - s, err := parseNumber(s, v, &integral, &scale) + s, err := parseNumber(s, v, integral, scale) if err != nil { return nil, err } @@ -115,7 +112,6 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { return nil, err } } - v.Mul(v, pow(ten, scale)) if neg { v.Neg(v) @@ -124,19 +120,45 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { return v, nil } +func SetSpecialValue(v *big.Int, s string) bool { + neg := s[0] == '-' + if neg || s[0] == '+' { + s = s[1:] + } + if isInf(s) { + if neg { + v.Set(neginf) + } else { + v.Set(inf) + } + + return true + } + if isNaN(s) { + if neg { + v.Set(negnan) + } else { + v.Set(nan) + } + + return true + } + + return false +} + func handleRemainingDigits(s string, v *big.Int, precision uint32) error { c := s[0] if !isDigit(c) { return syntaxError(s) } - plus := c > '5' - if !plus && c == '5' { - plus = shouldRoundUp(v, s) - } - if plus { - v.Add(v, one) - if v.Cmp(pow(ten, precision)) >= 0 { - v.Set(inf) + + if c >= '5' { + if c > '5' || shouldRoundUp(v, s[1:]) { + v.Add(v, one) + if v.Cmp(pow(ten, precision)) >= 0 { + v.Set(inf) + } } } @@ -145,7 +167,7 @@ func handleRemainingDigits(s string, v *big.Int, precision uint32) error { func shouldRoundUp(v *big.Int, s string) bool { var x big.Int - plus := x.And(v, big.NewInt(1)).Cmp(big.NewInt(0)) != 0 + plus := x.And(v, one).Cmp(zero) != 0 for !plus && len(s) > 0 { c := s[0] s = s[1:] @@ -171,17 +193,10 @@ func parseSign(s string) (neg bool, remaining string) { return neg, s } -func handleSpecialValues(v *big.Int, neg bool, pos, negVal *big.Int) (*big.Int, error) { - if neg { - return v.Set(negVal), nil - } - - return v.Set(pos), nil -} - -func parseNumber(s string, v *big.Int, integral, scale *uint32) (remaining string, err error) { +func parseNumber(s string, v *big.Int, integral, scale uint32) (remaining string, err error) { var dot bool var processed bool + var remainingBuilder strings.Builder for _, c := range s { if c == '.' { @@ -197,16 +212,16 @@ func parseNumber(s string, v *big.Int, integral, scale *uint32) (remaining strin return "", errors.New("syntax error: non-digit characters") } - if dot && *scale > 0 { - *scale-- + if dot && scale > 0 { + scale-- } else if !dot { - if *integral == 0 { - remaining += string(c) + if integral == 0 { + remainingBuilder.WriteRune(c) processed = true continue } - *integral-- + integral-- } if !processed { @@ -216,40 +231,41 @@ func parseNumber(s string, v *big.Int, integral, scale *uint32) (remaining strin } } - if !dot && *scale > 0 { - for *scale > 0 { + if !dot && scale > 0 { + for scale > 0 { v.Mul(v, big.NewInt(10)) - *scale-- + scale-- } } - return remaining, nil + // Convert the strings.Builder content to a string + return remainingBuilder.String(), nil } // Format returns the string representation of x with the given precision and // scale. func Format(x *big.Int, precision, scale uint32) string { - switch { - case x.CmpAbs(inf) == 0: + // Check for special values and nil pointer upfront. + if x == nil { + return "0" + } + if x.CmpAbs(inf) == 0 { if x.Sign() < 0 { return "-inf" } return "inf" - - case x.CmpAbs(nan) == 0: + } + if x.CmpAbs(nan) == 0 { if x.Sign() < 0 { return "-nan" } return "nan" - - case x == nil: - return "0" } - v, neg := prepareValue(x) - bts, pos := initializeBuffer() + v, neg := abs(x) + bts, pos := newStringBuffer() var digit big.Int for ; v.Cmp(zero) > 0; v.Div(v, ten) { @@ -260,44 +276,46 @@ func Format(x *big.Int, precision, scale uint32) string { digit.Mod(v, ten) d := int(digit.Int64()) - if d != 0 || scale == 0 || pos > 0 { - pos = appendDigit(bts, pos, d) + + pos-- + if d != 0 || scale == 0 || pos >= 0 { + setDigitAtPosition(bts, pos, d) } + if scale > 0 { scale-- if scale == 0 && pos > 0 { + bts[pos-1] = '.' pos-- - bts[pos] = '.' } } } - if scale > 0 { - for ; scale > 0; (scale)-- { - if precision == 0 { - pos = len(bts) // Return the full buffer length on precision error. - } - precision-- - pos-- - bts[pos] = '0' - } + for ; scale > 0; scale-- { + if precision == 0 { + pos = 0 + break + } + precision-- pos-- - bts[pos] = '.' + bts[pos] = '0' } + if bts[pos] == '.' { pos-- bts[pos] = '0' } + if neg { pos-- bts[pos] = '-' } - return xstring.FromBytes(bts[pos:]) + return string(bts[pos:]) } -func prepareValue(x *big.Int) (*big.Int, bool) { +func abs(x *big.Int) (*big.Int, bool) { v := big.NewInt(0).Set(x) neg := x.Sign() < 0 if neg { @@ -308,7 +326,7 @@ func prepareValue(x *big.Int) (*big.Int, bool) { return v, neg } -func initializeBuffer() ([]byte, int) { +func newStringBuffer() ([]byte, int) { // log_{10}(2^120) ~= 36.12, 37 decimal places // plus dot, zero before dot, sign. bts := make([]byte, 40) @@ -317,12 +335,9 @@ func initializeBuffer() ([]byte, int) { return bts, pos } -func appendDigit(bts []byte, pos, digit int) int { +func setDigitAtPosition(bts []byte, pos, digit int) { const numbers = "0123456789" - pos-- bts[pos] = numbers[digit] - - return pos } // BigIntToByte returns the 16-byte array representation of x. diff --git a/internal/decimal/decimal_test.go b/internal/decimal/decimal_test.go index 531da9952..11ac54b23 100644 --- a/internal/decimal/decimal_test.go +++ b/internal/decimal/decimal_test.go @@ -155,43 +155,6 @@ func TestParseSign(t *testing.T) { } } -func TestHandleSpecialValues(t *testing.T) { - tests := []struct { - name string - neg bool - initialValue *big.Int - posValue *big.Int - negValue *big.Int - expectedResult *big.Int - }{ - { - name: "Handle positive special value", - neg: false, - initialValue: big.NewInt(0), - posValue: big.NewInt(123), - negValue: big.NewInt(-123), - expectedResult: big.NewInt(123), - }, - { - name: "Handle negative special value", - neg: true, - initialValue: big.NewInt(0), - posValue: big.NewInt(123), - negValue: big.NewInt(-123), - expectedResult: big.NewInt(-123), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := handleSpecialValues(tt.initialValue, tt.neg, tt.posValue, tt.negValue) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, tt.expectedResult, result, "The result should match the expected value") - }) - } -} - func TestParseNumber(t *testing.T) { tests := []struct { name string @@ -259,7 +222,7 @@ func TestParseNumber(t *testing.T) { t.Run(tt.name, func(t *testing.T) { integral := tt.initialIntegral scale := tt.initialScale - remain, err := parseNumber(tt.s, tt.initialValue, &integral, &scale) + remain, err := parseNumber(tt.s, tt.initialValue, integral, scale) if tt.expectError { require.Error(t, err) @@ -370,7 +333,7 @@ func TestPrepareValue(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - value, neg := prepareValue(tt.input) + value, neg := abs(tt.input) require.Equal(t, tt.expectedValue, value) require.Equal(t, tt.expectedNeg, neg) }) @@ -378,22 +341,11 @@ func TestPrepareValue(t *testing.T) { } func TestInitializeBuffer(t *testing.T) { - bts, pos := initializeBuffer() + bts, pos := newStringBuffer() require.Len(t, bts, 40) require.Equal(t, 40, pos) } -func TestAppendDigit(t *testing.T) { - bts, _ := initializeBuffer() - pos := len(bts) - pos = appendDigit(bts, pos, 5) - - expectedByte := byte('5') // '5' as a byte, equivalent to 0x35 in hexadecimal - actualByte := bts[pos] // This should now correctly point to the appended '5' - - require.Equal(t, expectedByte, actualByte, "The appended byte should match '5'") -} - func uint128(hi, lo uint64) []byte { p := make([]byte, 16) binary.BigEndian.PutUint64(p[:8], hi) diff --git a/internal/stack/record.go b/internal/stack/record.go index 9b967ce40..444c76644 100644 --- a/internal/stack/record.go +++ b/internal/stack/record.go @@ -18,6 +18,14 @@ type recordOptions struct { lambdas bool } +type functionDetails struct { + pkgPath string + pkgName string + structName string + funcName string + lambdas []string +} + type recordOption func(opts *recordOptions) func PackageName(b bool) recordOption { @@ -91,13 +99,13 @@ func (c call) Record(opts ...recordOption) string { opt(&optionsHolder) } } - name, file := extractNames(c.function, c.file) - pkgPath, pkgName, structName, funcName, lambdas := parseFunctionName(name) + name, file := extractName(c.function, c.file) + fnDetails := parseFunctionName(name) - return buildRecordString(optionsHolder, pkgPath, pkgName, structName, funcName, file, c.line, lambdas) + return buildRecordString(optionsHolder, &fnDetails, file, c.line) } -func extractNames(function uintptr, file string) (name, fileName string) { +func extractName(function uintptr, file string) (name, fileName string) { name = runtime.FuncForPC(function).Name() if i := strings.LastIndex(file, "/"); i > -1 { fileName = file[i+1:] @@ -109,24 +117,25 @@ func extractNames(function uintptr, file string) (name, fileName string) { return name, fileName } -func parseFunctionName(name string) (pkgPath, pkgName, structName, funcName string, lambdas []string) { +func parseFunctionName(name string) functionDetails { + var details functionDetails if i := strings.LastIndex(name, "/"); i > -1 { - pkgPath, name = name[:i], name[i+1:] + details.pkgPath, name = name[:i], name[i+1:] } split := strings.Split(name, ".") - lambdas = extractLambdas(split) - split = split[:len(split)-len(lambdas)] + details.lambdas = extractLambdas(split) + split = split[:len(split)-len(details.lambdas)] if len(split) > 0 { - pkgName = split[0] + details.pkgName = split[0] } if len(split) > 1 { - funcName = split[len(split)-1] + details.funcName = split[len(split)-1] } if len(split) > 2 { - structName = split[1] + details.structName = split[1] } - return pkgPath, pkgName, structName, funcName, lambdas + return details } func extractLambdas(split []string) (lambdas []string) { @@ -144,36 +153,36 @@ func extractLambdas(split []string) (lambdas []string) { func buildRecordString( optionsHolder recordOptions, - pkgPath, pkgName, structName, funcName, file string, + fnDetails *functionDetails, + file string, line int, - lambdas []string, ) string { buffer := xstring.Buffer() defer buffer.Free() if optionsHolder.packagePath { - buffer.WriteString(pkgPath) + buffer.WriteString(fnDetails.pkgPath) } if optionsHolder.packageName { if buffer.Len() > 0 { buffer.WriteByte('/') } - buffer.WriteString(pkgName) + buffer.WriteString(fnDetails.pkgName) } - if optionsHolder.structName && len(structName) > 0 { + if optionsHolder.structName && len(fnDetails.structName) > 0 { if buffer.Len() > 0 { buffer.WriteByte('.') } - buffer.WriteString(structName) + buffer.WriteString(fnDetails.structName) } if optionsHolder.functionName { if buffer.Len() > 0 { buffer.WriteByte('.') } - buffer.WriteString(funcName) + buffer.WriteString(fnDetails.funcName) if optionsHolder.lambdas { - for i := range lambdas { + for i := range fnDetails.lambdas { buffer.WriteByte('.') - buffer.WriteString(lambdas[len(lambdas)-i-1]) + buffer.WriteString(fnDetails.lambdas[len(fnDetails.lambdas)-i-1]) } } } diff --git a/internal/stack/record_test.go b/internal/stack/record_test.go index cd2f2d4f1..a9fe86064 100644 --- a/internal/stack/record_test.go +++ b/internal/stack/record_test.go @@ -272,7 +272,7 @@ func TestExtractNames(t *testing.T) { fileParts := strings.Split(file, "/") fileNameExpected := fileParts[len(fileParts)-1] - name, fileName := extractNames(funcPtr, file) + name, fileName := extractName(funcPtr, file) require.Equal(t, funcNameExpected, name, "Function name should match expected value") require.Equal(t, fileNameExpected, fileName, "File name should match expected value") @@ -280,13 +280,13 @@ func TestExtractNames(t *testing.T) { func TestParseFunctionName(t *testing.T) { name := "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack.TestParseFunctionName.func1" - pkgPath, pkgName, structName, funcName, lambdas := parseFunctionName(name) + fnDetails := parseFunctionName(name) - require.Equal(t, "github.com/ydb-platform/ydb-go-sdk/v3/internal", pkgPath) - require.Equal(t, "stack", pkgName) - require.Empty(t, structName, "Struct name should be empty for standalone functions") - require.Equal(t, "TestParseFunctionName", funcName) - require.Contains(t, lambdas, "func1", "Lambdas should include 'func1'") + require.Equal(t, "github.com/ydb-platform/ydb-go-sdk/v3/internal", fnDetails.pkgPath) + require.Equal(t, "stack", fnDetails.pkgName) + require.Empty(t, fnDetails.structName, "Struct name should be empty for standalone functions") + require.Equal(t, "TestParseFunctionName", fnDetails.funcName) + require.Contains(t, fnDetails.lambdas, "func1", "Lambdas should include 'func1'") } func TestExtractLambdas(t *testing.T) { @@ -308,15 +308,18 @@ func TestBuildRecordString(t *testing.T) { line: true, lambdas: true, } - pkgPath := "github.com/ydb-platform/ydb-go-sdk/v3/internal" - pkgName := "" - structName := "testStruct" - funcName := "TestFunc" + fnDetails := functionDetails{ + pkgPath: "github.com/ydb-platform/ydb-go-sdk/v3/internal", + pkgName: "", + structName: "testStruct", + funcName: "TestFunc", + + lambdas: []string{"func1"}, + } file := "record_test.go" line := 319 - lambdas := []string{"func1"} - result := buildRecordString(optionsHolder, pkgPath, pkgName, structName, funcName, file, line, lambdas) + result := buildRecordString(optionsHolder, &fnDetails, file, line) expected := "github.com/ydb-platform/ydb-go-sdk/v3/internal.testStruct.TestFunc.func1(record_test.go:319)" require.Equal(t, expected, result) } diff --git a/retry/retry.go b/retry/retry.go index 3230e7fa7..1b847f3dc 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -230,6 +230,8 @@ func WithPanicCallback(panicCallback func(e interface{})) panicCallbackOption { // Warning: if deadline without deadline or cancellation func Retry will be worked infinite // // If you need to retry your op func on some logic errors - you must return RetryableError() from retryOperation +// +//nolint:funlen func Retry(ctx context.Context, op retryOperation, opts ...Option) (finalErr error) { options := &retryOptions{ call: stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/retry.Retry"), @@ -269,7 +271,9 @@ func Retry(ctx context.Context, op retryOperation, opts ...Option) (finalErr err attempts++ select { case <-ctx.Done(): - return handleContextDone(ctx, attempts) + return xerrors.WithStackTrace( + fmt.Errorf("retry failed on attempt No.%d: %w", attempts, ctx.Err()), + ) default: err := opWithRecover(ctx, options, op) @@ -279,7 +283,12 @@ func Retry(ctx context.Context, op retryOperation, opts ...Option) (finalErr err } if ctxErr := ctx.Err(); ctxErr != nil { - return handleContextError(attempts, ctxErr, err) + return xerrors.WithStackTrace( + xerrors.Join( + fmt.Errorf("context error occurred on attempt No.%d", attempts), + ctxErr, err, + ), + ) } m := Check(err) @@ -289,11 +298,17 @@ func Retry(ctx context.Context, op retryOperation, opts ...Option) (finalErr err } if !m.MustRetry(options.idempotent) { - return handleNonRetryableError(attempts, options.idempotent, err) + return xerrors.WithStackTrace( + fmt.Errorf("non-retryable error occurred on attempt No.%d (idempotent=%v): %w", + attempts, options.idempotent, err), + ) } if e := wait.Wait(ctx, options.fastBackoff, options.slowBackoff, m.BackoffType(), i); e != nil { - return handleWaitError(attempts, e, err) + return xerrors.WithStackTrace( + xerrors.Join( + fmt.Errorf("wait exit on attempt No.%d", attempts), e, err), + ) } code = m.StatusCode() @@ -316,41 +331,6 @@ func opWithRecover(ctx context.Context, options *retryOptions, op retryOperation return op(ctx) } -func handleContextDone(ctx context.Context, attempts int) error { - return xerrors.WithStackTrace( - fmt.Errorf("retry failed on attempt No.%d: %w", - attempts, ctx.Err(), - ), - ) -} - -func handleContextError(attempts int, ctxErr, err error) error { - return xerrors.WithStackTrace( - xerrors.Join( - fmt.Errorf("context error occurred on attempt No.%d", attempts), - ctxErr, err, - ), - ) -} - -func handleNonRetryableError(attempts int, idempotent bool, err error) error { - return xerrors.WithStackTrace( - fmt.Errorf("non-retryable error occurred on attempt No.%d (idempotent=%v): %w", - attempts, idempotent, err, - ), - ) -} - -func handleWaitError(attempts int, e, err error) error { - return xerrors.WithStackTrace( - xerrors.Join( - fmt.Errorf("wait exit on attempt No.%d", - attempts, - ), e, err, - ), - ) -} - // Check returns retry mode for queryErr. func Check(err error) (m retryMode) { code, errType, backoffType, deleteSession := xerrors.Check(err) diff --git a/retry/retry_test.go b/retry/retry_test.go index 87f5e8d0c..75d9d7aad 100644 --- a/retry/retry_test.go +++ b/retry/retry_test.go @@ -2,7 +2,6 @@ package retry import ( "context" - "errors" "fmt" "testing" "time" @@ -230,58 +229,3 @@ func TestOpWithRecover_WithPanic(t *testing.T) { require.True(t, mockCallback.called) require.Equal(t, "test panic", mockCallback.received) } - -func TestHandleContextDone(t *testing.T) { - attempts := 5 - ctx, cancel := context.WithCancel(context.Background()) - cancel() // immediately cancel to simulate a done context - - err := handleContextDone(ctx, attempts) - require.Error(t, err) - - expectedMsg := fmt.Sprintf("retry failed on attempt No.%d: %s", attempts, context.Canceled.Error()) - require.Contains(t, err.Error(), expectedMsg) -} - -func TestHandleContextError(t *testing.T) { - attempts := 3 - ctxErr := context.DeadlineExceeded - opErr := errors.New("operation failed") - - err := handleContextError(attempts, ctxErr, opErr) - require.Error(t, err) - - expectedMsg := fmt.Sprintf("context error occurred on attempt No.%d", attempts) - require.Contains(t, err.Error(), expectedMsg) - require.Contains(t, err.Error(), ctxErr.Error()) - require.Contains(t, err.Error(), opErr.Error()) -} - -func TestHandleNonRetryableError(t *testing.T) { - attempts := 2 - idempotent := false - opErr := errors.New("non-retryable error") - - err := handleNonRetryableError(attempts, idempotent, opErr) - require.Error(t, err) - - expectedMsg := fmt.Sprintf( - "non-retryable error occurred on attempt No.%d (idempotent=%v): %s", - attempts, idempotent, opErr.Error(), - ) - require.Contains(t, err.Error(), expectedMsg) -} - -func TestHandleWaitError(t *testing.T) { - attempts := 4 - waitErr := errors.New("wait error") - opErr := errors.New("operation during wait error") - - err := handleWaitError(attempts, waitErr, opErr) - require.Error(t, err) - - expectedMsg := fmt.Sprintf("wait exit on attempt No.%d", attempts) - require.Contains(t, err.Error(), expectedMsg) - require.Contains(t, err.Error(), waitErr.Error()) - require.Contains(t, err.Error(), opErr.Error()) -} From 63a94fd115ff607830fa5e716ea3432e64b62984 Mon Sep 17 00:00:00 2001 From: Gleb Brozhe Date: Wed, 10 Apr 2024 12:40:49 -0700 Subject: [PATCH 04/11] refactor after review --- internal/decimal/decimal.go | 8 +-- internal/decimal/decimal_test.go | 112 +++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 5 deletions(-) diff --git a/internal/decimal/decimal.go b/internal/decimal/decimal.go index dfe64f18d..5da249462 100644 --- a/internal/decimal/decimal.go +++ b/internal/decimal/decimal.go @@ -121,10 +121,8 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { } func SetSpecialValue(v *big.Int, s string) bool { - neg := s[0] == '-' - if neg || s[0] == '+' { - s = s[1:] - } + neg, s := parseSign(s) + if isInf(s) { if neg { v.Set(neginf) @@ -154,7 +152,7 @@ func handleRemainingDigits(s string, v *big.Int, precision uint32) error { } if c >= '5' { - if c > '5' || shouldRoundUp(v, s[1:]) { + if c > '5' || shouldRoundUp(v, s) { v.Add(v, one) if v.Cmp(pow(ten, precision)) >= 0 { v.Set(inf) diff --git a/internal/decimal/decimal_test.go b/internal/decimal/decimal_test.go index 11ac54b23..0ff50b6ed 100644 --- a/internal/decimal/decimal_test.go +++ b/internal/decimal/decimal_test.go @@ -357,3 +357,115 @@ func uint128(hi, lo uint64) []byte { func uint128s(lo uint64) []byte { return uint128(0, lo) } + +func FuzzParse(f *testing.F) { + f.Fuzz(func(t *testing.T, s string, precision, scale uint32) { + expectedRes, expectedErr := oldParse(s, precision, scale) + res, err := Parse(s, precision, scale) + if expectedErr == nil { + require.Equal(t, expectedRes, res) + } else { + require.Error(t, err) + } + }) +} + +func oldParse(s string, precision, scale uint32) (*big.Int, error) { + if scale > precision { + return nil, precisionError(s, precision, scale) + } + + v := big.NewInt(0) + if s == "" { + return v, nil + } + + neg := s[0] == '-' + if neg || s[0] == '+' { + s = s[1:] + } + if isInf(s) { + if neg { + return v.Set(neginf), nil + } + + return v.Set(inf), nil + } + if isNaN(s) { + if neg { + return v.Set(negnan), nil + } + + return v.Set(nan), nil + } + + integral := precision - scale + + var dot bool + for ; len(s) > 0; s = s[1:] { + c := s[0] + if c == '.' { + if dot { + return nil, syntaxError(s) + } + dot = true + + continue + } + if dot { + if scale > 0 { + scale-- + } else { + break + } + } + + if !isDigit(c) { + return nil, syntaxError(s) + } + + v.Mul(v, ten) + v.Add(v, big.NewInt(int64(c-'0'))) + + if !dot && v.Cmp(zero) > 0 && integral == 0 { + if neg { + return neginf, nil + } + + return inf, nil + } + integral-- + } + //nolint:nestif + if len(s) > 0 { // Characters remaining. + c := s[0] + if !isDigit(c) { + return nil, syntaxError(s) + } + plus := c > '5' + if !plus && c == '5' { + var x big.Int + plus = x.And(v, one).Cmp(zero) != 0 // Last digit is not a zero. + for !plus && len(s) > 1 { + s = s[1:] + c := s[0] + if !isDigit(c) { + return nil, syntaxError(s) + } + plus = c != '0' + } + } + if plus { + v.Add(v, one) + if v.Cmp(pow(ten, precision)) >= 0 { + v.Set(inf) + } + } + } + v.Mul(v, pow(ten, scale)) + if neg { + v.Neg(v) + } + + return v, nil +} From afdd3e860ede58ef7b2c344f0af431126be8c83c Mon Sep 17 00:00:00 2001 From: Gleb Brozhe Date: Thu, 11 Apr 2024 11:33:28 -0700 Subject: [PATCH 05/11] fixes after Fuzzing test --- internal/decimal/decimal.go | 205 +++++++++---------- internal/decimal/decimal_test.go | 327 ++++++++++--------------------- 2 files changed, 196 insertions(+), 336 deletions(-) diff --git a/internal/decimal/decimal.go b/internal/decimal/decimal.go index 5da249462..cf1098f86 100644 --- a/internal/decimal/decimal.go +++ b/internal/decimal/decimal.go @@ -1,13 +1,15 @@ package decimal import ( - "errors" "math/big" "math/bits" - "strings" ) -const wordSize = bits.UintSize / 8 +const ( + wordSize = bits.UintSize / 8 + bufferSize = 40 + negMask = 0x80 +) var ( ten = big.NewInt(10) @@ -58,7 +60,7 @@ func FromBytes(bts []byte, precision, scale uint32) *big.Int { v.SetBytes(bts) - neg := bts[0]&0x80 != 0 + neg := bts[0]&negMask != 0 if neg { // Given bytes contains negative value. // Interpret is as two's complement. @@ -95,149 +97,118 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { return v, nil } - if SetSpecialValue(v, s) { - return v, nil + s, neg, specialValue := setSpecialValue(s, v) + if specialValue != nil { + return specialValue, nil } - - neg, s := parseSign(s) - - integral := precision - scale - s, err := parseNumber(s, v, integral, scale) + var err error + v, err = parseNumber(s, v, precision, scale, neg) if err != nil { return nil, err } - if len(s) > 0 { - if err := handleRemainingDigits(s, v, precision); err != nil { - return nil, err - } - } - v.Mul(v, pow(ten, scale)) - if neg { - v.Neg(v) - } - return v, nil } -func SetSpecialValue(v *big.Int, s string) bool { - neg, s := parseSign(s) - +func setSpecialValue(s string, v *big.Int) (string, bool, *big.Int) { + neg := s[0] == '-' + if neg || s[0] == '+' { + s = s[1:] + } if isInf(s) { if neg { - v.Set(neginf) - } else { - v.Set(inf) + return s, neg, v.Set(neginf) } - return true + return s, neg, v.Set(inf) } if isNaN(s) { if neg { - v.Set(negnan) - } else { - v.Set(nan) + return s, neg, v.Set(negnan) } - return true - } - - return false -} - -func handleRemainingDigits(s string, v *big.Int, precision uint32) error { - c := s[0] - if !isDigit(c) { - return syntaxError(s) - } - - if c >= '5' { - if c > '5' || shouldRoundUp(v, s) { - v.Add(v, one) - if v.Cmp(pow(ten, precision)) >= 0 { - v.Set(inf) - } - } - } - - return nil -} - -func shouldRoundUp(v *big.Int, s string) bool { - var x big.Int - plus := x.And(v, one).Cmp(zero) != 0 - for !plus && len(s) > 0 { - c := s[0] - s = s[1:] - if c < '0' || c > '9' { - break - } - plus = c != '0' - } - - return plus -} - -func parseSign(s string) (neg bool, remaining string) { - if s == "" { - return false, s - } - - neg = s[0] == '-' - if neg || s[0] == '+' { - s = s[1:] + return s, neg, v.Set(nan) } - return neg, s + return s, neg, nil } -func parseNumber(s string, v *big.Int, integral, scale uint32) (remaining string, err error) { +func parseNumber(s string, v *big.Int, precision, scale uint32, neg bool) (*big.Int, error) { + var err error + integral := precision - scale var dot bool - var processed bool - var remainingBuilder strings.Builder - - for _, c := range s { + for ; len(s) > 0; s = s[1:] { + c := s[0] if c == '.' { if dot { - return "", errors.New("syntax error: unexpected '.'") + return nil, syntaxError(s) } dot = true continue } + if dot && scale > 0 { + scale-- + } else if dot { + break + } - if !isDigit(byte(c)) { - return "", errors.New("syntax error: non-digit characters") + if !isDigit(c) { + return nil, syntaxError(s) } - if dot && scale > 0 { - scale-- - } else if !dot { - if integral == 0 { - remainingBuilder.WriteRune(c) - processed = true + v.Mul(v, ten) + v.Add(v, big.NewInt(int64(c-'0'))) - continue + if !dot && v.Cmp(zero) > 0 && integral == 0 { + if neg { + return neginf, nil } - integral-- - } - if !processed { - digitVal := big.NewInt(int64(c - '0')) - v.Mul(v, big.NewInt(10)) - v.Add(v, digitVal) + return inf, nil } + integral-- + } + if len(s) > 0 { // Characters remaining. + v, err = handleRemainingDigits(s, v, precision) + if err != nil { + return nil, err + } + } + v.Mul(v, pow(ten, scale)) + if neg { + v.Neg(v) } - if !dot && scale > 0 { - for scale > 0 { - v.Mul(v, big.NewInt(10)) - scale-- + return v, nil +} + +func handleRemainingDigits(s string, v *big.Int, precision uint32) (*big.Int, error) { + c := s[0] + if !isDigit(c) { + return nil, syntaxError(s) + } + plus := c > '5' + if !plus && c == '5' { + var x big.Int + plus = x.And(v, one).Cmp(zero) != 0 // Last digit is not a zero. + for !plus && len(s) > 1 { + s = s[1:] + c := s[0] + if !isDigit(c) { + return nil, syntaxError(s) + } + plus = c != '0' + } + } + if plus { + v.Add(v, one) + if v.Cmp(pow(ten, precision)) >= 0 { + v.Set(inf) } } - // Convert the strings.Builder content to a string - return remainingBuilder.String(), nil + return v, nil } // Format returns the string representation of x with the given precision and @@ -262,8 +233,17 @@ func Format(x *big.Int, precision, scale uint32) string { return "nan" } - v, neg := abs(x) - bts, pos := newStringBuffer() + v := big.NewInt(0).Set(x) + neg := x.Sign() < 0 + if neg { + // Convert negative to positive. + v.Neg(x) + } + + // log_{10}(2^120) ~= 36.12, 37 decimal places + // plus dot, zero before dot, sign. + bts := make([]byte, bufferSize) + pos := len(bts) var digit big.Int for ; v.Cmp(zero) > 0; v.Div(v, ten) { @@ -324,15 +304,6 @@ func abs(x *big.Int) (*big.Int, bool) { return v, neg } -func newStringBuffer() ([]byte, int) { - // log_{10}(2^120) ~= 36.12, 37 decimal places - // plus dot, zero before dot, sign. - bts := make([]byte, 40) - pos := len(bts) - - return bts, pos -} - func setDigitAtPosition(bts []byte, pos, digit int) { const numbers = "0123456789" bts[pos] = numbers[digit] diff --git a/internal/decimal/decimal_test.go b/internal/decimal/decimal_test.go index 0ff50b6ed..9443b40ad 100644 --- a/internal/decimal/decimal_test.go +++ b/internal/decimal/decimal_test.go @@ -60,245 +60,68 @@ func TestFromBytes(t *testing.T) { } } -func TestShouldRoundUp(t *testing.T) { - tests := []struct { - name string - number *big.Int - additionalDigits string - expected bool - }{ - { - name: "Last digit not zero, no string", - number: big.NewInt(123), - additionalDigits: "", - expected: true, - }, - { - name: "Last digit zero, string starts not with zero", - number: big.NewInt(120), - additionalDigits: "1", - expected: true, - }, - { - name: "Last digit zero, string all zeros", - number: big.NewInt(120), - additionalDigits: "000", - expected: false, - }, - { - name: "Last digit not zero, string irrelevant", - number: big.NewInt(123), - additionalDigits: "004", - expected: true, - }, - { - name: "Last digit zero, string has non-zero after zeros", - number: big.NewInt(100), - additionalDigits: "001", - expected: true, - }, - { - name: "Last digit zero, string has non-digit characters", - number: big.NewInt(100), - additionalDigits: "00abc", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := shouldRoundUp(tt.number, tt.additionalDigits) - require.Equal(t, tt.expected, result) - }) - } -} - -func TestParseSign(t *testing.T) { +func TestSetSpecialValue(t *testing.T) { tests := []struct { name string input string + expectedS string expectedNeg bool - expectedRem string + expectedV *big.Int }{ { - name: "Negative sign", - input: "-123", - expectedNeg: true, - expectedRem: "123", - }, - { - name: "Positive sign", - input: "+456", + name: "Positive infinity", + input: "inf", + expectedS: "inf", expectedNeg: false, - expectedRem: "456", + expectedV: inf, }, { - name: "No sign", - input: "789", - expectedNeg: false, - expectedRem: "789", + name: "Negative infinity", + input: "-inf", + expectedS: "inf", + expectedNeg: true, + expectedV: neginf, }, { - name: "Empty string", - input: "", + name: "Positive NaN", + input: "nan", + expectedS: "nan", expectedNeg: false, - expectedRem: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - neg, rem := parseSign(tt.input) - require.Equal(t, tt.expectedNeg, neg, "Neg flag does not match expected value") - require.Equal(t, tt.expectedRem, rem, "Remaining string does not match expected value") - }) - } -} - -func TestParseNumber(t *testing.T) { - tests := []struct { - name string - s string - initialValue *big.Int - initialIntegral uint32 - initialScale uint32 - expectedValue *big.Int - expectedRemain string - expectError bool - }{ - { - name: "Parse integer", - s: "123", - initialValue: big.NewInt(0), - initialIntegral: 3, - initialScale: 0, - expectedValue: big.NewInt(123), - expectedRemain: "", - expectError: false, - }, - { - name: "Parse floating point", - s: "123.45", - initialValue: big.NewInt(0), - initialIntegral: 3, - initialScale: 2, - expectedValue: big.NewInt(12345), - expectedRemain: "", - expectError: false, + expectedV: nan, }, { - name: "Non-digit character", - s: "123x45", - initialValue: big.NewInt(0), - initialIntegral: 3, - initialScale: 2, - expectedValue: nil, - expectedRemain: "", - expectError: true, - }, - { - name: "Multiple dots", - s: "123.45.67", - initialValue: big.NewInt(0), - initialIntegral: 3, - initialScale: 2, - expectedValue: nil, - expectedRemain: "", - expectError: true, - }, - { - name: "Early termination by integral", - s: "12345", - initialValue: big.NewInt(0), - initialIntegral: 3, - initialScale: 0, - expectedValue: big.NewInt(123), - expectedRemain: "45", - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - integral := tt.initialIntegral - scale := tt.initialScale - remain, err := parseNumber(tt.s, tt.initialValue, integral, scale) - - if tt.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, tt.expectedValue, tt.initialValue) - require.Equal(t, tt.expectedRemain, remain) - } - }) - } -} - -func TestHandleRemainingDigits(t *testing.T) { - tests := []struct { - value *big.Int - expected *big.Int - name string - inputString string - precision uint32 - expectErr bool - }{ - { - name: "No rounding needed", - inputString: "4", - value: big.NewInt(1), - precision: 3, - expected: big.NewInt(1), - expectErr: false, - }, - { - name: "Rounding up needed", - inputString: "6", - value: big.NewInt(1), - precision: 3, - expected: big.NewInt(2), - expectErr: false, - }, - { - name: "Exactly halfway - assume round up for test", - inputString: "50", - value: big.NewInt(1), - precision: 3, - expected: big.NewInt(2), - expectErr: false, + name: "Negative NaN", + input: "-nan", + expectedS: "nan", + expectedNeg: true, + expectedV: negnan, }, { - name: "Invalid character", - inputString: "a", - value: big.NewInt(1), - precision: 3, - expected: nil, - expectErr: true, + name: "Regular number", + input: "123", + expectedS: "123", + expectedNeg: false, + expectedV: nil, }, { - name: "Exceeds precision limit - set to inf", - inputString: "9", // Triggers rounding that should exceed precision - value: pow(ten, 2), // Set v close to precision limit - precision: 2, // Precision limit - expected: inf, // Expected to be set to inf - expectErr: false, + name: "Negative regular number", + input: "-123", + expectedS: "123", + expectedNeg: true, + expectedV: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := handleRemainingDigits(tt.inputString, tt.value, tt.precision) - - if tt.expectErr { - require.Error(t, err) + v := big.NewInt(0) + gotS, gotNeg, gotV := setSpecialValue(tt.input, v) + require.Equal(t, tt.expectedS, gotS) + require.Equal(t, tt.expectedNeg, gotNeg) + if tt.expectedV != nil { + require.Equal(t, 0, tt.expectedV.Cmp(gotV)) } else { - require.NoError(t, err) - if tt.expected.Cmp(inf) == 0 { - require.Equal(t, 0, tt.expected.Cmp(tt.value), "v should be set to inf") - } else { - require.Equal(t, tt.expected.String(), tt.value.String(), "Expected and actual values should match") - } + require.Nil(t, gotV) } }) } @@ -340,10 +163,76 @@ func TestPrepareValue(t *testing.T) { } } -func TestInitializeBuffer(t *testing.T) { - bts, pos := newStringBuffer() - require.Len(t, bts, 40) - require.Equal(t, 40, pos) +func TestParseNumber(t *testing.T) { + // Mock or define these as per your actual implementation. + tests := []struct { + name string + s string + wantValue *big.Int + precision uint32 + scale uint32 + neg bool + wantErr bool + }{ + { + name: "Valid number without decimal", + s: "123", + precision: 3, + scale: 0, + neg: false, + wantValue: big.NewInt(123), + wantErr: false, + }, + { + name: "Valid number with decimal", + s: "123.45", + precision: 5, + scale: 2, + neg: false, + wantValue: big.NewInt(12345), + wantErr: false, + }, + { + name: "Valid negative number", + s: "123", + precision: 3, + scale: 0, + neg: true, + wantValue: big.NewInt(-123), + wantErr: false, + }, + { + name: "Syntax error with non-digit", + s: "123a", + precision: 4, + scale: 0, + neg: false, + wantValue: nil, + wantErr: true, + }, + { + name: "Multiple decimal points", + s: "12.3.4", + precision: 5, + scale: 2, + neg: false, + wantValue: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := big.NewInt(0) + gotValue, gotErr := parseNumber(tt.s, v, tt.precision, tt.scale, tt.neg) + if tt.wantErr { + require.Error(t, gotErr) + } else { + require.NoError(t, gotErr) + require.Equal(t, 0, tt.wantValue.Cmp(gotValue)) + } + }) + } } func uint128(hi, lo uint64) []byte { From 922b301377670b6272dbcfb4b90cc8215c5ab1c3 Mon Sep 17 00:00:00 2001 From: Gleb Brozhe Date: Thu, 11 Apr 2024 11:47:33 -0700 Subject: [PATCH 06/11] added test for the specific scenario from the PR --- internal/decimal/decimal_test.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/internal/decimal/decimal_test.go b/internal/decimal/decimal_test.go index 9443b40ad..e30a338e7 100644 --- a/internal/decimal/decimal_test.go +++ b/internal/decimal/decimal_test.go @@ -246,6 +246,35 @@ func uint128(hi, lo uint64) []byte { func uint128s(lo uint64) []byte { return uint128(0, lo) } +func TestParse(t *testing.T) { + + tests := []struct { + name string + s string + precision uint32 + scale uint32 + }{ + { + name: "Specific Parse test", + s: "100", + precision: 0, + scale: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expectedRes, expectedErr := oldParse(tt.s, tt.precision, tt.scale) + res, err := Parse(tt.s, tt.precision, tt.scale) + if expectedErr == nil { + require.Equal(t, expectedRes, res) + } else { + require.Error(t, err) + } + }) + } + +} func FuzzParse(f *testing.F) { f.Fuzz(func(t *testing.T, s string, precision, scale uint32) { From 07d827e1cfade6d3215b059081706f9631ffe829 Mon Sep 17 00:00:00 2001 From: Gleb Brozhe Date: Thu, 11 Apr 2024 12:37:53 -0700 Subject: [PATCH 07/11] an attempt to resolve merge conflict --- internal/decimal/decimal_test.go | 3 +-- internal/stack/record.go | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/decimal/decimal_test.go b/internal/decimal/decimal_test.go index e30a338e7..a4fe5fcaf 100644 --- a/internal/decimal/decimal_test.go +++ b/internal/decimal/decimal_test.go @@ -246,8 +246,8 @@ func uint128(hi, lo uint64) []byte { func uint128s(lo uint64) []byte { return uint128(0, lo) } -func TestParse(t *testing.T) { +func TestParse(t *testing.T) { tests := []struct { name string s string @@ -273,7 +273,6 @@ func TestParse(t *testing.T) { } }) } - } func FuzzParse(f *testing.F) { diff --git a/internal/stack/record.go b/internal/stack/record.go index 444c76644..48f984127 100644 --- a/internal/stack/record.go +++ b/internal/stack/record.go @@ -99,6 +99,7 @@ func (c call) Record(opts ...recordOption) string { opt(&optionsHolder) } } + name, file := extractName(c.function, c.file) fnDetails := parseFunctionName(name) From 8a1e868550196e48705f090540fa6a08fed4ac05 Mon Sep 17 00:00:00 2001 From: Gleb Brozhe Date: Tue, 16 Apr 2024 10:23:05 -0700 Subject: [PATCH 08/11] refactor after PR review --- internal/decimal/decimal.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/internal/decimal/decimal.go b/internal/decimal/decimal.go index cf1098f86..4b7f4330d 100644 --- a/internal/decimal/decimal.go +++ b/internal/decimal/decimal.go @@ -111,10 +111,21 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { } func setSpecialValue(s string, v *big.Int) (string, bool, *big.Int) { + s, neg := parseSign(s) + + return parseSpecialValue(s, neg, v) +} + +func parseSign(s string) (string, bool) { neg := s[0] == '-' if neg || s[0] == '+' { s = s[1:] } + + return s, neg +} + +func parseSpecialValue(s string, neg bool, v *big.Int) (string, bool, *big.Int) { if isInf(s) { if neg { return s, neg, v.Set(neginf) From cc8c811034c8507865bcff4d652f818e45bfd958 Mon Sep 17 00:00:00 2001 From: Gleb Brozhe Date: Wed, 17 Apr 2024 08:25:24 -0700 Subject: [PATCH 09/11] attempt to fix merge conflict --- internal/stack/record.go | 22 ++++++++-------------- internal/stack/record_test.go | 9 --------- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/internal/stack/record.go b/internal/stack/record.go index 48f984127..e51fe6847 100644 --- a/internal/stack/record.go +++ b/internal/stack/record.go @@ -124,7 +124,14 @@ func parseFunctionName(name string) functionDetails { details.pkgPath, name = name[:i], name[i+1:] } split := strings.Split(name, ".") - details.lambdas = extractLambdas(split) + details.lambdas = make([]string, 0, len(split)) + for i := range split { + elem := split[len(split)-i-1] + if !strings.HasPrefix(elem, "func") { + break + } + details.lambdas = append(details.lambdas, elem) + } split = split[:len(split)-len(details.lambdas)] if len(split) > 0 { details.pkgName = split[0] @@ -139,19 +146,6 @@ func parseFunctionName(name string) functionDetails { return details } -func extractLambdas(split []string) (lambdas []string) { - lambdas = make([]string, 0, len(split)) - for i := range split { - elem := split[len(split)-i-1] - if !strings.HasPrefix(elem, "func") { - break - } - lambdas = append(lambdas, elem) - } - - return lambdas -} - func buildRecordString( optionsHolder recordOptions, fnDetails *functionDetails, diff --git a/internal/stack/record_test.go b/internal/stack/record_test.go index a9fe86064..9f10be303 100644 --- a/internal/stack/record_test.go +++ b/internal/stack/record_test.go @@ -289,15 +289,6 @@ func TestParseFunctionName(t *testing.T) { require.Contains(t, fnDetails.lambdas, "func1", "Lambdas should include 'func1'") } -func TestExtractLambdas(t *testing.T) { - split := []string{"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack", "TestExtractLambdas", "func1", "func2"} - lambdas := extractLambdas(split) - - require.Len(t, lambdas, 2, "There should be two lambda functions extracted") - require.Contains(t, lambdas, "func1") - require.Contains(t, lambdas, "func2") -} - func TestBuildRecordString(t *testing.T) { optionsHolder := recordOptions{ packagePath: true, From 757f14f6064d01b7020644827a0ef69abac4f763 Mon Sep 17 00:00:00 2001 From: Gleb Brozhe Date: Wed, 17 Apr 2024 08:27:51 -0700 Subject: [PATCH 10/11] attempt to resolve merge conflict --- internal/stack/record.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/stack/record.go b/internal/stack/record.go index e51fe6847..b8b19aa29 100644 --- a/internal/stack/record.go +++ b/internal/stack/record.go @@ -139,7 +139,7 @@ func parseFunctionName(name string) functionDetails { if len(split) > 1 { details.funcName = split[len(split)-1] } - if len(split) > 2 { + if len(split) > 2 { //nolint:gomnd details.structName = split[1] } From 013a60c92b063be33ef5727a2cb351f829bdbebb Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 23 Apr 2024 19:13:28 +0300 Subject: [PATCH 11/11] fix coverage --- internal/stack/record.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/internal/stack/record.go b/internal/stack/record.go index b8b19aa29..17c156736 100644 --- a/internal/stack/record.go +++ b/internal/stack/record.go @@ -2,6 +2,7 @@ package stack import ( "fmt" + "path" "runtime" "strings" @@ -108,11 +109,7 @@ func (c call) Record(opts ...recordOption) string { func extractName(function uintptr, file string) (name, fileName string) { name = runtime.FuncForPC(function).Name() - if i := strings.LastIndex(file, "/"); i > -1 { - fileName = file[i+1:] - } else { - fileName = file - } + _, fileName = path.Split(file) name = strings.ReplaceAll(name, "[...]", "") return name, fileName