Skip to content

Commit

Permalink
Make this a bit more efficient
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 19, 2024
1 parent f6cb205 commit fc8b886
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions high_order_layers_torch/LagrangePolynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,31 +88,32 @@ def __init__(
self.num_basis = int(math.pow(n, dimensions))

def _compute_denominators(self):
denom = torch.ones([self.n, self.n], dtype=torch.float32, device=self.device)

for j in range(self.n):
for m in range(self.n):
if m != j:
denom[j, m] = self.X[j] - self.X[m]
X_diff = self.X.unsqueeze(0) - self.X.unsqueeze(1) # [n, n]
denom = torch.where(X_diff == 0, torch.tensor(1.0, device=self.device), X_diff)
return denom

def __call__(self, x, index: list[int]):
def __call__(self, x, index: list):
"""
TODO: I believe we can make this even more efficient if we
calculate all basis at once instead of one at a time and
we'll be able to do the whole thing as a cartesian product,
but this is pretty fast - O(n)*O(dims). The x_diff computation
is redundant as it's the same for every basis. This function
will be called O(n^dims) times so O(dims*n^(dims+1))
:param x: [batch, inputs, dimensions]
:param index : [dimensions]
:returns: basis value [batch, inputs]
"""
x_diff = x.unsqueeze(-1) - self.X # [batch, inputs, dimensions, basis]
r = 1.0
for i, basis_i in enumerate(index):
b = torch.where(
torch.arange(self.n, device=self.device) != basis_i,
x_diff[:, :, i, :] / self.denominators[basis_i],
torch.tensor(1.0, device=self.device),
)
r *= torch.prod(b, dim=-1)

return r
x_diff = x.unsqueeze(-1) - self.X # [batch, inputs, dimensions, n]
indices = torch.tensor(index, device=self.device).unsqueeze(0).unsqueeze(0).unsqueeze(-1)
mask = torch.arange(self.n, device=self.device).unsqueeze(0).unsqueeze(0).unsqueeze(0) != indices
denominators = self.denominators[index] # [dimensions, n]

b = torch.where(mask, x_diff / denominators, torch.tensor(1.0, device=self.device))
r = torch.prod(b, dim=-1) # [batch, inputs, dimensions]

return r.prod(dim=-1)


class LagrangeBasis1:
Expand Down

0 comments on commit fc8b886

Please sign in to comment.