Skip to content

Commit

Permalink
Merge branch 'main' into ralberd/path_dependent_adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
ralberd committed Dec 5, 2023
2 parents 638dab4 + d23dc34 commit 1bb5080
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 133 deletions.
54 changes: 54 additions & 0 deletions optimism/QuadratureRule.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,60 @@ def create_quadrature_rule_on_triangle(degree):
4.14255378091870E-02,
4.14255378091870E-02,
4.14255378091870E-02])
elif degree <= 10:
xi = onp.array([[0.33333333333333333E+00, 0.33333333333333333E+00],
[0.4269134091050342E-02, 0.49786543295447483E+00],
[0.49786543295447483E+00, 0.4269134091050342E-02],
[0.49786543295447483E+00, 0.49786543295447483E+00],
[0.14397510054188759E+00, 0.42801244972905617E+00],
[0.42801244972905617E+00, 0.14397510054188759E+00],
[0.42801244972905617E+00, 0.42801244972905617E+00],
[0.6304871745135507E+00, 0.18475641274322457E+00],
[0.18475641274322457E+00, 0.6304871745135507E+00],
[0.18475641274322457E+00, 0.18475641274322457E+00],
[0.9590375628566448E+00, 0.20481218571677562E-01],
[0.20481218571677562E-01, 0.9590375628566448E+00],
[0.20481218571677562E-01, 0.20481218571677562E-01],
[0.3500298989727196E-01, 0.1365735762560334E+00],
[0.1365735762560334E+00, 0.8284234338466947E+00],
[0.8284234338466947E+00, 0.3500298989727196E-01],
[0.1365735762560334E+00, 0.3500298989727196E-01],
[0.8284234338466947E+00, 0.1365735762560334E+00],
[0.3500298989727196E-01, 0.8284234338466947E+00],
[0.37549070258442674E-01, 0.3327436005886386E+00],
[0.3327436005886386E+00, 0.6297073291529187E+00],
[0.6297073291529187E+00, 0.37549070258442674E-01],
[0.3327436005886386E+00, 0.37549070258442674E-01],
[0.6297073291529187E+00, 0.3327436005886386E+00],
[0.37549070258442674E-01, 0.6297073291529187E+00]])

w = onp.array([0.4176169990259819E-01,
0.36149252960283717E-02,
0.36149252960283717E-02,
0.36149252960283717E-02,
0.3724608896049025E-01,
0.3724608896049025E-01,
0.3724608896049025E-01,
0.39323236701554264E-01,
0.39323236701554264E-01,
0.39323236701554264E-01,
0.3464161543553752E-02,
0.3464161543553752E-02,
0.3464161543553752E-02,
0.147591601673897E-01,
0.147591601673897E-01,
0.147591601673897E-01,
0.147591601673897E-01,
0.147591601673897E-01,
0.147591601673897E-01,
0.1978968359803062E-01,
0.1978968359803062E-01,
0.1978968359803062E-01,
0.1978968359803062E-01,
0.1978968359803062E-01,
0.1978968359803062E-01])
else:
raise ValueError("Quadrature of precision this high is not implemented.")

return QuadratureRule(xi, w)

Expand Down
142 changes: 17 additions & 125 deletions optimism/material/HyperViscoelastic.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,29 @@
import jax.numpy as np
from jax import vmap
from jax.scipy import linalg
from optimism import TensorMath
from optimism.material.MaterialModel import MaterialModel


# props
PROPS_K_eq = 0
PROPS_G_eq = 1
PROPS_G_neq = 2
PROPS_TAU = 3

NUM_PRONY_TERMS = -1

# isvs
VISCOUS_DISTORTION = slice(0, 9)

def create_material_model_functions(properties):

# prop processing
density = properties.get('density')
props = _make_properties(properties)

# energy function wrapper
def energy_density(dispGrad, state, dt):
return _energy_density(dispGrad, state, dt, props)

# wrapper for state var ics
def compute_initial_state(shape=(1,)):
# num_prony_terms = properties['number of prony terms']

# def vmap_body(_):
# return np.identity(3).ravel()

