Skip to content

Commit

Permalink
Merge pull request #99 from sandialabs/chamel/multi-branch-hyper-visco
Browse files Browse the repository at this point in the history
adding a multi branch (really hardcoded to 3 branches) hypervisco model.
  • Loading branch information
ralberd authored Nov 7, 2024
2 parents d215c9f + 26decde commit 81ccef8
Show file tree
Hide file tree
Showing 2 changed files with 262 additions and 0 deletions.
149 changes: 149 additions & 0 deletions optimism/material/MultiBranchHyperViscoelastic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import jax.numpy as np
from jax.scipy import linalg
from optimism import TensorMath
from optimism.material.MaterialModel import MaterialModel

import jax

PROPS_K_eq = 0
PROPS_G_eq = 1
PROPS_G_neq_1 = 2
PROPS_TAU_1 = 3
PROPS_G_neq_2 = 4
PROPS_TAU_2 = 5
PROPS_G_neq_3 = 6
PROPS_TAU_3 = 7


NUM_PRONY_TERMS = 3
VISCOUS_DISTORTION_SIZE = 9

def create_material_model_functions(properties):

density = properties.get('density')
props = _make_properties(properties)

def energy_density(dispGrad, state, dt):
return _energy_density(dispGrad, state, dt, props)

def compute_initial_state(shape=(1,)):
state = np.array([])
for n in range(NUM_PRONY_TERMS):
state = np.hstack((state, np.identity(3).ravel()))
return state

def compute_state_new(dispGrad, state, dt):
state = _compute_state_new(dispGrad, state, dt, props)
return state

def compute_material_qoi(dispGrad, state, dt):
return _compute_dissipated_energy(dispGrad, state, dt, props)

return MaterialModel(compute_energy_density = energy_density,
compute_initial_state = compute_initial_state,
compute_state_new = compute_state_new,
compute_material_qoi = compute_material_qoi,
density = density)

def _make_properties(properties):

print('Equilibrium properties')
print(' Bulk modulus = %s' % properties['equilibrium bulk modulus'])
print(' Shear modulus = %s' % properties['equilibrium shear modulus'])
print('Prony branch properties')
for n in range(NUM_PRONY_TERMS):
print(f' Shear modulus {n + 1} = %s' % properties[f'non equilibrium shear modulus {n + 1}'])
print(f' Relaxation time {n + 1} = %s' % properties[f'relaxation time {n + 1}'])

props = np.array([
properties['equilibrium bulk modulus'],
properties['equilibrium shear modulus'],
properties['non equilibrium shear modulus 1'],
properties['relaxation time 1'],
properties['non equilibrium shear modulus 2'],
properties['relaxation time 2'],
properties['non equilibrium shear modulus 3'],
properties['relaxation time 3']
])

return props

def _energy_density(dispGrad, state, dt, props):
W_eq = _eq_strain_energy(dispGrad, props)
W_neq = 0.0
Psi = 0.0
for n in range(NUM_PRONY_TERMS):
state_temp = _return_state_for_branch(state, n)
Ee_trial = _compute_elastic_logarithmic_strain(dispGrad, state_temp)
delta_Ev = _compute_state_increment(Ee_trial, dt, props, _return_Gneq_id_for_branch(n))
Ee = Ee_trial - delta_Ev
W_neq = W_neq + _neq_strain_energy(Ee, props, _return_Gneq_id_for_branch(n))

Dv = delta_Ev / dt
Psi = Psi + _dissipation_potential(Dv, props, _return_Gneq_id_for_branch(n))

return W_eq + W_neq + dt * Psi

def _eq_strain_energy(dispGrad, props):
K, G = props[PROPS_K_eq], props[PROPS_G_eq]
F = dispGrad + np.eye(3)
J = np.linalg.det(F)
J23 = np.power(J, -2.0 / 3.0)
I1Bar = J23 * np.tensordot(F,F)
Wvol = 0.5 * K * (0.5 * J**2 - 0.5 - np.log(J))
Wdev = 0.5 * G * (I1Bar - 3.0)
return Wdev + Wvol

def _neq_strain_energy(elasticStrain, props, prop_id):
G_neq = props[prop_id]
return G_neq * TensorMath.norm_of_deviator_squared(elasticStrain)

def _dissipation_potential(Dv, props, prop_id):
G_neq = props[prop_id]
tau = props[prop_id + 1]
eta = G_neq * tau

return eta * TensorMath.norm_of_deviator_squared(Dv)

def _compute_dissipated_energy(dispGrad, state, dt, props):
Psi = 0.0
for n in range(NUM_PRONY_TERMS):
state_temp = _return_state_for_branch(state, n)
Ee_trial = _compute_elastic_logarithmic_strain(dispGrad, state_temp)
delta_Ev = _compute_state_increment(Ee_trial, dt, props, _return_Gneq_id_for_branch(n))
Dv = delta_Ev / dt
Psi = Psi + dt * _dissipation_potential(Dv, props, _return_Gneq_id_for_branch(n))

return Psi

def _compute_state_new(dispGrad, stateOld, dt, props):
state_new = np.array([])
for n in range(NUM_PRONY_TERMS):
state_temp = _return_state_for_branch(stateOld, n)
Ee_trial = _compute_elastic_logarithmic_strain(dispGrad, state_temp)
delta_Ev = _compute_state_increment(Ee_trial, dt, props, _return_Gneq_id_for_branch(n))

Fv_old = state_temp.reshape((3, 3))
Fv_new = linalg.expm(delta_Ev)@Fv_old
state_new = np.hstack((state_new, Fv_new.ravel()))
return state_new

def _compute_state_increment(elasticStrain, dt, props, prop_id):
tau = props[prop_id + 1]
integration_factor = 1. / (1. + dt / tau)

Ee_dev = TensorMath.dev(elasticStrain)
return dt * integration_factor * Ee_dev / tau # dt * D

def _compute_elastic_logarithmic_strain(dispGrad, stateOld):
F = dispGrad + np.identity(3)
Fv_old = stateOld.reshape((3, 3))

Fe_trial = F @ np.linalg.inv(Fv_old)
return TensorMath.log_sqrt_symm(Fe_trial.T @ Fe_trial)

def _return_state_for_branch(state, n):
return state.at[n * VISCOUS_DISTORTION_SIZE : (n + 1) * VISCOUS_DISTORTION_SIZE].get()

def _return_Gneq_id_for_branch(n):
return PROPS_G_neq_1 + 2 * n
113 changes: 113 additions & 0 deletions optimism/material/test/test_MultiBranchHyperVisco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import unittest

import jax
import jax.numpy as np
from jax.scipy import linalg
from matplotlib import pyplot as plt

from optimism.material import MultiBranchHyperViscoelastic as HyperVisco
from optimism.test.TestFixture import TestFixture
from optimism.material import MaterialUniaxialSimulator
from optimism.TensorMath import deviator
from optimism.TensorMath import log_symm

plotting = False

class HyperViscoModelFixture(TestFixture):
def setUp(self):

G_eq = 0.855 # MPa
K_eq = 1000*G_eq # MPa
G_neq_1 = 1.0
tau_1 = 1.0
G_neq_2 = 2.0
tau_2 = 10.0
G_neq_3 = 3.0
tau_3 = 100.0

self.G_neqs = np.array([G_neq_1, G_neq_2, G_neq_3])
self.taus = np.array([tau_1, tau_2, tau_3])
self.etas = self.G_neqs * self.taus

self.props = {
'equilibrium bulk modulus' : K_eq,
'equilibrium shear modulus' : G_eq,
'non equilibrium shear modulus 1': G_neq_1,
'relaxation time 1' : tau_1,
'non equilibrium shear modulus 2': G_neq_2,
'relaxation time 2' : tau_2,
'non equilibrium shear modulus 3': G_neq_3,
'relaxation time 3' : tau_3,
}

materialModel = HyperVisco.create_material_model_functions(self.props)
self.energy_density = jax.jit(materialModel.compute_energy_density)
self.compute_state_new = jax.jit(materialModel.compute_state_new)
self.compute_initial_state = materialModel.compute_initial_state
self.compute_material_qoi = jax.jit(materialModel.compute_material_qoi)

class HyperViscoUniaxialStrain(HyperViscoModelFixture):
def test_loading_only(self):
strain_rate = 1.0e-2
total_time = 100.0
n_steps = 100
dt = total_time / n_steps
times = np.linspace(0.0, total_time, n_steps)
Fs = jax.vmap(
lambda t: np.array(
[[np.exp(strain_rate * t), 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0]]
)
)(times)
state_old = self.compute_initial_state()
energies = np.zeros(n_steps)
states = np.zeros((n_steps, state_old.shape[0]))
dissipated_energies = np.zeros(n_steps)

# numerical solution
for n, F in enumerate(Fs):
dispGrad = F - np.eye(3)
energies = energies.at[n].set(self.energy_density(dispGrad, state_old, dt))
state_new = self.compute_state_new(dispGrad, state_old, dt)
states = states.at[n, :].set(state_new)
dissipated_energies = dissipated_energies.at[n].set(self.compute_material_qoi(dispGrad, state_old, dt))
state_old = state_new

dissipated_energies_analytic = np.zeros(len(Fs))

for n in range(3):
Fvs = jax.vmap(lambda Fv: Fv.at[9 * n:9 * (n + 1)].get().reshape((3, 3)))(states)
Fes = jax.vmap(lambda F, Fv: F @ np.linalg.inv(Fv), in_axes=(0, 0))(Fs, Fvs)

Evs = jax.vmap(lambda Fv: log_symm(Fv))(Fvs)
Ees = jax.vmap(lambda Fe: log_symm(Fe))(Fes)

# analytic solution
e_v_11 = (2. / 3.) * strain_rate * times - \
(2. / 3.) * strain_rate * self.taus[n] * (1. - np.exp(-times / self.taus[n]))

e_e_11 = strain_rate * times - e_v_11
e_e_22 = 0.5 * e_v_11

Ee_analytic = jax.vmap(
lambda e_11, e_22: np.array(
[[e_11, 0., 0.],
[0., e_22, 0.],
[0., 0., e_22]]
), in_axes=(0, 0)
)(e_e_11, e_e_22)

Me_analytic = jax.vmap(lambda Ee: 2. * self.G_neqs[n] * deviator(Ee))(Ee_analytic)
Dv_analytic = jax.vmap(lambda Me: 1. / (2. * self.etas[n]) * deviator(Me))(Me_analytic)
dissipated_energies_analytic += jax.vmap(lambda Dv: dt * self.etas[n] * np.tensordot(deviator(Dv), deviator(Dv)) )(Dv_analytic)

# test
self.assertArrayNear(Evs[:, 0, 0], e_v_11, 3)
self.assertArrayNear(Ees[:, 0, 0], e_e_11, 3)
self.assertArrayNear(Ees[:, 1, 1], e_e_22, 3)

self.assertArrayNear(dissipated_energies, dissipated_energies_analytic, 3)

if __name__ == '__main__':
unittest.main()

0 comments on commit 81ccef8

Please sign in to comment.