From ae91ef58945f22f2a19973b60e6f5399c0590b8a Mon Sep 17 00:00:00 2001 From: jloveric Date: Mon, 17 Jun 2024 19:09:51 -0700 Subject: [PATCH] Adding the ND basis --- high_order_layers_torch/Basis.py | 32 ++++++++++++++++++++++++++++++++ tests/test_layers.py | 12 +++++++++++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/high_order_layers_torch/Basis.py b/high_order_layers_torch/Basis.py index 6346e17..f9197d8 100644 --- a/high_order_layers_torch/Basis.py +++ b/high_order_layers_torch/Basis.py @@ -397,6 +397,7 @@ def interpolate(self, x: Tensor, w: Tensor) -> Tensor: """ basis = [] + # TODO: get rid of this loop! for j in range(self.n): basis_j = self.basis(x, j) basis.append(basis_j) @@ -406,6 +407,37 @@ def interpolate(self, x: Tensor, w: Tensor) -> Tensor: return out_sum +class BasisFlatND: + """ + Single N dimensional element. + """ + + def __init__( + self, n: int, basis: Callable[[Tensor, list[int]], float], dimensions: int + ): + self.n = n + self.basis = basis + self.dimensions = dimensions + a = torch.arange(n) + self.indexes = torch.stack(torch.meshgrid([a]*dimensions)).reshape(dimensions, -1).T + + def interpolate(self, x: Tensor, w: Tensor) -> Tensor: + """ + :param x: size[batch, input, dimension] + :param w: size[input, output, basis] + :returns: size[batch, output] + """ + + basis = [] + for index in range(self.indexes): + basis_j = self.basis(x, index=index) + basis.append(basis_j) + basis = torch.stack(basis) + out_sum = torch.einsum("ijk,lki->jl", basis, w) + + return out_sum + + class BasisFlatProd: """ Single segment. diff --git a/tests/test_layers.py b/tests/test_layers.py index d7f42be..54cdf0c 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1,7 +1,7 @@ import os import pytest - +import math from high_order_layers_torch.FunctionalConvolution import * from high_order_layers_torch.LagrangePolynomial import * from high_order_layers_torch.LagrangePolynomial import LagrangePoly, LagrangeBasisND @@ -20,6 +20,7 @@ from high_order_layers_torch.networks import * from high_order_layers_torch.PolynomialLayers import * import torch +from high_order_layers_torch.Basis import BasisFlatND torch.set_default_device(device="cpu") @@ -38,6 +39,15 @@ def test_variable_dimension_input(n, in_features, out_features, segments): layer(a) """ +def test_basis_nd() : + dimensions = 3 + n=5 + lb = LagrangeBasisND(n=n, dimensions=dimensions) + basis = BasisFlatND(n=n, dimensions=dimensions, basis=lb) + + # The indexes should be unique so we cover all indices + assert len(set(basis.indexes)) == math.pow(5, dimensions) + @pytest.mark.parametrize("dimensions", [1, 2, 3, 4]) def test_lagrange_basis(dimensions):