diff --git a/genrl/agents/__init__.py b/genrl/agents/__init__.py index 3257caff..f19f6432 100644 --- a/genrl/agents/__init__.py +++ b/genrl/agents/__init__.py @@ -41,5 +41,7 @@ from genrl.agents.deep.sac.sac import SAC # noqa from genrl.agents.deep.td3.td3 import TD3 # noqa from genrl.agents.deep.vpg.vpg import VPG # noqa +from genrl.agents.modelbased.base import ModelBasedAgent # noqa +from genrl.agents.modelbased.cem.cem import CEM # noqa from genrl.agents.bandits.multiarmed.base import MABAgent # noqa; noqa; noqa diff --git a/genrl/agents/deep/a2c/a2c.py b/genrl/agents/deep/a2c/a2c.py index 1f94992e..9baa2bab 100644 --- a/genrl/agents/deep/a2c/a2c.py +++ b/genrl/agents/deep/a2c/a2c.py @@ -95,8 +95,8 @@ def _create_model(self) -> None: ) actor_params, critic_params = self.ac.get_params() - self.optimizer_policy = opt.Adam(critic_params, lr=self.lr_policy) - self.optimizer_value = opt.Adam(actor_params, lr=self.lr_value) + self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy) + self.optimizer_value = opt.Adam(critic_params, lr=self.lr_value) def select_action( self, state: torch.Tensor, deterministic: bool = False diff --git a/genrl/agents/modelbased/__init__.py b/genrl/agents/modelbased/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/genrl/agents/modelbased/base.py b/genrl/agents/modelbased/base.py new file mode 100644 index 00000000..cb11e020 --- /dev/null +++ b/genrl/agents/modelbased/base.py @@ -0,0 +1,47 @@ +from abc import ABC + +import torch + +from genrl.agents import BaseAgent + + +class Planner: + def __init__(self, initial_state, dynamics_model=None): + if dynamics_model is not None: + self.dynamics_model = dynamics_model + self.initial_state = initial_state + + def _learn_dynamics_model(self, state): + raise NotImplementedError + + def plan(self): + raise NotImplementedError + + def execute_actions(self): + raise NotImplementedError + + +class ModelBasedAgent(BaseAgent): + def __init__(self, *args, planner=None, **kwargs): + super(ModelBasedAgent, self).__init__(*args, **kwargs) + self.planner = planner + + def plan(self): + """ + To be used to plan out a sequence of actions + """ + if self.planner is not None: + raise ValueError("Provide a planner to plan for the environment") + self.planner.plan() + + def generate_data(self): + """ + To be used to generate synthetic data via a model (may be learnt or specified beforehand) + """ + raise NotImplementedError + + def value_equivalence(self, state_space): + """ + To be used for approximate value estimation methods e.g. Value Iteration Networks + """ + raise NotImplementedError diff --git a/genrl/agents/modelbased/cem/__init__.py b/genrl/agents/modelbased/cem/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/genrl/agents/modelbased/cem/cem.py b/genrl/agents/modelbased/cem/cem.py new file mode 100644 index 00000000..5451bb96 --- /dev/null +++ b/genrl/agents/modelbased/cem/cem.py @@ -0,0 +1,235 @@ +from typing import Any, Dict + +import numpy as np +import torch +import torch.nn.functional as F + +from genrl.agents import ModelBasedAgent +from genrl.core import RolloutBuffer +from genrl.utils import get_env_properties, get_model, safe_mean + + +class CEM(ModelBasedAgent): + """Cross Entropy method algorithm (CEM) + + Attributes: + network (str): The type of network to be used + env (Environment): The environment the agent is supposed to act on + create_model (bool): Whether the model of the algo should be created when initialised + policy_layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network of the policy + lr_policy (float): learning rate of the policy + percentile (float): Top percentile of rewards to consider as elite + simulations_per_epoch (int): Number of simulations to perform before taking a gradient step + rollout_size (int): Capacity of the replay buffer + render (bool): Whether to render the environment or not + device (str): Hardware being used for training. Options: + ["cuda" -> GPU, "cpu" -> CPU] + """ + + def __init__( + self, + *args, + network: str = "mlp", + percentile: float = 70, + simulations_per_epoch: int = 1000, + rollout_size, + **kwargs + ): + super(CEM, self).__init__(*args, **kwargs) + self.network = network + self.rollout_size = rollout_size + self.rollout = RolloutBuffer(self.rollout_size, self.env) + self.percentile = percentile + self.simulations_per_epoch = simulations_per_epoch + + self._create_model() + self.empty_logs() + + def _create_model(self): + """Function to initialize the Policy + + This will create the Policy net for the CEM agent + """ + self.state_dim, self.action_dim, discrete, action_lim = get_env_properties( + self.env, self.network + ) + self.agent = get_model("p", self.network)( + self.state_dim, + self.action_dim, + self.policy_layers, + "V", + discrete, + action_lim, + ).to(self.device) + self.optim = torch.optim.Adam(self.agent.parameters(), lr=self.lr_policy) + + def plan(self): + """Function to plan out one episode + + Returns: + states (:obj:`list` of :obj:`torch.Tensor`): Batch of states the agent encountered in the episode + actions (:obj:`list` of :obj:`torch.Tensor`): Batch of actions the agent took in the episode + rewards (:obj:`torch.Tensor`): The episode reward obtained + """ + state = self.env.reset() + self.rollout.reset() + states, actions = self.collect_rollouts(state) + return (states, actions, self.rewards[-1]) + + def select_elites(self, states_batch, actions_batch, rewards_batch): + """Function to select the elite states and elite actions based on the episode reward + + Args: + states_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of states + actions_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of actions + rewards_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of rewards + + Returns: + elite_states (:obj:`torch.Tensor`): Elite batch of states based on episode reward + elite_actions (:obj:`torch.Tensor`): Actions the agent took during the elite batch of states + + """ + reward_threshold = np.percentile(rewards_batch, self.percentile) + elite_states = torch.cat( + [ + s.unsqueeze(0).clone() + for i in range(len(states_batch)) + if rewards_batch[i] >= reward_threshold + for s in states_batch[i] + ], + dim=0, + ) + elite_actions = torch.cat( + [ + a.unsqueeze(0).clone() + for i in range(len(actions_batch)) + if rewards_batch[i] >= reward_threshold + for a in actions_batch[i] + ], + dim=0, + ) + + return elite_states, elite_actions + + def select_action(self, state): + """Select action given state + + Action selection policy for the Cross Entropy agent + + Args: + state (:obj:`torch.Tensor`): Current state of the agent + + Returns: + action (:obj:`torch.Tensor`): Action taken by the agent + """ + state = torch.as_tensor(state).float() + action, dist = self.agent.get_action(state) + return action + + def update_params(self): + """Updates the the Policy network of the CEM agent + + Function to update the policy network + """ + sess = [self.plan() for _ in range(self.simulations_per_epoch)] + batch_states, batch_actions, batch_rewards = zip(*sess) + elite_states, elite_actions = self.select_elites( + batch_states, batch_actions, batch_rewards + ) + action_probs = self.agent.forward(elite_states.float().to(self.device)) + loss = F.cross_entropy( + action_probs.view(-1, self.action_dim), + elite_actions.long().view(-1), + ) + self.logs["crossentropy_loss"].append(loss.item()) + loss.backward() + # torch.nn.utils.clip_grad_norm_(self.agent.parameters(), 0.5) + self.optim.step() + + def get_traj_loss(self, values, dones): + # No need for this here + pass + + def collect_rollouts(self, state: torch.Tensor): + """Function to collect rollouts + + Collects rollouts by playing the env like a human agent and inputs information into + the rollout buffer. + + Args: + state (:obj:`torch.Tensor`): The starting state of the environment + + Returns: + states (:obj:`list`): list of states the agent encountered during the episode + actions (:obj:`list`): list of actions the agent took in the corresponding states + """ + states = [] + actions = [] + for i in range(self.rollout_size): + action = self.select_action(state) + states.append(state) + actions.append(action) + + next_state, reward, dones, _ = self.env.step(action) + + if self.render: + self.env.render() + + state = next_state + + self.collect_rewards(dones, i) + + if torch.any(dones.byte()): + break + + return states, actions + + def collect_rewards(self, dones: torch.Tensor, timestep: int): + """Helper function to collect rewards + + Runs through all the envs and collects rewards accumulated during rollouts + + Args: + dones (:obj:`torch.Tensor`): Game over statuses of each environment + timestep (int): Timestep during rollout + """ + for i, done in enumerate(dones): + if done or timestep == self.rollout_size - 1: + self.rewards.append(self.env.episode_reward[i].detach().clone()) + + def get_hyperparams(self) -> Dict[str, Any]: + """Get relevant hyperparameters to save + + Returns: + hyperparams (:obj:`dict`): Hyperparameters to be saved + weights (:obj:`torch.Tensor`): Neural network weights + """ + hyperparams = { + "network": self.network, + "lr_policy": self.lr_policy, + "rollout_size": self.rollout_size, + } + return hyperparams, self.agent.state_dict() + + def _load_weights(self, weights) -> None: + self.agent.load_state_dict(weights) + + def get_logging_params(self) -> Dict[str, Any]: + """Gets relevant parameters for logging + + Returns: + logs (:obj:`dict`): Logging parameters for monitoring training + """ + logs = { + "crossentropy_loss": safe_mean(self.logs["crossentropy_loss"]), + "mean_reward": safe_mean(self.rewards), + } + + self.empty_logs() + return logs + + def empty_logs(self): + """Empties logs""" + self.logs = {} + self.logs["crossentropy_loss"] = [] + self.rewards = [] diff --git a/genrl/agents/modelbased/mcts/__init__.py b/genrl/agents/modelbased/mcts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/genrl/agents/modelbased/mcts/base.py b/genrl/agents/modelbased/mcts/base.py new file mode 100644 index 00000000..28bfda43 --- /dev/null +++ b/genrl/agents/modelbased/mcts/base.py @@ -0,0 +1,117 @@ +import numpy as np + +from genrl.agents.modelbased.base import ModelBasedAgent, Planner + + +class TreePlanner(Planner): + def __init__(self): + self.root = None + self.observations = [] + self.horizon = horizon + self.reset() + + def plan(self, state, obs): + raise NotImplementedError() + + def get_plan(self): + actions = [] + node = self.root + while node.children: + action = node.selection_rule() + actions.append(action) + node = node.children[action] + return actions + + def step(self, state, action): + obs, reward, done, info = state.step(action) + self.observations.append(obs) + return obs, reward, done, info + + def step_tree(self, actions): + if self.strategy == "reset": + self.reset() + elif self.strategy == "subtree": + if actions: + self._step_by_subtree(actions[0]) + else: + self.reset() + else: + raise NotImplementedError + + def _step_by_subtree(self, action): + if action in self.root.children: + self.root = self.root.children[action] + self.root.parent = None + else: + self.reset() + + def get_visits(self): + visits = {} + for obs in self.observations: + if str(obs) not in visits.keys(): + visits[str(obs)] = 0 + visits[str(obs)] += 1 + + def reset(): + raise NotImplementedError + + +class Node: + def __init__(self, parent, planner): + self.parent = parent + self.planner = planner + self.children = {} + + self.visits = 0 + + def get_value(self): + raise NotImplementedError + + def expand(self, branch_factor): + self.children[a] = Node(self, planner) + + def selection_rule(self): + raise NotImplementedError + + def is_leaf(self): + return not self.children + + +class TreeSearchAgent(ModelBasedAgent): + def __init__(self, *args, horizon, **kwargs): + super(TreeSearchAgent, self).__init__(*args, **kwargs) + self.planner = self._make_planner() + self.prev_actions = [] + self.horizon = horizon + self.remaining_horizon = 0 + self.steps = 0 + + def _create_planner(self): + pass + + def plan(self, obs): + self.steps += 1 + replan = self._replan(self.prev_actions) + if replan: + env = self.env + actions = self.planner.plan(state=env, obs=obs) + else: + actions = self.prev_actions[1:] + + self.prev_actions = actions + return actions + + def _replan(self, actions): + replan = self.remaining_horizon == 0 or len(actions) <= 1 + if replan: + self.remaining_horizon = self.horizon + else: + self.remaining_horizon -= 1 + + self.planner.step_tree(actions) + return replan + + def reset(self): + self.planner.reset() + self.remaining_horizon = 0 + self.steps = 0 diff --git a/genrl/agents/modelbased/mcts/mcts.py b/genrl/agents/modelbased/mcts/mcts.py new file mode 100644 index 00000000..a803694b --- /dev/null +++ b/genrl/agents/modelbased/mcts/mcts.py @@ -0,0 +1,181 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.optim as opt + +from genrl.agents.modelbased.mcts.base import Node, TreePlanner, TreeSearchAgent + + +class MCTSAgent(TreeSearchAgent): + def __init__(self, *args, **kwargs): + super(MCTSAgent, self).__init__(*args, **kwargs) + self.planner = self._create_planner() + + def _create_planner(self): + prior_policy = None + rollout_policy = None + return MCTSPlanner(prior_policy, rollout_policy) + + def _create_model(self): + if isinstance(self.network, str): + arch_type = self.network + if self.shared_layers is not None: + arch_type += "s" + self.ac = get_model("ac", arch_type)( + state_dim, + action_dim, + shared_layers=self.shared_layers, + policy_layers=self.policy_layers, + value_layers=self.value_layers, + val_type="V", + discrete=discrete, + action_lim=action_lim, + ).to(self.device) + + actor_params, critic_params = self.ac.get_params() + self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy) + self.optimizer_value = opt.Adam(critic_params, lr=self.lr_value) + + def select_action(self, state) -> torch.Tensor: + # return action.detach(), value, dist.log_prob.cpu() + pass + + def get_traj_loss(self) -> None: + pass + + def evaluate_actions(self, states, actions): + # return values, dist.log_prob(actions).cpu(), dist.entropy.cpu() + pass + + def update_params(self) -> None: + # Blah blah + # policy_loss = something + # value_loss = something + pass + + +class MCTSNode(Node): + def __init__(self, *args, prior, **kwargs): + super(MCTSNode, self).__init__(*args, **kwargs) + self.value = 0 + self.prior = prior + + def selection_rule(self): + if not self.children: + return None + actions = list(self.children.keys()) + counts = np.argmax([self.children[a] for a in actions]) + return actions[ + np.max(counts, key=(lambda i: self.children[actions[i]].get_value())) + ] + + def sampling_rule(self): + if self.children: + actions = list(self.children.keys()) + idx = [self.children[a].selection_strategy(temp) for a in actions] + random_idx = np.random.choice(np.argmax(idx)) + return actions[random_idx] + return None + + def expand(self, actions_dist): + actions, probs = actions_dist + for i in range(len(actions)): + if actions[i] not in self.children: + self.children[actions[i]] = MCTSNode( + parent=self, planner=self.planner, prior=probs[i] + ) + + def _update(self, total_rew): + self.count += 1 + self.value += 1 / self.count * (total_rew - self.value) + + def update_branch(self, total_rew): + self._update(total_rew) + if self.parent: + self.parent.update_branch(total_rew) + + def get_child(self, action, obs=None): + child = self.children[action] + if obs is not None: + if str(obs) not in child.children: + child.children[str(obs)] = MCTSNode( + parent=child, planner=self.planner, prior=0 + ) + child = child.children[str(obs)] + return child + + def selection_strategy(self, temp=0): + if not self.parent: + return self.get_value() + return self.get_value + temp * len(self.parent.children) * self.prior / ( + self.count - 1 + ) + + def get_value(self): + return self.value + + def convert_visits_to_prior(self, reg=0.5): + self.count = 0 + total_count = np.sum([(child.count + 1) for child in self.children]) + for child in self.children.values(): + child.prior = reg * (child.count + 1) / total_counts + reg / len( + self.children + ) + child.convert_visits_to_prior() + + +class MCTSPlanner(TreePlanner): + def __init__(self, *args, prior, rollout_policy, episodes, **kwargs): + super(MCTSPlanner, self).__init__(*args, **kwargs) + self.env = env + self.prior = prior + self.rollout_policy = rollout_policy + self.gamma = gamma + self.episodes = episodes + + def reset(self): + self.root = MCTSNode(parent=None, planner=self) + + def _mc_search(self, state, obs): + # Runs one iteration of mcts + node = self.root + total_rew = 0 + depth = 0 + terminal = False + while depth < self.horizon and node.children and not terminal: + action = node.sampling_rule() # Not so sure about this + obs, reward, terminal, _ = self.step(state, action) + total_rew += self.gamma ** depth * reward + node_obs = obs + node = node.get_child(action, node_obs) + depth += 1 + + if not terminal: + total_rew = self.eval(state, obs, total_rew, depth=depth) + node.update_branch(total_rew) + + def eval(self, state, obs, total_rew=0, depth=0): + # Run the rollout policy to yeild a sample for the value + for h in range(depth, self.horizon): + actions, probs = self.rollout_policy(state, obs) + action = None # rew Select an action + obs, reward, terminal, _ = self.step(state, action) + total_ += self.gamma ** h * reward + if np.all(terminal): + break + return total_rew + + def plan(self, obs): + for i in range(self.episodes): + self._mc_search(copy.deepcopy(state), obs) + return self.get_plan() + + def step_planner(seld, action): + if self.step_strategy == "prior": + self._step_by_prior(action) + else: + super().step_planner(action) + + def _step_by_prior(self, action): + self._step_by_subtree(action) + self.root.convert_visits_to_prior() diff --git a/genrl/agents/modelbased/mcts/uct.py b/genrl/agents/modelbased/mcts/uct.py new file mode 100644 index 00000000..b6d9128d --- /dev/null +++ b/genrl/agents/modelbased/mcts/uct.py @@ -0,0 +1,16 @@ +import numpy as np + +from genrl.agents.modelbased.mcts.mcts import MCTSNode + + +class UCTNode(MCTSNode): + def __init__(self, *args, disc_factor, **kwargs): + super(UCTNode, self).__init__(*args, **kwargs) + self.disc_factor = disc_factor + + def selection_strategy(self, temp=0): + if not self.parent: + return self.get_value() + return self.get_value() + temperature * self.prior * np.sqrt( + np.log(self.parent.count) / self.count + ) diff --git a/tests/test_agents/test_modelbased/__init__.py b/tests/test_agents/test_modelbased/__init__.py new file mode 100644 index 00000000..08c59ec8 --- /dev/null +++ b/tests/test_agents/test_modelbased/__init__.py @@ -0,0 +1 @@ +from tests.test_agents.test_modelbased.test_cem import TestCEM diff --git a/tests/test_agents/test_modelbased/test_cem.py b/tests/test_agents/test_modelbased/test_cem.py new file mode 100644 index 00000000..5cf76779 --- /dev/null +++ b/tests/test_agents/test_modelbased/test_cem.py @@ -0,0 +1,23 @@ +import shutil + +from genrl.agents import CEM +from genrl.environments import VectorEnv +from genrl.trainers import OnPolicyTrainer + + +class TestCEM: + def test_CEM(self): + env = VectorEnv("CartPole-v0", 1) + algo = CEM( + "mlp", + env, + percentile=70, + policy_layers=[100], + rollout_size=100, + simulations_per_epoch=100, + ) + trainer = OnPolicyTrainer( + algo, env, log_mode=["csv"], logdir="./logs", epochs=1 + ) + trainer.train() + shutil.rmtree("./logs")