-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bug example #688
Bug example #688
Conversation
The mistake can be traced to def streaming_average_update(
expectation, streaming_avg, weight=1.0, zero_prevention=0.0
): return streaming_avg The discrepancy disappears (at least in this case) |
I suspect that the bug is due to the multi-chain and vmap/pmap. Bascially ravel_pytree is flattening across chain for some reason. Nonetheless, could you try the below: def streaming_average_update(
current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0
):
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
current_value
the current value of the function that we want to take average of
previous_weight_and_average
tuple of (previous_weight, previous_average) where previous_weight is the
sum of weights and average is the current estimated average
weight
weight of the current state
zero_prevention
small value to prevent division by zero
Returns:
----------
new total weight and streaming average
"""
previous_weight, previous_average = previous_weight_and_average
current_weight = previous_weight + weight
current_average = jax.tree.map(
lambda x, avg: (previous_weight * avg + weight * x)
/ (current_weight + zero_prevention),
current_value,
previous_average,
)
return current_weight, current_average |
This doesn't affect the results. I don't think the pytree is the issue, since I also see the problem with: def streaming_average_update(
expectation, streaming_avg, weight=1.0, zero_prevention=0.0
):
total, average = streaming_avg
average = (total * average + weight * expectation) / (
total + weight + zero_prevention
)
total += weight
streaming_avg = (total, (average))
return streaming_avg |
Actually even with: def streaming_average_update(
expectation, streaming_avg, weight=1.0, zero_prevention=0.0
):
return streaming_avg I get a (small) discrepancy:
|
For a batch of 2 instead of 1, I also see a discrepancy, although it is small. I wonder if there is a key issue involved:
|
The other odd thing worth mentioning is that the errors often get less bad with longer runs and higher dimensions, e.g.: 10D Gaussian with 10000 steps
|
What appear to be a bug with jax (or at least a subtle case). A change of
map
betweenpmap
andvmap
inbug.py
changes the printed result.I also saw other very strange behaviors for this code that I have not yet reproduced in this example, where a print statement resulted in a similar change of value.