Skip to content

Commit

Permalink
fixes after Fuzzing test
Browse files Browse the repository at this point in the history
  • Loading branch information
Gleb Brozhe committed Apr 11, 2024
1 parent 63a94fd commit afdd3e8
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 336 deletions.
205 changes: 88 additions & 117 deletions internal/decimal/decimal.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package decimal

import (
"errors"
"math/big"
"math/bits"
"strings"
)

const wordSize = bits.UintSize / 8
const (
wordSize = bits.UintSize / 8
bufferSize = 40
negMask = 0x80
)

var (
ten = big.NewInt(10)
Expand Down Expand Up @@ -58,7 +60,7 @@ func FromBytes(bts []byte, precision, scale uint32) *big.Int {

v.SetBytes(bts)

neg := bts[0]&0x80 != 0
neg := bts[0]&negMask != 0
if neg {
// Given bytes contains negative value.
// Interpret is as two's complement.
Expand Down Expand Up @@ -95,149 +97,118 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) {
return v, nil
}

if SetSpecialValue(v, s) {
return v, nil
s, neg, specialValue := setSpecialValue(s, v)
if specialValue != nil {
return specialValue, nil
}

neg, s := parseSign(s)

integral := precision - scale
s, err := parseNumber(s, v, integral, scale)
var err error
v, err = parseNumber(s, v, precision, scale, neg)
if err != nil {
return nil, err
}

if len(s) > 0 {
if err := handleRemainingDigits(s, v, precision); err != nil {
return nil, err
}
}
v.Mul(v, pow(ten, scale))
if neg {
v.Neg(v)
}

return v, nil
}

func SetSpecialValue(v *big.Int, s string) bool {
neg, s := parseSign(s)

func setSpecialValue(s string, v *big.Int) (string, bool, *big.Int) {
neg := s[0] == '-'
if neg || s[0] == '+' {
s = s[1:]
}
if isInf(s) {
if neg {
v.Set(neginf)
} else {
v.Set(inf)
return s, neg, v.Set(neginf)
}

return true
return s, neg, v.Set(inf)
}
if isNaN(s) {
if neg {
v.Set(negnan)
} else {
v.Set(nan)
return s, neg, v.Set(negnan)
}

return true
}

return false
}

func handleRemainingDigits(s string, v *big.Int, precision uint32) error {
c := s[0]
if !isDigit(c) {
return syntaxError(s)
}

if c >= '5' {
if c > '5' || shouldRoundUp(v, s) {
v.Add(v, one)
if v.Cmp(pow(ten, precision)) >= 0 {
v.Set(inf)
}
}
}

return nil
}

func shouldRoundUp(v *big.Int, s string) bool {
var x big.Int
plus := x.And(v, one).Cmp(zero) != 0
for !plus && len(s) > 0 {
c := s[0]
s = s[1:]
if c < '0' || c > '9' {
break
}
plus = c != '0'
}

return plus
}

func parseSign(s string) (neg bool, remaining string) {
if s == "" {
return false, s
}

neg = s[0] == '-'
if neg || s[0] == '+' {
s = s[1:]
return s, neg, v.Set(nan)
}

return neg, s
return s, neg, nil
}

func parseNumber(s string, v *big.Int, integral, scale uint32) (remaining string, err error) {
func parseNumber(s string, v *big.Int, precision, scale uint32, neg bool) (*big.Int, error) {
var err error
integral := precision - scale
var dot bool
var processed bool
var remainingBuilder strings.Builder

for _, c := range s {
for ; len(s) > 0; s = s[1:] {
c := s[0]
if c == '.' {
if dot {
return "", errors.New("syntax error: unexpected '.'")
return nil, syntaxError(s)
}
dot = true

continue
}
if dot && scale > 0 {
scale--
} else if dot {
break
}

if !isDigit(byte(c)) {
return "", errors.New("syntax error: non-digit characters")
if !isDigit(c) {
return nil, syntaxError(s)
}

if dot && scale > 0 {
scale--
} else if !dot {
if integral == 0 {
remainingBuilder.WriteRune(c)
processed = true
v.Mul(v, ten)
v.Add(v, big.NewInt(int64(c-'0')))

continue
if !dot && v.Cmp(zero) > 0 && integral == 0 {
if neg {
return neginf, nil
}
integral--
}

if !processed {
digitVal := big.NewInt(int64(c - '0'))
v.Mul(v, big.NewInt(10))
v.Add(v, digitVal)
return inf, nil
}
integral--
}
if len(s) > 0 { // Characters remaining.
v, err = handleRemainingDigits(s, v, precision)
if err != nil {
return nil, err
}
}
v.Mul(v, pow(ten, scale))
if neg {
v.Neg(v)
}

if !dot && scale > 0 {
for scale > 0 {
v.Mul(v, big.NewInt(10))
scale--
return v, nil
}

func handleRemainingDigits(s string, v *big.Int, precision uint32) (*big.Int, error) {
c := s[0]
if !isDigit(c) {
return nil, syntaxError(s)
}
plus := c > '5'
if !plus && c == '5' {
var x big.Int
plus = x.And(v, one).Cmp(zero) != 0 // Last digit is not a zero.
for !plus && len(s) > 1 {
s = s[1:]
c := s[0]
if !isDigit(c) {
return nil, syntaxError(s)
}
plus = c != '0'
}
}
if plus {
v.Add(v, one)
if v.Cmp(pow(ten, precision)) >= 0 {
v.Set(inf)
}
}

// Convert the strings.Builder content to a string
return remainingBuilder.String(), nil
return v, nil
}

// Format returns the string representation of x with the given precision and
Expand All @@ -262,8 +233,17 @@ func Format(x *big.Int, precision, scale uint32) string {
return "nan"
}

v, neg := abs(x)
bts, pos := newStringBuffer()
v := big.NewInt(0).Set(x)
neg := x.Sign() < 0
if neg {
// Convert negative to positive.
v.Neg(x)
}

// log_{10}(2^120) ~= 36.12, 37 decimal places
// plus dot, zero before dot, sign.
bts := make([]byte, bufferSize)
pos := len(bts)

var digit big.Int
for ; v.Cmp(zero) > 0; v.Div(v, ten) {
Expand Down Expand Up @@ -324,15 +304,6 @@ func abs(x *big.Int) (*big.Int, bool) {
return v, neg
}

func newStringBuffer() ([]byte, int) {
// log_{10}(2^120) ~= 36.12, 37 decimal places
// plus dot, zero before dot, sign.
bts := make([]byte, 40)
pos := len(bts)

return bts, pos
}

func setDigitAtPosition(bts []byte, pos, digit int) {
const numbers = "0123456789"
bts[pos] = numbers[digit]
Expand Down
Loading

0 comments on commit afdd3e8

Please sign in to comment.