Skip to content

Commit

Permalink
Add per example reward loss in LossInfo extra
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705472790
Change-Id: I6ef2f76efd06edae950a1754b5ad0fe2773b455a
  • Loading branch information
TF-Agents Team authored and copybara-github committed Dec 12, 2024
1 parent ee4d6fe commit 528cef7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 36 deletions.
36 changes: 19 additions & 17 deletions tf_agents/bandits/agents/greedy_multi_objective_neural_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,14 @@ def _single_objective_loss(
sample_weights = (
sample_weights * 0.5 * tf.exp(-action_predicted_log_variance)
)
loss = 0.5 * tf.reduce_mean(action_predicted_log_variance)
loss = 0.5 * action_predicted_log_variance
# loss = 1/(2 * var(x)) * (y - f(x))^2 + 1/2 * log var(x)
# Kendall, Alex, and Yarin Gal. "What Uncertainties Do We Need in
# Bayesian Deep Learning for Computer Vision?." Advances in Neural
# Information Processing Systems. 2017. https://arxiv.org/abs/1703.04977
else:
predicted_values, _ = objective_network(observations, training=training)
loss = tf.constant(0.0)
loss = tf.zeros_like(single_objective_values)

action_predicted_values = common.index_with_actions(
predicted_values, tf.cast(actions, dtype=tf.int32)
Expand All @@ -338,18 +338,13 @@ def _single_objective_loss(
objective_idx
] * tf.reduce_mean(smoothness_batched * sample_weights)

# Reduction is done outside of the loss function because non-scalar
# weights with unknown shapes may trigger shape validation that fails
# XLA compilation.
loss += tf.reduce_mean(
tf.multiply(
self._error_loss_fns[objective_idx](
single_objective_values,
action_predicted_values,
reduction=tf.compat.v1.losses.Reduction.NONE,
),
sample_weights,
)
loss += tf.multiply(
self._error_loss_fns[objective_idx](
single_objective_values,
action_predicted_values,
reduction=tf.compat.v1.losses.Reduction.NONE,
),
sample_weights,
)

return loss
Expand Down Expand Up @@ -401,10 +396,10 @@ def _loss(
)
)

objective_losses = []
per_example_objective_losses = []
for idx in range(self._num_objectives):
single_objective_values = objective_values[:, idx]
objective_losses.append(
per_example_objective_losses.append(
self._single_objective_loss(
idx,
observations,
Expand All @@ -414,10 +409,17 @@ def _loss(
training,
)
)
per_example_loss = tf.reduce_sum(
tf.stack(per_example_objective_losses), axis=0
)
objective_losses = [
tf.reduce_mean(per_example_objective_losses[idx])
for idx in range(self._num_objectives)
]

self.compute_summaries(objective_losses)
total_loss = tf.reduce_sum(objective_losses)
return tf_agent.LossInfo(total_loss, extra=())
return tf_agent.LossInfo(total_loss, extra=per_example_loss)

def compute_summaries(self, losses: Sequence[tf.Tensor]):
if self._num_objectives != len(losses):
Expand Down
33 changes: 14 additions & 19 deletions tf_agents/bandits/agents/greedy_reward_prediction_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _train(self, experience, weights):

return loss_info

def reward_loss(
def per_example_reward_loss(
self,
observations: types.NestedTensor,
actions: types.Tensor,
Expand Down Expand Up @@ -325,7 +325,7 @@ def reward_loss(
sample_weights * 0.5 * tf.exp(-action_predicted_log_variance)
)

loss = 0.5 * tf.reduce_mean(action_predicted_log_variance)
loss = 0.5 * action_predicted_log_variance
# loss = 1/(2 * var(x)) * (y - f(x))^2 + 1/2 * log var(x)
# Kendall, Alex, and Yarin Gal. "What Uncertainties Do We Need in
# Bayesian Deep Learning for Computer Vision?." Advances in Neural
Expand All @@ -334,7 +334,7 @@ def reward_loss(
predicted_values, _ = self._reward_network(
observations, training=training
)
loss = tf.constant(0.0)
loss = tf.zeros_like(rewards)

action_predicted_values = common.index_with_actions(
predicted_values, tf.cast(actions, dtype=tf.int32)
Expand All @@ -348,24 +348,18 @@ def reward_loss(
self._laplacian_matrix, predicted_values, transpose_b=True
),
)
loss += self._laplacian_smoothing_weight * tf.reduce_mean(
loss += self._laplacian_smoothing_weight * (
tf.linalg.tensor_diag_part(smoothness_batched) * sample_weights
)

# Reduction is done outside of the loss function because non-scalar
# weights with unknown shapes may trigger shape validation that fails
# XLA compilation.
loss += tf.reduce_mean(
tf.multiply(
self._error_loss_fn(
rewards,
action_predicted_values,
reduction=tf.compat.v1.losses.Reduction.NONE,
),
sample_weights,
)
loss += tf.multiply(
self._error_loss_fn(
rewards,
action_predicted_values,
reduction=tf.compat.v1.losses.Reduction.NONE,
),
sample_weights,
)

return loss

def _loss(
Expand Down Expand Up @@ -404,9 +398,10 @@ def _loss(
rewards_tensor = rewards[bandit_spec_utils.REWARD_SPEC_KEY]
else:
rewards_tensor = rewards
reward_loss = self.reward_loss(
per_example_reward_loss = self.per_example_reward_loss(
observations, actions, rewards_tensor, weights, training
)
reward_loss = tf.reduce_mean(per_example_reward_loss)

constraint_loss = tf.constant(0.0)
for i, c in enumerate(self._constraints, 0):
Expand All @@ -424,7 +419,7 @@ def _loss(
total_loss = reward_loss
if self._constraints:
total_loss += constraint_loss
return tf_agent.LossInfo(total_loss, extra=())
return tf_agent.LossInfo(total_loss, extra=per_example_reward_loss)

def compute_summaries(
self, loss: types.Tensor, constraint_loss: Optional[types.Tensor] = None
Expand Down

0 comments on commit 528cef7

Please sign in to comment.