Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpointer.save, can global_step become a kwarg? #818

Open
tiamilani opened this issue Jan 19, 2023 · 0 comments
Open

Checkpointer.save, can global_step become a kwarg? #818

tiamilani opened this issue Jan 19, 2023 · 0 comments

Comments

@tiamilani
Copy link
Contributor

Hi, I'm not an expert in using TF agents, I've started to learn how to use this library quite recently, so I don't know if I'm just ignoring some implementation detail.

I noticed that the save() method of the utils.common.Checkpointer class asks for the global_step as a positional argument.
The global_step argument is then passed to the save function of the _manager object.
The piece of code I'm referring at is:

  def save(self, global_step: tf.Tensor,
           options: tf.train.CheckpointOptions = None):
    """Save state to checkpoint."""
    saved_checkpoint = self._manager.save(
        checkpoint_number=global_step, options=options)
    self._checkpoint_exists = True
    logging.info('%s', 'Saved checkpoint: {}'.format(saved_checkpoint))

Given that the CheckpoinManager.save() function accepts also None for the checkpoint_number kwarg, don't you think should be more correct to implement the save function as follows?

  def save(self, global_step: Optional[tf.Tensor] = None,
           options: tf.train.CheckpointOptions = None):
    """Save state to checkpoint."""
    saved_checkpoint = self._manager.save(
        checkpoint_number=global_step, options=options)
    self._checkpoint_exists = True
    logging.info('%s', 'Saved checkpoint: {}'.format(saved_checkpoint))

Thanks to this change the user would have the possibility to also use directly the checkpoint.save_counter mantained by the CheckpointManager class.

Am I missing some reason why it's mandatory to specify a global_step instead of giving the possibility to use the default CheckpointManager counter?

In case there is a positive feadback for this change I can also submit a pull request :)
Thanks for your help in advance!

@tiamilani tiamilani changed the title Checkpointer.save can global_step become a kwargs? Checkpointer.save, can global_step become a kwargs Jan 19, 2023
@tiamilani tiamilani changed the title Checkpointer.save, can global_step become a kwargs Checkpointer.save, can global_step become a kwarg? Jan 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant