Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding a multi branch (really hardcoded to 3 branches) hypervisco model. #99

Merged
merged 5 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions optimism/material/MultiBranchHyperViscoelastic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import jax.numpy as np
from jax.scipy import linalg
from optimism import TensorMath
from optimism.material.MaterialModel import MaterialModel

import jax

PROPS_K_eq = 0
PROPS_G_eq = 1
PROPS_G_neq_1 = 2
PROPS_TAU_1 = 3
PROPS_G_neq_2 = 4
PROPS_TAU_2 = 5
PROPS_G_neq_3 = 6
PROPS_TAU_3 = 7


NUM_PRONY_TERMS = 3
VISCOUS_DISTORTION_SIZE = 9

def create_material_model_functions(properties):

density = properties.get('density')
props = _make_properties(properties)

def energy_density(dispGrad, state, dt):
return _energy_density(dispGrad, state, dt, props)

def compute_initial_state(shape=(1,)):
state = np.array([])
for n in range(NUM_PRONY_TERMS):
state = np.hstack((state, np.identity(3).ravel()))
return state

def compute_state_new(dispGrad, state, dt):
state = _compute_state_new(dispGrad, state, dt, props)
return state

def compute_material_qoi(dispGrad, state, dt):
return _compute_dissipated_energy(dispGrad, state, dt, props)

return MaterialModel(compute_energy_density = energy_density,
compute_initial_state = compute_initial_state,
compute_state_new = compute_state_new,
compute_material_qoi = compute_material_qoi,
density = density)

def _make_properties(properties):

print('Equilibrium properties')
print(' Bulk modulus = %s' % properties['equilibrium bulk modulus'])
print(' Shear modulus = %s' % properties['equilibrium shear modulus'])
print('Prony branch properties')
for n in range(NUM_PRONY_TERMS):
print(f' Shear modulus {n + 1} = %s' % properties[f'non equilibrium shear modulus {n + 1}'])
print(f' Relaxation time {n + 1} = %s' % properties[f'relaxation time {n + 1}'])

props = np.array([
properties['equilibrium bulk modulus'],
properties['equilibrium shear modulus'],
properties['non equilibrium shear modulus 1'],
properties['relaxation time 1'],
properties['non equilibrium shear modulus 2'],
properties['relaxation time 2'],
properties['non equilibrium shear modulus 3'],
properties['relaxation time 3']
])

return props

def _energy_density(dispGrad, state, dt, props):
W_eq = _eq_strain_energy(dispGrad, props)
W_neq = 0.0
Psi = 0.0
for n in range(NUM_PRONY_TERMS):
state_temp = _return_state_for_branch(state, n)
Ee_trial = _compute_elastic_logarithmic_strain(dispGrad, state_temp)
delta_Ev = _compute_state_increment(Ee_trial, dt, props, _return_Gneq_id_for_branch(n))
Ee = Ee_trial - delta_Ev
W_neq = W_neq + _neq_strain_energy(Ee, props, _return_Gneq_id_for_branch(n))

Dv = delta_Ev / dt
Psi = Psi + _dissipation_potential(Dv, props, _return_Gneq_id_for_branch(n))

return W_eq + W_neq + dt * Psi

def _eq_strain_energy(dispGrad, props):
K, G = props[PROPS_K_eq], props[PROPS_G_eq]
F = dispGrad + np.eye(3)
J = np.linalg.det(F)
J23 = np.power(J, -2.0 / 3.0)
I1Bar = J23 * np.tensordot(F,F)
Wvol = 0.5 * K * (0.5 * J**2 - 0.5 - np.log(J))
Wdev = 0.5 * G * (I1Bar - 3.0)
return Wdev + Wvol

def _neq_strain_energy(elasticStrain, props, prop_id):
G_neq = props[prop_id]
return G_neq * TensorMath.norm_of_deviator_squared(elasticStrain)

def _dissipation_potential(Dv, props, prop_id):
G_neq = props[prop_id]
tau = props[prop_id + 1]
eta = G_neq * tau

return eta * TensorMath.norm_of_deviator_squared(Dv)

def _compute_dissipated_energy(dispGrad, state, dt, props):
Psi = 0.0
for n in range(NUM_PRONY_TERMS):
state_temp = _return_state_for_branch(state, n)
Ee_trial = _compute_elastic_logarithmic_strain(dispGrad, state_temp)
delta_Ev = _compute_state_increment(Ee_trial, dt, props, _return_Gneq_id_for_branch(n))
Dv = delta_Ev / dt
Psi = Psi + dt * _dissipation_potential(Dv, props, _return_Gneq_id_for_branch(n))

return Psi

def _compute_state_new(dispGrad, stateOld, dt, props):
state_new = np.array([])
for n in range(NUM_PRONY_TERMS):
state_temp = _return_state_for_branch(stateOld, n)
Ee_trial = _compute_elastic_logarithmic_strain(dispGrad, state_temp)
delta_Ev = _compute_state_increment(Ee_trial, dt, props, _return_Gneq_id_for_branch(n))

Fv_old = state_temp.reshape((3, 3))
Fv_new = linalg.expm(delta_Ev)@Fv_old
state_new = np.hstack((state_new, Fv_new.ravel()))
return state_new

def _compute_state_increment(elasticStrain, dt, props, prop_id):
tau = props[prop_id + 1]
integration_factor = 1. / (1. + dt / tau)

Ee_dev = TensorMath.dev(elasticStrain)
return dt * integration_factor * Ee_dev / tau # dt * D

def _compute_elastic_logarithmic_strain(dispGrad, stateOld):
F = dispGrad + np.identity(3)
Fv_old = stateOld.reshape((3, 3))

Fe_trial = F @ np.linalg.inv(Fv_old)
return TensorMath.log_sqrt_symm(Fe_trial.T @ Fe_trial)

def _return_state_for_branch(state, n):
return state.at[n * VISCOUS_DISTORTION_SIZE : (n + 1) * VISCOUS_DISTORTION_SIZE].get()

def _return_Gneq_id_for_branch(n):
return PROPS_G_neq_1 + 2 * n
113 changes: 113 additions & 0 deletions optimism/material/test/test_MultiBranchHyperVisco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import unittest

import jax
import jax.numpy as np
from jax.scipy import linalg
from matplotlib import pyplot as plt

from optimism.material import MultiBranchHyperViscoelastic as HyperVisco
from optimism.test.TestFixture import TestFixture
from optimism.material import MaterialUniaxialSimulator
from optimism.TensorMath import deviator
from optimism.TensorMath import log_symm

plotting = False

class HyperViscoModelFixture(TestFixture):
def setUp(self):

G_eq = 0.855 # MPa
K_eq = 1000*G_eq # MPa
G_neq_1 = 1.0
tau_1 = 1.0
G_neq_2 = 2.0
tau_2 = 10.0
G_neq_3 = 3.0
tau_3 = 100.0

self.G_neqs = np.array([G_neq_1, G_neq_2, G_neq_3])
self.taus = np.array([tau_1, tau_2, tau_3])
self.etas = self.G_neqs * self.taus

self.props = {
'equilibrium bulk modulus' : K_eq,
'equilibrium shear modulus' : G_eq,
'non equilibrium shear modulus 1': G_neq_1,
'relaxation time 1' : tau_1,
'non equilibrium shear modulus 2': G_neq_2,
'relaxation time 2' : tau_2,
'non equilibrium shear modulus 3': G_neq_3,
'relaxation time 3' : tau_3,
}

materialModel = HyperVisco.create_material_model_functions(self.props)
self.energy_density = jax.jit(materialModel.compute_energy_density)
self.compute_state_new = jax.jit(materialModel.compute_state_new)
self.compute_initial_state = materialModel.compute_initial_state
self.compute_material_qoi = jax.jit(materialModel.compute_material_qoi)

class HyperViscoUniaxialStrain(HyperViscoModelFixture):
def test_loading_only(self):
strain_rate = 1.0e-2
total_time = 100.0
n_steps = 100
dt = total_time / n_steps
times = np.linspace(0.0, total_time, n_steps)
Fs = jax.vmap(
lambda t: np.array(
[[np.exp(strain_rate * t), 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0]]
)
)(times)
state_old = self.compute_initial_state()
energies = np.zeros(n_steps)
states = np.zeros((n_steps, state_old.shape[0]))
dissipated_energies = np.zeros(n_steps)

# numerical solution
for n, F in enumerate(Fs):
dispGrad = F - np.eye(3)
energies = energies.at[n].set(self.energy_density(dispGrad, state_old, dt))
state_new = self.compute_state_new(dispGrad, state_old, dt)
states = states.at[n, :].set(state_new)
dissipated_energies = dissipated_energies.at[n].set(self.compute_material_qoi(dispGrad, state_old, dt))
state_old = state_new

dissipated_energies_analytic = np.zeros(len(Fs))

for n in range(3):
Fvs = jax.vmap(lambda Fv: Fv.at[9 * n:9 * (n + 1)].get().reshape((3, 3)))(states)
Fes = jax.vmap(lambda F, Fv: F @ np.linalg.inv(Fv), in_axes=(0, 0))(Fs, Fvs)

Evs = jax.vmap(lambda Fv: log_symm(Fv))(Fvs)
Ees = jax.vmap(lambda Fe: log_symm(Fe))(Fes)

# analytic solution
e_v_11 = (2. / 3.) * strain_rate * times - \
(2. / 3.) * strain_rate * self.taus[n] * (1. - np.exp(-times / self.taus[n]))

e_e_11 = strain_rate * times - e_v_11
e_e_22 = 0.5 * e_v_11

Ee_analytic = jax.vmap(
lambda e_11, e_22: np.array(
[[e_11, 0., 0.],
[0., e_22, 0.],
[0., 0., e_22]]
), in_axes=(0, 0)
)(e_e_11, e_e_22)

Me_analytic = jax.vmap(lambda Ee: 2. * self.G_neqs[n] * deviator(Ee))(Ee_analytic)
Dv_analytic = jax.vmap(lambda Me: 1. / (2. * self.etas[n]) * deviator(Me))(Me_analytic)
dissipated_energies_analytic += jax.vmap(lambda Dv: dt * self.etas[n] * np.tensordot(deviator(Dv), deviator(Dv)) )(Dv_analytic)

# test
self.assertArrayNear(Evs[:, 0, 0], e_v_11, 3)
self.assertArrayNear(Ees[:, 0, 0], e_e_11, 3)
self.assertArrayNear(Ees[:, 1, 1], e_e_22, 3)

self.assertArrayNear(dissipated_energies, dissipated_energies_analytic, 3)

if __name__ == '__main__':
unittest.main()
Loading