You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Installed using virtualenv? pip? conda?: conda/pip
When I try to call get_initial_state on a loaded SavedModelPyTFEagerPolicy that was saved with PolicySaver whenever the argument batch_size is not None, e.g.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-44-cbee9415af2c> in <module>()
----> 1 py_eval_policy.get_initial_state(batch_size=1)
7 frames
/usr/local/lib/python3.7/dist-packages/tf_agents/policies/py_policy.py in get_initial_state(self, batch_size)
137 An initial policy state.
138 """
--> 139 return self._get_initial_state(batch_size)
140
141 def action(self,
/usr/local/lib/python3.7/dist-packages/tf_agents/policies/py_tf_eager_policy.py in _get_initial_state(self, batch_size)
90 if batch_size is None:
91 batch_size = 0
---> 92 return self._policy.get_initial_state(batch_size=batch_size)
93
94 def _action(self, time_step, policy_state, seed: Optional[types.Seed] = None):
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
883
884 with OptionalXlaContext(self._jit_compile):
--> 885 result = self._call(*args, **kwds)
886
887 new_tracing_count = self.experimental_get_tracing_count()
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
931 # This is the first call of __call__, so we have to initialize.
932 initializers = []
--> 933 self._initialize(args, kwds, add_initializers_to=initializers)
934 finally:
935 # At this point we know that the initialization is complete (or less
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
758 self._concrete_stateful_fn = (
759 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
--> 760 *args, **kwds))
761
762 def invalid_creator_scope(*unused_args, **unused_kwds):
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
3064 args, kwargs = None, None
3065 with self._lock:
-> 3066 graph_function, _ = self._maybe_define_function(args, kwargs)
3067 return graph_function
3068
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3414 if self.input_signature is None or args is not None or kwargs is not None:
3415 args, kwargs, flat_args, filtered_flat_args = \
-> 3416 self._function_spec.canonicalize_function_inputs(*args, **kwargs)
3417 else:
3418 flat_args, filtered_flat_args = [None], []
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in canonicalize_function_inputs(self, *args, **kwargs)
2761 if len(missing_args) == 1:
2762 raise TypeError("{} missing 1 required argument: {}".format(
-> 2763 self.signature_summary(), missing_args[0]))
2764 else:
2765 raise TypeError("{} missing required arguments: {}".format(
TypeError: f(self) missing 1 required argument: self
I have reproduced the issue on Google Colab by modifying the official REINFORCE agent tutorial notebook to use an ActorDistributionRnnNetwork instead of an ActorDistributionNetwork and saving the model with policy_saver.PolicySaver(tf_agent.policy, batch_size=1, train_step=tf_agent.train_step_counter). You can find it here. The same error is produced when trying to call get_initial_state on the pure TF policy loaded by tf.saved_model.load(policy_dir), so the issue seems to be strictly related to the saving of the model, not loading.
Setting batch_size=None during saving seems to be a temporary workaround in the CartPole case, since it's then possible to call get_initial_state on the loaded policy with an arbitrary batch_size. However, I have a custom environment which is batched and thus the model I have built always expects batched observations, so setting batch_size=None is not an option for me since it will cause an error when saving the model. I suspect the error has something to do with lines 255-256 of policy_saver.py
since this is the block that gets executed when batch_size is not None. The loaded model is otherwise working fine, e.g. calling action works if I manually pass the initial state using tf.zeros with the appropriate shape.
The text was updated successfully, but these errors were encountered:
When I try to call
get_initial_state
on a loadedSavedModelPyTFEagerPolicy
that was saved withPolicySaver
whenever the argumentbatch_size
is not None, e.g.I encounter the below error:
I have reproduced the issue on Google Colab by modifying the official REINFORCE agent tutorial notebook to use an
ActorDistributionRnnNetwork
instead of anActorDistributionNetwork
and saving the model withpolicy_saver.PolicySaver(tf_agent.policy, batch_size=1, train_step=tf_agent.train_step_counter)
. You can find it here. The same error is produced when trying to callget_initial_state
on the pure TF policy loaded bytf.saved_model.load(policy_dir)
, so the issue seems to be strictly related to the saving of the model, not loading.Setting
batch_size=None
during saving seems to be a temporary workaround in the CartPole case, since it's then possible to callget_initial_state
on the loaded policy with an arbitrarybatch_size
. However, I have a custom environment which is batched and thus the model I have built always expects batched observations, so settingbatch_size=None
is not an option for me since it will cause an error when saving the model. I suspect the error has something to do with lines 255-256 ofpolicy_saver.py
since this is the block that gets executed when
batch_size
is not None. The loaded model is otherwise working fine, e.g. callingaction
works if I manually pass the initial state usingtf.zeros
with the appropriate shape.The text was updated successfully, but these errors were encountered: