Skip to content

Commit

Permalink
Merge pull request #96 from sandialabs/chamel/equinox-quadrature
Browse files Browse the repository at this point in the history
moving QuadratureRule over to an equinox.Module and patched up a few …
  • Loading branch information
cmhamel authored Oct 28, 2024
2 parents 9ea926d + 802ba44 commit bff3a67
Show file tree
Hide file tree
Showing 13 changed files with 208 additions and 47 deletions.
Binary file added examples/hole_array/hole_array.exo
Binary file not shown.
152 changes: 152 additions & 0 deletions examples/hole_array/hole_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import jax
import jax.numpy as np

from optimism import EquationSolver as EqSolver
from optimism import FunctionSpace
from optimism.material import Neohookean as MatModel
from optimism import Mechanics
from optimism.FunctionSpace import EssentialBC
from optimism.FunctionSpace import DofManager
from optimism import Objective
from optimism import SparseMatrixAssembler
from optimism import QuadratureRule
from optimism import ReadExodusMesh
from optimism import VTKWriter

import time


if __name__ == '__main__':
mesh = ReadExodusMesh.read_exodus_mesh('./hole_array.exo')
quad_rule = QuadratureRule.create_quadrature_rule_on_triangle(degree=2)
quad_rule_face = QuadratureRule.create_quadrature_rule_1D(4)
func_space = FunctionSpace.construct_function_space(mesh, quad_rule)

ebcs = [
EssentialBC(nodeSet='yminus_nodeset', component=0),
EssentialBC(nodeSet='yminus_nodeset', component=1),
EssentialBC(nodeSet='yplus_nodeset', component=0),
EssentialBC(nodeSet='yplus_nodeset', component=1)
]

dofManager = DofManager(func_space, 2, ebcs)

props = {'elastic modulus': 3. * 10.0 * (1. - 2. * 0.3),
'poisson ratio': 0.3,
'version': 'coupled'}

mat_model = MatModel.create_material_model_functions(props)
mech_funcs = Mechanics.create_mechanics_functions(func_space, mode2D='plane strain', materialModel=mat_model)

eq_settings = EqSolver.get_settings(
use_incremental_objective=False,
max_trust_iters=100,
tr_size=0.25,
min_tr_size=1e-15,
tol=5e-8
)

internal_variables = mech_funcs.compute_initial_state()

def get_ubcs(p):
yLoc = p[0]
V = np.zeros(mesh.coords.shape)
index = (mesh.nodeSets['yplus_nodeset'], 1)
V = V.at[index].set(yLoc)
return dofManager.get_bc_values(V)


def create_field(Uu, p):
return dofManager.create_field(Uu, get_ubcs(p))


def energy_function(Uu, p):
U = create_field(Uu, p)
# internal_variables = p[1]
return mech_funcs.compute_strain_energy(U, internal_variables)


def energy_function_with_contact(Uu, lam, p):
return energy_function(Uu, p)


def assemble_sparse(Uu, p):
U = create_field(Uu, p)
internal_variables = p[1]
element_stiffnesses = mech_funcs.compute_element_stiffnesses(U, internal_variables)
return SparseMatrixAssembler.\
assemble_sparse_stiffness_matrix(element_stiffnesses, func_space.mesh.conns, dofManager)


def update_params_function(step, Uu, p):
# update displacement BCs
max_disp = 5.
max_steps = 20
# disp = p[0]
# disp = disp - max_disp / max_steps
disp = -(step / max_steps) * max_disp

p = Objective.param_index_update(p, 0, disp)

# update contact stuff
# if step % search_frequency == 0:
# U = create_field(Uu, p)
# interaction_list_1 = get_potential_interaction_list(contact_edges_1, contact_edges_1, mesh, U, max_contact_neighbors)
# interaction_list_1 = np.array([filter_edge_neighbors(eneighbors, contact_edges_1[e]) for e, eneighbors in enumerate(interaction_list_1)])
# interaction_list_2 = get_potential_interaction_list(contact_edges_2, contact_edges_2, mesh, U, max_contact_neighbors)
# interaction_list_2 = np.array([filter_edge_neighbors(eneighbors, contact_edges_2[e]) for e, eneighbors in enumerate(interaction_list_2)])
# interaction_lists = (interaction_list_1, interaction_list_2)
# p = Objective.param_index_update(p, 1, interaction_lists)

return p


def plot_solution(dispField, plotName, p):
writer = VTKWriter.VTKWriter(mesh, baseFileName=plotName)
writer.add_nodal_field(name='displacement',
nodalData=dispField,
fieldType=VTKWriter.VTKFieldType.VECTORS)

