From a88bd83ebd6d7b0b5c7575984e0892932ff67e2e Mon Sep 17 00:00:00 2001 From: Roman Vaxenburg Date: Wed, 15 May 2024 19:30:12 -0400 Subject: [PATCH] Fix types and add callback. --- flybody/agents/actors.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/flybody/agents/actors.py b/flybody/agents/actors.py index 522a375..f329e33 100755 --- a/flybody/agents/actors.py +++ b/flybody/agents/actors.py @@ -1,6 +1,6 @@ """Acme agent implementations.""" -from typing import Optional +from typing import Callable from acme import adders from acme import core @@ -27,9 +27,10 @@ class DelayedFeedForwardActor(core.Actor): def __init__( self, policy_network: snt.Module, - adder: Optional[adders.Adder] = None, - variable_client: Optional[tf2_variable_utils.VariableClient] = None, - action_delay: Optional[int] = None, + adder: adders.Adder | None = None, + variable_client: tf2_variable_utils.VariableClient | None = None, + action_delay: int | None = None, + observation_callback: Callable | None = None, ): """Initializes the actor. @@ -40,6 +41,8 @@ def __init__( variable_client: object which allows to copy weights from the learner copy of the policy to the actor copy (in case they are separate). action_delay: number of timesteps to delay the action for. + observation_callback: Optional callable to process observations before + passing them to policy. """ # Store these for later use. @@ -49,9 +52,10 @@ def __init__( self._action_delay = action_delay if action_delay is not None: self._action_queue = [] + self._observation_callback = observation_callback @tf.function - def _policy(self, observation: types.NestedTensor) -> types.NestedTensor: + def _policy(self, observation: types.NestedArray) -> types.NestedTensor: # Add a dummy batch dimension and as a side effect convert numpy to TF. batched_observation = tf2_utils.add_batch_dim(observation) @@ -66,6 +70,9 @@ def _policy(self, observation: types.NestedTensor) -> types.NestedTensor: def select_action(self, observation: types.NestedArray) -> types.NestedArray: + """Samples from the policy and returns an action.""" + if self._observation_callback is not None: + observation = self._observation_callback(observation) # Pass the observation through the policy network. action = self._policy(observation)