Skip to content

Commit

Permalink
Add piecewise nd polynomials
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 23, 2024
1 parent 0afb5c5 commit 8905ed3
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 2 deletions.
5 changes: 3 additions & 2 deletions config/block_mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
max_epochs: 1
accelerator: 'cuda'
n: 10
n: 7
batch_size: 16
layer_type: polynomial_3d
layer_type: polynomial_3d # continuous_3d
segments: 5
train_fraction: 1.0

defaults:
Expand Down
1 change: 1 addition & 0 deletions examples/block_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, cfg: DictConfig):
layer1 = high_order_fc_layers(
layer_type=cfg.layer_type,
n=[3,n,n],
segments = cfg.segments,
in_features=1,
out_features=10,
intialization="constant_random",
Expand Down
14 changes: 14 additions & 0 deletions high_order_layers_torch/LagrangePolynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,20 @@ def interpolate(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:

return out_sum

class LagrangeBasisPiecewiseND(LagrangeBasisND) :
def interpolate(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
"""
Interpolates the input using the Lagrange basis.
:param x: size[batch, inputs, dimensions]
:param w: size[output, inputs, num_basis]
:returns: size[batch, output]
"""
basis = self._compute_basis(x, self.indexes) # [num_basis, batch, inputs]
out_sum = torch.einsum("ibk,bkoi->bo", basis, w) # [batch, output]

return out_sum


class FourierBasis:
def __init__(self, length: float):
"""
Expand Down
139 changes: 139 additions & 0 deletions high_order_layers_torch/PolynomialLayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,145 @@ def refine(
refine_polynomial_layer(layer_in=self, layer_out=layer_out)


class PiecewiseND(torch.nn.Module):
def __init__(
self,
n: Union[List[int], int],
in_features: int,
out_features: int,
segments: Union[List[int], int],
length: float = 2.0,
weight_magnitude: float = 1.0,
device: str = "cpu",
initialize: str = "constant_random",
**kwargs,
):
super().__init__()
self.n = [n] * in_features if isinstance(n, int) else n
self.segments = (
[segments] * in_features if isinstance(segments, int) else segments
)
self.dimensions = in_features
self.in_features = in_features
self.out_features = out_features
self.device = device
self._length = length
self._half = 0.5 * length

self.lagrange_basis = LagrangeBasisPiecewiseND(self.n, length=length, device=device)

# Calculate total number of weights needed
self.weights_per_segment = math.prod(self.n)
self.total_segments = math.prod(self.segments) # per block
total_weights = in_features*out_features * self.total_segments * self.weights_per_segment

self.w = nn.Parameter(torch.empty(total_weights, device=device))

if initialize == "constant_random":
self._constant_random_initialization(weight_magnitude)
else:
self.w.data.uniform_(
-weight_magnitude / in_features, weight_magnitude / in_features
)

def _constant_random_initialization(self, weight_magnitude):
# TODO: verify this.
segment_values = (
torch.rand(self.out_features, self.total_segments, device=self.device)
* 2
* weight_magnitude
- weight_magnitude
)
self.w.data = segment_values.repeat_interleave(self.weights_per_segment)

def which_segment(self, x: torch.Tensor) -> torch.Tensor:
return (
(
(x + self._half)
/ self._length
* torch.tensor(self.segments, device=self.device)
)
.long()
.clamp(torch.tensor(0, device=self.device), torch.tensor(self.segments, device=self.device) - 1)
)

def x_local(self, x_global: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
x_min = self._eta(index)
x_max = self._eta(index + 1)
return self._length * ((x_global - x_min) / (x_max - x_min)) - self._half

def x_global(self, x_local: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
x_min = self._eta(index)
x_max = self._eta(index + 1)
return ((x_local + self._half) / self._length) * (x_max - x_min) + x_min

def _eta(self, index: torch.Tensor) -> torch.Tensor:
return (
index.float() / torch.tensor(self.segments, device=self.device).float() * 2
- 1
)

def forward(self, x: torch.Tensor) -> torch.Tensor:

# Get segment indices for each dimension
segment_indices = self.which_segment(x)

# Calculate local coordinates
x_local = self.x_local(x, segment_indices)

# Determine which weights are active
weight_indices = self._get_weight_indices(segment_indices)

# Reshape weights for each segment
w = self._reshape_weights(weight_indices)

# Interpolate using Lagrange basis
result = self.lagrange_basis.interpolate(x_local, w)

return result

def _get_weight_indices(self, segment_indices):
# Convert N-dimensional segment indices to flat indices
batch_size, num_inputs = segment_indices.shape[:2]
flat_indices = torch.zeros(
batch_size, num_inputs, dtype=torch.long, device=self.device
)

for dim in range(self.dimensions):
flat_indices += segment_indices[..., dim] * math.prod(
self.segments[dim + 1 :]
)

# Expand indices for all basis functions within each segment
# basically, each basis has a weight, and now we know the corrected
# segment, here we are grabbing all the weights belonging to that
# segment, I believe there is a faster way, but, this should work
expanded_indices = flat_indices.unsqueeze(
-1
) * self.weights_per_segment + torch.arange(
self.weights_per_segment, device=self.device
)

# [batch, input, weights_per_segment] note this is missing the output dimension
return expanded_indices

def _reshape_weights(self, weight_indices):
# Reshape weights based on weight indices
batch_size, num_inputs, _ = weight_indices.shape
weight_indices = weight_indices.unsqueeze(2).expand(-1, -1, self.out_features, -1)

# Select weights for each input point
selected_weights = self.w[weight_indices]

# Reshape to [out_features, num_inputs, batch_size, weights_per_segment]
reshaped_weights = selected_weights.view(
batch_size, num_inputs, self.out_features, self.weights_per_segment
)

# Permute to get the desired shape [out_features, num_inputs, batch_size, weights_per_segment]
return reshaped_weights


class PiecewisePolynomial(Piecewise):
def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions high_order_layers_torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def switch_discontinuous(**kwargs):
"polynomial_3d" : Polynomial3D,
"polynomial_4d" : Polynomial4D,
"polynomial_5d" : Polynomial5D,
"continuous_nd" : PiecewiseND,
}

convolutional_layers = {
Expand Down

0 comments on commit 8905ed3

Please sign in to comment.