Skip to content

Commit

Permalink
feat: Adding random state fixing to parameter experiments to ensure c…
Browse files Browse the repository at this point in the history
…omparability (#107)

* feat: Adding random state fixing to parameter experiments to ensure comparability
  • Loading branch information
TomMcL authored Aug 10, 2022
1 parent e3e0895 commit ebd9916
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 55 deletions.
176 changes: 153 additions & 23 deletions poetry.lock

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ argon2-cffi-bindings==21.2.0; python_version >= "3.7"
argon2-cffi==21.3.0; python_version >= "3.7"
asttokens==2.0.5; python_full_version >= "3.7.0" and python_full_version < "4.0.0" and python_version >= "3.8"
atomicwrites==1.4.1; sys_platform == "win32" and python_version >= "3.7" and python_full_version >= "3.7.0" and python_full_version < "4.0.0"
attrs==21.4.0; python_full_version >= "3.7.0" and python_version >= "3.7" and python_full_version < "4.0.0"
attrs==22.1.0; python_full_version >= "3.7.0" and python_full_version < "4.0.0" and python_version >= "3.7"
backcall==0.2.0; python_full_version >= "3.7.0" and python_full_version < "4.0.0" and python_version >= "3.8"
beautifulsoup4==4.11.1; python_full_version >= "3.6.0" and python_version >= "3.7"
black==22.6.0; python_full_version >= "3.6.2"
Expand All @@ -22,8 +22,8 @@ executing==0.9.1; python_full_version >= "3.7.0" and python_full_version < "4.0.
fastjsonschema==2.16.1; python_full_version >= "3.7.0" and python_full_version < "4.0.0" and python_version >= "3.7"
flake8==4.0.1; python_version >= "3.6"
fonttools==4.34.4; python_version >= "3.7"
grpcio-tools==1.48.0rc1; python_version >= "3.6"
grpcio==1.48.0rc1; python_version >= "3.6"
grpcio-tools==1.48.0; python_version >= "3.6"
grpcio==1.48.0; python_version >= "3.6"
idna==3.3; python_version >= "3.7" and python_version < "4"
inflection==0.5.1; python_version >= "3.5"
iniconfig==1.1.1; python_full_version >= "3.7.0" and python_full_version < "4.0.0" and python_version >= "3.7"
Expand All @@ -33,7 +33,7 @@ ipython==8.4.0; python_full_version >= "3.7.0" and python_full_version < "4.0.0"
ipywidgets==7.7.1
jedi==0.18.1; python_full_version >= "3.7.0" and python_full_version < "4.0.0" and python_version >= "3.8"
jinja2==3.1.2; python_version >= "3.7"
jsonschema==4.7.2; python_full_version >= "3.7.0" and python_full_version < "4.0.0" and python_version >= "3.7"
jsonschema==4.9.0; python_full_version >= "3.7.0" and python_full_version < "4.0.0" and python_version >= "3.7"
jupyter-client==7.3.4; python_full_version >= "3.7.0" and python_full_version < "4.0.0" and python_version >= "3.7"
jupyter-core==4.11.1; python_full_version >= "3.7.0" and python_full_version < "4.0.0" and python_version >= "3.7"
jupyterlab-pygments==0.2.2; python_version >= "3.7"
Expand Down Expand Up @@ -85,7 +85,7 @@ pywinpty==2.0.6; os_name == "nt" and python_version >= "3.7"
pyzmq==23.2.0; python_full_version >= "3.7.0" and python_full_version < "4.0.0" and python_version >= "3.7"
requests-mock==1.9.3
requests==2.28.1; python_version >= "3.7" and python_version < "4"
scipy==1.8.1; python_version >= "3.8" and python_version < "3.11"
scipy==1.9.0; python_version >= "3.8" and python_version < "3.12"
send2trash==1.8.0; python_version >= "3.7"
setuptools-scm==7.0.5; python_version >= "3.7"
six==1.16.0; python_full_version >= "3.7.0" and python_version >= "3.8" and python_full_version < "4.0.0" and (python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.3.0" and python_version >= "3.7")
Expand Down
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ argon2-cffi-bindings==21.2.0; python_version >= "3.7"
argon2-cffi==21.3.0; python_version >= "3.7"
asttokens==2.0.5; python_version >= "3.8"
atomicwrites==1.4.1; python_version >= "3.7" and python_full_version < "3.0.0" and sys_platform == "win32" or sys_platform == "win32" and python_version >= "3.7" and python_full_version >= "3.4.0"
attrs==21.4.0; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.7"
attrs==22.1.0; python_version >= "3.7"
backcall==0.2.0; python_version >= "3.8"
beautifulsoup4==4.11.1; python_full_version >= "3.6.0" and python_version >= "3.7"
bleach==5.0.1; python_version >= "3.7"
Expand All @@ -19,8 +19,8 @@ entrypoints==0.4; python_version >= "3.7"
executing==0.9.1; python_version >= "3.8"
fastjsonschema==2.16.1; python_version >= "3.7"
fonttools==4.34.4; python_version >= "3.7"
grpcio-tools==1.48.0rc1; python_version >= "3.6"
grpcio==1.48.0rc1; python_version >= "3.6"
grpcio-tools==1.48.0; python_version >= "3.6"
grpcio==1.48.0; python_version >= "3.6"
idna==3.3; python_version >= "3.7" and python_version < "4"
inflection==0.5.1; python_version >= "3.5"
iniconfig==1.1.1; python_version >= "3.7"
Expand All @@ -30,7 +30,7 @@ ipython==8.4.0; python_version >= "3.8"
ipywidgets==7.7.1
jedi==0.18.1; python_version >= "3.8"
jinja2==3.1.2; python_version >= "3.7"
jsonschema==4.7.2; python_version >= "3.7"
jsonschema==4.9.0; python_version >= "3.7"
jupyter-client==7.3.4; python_full_version >= "3.7.0" and python_version >= "3.7"
jupyter-core==4.11.1; python_version >= "3.7"
jupyterlab-pygments==0.2.2; python_version >= "3.7"
Expand Down Expand Up @@ -73,7 +73,7 @@ pywin32==304; sys_platform == "win32" and platform_python_implementation != "PyP
pywinpty==2.0.6; os_name == "nt" and python_version >= "3.7"
pyzmq==23.2.0; python_version >= "3.7"
requests==2.28.1; python_version >= "3.7" and python_version < "4"
scipy==1.8.1; python_version >= "3.8" and python_version < "3.11"
scipy==1.9.0; python_version >= "3.8" and python_version < "3.12"
send2trash==1.8.0; python_version >= "3.7"
setuptools-scm==7.0.5; python_version >= "3.7"
six==1.16.0; python_version >= "3.8" and python_full_version < "3.0.0" or python_full_version >= "3.3.0" and python_version >= "3.8"
Expand Down
12 changes: 8 additions & 4 deletions vega_sim/parameter_test/parameter/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from typing import Any, Dict, List, Optional, Tuple
import pathlib
import numpy as np

from vega_sim.scenario.scenario import Scenario
from vega_sim.null_service import VegaServiceNull
Expand Down Expand Up @@ -52,6 +53,7 @@ def _run_parameter_iteration(
parameter_to_vary: str,
value: str,
additional_parameters_to_set: Optional[Dict[str, str]] = None,
random_state: Optional[np.random.RandomState] = None,
) -> Any:
with VegaServiceNull(
warn_on_raw_data_access=False, retain_log_files=True, run_with_console=False
Expand All @@ -73,25 +75,27 @@ def _run_parameter_iteration(
PARAMETER_AMEND_WALLET[0], parameter=parameter_to_vary, new_value=value
)

res = scenario.run_iteration(vega=vega)
# import pdb
res = scenario.run_iteration(vega=vega, random_state=random_state)

# pdb.set_trace()
return res


def run_single_parameter_experiment(
experiment: SingleParameterExperiment,
) -> Dict[str, List[Any]]:
results = {}
random_seeds = [
np.random.RandomState(i) for i in range(experiment.runs_per_scenario)
]
for value in experiment.values:
results[value] = []
for _ in range(experiment.runs_per_scenario):
for state in random_seeds:
results[value].append(
_run_parameter_iteration(
scenario=experiment.scenario,
parameter_to_vary=experiment.parameter_to_vary,
value=value,
random_state=state,
)
)
return results
Expand Down
1 change: 0 additions & 1 deletion vega_sim/scenario/ideal_market_maker/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def num_LimitOrderHit(self, bid_depth, ask_depth, num_buyMO, num_sellMO):
return num_BidLimitOrderHit, num_AskLimitOrderHit

def OptimalStrategy(self, current_position):

if current_position >= self.q_upper:
current_bid_depth = self.optimal_bid[self.current_step, 0]
current_ask_depth = 1 / 10**self.mdp
Expand Down
14 changes: 10 additions & 4 deletions vega_sim/scenario/ideal_market_maker/scenario.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import logging
import numpy as np
from typing import Any, Callable, List, Optional
from vega_sim.environment.agent import Agent

Expand Down Expand Up @@ -72,13 +73,15 @@ def set_up_background_market(
self,
vega: VegaServiceNull,
tag: str = "",
random_state: Optional[np.random.RandomState] = None,
) -> MarketEnvironment:
_, price_process = RW_model(
T=self.num_steps * self.dt,
dt=self.dt,
mdp=self.market_decimal,
sigma=self.sigma,
Midprice=self.initial_price,
random_state=random_state,
)

market_maker = OptimalMarketMaker(
Expand Down Expand Up @@ -151,10 +154,14 @@ def set_up_background_market(
)
return env

def run_iteration(self, vega: VegaServiceNull, pause_at_completion: bool = False):
def run_iteration(
self,
vega: VegaServiceNull,
pause_at_completion: bool = False,
random_state: Optional[np.random.RandomState] = None,
):
env = self.set_up_background_market(
vega=vega,
tag=str(0),
vega=vega, tag=str(0), random_state=random_state
)
result = env.run(
pause_at_completion=pause_at_completion,
Expand All @@ -163,7 +170,6 @@ def run_iteration(self, vega: VegaServiceNull, pause_at_completion: bool = False


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
Expand Down
19 changes: 13 additions & 6 deletions vega_sim/scenario/ideal_market_maker_v2/scenario.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import logging
import numpy as np
from typing import Any, Callable, List, Optional
from vega_sim.environment.agent import Agent

Expand Down Expand Up @@ -88,29 +89,34 @@ def __init__(
self.backgroundmarket_number_levels_per_side = (
backgroundmarket_number_levels_per_side
)
self.market_name = f"ETH:USD" if market_name is None else market_name
self.asset_name = f"tDAI" if asset_name is None else asset_name
self.market_name = "ETH:USD" if market_name is None else market_name
self.asset_name = "tDAI" if asset_name is None else asset_name

def _generate_price_process(self):
def _generate_price_process(
self,
random_state: Optional[np.random.RandomState] = None,
):
_, price_process = RW_model(
T=self.num_steps * self.dt,
dt=self.dt,
mdp=self.market_decimal,
sigma=self.sigma,
Midprice=self.initial_price,
random_state=random_state,
)
return price_process

def set_up_background_market(
self,
vega: VegaServiceNull,
tag: str = "",
random_state: Optional[np.random.RandomState] = None,
) -> MarketEnvironmentWithState:
# Set up market name and settlement asset
market_name = self.market_name + f"_{tag}"
asset_name = self.asset_name + f"_{tag}"

price_process = self._generate_price_process()
price_process = self._generate_price_process(random_state=random_state)

market_maker = OptimalMarketMaker(
wallet_name=MM_WALLET.name,
Expand Down Expand Up @@ -145,6 +151,7 @@ def set_up_background_market(
tag=str(tag),
buy_intensity=self.buy_intensity,
sell_intensity=self.sell_intensity,
random_state=random_state,
)

background_market = BackgroundMarket(
Expand Down Expand Up @@ -207,10 +214,10 @@ def run_iteration(
vega: VegaServiceNull,
pause_at_completion: bool = False,
run_with_console: bool = False,
random_state: Optional[np.random.RandomState] = None,
):
env = self.set_up_background_market(
vega=vega,
tag=str(0),
vega=vega, tag=str(0), random_state=random_state
)
result = env.run(
pause_at_completion=pause_at_completion,
Expand Down
16 changes: 10 additions & 6 deletions vega_sim/scenario/market_crash/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def __init__(
Callable[[VegaServiceNull, List[Agent]], Any]
] = None,
pause_every_n_steps: Optional[int] = None,
random_state: Optional[np.random.RandomState] = None,
trim_to_min: Optional[float] = None,
):
self.num_steps = num_steps
Expand Down Expand Up @@ -94,18 +93,19 @@ def __init__(
self.noise_buy_intensity = noise_buy_intensity
self.noise_sell_intensity = noise_sell_intensity
self.pause_every_n_steps = pause_every_n_steps
self.random_state = random_state
self.trim_to_min = trim_to_min
self.position_taker_mint = position_taker_mint
self.num_position_traders = num_position_traders
self.num_noise_traders = num_noise_traders
self.settle_at_end = settle_at_end
self.initial_asset_mint = initial_asset_mint

def _generate_price_process(self):
def _generate_price_process(
self,
random_state: Optional[np.random.RandomState] = None,
):
return regime_change_random_walk(
num_steps=self.num_steps + 1, # Number of steps plus 'initial' state
random_state=self.random_state,
sigma_pre=self.sigma_pre,
sigma_post=self.sigma_post,
drift_pre=self.drift_pre,
Expand All @@ -114,12 +114,14 @@ def _generate_price_process(self):
break_point=self.break_point,
decimal_precision=self.market_decimal,
trim_to_min=self.trim_to_min,
random_state=random_state,
)

def set_up_background_market(
self,
vega: VegaServiceNull,
tag: str = "",
random_state: Optional[np.random.RandomState] = None,
) -> MarketEnvironmentWithState:
self.market_name = f"BTC:DAI_{tag}"
self.asset_name = f"tDAI{tag}"
Expand Down Expand Up @@ -154,6 +156,7 @@ def set_up_background_market(
initial_asset_mint=self.position_taker_mint,
buy_intensity=self.noise_buy_intensity,
sell_intensity=self.noise_sell_intensity,
random_state=random_state,
)
)
for i in range(self.num_position_traders):
Expand All @@ -167,6 +170,7 @@ def set_up_background_market(
tag=f"{tag}_pos_{i}",
buy_intensity=self.position_taker_buy_intensity,
sell_intensity=self.position_taker_sell_intensity,
random_state=random_state,
)
)

Expand Down Expand Up @@ -238,10 +242,10 @@ def run_iteration(
vega: VegaServiceNull,
pause_at_completion: bool = False,
tag: Optional[str] = None,
random_state: Optional[np.random.RandomState] = None,
):
env = self.set_up_background_market(
vega=vega,
tag=tag if tag is not None else str(0),
vega=vega, tag=tag if tag is not None else str(0), random_state=random_state
)
result = env.run(
pause_at_completion=pause_at_completion,
Expand Down
9 changes: 8 additions & 1 deletion vega_sim/scenario/scenario.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import abc
import numpy as np
from typing import Optional

from vega_sim.null_service import VegaServiceNull


class Scenario(abc.ABC):
def run_iteration(self, vega: VegaServiceNull, pause_at_completion: bool = False):
def run_iteration(
self,
vega: VegaServiceNull,
pause_at_completion: bool = False,
random_state: Optional[np.random.RandomState] = None,
):
pass

0 comments on commit ebd9916

Please sign in to comment.