diff --git a/flybody/agents/utils_tf.py b/flybody/agents/utils_tf.py index 3736707..11066a3 100755 --- a/flybody/agents/utils_tf.py +++ b/flybody/agents/utils_tf.py @@ -1,5 +1,7 @@ """Utilities for tensorflow networks and nested data structures.""" +import numpy as np +from acme import types from acme.tf import utils as tf2_utils @@ -17,7 +19,9 @@ def __init__(self, policy, sample=False): self._policy = policy self._sample = sample - def __call__(self, observation): + def __call__(self, observation: types.NestedArray) -> np.ndarray: + # Add a dummy batch dimension and as a side effect convert numpy to TF, + # batched_observation: types.NestedTensor. batched_observation = tf2_utils.add_batch_dim(observation) distribution = self._policy(batched_observation) if self._sample: