diff --git a/calculator.go b/calculator.go index 480862c..56b50e0 100644 --- a/calculator.go +++ b/calculator.go @@ -1,73 +1,52 @@ package money -import "math" +import ( + "github.com/shopspring/decimal" +) type calculator struct{} func (c *calculator) add(a, b Amount) Amount { - return a + b + return a.Add(b) } func (c *calculator) subtract(a, b Amount) Amount { - return a - b + return a.Sub(b) } func (c *calculator) multiply(a Amount, m int64) Amount { - return a * m + return a.Mul(decimal.NewFromInt(m)) } func (c *calculator) divide(a Amount, d int64) Amount { - return a / d + return a.Div(decimal.NewFromInt(d)) } func (c *calculator) modulus(a Amount, d int64) Amount { - return a % d + return a.Mod(decimal.NewFromInt(d)) } func (c *calculator) allocate(a Amount, r, s uint) Amount { - if a == 0 || s == 0 { - return 0 + if a.IsZero() || s == 0 { + return decimal.Zero } - return a * int64(r) / int64(s) + res := a.Mul(decimal.NewFromInt(int64(r))).Div(decimal.NewFromInt(int64(s))).IntPart() + return decimal.NewFromInt(res) } func (c *calculator) absolute(a Amount) Amount { - if a < 0 { - return -a - } - - return a + return a.Abs() } func (c *calculator) negative(a Amount) Amount { - if a > 0 { - return -a + if a.IsPositive() { + return a.Mul(decimal.NewFromInt(-1)) } return a } func (c *calculator) round(a Amount, e int) Amount { - if a == 0 { - return 0 - } - - absam := c.absolute(a) - exp := int64(math.Pow(10, float64(e))) - m := absam % exp - - if m > (exp / 2) { - absam += exp - } - - absam = (absam / exp) * exp - - if a < 0 { - a = -absam - } else { - a = absam - } - - return a + return a.Round(int32(e * -1)) } diff --git a/go.mod b/go.mod index 6d4223b..f9a9b34 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/Rhymond/go-money go 1.13 + +require github.com/shopspring/decimal v1.3.1 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..3289fec --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= +github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= diff --git a/money.go b/money.go index 28fc431..e836b5d 100644 --- a/money.go +++ b/money.go @@ -5,13 +5,15 @@ import ( "encoding/json" "errors" "fmt" - "math" + + "github.com/shopspring/decimal" ) // Injection points for backward compatibility. // If you need to keep your JSON marshal/unmarshal way, overwrite them like below. -// money.UnmarshalJSON = func (m *Money, b []byte) error { ... } -// money.MarshalJSON = func (m Money) ([]byte, error) { ... } +// +// money.UnmarshalJSON = func (m *Money, b []byte) error { ... } +// money.MarshalJSON = func (m Money) ([]byte, error) { ... } var ( // UnmarshalJSON is injection point of json.Unmarshaller for money.Money UnmarshalJSON = defaultUnmarshalJSON @@ -69,7 +71,7 @@ func defaultMarshalJSON(m Money) ([]byte, error) { } // Amount is a data structure that stores the amount being used for calculations. -type Amount = int64 +type Amount = decimal.Decimal // Money represents monetary value information, stores // currency and amount value. @@ -81,7 +83,7 @@ type Money struct { // New creates and returns new instance of Money. func New(amount int64, code string) *Money { return &Money{ - amount: amount, + amount: decimal.NewFromInt(amount), currency: newCurrency(code).get(), } } @@ -89,8 +91,9 @@ func New(amount int64, code string) *Money { // NewFromFloat creates and returns new instance of Money from a float64. // Always rounding trailing decimals down. func NewFromFloat(amount float64, currency string) *Money { - currencyDecimals := math.Pow10(GetCurrency(currency).Fraction) - return New(int64(amount*currencyDecimals), currency) + c := GetCurrency(currency) + amt := decimal.NewFromFloat(amount).Mul(decimal.New(1, int32(c.Fraction))) + return New(amt.IntPart(), currency) } // Currency returns the currency used by Money. @@ -100,7 +103,7 @@ func (m *Money) Currency() *Currency { // Amount returns a copy of the internal monetary value as an int64. func (m *Money) Amount() int64 { - return m.amount + return m.amount.IntPart() } // SameCurrency check if given Money is equals by currency. @@ -117,14 +120,7 @@ func (m *Money) assertSameCurrency(om *Money) error { } func (m *Money) compare(om *Money) int { - switch { - case m.amount > om.amount: - return 1 - case m.amount < om.amount: - return -1 - } - - return 0 + return m.amount.Cmp(om.amount) } // Equals checks equality between two Money types. @@ -174,17 +170,17 @@ func (m *Money) LessThanOrEqual(om *Money) (bool, error) { // IsZero returns boolean of whether the value of Money is equals to zero. func (m *Money) IsZero() bool { - return m.amount == 0 + return m.amount.IsZero() } // IsPositive returns boolean of whether the value of Money is positive. func (m *Money) IsPositive() bool { - return m.amount > 0 + return m.amount.IsPositive() } // IsNegative returns boolean of whether the value of Money is negative. func (m *Money) IsNegative() bool { - return m.amount < 0 + return m.amount.IsNegative() } // Absolute returns new Money struct from given Money using absolute monetary value. @@ -245,12 +241,12 @@ func (m *Money) Split(n int) ([]*Money, error) { // Add leftovers to the first parties. v := int64(1) - if m.amount < 0 { + if m.amount.IsNegative() { v = -1 } - for p := 0; l != 0; p++ { - ms[p].amount = mutate.calc.add(ms[p].amount, v) - l-- + for p := 0; !l.IsZero(); p++ { + ms[p].amount = mutate.calc.add(ms[p].amount, decimal.NewFromInt(v)) + l = l.Sub(decimal.NewFromInt(1)) } return ms, nil @@ -273,7 +269,7 @@ func (m *Money) Allocate(rs ...int) ([]*Money, error) { sum += uint(r) } - var total int64 + var total decimal.Decimal ms := make([]*Money, 0, len(rs)) for _, r := range rs { party := &Money{ @@ -282,7 +278,7 @@ func (m *Money) Allocate(rs ...int) ([]*Money, error) { } ms = append(ms, party) - total += party.amount + total = total.Add(party.amount) } // if the sum of all ratios is zero, then we just returns zeros and don't do anything @@ -292,15 +288,15 @@ func (m *Money) Allocate(rs ...int) ([]*Money, error) { } // Calculate leftover value and divide to first parties. - lo := m.amount - total - sub := int64(1) - if lo < 0 { - sub = -sub + lo := m.amount.Sub(total) + sub := decimal.NewFromInt(1) + if lo.IsNegative() { + sub = sub.Mul(decimal.NewFromInt(-1)) } - for p := 0; lo != 0; p++ { + for p := 0; !lo.IsZero(); p++ { ms[p].amount = mutate.calc.add(ms[p].amount, sub) - lo -= sub + lo = lo.Sub(sub) } return ms, nil @@ -309,13 +305,13 @@ func (m *Money) Allocate(rs ...int) ([]*Money, error) { // Display lets represent Money struct as string in given Currency value. func (m *Money) Display() string { c := m.currency.get() - return c.Formatter().Format(m.amount) + return c.Formatter().Format(m.amount.IntPart()) } // AsMajorUnits lets represent Money struct as subunits (float64) in given Currency value func (m *Money) AsMajorUnits() float64 { c := m.currency.get() - return c.Formatter().ToMajorUnits(m.amount) + return c.Formatter().ToMajorUnits(m.amount.IntPart()) } // UnmarshalJSON is implementation of json.Unmarshaller @@ -329,13 +325,15 @@ func (m Money) MarshalJSON() ([]byte, error) { } // Compare function compares two money of the same type -// if m.amount > om.amount returns (1, nil) -// if m.amount == om.amount returns (0, nil -// if m.amount < om.amount returns (-1, nil) +// +// if m.amount > om.amount returns (1, nil) +// if m.amount == om.amount returns (0, nil +// if m.amount < om.amount returns (-1, nil) +// // If compare moneys from distinct currency, return (m.amount, ErrCurrencyMismatch) func (m *Money) Compare(om *Money) (int, error) { if err := m.assertSameCurrency(om); err != nil { - return int(m.amount), err + return int(m.amount.IntPart()), err } return m.compare(om), nil diff --git a/money_example_test.go b/money_example_test.go index 8e948dd..5dda929 100644 --- a/money_example_test.go +++ b/money_example_test.go @@ -40,6 +40,14 @@ func ExampleNew() { // £1.00 } +func ExampleNewFromFloat() { + amount := 136.98 + fmt.Println(money.NewFromFloat(amount, "SGD").Display()) + + // Output: + // $136.98 +} + func ExampleMoney_comparisons() { pound := money.New(100, "GBP") twoPounds := money.New(200, "GBP") diff --git a/money_test.go b/money_test.go index e336998..c85f6ad 100644 --- a/money_test.go +++ b/money_test.go @@ -7,12 +7,14 @@ import ( "fmt" "reflect" "testing" + + "github.com/shopspring/decimal" ) func TestNew(t *testing.T) { m := New(1, EUR) - if m.amount != 1 { + if !m.amount.Equal(decimal.NewFromInt(1)) { t.Errorf("Expected %d got %d", 1, m.amount) } @@ -22,7 +24,7 @@ func TestNew(t *testing.T) { m = New(-100, EUR) - if m.amount != -100 { + if !m.amount.Equal(decimal.NewFromInt(-100)) { t.Errorf("Expected %d got %d", -100, m.amount) } } @@ -255,7 +257,7 @@ func TestMoney_Absolute(t *testing.T) { m := New(tc.amount, EUR) r := m.Absolute().amount - if r != tc.expected { + if !r.Equal(decimal.NewFromInt(tc.expected)) { t.Errorf("Expected absolute %d to be %d got %d", m.amount, tc.expected, r) } @@ -276,7 +278,7 @@ func TestMoney_Negative(t *testing.T) { m := New(tc.amount, EUR) r := m.Negative().amount - if r != tc.expected { + if !r.Equal(decimal.NewFromInt(tc.expected)) { t.Errorf("Expected absolute %d to be %d got %d", m.amount, tc.expected, r) } @@ -340,7 +342,7 @@ func TestMoney_Subtract(t *testing.T) { t.Error(err) } - if r.amount != tc.expected { + if !r.amount.Equal(decimal.NewFromInt(tc.expected)) { t.Errorf("Expected %d - %d = %d got %d", tc.amount1, tc.amount2, tc.expected, r.amount) } @@ -373,7 +375,7 @@ func TestMoney_Multiply(t *testing.T) { m := New(tc.amount, EUR) r := m.Multiply(tc.multiplier).amount - if r != tc.expected { + if !r.Equal(decimal.NewFromInt(tc.expected)) { t.Errorf("Expected %d * %d = %d got %d", tc.amount, tc.multiplier, tc.expected, r) } } @@ -397,7 +399,7 @@ func TestMoney_Round(t *testing.T) { m := New(tc.amount, EUR) r := m.Round().amount - if r != tc.expected { + if !r.Equal(decimal.NewFromInt(tc.expected)) { t.Errorf("Expected rounded %d to be %d got %d", tc.amount, tc.expected, r) } } @@ -416,7 +418,7 @@ func TestMoney_RoundWithExponential(t *testing.T) { m := New(tc.amount, "CUR") r := m.Round().amount - if r != tc.expected { + if !r.Equal(decimal.NewFromInt(tc.expected)) { t.Errorf("Expected rounded %d to be %d got %d", tc.amount, tc.expected, r) } } @@ -442,7 +444,7 @@ func TestMoney_Split(t *testing.T) { split, _ := m.Split(tc.split) for _, party := range split { - rs = append(rs, party.amount) + rs = append(rs, party.amount.IntPart()) } if !reflect.DeepEqual(tc.expected, rs) { @@ -482,7 +484,7 @@ func TestMoney_Allocate(t *testing.T) { split, _ := m.Allocate(tc.ratios...) for _, party := range split { - rs = append(rs, party.amount) + rs = append(rs, party.amount.IntPart()) } if !reflect.DeepEqual(tc.expected, rs) { @@ -657,7 +659,7 @@ func TestMoney_Amount(t *testing.T) { func TestNewFromFloat(t *testing.T) { m := NewFromFloat(12.34, EUR) - if m.amount != 1234 { + if !m.amount.Equal(decimal.NewFromInt(1234)) { t.Errorf("Expected %d got %d", 1234, m.amount) } @@ -667,7 +669,7 @@ func TestNewFromFloat(t *testing.T) { m = NewFromFloat(-0.125, EUR) - if m.amount != -12 { + if !m.amount.Equal(decimal.NewFromInt(-12)) { t.Errorf("Expected %d got %d", -12, m.amount) } }