# state = np.hstack(vmap(vmap_body, in_axes=(0,))(np.arange(num_prony_terms)))
state = np.identity(3).ravel()
return state

# update state vars wrapper
def compute_state_new(dispGrad, state, dt):
state = _compute_state_new(dispGrad, state, dt, props)
return state
Expand All @@ -49,55 +35,29 @@ def compute_state_new(dispGrad, state, dt):
density
)

# implementation
def _make_properties(properties):
# assert properties['number of prony terms'] > 0, 'Need at least 1 prony term'
assert 'equilibrium bulk modulus' in properties.keys()
assert 'equilibrium shear modulus' in properties.keys()
# for n in range(1, properties['number of prony terms'] + 1):
# assert 'non equilibrium shear modulus %s' % n in properties.keys()
# assert 'relaxation time %s' % n in properties.keys()
assert 'non equilibrium shear modulus' in properties.keys()
assert 'relaxation time' in properties.keys()

print('Equilibrium properties')
print(' Bulk modulus = %s' % properties['equilibrium bulk modulus'])
print(' Shear modulus = %s' % properties['equilibrium shear modulus'])
print('Prony branch properties')
print(' Shear modulus = %s' % properties['non equilibrium shear modulus'])
print(' Relaxation time = %s' % properties['relaxation time'])
# this is dirty, fuck jax (can't use an int from a jax numpy array or else jit tries to trace that)
# global NUM_PRONY_TERMS
# NUM_PRONY_TERMS = properties['number of prony terms']

# first pack equilibrium properties
props = np.array([
properties['equilibrium bulk modulus'],
properties['equilibrium shear modulus'],
properties['non equilibrium shear modulus'],
properties['relaxation time']
])

# props = np.hstack((props, properties['number of prony terms']))

# for n in range(1, properties['number of prony terms'] + 1):
# print('Prony branch %s properties' % n)
# print(' Shear modulus = %s' % properties['non equilibrium shear modulus %s' % n])
# print(' Relaxation time = %s' % properties['relaxation time %s' % n])
# props = np.hstack(
# (props, np.array([properties['non equilibrium shear modulus %s' % n],
# properties['relaxation time %s' % n]])))



return props

def _energy_density(dispGrad, state, dt, props):
W_eq = _eq_strain_energy(dispGrad, props)
W_neq = _neq_strain_energy(dispGrad, state, dt, props)
return W_eq + W_neq

# TODO generalize to arbitrary strain energy density
def _eq_strain_energy(dispGrad, props):
K, G = props[PROPS_K_eq], props[PROPS_G_eq]
F = dispGrad + np.eye(3)
Expand All @@ -109,108 +69,40 @@ def _eq_strain_energy(dispGrad, props):
return Wdev + Wvol

def _neq_strain_energy(dispGrad, stateOld, dt, props):
I = np.identity(3)
F = dispGrad + I
state_new = _compute_state_new(dispGrad, stateOld, dt, props)
# Fvs = state_new.reshape((NUM_PRONY_TERMS, 3, 3))
Fv_new = state_new.reshape((3, 3))

G_neq = props[PROPS_G_neq]
tau = props[PROPS_TAU]
eta = G_neq * tau

Fe = F @ np.linalg.inv(Fv_new)
Ee = 0.5 * TensorMath.mtk_log_sqrt(Fe.T @ Fe)
Ee_trial = _compute_elastic_logarithmic_strain(dispGrad, stateOld)
delta_Ev = _compute_state_increment(Ee_trial, dt, props)
Ee = Ee_trial - delta_Ev

Me = 2. * G_neq * Ee
M_bar = TensorMath.norm_of_deviator_squared(Me)
gamma_dot = M_bar / eta
# visco_energy = (dt / (G_neq * tau)) * M_bar**2
visco_energy = 0.5 * dt * eta * gamma_dot**2

W_neq = G_neq * TensorMath.norm_of_deviator_squared(Ee) + visco_energy
# def vmap_body(n, Fv):
# G_neq = props[PROPS_G_neq + 2 * n]
# # tau = props[PROPS_TAU + 2 * n]
return G_neq * TensorMath.norm_of_deviator_squared(Ee) + visco_energy

# Fe = F @ np.linalg.inv(Fv)
# Ee = 0.5 * TensorMath.mtk_log_sqrt(Fe.T @ Fe)

# # viscous shearing
# # Me = 2. * G_neq * Ee
# # M_bar = TensorMath.norm_of_deviator_squared(Me)
# # visco_energy = (dt / (G_neq * tau)) * M_bar**2

# # still need another term I think
# W_neq = G_neq * TensorMath.norm_of_deviator_squared(Ee) #+ visco_energy
# return W_neq

# W_neq = np.sum(vmap(vmap_body, in_axes=(0, 0))(np.arange(NUM_PRONY_TERMS), Fvs))

return W_neq

# state update
def _compute_state_new(dispGrad, stateOld, dt, props):
state_inc = _compute_state_increment(dispGrad, stateOld, dt, props)
# Fv_olds = stateOld.reshape((NUM_PRONY_TERMS, 3, 3))
# Fv_incs = state_inc.reshape((NUM_PRONY_TERMS, 3, 3))

# def vmap_body(n, Fv_old, Fv_inc):
# Fv_new = Fv_inc @ Fv_old
# return Fv_new.ravel()

# state_new = np.hstack(vmap(vmap_body, in_axes=(0, 0, 0))(np.arange(NUM_PRONY_TERMS), Fv_olds, Fv_incs))
Ee_trial = _compute_elastic_logarithmic_strain(dispGrad, stateOld)
delta_Ev = _compute_state_increment(Ee_trial, dt, props)

Fv_old = stateOld.reshape((3, 3))
Fv_inc = state_inc.reshape((3, 3))
state_new = (Fv_inc @ Fv_old).ravel()
return state_new

def _compute_state_increment(dispGrad, stateOld, dt, props):
I = np.identity(3)
F = dispGrad + I
# Fv_olds = stateOld.reshape((NUM_PRONY_TERMS, 3, 3))
Fv_new = linalg.expm(delta_Ev)@Fv_old
return Fv_new.ravel()

# def vmap_body(n, Fv_old):
# # TODO add shift factor
# G_neq = props[PROPS_G_neq + 2 * n]
# tau = props[PROPS_TAU + 2 * n]

# # kinematics
# Fe_trial = F @ np.linalg.inv(Fv_old)
# Ee_trial = 0.5 * TensorMath.mtk_log_sqrt(Fe_trial.T @ Fe_trial)
# Ee_dev = Ee_trial - (1. / 3.) * np.trace(Ee_trial) * I

# # updates
# integration_factor = 1. / (1. + dt / tau)

# Me = 2.0 * G_neq * Ee_dev
# Me = integration_factor * Me

# Dv = (1. / (2. * G_neq * tau)) * Me
# A = dt * Dv

# Fv_inc = linalg.expm(A)

# return Fv_inc.ravel()

# state_inc = np.hstack(vmap(vmap_body, in_axes=(0, 0))(np.arange(NUM_PRONY_TERMS), Fv_olds))

Fv_old = stateOld.reshape((3, 3))
G_neq = props[PROPS_G_neq]
def _compute_state_increment(elasticStrain, dt, props):
tau = props[PROPS_TAU]

Fe_trial = F @ np.linalg.inv(Fv_old)
Ee_trial = 0.5 * TensorMath.mtk_log_sqrt(Fe_trial.T @ Fe_trial)
Ee_dev = Ee_trial - (1. / 3.) * np.trace(Ee_trial) * I

integration_factor = 1. / (1. + dt / tau)

Me = 2.0 * G_neq * Ee_dev
Me = integration_factor * Me

Dv = (1. / (2. * G_neq * tau)) * Me
A = dt * Dv
Ee_dev = TensorMath.compute_deviatoric_tensor(elasticStrain)
return dt * integration_factor * Ee_dev / tau # dt * D

