diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/.DS_Store differ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..71e3f4b --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,137 @@ + + +# Changelog + +## [Unreleased] + +### Features + +* [#17803](https://github.com/cosmos/cosmos-sdk/pull/17803) Add mutative api for Int.BigInt() + +### Bug Fixes + +* [#18228](https://github.com/cosmos/cosmos-sdk/pull/18228) Fix panic when calling `BigInt()` on an uninitialized `Uint`. +* [#18214](https://github.com/cosmos/cosmos-sdk/pull/18214) Ensure that modifying the argument to `NewUIntFromBigInt` doesn't mutate the returned value. +* [#18211](https://github.com/cosmos/cosmos-sdk/pull/18211) RelativePow now returns 1 when 0^0, before it was returning the scale factor. +* [#17725](https://github.com/cosmos/cosmos-sdk/pull/17725) Fix state break in ApproxRoot. This has been present since math/v1.0.1. It changed the rounding behavior at precision end in an intermediary division from banker's to truncation. The truncation occurs from binary right shift in the case of square roots. The change is now reverted back to banker's rounding universally for any root. + +## [math/v1.1.2](https://github.com/cosmos/cosmos-sdk/releases/tag/math/v1.1.2) - 2023-08-21 + +### Bug Fixes + +* [#17489](https://github.com/cosmos/cosmos-sdk/pull/17489) Revert [#16263](https://github.com/cosmos/cosmos-sdk/pull/16263). + +## [math/v1.1.1](https://github.com/cosmos/cosmos-sdk/releases/tag/math/v1.1.1) - 2023-08-21 + +### Bug Fixes + +* [#17480](https://github.com/cosmos/cosmos-sdk/pull/17480) Fix panic when calling `.Size()` on a nil `math.Int` value. + +## [math/v1.1.0](https://github.com/cosmos/cosmos-sdk/releases/tag/math/v1.1.0) - 2023-08-19 + +### Features + +* [#17427](https://github.com/cosmos/cosmos-sdk/pull/17427) Implement LegacyDec.MulRoundUp that rounds up at precision end. + +### Improvements + +* [#17109](https://github.com/cosmos/cosmos-sdk/pull/17109) Add `.ToLegacyDec()` method on `math.Int` type for converting to `math.LegacyDec`. +* [#16263](https://github.com/cosmos/cosmos-sdk/pull/16263) Improved `math/Int.Size` by computing the decimal digits count instead of firstly invoking .Marshal() then checking the length + +### Bug Fixes + +* [#17352](https://github.com/cosmos/cosmos-sdk/pull/17352) Ensure that modifying the argument to `NewIntFromBigInt` doesn't mutate the returned value. +* [#16266](https://github.com/cosmos/cosmos-sdk/pull/16266) Fix legacy dec power mut zero exponent precision. + +## [math/v1.0.1](https://github.com/cosmos/cosmos-sdk/releases/tag/math/v1.0.1) - 2023-05-15 + +### Improvements + +* [#15768](https://github.com/cosmos/cosmos-sdk/pull/15768) Removed the second call to the `init` method for the global variable `grand`. +* [#16141](https://github.com/cosmos/cosmos-sdk/pull/16141) Speedup `LegacyDec.ApproxRoot` and `LegacyDec.ApproxSqrt`. + +### Bug Fixes + +* [#15714](https://github.com/cosmos/cosmos-sdk/pull/15714) `FormatInt` returns an error on empty string. + +## [math/v1.0.0](https://github.com/cosmos/cosmos-sdk/releases/tag/math/v1.0.0) - 2023-03-23 + +### Bug Fixes + +* [#15506](https://github.com/cosmos/cosmos-sdk/issues/16605) Dec marshal shouldn't have side effects + +## [math/v1.0.0-rc.0](https://github.com/cosmos/cosmos-sdk/releases/tag/math/v1.0.0-rc.0) - 2023-03-13 + +### Features + +* [#15043](https://github.com/cosmos/cosmos-sdk/issues/15043) add rand funcs to math + +### Bug Fixes + +* [#14922](https://github.com/cosmos/cosmos-sdk/issues/14922) check for negative precision + +### Testing + +* [#15215](https://github.com/cosmos/cosmos-sdk/issues/15215) fix `FormatDec` test + +## [math/v1.0.0-beta.6](https://github.com/cosmos/cosmos-sdk/releases/tag/math/v1.0.0-beta.6) - 2023-02-06 + +### Features + +* [#14760](https://github.com/cosmos/cosmos-sdk/issues/14760) add collections key encoders and value encoders for common types. +* [#14166](https://github.com/cosmos/cosmos-sdk/issues/14166) math: add generics versions of Max, Min to cater to all numeric types +* [#13381](https://github.com/cosmos/cosmos-sdk/issues/13381) add uint `IsNil` method + +### Improvements + +* [#14010](https://github.com/cosmos/cosmos-sdk/issues/14010) math: optimize and test FormatInt + simplify LegacyNewDecFromStr +* [#12794](https://github.com/cosmos/cosmos-sdk/issues/12794) math: precompute & use square of precisionReuse instead of 2 repeated computations + +### Bug Fixes + +* [#14691](https://github.com/cosmos/cosmos-sdk/issues/14691) do not flatten events attributes by event types +* [#14252](https://github.com/cosmos/cosmos-sdk/issues/14252) math: add LegacyNewDecFromStr fuzzers + remove unnecessary error wrapping + +### Testing + +* [#14576](https://github.com/cosmos/cosmos-sdk/issues/14576) Added test cases for precisionMultiplier + +## [math/v1.0.0-beta.3](https://github.com/cosmos/cosmos-sdk/releases/tag/math/v1.0.0-beta.3) - 2022-07-20 + +### Bug Fixes + +* [#11996](https://github.com/cosmos/cosmos-sdk/issues/11996) math: fix Uint.Unmarshal's lack of negative value checking + + diff --git a/dec.go b/dec.go new file mode 100644 index 0000000..794a93b --- /dev/null +++ b/dec.go @@ -0,0 +1,957 @@ +package math + +import ( + "encoding/json" + "errors" + "fmt" + "math/big" + "strconv" + "strings" + "testing" +) + +// NOTE: never use new(Dec) or else we will panic unmarshalling into the +// nil embedded big.Int +type LegacyDec struct { + i *big.Int +} + +const ( + // number of decimal places + LegacyPrecision = 18 + + // bits required to represent the above precision + // Ceiling[Log2[10^Precision - 1]] + LegacyDecimalPrecisionBits = 60 + + // decimalTruncateBits is the minimum number of bits removed + // by a truncate operation. It is equal to + // Floor[Log2[10^Precision - 1]]. + decimalTruncateBits = LegacyDecimalPrecisionBits - 1 + + maxDecBitLen = MaxBitLen + decimalTruncateBits + + // max number of iterations in ApproxRoot function + maxApproxRootIterations = 300 +) + +var ( + precisionReuse = new(big.Int).Exp(big.NewInt(10), big.NewInt(LegacyPrecision), nil) + fivePrecision = new(big.Int).Quo(precisionReuse, big.NewInt(2)) + precisionMultipliers []*big.Int + zeroInt = big.NewInt(0) + oneInt = big.NewInt(1) + tenInt = big.NewInt(10) + smallestDec = LegacySmallestDec() +) + +// Decimal errors +var ( + ErrLegacyEmptyDecimalStr = errors.New("decimal string cannot be empty") + ErrLegacyInvalidDecimalLength = errors.New("invalid decimal length") + ErrLegacyInvalidDecimalStr = errors.New("invalid decimal string") +) + +// Set precision multipliers +func init() { + precisionMultipliers = make([]*big.Int, LegacyPrecision+1) + for i := 0; i <= LegacyPrecision; i++ { + precisionMultipliers[i] = calcPrecisionMultiplier(int64(i)) + } +} + +func precisionInt() *big.Int { + return new(big.Int).Set(precisionReuse) +} + +func LegacyZeroDec() LegacyDec { return LegacyDec{new(big.Int).Set(zeroInt)} } +func LegacyOneDec() LegacyDec { return LegacyDec{precisionInt()} } +func LegacySmallestDec() LegacyDec { return LegacyDec{new(big.Int).Set(oneInt)} } + +// calculate the precision multiplier +func calcPrecisionMultiplier(prec int64) *big.Int { + if prec < 0 { + panic(fmt.Sprintf("negative precision %v", prec)) + } + + if prec > LegacyPrecision { + panic(fmt.Sprintf("too much precision, maximum %v, provided %v", LegacyPrecision, prec)) + } + zerosToAdd := LegacyPrecision - prec + multiplier := new(big.Int).Exp(tenInt, big.NewInt(zerosToAdd), nil) + return multiplier +} + +// get the precision multiplier, do not mutate result +func precisionMultiplier(prec int64) *big.Int { + if prec < 0 { + panic(fmt.Sprintf("negative precision %v", prec)) + } + + if prec > LegacyPrecision { + panic(fmt.Sprintf("too much precision, maximum %v, provided %v", LegacyPrecision, prec)) + } + return precisionMultipliers[prec] +} + +// create a new Dec from integer assuming whole number +func LegacyNewDec(i int64) LegacyDec { + return LegacyNewDecWithPrec(i, 0) +} + +// create a new Dec from integer with decimal place at prec +// CONTRACT: prec <= Precision +func LegacyNewDecWithPrec(i, prec int64) LegacyDec { + return LegacyDec{ + new(big.Int).Mul(big.NewInt(i), precisionMultiplier(prec)), + } +} + +// create a new Dec from big integer assuming whole numbers +// CONTRACT: prec <= Precision +func LegacyNewDecFromBigInt(i *big.Int) LegacyDec { + return LegacyNewDecFromBigIntWithPrec(i, 0) +} + +// create a new Dec from big integer assuming whole numbers +// CONTRACT: prec <= Precision +func LegacyNewDecFromBigIntWithPrec(i *big.Int, prec int64) LegacyDec { + return LegacyDec{ + new(big.Int).Mul(i, precisionMultiplier(prec)), + } +} + +// create a new Dec from big integer assuming whole numbers +// CONTRACT: prec <= Precision +func LegacyNewDecFromInt(i Int) LegacyDec { + return LegacyNewDecFromIntWithPrec(i, 0) +} + +// create a new Dec from big integer with decimal place at prec +// CONTRACT: prec <= Precision +func LegacyNewDecFromIntWithPrec(i Int, prec int64) LegacyDec { + return LegacyDec{ + new(big.Int).Mul(i.BigInt(), precisionMultiplier(prec)), + } +} + +// create a decimal from an input decimal string. +// valid must come in the form: +// +// (-) whole integers (.) decimal integers +// +// examples of acceptable input include: +// +// -123.456 +// 456.7890 +// 345 +// -456789 +// +// NOTE - An error will return if more decimal places +// are provided in the string than the constant Precision. +// +// CONTRACT - This function does not mutate the input str. +func LegacyNewDecFromStr(str string) (LegacyDec, error) { + // first extract any negative symbol + neg := false + if len(str) > 0 && str[0] == '-' { + neg = true + str = str[1:] + } + + if len(str) == 0 { + return LegacyDec{}, ErrLegacyEmptyDecimalStr + } + + strs := strings.Split(str, ".") + lenDecs := 0 + combinedStr := strs[0] + + if len(strs) == 2 { // has a decimal place + lenDecs = len(strs[1]) + if lenDecs == 0 || len(combinedStr) == 0 { + return LegacyDec{}, ErrLegacyInvalidDecimalLength + } + combinedStr += strs[1] + } else if len(strs) > 2 { + return LegacyDec{}, ErrLegacyInvalidDecimalStr + } + + if lenDecs > LegacyPrecision { + return LegacyDec{}, fmt.Errorf("value '%s' exceeds max precision by %d decimal places: max precision %d", str, LegacyPrecision-lenDecs, LegacyPrecision) + } + + // add some extra zero's to correct to the Precision factor + zerosToAdd := LegacyPrecision - lenDecs + zeros := strings.Repeat("0", zerosToAdd) + combinedStr += zeros + + combined, ok := new(big.Int).SetString(combinedStr, 10) // base 10 + if !ok { + return LegacyDec{}, fmt.Errorf("failed to set decimal string with base 10: %s", combinedStr) + } + if combined.BitLen() > maxDecBitLen { + return LegacyDec{}, fmt.Errorf("decimal '%s' out of range; bitLen: got %d, max %d", str, combined.BitLen(), maxDecBitLen) + } + if neg { + combined = new(big.Int).Neg(combined) + } + + return LegacyDec{combined}, nil +} + +// Decimal from string, panic on error +func LegacyMustNewDecFromStr(s string) LegacyDec { + dec, err := LegacyNewDecFromStr(s) + if err != nil { + panic(err) + } + return dec +} + +func (d LegacyDec) IsNil() bool { return d.i == nil } // is decimal nil +func (d LegacyDec) IsZero() bool { return (d.i).Sign() == 0 } // is equal to zero +func (d LegacyDec) IsNegative() bool { return (d.i).Sign() == -1 } // is negative +func (d LegacyDec) IsPositive() bool { return (d.i).Sign() == 1 } // is positive +func (d LegacyDec) Equal(d2 LegacyDec) bool { return (d.i).Cmp(d2.i) == 0 } // equal decimals +func (d LegacyDec) GT(d2 LegacyDec) bool { return (d.i).Cmp(d2.i) > 0 } // greater than +func (d LegacyDec) GTE(d2 LegacyDec) bool { return (d.i).Cmp(d2.i) >= 0 } // greater than or equal +func (d LegacyDec) LT(d2 LegacyDec) bool { return (d.i).Cmp(d2.i) < 0 } // less than +func (d LegacyDec) LTE(d2 LegacyDec) bool { return (d.i).Cmp(d2.i) <= 0 } // less than or equal +func (d LegacyDec) Neg() LegacyDec { return LegacyDec{new(big.Int).Neg(d.i)} } // reverse the decimal sign +func (d LegacyDec) NegMut() LegacyDec { d.i.Neg(d.i); return d } // reverse the decimal sign, mutable +func (d LegacyDec) Abs() LegacyDec { return LegacyDec{new(big.Int).Abs(d.i)} } // absolute value +func (d LegacyDec) AbsMut() LegacyDec { d.i.Abs(d.i); return d } // absolute value, mutable +func (d LegacyDec) Set(d2 LegacyDec) LegacyDec { d.i.Set(d2.i); return d } // set to existing dec value +func (d LegacyDec) Clone() LegacyDec { return LegacyDec{new(big.Int).Set(d.i)} } // clone new dec + +// BigInt returns a copy of the underlying big.Int. +func (d LegacyDec) BigInt() *big.Int { + if d.IsNil() { + return nil + } + + cp := new(big.Int) + return cp.Set(d.i) +} + +func (d LegacyDec) ImmutOp(op func(LegacyDec, LegacyDec) LegacyDec, d2 LegacyDec) LegacyDec { + return op(d.Clone(), d2) +} + +func (d LegacyDec) ImmutOpInt(op func(LegacyDec, Int) LegacyDec, d2 Int) LegacyDec { + return op(d.Clone(), d2) +} + +func (d LegacyDec) ImmutOpInt64(op func(LegacyDec, int64) LegacyDec, d2 int64) LegacyDec { + // TODO: use already allocated operand bigint to avoid + // newint each time, add mutex for race condition + // Issue: https://github.com/cosmos/cosmos-sdk/issues/11166 + return op(d.Clone(), d2) +} + +func (d LegacyDec) SetInt64(i int64) LegacyDec { + d.i.SetInt64(i) + d.i.Mul(d.i, precisionReuse) + return d +} + +// addition +func (d LegacyDec) Add(d2 LegacyDec) LegacyDec { + return d.ImmutOp(LegacyDec.AddMut, d2) +} + +// mutable addition +func (d LegacyDec) AddMut(d2 LegacyDec) LegacyDec { + d.i.Add(d.i, d2.i) + + if d.i.BitLen() > maxDecBitLen { + panic("Int overflow") + } + return d +} + +// subtraction +func (d LegacyDec) Sub(d2 LegacyDec) LegacyDec { + return d.ImmutOp(LegacyDec.SubMut, d2) +} + +// mutable subtraction +func (d LegacyDec) SubMut(d2 LegacyDec) LegacyDec { + d.i.Sub(d.i, d2.i) + + if d.i.BitLen() > maxDecBitLen { + panic("Int overflow") + } + return d +} + +// multiplication +func (d LegacyDec) Mul(d2 LegacyDec) LegacyDec { + return d.ImmutOp(LegacyDec.MulMut, d2) +} + +// mutable multiplication +func (d LegacyDec) MulMut(d2 LegacyDec) LegacyDec { + d.i.Mul(d.i, d2.i) + chopped := chopPrecisionAndRound(d.i) + + if chopped.BitLen() > maxDecBitLen { + panic("Int overflow") + } + *d.i = *chopped + return d +} + +// multiplication truncate +func (d LegacyDec) MulTruncate(d2 LegacyDec) LegacyDec { + return d.ImmutOp(LegacyDec.MulTruncateMut, d2) +} + +// mutable multiplication truncage +func (d LegacyDec) MulTruncateMut(d2 LegacyDec) LegacyDec { + d.i.Mul(d.i, d2.i) + chopPrecisionAndTruncate(d.i) + + if d.i.BitLen() > maxDecBitLen { + panic("Int overflow") + } + return d +} + +// multiplication round up at precision end. +func (d LegacyDec) MulRoundUp(d2 LegacyDec) LegacyDec { + return d.ImmutOp(LegacyDec.MulRoundUpMut, d2) +} + +// mutable multiplication with round up at precision end. +func (d LegacyDec) MulRoundUpMut(d2 LegacyDec) LegacyDec { + d.i.Mul(d.i, d2.i) + chopPrecisionAndRoundUp(d.i) + + if d.i.BitLen() > maxDecBitLen { + panic("Int overflow") + } + return d +} + +// multiplication +func (d LegacyDec) MulInt(i Int) LegacyDec { + return d.ImmutOpInt(LegacyDec.MulIntMut, i) +} + +func (d LegacyDec) MulIntMut(i Int) LegacyDec { + d.i.Mul(d.i, i.BigInt()) + if d.i.BitLen() > maxDecBitLen { + panic("Int overflow") + } + return d +} + +// MulInt64 - multiplication with int64 +func (d LegacyDec) MulInt64(i int64) LegacyDec { + return d.ImmutOpInt64(LegacyDec.MulInt64Mut, i) +} + +func (d LegacyDec) MulInt64Mut(i int64) LegacyDec { + d.i.Mul(d.i, big.NewInt(i)) + + if d.i.BitLen() > maxDecBitLen { + panic("Int overflow") + } + return d +} + +// quotient +func (d LegacyDec) Quo(d2 LegacyDec) LegacyDec { + return d.ImmutOp(LegacyDec.QuoMut, d2) +} + +var squaredPrecisionReuse = new(big.Int).Mul(precisionReuse, precisionReuse) + +// mutable quotient +func (d LegacyDec) QuoMut(d2 LegacyDec) LegacyDec { + // multiply by precision twice + d.i.Mul(d.i, squaredPrecisionReuse) + d.i.Quo(d.i, d2.i) + + chopPrecisionAndRound(d.i) + if d.i.BitLen() > maxDecBitLen { + panic("Int overflow") + } + return d +} + +// quotient truncate +func (d LegacyDec) QuoTruncate(d2 LegacyDec) LegacyDec { + return d.ImmutOp(LegacyDec.QuoTruncateMut, d2) +} + +// mutable quotient truncate +func (d LegacyDec) QuoTruncateMut(d2 LegacyDec) LegacyDec { + // multiply precision twice + d.i.Mul(d.i, squaredPrecisionReuse) + d.i.Quo(d.i, d2.i) + + chopPrecisionAndTruncate(d.i) + if d.i.BitLen() > maxDecBitLen { + panic("Int overflow") + } + return d +} + +// quotient, round up +func (d LegacyDec) QuoRoundUp(d2 LegacyDec) LegacyDec { + return d.ImmutOp(LegacyDec.QuoRoundupMut, d2) +} + +// mutable quotient, round up +func (d LegacyDec) QuoRoundupMut(d2 LegacyDec) LegacyDec { + // multiply precision twice + d.i.Mul(d.i, squaredPrecisionReuse) + d.i.Quo(d.i, d2.i) + + chopPrecisionAndRoundUp(d.i) + if d.i.BitLen() > maxDecBitLen { + panic("Int overflow") + } + return d +} + +// quotient +func (d LegacyDec) QuoInt(i Int) LegacyDec { + return d.ImmutOpInt(LegacyDec.QuoIntMut, i) +} + +func (d LegacyDec) QuoIntMut(i Int) LegacyDec { + d.i.Quo(d.i, i.BigInt()) + return d +} + +// QuoInt64 - quotient with int64 +func (d LegacyDec) QuoInt64(i int64) LegacyDec { + return d.ImmutOpInt64(LegacyDec.QuoInt64Mut, i) +} + +func (d LegacyDec) QuoInt64Mut(i int64) LegacyDec { + d.i.Quo(d.i, big.NewInt(i)) + return d +} + +// ApproxRoot returns an approximate estimation of a Dec's positive real nth root +// using Newton's method (where n is positive). The algorithm starts with some guess and +// computes the sequence of improved guesses until an answer converges to an +// approximate answer. It returns `|d|.ApproxRoot() * -1` if input is negative. +// A maximum number of 100 iterations is used a backup boundary condition for +// cases where the answer never converges enough to satisfy the main condition. +func (d LegacyDec) ApproxRoot(root uint64) (guess LegacyDec, err error) { + defer func() { + if r := recover(); r != nil { + var ok bool + err, ok = r.(error) + if !ok { + err = errors.New("out of bounds") + } + } + }() + + if d.IsNegative() { + absRoot, err := d.Neg().ApproxRoot(root) + return absRoot.NegMut(), err + } + + // One decimal, that we invalidate later. Helps us save a heap allocation. + scratchOneDec := LegacyOneDec() + if root == 1 || d.IsZero() || d.Equal(scratchOneDec) { + return d, nil + } + + if root == 0 { + return scratchOneDec, nil + } + + guess, delta := scratchOneDec, LegacyOneDec() + + for iter := 0; iter < maxApproxRootIterations && delta.Abs().GT(smallestDec); iter++ { + prev := guess.Power(root - 1) + if prev.IsZero() { + prev = smallestDec + } + delta.Set(d).QuoMut(prev) + delta.SubMut(guess) + delta.QuoInt64Mut(int64(root)) + + guess.AddMut(delta) + } + + return guess, nil +} + +// Power returns a the result of raising to a positive integer power +func (d LegacyDec) Power(power uint64) LegacyDec { + res := LegacyDec{new(big.Int).Set(d.i)} + return res.PowerMut(power) +} + +func (d LegacyDec) PowerMut(power uint64) LegacyDec { + if power == 0 { + // Set to 1 with the correct precision. + d.i.Set(precisionReuse) + return d + } + tmp := LegacyOneDec() + + for i := power; i > 1; { + if i%2 != 0 { + tmp.MulMut(d) + } + i /= 2 + d.MulMut(d) + } + + return d.MulMut(tmp) +} + +// ApproxSqrt is a wrapper around ApproxRoot for the common special case +// of finding the square root of a number. It returns -(sqrt(abs(d)) if input is negative. +func (d LegacyDec) ApproxSqrt() (LegacyDec, error) { + return d.ApproxRoot(2) +} + +// is integer, e.g. decimals are zero +func (d LegacyDec) IsInteger() bool { + return new(big.Int).Rem(d.i, precisionReuse).Sign() == 0 +} + +// format decimal state +func (d LegacyDec) Format(s fmt.State, verb rune) { + _, err := s.Write([]byte(d.String())) + if err != nil { + panic(err) + } +} + +func (d LegacyDec) String() string { + if d.i == nil { + return d.i.String() + } + + isNeg := d.IsNegative() + + if isNeg { + d = d.Neg() + } + + bzInt, err := d.i.MarshalText() + if err != nil { + return "" + } + inputSize := len(bzInt) + + var bzStr []byte + + // TODO: Remove trailing zeros + // case 1, purely decimal + if inputSize <= LegacyPrecision { + bzStr = make([]byte, LegacyPrecision+2) + + // 0. prefix + bzStr[0] = byte('0') + bzStr[1] = byte('.') + + // set relevant digits to 0 + for i := 0; i < LegacyPrecision-inputSize; i++ { + bzStr[i+2] = byte('0') + } + + // set final digits + copy(bzStr[2+(LegacyPrecision-inputSize):], bzInt) + } else { + // inputSize + 1 to account for the decimal point that is being added + bzStr = make([]byte, inputSize+1) + decPointPlace := inputSize - LegacyPrecision + + copy(bzStr, bzInt[:decPointPlace]) // pre-decimal digits + bzStr[decPointPlace] = byte('.') // decimal point + copy(bzStr[decPointPlace+1:], bzInt[decPointPlace:]) // post-decimal digits + } + + if isNeg { + return "-" + string(bzStr) + } + + return string(bzStr) +} + +// Float64 returns the float64 representation of a Dec. +// Will return the error if the conversion failed. +func (d LegacyDec) Float64() (float64, error) { + return strconv.ParseFloat(d.String(), 64) +} + +// MustFloat64 returns the float64 representation of a Dec. +// Would panic if the conversion failed. +func (d LegacyDec) MustFloat64() float64 { + if value, err := strconv.ParseFloat(d.String(), 64); err != nil { + panic(err) + } else { + return value + } +} + +// ____ +// __| |__ "chop 'em +// ` \ round!" +// ___|| ~ _ -bankers +// | | __ +// | | | __|__|__ +// |_____: / | $$$ | +// |________| + +// Remove a Precision amount of rightmost digits and perform bankers rounding +// on the remainder (gaussian rounding) on the digits which have been removed. +// +// Mutates the input. Use the non-mutative version if that is undesired +func chopPrecisionAndRound(d *big.Int) *big.Int { + // remove the negative and add it back when returning + if d.Sign() == -1 { + // make d positive, compute chopped value, and then un-mutate d + d = d.Neg(d) + d = chopPrecisionAndRound(d) + d = d.Neg(d) + return d + } + + // get the truncated quotient and remainder + quo, rem := d, big.NewInt(0) + quo, rem = quo.QuoRem(d, precisionReuse, rem) + + if rem.Sign() == 0 { // remainder is zero + return quo + } + + switch rem.Cmp(fivePrecision) { + case -1: + return quo + case 1: + return quo.Add(quo, oneInt) + default: // bankers rounding must take place + // always round to an even number + if quo.Bit(0) == 0 { + return quo + } + return quo.Add(quo, oneInt) + } +} + +func chopPrecisionAndRoundUp(d *big.Int) *big.Int { + // remove the negative and add it back when returning + if d.Sign() == -1 { + // make d positive, compute chopped value, and then un-mutate d + d = d.Neg(d) + // truncate since d is negative... + chopPrecisionAndTruncate(d) + d = d.Neg(d) + return d + } + + // get the truncated quotient and remainder + quo, rem := d, big.NewInt(0) + quo, rem = quo.QuoRem(d, precisionReuse, rem) + + if rem.Sign() == 0 { // remainder is zero + return quo + } + + return quo.Add(quo, oneInt) +} + +func chopPrecisionAndRoundNonMutative(d *big.Int) *big.Int { + tmp := new(big.Int).Set(d) + return chopPrecisionAndRound(tmp) +} + +// RoundInt64 rounds the decimal using bankers rounding +func (d LegacyDec) RoundInt64() int64 { + chopped := chopPrecisionAndRoundNonMutative(d.i) + if !chopped.IsInt64() { + panic("Int64() out of bound") + } + return chopped.Int64() +} + +// RoundInt round the decimal using bankers rounding +func (d LegacyDec) RoundInt() Int { + return NewIntFromBigInt(chopPrecisionAndRoundNonMutative(d.i)) +} + +// chopPrecisionAndTruncate is similar to chopPrecisionAndRound, +// but always rounds down. It does not mutate the input. +func chopPrecisionAndTruncate(d *big.Int) { + d.Quo(d, precisionReuse) +} + +func chopPrecisionAndTruncateNonMutative(d *big.Int) *big.Int { + tmp := new(big.Int).Set(d) + chopPrecisionAndTruncate(tmp) + return tmp +} + +// TruncateInt64 truncates the decimals from the number and returns an int64 +func (d LegacyDec) TruncateInt64() int64 { + chopped := chopPrecisionAndTruncateNonMutative(d.i) + if !chopped.IsInt64() { + panic("Int64() out of bound") + } + return chopped.Int64() +} + +// TruncateInt truncates the decimals from the number and returns an Int +func (d LegacyDec) TruncateInt() Int { + return NewIntFromBigInt(chopPrecisionAndTruncateNonMutative(d.i)) +} + +// TruncateDec truncates the decimals from the number and returns a Dec +func (d LegacyDec) TruncateDec() LegacyDec { + return LegacyNewDecFromBigInt(chopPrecisionAndTruncateNonMutative(d.i)) +} + +// Ceil returns the smallest interger value (as a decimal) that is greater than +// or equal to the given decimal. +func (d LegacyDec) Ceil() LegacyDec { + tmp := new(big.Int).Set(d.i) + + quo, rem := tmp, big.NewInt(0) + quo, rem = quo.QuoRem(tmp, precisionReuse, rem) + + // no need to round with a zero remainder regardless of sign + if rem.Cmp(zeroInt) == 0 { + return LegacyNewDecFromBigInt(quo) + } + + if rem.Sign() == -1 { + return LegacyNewDecFromBigInt(quo) + } + + return LegacyNewDecFromBigInt(quo.Add(quo, oneInt)) +} + +// LegacyMaxSortableDec is the largest Dec that can be passed into SortableDecBytes() +// Its negative form is the least Dec that can be passed in. +var LegacyMaxSortableDec LegacyDec + +func init() { + LegacyMaxSortableDec = LegacyOneDec().Quo(LegacySmallestDec()) +} + +// ValidSortableDec ensures that a Dec is within the sortable bounds, +// a Dec can't have a precision of less than 10^-18. +// Max sortable decimal was set to the reciprocal of SmallestDec. +func LegacyValidSortableDec(dec LegacyDec) bool { + return dec.Abs().LTE(LegacyMaxSortableDec) +} + +// SortableDecBytes returns a byte slice representation of a Dec that can be sorted. +// Left and right pads with 0s so there are 18 digits to left and right of the decimal point. +// For this reason, there is a maximum and minimum value for this, enforced by ValidSortableDec. +func LegacySortableDecBytes(dec LegacyDec) []byte { + if !LegacyValidSortableDec(dec) { + panic("dec must be within bounds") + } + // Instead of adding an extra byte to all sortable decs in order to handle max sortable, we just + // makes its bytes be "max" which comes after all numbers in ASCIIbetical order + if dec.Equal(LegacyMaxSortableDec) { + return []byte("max") + } + // For the same reason, we make the bytes of minimum sortable dec be --, which comes before all numbers. + if dec.Equal(LegacyMaxSortableDec.Neg()) { + return []byte("--") + } + // We move the negative sign to the front of all the left padded 0s, to make negative numbers come before positive numbers + if dec.IsNegative() { + return append([]byte("-"), []byte(fmt.Sprintf(fmt.Sprintf("%%0%ds", LegacyPrecision*2+1), dec.Abs().String()))...) + } + return []byte(fmt.Sprintf(fmt.Sprintf("%%0%ds", LegacyPrecision*2+1), dec.String())) +} + +// reuse nil values +var nilJSON []byte + +func init() { + empty := new(big.Int) + bz, _ := empty.MarshalText() + nilJSON, _ = json.Marshal(string(bz)) +} + +// MarshalJSON marshals the decimal +func (d LegacyDec) MarshalJSON() ([]byte, error) { + if d.i == nil { + return nilJSON, nil + } + return json.Marshal(d.String()) +} + +// UnmarshalJSON defines custom decoding scheme +func (d *LegacyDec) UnmarshalJSON(bz []byte) error { + if d.i == nil { + d.i = new(big.Int) + } + + var text string + err := json.Unmarshal(bz, &text) + if err != nil { + return err + } + + // TODO: Reuse dec allocation + newDec, err := LegacyNewDecFromStr(text) + if err != nil { + return err + } + + d.i = newDec.i + return nil +} + +// MarshalYAML returns the YAML representation. +func (d LegacyDec) MarshalYAML() (interface{}, error) { + return d.String(), nil +} + +// Marshal implements the gogo proto custom type interface. +func (d LegacyDec) Marshal() ([]byte, error) { + i := d.i + if i == nil { + i = new(big.Int) + } + return i.MarshalText() +} + +// MarshalTo implements the gogo proto custom type interface. +func (d *LegacyDec) MarshalTo(data []byte) (n int, err error) { + i := d.i + if i == nil { + i = new(big.Int) + } + + if i.Cmp(zeroInt) == 0 { + copy(data, []byte{0x30}) + return 1, nil + } + + bz, err := d.Marshal() + if err != nil { + return 0, err + } + + copy(data, bz) + return len(bz), nil +} + +// Unmarshal implements the gogo proto custom type interface. +func (d *LegacyDec) Unmarshal(data []byte) error { + if len(data) == 0 { + d = nil + return nil + } + + if d.i == nil { + d.i = new(big.Int) + } + + if err := d.i.UnmarshalText(data); err != nil { + return err + } + + if d.i.BitLen() > maxDecBitLen { + return fmt.Errorf("decimal out of range; got: %d, max: %d", d.i.BitLen(), maxDecBitLen) + } + + return nil +} + +// Size implements the gogo proto custom type interface. +func (d *LegacyDec) Size() int { + bz, _ := d.Marshal() + return len(bz) +} + +// Override Amino binary serialization by proxying to protobuf. +func (d LegacyDec) MarshalAmino() ([]byte, error) { return d.Marshal() } +func (d *LegacyDec) UnmarshalAmino(bz []byte) error { return d.Unmarshal(bz) } + +// helpers + +// test if two decimal arrays are equal +func LegacyDecsEqual(d1s, d2s []LegacyDec) bool { + if len(d1s) != len(d2s) { + return false + } + + for i, d1 := range d1s { + if !d1.Equal(d2s[i]) { + return false + } + } + return true +} + +// minimum decimal between two +func LegacyMinDec(d1, d2 LegacyDec) LegacyDec { + if d1.LT(d2) { + return d1 + } + return d2 +} + +// maximum decimal between two +func LegacyMaxDec(d1, d2 LegacyDec) LegacyDec { + if d1.LT(d2) { + return d2 + } + return d1 +} + +// intended to be used with require/assert: require.True(DecEq(...)) +func LegacyDecEq(t *testing.T, exp, got LegacyDec) (*testing.T, bool, string, string, string) { + t.Helper() + return t, exp.Equal(got), "expected:\t%v\ngot:\t\t%v", exp.String(), got.String() +} + +func LegacyDecApproxEq(t *testing.T, d1, d2, tol LegacyDec) (*testing.T, bool, string, string, string) { + t.Helper() + diff := d1.Sub(d2).Abs() + return t, diff.LTE(tol), "expected |d1 - d2| <:\t%v\ngot |d1 - d2| = \t\t%v", tol.String(), diff.String() +} + +// FormatDec formats a decimal (as encoded in protobuf) into a value-rendered +// string following ADR-050. This function operates with string manipulation +// (instead of manipulating the sdk.Dec object). +func FormatDec(v string) (string, error) { + parts := strings.Split(v, ".") + if len(parts) > 2 { + return "", fmt.Errorf("invalid decimal: too many points in %s", v) + } + + intPart, err := FormatInt(parts[0]) + if err != nil { + return "", err + } + + if len(parts) == 1 { + return intPart, nil + } + + decPart := strings.TrimRight(parts[1], "0") + if len(decPart) == 0 { + return intPart, nil + } + + // Ensure that the decimal part has only digits. + // https://github.com/cosmos/cosmos-sdk/issues/12811 + if !hasOnlyDigits(decPart) { + return "", fmt.Errorf("non-digits detected after decimal point in: %q", decPart) + } + + return intPart + "." + decPart, nil +} diff --git a/dec_internal_test.go b/dec_internal_test.go new file mode 100644 index 0000000..8b89930 --- /dev/null +++ b/dec_internal_test.go @@ -0,0 +1,109 @@ +package math + +import ( + "encoding/json" + "math/big" + "testing" + + "github.com/stretchr/testify/suite" +) + +type decimalInternalTestSuite struct { + suite.Suite +} + +func TestDecimalInternalTestSuite(t *testing.T) { + suite.Run(t, new(decimalInternalTestSuite)) +} + +func (s *decimalInternalTestSuite) TestPrecisionMultiplier() { + tests := []struct { + prec int64 + exp *big.Int + }{ + { + 5, + big.NewInt(10000000000000), + }, + { + 8, + big.NewInt(10000000000), + }, + { + 11, + big.NewInt(10000000), + }, + { + 15, + big.NewInt(1000), + }, + { + 18, + big.NewInt(1), + }, + } + for _, tt := range tests { + res := precisionMultiplier(tt.prec) + s.Require().Equal(0, res.Cmp(tt.exp), "equality was incorrect, res %v, exp %v", res, tt.exp) + } +} + +func (s *decimalInternalTestSuite) TestZeroDeserializationJSON() { + d := LegacyDec{new(big.Int)} + err := json.Unmarshal([]byte(`"0"`), &d) + s.Require().Nil(err) + err = json.Unmarshal([]byte(`"{}"`), &d) + s.Require().NotNil(err) +} + +func (s *decimalInternalTestSuite) TestSerializationGocodecJSON() { + d := LegacyMustNewDecFromStr("0.333") + + bz, err := json.Marshal(d) + s.Require().NoError(err) + + d2 := LegacyDec{new(big.Int)} + err = json.Unmarshal(bz, &d2) + s.Require().NoError(err) + s.Require().True(d.Equal(d2), "original: %v, unmarshalled: %v", d, d2) +} + +func (s *decimalInternalTestSuite) TestDecMarshalJSON() { + decimal := func(i int64) LegacyDec { + d := LegacyNewDec(0) + d.i = new(big.Int).SetInt64(i) + return d + } + tests := []struct { + name string + d LegacyDec + want string + wantErr bool // if wantErr = false, will also attempt unmarshaling + }{ + {"zero", decimal(0), "\"0.000000000000000000\"", false}, + {"one", decimal(1), "\"0.000000000000000001\"", false}, + {"ten", decimal(10), "\"0.000000000000000010\"", false}, + {"12340", decimal(12340), "\"0.000000000000012340\"", false}, + {"zeroInt", LegacyNewDec(0), "\"0.000000000000000000\"", false}, + {"oneInt", LegacyNewDec(1), "\"1.000000000000000000\"", false}, + {"tenInt", LegacyNewDec(10), "\"10.000000000000000000\"", false}, + {"12340Int", LegacyNewDec(12340), "\"12340.000000000000000000\"", false}, + } + for _, tt := range tests { + tt := tt + s.T().Run(tt.name, func(t *testing.T) { + got, err := tt.d.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("Dec.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + s.Require().Equal(tt.want, string(got), "incorrect marshaled value") + unmarshalledDec := LegacyNewDec(0) + err := unmarshalledDec.UnmarshalJSON(got) + s.Require().NoError(err) + s.Require().Equal(tt.d, unmarshalledDec, "incorrect unmarshalled value") + } + }) + } +} diff --git a/dec_test.go b/dec_test.go new file mode 100644 index 0000000..85ff732 --- /dev/null +++ b/dec_test.go @@ -0,0 +1,756 @@ +package math_test + +import ( + "bytes" + "encoding/json" + "fmt" + "math/big" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "sigs.k8s.io/yaml" + + "cosmossdk.io/math" +) + +type decimalTestSuite struct { + suite.Suite +} + +func TestDecimalTestSuite(t *testing.T) { + suite.Run(t, new(decimalTestSuite)) +} + +func TestDecApproxEq(t *testing.T) { + // d1 = 0.55, d2 = 0.6, tol = 0.1 + d1 := math.LegacyNewDecWithPrec(55, 2) + d2 := math.LegacyNewDecWithPrec(6, 1) + tol := math.LegacyNewDecWithPrec(1, 1) + + require.True(math.LegacyDecApproxEq(t, d1, d2, tol)) + + // d1 = 0.55, d2 = 0.6, tol = 1E-5 + d1 = math.LegacyNewDecWithPrec(55, 2) + d2 = math.LegacyNewDecWithPrec(6, 1) + tol = math.LegacyNewDecWithPrec(1, 5) + + require.False(math.LegacyDecApproxEq(t, d1, d2, tol)) + + // d1 = 0.6, d2 = 0.61, tol = 0.01 + d1 = math.LegacyNewDecWithPrec(6, 1) + d2 = math.LegacyNewDecWithPrec(61, 2) + tol = math.LegacyNewDecWithPrec(1, 2) + + require.True(math.LegacyDecApproxEq(t, d1, d2, tol)) +} + +// create a decimal from a decimal string (ex. "1234.5678") +func (s *decimalTestSuite) mustNewDecFromStr(str string) (d math.LegacyDec) { + d, err := math.LegacyNewDecFromStr(str) + s.Require().NoError(err) + + return d +} + +func (s *decimalTestSuite) TestNewDecFromStr() { + largeBigInt, ok := new(big.Int).SetString("3144605511029693144278234343371835", 10) + s.Require().True(ok) + + largerBigInt, ok := new(big.Int).SetString("8888888888888888888888888888888888888888888888888888888888888888888844444440", 10) + s.Require().True(ok) + + largestBigInt, ok := new(big.Int).SetString("33499189745056880149688856635597007162669032647290798121690100488888732861290034376435130433535", 10) + s.Require().True(ok) + + tests := []struct { + decimalStr string + expErr bool + exp math.LegacyDec + }{ + {"", true, math.LegacyDec{}}, + {"0.-75", true, math.LegacyDec{}}, + {"0", false, math.LegacyNewDec(0)}, + {"1", false, math.LegacyNewDec(1)}, + {"1.1", false, math.LegacyNewDecWithPrec(11, 1)}, + {"0.75", false, math.LegacyNewDecWithPrec(75, 2)}, + {"0.8", false, math.LegacyNewDecWithPrec(8, 1)}, + {"0.11111", false, math.LegacyNewDecWithPrec(11111, 5)}, + {"314460551102969.3144278234343371835", true, math.LegacyNewDec(3141203149163817869)}, + { + "314460551102969314427823434337.1835718092488231350", + true, math.LegacyNewDecFromBigIntWithPrec(largeBigInt, 4), + }, + { + "314460551102969314427823434337.1835", + false, math.LegacyNewDecFromBigIntWithPrec(largeBigInt, 4), + }, + {".", true, math.LegacyDec{}}, + {".0", true, math.LegacyNewDec(0)}, + {"1.", true, math.LegacyNewDec(1)}, + {"foobar", true, math.LegacyDec{}}, + {"0.foobar", true, math.LegacyDec{}}, + {"0.foobar.", true, math.LegacyDec{}}, + {"8888888888888888888888888888888888888888888888888888888888888888888844444440", false, math.LegacyNewDecFromBigInt(largerBigInt)}, + {"33499189745056880149688856635597007162669032647290798121690100488888732861290.034376435130433535", false, math.LegacyNewDecFromBigIntWithPrec(largestBigInt, 18)}, + {"133499189745056880149688856635597007162669032647290798121690100488888732861291", true, math.LegacyDec{}}, + } + + for tcIndex, tc := range tests { + res, err := math.LegacyNewDecFromStr(tc.decimalStr) + if tc.expErr { + s.Require().NotNil(err, "error expected, decimalStr %v, tc %v", tc.decimalStr, tcIndex) + } else { + s.Require().Nil(err, "unexpected error, decimalStr %v, tc %v", tc.decimalStr, tcIndex) + s.Require().True(res.Equal(tc.exp), "equality was incorrect, res %v, exp %v, tc %v", res, tc.exp, tcIndex) + } + + // negative tc + res, err = math.LegacyNewDecFromStr("-" + tc.decimalStr) + if tc.expErr { + s.Require().NotNil(err, "error expected, decimalStr %v, tc %v", tc.decimalStr, tcIndex) + } else { + s.Require().Nil(err, "unexpected error, decimalStr %v, tc %v", tc.decimalStr, tcIndex) + exp := tc.exp.Mul(math.LegacyNewDec(-1)) + s.Require().True(res.Equal(exp), "equality was incorrect, res %v, exp %v, tc %v", res, exp, tcIndex) + } + } +} + +func (s *decimalTestSuite) TestDecString() { + tests := []struct { + d math.LegacyDec + want string + }{ + {math.LegacyNewDec(0), "0.000000000000000000"}, + {math.LegacyNewDec(1), "1.000000000000000000"}, + {math.LegacyNewDec(10), "10.000000000000000000"}, + {math.LegacyNewDec(12340), "12340.000000000000000000"}, + {math.LegacyNewDecWithPrec(12340, 4), "1.234000000000000000"}, + {math.LegacyNewDecWithPrec(12340, 5), "0.123400000000000000"}, + {math.LegacyNewDecWithPrec(12340, 8), "0.000123400000000000"}, + {math.LegacyNewDecWithPrec(1009009009009009009, 17), "10.090090090090090090"}, + } + for tcIndex, tc := range tests { + s.Require().Equal(tc.want, tc.d.String(), "bad String(), index: %v", tcIndex) + } +} + +func (s *decimalTestSuite) TestDecFloat64() { + tests := []struct { + d math.LegacyDec + want float64 + }{ + {math.LegacyNewDec(0), 0.000000000000000000}, + {math.LegacyNewDec(1), 1.000000000000000000}, + {math.LegacyNewDec(10), 10.000000000000000000}, + {math.LegacyNewDec(12340), 12340.000000000000000000}, + {math.LegacyNewDecWithPrec(12340, 4), 1.234000000000000000}, + {math.LegacyNewDecWithPrec(12340, 5), 0.123400000000000000}, + {math.LegacyNewDecWithPrec(12340, 8), 0.000123400000000000}, + {math.LegacyNewDecWithPrec(1009009009009009009, 17), 10.090090090090090090}, + } + for tcIndex, tc := range tests { + value, err := tc.d.Float64() + s.Require().Nil(err, "error getting Float64(), index: %v", tcIndex) + s.Require().Equal(tc.want, value, "bad Float64(), index: %v", tcIndex) + s.Require().Equal(tc.want, tc.d.MustFloat64(), "bad MustFloat64(), index: %v", tcIndex) + } +} + +func (s *decimalTestSuite) TestEqualities() { + tests := []struct { + d1, d2 math.LegacyDec + gt, lt, eq bool + }{ + {math.LegacyNewDec(0), math.LegacyNewDec(0), false, false, true}, + {math.LegacyNewDecWithPrec(0, 2), math.LegacyNewDecWithPrec(0, 4), false, false, true}, + {math.LegacyNewDecWithPrec(100, 0), math.LegacyNewDecWithPrec(100, 0), false, false, true}, + {math.LegacyNewDecWithPrec(-100, 0), math.LegacyNewDecWithPrec(-100, 0), false, false, true}, + {math.LegacyNewDecWithPrec(-1, 1), math.LegacyNewDecWithPrec(-1, 1), false, false, true}, + {math.LegacyNewDecWithPrec(3333, 3), math.LegacyNewDecWithPrec(3333, 3), false, false, true}, + + {math.LegacyNewDecWithPrec(0, 0), math.LegacyNewDecWithPrec(3333, 3), false, true, false}, + {math.LegacyNewDecWithPrec(0, 0), math.LegacyNewDecWithPrec(100, 0), false, true, false}, + {math.LegacyNewDecWithPrec(-1, 0), math.LegacyNewDecWithPrec(3333, 3), false, true, false}, + {math.LegacyNewDecWithPrec(-1, 0), math.LegacyNewDecWithPrec(100, 0), false, true, false}, + {math.LegacyNewDecWithPrec(1111, 3), math.LegacyNewDecWithPrec(100, 0), false, true, false}, + {math.LegacyNewDecWithPrec(1111, 3), math.LegacyNewDecWithPrec(3333, 3), false, true, false}, + {math.LegacyNewDecWithPrec(-3333, 3), math.LegacyNewDecWithPrec(-1111, 3), false, true, false}, + + {math.LegacyNewDecWithPrec(3333, 3), math.LegacyNewDecWithPrec(0, 0), true, false, false}, + {math.LegacyNewDecWithPrec(100, 0), math.LegacyNewDecWithPrec(0, 0), true, false, false}, + {math.LegacyNewDecWithPrec(3333, 3), math.LegacyNewDecWithPrec(-1, 0), true, false, false}, + {math.LegacyNewDecWithPrec(100, 0), math.LegacyNewDecWithPrec(-1, 0), true, false, false}, + {math.LegacyNewDecWithPrec(100, 0), math.LegacyNewDecWithPrec(1111, 3), true, false, false}, + {math.LegacyNewDecWithPrec(3333, 3), math.LegacyNewDecWithPrec(1111, 3), true, false, false}, + {math.LegacyNewDecWithPrec(-1111, 3), math.LegacyNewDecWithPrec(-3333, 3), true, false, false}, + } + + for tcIndex, tc := range tests { + s.Require().Equal(tc.gt, tc.d1.GT(tc.d2), "GT result is incorrect, tc %d", tcIndex) + s.Require().Equal(tc.lt, tc.d1.LT(tc.d2), "LT result is incorrect, tc %d", tcIndex) + s.Require().Equal(tc.eq, tc.d1.Equal(tc.d2), "equality result is incorrect, tc %d", tcIndex) + } +} + +func (s *decimalTestSuite) TestDecsEqual() { + tests := []struct { + d1s, d2s []math.LegacyDec + eq bool + }{ + {[]math.LegacyDec{math.LegacyNewDec(0)}, []math.LegacyDec{math.LegacyNewDec(0)}, true}, + {[]math.LegacyDec{math.LegacyNewDec(0)}, []math.LegacyDec{math.LegacyNewDec(1)}, false}, + {[]math.LegacyDec{math.LegacyNewDec(0)}, []math.LegacyDec{}, false}, + {[]math.LegacyDec{math.LegacyNewDec(0), math.LegacyNewDec(1)}, []math.LegacyDec{math.LegacyNewDec(0), math.LegacyNewDec(1)}, true}, + {[]math.LegacyDec{math.LegacyNewDec(1), math.LegacyNewDec(0)}, []math.LegacyDec{math.LegacyNewDec(1), math.LegacyNewDec(0)}, true}, + {[]math.LegacyDec{math.LegacyNewDec(1), math.LegacyNewDec(0)}, []math.LegacyDec{math.LegacyNewDec(0), math.LegacyNewDec(1)}, false}, + {[]math.LegacyDec{math.LegacyNewDec(1), math.LegacyNewDec(0)}, []math.LegacyDec{math.LegacyNewDec(1)}, false}, + {[]math.LegacyDec{math.LegacyNewDec(1), math.LegacyNewDec(2)}, []math.LegacyDec{math.LegacyNewDec(2), math.LegacyNewDec(4)}, false}, + {[]math.LegacyDec{math.LegacyNewDec(3), math.LegacyNewDec(18)}, []math.LegacyDec{math.LegacyNewDec(1), math.LegacyNewDec(6)}, false}, + } + + for tcIndex, tc := range tests { + s.Require().Equal(tc.eq, math.LegacyDecsEqual(tc.d1s, tc.d2s), "equality of decional arrays is incorrect, tc %d", tcIndex) + s.Require().Equal(tc.eq, math.LegacyDecsEqual(tc.d2s, tc.d1s), "equality of decional arrays is incorrect (converse), tc %d", tcIndex) + } +} + +func (s *decimalTestSuite) TestArithmetic() { + tests := []struct { + d1, d2 math.LegacyDec + expMul, expMulTruncate, expMulRoundUp math.LegacyDec + expQuo, expQuoRoundUp, expQuoTruncate math.LegacyDec + expAdd, expSub math.LegacyDec + }{ + // d1 d2 MUL MulTruncate MulRoundUp QUO QUORoundUp QUOTrunctate ADD SUB + {math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0)}, + {math.LegacyNewDec(1), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(1), math.LegacyNewDec(1)}, + {math.LegacyNewDec(0), math.LegacyNewDec(1), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(1), math.LegacyNewDec(-1)}, + {math.LegacyNewDec(0), math.LegacyNewDec(-1), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(-1), math.LegacyNewDec(1)}, + {math.LegacyNewDec(-1), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(0), math.LegacyNewDec(-1), math.LegacyNewDec(-1)}, + + {math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(2), math.LegacyNewDec(0)}, + {math.LegacyNewDec(-1), math.LegacyNewDec(-1), math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(-2), math.LegacyNewDec(0)}, + {math.LegacyNewDec(1), math.LegacyNewDec(-1), math.LegacyNewDec(-1), math.LegacyNewDec(-1), math.LegacyNewDec(-1), math.LegacyNewDec(-1), math.LegacyNewDec(-1), math.LegacyNewDec(-1), math.LegacyNewDec(0), math.LegacyNewDec(2)}, + {math.LegacyNewDec(-1), math.LegacyNewDec(1), math.LegacyNewDec(-1), math.LegacyNewDec(-1), math.LegacyNewDec(-1), math.LegacyNewDec(-1), math.LegacyNewDec(-1), math.LegacyNewDec(-1), math.LegacyNewDec(0), math.LegacyNewDec(-2)}, + + { + math.LegacyNewDec(3), math.LegacyNewDec(7), math.LegacyNewDec(21), math.LegacyNewDec(21), math.LegacyNewDec(21), + math.LegacyNewDecWithPrec(428571428571428571, 18), math.LegacyNewDecWithPrec(428571428571428572, 18), math.LegacyNewDecWithPrec(428571428571428571, 18), + math.LegacyNewDec(10), math.LegacyNewDec(-4), + }, + { + math.LegacyNewDec(2), math.LegacyNewDec(4), math.LegacyNewDec(8), math.LegacyNewDec(8), math.LegacyNewDec(8), math.LegacyNewDecWithPrec(5, 1), math.LegacyNewDecWithPrec(5, 1), math.LegacyNewDecWithPrec(5, 1), + math.LegacyNewDec(6), math.LegacyNewDec(-2), + }, + + {math.LegacyNewDec(100), math.LegacyNewDec(100), math.LegacyNewDec(10000), math.LegacyNewDec(10000), math.LegacyNewDec(10000), math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(200), math.LegacyNewDec(0)}, + + { + math.LegacyNewDecWithPrec(15, 1), math.LegacyNewDecWithPrec(15, 1), math.LegacyNewDecWithPrec(225, 2), math.LegacyNewDecWithPrec(225, 2), math.LegacyNewDecWithPrec(225, 2), + math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(1), math.LegacyNewDec(3), math.LegacyNewDec(0), + }, + { + math.LegacyNewDecWithPrec(3333, 4), math.LegacyNewDecWithPrec(333, 4), math.LegacyNewDecWithPrec(1109889, 8), math.LegacyNewDecWithPrec(1109889, 8), math.LegacyNewDecWithPrec(1109889, 8), + math.LegacyMustNewDecFromStr("10.009009009009009009"), math.LegacyMustNewDecFromStr("10.009009009009009010"), math.LegacyMustNewDecFromStr("10.009009009009009009"), + math.LegacyNewDecWithPrec(3666, 4), math.LegacyNewDecWithPrec(3, 1), + }, + } + + for tcIndex, tc := range tests { + tc := tc + resAdd := tc.d1.Add(tc.d2) + resSub := tc.d1.Sub(tc.d2) + resMul := tc.d1.Mul(tc.d2) + resMulTruncate := tc.d1.MulTruncate(tc.d2) + resMulRoundUp := tc.d1.MulRoundUp(tc.d2) + s.Require().True(tc.expAdd.Equal(resAdd), "exp %v, res %v, tc %d", tc.expAdd, resAdd, tcIndex) + s.Require().True(tc.expSub.Equal(resSub), "exp %v, res %v, tc %d", tc.expSub, resSub, tcIndex) + s.Require().True(tc.expMul.Equal(resMul), "exp %v, res %v, tc %d", tc.expMul, resMul, tcIndex) + s.Require().True(tc.expMulTruncate.Equal(resMulTruncate), "exp %v, res %v, tc %d", tc.expMulTruncate, resMulTruncate, tcIndex) + s.Require().True(tc.expMulRoundUp.Equal(resMulRoundUp), "exp %v, res %v, tc %d", tc.expMulRoundUp, resMulRoundUp, tcIndex) + + if tc.d2.IsZero() { // panic for divide by zero + s.Require().Panics(func() { tc.d1.Quo(tc.d2) }) + } else { + resQuo := tc.d1.Quo(tc.d2) + s.Require().True(tc.expQuo.Equal(resQuo), "exp %v, res %v, tc %d", tc.expQuo.String(), resQuo.String(), tcIndex) + + resQuoRoundUp := tc.d1.QuoRoundUp(tc.d2) + s.Require().True(tc.expQuoRoundUp.Equal(resQuoRoundUp), "exp %v, res %v, tc %d", + tc.expQuoRoundUp.String(), resQuoRoundUp.String(), tcIndex) + + resQuoTruncate := tc.d1.QuoTruncate(tc.d2) + s.Require().True(tc.expQuoTruncate.Equal(resQuoTruncate), "exp %v, res %v, tc %d", + tc.expQuoTruncate.String(), resQuoTruncate.String(), tcIndex) + } + } +} + +func (s *decimalTestSuite) TestMulRoundUp_RoundingAtPrecisionEnd() { + var ( + a = math.LegacyMustNewDecFromStr("0.000000000000000009") + b = math.LegacyMustNewDecFromStr("0.000000000000000009") + expectedRoundUp = math.LegacyMustNewDecFromStr("0.000000000000000001") + expectedTruncate = math.LegacyMustNewDecFromStr("0.000000000000000000") + ) + + actualRoundUp := a.MulRoundUp(b) + s.Require().Equal(expectedRoundUp.String(), actualRoundUp.String(), "exp %v, res %v", expectedRoundUp, actualRoundUp) + + actualTruncate := a.MulTruncate(b) + s.Require().Equal(expectedTruncate.String(), actualTruncate.String(), "exp %v, res %v", expectedRoundUp, actualTruncate) +} + +func (s *decimalTestSuite) TestBankerRoundChop() { + tests := []struct { + d1 math.LegacyDec + exp int64 + }{ + {s.mustNewDecFromStr("0.25"), 0}, + {s.mustNewDecFromStr("0"), 0}, + {s.mustNewDecFromStr("1"), 1}, + {s.mustNewDecFromStr("0.75"), 1}, + {s.mustNewDecFromStr("0.5"), 0}, + {s.mustNewDecFromStr("7.5"), 8}, + {s.mustNewDecFromStr("1.5"), 2}, + {s.mustNewDecFromStr("2.5"), 2}, + {s.mustNewDecFromStr("0.545"), 1}, // 0.545-> 1 even though 5 is first decimal and 1 not even + {s.mustNewDecFromStr("1.545"), 2}, + } + + for tcIndex, tc := range tests { + resNeg := tc.d1.Neg().RoundInt64() + s.Require().Equal(-1*tc.exp, resNeg, "negative tc %d", tcIndex) + + resPos := tc.d1.RoundInt64() + s.Require().Equal(tc.exp, resPos, "positive tc %d", tcIndex) + } +} + +func (s *decimalTestSuite) TestTruncate() { + tests := []struct { + d1 math.LegacyDec + exp int64 + }{ + {s.mustNewDecFromStr("0"), 0}, + {s.mustNewDecFromStr("0.25"), 0}, + {s.mustNewDecFromStr("0.75"), 0}, + {s.mustNewDecFromStr("1"), 1}, + {s.mustNewDecFromStr("1.5"), 1}, + {s.mustNewDecFromStr("7.5"), 7}, + {s.mustNewDecFromStr("7.6"), 7}, + {s.mustNewDecFromStr("7.4"), 7}, + {s.mustNewDecFromStr("100.1"), 100}, + {s.mustNewDecFromStr("1000.1"), 1000}, + } + + for tcIndex, tc := range tests { + resNeg := tc.d1.Neg().TruncateInt64() + s.Require().Equal(-1*tc.exp, resNeg, "negative tc %d", tcIndex) + + resPos := tc.d1.TruncateInt64() + s.Require().Equal(tc.exp, resPos, "positive tc %d", tcIndex) + } +} + +func (s *decimalTestSuite) TestStringOverflow() { + // two random 64 bit primes + dec1, err := math.LegacyNewDecFromStr("51643150036226787134389711697696177267") + s.Require().NoError(err) + dec2, err := math.LegacyNewDecFromStr("-31798496660535729618459429845579852627") + s.Require().NoError(err) + dec3 := dec1.Add(dec2) + s.Require().Equal( + "19844653375691057515930281852116324640.000000000000000000", + dec3.String(), + ) +} + +func (s *decimalTestSuite) TestDecMulInt() { + tests := []struct { + sdkDec math.LegacyDec + sdkInt math.Int + want math.LegacyDec + }{ + {math.LegacyNewDec(10), math.NewInt(2), math.LegacyNewDec(20)}, + {math.LegacyNewDec(1000000), math.NewInt(100), math.LegacyNewDec(100000000)}, + {math.LegacyNewDecWithPrec(1, 1), math.NewInt(10), math.LegacyNewDec(1)}, + {math.LegacyNewDecWithPrec(1, 5), math.NewInt(20), math.LegacyNewDecWithPrec(2, 4)}, + } + for i, tc := range tests { + got := tc.sdkDec.MulInt(tc.sdkInt) + s.Require().Equal(tc.want, got, "Incorrect result on test case %d", i) + } +} + +func (s *decimalTestSuite) TestDecCeil() { + testCases := []struct { + input math.LegacyDec + expected math.LegacyDec + }{ + {math.LegacyNewDecWithPrec(1000000000000000, math.LegacyPrecision), math.LegacyNewDec(1)}, // 0.001 => 1.0 + {math.LegacyNewDecWithPrec(-1000000000000000, math.LegacyPrecision), math.LegacyZeroDec()}, // -0.001 => 0.0 + {math.LegacyZeroDec(), math.LegacyZeroDec()}, // 0.0 => 0.0 + {math.LegacyNewDecWithPrec(900000000000000000, math.LegacyPrecision), math.LegacyNewDec(1)}, // 0.9 => 1.0 + {math.LegacyNewDecWithPrec(4001000000000000000, math.LegacyPrecision), math.LegacyNewDec(5)}, // 4.001 => 5.0 + {math.LegacyNewDecWithPrec(-4001000000000000000, math.LegacyPrecision), math.LegacyNewDec(-4)}, // -4.001 => -4.0 + {math.LegacyNewDecWithPrec(4700000000000000000, math.LegacyPrecision), math.LegacyNewDec(5)}, // 4.7 => 5.0 + {math.LegacyNewDecWithPrec(-4700000000000000000, math.LegacyPrecision), math.LegacyNewDec(-4)}, // -4.7 => -4.0 + } + + for i, tc := range testCases { + res := tc.input.Ceil() + s.Require().Equal(tc.expected, res, "unexpected result for test case %d, input: %v", i, tc.input) + } +} + +func (s *decimalTestSuite) TestPower() { + testCases := []struct { + input math.LegacyDec + power uint64 + expected math.LegacyDec + }{ + {math.LegacyNewDec(100), 0, math.LegacyOneDec()}, // 10 ^ (0) => 1.0 + {math.LegacyOneDec(), 10, math.LegacyOneDec()}, // 1.0 ^ (10) => 1.0 + {math.LegacyNewDecWithPrec(5, 1), 2, math.LegacyNewDecWithPrec(25, 2)}, // 0.5 ^ 2 => 0.25 + {math.LegacyNewDecWithPrec(2, 1), 2, math.LegacyNewDecWithPrec(4, 2)}, // 0.2 ^ 2 => 0.04 + {math.LegacyNewDecFromInt(math.NewInt(3)), 3, math.LegacyNewDecFromInt(math.NewInt(27))}, // 3 ^ 3 => 27 + {math.LegacyNewDecFromInt(math.NewInt(-3)), 4, math.LegacyNewDecFromInt(math.NewInt(81))}, // -3 ^ 4 = 81 + {math.LegacyNewDecWithPrec(1414213562373095049, 18), 2, math.LegacyNewDecFromInt(math.NewInt(2))}, // 1.414213562373095049 ^ 2 = 2 + } + + for i, tc := range testCases { + res := tc.input.Power(tc.power) + s.Require().True(tc.expected.Sub(res).Abs().LTE(math.LegacySmallestDec()), "unexpected result for test case %d, normal power, input: %v", i, tc.input) + + mutableInput := tc.input + mutableInput.PowerMut(tc.power) + s.Require().True(tc.expected.Sub(mutableInput).Abs().LTE(math.LegacySmallestDec()), + "unexpected result for test case %d, input %v", i, tc.input) + s.Require().True(res.Equal(tc.input), "unexpected result for test case %d, mutable power, input: %v", i, tc.input) + } +} + +func (s *decimalTestSuite) TestApproxRoot() { + testCases := []struct { + input math.LegacyDec + root uint64 + expected math.LegacyDec + }{ + {math.LegacyOneDec(), 10, math.LegacyOneDec()}, // 1.0 ^ (0.1) => 1.0 + {math.LegacyNewDecWithPrec(25, 2), 2, math.LegacyNewDecWithPrec(5, 1)}, // 0.25 ^ (0.5) => 0.5 + {math.LegacyNewDecWithPrec(4, 2), 2, math.LegacyNewDecWithPrec(2, 1)}, // 0.04 ^ (0.5) => 0.2 + {math.LegacyNewDecFromInt(math.NewInt(27)), 3, math.LegacyNewDecFromInt(math.NewInt(3))}, // 27 ^ (1/3) => 3 + {math.LegacyNewDecFromInt(math.NewInt(-81)), 4, math.LegacyNewDecFromInt(math.NewInt(-3))}, // -81 ^ (0.25) => -3 + {math.LegacyNewDecFromInt(math.NewInt(2)), 2, math.LegacyNewDecWithPrec(1414213562373095049, 18)}, // 2 ^ (0.5) => 1.414213562373095049 + {math.LegacyNewDecWithPrec(1005, 3), 31536000, math.LegacyMustNewDecFromStr("1.000000000158153904")}, // 1.005 ^ (1/31536000) ≈ 1.00000000016 + {math.LegacySmallestDec(), 2, math.LegacyNewDecWithPrec(1, 9)}, // 1e-18 ^ (0.5) => 1e-9 + {math.LegacySmallestDec(), 3, math.LegacyMustNewDecFromStr("0.000000999999999997")}, // 1e-18 ^ (1/3) => 1e-6 + {math.LegacyNewDecWithPrec(1, 8), 3, math.LegacyMustNewDecFromStr("0.002154434690031900")}, // 1e-8 ^ (1/3) ≈ 0.00215443469 + {math.LegacyMustNewDecFromStr("9000002314687921634000000000000000000021394871242000000000000000"), 2, math.LegacyMustNewDecFromStr("94868342004527103646332858502867.899477053226766107")}, + } + + // In the case of 1e-8 ^ (1/3), the result repeats every 5 iterations starting from iteration 24 + // (i.e. 24, 29, 34, ... give the same result) and never converges enough. The maximum number of + // iterations (300) causes the result at iteration 300 to be returned, regardless of convergence. + + for i, tc := range testCases { + res, err := tc.input.ApproxRoot(tc.root) + s.Require().NoError(err) + s.Require().True(tc.expected.Sub(res).Abs().LTE(math.LegacySmallestDec()), "unexpected result for test case %d, input: %v", i, tc.input) + } +} + +func (s *decimalTestSuite) TestApproxSqrt() { + testCases := []struct { + input math.LegacyDec + expected math.LegacyDec + }{ + {math.LegacyOneDec(), math.LegacyOneDec()}, // 1.0 => 1.0 + {math.LegacyNewDecWithPrec(25, 2), math.LegacyNewDecWithPrec(5, 1)}, // 0.25 => 0.5 + {math.LegacyNewDecWithPrec(4, 2), math.LegacyNewDecWithPrec(2, 1)}, // 0.09 => 0.3 + {math.LegacyNewDec(9), math.LegacyNewDecFromInt(math.NewInt(3))}, // 9 => 3 + {math.LegacyNewDec(-9), math.LegacyNewDecFromInt(math.NewInt(-3))}, // -9 => -3 + {math.LegacyNewDec(2), math.LegacyNewDecWithPrec(1414213562373095049, 18)}, // 2 => 1.414213562373095049 + { // 2^127 - 1 => 13043817825332782212.3495718062525083688 which rounds to 13043817825332782212.3495718062525083689 + math.LegacyNewDec(2).Power(127).Sub(math.LegacyOneDec()), + math.LegacyMustNewDecFromStr("13043817825332782212.349571806252508369"), + }, + {math.LegacyMustNewDecFromStr("1.000000011823380862"), math.LegacyMustNewDecFromStr("1.000000005911690414")}, + } + + for i, tc := range testCases { + res, err := tc.input.ApproxSqrt() + s.Require().NoError(err) + s.Require().Equal(tc.expected, res, "unexpected result for test case %d, input: %v", i, tc.input) + } +} + +func (s *decimalTestSuite) TestDecSortableBytes() { + tests := []struct { + d math.LegacyDec + want []byte + }{ + {math.LegacyNewDec(0), []byte("000000000000000000.000000000000000000")}, + {math.LegacyNewDec(1), []byte("000000000000000001.000000000000000000")}, + {math.LegacyNewDec(10), []byte("000000000000000010.000000000000000000")}, + {math.LegacyNewDec(12340), []byte("000000000000012340.000000000000000000")}, + {math.LegacyNewDecWithPrec(12340, 4), []byte("000000000000000001.234000000000000000")}, + {math.LegacyNewDecWithPrec(12340, 5), []byte("000000000000000000.123400000000000000")}, + {math.LegacyNewDecWithPrec(12340, 8), []byte("000000000000000000.000123400000000000")}, + {math.LegacyNewDecWithPrec(1009009009009009009, 17), []byte("000000000000000010.090090090090090090")}, + {math.LegacyNewDecWithPrec(-1009009009009009009, 17), []byte("-000000000000000010.090090090090090090")}, + {math.LegacyNewDec(1000000000000000000), []byte("max")}, + {math.LegacyNewDec(-1000000000000000000), []byte("--")}, + } + for tcIndex, tc := range tests { + s.Require().Equal(tc.want, math.LegacySortableDecBytes(tc.d), "bad String(), index: %v", tcIndex) + } + + s.Require().Panics(func() { math.LegacySortableDecBytes(math.LegacyNewDec(1000000000000000001)) }) + s.Require().Panics(func() { math.LegacySortableDecBytes(math.LegacyNewDec(-1000000000000000001)) }) +} + +func (s *decimalTestSuite) TestDecEncoding() { + largestBigInt, ok := new(big.Int).SetString("33499189745056880149688856635597007162669032647290798121690100488888732861290034376435130433535", 10) + s.Require().True(ok) + + smallestBigInt, ok := new(big.Int).SetString("-33499189745056880149688856635597007162669032647290798121690100488888732861290034376435130433535", 10) + s.Require().True(ok) + + const maxDecBitLen = 315 + maxInt, ok := new(big.Int).SetString(strings.Repeat("1", maxDecBitLen), 2) + s.Require().True(ok) + + testCases := []struct { + input math.LegacyDec + rawBz string + jsonStr string + yamlStr string + }{ + { + math.LegacyNewDec(0), "30", + "\"0.000000000000000000\"", + "\"0.000000000000000000\"\n", + }, + { + math.LegacyNewDecWithPrec(4, 2), + "3430303030303030303030303030303030", + "\"0.040000000000000000\"", + "\"0.040000000000000000\"\n", + }, + { + math.LegacyNewDecWithPrec(-4, 2), + "2D3430303030303030303030303030303030", + "\"-0.040000000000000000\"", + "\"-0.040000000000000000\"\n", + }, + { + math.LegacyNewDecWithPrec(1414213562373095049, 18), + "31343134323133353632333733303935303439", + "\"1.414213562373095049\"", + "\"1.414213562373095049\"\n", + }, + { + math.LegacyNewDecWithPrec(-1414213562373095049, 18), + "2D31343134323133353632333733303935303439", + "\"-1.414213562373095049\"", + "\"-1.414213562373095049\"\n", + }, + { + math.LegacyNewDecFromBigIntWithPrec(largestBigInt, 18), + "3333343939313839373435303536383830313439363838383536363335353937303037313632363639303332363437323930373938313231363930313030343838383838373332383631323930303334333736343335313330343333353335", + "\"33499189745056880149688856635597007162669032647290798121690100488888732861290.034376435130433535\"", + "\"33499189745056880149688856635597007162669032647290798121690100488888732861290.034376435130433535\"\n", + }, + { + math.LegacyNewDecFromBigIntWithPrec(smallestBigInt, 18), + "2D3333343939313839373435303536383830313439363838383536363335353937303037313632363639303332363437323930373938313231363930313030343838383838373332383631323930303334333736343335313330343333353335", + "\"-33499189745056880149688856635597007162669032647290798121690100488888732861290.034376435130433535\"", + "\"-33499189745056880149688856635597007162669032647290798121690100488888732861290.034376435130433535\"\n", + }, + { + math.LegacyNewDecFromBigIntWithPrec(maxInt, 18), + "3636373439353934383732353238343430303734383434343238333137373938353033353831333334353136333233363435333939303630383435303530323434343434333636343330363435303137313838323137353635323136373637", + "\"66749594872528440074844428317798503581334516323645399060845050244444366430645.017188217565216767\"", + "\"66749594872528440074844428317798503581334516323645399060845050244444366430645.017188217565216767\"\n", + }, + } + + for _, tc := range testCases { + bz, err := tc.input.Marshal() + s.Require().NoError(err) + s.Require().Equal(tc.rawBz, fmt.Sprintf("%X", bz)) + + var other math.LegacyDec + s.Require().NoError((&other).Unmarshal(bz)) + s.Require().True(tc.input.Equal(other)) + + bz, err = json.Marshal(tc.input) + s.Require().NoError(err) + s.Require().Equal(tc.jsonStr, string(bz)) + s.Require().NoError(json.Unmarshal(bz, &other)) + s.Require().True(tc.input.Equal(other)) + + bz, err = yaml.Marshal(tc.input) + s.Require().NoError(err) + s.Require().Equal(tc.yamlStr, string(bz)) + } +} + +// Showcase that different orders of operations causes different results. +func (s *decimalTestSuite) TestOperationOrders() { + n1 := math.LegacyNewDec(10) + n2 := math.LegacyNewDec(1000000010) + s.Require().Equal(n1.Mul(n2).Quo(n2), math.LegacyNewDec(10)) + s.Require().NotEqual(n1.Mul(n2).Quo(n2), n1.Quo(n2).Mul(n2)) +} + +func BenchmarkMarshalTo(b *testing.B) { + b.ReportAllocs() + bis := []struct { + in math.LegacyDec + want []byte + }{ + { + math.LegacyNewDec(1e8), []byte{ + 0x31, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + }, + }, + {math.LegacyNewDec(0), []byte{0x30}}, + } + data := make([]byte, 100) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + for _, bi := range bis { + if n, err := bi.in.MarshalTo(data); err != nil { + b.Fatal(err) + } else if !bytes.Equal(data[:n], bi.want) { + b.Fatalf("Mismatch\nGot: % x\nWant: % x\n", data[:n], bi.want) + } + } + } +} + +var sink interface{} + +func BenchmarkLegacyQuoMut(b *testing.B) { + b1 := math.LegacyNewDec(17e2 + 8371) + b2 := math.LegacyNewDec(4371) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + sink = b1.QuoMut(b2) + } + + if sink == nil { + b.Fatal("Benchmark did not run") + } + sink = (interface{})(nil) +} + +func BenchmarkLegacyQuoTruncateMut(b *testing.B) { + b1 := math.LegacyNewDec(17e2 + 8371) + b2 := math.LegacyNewDec(4371) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + sink = b1.QuoTruncateMut(b2) + } + + if sink == nil { + b.Fatal("Benchmark did not run") + } + sink = (interface{})(nil) +} + +func BenchmarkLegacySqrtOnMersennePrime(b *testing.B) { + b1 := math.LegacyNewDec(2).Power(127).Sub(math.LegacyOneDec()) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + sink, _ = b1.ApproxSqrt() + } + + if sink == nil { + b.Fatal("Benchmark did not run") + } + sink = (interface{})(nil) +} + +func BenchmarkLegacyQuoRoundupMut(b *testing.B) { + b1 := math.LegacyNewDec(17e2 + 8371) + b2 := math.LegacyNewDec(4371) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + sink = b1.QuoRoundupMut(b2) + } + + if sink == nil { + b.Fatal("Benchmark did not run") + } + sink = (interface{})(nil) +} + +func TestFormatDec(t *testing.T) { + type decimalTest []string + var testcases []decimalTest + raw, err := os.ReadFile("./testdata/decimals.json") + require.NoError(t, err) + err = json.Unmarshal(raw, &testcases) + require.NoError(t, err) + + for _, tc := range testcases { + tc := tc + t.Run(tc[0], func(t *testing.T) { + out, err := math.FormatDec(tc[0]) + require.NoError(t, err) + require.Equal(t, tc[1], out) + }) + } +} + +func TestFormatDecNonDigits(t *testing.T) { + badCases := []string{ + "10.a", + "1a.10", + "p1a10.", + "0.10p", + "--10", + "12.😎😎", + "11111111111133333333333333333333333333333a", + "11111111111133333333333333333333333333333 192892", + } + + for _, value := range badCases { + value := value + t.Run(value, func(t *testing.T) { + s, err := math.FormatDec(value) + if err == nil { + t.Fatal("Expected an error") + } + if g, w := err.Error(), "non-digits"; !strings.Contains(g, w) { + t.Errorf("Error mismatch\nGot: %q\nWant substring: %q", g, w) + } + if s != "" { + t.Fatalf("Got a non-empty string: %q", s) + } + }) + } +} + +func TestNegativePrecisionPanic(t *testing.T) { + require.Panics(t, func() { + math.LegacyNewDecWithPrec(10, -1) + }) +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..17664a9 --- /dev/null +++ b/doc.go @@ -0,0 +1,6 @@ +/* +Package math implements custom Cosmos SDK math types used for arithmetic +operations. Signed and unsigned integer types utilize Golang's standard library +big integers types, having a maximum bit length of 256 bits. +*/ +package math diff --git a/fuzz_test.go b/fuzz_test.go new file mode 100644 index 0000000..e50ae41 --- /dev/null +++ b/fuzz_test.go @@ -0,0 +1,24 @@ +package math + +import ( + "testing" +) + +func FuzzLegacyNewDecFromStr(f *testing.F) { + if testing.Short() { + f.Skip("running in -short mode") + } + + f.Add("-123.456") + f.Add("123.456789") + f.Add("123456789") + f.Add("0.12123456789") + f.Add("-12123456789") + + f.Fuzz(func(t *testing.T, input string) { + dec, err := LegacyNewDecFromStr(input) + if err != nil && !dec.IsNil() { + t.Fatalf("Inconsistency: dec.notNil=%v yet err=%v", dec, err) + } + }) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..4e2b87c --- /dev/null +++ b/go.mod @@ -0,0 +1,21 @@ +module cosmossdk.io/math + +go 1.20 + +require ( + github.com/stretchr/testify v1.8.4 + golang.org/x/exp v0.0.0-20221205204356-47842c84f3db + sigs.k8s.io/yaml v1.3.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +// Issue with math.Int{}.Size() implementation. +retract [v1.1.0, v1.1.1] diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..ea4dfe1 --- /dev/null +++ b/go.sum @@ -0,0 +1,28 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= +golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo= +sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8= diff --git a/int.go b/int.go new file mode 100644 index 0000000..8685cf5 --- /dev/null +++ b/int.go @@ -0,0 +1,524 @@ +package math + +import ( + "encoding" + "encoding/json" + "fmt" + "math/big" + "strings" + "sync" + "testing" +) + +// MaxBitLen defines the maximum bit length supported bit Int and Uint types. +const MaxBitLen = 256 + +func newIntegerFromString(s string) (*big.Int, bool) { + return new(big.Int).SetString(s, 0) +} + +func equal(i, i2 *big.Int) bool { return i.Cmp(i2) == 0 } + +func gt(i, i2 *big.Int) bool { return i.Cmp(i2) == 1 } + +func gte(i, i2 *big.Int) bool { return i.Cmp(i2) >= 0 } + +func lt(i, i2 *big.Int) bool { return i.Cmp(i2) == -1 } + +func lte(i, i2 *big.Int) bool { return i.Cmp(i2) <= 0 } + +func add(i, i2 *big.Int) *big.Int { return new(big.Int).Add(i, i2) } + +func sub(i, i2 *big.Int) *big.Int { return new(big.Int).Sub(i, i2) } + +func mul(i, i2 *big.Int) *big.Int { return new(big.Int).Mul(i, i2) } + +func div(i, i2 *big.Int) *big.Int { return new(big.Int).Quo(i, i2) } + +func mod(i, i2 *big.Int) *big.Int { return new(big.Int).Mod(i, i2) } + +func neg(i *big.Int) *big.Int { return new(big.Int).Neg(i) } + +func abs(i *big.Int) *big.Int { return new(big.Int).Abs(i) } + +func min(i, i2 *big.Int) *big.Int { + if i.Cmp(i2) == 1 { + return new(big.Int).Set(i2) + } + + return new(big.Int).Set(i) +} + +func max(i, i2 *big.Int) *big.Int { + if i.Cmp(i2) == -1 { + return new(big.Int).Set(i2) + } + + return new(big.Int).Set(i) +} + +func unmarshalText(i *big.Int, text string) error { + if err := i.UnmarshalText([]byte(text)); err != nil { + return err + } + + if i.BitLen() > MaxBitLen { + return fmt.Errorf("integer out of range: %s", text) + } + + return nil +} + +var _ customProtobufType = (*Int)(nil) + +// Int wraps big.Int with a 256 bit range bound +// Checks overflow, underflow and division by zero +// Exists in range from -(2^256 - 1) to 2^256 - 1 +type Int struct { + i *big.Int +} + +// BigInt converts Int to big.Int +func (i Int) BigInt() *big.Int { + if i.IsNil() { + return nil + } + return new(big.Int).Set(i.i) +} + +// BigInt converts Int to big.Int, mutative the input +func (i Int) BigIntMut() *big.Int { + if i.IsNil() { + return nil + } + return i.i +} + +// IsNil returns true if Int is uninitialized +func (i Int) IsNil() bool { + return i.i == nil +} + +// NewInt constructs Int from int64 +func NewInt(n int64) Int { + return Int{big.NewInt(n)} +} + +// NewIntFromUint64 constructs an Int from a uint64. +func NewIntFromUint64(n uint64) Int { + b := big.NewInt(0) + b.SetUint64(n) + return Int{b} +} + +// NewIntFromBigInt constructs Int from big.Int. If the provided big.Int is nil, +// it returns an empty instance. This function panics if the bit length is > 256. +// Note, the caller can safely mutate the argument after this function returns. +func NewIntFromBigInt(i *big.Int) Int { + if i == nil { + return Int{} + } + + if i.BitLen() > MaxBitLen { + panic("NewIntFromBigInt() out of bound") + } + + return Int{new(big.Int).Set(i)} +} + +// NewIntFromString constructs Int from string +func NewIntFromString(s string) (res Int, ok bool) { + i, ok := newIntegerFromString(s) + if !ok { + return + } + // Check overflow + if i.BitLen() > MaxBitLen { + ok = false + return + } + return Int{i}, true +} + +// NewIntWithDecimal constructs Int with decimal +// Result value is n*10^dec +func NewIntWithDecimal(n int64, dec int) Int { + if dec < 0 { + panic("NewIntWithDecimal() decimal is negative") + } + exp := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(dec)), nil) + i := new(big.Int) + i.Mul(big.NewInt(n), exp) + + // Check overflow + if i.BitLen() > MaxBitLen { + panic("NewIntWithDecimal() out of bound") + } + return Int{i} +} + +// ZeroInt returns Int value with zero +func ZeroInt() Int { return Int{big.NewInt(0)} } + +// OneInt returns Int value with one +func OneInt() Int { return Int{big.NewInt(1)} } + +// ToLegacyDec converts Int to LegacyDec +func (i Int) ToLegacyDec() LegacyDec { + return LegacyNewDecFromInt(i) +} + +// Int64 converts Int to int64 +// Panics if the value is out of range +func (i Int) Int64() int64 { + if !i.i.IsInt64() { + panic("Int64() out of bound") + } + return i.i.Int64() +} + +// IsInt64 returns true if Int64() not panics +func (i Int) IsInt64() bool { + return i.i.IsInt64() +} + +// Uint64 converts Int to uint64 +// Panics if the value is out of range +func (i Int) Uint64() uint64 { + if !i.i.IsUint64() { + panic("Uint64() out of bounds") + } + return i.i.Uint64() +} + +// IsUint64 returns true if Uint64() not panics +func (i Int) IsUint64() bool { + return i.i.IsUint64() +} + +// IsZero returns true if Int is zero +func (i Int) IsZero() bool { + return i.i.Sign() == 0 +} + +// IsNegative returns true if Int is negative +func (i Int) IsNegative() bool { + return i.i.Sign() == -1 +} + +// IsPositive returns true if Int is positive +func (i Int) IsPositive() bool { + return i.i.Sign() == 1 +} + +// Sign returns sign of Int +func (i Int) Sign() int { + return i.i.Sign() +} + +// Equal compares two Ints +func (i Int) Equal(i2 Int) bool { + return equal(i.i, i2.i) +} + +// GT returns true if first Int is greater than second +func (i Int) GT(i2 Int) bool { + return gt(i.i, i2.i) +} + +// GTE returns true if receiver Int is greater than or equal to the parameter +// Int. +func (i Int) GTE(i2 Int) bool { + return gte(i.i, i2.i) +} + +// LT returns true if first Int is lesser than second +func (i Int) LT(i2 Int) bool { + return lt(i.i, i2.i) +} + +// LTE returns true if first Int is less than or equal to second +func (i Int) LTE(i2 Int) bool { + return lte(i.i, i2.i) +} + +// Add adds Int from another +func (i Int) Add(i2 Int) (res Int) { + res = Int{add(i.i, i2.i)} + // Check overflow + if res.i.BitLen() > MaxBitLen { + panic("Int overflow") + } + return +} + +// AddRaw adds int64 to Int +func (i Int) AddRaw(i2 int64) Int { + return i.Add(NewInt(i2)) +} + +// Sub subtracts Int from another +func (i Int) Sub(i2 Int) (res Int) { + res = Int{sub(i.i, i2.i)} + // Check overflow + if res.i.BitLen() > MaxBitLen { + panic("Int overflow") + } + return +} + +// SubRaw subtracts int64 from Int +func (i Int) SubRaw(i2 int64) Int { + return i.Sub(NewInt(i2)) +} + +// Mul multiples two Ints +func (i Int) Mul(i2 Int) (res Int) { + // Check overflow + if i.i.BitLen()+i2.i.BitLen()-1 > MaxBitLen { + panic("Int overflow") + } + res = Int{mul(i.i, i2.i)} + // Check overflow if sign of both are same + if res.i.BitLen() > MaxBitLen { + panic("Int overflow") + } + return +} + +// MulRaw multipies Int and int64 +func (i Int) MulRaw(i2 int64) Int { + return i.Mul(NewInt(i2)) +} + +// Quo divides Int with Int +func (i Int) Quo(i2 Int) (res Int) { + // Check division-by-zero + if i2.i.Sign() == 0 { + panic("Division by zero") + } + return Int{div(i.i, i2.i)} +} + +// QuoRaw divides Int with int64 +func (i Int) QuoRaw(i2 int64) Int { + return i.Quo(NewInt(i2)) +} + +// Mod returns remainder after dividing with Int +func (i Int) Mod(i2 Int) Int { + if i2.Sign() == 0 { + panic("division-by-zero") + } + return Int{mod(i.i, i2.i)} +} + +// ModRaw returns remainder after dividing with int64 +func (i Int) ModRaw(i2 int64) Int { + return i.Mod(NewInt(i2)) +} + +// Neg negates Int +func (i Int) Neg() (res Int) { + return Int{neg(i.i)} +} + +// Abs returns the absolute value of Int. +func (i Int) Abs() Int { + return Int{abs(i.i)} +} + +// return the minimum of the ints +func MinInt(i1, i2 Int) Int { + return Int{min(i1.BigInt(), i2.BigInt())} +} + +// MaxInt returns the maximum between two integers. +func MaxInt(i, i2 Int) Int { + return Int{max(i.BigInt(), i2.BigInt())} +} + +// Human readable string +func (i Int) String() string { + return i.i.String() +} + +// MarshalJSON defines custom encoding scheme +func (i Int) MarshalJSON() ([]byte, error) { + if i.i == nil { // Necessary since default Uint initialization has i.i as nil + i.i = new(big.Int) + } + return marshalJSON(i.i) +} + +// UnmarshalJSON defines custom decoding scheme +func (i *Int) UnmarshalJSON(bz []byte) error { + if i.i == nil { // Necessary since default Int initialization has i.i as nil + i.i = new(big.Int) + } + return unmarshalJSON(i.i, bz) +} + +// MarshalJSON for custom encoding scheme +// Must be encoded as a string for JSON precision +func marshalJSON(i encoding.TextMarshaler) ([]byte, error) { + text, err := i.MarshalText() + if err != nil { + return nil, err + } + + return json.Marshal(string(text)) +} + +// UnmarshalJSON for custom decoding scheme +// Must be encoded as a string for JSON precision +func unmarshalJSON(i *big.Int, bz []byte) error { + var text string + if err := json.Unmarshal(bz, &text); err != nil { + return err + } + + return unmarshalText(i, text) +} + +// MarshalYAML returns the YAML representation. +func (i Int) MarshalYAML() (interface{}, error) { + return i.String(), nil +} + +// Marshal implements the gogo proto custom type interface. +func (i Int) Marshal() ([]byte, error) { + if i.i == nil { + i.i = new(big.Int) + } + return i.i.MarshalText() +} + +// MarshalTo implements the gogo proto custom type interface. +func (i *Int) MarshalTo(data []byte) (n int, err error) { + if i.i == nil { + i.i = new(big.Int) + } + if i.i.BitLen() == 0 { // The value 0 + n = copy(data, []byte{0x30}) + return n, nil + } + + bz, err := i.Marshal() + if err != nil { + return 0, err + } + + n = copy(data, bz) + return n, nil +} + +// Unmarshal implements the gogo proto custom type interface. +func (i *Int) Unmarshal(data []byte) error { + if len(data) == 0 { + i = nil + return nil + } + + if i.i == nil { + i.i = new(big.Int) + } + + if err := i.i.UnmarshalText(data); err != nil { + return err + } + + if i.i.BitLen() > MaxBitLen { + return fmt.Errorf("integer out of range; got: %d, max: %d", i.i.BitLen(), MaxBitLen) + } + + return nil +} + +// Size implements the gogo proto custom type interface. +func (i *Int) Size() int { + bz, _ := i.Marshal() + return len(bz) +} + +// Override Amino binary serialization by proxying to protobuf. +func (i Int) MarshalAmino() ([]byte, error) { return i.Marshal() } +func (i *Int) UnmarshalAmino(bz []byte) error { return i.Unmarshal(bz) } + +// intended to be used with require/assert: require.True(IntEq(...)) +func IntEq(t *testing.T, exp, got Int) (*testing.T, bool, string, string, string) { + t.Helper() + return t, exp.Equal(got), "expected:\t%v\ngot:\t\t%v", exp.String(), got.String() +} + +func hasOnlyDigits(s string) bool { + if s == "" { + return false + } + for _, r := range s { + if r < '0' || r > '9' { + return false + } + } + return true +} + +const thousandSeparator string = "'" + +var stringsBuilderPool = &sync.Pool{ + New: func() any { return new(strings.Builder) }, +} + +// FormatInt formats an integer (encoded as in protobuf) into a value-rendered +// string following ADR-050. This function operates with string manipulation +// (instead of manipulating the int or math.Int object). +func FormatInt(v string) (string, error) { + if len(v) == 0 { + return "", fmt.Errorf("cannot format empty string") + } + + sign := "" + if v[0] == '-' { + sign = "-" + v = v[1:] + } + if len(v) > 1 { + v = strings.TrimLeft(v, "0") + } + + // Ensure that the string contains only digits at this point. + if !hasOnlyDigits(v) { + return "", fmt.Errorf("expecting only digits 0-9, but got non-digits in %q", v) + } + + // 1. Less than 4 digits don't need any formatting. + if len(v) <= 3 { + return sign + v, nil + } + + sb := stringsBuilderPool.Get().(*strings.Builder) + defer stringsBuilderPool.Put(sb) + sb.Reset() + sb.Grow(len(v) + len(v)/3) // Exactly v + numberOfThousandSeparatorsIn(v) + + // 2. If the length of v is not a multiple of 3 e.g. 1234 or 12345, to achieve 1'234 or 12'345, + // we can simply slide to the first mod3 values of v that aren't the multiples of 3 then insert in + // the thousands separator so in this case: write(12'); then the remaining v will be entirely multiple + // of 3 hence v = 34* + if mod3 := len(v) % 3; mod3 != 0 { + sb.WriteString(v[:mod3]) + v = v[mod3:] + sb.WriteString(thousandSeparator) + } + + // 3. By this point v is entirely multiples of 3 hence we just insert the separator at every 3 digit. + for i := 0; i < len(v); i += 3 { + end := i + 3 + sb.WriteString(v[i:end]) + if end < len(v) { + sb.WriteString(thousandSeparator) + } + } + + return sign + sb.String(), nil +} diff --git a/int_internal_test.go b/int_internal_test.go new file mode 100644 index 0000000..d2cb902 --- /dev/null +++ b/int_internal_test.go @@ -0,0 +1,155 @@ +package math + +import ( + "math/big" + "math/rand" + "testing" + + "github.com/stretchr/testify/suite" +) + +type internalIntTestSuite struct { + suite.Suite +} + +func TestInternalIntTestSuite(t *testing.T) { + suite.Run(t, new(internalIntTestSuite)) +} + +func (s *internalIntTestSuite) TestEncodingRandom() { + for i := 0; i < 1000; i++ { + n := rand.Int63() + ni := NewInt(n) + var ri Int + + str, err := ni.Marshal() + s.Require().Nil(err) + err = (&ri).Unmarshal(str) + s.Require().Nil(err) + + s.Require().Equal(ni, ri, "binary mismatch; tc #%d, expected %s, actual %s", i, ni.String(), ri.String()) + s.Require().True(ni.i != ri.i, "pointer addresses are equal; tc #%d", i) + + bz, err := ni.MarshalJSON() + s.Require().Nil(err) + err = (&ri).UnmarshalJSON(bz) + s.Require().Nil(err) + + s.Require().Equal(ni, ri, "json mismatch; tc #%d, expected %s, actual %s", i, ni.String(), ri.String()) + s.Require().True(ni.i != ri.i, "pointer addresses are equal; tc #%d", i) + } + + for i := 0; i < 1000; i++ { + n := rand.Uint64() + ni := NewUint(n) + var ri Uint + + str, err := ni.Marshal() + s.Require().Nil(err) + err = (&ri).Unmarshal(str) + s.Require().Nil(err) + + s.Require().Equal(ni, ri, "binary mismatch; tc #%d, expected %s, actual %s", i, ni.String(), ri.String()) + s.Require().True(ni.i != ri.i, "pointer addresses are equal; tc #%d", i) + + bz, err := ni.MarshalJSON() + s.Require().Nil(err) + err = (&ri).UnmarshalJSON(bz) + s.Require().Nil(err) + + s.Require().Equal(ni, ri, "json mismatch; tc #%d, expected %s, actual %s", i, ni.String(), ri.String()) + s.Require().True(ni.i != ri.i, "pointer addresses are equal; tc #%d", i) + } +} + +func (s *internalIntTestSuite) TestSerializationOverflow() { + bx, _ := new(big.Int).SetString("115792089237316195423570985008687907853269984665640564039457584007913129639936", 10) + x := Int{bx} + y := new(Int) + + bz, err := x.Marshal() + s.Require().NoError(err) + + // require deserialization to fail due to overflow + s.Require().Error(y.Unmarshal(bz)) + + // require JSON deserialization to fail due to overflow + bz, err = x.MarshalJSON() + s.Require().NoError(err) + + s.Require().Error(y.UnmarshalJSON(bz)) +} + +func (s *internalIntTestSuite) TestDeserializeMaxERC20() { + bx, _ := new(big.Int).SetString("115792089237316195423570985008687907853269984665640564039457584007913129639935", 10) + x := Int{bx} + y := new(Int) + + bz, err := x.Marshal() + s.Require().NoError(err) + + // require deserialization to be successful + s.Require().NoError(y.Unmarshal(bz)) + + // require JSON deserialization to succeed + bz, err = x.MarshalJSON() + s.Require().NoError(err) + + s.Require().NoError(y.UnmarshalJSON(bz)) +} + +func (s *internalIntTestSuite) TestImmutabilityArithInt() { + size := 500 + + ops := []intOp{ + applyWithRand(Int.Add, (*big.Int).Add), + applyWithRand(Int.Sub, (*big.Int).Sub), + applyWithRand(Int.Mul, (*big.Int).Mul), + applyWithRand(Int.Quo, (*big.Int).Quo), + applyRawWithRand(Int.AddRaw, (*big.Int).Add), + applyRawWithRand(Int.SubRaw, (*big.Int).Sub), + applyRawWithRand(Int.MulRaw, (*big.Int).Mul), + applyRawWithRand(Int.QuoRaw, (*big.Int).Quo), + } + + for i := 0; i < 100; i++ { + uis := make([]Int, size) + bis := make([]*big.Int, size) + + n := rand.Int63() + ui := NewInt(n) + bi := new(big.Int).SetInt64(n) + + for j := 0; j < size; j++ { + op := ops[rand.Intn(len(ops))] + uis[j], bis[j] = op(ui, bi) + } + + for j := 0; j < size; j++ { + s.Require().Equal(0, bis[j].Cmp(uis[j].BigInt()), "Int is different from *big.Int. tc #%d, Int %s, *big.Int %s", j, uis[j].String(), bis[j].String()) + s.Require().Equal(NewIntFromBigInt(bis[j]), uis[j], "Int is different from *big.Int. tc #%d, Int %s, *big.Int %s", j, uis[j].String(), bis[j].String()) + s.Require().True(uis[j].i != bis[j], "Pointer addresses are equal. tc #%d, Int %s, *big.Int %s", j, uis[j].String(), bis[j].String()) + } + } +} + +type ( + intOp func(Int, *big.Int) (Int, *big.Int) + bigIntFunc func(*big.Int, *big.Int, *big.Int) *big.Int +) + +func applyWithRand(intFn func(Int, Int) Int, bigIntFn bigIntFunc) intOp { + return func(integer Int, bigInteger *big.Int) (Int, *big.Int) { + r := rand.Int63() + br := new(big.Int).SetInt64(r) + return intFn(integer, NewInt(r)), bigIntFn(new(big.Int), bigInteger, br) + } +} + +func applyRawWithRand(intFn func(Int, int64) Int, bigIntFn bigIntFunc) intOp { + return func(integer Int, bigInteger *big.Int) (Int, *big.Int) { + r := rand.Int63() + br := new(big.Int).SetInt64(r) + return intFn(integer, r), bigIntFn(new(big.Int), bigInteger, br) + } +} diff --git a/int_test.go b/int_test.go new file mode 100644 index 0000000..410e395 --- /dev/null +++ b/int_test.go @@ -0,0 +1,631 @@ +package math_test + +import ( + "encoding/json" + "fmt" + "math/big" + "math/rand" + "os" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "cosmossdk.io/math" +) + +type intTestSuite struct { + suite.Suite +} + +func TestIntTestSuite(t *testing.T) { + suite.Run(t, new(intTestSuite)) +} + +func (s *intTestSuite) SetupSuite() { + s.T().Parallel() +} + +func (s *intTestSuite) TestFromInt64() { + for n := 0; n < 20; n++ { + r := rand.Int63() + s.Require().Equal(r, math.NewInt(r).Int64()) + } +} + +func (s *intTestSuite) TestFromUint64() { + for n := 0; n < 20; n++ { + r := rand.Uint64() + s.Require().True(math.NewIntFromUint64(r).IsUint64()) + s.Require().Equal(r, math.NewIntFromUint64(r).Uint64()) + } +} + +func (s *intTestSuite) TestNewIntFromBigInt() { + i := math.NewIntFromBigInt(nil) + s.Require().True(i.IsNil()) + + r := big.NewInt(42) + i = math.NewIntFromBigInt(r) + s.Require().Equal(r, i.BigInt()) + + // modify r and ensure i doesn't change + r = r.SetInt64(100) + s.Require().NotEqual(r, i.BigInt()) +} + +func (s *intTestSuite) TestConvertToBigIntMutative() { + r := big.NewInt(42) + i := math.NewIntFromBigInt(r) + + // Compare value of BigInt & BigIntMut + s.Require().Equal(i.BigInt(), i.BigIntMut()) + + // Modify BigIntMut() pointer and ensure i.BigIntMut() & i.BigInt() change + p := i.BigIntMut() + p.SetInt64(50) + s.Require().Equal(big.NewInt(50), i.BigIntMut()) + s.Require().Equal(big.NewInt(50), i.BigInt()) + + // Modify big.Int() pointer and ensure i.BigIntMut() & i.BigInt() don't change + p = i.BigInt() + p.SetInt64(60) + s.Require().NotEqual(big.NewInt(60), i.BigIntMut()) + s.Require().NotEqual(big.NewInt(60), i.BigInt()) +} + +func (s *intTestSuite) TestIntPanic() { + // Max Int = 2^256-1 = 1.1579209e+77 + // Min Int = -(2^256-1) = -1.1579209e+77 + s.Require().NotPanics(func() { math.NewIntWithDecimal(4, 76) }) + i1 := math.NewIntWithDecimal(4, 76) + s.Require().NotPanics(func() { math.NewIntWithDecimal(5, 76) }) + i2 := math.NewIntWithDecimal(5, 76) + s.Require().NotPanics(func() { math.NewIntWithDecimal(6, 76) }) + i3 := math.NewIntWithDecimal(6, 76) + + s.Require().Panics(func() { math.NewIntWithDecimal(2, 77) }) + s.Require().Panics(func() { math.NewIntWithDecimal(9, 80) }) + + // Overflow check + s.Require().NotPanics(func() { i1.Add(i1) }) + s.Require().NotPanics(func() { i2.Add(i2) }) + s.Require().Panics(func() { i3.Add(i3) }) + + s.Require().NotPanics(func() { i1.Sub(i1.Neg()) }) + s.Require().NotPanics(func() { i2.Sub(i2.Neg()) }) + s.Require().Panics(func() { i3.Sub(i3.Neg()) }) + + s.Require().Panics(func() { i1.Mul(i1) }) + s.Require().Panics(func() { i2.Mul(i2) }) + s.Require().Panics(func() { i3.Mul(i3) }) + + s.Require().Panics(func() { i1.Neg().Mul(i1.Neg()) }) + s.Require().Panics(func() { i2.Neg().Mul(i2.Neg()) }) + s.Require().Panics(func() { i3.Neg().Mul(i3.Neg()) }) + + // // Underflow check + i3n := i3.Neg() + s.Require().NotPanics(func() { i3n.Sub(i1) }) + s.Require().NotPanics(func() { i3n.Sub(i2) }) + s.Require().Panics(func() { i3n.Sub(i3) }) + + s.Require().NotPanics(func() { i3n.Add(i1.Neg()) }) + s.Require().NotPanics(func() { i3n.Add(i2.Neg()) }) + s.Require().Panics(func() { i3n.Add(i3.Neg()) }) + + s.Require().Panics(func() { i1.Mul(i1.Neg()) }) + s.Require().Panics(func() { i2.Mul(i2.Neg()) }) + s.Require().Panics(func() { i3.Mul(i3.Neg()) }) + + // Bound check + intmax := math.NewIntFromBigInt(new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil), big.NewInt(1))) + intmin := intmax.Neg() + s.Require().NotPanics(func() { intmax.Add(math.ZeroInt()) }) + s.Require().NotPanics(func() { intmin.Sub(math.ZeroInt()) }) + s.Require().Panics(func() { intmax.Add(math.OneInt()) }) + s.Require().Panics(func() { intmin.Sub(math.OneInt()) }) + + s.Require().NotPanics(func() { math.NewIntFromBigInt(nil) }) + s.Require().True(math.NewIntFromBigInt(nil).IsNil()) + + // Division-by-zero check + s.Require().Panics(func() { i1.Quo(math.NewInt(0)) }) + + s.Require().NotPanics(func() { math.Int{}.BigInt() }) +} + +// Tests below uses randomness +// Since we are using *big.Int as underlying value +// and (U/)Int is immutable value(see TestImmutability(U/)Int) +// it is safe to use randomness in the tests +func (s *intTestSuite) TestIdentInt() { + for d := 0; d < 1000; d++ { + n := rand.Int63() + i := math.NewInt(n) + + ifromstr, ok := math.NewIntFromString(strconv.FormatInt(n, 10)) + s.Require().True(ok) + + cases := []int64{ + i.Int64(), + i.BigInt().Int64(), + ifromstr.Int64(), + math.NewIntFromBigInt(big.NewInt(n)).Int64(), + math.NewIntWithDecimal(n, 0).Int64(), + } + + for tcnum, tc := range cases { + s.Require().Equal(n, tc, "Int is modified during conversion. tc #%d", tcnum) + } + } +} + +func minint(i1, i2 int64) int64 { + if i1 < i2 { + return i1 + } + return i2 +} + +func maxint(i1, i2 int64) int64 { + if i1 > i2 { + return i1 + } + return i2 +} + +func (s *intTestSuite) TestArithInt() { + for d := 0; d < 1000; d++ { + n1 := int64(rand.Int31()) + i1 := math.NewInt(n1) + n2 := int64(rand.Int31()) + i2 := math.NewInt(n2) + + cases := []struct { + ires math.Int + nres int64 + }{ + {i1.Add(i2), n1 + n2}, + {i1.Sub(i2), n1 - n2}, + {i1.Mul(i2), n1 * n2}, + {i1.Quo(i2), n1 / n2}, + {i1.AddRaw(n2), n1 + n2}, + {i1.SubRaw(n2), n1 - n2}, + {i1.MulRaw(n2), n1 * n2}, + {i1.QuoRaw(n2), n1 / n2}, + {math.MinInt(i1, i2), minint(n1, n2)}, + {math.MaxInt(i1, i2), maxint(n1, n2)}, + {i1.Neg(), -n1}, + {i1.Abs(), n1}, + {i1.Neg().Abs(), n1}, + } + + for tcnum, tc := range cases { + s.Require().Equal(tc.nres, tc.ires.Int64(), "Int arithmetic operation does not match with int64 operation. tc #%d", tcnum) + } + } +} + +func (s *intTestSuite) TestCompInt() { + for d := 0; d < 1000; d++ { + n1 := int64(rand.Int31()) + i1 := math.NewInt(n1) + n2 := int64(rand.Int31()) + i2 := math.NewInt(n2) + + cases := []struct { + ires bool + nres bool + }{ + {i1.Equal(i2), n1 == n2}, + {i1.GT(i2), n1 > n2}, + {i1.LT(i2), n1 < n2}, + {i1.LTE(i2), n1 <= n2}, + } + + for tcnum, tc := range cases { + s.Require().Equal(tc.nres, tc.ires, "Int comparison operation does not match with int64 operation. tc #%d", tcnum) + } + } +} + +func randint() math.Int { + return math.NewInt(rand.Int63()) +} + +func (s *intTestSuite) TestImmutabilityAllInt() { + ops := []func(*math.Int){ + func(i *math.Int) { _ = i.Add(randint()) }, + func(i *math.Int) { _ = i.Sub(randint()) }, + func(i *math.Int) { _ = i.Mul(randint()) }, + func(i *math.Int) { _ = i.Quo(randint()) }, + func(i *math.Int) { _ = i.AddRaw(rand.Int63()) }, + func(i *math.Int) { _ = i.SubRaw(rand.Int63()) }, + func(i *math.Int) { _ = i.MulRaw(rand.Int63()) }, + func(i *math.Int) { _ = i.QuoRaw(rand.Int63()) }, + func(i *math.Int) { _ = i.Neg() }, + func(i *math.Int) { _ = i.Abs() }, + func(i *math.Int) { _ = i.IsZero() }, + func(i *math.Int) { _ = i.Sign() }, + func(i *math.Int) { _ = i.Equal(randint()) }, + func(i *math.Int) { _ = i.GT(randint()) }, + func(i *math.Int) { _ = i.LT(randint()) }, + func(i *math.Int) { _ = i.String() }, + } + + for i := 0; i < 1000; i++ { + n := rand.Int63() + ni := math.NewInt(n) + + for opnum, op := range ops { + op(&ni) + + s.Require().Equal(n, ni.Int64(), "Int is modified by operation. tc #%d", opnum) + s.Require().Equal(math.NewInt(n), ni, "Int is modified by operation. tc #%d", opnum) + } + } +} + +func (s *intTestSuite) TestEncodingTableInt() { + var i math.Int + + cases := []struct { + i math.Int + jsonBz []byte + rawBz []byte + }{ + { + math.NewInt(0), + []byte("\"0\""), + []byte{0x30}, + }, + { + math.NewInt(100), + []byte("\"100\""), + []byte{0x31, 0x30, 0x30}, + }, + { + math.NewInt(-100), + []byte("\"-100\""), + []byte{0x2d, 0x31, 0x30, 0x30}, + }, + { + math.NewInt(51842), + []byte("\"51842\""), + []byte{0x35, 0x31, 0x38, 0x34, 0x32}, + }, + { + math.NewInt(-51842), + []byte("\"-51842\""), + []byte{0x2d, 0x35, 0x31, 0x38, 0x34, 0x32}, + }, + { + math.NewInt(19513368), + []byte("\"19513368\""), + []byte{0x31, 0x39, 0x35, 0x31, 0x33, 0x33, 0x36, 0x38}, + }, + { + math.NewInt(-19513368), + []byte("\"-19513368\""), + []byte{0x2d, 0x31, 0x39, 0x35, 0x31, 0x33, 0x33, 0x36, 0x38}, + }, + { + math.NewInt(999999999999), + []byte("\"999999999999\""), + []byte{0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39}, + }, + { + math.NewInt(-999999999999), + []byte("\"-999999999999\""), + []byte{0x2d, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39}, + }, + } + + for tcnum, tc := range cases { + bz, err := tc.i.MarshalJSON() + s.Require().Nil(err, "Error marshaling Int. tc #%d, err %s", tcnum, err) + s.Require().Equal(tc.jsonBz, bz, "Marshaled value is different from exported. tc #%d", tcnum) + + err = (&i).UnmarshalJSON(bz) + s.Require().Nil(err, "Error unmarshaling Int. tc #%d, err %s", tcnum, err) + s.Require().Equal(tc.i, i, "Unmarshaled value is different from exported. tc #%d", tcnum) + + bz, err = tc.i.Marshal() + s.Require().Nil(err, "Error marshaling Int. tc #%d, err %s", tcnum, err) + s.Require().Equal(tc.rawBz, bz, "Marshaled value is different from exported. tc #%d", tcnum) + + err = (&i).Unmarshal(bz) + s.Require().Nil(err, "Error unmarshaling Int. tc #%d, err %s", tcnum, err) + s.Require().Equal(tc.i, i, "Unmarshaled value is different from exported. tc #%d", tcnum) + } +} + +func (s *intTestSuite) TestEncodingTableUint() { + var i math.Uint + + cases := []struct { + i math.Uint + jsonBz []byte + rawBz []byte + }{ + { + math.NewUint(0), + []byte("\"0\""), + []byte{0x30}, + }, + { + math.NewUint(100), + []byte("\"100\""), + []byte{0x31, 0x30, 0x30}, + }, + { + math.NewUint(51842), + []byte("\"51842\""), + []byte{0x35, 0x31, 0x38, 0x34, 0x32}, + }, + { + math.NewUint(19513368), + []byte("\"19513368\""), + []byte{0x31, 0x39, 0x35, 0x31, 0x33, 0x33, 0x36, 0x38}, + }, + { + math.NewUint(999999999999), + []byte("\"999999999999\""), + []byte{0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39}, + }, + } + + for tcnum, tc := range cases { + bz, err := tc.i.MarshalJSON() + s.Require().Nil(err, "Error marshaling Int. tc #%d, err %s", tcnum, err) + s.Require().Equal(tc.jsonBz, bz, "Marshaled value is different from exported. tc #%d", tcnum) + + err = (&i).UnmarshalJSON(bz) + s.Require().Nil(err, "Error unmarshaling Int. tc #%d, err %s", tcnum, err) + s.Require().Equal(tc.i, i, "Unmarshaled value is different from exported. tc #%d", tcnum) + + bz, err = tc.i.Marshal() + s.Require().Nil(err, "Error marshaling Int. tc #%d, err %s", tcnum, err) + s.Require().Equal(tc.rawBz, bz, "Marshaled value is different from exported. tc #%d", tcnum) + + err = (&i).Unmarshal(bz) + s.Require().Nil(err, "Error unmarshaling Int. tc #%d, err %s", tcnum, err) + s.Require().Equal(tc.i, i, "Unmarshaled value is different from exported. tc #%d", tcnum) + } +} + +func (s *intTestSuite) TestIntMod() { + tests := []struct { + name string + x int64 + y int64 + ret int64 + wantPanic bool + }{ + {"3 % 10", 3, 10, 3, false}, + {"10 % 3", 10, 3, 1, false}, + {"4 % 2", 4, 2, 0, false}, + {"2 % 0", 2, 0, 0, true}, + } + + for _, tt := range tests { + if tt.wantPanic { + s.Require().Panics(func() { math.NewInt(tt.x).Mod(math.NewInt(tt.y)) }) + s.Require().Panics(func() { math.NewInt(tt.x).ModRaw(tt.y) }) + return + } + s.Require().True(math.NewInt(tt.x).Mod(math.NewInt(tt.y)).Equal(math.NewInt(tt.ret))) + s.Require().True(math.NewInt(tt.x).ModRaw(tt.y).Equal(math.NewInt(tt.ret))) + } +} + +func (s *intTestSuite) TestIntEq() { + _, resp, _, _, _ := math.IntEq(s.T(), math.ZeroInt(), math.ZeroInt()) + s.Require().True(resp) + _, resp, _, _, _ = math.IntEq(s.T(), math.OneInt(), math.ZeroInt()) + s.Require().False(resp) +} + +func TestRoundTripMarshalToInt(t *testing.T) { + values := []int64{ + 0, + 1, + 1 << 10, + 1<<10 - 3, + 1<<63 - 1, + 1<<32 - 7, + 1<<22 - 8, + } + + for _, value := range values { + value := value + t.Run(fmt.Sprintf("%d", value), func(t *testing.T) { + t.Parallel() + + var scratch [20]byte + iv := math.NewInt(value) + n, err := iv.MarshalTo(scratch[:]) + if err != nil { + t.Fatal(err) + } + rt := new(math.Int) + if err := rt.Unmarshal(scratch[:n]); err != nil { + t.Fatal(err) + } + if !rt.Equal(iv) { + t.Fatalf("roundtrip=%q != original=%q", rt, iv) + } + }) + } +} + +func TestFormatInt(t *testing.T) { + type integerTest []string + var testcases []integerTest + raw, err := os.ReadFile("testdata/integers.json") + require.NoError(t, err) + err = json.Unmarshal(raw, &testcases) + require.NoError(t, err) + + for _, tc := range testcases { + out, err := math.FormatInt(tc[0]) + require.NoError(t, err) + require.Equal(t, tc[1], out) + } +} + +func TestFormatIntNonDigits(t *testing.T) { + badCases := []string{ + "a10", + "1a10", + "p1a10", + "10p", + "--10", + "😎😎", + "11111111111133333333333333333333333333333a", + "11111111111133333333333333333333333333333 192892", + } + + for _, value := range badCases { + value := value + t.Run(value, func(t *testing.T) { + s, err := math.FormatInt(value) + if err == nil { + t.Fatal("Expected an error") + } + if g, w := err.Error(), "but got non-digits in"; !strings.Contains(g, w) { + t.Errorf("Error mismatch\nGot: %q\nWant substring: %q", g, w) + } + if s != "" { + t.Fatalf("Got a non-empty string: %q", s) + } + }) + } +} + +func TestFormatIntEmptyString(t *testing.T) { + _, err := math.FormatInt("") + require.ErrorContains(t, err, "cannot format empty string") +} + +func TestFormatIntCorrectness(t *testing.T) { + tests := []struct { + in string + want string + }{ + {"0", "0"}, + {"-2", "-2"}, + {"10", "10"}, + {"123", "123"}, + {"1234", "1'234"}, + {"12345", "12'345"}, + {"123456", "123'456"}, + {"-123456", "-123'456"}, + {"1234567", "1'234'567"}, + {"12345678", "12'345'678"}, + {"123456789", "123'456'789"}, + {"12345678910", "12'345'678'910"}, + {"9999999999999999", "9'999'999'999'999'999"}, + {"-9999999999999999", "-9'999'999'999'999'999"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.in, func(t *testing.T) { + got, err := math.FormatInt(tt.in) + if err != nil { + t.Fatal(err) + } + + if got != tt.want { + t.Fatalf("Mismatch:\n\tGot: %q\n\tWant: %q", got, tt.want) + } + }) + } +} + +var sizeTests = []struct { + s string + want int +}{ + {"", 1}, + {"0", 1}, + {"-0", 1}, + {"-10", 3}, + {"-10000", 6}, + {"10000", 5}, + {"100000", 6}, + {"99999", 5}, + {"9999999999", 10}, + {"10000000000", 11}, + {"99999999999", 11}, + {"999999999999", 12}, + {"9999999999999", 13}, + {"99999999999999", 14}, + {"999999999999999", 15}, + {"1000000000000000", 16}, + {"9999999999999999", 16}, + {"99999999999999999", 17}, + {"999999999999999999", 18}, + {"-999999999999999999", 19}, + {"9000000000000000000", 19}, + {"-9999999999999990000", 20}, + {"9999999999999990000", 19}, + {"9999999999999999000", 19}, + {"9999999999999999999", 19}, + {"-9999999999999999999", 20}, + {"18446744073709551616", 20}, + {"18446744073709551618", 20}, + {"184467440737095516181", 21}, + {"100000000000000000000000", 24}, + {"1000000000000000000000000000", 28}, + {"9000000000099999999999999999", 28}, + {"9999999999999999999999999999", 28}, + {"9903520314283042199192993792", 28}, + {"340282366920938463463374607431768211456", 39}, + {"3402823669209384634633746074317682114569999", 43}, + {"9999999999999999999999999999999999999999999", 43}, + {"99999999999999999999999999999999999999999999", 44}, + {"999999999999999999999999999999999999999999999", 45}, + {"90000000000999999999999999999000000000099999999999999999", 56}, + {"-90000000000999999999999999999000000000099999999999999999", 57}, + {"9000000000099999999999999999900000000009999999999999999990", 58}, + {"990000000009999999999999999990000000000999999999999999999999", 60}, + {"99000000000999999999999999999000000000099999999999999999999919", 62}, + {"90000000000999999990000000000000000000000000000000000000000000", 62}, + {"99999999999999999999999999990000000000000000000000000000000000", 62}, + {"11111111111111119999999999990000000000000000000000000000000000", 62}, + {"99000000000999999999999999999000000000099999999999999999999919", 62}, + {"10000000000000000000000000000000000000000000000000000000000000", 62}, + {"10000000000000000000000000000000000000000000000000000000000000000000000000000", 77}, + {"99999999999999999999999999999999999999999999999999999999999999999999999999999", 77}, + {"110000000000000000000000000000000000000000000000000000000000000000000000000009", 78}, +} + +func TestNewIntFromString(t *testing.T) { + for _, st := range sizeTests { + ii, _ := math.NewIntFromString(st.s) + require.Equal(t, st.want, ii.Size(), "size mismatch for %q", st.s) + } +} + +func BenchmarkIntSize(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, st := range sizeTests { + ii, _ := math.NewIntFromString(st.s) + got := ii.Size() + if got != st.want { + b.Errorf("%q:: got=%d, want=%d", st.s, got, st.want) + } + sink = got + } + } + if sink == nil { + b.Fatal("Benchmark did not run!") + } + sink = nil +} diff --git a/max_min.go b/max_min.go new file mode 100644 index 0000000..407dd81 --- /dev/null +++ b/max_min.go @@ -0,0 +1,29 @@ +package math + +import "golang.org/x/exp/constraints" + +func Max[T constraints.Ordered](a, b T, rest ...T) T { + max := a + if b > a { + max = b + } + for _, val := range rest { + if val > max { + max = val + } + } + return max +} + +func Min[T constraints.Ordered](a, b T, rest ...T) T { + min := a + if b < a { + min = b + } + for _, val := range rest { + if val < min { + min = val + } + } + return min +} diff --git a/max_min_test.go b/max_min_test.go new file mode 100644 index 0000000..ae700f7 --- /dev/null +++ b/max_min_test.go @@ -0,0 +1,19 @@ +package math + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMax(t *testing.T) { + maxInt := Max(10, -10, 20, 1_000_000, 10, 8, -11_000_000, 20) + require.Equal(t, 1_000_000, maxInt, "invalid max for int") + minInt := Min(10, -10, 20, 1_000_000, 10, 8, -11_000_000, 20) + require.Equal(t, -11_000_000, minInt, "invalid min for int") + + maxf64 := Max(10.1, -10.1, 20.8, 1_000_000.9, 10.5, 8.4, -11_000_000.9, 20.7) + require.Equal(t, 1_000_000.9, maxf64, "invalid max for float64") + minf64 := Min(10.1, -10.1, 20.8, 1_000_000.9, 10.5, 8.4, -11_000_000.9, 20.7) + require.Equal(t, -11_000_000.9, minf64, "invalid min for float64") +} diff --git a/proto.go b/proto.go new file mode 100644 index 0000000..d1afa1a --- /dev/null +++ b/proto.go @@ -0,0 +1,15 @@ +package math + +// customProtobufType defines the interface custom gogo proto types must implement +// in order to be used as a "customtype" extension. +// +// ref: https://github.com/cosmos/gogoproto/blob/master/custom_types.md +type customProtobufType interface { + Marshal() ([]byte, error) + MarshalTo(data []byte) (n int, err error) + Unmarshal(data []byte) error + Size() int + + MarshalJSON() ([]byte, error) + UnmarshalJSON(data []byte) error +} diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 0000000..d4a36b3 --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,14 @@ +sonar.projectKey=cosmos-sdk-math +sonar.organization=cosmos + +sonar.projectName=Cosmos SDK - Math +sonar.project.monorepo.enabled=true + +sonar.sources=. +sonar.exclusions=**/*_test.go +sonar.tests=. +sonar.test.inclusions=**/*_test.go +sonar.go.coverage.reportPaths=coverage.out + +sonar.sourceEncoding=UTF-8 +sonar.scm.provider=git \ No newline at end of file diff --git a/testdata/decimals.json b/testdata/decimals.json new file mode 100644 index 0000000..3564b59 --- /dev/null +++ b/testdata/decimals.json @@ -0,0 +1,47 @@ +[ + ["0", "0"], + ["1", "1"], + ["12", "12"], + ["123", "123"], + ["1234", "1'234"], + ["0.1", "0.1"], + ["0.01", "0.01"], + ["0.001", "0.001"], + ["0.0001", "0.0001"], + ["0.00001", "0.00001"], + ["0.000001", "0.000001"], + ["0.0000001", "0.0000001"], + ["0.00000001", "0.00000001"], + ["0.000000001", "0.000000001"], + ["0.0000000001", "0.0000000001"], + ["0.00000000001", "0.00000000001"], + ["0.000000000001", "0.000000000001"], + ["0.0000000000001", "0.0000000000001"], + ["0.00000000000001", "0.00000000000001"], + ["0.000000000000001", "0.000000000000001"], + ["0.0000000000000001", "0.0000000000000001"], + ["0.00000000000000001", "0.00000000000000001"], + ["0.000000000000000001", "0.000000000000000001"], + ["0.100000000000000000", "0.1"], + ["0.010000000000000000", "0.01"], + ["0.001000000000000000", "0.001"], + ["0.000100000000000000", "0.0001"], + ["0.000010000000000000", "0.00001"], + ["0.000001000000000000", "0.000001"], + ["0.000000100000000000", "0.0000001"], + ["0.000000010000000000", "0.00000001"], + ["0.000000001000000000", "0.000000001"], + ["0.000000000100000000", "0.0000000001"], + ["0.000000000010000000", "0.00000000001"], + ["0.000000000001000000", "0.000000000001"], + ["0.000000000000100000", "0.0000000000001"], + ["0.000000000000010000", "0.00000000000001"], + ["0.000000000000001000", "0.000000000000001"], + ["0.000000000000000100", "0.0000000000000001"], + ["0.000000000000000010", "0.00000000000000001"], + ["0.000000000000000001", "0.000000000000000001"], + ["-10.0", "-10"], + ["-10000", "-10'000"], + ["-9999", "-9'999"], + ["-999999999999", "-999'999'999'999"] +] diff --git a/testdata/integers.json b/testdata/integers.json new file mode 100644 index 0000000..0faa4f5 --- /dev/null +++ b/testdata/integers.json @@ -0,0 +1,19 @@ +[ + ["0", "0"], + ["1", "1"], + ["12", "12"], + ["123", "123"], + ["1234", "1'234"], + ["12345", "12'345"], + ["123456", "123'456"], + ["1234567", "1'234'567"], + ["9007199254740991", "9'007'199'254'740'991"], + ["9007199254740992", "9'007'199'254'740'992"], + ["18446744073709551615", "18'446'744'073'709'551'615"], + ["18446744073709551616", "18'446'744'073'709'551'616"], + ["340282366920938463463374607431768211455", "340'282'366'920'938'463'463'374'607'431'768'211'455"], + ["01", "1"], + ["001", "1"], + ["0001", "1"], + ["00001", "1"] +] diff --git a/uint.go b/uint.go new file mode 100644 index 0000000..be588b0 --- /dev/null +++ b/uint.go @@ -0,0 +1,278 @@ +package math + +import ( + "errors" + "fmt" + "math/big" +) + +// Uint wraps integer with 256 bit range bound +// Checks overflow, underflow and division by zero +// Exists in range from 0 to 2^256-1 +type Uint struct { + i *big.Int +} + +// BigInt converts Uint to big.Int +func (u Uint) BigInt() *big.Int { + if u.IsNil() { + return nil + } + return new(big.Int).Set(u.i) +} + +// IsNil returns true if Uint is uninitialized +func (u Uint) IsNil() bool { + return u.i == nil +} + +// NewUintFromBigUint constructs Uint from big.Uint +func NewUintFromBigInt(i *big.Int) Uint { + u, err := checkNewUint(i) + if err != nil { + panic(fmt.Errorf("overflow: %s", err)) + } + return u +} + +// NewUint constructs Uint from int64 +func NewUint(n uint64) Uint { + i := new(big.Int) + i.SetUint64(n) + return NewUintFromBigInt(i) +} + +// NewUintFromString constructs Uint from string +func NewUintFromString(s string) Uint { + u, err := ParseUint(s) + if err != nil { + panic(err) + } + return u +} + +// ZeroUint returns unsigned zero. +func ZeroUint() Uint { return Uint{big.NewInt(0)} } + +// OneUint returns Uint value with one. +func OneUint() Uint { return Uint{big.NewInt(1)} } + +var _ customProtobufType = (*Uint)(nil) + +// Uint64 converts Uint to uint64 +// Panics if the value is out of range +func (u Uint) Uint64() uint64 { + if !u.i.IsUint64() { + panic("Uint64() out of bound") + } + return u.i.Uint64() +} + +// IsZero returns 1 if the uint equals to 0. +func (u Uint) IsZero() bool { return u.Equal(ZeroUint()) } + +// Equal compares two Uints +func (u Uint) Equal(u2 Uint) bool { return equal(u.i, u2.i) } + +// GT returns true if first Uint is greater than second +func (u Uint) GT(u2 Uint) bool { return gt(u.i, u2.i) } + +// GTE returns true if first Uint is greater than second +func (u Uint) GTE(u2 Uint) bool { return u.GT(u2) || u.Equal(u2) } + +// LT returns true if first Uint is lesser than second +func (u Uint) LT(u2 Uint) bool { return lt(u.i, u2.i) } + +// LTE returns true if first Uint is lesser than or equal to the second +func (u Uint) LTE(u2 Uint) bool { return !u.GT(u2) } + +// Add adds Uint from another +func (u Uint) Add(u2 Uint) Uint { return NewUintFromBigInt(new(big.Int).Add(u.i, u2.i)) } + +// Add convert uint64 and add it to Uint +func (u Uint) AddUint64(u2 uint64) Uint { return u.Add(NewUint(u2)) } + +// Sub adds Uint from another +func (u Uint) Sub(u2 Uint) Uint { return NewUintFromBigInt(new(big.Int).Sub(u.i, u2.i)) } + +// SubUint64 adds Uint from another +func (u Uint) SubUint64(u2 uint64) Uint { return u.Sub(NewUint(u2)) } + +// Mul multiplies two Uints +func (u Uint) Mul(u2 Uint) (res Uint) { + return NewUintFromBigInt(new(big.Int).Mul(u.i, u2.i)) +} + +// Mul multiplies two Uints +func (u Uint) MulUint64(u2 uint64) (res Uint) { return u.Mul(NewUint(u2)) } + +// Quo divides Uint with Uint +func (u Uint) Quo(u2 Uint) (res Uint) { return NewUintFromBigInt(div(u.i, u2.i)) } + +// Mod returns remainder after dividing with Uint +func (u Uint) Mod(u2 Uint) Uint { + if u2.IsZero() { + panic("division-by-zero") + } + return Uint{mod(u.i, u2.i)} +} + +// Incr increments the Uint by one. +func (u Uint) Incr() Uint { + return u.Add(OneUint()) +} + +// Decr decrements the Uint by one. +// Decr will panic if the Uint is zero. +func (u Uint) Decr() Uint { + return u.Sub(OneUint()) +} + +// Quo divides Uint with uint64 +func (u Uint) QuoUint64(u2 uint64) Uint { return u.Quo(NewUint(u2)) } + +// Return the minimum of the Uints +func MinUint(u1, u2 Uint) Uint { return NewUintFromBigInt(min(u1.i, u2.i)) } + +// Return the maximum of the Uints +func MaxUint(u1, u2 Uint) Uint { return NewUintFromBigInt(max(u1.i, u2.i)) } + +// Human readable string +func (u Uint) String() string { return u.i.String() } + +// MarshalJSON defines custom encoding scheme +func (u Uint) MarshalJSON() ([]byte, error) { + if u.i == nil { // Necessary since default Uint initialization has i.i as nil + u.i = new(big.Int) + } + return marshalJSON(u.i) +} + +// UnmarshalJSON defines custom decoding scheme +func (u *Uint) UnmarshalJSON(bz []byte) error { + if u.i == nil { // Necessary since default Uint initialization has i.i as nil + u.i = new(big.Int) + } + return unmarshalJSON(u.i, bz) +} + +// Marshal implements the gogo proto custom type interface. +func (u Uint) Marshal() ([]byte, error) { + if u.i == nil { + u.i = new(big.Int) + } + return u.i.MarshalText() +} + +// MarshalTo implements the gogo proto custom type interface. +func (u *Uint) MarshalTo(data []byte) (n int, err error) { + if u.i == nil { + u.i = new(big.Int) + } + if u.i.BitLen() == 0 { // The value 0 + n = copy(data, []byte{0x30}) + return n, nil + } + + bz, err := u.Marshal() + if err != nil { + return 0, err + } + + n = copy(data, bz) + return n, nil +} + +// Unmarshal implements the gogo proto custom type interface. +func (u *Uint) Unmarshal(data []byte) error { + if len(data) == 0 { + u = nil + return nil + } + + if u.i == nil { + u.i = new(big.Int) + } + + if err := u.i.UnmarshalText(data); err != nil { + return err + } + + // Finally check for overflow. + return UintOverflow(u.i) +} + +// Size implements the gogo proto custom type interface. +func (u *Uint) Size() int { + bz, _ := u.Marshal() + return len(bz) +} + +// Override Amino binary serialization by proxying to protobuf. +func (u Uint) MarshalAmino() ([]byte, error) { return u.Marshal() } +func (u *Uint) UnmarshalAmino(bz []byte) error { return u.Unmarshal(bz) } + +// UintOverflow returns true if a given unsigned integer overflows and false +// otherwise. +func UintOverflow(i *big.Int) error { + if i.Sign() < 0 { + return errors.New("non-positive integer") + } + + if g, w := i.BitLen(), MaxBitLen; g > w { + return fmt.Errorf("integer out of range; got: %d, max: %d", g, w) + } + return nil +} + +// ParseUint reads a string-encoded Uint value and return a Uint. +func ParseUint(s string) (Uint, error) { + i, ok := new(big.Int).SetString(s, 0) + if !ok { + return Uint{}, fmt.Errorf("cannot convert %q to big.Int", s) + } + return checkNewUint(i) +} + +func checkNewUint(i *big.Int) (Uint, error) { + if err := UintOverflow(i); err != nil { + return Uint{}, err + } + return Uint{new(big.Int).Set(i)}, nil +} + +// RelativePow raises x to the power of n, where x (and the result, z) are scaled by factor b +// for example, RelativePow(210, 2, 100) = 441 (2.1^2 = 4.41) +func RelativePow(x, n, b Uint) (z Uint) { + if x.IsZero() { + if n.IsZero() { + z = OneUint() // 0^0 = 1 + return z + } + z = ZeroUint() // otherwise 0^a = 0 + return z + } + + z = x + if n.Mod(NewUint(2)).Equal(ZeroUint()) { + z = b + } + + halfOfB := b.Quo(NewUint(2)) + n = n.Quo(NewUint(2)) + + for n.GT(ZeroUint()) { + xSquared := x.Mul(x) + xSquaredRounded := xSquared.Add(halfOfB) + + x = xSquaredRounded.Quo(b) + + if n.Mod(NewUint(2)).Equal(OneUint()) { + zx := z.Mul(x) + zxRounded := zx.Add(halfOfB) + z = zxRounded.Quo(b) + } + n = n.Quo(NewUint(2)) + } + return z +} diff --git a/uint_internal_test.go b/uint_internal_test.go new file mode 100644 index 0000000..a3f1fb4 --- /dev/null +++ b/uint_internal_test.go @@ -0,0 +1,54 @@ +package math + +import ( + "math/big" + "math/rand" + "strconv" + "testing" + + "github.com/stretchr/testify/suite" +) + +type uintInternalTestSuite struct { + suite.Suite +} + +func TestUintInternalTestSuite(t *testing.T) { + suite.Run(t, new(uintInternalTestSuite)) +} + +func (s *uintInternalTestSuite) SetupSuite() { + s.T().Parallel() +} + +func (s *uintInternalTestSuite) TestIdentUint() { + for d := 0; d < 1000; d++ { + n := rand.Uint64() + i := NewUint(n) + + ifromstr := NewUintFromString(strconv.FormatUint(n, 10)) + + cases := []uint64{ + i.Uint64(), + i.BigInt().Uint64(), + i.i.Uint64(), + ifromstr.Uint64(), + NewUintFromBigInt(new(big.Int).SetUint64(n)).Uint64(), + } + + for tcnum, tc := range cases { + s.Require().Equal(n, tc, "Uint is modified during conversion. tc #%d", tcnum) + } + } +} + +func (s *uintInternalTestSuite) TestUintSize() { + x := Uint{i: nil} + s.Require().Equal(1, x.Size()) + x = NewUint(0) + s.Require().Equal(1, x.Size()) + x = NewUint(10) + s.Require().Equal(2, x.Size()) + x = NewUint(100) + s.Require().Equal(3, x.Size()) +} diff --git a/uint_test.go b/uint_test.go new file mode 100644 index 0000000..b75b751 --- /dev/null +++ b/uint_test.go @@ -0,0 +1,381 @@ +package math_test + +import ( + "fmt" + "math" + "math/big" + "math/rand" + "strings" + "testing" + + "github.com/stretchr/testify/suite" + + sdkmath "cosmossdk.io/math" +) + +type uintTestSuite struct { + suite.Suite +} + +func TestUnitTestSuite(t *testing.T) { + suite.Run(t, new(uintTestSuite)) +} + +func (s *uintTestSuite) SetupSuite() { + s.T().Parallel() +} + +func (s *uintTestSuite) TestUintPanics() { + // Max Uint = 1.15e+77 + // Min Uint = 0 + u1 := sdkmath.NewUint(0) + u2 := sdkmath.OneUint() + + s.Require().Equal(uint64(0), u1.Uint64()) + s.Require().Equal(uint64(1), u2.Uint64()) + + s.Require().Panics(func() { sdkmath.NewUintFromBigInt(big.NewInt(-5)) }) + s.Require().Panics(func() { sdkmath.NewUintFromString("-1") }) + s.Require().NotPanics(func() { + s.Require().True(sdkmath.NewUintFromString("0").Equal(sdkmath.ZeroUint())) + s.Require().True(sdkmath.NewUintFromString("5").Equal(sdkmath.NewUint(5))) + }) + + // Overflow check + s.Require().True(u1.Add(u1).Equal(sdkmath.ZeroUint())) + s.Require().True(u1.Add(sdkmath.OneUint()).Equal(sdkmath.OneUint())) + s.Require().Equal(uint64(0), u1.Uint64()) + s.Require().Equal(uint64(1), sdkmath.OneUint().Uint64()) + s.Require().Panics(func() { u1.SubUint64(2) }) + s.Require().True(u1.SubUint64(0).Equal(sdkmath.ZeroUint())) + s.Require().True(u2.Add(sdkmath.OneUint()).Sub(sdkmath.OneUint()).Equal(sdkmath.OneUint())) // i2 == 1 + s.Require().True(u2.Add(sdkmath.OneUint()).Mul(sdkmath.NewUint(5)).Equal(sdkmath.NewUint(10))) // i2 == 10 + s.Require().True(sdkmath.NewUint(7).Quo(sdkmath.NewUint(2)).Equal(sdkmath.NewUint(3))) + s.Require().True(sdkmath.NewUint(0).Quo(sdkmath.NewUint(2)).Equal(sdkmath.ZeroUint())) + s.Require().True(sdkmath.NewUint(5).MulUint64(4).Equal(sdkmath.NewUint(20))) + s.Require().True(sdkmath.NewUint(5).MulUint64(0).Equal(sdkmath.ZeroUint())) + + uintmax := sdkmath.NewUintFromBigInt(new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil), big.NewInt(1))) + uintmin := sdkmath.ZeroUint() + + // divs by zero + s.Require().Panics(func() { sdkmath.OneUint().Mul(sdkmath.ZeroUint().SubUint64(uint64(1))) }) + s.Require().Panics(func() { sdkmath.OneUint().QuoUint64(0) }) + s.Require().Panics(func() { sdkmath.OneUint().Quo(sdkmath.ZeroUint()) }) + s.Require().Panics(func() { sdkmath.ZeroUint().QuoUint64(0) }) + s.Require().Panics(func() { sdkmath.OneUint().Quo(sdkmath.ZeroUint().Sub(sdkmath.OneUint())) }) + s.Require().Panics(func() { uintmax.Add(sdkmath.OneUint()) }) + s.Require().Panics(func() { uintmax.Incr() }) + s.Require().Panics(func() { uintmin.Sub(sdkmath.OneUint()) }) + s.Require().Panics(func() { uintmin.Decr() }) + + s.Require().NotPanics(func() { sdkmath.Uint{}.BigInt() }) + + s.Require().Equal(uint64(0), sdkmath.MinUint(sdkmath.ZeroUint(), sdkmath.OneUint()).Uint64()) + s.Require().Equal(uint64(1), sdkmath.MaxUint(sdkmath.ZeroUint(), sdkmath.OneUint()).Uint64()) + + // comparison ops + s.Require().True( + sdkmath.OneUint().GT(sdkmath.ZeroUint()), + ) + s.Require().False( + sdkmath.OneUint().LT(sdkmath.ZeroUint()), + ) + s.Require().True( + sdkmath.OneUint().GTE(sdkmath.ZeroUint()), + ) + s.Require().False( + sdkmath.OneUint().LTE(sdkmath.ZeroUint()), + ) + + s.Require().False(sdkmath.ZeroUint().GT(sdkmath.OneUint())) + s.Require().True(sdkmath.ZeroUint().LT(sdkmath.OneUint())) + s.Require().False(sdkmath.ZeroUint().GTE(sdkmath.OneUint())) + s.Require().True(sdkmath.ZeroUint().LTE(sdkmath.OneUint())) +} + +func (s *uintTestSuite) TestIsNil() { + s.Require().False(sdkmath.OneUint().IsNil()) + s.Require().True(sdkmath.Uint{}.IsNil()) +} + +func (s *uintTestSuite) TestArithUint() { + for d := 0; d < 1000; d++ { + n1 := uint64(rand.Uint32()) + u1 := sdkmath.NewUint(n1) + n2 := uint64(rand.Uint32()) + u2 := sdkmath.NewUint(n2) + + cases := []struct { + ures sdkmath.Uint + nres uint64 + }{ + {u1.Add(u2), n1 + n2}, + {u1.Mul(u2), n1 * n2}, + {u1.Quo(u2), n1 / n2}, + {u1.AddUint64(n2), n1 + n2}, + {u1.MulUint64(n2), n1 * n2}, + {u1.QuoUint64(n2), n1 / n2}, + {sdkmath.MinUint(u1, u2), minuint(n1, n2)}, + {sdkmath.MaxUint(u1, u2), maxuint(n1, n2)}, + {u1.Incr(), n1 + 1}, + } + + for tcnum, tc := range cases { + s.Require().Equal(tc.nres, tc.ures.Uint64(), "Uint arithmetic operation does not match with uint64 operation. tc #%d", tcnum) + } + + if n2 > n1 { + n1, n2 = n2, n1 + u1, u2 = sdkmath.NewUint(n1), sdkmath.NewUint(n2) + } + + subs := []struct { + ures sdkmath.Uint + nres uint64 + }{ + {u1.Sub(u2), n1 - n2}, + {u1.SubUint64(n2), n1 - n2}, + {u1.Decr(), n1 - 1}, + } + + for tcnum, tc := range subs { + s.Require().Equal(tc.nres, tc.ures.Uint64(), "Uint subtraction does not match with uint64 operation. tc #%d", tcnum) + } + } +} + +func (s *uintTestSuite) TestCompUint() { + for d := 0; d < 10000; d++ { + n1 := rand.Uint64() + i1 := sdkmath.NewUint(n1) + n2 := rand.Uint64() + i2 := sdkmath.NewUint(n2) + + cases := []struct { + ires bool + nres bool + }{ + {i1.Equal(i2), n1 == n2}, + {i1.GT(i2), n1 > n2}, + {i1.LT(i2), n1 < n2}, + {i1.GTE(i2), !i1.LT(i2)}, + {!i1.GTE(i2), i1.LT(i2)}, + {i1.LTE(i2), n1 <= n2}, + {i2.LTE(i1), n2 <= n1}, + } + + for tcnum, tc := range cases { + s.Require().Equal(tc.nres, tc.ires, "Uint comparison operation does not match with uint64 operation. tc #%d", tcnum) + } + } +} + +func (s *uintTestSuite) TestImmutabilityAllUint() { + ops := []func(*sdkmath.Uint){ + func(i *sdkmath.Uint) { _ = i.Add(sdkmath.NewUint(rand.Uint64())) }, + func(i *sdkmath.Uint) { _ = i.Sub(sdkmath.NewUint(rand.Uint64() % i.Uint64())) }, + func(i *sdkmath.Uint) { _ = i.Mul(randuint()) }, + func(i *sdkmath.Uint) { _ = i.Quo(randuint()) }, + func(i *sdkmath.Uint) { _ = i.AddUint64(rand.Uint64()) }, + func(i *sdkmath.Uint) { _ = i.SubUint64(rand.Uint64() % i.Uint64()) }, + func(i *sdkmath.Uint) { _ = i.MulUint64(rand.Uint64()) }, + func(i *sdkmath.Uint) { _ = i.QuoUint64(rand.Uint64()) }, + func(i *sdkmath.Uint) { _ = i.IsZero() }, + func(i *sdkmath.Uint) { _ = i.Equal(randuint()) }, + func(i *sdkmath.Uint) { _ = i.GT(randuint()) }, + func(i *sdkmath.Uint) { _ = i.GTE(randuint()) }, + func(i *sdkmath.Uint) { _ = i.LT(randuint()) }, + func(i *sdkmath.Uint) { _ = i.LTE(randuint()) }, + func(i *sdkmath.Uint) { _ = i.String() }, + func(i *sdkmath.Uint) { _ = i.Incr() }, + func(i *sdkmath.Uint) { + if i.IsZero() { + return + } + + _ = i.Decr() + }, + } + + for i := 0; i < 1000; i++ { + n := rand.Uint64() + ni := sdkmath.NewUint(n) + + for opnum, op := range ops { + op(&ni) + + s.Require().Equal(n, ni.Uint64(), "Uint is modified by operation. #%d", opnum) + s.Require().Equal(sdkmath.NewUint(n), ni, "Uint is modified by operation. #%d", opnum) + } + } +} + +func (s *uintTestSuite) TestSafeSub() { + testCases := []struct { + x, y sdkmath.Uint + expected uint64 + panic bool + }{ + {sdkmath.NewUint(0), sdkmath.NewUint(0), 0, false}, + {sdkmath.NewUint(10), sdkmath.NewUint(5), 5, false}, + {sdkmath.NewUint(5), sdkmath.NewUint(10), 5, true}, + {sdkmath.NewUint(math.MaxUint64), sdkmath.NewUint(0), math.MaxUint64, false}, + } + + for i, tc := range testCases { + tc := tc + if tc.panic { + s.Require().Panics(func() { tc.x.Sub(tc.y) }) + continue + } + s.Require().Equal( + tc.expected, tc.x.Sub(tc.y).Uint64(), + "invalid subtraction result; x: %s, y: %s, tc: #%d", tc.x, tc.y, i, + ) + } +} + +func (s *uintTestSuite) TestParseUint() { + type args struct { + s string + } + tests := []struct { + name string + args args + want sdkmath.Uint + wantErr bool + }{ + {"malformed", args{"malformed"}, sdkmath.Uint{}, true}, + {"empty", args{""}, sdkmath.Uint{}, true}, + {"positive", args{"50"}, sdkmath.NewUint(uint64(50)), false}, + {"negative", args{"-1"}, sdkmath.Uint{}, true}, + {"zero", args{"0"}, sdkmath.ZeroUint(), false}, + } + for _, tt := range tests { + got, err := sdkmath.ParseUint(tt.args.s) + if tt.wantErr { + s.Require().Error(err) + continue + } + s.Require().NoError(err) + s.Require().True(got.Equal(tt.want)) + } +} + +func (s *uintTestSuite) TestNewUintFromBigInt() { + r := big.NewInt(42) + i := sdkmath.NewUintFromBigInt(r) + s.Require().Equal(r, i.BigInt()) + + // modify r and ensure i doesn't change + r = r.SetInt64(100) + s.Require().NotEqual(r, i.BigInt()) +} + +func randuint() sdkmath.Uint { + return sdkmath.NewUint(rand.Uint64()) +} + +func (s *uintTestSuite) TestRelativePow() { + tests := []struct { + args []sdkmath.Uint + want sdkmath.Uint + }{ + {[]sdkmath.Uint{sdkmath.ZeroUint(), sdkmath.ZeroUint(), sdkmath.OneUint()}, sdkmath.OneUint()}, + {[]sdkmath.Uint{sdkmath.ZeroUint(), sdkmath.ZeroUint(), sdkmath.NewUint(10)}, sdkmath.NewUint(1)}, + {[]sdkmath.Uint{sdkmath.ZeroUint(), sdkmath.OneUint(), sdkmath.NewUint(10)}, sdkmath.ZeroUint()}, + {[]sdkmath.Uint{sdkmath.NewUint(10), sdkmath.NewUint(2), sdkmath.OneUint()}, sdkmath.NewUint(100)}, + {[]sdkmath.Uint{sdkmath.NewUint(210), sdkmath.NewUint(2), sdkmath.NewUint(100)}, sdkmath.NewUint(441)}, + {[]sdkmath.Uint{sdkmath.NewUint(2100), sdkmath.NewUint(2), sdkmath.NewUint(1000)}, sdkmath.NewUint(4410)}, + {[]sdkmath.Uint{sdkmath.NewUint(1000000001547125958), sdkmath.NewUint(600), sdkmath.NewUint(1000000000000000000)}, sdkmath.NewUint(1000000928276004850)}, + } + for i, tc := range tests { + res := sdkmath.RelativePow(tc.args[0], tc.args[1], tc.args[2]) + s.Require().Equal(tc.want, res, "unexpected result for test case %d, input: %v, got: %v", i, tc.args, res) + } +} + +func minuint(i1, i2 uint64) uint64 { + if i1 < i2 { + return i1 + } + return i2 +} + +func maxuint(i1, i2 uint64) uint64 { + if i1 > i2 { + return i1 + } + return i2 +} + +func TestRoundTripMarshalToUint(t *testing.T) { + values := []uint64{ + 0, + 1, + 1 << 10, + 1<<10 - 3, + 1<<63 - 1, + 1<<32 - 7, + 1<<22 - 8, + math.MaxUint64, + } + + for _, value := range values { + value := value + t.Run(fmt.Sprintf("%d", value), func(t *testing.T) { + t.Parallel() + + var scratch [20]byte + uv := sdkmath.NewUint(value) + n, err := uv.MarshalTo(scratch[:]) + if err != nil { + t.Fatal(err) + } + rt := new(sdkmath.Uint) + if err := rt.Unmarshal(scratch[:n]); err != nil { + t.Fatal(err) + } + if !rt.Equal(uv) { + t.Fatalf("roundtrip=%q != original=%q", rt, uv) + } + }) + } +} + +func TestWeakUnmarshalNegativeSign(t *testing.T) { + neg10, _ := new(big.Int).SetString("-10", 0) + blob, err := neg10.MarshalText() + if err != nil { + t.Fatal(err) + } + + ui := new(sdkmath.Uint) + err = ui.Unmarshal(blob) + if err == nil { + t.Fatal("Failed to catch the negative value") + } + if errStr := err.Error(); !strings.Contains(errStr, "non-positive") { + t.Fatalf("negative value not reported, got instead %q", errStr) + } +} + +func TestWeakUnmarshalOverflow(t *testing.T) { + exp := new(big.Int).SetUint64(256) + pos10, _ := new(big.Int).SetString("10", 0) + exp10Pow256 := new(big.Int).Exp(pos10, exp, nil) + blob, err := exp10Pow256.MarshalText() + if err != nil { + t.Fatal(err) + } + + ui := new(sdkmath.Uint) + err = ui.Unmarshal(blob) + if err == nil { + t.Fatal("Failed to catch the overflowed value") + } + if errStr := err.Error(); !strings.Contains(errStr, "out of range") { + t.Fatalf("out of range value not reported, got instead %q", errStr) + } +} diff --git a/unsafe/rand.go b/unsafe/rand.go new file mode 100644 index 0000000..66d48ae --- /dev/null +++ b/unsafe/rand.go @@ -0,0 +1,147 @@ +package unsafe + +import ( + crand "crypto/rand" + mrand "math/rand" + "sync" +) + +const ( + strChars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" // 62 characters +) + +// Rand is a prng, that is seeded with OS randomness. +// The OS randomness is obtained from crypto/rand, however none of the provided +// methods are suitable for cryptographic usage. +// They all utilize math/rand's prng internally. +// +// All of the methods here are suitable for concurrent use. +// This is achieved by using a mutex lock on all of the provided methods. +type Rand struct { + sync.Mutex + rand *mrand.Rand +} + +var grand *Rand + +func init() { + grand = NewRand() +} + +func NewRand() *Rand { + rand := &Rand{} + rand.init() + return rand +} + +func (r *Rand) init() { + bz := cRandBytes(8) + var seed uint64 + for i := 0; i < 8; i++ { + seed |= uint64(bz[i]) + seed <<= 8 + } + r.reset(int64(seed)) +} + +func (r *Rand) reset(seed int64) { + r.rand = mrand.New(mrand.NewSource(seed)) +} + +//---------------------------------------- +// Global functions + +func Seed(seed int64) { + grand.Seed(seed) +} + +func Str(length int) string { + return grand.Str(length) +} + +func Int63() int64 { + return grand.Int63() +} + +func Int() int { + return grand.Int() +} + +func Bytes(n int) []byte { + return grand.Bytes(n) +} + +//---------------------------------------- +// Rand methods + +func (r *Rand) Seed(seed int64) { + r.Lock() + r.reset(seed) + r.Unlock() +} + +// Str constructs a random alphanumeric string of given length. +func (r *Rand) Str(length int) string { + if length <= 0 { + return "" + } + + chars := []byte{} +MAIN_LOOP: + for { + val := r.Int63() + for i := 0; i < 10; i++ { + v := int(val & 0x3f) // rightmost 6 bits + if v >= 62 { // only 62 characters in strChars + val >>= 6 + continue + } else { + chars = append(chars, strChars[v]) + if len(chars) == length { + break MAIN_LOOP + } + val >>= 6 + } + } + } + + return string(chars) +} + +func (r *Rand) Int63() int64 { + r.Lock() + i63 := r.rand.Int63() + r.Unlock() + return i63 +} + +func (r *Rand) Int() int { + r.Lock() + i := r.rand.Int() + r.Unlock() + return i +} + +// Bytes returns n random bytes generated from the internal +// prng. +func (r *Rand) Bytes(n int) []byte { + // cRandBytes isn't guaranteed to be fast so instead + // use random bytes generated from the internal PRNG + bs := make([]byte, n) + for i := 0; i < len(bs); i++ { + bs[i] = byte(r.Int() & 0xFF) + } + return bs +} + +// NOTE: This relies on the os's random number generator. +// For real security, we should salt that with some seed. +// See github.com/cometbft/cometbft/crypto for a more secure reader. +func cRandBytes(numBytes int) []byte { + b := make([]byte, numBytes) + _, err := crand.Read(b) + if err != nil { + panic(err) + } + return b +}