diff --git a/protocol/streaming/full_node_streaming_manager.go b/protocol/streaming/full_node_streaming_manager.go index af0d5671cd..29717366fc 100644 --- a/protocol/streaming/full_node_streaming_manager.go +++ b/protocol/streaming/full_node_streaming_manager.go @@ -7,6 +7,7 @@ import ( "time" "github.com/dydxprotocol/v4-chain/protocol/lib" + pricestypes "github.com/dydxprotocol/v4-chain/protocol/x/prices/types" satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types" "cosmossdk.io/log" @@ -52,6 +53,8 @@ type FullNodeStreamingManagerImpl struct { clobPairIdToSubscriptionIdMapping map[uint32][]uint32 // map from subaccount id to subscription ids. subaccountIdToSubscriptionIdMapping map[satypes.SubaccountId][]uint32 + // map from market id to subscription ids. + marketIdToSubscriptionIdMapping map[uint32][]uint32 maxUpdatesInCache uint32 maxSubscriptionChannelSize uint32 @@ -79,6 +82,9 @@ type OrderbookSubscription struct { // Subaccount ids to subscribe to. subaccountIds []satypes.SubaccountId + // market ids to subscribe to. + marketIds []uint32 + // Stream messageSender types.OutgoingMessageSender @@ -114,6 +120,7 @@ func NewFullNodeStreamingManager( streamUpdateSubscriptionCache: make([][]uint32, 0), clobPairIdToSubscriptionIdMapping: make(map[uint32][]uint32), subaccountIdToSubscriptionIdMapping: make(map[satypes.SubaccountId][]uint32), + marketIdToSubscriptionIdMapping: make(map[uint32][]uint32), maxUpdatesInCache: maxUpdatesInCache, maxSubscriptionChannelSize: maxSubscriptionChannelSize, @@ -184,6 +191,7 @@ func (sm *FullNodeStreamingManagerImpl) getNextAvailableSubscriptionId() uint32 func (sm *FullNodeStreamingManagerImpl) Subscribe( clobPairIds []uint32, subaccountIds []*satypes.SubaccountId, + marketIds []uint32, messageSender types.OutgoingMessageSender, ) ( err error, @@ -206,6 +214,7 @@ func (sm *FullNodeStreamingManagerImpl) Subscribe( initialized: &atomic.Bool{}, // False by default. clobPairIds: clobPairIds, subaccountIds: sIds, + marketIds: marketIds, messageSender: messageSender, updatesChannel: make(chan []clobtypes.StreamUpdate, sm.maxSubscriptionChannelSize), } @@ -231,6 +240,17 @@ func (sm *FullNodeStreamingManagerImpl) Subscribe( subscription.subscriptionId, ) } + for _, marketId := range marketIds { + // if subaccountId exists in the map, append the subscription id to the slice + // otherwise, create a new slice with the subscription id + if _, ok := sm.marketIdToSubscriptionIdMapping[marketId]; !ok { + sm.marketIdToSubscriptionIdMapping[marketId] = []uint32{} + } + sm.marketIdToSubscriptionIdMapping[marketId] = append( + sm.marketIdToSubscriptionIdMapping[marketId], + subscription.subscriptionId, + ) + } sm.logger.Info( fmt.Sprintf( @@ -325,6 +345,21 @@ func (sm *FullNodeStreamingManagerImpl) removeSubscription( } } + // Iterate over the marketIdToSubscriptionIdMapping to remove the subscriptionIdToRemove + for marketId, subscriptionIds := range sm.marketIdToSubscriptionIdMapping { + for i, id := range subscriptionIds { + if id == subscriptionIdToRemove { + // Remove the subscription ID from the slice + sm.marketIdToSubscriptionIdMapping[marketId] = append(subscriptionIds[:i], subscriptionIds[i+1:]...) + break + } + } + // If the list is empty after removal, delete the key from the map + if len(sm.marketIdToSubscriptionIdMapping[marketId]) == 0 { + delete(sm.marketIdToSubscriptionIdMapping, marketId) + } + } + sm.logger.Info( fmt.Sprintf("Removed streaming subscription id %+v", subscriptionIdToRemove), ) @@ -372,6 +407,24 @@ func toSubaccountStreamUpdates( return streamUpdates } +func toPriceStreamUpdates( + priceUpdates []*pricestypes.StreamPriceUpdate, + blockHeight uint32, + execMode sdk.ExecMode, +) []clobtypes.StreamUpdate { + streamUpdates := make([]clobtypes.StreamUpdate, 0) + for _, update := range priceUpdates { + streamUpdates = append(streamUpdates, clobtypes.StreamUpdate{ + UpdateMessage: &clobtypes.StreamUpdate_PriceUpdate{ + PriceUpdate: update, + }, + BlockHeight: blockHeight, + ExecMode: uint32(execMode), + }) + } + return streamUpdates +} + func (sm *FullNodeStreamingManagerImpl) sendStreamUpdates( subscriptionId uint32, streamUpdates []clobtypes.StreamUpdate, @@ -466,6 +519,7 @@ func (sm *FullNodeStreamingManagerImpl) GetStagedFinalizeBlockEvents( func (sm *FullNodeStreamingManagerImpl) SendCombinedSnapshot( offchainUpdates *clobtypes.OffchainUpdates, saUpdates []*satypes.StreamSubaccountUpdate, + priceUpdates []*pricestypes.StreamPriceUpdate, subscriptionId uint32, blockHeight uint32, execMode sdk.ExecMode, @@ -479,6 +533,7 @@ func (sm *FullNodeStreamingManagerImpl) SendCombinedSnapshot( var streamUpdates []clobtypes.StreamUpdate streamUpdates = append(streamUpdates, toOrderbookStreamUpdate(offchainUpdates, blockHeight, execMode)...) streamUpdates = append(streamUpdates, toSubaccountStreamUpdates(saUpdates, blockHeight, execMode)...) + streamUpdates = append(streamUpdates, toPriceStreamUpdates(priceUpdates, blockHeight, execMode)...) sm.sendStreamUpdates(subscriptionId, streamUpdates) } @@ -863,6 +918,30 @@ func (sm *FullNodeStreamingManagerImpl) GetSubaccountSnapshotsForInitStreams( return ret } +func (sm *FullNodeStreamingManagerImpl) GetPriceSnapshotsForInitStreams( + getPriceSnapshot func(marketId uint32) *pricestypes.StreamPriceUpdate, +) map[uint32]*pricestypes.StreamPriceUpdate { + sm.Lock() + defer sm.Unlock() + + ret := make(map[uint32]*pricestypes.StreamPriceUpdate) + for _, subscription := range sm.orderbookSubscriptions { + // If the subscription has been initialized, no need to grab the price snapshot. + if alreadyInitialized := subscription.initialized.Load(); alreadyInitialized { + continue + } + + for _, marketId := range subscription.marketIds { + if _, exists := ret[marketId]; exists { + continue + } + + ret[marketId] = getPriceSnapshot(marketId) + } + } + return ret +} + // cacheStreamUpdatesByClobPairWithLock adds stream updates to cache, // and store corresponding clob pair Ids. // This method requires the lock and assumes that the lock has already been @@ -1003,6 +1082,7 @@ func (sm *FullNodeStreamingManagerImpl) getStagedEventsFromFinalizeBlock( func (sm *FullNodeStreamingManagerImpl) InitializeNewStreams( getOrderbookSnapshot func(clobPairId clobtypes.ClobPairId) *clobtypes.OffchainUpdates, subaccountSnapshots map[satypes.SubaccountId]*satypes.StreamSubaccountUpdate, + pricesSnapshots map[uint32]*pricestypes.StreamPriceUpdate, blockHeight uint32, execMode sdk.ExecMode, ) { @@ -1038,7 +1118,28 @@ func (sm *FullNodeStreamingManagerImpl) InitializeNewStreams( } } - sm.SendCombinedSnapshot(allUpdates, saUpdates, subscriptionId, blockHeight, execMode) + priceUpdates := []*pricestypes.StreamPriceUpdate{} + for _, marketId := range subscription.marketIds { + if priceUpdate, ok := pricesSnapshots[marketId]; ok { + priceUpdates = append(priceUpdates, priceUpdate) + } else { + sm.logger.Error( + fmt.Sprintf( + "Price update not found for market id %v. This should not happen.", + marketId, + ), + ) + } + } + + sm.SendCombinedSnapshot( + allUpdates, + saUpdates, + priceUpdates, + subscriptionId, + blockHeight, + execMode, + ) if sm.snapshotBlockInterval != 0 { subscription.nextSnapshotBlock = blockHeight + sm.snapshotBlockInterval diff --git a/protocol/streaming/noop_streaming_manager.go b/protocol/streaming/noop_streaming_manager.go index 89250854c4..6d6dd41608 100644 --- a/protocol/streaming/noop_streaming_manager.go +++ b/protocol/streaming/noop_streaming_manager.go @@ -4,6 +4,7 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/dydxprotocol/v4-chain/protocol/streaming/types" clobtypes "github.com/dydxprotocol/v4-chain/protocol/x/clob/types" + pricestypes "github.com/dydxprotocol/v4-chain/protocol/x/prices/types" satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types" ) @@ -22,6 +23,7 @@ func (sm *NoopGrpcStreamingManager) Enabled() bool { func (sm *NoopGrpcStreamingManager) Subscribe( _ []uint32, _ []*satypes.SubaccountId, + _ []uint32, _ types.OutgoingMessageSender, ) ( err error, @@ -58,9 +60,16 @@ func (sm *NoopGrpcStreamingManager) GetSubaccountSnapshotsForInitStreams( return nil } +func (sm *NoopGrpcStreamingManager) GetPriceSnapshotsForInitStreams( + _ func(_ uint32) *pricestypes.StreamPriceUpdate, +) map[uint32]*pricestypes.StreamPriceUpdate { + return nil +} + func (sm *NoopGrpcStreamingManager) InitializeNewStreams( getOrderbookSnapshot func(clobPairId clobtypes.ClobPairId) *clobtypes.OffchainUpdates, subaccountSnapshots map[satypes.SubaccountId]*satypes.StreamSubaccountUpdate, + priceSnapshots map[uint32]*pricestypes.StreamPriceUpdate, blockHeight uint32, execMode sdk.ExecMode, ) { diff --git a/protocol/streaming/types/interface.go b/protocol/streaming/types/interface.go index 33907fc1ec..028352a10b 100644 --- a/protocol/streaming/types/interface.go +++ b/protocol/streaming/types/interface.go @@ -3,6 +3,7 @@ package types import ( sdk "github.com/cosmos/cosmos-sdk/types" clobtypes "github.com/dydxprotocol/v4-chain/protocol/x/clob/types" + pricestypes "github.com/dydxprotocol/v4-chain/protocol/x/prices/types" satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types" ) @@ -14,6 +15,7 @@ type FullNodeStreamingManager interface { Subscribe( clobPairIds []uint32, subaccountIds []*satypes.SubaccountId, + marketIds []uint32, srv OutgoingMessageSender, ) ( err error, @@ -23,12 +25,16 @@ type FullNodeStreamingManager interface { InitializeNewStreams( getOrderbookSnapshot func(clobPairId clobtypes.ClobPairId) *clobtypes.OffchainUpdates, subaccountSnapshots map[satypes.SubaccountId]*satypes.StreamSubaccountUpdate, + priceSnapshots map[uint32]*pricestypes.StreamPriceUpdate, blockHeight uint32, execMode sdk.ExecMode, ) GetSubaccountSnapshotsForInitStreams( getSubaccountSnapshot func(subaccountId satypes.SubaccountId) *satypes.StreamSubaccountUpdate, ) map[satypes.SubaccountId]*satypes.StreamSubaccountUpdate + GetPriceSnapshotsForInitStreams( + getPriceSnapshot func(marketId uint32) *pricestypes.StreamPriceUpdate, + ) map[uint32]*pricestypes.StreamPriceUpdate SendOrderbookUpdates( offchainUpdates *clobtypes.OffchainUpdates, ctx sdk.Context, diff --git a/protocol/streaming/ws/websocket_server.go b/protocol/streaming/ws/websocket_server.go index ba4477d03b..3cbb9219a4 100644 --- a/protocol/streaming/ws/websocket_server.go +++ b/protocol/streaming/ws/websocket_server.go @@ -16,6 +16,11 @@ import ( "github.com/gorilla/websocket" ) +const ( + CLOB_PAIR_IDS_QUERY_PARAM = "clobPairIds" + MARKET_IDS_QUERY_PARAM = "marketIds" +) + var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, @@ -61,7 +66,7 @@ func (ws *WebsocketServer) Handler(w http.ResponseWriter, r *http.Request) { conn.SetReadLimit(10 * 1024 * 1024) // Parse clobPairIds from query parameters - clobPairIds, err := parseClobPairIds(r) + clobPairIds, err := parseUint32(r, CLOB_PAIR_IDS_QUERY_PARAM) if err != nil { ws.logger.Error( "Error parsing clobPairIds", @@ -70,6 +75,18 @@ func (ws *WebsocketServer) Handler(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } + + // Parse marketIds from query parameters + marketIds, err := parseUint32(r, MARKET_IDS_QUERY_PARAM) + if err != nil { + ws.logger.Error( + "Error parsing marketIds", + "err", err, + ) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + // Parse subaccountIds from query parameters subaccountIds, err := parseSubaccountIds(r) if err != nil { @@ -93,6 +110,7 @@ func (ws *WebsocketServer) Handler(w http.ResponseWriter, r *http.Request) { err = ws.streamingManager.Subscribe( clobPairIds, subaccountIds, + marketIds, websocketMessageSender, ) if err != nil { @@ -136,26 +154,26 @@ func parseSubaccountIds(r *http.Request) ([]*satypes.SubaccountId, error) { return subaccountIds, nil } -// parseClobPairIds is a helper function to parse the clobPairIds from the query parameters. -func parseClobPairIds(r *http.Request) ([]uint32, error) { - clobPairIdsParam := r.URL.Query().Get("clobPairIds") - if clobPairIdsParam == "" { +// parseUint32 is a helper function to parse the uint32 from the query parameters. +func parseUint32(r *http.Request, queryParam string) ([]uint32, error) { + param := r.URL.Query().Get(queryParam) + if param == "" { return []uint32{}, nil } - idStrs := strings.Split(clobPairIdsParam, ",") - clobPairIds := make([]uint32, 0) + idStrs := strings.Split(param, ",") + ids := make([]uint32, 0) for _, idStr := range idStrs { id, err := strconv.Atoi(idStr) if err != nil { - return nil, fmt.Errorf("invalid clobPairId: %s", idStr) + return nil, fmt.Errorf("invalid %s: %s", queryParam, idStr) } if id < 0 || id > math.MaxInt32 { - return nil, fmt.Errorf("invalid clob pair id: %s", idStr) + return nil, fmt.Errorf("invalid %s: %s", queryParam, idStr) } - clobPairIds = append(clobPairIds, uint32(id)) + ids = append(ids, uint32(id)) } - return clobPairIds, nil + return ids, nil } // Start the websocket server in a separate goroutine. diff --git a/protocol/x/clob/keeper/grpc_stream_orderbook.go b/protocol/x/clob/keeper/grpc_stream_orderbook.go index caca5fbfbe..029266901a 100644 --- a/protocol/x/clob/keeper/grpc_stream_orderbook.go +++ b/protocol/x/clob/keeper/grpc_stream_orderbook.go @@ -11,6 +11,7 @@ func (k Keeper) StreamOrderbookUpdates( err := k.GetFullNodeStreamingManager().Subscribe( req.GetClobPairId(), req.GetSubaccountIds(), + req.GetMarketIds(), stream, ) if err != nil { diff --git a/protocol/x/clob/keeper/keeper.go b/protocol/x/clob/keeper/keeper.go index e371e91039..4d2303bb0d 100644 --- a/protocol/x/clob/keeper/keeper.go +++ b/protocol/x/clob/keeper/keeper.go @@ -23,6 +23,7 @@ import ( flags "github.com/dydxprotocol/v4-chain/protocol/x/clob/flags" "github.com/dydxprotocol/v4-chain/protocol/x/clob/rate_limit" "github.com/dydxprotocol/v4-chain/protocol/x/clob/types" + pricestypes "github.com/dydxprotocol/v4-chain/protocol/x/prices/types" ) type ( @@ -331,6 +332,25 @@ func (k Keeper) GetSubaccountSnapshotsForInitStreams( ) } +func (k Keeper) GetPriceSnapshotsForInitStreams( + ctx sdk.Context, +) ( + priceSnapshots map[uint32]*pricestypes.StreamPriceUpdate, +) { + lib.AssertCheckTxMode(ctx) + + return k.GetFullNodeStreamingManager().GetPriceSnapshotsForInitStreams( + func(marketId uint32) *pricestypes.StreamPriceUpdate { + update := k.pricesKeeper.GetStreamPriceUpdate( + ctx, + marketId, + true, + ) + return &update + }, + ) +} + // InitializeNewStreams initializes new streams for all uninitialized clob pairs // by sending the corresponding orderbook snapshots. func (k Keeper) InitializeNewStreams( @@ -339,6 +359,8 @@ func (k Keeper) InitializeNewStreams( ) { streamingManager := k.GetFullNodeStreamingManager() + priceSnapshots := k.GetPriceSnapshotsForInitStreams(ctx) + streamingManager.InitializeNewStreams( func(clobPairId types.ClobPairId) *types.OffchainUpdates { return k.MemClob.GetOffchainUpdatesForOrderbookSnapshot( @@ -347,6 +369,7 @@ func (k Keeper) InitializeNewStreams( ) }, subaccountSnapshots, + priceSnapshots, lib.MustConvertIntegerToUint32(ctx.BlockHeight()), ctx.ExecMode(), ) diff --git a/protocol/x/clob/types/expected_keepers.go b/protocol/x/clob/types/expected_keepers.go index a1e3bafd3a..968db325a8 100644 --- a/protocol/x/clob/types/expected_keepers.go +++ b/protocol/x/clob/types/expected_keepers.go @@ -152,6 +152,7 @@ type PerpetualsKeeper interface { type PricesKeeper interface { GetMarketParam(ctx sdk.Context, id uint32) (param pricestypes.MarketParam, exists bool) + GetStreamPriceUpdate(ctx sdk.Context, id uint32, snapshot bool) (val pricestypes.StreamPriceUpdate) } type StatsKeeper interface { diff --git a/protocol/x/prices/keeper/market_price.go b/protocol/x/prices/keeper/market_price.go index 3cb0fdd278..78490b0670 100644 --- a/protocol/x/prices/keeper/market_price.go +++ b/protocol/x/prices/keeper/market_price.go @@ -193,3 +193,26 @@ func (k Keeper) GetMarketIdToValidIndexPrice( } return ret } + +// GetStreamPriceUpdate returns a stream price update from its id. +func (k Keeper) GetStreamPriceUpdate( + ctx sdk.Context, + id uint32, + snapshot bool, +) ( + val types.StreamPriceUpdate, +) { + price, err := k.GetMarketPrice(ctx, id) + if err != nil { + k.Logger(ctx).Error( + "failed to get market price for streaming", + "market id", id, + "error", err, + ) + } + return types.StreamPriceUpdate{ + MarketId: id, + Price: price, + Snapshot: snapshot, + } +}