Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MaskedCategorical warns about inheriting _parameter_properties from its parent #658

Open
emma-resor opened this issue Sep 20, 2021 · 5 comments

Comments

@emma-resor
Copy link

I am trying to use an action mask with a DQNAgent and, for the most part, have succeeded. But I have a persistent warning that I can't get rid of.. I think this is a convoluted example of the warning, but this is how I hit it.

Versions:

tensorflow 2.6.0
tensorflow-datasets 4.2.0
tensorflow-estimator 2.6.0
tensorflow-metadata 0.21.0
tensorflow-probability 0.14.0

Original Example (for a more contrived example, see below)

My setup:

QRnnNetwork, DQNAgent, custom environment

Observation spec defined as follows:

self._observation_spec = {
                                'observations': array_spec.BoundedArraySpec(name='observations',
                                                                            shape=len(day_df.columns),),
                                                                            dtype=np.float32,
                                                                            minimum=0,
                                                                            maximum=1),
                                'valid_actions': array_spec.BoundedArraySpec(name='valid_actions',
                                                                             shape=(2,),
                                                                             dtype=np.int32,
                                                                             minimum=0,
                                                                             maximum=1)
                                } 

I create the network using the network-relevant portion of the observation:

q_net = q_rnn_network.QRnnNetwork(
                tf_env.observation_spec()['observations'], # input tensor spec
                tf_env.action_spec(),
                input_fc_layer_params=input_fc_layer_params,
                lstm_size=lstm_size,
                output_fc_layer_params=output_fc_layer_params)

And the DQNAgent takes a splitter

def observation_and_action_splitter_fn(observation):
        return observation['observations'], observation['valid_actions']
tf_agent = dqn_agent.DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        observation_and_action_constraint_splitter=observation_and_action_splitter_fn,
        n_step_update=n_step_update,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

My Problem:

The warning I see (with great frequency) is the following:

WARNING:root:
Distribution subclass MaskedCategorical inherits `_parameter_properties from its parent (Categorical)
while also redefining `__init__`. The inherited annotations cover the following
parameters: dict_keys(['logits', 'probs']). It is likely that these do not match the subclass parameters.
This may lead to errors when computing batch shapes, slicing into batch
dimensions, calling `.copy()`, flattening the distribution as a CompositeTensor
(e.g., when it is passed or returned from a `tf.function`), and possibly other
cases. The recommended pattern for distribution subclasses is to define a new
`_parameter_properties` method with the subclass parameters, and to store the
corresponding parameter values as `self._parameters` in `__init__`, after
calling the superclass constructor:


class MySubclass(tfd.SomeDistribution):

  def __init__(self, param_a, param_b):
    parameters = dict(locals())
    # ... do subclass initialization ...
    super(MySubclass, self).__init__(**base_class_params)
    # Ensure that the subclass (not base class) parameters are stored.
    self._parameters = parameters

  def _parameter_properties(self, dtype, num_classes=None):
    return dict(
      # Annotations may optionally specify properties, such as `event_ndims`,
      # `default_constraining_bijector_fn`, `specifies_shape`, etc.; see
      # the `ParameterProperties` documentation for details.
      param_a=tfp.util.ParameterProperties(),
      param_b=tfp.util.ParameterProperties())

Which I have identified as originating from https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/distributions/distribution.py

I believe this warning is implying a deficiency in how another class inherits from the Distribution class. In my case, it rears its head when I use MaskedCategorical through specifying an action and constraint splitter fn.

Contrived Example

Here is a more contrived example where I hit this same warning:

from typing import cast
from tf_agents.distributions import masked

mask = np.array([1,1])
zero_logits = tf.cast(tf.zeros(mask), tf.float32)
masked_categorical = masked.MaskedCategorical(zero_logits, mask)
masked_categorical.sample()
WARNING:root:
Distribution subclass MaskedCategorical inherits `_parameter_properties from its parent (Categorical)
while also redefining `__init__`. The inherited annotations cover the following
parameters: dict_keys(['logits', 'probs']). It is likely that these do not match the subclass parameters.
This may lead to errors when computing batch shapes, slicing into batch
dimensions, calling `.copy()`, flattening the distribution as a CompositeTensor
(e.g., when it is passed or returned from a `tf.function`), and possibly other
cases. The recommended pattern for distribution subclasses is to define a new
`_parameter_properties` method with the subclass parameters, and to store the
corresponding parameter values as `self._parameters` in `__init__`, after
calling the superclass constructor:


