diff --git a/Dockerfile b/Dockerfile index 9f484f033..58716a311 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,7 +28,9 @@ FROM vegasim_base AS vegasim_test COPY pytest.ini . COPY ./requirements-dev.txt . +COPY ./requirements-learning.txt . RUN pip install -r requirements-dev.txt +RUN pip install -r requirements-learning.txt COPY ./examples ./examples COPY ./pyproject.toml ./pyproject.toml diff --git a/Jenkinsfile b/Jenkinsfile index 24d58444a..58596b287 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -59,9 +59,9 @@ pipeline { script { vegaMarketSim ignoreFailure: false, timeout: 90, - vegaMarketSim: commitHash, + vegaMarketSimBranch: commitHash, vegaVersion: params.VEGA_VERSION, - jenkinsSharedLib: params.JENKINS_SHARED_LIB_BRANCH + jenkinsSharedLibBranch: params.JENKINS_SHARED_LIB_BRANCH } } } diff --git a/tests/integration/test_reinforcement.py b/tests/integration/test_reinforcement.py new file mode 100644 index 000000000..aafa0c7d1 --- /dev/null +++ b/tests/integration/test_reinforcement.py @@ -0,0 +1,9 @@ +import pytest + + +@pytest.mark.integration +def test_rl_run(): + # Simply testing that it doesn't error + import vega_sim.reinforcement.run_rl_agent as rl + + rl._run(1) diff --git a/vega_sim/reinforcement/agents/learning_agent.py b/vega_sim/reinforcement/agents/learning_agent.py index 24da13cb6..24cd940cd 100644 --- a/vega_sim/reinforcement/agents/learning_agent.py +++ b/vega_sim/reinforcement/agents/learning_agent.py @@ -2,7 +2,7 @@ from dataclasses import dataclass import numpy as np from collections import namedtuple, defaultdict -from typing import List, Tuple +from typing import List, Tuple, Dict from functools import partial from tqdm import tqdm @@ -44,10 +44,10 @@ def state_fn( service: VegaServiceNull, - agents: List[Agent], + agents: Dict[str, Agent], state_values=None, ) -> Tuple[LAMarketState, AbstractAction]: - learner = [a for a in agents if isinstance(a, LearningAgent)][0] + learner = agents["learner"] return (learner.latest_state, learner.latest_action) @@ -234,7 +234,8 @@ def finalise(self): if account.margin > 0: print( - "Market should be settled but there is still balance in margin account. What's up?" + "Market should be settled but there is still balance in margin account." + " What's up?" ) self.latest_action = self.empty_action() diff --git a/vega_sim/reinforcement/run_rl_agent.py b/vega_sim/reinforcement/run_rl_agent.py index 39591db08..97b014c4c 100644 --- a/vega_sim/reinforcement/run_rl_agent.py +++ b/vega_sim/reinforcement/run_rl_agent.py @@ -1,35 +1,23 @@ import argparse import logging -from typing import List, Optional, Tuple import os -import torch -import time - +from logging import getLogger -from vega_sim.reinforcement.la_market_state import LAMarketState, AbstractAction -from vega_sim.reinforcement.agents.learning_agent import ( - LearningAgent, - WALLET as LEARNING_WALLET, - state_fn, +import torch +from vega_sim.null_service import VegaServiceNull +from vega_sim.reinforcement.agents.learning_agent import WALLET as LEARNING_WALLET +from vega_sim.reinforcement.agents.learning_agent import LearningAgent, state_fn +from vega_sim.reinforcement.agents.learning_agent_heuristic import ( + LearningAgentHeuristic, ) +from vega_sim.reinforcement.agents.learning_agent_MO import LearningAgentFixedVol from vega_sim.reinforcement.agents.learning_agent_MO_with_vol import ( LearningAgentWithVol, ) -from vega_sim.reinforcement.agents.learning_agent_MO import LearningAgentFixedVol -from vega_sim.reinforcement.agents.learning_agent_heuristic import ( - LearningAgentHeuristic, -) - -from vega_sim.scenario.registry import IdealMarketMakerV2 -from vega_sim.scenario.registry import CurveMarketMaker - from vega_sim.reinforcement.helpers import set_seed -from vega_sim.null_service import VegaServiceNull - - from vega_sim.reinforcement.plot import plot_learning, plot_pnl, plot_simulation - -from logging import getLogger +from vega_sim.scenario.registry import CurveMarketMaker +from vega_sim.scenario.common.agents import Snitch logger = getLogger(__name__) @@ -63,90 +51,51 @@ def run_iteration( scenario.agents = scenario.configure_agents( vega=vega, tag=str(step_tag), random_state=None ) + # add the learning agaent to the environment's list of agents + learning_agent.set_market_tag(str(step_tag)) + learning_agent.price_process = scenario.price_process + scenario.agents["learner"] = learning_agent + + scenario.agents["snitch"] = Snitch( + agents=scenario.agents, additional_state_fn=scenario.state_extraction_fn + ) scenario.env = scenario.configure_environment( vega=vega, tag=str(step_tag), ) - # add the learning agaent to the environement's list of agents - scenario.env.agents = scenario.env.agents + [learning_agent] - - learning_agent.set_market_tag(str(step_tag)) - learning_agent.price_process = scenario.price_process - - result = scenario.env.run( + scenario.env.run( run_with_console=run_with_console, pause_at_completion=pause_at_completion, ) + + result = scenario.get_additional_run_data() + # Update the memory of the learning agent with the simulated data learning_agent.update_memory(result) return result -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - - parser = argparse.ArgumentParser() - parser.add_argument("-n", "--num-procs", default=6, type=int) - parser.add_argument( - "--rl-max-it", - default=1, - type=int, - help="Number of iterations of policy improvement + policy iterations", - ) - parser.add_argument("--use_cuda", action="store_true", default=False) - parser.add_argument("--use_mps", action="store_true", default=False) - parser.add_argument("--device", default=0, type=int) - parser.add_argument("--results_dir", default="numerical_results", type=str) - parser.add_argument( - "--evaluate", - default=0, - type=int, - help="If true, do not train and directly run the chosen number of evaluations", - ) - parser.add_argument("--resume_training", action="store_true") - parser.add_argument("--plot_every_step", action="store_true") - parser.add_argument("--plot_only", action="store_true") - - args = parser.parse_args() - - # set device - if torch.cuda.is_available() and args.use_cuda: - device = "cuda:{}".format(args.device) - elif torch.backends.mps.is_available() and args.use_mps: - device = torch.device("mps") - logger.warn( - "WARNING: as of today this will likely crash due to mps not implementing" - " all required functionality." - ) - else: - device = "cpu" - - # create results dir - if not os.path.exists(args.results_dir): - os.makedirs(args.results_dir) - logfile_pol_imp = os.path.join(args.results_dir, "learning_pol_imp.csv") - logfile_pol_eval = os.path.join(args.results_dir, "learning_pol_eval.csv") - logfile_pnl = os.path.join(args.results_dir, "learning_pnl.csv") - - if args.plot_only: - plot_learning( - results_dir=args.results_dir, - logfile_pol_eval=logfile_pol_eval, - logfile_pol_imp=logfile_pol_imp, - ) - plot_pnl(results_dir=args.results_dir, logfile_pnl=logfile_pnl) - exit(0) - +def _run( + max_iterations: int, + results_dir: str = "numerical_results", + resume_training: bool = False, + evaluate_only: bool = False, + plot_every_step: bool = False, + device: str = "cpu", +): # set seed for results replication set_seed(1) # set market name market_name = "ETH:USD" position_decimals = 2 - initial_price = 1000 + + logfile_pol_imp = os.path.join(results_dir, "learning_pol_imp.csv") + logfile_pol_eval = os.path.join(results_dir, "learning_pol_eval.csv") + logfile_pnl = os.path.join(results_dir, "learning_pnl.csv") # create the Learning Agent learning_agent = LearningAgentFixedVol( @@ -172,12 +121,12 @@ def run_iteration( ) as vega: vega.wait_for_total_catchup() - if args.evaluate == 0: - logger.info(f"Running training for {args.rl_max_it} iterations") + if not evaluate_only: + logger.info(f"Running training for {max_iterations} iterations") # TRAINING OF AGENT - if args.resume_training: - logger.info("Loading neural net weights from: " + args.results_dir) - learning_agent.load(args.results_dir) + if resume_training: + logger.info("Loading neural net weights from: " + results_dir) + learning_agent.load(results_dir) else: with open(logfile_pol_imp, "w") as f: f.write("iteration,loss\n") @@ -186,7 +135,7 @@ def run_iteration( with open(logfile_pnl, "w") as f: f.write("iteration,pnl\n") - for it in range(args.rl_max_it): + for it in range(max_iterations): # simulation of market to get some data learning_agent.move_to_cpu() @@ -205,11 +154,11 @@ def run_iteration( learning_agent.policy_improvement(batch_size=100_000, n_epochs=10) # save in case environment chooses to crash - learning_agent.save(args.results_dir) + learning_agent.save(results_dir) - if args.plot_every_step: + if plot_every_step: plot_learning( - results_dir=args.results_dir, + results_dir=results_dir, logfile_pol_eval=logfile_pol_eval, logfile_pol_imp=logfile_pol_imp, ) @@ -240,3 +189,68 @@ def run_iteration( ) learning_agent.lerningIteration += 1 + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + parser = argparse.ArgumentParser() + parser.add_argument("-n", "--num-procs", default=6, type=int) + parser.add_argument( + "--rl-max-it", + default=1, + type=int, + help="Number of iterations of policy improvement + policy iterations", + ) + parser.add_argument("--use_cuda", action="store_true", default=False) + parser.add_argument("--use_mps", action="store_true", default=False) + parser.add_argument("--device", default=0, type=int) + parser.add_argument("--results_dir", default="numerical_results", type=str) + parser.add_argument( + "--evaluate", + default=0, + type=int, + help="If true, do not train and directly run the chosen number of evaluations", + ) + parser.add_argument("--resume_training", action="store_true") + parser.add_argument("--plot_every_step", action="store_true") + parser.add_argument("--plot_only", action="store_true") + + args = parser.parse_args() + + # set device + if torch.cuda.is_available() and args.use_cuda: + device = "cuda:{}".format(args.device) + elif torch.backends.mps.is_available() and args.use_mps: + device = torch.device("mps") + logger.warn( + "WARNING: as of today this will likely crash due to mps not implementing" + " all required functionality." + ) + else: + device = "cpu" + + # create results dir + if not os.path.exists(args.results_dir): + os.makedirs(args.results_dir) + logfile_pol_imp = os.path.join(args.results_dir, "learning_pol_imp.csv") + logfile_pol_eval = os.path.join(args.results_dir, "learning_pol_eval.csv") + logfile_pnl = os.path.join(args.results_dir, "learning_pnl.csv") + + if args.plot_only: + plot_learning( + results_dir=args.results_dir, + logfile_pol_eval=logfile_pol_eval, + logfile_pol_imp=logfile_pol_imp, + ) + plot_pnl(results_dir=args.results_dir, logfile_pnl=logfile_pnl) + exit(0) + + _run( + max_iterations=args.rl_max_it, + results_dir=args.results_dir, + resume_training=args.resume_training, + evaluate_only=args.evaluate, + plot_every_step=args.plot_every_step, + device=device, + ) diff --git a/vega_sim/reinforcement/run_simple_agent.py b/vega_sim/reinforcement/run_simple_agent.py index d97b7b135..34ca35ae1 100644 --- a/vega_sim/reinforcement/run_simple_agent.py +++ b/vega_sim/reinforcement/run_simple_agent.py @@ -6,6 +6,7 @@ from vega_sim.reinforcement.agents.simple_agent import SimpleAgent from vega_sim.scenario.registry import CurveMarketMaker +from vega_sim.scenario.common.agents import Snitch from vega_sim.reinforcement.helpers import set_seed from vega_sim.null_service import VegaServiceNull @@ -36,18 +37,25 @@ def run_iteration( random_agent_ordering=False, sigma=100, ) - env = scenario.set_up_background_market( - vega=vega, - ) - # add the learning agaent to the environement's list of agents - env.agents = env.agents + [learning_agent] + scenario.agents = scenario.configure_agents(vega=vega, random_state=None) + # add the learning agaent to the environment's list of agents learning_agent.price_process = scenario.price_process + scenario.agents["learner"] = learning_agent + + scenario.agents["snitch"] = Snitch( + agents=scenario.agents, additional_state_fn=scenario.state_extraction_fn + ) - result = env.run( + scenario.env = scenario.configure_environment(vega=vega) + + scenario.env.run( run_with_console=run_with_console, pause_at_completion=pause_at_completion, ) + + result = scenario.get_additional_run_data() + return result