Skip to content

Commit

Permalink
Adding the ND basis
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 18, 2024
1 parent f4bfa50 commit ae91ef5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
32 changes: 32 additions & 0 deletions high_order_layers_torch/Basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def interpolate(self, x: Tensor, w: Tensor) -> Tensor:
"""

basis = []
# TODO: get rid of this loop!
for j in range(self.n):
basis_j = self.basis(x, j)
basis.append(basis_j)
Expand All @@ -406,6 +407,37 @@ def interpolate(self, x: Tensor, w: Tensor) -> Tensor:
return out_sum


class BasisFlatND:
"""
Single N dimensional element.
"""

def __init__(
self, n: int, basis: Callable[[Tensor, list[int]], float], dimensions: int
):
self.n = n
self.basis = basis
self.dimensions = dimensions
a = torch.arange(n)
self.indexes = torch.stack(torch.meshgrid([a]*dimensions)).reshape(dimensions, -1).T

def interpolate(self, x: Tensor, w: Tensor) -> Tensor:
"""
:param x: size[batch, input, dimension]
:param w: size[input, output, basis]
:returns: size[batch, output]
"""

basis = []
for index in range(self.indexes):
basis_j = self.basis(x, index=index)
basis.append(basis_j)
basis = torch.stack(basis)
out_sum = torch.einsum("ijk,lki->jl", basis, w)

return out_sum


class BasisFlatProd:
"""
Single segment.
Expand Down
12 changes: 11 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

import pytest

import math
from high_order_layers_torch.FunctionalConvolution import *
from high_order_layers_torch.LagrangePolynomial import *
from high_order_layers_torch.LagrangePolynomial import LagrangePoly, LagrangeBasisND
Expand All @@ -20,6 +20,7 @@
from high_order_layers_torch.networks import *
from high_order_layers_torch.PolynomialLayers import *
import torch
from high_order_layers_torch.Basis import BasisFlatND

torch.set_default_device(device="cpu")

Expand All @@ -38,6 +39,15 @@ def test_variable_dimension_input(n, in_features, out_features, segments):
layer(a)
"""

def test_basis_nd() :
dimensions = 3
n=5
lb = LagrangeBasisND(n=n, dimensions=dimensions)
basis = BasisFlatND(n=n, dimensions=dimensions, basis=lb)

# The indexes should be unique so we cover all indices
assert len(set(basis.indexes)) == math.pow(5, dimensions)


@pytest.mark.parametrize("dimensions", [1, 2, 3, 4])
def test_lagrange_basis(dimensions):
Expand Down

0 comments on commit ae91ef5

Please sign in to comment.