From afdd3e860ede58ef7b2c344f0af431126be8c83c Mon Sep 17 00:00:00 2001 From: Gleb Brozhe Date: Thu, 11 Apr 2024 11:33:28 -0700 Subject: [PATCH] fixes after Fuzzing test --- internal/decimal/decimal.go | 205 +++++++++---------- internal/decimal/decimal_test.go | 327 ++++++++++--------------------- 2 files changed, 196 insertions(+), 336 deletions(-) diff --git a/internal/decimal/decimal.go b/internal/decimal/decimal.go index 5da249462..cf1098f86 100644 --- a/internal/decimal/decimal.go +++ b/internal/decimal/decimal.go @@ -1,13 +1,15 @@ package decimal import ( - "errors" "math/big" "math/bits" - "strings" ) -const wordSize = bits.UintSize / 8 +const ( + wordSize = bits.UintSize / 8 + bufferSize = 40 + negMask = 0x80 +) var ( ten = big.NewInt(10) @@ -58,7 +60,7 @@ func FromBytes(bts []byte, precision, scale uint32) *big.Int { v.SetBytes(bts) - neg := bts[0]&0x80 != 0 + neg := bts[0]&negMask != 0 if neg { // Given bytes contains negative value. // Interpret is as two's complement. @@ -95,149 +97,118 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { return v, nil } - if SetSpecialValue(v, s) { - return v, nil + s, neg, specialValue := setSpecialValue(s, v) + if specialValue != nil { + return specialValue, nil } - - neg, s := parseSign(s) - - integral := precision - scale - s, err := parseNumber(s, v, integral, scale) + var err error + v, err = parseNumber(s, v, precision, scale, neg) if err != nil { return nil, err } - if len(s) > 0 { - if err := handleRemainingDigits(s, v, precision); err != nil { - return nil, err - } - } - v.Mul(v, pow(ten, scale)) - if neg { - v.Neg(v) - } - return v, nil } -func SetSpecialValue(v *big.Int, s string) bool { - neg, s := parseSign(s) - +func setSpecialValue(s string, v *big.Int) (string, bool, *big.Int) { + neg := s[0] == '-' + if neg || s[0] == '+' { + s = s[1:] + } if isInf(s) { if neg { - v.Set(neginf) - } else { - v.Set(inf) + return s, neg, v.Set(neginf) } - return true + return s, neg, v.Set(inf) } if isNaN(s) { if neg { - v.Set(negnan) - } else { - v.Set(nan) + return s, neg, v.Set(negnan) } - return true - } - - return false -} - -func handleRemainingDigits(s string, v *big.Int, precision uint32) error { - c := s[0] - if !isDigit(c) { - return syntaxError(s) - } - - if c >= '5' { - if c > '5' || shouldRoundUp(v, s) { - v.Add(v, one) - if v.Cmp(pow(ten, precision)) >= 0 { - v.Set(inf) - } - } - } - - return nil -} - -func shouldRoundUp(v *big.Int, s string) bool { - var x big.Int - plus := x.And(v, one).Cmp(zero) != 0 - for !plus && len(s) > 0 { - c := s[0] - s = s[1:] - if c < '0' || c > '9' { - break - } - plus = c != '0' - } - - return plus -} - -func parseSign(s string) (neg bool, remaining string) { - if s == "" { - return false, s - } - - neg = s[0] == '-' - if neg || s[0] == '+' { - s = s[1:] + return s, neg, v.Set(nan) } - return neg, s + return s, neg, nil } -func parseNumber(s string, v *big.Int, integral, scale uint32) (remaining string, err error) { +func parseNumber(s string, v *big.Int, precision, scale uint32, neg bool) (*big.Int, error) { + var err error + integral := precision - scale var dot bool - var processed bool - var remainingBuilder strings.Builder - - for _, c := range s { + for ; len(s) > 0; s = s[1:] { + c := s[0] if c == '.' { if dot { - return "", errors.New("syntax error: unexpected '.'") + return nil, syntaxError(s) } dot = true continue } + if dot && scale > 0 { + scale-- + } else if dot { + break + } - if !isDigit(byte(c)) { - return "", errors.New("syntax error: non-digit characters") + if !isDigit(c) { + return nil, syntaxError(s) } - if dot && scale > 0 { - scale-- - } else if !dot { - if integral == 0 { - remainingBuilder.WriteRune(c) - processed = true + v.Mul(v, ten) + v.Add(v, big.NewInt(int64(c-'0'))) - continue + if !dot && v.Cmp(zero) > 0 && integral == 0 { + if neg { + return neginf, nil } - integral-- - } - if !processed { - digitVal := big.NewInt(int64(c - '0')) - v.Mul(v, big.NewInt(10)) - v.Add(v, digitVal) + return inf, nil } + integral-- + } + if len(s) > 0 { // Characters remaining. + v, err = handleRemainingDigits(s, v, precision) + if err != nil { + return nil, err + } + } + v.Mul(v, pow(ten, scale)) + if neg { + v.Neg(v) } - if !dot && scale > 0 { - for scale > 0 { - v.Mul(v, big.NewInt(10)) - scale-- + return v, nil +} + +func handleRemainingDigits(s string, v *big.Int, precision uint32) (*big.Int, error) { + c := s[0] + if !isDigit(c) { + return nil, syntaxError(s) + } + plus := c > '5' + if !plus && c == '5' { + var x big.Int + plus = x.And(v, one).Cmp(zero) != 0 // Last digit is not a zero. + for !plus && len(s) > 1 { + s = s[1:] + c := s[0] + if !isDigit(c) { + return nil, syntaxError(s) + } + plus = c != '0' + } + } + if plus { + v.Add(v, one) + if v.Cmp(pow(ten, precision)) >= 0 { + v.Set(inf) } } - // Convert the strings.Builder content to a string - return remainingBuilder.String(), nil + return v, nil } // Format returns the string representation of x with the given precision and @@ -262,8 +233,17 @@ func Format(x *big.Int, precision, scale uint32) string { return "nan" } - v, neg := abs(x) - bts, pos := newStringBuffer() + v := big.NewInt(0).Set(x) + neg := x.Sign() < 0 + if neg { + // Convert negative to positive. + v.Neg(x) + } + + // log_{10}(2^120) ~= 36.12, 37 decimal places + // plus dot, zero before dot, sign. + bts := make([]byte, bufferSize) + pos := len(bts) var digit big.Int for ; v.Cmp(zero) > 0; v.Div(v, ten) { @@ -324,15 +304,6 @@ func abs(x *big.Int) (*big.Int, bool) { return v, neg } -func newStringBuffer() ([]byte, int) { - // log_{10}(2^120) ~= 36.12, 37 decimal places - // plus dot, zero before dot, sign. - bts := make([]byte, 40) - pos := len(bts) - - return bts, pos -} - func setDigitAtPosition(bts []byte, pos, digit int) { const numbers = "0123456789" bts[pos] = numbers[digit] diff --git a/internal/decimal/decimal_test.go b/internal/decimal/decimal_test.go index 0ff50b6ed..9443b40ad 100644 --- a/internal/decimal/decimal_test.go +++ b/internal/decimal/decimal_test.go @@ -60,245 +60,68 @@ func TestFromBytes(t *testing.T) { } } -func TestShouldRoundUp(t *testing.T) { - tests := []struct { - name string - number *big.Int - additionalDigits string - expected bool - }{ - { - name: "Last digit not zero, no string", - number: big.NewInt(123), - additionalDigits: "", - expected: true, - }, - { - name: "Last digit zero, string starts not with zero", - number: big.NewInt(120), - additionalDigits: "1", - expected: true, - }, - { - name: "Last digit zero, string all zeros", - number: big.NewInt(120), - additionalDigits: "000", - expected: false, - }, - { - name: "Last digit not zero, string irrelevant", - number: big.NewInt(123), - additionalDigits: "004", - expected: true, - }, - { - name: "Last digit zero, string has non-zero after zeros", - number: big.NewInt(100), - additionalDigits: "001", - expected: true, - }, - { - name: "Last digit zero, string has non-digit characters", - number: big.NewInt(100), - additionalDigits: "00abc", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := shouldRoundUp(tt.number, tt.additionalDigits) - require.Equal(t, tt.expected, result) - }) - } -} - -func TestParseSign(t *testing.T) { +func TestSetSpecialValue(t *testing.T) { tests := []struct { name string input string + expectedS string expectedNeg bool - expectedRem string + expectedV *big.Int }{ { - name: "Negative sign", - input: "-123", - expectedNeg: true, - expectedRem: "123", - }, - { - name: "Positive sign", - input: "+456", + name: "Positive infinity", + input: "inf", + expectedS: "inf", expectedNeg: false, - expectedRem: "456", + expectedV: inf, }, { - name: "No sign", - input: "789", - expectedNeg: false, - expectedRem: "789", + name: "Negative infinity", + input: "-inf", + expectedS: "inf", + expectedNeg: true, + expectedV: neginf, }, { - name: "Empty string", - input: "", + name: "Positive NaN", + input: "nan", + expectedS: "nan", expectedNeg: false, - expectedRem: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - neg, rem := parseSign(tt.input) - require.Equal(t, tt.expectedNeg, neg, "Neg flag does not match expected value") - require.Equal(t, tt.expectedRem, rem, "Remaining string does not match expected value") - }) - } -} - -func TestParseNumber(t *testing.T) { - tests := []struct { - name string - s string - initialValue *big.Int - initialIntegral uint32 - initialScale uint32 - expectedValue *big.Int - expectedRemain string - expectError bool - }{ - { - name: "Parse integer", - s: "123", - initialValue: big.NewInt(0), - initialIntegral: 3, - initialScale: 0, - expectedValue: big.NewInt(123), - expectedRemain: "", - expectError: false, - }, - { - name: "Parse floating point", - s: "123.45", - initialValue: big.NewInt(0), - initialIntegral: 3, - initialScale: 2, - expectedValue: big.NewInt(12345), - expectedRemain: "", - expectError: false, + expectedV: nan, }, { - name: "Non-digit character", - s: "123x45", - initialValue: big.NewInt(0), - initialIntegral: 3, - initialScale: 2, - expectedValue: nil, - expectedRemain: "", - expectError: true, - }, - { - name: "Multiple dots", - s: "123.45.67", - initialValue: big.NewInt(0), - initialIntegral: 3, - initialScale: 2, - expectedValue: nil, - expectedRemain: "", - expectError: true, - }, - { - name: "Early termination by integral", - s: "12345", - initialValue: big.NewInt(0), - initialIntegral: 3, - initialScale: 0, - expectedValue: big.NewInt(123), - expectedRemain: "45", - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - integral := tt.initialIntegral - scale := tt.initialScale - remain, err := parseNumber(tt.s, tt.initialValue, integral, scale) - - if tt.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, tt.expectedValue, tt.initialValue) - require.Equal(t, tt.expectedRemain, remain) - } - }) - } -} - -func TestHandleRemainingDigits(t *testing.T) { - tests := []struct { - value *big.Int - expected *big.Int - name string - inputString string - precision uint32 - expectErr bool - }{ - { - name: "No rounding needed", - inputString: "4", - value: big.NewInt(1), - precision: 3, - expected: big.NewInt(1), - expectErr: false, - }, - { - name: "Rounding up needed", - inputString: "6", - value: big.NewInt(1), - precision: 3, - expected: big.NewInt(2), - expectErr: false, - }, - { - name: "Exactly halfway - assume round up for test", - inputString: "50", - value: big.NewInt(1), - precision: 3, - expected: big.NewInt(2), - expectErr: false, + name: "Negative NaN", + input: "-nan", + expectedS: "nan", + expectedNeg: true, + expectedV: negnan, }, { - name: "Invalid character", - inputString: "a", - value: big.NewInt(1), - precision: 3, - expected: nil, - expectErr: true, + name: "Regular number", + input: "123", + expectedS: "123", + expectedNeg: false, + expectedV: nil, }, { - name: "Exceeds precision limit - set to inf", - inputString: "9", // Triggers rounding that should exceed precision - value: pow(ten, 2), // Set v close to precision limit - precision: 2, // Precision limit - expected: inf, // Expected to be set to inf - expectErr: false, + name: "Negative regular number", + input: "-123", + expectedS: "123", + expectedNeg: true, + expectedV: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := handleRemainingDigits(tt.inputString, tt.value, tt.precision) - - if tt.expectErr { - require.Error(t, err) + v := big.NewInt(0) + gotS, gotNeg, gotV := setSpecialValue(tt.input, v) + require.Equal(t, tt.expectedS, gotS) + require.Equal(t, tt.expectedNeg, gotNeg) + if tt.expectedV != nil { + require.Equal(t, 0, tt.expectedV.Cmp(gotV)) } else { - require.NoError(t, err) - if tt.expected.Cmp(inf) == 0 { - require.Equal(t, 0, tt.expected.Cmp(tt.value), "v should be set to inf") - } else { - require.Equal(t, tt.expected.String(), tt.value.String(), "Expected and actual values should match") - } + require.Nil(t, gotV) } }) } @@ -340,10 +163,76 @@ func TestPrepareValue(t *testing.T) { } } -func TestInitializeBuffer(t *testing.T) { - bts, pos := newStringBuffer() - require.Len(t, bts, 40) - require.Equal(t, 40, pos) +func TestParseNumber(t *testing.T) { + // Mock or define these as per your actual implementation. + tests := []struct { + name string + s string + wantValue *big.Int + precision uint32 + scale uint32 + neg bool + wantErr bool + }{ + { + name: "Valid number without decimal", + s: "123", + precision: 3, + scale: 0, + neg: false, + wantValue: big.NewInt(123), + wantErr: false, + }, + { + name: "Valid number with decimal", + s: "123.45", + precision: 5, + scale: 2, + neg: false, + wantValue: big.NewInt(12345), + wantErr: false, + }, + { + name: "Valid negative number", + s: "123", + precision: 3, + scale: 0, + neg: true, + wantValue: big.NewInt(-123), + wantErr: false, + }, + { + name: "Syntax error with non-digit", + s: "123a", + precision: 4, + scale: 0, + neg: false, + wantValue: nil, + wantErr: true, + }, + { + name: "Multiple decimal points", + s: "12.3.4", + precision: 5, + scale: 2, + neg: false, + wantValue: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := big.NewInt(0) + gotValue, gotErr := parseNumber(tt.s, v, tt.precision, tt.scale, tt.neg) + if tt.wantErr { + require.Error(t, gotErr) + } else { + require.NoError(t, gotErr) + require.Equal(t, 0, tt.wantValue.Cmp(gotValue)) + } + }) + } } func uint128(hi, lo uint64) []byte {