class MySubclass(tfd.SomeDistribution):

  def __init__(self, param_a, param_b):
    parameters = dict(locals())
    # ... do subclass initialization ...
    super(MySubclass, self).__init__(**base_class_params)
    # Ensure that the subclass (not base class) parameters are stored.
    self._parameters = parameters

  def _parameter_properties(self, dtype, num_classes=None):
    return dict(
      # Annotations may optionally specify properties, such as `event_ndims`,
      # `default_constraining_bijector_fn`, `specifies_shape`, etc.; see
      # the `ParameterProperties` documentation for details.
      param_a=tfp.util.ParameterProperties(),
      param_b=tfp.util.ParameterProperties())


WARNING:root:
Distribution subclass MaskedCategorical inherits `_parameter_properties from its parent (Categorical)
while also redefining `__init__`. The inherited annotations cover the following
parameters: dict_keys(['logits', 'probs']). It is likely that these do not match the subclass parameters.
This may lead to errors when computing batch shapes, slicing into batch
dimensions, calling `.copy()`, flattening the distribution as a CompositeTensor
(e.g., when it is passed or returned from a `tf.function`), and possibly other
cases. The recommended pattern for distribution subclasses is to define a new
`_parameter_properties` method with the subclass parameters, and to store the
corresponding parameter values as `self._parameters` in `__init__`, after
calling the superclass constructor:


class MySubclass(tfd.SomeDistribution):

  def __init__(self, param_a, param_b):
    parameters = dict(locals())
    # ... do subclass initialization ...
    super(MySubclass, self).__init__(**base_class_params)
    # Ensure that the subclass (not base class) parameters are stored.
    self._parameters = parameters

  def _parameter_properties(self, dtype, num_classes=None):
    return dict(
      # Annotations may optionally specify properties, such as `event_ndims`,
      # `default_constraining_bijector_fn`, `specifies_shape`, etc.; see
      # the `ParameterProperties` documentation for details.
      param_a=tfp.util.ParameterProperties(),
      param_b=tfp.util.ParameterProperties())

Thanks

@emma-resor emma-resor changed the title _DistributionMeta warns about inheriting _parameter_properties from its parent MaskedCategorical warns about inheriting _parameter_properties from its parent Sep 20, 2021
@emma-resor
Copy link
Author

I know MaskedCategorical needs to define a _parameter_properties method to get rid of the warning, and I think this should do it:

@classmethod
def _parameter_properties(cls, dtype, num_classes=None):
  # pylint: disable=g-long-lambda
  return dict(
      logits=parameter_properties.ParameterProperties(
          event_ndims=1,
          shape_fn=lambda sample_shape: ps.concat(
              [sample_shape, [num_classes]], axis=0)))
  # pylint: enable=g-long-lambda

I came to this by mimicking the Categorical class's _parameter_properties method and removing probs (because it is required to be None). Would appreciate advice on if this is correct.

@o0RUI0o
Copy link

o0RUI0o commented Sep 22, 2021

I meet the same problem too. How can I turn off this warning? I have tried os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2", but it didn't work.

@tsimons89
Copy link

I am getting the same warning. I just installed tf_agent today. It would be nice to get an answer to this.

@wangdong-ivymobile
Copy link

got the same warning

@wangdong-ivymobile
Copy link

I found that tensorflow-probability logs the warning with the root logger, and maybe that's why TF_CPP_MIN_LOG_LEVEL has no effect on it.
Using logging.disable(logging.WARNING) can disable the warning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants