diff --git a/high_order_layers_torch/Basis.py b/high_order_layers_torch/Basis.py index 7f59839..092a2e2 100644 --- a/high_order_layers_torch/Basis.py +++ b/high_order_layers_torch/Basis.py @@ -408,43 +408,42 @@ def interpolate(self, x: Tensor, w: Tensor) -> Tensor: return out_sum -class BasisFlatND: - """ - Single N dimensional element. - """ - - def __init__( - self, - n: int, - basis: Callable[[Tensor, list[int]], float], - dimensions: int, - **kwargs - ): - self.n = n - self.basis = basis - self.dimensions = dimensions - a = torch.arange(n) - self.indexes = ( - torch.stack(torch.meshgrid([a] * dimensions)) - .reshape(dimensions, -1) - .T.long() - ) - self.num_basis = basis.num_basis +# class BasisFlatND: + +# def __init__( +# self, +# n: int, +# basis: Callable[[Tensor, list], float], +# dimensions: int, +# **kwargs +# ): +# self.n = n +# self.basis = basis +# self.dimensions = dimensions +# a = torch.arange(n) +# self.indexes = ( +# torch.stack(torch.meshgrid([a] * dimensions)) +# .reshape(dimensions, -1) +# .T.long() +# ) +# self.num_basis = basis.num_basis + +# def interpolate(self, x: Tensor, w: Tensor) -> Tensor: +# """ +# :param x: size[batch, input, dimension] +# :param w: size[output, input, basis] +# :returns: size[batch, output] +# """ +# basis = [] +# for index in self.indexes: +# basis_j = self.basis(x, index=index) +# basis.append(basis_j) +# basis = torch.stack(basis) +# out_sum = torch.einsum("ijk,lki->jl", basis, w) + +# return out_sum - def interpolate(self, x: Tensor, w: Tensor) -> Tensor: - """ - :param x: size[batch, input, dimension] - :param w: size[output, input, basis] - :returns: size[batch, output] - """ - basis = [] - for index in self.indexes: - basis_j = self.basis(x, index=index) - basis.append(basis_j) - basis = torch.stack(basis) - out_sum = torch.einsum("ijk,lki->jl", basis, w) - return out_sum class BasisFlatProd: diff --git a/high_order_layers_torch/LagrangePolynomial.py b/high_order_layers_torch/LagrangePolynomial.py index 99ae585..e10b581 100644 --- a/high_order_layers_torch/LagrangePolynomial.py +++ b/high_order_layers_torch/LagrangePolynomial.py @@ -22,6 +22,64 @@ def chebyshevLobatto(n: int): return -torch.cos(torch.pi * torch.arange(n) / (n - 1)) +class LagrangeBasisND: + """ + Single N dimensional element with Lagrange basis interpolation. + """ + def __init__( + self, n: int, length: float = 2.0, dimensions: int = 2, device: str = "cpu", **kwargs + ): + self.n = n + self.dimensions = dimensions + self.X = (length / 2.0) * chebyshevLobatto(n).to(device) + self.device = device + self.denominators = self._compute_denominators() + self.num_basis = int(math.pow(n, dimensions)) + + a = torch.arange(n) + self.indexes = ( + torch.stack(torch.meshgrid([a] * dimensions, indexing="ij")) + .reshape(dimensions, -1) + .T.long().to(self.device) + ) + + def _compute_denominators(self): + X_diff = self.X.unsqueeze(0) - self.X.unsqueeze(1) # [n, n] + denom = torch.where(X_diff == 0, torch.tensor(1.0, device=self.device), X_diff) + return denom + + def _compute_basis(self, x, indexes): + """ + Computes the basis values for all index combinations. + :param x: [batch, inputs, dimensions] + :param indexes: [num_basis, dimensions] + :returns: basis values [num_basis, batch, inputs] + """ + x_diff = x.unsqueeze(-1) - self.X # [batch, inputs, dimensions, n] + mask = (indexes.unsqueeze(1).unsqueeze(2).unsqueeze(4) != torch.arange(self.n, device=self.device).view(1, 1, 1, 1, self.n)) + denominators = self.denominators[indexes] # [num_basis, dimensions, n] + + b = torch.where(mask, x_diff.unsqueeze(0) / denominators.unsqueeze(1).unsqueeze(2), torch.tensor(1.0, device=self.device)) + #print('b.shape', b.shape) + r = torch.prod(b, dim=-1) # [num_basis, batch, inputs, dimensions] + + return r.prod(dim=-1) # [num_basis, batch, inputs] + + def interpolate(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + """ + Interpolates the input using the Lagrange basis. + :param x: size[batch, inputs, dimensions] + :param w: size[output, inputs, num_basis] + :returns: size[batch, output] + """ + basis = self._compute_basis(x, self.indexes) # [num_basis, batch, inputs] + #print('bassis.shape', basis.shape, 'w.shape', w.shape) + out_sum = torch.einsum("ibk,oki->bo", basis, w) # [batch, output] + + return out_sum + + + class FourierBasis: def __init__(self, length: float): """ @@ -75,47 +133,6 @@ def __call__(self, x, j: int): return ans -class LagrangeBasisND: - - def __init__( - self, n: int, length: float = 2.0, dimensions: int = 2, device: str = "cpu", **kwargs - ): - self.n = n - self.dimensions = dimensions - self.X = (length / 2.0) * chebyshevLobatto(n).to(device) - self.device = device - self.denominators = self._compute_denominators() - self.num_basis = int(math.pow(n, dimensions)) - - def _compute_denominators(self): - - X_diff = self.X.unsqueeze(0) - self.X.unsqueeze(1) # [n, n] - denom = torch.where(X_diff == 0, torch.tensor(1.0, device=self.device), X_diff) - return denom - - def __call__(self, x, index: list): - """ - TODO: I believe we can make this even more efficient if we - calculate all basis at once instead of one at a time and - we'll be able to do the whole thing as a cartesian product, - but this is pretty fast - O(n)*O(dims). The x_diff computation - is redundant as it's the same for every basis. This function - will be called O(n^dims) times so O(dims*n^(dims+1)) - :param x: [batch, inputs, dimensions] - :param index : [dimensions] - :returns: basis value [batch, inputs] - """ - x_diff = x.unsqueeze(-1) - self.X # [batch, inputs, dimensions, n] - indices = torch.tensor(index, device=self.device).unsqueeze(0).unsqueeze(0).unsqueeze(-1) - mask = torch.arange(self.n, device=self.device).unsqueeze(0).unsqueeze(0).unsqueeze(0) != indices - denominators = self.denominators[index] # [dimensions, n] - - b = torch.where(mask, x_diff / denominators, torch.tensor(1.0, device=self.device)) - r = torch.prod(b, dim=-1) # [batch, inputs, dimensions] - - return r.prod(dim=-1) - - class LagrangeBasis1: """ TODO: Degenerate case, test this and see if it works with everything else. @@ -181,7 +198,7 @@ class LagrangePolyFlat(BasisFlat): def __init__(self, n: int, length: float = 2.0, **kwargs): super().__init__(n, get_lagrange_basis(n, length), **kwargs) - +""" class LagrangePolyFlatND(BasisFlatND): def __init__(self, n: int, length: float = 2.0, dimensions: int = 2, **kwargs): super().__init__( @@ -190,6 +207,16 @@ def __init__(self, n: int, length: float = 2.0, dimensions: int = 2, **kwargs): dimensions=dimensions, **kwargs ) +""" + +class LagrangePolyFlatND(LagrangeBasisND): + def __init__(self, n: int, length: float = 2.0, dimensions: int = 2, **kwargs): + super().__init__( + n, + length=length, + dimensions=dimensions, + **kwargs + ) class LagrangePolyFlatProd(BasisFlatProd): diff --git a/tests/test_layers.py b/tests/test_layers.py index 786bca7..f1f3a7d 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -20,7 +20,7 @@ from high_order_layers_torch.networks import * from high_order_layers_torch.PolynomialLayers import * import torch -from high_order_layers_torch.Basis import BasisFlatND +#from high_order_layers_torch.Basis import BasisFlatND torch.set_default_device(device="cpu") @@ -39,6 +39,9 @@ def test_variable_dimension_input(n, in_features, out_features, segments): layer(a) """ +""" +These have both been combined into the new LagrangeBasisND so +the computation is faster. def test_basis_nd() : dimensions = 3 n=5 @@ -75,7 +78,7 @@ def test_lagrange_basis(dimensions): print("res2", res) assert res[1] == 1 assert torch.abs(res[0]) < 1e-12 - +""" def test_nodes(): ans = chebyshevLobatto(20)