Skip to content

Commit

Permalink
Rework garage.torch.CNNModule (#2189)
Browse files Browse the repository at this point in the history
It now takes an env spec and automatically computes the output size.
It also handles NHWC and NCHW instead of requiring an environment
wrapper.

This change also fixes several issues with the existing usage of CNNs
in pytorch.
  • Loading branch information
krzentner authored Mar 23, 2021
1 parent f896dca commit f056fb8
Show file tree
Hide file tree
Showing 29 changed files with 726 additions and 1,097 deletions.
6 changes: 3 additions & 3 deletions src/garage/examples/torch/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def dqn_atari(ctxt=None,
env = Resize(env, 84, 84)
env = ClipReward(env)
env = StackFrames(env, 4, axis=0)
env = GymEnv(env, max_episode_length=max_episode_length)
env = GymEnv(env, max_episode_length=max_episode_length, is_image=True)
set_seed(seed)
trainer = Trainer(ctxt)

Expand All @@ -152,13 +152,13 @@ def dqn_atari(ctxt=None,

qf = DiscreteCNNQFunction(
env_spec=env.spec,
image_format='NCHW',
hidden_channels=hyperparams['hidden_channels'],
kernel_sizes=hyperparams['kernel_sizes'],
strides=hyperparams['strides'],
hidden_w_init=(
lambda x: torch.nn.init.orthogonal_(x, gain=np.sqrt(2))),
hidden_sizes=hyperparams['hidden_sizes'],
is_image=True)
hidden_sizes=hyperparams['hidden_sizes'])

policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
exploration_policy = EpsilonGreedyPolicy(
Expand Down
17 changes: 9 additions & 8 deletions src/garage/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
"""PyTorch-backed modules and algorithms."""
# yapf: disable
from garage.torch._functions import (compute_advantages, dict_np_to_torch,
from garage.torch._functions import (as_torch, as_torch_dict,
compute_advantages, expand_var,
filter_valids, flatten_batch,
flatten_to_single_vector, global_device,
NonLinearity, np_to_torch, pad_to_last,
prefer_gpu, product_of_gaussians,
set_gpu_mode, soft_update_model,
torch_to_np, TransposeImage,
NonLinearity, output_height_2d,
output_width_2d, pad_to_last, prefer_gpu,
product_of_gaussians, set_gpu_mode,
soft_update_model, torch_to_np,
update_module_params)

# yapf: enable
__all__ = [
'compute_advantages', 'dict_np_to_torch', 'filter_valids', 'flatten_batch',
'global_device', 'np_to_torch', 'pad_to_last', 'prefer_gpu',
'compute_advantages', 'as_torch_dict', 'filter_valids', 'flatten_batch',
'global_device', 'as_torch', 'pad_to_last', 'prefer_gpu',
'product_of_gaussians', 'set_gpu_mode', 'soft_update_model', 'torch_to_np',
'update_module_params', 'NonLinearity', 'flatten_to_single_vector',
'TransposeImage'
'output_width_2d', 'output_height_2d', 'expand_var'
]
145 changes: 112 additions & 33 deletions src/garage/torch/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@
- Updating model parameters
"""
import copy
import dataclasses
import math
import warnings

import akro
import torch
from torch import nn
import torch.nn.functional as F

from garage import EnvSpec, Wrapper

_USE_GPU = False
_DEVICE = None
_GPU_ID = 0
Expand Down Expand Up @@ -134,7 +132,7 @@ def filter_valids(tensor, valids):
return [tensor[i][:valid] for i, valid in enumerate(valids)]


def np_to_torch(array):
def as_torch(array):
"""Numpy arrays to PyTorch tensors.
Args:
Expand All @@ -144,10 +142,10 @@ def np_to_torch(array):
torch.Tensor: float tensor on the global device.
"""
return torch.from_numpy(array).float().to(global_device())
return torch.as_tensor(array).float().to(global_device())


def dict_np_to_torch(array_dict):
def as_torch_dict(array_dict):
"""Convert a dict whose values are numpy arrays to PyTorch tensors.
Modifies array_dict in place.
Expand All @@ -160,7 +158,7 @@ def dict_np_to_torch(array_dict):
"""
for key, value in array_dict.items():
array_dict[key] = np_to_torch(value)
array_dict[key] = as_torch(value)
return array_dict


Expand Down Expand Up @@ -227,15 +225,14 @@ def update_module_params(module, new_params): # noqa: D202
generated by `torch.nn.Module.named_parameters()`
"""
named_modules = dict(module.named_modules())

# pylint: disable=protected-access
def update(m, name, param):
del m._parameters[name] # noqa: E501
setattr(m, name, param)
m._parameters[name] = param # noqa: E501

named_modules = dict(module.named_modules())

for name, new_param in new_params.items():
if '.' in name:
module_name, param_name = tuple(name.rsplit('.', 1))
Expand Down Expand Up @@ -370,34 +367,116 @@ def __repr__(self):
return repr(self.module)


class TransposeImage(Wrapper):
"""Transpose observation space for image observation in PyTorch.
def _value_at_axis(value, axis):
"""Get the value for a particular axis.
Args:
value (tuple or list or int): Possible tuple of per-axis values.
axis (int): Axis to get value for.
Returns:
int: the value at the available axis.
Reshape the input observation shape from (H, W, C) into (C, H, W)
in pytorch format.
"""
if not isinstance(value, (list, tuple)):
return value
if len(value) == 1:
return value[0]
else:
return value[axis]

@property
def observation_space(self):
"""akro.Space: The observation space specification."""
obs_shape = self._env.observation_space.shape
return akro.Image((obs_shape[2], obs_shape[1], obs_shape[0]))

@property
def spec(self):
"""EnvSpec: The environment specification."""
return EnvSpec(self.observation_space, self._env.spec.action_space)
def output_height_2d(layer, height):
"""Compute the output height of a torch.nn.Conv2d, assuming NCHW format.
def step(self, action):
"""Step the wrapped env.
This requires knowing the input height. Because NCHW format makes this very
easy to mix up, this is a seperate function from conv2d_output_height.
Args:
action (np.ndarray): An action provided by the agent.
It also works on torch.nn.MaxPool2d.
Returns:
EnvStep: The environment step resulting from the action.
This function implements the formula described in the torch.nn.Conv2d
documentation:
https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
"""
env_step = super().step(action)
obs = env_step.observation.transpose(2, 0, 1)
return dataclasses.replace(env_step, observation=obs)
Args:
layer (torch.nn.Conv2d): The layer to compute output size for.
height (int): The height of the input image.
Returns:
int: The height of the output image.
"""
assert isinstance(layer, (torch.nn.Conv2d, torch.nn.MaxPool2d))
padding = _value_at_axis(layer.padding, 0)
dilation = _value_at_axis(layer.dilation, 0)
kernel_size = _value_at_axis(layer.kernel_size, 0)
stride = _value_at_axis(layer.stride, 0)
return math.floor((height + 2 * padding - dilation *
(kernel_size - 1) - 1) / stride + 1)


def output_width_2d(layer, width):
"""Compute the output width of a torch.nn.Conv2d, assuming NCHW format.
This requires knowing the input width. Because NCHW format makes this very
easy to mix up, this is a seperate function from conv2d_output_height.
It also works on torch.nn.MaxPool2d.
This function implements the formula described in the torch.nn.Conv2d
documentation:
https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
Args:
layer (torch.nn.Conv2d): The layer to compute output size for.
width (int): The width of the input image.
Returns:
int: The width of the output image.
"""
assert isinstance(layer, (torch.nn.Conv2d, torch.nn.MaxPool2d))

padding = _value_at_axis(layer.padding, 1)
dilation = _value_at_axis(layer.dilation, 1)
kernel_size = _value_at_axis(layer.kernel_size, 1)
stride = _value_at_axis(layer.stride, 1)
return math.floor((width + 2 * padding - dilation *
(kernel_size - 1) - 1) / stride + 1)


def expand_var(name, item, n_expected, reference):
"""Expand a variable to an expected length.
This is used to handle arguments to primitives that can all be reasonably
set to the same value, or multiple different values.
Args:
name (str): Name of variable being expanded.
item (any): Element being expanded.
n_expected (int): Number of elements expected.
reference (str): Source of n_expected.
Returns:
list: List of references to item or item itself.
Raises:
ValueError: If the variable is a sequence but length of the variable
is not 1 or n_expected.
"""
if n_expected == 1:
warnings.warn(
f'Providing a {reference} of length 1 prevents {name} from '
f'being expanded.')
if isinstance(item, (list, tuple)):
if len(item) == n_expected:
return item
elif len(item) == 1:
return list(item) * n_expected
else:
raise ValueError(
f'{name} is length {len(item)} but should be length '
f'{n_expected} to match {reference}')
else:
return [item] * n_expected
6 changes: 3 additions & 3 deletions src/garage/torch/algos/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from garage.np.algos.rl_algorithm import RLAlgorithm
from garage.np.policies import Policy
from garage.sampler import Sampler
from garage.torch import np_to_torch
from garage.torch import as_torch

# yapf: enable

Expand Down Expand Up @@ -130,8 +130,8 @@ def _train_once(self, trainer, epoch):
minibatches = np.array_split(indices, self._minibatches_per_epoch)
losses = []
for minibatch in minibatches:
observations = np_to_torch(batch.observations[minibatch])
actions = np_to_torch(batch.actions[minibatch])
observations = as_torch(batch.observations[minibatch])
actions = as_torch(batch.actions[minibatch])
self._optimizer.zero_grad()
loss = self._compute_loss(observations, actions)
loss.backward()
Expand Down
4 changes: 2 additions & 2 deletions src/garage/torch/algos/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from garage import (_Default, log_performance, make_optimizer,
obtain_evaluation_episodes)
from garage.np.algos import RLAlgorithm
from garage.torch import dict_np_to_torch, torch_to_np
from garage.torch import as_torch_dict, torch_to_np

# yapf: enable

Expand Down Expand Up @@ -230,7 +230,7 @@ def optimize_policy(self, samples_data):
qval: Q-value predicted by the Q-network.
"""
transitions = dict_np_to_torch(samples_data)
transitions = as_torch_dict(samples_data)

observations = transitions['observations']
rewards = transitions['rewards'].reshape(-1, 1)
Expand Down
12 changes: 6 additions & 6 deletions src/garage/torch/algos/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from garage import _Default, log_performance, make_optimizer
from garage._functions import obtain_evaluation_episodes
from garage.np.algos import RLAlgorithm
from garage.torch import global_device, np_to_torch
from garage.torch import as_torch, global_device


class DQN(RLAlgorithm):
Expand Down Expand Up @@ -243,12 +243,12 @@ def _optimize_qf(self, timesteps):
qval: Q-value predicted by the Q-network.
"""
observations = np_to_torch(timesteps.observations)
rewards = np_to_torch(timesteps.rewards).reshape(-1, 1)
observations = as_torch(timesteps.observations)
rewards = as_torch(timesteps.rewards).reshape(-1, 1)
rewards *= self._reward_scale
actions = np_to_torch(timesteps.actions)
next_observations = np_to_torch(timesteps.next_observations)
terminals = np_to_torch(timesteps.terminals).reshape(-1, 1)
actions = as_torch(timesteps.actions)
next_observations = as_torch(timesteps.next_observations)
terminals = as_torch(timesteps.terminals).reshape(-1, 1)

next_inputs = next_observations
inputs = observations
Expand Down
4 changes: 2 additions & 2 deletions src/garage/torch/algos/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from garage import log_performance, obtain_evaluation_episodes, StepType
from garage.np.algos import RLAlgorithm
from garage.torch import dict_np_to_torch, global_device
from garage.torch import as_torch_dict, global_device

# yapf: enable

Expand Down Expand Up @@ -236,7 +236,7 @@ def train_once(self, itr=None, paths=None):
if self.replay_buffer.n_transitions_stored >= self._min_buffer_size:
samples = self.replay_buffer.sample_transitions(
self._buffer_batch_size)
samples = dict_np_to_torch(samples)
samples = as_torch_dict(samples)
policy_loss, qf1_loss, qf2_loss = self.optimize_policy(samples)
self._update_targets()

Expand Down
4 changes: 2 additions & 2 deletions src/garage/torch/algos/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from garage import (_Default, log_performance, make_optimizer,
obtain_evaluation_episodes)
from garage.np.algos import RLAlgorithm
from garage.torch import (dict_np_to_torch, global_device, soft_update_model,
from garage.torch import (as_torch_dict, global_device, soft_update_model,
torch_to_np)


Expand Down Expand Up @@ -238,7 +238,7 @@ def _train_once(self, itr):
# Sample from buffer
samples = self._replay_buffer.sample_transitions(
self._buffer_batch_size)
samples = dict_np_to_torch(samples)
samples = as_torch_dict(samples)

# Optimize
qf_loss, y, q, policy_loss = torch_to_np(
Expand Down
5 changes: 0 additions & 5 deletions src/garage/torch/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""PyTorch Modules."""
# yapf: disable
# isort:skip_file
from garage.torch.modules.categorical_cnn_module import CategoricalCNNModule
from garage.torch.modules.cnn_module import CNNModule
from garage.torch.modules.gaussian_mlp_module import (
GaussianMLPIndependentStdModule) # noqa: E501
Expand All @@ -12,15 +11,11 @@
from garage.torch.modules.multi_headed_mlp_module import MultiHeadedMLPModule
# DiscreteCNNModule must go after MLPModule
from garage.torch.modules.discrete_cnn_module import DiscreteCNNModule
from garage.torch.modules.discrete_dueling_cnn_module import (
DiscreteDuelingCNNModule)
# yapf: enable

__all__ = [
'CategoricalCNNModule',
'CNNModule',
'DiscreteCNNModule',
'DiscreteDuelingCNNModule',
'MLPModule',
'MultiHeadedMLPModule',
'GaussianMLPModule',
Expand Down
Loading

0 comments on commit f056fb8

Please sign in to comment.