Skip to content

Commit

Permalink
fix: correct plot oom (#591)
Browse files Browse the repository at this point in the history
* fix: remove broken plotting tool

* fix: avoid empty legend warnings

* feat: add logging to plot calls

* fix: update error message for check tagged versions
  • Loading branch information
cdummett authored Jan 15, 2024
1 parent 78d7378 commit 2aed8c9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
4 changes: 3 additions & 1 deletion scripts/check_tagged_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def check_versions():
tagged_vega_versions.add(tag["name"][5:])

if versions_to_check.difference(tagged_vega_versions):
raise Exception(f"Missing tags for vega versions: {versions_to_check}")
raise Exception(
f"Missing tags for vega versions: {versions_to_check - tagged_vega_versions}"
)


if __name__ == "__main__":
Expand Down
40 changes: 38 additions & 2 deletions vega_sim/tools/scenario_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@

import numpy as np

import logging
from functools import wraps

logger = logging.getLogger(__name__)


def log_plotting_function(func):
@wraps(func)
def wrapped_fn(*args, **kwargs):
logger.info(f"Calling func {func.__name__}")
return func(*args, **kwargs)

return wrapped_fn


TRADING_MODE_COLORS = {
0: (200 / 255, 200 / 255, 200 / 255),
Expand Down Expand Up @@ -137,6 +151,7 @@ def _series_df_to_single_series(
return df_clone.merge(selector, on=["time", "market_id"]).set_index("time")


@log_plotting_function
def plot_trading_summary(
ax: Axes,
trade_df: pd.DataFrame,
Expand Down Expand Up @@ -180,6 +195,7 @@ def plot_trading_summary(
)


@log_plotting_function
def plot_total_traded_volume(ax: Axes, trades_df: pd.DataFrame) -> None:
if trades_df.empty:
return
Expand All @@ -191,13 +207,15 @@ def plot_total_traded_volume(ax: Axes, trades_df: pd.DataFrame) -> None:
ax.plot(traded)


@log_plotting_function
def plot_open_interest(ax: Axes, market_data_df: pd.DataFrame):
ax.set_title("Open Interest")

_set_xticks(ax, market_data_df.index)
ax.plot(market_data_df["open_interest"])


@log_plotting_function
def plot_open_notional(
ax: Axes,
market_data_df: pd.DataFrame,
Expand All @@ -211,6 +229,7 @@ def plot_open_notional(
ax.plot(market_data_df["open_interest"] * price_df["price"])


@log_plotting_function
def plot_spread(ax: Axes, order_book_df: pd.DataFrame) -> None:
if order_book_df.empty:
return
Expand All @@ -229,6 +248,7 @@ def plot_spread(ax: Axes, order_book_df: pd.DataFrame) -> None:
ax.plot(spread_df)


@log_plotting_function
def plot_margin_totals(ax: Axes, accounts_df: pd.DataFrame) -> None:
if accounts_df.empty:
return
Expand All @@ -243,6 +263,7 @@ def plot_margin_totals(ax: Axes, accounts_df: pd.DataFrame) -> None:
ax.plot(grouped_df)


@log_plotting_function
def plot_target_stake(ax: Axes, market_data_df: pd.DataFrame) -> None:
if market_data_df.empty:
return
Expand All @@ -252,6 +273,7 @@ def plot_target_stake(ax: Axes, market_data_df: pd.DataFrame) -> None:
ax.plot(market_data_df["target_stake"])


@log_plotting_function
def plot_run_outputs(run_name: Optional[str] = None) -> list[Figure]:
order_df = load_order_book_df(run_name=run_name)
trades_df = load_trades_df(run_name=run_name)
Expand Down Expand Up @@ -299,7 +321,7 @@ def plot_run_outputs(run_name: Optional[str] = None) -> list[Figure]:
ax6 = plt.subplot(427)
ax7 = plt.subplot(428)

plot_trading_summary(ax, market_trades_df, market_order_df, market_mid_df)
# plot_trading_summary(ax, market_trades_df, market_order_df, market_mid_df)
plot_total_traded_volume(ax2, market_trades_df)
plot_spread(ax3, order_book_df=market_order_df)
plot_open_interest(ax4, market_data_df)
Expand All @@ -312,6 +334,7 @@ def plot_run_outputs(run_name: Optional[str] = None) -> list[Figure]:
return figs


@log_plotting_function
def plot_trading_mode(
fig: Figure, data_df: pd.DataFrame, ss: Optional[SubplotSpec] = None
):
Expand Down Expand Up @@ -391,6 +414,7 @@ def plot_trading_mode(
ax1.legend(labels=names, loc="lower right")


@log_plotting_function
def plot_price_comparison(
fig: Figure,
data_df: pd.DataFrame,
Expand Down Expand Up @@ -450,6 +474,7 @@ def plot_price_comparison(
ax0.legend(labels=["external price", "mark price"])


@log_plotting_function
def plot_risky_close_outs(
fig: Figure,
accounts_df: pd.DataFrame,
Expand Down Expand Up @@ -547,6 +572,7 @@ def plot_risky_close_outs(
ax1r.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))


@log_plotting_function
def fuzz_plots(run_name: Optional[str] = None) -> Figure:
data_df = load_market_data_df(run_name=run_name)
accounts_df = load_accounts_df(run_name=run_name)
Expand Down Expand Up @@ -606,6 +632,7 @@ def fuzz_plots(run_name: Optional[str] = None) -> Figure:
return figs


@log_plotting_function
def account_plots(run_name: Optional[str] = None, agent_types: Optional[list] = None):
accounts_df = load_accounts_df(run_name=run_name)
agents_df = load_agents_df(run_name=run_name)
Expand Down Expand Up @@ -657,6 +684,7 @@ def account_plots(run_name: Optional[str] = None, agent_types: Optional[list] =
return fig


@log_plotting_function
def plot_price_monitoring(
run_name: Optional[str] = None,
fig: Optional[Figure] = None,
Expand Down Expand Up @@ -815,6 +843,7 @@ def plot_price_monitoring(
return figs


@log_plotting_function
def resource_monitoring_plot(run_name: Optional[str] = None):
resource_df = load_resource_df(
run_name=run_name,
Expand Down Expand Up @@ -878,6 +907,7 @@ def resource_monitoring_plot(run_name: Optional[str] = None):
return fig


@log_plotting_function
def reward_plots(run_name: Optional[str] = None):
accounts_df = load_accounts_df(run_name=run_name)
assets_df = load_assets_df(run_name=run_name)
Expand Down Expand Up @@ -966,6 +996,8 @@ def reward_plots(run_name: Optional[str] = None):
axs[-1].set_title(f"Asset: {plot[2]}")
accounts_for_asset = joined_df[joined_df.symbol == str(plot[2])]
grouped = accounts_for_asset.groupby(["agent_type", "time"])["balance"].sum()
if len(grouped.index.get_level_values(0).unique().values) == 1:
continue
for index in grouped.index.get_level_values(0).unique():
if index == "RewardFunder":
continue
Expand All @@ -978,6 +1010,7 @@ def reward_plots(run_name: Optional[str] = None):
return fig


@log_plotting_function
def plot_account_by_party(
run_name: Optional[str] = None,
fig: Optional[Figure] = None,
Expand Down Expand Up @@ -1059,6 +1092,7 @@ def plot_account_by_party(
axs[-1].legend()


@log_plotting_function
def sla_plot(run_name: Optional[str] = None):
accounts_df = load_accounts_df(run_name=run_name)
ledger_entries_df = load_ledger_entries_df(run_name=run_name)
Expand Down Expand Up @@ -1205,6 +1239,7 @@ def determine_poi(row):
return fig


@log_plotting_function
def price_series_plot(
run_name: Optional[str] = None,
fig: Optional[Figure] = None,
Expand Down Expand Up @@ -1296,6 +1331,7 @@ def price_series_plot(
return fig


@log_plotting_function
def liquidation_plot(
run_name: Optional[str] = None,
fig: Optional[Figure] = None,
Expand Down Expand Up @@ -1448,6 +1484,7 @@ def liquidation_plot(
return fig


@log_plotting_function
def plot_overlay_auctions(ax: Axes, chained_market_data: pd.DataFrame):
twinx = ax.twinx()
twinx.set_ylim(0, 1)
Expand Down Expand Up @@ -1548,5 +1585,4 @@ def plot_overlay_auctions(ax: Axes, chained_market_data: pd.DataFrame):
fig.savefig(f"{dir}/price.jpg")

if args.show:
plt.ion()
plt.show()

0 comments on commit 2aed8c9

Please sign in to comment.