Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Hypervisco #68

Merged
merged 8 commits into from
Nov 3, 2023
255 changes: 255 additions & 0 deletions examples/uniaxial/UniaxialCycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
from functools import partial
import jax
from jax import numpy as np
from matplotlib import pyplot as plt

from optimism import EquationSolver as EqSolver
from optimism import FunctionSpace
# from optimism.material import J2Plastic as Material
#from optimism.material import Neohookean
from optimism.material import Neohookean
from optimism.material import HyperViscoelastic as Material
from optimism import Mechanics
from optimism import Mesh
from optimism import Objective
from optimism import QuadratureRule
from optimism import SparseMatrixAssembler
from optimism import VTKWriter

dt = 0.01


class Uniaxial:

def __init__(self):
self.w = 0.1
self.L = 1.0
N = 15
M = 4
L = self.L
w = self.w
xRange = [0.0, L]
yRange = [0.0, w]

coords, conns = Mesh.create_structured_mesh_data(N, M, xRange, yRange)
blocks = {'block_0': np.arange(conns.shape[0])}
mesh = Mesh.construct_mesh_from_basic_data(coords, conns, blocks=blocks)
pOrder = 2
mesh = Mesh.create_higher_order_mesh_from_simplex_mesh(mesh, order=pOrder)

nodeSets = {'left': np.flatnonzero(mesh.coords[:,0] < xRange[0] + 1e-8),
'right': np.flatnonzero(mesh.coords[:,0] > xRange[1] - 1e-8),
'bottom': np.flatnonzero(mesh.coords[:,1] < yRange[0] + 1e-8)}
self.mesh = Mesh.mesh_with_nodesets(mesh, nodeSets)

quadRule = QuadratureRule.create_quadrature_rule_on_triangle(degree=2*(pOrder-1))
self.fs = FunctionSpace.construct_function_space(self.mesh, quadRule)

ebcs = [FunctionSpace.EssentialBC(nodeSet='left', component=0),
FunctionSpace.EssentialBC(nodeSet='right', component=0),
FunctionSpace.EssentialBC(nodeSet='bottom', component=1)]

self.dofManager = FunctionSpace.DofManager(self.fs, dim=2, EssentialBCs=ebcs)

# E = 10.0
# nu = 0.25
# Y0 = 0.01*E
# H = E/100
# props = {'elastic modulus': E,
# 'poisson ratio': nu,
# 'yield strength': Y0,
# 'hardening model': 'linear',
# 'hardening modulus': H}

# materialModel = Material.create_material_model_functions(props)

self.K_eq = 1.e2
self.G_eq = 1.0

G_neq_1 = 5.0
tau_1 = 0.1

props = {
'equilibrium bulk modulus' : self.K_eq,
'equilibrium shear modulus' : self.G_eq,
#
'non equilibrium shear modulus': G_neq_1,
'relaxation time' : tau_1,
}
materialModel = Material.create_material_model_functions(props)

self.mechanicsFunctions = Mechanics.create_mechanics_functions(
self.fs, "plane strain", materialModel, dt=dt
)

self.outputForce = []
self.outputDisp = []


def assemble_sparse(self, Uu, p):
U = self.create_field(Uu, p)
internalVariables = p[1]
elementStiffnesses = self.mechanicsFunctions.\
compute_element_stiffnesses(U, internalVariables)
return SparseMatrixAssembler.assemble_sparse_stiffness_matrix(elementStiffnesses,
self.mesh.conns,
self.dofManager)


def energy_function(self, Uu, p):
U = self.create_field(Uu, p)
internalVariables = p[1]
return self.mechanicsFunctions.compute_strain_energy(U, internalVariables)


@partial(jax.jit, static_argnums=0)
@partial(jax.value_and_grad, argnums=2)
def compute_reactions_from_bcs(self, Uu, Ubc, internalVariables):
U = self.dofManager.create_field(Uu, Ubc)
return self.mechanicsFunctions.compute_strain_energy(U, internalVariables)


def create_field(self, Uu, p):
return self.dofManager.create_field(Uu, self.get_ubcs(p))


def get_ubcs(self, p):
endDisp = p[0]
EbcIndex = (self.mesh.nodeSets['right'],0)
V = np.zeros_like(self.mesh.coords).at[EbcIndex].set(endDisp)
return self.dofManager.get_bc_values(V)


def write_output(self, Uu, p, step):
print('writing output')
vtkFileName = 'uniaxial-' + str(step).zfill(3)
writer = VTKWriter.VTKWriter(self.mesh, baseFileName=vtkFileName)

U = self.create_field(Uu, p)
internalVariables = p[1]

writer.add_nodal_field(name='displacement', nodalData=U, fieldType=VTKWriter.VTKFieldType.VECTORS)

bcs = np.array(self.dofManager.isBc, dtype=int)
writer.add_nodal_field(name='bcs', nodalData=bcs, fieldType=VTKWriter.VTKFieldType.VECTORS, dataType=VTKWriter.VTKDataType.INT)

Ubc = self.dofManager.get_bc_values(U)
_,rxnBc = self.compute_reactions_from_bcs(Uu, Ubc, internalVariables)
reactions = np.zeros(U.shape).at[self.dofManager.isBc].set(rxnBc)
writer.add_nodal_field(name='reactions', nodalData=reactions, fieldType=VTKWriter.VTKFieldType.VECTORS)

if hasattr(Material, 'EQPS'):
eqpsField = internalVariables[:,:,Material.EQPS]
cellEqpsField = FunctionSpace.\
project_quadrature_field_to_element_field(self.fs, eqpsField)
writer.add_cell_field(name='eqps', cellData=cellEqpsField, fieldType=VTKWriter.VTKFieldType.SCALARS)

strainEnergyDensities, stresses = \
self.mechanicsFunctions.\
compute_output_energy_densities_and_stresses(U, internalVariables)
cellStrainEnergyDensities = FunctionSpace.\
project_quadrature_field_to_element_field(self.fs, strainEnergyDensities)
cellStresses = FunctionSpace.\
project_quadrature_field_to_element_field(self.fs, stresses)
writer.add_cell_field(name='strain_energy_density',
cellData=cellStrainEnergyDensities,
fieldType=VTKWriter.VTKFieldType.SCALARS)
writer.add_cell_field(name='stress',
cellData=cellStresses,
fieldType=VTKWriter.VTKFieldType.TENSORS)

writer.write()

self.outputForce.append(float(np.sum(reactions[self.mesh.nodeSets['right'],0])))
self.outputDisp.append(float(p[0]))


def run(self):

Uu = self.dofManager.get_unknown_values(np.zeros(self.mesh.coords.shape))
settings = EqSolver.get_settings(max_cumulative_cg_iters=100,
max_trust_iters=1000,
use_preconditioned_inner_product_for_cg=True)

xDisp = 0.0
state = self.mechanicsFunctions.compute_initial_state()
p = Objective.Params(xDisp, state)

precondStrategy = Objective.PrecondStrategy(self.assemble_sparse)
objective = Objective.ScaledObjective(self.energy_function, Uu, p, precondStrategy)

self.write_output(Uu, p, step=0)

N = 250

# maxDisp = self.L*0.05
maxDisp = self.L
for i in range(1, N+1):
print('LOAD STEP ', i, '------------------------\n')
if i < N / 10:
xDisp = i/((N / 10)*maxDisp)
p = Objective.param_index_update(p, 0, xDisp)

Uu = EqSolver.nonlinear_equation_solve(objective, Uu, p, settings)

state = self.mechanicsFunctions.\
compute_updated_internal_variables(self.create_field(Uu, p), p[1])
p = Objective.param_index_update(p, 1, state)

self.write_output(Uu, p, i)


def make_FD_plot(self):
plt.figure(1)
plt.plot(self.outputDisp, self.outputForce, marker='o')
plt.xlabel('Displacement')
plt.ylabel('Force')
# plt.savefig('uniaxial_FD.pdf')

times = np.linspace(0.0, 2.0, num=len(self.outputDisp))
plt.figure(2)
plt.plot(times, self.outputDisp, marker='o')
plt.xlabel('Time')
plt.ylabel('Displacement')
# plt.savefig('uniaxial_TU.pdf')

plt.figure(3)
plt.plot(times, self.outputForce, marker='o')
plt.xlabel('Time')
plt.ylabel('Force')
# plt.savefig('uniaxial_TF.pdf')


if __name__=='__main__':
app = Uniaxial()
app.run()
app.make_FD_plot()

# reset mechanics functions to just do the equilibrium
app.outputDisp = []
app.outputForce = []

props = {
# 'bulk modulus' : app.K_eq,
# 'shear modulus': app.G_eq,
'elastic modulus': (9. * app.K_eq * app.G_eq) / (3. * app.K_eq + app.G_eq),
'poisson ratio' : (3. * app.K_eq - 2. * app.G_eq) / (2. * (3. * app.K_eq + app.G_eq)),
'version' : 'adagio'
}
materialModel = Neohookean.create_material_model_functions(props)

app.mechanicsFunctions = Mechanics.create_mechanics_functions(
app.fs, "plane strain", materialModel, dt=0.0
)

app.run()
app.make_FD_plot()

plt.figure(1)
plt.savefig('uniaxial_FD.pdf')

plt.figure(2)
plt.savefig('uniaxial_TU.pdf')

plt.figure(3)
plt.savefig('uniaxial_TF.pdf')
12 changes: 7 additions & 5 deletions optimism/Mechanics.py
Original file line number Diff line number Diff line change
@@ -246,7 +246,9 @@ def compute_initial_state():
######


def create_mechanics_functions(functionSpace, mode2D, materialModel, pressureProjectionDegree=None):
def create_mechanics_functions(functionSpace, mode2D, materialModel,
pressureProjectionDegree=None,
dt=0.0):
Comment on lines +249 to +251
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Late comment. Can I ask, why did dt get moved into the factory function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, things weren't properly connected so dt wasn't making its way to the material models. Maybe I misunderstood the interface as it was though?

fs = functionSpace

if mode2D == 'plane strain':
@@ -267,23 +269,23 @@ def modify_element_gradient(elemGrads, elemShapes, elemVols, elemNodalDisps, ele
return grad_2D_to_3D(elemGrads, elemShapes, elemVols, elemNodalDisps, elemNodalCoords)


def compute_strain_energy(U, stateVariables, dt=0.0):
def compute_strain_energy(U, stateVariables, dt=dt):
return _compute_strain_energy(fs, U, stateVariables, dt, materialModel.compute_energy_density, modify_element_gradient)


def compute_updated_internal_variables(U, stateVariables, dt=0.0):
def compute_updated_internal_variables(U, stateVariables, dt=dt):
return _compute_updated_internal_variables(fs, U, stateVariables, dt, materialModel.compute_state_new, modify_element_gradient)


def compute_element_stiffnesses(U, stateVariables, dt=0.0):
def compute_element_stiffnesses(U, stateVariables, dt=dt):
return _compute_element_stiffnesses(U, stateVariables, dt, fs, materialModel.compute_energy_density, modify_element_gradient)


output_lagrangian = strain_energy_density_to_lagrangian_density(materialModel.compute_energy_density)
output_constitutive = value_and_grad(output_lagrangian, 1)


def compute_output_energy_densities_and_stresses(U, stateVariables, dt=0.0):
def compute_output_energy_densities_and_stresses(U, stateVariables, dt=dt):
return FunctionSpace.evaluate_on_block(fs, U, stateVariables, dt, output_constitutive, slice(None), modify_element_gradient=modify_element_gradient)


216 changes: 216 additions & 0 deletions optimism/material/HyperViscoelastic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import jax.numpy as np
from jax import vmap
from jax.scipy import linalg
from optimism import TensorMath
from optimism.material.MaterialModel import MaterialModel


# props
PROPS_K_eq = 0
PROPS_G_eq = 1
PROPS_G_neq = 2
PROPS_TAU = 3

NUM_PRONY_TERMS = -1

# isvs
VISCOUS_DISTORTION = slice(0, 9)

def create_material_model_functions(properties):

# prop processing
density = properties.get('density')
props = _make_properties(properties)

# energy function wrapper
def energy_density(dispGrad, state, dt):
return _energy_density(dispGrad, state, dt, props)

# wrapper for state var ics
def compute_initial_state(shape=(1,)):
# num_prony_terms = properties['number of prony terms']

# def vmap_body(_):
# return np.identity(3).ravel()

# state = np.hstack(vmap(vmap_body, in_axes=(0,))(np.arange(num_prony_terms)))
state = np.identity(3).ravel()
return state

# update state vars wrapper
def compute_state_new(dispGrad, state, dt):
state = _compute_state_new(dispGrad, state, dt, props)
return state

return MaterialModel(
energy_density,
compute_initial_state,
compute_state_new,
density
)

# implementation
def _make_properties(properties):
# assert properties['number of prony terms'] > 0, 'Need at least 1 prony term'
assert 'equilibrium bulk modulus' in properties.keys()
assert 'equilibrium shear modulus' in properties.keys()
# for n in range(1, properties['number of prony terms'] + 1):
# assert 'non equilibrium shear modulus %s' % n in properties.keys()
# assert 'relaxation time %s' % n in properties.keys()
assert 'non equilibrium shear modulus' in properties.keys()
assert 'relaxation time' in properties.keys()

print('Equilibrium properties')
print(' Bulk modulus = %s' % properties['equilibrium bulk modulus'])
print(' Shear modulus = %s' % properties['equilibrium shear modulus'])
print('Prony branch properties')
print(' Shear modulus = %s' % properties['non equilibrium shear modulus'])
print(' Relaxation time = %s' % properties['relaxation time'])
# this is dirty, fuck jax (can't use an int from a jax numpy array or else jit tries to trace that)
# global NUM_PRONY_TERMS
# NUM_PRONY_TERMS = properties['number of prony terms']

# first pack equilibrium properties
props = np.array([
properties['equilibrium bulk modulus'],
properties['equilibrium shear modulus'],
properties['non equilibrium shear modulus'],
properties['relaxation time']
])

# props = np.hstack((props, properties['number of prony terms']))

# for n in range(1, properties['number of prony terms'] + 1):
# print('Prony branch %s properties' % n)
# print(' Shear modulus = %s' % properties['non equilibrium shear modulus %s' % n])
# print(' Relaxation time = %s' % properties['relaxation time %s' % n])
# props = np.hstack(
# (props, np.array([properties['non equilibrium shear modulus %s' % n],
# properties['relaxation time %s' % n]])))



return props

def _energy_density(dispGrad, state, dt, props):
W_eq = _eq_strain_energy(dispGrad, props)
W_neq = _neq_strain_energy(dispGrad, state, dt, props)
return W_eq + W_neq

# TODO generalize to arbitrary strain energy density
def _eq_strain_energy(dispGrad, props):
K, G = props[PROPS_K_eq], props[PROPS_G_eq]
F = dispGrad + np.eye(3)
J = np.linalg.det(F)
J23 = np.power(J, -2.0 / 3.0)
I1Bar = J23 * np.tensordot(F,F)
Wvol = 0.5 * K * (0.5 * J**2 - 0.5 - np.log(J))
Wdev = 0.5 * G * (I1Bar - 3.0)
return Wdev + Wvol

def _neq_strain_energy(dispGrad, stateOld, dt, props):
I = np.identity(3)
F = dispGrad + I
state_new = _compute_state_new(dispGrad, stateOld, dt, props)
# Fvs = state_new.reshape((NUM_PRONY_TERMS, 3, 3))
Fv_new = state_new.reshape((3, 3))

G_neq = props[PROPS_G_neq]
tau = props[PROPS_TAU]
eta = G_neq * tau

Fe = F @ np.linalg.inv(Fv_new)
Ee = 0.5 * TensorMath.mtk_log_sqrt(Fe.T @ Fe)
Me = 2. * G_neq * Ee
M_bar = TensorMath.norm_of_deviator_squared(Me)
gamma_dot = M_bar / eta
# visco_energy = (dt / (G_neq * tau)) * M_bar**2
visco_energy = 0.5 * dt * eta * gamma_dot**2

W_neq = G_neq * TensorMath.norm_of_deviator_squared(Ee) + visco_energy
# def vmap_body(n, Fv):
# G_neq = props[PROPS_G_neq + 2 * n]
# # tau = props[PROPS_TAU + 2 * n]
Comment on lines +131 to +133
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, If we end up implementing multiple Prony terms, I'd suggest trying with a jax.lax.fori loop, or jax.lax.scan. I think it might be more straightforward than vmap.


# Fe = F @ np.linalg.inv(Fv)
# Ee = 0.5 * TensorMath.mtk_log_sqrt(Fe.T @ Fe)

# # viscous shearing
# # Me = 2. * G_neq * Ee
# # M_bar = TensorMath.norm_of_deviator_squared(Me)
# # visco_energy = (dt / (G_neq * tau)) * M_bar**2

# # still need another term I think
# W_neq = G_neq * TensorMath.norm_of_deviator_squared(Ee) #+ visco_energy
# return W_neq

# W_neq = np.sum(vmap(vmap_body, in_axes=(0, 0))(np.arange(NUM_PRONY_TERMS), Fvs))

return W_neq

# state update
def _compute_state_new(dispGrad, stateOld, dt, props):
state_inc = _compute_state_increment(dispGrad, stateOld, dt, props)
# Fv_olds = stateOld.reshape((NUM_PRONY_TERMS, 3, 3))
# Fv_incs = state_inc.reshape((NUM_PRONY_TERMS, 3, 3))

# def vmap_body(n, Fv_old, Fv_inc):
# Fv_new = Fv_inc @ Fv_old
# return Fv_new.ravel()

# state_new = np.hstack(vmap(vmap_body, in_axes=(0, 0, 0))(np.arange(NUM_PRONY_TERMS), Fv_olds, Fv_incs))

Fv_old = stateOld.reshape((3, 3))
Fv_inc = state_inc.reshape((3, 3))
state_new = (Fv_inc @ Fv_old).ravel()
return state_new

def _compute_state_increment(dispGrad, stateOld, dt, props):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The J2 model had this function so that we could put the finite strain and the infinitesimal strain versions in one implementation. For this model, which only makes sense as a finite deformation model, we might as well do all of the update in the _compute_state_new() function which will make it easier to read. Maybe this will speed things up a little, too.

I = np.identity(3)
F = dispGrad + I
# Fv_olds = stateOld.reshape((NUM_PRONY_TERMS, 3, 3))

# def vmap_body(n, Fv_old):
# # TODO add shift factor
# G_neq = props[PROPS_G_neq + 2 * n]
# tau = props[PROPS_TAU + 2 * n]

# # kinematics
# Fe_trial = F @ np.linalg.inv(Fv_old)
# Ee_trial = 0.5 * TensorMath.mtk_log_sqrt(Fe_trial.T @ Fe_trial)
# Ee_dev = Ee_trial - (1. / 3.) * np.trace(Ee_trial) * I

# # updates
# integration_factor = 1. / (1. + dt / tau)

# Me = 2.0 * G_neq * Ee_dev
# Me = integration_factor * Me

# Dv = (1. / (2. * G_neq * tau)) * Me
# A = dt * Dv

# Fv_inc = linalg.expm(A)

# return Fv_inc.ravel()

# state_inc = np.hstack(vmap(vmap_body, in_axes=(0, 0))(np.arange(NUM_PRONY_TERMS), Fv_olds))

Fv_old = stateOld.reshape((3, 3))
G_neq = props[PROPS_G_neq]
tau = props[PROPS_TAU]

Fe_trial = F @ np.linalg.inv(Fv_old)
Ee_trial = 0.5 * TensorMath.mtk_log_sqrt(Fe_trial.T @ Fe_trial)
Ee_dev = Ee_trial - (1. / 3.) * np.trace(Ee_trial) * I

integration_factor = 1. / (1. + dt / tau)

Me = 2.0 * G_neq * Ee_dev
Me = integration_factor * Me

Dv = (1. / (2. * G_neq * tau)) * Me
A = dt * Dv

Fv_inc = linalg.expm(A)

return Fv_inc.ravel()