diff --git a/src/garage/examples/torch/dqn_atari.py b/src/garage/examples/torch/dqn_atari.py index db4183256c..82c8441b2b 100755 --- a/src/garage/examples/torch/dqn_atari.py +++ b/src/garage/examples/torch/dqn_atari.py @@ -139,7 +139,7 @@ def dqn_atari(ctxt=None, env = Resize(env, 84, 84) env = ClipReward(env) env = StackFrames(env, 4, axis=0) - env = GymEnv(env, max_episode_length=max_episode_length) + env = GymEnv(env, max_episode_length=max_episode_length, is_image=True) set_seed(seed) trainer = Trainer(ctxt) @@ -152,13 +152,13 @@ def dqn_atari(ctxt=None, qf = DiscreteCNNQFunction( env_spec=env.spec, + image_format='NCHW', hidden_channels=hyperparams['hidden_channels'], kernel_sizes=hyperparams['kernel_sizes'], strides=hyperparams['strides'], hidden_w_init=( lambda x: torch.nn.init.orthogonal_(x, gain=np.sqrt(2))), - hidden_sizes=hyperparams['hidden_sizes'], - is_image=True) + hidden_sizes=hyperparams['hidden_sizes']) policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) exploration_policy = EpsilonGreedyPolicy( diff --git a/src/garage/torch/__init__.py b/src/garage/torch/__init__.py index 9b47be1771..001de12142 100644 --- a/src/garage/torch/__init__.py +++ b/src/garage/torch/__init__.py @@ -1,19 +1,20 @@ """PyTorch-backed modules and algorithms.""" # yapf: disable -from garage.torch._functions import (compute_advantages, dict_np_to_torch, +from garage.torch._functions import (as_torch, as_torch_dict, + compute_advantages, expand_var, filter_valids, flatten_batch, flatten_to_single_vector, global_device, - NonLinearity, np_to_torch, pad_to_last, - prefer_gpu, product_of_gaussians, - set_gpu_mode, soft_update_model, - torch_to_np, TransposeImage, + NonLinearity, output_height_2d, + output_width_2d, pad_to_last, prefer_gpu, + product_of_gaussians, set_gpu_mode, + soft_update_model, torch_to_np, update_module_params) # yapf: enable __all__ = [ - 'compute_advantages', 'dict_np_to_torch', 'filter_valids', 'flatten_batch', - 'global_device', 'np_to_torch', 'pad_to_last', 'prefer_gpu', + 'compute_advantages', 'as_torch_dict', 'filter_valids', 'flatten_batch', + 'global_device', 'as_torch', 'pad_to_last', 'prefer_gpu', 'product_of_gaussians', 'set_gpu_mode', 'soft_update_model', 'torch_to_np', 'update_module_params', 'NonLinearity', 'flatten_to_single_vector', - 'TransposeImage' + 'output_width_2d', 'output_height_2d', 'expand_var' ] diff --git a/src/garage/torch/_functions.py b/src/garage/torch/_functions.py index ece53b0e35..64c0f61d54 100644 --- a/src/garage/torch/_functions.py +++ b/src/garage/torch/_functions.py @@ -10,15 +10,13 @@ - Updating model parameters """ import copy -import dataclasses +import math +import warnings -import akro import torch from torch import nn import torch.nn.functional as F -from garage import EnvSpec, Wrapper - _USE_GPU = False _DEVICE = None _GPU_ID = 0 @@ -134,7 +132,7 @@ def filter_valids(tensor, valids): return [tensor[i][:valid] for i, valid in enumerate(valids)] -def np_to_torch(array): +def as_torch(array): """Numpy arrays to PyTorch tensors. Args: @@ -144,10 +142,10 @@ def np_to_torch(array): torch.Tensor: float tensor on the global device. """ - return torch.from_numpy(array).float().to(global_device()) + return torch.as_tensor(array).float().to(global_device()) -def dict_np_to_torch(array_dict): +def as_torch_dict(array_dict): """Convert a dict whose values are numpy arrays to PyTorch tensors. Modifies array_dict in place. @@ -160,7 +158,7 @@ def dict_np_to_torch(array_dict): """ for key, value in array_dict.items(): - array_dict[key] = np_to_torch(value) + array_dict[key] = as_torch(value) return array_dict @@ -227,6 +225,7 @@ def update_module_params(module, new_params): # noqa: D202 generated by `torch.nn.Module.named_parameters()` """ + named_modules = dict(module.named_modules()) # pylint: disable=protected-access def update(m, name, param): @@ -234,8 +233,6 @@ def update(m, name, param): setattr(m, name, param) m._parameters[name] = param # noqa: E501 - named_modules = dict(module.named_modules()) - for name, new_param in new_params.items(): if '.' in name: module_name, param_name = tuple(name.rsplit('.', 1)) @@ -370,34 +367,116 @@ def __repr__(self): return repr(self.module) -class TransposeImage(Wrapper): - """Transpose observation space for image observation in PyTorch. +def _value_at_axis(value, axis): + """Get the value for a particular axis. + + Args: + value (tuple or list or int): Possible tuple of per-axis values. + axis (int): Axis to get value for. + + Returns: + int: the value at the available axis. - Reshape the input observation shape from (H, W, C) into (C, H, W) - in pytorch format. """ + if not isinstance(value, (list, tuple)): + return value + if len(value) == 1: + return value[0] + else: + return value[axis] - @property - def observation_space(self): - """akro.Space: The observation space specification.""" - obs_shape = self._env.observation_space.shape - return akro.Image((obs_shape[2], obs_shape[1], obs_shape[0])) - @property - def spec(self): - """EnvSpec: The environment specification.""" - return EnvSpec(self.observation_space, self._env.spec.action_space) +def output_height_2d(layer, height): + """Compute the output height of a torch.nn.Conv2d, assuming NCHW format. - def step(self, action): - """Step the wrapped env. + This requires knowing the input height. Because NCHW format makes this very + easy to mix up, this is a seperate function from conv2d_output_height. - Args: - action (np.ndarray): An action provided by the agent. + It also works on torch.nn.MaxPool2d. - Returns: - EnvStep: The environment step resulting from the action. + This function implements the formula described in the torch.nn.Conv2d + documentation: + https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html - """ - env_step = super().step(action) - obs = env_step.observation.transpose(2, 0, 1) - return dataclasses.replace(env_step, observation=obs) + Args: + layer (torch.nn.Conv2d): The layer to compute output size for. + height (int): The height of the input image. + + Returns: + int: The height of the output image. + + """ + assert isinstance(layer, (torch.nn.Conv2d, torch.nn.MaxPool2d)) + padding = _value_at_axis(layer.padding, 0) + dilation = _value_at_axis(layer.dilation, 0) + kernel_size = _value_at_axis(layer.kernel_size, 0) + stride = _value_at_axis(layer.stride, 0) + return math.floor((height + 2 * padding - dilation * + (kernel_size - 1) - 1) / stride + 1) + + +def output_width_2d(layer, width): + """Compute the output width of a torch.nn.Conv2d, assuming NCHW format. + + This requires knowing the input width. Because NCHW format makes this very + easy to mix up, this is a seperate function from conv2d_output_height. + + It also works on torch.nn.MaxPool2d. + + This function implements the formula described in the torch.nn.Conv2d + documentation: + https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + + Args: + layer (torch.nn.Conv2d): The layer to compute output size for. + width (int): The width of the input image. + + Returns: + int: The width of the output image. + + """ + assert isinstance(layer, (torch.nn.Conv2d, torch.nn.MaxPool2d)) + + padding = _value_at_axis(layer.padding, 1) + dilation = _value_at_axis(layer.dilation, 1) + kernel_size = _value_at_axis(layer.kernel_size, 1) + stride = _value_at_axis(layer.stride, 1) + return math.floor((width + 2 * padding - dilation * + (kernel_size - 1) - 1) / stride + 1) + + +def expand_var(name, item, n_expected, reference): + """Expand a variable to an expected length. + + This is used to handle arguments to primitives that can all be reasonably + set to the same value, or multiple different values. + + Args: + name (str): Name of variable being expanded. + item (any): Element being expanded. + n_expected (int): Number of elements expected. + reference (str): Source of n_expected. + + Returns: + list: List of references to item or item itself. + + Raises: + ValueError: If the variable is a sequence but length of the variable + is not 1 or n_expected. + + """ + if n_expected == 1: + warnings.warn( + f'Providing a {reference} of length 1 prevents {name} from ' + f'being expanded.') + if isinstance(item, (list, tuple)): + if len(item) == n_expected: + return item + elif len(item) == 1: + return list(item) * n_expected + else: + raise ValueError( + f'{name} is length {len(item)} but should be length ' + f'{n_expected} to match {reference}') + else: + return [item] * n_expected diff --git a/src/garage/torch/algos/bc.py b/src/garage/torch/algos/bc.py index 220444beeb..a29e58d162 100644 --- a/src/garage/torch/algos/bc.py +++ b/src/garage/torch/algos/bc.py @@ -11,7 +11,7 @@ from garage.np.algos.rl_algorithm import RLAlgorithm from garage.np.policies import Policy from garage.sampler import Sampler -from garage.torch import np_to_torch +from garage.torch import as_torch # yapf: enable @@ -130,8 +130,8 @@ def _train_once(self, trainer, epoch): minibatches = np.array_split(indices, self._minibatches_per_epoch) losses = [] for minibatch in minibatches: - observations = np_to_torch(batch.observations[minibatch]) - actions = np_to_torch(batch.actions[minibatch]) + observations = as_torch(batch.observations[minibatch]) + actions = as_torch(batch.actions[minibatch]) self._optimizer.zero_grad() loss = self._compute_loss(observations, actions) loss.backward() diff --git a/src/garage/torch/algos/ddpg.py b/src/garage/torch/algos/ddpg.py index a69b981423..ab9c5fa0d1 100644 --- a/src/garage/torch/algos/ddpg.py +++ b/src/garage/torch/algos/ddpg.py @@ -9,7 +9,7 @@ from garage import (_Default, log_performance, make_optimizer, obtain_evaluation_episodes) from garage.np.algos import RLAlgorithm -from garage.torch import dict_np_to_torch, torch_to_np +from garage.torch import as_torch_dict, torch_to_np # yapf: enable @@ -230,7 +230,7 @@ def optimize_policy(self, samples_data): qval: Q-value predicted by the Q-network. """ - transitions = dict_np_to_torch(samples_data) + transitions = as_torch_dict(samples_data) observations = transitions['observations'] rewards = transitions['rewards'].reshape(-1, 1) diff --git a/src/garage/torch/algos/dqn.py b/src/garage/torch/algos/dqn.py index ad23afbf04..878995e14c 100644 --- a/src/garage/torch/algos/dqn.py +++ b/src/garage/torch/algos/dqn.py @@ -10,7 +10,7 @@ from garage import _Default, log_performance, make_optimizer from garage._functions import obtain_evaluation_episodes from garage.np.algos import RLAlgorithm -from garage.torch import global_device, np_to_torch +from garage.torch import as_torch, global_device class DQN(RLAlgorithm): @@ -243,12 +243,12 @@ def _optimize_qf(self, timesteps): qval: Q-value predicted by the Q-network. """ - observations = np_to_torch(timesteps.observations) - rewards = np_to_torch(timesteps.rewards).reshape(-1, 1) + observations = as_torch(timesteps.observations) + rewards = as_torch(timesteps.rewards).reshape(-1, 1) rewards *= self._reward_scale - actions = np_to_torch(timesteps.actions) - next_observations = np_to_torch(timesteps.next_observations) - terminals = np_to_torch(timesteps.terminals).reshape(-1, 1) + actions = as_torch(timesteps.actions) + next_observations = as_torch(timesteps.next_observations) + terminals = as_torch(timesteps.terminals).reshape(-1, 1) next_inputs = next_observations inputs = observations diff --git a/src/garage/torch/algos/sac.py b/src/garage/torch/algos/sac.py index 6ed1b63e0f..4d328dc912 100644 --- a/src/garage/torch/algos/sac.py +++ b/src/garage/torch/algos/sac.py @@ -10,7 +10,7 @@ from garage import log_performance, obtain_evaluation_episodes, StepType from garage.np.algos import RLAlgorithm -from garage.torch import dict_np_to_torch, global_device +from garage.torch import as_torch_dict, global_device # yapf: enable @@ -236,7 +236,7 @@ def train_once(self, itr=None, paths=None): if self.replay_buffer.n_transitions_stored >= self._min_buffer_size: samples = self.replay_buffer.sample_transitions( self._buffer_batch_size) - samples = dict_np_to_torch(samples) + samples = as_torch_dict(samples) policy_loss, qf1_loss, qf2_loss = self.optimize_policy(samples) self._update_targets() diff --git a/src/garage/torch/algos/td3.py b/src/garage/torch/algos/td3.py index 14351a0c8c..13f3037548 100644 --- a/src/garage/torch/algos/td3.py +++ b/src/garage/torch/algos/td3.py @@ -9,7 +9,7 @@ from garage import (_Default, log_performance, make_optimizer, obtain_evaluation_episodes) from garage.np.algos import RLAlgorithm -from garage.torch import (dict_np_to_torch, global_device, soft_update_model, +from garage.torch import (as_torch_dict, global_device, soft_update_model, torch_to_np) @@ -238,7 +238,7 @@ def _train_once(self, itr): # Sample from buffer samples = self._replay_buffer.sample_transitions( self._buffer_batch_size) - samples = dict_np_to_torch(samples) + samples = as_torch_dict(samples) # Optimize qf_loss, y, q, policy_loss = torch_to_np( diff --git a/src/garage/torch/modules/__init__.py b/src/garage/torch/modules/__init__.py index c02b7a0254..e725dfb4b8 100644 --- a/src/garage/torch/modules/__init__.py +++ b/src/garage/torch/modules/__init__.py @@ -1,7 +1,6 @@ """PyTorch Modules.""" # yapf: disable # isort:skip_file -from garage.torch.modules.categorical_cnn_module import CategoricalCNNModule from garage.torch.modules.cnn_module import CNNModule from garage.torch.modules.gaussian_mlp_module import ( GaussianMLPIndependentStdModule) # noqa: E501 @@ -12,15 +11,11 @@ from garage.torch.modules.multi_headed_mlp_module import MultiHeadedMLPModule # DiscreteCNNModule must go after MLPModule from garage.torch.modules.discrete_cnn_module import DiscreteCNNModule -from garage.torch.modules.discrete_dueling_cnn_module import ( - DiscreteDuelingCNNModule) # yapf: enable __all__ = [ - 'CategoricalCNNModule', 'CNNModule', 'DiscreteCNNModule', - 'DiscreteDuelingCNNModule', 'MLPModule', 'MultiHeadedMLPModule', 'GaussianMLPModule', diff --git a/src/garage/torch/modules/categorical_cnn_module.py b/src/garage/torch/modules/categorical_cnn_module.py deleted file mode 100644 index aafe73410a..0000000000 --- a/src/garage/torch/modules/categorical_cnn_module.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Categorical CNN Module. - -A model represented by a categorical distribution -which is parameterized by a convolutional neural network (CNN) -followed a multilayer perceptron (MLP). -""" -import torch -from torch import nn -from torch.distributions import Categorical - -from garage.torch.modules.cnn_module import CNNModule - - -class CategoricalCNNModule(nn.Module): - """Categorical CNN Model. - - A model represented by a Categorical distribution - which is parameterized by a convolutional neural network (CNN) followed - by a fully-connected layer. - - Args: - input_var (torch.tensor): Input tensor of the model. - output_dim (int): Output dimension of the model. - kernel_sizes (tuple[int]): Dimension of the conv filters. - For example, (3, 5) means there are two convolutional layers. - The filter for first layer is of dimension (3 x 3) - and the second one is of dimension (5 x 5). - strides (tuple[int]): The stride of the sliding window. For example, - (1, 2) means there are two convolutional layers. The stride of the - filter for first layer is 1 and that of the second layer is 2. - hidden_channels (tuple[int]): Number of output channels for CNN. - For example, (3, 32) means there are two convolutional layers. - The filter for the first conv layer outputs 3 channels - and the second one outputs 32 channels. - hidden_sizes (list[int]): Output dimension of dense layer(s) for - the MLP for mean. For example, (32, 32) means the MLP consists - of two hidden layers, each with 32 hidden units. - hidden_nonlinearity (callable): Activation function for intermediate - dense layer(s). It should return a torch.Tensor. Set it to - None to maintain a linear activation. - hidden_w_init (callable): Initializer function for the weight - of intermediate dense layer(s). The function should return a - torch.Tensor. - hidden_b_init (callable): Initializer function for the bias - of intermediate dense layer(s). The function should return a - torch.Tensor. - paddings (tuple[int]): Zero-padding added to both sides of the input - padding_mode (str): The type of padding algorithm to use, - either 'SAME' or 'VALID'. - max_pool (bool): Bool for using max-pooling or not. - pool_shape (tuple[int]): Dimension of the pooling layer(s). For - example, (2, 2) means that all the pooling layers are of the same - shape (2, 2). - pool_stride (tuple[int]): The strides of the pooling layer(s). For - example, (2, 2) means that all the pooling layers have - strides (2, 2). - output_nonlinearity (callable): Activation function for output dense - layer. It should return a torch.Tensor. Set it to None to - maintain a linear activation. - output_w_init (callable): Initializer function for the weight - of output dense layer(s). The function should return a - torch.Tensor. - output_b_init (callable): Initializer function for the bias - of output dense layer(s). The function should return a - torch.Tensor. - layer_normalization (bool): Bool for using layer normalization or not. - is_image (bool): Whether observations are images or not. - """ - - def __init__(self, - input_var, - output_dim, - kernel_sizes, - hidden_channels, - strides=1, - hidden_sizes=(32, 32), - hidden_nonlinearity=torch.tanh, - hidden_w_init=nn.init.xavier_uniform_, - hidden_b_init=nn.init.zeros_, - paddings=0, - padding_mode='zeros', - max_pool=False, - pool_shape=None, - pool_stride=1, - output_nonlinearity=None, - output_w_init=nn.init.xavier_uniform_, - output_b_init=nn.init.zeros_, - layer_normalization=False, - is_image=True): - super().__init__() - self._input_var = input_var - self._action_dim = output_dim - self._kernel_sizes = kernel_sizes - self._strides = strides - self._hidden_sizes = hidden_sizes - self._hidden_conv_channels = hidden_channels - self._hidden_nonlinearity = hidden_nonlinearity - self._hidden_w_init = hidden_w_init - self._hidden_b_init = hidden_b_init - self._paddings = paddings - self._padding_mode = padding_mode - self._max_pool = max_pool - self._pool_shape = pool_shape - self._pool_stride = pool_stride - self._output_nonlinearity = output_nonlinearity - self._output_w_init = output_w_init - self._output_b_init = output_b_init - self._layer_normalization = layer_normalization - self._is_image = is_image - - self._cnn_module = CNNModule( - input_var=self._input_var, - kernel_sizes=self._kernel_sizes, - strides=self._strides, - hidden_channels=self._hidden_conv_channels, - hidden_nonlinearity=self._hidden_nonlinearity, - paddings=self._paddings, - padding_mode=self._padding_mode, - max_pool=self._max_pool, - pool_shape=self._pool_shape, - pool_stride=self._pool_stride, - is_image=self._is_image) - - def forward(self, *inputs): - """Forward method. - - Args: - *inputs: Input to the module. - - Returns: - torch.distributions.Categorical: Policy distribution. - - """ - assert len(inputs) == 1 - cnn_output = self._cnn_module(inputs[0]) - - # low-level pytorch fully-connection layer - w = torch.empty((cnn_output.shape[1], self._action_dim)) - w.requires_grad = True - b = torch.empty(self._action_dim) - b.require_grad = True - fc_w = self._hidden_w_init(w) - fc_b = self._hidden_b_init(b) - fc_output = cnn_output.mm(fc_w) + fc_b - - dist = Categorical(logits=fc_output) - return dist diff --git a/src/garage/torch/modules/cnn_module.py b/src/garage/torch/modules/cnn_module.py index 2316359c15..269765ea04 100644 --- a/src/garage/torch/modules/cnn_module.py +++ b/src/garage/torch/modules/cnn_module.py @@ -1,21 +1,36 @@ """CNN Module.""" -import copy +import warnings +import akro +import numpy as np import torch from torch import nn -from garage.torch import flatten_to_single_vector, NonLinearity +from garage import InOutSpec +from garage.torch import (expand_var, NonLinearity, output_height_2d, + output_width_2d) # pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305 # pylint: disable=abstract-method -# pylint: disable=unused-argument class CNNModule(nn.Module): """Convolutional neural network (CNN) model in pytorch. Args: - input_var (pytorch.tensor): Input tensor of the model. - Based on 'NCHW' data format: [batch_size, channel, height, width]. + spec (garage.InOutSpec): Specification of inputs and outputs. + The input should be in 'NCHW' format: [batch_size, channel, height, + width]. Will print a warning if the channel size is not 1 or 3. + If output_space is specified, then a final linear layer will be + inserted to map to that dimensionality. + If output_space is None, it will be filled in with the computed + output space. + image_format (str): Either 'NCHW' or 'NHWC'. Should match the input + specification. Gym uses NHWC by default, but PyTorch uses NCHW by + default. + hidden_channels (tuple[int]): Number of output channels for CNN. + For example, (3, 32) means there are two convolutional layers. + The filter for the first conv layer outputs 3 channels + and the second one outputs 32 channels. kernel_sizes (tuple[int]): Dimension of the conv filters. For example, (3, 5) means there are two convolutional layers. The filter for first layer is of dimension (3 x 3) @@ -23,25 +38,18 @@ class CNNModule(nn.Module): strides (tuple[int]): The stride of the sliding window. For example, (1, 2) means there are two convolutional layers. The stride of the filter for first layer is 1 and that of the second layer is 2. - hidden_channels (tuple[int]): Number of output channels for CNN. - For example, (3, 32) means there are two convolutional layers. - The filter for the first conv layer outputs 3 channels - and the second one outputs 32 channels. + paddings (tuple[int]): Amount of zero-padding added to both sides of + the input of a conv layer. + padding_mode (str): The type of padding algorithm to use, i.e. + 'constant', 'reflect', 'replicate' or 'circular' and + by default is 'zeros'. hidden_nonlinearity (callable or torch.nn.Module): Activation function for intermediate dense layer(s). It should return a torch.Tensor. Set it to None to maintain a linear activation. - hidden_w_init (callable): Initializer function for the weight - of intermediate dense layer(s). The function should return a - torch.Tensor. hidden_b_init (callable): Initializer function for the bias of intermediate dense layer(s). The function should return a torch.Tensor. - paddings (tuple[int]): Amount of zero-padding added to both sides of - the input of a conv layer. - padding_mode (str): The type of padding algorithm to use, i.e. - 'constant', 'reflect', 'replicate' or 'circular' and - by default is 'zeros'. max_pool (bool): Bool for using max-pooling or not. pool_shape (tuple[int]): Dimension of the pooling layer(s). For example, (2, 2) means that all pooling layers are of the same @@ -50,174 +58,188 @@ class CNNModule(nn.Module): example, (2, 2) means that all the pooling layers have strides (2, 2). layer_normalization (bool): Bool for using layer normalization or not. - n_layers (int): number of convolutional layer. - is_image (bool): Whether observations are images or not. - """ - - def __init__(self, - input_var, - hidden_channels, - kernel_sizes, - strides, - hidden_nonlinearity=nn.ReLU, - hidden_w_init=nn.init.xavier_uniform_, - hidden_b_init=nn.init.zeros_, - paddings=0, - padding_mode='zeros', - max_pool=False, - pool_shape=None, - pool_stride=1, - layer_normalization=False, - n_layers=None, - is_image=True): - if len(strides) != len(hidden_channels): - raise ValueError('Strides and hidden_channels must have the same' - ' number of dimensions') - super().__init__() - self._hidden_channels = hidden_channels - self._kernel_sizes = kernel_sizes - self._strides = strides - self._hidden_nonlinearity = hidden_nonlinearity - self._hidden_w_init = hidden_w_init - self._hidden_b_init = hidden_b_init - self._paddings = paddings - self._padding_mode = padding_mode - self._max_pool = max_pool - self._pool_shape = pool_shape - self._pool_stride = pool_stride - self._layer_normalization = layer_normalization - self._is_image = is_image - - self._cnn_layers = nn.ModuleList() - self._in_channel = input_var.shape[1] # read in N, C, H, W - self._CNNCell() - - @classmethod - def _check_parameter_for_output_layer(cls, var, n_layers): - """Check input parameters for conv layer are valid. - - Args: - var (any): variable to be checked - n_layers (int): number of layers + hidden_w_init (callable): Initializer function for the weight + of intermediate dense layer(s). The function should return a + torch.Tensor. - Returns: - list: list of variables (length of n_layers) + Raises: + ValueError: If spec or other arguments are inconsistent. - Raises: - ValueError: if the variable is a list but length of the variable - is not equal to n_layers + """ - """ - if isinstance(var, (list, tuple)): - if len(var) == 1: - return list(var) * n_layers - if len(var) == n_layers: - return var - msg = ('{} should be either an integer or a collection of length ' - 'n_layers ({}), but got {} instead.') - raise ValueError(msg.format(str(var), n_layers, var)) - return [copy.deepcopy(var) for _ in range(n_layers)] + def __init__( + self, + spec, + image_format, + hidden_channels, + *, # Many things after this are ints or tuples of ints. + kernel_sizes, + strides, + paddings=0, + padding_mode='zeros', + hidden_nonlinearity=nn.ReLU, + hidden_w_init=nn.init.xavier_uniform_, + hidden_b_init=nn.init.zeros_, + max_pool=False, + pool_shape=None, + pool_stride=1, + layer_normalization=False): + super().__init__() + assert len(hidden_channels) > 0 + # PyTorch forces us to use NCHW internally. + in_channels, height, width = _check_spec(spec, image_format) + self._format = image_format + kernel_sizes = expand_var('kernel_sizes', kernel_sizes, + len(hidden_channels), 'hidden_channels') + strides = expand_var('strides', strides, len(hidden_channels), + 'hidden_channels') + paddings = expand_var('paddings', paddings, len(hidden_channels), + 'hidden_channels') + pool_shape = expand_var('pool_shape', pool_shape, len(hidden_channels), + 'hidden_channels') + pool_stride = expand_var('pool_stride', pool_stride, + len(hidden_channels), 'hidden_channels') + + self._cnn_layers = nn.Sequential() + + # In case there are no hidden channels, handle output case. + out_channels = in_channels + for i, out_channels in enumerate(hidden_channels): + conv_layer = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_sizes[i], + stride=strides[i], + padding=paddings[i], + padding_mode=padding_mode) + height = output_height_2d(conv_layer, height) + width = output_width_2d(conv_layer, width) + hidden_w_init(conv_layer.weight) + hidden_b_init(conv_layer.bias) + self._cnn_layers.add_module(f'conv_{i}', conv_layer) + + if layer_normalization: + self._cnn_layers.add_module( + f'layer_norm_{i}', + nn.LayerNorm((out_channels, height, width))) + + if hidden_nonlinearity: + self._cnn_layers.add_module(f'non_linearity_{i}', + NonLinearity(hidden_nonlinearity)) + + if max_pool: + pool = nn.MaxPool2d(kernel_size=pool_shape[i], + stride=pool_stride[i]) + height = output_height_2d(pool, height) + width = output_width_2d(pool, width) + self._cnn_layers.add_module(f'max_pooling_{i}', pool) + + in_channels = out_channels + + output_dims = out_channels * height * width + + if spec.output_space is None: + final_spec = InOutSpec( + spec.input_space, + akro.Box(low=-np.inf, high=np.inf, shape=(output_dims, ))) + self._final_layer = None + else: + final_spec = spec + # Checked at start of __init__ + self._final_layer = nn.Linear(output_dims, + spec.output_space.shape[0]) + + self.spec = final_spec # pylint: disable=arguments-differ - def forward(self, input_val): + def forward(self, x): """Forward method. Args: - input_val (torch.Tensor): Input values with (N, C, H, W) - shape. + x (torch.Tensor): Input values. Should match image_format + specified at construction (either NCHW or NCWH). Returns: List[torch.Tensor]: Output values """ - if self._is_image: - input_val = torch.div(input_val, 255.0) - x = input_val + # Transform single values into batch, if necessary. + if len(x.shape) == 3: + x = x.unsqueeze(0) + # This should be the single place in torch that image normalization + # happens + if isinstance(self.spec.input_space, akro.Image): + x = torch.div(x, 255.0) + assert len(x.shape) == 4 + if self._format == 'NHWC': + # Convert to internal NCHW format + x = x.permute((0, 3, 1, 2)) for layer in self._cnn_layers: x = layer(x) - x = flatten_to_single_vector(x) + if self._format == 'NHWC': + # Convert back to NHWC (just in case) + x = x.permute((0, 2, 3, 1)) + # Remove non-batch dimensions + x = x.reshape(x.shape[0], -1) + # Apply final linearity, if it was requested. + if self._final_layer is not None: + x = self._final_layer(x) return x - def _CNNCell(self): - """Helper function for initializing convolutional layer(s).""" - prev_channel = self._in_channel - for index, (channel, kernel_size, stride) in enumerate( - zip(self._hidden_channels, self._kernel_sizes, self._strides)): - hidden_layers = nn.Sequential() - - if isinstance(self._paddings, (list, tuple)): - padding = self._paddings[index] - elif isinstance(self._paddings, int): - padding = self._paddings - - # conv 2D layer - conv_layer = _conv(in_channels=prev_channel, - out_channels=channel, - kernel_size=kernel_size, - stride=stride, - padding=padding) - self._hidden_w_init(conv_layer.weight) - self._hidden_b_init(conv_layer.bias) - hidden_layers.add_module('conv_{}'.format(index), conv_layer) - - # layer normalization - if self._layer_normalization: - hidden_layers.add_module('layer_normalization', - nn.LayerNorm(channel)) - - # non-linear function - if self._hidden_nonlinearity: - hidden_layers.add_module( - 'non_linearity', NonLinearity(self._hidden_nonlinearity)) - - # max-pooling - if self._max_pool: - hidden_layers.add_module( - 'max_pooling', - nn.MaxPool2d(kernel_size=self._pool_shape, - stride=self._pool_stride)) - - self._cnn_layers.append(hidden_layers) - prev_channel = channel - - -def _conv(in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - padding_mode='zeros', - dilation=1, - bias=True): - """Helper function for performing convolution. + +def _check_spec(spec, image_format): + """Check that an InOutSpec is suitable for a CNNModule. Args: - in_channels (int): - Number of channels in the input image - out_channels (int): - Number of channels produced by the convolution - kernel_size (int or tuple): - Size of the convolving kernel - stride (int or tuple): Stride of the convolution. - Default: 1 - padding (int or tuple): Zero-padding added to both sides - of the input. Default: 0 - padding_mode (string): 'zeros', 'reflect', 'replicate' - or 'circular'. Default: 'zeros' - dilation (int or tuple): Spacing between kernel elements. - Default: 1 - bias (bool): If True, adds a learnable bias to the output. - Default: True + spec (garage.InOutSpec): Specification of inputs and outputs. The + input should be in 'NCHW' format: [batch_size, channel, height, + width]. Will print a warning if the channel size is not 1 or 3. + If output_space is specified, then a final linear layer will be + inserted to map to that dimensionality. If output_space is None, + it will be filled in with the computed output space. + image_format (str): Either 'NCHW' or 'NHWC'. Should match the input + specification. Gym uses NHWC by default, but PyTorch uses NCHW by + default. Returns: - torch.Tensor: The output of the 2D convolution. + tuple[int, int, int]: The input channels, height, and width. + + Raises: + ValueError: If spec isn't suitable for a CNNModule. """ - return nn.Conv2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - padding_mode=padding_mode, - dilation=dilation, - bias=bias) + # pylint: disable=no-else-raise + input_space = spec.input_space + output_space = spec.output_space + # Don't use isinstance, since akro.Space is guaranteed to inherit from + # gym.Space + if getattr(input_space, 'shape', None) is None: + raise ValueError( + f'input_space to CNNModule is {input_space}, but should be an ' + f'akro.Box or akro.Image') + elif len(input_space.shape) != 3: + raise ValueError( + f'Input to CNNModule is {input_space}, but should have three ' + f'dimensions.') + if (output_space is not None and not (hasattr(output_space, 'shape') + and len(output_space.shape) == 1)): + raise ValueError( + f'output_space to CNNModule is {output_space}, but should be ' + f'an akro.Box with a single dimension or None') + if image_format == 'NCHW': + in_channels = spec.input_space.shape[0] + height = spec.input_space.shape[1] + width = spec.input_space.shape[2] + elif image_format == 'NHWC': + height = spec.input_space.shape[0] + width = spec.input_space.shape[1] + in_channels = spec.input_space.shape[2] + else: + raise ValueError( + f'image_format has value {image_format!r}, but must be either ' + f"'NCHW' or 'NHWC'") + if in_channels not in (1, 3): + warnings.warn( + f'CNNModule input has {in_channels} channels, but ' + f'1 or 3 channels are typical. Consider changing the CNN ' + f'image_format.') + return in_channels, height, width diff --git a/src/garage/torch/modules/discrete_cnn_module.py b/src/garage/torch/modules/discrete_cnn_module.py index c26db257ae..39431f2719 100644 --- a/src/garage/torch/modules/discrete_cnn_module.py +++ b/src/garage/torch/modules/discrete_cnn_module.py @@ -1,7 +1,7 @@ -"""Discrete CNN Module.""" -import torch +"""Discrete CNN Q Function.""" from torch import nn +from garage import InOutSpec from garage.torch.modules import CNNModule, MLPModule @@ -14,9 +14,13 @@ class DiscreteCNNModule(nn.Module): of discrete outputs. Args: - input_shape (tuple[int]): Shape of the input. Based on 'NCHW' data - format: [batch_size, channel, height, width]. - output_dim (int): Output dimension of the fully-connected layer. + spec (garage.InOutSpec): Specification of inputs and outputs. + The input should be in 'NCHW' format: [batch_size, channel, height, + width]. Will print a warning if the channel size is not 1 or 3. + The output space will be flattened. + image_format (str): Either 'NCHW' or 'NHWC'. Should match the input + specification. Gym uses NHWC by default, but PyTorch uses NCHW by + default. kernel_sizes (tuple[int]): Dimension of the conv filters. For example, (3, 5) means there are two convolutional layers. The filter for first layer is of dimension (3 x 3) @@ -63,12 +67,12 @@ class DiscreteCNNModule(nn.Module): of output dense layer(s). The function should return a torch.Tensor. layer_normalization (bool): Bool for using layer normalization or not. - is_image (bool): If true, the inputs are normalized by dividing by 255. """ def __init__(self, - input_shape, - output_dim, + spec, + image_format, + *, kernel_sizes, hidden_channels, strides, @@ -85,13 +89,13 @@ def __init__(self, output_nonlinearity=None, output_w_init=nn.init.xavier_uniform_, output_b_init=nn.init.zeros_, - layer_normalization=False, - is_image=True): + layer_normalization=False): super().__init__() - input_var = torch.zeros(input_shape) - cnn_module = CNNModule(input_var=input_var, + cnn_spec = InOutSpec(input_space=spec.input_space, output_space=None) + cnn_module = CNNModule(spec=cnn_spec, + image_format=image_format, kernel_sizes=kernel_sizes, strides=strides, hidden_w_init=hidden_w_init, @@ -103,13 +107,10 @@ def __init__(self, max_pool=max_pool, layer_normalization=layer_normalization, pool_shape=pool_shape, - pool_stride=pool_stride, - is_image=is_image) - - with torch.no_grad(): - cnn_out = cnn_module(input_var) - flat_dim = torch.flatten(cnn_out, start_dim=1).shape[1] + pool_stride=pool_stride) + flat_dim = cnn_module.spec.output_space.flat_dim + output_dim = spec.output_space.flat_dim mlp_module = MLPModule(flat_dim, output_dim, hidden_sizes, diff --git a/src/garage/torch/modules/discrete_dueling_cnn_module.py b/src/garage/torch/modules/discrete_dueling_cnn_module.py deleted file mode 100644 index 243c1b5579..0000000000 --- a/src/garage/torch/modules/discrete_dueling_cnn_module.py +++ /dev/null @@ -1,156 +0,0 @@ -"""Discrete Dueling CNN Module.""" -import torch -from torch import nn - -from garage.torch.modules import CNNModule, MLPModule - - -# pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305 -# pylint: disable=abstract-method -class DiscreteDuelingCNNModule(nn.Module): - """Discrete Dueling CNN Module. - - A CNN followed by one or more fully connected layers with a set number - of discrete outputs for each of the advantage and value parts of - the dueling network. - - Args: - input_shape (tuple[int]): Shape of the input. Based on 'NCHW' data - format: [batch_size, channel, height, width]. - output_dim (int): Output dimension of the fully-connected layer. - kernel_sizes (tuple[int]): Dimension of the conv filters. - For example, (3, 5) means there are two convolutional layers. - The filter for first layer is of dimension (3 x 3) - and the second one is of dimension (5 x 5). - strides (tuple[int]): The stride of the sliding window. For example, - (1, 2) means there are two convolutional layers. The stride of the - filter for first layer is 1 and that of the second layer is 2. - hidden_channels (tuple[int]): Number of output channels for CNN. - For example, (3, 32) means there are two convolutional layers. - The filter for the first conv layer outputs 3 channels - and the second one outputs 32 channels. - hidden_sizes (list[int]): Output dimension of dense layer(s) for - the MLP for mean. For example, (32, 32) means the MLP consists - of two hidden layers, each with 32 hidden units. - mlp_hidden_nonlinearity (callable): Activation function for - intermediate dense layer(s) in the MLP. It should return - a torch.Tensor. Set it to None to maintain a linear activation. - cnn_hidden_nonlinearity (callable): Activation function for - intermediate CNN layer(s). It should return a torch.Tensor. - Set it to None to maintain a linear activation. - hidden_w_init (callable): Initializer function for the weight - of intermediate dense layer(s). The function should return a - torch.Tensor. - hidden_b_init (callable): Initializer function for the bias - of intermediate dense layer(s). The function should return a - torch.Tensor. - paddings (tuple[int]): Zero-padding added to both sides of the input - padding_mode (str): The type of padding algorithm to use, - either 'SAME' or 'VALID'. - max_pool (bool): Bool for using max-pooling or not. - pool_shape (tuple[int]): Dimension of the pooling layer(s). For - example, (2, 2) means that all the pooling layers are of the same - shape (2, 2). - pool_stride (tuple[int]): The strides of the pooling layer(s). For - example, (2, 2) means that all the pooling layers have - strides (2, 2). - output_nonlinearity (callable): Activation function for output dense - layer. It should return a torch.Tensor. Set it to None to - maintain a linear activation. - output_w_init (callable): Initializer function for the weight - of output dense layer(s). The function should return a - torch.Tensor. - output_b_init (callable): Initializer function for the bias - of output dense layer(s). The function should return a - torch.Tensor. - layer_normalization (bool): Bool for using layer normalization or not. - is_image (bool): If true, the inputs are normalized by dividing by 255. - """ - - def __init__(self, - input_shape, - output_dim, - kernel_sizes, - hidden_channels, - strides, - hidden_sizes=(32, 32), - cnn_hidden_nonlinearity=nn.ReLU, - mlp_hidden_nonlinearity=nn.ReLU, - hidden_w_init=nn.init.xavier_uniform_, - hidden_b_init=nn.init.zeros_, - paddings=0, - padding_mode='zeros', - max_pool=False, - pool_shape=None, - pool_stride=1, - output_nonlinearity=None, - output_w_init=nn.init.xavier_uniform_, - output_b_init=nn.init.zeros_, - layer_normalization=False, - is_image=True): - - super().__init__() - - input_var = torch.zeros(input_shape) - cnn_module = CNNModule(input_var=input_var, - kernel_sizes=kernel_sizes, - strides=strides, - hidden_w_init=hidden_w_init, - hidden_b_init=hidden_b_init, - hidden_channels=hidden_channels, - hidden_nonlinearity=cnn_hidden_nonlinearity, - paddings=paddings, - padding_mode=padding_mode, - max_pool=max_pool, - layer_normalization=layer_normalization, - pool_shape=pool_shape, - pool_stride=pool_stride, - is_image=is_image) - - with torch.no_grad(): - cnn_out = cnn_module(input_var) - flat_dim = torch.flatten(cnn_out, start_dim=1).shape[1] - - self._val = MLPModule(flat_dim, - 1, - hidden_sizes, - hidden_nonlinearity=mlp_hidden_nonlinearity, - hidden_w_init=hidden_w_init, - hidden_b_init=hidden_b_init, - output_nonlinearity=output_nonlinearity, - output_w_init=output_w_init, - output_b_init=output_b_init, - layer_normalization=layer_normalization) - self._act = MLPModule(flat_dim, - output_dim, - hidden_sizes, - hidden_nonlinearity=mlp_hidden_nonlinearity, - hidden_w_init=hidden_w_init, - hidden_b_init=hidden_b_init, - output_nonlinearity=output_nonlinearity, - output_w_init=output_w_init, - output_b_init=output_b_init, - layer_normalization=layer_normalization) - if mlp_hidden_nonlinearity is None: - self._module = nn.Sequential(cnn_module, nn.Flatten()) - else: - self._module = nn.Sequential(cnn_module, mlp_hidden_nonlinearity(), - nn.Flatten()) - - # pylint: disable=arguments-differ - def forward(self, inputs): - """Forward method. - - Args: - inputs (torch.Tensor): Inputs to the model of shape - (input_shape*). - - Returns: - torch.Tensor: Output tensor of shape :math:`(N, output_dim)`. - - """ - out = self._module(inputs) - val = self._val(out) - act = self._act(out) - act = act - act.mean(1).unsqueeze(1) - return val + act diff --git a/src/garage/torch/policies/categorical_cnn_policy.py b/src/garage/torch/policies/categorical_cnn_policy.py index dc7c02c8de..05ef812069 100644 --- a/src/garage/torch/policies/categorical_cnn_policy.py +++ b/src/garage/torch/policies/categorical_cnn_policy.py @@ -3,7 +3,8 @@ import torch from torch import nn -from garage.torch.modules import CategoricalCNNModule +from garage import InOutSpec +from garage.torch.modules import CNNModule, MultiHeadedMLPModule from garage.torch.policies.stochastic_policy import StochasticPolicy @@ -16,7 +17,9 @@ class CategoricalCNNPolicy(StochasticPolicy): It only works with akro.Discrete action space. Args: - env (garage.envs): Environment. + env_spec (garage.EnvSpec): Environment specification. + image_format (str): Either 'NCHW' or 'NHWC'. Should match env_spec. Gym + uses NHWC by default, but PyTorch uses NCHW by default. kernel_sizes (tuple[int]): Dimension of the conv filters. For example, (3, 5) means there are two convolutional layers. The filter for first layer is of dimension (3 x 3) @@ -49,9 +52,6 @@ class CategoricalCNNPolicy(StochasticPolicy): pool_stride (tuple[int]): The strides of the pooling layer(s). For example, (2, 2) means that all the pooling layers have strides (2, 2). - output_nonlinearity (callable): Activation function for output dense - layer. It should return a torch.Tensor. Set it to None to - maintain a linear activation. output_w_init (callable): Initializer function for the weight of output dense layer(s). The function should return a torch.Tensor. @@ -64,8 +64,10 @@ class CategoricalCNNPolicy(StochasticPolicy): """ def __init__(self, - env, + env_spec, + image_format, kernel_sizes, + *, hidden_channels, strides=1, hidden_sizes=(32, 32), @@ -77,74 +79,62 @@ def __init__(self, max_pool=False, pool_shape=None, pool_stride=1, - output_nonlinearity=None, output_w_init=nn.init.xavier_uniform_, output_b_init=nn.init.zeros_, layer_normalization=False, name='CategoricalCNNPolicy'): - if not isinstance(env.spec.action_space, akro.Discrete): + if not isinstance(env_spec.action_space, akro.Discrete): raise ValueError('CategoricalMLPPolicy only works ' 'with akro.Discrete action space.') - if isinstance(env.spec.observation_space, akro.Dict): + if isinstance(env_spec.observation_space, akro.Dict): raise ValueError('CNN policies do not support ' 'with akro.Dict observation spaces.') - super().__init__(env.spec, name) - self._env = env - self._obs_dim = self._env.spec.observation_space.shape - self._action_dim = self._env.spec.action_space.flat_dim - self._kernel_sizes = kernel_sizes - self._strides = strides - self._hidden_nonlinearity = hidden_nonlinearity - self._hidden_conv_channels = hidden_channels - self._hidden_w_init = hidden_w_init - self._hidden_b_init = hidden_b_init - self._hidden_sizes = hidden_sizes - self._paddings = paddings - self._padding_mode = padding_mode - self._max_pool = max_pool - self._pool_shape = pool_shape - self._pool_stride = pool_stride - self._output_nonlinearity = output_nonlinearity - self._output_w_init = output_w_init - self._output_b_init = output_b_init - self._layer_normalization = layer_normalization - self._is_image = isinstance(self._env.spec.observation_space, - akro.Image) + super().__init__(env_spec, name) + + self._cnn_module = CNNModule(InOutSpec( + self._env_spec.observation_space, None), + image_format=image_format, + kernel_sizes=kernel_sizes, + strides=strides, + hidden_channels=hidden_channels, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + hidden_nonlinearity=hidden_nonlinearity, + paddings=paddings, + padding_mode=padding_mode, + max_pool=max_pool, + pool_shape=pool_shape, + pool_stride=pool_stride, + layer_normalization=layer_normalization) + self._mlp_module = MultiHeadedMLPModule( + n_heads=1, + input_dim=self._cnn_module.spec.output_space.flat_dim, + output_dims=[self._env_spec.action_space.flat_dim], + hidden_sizes=hidden_sizes, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + hidden_nonlinearity=hidden_nonlinearity, + output_w_inits=output_w_init, + output_b_inits=output_b_init) def forward(self, observations): """Compute the action distributions from the observations. Args: - observations (torch.Tensor): Batch of observations on default - torch device. + observations (torch.Tensor): Observations to act on. Returns: torch.distributions.Distribution: Batch distribution of actions. dict[str, torch.Tensor]: Additional agent_info, as torch Tensors. Do not need to be detached, and can be on any device. """ - module = CategoricalCNNModule( - input_var=observations, - output_dim=self._action_dim, - kernel_sizes=self._kernel_sizes, - strides=self._strides, - hidden_channels=self._hidden_conv_channels, - hidden_sizes=self._hidden_sizes, - hidden_nonlinearity=self._hidden_nonlinearity, - hidden_w_init=self._hidden_w_init, - hidden_b_init=self._hidden_b_init, - paddings=self._paddings, - padding_mode=self._padding_mode, - max_pool=self._max_pool, - pool_shape=self._pool_shape, - pool_stride=self._pool_stride, - output_nonlinearity=self._output_nonlinearity, - output_w_init=self._output_w_init, - output_b_init=self._output_b_init, - layer_normalization=self._layer_normalization, - is_image=self._is_image) - - dist = module(observations) + # We're given flattened observations. + observations = observations.reshape( + -1, *self._env_spec.observation_space.shape) + cnn_output = self._cnn_module(observations) + mlp_output = self._mlp_module(cnn_output)[0] + logits = torch.softmax(mlp_output, axis=1) + dist = torch.distributions.Categorical(logits=logits) return dist, {} diff --git a/src/garage/torch/policies/discrete_cnn_policy.py b/src/garage/torch/policies/discrete_cnn_policy.py index bcae1e2d52..50b49757b8 100644 --- a/src/garage/torch/policies/discrete_cnn_policy.py +++ b/src/garage/torch/policies/discrete_cnn_policy.py @@ -1,8 +1,8 @@ """Discrete CNN Policy.""" -import akro import torch from torch import nn +from garage import InOutSpec from garage.torch.modules import DiscreteCNNModule from garage.torch.policies.stochastic_policy import StochasticPolicy @@ -15,6 +15,8 @@ class DiscreteCNNPolicy(StochasticPolicy): Args: env_spec (EnvSpec): Environment specification. + image_format (str): Either 'NCHW' or 'NHWC'. Should match env_spec. Gym + uses NHWC by default, but PyTorch uses NCHW by default. kernel_sizes (tuple[int]): Dimension of the conv filters. For example, (3, 5) means there are two convolutional layers. The filter for first layer is of dimension (3 x 3) @@ -66,6 +68,7 @@ class DiscreteCNNPolicy(StochasticPolicy): def __init__(self, env_spec, + image_format, kernel_sizes, hidden_channels, strides, @@ -86,19 +89,28 @@ def __init__(self, name='DiscreteCNNPolicy'): super().__init__(env_spec, name) - self._env_spec = env_spec - self._input_shape = env_spec.observation_space.shape - self._output_dim = env_spec.action_space.flat_dim - self._is_image = isinstance(self._env_spec.observation_space, - akro.Image) self._cnn_module = DiscreteCNNModule( - self._input_shape, self._output_dim, kernel_sizes, hidden_channels, - strides, hidden_sizes, cnn_hidden_nonlinearity, - mlp_hidden_nonlinearity, hidden_w_init, hidden_b_init, paddings, - padding_mode, max_pool, pool_shape, pool_stride, - output_nonlinearity, output_w_init, output_b_init, - layer_normalization, self._is_image) + spec=InOutSpec(input_space=env_spec.observation_space, + output_space=env_spec.action_space), + image_format=image_format, + kernel_sizes=kernel_sizes, + hidden_channels=hidden_channels, + strides=strides, + hidden_sizes=hidden_sizes, + cnn_hidden_nonlinearity=cnn_hidden_nonlinearity, + mlp_hidden_nonlinearity=mlp_hidden_nonlinearity, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + paddings=paddings, + padding_mode=padding_mode, + max_pool=max_pool, + pool_shape=pool_shape, + pool_stride=pool_stride, + output_nonlinearity=output_nonlinearity, + output_w_init=output_w_init, + output_b_init=output_b_init, + layer_normalization=layer_normalization) def forward(self, observations): """Compute the action distributions from the observations. @@ -114,12 +126,9 @@ def forward(self, observations): dict[str, torch.Tensor]: Additional agent_info, as torch Tensors. Do not need to be detached, and can be on any device. """ - observations = self._env_spec.observation_space.unflatten_n( - observations) - if isinstance(self._env_spec.observation_space, akro.Image): - observations = torch.div(observations, 255.0) - - observations = torch.Tensor(observations[0]) + # We're given flattened observations. + observations = observations.reshape( + -1, *self._env_spec.observation_space.shape) output = self._cnn_module(observations) logits = torch.softmax(output, axis=1) dist = torch.distributions.Bernoulli(logits=logits) diff --git a/src/garage/torch/policies/discrete_qf_argmax_policy.py b/src/garage/torch/policies/discrete_qf_argmax_policy.py index 9cf8c2e625..24d7d21949 100644 --- a/src/garage/torch/policies/discrete_qf_argmax_policy.py +++ b/src/garage/torch/policies/discrete_qf_argmax_policy.py @@ -5,7 +5,7 @@ import numpy as np import torch -from garage.torch import np_to_torch +from garage.torch import as_torch from garage.torch.policies.policy import Policy @@ -51,8 +51,8 @@ def get_action(self, observation): torch.Tensor: Predicted action with shape :math:`(A, )`. dict: Empty since this policy does not produce a distribution. """ - act, dist = self.get_actions(np.expand_dims(observation, axis=0)) - return act[0], dist + act, info = self.get_actions(np.expand_dims(observation, axis=0)) + return act[0], info def get_actions(self, observations): """Get actions given observations. @@ -66,4 +66,4 @@ def get_actions(self, observations): dict: Empty since this policy does not produce a distribution. """ with torch.no_grad(): - return self(np_to_torch(observations)).cpu().numpy(), dict() + return self(as_torch(observations)).cpu().numpy(), dict() diff --git a/src/garage/torch/policies/stochastic_policy.py b/src/garage/torch/policies/stochastic_policy.py index 2755f259db..5185719db1 100644 --- a/src/garage/torch/policies/stochastic_policy.py +++ b/src/garage/torch/policies/stochastic_policy.py @@ -1,11 +1,10 @@ """Base Stochastic Policy.""" import abc -import akro import numpy as np import torch -from garage.torch import global_device +from garage.torch import as_torch from garage.torch.policies.policy import Policy @@ -39,8 +38,7 @@ def get_action(self, observation): observation = torch.flatten(observation) with torch.no_grad(): if not isinstance(observation, torch.Tensor): - observation = torch.as_tensor(observation).float().to( - global_device()) + observation = as_torch(observation) observation = observation.unsqueeze(0) action, agent_infos = self.get_actions(observation) return action[0], {k: v[0] for k, v in agent_infos.items()} @@ -81,18 +79,9 @@ def get_actions(self, observations): elif isinstance(observations[0], torch.Tensor) and len(observations[0].shape) > 1: observations = torch.flatten(observations, start_dim=1) - - if isinstance(self._env_spec.observation_space, akro.Image) and \ - len(observations.shape) < \ - len(self._env_spec.observation_space.shape): - observations = self._env_spec.observation_space.unflatten_n( - observations) with torch.no_grad(): if not isinstance(observations, torch.Tensor): - observations = torch.as_tensor(observations).float().to( - global_device()) - if isinstance(self._env_spec.observation_space, akro.Image): - observations /= 255.0 # scale image + observations = as_torch(observations) dist, info = self.forward(observations) return dist.sample().cpu().numpy(), { k: v.detach().cpu().numpy() diff --git a/src/garage/torch/q_functions/discrete_cnn_q_function.py b/src/garage/torch/q_functions/discrete_cnn_q_function.py index 4550ef52da..f7c662cf5f 100644 --- a/src/garage/torch/q_functions/discrete_cnn_q_function.py +++ b/src/garage/torch/q_functions/discrete_cnn_q_function.py @@ -2,12 +2,13 @@ import torch from torch import nn +from garage import InOutSpec from garage.torch.modules import DiscreteCNNModule # pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305 # pylint: disable=abstract-method -class DiscreteCNNQFunction(DiscreteCNNModule): +class DiscreteCNNQFunction(nn.Module): """Discrete CNN Q Function. A Q network that estimates Q values of all possible discrete actions. @@ -16,6 +17,9 @@ class DiscreteCNNQFunction(DiscreteCNNModule): Args: env_spec (EnvSpec): Environment specification. + image_format (str): Either 'NCHW' or 'NHWC'. Should match the input + specification. Gym uses NHWC by default, but PyTorch uses NCHW by + default. kernel_sizes (tuple[int]): Dimension of the conv filters. For example, (3, 5) means there are two convolutional layers. The filter for first layer is of dimension (3 x 3) @@ -62,11 +66,12 @@ class DiscreteCNNQFunction(DiscreteCNNModule): of output dense layer(s). The function should return a torch.Tensor. layer_normalization (bool): Bool for using layer normalization or not. - is_image (bool): If true, the inputs are normalized by dividing by 255. """ def __init__(self, env_spec, + image_format, + *, kernel_sizes, hidden_channels, strides, @@ -83,32 +88,32 @@ def __init__(self, output_nonlinearity=None, output_w_init=nn.init.xavier_uniform_, output_b_init=nn.init.zeros_, - layer_normalization=False, - is_image=True): + layer_normalization=False): + super().__init__() self._env_spec = env_spec - input_shape = (1, ) + env_spec.observation_space.shape - output_dim = env_spec.action_space.flat_dim - super().__init__(input_shape=input_shape, - output_dim=output_dim, - kernel_sizes=kernel_sizes, - strides=strides, - hidden_sizes=hidden_sizes, - hidden_channels=hidden_channels, - cnn_hidden_nonlinearity=cnn_hidden_nonlinearity, - mlp_hidden_nonlinearity=mlp_hidden_nonlinearity, - hidden_w_init=hidden_w_init, - hidden_b_init=hidden_b_init, - paddings=paddings, - padding_mode=padding_mode, - max_pool=max_pool, - pool_shape=pool_shape, - pool_stride=pool_stride, - output_nonlinearity=output_nonlinearity, - output_w_init=output_w_init, - output_b_init=output_b_init, - layer_normalization=layer_normalization, - is_image=is_image) + + self._cnn_module = DiscreteCNNModule( + spec=InOutSpec(input_space=env_spec.observation_space, + output_space=env_spec.action_space), + image_format=image_format, + kernel_sizes=kernel_sizes, + hidden_channels=hidden_channels, + strides=strides, + hidden_sizes=hidden_sizes, + cnn_hidden_nonlinearity=cnn_hidden_nonlinearity, + mlp_hidden_nonlinearity=mlp_hidden_nonlinearity, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + paddings=paddings, + padding_mode=padding_mode, + max_pool=max_pool, + pool_shape=pool_shape, + pool_stride=pool_stride, + output_nonlinearity=output_nonlinearity, + output_w_init=output_w_init, + output_b_init=output_b_init, + layer_normalization=layer_normalization) # pylint: disable=arguments-differ def forward(self, observations): @@ -120,10 +125,7 @@ def forward(self, observations): Returns: torch.Tensor: Output value """ - if observations.shape != self._env_spec.observation_space.shape: - # avoid using observation_space.unflatten_n - # to support tensors on GPUs - obs_shape = ((len(observations), ) + - self._env_spec.observation_space.shape) - observations = observations.reshape(obs_shape) - return super().forward(observations) + # We're given flattened observations. + observations = observations.reshape( + -1, *self._env_spec.observation_space.shape) + return self._cnn_module(observations) diff --git a/src/garage/torch/q_functions/discrete_dueling_cnn_q_function.py b/src/garage/torch/q_functions/discrete_dueling_cnn_q_function.py index 4c86147a2b..63a89b4e15 100644 --- a/src/garage/torch/q_functions/discrete_dueling_cnn_q_function.py +++ b/src/garage/torch/q_functions/discrete_dueling_cnn_q_function.py @@ -2,12 +2,13 @@ import torch from torch import nn -from garage.torch.modules import DiscreteDuelingCNNModule +from garage import InOutSpec +from garage.torch.modules import CNNModule, MLPModule # pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305 # pylint: disable=abstract-method -class DiscreteDuelingCNNQFunction(DiscreteDuelingCNNModule): +class DiscreteDuelingCNNQFunction(nn.Module): """Discrete Dueling CNN Q Function. A dueling Q network that estimates Q values of all possible discrete @@ -17,6 +18,9 @@ class DiscreteDuelingCNNQFunction(DiscreteDuelingCNNModule): Args: env_spec (EnvSpec): Environment specification. + image_format (str): Either 'NCHW' or 'NHWC'. Should match the input + specification. Gym uses NHWC by default, but PyTorch uses NCHW by + default. kernel_sizes (tuple[int]): Dimension of the conv filters. For example, (3, 5) means there are two convolutional layers. The filter for first layer is of dimension (3 x 3) @@ -63,11 +67,12 @@ class DiscreteDuelingCNNQFunction(DiscreteDuelingCNNModule): of output dense layer(s). The function should return a torch.Tensor. layer_normalization (bool): Bool for using layer normalization or not. - is_image (bool): If true, the inputs are normalized by dividing by 255. """ def __init__(self, env_spec, + image_format, + *, kernel_sizes, hidden_channels, strides, @@ -84,32 +89,56 @@ def __init__(self, output_nonlinearity=None, output_w_init=nn.init.xavier_uniform_, output_b_init=nn.init.zeros_, - layer_normalization=False, - is_image=True): + layer_normalization=False): + super().__init__() self._env_spec = env_spec - input_shape = (1, ) + env_spec.observation_space.shape - output_dim = env_spec.action_space.flat_dim - super().__init__(input_shape=input_shape, - output_dim=output_dim, - kernel_sizes=kernel_sizes, - strides=strides, - hidden_sizes=hidden_sizes, - hidden_channels=hidden_channels, - cnn_hidden_nonlinearity=cnn_hidden_nonlinearity, - mlp_hidden_nonlinearity=mlp_hidden_nonlinearity, - hidden_w_init=hidden_w_init, - hidden_b_init=hidden_b_init, - paddings=paddings, - padding_mode=padding_mode, - max_pool=max_pool, - pool_shape=pool_shape, - pool_stride=pool_stride, - output_nonlinearity=output_nonlinearity, - output_w_init=output_w_init, - output_b_init=output_b_init, - layer_normalization=layer_normalization, - is_image=is_image) + cnn_spec = InOutSpec(input_space=env_spec.observation_space, + output_space=None) + + cnn_module = CNNModule(spec=cnn_spec, + image_format=image_format, + kernel_sizes=kernel_sizes, + strides=strides, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + hidden_channels=hidden_channels, + hidden_nonlinearity=cnn_hidden_nonlinearity, + paddings=paddings, + padding_mode=padding_mode, + max_pool=max_pool, + layer_normalization=layer_normalization, + pool_shape=pool_shape, + pool_stride=pool_stride) + + # CNNModule computes output dimensionality + flat_dim = cnn_module.spec.output_space.flat_dim + + self._val = MLPModule(flat_dim, + 1, + hidden_sizes, + hidden_nonlinearity=mlp_hidden_nonlinearity, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + output_nonlinearity=output_nonlinearity, + output_w_init=output_w_init, + output_b_init=output_b_init, + layer_normalization=layer_normalization) + self._act = MLPModule(flat_dim, + env_spec.action_space.flat_dim, + hidden_sizes, + hidden_nonlinearity=mlp_hidden_nonlinearity, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + output_nonlinearity=output_nonlinearity, + output_w_init=output_w_init, + output_b_init=output_b_init, + layer_normalization=layer_normalization) + if mlp_hidden_nonlinearity is None: + self._module = nn.Sequential(cnn_module, nn.Flatten()) + else: + self._module = nn.Sequential(cnn_module, mlp_hidden_nonlinearity(), + nn.Flatten()) # pylint: disable=arguments-differ def forward(self, observations): @@ -121,10 +150,11 @@ def forward(self, observations): Returns: torch.Tensor: Output value """ - if observations.shape != self._env_spec.observation_space.shape: - # avoid using observation_space.unflatten_n - # to support tensors on GPUs - obs_shape = ((len(observations), ) + - self._env_spec.observation_space.shape) - observations = observations.reshape(obs_shape) - return super().forward(observations) + # We're given flattened observations. + observations = observations.reshape( + -1, *self._env_spec.observation_space.shape) + out = self._module(observations) + val = self._val(out) + act = self._act(out) + act = act - act.mean(1).unsqueeze(1) + return val + act diff --git a/tests/garage/torch/algos/test_dqn.py b/tests/garage/torch/algos/test_dqn.py index 127192eacc..30aa970494 100644 --- a/tests/garage/torch/algos/test_dqn.py +++ b/tests/garage/torch/algos/test_dqn.py @@ -13,7 +13,7 @@ from garage.np.exploration_policies import EpsilonGreedyPolicy from garage.replay_buffer import PathBuffer from garage.sampler import FragmentWorker, LocalSampler -from garage.torch import np_to_torch +from garage.torch import as_torch from garage.torch.algos import DQN from garage.torch.policies import DiscreteQFArgmaxPolicy from garage.torch.q_functions import DiscreteMLPQFunction @@ -94,11 +94,11 @@ def test_dqn_loss(setup): timesteps = buff.sample_timesteps(algo._buffer_batch_size) timesteps_copy = copy.deepcopy(timesteps) - observations = np_to_torch(timesteps.observations) - rewards = np_to_torch(timesteps.rewards).reshape(-1, 1) - actions = np_to_torch(timesteps.actions) - next_observations = np_to_torch(timesteps.next_observations) - terminals = np_to_torch(timesteps.terminals).reshape(-1, 1) + observations = as_torch(timesteps.observations) + rewards = as_torch(timesteps.rewards).reshape(-1, 1) + actions = as_torch(timesteps.actions) + next_observations = as_torch(timesteps.next_observations) + terminals = as_torch(timesteps.terminals).reshape(-1, 1) next_inputs = next_observations inputs = observations @@ -138,11 +138,11 @@ def test_double_dqn_loss(setup): timesteps = buff.sample_timesteps(algo._buffer_batch_size) timesteps_copy = copy.deepcopy(timesteps) - observations = np_to_torch(timesteps.observations) - rewards = np_to_torch(timesteps.rewards).reshape(-1, 1) - actions = np_to_torch(timesteps.actions) - next_observations = np_to_torch(timesteps.next_observations) - terminals = np_to_torch(timesteps.terminals).reshape(-1, 1) + observations = as_torch(timesteps.observations) + rewards = as_torch(timesteps.rewards).reshape(-1, 1) + actions = as_torch(timesteps.actions) + next_observations = as_torch(timesteps.next_observations) + terminals = as_torch(timesteps.terminals).reshape(-1, 1) next_inputs = next_observations inputs = observations diff --git a/tests/garage/torch/modules/test_categorical_cnn_module.py b/tests/garage/torch/modules/test_categorical_cnn_module.py deleted file mode 100644 index 8f63203efd..0000000000 --- a/tests/garage/torch/modules/test_categorical_cnn_module.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Test CategoricalCNNModule.""" -import pickle - -import numpy as np -import pytest -import torch -from torch.distributions import Categorical -import torch.nn as nn - -from garage.torch.modules.categorical_cnn_module import CategoricalCNNModule - - -class TestCategoricalCNNModule: - """Test CategoricalCNNModule.""" - - def setup_method(self): - self.batch_size = 64 - self.input_width = 32 - self.input_height = 32 - self.in_channel = 3 - self.dtype = torch.float32 - self.input = torch.zeros( - (self.batch_size, self.in_channel, self.input_height, - self.input_width), - dtype=self.dtype) # minibatch size 64, image size [3, 32, 32] - - def test_dist(self): - model = CategoricalCNNModule( - input_var=self.input, - output_dim=1, - kernel_sizes=((3), ), - hidden_channels=((5), ), - strides=(1, ), - ) - dist = model(self.input) - assert isinstance(dist, Categorical) - - @pytest.mark.parametrize( - 'output_dim, hidden_channels, kernel_sizes, strides, hidden_sizes', [ - (1, (1, ), (1, ), (1, ), (1, )), - (1, (3, ), (3, ), (2, ), (2, )), - (1, (3, ), (3, ), (2, ), (3, )), - (2, (3, 3), (3, 3), (2, 2), (1, 1)), - (3, (3, 3), (3, 3), (2, 2), (2, 2)), - ]) - def test_is_pickleable(self, output_dim, hidden_channels, kernel_sizes, - strides, hidden_sizes): - model = CategoricalCNNModule(input_var=self.input, - output_dim=output_dim, - kernel_sizes=kernel_sizes, - hidden_channels=hidden_channels, - strides=strides, - hidden_sizes=hidden_sizes, - hidden_nonlinearity=None, - hidden_w_init=nn.init.xavier_uniform_, - output_w_init=nn.init.zeros_) - dist1 = model(self.input) - - h = pickle.dumps(model) - model_pickled = pickle.loads(h) - dist2 = model_pickled(self.input) - - assert np.array_equal(dist1.probs.shape, dist2.probs.shape) - assert np.array_equal(torch.all(torch.eq(dist1.probs, dist2.probs)), - True) diff --git a/tests/garage/torch/modules/test_cnn_module.py b/tests/garage/torch/modules/test_cnn_module.py index 4463c144ff..3f42ae6de2 100644 --- a/tests/garage/torch/modules/test_cnn_module.py +++ b/tests/garage/torch/modules/test_cnn_module.py @@ -1,12 +1,15 @@ """Test CNNModule.""" import pickle +import akro import numpy as np import pytest import torch import torch.nn as nn +from garage import InOutSpec from garage.torch.modules import CNNModule +from garage.torch.modules.cnn_module import _check_spec class TestCNNModule: @@ -18,6 +21,11 @@ def setup_method(self): self.input_height = 32 self.in_channel = 3 self.dtype = torch.float32 + self.input_spec = InOutSpec( + akro.Box( + shape=[self.in_channel, self.input_height, self.input_width], + high=np.inf, + low=np.inf), None) self.input = torch.zeros( (self.batch_size, self.in_channel, self.input_height, self.input_width), @@ -45,7 +53,8 @@ def test_output_values(self, kernel_sizes, hidden_channels, strides, """ module_with_nonlinear_function_and_module = CNNModule( - input_var=self.input, + self.input_spec, + image_format='NCHW', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, @@ -55,7 +64,8 @@ def test_output_values(self, kernel_sizes, hidden_channels, strides, hidden_w_init=nn.init.xavier_uniform_) module_with_nonlinear_module_instance_and_function = CNNModule( - input_var=self.input, + self.input_spec, + image_format='NCHW', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, @@ -99,7 +109,8 @@ def test_output_values_with_unequal_stride_with_padding( paddings (tuple[int]): value of zero-padding. """ - model = CNNModule(input_var=self.input, + model = CNNModule(self.input_spec, + image_format='NCHW', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, @@ -136,7 +147,8 @@ def test_is_pickleable(self, hidden_channels, kernel_sizes, strides): strides (tuple[int]): strides. """ - model = CNNModule(input_var=self.input, + model = CNNModule(self.input_spec, + image_format='NCHW', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides) @@ -158,13 +170,14 @@ def test_is_pickleable(self, hidden_channels, kernel_sizes, strides): ((3, 3), (32, 64), (1, 1), 2, 2)]) def test_output_with_max_pooling(self, kernel_sizes, hidden_channels, strides, pool_shape, pool_stride): - model = CNNModule(input_var=self.input, + model = CNNModule(self.input_spec, + image_format='NCHW', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, max_pool=True, - pool_shape=(pool_shape, pool_shape), - pool_stride=(pool_stride, pool_stride)) + pool_shape=[(pool_shape, pool_shape)], + pool_stride=[(pool_stride, pool_stride)]) x = model(self.input) fc_w = torch.zeros((x.shape[1], 10)) fc_b = torch.zeros(10) @@ -182,8 +195,61 @@ def test_no_head_invalid_settings(self, hidden_nonlinear): """ expected_msg = 'Non linear function .* is not supported' with pytest.raises(ValueError, match=expected_msg): - CNNModule(input_var=self.input, + CNNModule(self.input_spec, + image_format='NCHW', hidden_channels=(32, ), kernel_sizes=(3, ), strides=(1, ), hidden_nonlinearity=hidden_nonlinear) + + +@pytest.mark.parametrize('kernel_sizes, hidden_channels, ' + 'strides, pool_shape, pool_stride', + [((1, ), (32, ), (1, ), 1, 1), + ((3, ), (32, ), (1, ), 1, 1), + ((3, ), (32, ), (2, ), 2, 2), + ((1, 1), (32, 64), (1, 1), 1, 1), + ((3, 2), (32, 64), (1, 1), 1, 1), + ((3, 2), (32, 64), (2, 1), 1, 1), + ((3, 3), (32, 64), (1, 1), 1, 1), + ((3, 3), (32, 64), (1, 1), 2, 2)]) +def test_set_output_size(kernel_sizes, hidden_channels, strides, pool_shape, + pool_stride): + spec = InOutSpec(akro.Box(shape=[3, 19, 15], high=np.inf, low=-np.inf), + akro.Box(shape=[200], high=np.inf, low=-np.inf)) + model = CNNModule(spec, + image_format='NCHW', + hidden_channels=hidden_channels, + kernel_sizes=kernel_sizes, + strides=strides, + pool_shape=[(pool_shape, pool_shape)], + pool_stride=[(pool_stride, pool_stride)], + layer_normalization=True) + images = torch.ones(10, 3, 19, 15) + x = model(images) + assert x.shape == (10, 200) + + +def test_check_spec(): + with pytest.raises(ValueError, match='should be an akro.Box'): + # Input space is not Box or Image + _check_spec(InOutSpec(akro.Dict(), None), 'NCHW') + with pytest.raises(ValueError, match='should have three dimensions'): + # Too many input dimensions + _check_spec( + InOutSpec(akro.Box(shape=[1, 1, 1, 1], low=-np.inf, high=np.inf), + None), 'NCHW') + with pytest.raises(ValueError, match='akro.Box with a single dimension'): + # Output is not one-dimensional + _check_spec( + InOutSpec(akro.Box(shape=[1, 1, 1], low=-np.inf, high=np.inf), + akro.Box( + shape=[1, 1], + low=-np.inf, + high=np.inf, + )), 'NCHW') + with pytest.warns(UserWarning): + # 4 color channels should warn + _check_spec( + InOutSpec(akro.Box(shape=[4, 1, 1], low=-np.inf, high=np.inf), + None), 'NCHW') diff --git a/tests/garage/torch/modules/test_discrete_cnn_module.py b/tests/garage/torch/modules/test_discrete_cnn_module.py index 82353edaf6..c7205352aa 100644 --- a/tests/garage/torch/modules/test_discrete_cnn_module.py +++ b/tests/garage/torch/modules/test_discrete_cnn_module.py @@ -1,11 +1,13 @@ """Test DiscreteCNNModule.""" import pickle +import akro import numpy as np import pytest import torch import torch.nn as nn +from garage import InOutSpec from garage.torch.modules import CNNModule, DiscreteCNNModule, MLPModule @@ -22,15 +24,16 @@ def test_output_values(output_dim, kernel_sizes, hidden_channels, strides, paddings): - batch_size = 64 input_width = 32 input_height = 32 in_channel = 3 - input_shape = (batch_size, in_channel, input_height, input_width) + input_shape = (in_channel, input_height, input_width) + spec = InOutSpec(akro.Box(shape=input_shape, low=-np.inf, high=np.inf), + akro.Box(shape=(output_dim, ), low=-np.inf, high=np.inf)) obs = torch.rand(input_shape) - module = DiscreteCNNModule(input_shape=input_shape, - output_dim=output_dim, + module = DiscreteCNNModule(spec=spec, + image_format='NCHW', hidden_channels=hidden_channels, hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, @@ -38,17 +41,17 @@ def test_output_values(output_dim, kernel_sizes, hidden_channels, strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_, - output_w_init=nn.init.ones_, - is_image=False) + output_w_init=nn.init.ones_) - cnn = CNNModule(input_var=obs, + cnn = CNNModule(spec=InOutSpec( + akro.Box(shape=input_shape, low=-np.inf, high=np.inf), None), + image_format='NCHW', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', - hidden_w_init=nn.init.ones_, - is_image=False) + hidden_w_init=nn.init.ones_) flat_dim = torch.flatten(cnn(obs).detach(), start_dim=1).shape[1] mlp = MLPModule( @@ -69,14 +72,15 @@ def test_output_values(output_dim, kernel_sizes, hidden_channels, strides, [(1, (32, ), (1, ), (1, ))]) def test_without_nonlinearity(output_dim, hidden_channels, kernel_sizes, strides): - batch_size = 64 input_width = 32 input_height = 32 in_channel = 3 - input_shape = (batch_size, in_channel, input_height, input_width) + input_shape = (in_channel, input_height, input_width) + spec = InOutSpec(akro.Box(shape=input_shape, low=-np.inf, high=np.inf), + akro.Box(shape=(output_dim, ), low=-np.inf, high=np.inf)) - module = DiscreteCNNModule(input_shape=input_shape, - output_dim=output_dim, + module = DiscreteCNNModule(spec=spec, + image_format='NCHW', hidden_channels=hidden_channels, hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, @@ -84,8 +88,7 @@ def test_without_nonlinearity(output_dim, hidden_channels, kernel_sizes, mlp_hidden_nonlinearity=None, cnn_hidden_nonlinearity=None, hidden_w_init=nn.init.ones_, - output_w_init=nn.init.ones_, - is_image=False) + output_w_init=nn.init.ones_) assert len(module._module) == 3 @@ -96,15 +99,16 @@ def test_without_nonlinearity(output_dim, hidden_channels, kernel_sizes, (3, (32, 64), (1, 1), (1, 1)), (4, (32, 64), (3, 3), (1, 1))]) def test_is_pickleable(output_dim, hidden_channels, kernel_sizes, strides): - batch_size = 64 input_width = 32 input_height = 32 in_channel = 3 - input_shape = (batch_size, in_channel, input_height, input_width) + input_shape = (in_channel, input_height, input_width) input_a = torch.ones(input_shape) + spec = InOutSpec(akro.Box(shape=input_shape, low=-np.inf, high=np.inf), + akro.Box(shape=(output_dim, ), low=-np.inf, high=np.inf)) - model = DiscreteCNNModule(input_shape=input_shape, - output_dim=output_dim, + model = DiscreteCNNModule(spec=spec, + image_format='NCHW', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, mlp_hidden_nonlinearity=nn.ReLU, diff --git a/tests/garage/torch/modules/test_discrete_dueling_cnn_module.py b/tests/garage/torch/modules/test_discrete_dueling_cnn_module.py deleted file mode 100644 index ba6f92c318..0000000000 --- a/tests/garage/torch/modules/test_discrete_dueling_cnn_module.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Test DiscreteDuelingCNNModule.""" -import pickle - -import numpy as np -import pytest -import torch -import torch.nn as nn - -from garage.torch.modules import CNNModule, DiscreteDuelingCNNModule, MLPModule - - -@pytest.mark.parametrize( - 'output_dim, kernel_sizes, hidden_channels, strides, paddings', [ - (1, (1, ), (32, ), (1, ), (0, )), - (2, (3, ), (32, ), (1, ), (0, )), - (5, (3, ), (32, ), (2, ), (0, )), - (5, (5, ), (12, ), (1, ), (2, )), - (5, (1, 1), (32, 64), (1, 1), (0, 0)), - (10, (3, 3), (32, 64), (1, 1), (0, 0)), - (10, (3, 3), (32, 64), (2, 2), (0, 0)), - ]) -def test_dueling_output_values(output_dim, kernel_sizes, hidden_channels, - strides, paddings): - - batch_size = 64 - input_width = 32 - input_height = 32 - in_channel = 3 - input_shape = (batch_size, in_channel, input_height, input_width) - obs = torch.rand(input_shape) - - module = DiscreteDuelingCNNModule(input_shape=input_shape, - output_dim=output_dim, - hidden_channels=hidden_channels, - hidden_sizes=hidden_channels, - kernel_sizes=kernel_sizes, - strides=strides, - paddings=paddings, - padding_mode='zeros', - hidden_w_init=nn.init.ones_, - output_w_init=nn.init.ones_, - is_image=False) - - cnn = CNNModule(input_var=obs, - hidden_channels=hidden_channels, - kernel_sizes=kernel_sizes, - strides=strides, - paddings=paddings, - padding_mode='zeros', - hidden_w_init=nn.init.ones_, - is_image=False) - flat_dim = torch.flatten(cnn(obs).detach(), start_dim=1).shape[1] - - mlp_adv = MLPModule( - flat_dim, - output_dim, - hidden_channels, - hidden_w_init=nn.init.ones_, - output_w_init=nn.init.ones_, - ) - - mlp_val = MLPModule( - flat_dim, - 1, - hidden_channels, - hidden_w_init=nn.init.ones_, - output_w_init=nn.init.ones_, - ) - - cnn_out = cnn(obs) - val = mlp_val(torch.flatten(cnn_out, start_dim=1)) - adv = mlp_adv(torch.flatten(cnn_out, start_dim=1)) - output = val + (adv - adv.mean(1).unsqueeze(1)) - - assert torch.all(torch.eq(output.detach(), module(obs).detach())) - - -@pytest.mark.parametrize('output_dim, hidden_channels, kernel_sizes, strides', - [(1, (32, ), (1, ), (1, ))]) -def test_without_nonlinearity(output_dim, hidden_channels, kernel_sizes, - strides): - batch_size = 64 - input_width = 32 - input_height = 32 - in_channel = 3 - input_shape = (batch_size, in_channel, input_height, input_width) - - module = DiscreteDuelingCNNModule(input_shape=input_shape, - output_dim=output_dim, - hidden_channels=hidden_channels, - hidden_sizes=hidden_channels, - kernel_sizes=kernel_sizes, - strides=strides, - mlp_hidden_nonlinearity=None, - cnn_hidden_nonlinearity=None, - hidden_w_init=nn.init.ones_, - output_w_init=nn.init.ones_, - is_image=False) - - assert len(module._module) == 2 - - -@pytest.mark.parametrize('output_dim, hidden_channels, kernel_sizes, strides', - [(1, (32, ), (1, ), (1, )), (5, (32, ), (3, ), (1, )), - (2, (32, ), (3, ), (1, )), - (3, (32, 64), (1, 1), (1, 1)), - (4, (32, 64), (3, 3), (1, 1))]) -def test_is_pickleable(output_dim, hidden_channels, kernel_sizes, strides): - batch_size = 64 - input_width = 32 - input_height = 32 - in_channel = 3 - input_shape = (batch_size, in_channel, input_height, input_width) - input_a = torch.ones(input_shape) - - model = DiscreteDuelingCNNModule(input_shape=input_shape, - output_dim=output_dim, - hidden_channels=hidden_channels, - kernel_sizes=kernel_sizes, - mlp_hidden_nonlinearity=nn.ReLU, - cnn_hidden_nonlinearity=nn.ReLU, - strides=strides) - output1 = model(input_a) - - h = pickle.dumps(model) - model_pickled = pickle.loads(h) - output2 = model_pickled(input_a) - - assert np.array_equal(torch.all(torch.eq(output1, output2)), True) diff --git a/tests/garage/torch/policies/test_categorical_cnn_policy.py b/tests/garage/torch/policies/test_categorical_cnn_policy.py index 1c1560ae72..e7733edf30 100644 --- a/tests/garage/torch/policies/test_categorical_cnn_policy.py +++ b/tests/garage/torch/policies/test_categorical_cnn_policy.py @@ -4,7 +4,6 @@ import torch from garage.envs import GymEnv -from garage.torch import TransposeImage from garage.torch.policies import CategoricalCNNPolicy from tests.fixtures.envs.dummy import DummyDictEnv, DummyDiscretePixelEnv @@ -12,20 +11,6 @@ class TestCategoricalCNNPolicy: - def _initialize_obs_env(self, env): - """Initialize observation env depends on observation space type. - - If observation space (i.e. akro.Image, gym.spaces.Box) is an image, - wrap the input of shape (W, H, 3) for PyTorch (N, 3, W, H). - - Return: - Transformed environment (garage.envs). - """ - obs_shape = env.observation_space.shape - if len(obs_shape) == 3 and obs_shape[2] in [1, 3]: - env = TransposeImage(env) - return env - @pytest.mark.parametrize( 'hidden_channels, kernel_sizes, strides, hidden_sizes', [ ((3, ), (3, ), (1, ), (4, )), @@ -36,8 +21,8 @@ def test_get_action(self, hidden_channels, kernel_sizes, strides, hidden_sizes): """Test get_action function.""" env = GymEnv(DummyDiscretePixelEnv(), is_image=True) - env = self._initialize_obs_env(env) - policy = CategoricalCNNPolicy(env=env, + policy = CategoricalCNNPolicy(env_spec=env.spec, + image_format='NHWC', kernel_sizes=kernel_sizes, hidden_channels=hidden_channels, strides=strides, @@ -57,8 +42,8 @@ def test_get_action_img_obs(self, hidden_channels, kernel_sizes, strides, hidden_sizes): """Test get_action function with akro.Image observation space.""" env = GymEnv(DummyDiscretePixelEnv(), is_image=True) - env = self._initialize_obs_env(env) - policy = CategoricalCNNPolicy(env=env, + policy = CategoricalCNNPolicy(env_spec=env.spec, + image_format='NHWC', kernel_sizes=kernel_sizes, hidden_channels=hidden_channels, strides=strides, @@ -79,8 +64,8 @@ def test_get_actions(self, hidden_channels, kernel_sizes, strides, hidden_sizes): """Test get_actions function with akro.Image observation space.""" env = GymEnv(DummyDiscretePixelEnv(), is_image=True) - env = self._initialize_obs_env(env) - policy = CategoricalCNNPolicy(env=env, + policy = CategoricalCNNPolicy(env_spec=env.spec, + image_format='NHWC', kernel_sizes=kernel_sizes, hidden_channels=hidden_channels, strides=strides, @@ -106,8 +91,8 @@ def test_is_pickleable(self, hidden_channels, kernel_sizes, strides, hidden_sizes): """Test if policy is pickable.""" env = GymEnv(DummyDiscretePixelEnv(), is_image=True) - env = self._initialize_obs_env(env) - policy = CategoricalCNNPolicy(env=env, + policy = CategoricalCNNPolicy(env_spec=env.spec, + image_format='NHWC', kernel_sizes=kernel_sizes, hidden_channels=hidden_channels, strides=strides, @@ -131,7 +116,8 @@ def test_does_not_support_dict_obs_space(self): with pytest.raises(ValueError, match=('CNN policies do not support ' 'with akro.Dict observation spaces.')): - CategoricalCNNPolicy(env=env, + CategoricalCNNPolicy(env_spec=env.spec, + image_format='NHWC', kernel_sizes=(3, ), hidden_channels=(3, )) @@ -139,7 +125,8 @@ def test_invalid_action_spaces(self): """Test that policy raises error if passed a box obs space.""" env = GymEnv(DummyDictEnv(act_space_type='box')) with pytest.raises(ValueError): - CategoricalCNNPolicy(env=env, + CategoricalCNNPolicy(env_spec=env.spec, + image_format='NHWC', kernel_sizes=(3, ), hidden_channels=(3, )) @@ -155,9 +142,9 @@ def test_obs_unflattened(self, hidden_channels, kernel_sizes, strides, then it is unflattened. """ env = GymEnv(DummyDiscretePixelEnv(), is_image=True) - env = self._initialize_obs_env(env) env.reset() - policy = CategoricalCNNPolicy(env=env, + policy = CategoricalCNNPolicy(env_spec=env.spec, + image_format='NHWC', kernel_sizes=kernel_sizes, hidden_channels=hidden_channels, strides=strides, diff --git a/tests/garage/torch/policies/test_discrete_cnn_policy.py b/tests/garage/torch/policies/test_discrete_cnn_policy.py index 882971b469..7420dbb16b 100644 --- a/tests/garage/torch/policies/test_discrete_cnn_policy.py +++ b/tests/garage/torch/policies/test_discrete_cnn_policy.py @@ -4,53 +4,30 @@ import torch.nn as nn from garage.envs import GymEnv -from garage.torch import TransposeImage from garage.torch.policies import DiscreteCNNPolicy -from tests.fixtures.envs.dummy import DummyDiscreteEnv +from tests.fixtures.envs.dummy import DummyDiscretePixelEnv -class TestCategoricalCNNPolicy: - - def _initialize_obs_env(self, env): - """Initialize observation env depends on observation space type. - - If observation space (i.e. akro.Image, gym.spaces.Box) is an image, - wrap the input of shape (W, H, 3) for PyTorch (N, 3, W, H). - - Return: - Transformed environment (garage.envs). - """ - obs_shape = env.observation_space.shape - if len(obs_shape) == 3 and obs_shape[2] in [1, 3]: - env = TransposeImage(env) - return env +class TestDiscreteCNNPolicy: @pytest.mark.parametrize( - 'action_dim, kernel_sizes, hidden_channels, strides, paddings', [ - (3, (1, ), (32, ), (1, ), (0, )), - (3, (3, ), (32, ), (1, ), (0, )), - (3, (3, ), (32, ), (2, ), (0, )), - (3, (5, ), (12, ), (1, ), (2, )), - (3, (1, 1), (32, 64), (1, 1), (0, 0)), - (3, (3, 3), (32, 64), (1, 1), (0, 0)), - (3, (3, 3), (32, 64), (2, 2), (0, 0)), + 'kernel_sizes, hidden_channels, strides, paddings', [ + ((1, ), (32, ), (1, ), (0, )), + ((3, ), (32, ), (1, ), (0, )), + ((3, ), (32, ), (2, ), (0, )), + ((5, ), (12, ), (1, ), (2, )), + ((1, 1), (32, 64), (1, 1), (0, 0)), + ((3, 3), (32, 64), (1, 1), (0, 0)), + ((3, 3), (32, 64), (2, 2), (0, 0)), ]) - def test_get_action(self, action_dim, kernel_sizes, hidden_channels, - strides, paddings): + def test_get_action(self, kernel_sizes, hidden_channels, strides, + paddings): """Test get_action function.""" - batch_size = 64 - input_width = 32 - input_height = 32 - in_channel = 3 - input_shape = (batch_size, in_channel, input_height, input_width) - env = GymEnv( - DummyDiscreteEnv(obs_dim=input_shape, action_dim=action_dim)) - - env = self._initialize_obs_env(env) + env = GymEnv(DummyDiscretePixelEnv()) policy = DiscreteCNNPolicy(env_spec=env.spec, + image_format='NHWC', hidden_channels=hidden_channels, - hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, @@ -62,33 +39,24 @@ def test_get_action(self, action_dim, kernel_sizes, hidden_channels, action, _ = policy.get_action(obs.flatten()) assert env.action_space.contains(int(action[0])) - assert env.action_space.n == action_dim @pytest.mark.parametrize( - 'action_dim, kernel_sizes, hidden_channels, strides, paddings', [ - (3, (1, ), (32, ), (1, ), (0, )), - (3, (3, ), (32, ), (1, ), (0, )), - (3, (3, ), (32, ), (2, ), (0, )), - (3, (5, ), (12, ), (1, ), (2, )), - (3, (1, 1), (32, 64), (1, 1), (0, 0)), - (3, (3, 3), (32, 64), (1, 1), (0, 0)), - (3, (3, 3), (32, 64), (2, 2), (0, 0)), + 'kernel_sizes, hidden_channels, strides, paddings', [ + ((1, ), (32, ), (1, ), (0, )), + ((3, ), (32, ), (1, ), (0, )), + ((3, ), (32, ), (2, ), (0, )), + ((5, ), (12, ), (1, ), (2, )), + ((1, 1), (32, 64), (1, 1), (0, 0)), + ((3, 3), (32, 64), (1, 1), (0, 0)), + ((3, 3), (32, 64), (2, 2), (0, 0)), ]) - def test_get_actions(self, action_dim, kernel_sizes, hidden_channels, - strides, paddings): + def test_get_actions(self, kernel_sizes, hidden_channels, strides, + paddings): """Test get_actions function.""" - batch_size = 64 - input_width = 32 - input_height = 32 - in_channel = 3 - input_shape = (batch_size, in_channel, input_height, input_width) - env = GymEnv( - DummyDiscreteEnv(obs_dim=input_shape, action_dim=action_dim)) - - env = self._initialize_obs_env(env) + env = GymEnv(DummyDiscretePixelEnv()) policy = DiscreteCNNPolicy(env_spec=env.spec, + image_format='NHWC', hidden_channels=hidden_channels, - hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, @@ -102,33 +70,24 @@ def test_get_actions(self, action_dim, kernel_sizes, hidden_channels, actions, _ = policy.get_actions([obs, obs, obs]) for action in actions: assert env.action_space.contains(int(action[0])) - assert env.action_space.n == action_dim @pytest.mark.parametrize( - 'action_dim, kernel_sizes, hidden_channels, strides, paddings', [ - (3, (1, ), (32, ), (1, ), (0, )), - (3, (3, ), (32, ), (1, ), (0, )), - (3, (3, ), (32, ), (2, ), (0, )), - (3, (5, ), (12, ), (1, ), (2, )), - (3, (1, 1), (32, 64), (1, 1), (0, 0)), - (3, (3, 3), (32, 64), (1, 1), (0, 0)), - (3, (3, 3), (32, 64), (2, 2), (0, 0)), + 'kernel_sizes, hidden_channels, strides, paddings', [ + ((1, ), (32, ), (1, ), (0, )), + ((3, ), (32, ), (1, ), (0, )), + ((3, ), (32, ), (2, ), (0, )), + ((5, ), (12, ), (1, ), (2, )), + ((1, 1), (32, 64), (1, 1), (0, 0)), + ((3, 3), (32, 64), (1, 1), (0, 0)), + ((3, 3), (32, 64), (2, 2), (0, 0)), ]) - def test_is_pickleable(self, action_dim, kernel_sizes, hidden_channels, - strides, paddings): + def test_is_pickleable(self, kernel_sizes, hidden_channels, strides, + paddings): """Test if policy is pickable.""" - batch_size = 64 - input_width = 32 - input_height = 32 - in_channel = 3 - input_shape = (batch_size, in_channel, input_height, input_width) - env = GymEnv( - DummyDiscreteEnv(obs_dim=input_shape, action_dim=action_dim)) - - env = self._initialize_obs_env(env) + env = GymEnv(DummyDiscretePixelEnv()) policy = DiscreteCNNPolicy(env_spec=env.spec, + image_format='NHWC', hidden_channels=hidden_channels, - hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, @@ -149,33 +108,25 @@ def test_is_pickleable(self, action_dim, kernel_sizes, hidden_channels, assert output_action_1.shape == output_action_2.shape @pytest.mark.parametrize( - 'action_dim, kernel_sizes, hidden_channels, strides, paddings', [ - (3, (1, ), (32, ), (1, ), (0, )), - (3, (3, ), (32, ), (1, ), (0, )), - (3, (3, ), (32, ), (2, ), (0, )), - (3, (5, ), (12, ), (1, ), (2, )), - (3, (1, 1), (32, 64), (1, 1), (0, 0)), - (3, (3, 3), (32, 64), (1, 1), (0, 0)), - (3, (3, 3), (32, 64), (2, 2), (0, 0)), + 'kernel_sizes, hidden_channels, strides, paddings', [ + ((1, ), (32, ), (1, ), (0, )), + ((3, ), (32, ), (1, ), (0, )), + ((3, ), (32, ), (2, ), (0, )), + ((5, ), (12, ), (1, ), (2, )), + ((1, 1), (32, 64), (1, 1), (0, 0)), + ((3, 3), (32, 64), (1, 1), (0, 0)), + ((3, 3), (32, 64), (2, 2), (0, 0)), ]) - def test_obs_unflattened(self, action_dim, kernel_sizes, hidden_channels, - strides, paddings): + def test_obs_unflattened(self, kernel_sizes, hidden_channels, strides, + paddings): """Test if a flattened image obs is passed to get_action then it is unflattened. """ - batch_size = 64 - input_width = 32 - input_height = 32 - in_channel = 3 - input_shape = (batch_size, in_channel, input_height, input_width) - env = GymEnv( - DummyDiscreteEnv(obs_dim=input_shape, action_dim=action_dim)) - env = self._initialize_obs_env(env) - + env = GymEnv(DummyDiscretePixelEnv()) env.reset() policy = DiscreteCNNPolicy(env_spec=env.spec, + image_format='NHWC', hidden_channels=hidden_channels, - hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, @@ -185,4 +136,4 @@ def test_obs_unflattened(self, action_dim, kernel_sizes, hidden_channels, obs = env.observation_space.sample() action, _ = policy.get_action(env.observation_space.flatten(obs)) - env.step(action) + env.step(action[0]) diff --git a/tests/garage/torch/q_functions/test_discrete_cnn_q_function.py b/tests/garage/torch/q_functions/test_discrete_cnn_q_function.py index 667aae734d..29d9584ac8 100644 --- a/tests/garage/torch/q_functions/test_discrete_cnn_q_function.py +++ b/tests/garage/torch/q_functions/test_discrete_cnn_q_function.py @@ -25,6 +25,7 @@ def test_forward(batch_size, hidden_channels, kernel_sizes, strides): obs = torch.zeros((batch_size, ) + obs_dim, dtype=torch.float32) qf = DiscreteCNNQFunction(env_spec=env_spec, + image_format='NCHW', kernel_sizes=kernel_sizes, strides=strides, mlp_hidden_nonlinearity=None, @@ -32,8 +33,7 @@ def test_forward(batch_size, hidden_channels, kernel_sizes, strides): hidden_channels=hidden_channels, hidden_sizes=hidden_channels, hidden_w_init=nn.init.ones_, - output_w_init=nn.init.ones_, - is_image=False) + output_w_init=nn.init.ones_) output = qf(obs) expected_output = torch.zeros(output.shape) @@ -54,6 +54,7 @@ def test_is_pickleable(batch_size, hidden_channels, kernel_sizes, strides): obs = torch.ones((batch_size, ) + obs_dim, dtype=torch.float32) qf = DiscreteCNNQFunction(env_spec=env_spec, + image_format='NCHW', kernel_sizes=kernel_sizes, strides=strides, mlp_hidden_nonlinearity=None, @@ -61,8 +62,7 @@ def test_is_pickleable(batch_size, hidden_channels, kernel_sizes, strides): hidden_channels=hidden_channels, hidden_sizes=hidden_channels, hidden_w_init=nn.init.ones_, - output_w_init=nn.init.ones_, - is_image=False) + output_w_init=nn.init.ones_) output1 = qf(obs) p = pickle.dumps(qf) diff --git a/tests/garage/torch/q_functions/test_discrete_dueling_cnn_q_function.py b/tests/garage/torch/q_functions/test_discrete_dueling_cnn_q_function.py index f3bc334374..6e7131690a 100644 --- a/tests/garage/torch/q_functions/test_discrete_dueling_cnn_q_function.py +++ b/tests/garage/torch/q_functions/test_discrete_dueling_cnn_q_function.py @@ -25,6 +25,7 @@ def test_forward(batch_size, hidden_channels, kernel_sizes, strides): obs = torch.zeros((batch_size, ) + obs_dim, dtype=torch.float32) qf = DiscreteDuelingCNNQFunction(env_spec=env_spec, + image_format='NCHW', kernel_sizes=kernel_sizes, strides=strides, mlp_hidden_nonlinearity=None, @@ -32,8 +33,7 @@ def test_forward(batch_size, hidden_channels, kernel_sizes, strides): hidden_channels=hidden_channels, hidden_sizes=hidden_channels, hidden_w_init=nn.init.ones_, - output_w_init=nn.init.ones_, - is_image=False) + output_w_init=nn.init.ones_) output = qf(obs) expected_output = torch.zeros(output.shape) @@ -54,6 +54,7 @@ def test_is_pickleable(batch_size, hidden_channels, kernel_sizes, strides): obs = torch.ones((batch_size, ) + obs_dim, dtype=torch.float32) qf = DiscreteDuelingCNNQFunction(env_spec=env_spec, + image_format='NCHW', kernel_sizes=kernel_sizes, strides=strides, mlp_hidden_nonlinearity=None, @@ -61,8 +62,7 @@ def test_is_pickleable(batch_size, hidden_channels, kernel_sizes, strides): hidden_channels=hidden_channels, hidden_sizes=hidden_channels, hidden_w_init=nn.init.ones_, - output_w_init=nn.init.ones_, - is_image=False) + output_w_init=nn.init.ones_) output1 = qf(obs) p = pickle.dumps(qf) diff --git a/tests/garage/torch/test_functions.py b/tests/garage/torch/test_functions.py index 8e12f31fad..f41dadcdda 100644 --- a/tests/garage/torch/test_functions.py +++ b/tests/garage/torch/test_functions.py @@ -5,19 +5,12 @@ import torch import torch.nn.functional as F -from garage.torch import (compute_advantages, - dict_np_to_torch, - flatten_to_single_vector, - global_device, - pad_to_last, - product_of_gaussians, - set_gpu_mode, - torch_to_np, - TransposeImage) +from garage.torch import (as_torch_dict, compute_advantages, + flatten_to_single_vector, global_device, pad_to_last, + product_of_gaussians, set_gpu_mode, torch_to_np) import garage.torch._functions as tu from tests.fixtures import TfGraphTestCase -from tests.fixtures.envs.dummy import DummyDiscretePixelEnv # yapf: enable @@ -59,10 +52,10 @@ def test_torch_to_np(): assert isinstance(np_out_2, np.ndarray) -def test_dict_np_to_torch(): +def test_as_torch_dict(): """Test if dict whose values are tensors can be converted to np arrays.""" dic = {'a': np.zeros(1), 'b': np.ones(1)} - dict_np_to_torch(dic) + as_torch_dict(dic) for tensor in dic.values(): assert isinstance(tensor, torch.Tensor) @@ -87,16 +80,6 @@ def test_flatten_to_single_vector(): assert expected.shape == flatten_tensor.shape -def test_transpose_image(): - """Test TransposeImage.""" - original_env = DummyDiscretePixelEnv() - obs_shape = original_env.observation_space.shape - if len(obs_shape) == 3 and obs_shape[2] in [1, 3]: - transposed_env = TransposeImage(original_env) - assert (original_env.observation_space.shape[2] == - transposed_env.observation_space.shape[0]) - - class TestTorchAlgoUtils(TfGraphTestCase): """Test class for torch algo utility functions.""" # yapf: disable @@ -194,3 +177,20 @@ def test_out_of_index_error(self, nums): """Test pad_to_last raises IndexError.""" with pytest.raises(IndexError): pad_to_last(nums, total_length=10, axis=len(nums.shape)) + + +def test_expand_var(): + with pytest.raises(ValueError, match='test_var is length 2'): + tu.expand_var('test_var', (1, 2), 3, 'reference_var') + + +def test_value_at_axis(): + assert tu._value_at_axis('test_value', 0) == 'test_value' + assert tu._value_at_axis('test_value', 1) == 'test_value' + assert tu._value_at_axis(['a', 'b', 'c'], 0) == 'a' + assert tu._value_at_axis(['a', 'b', 'c'], 1) == 'b' + assert tu._value_at_axis(['a', 'b', 'c'], 2) == 'c' + assert tu._value_at_axis(('a', 'b', 'c'), 0) == 'a' + assert tu._value_at_axis(('a', 'b', 'c'), 1) == 'b' + assert tu._value_at_axis(('a', 'b', 'c'), 2) == 'c' + assert tu._value_at_axis(['test_value'], 3) == 'test_value'