Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ralberd/path dependent adjoint #69

Merged
merged 22 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
ea74838
minor fix for Jax warning
ralberd Oct 27, 2023
2456a9f
unit tests for path dependent adjoints using j2
ralberd Oct 27, 2023
c2a0cfd
adding finite difference check test harness; adding FD checks for hyp…
ralberd Oct 31, 2023
9956926
pulling out gradient checks as separate file
ralberd Nov 1, 2023
112b10b
creating functions to avoid superfluous computation of 0 gradient val…
ralberd Nov 2, 2023
905d458
fixed bug in J2 plasticity adjoint
ralberd Nov 3, 2023
904a916
turning on actual total work objective now that sensitivity is fixed
ralberd Nov 3, 2023
e8e2941
Merge branch 'main' into ralberd/path_dependent_adjoint
ralberd Nov 3, 2023
e287a6b
WIP: moving functionality from scripts to MechanicsInverse; jitting f…
ralberd Nov 9, 2023
6cfd43d
adding gradient checks for HyperViscoelastiticy; needs to be updated
ralberd Nov 9, 2023
f618c1b
moving functionality into MechanicsInverse and improving jitting
ralberd Nov 10, 2023
9890dff
Merge branch 'main' into ralberd/path_dependent_adjoint
ralberd Nov 14, 2023
638dab4
changing use of dt to match Mechanics.py
ralberd Nov 14, 2023
1bb5080
Merge branch 'main' into ralberd/path_dependent_adjoint
ralberd Dec 5, 2023
056a392
adding gradient check for target curve difference L2 norm objective
ralberd Jan 3, 2024
670885f
Merge branch 'main' into ralberd/path_dependent_adjoint
ralberd Jan 3, 2024
6d6b713
updating to incorporate new isAxisymmetric variable in FunctionSpace
ralberd Jan 3, 2024
24b4f32
changing gradient checks to only check one step size (reduces run time)
ralberd Jan 4, 2024
2b15305
getting rid of repeated function definition
ralberd Jan 14, 2024
65732b3
fixing imports
ralberd Jan 14, 2024
5d3c31a
fixing file permissions
ralberd Jan 17, 2024
ebb5fb8
changing import of numpy to be onp
ralberd Jan 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimism/Interpolants.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_lobatto_nodes_1d(degree):
p = onp.polynomial.Legendre.basis(degree, domain=[0.0, 1.0])
dp = p.deriv()
xInterior = dp.roots()
xn = np.hstack(([0.0], xInterior, [1.0]))
xn = np.hstack((np.array([0.0]), xInterior, np.array([1.0])))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, those warnings were annoying.

return xn


Expand Down
30 changes: 30 additions & 0 deletions optimism/inverse/AdjointFunctionSpace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from optimism import FunctionSpace
from optimism import Interpolants
from optimism import Mesh
from optimism.FunctionSpace import compute_element_volumes
from optimism.FunctionSpace import compute_element_volumes_axisymmetric
from optimism.FunctionSpace import map_element_shape_grads
from jax import vmap

def construct_function_space_for_adjoint(coords, mesh, quadratureRule, mode2D='cartesian'):

shapeOnRef = Interpolants.compute_shapes(mesh.parentElement, quadratureRule.xigauss)

shapes = vmap(lambda elConns, elShape: elShape, (0, None))(mesh.conns, shapeOnRef.values)

shapeGrads = vmap(map_element_shape_grads, (None, 0, None, None))(coords, mesh.conns, mesh.parentElement, shapeOnRef.gradients)

if mode2D == 'cartesian':
el_vols = compute_element_volumes
isAxisymmetric = False
elif mode2D == 'axisymmetric':
el_vols = compute_element_volumes_axisymmetric
isAxisymmetric = True
vols = vmap(el_vols, (None, 0, None, 0, None))(coords, mesh.conns, mesh.parentElement, shapes, quadratureRule.wgauss)

# unpack mesh and remake a mesh to make sure we get all the AD
mesh = Mesh.Mesh(coords=coords, conns=mesh.conns, simplexNodesOrdinals=mesh.simplexNodesOrdinals,
parentElement=mesh.parentElement, parentElement1d=mesh.parentElement1d, blocks=mesh.blocks,
nodeSets=mesh.nodeSets, sideSets=mesh.sideSets)

return FunctionSpace.FunctionSpace(shapes, vols, shapeGrads, mesh, quadratureRule, isAxisymmetric)
99 changes: 99 additions & 0 deletions optimism/inverse/MechanicsInverse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from collections import namedtuple

from optimism.JaxConfig import *
Copy link
Collaborator

@btalamini btalamini Jan 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need this import. It's a legacy thing I'm trying to remove. If it doesn't work without it, show me the error and I'll fix it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I happen to like import * from JaxConfig. For me, its intent was to avoid having to constantly duplicate lots of lines like: from jax import grad, import blah, ... which I found to be annoying in every file. It's an example where the DRY (dont repeat yourself) principal is in conflict with other principals. For me, I like DRY here, but I know there is a trade off. I understand that it pull in extra things, but that is essentially the point! Without it, the tops of files are more verbose and it takes longer to get to the actual work of the file without such things. If JaxConfig pull in things that may be duplicated or we don't want, we should delete them from JaxConfig.

Copy link
Collaborator

@tupek2 tupek2 Jan 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus, the way I think about it at least, I don't want to type jax.grad, because that is essentially hard-coding the implemenation to jax (which we are hard coded to anyway, but whatever). Its like in c++, where you can say using MatrixClass = Eigen::Matrix, but, potentially, if one is very careful, it is possible to switch out the matrix by just changing MatrixClass = Armadillo::Matrix. So, I don't think of it as polluting the namespace, I think of it as concretizing the specific grad our library is using.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand what you mean, but there are better ways of doing that. The wildcard import is discouraged pretty much universally because it makes it so hard to understand where things are coming from, and names can collide and unintentionally overwrite the one you want. It's essentially like a GOTO - you'd have to follow all the import * instances to see which grad you're ultimately getting. Even the Python project itself doesn't allow it in their style any more. If you want to use the concretization/substitution idea, there are ways to do it by importing things in the __init__.py file for the project, which is similar, but forces you to be explicit (like your MatrixClass = Armadillo::Matrix example).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, can we basically put things like grad = jax.grad, etc. in init then? Basically things which are essentially always used and in the way we generally want to use them, so we can establish a standard way of doing it. That was the intent of the .Config file. In C++, the equivalent is basically #include "types.hpp", where types are put into the same namespace as the project on purpose. That is essentially the same thing I want to do in python. The issues is that putting it in init puts in in EVERY file, where as "from optimism.types import *" is more selective. So I don't really appreciate what the issue with that is. We are not doing "from jax import *", we are including a types header.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should discuss this offline. Ryan, you can leave things as they are until we come to an agreement.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, go ahead and leave the wildcard import. We'll make a sweeping change in the future to get rid of them. Try to leave out the JaxConfig import in the future.

from optimism import FunctionSpace
from optimism import Interpolants
from optimism.TensorMath import tensor_2D_to_3D

IvsUpdateInverseFunctions = namedtuple('IvsUpdateInverseFunctions',
['ivs_update_jac_ivs_prev',
'ivs_update_jac_disp_vjp',
'ivs_update_jac_coords_vjp'])

PathDependentResidualInverseFunctions = namedtuple('PathDependentResidualInverseFunctions',
['residual_jac_ivs_prev_vjp',
'residual_jac_coords_vjp'])

ResidualInverseFunctions = namedtuple('ResidualInverseFunctions',
['residual_jac_coords_vjp'])

def _compute_element_field_gradient(U, elemShapeGrads, elemConnectivity, modify_element_gradient):
elemNodalDisps = U[elemConnectivity]
elemGrads = vmap(FunctionSpace.compute_quadrature_point_field_gradient, (None, 0))(elemNodalDisps, elemShapeGrads)
elemGrads = modify_element_gradient(elemGrads)
return elemGrads

def _compute_field_gradient(shapeGrads, conns, nodalField, modify_element_gradient):
return vmap(_compute_element_field_gradient, (None,0,0,None))(nodalField, shapeGrads, conns, modify_element_gradient)

def _compute_updated_internal_variables_gradient(dispGrads, states, dt, compute_state_new, output_shape):
dgQuadPointRavel = dispGrads.reshape(dispGrads.shape[0]*dispGrads.shape[1],*dispGrads.shape[2:])
stQuadPointRavel = states.reshape(states.shape[0]*states.shape[1],*states.shape[2:])
statesNew = vmap(compute_state_new, (0, 0, None))(dgQuadPointRavel, stQuadPointRavel, dt)
return statesNew.reshape(output_shape)


def create_ivs_update_inverse_functions(functionSpace, mode2D, materialModel, pressureProjectionDegree=None):
fs = functionSpace
shapeOnRef = Interpolants.compute_shapes(fs.mesh.parentElement, fs.quadratureRule.xigauss)

if mode2D == 'plane strain':
grad_2D_to_3D = vmap(tensor_2D_to_3D)
elif mode2D == 'axisymmetric':
raise NotImplementedError

modify_element_gradient = grad_2D_to_3D
if pressureProjectionDegree is not None:
raise NotImplementedError

def compute_partial_ivs_update_partial_ivs_prev(U, stateVariables, dt=0.0):
dispGrads = _compute_field_gradient(fs.shapeGrads, fs.mesh.conns, U, modify_element_gradient)
update_gradient = jacfwd(materialModel.compute_state_new, argnums=1)
grad_shape = stateVariables.shape + (stateVariables.shape[2],)
return _compute_updated_internal_variables_gradient(dispGrads, stateVariables, dt,\
update_gradient, grad_shape)

