Skip to content

Commit

Permalink
push lie integrators
Browse files Browse the repository at this point in the history
  • Loading branch information
Eelco Hoogendoorn committed Aug 27, 2024
1 parent 5197467 commit 48b87a6
Show file tree
Hide file tree
Showing 2 changed files with 305 additions and 0 deletions.
134 changes: 134 additions & 0 deletions numga/examples/physics/lie_integrators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Implementations of
https://pure.tue.nl/ws/portalfiles/portal/3801945/900594800955441.pdf
https://www.research.unipd.it/retrieve/e14fb26f-e9d2-3de1-e053-1705fe0ac030/Ortolan_PhD11.pdf
might want to add manual computed jacobians in here as well?
need to be able to compose concrete operators to do so
would be a good test case for such a rewrite
"""

import jax
import jax.numpy as jnp


def newton_solver(fn, n=10):
"""Return callable that performs n newton steps on fn"""
jac_fn = jax.jacfwd(fn)
step = lambda i, x: x - jnp.linalg.solve(jac_fn(x), fn(x))
return lambda x: jax.lax.fori_loop(0, n, step, x)


def newton_solver_wrap(func, init):
"""wrap newton solve of a multivector function"""
func_wrap = lambda x: func(init.copy(values=x)).values
solver = newton_solver(func_wrap)
return init.copy(values=solver(init.values))


# second order accurate cayley log/exp approximations
def exp2(b):
return b.exp_quadratic()
def log2(r):
return r.motor_log_quadratic()
# higher order accurate log/exp approximations
def exp4(b):
return exp2(b / 2).squared()
def log4(r):
return log2(r.motor_square_root()) * 2


def variational_lie_verlet(motor, rate, inertia, inertia_inv, dt, ext_forque):
"""Variational Lie-Verlet"""
energy = lambda rate: \
inertia(rate).wedge(rate) * rate * (dt / 2)
forque = lambda motor, rate: \
ext_forque(motor, rate) - inertia(rate).commutator(rate)

implicit = lambda rh: -rh + rate + \
inertia_inv(forque(motor, rh) - energy(rh)) * (dt / 2)
rate_half = newton_solver_wrap(implicit, init=rate)

motor_new = motor * exp2(rate_half * dt / -4)

rate_new = rate_half + \
inertia_inv(forque(motor_new, rate_half) + energy(rate_half)) * (dt / 2)

return motor_new, rate_new


def explicit_lie_newmark(motor, rate, inertia, inertia_inv, dt, ext_forque):
"""Explicit Lie-Newmark method"""
impulse = lambda motor, rate: \
(ext_forque(motor, rate) - inertia(rate).commutator(rate)) * (dt / 2)

half_rate_step = inertia_inv(impulse(motor, rate))
rate_half = rate + half_rate_step
motor_new = exp2(rate_half * dt / 2) * motor
# motor_new = (rate_half * dt / 2).exp() * motor

implicit = lambda rn: \
-rn + rate_half + inertia_inv(impulse(motor_new, rn))
rate_new = newton_solver_wrap(implicit, rate_half + half_rate_step)

return motor_new, rate_new


def explicit_lie_newmark_rev(motor, rate, inertia, inertia_inv, dt, ext_forque):
"""My own crazy mix"""
impulse = lambda motor, rate: \
(ext_forque(motor, rate) - inertia(rate).commutator(rate)) * (dt / 2)

half_rate_step = inertia_inv(impulse(motor, rate))
implicit = lambda rh: \
-rh + rate + inertia_inv(impulse(motor, rh))
rate_half = newton_solver_wrap(implicit, rate + half_rate_step)

motor_new = exp2(rate_half * dt / 2) * motor

# double-sided implicit does not seem to work; need forward/backward cancellation
# half_rate_step = inertia_inv(impulse(motor_new, rate_half))
# implicit = lambda rn: -rn + rate_half + inertia_inv(impulse(motor_new, rn))
# rate_new = newton_solver_wrap(implicit, rate_half + half_rate_step)

half_rate_step = inertia_inv(impulse(motor_new, rate_half))
rate_new = rate_half + half_rate_step

return motor_new, rate_new


def new3(motor, rate, inertia, inertia_inv, dt, ext_forque):
""""""
# FIXME: works like garbage?
motor_half = exp2(rate * dt / 4) * motor
impulse = lambda motor, rate: \
(ext_forque(motor, rate) - inertia(rate).commutator(rate)) * dt
rate_new = rate + inertia_inv(impulse(motor, rate))

rhs = exp4(rate * dt / 4) >> inertia(rate)
implicit = lambda r: \
exp4(r * dt / -4) >> inertia(r) - rhs
rate_new = newton_solver_wrap(implicit, rate_new)

motor_new = exp2(rate_new * dt / 4) * motor_half

return motor_new, rate_new


def explicit_rk1(motor, rate, inertia, inertia_inv, dt, ext_forque):
"""RK1 integration of lie state"""
dr = lambda r: inertia_inv(ext_forque(motor, r) - inertia(r).commutator(r))
from numga.examples.integrators import RK1
rate = RK1(dr, rate, dt)
motor = exp2(rate * dt / 2) * motor
return motor, rate


def explicit_rk4(motor, rate, inertia, inertia_inv, dt, ext_forque):
"""RK4 integration of lie state"""
dr = lambda r: inertia_inv(ext_forque(motor, r) - inertia(r).commutator(r))
from numga.examples.integrators import RK4
rate = RK4(dr, rate, dt)
motor = exp2(rate * dt / 2) * motor
# motor = (rate * dt / 2).exp() * motor
return motor, rate
171 changes: 171 additions & 0 deletions numga/examples/physics/test_lie_integrators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import numpy as np
from numga.examples.physics.lie_integrators import *


def test_newton():
func = lambda x: x**2 - 1
solver = newton_solver(func)
r= solver(jnp.array([1.5]))
print(r)


def test_log_exp():
from numga.backend.numpy.context import NumpyContext as Context

ctx = Context('x+y+z+')
b = ctx.multivector.bivector([2,0,0])
r = exp2(b)
print(r)
print(log2(r))

ctx = Context('x+y+z+w+')
b = ctx.multivector.bivector([1,2,3,4,5,6])
r = exp2(b)
print(r)
print(log2(r))



def make_n_cube(N):
b = ((np.arange(2 ** N)[:, None] & (1 << np.arange(N))) > 0)
return (2 * b - 1)

def make_n_rect(N):
return make_n_cube(N) * (np.arange(N) + 1)


def test_tennis_racket():
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from numga.backend.jax.context import JaxContext as Context
# works for p=2,3,4,5
# p>3 is fascinating; some medial axes become seemingly chaotic
# but stranger still, some medial axes actually stabilize?
# also, different unstable axes appear to show qualitatively different behavior
context = Context((4, 0, 0), dtype=jnp.float64)

dt = 0.2
runtime = 2000


nd = context.algebra.description.n_dimensions
nb = len(context.subspace.bivector())

# create a point cloud with distinct moments of inertia on each axis
points = make_n_rect(nd)
points = context.multivector.vector(points).dual()

inertia = points.inertia_map().sum(axis=-3)
inertia_inv = inertia.inverse()

rate = context.multivector.bivector((np.eye(nb) + np.random.normal(size=(nb, nb)) * 1e-5))
motor = context.multivector.motor() * np.ones((nb))
kinetic = lambda rate: inertia(rate).wedge(rate)


import functools
e = context.multivector.empty() #* np.ones((nb))
from numga.examples.physics.lie_integrators import variational_lie_verlet as integrator
integrator = functools.partial(integrator, dt=dt, ext_forque=lambda m, r: e)
integrator = jax.vmap(integrator, (0,0, None, None))
integrator = jax.jit(integrator)

states = []
for i in range(int(runtime / dt)):
motor, rate = integrator(motor, rate, inertia, inertia_inv)
# states.append(kinetic(rate).values)
# states.append(motor.values)
states.append(rate.values)

states = jnp.array(states)
import matplotlib.pyplot as plt
fig, ax = plt.subplots(nb, 1)
for i in range(nb):
ax[i].plot(states[:, i])
plt.show()


def test_2dpga():
"""test energy drift in 2d pga
in the force-free case, things seem quite alright
RK4 does a decent job; dissipative for large timesteps
Testing of linear posistion stability raises a lot of questions though
lie-verlet seems very broken,
though lie-newmark seems to do ok if imperfect,
but not much better than rk4, if not worse?
"""
np.random.seed(0)
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from numga.backend.jax.context import JaxContext as Context
# context = Context((3, 0, 0), dtype=jnp.float64)
context = Context('x+y+w0', dtype=jnp.float64)

# dt = 1/4
dt = .1
runtime = 2000


nd = context.algebra.description.n_dimensions
nb = len(context.subspace.bivector())

# create a point cloud with distinct moments of inertia on each axis
points = make_n_rect(nd)
points = context.multivector.vector(points).dual().normalized()

inertia = points.inertia_map().sum(axis=-3)
inertia_inv = inertia.inverse()

# rate = context.multivector.bivector((np.eye(nb) + np.random.normal(size=(nb, nb)) * 1e-5)[1])
rate = context.multivector.bivector([1,1,1])
# rate = context.multivector.bivector(np.random.normal(size=(nb))*0.3)
print(rate)
motor = context.multivector.motor()
kinetic = lambda rate: inertia(rate).wedge(rate)


import functools
e = context.multivector.empty()
# from numga.examples.physics.lie_integrators import variational_lie_verlet as integrator
# from numga.examples.physics.lie_integrators import explicit_lie_newmark as integrator
from numga.examples.physics.lie_integrators import explicit_rk4 as integrator
integrator = functools.partial(integrator, dt=dt, ext_forque=lambda m, r: e)
integrator = jax.jit(integrator)

states = []
energy = []
for i in range(int(runtime / dt)):
motor, rate = integrator(motor, rate, inertia, inertia_inv)
energy.append(kinetic(rate).values)
states.append(motor.values)
# states.append(rate.values)

# print(jnp.array(states))
import matplotlib.pyplot as plt
plt.plot(np.array(energy))
plt.show()
plt.plot(np.array(states))#-200:])
plt.show()


def test_potential():
"""
test conservations props of rotor based potentials
"""


def test_verlet():
"""
need to test interaction between lie integrators and verlet correction steps
should we track rotor delta between pre and post integrate;
and derive a rate-delta from that?
# FIXME: initial delta attemps seem broken?
or should we backtrace the entire forward integrator?
what we do right now is essentially backtrace forward euler
"""

0 comments on commit 48b87a6

Please sign in to comment.