-
Notifications
You must be signed in to change notification settings - Fork 728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MultiCategoricalProjectionNetwork #705
Add MultiCategoricalProjectionNetwork #705
Conversation
@sidney-tio Thank you for this PR. It could also solve #702 for me. Hope that this will be merged soon. |
hi team! could i request a review for this please? |
return distribs | ||
|
||
def _mode(self): | ||
return self._flatten_and_concat_event( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why mode needs flatten_and_concat_event but sample doesn't?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blockwise
from tf-probability doesn't implement mode
, but implements flatten_and_concat_event
for sample
and mean
.
|
||
def __init__(self, logits, categories_shape): | ||
self.categories_shape = categories_shape | ||
distribs = self._create_distrib(logits) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you validate that the dimensions of logits align with categories_shape?
Args: | ||
sample_spec: A collection of `tensor_spec.BoundedTensorSpec` detailing | ||
the shape and dtypes of samples pulled from the output distribution. | ||
logits_init_output_factor: Output factor for initializing kernal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not clear what does it mean output factor?
kernal -> kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed kernel spelling error.
The descriptor for logits_init_output_factor
was inherited from CategoricalProjectionNetwork
, but I do agree it is not exactly clear what output factor means here.
|
||
self._projection_layer = tf.keras.layers.Dense( | ||
self.n_unique_categories, | ||
kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why using VarianceScaling as initializer? Can you use just tf.keras.initializers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to output factor, the initializer was selected for CategoricalProjectionNetwork
, to which I'm not sure if there's any reason behind VarianceScaling used here
agents/tf_agents/networks/categorical_projection_network.py
Lines 81 to 82 in 5360685
kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling( | |
scale=logits_init_output_factor), |
|
||
def _categories_shape(self, sample_spec): | ||
def _get_n_categories(array_spec): | ||
if not tensor_spec.is_bounded(array_spec): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplify to
if tensor_spec.is_bounded(array_spec) and tensor_spec.is_discrete(array_spec) :
n_categories = array_spec.maximum - array_spec.minimum + 1
return n_categories
else:
raise ValueError('sample_spec must be discrete and bounded. Got: %s.' % array_spec)
logits = tf.reshape(logits, [-1] + [self.n_unique_categories]) | ||
logits = batch_squash.unflatten(logits) | ||
if mask is not None: | ||
# assume mask is a flattened array for now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose mask should have the same shape as actions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about the more appropriate approach for this so any advice here would be good; At this stage the outputs of the network exists as logits, i.e. the vector exists as flattened. In this case, mask
needs to be flattened as well.
Alternatively, if we want the mask
to be same shape as actions, we could do the masking during init of MultiCategoricalDistributionBlock
# Overwrite the logits for invalid actions to a very large negative | ||
# number. We do not use -inf because it produces NaNs in many tfp | ||
# functions. | ||
almost_neg_inf = tf.constant(logits.dtype.min, dtype = logits.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure lines are <80 and parameters don't have extra spaces.
almost_neg_inf = tf.constant(logits.dtype.min,
dtype=logits.dtype)
logits = tf.compat.v2.where( | ||
tf.cast(mask, tf.bool), logits, almost_neg_inf) | ||
|
||
return self.output_spec.build_distribution(logits= logits), () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(logits= logits) -> (logits=logits)
output_spec = [ | ||
tensor_spec.BoundedTensorSpec([], tf.int32, 0, 1), | ||
tensor_spec.BoundedTensorSpec([], tf.int32, 0, 4)] | ||
network = multi_categorical_projection_network.MultiCategoricalProjectionNetwork( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
too long of a name, can you make it shorter multi_categorical_network.MultiCategoricalNetwork
self.assertEqual(tfp.distributions.Categorical, type(distribution.distributions[0])) | ||
self.assertEqual(2, len(distribution.distributions)) | ||
self.assertEqual((3, 7), distribution._parameters['logits'].shape) | ||
self.assertEqual((3, 2), sample.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also test that the samples respect the bounds?
thanks for the review! made the changes as advised. Some of the changes requested were for code inherited from CategoricalProjectionNetwork, so I'm not sure if it warrants a separate PR to reconcile both? |
@sguada Is there a reason this was never merged? I think this one would make such a difference. |
Closing, after this much time I do not think anyone will get to it. I am merging others that I can review myself. This one is too big for me. Sorry this one got so far and didn't land. :-( |
@sguada just for visibility. |
Closes #694