def compute_ivs_update_parameterized(U, stateVariables, coords, dt=0.0):
shapeGrads = vmap(FunctionSpace.map_element_shape_grads, (None, 0, None, None))(coords, fs.mesh.conns, fs.mesh.parentElement, shapeOnRef.gradients)
dispGrads = _compute_field_gradient(shapeGrads, fs.mesh.conns, U, modify_element_gradient)
update_func = materialModel.compute_state_new
output_shape = stateVariables.shape
return _compute_updated_internal_variables_gradient(dispGrads, stateVariables, dt,\
update_func, output_shape)

compute_partial_ivs_update_partial_coords = jit(lambda u, ivs, x, av, dt=0.0:
vjp(lambda z: compute_ivs_update_parameterized(u, ivs, z, dt), x)[1](av)[0])

def compute_ivs_update(U, stateVariables, dt=0.0):
dispGrads = _compute_field_gradient(fs.shapeGrads, fs.mesh.conns, U, modify_element_gradient)
update_func = materialModel.compute_state_new
output_shape = stateVariables.shape
return _compute_updated_internal_variables_gradient(dispGrads, stateVariables, dt,\
update_func, output_shape)

compute_partial_ivs_update_partial_disp = jit(lambda x, ivs, av, dt=0.0:
vjp(lambda z: compute_ivs_update(z, ivs, dt), x)[1](av)[0])

return IvsUpdateInverseFunctions(jit(compute_partial_ivs_update_partial_ivs_prev),
compute_partial_ivs_update_partial_disp,
compute_partial_ivs_update_partial_coords
)

def create_path_dependent_residual_inverse_functions(energyFunction):

compute_partial_residual_partial_ivs_prev = jit(lambda u, q, iv, x, vx:
vjp(lambda z: grad(energyFunction, 0)(u, q, z, x), iv)[1](vx)[0])

compute_partial_residual_partial_coords = jit(lambda u, q, iv, x, vx:
vjp(lambda z: grad(energyFunction, 0)(u, q, iv, z), x)[1](vx)[0])

return PathDependentResidualInverseFunctions(compute_partial_residual_partial_ivs_prev,
compute_partial_residual_partial_coords
)

def create_residual_inverse_functions(energyFunction):

compute_partial_residual_partial_coords = jit(lambda u, q, x, vx:
vjp(lambda z: grad(energyFunction, 0)(u, q, z), x)[1](vx)[0])

return ResidualInverseFunctions(compute_partial_residual_partial_coords)
67 changes: 67 additions & 0 deletions optimism/inverse/test/FiniteDifferenceFixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from optimism.test.MeshFixture import MeshFixture
from collections import namedtuple
import numpy as onp

class FiniteDifferenceFixture(MeshFixture):
def assertFiniteDifferenceCheckHasVShape(self, errors, tolerance=1e-6):
minError = min(errors)
self.assertLess(minError, tolerance, "Smallest finite difference error not less than tolerance.")
self.assertLess(minError, errors[0], "Finite difference error does not decrease from initial step size.")
self.assertLess(minError, errors[-1], "Finite difference error does not increase after reaching minimum. Try more finite difference steps.")

def build_direction_vector(self, numDesignVars, seed=123):

onp.random.seed(seed)
directionVector = onp.random.uniform(-1.0, 1.0, numDesignVars)
normVector = directionVector / onp.linalg.norm(directionVector)

return onp.array(normVector)

def compute_finite_difference_error(self, stepSize, initialParameters):
storedState = self.forward_solve(initialParameters)
originalObjective = self.compute_objective_function(storedState, initialParameters)
gradient = self.compute_gradient(storedState, initialParameters)

directionVector = self.build_direction_vector(initialParameters.shape[0])
directionalDerivative = onp.tensordot(directionVector, gradient, axes=1)

perturbedParameters = initialParameters + stepSize * directionVector
storedState = self.forward_solve(perturbedParameters)
perturbedObjective = self.compute_objective_function(storedState, perturbedParameters)

fd_value = (perturbedObjective - originalObjective) / stepSize
error = abs(directionalDerivative - fd_value)

return error

def compute_finite_difference_errors(self, stepSize, steps, initialParameters, printOutput=True):
storedState = self.forward_solve(initialParameters)
originalObjective = self.compute_objective_function(storedState, initialParameters)
gradient = self.compute_gradient(storedState, initialParameters)

directionVector = self.build_direction_vector(initialParameters.shape[0])
directionalDerivative = onp.tensordot(directionVector, gradient, axes=1)

fd_values = []
errors = []
for i in range(0, steps):
perturbedParameters = initialParameters + stepSize * directionVector
storedState = self.forward_solve(perturbedParameters)
perturbedObjective = self.compute_objective_function(storedState, perturbedParameters)

fd_value = (perturbedObjective - originalObjective) / stepSize
fd_values.append(fd_value)

error = abs(directionalDerivative - fd_value)
errors.append(error)

stepSize *= 1e-1

if printOutput:
print("\n grad'*dir | FD approx | abs error")
print("--------------------------------------------------------------------------------")
for i in range(0, steps):
print(f" {directionalDerivative} | {fd_values[i]} | {errors[i]}")

return errors

Loading