bcs = np.array(dofManager.isBc, dtype=int)
writer.add_nodal_field(name='bcs',
nodalData=bcs,
fieldType=VTKWriter.VTKFieldType.VECTORS,
dataType=VTKWriter.VTKDataType.INT)

writer.write()


def run():
Uu = dofManager.get_unknown_values(np.zeros(mesh.coords.shape))
disp = 0.0
ivs = mech_funcs.compute_initial_state()
p = Objective.Params(disp, ivs)
precond_strategy = Objective.PrecondStrategy(assemble_sparse)
objective = Objective.Objective(energy_function, Uu, p, precond_strategy)

step = 0
maxDisp = 5.0

plot_solution(create_field(Uu, p), 'output-0000', p)

steps = 20
for step in range(1, steps):
print('--------------------------------------')
print('LOAD STEP ', step)
disp = disp - maxDisp / steps

p = Objective.param_index_update(p, 0, disp)
Uu,_ = EqSolver.nonlinear_equation_solve(objective, Uu, p, eq_settings)
plot_solution(create_field(Uu, p), 'output-%s' % str(step + 1).zfill(4), p)

# run_without_contact()

if __name__ == '__main__':
times = []
for n in range(10):
start_time = time.time()
run()
total_time = time.time() - start_time
print(f' Sim {n + 1} time = {total_time}')
times.append(total_time)
print(f'Average time = {sum(times) / len(times)}')
14 changes: 7 additions & 7 deletions optimism/Mechanics.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _compute_updated_internal_variables_multi_block(functionSpace, U, states, dt

def _compute_initial_state_multi_block(fs, blockModels):

numQuadPoints = QuadratureRule.len(fs.quadratureRule)
numQuadPoints = len(fs.quadratureRule)
# Store the same number of state variables for every material to make
# vmapping easy.
#
Expand Down Expand Up @@ -225,8 +225,8 @@ def compute_element_stiffnesses(U, stateVariables, dt=0.0):


def compute_output_energy_densities_and_stresses(U, stateVariables, dt=0.0):
energy_densities = np.zeros((Mesh.num_elements(fs.mesh), QuadratureRule.len(fs.quadratureRule)))
stresses = np.zeros((Mesh.num_elements(fs.mesh), QuadratureRule.len(fs.quadratureRule), 3, 3))
energy_densities = np.zeros((Mesh.num_elements(fs.mesh), len(fs.quadratureRule)))
stresses = np.zeros((Mesh.num_elements(fs.mesh), len(fs.quadratureRule), 3, 3))
for blockKey in materialModels:
compute_output_energy_density = materialModels[blockKey].compute_energy_density
output_lagrangian = strain_energy_density_to_lagrangian_density(compute_output_energy_density)
Expand Down Expand Up @@ -291,7 +291,7 @@ def compute_output_energy_densities_and_stresses(U, stateVariables, dt=0.0):


def compute_initial_state():
shape = Mesh.num_elements(fs.mesh), QuadratureRule.len(fs.quadratureRule), 1
shape = Mesh.num_elements(fs.mesh), len(fs.quadratureRule), 1
return np.tile(materialModel.compute_initial_state(), shape)

def lagrangian_qoi(U, gradU, Q, X, dt):
Expand Down Expand Up @@ -414,19 +414,19 @@ def compute_output_potential_densities_and_stresses(U, stateVariables, dt):
return FunctionSpace.evaluate_on_block(fs, U, stateVariables, dt, output_constitutive, slice(None), modify_element_gradient=modify_element_gradient)

def compute_kinetic_energy(V):
stateVariables = np.zeros((Mesh.num_elements(fs.mesh), QuadratureRule.len(fs.quadratureRule)))
stateVariables = np.zeros((Mesh.num_elements(fs.mesh), len(fs.quadratureRule)))
return _compute_kinetic_energy(functionSpace, V, stateVariables, materialModel.density)

def compute_output_strain_energy(U, stateVariables, dt):
return _compute_strain_energy(functionSpace, U, stateVariables, dt, materialModel.compute_energy_density, modify_element_gradient)

def compute_initial_state():
shape = Mesh.num_elements(fs.mesh), QuadratureRule.len(fs.quadratureRule), 1
shape = Mesh.num_elements(fs.mesh), len(fs.quadratureRule), 1
return np.tile(materialModel.compute_initial_state(), shape)

def compute_element_masses():
V = np.zeros_like(fs.mesh.coords)
stateVariables = np.zeros((Mesh.num_elements(fs.mesh), QuadratureRule.len(fs.quadratureRule)))
stateVariables = np.zeros((Mesh.num_elements(fs.mesh), len(fs.quadratureRule)))
return _compute_element_masses(functionSpace, V, stateVariables, materialModel.density, modify_element_gradient)

def predict(U, V, A, dt):
Expand Down
53 changes: 30 additions & 23 deletions optimism/QuadratureRule.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
from collections import namedtuple
from jax.lax import switch
from jaxtyping import Array, Float
import equinox as eqx
import jax.numpy as np
import math
import numpy as onp
import scipy.special

import jax.numpy as np
from jax.lax import switch


QuadratureRule = namedtuple('QuadratureRule', ['xigauss', 'wgauss'])
QuadratureRule.__doc__ = """Quadrature rule points and weights.
A ``namedtuple`` containing ``xigauss``, a numpy array of the
class QuadratureRule(eqx.Module):
"""
Quadrature rule points and weights.
An ``equinox`` ``Module`` containing ``xigauss``, a ``jax.numpy`` array of the
coordinates of the sample points in the reference domain, and
``wgauss``, a numpy array with the weights.
``wgauss``, a ``jax.numpy`` array with the weights.
"""
xigauss: Float[Array, "nq 2"]
wgauss: Float[Array, "nq"]

def len(quadRule):
"""Gets the number of points in a quadrature rule."""
return quadRule.xigauss.shape[0]
def __iter__(self):
yield self.xigauss
yield self.wgauss

def __len__(self):
return self.xigauss.shape[0]


def create_quadrature_rule_1D(degree):
Expand Down Expand Up @@ -64,49 +70,49 @@ def create_quadrature_rule_on_triangle(degree):
and the weights.
"""
if degree <= 1:
xi = onp.array([[3.33333333333333333E-01, 3.33333333333333333E-01]])
xi = np.array([[3.33333333333333333E-01, 3.33333333333333333E-01]])

w = onp.array([ 5.00000000000000000E-01 ])
w = np.array([ 5.00000000000000000E-01 ])
elif degree == 2:
xi = onp.array([[6.66666666666666667E-01, 1.66666666666666667E-01],
xi = np.array([[6.66666666666666667E-01, 1.66666666666666667E-01],
[1.66666666666666667E-01, 6.66666666666666667E-01],
[1.66666666666666667E-01, 1.66666666666666667E-01]])

w = onp.array([1.66666666666666666E-01,
w = np.array([1.66666666666666666E-01,
1.66666666666666667E-01,
1.66666666666666667E-01])
elif degree <= 4:
xi = onp.array([[1.081030181680700E-01, 4.459484909159650E-01],
xi = np.array([[1.081030181680700E-01, 4.459484909159650E-01],
[4.459484909159650E-01, 1.081030181680700E-01],
[4.459484909159650E-01, 4.459484909159650E-01],
[8.168475729804590E-01, 9.157621350977100E-02],
[9.157621350977100E-02, 8.168475729804590E-01],
[9.157621350977100E-02, 9.157621350977100E-02]])

w = onp.array([1.116907948390055E-01,
w = np.array([1.116907948390055E-01,
1.116907948390055E-01,
1.116907948390055E-01,
5.497587182766100E-02,
5.497587182766100E-02,
5.497587182766100E-02])
elif degree <= 5:
xi = onp.array([[3.33333333333333E-01, 3.33333333333333E-01],
xi = np.array([[3.33333333333333E-01, 3.33333333333333E-01],
[5.97158717897700E-02, 4.70142064105115E-01],
[4.70142064105115E-01, 5.97158717897700E-02],
[4.70142064105115E-01, 4.70142064105115E-01],
[7.97426985353087E-01, 1.01286507323456E-01],
[1.01286507323456E-01, 7.97426985353087E-01],
[1.01286507323456E-01, 1.01286507323456E-01]])

w = onp.array([1.12500000000000E-01,
w = np.array([1.12500000000000E-01,
6.61970763942530E-02,
6.61970763942530E-02,
6.61970763942530E-02,
6.29695902724135E-02,
6.29695902724135E-02,
6.29695902724135E-02])
elif degree <= 6:
xi = onp.array([[5.01426509658179E-01, 2.49286745170910E-01],
xi = np.array([[5.01426509658179E-01, 2.49286745170910E-01],
[2.49286745170910E-01, 5.01426509658179E-01],
[2.49286745170910E-01, 2.49286745170910E-01],
[8.73821971016996E-01, 6.30890144915020E-02],
Expand All @@ -119,7 +125,7 @@ def create_quadrature_rule_on_triangle(degree):
[6.36502499121399E-01, 3.10352451033784E-01],
[3.10352451033784E-01, 5.31450498448170E-02]])

w = onp.array([5.83931378631895E-02,
w = np.array([5.83931378631895E-02,
5.83931378631895E-02,
5.83931378631895E-02,
2.54224531851035E-02,
Expand All @@ -132,7 +138,7 @@ def create_quadrature_rule_on_triangle(degree):
4.14255378091870E-02,
4.14255378091870E-02])
elif degree <= 10:
xi = onp.array([[0.33333333333333333E+00, 0.33333333333333333E+00],
xi = np.array([[0.33333333333333333E+00, 0.33333333333333333E+00],
[0.4269134091050342E-02, 0.49786543295447483E+00],
[0.49786543295447483E+00, 0.4269134091050342E-02],
[0.49786543295447483E+00, 0.49786543295447483E+00],
Expand All @@ -158,7 +164,7 @@ def create_quadrature_rule_on_triangle(degree):
[0.6297073291529187E+00, 0.3327436005886386E+00],
[0.37549070258442674E-01, 0.6297073291529187E+00]])

w = onp.array([0.4176169990259819E-01,
w = np.array([0.4176169990259819E-01,
0.36149252960283717E-02,
0.36149252960283717E-02,
0.36149252960283717E-02,
Expand Down Expand Up @@ -231,6 +237,7 @@ def _gauss_quad_1D_3pt(_):
0., 0.])
return xi,w


def _gauss_quad_1D_4pt(_):
xi = np.array([-0.8611363115940526 , -0.33998104358485626, 0.33998104358485626,
0.8611363115940526 , 0.])
Expand Down
6 changes: 3 additions & 3 deletions optimism/inverse/test/test_J2Plastic_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def update_internal_vars_test(Uu, ivs_prev):
dc_du, dc_dc_n = update_internal_variables_derivs(Uu, self.ivs_prev)

nElems = Mesh.num_elements(self.mesh)
nQpsPerElem = QuadratureRule.len(self.quadRule)
nQpsPerElem = len(self.quadRule)
nIntVars = 10
nFreeDofs = Uu.shape[0]

Expand All @@ -205,7 +205,7 @@ def update_internal_vars_test(U, ivs_prev):
dc_du, dc_dc_n = update_internal_variables_derivs(U, self.ivs_prev)

nElems = Mesh.num_elements(self.mesh)
nQpsPerElem = QuadratureRule.len(self.quadRule)
nQpsPerElem = len(self.quadRule)
nIntVars = 10
nDims = 2
nNodes = Mesh.num_nodes(self.mesh)
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_state_derivs_computed_locally_at_plastic_step(self):
dc_dc_n = ivsUpdateInverseFuncs.ivs_update_jac_ivs_prev(U, self.ivs_prev)

nElems = Mesh.num_elements(self.mesh)
nQpsPerElem = QuadratureRule.len(self.quadRule)
nQpsPerElem = len(self.quadRule)
nIntVars = 10

self.assertEqual(dc_dc_n.shape, (nElems,nQpsPerElem,nIntVars,nIntVars))
Expand Down
2 changes: 1 addition & 1 deletion optimism/material/test/test_J2Plastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_plasticity_with_mesh(self):
Ubc = dofManager.get_bc_values(U)

nElems = Mesh.num_elements(mesh)
nQpsPerElem = QuadratureRule.len(quadRule)
nQpsPerElem = len(quadRule)
internalVariables = mechFuncs.compute_initial_state()

tOld = 0.0
Expand Down
2 changes: 1 addition & 1 deletion optimism/phasefield/PhaseField.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def compute_strain_energy_density(U, Q, dt=0.0):
return FunctionSpace.evaluate_on_block(fs, U, Q, dt, L_strain, slice(None), modify_element_gradient=modify_element_gradient)

def compute_initial_state():
return materialModel.compute_initial_state((Mesh.num_elements(fs.mesh), QuadratureRule.len(fs.quadratureRule), 1))
return materialModel.compute_initial_state((Mesh.num_elements(fs.mesh), len(fs.quadratureRule), 1))

L_compute_state_new = energy_density_to_lagrangian_density(materialModel.compute_state_new)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def setUp(self):
materialModel)

self.nElements = Mesh.num_elements(self.mesh)
self.nQuadPtsPerElem = QuadratureRule.len(quadRule)
self.nQuadPtsPerElem = len(quadRule)
self.stateVars = self.bvpFunctions.compute_initial_state()

dofToUnknown = self.dofManager.dofToUnknown.reshape(self.U.shape)
Expand Down
Loading

0 comments on commit bff3a67

Please sign in to comment.