Skip to content

Commit

Permalink
🔧 fix(coinbase): bug fix and add tests for API interactions
Browse files Browse the repository at this point in the history
  • Loading branch information
dboyliao committed Feb 26, 2025
1 parent 5137f42 commit 910e1c9
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 41 deletions.
23 changes: 11 additions & 12 deletions pkg/exchange/coinbase/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func toGlobalOrder(cbOrder *api.Order) types.Order {
UUID: cbOrder.ID,
OrderID: FNV64a(cbOrder.ID),
OriginalStatus: string(cbOrder.Status),
CreationTime: types.Time(cbOrder.CreatedAt),
CreationTime: cbOrder.CreatedAt,
}
}

Expand Down Expand Up @@ -60,18 +60,17 @@ func FNV64a(text string) uint64 {
return hash.Sum64()
}

func toGlobalKline(symbol string, granity string, candle *api.Candle) types.KLine {
func toGlobalKline(symbol string, interval types.Interval, candle *api.Candle) types.KLine {
kline := types.KLine{
Exchange: types.ExchangeCoinBase,
Symbol: symbol,
StartTime: types.Time(candle.Time),
EndTime: types.Time(time.Time(candle.Time).Add(types.Interval(granity).Duration())),
Interval: types.Interval(granity),
Open: candle.Open,
Close: candle.Close,
High: candle.High,
Low: candle.Low,
Volume: candle.Volume,
Exchange: types.ExchangeCoinBase,
Symbol: symbol,
EndTime: types.Time(candle.Time),
Interval: interval,
Open: candle.Open,
Close: candle.Close,
High: candle.High,
Low: candle.Low,
Volume: candle.Volume,
}
return kline
}
Expand Down
72 changes: 43 additions & 29 deletions pkg/exchange/coinbase/exchage.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,20 @@ func (e *Exchange) SubmitOrder(ctx context.Context, order types.SubmitOrder) (cr
req.Price(order.Price)
}

// set time in force
switch order.TimeInForce {
case types.TimeInForceGTC:
req.TimeInForce("GTC")
case types.TimeInForceIOC:
req.TimeInForce("IOC")
case types.TimeInForceFOK:
req.TimeInForce("FOK")
case types.TimeInForceGTT:
req.TimeInForce("GTT")
default:
return nil, fmt.Errorf("unsupported time in force: %v", order.TimeInForce)
// set time in force, using default if not set
if len(order.TimeInForce) > 0 {
switch order.TimeInForce {
case types.TimeInForceGTC:
req.TimeInForce("GTC")
case types.TimeInForceIOC:
req.TimeInForce("IOC")
case types.TimeInForceFOK:
req.TimeInForce("FOK")
case types.TimeInForceGTT:
req.TimeInForce("GTT")
default:
return nil, fmt.Errorf("unsupported time in force: %v", order.TimeInForce)
}
}
// client order id
if len(order.ClientOrderID) > 0 {
Expand Down Expand Up @@ -286,7 +288,7 @@ func (e *Exchange) QueryTickers(ctx context.Context, symbol ...string) (map[stri
for _, s := range symbol {
ticker, err := e.QueryTicker(ctx, s)
if err != nil {
return nil, errors.Wrap(err, "failed to get tickers")
return nil, errors.Wrapf(err, "failed to get ticker for %v", s)
}
tickers[s] = *ticker
}
Expand All @@ -305,29 +307,38 @@ func (e *Exchange) QueryKLines(ctx context.Context, symbol string, interval type
log.Warnf("limit %d is greater than the maximum limit 300, set to 300", options.Limit)
options.Limit = DefaultKLineLimit
}
granity := interval.String()
granity := fmt.Sprintf("%d", interval.Seconds())
req := e.client.NewGetCandlesRequest().ProductID(toLocalSymbol(symbol)).Granularity(granity)
if options.StartTime != nil {
req.Start(*options.StartTime)
}
if options.EndTime != nil {
req.End(*options.EndTime)
}
rawCandles, err := req.Do(ctx)
res, err := req.Do(ctx)
if err != nil {
return nil, errors.Wrapf(err, "failed to get klines(%v): %v", interval, symbol)
}
if len(rawCandles) > options.Limit {
rawCandles = rawCandles[:options.Limit]
}
klines := make([]types.KLine, 0, len(rawCandles))
for _, rawCandle := range rawCandles {
candle, err := rawCandle.Candle()
candles := make([]api.Candle, 0, len(res))
for _, c := range res {
candle, err := c.Candle()
if err != nil {
log.Warnf("invalid raw candle detected, skipped: %v", rawCandle)
log.Warnf("invalid raw candle detected, skipping: %v", c)
continue
}
klines = append(klines, toGlobalKline(symbol, granity, candle))
candles = append(candles, *candle)
}
numCandles := len(candles)
klines := make([]types.KLine, 0, numCandles)
if numCandles > 0 {
for idx, candle := range candles {
kline := toGlobalKline(symbol, interval, &candle)
klines = append(klines, kline)
if idx > 0 {
klines[idx-1].StartTime = kline.EndTime
}
}
klines[numCandles-1].StartTime = types.Time(klines[numCandles-1].EndTime.Time().Add(-interval.Duration()))
}
return klines, nil
}
Expand All @@ -344,7 +355,7 @@ func (e *Exchange) QueryOrder(ctx context.Context, q types.OrderQuery) (*types.O
}

func (e *Exchange) QueryOrderTrades(ctx context.Context, q types.OrderQuery) ([]types.Trade, error) {
cbTrades, err := e.queryOrderTradesByPagination(ctx, q.OrderID)
cbTrades, err := e.queryOrderTradesByPagination(ctx, q)
if err != nil {
return nil, errors.Wrapf(err, "failed to get order trades: %v", q.OrderID)
}
Expand All @@ -355,15 +366,18 @@ func (e *Exchange) QueryOrderTrades(ctx context.Context, q types.OrderQuery) ([]
return trades, nil
}

func (e *Exchange) queryOrderTradesByPagination(ctx context.Context, orderID string) (api.TradeSnapshot, error) {
req := e.client.NewGetOrderTradesRequest().OrderID(orderID).Limit(PaginationLimit)
func (e *Exchange) queryOrderTradesByPagination(ctx context.Context, q types.OrderQuery) (api.TradeSnapshot, error) {
req := e.client.NewGetOrderTradesRequest().Limit(PaginationLimit)
if len(q.OrderID) > 0 {
req.OrderID(q.OrderID)
}
if len(q.Symbol) > 0 {
req.ProductID(toLocalSymbol(q.Symbol))
}
cbTrades, err := req.Do(ctx)
if err != nil {
return nil, err
}
if len(cbTrades) < PaginationLimit {
return cbTrades, nil
}

if len(cbTrades) < PaginationLimit {
return cbTrades, nil
Expand Down
161 changes: 161 additions & 0 deletions pkg/exchange/coinbase/exchange_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package coinbase

import (
"context"
"os"
"strconv"
"testing"

"github.com/stretchr/testify/assert"

"github.com/c9s/bbgo/pkg/fixedpoint"
"github.com/c9s/bbgo/pkg/testutil"
"github.com/c9s/bbgo/pkg/types"
)

func Test_new(t *testing.T) {
ex := getExchangeOrSkip(t)
assert.Equal(t, ex.Name(), types.ExchangeCoinBase)
t.Log("successfully created coinbase exchange client")
_ = ex.SupportedInterval()
_ = ex.PlatformFeeCurrency()
}

func Test_OrdersAPI(t *testing.T) {
ex := getExchangeOrSkip(t)
ctx := context.Background()

// should fail on unsupported symbol
order, err := ex.SubmitOrder(
ctx,
types.SubmitOrder{
Market: types.Market{
Symbol: "NOTEXIST",
},
Side: types.SideTypeBuy,
Type: types.OrderTypeLimit,
Price: fixedpoint.MustNewFromString("0.001"),
Quantity: fixedpoint.MustNewFromString("0.001"),
})
assert.Error(t, err)
assert.Empty(t, order)
// should succeed
order, err = ex.SubmitOrder(
ctx,
types.SubmitOrder{
Market: types.Market{
Symbol: "ETHUSD",
},
Side: types.SideTypeBuy,
Type: types.OrderTypeLimit,
Price: fixedpoint.MustNewFromString("0.01"),
Quantity: fixedpoint.MustNewFromString("100"), // min funds is $1
})
assert.NoError(t, err)
assert.NotEmpty(t, order)

// test query open orders
order, err = ex.QueryOrder(ctx, types.OrderQuery{Symbol: "ETHUSD", OrderID: order.UUID, ClientOrderID: order.UUID})
assert.NoError(t, err)

orders, err := ex.QueryOpenOrders(ctx, "ETHUSD")
assert.NoError(t, err)
found := false
for _, o := range orders {
if o.UUID == order.UUID {
found = true
break
}
}
assert.True(t, found)

// test cancel order
err = ex.CancelOrders(ctx, types.Order{
Exchange: types.ExchangeCoinBase,
UUID: order.UUID,
})
assert.NoError(t, err)
}

func Test_QueryAccount(t *testing.T) {
ex := getExchangeOrSkip(t)
ctx := context.Background()
_, err := ex.QueryAccount(ctx)
assert.NoError(t, err)
}

func Test_QueryAccountBalances(t *testing.T) {
ex := getExchangeOrSkip(t)
ctx := context.Background()
_, err := ex.QueryAccountBalances(ctx)
assert.NoError(t, err)
}

func Test_QueryOpenOrders(t *testing.T) {
ex := getExchangeOrSkip(t)
ctx := context.Background()

symbols := []string{"BTCUSD", "ETHUSD", "ETHBTC"}
for _, k := range symbols {
_, err := ex.QueryOpenOrders(ctx, k)
assert.NoError(t, err)
}
}

func Test_QueryMarkets(t *testing.T) {
ex := getExchangeOrSkip(t)
ctx := context.Background()
_, err := ex.QueryMarkets(ctx)
assert.NoError(t, err)
}

func Test_QueryTicker(t *testing.T) {
ex := getExchangeOrSkip(t)
ctx := context.Background()
ticker, err := ex.QueryTicker(ctx, "BTCUSD")
assert.NoError(t, err)
assert.NotNil(t, ticker)
}

func Test_QueryTickers(t *testing.T) {
ex := getExchangeOrSkip(t)
ctx := context.Background()
symbols := []string{"BTCUSD", "ETHUSD", "ETHBTC"}
tickers, err := ex.QueryTickers(ctx, symbols...)
assert.NoError(t, err)
assert.NotNil(t, tickers)
}

func Test_QueryKLines(t *testing.T) {
ex := getExchangeOrSkip(t)
ctx := context.Background()
// should fail on unsupported interval
_, err := ex.QueryKLines(ctx, "BTCUSD", types.Interval12h, types.KLineQueryOptions{})
assert.Error(t, err)

klines, err := ex.QueryKLines(ctx, "BTCUSD", types.Interval1m, types.KLineQueryOptions{})
assert.NoError(t, err)
assert.NotNil(t, klines)
}

func Test_QueryOrderTrades(t *testing.T) {
ex := getExchangeOrSkip(t)
ctx := context.Background()

trades, err := ex.QueryOrderTrades(ctx, types.OrderQuery{Symbol: "ETHUSD"})
assert.NoError(t, err)
assert.NotNil(t, trades)
}

func getExchangeOrSkip(t *testing.T) *Exchange {
if b, _ := strconv.ParseBool(os.Getenv("CI")); b {
t.Skip("skip test for CI")
}
key, secret, passphrase, ok := testutil.IntegrationTestWithPassphraseConfigured(t, "COINBASE")
if !ok {
t.SkipNow()
return nil
}

return New(key, secret, passphrase, 0)
}

0 comments on commit 910e1c9

Please sign in to comment.