Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove reference cycle in VecAccessMixin #4033

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

angus-g
Copy link
Contributor

@angus-g angus-g commented Feb 12, 2025

With an associated PETSc Vec, VecAccessMixin deferred its version property to a lambda to avoid allocating the storage until necessary. Unfortunately, this lambda creates a reference cycle to self for all users of the VecAccessMixin. Given that counter accesses should be relatively infrequent, it seems fine to look up the counter type within the method itself.

Description

Related to #4014. To benchmark, I'm using the following script (very similar to the one in the linked issue, but uses 500 timesteps, a timestepper object, and removes explicit GC calls):

from firedrake import *
from firedrake.adjoint import *
# from memory_profiler import profile

def test():
    T_c, rf = rf_generator()
    rf.fwd_call = profile(rf.__call__)
    rf.derivative = profile(rf.derivative)

    for i in range(5):
        rf.fwd_call(T_c)
        rf.derivative()

@profile
def rf_generator(checkpoint_to_disk=True):
    tape = get_working_tape()
    tape.clear_tape()
    continue_annotation()

    mesh = RectangleMesh(100, 100, 1.0, 1.0)

    if checkpoint_to_disk:
        enable_disk_checkpointing()
        mesh = checkpointable_mesh(mesh)

    V = VectorFunctionSpace(mesh, "CG", 2)
    Q = FunctionSpace(mesh, "CG", 1)

    # Define the rotation vector field
    X = SpatialCoordinate(mesh)

    w = Function(V, name="rotation").interpolate(as_vector([-X[1] - 0.5, X[0] - 0.5]))
    T_c = Function(Q, name="control")
    T = Function(Q, name="Temperature")
    T_c.interpolate(0.1 * exp(-0.5 * ((X - as_vector((0.75, 0.5))) / Constant(0.1)) ** 2))
    control = Control(T_c)
    T.assign(T_c)

    # for i in ts:
    for i in tape.timestepper(iter(range(500))):
        T.interpolate(T + inner(grad(T), w) * Constant(0.0001))

    objective = assemble(T**2 * dx)

    pause_annotation()
    return T_c, ReducedFunctional(objective, control)


if __name__ == "__main__":
    test()

I'm also running this on the #4020 branch to automatically enable the SingleDiskStorageSchedule and handle the leak of function within CheckpointFunction. On the pyadjoint side, I am using dolfin-adjoint/pyadjoint#194.

Here's a pretty simple mprof comparison:
image
In black is the base, without this branch. In blue is the base, but with gc.collect() within Block.recompute (very eager, and expensive, also doesn't apply to the derivative). In red is the result with this branch, without any explicit gc. Individual plots follow, but the rescaling means you have to look a bit more closely.

Base plot

image

GC plot

image

This PR

image

I think there is a still a bit left out there in terms of making expensive allocations delete through refcounting, and perhaps there's a more elegant way of implementing the change proposed here.

Copy link

github-actions bot commented Feb 12, 2025

TestsPassed ✅Skipped ⏭️Failed ❌
Firedrake real8196 ran7484 passed712 skipped0 failed

Copy link

github-actions bot commented Feb 12, 2025

TestsPassed ✅Skipped ⏭️Failed ❌
Firedrake complex8138 ran6478 passed1660 skipped0 failed

Copy link
Contributor

@connorjward connorjward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an excellent spot!

I have absolutely no idea why this is failing tests though... AFAICT the changes you have made shouldn't impact the rest of the code.

pyop2/types/data_carrier.py Outdated Show resolved Hide resolved
With an associated PETSc Vec, VecAccessMixin deferred its version
property to a lambda to avoid allocating the storage until necessary.
Unfortunately, this lambda creates a reference cycle to self for all
users of the VecAccessMixin. Given that counter accesses should be
relatively infrequent, it seems fine to look up the counter type within
the method itself.
Doesn't make sense to cache a reference to self, just return self.
@angus-g angus-g force-pushed the angus-g/vecaccess-cycle branch from 97b5063 to 2af3ac7 Compare February 17, 2025 01:19
@angus-g
Copy link
Contributor Author

angus-g commented Feb 17, 2025

I had to do a few wider modifications around the inheritance of AbstractDat and VecAccessMixin. Hopefully it passes the tests now.

@angus-g angus-g force-pushed the angus-g/vecaccess-cycle branch from 9643d7d to a17b255 Compare February 17, 2025 03:10
The inheritance chain for Dat and Global puts VecAccessMixin
(rightly) behind DataCarrier. This means that by MRO, the
increment_dat_version method provided on DataCarrier will be used,
which is a null operation. I think this makes sense, given that not all
classes use this. However, because we're providing
increment_dat_version as an override through VecAccessMixin, we need to
explicitly refer to it in the inheriting classes.
@angus-g angus-g force-pushed the angus-g/vecaccess-cycle branch from a17b255 to afe09ea Compare February 17, 2025 04:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants