Skip to content

Commit

Permalink
Separate ws subscribe requests (#390)
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric-Warehime authored May 6, 2024
1 parent a3dbc2f commit 29104c3
Show file tree
Hide file tree
Showing 16 changed files with 269 additions and 197 deletions.
19 changes: 4 additions & 15 deletions providers/websockets/bybit/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,27 +64,16 @@ func NewSubscriptionRequestMessage(tickers []string) ([]handlers.WebsocketEncode
return nil, fmt.Errorf("tickers cannot be empty")
}

numMessages := (numTickers / MaxArgsPerRequest) + 1
messages := make([]handlers.WebsocketEncodedMessage, numMessages)

for i := range messages {
start := i * MaxArgsPerRequest
end := (i + 1) * MaxArgsPerRequest

var argTickers []string
if i == numMessages-1 {
// if we are on the last message, truncate
argTickers = tickers[start:]
} else {
argTickers = tickers[start:end]
}
messages := make([]handlers.WebsocketEncodedMessage, len(tickers))

for i, ticker := range tickers {

bz, err := json.Marshal(
SubscriptionRequest{
BaseRequest: BaseRequest{
Op: string(OperationSubscribe),
},
Args: argTickers,
Args: []string{ticker},
},
)
if err != nil {
Expand Down
38 changes: 22 additions & 16 deletions providers/websockets/bybit/ws_data_hander_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,13 @@ func TestCreateMessage(t *testing.T) {
testCases := []struct {
name string
cps []types.ProviderTicker
expected func() []byte
expected func() [][]byte
expectedErr bool
}{
{
name: "no currency pairs",
cps: []types.ProviderTicker{},
expected: func() []byte {
expected: func() [][]byte {
return nil
},
expectedErr: true,
Expand All @@ -210,7 +210,7 @@ func TestCreateMessage(t *testing.T) {
cps: []types.ProviderTicker{
btcusdt,
},
expected: func() []byte {
expected: func() [][]byte {
msg := bybit.SubscriptionRequest{
BaseRequest: bybit.BaseRequest{
Op: string(bybit.OperationSubscribe),
Expand All @@ -221,7 +221,7 @@ func TestCreateMessage(t *testing.T) {
bz, err := json.Marshal(msg)
require.NoError(t, err)

return bz
return [][]byte{bz}
},
expectedErr: false,
},
Expand All @@ -231,18 +231,21 @@ func TestCreateMessage(t *testing.T) {
btcusdt,
ethusdt,
},
expected: func() []byte {
msg := bybit.SubscriptionRequest{
BaseRequest: bybit.BaseRequest{
Op: string(bybit.OperationSubscribe),
},
Args: []string{"tickers.BTCUSDT", "tickers.ETHUSDT"},
expected: func() [][]byte {
msgs := make([][]byte, 2)
for i, ticker := range []string{"tickers.BTCUSDT", "tickers.ETHUSDT"} {
msg := bybit.SubscriptionRequest{
BaseRequest: bybit.BaseRequest{
Op: string(bybit.OperationSubscribe),
},
Args: []string{ticker},
}
bz, err := json.Marshal(msg)
require.NoError(t, err)
msgs[i] = bz
}

bz, err := json.Marshal(msg)
require.NoError(t, err)

return bz
return msgs
},
expectedErr: false,
},
Expand All @@ -259,8 +262,11 @@ func TestCreateMessage(t *testing.T) {
return
}

require.Equal(t, 1, len(msgs))
require.EqualValues(t, tc.expected(), []byte(msgs[0]))
expected := tc.expected()
require.Equal(t, len(expected), len(msgs))
for i, msg := range msgs {
require.EqualValues(t, expected[i], msg)
}
})
}
}
20 changes: 12 additions & 8 deletions providers/websockets/coinbase/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,20 @@ func NewSubscribeRequestMessage(instruments []string) ([]handlers.WebsocketEncod
return nil, fmt.Errorf("no instruments provided")
}

bz, err := json.Marshal(SubscribeRequestMessage{
Type: string(SubscribeMessage),
ProductIDs: instruments,
Channels: []string{string(TickerChannel)},
})
if err != nil {
return nil, fmt.Errorf("failed to marshal subscribe request message %w", err)
msgs := make([]handlers.WebsocketEncodedMessage, len(instruments))
for i, instrument := range instruments {
bz, err := json.Marshal(SubscribeRequestMessage{
Type: string(SubscribeMessage),
ProductIDs: []string{instrument},
Channels: []string{string(TickerChannel)},
})
if err != nil {
return nil, fmt.Errorf("failed to marshal subscribe request message %w", err)
}
msgs[i] = bz
}

return []handlers.WebsocketEncodedMessage{bz}, nil
return msgs, nil
}

// SubscribeResponseMessage represents a subscribe response message.
Expand Down
29 changes: 16 additions & 13 deletions providers/websockets/coinbase/ws_data_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,21 +255,24 @@ func TestCreateMessages(t *testing.T) {
ethusd,
},
expected: func() []handlers.WebsocketEncodedMessage {
msg := coinbase.SubscribeRequestMessage{
Type: string(coinbase.SubscribeMessage),
ProductIDs: []string{
"BTC-USD",
"ETH-USD",
},
Channels: []string{
string(coinbase.TickerChannel),
},
}
msgs := make([]handlers.WebsocketEncodedMessage, 2)
for i, ticker := range []string{"BTC-USD", "ETH-USD"} {
msg := coinbase.SubscribeRequestMessage{
Type: string(coinbase.SubscribeMessage),
ProductIDs: []string{
ticker,
},
Channels: []string{
string(coinbase.TickerChannel),
},
}

bz, err := json.Marshal(msg)
require.NoError(t, err)
bz, err := json.Marshal(msg)
require.NoError(t, err)
msgs[i] = bz
}

return []handlers.WebsocketEncodedMessage{bz}
return msgs
},
expectedErr: false,
},
Expand Down
21 changes: 14 additions & 7 deletions providers/websockets/cryptodotcom/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,21 @@ func NewInstrumentMessage(instruments []string) ([]handlers.WebsocketEncodedMess
return nil, fmt.Errorf("no instruments specified")
}

bz, err := json.Marshal(InstrumentRequestMessage{
Method: string(InstrumentMethod),
Params: InstrumentParams{
Channels: instruments,
},
})
msgs := make([]handlers.WebsocketEncodedMessage, len(instruments))
for i, instrument := range instruments {
bz, err := json.Marshal(InstrumentRequestMessage{
Method: string(InstrumentMethod),
Params: InstrumentParams{
Channels: []string{instrument},
},
})
if err != nil {
return msgs, err
}
msgs[i] = bz
}

return []handlers.WebsocketEncodedMessage{bz}, err
return msgs, nil
}

// InstrumentResponseMessage is the response received from the Crypto.com websocket API when subscribing
Expand Down
54 changes: 37 additions & 17 deletions providers/websockets/cryptodotcom/ws_data_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,24 +292,26 @@ func TestCreateMessage(t *testing.T) {
testCases := []struct {
name string
cps []types.ProviderTicker
msg cryptodotcom.InstrumentRequestMessage
msgs []cryptodotcom.InstrumentRequestMessage
expectedErr bool
}{
{
name: "no currency pairs",
cps: []types.ProviderTicker{},
msg: cryptodotcom.InstrumentRequestMessage{},
msgs: []cryptodotcom.InstrumentRequestMessage{},
expectedErr: true,
},
{
name: "one currency pair",
cps: []types.ProviderTicker{
btcusd,
},
msg: cryptodotcom.InstrumentRequestMessage{
Method: "subscribe",
Params: cryptodotcom.InstrumentParams{
Channels: []string{"ticker.BTCUSD-PERP"},
msgs: []cryptodotcom.InstrumentRequestMessage{
{
Method: "subscribe",
Params: cryptodotcom.InstrumentParams{
Channels: []string{"ticker.BTCUSD-PERP"},
},
},
},
expectedErr: false,
Expand All @@ -321,13 +323,29 @@ func TestCreateMessage(t *testing.T) {
ethusd,
solusd,
},
msg: cryptodotcom.InstrumentRequestMessage{
Method: "subscribe",
Params: cryptodotcom.InstrumentParams{
Channels: []string{
"ticker.BTCUSD-PERP",
"ticker.ETHUSD-PERP",
"ticker.SOLUSD-PERP",
msgs: []cryptodotcom.InstrumentRequestMessage{
{
Method: "subscribe",
Params: cryptodotcom.InstrumentParams{
Channels: []string{
"ticker.BTCUSD-PERP",
},
},
},
{
Method: "subscribe",
Params: cryptodotcom.InstrumentParams{
Channels: []string{
"ticker.ETHUSD-PERP",
},
},
},
{
Method: "subscribe",
Params: cryptodotcom.InstrumentParams{
Channels: []string{
"ticker.SOLUSD-PERP",
},
},
},
},
Expand All @@ -347,10 +365,12 @@ func TestCreateMessage(t *testing.T) {
}
require.NoError(t, err)

expectedBz, err := json.Marshal(tc.msg)
require.NoError(t, err)
require.Equal(t, 1, len(msgs))
require.EqualValues(t, expectedBz, msgs[0])
require.Equal(t, len(tc.msgs), len(msgs))
for i := range tc.msgs {
expectedBz, err := json.Marshal(tc.msgs[i])
require.NoError(t, err)
require.EqualValues(t, expectedBz, msgs[i])
}
})
}
}
29 changes: 18 additions & 11 deletions providers/websockets/gate/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,24 @@ func NewSubscribeRequest(symbols []string) ([]handlers.WebsocketEncodedMessage,
return nil, fmt.Errorf("cannot attach payload of 0 length")
}

bz, err := json.Marshal(SubscribeRequest{
BaseMessage: BaseMessage{
Time: time.Now().UTC().Second(),
Channel: string(ChannelTickers),
Event: string(EventSubscribe),
},
ID: time.Now().UTC().Second(),
Payload: symbols,
})

return []handlers.WebsocketEncodedMessage{bz}, err
msgs := make([]handlers.WebsocketEncodedMessage, len(symbols))
for i, symbol := range symbols {
bz, err := json.Marshal(SubscribeRequest{
BaseMessage: BaseMessage{
Time: time.Now().UTC().Second(),
Channel: string(ChannelTickers),
Event: string(EventSubscribe),
},
ID: time.Now().UTC().Second(),
Payload: []string{symbol},
})
if err != nil {
return msgs, err
}
msgs[i] = bz
}

return msgs, nil
}

// SubscribeResponse is a subscription response sent from the Gate.io websocket API.
Expand Down
48 changes: 28 additions & 20 deletions providers/websockets/gate/ws_data_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,20 +274,24 @@ func TestCreateMessage(t *testing.T) {
ethusdt,
},
expected: func() []handlers.WebsocketEncodedMessage {
msg := gate.SubscribeRequest{
BaseMessage: gate.BaseMessage{
Time: time.Now().Second(),
Channel: string(gate.ChannelTickers),
Event: string(gate.EventSubscribe),
},
ID: time.Now().Second(),
Payload: []string{"BTC_USDT", "ETH_USDT"},
msgs := make([]handlers.WebsocketEncodedMessage, 2)
for i, ticker := range []string{"BTC_USDT", "ETH_USDT"} {
msg := gate.SubscribeRequest{
BaseMessage: gate.BaseMessage{
Time: time.Now().Second(),
Channel: string(gate.ChannelTickers),
Event: string(gate.EventSubscribe),
},
ID: time.Now().Second(),
Payload: []string{ticker},
}

bz, err := json.Marshal(msg)
require.NoError(t, err)
msgs[i] = bz
}

bz, err := json.Marshal(msg)
require.NoError(t, err)

return []handlers.WebsocketEncodedMessage{bz}
return msgs
},
expectedErr: false,
},
Expand All @@ -309,15 +313,19 @@ func TestCreateMessage(t *testing.T) {
expectedMsg gate.SubscribeRequest
)

// need to check the non-time based fields
err = json.Unmarshal(msgs[0], &gotMsg)
require.NoError(t, err)
err = json.Unmarshal(tc.expected()[0], &expectedMsg)
require.NoError(t, err)
expected := tc.expected()
require.Equal(t, len(expected), len(msgs))
for i := range expected {
// need to check the non-time based fields
err = json.Unmarshal(msgs[i], &gotMsg)
require.NoError(t, err)
err = json.Unmarshal(expected[i], &expectedMsg)
require.NoError(t, err)

require.Equal(t, expectedMsg.Event, gotMsg.Event)
require.Equal(t, expectedMsg.Channel, gotMsg.Channel)
require.Equal(t, expectedMsg.Payload, gotMsg.Payload)
require.Equal(t, expectedMsg.Event, gotMsg.Event)
require.Equal(t, expectedMsg.Channel, gotMsg.Channel)
require.Equal(t, expectedMsg.Payload, gotMsg.Payload)
}
})
}
}
Loading

0 comments on commit 29104c3

Please sign in to comment.