diff --git a/tf_agents/bandits/networks/global_and_arm_feature_network.py b/tf_agents/bandits/networks/global_and_arm_feature_network.py index 907b82343..ecfaac9e8 100644 --- a/tf_agents/bandits/networks/global_and_arm_feature_network.py +++ b/tf_agents/bandits/networks/global_and_arm_feature_network.py @@ -67,6 +67,9 @@ def create_feed_forward_common_tower_network( activation_fn: Callable[ [types.Tensor], types.Tensor ] = tf.keras.activations.relu, + final_activation_fn_for_q_network: ( + Callable[[types.Tensor], types.Tensor] | None + ) = None, name: Optional[str] = None, ) -> types.Network: """Creates a common tower network with feedforward towers. @@ -92,6 +95,9 @@ def create_feed_forward_common_tower_network( arm_preprocessing_combiner: Preprocessing combiner for the arm features. activation_fn: A keras activation, specifying the activation function used in all layers. Defaults to relu. + final_activation_fn_for_q_network: Only used when `output_dim=1`, a Callable + specifying the activation function for the final layer of the common + tower. Defaults to None, which means no activation function. name: The network name to use. Shows up in Tensorboard losses. Returns: @@ -145,6 +151,7 @@ def create_feed_forward_common_tower_network( ), fc_layer_params=common_layers, activation_fn=activation_fn, + q_layer_activation_fn=final_activation_fn_for_q_network, ) else: common_network = encoding_network.EncodingNetwork( diff --git a/tf_agents/bandits/networks/global_and_arm_feature_network_test.py b/tf_agents/bandits/networks/global_and_arm_feature_network_test.py index c5b87d9f1..284240459 100644 --- a/tf_agents/bandits/networks/global_and_arm_feature_network_test.py +++ b/tf_agents/bandits/networks/global_and_arm_feature_network_test.py @@ -69,6 +69,29 @@ def testCreateFeedForwardCommonTowerNetwork( output = self.evaluate(output) self.assertAllEqual(output.shape, (batch_size, num_actions)) + @parameters + def testCreateFeedForwardCommonTowerNetworkWithFinalActivation( + self, batch_size, feature_dim, num_actions + ): + obs_spec = bandit_spec_utils.create_per_arm_observation_spec( + 7, feature_dim, num_actions + ) + net = gafn.create_feed_forward_common_tower_network( + obs_spec, + global_layers=(4, 3, 2), + arm_layers=(6, 5, 4), + common_layers=(7, 6, 5), + final_activation_fn_for_q_network=tf.exp, + ) + input_nest = tensor_spec.sample_spec_nest( + obs_spec, outer_dims=(batch_size,) + ) + output, _ = net(input_nest) + self.evaluate(tf.compat.v1.global_variables_initializer()) + output = self.evaluate(output) + self.assertAllEqual(output.shape, (batch_size, num_actions)) + self.assertAllGreaterEqual(output, 0.0) + @parameters def testCreateFeedForwardCommonTowerNetworkWithEmptyLayers( self, batch_size, feature_dim, num_actions