diff --git a/docs/cart_pole.gif b/docs/cart_pole.gif new file mode 100644 index 00000000..96365f6a Binary files /dev/null and b/docs/cart_pole.gif differ diff --git a/docs/getting_started.md b/docs/getting_started.md new file mode 100644 index 00000000..1f2ed253 --- /dev/null +++ b/docs/getting_started.md @@ -0,0 +1,25 @@ +# 🔥 Getting Started + +In the `/experiments` folder, example runs can be found for different Gymnasium environments. + +For example, you can run the cartpole example using DQN with the following command: + +```python +pdm run python experiments/train_dqn_cartpole.py +``` + +![Alt Text](cart_pole.gif) + +This comes with a lot of predefined arguments, such as the learning rate, the amount of hidden layers, the batch size, etc. You can find all the arguments in the `experiments/train_dqn_cartpole.py` file. + +## 📊 Tensorboard + +To visualize the training process, you can use Tensorboard. To do so, run the following command: + +```bash +pdm run tensorboard --logdir ./mllogs +``` + +This will start a Tensorboard server on `localhost:6006`. You can now open your browser and go to `localhost:6006` to see the training process where you can see the rewards over time, the loss over time, etc. + +![Alt Text](tensorboard.png) diff --git a/docs/index.rst b/docs/index.rst index c46506c0..910a8aed 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -35,6 +35,7 @@ to build other algorithms. coding-standard 📚 Editing documentation 🌡 Metrics + 🚀 Getting Started .. toctree:: :maxdepth: 6 @@ -58,6 +59,7 @@ to build other algorithms. .. include:: adr/doc.md .. include:: documentation.md .. include:: metrics.md + .. include:: getting_started.md :parser: myst_parser.sphinx_ diff --git a/docs/tensorboard.png b/docs/tensorboard.png new file mode 100644 index 00000000..00cd8a83 Binary files /dev/null and b/docs/tensorboard.png differ diff --git a/experiments/gym/train_dqn_cartpole.py b/experiments/gym/train_dqn_cartpole.py index 558b8701..d92bb2a8 100644 --- a/experiments/gym/train_dqn_cartpole.py +++ b/experiments/gym/train_dqn_cartpole.py @@ -27,6 +27,13 @@ def _make_env(): + """Create the environment for the experiment, the environment is created in a thunk to avoid + creating multiple environments in the same process. This is important for the vectorized + environments. + Returns: + (Callable[[], gym.Env]): The thunk that creates the environment + """ + def _thunk(): env = gym.make("CartPole-v1") env = gym.wrappers.FrameStack(env, 3) @@ -37,6 +44,18 @@ def _thunk(): class QNet(nn.Module): + """ + Q-Network class for Q-Learning. It takes observations and returns Q-values for actions. + + Attributes: + network (nn.Sequential): Neural network for computing Q-values. + + Args: + num_obs (int): Dimensionality of observations. + num_actions (int): Number of possible actions. + hidden_dims (list of int): Dimensions of hidden layers. + """ + def __init__(self, num_obs, num_actions, hidden_dims): super(QNet, self).__init__() @@ -53,10 +72,37 @@ def __init__(self, num_obs, num_actions, hidden_dims): self.network = nn.Sequential(*layers) def forward(self, obs): + """ + Forward pass for the Q-Network. + + Args: + obs (Tensor): Observations. + + Returns: + Tensor: Q-values for each action. + """ return self.network(obs) class DQNPolicy(nn.Module): + """ + DQN Policy class to handle action selection with epsilon-greedy strategy. + + Attributes: + q_net (QNet): Q-Network to evaluate Q-values. + initial_epsilon (float): Initial value of epsilon in epsilon-greedy. + target_epsilon (float): Target value of epsilon. + step_count (int): Counter for steps taken. + epsilon_decay_duration (int): Steps over which epsilon is decayed. + log_epsilon (bool): Flag to log epsilon values. + + Args: + q_net (QNet): Q-Network. + epsilon_range (list of float): Initial and target epsilon for epsilon-greedy. + epsilon_decay_duration (int): Number of steps over which epsilon will decay. + log_epsilon (bool): Whether to log epsilon values or not. + """ + def __init__( self, q_net, epsilon_range=[0.9, 0.05], epsilon_decay_duration=10_000, log_epsilon=True ): @@ -71,6 +117,15 @@ def __init__( # Returns the index of the chosen action def forward(self, state): + """ + Forward pass for action selection. + + Args: + state (Tensor): The state observations. + + Returns: + Tensor: Indices of chosen actions for each environment. + """ with torch.no_grad(): epsilon = self.target_epsilon + (self.initial_epsilon - self.target_epsilon) * math.exp( -1.0 * self.step_count / self.epsilon_decay_duration @@ -120,13 +175,16 @@ def create_memory( (tuple[TableMemoryProxy, MemoryLoader]): A proxy for the memory and a dataloader """ + # Create the memory table = DictObsNStepTable( spaces=space, use_terminal_column=False, maxlen=memory_size, device=device, ) + # The memory proxy is used to upload the data to the memory memory_proxy = TableMemoryProxy(table=table, use_terminal=False) + # The data loader is used to sample the data from the memory data_loader = MemoryLoader( table=table, rollout_count=batch_size // len_rollout, @@ -152,33 +210,20 @@ def create_complementary_callbacks( Returns: (list[Callback]): the full list of callbacks for the training """ - if args.use_wandb: - from emote.callbacks.wb_logger import WBLogger - - config = { - "wandb_project": args.name, - "wandb_run": args.wandb_run, - "hidden_dims": args.hidden_layer_size, - "batch_size": args.batch_size, - "learning_rate": args.actor_lr, - "rollout_len": args.rollout_length, - } - logger = WBLogger( - callbacks=logged_cbs, - config=config, - log_interval=100, - ) - else: - logger = TensorboardLogger( - logged_cbs, - SummaryWriter(log_dir=args.log_dir + "/" + args.name + "_{}".format(time.time())), - 100, - ) + # The logger callback is used for logging the training progress + logger = TensorboardLogger( + logged_cbs, + SummaryWriter(log_dir=args.log_dir + "/" + args.name + "_{}".format(time.time())), + 100, + ) + # Terminates the training after a certain number of backprop steps bp_step_terminator = BackPropStepsTerminator(bp_steps=args.bp_steps) + # Callbacks to be used during training callbacks = logged_cbs + [logger, bp_step_terminator] if cbs_name_to_checkpoint: + # The checkpointer exports the model weights to the checkpoint directory checkpointer = Checkpointer( callbacks=[ cb for cb in logged_cbs if hasattr(cb, "name") and cb.name in cbs_name_to_checkpoint @@ -192,9 +237,11 @@ def create_complementary_callbacks( def main(args): + # Create the environment env = DictGymWrapper(AsyncVectorEnv([_make_env() for _ in range(args.num_envs)])) device = torch.device(args.device) + # Define the space in order to create the memory input_shapes = {k: v.shape for k, v in env.dict_space.state.spaces.items()} output_shapes = {"actions": env.dict_space.actions.shape} action_shape = output_shapes["actions"] @@ -231,14 +278,17 @@ def main(args): num_actions = env.action_space.nvec[0] + # Create our two networks and the policy online_q_net = QNet(num_obs, num_actions, args.hidden_dims) target_q_net = QNet(num_obs, num_actions, args.hidden_dims) policy = DQNPolicy(online_q_net) + # Move them to the device online_q_net = online_q_net.to(device) target_q_net = target_q_net.to(device) policy = policy.to(device) + # The agent proxy is responsible for inference agent_proxy = GenericAgentProxy( policy, device=device, @@ -248,6 +298,7 @@ def main(args): spaces=spaces, ) + # Create an optimizer for the online network optimizers = [ QLoss( name="q1", @@ -258,11 +309,13 @@ def main(args): ] train_callbacks = optimizers + [ + # The QTarget callback is responsible for updating the target network QTarget( q_net=online_q_net, target_q_net=target_q_net, roll_length=args.rollout_length, ), + # The collector is responsible for the interaction with the environment ThreadedGymCollector( env, agent_proxy, @@ -277,29 +330,51 @@ def main(args): train_callbacks, ) + # The trainer acts as the main callback, responsible for calling all other callbacks trainer = Trainer(all_callbacks, dataloader) trainer.train() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--name", type=str, default="cartpole") - parser.add_argument("--log-dir", type=str, default="./mllogs/emote/cartpole") - parser.add_argument("--num-envs", type=int, default=4) - parser.add_argument("--rollout-length", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=128) - parser.add_argument("--hidden-dims", type=list, default=[128, 128]) - parser.add_argument("--lr", type=float, default=1e-3, help="The learning rate") - parser.add_argument("--device", type=str, default="cpu") - parser.add_argument("--bp-steps", type=int, default=50_000) - parser.add_argument("--memory-size", type=int, default=50_000) - parser.add_argument("--export-memory", action="store_true", default=False) - parser.add_argument("--use-wandb", action="store_true") + parser.add_argument("--name", type=str, default="cartpole", help="The name of the experiment") parser.add_argument( - "--wandb-run", + "--log-dir", type=str, - default=None, - help="Short display name of run for the W&B UI. Randomly generated by default.", + default="./mllogs/emote/cartpole", + help="Directory where logs will be stored.", + ) + parser.add_argument( + "--num-envs", type=int, default=4, help="Number of environments to run in parallel" + ) + parser.add_argument( + "--rollout-length", + type=int, + default=1, + help="The length of each rollout. Refers to the number of steps or time-steps taken during a simulated trajectory or rollout when estimating the expected return of a policy.", + ) + parser.add_argument("--batch-size", type=int, default=128, help="Size of each training batch") + parser.add_argument( + "--hidden-dims", type=list, default=[128, 128], help="The hidden dimensions of the network" + ) + parser.add_argument("--lr", type=float, default=1e-3, help="Learning Rate") + parser.add_argument( + "--device", type=str, default="cpu", help="Device to run the model on, e.g. cpu or cuda:0" + ) + parser.add_argument( + "--bp-steps", + type=int, + default=50_000, + help="Number of backpropagation steps until the training run is finished", + ) + parser.add_argument( + "--memory-size", + type=int, + default=50_000, + help="The size of the replay buffer. More complex environments require larger replay buffers, as they need more data to learn. Given that cartpole is a simple environment, a replay buffer of size 50_000 is sufficient.", + ) + parser.add_argument( + "--export-memory", action="store_true", default=False, help="Whether to export the memory" ) args = parser.parse_args() main(args)