Skip to content

Commit

Permalink
Merge pull request #1503 from c9s/edwin/okx/push-balance-snapshot
Browse files Browse the repository at this point in the history
FEATURE: [okx] emit balance snapshot after authenticated
  • Loading branch information
bailantaotao authored Jan 17, 2024
2 parents 735123b + c5d2047 commit 7f90620
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pkg/exchange/okex/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ func (e *Exchange) CancelOrders(ctx context.Context, orders ...types.Order) erro
}

func (e *Exchange) NewStream() types.Stream {
return NewStream(e.client)
return NewStream(e.client, e)
}

func (e *Exchange) QueryKLines(ctx context.Context, symbol string, interval types.Interval, options types.KLineQueryOptions) ([]types.KLine, error) {
Expand Down
53 changes: 39 additions & 14 deletions pkg/exchange/okex/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/c9s/bbgo/pkg/exchange/okex/okexapi"
"github.com/c9s/bbgo/pkg/exchange/retry"
"github.com/c9s/bbgo/pkg/types"
)

Expand All @@ -31,7 +32,8 @@ type WebsocketLogin struct {
type Stream struct {
types.StandardStream

client *okexapi.RestClient
client *okexapi.RestClient
balanceProvider types.ExchangeAccountService

// public callbacks
kLineEventCallbacks []func(candle KLineEvent)
Expand All @@ -41,10 +43,11 @@ type Stream struct {
marketTradeEventCallbacks []func(tradeDetail []MarketTradeEvent)
}

func NewStream(client *okexapi.RestClient) *Stream {
func NewStream(client *okexapi.RestClient, balanceProvider types.ExchangeAccountService) *Stream {
stream := &Stream{
client: client,
StandardStream: types.NewStandardStream(),
client: client,
balanceProvider: balanceProvider,
StandardStream: types.NewStandardStream(),
}

stream.SetParser(parseWebSocketEvent)
Expand All @@ -57,7 +60,7 @@ func NewStream(client *okexapi.RestClient) *Stream {
stream.OnMarketTradeEvent(stream.handleMarketTradeEvent)
stream.OnOrderDetailsEvent(stream.handleOrderDetailsEvent)
stream.OnConnect(stream.handleConnect)
stream.OnAuth(stream.handleAuth)
stream.OnAuth(stream.subscribePrivateChannels(stream.emitBalanceSnapshot))
return stream
}

Expand Down Expand Up @@ -151,20 +154,42 @@ func (s *Stream) handleConnect() {
}
}

func (s *Stream) handleAuth() {
var subs = []WebsocketSubscription{
{Channel: ChannelAccount},
{Channel: "orders", InstrumentType: string(okexapi.InstrumentTypeSpot)},
func (s *Stream) subscribePrivateChannels(next func()) func() {
return func() {
var subs = []WebsocketSubscription{
{Channel: ChannelAccount},
{Channel: "orders", InstrumentType: string(okexapi.InstrumentTypeSpot)},
}

log.Infof("subscribing private channels: %+v", subs)
err := s.Conn.WriteJSON(WebsocketOp{
Op: "subscribe",
Args: subs,
})
if err != nil {
log.WithError(err).Error("private channel subscribe error")
return
}
next()
}
}

log.Infof("subscribing private channels: %+v", subs)
err := s.Conn.WriteJSON(WebsocketOp{
Op: "subscribe",
Args: subs,
func (s *Stream) emitBalanceSnapshot() {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()

var balancesMap types.BalanceMap
var err error
err = retry.GeneralBackoff(ctx, func() error {
balancesMap, err = s.balanceProvider.QueryAccountBalances(ctx)
return err
})
if err != nil {
log.WithError(err).Error("private channel subscribe error")
log.WithError(err).Error("no more attempts to retrieve balances")
return
}

s.EmitBalanceSnapshot(balancesMap)
}

func (s *Stream) handleOrderDetailsEvent(orderDetails []okexapi.OrderDetails) {
Expand Down
5 changes: 4 additions & 1 deletion pkg/exchange/okex/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func getTestClientOrSkip(t *testing.T) *Stream {
}

exchange := New(key, secret, passphrase)
return NewStream(exchange.client)
return NewStream(exchange.client, exchange)
}

func TestStream(t *testing.T) {
Expand All @@ -39,6 +39,9 @@ func TestStream(t *testing.T) {
s.OnBalanceUpdate(func(balances types.BalanceMap) {
t.Log("got snapshot", balances)
})
s.OnBalanceSnapshot(func(balances types.BalanceMap) {
t.Log("got snapshot", balances)
})
s.OnBookUpdate(func(book types.SliceOrderBook) {
t.Log("got update", book)
})
Expand Down

0 comments on commit 7f90620

Please sign in to comment.