Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CategoricalDqnAgent Distributional Training Loss Function error #816

Open
jacklu333333 opened this issue Jan 7, 2023 · 0 comments
Open

Comments

@jacklu333333
Copy link

Hi,
I am using CategoricalDqnAgent with multiple GPUs, this is my code block.

with strategy.scope():

    collect_op_episode.run()

    experience, buffer_info = tf_agents.utils.eager_utils.get_next(iterator)
    train_loss = agent.train(experience)

and I got the following error.

RuntimeError                              Traceback (most recent call last)
File <timed exec>:30

File ~/.local/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153   raise e.with_traceback(filtered_tb) from None
    154 finally:
    155   del filtered_tb

File ~/.local/lib/python3.10/site-packages/tf_agents/agents/tf_agent.py:330, in TFAgent.train(self, experience, weights, **kwargs)
    325   raise RuntimeError(
    326       "Cannot find _train_fn.  Did %s.__init__ call super?"
    327       % type(self).__name__)
    329 if self._enable_functions:
--> 330   loss_info = self._train_fn(
    331       experience=experience, weights=weights, **kwargs)
    332 else:
    333   loss_info = self._train(experience=experience, weights=weights, **kwargs)

File ~/.local/lib/python3.10/site-packages/tf_agents/utils/common.py:188, in function_in_tf1.<locals>.maybe_wrap.<locals>.with_check_resource_vars(*fn_args, **fn_kwargs)
    184 check_tf1_allowed()
    185 if has_eager_been_enabled():
    186   # We're either in eager mode or in tf.function mode (no in-between); so
    187   # autodep-like behavior is already expected of fn.
--> 188   return fn(*fn_args, **fn_kwargs)
    189 if not resource_variables_enabled():
    190   raise RuntimeError(MISSING_RESOURCE_VARIABLES_ERROR)

File ~/.local/lib/python3.10/site-packages/tf_agents/agents/dqn/dqn_agent.py:393, in DqnAgent._train(self, experience, weights)
    391 def _train(self, experience, weights):
    392   with tf.GradientTape() as tape:
--> 393     loss_info = self._loss(
    394         experience,
    395         td_errors_loss_fn=self._td_errors_loss_fn,
    396         gamma=self._gamma,
    397         reward_scale_factor=self._reward_scale_factor,
    398         weights=weights,
    399         training=True)
    400   tf.debugging.check_numerics(loss_info.loss, 'Loss is inf or nan')
    401   variables_to_train = self._q_network.trainable_weights

File ~/.local/lib/python3.10/site-packages/tf_agents/agents/categorical_dqn/categorical_dqn_agent.py:436, in CategoricalDqnAgent._loss(self, experience, td_errors_loss_fn, gamma, reward_scale_factor, weights, training)
    431 else:
    432   critic_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
    433       labels=target_distribution,
    434       logits=chosen_action_logits)
--> 436 agg_loss = common.aggregate_losses(
    437     per_example_loss=critic_loss,
    438     regularization_loss=self._q_network.losses)
    439 total_loss = agg_loss.total_loss
    441 dict_losses = {'critic_loss': agg_loss.weighted,
    442                'reg_loss': agg_loss.regularization,
    443                'total_loss': total_loss}

File ~/.local/lib/python3.10/site-packages/tf_agents/utils/common.py:1376, in aggregate_losses(per_example_loss, sample_weight, global_batch_size, regularization_loss)
   1372     per_example_loss = tf.reduce_mean(per_example_loss, range(1, loss_rank))
   1374   global_batch_size = global_batch_size and tf.cast(global_batch_size,
   1375                                                     per_example_loss.dtype)
-> 1376   weighted_loss = tf.nn.compute_average_loss(
   1377       per_example_loss,
   1378       global_batch_size=global_batch_size)
   1379   total_loss = weighted_loss
   1380 # Add scaled regularization losses.

RuntimeError: You are calling `compute_average_loss` in cross replica context, while it was expected to be called in replica context.

It seems that the loss computed by the CategoricalDqnAgent doesn't support these features?

Best Regards,
Jack Lu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant