Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug example #688

Closed
wants to merge 4 commits into from
Closed

Bug example #688

wants to merge 4 commits into from

Conversation

reubenharry
Copy link
Contributor

@reubenharry reubenharry commented Jun 2, 2024

What appear to be a bug with jax (or at least a subtle case). A change of map between pmap and vmap in bug.py changes the printed result.

I also saw other very strange behaviors for this code that I have not yet reproduced in this example, where a print statement resulted in a similar change of value.

@reubenharry
Copy link
Contributor Author

The mistake can be traced to streaming_average_update, but I don't know exactly why. For example, if we replace streaming_average_update by:

def streaming_average_update(
    expectation, streaming_avg, weight=1.0, zero_prevention=0.0
): return streaming_avg

The discrepancy disappears (at least in this case)

@junpenglao
Copy link
Member

junpenglao commented Jun 3, 2024

I suspect that the bug is due to the multi-chain and vmap/pmap. Bascially ravel_pytree is flattening across chain for some reason.

Nonetheless, could you try the below:

def streaming_average_update(
    current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0
):
    """Compute the streaming average of a function O(x) using a weight.
    Parameters:
    ----------
        current_value
            the current value of the function that we want to take average of
        previous_weight_and_average
            tuple of (previous_weight, previous_average) where previous_weight is the
            sum of weights and average is the current estimated average
        weight
            weight of the current state
        zero_prevention
            small value to prevent division by zero
    Returns:
    ----------
        new total weight and streaming average
    """
    previous_weight, previous_average = previous_weight_and_average
    current_weight = previous_weight + weight
    current_average = jax.tree.map(
        lambda x, avg: (previous_weight * avg + weight * x)
        / (current_weight + zero_prevention),
        current_value,
        previous_average,
    )
    return current_weight, current_average

@reubenharry
Copy link
Contributor Author

This doesn't affect the results. I don't think the pytree is the issue, since I also see the problem with:

def streaming_average_update(
    expectation, streaming_avg, weight=1.0, zero_prevention=0.0
):
    total, average = streaming_avg
    average = (total * average + weight * expectation) / (
        total + weight + zero_prevention
    )
    total += weight
    streaming_avg = (total, (average))
    return streaming_avg

@reubenharry
Copy link
Contributor Author

Actually even with:

def streaming_average_update(
    expectation, streaming_avg, weight=1.0, zero_prevention=0.0
):

    return streaming_avg

I get a (small) discrepancy:

Result with <function pmap at 0x11f3b40e0> is [[ 1.2639244  -0.19290113]]
Result with <function vmap at 0x1224eff60> is [[ 1.2648091  -0.19346505]]

@reubenharry
Copy link
Contributor Author

For a batch of 2 instead of 1, I also see a discrepancy, although it is small. I wonder if there is a key issue involved:

Result with <function pmap at 0x127feff60> is [[ 0.13665608 -1.9322048 ]
 [ 1.276183   -0.18907894]]

Result with <function vmap at 0x12ff97f60> is [[ 0.13712421 -1.9322233 ]
 [ 1.2760496  -0.18907525]]

@reubenharry
Copy link
Contributor Author

The other odd thing worth mentioning is that the errors often get less bad with longer runs and higher dimensions, e.g.: 10D Gaussian with 10000 steps

Result with <function pmap at 0x165fc8180> is [[ 0.8321227  -2.5019531  -0.7922016  -1.3899261   1.2677195  -0.35655203
   0.48542497 -0.03980938  0.14613445 -1.3119547 ]
 [-0.03221251  0.5771989  -0.23249874 -0.62638044  0.9548867  -1.3366085
  -0.1022618   0.6617969   0.9869026  -1.4713044 ]]

Result with <function vmap at 0x11d3abe20> is [[ 0.8321227  -2.501953   -0.79220164 -1.389926    1.2677199  -0.35655132
   0.4854263  -0.03980982  0.14613375 -1.3119547 ]
 [-0.03221169  0.5772002  -0.2324989  -0.6263806   0.9548884  -1.3366088
  -0.1022612   0.6617972   0.9869019  -1.4713038 ]]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants