-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
2 changed files
with
264 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
from typing import Callable | ||
|
||
import torch | ||
|
||
from torch import Tensor, nn | ||
|
||
from emote import Callback | ||
from emote.callbacks.logging import LoggingMixin | ||
from emote.callbacks.loss import LossCallback | ||
|
||
|
||
def gradient_loss_function(model_output: Tensor, model_input: Tensor) -> Tensor: | ||
""" | ||
Given inputs and outputs of an nn.Module, computes the sum of | ||
squared derivatives of outputs to the inputs | ||
Arguments: | ||
model_output (Tensor): the output of the nn.Module | ||
model_input (Tensor): the input to the nn.Module | ||
Returns: | ||
loss (Tensor): the sum of squared derivatives | ||
""" | ||
# grad can be implicitly created only for scalar outputs | ||
predictions = torch.split(model_output, 1, dim=0) | ||
inputs_grad = torch.autograd.grad( | ||
predictions, model_input, create_graph=True, retain_graph=True | ||
) | ||
inputs_grad = torch.cat(inputs_grad, dim=1) | ||
inputs_grad_norm = torch.square(inputs_grad) | ||
inputs_grad_norm = torch.sum(inputs_grad_norm, dim=1) | ||
return torch.mean(inputs_grad_norm) | ||
|
||
|
||
class DiscriminatorLoss(LossCallback): | ||
""" | ||
This loss is used to train a discriminator for | ||
adversarial training. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
discriminator: nn.Module, | ||
imitation_state_map_fn: Callable[[Tensor], Tensor], | ||
policy_state_map_fn: Callable[[Tensor], Tensor], | ||
grad_loss_weight: float, | ||
optimizer: torch.optim.Optimizer, | ||
lr_schedule: torch.optim.lr_scheduler._LRScheduler, | ||
max_grad_norm: float, | ||
input_key: str = "features", | ||
name: str = "Discriminator", | ||
): | ||
super().__init__( | ||
lr_schedule=lr_schedule, | ||
name=name, | ||
network=discriminator, | ||
optimizer=optimizer, | ||
max_grad_norm=max_grad_norm, | ||
data_group=None, | ||
) | ||
self._discriminator = discriminator | ||
self._imitation_state_map_function = imitation_state_map_fn | ||
self._policy_state_map_function = policy_state_map_fn | ||
self._grad_loss_weight = grad_loss_weight | ||
self._obs_key = input_key | ||
|
||
def loss(self, imitation_batch: dict, policy_batch: dict) -> Tensor: | ||
""" | ||
Computing the loss | ||
Arguments: | ||
imitation_batch (dict): a batch of data from the reference animation. | ||
the discriminator is trained to classify data from this batch as | ||
positive samples | ||
policy_batch (dict): a batch of data from the RL buffer. the discriminator | ||
is trained to classify data from this batch as negative samples. | ||
Returns: | ||
loss (Tensor): the loss tensor | ||
""" | ||
imitation_batch_size = imitation_batch["batch_size"] | ||
policy_data_batch_size = policy_batch["batch_size"] | ||
|
||
pos_obs: Tensor = imitation_batch["observation"][self._obs_key] | ||
pos_next_obs: Tensor = imitation_batch["next_observation"][self._obs_key] | ||
neg_obs: Tensor = policy_batch["observation"][self._obs_key] | ||
neg_next_obs: Tensor = policy_batch["next_observation"][self._obs_key] | ||
|
||
pos_obs = self._imitation_state_map_function(pos_obs) | ||
pos_next_obs = self._imitation_state_map_function(pos_next_obs) | ||
pos_input = torch.cat([pos_obs, pos_next_obs], dim=-1) | ||
pos_input.requires_grad_(True) | ||
|
||
neg_obs = self._policy_state_map_function(neg_obs) | ||
neg_next_obs = self._policy_state_map_function(neg_next_obs) | ||
neg_input = torch.cat([neg_obs, neg_next_obs], dim=-1) | ||
|
||
pos_output = self._discriminator(pos_input) | ||
neg_output = self._discriminator(neg_input) | ||
assert pos_output.shape == (imitation_batch_size, 1) | ||
assert neg_output.shape == (policy_data_batch_size, 1) | ||
|
||
pos_loss = torch.mean(torch.square(pos_output - 1.0)) # Positive samples should label to 1. | ||
neg_loss = torch.mean( | ||
torch.square(neg_output + 1.0) | ||
) # Negative samples should label to -1. | ||
|
||
grad_penalty_loss = self._grad_loss_weight * gradient_loss_function(pos_output, pos_input) | ||
|
||
loss = pos_loss + neg_loss + grad_penalty_loss | ||
|
||
self.log_scalar("amp/loss/pos_discrimination_loss", pos_loss) | ||
self.log_scalar("amp/loss/neg_discrimination_loss", neg_loss) | ||
self.log_scalar("amp/loss/grad_loss", grad_penalty_loss) | ||
self.log_scalar("amp/loss/total", loss) | ||
self.log_scalar("amp/predict/positive_samples_mean", torch.mean(pos_output)) | ||
self.log_scalar("amp/predict/positive_samples_std", torch.std(pos_output)) | ||
self.log_scalar("amp/predict/negative_samples_mean", torch.mean(neg_output)) | ||
self.log_scalar("amp/predict/negative_samples_std", torch.std(neg_output)) | ||
|
||
return loss | ||
|
||
|
||
class AMPReward(Callback, LoggingMixin): | ||
"""Adversarial rewarding with AMP""" | ||
|
||
def __init__( | ||
self, | ||
discriminator: nn.Module, | ||
state_map_fn: Callable[[Tensor], Tensor], | ||
style_reward_weight: float, | ||
rollout_length: int, | ||
observation_key: str, | ||
data_group: str, | ||
): | ||
super().__init__() | ||
self._discriminator = discriminator | ||
self._order = 0 | ||
self.data_group = data_group | ||
self._style_reward_weight = style_reward_weight | ||
self._state_map_function = state_map_fn | ||
self._rollout_length = rollout_length | ||
self._obs_key = observation_key | ||
|
||
def begin_batch( | ||
self, observation: dict[str, Tensor], next_observation: dict[str, Tensor], rewards: Tensor | ||
): | ||
""" | ||
Updating the reward by adding the weighted AMP reward | ||
Arguments: | ||
observation: current observation | ||
next_observation: next observation | ||
rewards: task reward | ||
Returns | ||
dict: the batch data with updated reward | ||
""" | ||
obs = observation[self._obs_key] | ||
bsz = obs.shape[0] | ||
rollouts = obs.reshape(bsz // self._rollout_length, self._rollout_length, -1) | ||
next_obs = next_observation[self._obs_key] | ||
|
||
next_obs = next_obs.unsqueeze(dim=1) | ||
combined_obs = torch.cat((rollouts, next_obs), dim=1) | ||
next_obs = combined_obs[:, 1:] | ||
|
||
next_obs = next_obs.reshape(bsz, -1) | ||
|
||
state = self._state_map_function(obs) | ||
next_state = self._state_map_function(next_obs) | ||
|
||
consecutive_states = torch.cat([state, next_state], dim=-1) | ||
|
||
predictions = self._discriminator(consecutive_states).detach() | ||
|
||
style_reward = 1.0 - 0.25 * (predictions - 1.0) ** 2.0 | ||
scaled_style_reward = self._style_reward_weight * style_reward | ||
assert scaled_style_reward.shape == rewards.shape | ||
|
||
total_reward = rewards + scaled_style_reward | ||
|
||
self.log_scalar("amp/unscaled_style_reward", torch.mean(style_reward)) | ||
self.log_scalar("amp/task_reward", torch.mean(rewards)) | ||
self.log_scalar("amp/scaled_style_reward", torch.mean(scaled_style_reward)) | ||
self.log_scalar("amp/total_reward", torch.mean(total_reward)) | ||
self.log_scalar("amp/predicts_mean", torch.mean(predictions)) | ||
self.log_scalar("amp/predicts_std", torch.std(predictions)) | ||
|
||
return {self.data_group: {"rewards": total_reward}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch | ||
|
||
from torch import Tensor, nn | ||
|
||
from emote.algorithms.amp import AMPReward, DiscriminatorLoss, gradient_loss_function | ||
|
||
|
||
class Discriminator(nn.Module): | ||
def __init__(self, input_size: int, hidden_dims: list[int]): | ||
super().__init__() | ||
self.encoder = nn.Sequential( | ||
*[ | ||
nn.Sequential(nn.Linear(n_in, n_out), nn.ReLU()) | ||
for n_in, n_out in zip([input_size] + hidden_dims, hidden_dims) | ||
], | ||
) | ||
final_layers: list[nn.Module] = [nn.Linear(hidden_dims[-1], 1)] | ||
self.final_layer = nn.Sequential(*final_layers) | ||
|
||
def forward(self, x: Tensor): | ||
return self.final_layer(self.encoder(x)) | ||
|
||
|
||
def state_map_fn(obs: Tensor): | ||
return obs | ||
|
||
|
||
def test_gradient_loss(): | ||
|
||
x = torch.ones(10, 3, requires_grad=True) | ||
x = x * torch.rand(10, 3) | ||
y = torch.sum(4 * x * x + torch.sin(x), dim=1) | ||
|
||
grad1 = gradient_loss_function(y, x) | ||
y_dot = 8 * x + torch.cos(x) | ||
grad2 = torch.mean(torch.sum(y_dot * y_dot, dim=1)) | ||
|
||
assert abs(grad1.item() - grad2.item()) < 0.001 | ||
|
||
|
||
def test_discriminator_loss(): | ||
discriminator = Discriminator(input_size=20, hidden_dims=[128, 128]) | ||
discriminator_opt = torch.optim.Adam(discriminator.parameters(), lr=0.001) | ||
loss = DiscriminatorLoss( | ||
discriminator=discriminator, | ||
imitation_state_map_fn=state_map_fn, | ||
policy_state_map_fn=state_map_fn, | ||
grad_loss_weight=1, | ||
optimizer=discriminator_opt, | ||
lr_schedule=None, | ||
input_key="features", | ||
max_grad_norm=10.0, | ||
) | ||
batch1 = { | ||
"batch_size": 30, | ||
"observation": {"features": torch.rand(30, 10)}, | ||
"next_observation": {"features": torch.rand(30, 10)}, | ||
} | ||
batch2 = { | ||
"batch_size": 30, | ||
"observation": {"features": torch.rand(30, 10)}, | ||
"next_observation": {"features": torch.rand(30, 10)}, | ||
} | ||
assert loss.loss(batch1, batch2) >= 0 | ||
|
||
|
||
def test_amp_reward(): | ||
discriminator = Discriminator(input_size=20, hidden_dims=[128, 128]) | ||
amp_reward = AMPReward( | ||
discriminator=discriminator, | ||
state_map_fn=state_map_fn, | ||
style_reward_weight=1.0, | ||
rollout_length=1, | ||
observation_key="features", | ||
data_group=None, | ||
) | ||
observation = {"features": torch.rand(30, 10)} | ||
next_observation = {"features": torch.rand(30, 10)} | ||
reward = torch.rand(30, 1) | ||
amp_reward.begin_batch(observation, next_observation, reward) |