Skip to content

Commit

Permalink
feat: Adding RL test and fixing up scenario (#287)
Browse files Browse the repository at this point in the history
* test: Add RL test and fix up agent
  • Loading branch information
TomMcL authored Dec 21, 2022
1 parent 649944d commit 403bc6b
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 106 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ FROM vegasim_base AS vegasim_test
COPY pytest.ini .

COPY ./requirements-dev.txt .
COPY ./requirements-learning.txt .
RUN pip install -r requirements-dev.txt
RUN pip install -r requirements-learning.txt

COPY ./examples ./examples
COPY ./pyproject.toml ./pyproject.toml
Expand Down
4 changes: 2 additions & 2 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ pipeline {
script {
vegaMarketSim ignoreFailure: false,
timeout: 90,
vegaMarketSim: commitHash,
vegaMarketSimBranch: commitHash,
vegaVersion: params.VEGA_VERSION,
jenkinsSharedLib: params.JENKINS_SHARED_LIB_BRANCH
jenkinsSharedLibBranch: params.JENKINS_SHARED_LIB_BRANCH
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/test_reinforcement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest


@pytest.mark.integration
def test_rl_run():
# Simply testing that it doesn't error
import vega_sim.reinforcement.run_rl_agent as rl

rl._run(1)
9 changes: 5 additions & 4 deletions vega_sim/reinforcement/agents/learning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
import numpy as np
from collections import namedtuple, defaultdict
from typing import List, Tuple
from typing import List, Tuple, Dict

from functools import partial
from tqdm import tqdm
Expand Down Expand Up @@ -44,10 +44,10 @@

def state_fn(
service: VegaServiceNull,
agents: List[Agent],
agents: Dict[str, Agent],
state_values=None,
) -> Tuple[LAMarketState, AbstractAction]:
learner = [a for a in agents if isinstance(a, LearningAgent)][0]
learner = agents["learner"]
return (learner.latest_state, learner.latest_action)


Expand Down Expand Up @@ -234,7 +234,8 @@ def finalise(self):

if account.margin > 0:
print(
"Market should be settled but there is still balance in margin account. What's up?"
"Market should be settled but there is still balance in margin account."
" What's up?"
)

self.latest_action = self.empty_action()
Expand Down
202 changes: 108 additions & 94 deletions vega_sim/reinforcement/run_rl_agent.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,23 @@
import argparse
import logging
from typing import List, Optional, Tuple
import os
import torch
import time

from logging import getLogger

from vega_sim.reinforcement.la_market_state import LAMarketState, AbstractAction
from vega_sim.reinforcement.agents.learning_agent import (
LearningAgent,
WALLET as LEARNING_WALLET,
state_fn,
import torch
from vega_sim.null_service import VegaServiceNull
from vega_sim.reinforcement.agents.learning_agent import WALLET as LEARNING_WALLET
from vega_sim.reinforcement.agents.learning_agent import LearningAgent, state_fn
from vega_sim.reinforcement.agents.learning_agent_heuristic import (
LearningAgentHeuristic,
)
from vega_sim.reinforcement.agents.learning_agent_MO import LearningAgentFixedVol
from vega_sim.reinforcement.agents.learning_agent_MO_with_vol import (
LearningAgentWithVol,
)
from vega_sim.reinforcement.agents.learning_agent_MO import LearningAgentFixedVol
from vega_sim.reinforcement.agents.learning_agent_heuristic import (
LearningAgentHeuristic,
)

from vega_sim.scenario.registry import IdealMarketMakerV2
from vega_sim.scenario.registry import CurveMarketMaker

from vega_sim.reinforcement.helpers import set_seed
from vega_sim.null_service import VegaServiceNull


from vega_sim.reinforcement.plot import plot_learning, plot_pnl, plot_simulation

from logging import getLogger
from vega_sim.scenario.registry import CurveMarketMaker
from vega_sim.scenario.common.agents import Snitch

logger = getLogger(__name__)

Expand Down Expand Up @@ -63,90 +51,51 @@ def run_iteration(
scenario.agents = scenario.configure_agents(
vega=vega, tag=str(step_tag), random_state=None
)
# add the learning agaent to the environment's list of agents
learning_agent.set_market_tag(str(step_tag))
learning_agent.price_process = scenario.price_process
scenario.agents["learner"] = learning_agent

scenario.agents["snitch"] = Snitch(
agents=scenario.agents, additional_state_fn=scenario.state_extraction_fn
)

scenario.env = scenario.configure_environment(
vega=vega,
tag=str(step_tag),
)

# add the learning agaent to the environement's list of agents
scenario.env.agents = scenario.env.agents + [learning_agent]

learning_agent.set_market_tag(str(step_tag))
learning_agent.price_process = scenario.price_process

result = scenario.env.run(
scenario.env.run(
run_with_console=run_with_console,
pause_at_completion=pause_at_completion,
)

result = scenario.get_additional_run_data()

# Update the memory of the learning agent with the simulated data
learning_agent.update_memory(result)

return result


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

parser = argparse.ArgumentParser()
parser.add_argument("-n", "--num-procs", default=6, type=int)
parser.add_argument(
"--rl-max-it",
default=1,
type=int,
help="Number of iterations of policy improvement + policy iterations",
)
parser.add_argument("--use_cuda", action="store_true", default=False)
parser.add_argument("--use_mps", action="store_true", default=False)
parser.add_argument("--device", default=0, type=int)
parser.add_argument("--results_dir", default="numerical_results", type=str)
parser.add_argument(
"--evaluate",
default=0,
type=int,
help="If true, do not train and directly run the chosen number of evaluations",
)
parser.add_argument("--resume_training", action="store_true")
parser.add_argument("--plot_every_step", action="store_true")
parser.add_argument("--plot_only", action="store_true")

args = parser.parse_args()

# set device
if torch.cuda.is_available() and args.use_cuda:
device = "cuda:{}".format(args.device)
elif torch.backends.mps.is_available() and args.use_mps:
device = torch.device("mps")
logger.warn(
"WARNING: as of today this will likely crash due to mps not implementing"
" all required functionality."
)
else:
device = "cpu"

# create results dir
if not os.path.exists(args.results_dir):
os.makedirs(args.results_dir)
logfile_pol_imp = os.path.join(args.results_dir, "learning_pol_imp.csv")
logfile_pol_eval = os.path.join(args.results_dir, "learning_pol_eval.csv")
logfile_pnl = os.path.join(args.results_dir, "learning_pnl.csv")

if args.plot_only:
plot_learning(
results_dir=args.results_dir,
logfile_pol_eval=logfile_pol_eval,
logfile_pol_imp=logfile_pol_imp,
)
plot_pnl(results_dir=args.results_dir, logfile_pnl=logfile_pnl)
exit(0)

def _run(
max_iterations: int,
results_dir: str = "numerical_results",
resume_training: bool = False,
evaluate_only: bool = False,
plot_every_step: bool = False,
device: str = "cpu",
):
# set seed for results replication
set_seed(1)

# set market name
market_name = "ETH:USD"
position_decimals = 2
initial_price = 1000

logfile_pol_imp = os.path.join(results_dir, "learning_pol_imp.csv")
logfile_pol_eval = os.path.join(results_dir, "learning_pol_eval.csv")
logfile_pnl = os.path.join(results_dir, "learning_pnl.csv")

# create the Learning Agent
learning_agent = LearningAgentFixedVol(
Expand All @@ -172,12 +121,12 @@ def run_iteration(
) as vega:
vega.wait_for_total_catchup()

if args.evaluate == 0:
logger.info(f"Running training for {args.rl_max_it} iterations")
if not evaluate_only:
logger.info(f"Running training for {max_iterations} iterations")
# TRAINING OF AGENT
if args.resume_training:
logger.info("Loading neural net weights from: " + args.results_dir)
learning_agent.load(args.results_dir)
if resume_training:
logger.info("Loading neural net weights from: " + results_dir)
learning_agent.load(results_dir)
else:
with open(logfile_pol_imp, "w") as f:
f.write("iteration,loss\n")
Expand All @@ -186,7 +135,7 @@ def run_iteration(
with open(logfile_pnl, "w") as f:
f.write("iteration,pnl\n")

for it in range(args.rl_max_it):
for it in range(max_iterations):
# simulation of market to get some data

learning_agent.move_to_cpu()
Expand All @@ -205,11 +154,11 @@ def run_iteration(
learning_agent.policy_improvement(batch_size=100_000, n_epochs=10)

# save in case environment chooses to crash
learning_agent.save(args.results_dir)
learning_agent.save(results_dir)

if args.plot_every_step:
if plot_every_step:
plot_learning(
results_dir=args.results_dir,
results_dir=results_dir,
logfile_pol_eval=logfile_pol_eval,
logfile_pol_imp=logfile_pol_imp,
)
Expand Down Expand Up @@ -240,3 +189,68 @@ def run_iteration(
)

learning_agent.lerningIteration += 1


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

parser = argparse.ArgumentParser()
parser.add_argument("-n", "--num-procs", default=6, type=int)
parser.add_argument(
"--rl-max-it",
default=1,
type=int,
help="Number of iterations of policy improvement + policy iterations",
)
parser.add_argument("--use_cuda", action="store_true", default=False)
parser.add_argument("--use_mps", action="store_true", default=False)
parser.add_argument("--device", default=0, type=int)
parser.add_argument("--results_dir", default="numerical_results", type=str)
parser.add_argument(
"--evaluate",
default=0,
type=int,
help="If true, do not train and directly run the chosen number of evaluations",
)
parser.add_argument("--resume_training", action="store_true")
parser.add_argument("--plot_every_step", action="store_true")
parser.add_argument("--plot_only", action="store_true")

args = parser.parse_args()

# set device
if torch.cuda.is_available() and args.use_cuda:
device = "cuda:{}".format(args.device)
elif torch.backends.mps.is_available() and args.use_mps:
device = torch.device("mps")
logger.warn(
"WARNING: as of today this will likely crash due to mps not implementing"
" all required functionality."
)
else:
device = "cpu"

# create results dir
if not os.path.exists(args.results_dir):
os.makedirs(args.results_dir)
logfile_pol_imp = os.path.join(args.results_dir, "learning_pol_imp.csv")
logfile_pol_eval = os.path.join(args.results_dir, "learning_pol_eval.csv")
logfile_pnl = os.path.join(args.results_dir, "learning_pnl.csv")

if args.plot_only:
plot_learning(
results_dir=args.results_dir,
logfile_pol_eval=logfile_pol_eval,
logfile_pol_imp=logfile_pol_imp,
)
plot_pnl(results_dir=args.results_dir, logfile_pnl=logfile_pnl)
exit(0)

_run(
max_iterations=args.rl_max_it,
results_dir=args.results_dir,
resume_training=args.resume_training,
evaluate_only=args.evaluate,
plot_every_step=args.plot_every_step,
device=device,
)
20 changes: 14 additions & 6 deletions vega_sim/reinforcement/run_simple_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from vega_sim.reinforcement.agents.simple_agent import SimpleAgent

from vega_sim.scenario.registry import CurveMarketMaker
from vega_sim.scenario.common.agents import Snitch

from vega_sim.reinforcement.helpers import set_seed
from vega_sim.null_service import VegaServiceNull
Expand Down Expand Up @@ -36,18 +37,25 @@ def run_iteration(
random_agent_ordering=False,
sigma=100,
)
env = scenario.set_up_background_market(
vega=vega,
)
# add the learning agaent to the environement's list of agents
env.agents = env.agents + [learning_agent]

scenario.agents = scenario.configure_agents(vega=vega, random_state=None)
# add the learning agaent to the environment's list of agents
learning_agent.price_process = scenario.price_process
scenario.agents["learner"] = learning_agent

scenario.agents["snitch"] = Snitch(
agents=scenario.agents, additional_state_fn=scenario.state_extraction_fn
)

result = env.run(
scenario.env = scenario.configure_environment(vega=vega)

scenario.env.run(
run_with_console=run_with_console,
pause_at_completion=pause_at_completion,
)

result = scenario.get_additional_run_data()

return result


Expand Down

0 comments on commit 403bc6b

Please sign in to comment.