Fv_inc = linalg.expm(A)
def _compute_elastic_logarithmic_strain(dispGrad, stateOld):
F = dispGrad + np.identity(3)
Fv_old = stateOld.reshape((3, 3))

return Fv_inc.ravel()
Fe_trial = F @ np.linalg.inv(Fv_old)
return TensorMath.mtk_log_sqrt(Fe_trial.T @ Fe_trial)
62 changes: 62 additions & 0 deletions optimism/material/test/test_HyperVisco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import unittest

import jax
import jax.numpy as np
from jax.scipy import linalg

from optimism.material import HyperViscoelastic as HyperVisco
from optimism.test.TestFixture import TestFixture

def make_disp_grad_from_strain(strain):
return linalg.expm(strain) - np.identity(3)


class HyperViscoModelFixture(TestFixture):
def setUp(self):

G_eq = 0.855 # MPa
K_eq = 1000*G_eq # MPa
G_neq_1 = 5.0
tau_1 = 0.1
self.props = {
'equilibrium bulk modulus' : K_eq,
'equilibrium shear modulus' : G_eq,
'non equilibrium shear modulus': G_neq_1,
'relaxation time' : tau_1,
}

materialModel = HyperVisco.create_material_model_functions(self.props)

self.energy_density = jax.jit(materialModel.compute_energy_density)
self.compute_state_new = materialModel.compute_state_new
self.compute_initial_state = materialModel.compute_initial_state

def test_zero_point(self):
dispGrad = np.zeros((3,3))
initialState = self.compute_initial_state()
dt = 1.0

energy = self.energy_density(dispGrad, initialState, dt)
self.assertNear(energy, 0.0, 12)

state = self.compute_state_new(dispGrad, initialState, dt)
self.assertArrayNear(state, np.eye(3).ravel(), 12)

def test_regression_nonzero_point(self):
key = jax.random.PRNGKey(1)
dispGrad = jax.random.uniform(key, (3, 3))
initialState = self.compute_initial_state()
dt = 1.0

energy = self.energy_density(dispGrad, initialState, dt)
self.assertNear(energy, 133.3469451269987, 12)

state = self.compute_state_new(dispGrad, initialState, dt)
stateGold = np.array([0.988233534321, 0.437922586964, 0.433881277313,
0.437922586964, 1.378870045574, 0.079038974065,
0.433881277313, 0.079038974065, 1.055381729505])
self.assertArrayNear(state, stateGold, 12)


if __name__ == '__main__':
unittest.main()
17 changes: 9 additions & 8 deletions optimism/test/test_QuadratureRule.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def are_positive_weights(QuadratureRuleFactory, degree):

class TestQuadratureRules(TestFixture.TestFixture):
endpoints = (0.0, 1.0) # better if the quadrature rule module provided this
max_degree_2D = 6
max_degree_2D = 10
max_degree_1D = 25


Expand Down Expand Up @@ -87,13 +87,14 @@ def test_triangle_quadrature_points_in_domain(self):

def test_triangle_quadrature_exactness(self):
for degree in range(self.max_degree_2D + 1):
qr = QuadratureRule.create_quadrature_rule_on_triangle(degree)
for i in range(degree + 1):
for j in range(degree + 1 - i):
monomial = qr.xigauss[:,0]**i * qr.xigauss[:,1]**j
quadratureAnswer = np.sum(monomial * qr.wgauss)
exactAnswer = integrate_2D_monomial_on_triangle(i, j)
self.assertNear(quadratureAnswer, exactAnswer, 14)
with self.subTest(i=degree):
qr = QuadratureRule.create_quadrature_rule_on_triangle(degree)
for i in range(degree + 1):
for j in range(degree + 1 - i):
monomial = qr.xigauss[:,0]**i * qr.xigauss[:,1]**j
quadratureAnswer = np.sum(monomial * qr.wgauss)
exactAnswer = integrate_2D_monomial_on_triangle(i, j)
self.assertNear(quadratureAnswer, exactAnswer, 14)

if __name__ == '__main__':
unittest.main()

0 comments on commit 1bb5080

Please sign in to comment.