From 3c7502290dda271e3bdc9649e85a1cb2409385f9 Mon Sep 17 00:00:00 2001 From: Roman Golov Date: Fri, 17 May 2024 21:34:13 +0300 Subject: [PATCH] Revert decimal to latest working version. --- CHANGELOG.md | 2 + internal/decimal/decimal.go | 170 ++++----- internal/decimal/decimal_test.go | 334 +----------------- .../unexpected_decimal_parse_test.go | 83 +++++ 4 files changed, 160 insertions(+), 429 deletions(-) create mode 100644 tests/integration/unexpected_decimal_parse_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 92b8a63b2..a038fbff2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Fixed incorrect formatting of decimal. Implementation of decimal has been reverted to latest working version + ## v3.67.1 * Fixed race of stop internal processes on close topic writer * Fixed goroutines leak within topic reader on network problems diff --git a/internal/decimal/decimal.go b/internal/decimal/decimal.go index 5926562f5..4fdbd32e6 100644 --- a/internal/decimal/decimal.go +++ b/internal/decimal/decimal.go @@ -3,6 +3,8 @@ package decimal import ( "math/big" "math/bits" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring" ) const ( @@ -97,56 +99,27 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { return v, nil } - s, neg, specialValue := setSpecialValue(s, v) - if specialValue != nil { - return specialValue, nil - } - var err error - v, err = parseNumber(s, v, precision, scale, neg) - if err != nil { - return nil, err - } - - return v, nil -} - -func setSpecialValue(s string, v *big.Int) (string, bool, *big.Int) { - s, neg := parseSign(s) - - return parseSpecialValue(s, neg, v) -} - -func parseSign(s string) (string, bool) { neg := s[0] == '-' if neg || s[0] == '+' { s = s[1:] } - - return s, neg -} - -func parseSpecialValue(s string, neg bool, v *big.Int) (string, bool, *big.Int) { if isInf(s) { if neg { - return s, neg, v.Set(neginf) + return v.Set(neginf), nil } - return s, neg, v.Set(inf) + return v.Set(inf), nil } if isNaN(s) { if neg { - return s, neg, v.Set(negnan) + return v.Set(negnan), nil } - return s, neg, v.Set(nan) + return v.Set(nan), nil } - return s, neg, nil -} - -func parseNumber(s string, v *big.Int, precision, scale uint32, neg bool) (*big.Int, error) { - var err error integral := precision - scale + var dot bool for ; len(s) > 0; s = s[1:] { c := s[0] @@ -158,10 +131,12 @@ func parseNumber(s string, v *big.Int, precision, scale uint32, neg bool) (*big. continue } - if dot && scale > 0 { - scale-- - } else if dot { - break + if dot { + if scale > 0 { + scale-- + } else { + break + } } if !isDigit(c) { @@ -180,10 +155,30 @@ func parseNumber(s string, v *big.Int, precision, scale uint32, neg bool) (*big. } integral-- } + //nolint:nestif if len(s) > 0 { // Characters remaining. - v, err = handleRemainingDigits(s, v, precision) - if err != nil { - return nil, err + c := s[0] + if !isDigit(c) { + return nil, syntaxError(s) + } + plus := c > '5' + if !plus && c == '5' { + var x big.Int + plus = x.And(v, one).Cmp(zero) != 0 // Last digit is not a zero. + for !plus && len(s) > 1 { + s = s[1:] + c := s[0] + if !isDigit(c) { + return nil, syntaxError(s) + } + plus = c != '0' + } + } + if plus { + v.Add(v, one) + if v.Cmp(pow(ten, precision)) >= 0 { + v.Set(inf) + } } } v.Mul(v, pow(ten, scale)) @@ -194,56 +189,26 @@ func parseNumber(s string, v *big.Int, precision, scale uint32, neg bool) (*big. return v, nil } -func handleRemainingDigits(s string, v *big.Int, precision uint32) (*big.Int, error) { - c := s[0] - if !isDigit(c) { - return nil, syntaxError(s) - } - plus := c > '5' - if !plus && c == '5' { - var x big.Int - plus = x.And(v, one).Cmp(zero) != 0 // Last digit is not a zero. - for !plus && len(s) > 1 { - s = s[1:] - c := s[0] - if !isDigit(c) { - return nil, syntaxError(s) - } - plus = c != '0' - } - } - if plus { - v.Add(v, one) - if v.Cmp(pow(ten, precision)) >= 0 { - v.Set(inf) - } - } - - return v, nil -} - // Format returns the string representation of x with the given precision and // scale. -// -//nolint:funlen func Format(x *big.Int, precision, scale uint32) string { - // Check for special values and nil pointer upfront. - if x == nil { - return "0" - } - if x.CmpAbs(inf) == 0 { + switch { + case x.CmpAbs(inf) == 0: if x.Sign() < 0 { return "-inf" } return "inf" - } - if x.CmpAbs(nan) == 0 { + + case x.CmpAbs(nan) == 0: if x.Sign() < 0 { return "-nan" } return "nan" + + case x == nil: + return "0" } v := big.NewInt(0).Set(x) @@ -267,59 +232,42 @@ func Format(x *big.Int, precision, scale uint32) string { digit.Mod(v, ten) d := int(digit.Int64()) - - pos-- - if d != 0 || scale == 0 || pos >= 0 { - setDigitAtPosition(bts, pos, d) + if d != 0 || scale == 0 || pos > 0 { + const numbers = "0123456789" + pos-- + bts[pos] = numbers[d] } - if scale > 0 { scale-- if scale == 0 && pos > 0 { - bts[pos-1] = '.' pos-- + bts[pos] = '.' } } } - - for ; scale > 0; scale-- { - if precision == 0 { - pos = 0 - - break + if scale > 0 { + for ; scale > 0; scale-- { + if precision == 0 { + return errorTag + } + precision-- + pos-- + bts[pos] = '0' } - precision-- + pos-- - bts[pos] = '0' + bts[pos] = '.' } - if bts[pos] == '.' { pos-- bts[pos] = '0' } - if neg { pos-- bts[pos] = '-' } - return string(bts[pos:]) -} - -func abs(x *big.Int) (*big.Int, bool) { - v := big.NewInt(0).Set(x) - neg := x.Sign() < 0 - if neg { - // Convert negative to positive. - v.Neg(x) - } - - return v, neg -} - -func setDigitAtPosition(bts []byte, pos, digit int) { - const numbers = "0123456789" - bts[pos] = numbers[digit] + return xstring.FromBytes(bts[pos:]) } // BigIntToByte returns the 16-byte array representation of x. diff --git a/internal/decimal/decimal_test.go b/internal/decimal/decimal_test.go index a4fe5fcaf..d6945135f 100644 --- a/internal/decimal/decimal_test.go +++ b/internal/decimal/decimal_test.go @@ -2,10 +2,7 @@ package decimal import ( "encoding/binary" - "math/big" "testing" - - "github.com/stretchr/testify/require" ) func TestFromBytes(t *testing.T) { @@ -14,31 +11,43 @@ func TestFromBytes(t *testing.T) { bts []byte precision uint32 scale uint32 + format string }{ { bts: uint128(0xffffffffffffffff, 0xffffffffffffffff), precision: 22, scale: 9, + format: "-0.000000001", }, { bts: uint128(0xffffffffffffffff, 0), precision: 22, scale: 9, + format: "-18446744073.709551616", }, { bts: uint128(0x4000000000000000, 0), precision: 22, scale: 9, + format: "inf", }, { bts: uint128(0x8000000000000000, 0), precision: 22, scale: 9, + format: "-inf", }, { bts: uint128s(1000000000), precision: 22, scale: 9, + format: "1.000000000", + }, + { + bts: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 250, 240, 128}, + precision: 22, + scale: 9, + format: "0.050000000", }, } { t.Run(test.name, func(t *testing.T) { @@ -51,6 +60,10 @@ func TestFromBytes(t *testing.T) { x, y, ) } + formatted := Format(x, test.precision, test.scale) + if test.format != formatted { + t.Errorf("unexpected decimal format. Expected: %s, actual %s", test.format, formatted) + } t.Logf( "%s %s", Format(x, test.precision, test.scale), @@ -60,181 +73,6 @@ func TestFromBytes(t *testing.T) { } } -func TestSetSpecialValue(t *testing.T) { - tests := []struct { - name string - input string - expectedS string - expectedNeg bool - expectedV *big.Int - }{ - { - name: "Positive infinity", - input: "inf", - expectedS: "inf", - expectedNeg: false, - expectedV: inf, - }, - { - name: "Negative infinity", - input: "-inf", - expectedS: "inf", - expectedNeg: true, - expectedV: neginf, - }, - { - name: "Positive NaN", - input: "nan", - expectedS: "nan", - expectedNeg: false, - expectedV: nan, - }, - { - name: "Negative NaN", - input: "-nan", - expectedS: "nan", - expectedNeg: true, - expectedV: negnan, - }, - { - name: "Regular number", - input: "123", - expectedS: "123", - expectedNeg: false, - expectedV: nil, - }, - { - name: "Negative regular number", - input: "-123", - expectedS: "123", - expectedNeg: true, - expectedV: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - v := big.NewInt(0) - gotS, gotNeg, gotV := setSpecialValue(tt.input, v) - require.Equal(t, tt.expectedS, gotS) - require.Equal(t, tt.expectedNeg, gotNeg) - if tt.expectedV != nil { - require.Equal(t, 0, tt.expectedV.Cmp(gotV)) - } else { - require.Nil(t, gotV) - } - }) - } -} - -func TestPrepareValue(t *testing.T) { - tests := []struct { - name string - input *big.Int - expectedValue *big.Int - expectedNeg bool - }{ - { - name: "Positive value", - input: big.NewInt(123), - expectedValue: big.NewInt(123), - expectedNeg: false, - }, - { - name: "Negative value", - input: big.NewInt(-123), - expectedValue: big.NewInt(123), - expectedNeg: true, - }, - { - name: "Zero value", - input: big.NewInt(0), - expectedValue: big.NewInt(0), - expectedNeg: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - value, neg := abs(tt.input) - require.Equal(t, tt.expectedValue, value) - require.Equal(t, tt.expectedNeg, neg) - }) - } -} - -func TestParseNumber(t *testing.T) { - // Mock or define these as per your actual implementation. - tests := []struct { - name string - s string - wantValue *big.Int - precision uint32 - scale uint32 - neg bool - wantErr bool - }{ - { - name: "Valid number without decimal", - s: "123", - precision: 3, - scale: 0, - neg: false, - wantValue: big.NewInt(123), - wantErr: false, - }, - { - name: "Valid number with decimal", - s: "123.45", - precision: 5, - scale: 2, - neg: false, - wantValue: big.NewInt(12345), - wantErr: false, - }, - { - name: "Valid negative number", - s: "123", - precision: 3, - scale: 0, - neg: true, - wantValue: big.NewInt(-123), - wantErr: false, - }, - { - name: "Syntax error with non-digit", - s: "123a", - precision: 4, - scale: 0, - neg: false, - wantValue: nil, - wantErr: true, - }, - { - name: "Multiple decimal points", - s: "12.3.4", - precision: 5, - scale: 2, - neg: false, - wantValue: nil, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - v := big.NewInt(0) - gotValue, gotErr := parseNumber(tt.s, v, tt.precision, tt.scale, tt.neg) - if tt.wantErr { - require.Error(t, gotErr) - } else { - require.NoError(t, gotErr) - require.Equal(t, 0, tt.wantValue.Cmp(gotValue)) - } - }) - } -} - func uint128(hi, lo uint64) []byte { p := make([]byte, 16) binary.BigEndian.PutUint64(p[:8], hi) @@ -246,143 +84,3 @@ func uint128(hi, lo uint64) []byte { func uint128s(lo uint64) []byte { return uint128(0, lo) } - -func TestParse(t *testing.T) { - tests := []struct { - name string - s string - precision uint32 - scale uint32 - }{ - { - name: "Specific Parse test", - s: "100", - precision: 0, - scale: 0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - expectedRes, expectedErr := oldParse(tt.s, tt.precision, tt.scale) - res, err := Parse(tt.s, tt.precision, tt.scale) - if expectedErr == nil { - require.Equal(t, expectedRes, res) - } else { - require.Error(t, err) - } - }) - } -} - -func FuzzParse(f *testing.F) { - f.Fuzz(func(t *testing.T, s string, precision, scale uint32) { - expectedRes, expectedErr := oldParse(s, precision, scale) - res, err := Parse(s, precision, scale) - if expectedErr == nil { - require.Equal(t, expectedRes, res) - } else { - require.Error(t, err) - } - }) -} - -func oldParse(s string, precision, scale uint32) (*big.Int, error) { - if scale > precision { - return nil, precisionError(s, precision, scale) - } - - v := big.NewInt(0) - if s == "" { - return v, nil - } - - neg := s[0] == '-' - if neg || s[0] == '+' { - s = s[1:] - } - if isInf(s) { - if neg { - return v.Set(neginf), nil - } - - return v.Set(inf), nil - } - if isNaN(s) { - if neg { - return v.Set(negnan), nil - } - - return v.Set(nan), nil - } - - integral := precision - scale - - var dot bool - for ; len(s) > 0; s = s[1:] { - c := s[0] - if c == '.' { - if dot { - return nil, syntaxError(s) - } - dot = true - - continue - } - if dot { - if scale > 0 { - scale-- - } else { - break - } - } - - if !isDigit(c) { - return nil, syntaxError(s) - } - - v.Mul(v, ten) - v.Add(v, big.NewInt(int64(c-'0'))) - - if !dot && v.Cmp(zero) > 0 && integral == 0 { - if neg { - return neginf, nil - } - - return inf, nil - } - integral-- - } - //nolint:nestif - if len(s) > 0 { // Characters remaining. - c := s[0] - if !isDigit(c) { - return nil, syntaxError(s) - } - plus := c > '5' - if !plus && c == '5' { - var x big.Int - plus = x.And(v, one).Cmp(zero) != 0 // Last digit is not a zero. - for !plus && len(s) > 1 { - s = s[1:] - c := s[0] - if !isDigit(c) { - return nil, syntaxError(s) - } - plus = c != '0' - } - } - if plus { - v.Add(v, one) - if v.Cmp(pow(ten, precision)) >= 0 { - v.Set(inf) - } - } - } - v.Mul(v, pow(ten, scale)) - if neg { - v.Neg(v) - } - - return v, nil -} diff --git a/tests/integration/unexpected_decimal_parse_test.go b/tests/integration/unexpected_decimal_parse_test.go new file mode 100644 index 000000000..22b1d63ed --- /dev/null +++ b/tests/integration/unexpected_decimal_parse_test.go @@ -0,0 +1,83 @@ +//go:build integration +// +build integration + +package integration + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ydb-platform/ydb-go-sdk/v3" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/decimal" + "github.com/ydb-platform/ydb-go-sdk/v3/table" + "github.com/ydb-platform/ydb-go-sdk/v3/table/types" +) + +func TestIssue1234UnexpectedDecimalRepresentation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAnonymousCredentials(), + ) + require.NoError(t, err) + + defer func(db *ydb.Driver) { + // cleanup + _ = db.Close(ctx) + }(db) + + tests := []struct { + name string + bts [16]byte + precision uint32 + scale uint32 + expectedFormat string + }{ + { + bts: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 250, 240, 128}, + precision: 22, + scale: 9, + expectedFormat: "0.050000000", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected := decimal.Decimal{ + Bytes: tt.bts, + Precision: tt.precision, + Scale: tt.scale, + } + var actual decimal.Decimal + + err = db.Table().Do(ctx, func(ctx context.Context, s table.Session) error { + _, result, err := s.Execute(ctx, table.DefaultTxControl(), ` + DECLARE $value AS Decimal(22,9); + SELECT $value;`, + table.NewQueryParameters( + table.ValueParam("$value", types.DecimalValue(&expected)), + ), + ) + if err != nil { + return err + } + for result.NextResultSet(ctx) { + for result.NextRow() { + err = result.Scan(&actual) + if err != nil { + return err + } + } + } + return nil + }) + require.NoError(t, err) + require.Equal(t, expected, actual) + require.Equal(t, tt.expectedFormat, actual.String()) + }) + } +}