Skip to content

Commit

Permalink
some fixes to la_inverse and added test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Eelco Hoogendoorn committed Nov 1, 2023
1 parent 5a40828 commit f864f9d
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 40 deletions.
17 changes: 7 additions & 10 deletions numga/backend/jax/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ class JaxOperator(AbstractConcreteOperator):
def shape(self):
"""Shape of the kernel, modulo subspace axis"""
return self.kernel.shape[:-(self.arity+1)]
# @property
# def kernel_shape(self):
# """Shape of the kernel, modulo subspace axis"""
# return self.kernel.shape[self.arity:]

def broadcast_allocate(self, inputs: Tuple[JaxMultiVector], output=None) -> JaxMultiVector:
"""Allocation for a set of inputs, with multivector components as last axis,
Expand Down Expand Up @@ -124,19 +120,20 @@ class JaxDenseOperator(JaxOperator):

def __init__(self, *args, **kwargs):
super(JaxDenseOperator, self).__init__(*args, **kwargs)
# put kernel on device
# self.jax_kernel = np.array(self.kernel)
# precompute reshape operations
self.shapes = self.broadcasting_shapes
# contraction over all kernel input axes
self.sum_axes = tuple(-(a + 2) for a in range(self.arity))

@partial(jax.jit, static_argnums=(0,))
def __call__(self, *inputs: Tuple[JaxMultiVector]) -> JaxMultiVector:
# kernel, shapes = self.precompute
shape = jnp.broadcast_shapes(*(i.shape for i in inputs))
return self.context.multivector(
values=jnp.sum(
math.prod((i.values.reshape(i.shape + s) for i, s in zip(inputs, self.shapes)), start=self.kernel),
axis=range(len(shape), len(shape)+self.arity)# FIXME: negative indexing rather than len(shape)?
math.prod(
(i.values.reshape(i.shape + s) for i, s in zip(inputs, self.shapes)),
start=self.kernel
),
axis=self.sum_axes
),
subspace=self.output
)
Expand Down
5 changes: 3 additions & 2 deletions numga/backend/numpy/multivector.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ def flatten(self):

def la_inverse(self):
"""Inverse of x such that x * x.inverse() == 1 == x.inverse() * x"""
op = self.operator.product(self.subspace, self.subspace)
inverse_subspace = self.operator.inverse_factor(self.subspace).output
op = self.operator.product(self.subspace, inverse_subspace)
k = op.partial({0: self}).kernel
idx, = np.flatnonzero(op.output.blades == 0) # grab index of scalar of output; zero or raises
r = np.linalg.solve( # use least squares to solve for inverse
np.einsum('...ji,...ki->...jk', k, k), # k.T * k
k[..., idx], #equal to k.T * unit_scalar
)
return self.copy(r)
return self.context.multivector(values=r, subspace=inverse_subspace)
12 changes: 0 additions & 12 deletions numga/backend/numpy/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,6 @@ def transpose(self):
axes=(self.operator.axes[1], self.operator.axes[0]),
)
)
# def quadratic(self, l, r=None):
# """Compute quadratic form of unary operator; l*O*r"""
# assert self.arity == 1
# r = l if r is None else r
# output = self.broadcast_allocate((l, r), self.algebra.subspace.scalar())
# np.einsum(
# '...ij,...i,...j->...',
# self.kernel, l.values, r.values,
# out=output.values[..., 0],
# optimize=True
# )
# return output


class NumpyEinsumOperator(NumpyOperator):
Expand Down
50 changes: 39 additions & 11 deletions numga/backend/test/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,30 @@
from numga.backend.numpy.operator import NumpySparseOperator

import pytest
from numga.multivector.test.util import random_motor, random_subspace


def test_basic():
print()
ga = NumpyContext('x+y+z+w0')
Q, V, B = ga.subspace.even_grade(), ga.subspace.vector(), ga.subspace.bivector()
q, v = ga.multivector(Q), ga.multivector(V)
q, v = random_motor(ga), random_subspace(ga, V)
print(q)
print(v)

output = q.sandwich(v)
op = ga.operator.sandwich(Q, V)
op_partial = op.partial({0: q, 2: q})
op_partial2 = q.sandwich_map(V) # same as above
assert np.allclose(op_partial.kernel, op_partial2.kernel)

r0 = (q * v * ~q).select_subspace(v.subspace) # direct sandwich product
r1 = q.sandwich(v) # using optimized sandwich operator
r2 = op_partial(v) # using pre-bound q argument
r3 = op_partial2(v)

op = q.sandwich_map()
print('our projective matrix')
print(op.operator.kernel)
print(op.operator.axes)
print(op(v))
assert np.allclose(r0.values, r1.values)
assert np.allclose(r0.values, r2.values)
assert np.allclose(r0.values, r3.values)

def print_op(op):
print(op.operator.axes)
Expand Down Expand Up @@ -108,11 +116,29 @@ def test_operator_composition():
print()


def test_inverse():
def check_inverse(x, i):
assert np.allclose((x * i - 1).values, 0, atol=1e-9)
assert np.allclose((i * x - 1).values, 0, atol=1e-9)
def check_inverse(x, i):
assert np.allclose((x * i - 1).values, 0, atol=1e-9)
assert np.allclose((i * x - 1).values, 0, atol=1e-9)


def test_inverse_counterexample():
"""test counterexample that inverse_factor method fails to solve
afaik, the inverse_factor method should work for all multivectors < 6d
"""
ga = NumpyContext('x+y+z+a+b+c+')
mv = ga.multivector
x = 2 + mv.xy + mv.ab + mv.xyzabc

i = x.la_inverse()
check_inverse(x, i)

with pytest.raises(Exception):
i = x.inverse()
check_inverse(x, i)


def test_inverse():
"""test some general inversion cases"""
ga = NumpyContext('x+y+z+w+')
V = ga.subspace.vector()
x = ga.multivector.vector(values=np.random.normal(size=(2, len(V))))
Expand All @@ -139,6 +165,8 @@ def check_inverse(x, i):
q = x.inverse_factor()
check_inverse(x, q / x.scalar_product(q))


def test_inverse_degenerate():
with pytest.raises(Exception):
ga = NumpyContext('x+w0')
x = ga.multivector.w
Expand Down
11 changes: 7 additions & 4 deletions numga/operator/test/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,20 @@ def test_basic():
def test_commutator():
"""Visualize the output grade of the commutator of pairs of i-j vectors"""
print()
algebra = Algebra('x+y+z+w+')
algebra = Algebra.from_pqr(8, 0, 0)
n = algebra.n_dimensions + 1
dims = range(n)
v = [algebra.subspace.k_vector(i) for i in dims]
r = -np.ones((n, n))
r = np.empty((n, n), dtype=object)

for i in dims:
for j in dims:
try:
r[i, j] = algebra.operator.anti_commutator(v[i], v[j]).output.grade()
s = str(np.unique(algebra.operator.commutator(v[i], v[j]).output.grades()))
r[i, j] = s.rjust(5, ' ')
except:
pass
print(r)
print('\n'.join(' '.join(q) for q in r))


def test_square_signs():
Expand All @@ -103,6 +105,7 @@ def test_square_signs():
from numga.multivector.test.util import random_subspace
import numpy.testing as npt


def test_inertia():
"""Test equivalence of composed ternary operators to their direct expression form"""
from numga.backend.numpy.context import NumpyContext
Expand Down
2 changes: 1 addition & 1 deletion numga/subspace/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def n_blades(self):
def __hash__(self):
return hash((id(self.algebra), hash(self.blades.tostring())))
def __eq__(self, other: "SubSpaceInterface"):
return (self is other) or ((self.algebra is other.algebra) and np.alltrue(self.blades == other.blades))
return (self is other) or ((self.algebra is other.algebra) and np.all(self.blades == other.blades))

def __contains__(self, other: "SubSpaceInterface"):
return set(other.blades) in set(self.blades)

0 comments on commit f864f9d

Please sign in to comment.