diff --git a/config/block_mnist.yaml b/config/block_mnist.yaml index 5300a95..a5ec06d 100644 --- a/config/block_mnist.yaml +++ b/config/block_mnist.yaml @@ -1,8 +1,9 @@ max_epochs: 1 accelerator: 'cuda' -n: 10 +n: 7 batch_size: 16 -layer_type: polynomial_3d +layer_type: polynomial_3d # continuous_3d +segments: 5 train_fraction: 1.0 defaults: diff --git a/examples/block_mnist.py b/examples/block_mnist.py index 6ae55d2..2ffd8f2 100644 --- a/examples/block_mnist.py +++ b/examples/block_mnist.py @@ -75,6 +75,7 @@ def __init__(self, cfg: DictConfig): layer1 = high_order_fc_layers( layer_type=cfg.layer_type, n=[3,n,n], + segments = cfg.segments, in_features=1, out_features=10, intialization="constant_random", diff --git a/high_order_layers_torch/LagrangePolynomial.py b/high_order_layers_torch/LagrangePolynomial.py index 7376c82..1248ef6 100644 --- a/high_order_layers_torch/LagrangePolynomial.py +++ b/high_order_layers_torch/LagrangePolynomial.py @@ -144,6 +144,20 @@ def interpolate(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: return out_sum +class LagrangeBasisPiecewiseND(LagrangeBasisND) : + def interpolate(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + """ + Interpolates the input using the Lagrange basis. + :param x: size[batch, inputs, dimensions] + :param w: size[output, inputs, num_basis] + :returns: size[batch, output] + """ + basis = self._compute_basis(x, self.indexes) # [num_basis, batch, inputs] + out_sum = torch.einsum("ibk,bkoi->bo", basis, w) # [batch, output] + + return out_sum + + class FourierBasis: def __init__(self, length: float): """ diff --git a/high_order_layers_torch/PolynomialLayers.py b/high_order_layers_torch/PolynomialLayers.py index e0fd3c5..5b7c417 100644 --- a/high_order_layers_torch/PolynomialLayers.py +++ b/high_order_layers_torch/PolynomialLayers.py @@ -356,6 +356,145 @@ def refine( refine_polynomial_layer(layer_in=self, layer_out=layer_out) +class PiecewiseND(torch.nn.Module): + def __init__( + self, + n: Union[List[int], int], + in_features: int, + out_features: int, + segments: Union[List[int], int], + length: float = 2.0, + weight_magnitude: float = 1.0, + device: str = "cpu", + initialize: str = "constant_random", + **kwargs, + ): + super().__init__() + self.n = [n] * in_features if isinstance(n, int) else n + self.segments = ( + [segments] * in_features if isinstance(segments, int) else segments + ) + self.dimensions = in_features + self.in_features = in_features + self.out_features = out_features + self.device = device + self._length = length + self._half = 0.5 * length + + self.lagrange_basis = LagrangeBasisPiecewiseND(self.n, length=length, device=device) + + # Calculate total number of weights needed + self.weights_per_segment = math.prod(self.n) + self.total_segments = math.prod(self.segments) # per block + total_weights = in_features*out_features * self.total_segments * self.weights_per_segment + + self.w = nn.Parameter(torch.empty(total_weights, device=device)) + + if initialize == "constant_random": + self._constant_random_initialization(weight_magnitude) + else: + self.w.data.uniform_( + -weight_magnitude / in_features, weight_magnitude / in_features + ) + + def _constant_random_initialization(self, weight_magnitude): + # TODO: verify this. + segment_values = ( + torch.rand(self.out_features, self.total_segments, device=self.device) + * 2 + * weight_magnitude + - weight_magnitude + ) + self.w.data = segment_values.repeat_interleave(self.weights_per_segment) + + def which_segment(self, x: torch.Tensor) -> torch.Tensor: + return ( + ( + (x + self._half) + / self._length + * torch.tensor(self.segments, device=self.device) + ) + .long() + .clamp(torch.tensor(0, device=self.device), torch.tensor(self.segments, device=self.device) - 1) + ) + + def x_local(self, x_global: torch.Tensor, index: torch.Tensor) -> torch.Tensor: + x_min = self._eta(index) + x_max = self._eta(index + 1) + return self._length * ((x_global - x_min) / (x_max - x_min)) - self._half + + def x_global(self, x_local: torch.Tensor, index: torch.Tensor) -> torch.Tensor: + x_min = self._eta(index) + x_max = self._eta(index + 1) + return ((x_local + self._half) / self._length) * (x_max - x_min) + x_min + + def _eta(self, index: torch.Tensor) -> torch.Tensor: + return ( + index.float() / torch.tensor(self.segments, device=self.device).float() * 2 + - 1 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # Get segment indices for each dimension + segment_indices = self.which_segment(x) + + # Calculate local coordinates + x_local = self.x_local(x, segment_indices) + + # Determine which weights are active + weight_indices = self._get_weight_indices(segment_indices) + + # Reshape weights for each segment + w = self._reshape_weights(weight_indices) + + # Interpolate using Lagrange basis + result = self.lagrange_basis.interpolate(x_local, w) + + return result + + def _get_weight_indices(self, segment_indices): + # Convert N-dimensional segment indices to flat indices + batch_size, num_inputs = segment_indices.shape[:2] + flat_indices = torch.zeros( + batch_size, num_inputs, dtype=torch.long, device=self.device + ) + + for dim in range(self.dimensions): + flat_indices += segment_indices[..., dim] * math.prod( + self.segments[dim + 1 :] + ) + + # Expand indices for all basis functions within each segment + # basically, each basis has a weight, and now we know the corrected + # segment, here we are grabbing all the weights belonging to that + # segment, I believe there is a faster way, but, this should work + expanded_indices = flat_indices.unsqueeze( + -1 + ) * self.weights_per_segment + torch.arange( + self.weights_per_segment, device=self.device + ) + + # [batch, input, weights_per_segment] note this is missing the output dimension + return expanded_indices + + def _reshape_weights(self, weight_indices): + # Reshape weights based on weight indices + batch_size, num_inputs, _ = weight_indices.shape + weight_indices = weight_indices.unsqueeze(2).expand(-1, -1, self.out_features, -1) + + # Select weights for each input point + selected_weights = self.w[weight_indices] + + # Reshape to [out_features, num_inputs, batch_size, weights_per_segment] + reshaped_weights = selected_weights.view( + batch_size, num_inputs, self.out_features, self.weights_per_segment + ) + + # Permute to get the desired shape [out_features, num_inputs, batch_size, weights_per_segment] + return reshaped_weights + + class PiecewisePolynomial(Piecewise): def __init__( self, diff --git a/high_order_layers_torch/layers.py b/high_order_layers_torch/layers.py index 4dd16d1..2ecdd52 100644 --- a/high_order_layers_torch/layers.py +++ b/high_order_layers_torch/layers.py @@ -272,6 +272,7 @@ def switch_discontinuous(**kwargs): "polynomial_3d" : Polynomial3D, "polynomial_4d" : Polynomial4D, "polynomial_5d" : Polynomial5D, + "continuous_nd" : PiecewiseND, } convolutional_layers = {