We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi, I am using CategoricalDqnAgent with multiple GPUs, this is my code block.
CategoricalDqnAgent
with strategy.scope(): collect_op_episode.run() experience, buffer_info = tf_agents.utils.eager_utils.get_next(iterator) train_loss = agent.train(experience)
and I got the following error.
RuntimeError Traceback (most recent call last) File <timed exec>:30 File ~/.local/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs) 151 except Exception as e: 152 filtered_tb = _process_traceback_frames(e.__traceback__) --> 153 raise e.with_traceback(filtered_tb) from None 154 finally: 155 del filtered_tb File ~/.local/lib/python3.10/site-packages/tf_agents/agents/tf_agent.py:330, in TFAgent.train(self, experience, weights, **kwargs) 325 raise RuntimeError( 326 "Cannot find _train_fn. Did %s.__init__ call super?" 327 % type(self).__name__) 329 if self._enable_functions: --> 330 loss_info = self._train_fn( 331 experience=experience, weights=weights, **kwargs) 332 else: 333 loss_info = self._train(experience=experience, weights=weights, **kwargs) File ~/.local/lib/python3.10/site-packages/tf_agents/utils/common.py:188, in function_in_tf1.<locals>.maybe_wrap.<locals>.with_check_resource_vars(*fn_args, **fn_kwargs) 184 check_tf1_allowed() 185 if has_eager_been_enabled(): 186 # We're either in eager mode or in tf.function mode (no in-between); so 187 # autodep-like behavior is already expected of fn. --> 188 return fn(*fn_args, **fn_kwargs) 189 if not resource_variables_enabled(): 190 raise RuntimeError(MISSING_RESOURCE_VARIABLES_ERROR) File ~/.local/lib/python3.10/site-packages/tf_agents/agents/dqn/dqn_agent.py:393, in DqnAgent._train(self, experience, weights) 391 def _train(self, experience, weights): 392 with tf.GradientTape() as tape: --> 393 loss_info = self._loss( 394 experience, 395 td_errors_loss_fn=self._td_errors_loss_fn, 396 gamma=self._gamma, 397 reward_scale_factor=self._reward_scale_factor, 398 weights=weights, 399 training=True) 400 tf.debugging.check_numerics(loss_info.loss, 'Loss is inf or nan') 401 variables_to_train = self._q_network.trainable_weights File ~/.local/lib/python3.10/site-packages/tf_agents/agents/categorical_dqn/categorical_dqn_agent.py:436, in CategoricalDqnAgent._loss(self, experience, td_errors_loss_fn, gamma, reward_scale_factor, weights, training) 431 else: 432 critic_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( 433 labels=target_distribution, 434 logits=chosen_action_logits) --> 436 agg_loss = common.aggregate_losses( 437 per_example_loss=critic_loss, 438 regularization_loss=self._q_network.losses) 439 total_loss = agg_loss.total_loss 441 dict_losses = {'critic_loss': agg_loss.weighted, 442 'reg_loss': agg_loss.regularization, 443 'total_loss': total_loss} File ~/.local/lib/python3.10/site-packages/tf_agents/utils/common.py:1376, in aggregate_losses(per_example_loss, sample_weight, global_batch_size, regularization_loss) 1372 per_example_loss = tf.reduce_mean(per_example_loss, range(1, loss_rank)) 1374 global_batch_size = global_batch_size and tf.cast(global_batch_size, 1375 per_example_loss.dtype) -> 1376 weighted_loss = tf.nn.compute_average_loss( 1377 per_example_loss, 1378 global_batch_size=global_batch_size) 1379 total_loss = weighted_loss 1380 # Add scaled regularization losses. RuntimeError: You are calling `compute_average_loss` in cross replica context, while it was expected to be called in replica context.
It seems that the loss computed by the CategoricalDqnAgent doesn't support these features?
Best Regards, Jack Lu
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Hi,
I am using
CategoricalDqnAgent
with multiple GPUs, this is my code block.and I got the following error.
It seems that the loss computed by the
CategoricalDqnAgent
doesn't support these features?Best Regards,
Jack Lu
The text was updated successfully, but these errors were encountered: