-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Hypervisco #68
[WIP] Hypervisco #68
Changes from all commits
b0321db
fc704ce
4771751
ebf08ec
667bc4a
5403f5d
0e5f382
35d0bb5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
from functools import partial | ||
import jax | ||
from jax import numpy as np | ||
from matplotlib import pyplot as plt | ||
|
||
from optimism import EquationSolver as EqSolver | ||
from optimism import FunctionSpace | ||
# from optimism.material import J2Plastic as Material | ||
#from optimism.material import Neohookean | ||
from optimism.material import Neohookean | ||
from optimism.material import HyperViscoelastic as Material | ||
from optimism import Mechanics | ||
from optimism import Mesh | ||
from optimism import Objective | ||
from optimism import QuadratureRule | ||
from optimism import SparseMatrixAssembler | ||
from optimism import VTKWriter | ||
|
||
dt = 0.01 | ||
|
||
|
||
class Uniaxial: | ||
|
||
def __init__(self): | ||
self.w = 0.1 | ||
self.L = 1.0 | ||
N = 15 | ||
M = 4 | ||
L = self.L | ||
w = self.w | ||
xRange = [0.0, L] | ||
yRange = [0.0, w] | ||
|
||
coords, conns = Mesh.create_structured_mesh_data(N, M, xRange, yRange) | ||
blocks = {'block_0': np.arange(conns.shape[0])} | ||
mesh = Mesh.construct_mesh_from_basic_data(coords, conns, blocks=blocks) | ||
pOrder = 2 | ||
mesh = Mesh.create_higher_order_mesh_from_simplex_mesh(mesh, order=pOrder) | ||
|
||
nodeSets = {'left': np.flatnonzero(mesh.coords[:,0] < xRange[0] + 1e-8), | ||
'right': np.flatnonzero(mesh.coords[:,0] > xRange[1] - 1e-8), | ||
'bottom': np.flatnonzero(mesh.coords[:,1] < yRange[0] + 1e-8)} | ||
self.mesh = Mesh.mesh_with_nodesets(mesh, nodeSets) | ||
|
||
quadRule = QuadratureRule.create_quadrature_rule_on_triangle(degree=2*(pOrder-1)) | ||
self.fs = FunctionSpace.construct_function_space(self.mesh, quadRule) | ||
|
||
ebcs = [FunctionSpace.EssentialBC(nodeSet='left', component=0), | ||
FunctionSpace.EssentialBC(nodeSet='right', component=0), | ||
FunctionSpace.EssentialBC(nodeSet='bottom', component=1)] | ||
|
||
self.dofManager = FunctionSpace.DofManager(self.fs, dim=2, EssentialBCs=ebcs) | ||
|
||
# E = 10.0 | ||
# nu = 0.25 | ||
# Y0 = 0.01*E | ||
# H = E/100 | ||
# props = {'elastic modulus': E, | ||
# 'poisson ratio': nu, | ||
# 'yield strength': Y0, | ||
# 'hardening model': 'linear', | ||
# 'hardening modulus': H} | ||
|
||
# materialModel = Material.create_material_model_functions(props) | ||
|
||
self.K_eq = 1.e2 | ||
self.G_eq = 1.0 | ||
|
||
G_neq_1 = 5.0 | ||
tau_1 = 0.1 | ||
|
||
props = { | ||
'equilibrium bulk modulus' : self.K_eq, | ||
'equilibrium shear modulus' : self.G_eq, | ||
# | ||
'non equilibrium shear modulus': G_neq_1, | ||
'relaxation time' : tau_1, | ||
} | ||
materialModel = Material.create_material_model_functions(props) | ||
|
||
self.mechanicsFunctions = Mechanics.create_mechanics_functions( | ||
self.fs, "plane strain", materialModel, dt=dt | ||
) | ||
|
||
self.outputForce = [] | ||
self.outputDisp = [] | ||
|
||
|
||
def assemble_sparse(self, Uu, p): | ||
U = self.create_field(Uu, p) | ||
internalVariables = p[1] | ||
elementStiffnesses = self.mechanicsFunctions.\ | ||
compute_element_stiffnesses(U, internalVariables) | ||
return SparseMatrixAssembler.assemble_sparse_stiffness_matrix(elementStiffnesses, | ||
self.mesh.conns, | ||
self.dofManager) | ||
|
||
|
||
def energy_function(self, Uu, p): | ||
U = self.create_field(Uu, p) | ||
internalVariables = p[1] | ||
return self.mechanicsFunctions.compute_strain_energy(U, internalVariables) | ||
|
||
|
||
@partial(jax.jit, static_argnums=0) | ||
@partial(jax.value_and_grad, argnums=2) | ||
def compute_reactions_from_bcs(self, Uu, Ubc, internalVariables): | ||
U = self.dofManager.create_field(Uu, Ubc) | ||
return self.mechanicsFunctions.compute_strain_energy(U, internalVariables) | ||
|
||
|
||
def create_field(self, Uu, p): | ||
return self.dofManager.create_field(Uu, self.get_ubcs(p)) | ||
|
||
|
||
def get_ubcs(self, p): | ||
endDisp = p[0] | ||
EbcIndex = (self.mesh.nodeSets['right'],0) | ||
V = np.zeros_like(self.mesh.coords).at[EbcIndex].set(endDisp) | ||
return self.dofManager.get_bc_values(V) | ||
|
||
|
||
def write_output(self, Uu, p, step): | ||
print('writing output') | ||
vtkFileName = 'uniaxial-' + str(step).zfill(3) | ||
writer = VTKWriter.VTKWriter(self.mesh, baseFileName=vtkFileName) | ||
|
||
U = self.create_field(Uu, p) | ||
internalVariables = p[1] | ||
|
||
writer.add_nodal_field(name='displacement', nodalData=U, fieldType=VTKWriter.VTKFieldType.VECTORS) | ||
|
||
bcs = np.array(self.dofManager.isBc, dtype=int) | ||
writer.add_nodal_field(name='bcs', nodalData=bcs, fieldType=VTKWriter.VTKFieldType.VECTORS, dataType=VTKWriter.VTKDataType.INT) | ||
|
||
Ubc = self.dofManager.get_bc_values(U) | ||
_,rxnBc = self.compute_reactions_from_bcs(Uu, Ubc, internalVariables) | ||
reactions = np.zeros(U.shape).at[self.dofManager.isBc].set(rxnBc) | ||
writer.add_nodal_field(name='reactions', nodalData=reactions, fieldType=VTKWriter.VTKFieldType.VECTORS) | ||
|
||
if hasattr(Material, 'EQPS'): | ||
eqpsField = internalVariables[:,:,Material.EQPS] | ||
cellEqpsField = FunctionSpace.\ | ||
project_quadrature_field_to_element_field(self.fs, eqpsField) | ||
writer.add_cell_field(name='eqps', cellData=cellEqpsField, fieldType=VTKWriter.VTKFieldType.SCALARS) | ||
|
||
strainEnergyDensities, stresses = \ | ||
self.mechanicsFunctions.\ | ||
compute_output_energy_densities_and_stresses(U, internalVariables) | ||
cellStrainEnergyDensities = FunctionSpace.\ | ||
project_quadrature_field_to_element_field(self.fs, strainEnergyDensities) | ||
cellStresses = FunctionSpace.\ | ||
project_quadrature_field_to_element_field(self.fs, stresses) | ||
writer.add_cell_field(name='strain_energy_density', | ||
cellData=cellStrainEnergyDensities, | ||
fieldType=VTKWriter.VTKFieldType.SCALARS) | ||
writer.add_cell_field(name='stress', | ||
cellData=cellStresses, | ||
fieldType=VTKWriter.VTKFieldType.TENSORS) | ||
|
||
writer.write() | ||
|
||
self.outputForce.append(float(np.sum(reactions[self.mesh.nodeSets['right'],0]))) | ||
self.outputDisp.append(float(p[0])) | ||
|
||
|
||
def run(self): | ||
|
||
Uu = self.dofManager.get_unknown_values(np.zeros(self.mesh.coords.shape)) | ||
settings = EqSolver.get_settings(max_cumulative_cg_iters=100, | ||
max_trust_iters=1000, | ||
use_preconditioned_inner_product_for_cg=True) | ||
|
||
xDisp = 0.0 | ||
state = self.mechanicsFunctions.compute_initial_state() | ||
p = Objective.Params(xDisp, state) | ||
|
||
precondStrategy = Objective.PrecondStrategy(self.assemble_sparse) | ||
objective = Objective.ScaledObjective(self.energy_function, Uu, p, precondStrategy) | ||
|
||
self.write_output(Uu, p, step=0) | ||
|
||
N = 250 | ||
|
||
# maxDisp = self.L*0.05 | ||
maxDisp = self.L | ||
for i in range(1, N+1): | ||
print('LOAD STEP ', i, '------------------------\n') | ||
if i < N / 10: | ||
xDisp = i/((N / 10)*maxDisp) | ||
p = Objective.param_index_update(p, 0, xDisp) | ||
|
||
Uu = EqSolver.nonlinear_equation_solve(objective, Uu, p, settings) | ||
|
||
state = self.mechanicsFunctions.\ | ||
compute_updated_internal_variables(self.create_field(Uu, p), p[1]) | ||
p = Objective.param_index_update(p, 1, state) | ||
|
||
self.write_output(Uu, p, i) | ||
|
||
|
||
def make_FD_plot(self): | ||
plt.figure(1) | ||
plt.plot(self.outputDisp, self.outputForce, marker='o') | ||
plt.xlabel('Displacement') | ||
plt.ylabel('Force') | ||
# plt.savefig('uniaxial_FD.pdf') | ||
|
||
times = np.linspace(0.0, 2.0, num=len(self.outputDisp)) | ||
plt.figure(2) | ||
plt.plot(times, self.outputDisp, marker='o') | ||
plt.xlabel('Time') | ||
plt.ylabel('Displacement') | ||
# plt.savefig('uniaxial_TU.pdf') | ||
|
||
plt.figure(3) | ||
plt.plot(times, self.outputForce, marker='o') | ||
plt.xlabel('Time') | ||
plt.ylabel('Force') | ||
# plt.savefig('uniaxial_TF.pdf') | ||
|
||
|
||
if __name__=='__main__': | ||
app = Uniaxial() | ||
app.run() | ||
app.make_FD_plot() | ||
|
||
# reset mechanics functions to just do the equilibrium | ||
app.outputDisp = [] | ||
app.outputForce = [] | ||
|
||
props = { | ||
# 'bulk modulus' : app.K_eq, | ||
# 'shear modulus': app.G_eq, | ||
'elastic modulus': (9. * app.K_eq * app.G_eq) / (3. * app.K_eq + app.G_eq), | ||
'poisson ratio' : (3. * app.K_eq - 2. * app.G_eq) / (2. * (3. * app.K_eq + app.G_eq)), | ||
'version' : 'adagio' | ||
} | ||
materialModel = Neohookean.create_material_model_functions(props) | ||
|
||
app.mechanicsFunctions = Mechanics.create_mechanics_functions( | ||
app.fs, "plane strain", materialModel, dt=0.0 | ||
) | ||
|
||
app.run() | ||
app.make_FD_plot() | ||
|
||
plt.figure(1) | ||
plt.savefig('uniaxial_FD.pdf') | ||
|
||
plt.figure(2) | ||
plt.savefig('uniaxial_TU.pdf') | ||
|
||
plt.figure(3) | ||
plt.savefig('uniaxial_TF.pdf') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
import jax.numpy as np | ||
from jax import vmap | ||
from jax.scipy import linalg | ||
from optimism import TensorMath | ||
from optimism.material.MaterialModel import MaterialModel | ||
|
||
|
||
# props | ||
PROPS_K_eq = 0 | ||
PROPS_G_eq = 1 | ||
PROPS_G_neq = 2 | ||
PROPS_TAU = 3 | ||
|
||
NUM_PRONY_TERMS = -1 | ||
|
||
# isvs | ||
VISCOUS_DISTORTION = slice(0, 9) | ||
|
||
def create_material_model_functions(properties): | ||
|
||
# prop processing | ||
density = properties.get('density') | ||
props = _make_properties(properties) | ||
|
||
# energy function wrapper | ||
def energy_density(dispGrad, state, dt): | ||
return _energy_density(dispGrad, state, dt, props) | ||
|
||
# wrapper for state var ics | ||
def compute_initial_state(shape=(1,)): | ||
# num_prony_terms = properties['number of prony terms'] | ||
|
||
# def vmap_body(_): | ||
# return np.identity(3).ravel() | ||
|
||
# state = np.hstack(vmap(vmap_body, in_axes=(0,))(np.arange(num_prony_terms))) | ||
state = np.identity(3).ravel() | ||
return state | ||
|
||
# update state vars wrapper | ||
def compute_state_new(dispGrad, state, dt): | ||
state = _compute_state_new(dispGrad, state, dt, props) | ||
return state | ||
|
||
return MaterialModel( | ||
energy_density, | ||
compute_initial_state, | ||
compute_state_new, | ||
density | ||
) | ||
|
||
# implementation | ||
def _make_properties(properties): | ||
# assert properties['number of prony terms'] > 0, 'Need at least 1 prony term' | ||
assert 'equilibrium bulk modulus' in properties.keys() | ||
assert 'equilibrium shear modulus' in properties.keys() | ||
# for n in range(1, properties['number of prony terms'] + 1): | ||
# assert 'non equilibrium shear modulus %s' % n in properties.keys() | ||
# assert 'relaxation time %s' % n in properties.keys() | ||
assert 'non equilibrium shear modulus' in properties.keys() | ||
assert 'relaxation time' in properties.keys() | ||
|
||
print('Equilibrium properties') | ||
print(' Bulk modulus = %s' % properties['equilibrium bulk modulus']) | ||
print(' Shear modulus = %s' % properties['equilibrium shear modulus']) | ||
print('Prony branch properties') | ||
print(' Shear modulus = %s' % properties['non equilibrium shear modulus']) | ||
print(' Relaxation time = %s' % properties['relaxation time']) | ||
# this is dirty, fuck jax (can't use an int from a jax numpy array or else jit tries to trace that) | ||
# global NUM_PRONY_TERMS | ||
# NUM_PRONY_TERMS = properties['number of prony terms'] | ||
|
||
# first pack equilibrium properties | ||
props = np.array([ | ||
properties['equilibrium bulk modulus'], | ||
properties['equilibrium shear modulus'], | ||
properties['non equilibrium shear modulus'], | ||
properties['relaxation time'] | ||
]) | ||
|
||
# props = np.hstack((props, properties['number of prony terms'])) | ||
|
||
# for n in range(1, properties['number of prony terms'] + 1): | ||
# print('Prony branch %s properties' % n) | ||
# print(' Shear modulus = %s' % properties['non equilibrium shear modulus %s' % n]) | ||
# print(' Relaxation time = %s' % properties['relaxation time %s' % n]) | ||
# props = np.hstack( | ||
# (props, np.array([properties['non equilibrium shear modulus %s' % n], | ||
# properties['relaxation time %s' % n]]))) | ||
|
||
|
||
|
||
return props | ||
|
||
def _energy_density(dispGrad, state, dt, props): | ||
W_eq = _eq_strain_energy(dispGrad, props) | ||
W_neq = _neq_strain_energy(dispGrad, state, dt, props) | ||
return W_eq + W_neq | ||
|
||
# TODO generalize to arbitrary strain energy density | ||
def _eq_strain_energy(dispGrad, props): | ||
K, G = props[PROPS_K_eq], props[PROPS_G_eq] | ||
F = dispGrad + np.eye(3) | ||
J = np.linalg.det(F) | ||
J23 = np.power(J, -2.0 / 3.0) | ||
I1Bar = J23 * np.tensordot(F,F) | ||
Wvol = 0.5 * K * (0.5 * J**2 - 0.5 - np.log(J)) | ||
Wdev = 0.5 * G * (I1Bar - 3.0) | ||
return Wdev + Wvol | ||
|
||
def _neq_strain_energy(dispGrad, stateOld, dt, props): | ||
I = np.identity(3) | ||
F = dispGrad + I | ||
state_new = _compute_state_new(dispGrad, stateOld, dt, props) | ||
# Fvs = state_new.reshape((NUM_PRONY_TERMS, 3, 3)) | ||
Fv_new = state_new.reshape((3, 3)) | ||
|
||
G_neq = props[PROPS_G_neq] | ||
tau = props[PROPS_TAU] | ||
eta = G_neq * tau | ||
|
||
Fe = F @ np.linalg.inv(Fv_new) | ||
Ee = 0.5 * TensorMath.mtk_log_sqrt(Fe.T @ Fe) | ||
Me = 2. * G_neq * Ee | ||
M_bar = TensorMath.norm_of_deviator_squared(Me) | ||
gamma_dot = M_bar / eta | ||
# visco_energy = (dt / (G_neq * tau)) * M_bar**2 | ||
visco_energy = 0.5 * dt * eta * gamma_dot**2 | ||
|
||
W_neq = G_neq * TensorMath.norm_of_deviator_squared(Ee) + visco_energy | ||
# def vmap_body(n, Fv): | ||
# G_neq = props[PROPS_G_neq + 2 * n] | ||
# # tau = props[PROPS_TAU + 2 * n] | ||
Comment on lines
+131
to
+133
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW, If we end up implementing multiple Prony terms, I'd suggest trying with a |
||
|
||
# Fe = F @ np.linalg.inv(Fv) | ||
# Ee = 0.5 * TensorMath.mtk_log_sqrt(Fe.T @ Fe) | ||
|
||
# # viscous shearing | ||
# # Me = 2. * G_neq * Ee | ||
# # M_bar = TensorMath.norm_of_deviator_squared(Me) | ||
# # visco_energy = (dt / (G_neq * tau)) * M_bar**2 | ||
|
||
# # still need another term I think | ||
# W_neq = G_neq * TensorMath.norm_of_deviator_squared(Ee) #+ visco_energy | ||
# return W_neq | ||
|
||
# W_neq = np.sum(vmap(vmap_body, in_axes=(0, 0))(np.arange(NUM_PRONY_TERMS), Fvs)) | ||
|
||
return W_neq | ||
|
||
# state update | ||
def _compute_state_new(dispGrad, stateOld, dt, props): | ||
state_inc = _compute_state_increment(dispGrad, stateOld, dt, props) | ||
# Fv_olds = stateOld.reshape((NUM_PRONY_TERMS, 3, 3)) | ||
# Fv_incs = state_inc.reshape((NUM_PRONY_TERMS, 3, 3)) | ||
|
||
# def vmap_body(n, Fv_old, Fv_inc): | ||
# Fv_new = Fv_inc @ Fv_old | ||
# return Fv_new.ravel() | ||
|
||
# state_new = np.hstack(vmap(vmap_body, in_axes=(0, 0, 0))(np.arange(NUM_PRONY_TERMS), Fv_olds, Fv_incs)) | ||
|
||
Fv_old = stateOld.reshape((3, 3)) | ||
Fv_inc = state_inc.reshape((3, 3)) | ||
state_new = (Fv_inc @ Fv_old).ravel() | ||
return state_new | ||
|
||
def _compute_state_increment(dispGrad, stateOld, dt, props): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The J2 model had this function so that we could put the finite strain and the infinitesimal strain versions in one implementation. For this model, which only makes sense as a finite deformation model, we might as well do all of the update in the |
||
I = np.identity(3) | ||
F = dispGrad + I | ||
# Fv_olds = stateOld.reshape((NUM_PRONY_TERMS, 3, 3)) | ||
|
||
# def vmap_body(n, Fv_old): | ||
# # TODO add shift factor | ||
# G_neq = props[PROPS_G_neq + 2 * n] | ||
# tau = props[PROPS_TAU + 2 * n] | ||
|
||
# # kinematics | ||
# Fe_trial = F @ np.linalg.inv(Fv_old) | ||
# Ee_trial = 0.5 * TensorMath.mtk_log_sqrt(Fe_trial.T @ Fe_trial) | ||
# Ee_dev = Ee_trial - (1. / 3.) * np.trace(Ee_trial) * I | ||
|
||
# # updates | ||
# integration_factor = 1. / (1. + dt / tau) | ||
|
||
# Me = 2.0 * G_neq * Ee_dev | ||
# Me = integration_factor * Me | ||
|
||
# Dv = (1. / (2. * G_neq * tau)) * Me | ||
# A = dt * Dv | ||
|
||
# Fv_inc = linalg.expm(A) | ||
|
||
# return Fv_inc.ravel() | ||
|
||
# state_inc = np.hstack(vmap(vmap_body, in_axes=(0, 0))(np.arange(NUM_PRONY_TERMS), Fv_olds)) | ||
|
||
Fv_old = stateOld.reshape((3, 3)) | ||
G_neq = props[PROPS_G_neq] | ||
tau = props[PROPS_TAU] | ||
|
||
Fe_trial = F @ np.linalg.inv(Fv_old) | ||
Ee_trial = 0.5 * TensorMath.mtk_log_sqrt(Fe_trial.T @ Fe_trial) | ||
Ee_dev = Ee_trial - (1. / 3.) * np.trace(Ee_trial) * I | ||
|
||
integration_factor = 1. / (1. + dt / tau) | ||
|
||
Me = 2.0 * G_neq * Ee_dev | ||
Me = integration_factor * Me | ||
|
||
Dv = (1. / (2. * G_neq * tau)) * Me | ||
A = dt * Dv | ||
|
||
Fv_inc = linalg.expm(A) | ||
|
||
return Fv_inc.ravel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Late comment. Can I ask, why did
dt
get moved into the factory function?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remember correctly, things weren't properly connected so dt wasn't making its way to the material models. Maybe I misunderstood the interface as it was though?