-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ralberd/path dependent adjoint #69
Changes from all commits
ea74838
2456a9f
c2a0cfd
9956926
112b10b
905d458
904a916
e8e2941
e287a6b
6cfd43d
f618c1b
9890dff
638dab4
1bb5080
056a392
670885f
6d6b713
24b4f32
2b15305
65732b3
5d3c31a
ebb5fb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from optimism import FunctionSpace | ||
from optimism import Interpolants | ||
from optimism import Mesh | ||
from optimism.FunctionSpace import compute_element_volumes | ||
from optimism.FunctionSpace import compute_element_volumes_axisymmetric | ||
from optimism.FunctionSpace import map_element_shape_grads | ||
from jax import vmap | ||
|
||
def construct_function_space_for_adjoint(coords, mesh, quadratureRule, mode2D='cartesian'): | ||
|
||
shapeOnRef = Interpolants.compute_shapes(mesh.parentElement, quadratureRule.xigauss) | ||
|
||
shapes = vmap(lambda elConns, elShape: elShape, (0, None))(mesh.conns, shapeOnRef.values) | ||
|
||
shapeGrads = vmap(map_element_shape_grads, (None, 0, None, None))(coords, mesh.conns, mesh.parentElement, shapeOnRef.gradients) | ||
|
||
if mode2D == 'cartesian': | ||
el_vols = compute_element_volumes | ||
isAxisymmetric = False | ||
elif mode2D == 'axisymmetric': | ||
el_vols = compute_element_volumes_axisymmetric | ||
isAxisymmetric = True | ||
vols = vmap(el_vols, (None, 0, None, 0, None))(coords, mesh.conns, mesh.parentElement, shapes, quadratureRule.wgauss) | ||
|
||
# unpack mesh and remake a mesh to make sure we get all the AD | ||
mesh = Mesh.Mesh(coords=coords, conns=mesh.conns, simplexNodesOrdinals=mesh.simplexNodesOrdinals, | ||
parentElement=mesh.parentElement, parentElement1d=mesh.parentElement1d, blocks=mesh.blocks, | ||
nodeSets=mesh.nodeSets, sideSets=mesh.sideSets) | ||
|
||
return FunctionSpace.FunctionSpace(shapes, vols, shapeGrads, mesh, quadratureRule, isAxisymmetric) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from collections import namedtuple | ||
|
||
from optimism.JaxConfig import * | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You shouldn't need this import. It's a legacy thing I'm trying to remove. If it doesn't work without it, show me the error and I'll fix it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I happen to like import * from JaxConfig. For me, its intent was to avoid having to constantly duplicate lots of lines like: from jax import grad, import blah, ... which I found to be annoying in every file. It's an example where the DRY (dont repeat yourself) principal is in conflict with other principals. For me, I like DRY here, but I know there is a trade off. I understand that it pull in extra things, but that is essentially the point! Without it, the tops of files are more verbose and it takes longer to get to the actual work of the file without such things. If JaxConfig pull in things that may be duplicated or we don't want, we should delete them from JaxConfig. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Plus, the way I think about it at least, I don't want to type jax.grad, because that is essentially hard-coding the implemenation to jax (which we are hard coded to anyway, but whatever). Its like in c++, where you can say using MatrixClass = Eigen::Matrix, but, potentially, if one is very careful, it is possible to switch out the matrix by just changing MatrixClass = Armadillo::Matrix. So, I don't think of it as polluting the namespace, I think of it as concretizing the specific grad our library is using. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand what you mean, but there are better ways of doing that. The wildcard import is discouraged pretty much universally because it makes it so hard to understand where things are coming from, and names can collide and unintentionally overwrite the one you want. It's essentially like a GOTO - you'd have to follow all the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, can we basically put things like grad = jax.grad, etc. in init then? Basically things which are essentially always used and in the way we generally want to use them, so we can establish a standard way of doing it. That was the intent of the .Config file. In C++, the equivalent is basically #include "types.hpp", where types are put into the same namespace as the project on purpose. That is essentially the same thing I want to do in python. The issues is that putting it in init puts in in EVERY file, where as "from optimism.types import *" is more selective. So I don't really appreciate what the issue with that is. We are not doing "from jax import *", we are including a types header. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should discuss this offline. Ryan, you can leave things as they are until we come to an agreement. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, go ahead and leave the wildcard import. We'll make a sweeping change in the future to get rid of them. Try to leave out the JaxConfig import in the future. |
||
from optimism import FunctionSpace | ||
from optimism import Interpolants | ||
from optimism.TensorMath import tensor_2D_to_3D | ||
|
||
IvsUpdateInverseFunctions = namedtuple('IvsUpdateInverseFunctions', | ||
['ivs_update_jac_ivs_prev', | ||
'ivs_update_jac_disp_vjp', | ||
'ivs_update_jac_coords_vjp']) | ||
|
||
PathDependentResidualInverseFunctions = namedtuple('PathDependentResidualInverseFunctions', | ||
['residual_jac_ivs_prev_vjp', | ||
'residual_jac_coords_vjp']) | ||
|
||
ResidualInverseFunctions = namedtuple('ResidualInverseFunctions', | ||
['residual_jac_coords_vjp']) | ||
|
||
def _compute_element_field_gradient(U, elemShapeGrads, elemConnectivity, modify_element_gradient): | ||
elemNodalDisps = U[elemConnectivity] | ||
elemGrads = vmap(FunctionSpace.compute_quadrature_point_field_gradient, (None, 0))(elemNodalDisps, elemShapeGrads) | ||
elemGrads = modify_element_gradient(elemGrads) | ||
return elemGrads | ||
|
||
def _compute_field_gradient(shapeGrads, conns, nodalField, modify_element_gradient): | ||
return vmap(_compute_element_field_gradient, (None,0,0,None))(nodalField, shapeGrads, conns, modify_element_gradient) | ||
|
||
def _compute_updated_internal_variables_gradient(dispGrads, states, dt, compute_state_new, output_shape): | ||
dgQuadPointRavel = dispGrads.reshape(dispGrads.shape[0]*dispGrads.shape[1],*dispGrads.shape[2:]) | ||
stQuadPointRavel = states.reshape(states.shape[0]*states.shape[1],*states.shape[2:]) | ||
statesNew = vmap(compute_state_new, (0, 0, None))(dgQuadPointRavel, stQuadPointRavel, dt) | ||
return statesNew.reshape(output_shape) | ||
|
||
|
||
def create_ivs_update_inverse_functions(functionSpace, mode2D, materialModel, pressureProjectionDegree=None): | ||
fs = functionSpace | ||
shapeOnRef = Interpolants.compute_shapes(fs.mesh.parentElement, fs.quadratureRule.xigauss) | ||
|
||
if mode2D == 'plane strain': | ||
grad_2D_to_3D = vmap(tensor_2D_to_3D) | ||
elif mode2D == 'axisymmetric': | ||
raise NotImplementedError | ||
|
||
modify_element_gradient = grad_2D_to_3D | ||
if pressureProjectionDegree is not None: | ||
raise NotImplementedError | ||
|
||
def compute_partial_ivs_update_partial_ivs_prev(U, stateVariables, dt=0.0): | ||
dispGrads = _compute_field_gradient(fs.shapeGrads, fs.mesh.conns, U, modify_element_gradient) | ||
update_gradient = jacfwd(materialModel.compute_state_new, argnums=1) | ||
grad_shape = stateVariables.shape + (stateVariables.shape[2],) | ||
return _compute_updated_internal_variables_gradient(dispGrads, stateVariables, dt,\ | ||
update_gradient, grad_shape) | ||
|
||
def compute_ivs_update_parameterized(U, stateVariables, coords, dt=0.0): | ||
shapeGrads = vmap(FunctionSpace.map_element_shape_grads, (None, 0, None, None))(coords, fs.mesh.conns, fs.mesh.parentElement, shapeOnRef.gradients) | ||
dispGrads = _compute_field_gradient(shapeGrads, fs.mesh.conns, U, modify_element_gradient) | ||
update_func = materialModel.compute_state_new | ||
output_shape = stateVariables.shape | ||
return _compute_updated_internal_variables_gradient(dispGrads, stateVariables, dt,\ | ||
update_func, output_shape) | ||
|
||
compute_partial_ivs_update_partial_coords = jit(lambda u, ivs, x, av, dt=0.0: | ||
vjp(lambda z: compute_ivs_update_parameterized(u, ivs, z, dt), x)[1](av)[0]) | ||
|
||
def compute_ivs_update(U, stateVariables, dt=0.0): | ||
dispGrads = _compute_field_gradient(fs.shapeGrads, fs.mesh.conns, U, modify_element_gradient) | ||
update_func = materialModel.compute_state_new | ||
output_shape = stateVariables.shape | ||
return _compute_updated_internal_variables_gradient(dispGrads, stateVariables, dt,\ | ||
update_func, output_shape) | ||
|
||
compute_partial_ivs_update_partial_disp = jit(lambda x, ivs, av, dt=0.0: | ||
vjp(lambda z: compute_ivs_update(z, ivs, dt), x)[1](av)[0]) | ||
|
||
return IvsUpdateInverseFunctions(jit(compute_partial_ivs_update_partial_ivs_prev), | ||
compute_partial_ivs_update_partial_disp, | ||
compute_partial_ivs_update_partial_coords | ||
) | ||
|
||
def create_path_dependent_residual_inverse_functions(energyFunction): | ||
|
||
compute_partial_residual_partial_ivs_prev = jit(lambda u, q, iv, x, vx: | ||
vjp(lambda z: grad(energyFunction, 0)(u, q, z, x), iv)[1](vx)[0]) | ||
|
||
compute_partial_residual_partial_coords = jit(lambda u, q, iv, x, vx: | ||
vjp(lambda z: grad(energyFunction, 0)(u, q, iv, z), x)[1](vx)[0]) | ||
|
||
return PathDependentResidualInverseFunctions(compute_partial_residual_partial_ivs_prev, | ||
compute_partial_residual_partial_coords | ||
) | ||
|
||
def create_residual_inverse_functions(energyFunction): | ||
|
||
compute_partial_residual_partial_coords = jit(lambda u, q, x, vx: | ||
vjp(lambda z: grad(energyFunction, 0)(u, q, z), x)[1](vx)[0]) | ||
|
||
return ResidualInverseFunctions(compute_partial_residual_partial_coords) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from optimism.test.MeshFixture import MeshFixture | ||
from collections import namedtuple | ||
import numpy as onp | ||
|
||
class FiniteDifferenceFixture(MeshFixture): | ||
def assertFiniteDifferenceCheckHasVShape(self, errors, tolerance=1e-6): | ||
minError = min(errors) | ||
self.assertLess(minError, tolerance, "Smallest finite difference error not less than tolerance.") | ||
self.assertLess(minError, errors[0], "Finite difference error does not decrease from initial step size.") | ||
self.assertLess(minError, errors[-1], "Finite difference error does not increase after reaching minimum. Try more finite difference steps.") | ||
|
||
def build_direction_vector(self, numDesignVars, seed=123): | ||
|
||
onp.random.seed(seed) | ||
directionVector = onp.random.uniform(-1.0, 1.0, numDesignVars) | ||
normVector = directionVector / onp.linalg.norm(directionVector) | ||
|
||
return onp.array(normVector) | ||
|
||
def compute_finite_difference_error(self, stepSize, initialParameters): | ||
storedState = self.forward_solve(initialParameters) | ||
originalObjective = self.compute_objective_function(storedState, initialParameters) | ||
gradient = self.compute_gradient(storedState, initialParameters) | ||
|
||
directionVector = self.build_direction_vector(initialParameters.shape[0]) | ||
directionalDerivative = onp.tensordot(directionVector, gradient, axes=1) | ||
|
||
perturbedParameters = initialParameters + stepSize * directionVector | ||
storedState = self.forward_solve(perturbedParameters) | ||
perturbedObjective = self.compute_objective_function(storedState, perturbedParameters) | ||
|
||
fd_value = (perturbedObjective - originalObjective) / stepSize | ||
error = abs(directionalDerivative - fd_value) | ||
|
||
return error | ||
|
||
def compute_finite_difference_errors(self, stepSize, steps, initialParameters, printOutput=True): | ||
storedState = self.forward_solve(initialParameters) | ||
originalObjective = self.compute_objective_function(storedState, initialParameters) | ||
gradient = self.compute_gradient(storedState, initialParameters) | ||
|
||
directionVector = self.build_direction_vector(initialParameters.shape[0]) | ||
directionalDerivative = onp.tensordot(directionVector, gradient, axes=1) | ||
|
||
fd_values = [] | ||
errors = [] | ||
for i in range(0, steps): | ||
perturbedParameters = initialParameters + stepSize * directionVector | ||
storedState = self.forward_solve(perturbedParameters) | ||
perturbedObjective = self.compute_objective_function(storedState, perturbedParameters) | ||
|
||
fd_value = (perturbedObjective - originalObjective) / stepSize | ||
fd_values.append(fd_value) | ||
|
||
error = abs(directionalDerivative - fd_value) | ||
errors.append(error) | ||
|
||
stepSize *= 1e-1 | ||
|
||
if printOutput: | ||
print("\n grad'*dir | FD approx | abs error") | ||
print("--------------------------------------------------------------------------------") | ||
for i in range(0, steps): | ||
print(f" {directionalDerivative} | {fd_values[i]} | {errors[i]}") | ||
|
||
return errors | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, those warnings were annoying.