From c5d20476055f0624d002f46e98f26616fbf79152 Mon Sep 17 00:00:00 2001 From: Edwin Date: Wed, 17 Jan 2024 11:34:15 +0800 Subject: [PATCH] pkg/exchange: emit balance snapshot after authed --- pkg/exchange/okex/exchange.go | 2 +- pkg/exchange/okex/stream.go | 53 +++++++++++++++++++++++--------- pkg/exchange/okex/stream_test.go | 5 ++- 3 files changed, 44 insertions(+), 16 deletions(-) diff --git a/pkg/exchange/okex/exchange.go b/pkg/exchange/okex/exchange.go index 389b56a502..dc20a2b717 100644 --- a/pkg/exchange/okex/exchange.go +++ b/pkg/exchange/okex/exchange.go @@ -375,7 +375,7 @@ func (e *Exchange) CancelOrders(ctx context.Context, orders ...types.Order) erro } func (e *Exchange) NewStream() types.Stream { - return NewStream(e.client) + return NewStream(e.client, e) } func (e *Exchange) QueryKLines(ctx context.Context, symbol string, interval types.Interval, options types.KLineQueryOptions) ([]types.KLine, error) { diff --git a/pkg/exchange/okex/stream.go b/pkg/exchange/okex/stream.go index 96b955c73d..1004af03e7 100644 --- a/pkg/exchange/okex/stream.go +++ b/pkg/exchange/okex/stream.go @@ -8,6 +8,7 @@ import ( "time" "github.com/c9s/bbgo/pkg/exchange/okex/okexapi" + "github.com/c9s/bbgo/pkg/exchange/retry" "github.com/c9s/bbgo/pkg/types" ) @@ -31,7 +32,8 @@ type WebsocketLogin struct { type Stream struct { types.StandardStream - client *okexapi.RestClient + client *okexapi.RestClient + balanceProvider types.ExchangeAccountService // public callbacks kLineEventCallbacks []func(candle KLineEvent) @@ -41,10 +43,11 @@ type Stream struct { marketTradeEventCallbacks []func(tradeDetail []MarketTradeEvent) } -func NewStream(client *okexapi.RestClient) *Stream { +func NewStream(client *okexapi.RestClient, balanceProvider types.ExchangeAccountService) *Stream { stream := &Stream{ - client: client, - StandardStream: types.NewStandardStream(), + client: client, + balanceProvider: balanceProvider, + StandardStream: types.NewStandardStream(), } stream.SetParser(parseWebSocketEvent) @@ -57,7 +60,7 @@ func NewStream(client *okexapi.RestClient) *Stream { stream.OnMarketTradeEvent(stream.handleMarketTradeEvent) stream.OnOrderDetailsEvent(stream.handleOrderDetailsEvent) stream.OnConnect(stream.handleConnect) - stream.OnAuth(stream.handleAuth) + stream.OnAuth(stream.subscribePrivateChannels(stream.emitBalanceSnapshot)) return stream } @@ -151,20 +154,42 @@ func (s *Stream) handleConnect() { } } -func (s *Stream) handleAuth() { - var subs = []WebsocketSubscription{ - {Channel: ChannelAccount}, - {Channel: "orders", InstrumentType: string(okexapi.InstrumentTypeSpot)}, +func (s *Stream) subscribePrivateChannels(next func()) func() { + return func() { + var subs = []WebsocketSubscription{ + {Channel: ChannelAccount}, + {Channel: "orders", InstrumentType: string(okexapi.InstrumentTypeSpot)}, + } + + log.Infof("subscribing private channels: %+v", subs) + err := s.Conn.WriteJSON(WebsocketOp{ + Op: "subscribe", + Args: subs, + }) + if err != nil { + log.WithError(err).Error("private channel subscribe error") + return + } + next() } +} - log.Infof("subscribing private channels: %+v", subs) - err := s.Conn.WriteJSON(WebsocketOp{ - Op: "subscribe", - Args: subs, +func (s *Stream) emitBalanceSnapshot() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + var balancesMap types.BalanceMap + var err error + err = retry.GeneralBackoff(ctx, func() error { + balancesMap, err = s.balanceProvider.QueryAccountBalances(ctx) + return err }) if err != nil { - log.WithError(err).Error("private channel subscribe error") + log.WithError(err).Error("no more attempts to retrieve balances") + return } + + s.EmitBalanceSnapshot(balancesMap) } func (s *Stream) handleOrderDetailsEvent(orderDetails []okexapi.OrderDetails) { diff --git a/pkg/exchange/okex/stream_test.go b/pkg/exchange/okex/stream_test.go index 7f85973adb..e2e2609539 100644 --- a/pkg/exchange/okex/stream_test.go +++ b/pkg/exchange/okex/stream_test.go @@ -25,7 +25,7 @@ func getTestClientOrSkip(t *testing.T) *Stream { } exchange := New(key, secret, passphrase) - return NewStream(exchange.client) + return NewStream(exchange.client, exchange) } func TestStream(t *testing.T) { @@ -39,6 +39,9 @@ func TestStream(t *testing.T) { s.OnBalanceUpdate(func(balances types.BalanceMap) { t.Log("got snapshot", balances) }) + s.OnBalanceSnapshot(func(balances types.BalanceMap) { + t.Log("got snapshot", balances) + }) s.OnBookUpdate(func(book types.SliceOrderBook) { t.Log("got update", book) })