Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when calling get_initial_state on loaded policies saved with PolicySaver when batch_size != None #660

Open
hornstenfilip opened this issue Sep 23, 2021 · 0 comments

Comments

@hornstenfilip
Copy link

hornstenfilip commented Sep 23, 2021

  • OS Platform and Distribution: Google Colab
  • TensorFlow installed from: Binary
  • TensorFlow version: 2.6.0
  • TF Agents version: 0.9.0
  • Python version: Python 3.7.12
  • Installed using virtualenv? pip? conda?: conda/pip

When I try to call get_initial_state on a loaded SavedModelPyTFEagerPolicy that was saved with PolicySaver whenever the argument batch_size is not None, e.g.

policy_saver.PolicySaver(tf_agent.policy, batch_size=1, train_step=tf_agent.train_step_counter)

I encounter the below error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-44-cbee9415af2c> in <module>()
----> 1 py_eval_policy.get_initial_state(batch_size=1)

7 frames
/usr/local/lib/python3.7/dist-packages/tf_agents/policies/py_policy.py in get_initial_state(self, batch_size)
    137       An initial policy state.
    138     """
--> 139     return self._get_initial_state(batch_size)
    140 
    141   def action(self,

/usr/local/lib/python3.7/dist-packages/tf_agents/policies/py_tf_eager_policy.py in _get_initial_state(self, batch_size)
     90     if batch_size is None:
     91       batch_size = 0
---> 92     return self._policy.get_initial_state(batch_size=batch_size)
     93 
     94   def _action(self, time_step, policy_state, seed: Optional[types.Seed] = None):

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    883 
    884       with OptionalXlaContext(self._jit_compile):
--> 885         result = self._call(*args, **kwds)
    886 
    887       new_tracing_count = self.experimental_get_tracing_count()

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    931       # This is the first call of __call__, so we have to initialize.
    932       initializers = []
--> 933       self._initialize(args, kwds, add_initializers_to=initializers)
    934     finally:
    935       # At this point we know that the initialization is complete (or less

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    758     self._concrete_stateful_fn = (
    759         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 760             *args, **kwds))
    761 
    762     def invalid_creator_scope(*unused_args, **unused_kwds):

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3064       args, kwargs = None, None
   3065     with self._lock:
-> 3066       graph_function, _ = self._maybe_define_function(args, kwargs)
   3067     return graph_function
   3068 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3414     if self.input_signature is None or args is not None or kwargs is not None:
   3415       args, kwargs, flat_args, filtered_flat_args = \
-> 3416           self._function_spec.canonicalize_function_inputs(*args, **kwargs)
   3417     else:
   3418       flat_args, filtered_flat_args = [None], []

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in canonicalize_function_inputs(self, *args, **kwargs)
   2761         if len(missing_args) == 1:
   2762           raise TypeError("{} missing 1 required argument: {}".format(
-> 2763               self.signature_summary(), missing_args[0]))
   2764         else:
   2765           raise TypeError("{} missing required arguments: {}".format(

TypeError: f(self) missing 1 required argument: self

I have reproduced the issue on Google Colab by modifying the official REINFORCE agent tutorial notebook to use an ActorDistributionRnnNetwork instead of an ActorDistributionNetwork and saving the model with policy_saver.PolicySaver(tf_agent.policy, batch_size=1, train_step=tf_agent.train_step_counter). You can find it here. The same error is produced when trying to call get_initial_state on the pure TF policy loaded by tf.saved_model.load(policy_dir), so the issue seems to be strictly related to the saving of the model, not loading.

Setting batch_size=None during saving seems to be a temporary workaround in the CartPole case, since it's then possible to call get_initial_state on the loaded policy with an arbitrary batch_size. However, I have a custom environment which is batched and thus the model I have built always expects batched observations, so setting batch_size=None is not an option for me since it will cause an error when saving the model. I suspect the error has something to do with lines 255-256 of policy_saver.py

get_initial_state_fn = functools.partial(
          policy.get_initial_state, batch_size=batch_size)

since this is the block that gets executed when batch_size is not None. The loaded model is otherwise working fine, e.g. calling action works if I manually pass the initial state using tf.zeros with the appropriate shape.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant