Skip to content

Commit

Permalink
feat: Adding retry decorator (#360)
Browse files Browse the repository at this point in the history
* feat: Adding retry decoratoR
  • Loading branch information
TomMcL authored Mar 20, 2023
1 parent a8d6941 commit 16bd076
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 42 deletions.
4 changes: 3 additions & 1 deletion vega_sim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,9 @@ def _stream_handler(
market_id = getattr(event, "market_id", getattr(event, "market", None))
asset_decimals = asset_dp.get(
getattr(
event, "asset", mkt_to_asset[market_id] if market_id is not None else None
event,
"asset",
mkt_to_asset[market_id] if market_id is not None else None,
)
)

Expand Down
44 changes: 43 additions & 1 deletion vega_sim/api/data_raw.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import datetime
import logging
from collections import namedtuple
from functools import wraps
from typing import Callable, Iterable, List, Optional, TypeVar, Union
import datetime

import vega_sim.grpc.client as vac
import vega_sim.proto.data_node.api.v2 as data_node_protos_v2
Expand All @@ -19,6 +20,27 @@
U = TypeVar("U")


def _retry(num_retry_attempts: int = 3):
"""Automatically retries a function a certain number
of times. Swallows any errors raised by earlier attempts
and will raise the last one if still failing.
"""

def retry_decorator(fn):
@wraps(fn)
def auto_retry_fn(*args, **kwargs):
for i in range(num_retry_attempts):
try:
return fn(*args, **kwargs)
except Exception as e:
if (i + 1) == num_retry_attempts:
raise e

return auto_retry_fn

return retry_decorator


def unroll_v2_pagination(
base_request: S,
request_func: Callable[[S], T],
Expand All @@ -38,6 +60,7 @@ def unroll_v2_pagination(
return full_list


@_retry(3)
def positions_by_market(
pub_key: str,
data_client: vac.VegaTradingDataClientV2,
Expand All @@ -58,6 +81,7 @@ def positions_by_market(
)


@_retry(3)
def all_markets(
data_client: vac.VegaTradingDataClientV2,
) -> List[vega_protos.markets.Market]:
Expand All @@ -71,6 +95,7 @@ def all_markets(
)


@_retry(3)
def market_info(
market_id: str,
data_client: vac.VegaTradingDataClientV2,
Expand All @@ -83,6 +108,7 @@ def market_info(
).market


@_retry(3)
def list_assets(data_client: vac.VegaTradingDataClientV2):
return unroll_v2_pagination(
base_request=data_node_protos_v2.trading_data.ListAssetsRequest(),
Expand All @@ -91,6 +117,7 @@ def list_assets(data_client: vac.VegaTradingDataClientV2):
)


@_retry(3)
def asset_info(
asset_id: str,
data_client: vac.VegaTradingDataClientV2,
Expand All @@ -111,6 +138,7 @@ def asset_info(
).asset


@_retry(3)
def list_accounts(
data_client: vac.VegaTradingDataClientV2,
asset_id: Optional[str] = None,
Expand All @@ -137,6 +165,7 @@ def list_accounts(
)


@_retry(3)
def market_accounts(
asset_id: str,
market_id: str,
Expand All @@ -155,6 +184,7 @@ def market_accounts(
)


@_retry(3)
def get_latest_market_data(
market_id: str,
data_client: vac.VegaTradingDataClientV2,
Expand All @@ -167,6 +197,7 @@ def get_latest_market_data(
).market_data


@_retry(3)
def market_data_history(
market_id: str,
start: datetime.datetime,
Expand All @@ -188,6 +219,7 @@ def market_data_history(
)


@_retry(3)
def infrastructure_fee_accounts(
asset_id: str,
data_client: vac.VegaTradingDataClientV2,
Expand All @@ -212,6 +244,7 @@ def infrastructure_fee_accounts(
)


@_retry(3)
def list_orders(
data_client: vac.VegaTradingDataClientV2,
market_id: str = None,
Expand Down Expand Up @@ -256,6 +289,7 @@ def list_orders(
)


@_retry(3)
def order_status(
order_id: str, data_client: vac.VegaTradingDataClientV2, version: int = 0
) -> Optional[vega_protos.vega.Order]:
Expand All @@ -282,6 +316,7 @@ def order_status(
).order


@_retry(3)
def market_depth(
market_id: str,
data_client: vac.VegaTradingDataClientV2,
Expand All @@ -294,6 +329,7 @@ def market_depth(
)


@_retry(3)
def liquidity_provisions(
data_client: vac.VegaTradingDataClientV2,
market_id: Optional[str] = None,
Expand Down Expand Up @@ -356,6 +392,7 @@ def observe_event_bus(
return data_client.ObserveEventBus(iter([request]))


@_retry(3)
def margin_levels(
data_client: vac.VegaTradingDataClientV2,
party_id: str,
Expand All @@ -370,6 +407,7 @@ def margin_levels(
)


@_retry(3)
def get_trades(
data_client: vac.VegaTradingDataClientV2,
market_id: str,
Expand All @@ -385,6 +423,7 @@ def get_trades(
)


@_retry(3)
def get_network_parameter(
data_client: vac.VegaTradingDataClientV2,
key: str,
Expand All @@ -406,6 +445,7 @@ def get_network_parameter(
).network_parameter


@_retry(3)
def list_transfers(
data_client: vac.VegaTradingDataClientV2,
party_id: Optional[str] = None,
Expand Down Expand Up @@ -448,6 +488,7 @@ def list_transfers(
)


@_retry(3)
def list_ledger_entries(
data_client: vac.VegaTradingDataClientV2,
close_on_account_filters: bool = False,
Expand Down Expand Up @@ -580,6 +621,7 @@ def _data_gen(
return _data_gen(data_stream=data_stream)


@_retry(3)
def get_risk_factors(
data_client: vac.VegaTradingDataClientV2,
market_id: str,
Expand Down
26 changes: 10 additions & 16 deletions vega_sim/local_data_cache.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,22 @@
from __future__ import annotations

import copy
import grpc
import logging
import threading
import traceback
from collections import defaultdict
from queue import Queue, Empty
from itertools import product, chain
from typing import (
Any,
Dict,
List,
Optional,
Set,
Tuple,
Union,
Callable,
)
from itertools import chain, product
from queue import Empty, Queue
from types import GeneratorType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import vega_sim.api.data as data
import vega_sim.api.data_raw as data_raw

import vega_sim.grpc.client as vac
import vega_sim.proto.vega as vega_protos
import vega_sim.proto.vega.events.v1.events_pb2 as events_protos


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -68,8 +59,11 @@ def _queue_forwarder(
sink.put(elem)
else:
sink.put(output)
except Exception:
logger.info("Data cache event bus closed")
except grpc._channel._MultiThreadedRendezvous as e:
if e.details() == "Socket closed":
logger.info("Data cache event bus closed")
else:
raise e


class DecimalsCache(defaultdict):
Expand Down
42 changes: 18 additions & 24 deletions vega_sim/scenario/common/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,30 +278,24 @@ def step(self, vega_state: VegaState):
)

def place_order(self, vega_state: VegaState, volume: float, side: vega_protos.Side):
try:
if (
(
vega_state.market_state[self.market_id].trading_mode
== markets_protos.Market.TradingMode.TRADING_MODE_CONTINUOUS
)
and vega_state.market_state[self.market_id].state
== markets_protos.Market.State.STATE_ACTIVE
and volume != 0
):
self.vega.submit_market_order(
trading_key=self.key_name,
market_id=self.market_id,
side=side,
volume=volume,
wait=False,
fill_or_kill=False,
trading_wallet=self.wallet_name,
)
except:
import pdb

pdb.set_trace()
a = 4
if (
(
vega_state.market_state[self.market_id].trading_mode
== markets_protos.Market.TradingMode.TRADING_MODE_CONTINUOUS
)
and vega_state.market_state[self.market_id].state
== markets_protos.Market.State.STATE_ACTIVE
and volume != 0
):
self.vega.submit_market_order(
trading_key=self.key_name,
market_id=self.market_id,
side=side,
volume=volume,
wait=False,
fill_or_kill=False,
trading_wallet=self.wallet_name,
)


class PriceSensitiveLimitOrderTrader(StateAgentWithWallet):
Expand Down

0 comments on commit 16bd076

Please sign in to comment.