Skip to content

Commit

Permalink
switched psi tensor computation to double precision and implemented a…
Browse files Browse the repository at this point in the history
… fudge factor for theta_cutoff to avoid aliasing issues with the grid width
  • Loading branch information
bonevbs committed Jan 14, 2025
1 parent 55bbcb2 commit 15d0750
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 28 deletions.
27 changes: 15 additions & 12 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalizati
def _precompute_convolution_tensor_dense(
in_shape,
out_shape,
kernel_shape,
filter_basis,
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
theta_eps=1e-3,
transpose_normalization=False,
basis_norm_mode="none",
merge_quadrature=False,
Expand All @@ -106,21 +106,25 @@ def _precompute_convolution_tensor_dense(
nlat_out, nlon_out = out_shape

lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_in = torch.from_numpy(lats_in)
lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float() # array for accumulating non-zero indices
lats_out = torch.from_numpy(lats_out)

# compute the phi differences. We need to make the linspace exclusive to not double the last point
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1)[:-1]
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1, dtype=torch.float64)[:-1]

# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff

# compute quadrature weights that will be merged into the Psi tensor
if transpose_normalization:
quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0
quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
else:
quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0
quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0

out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in)
# array for accumulating non-zero indices
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64)

for t in range(nlat_out):
for p in range(nlon_out):
Expand All @@ -147,13 +151,14 @@ def _precompute_convolution_tensor_dense(
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)

# find the indices where the rotated position falls into the support of the kernel
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
out[iidx[:, 0], t, p, iidx[:, 1], iidx[:, 2]] = vals

# take care of normalization
# take care of normalization and cast to float
out = _normalize_convolution_tensor_dense(
out, quad_weights=quad_weights, transpose_normalization=transpose_normalization, basis_norm_mode=basis_norm_mode, merge_quadrature=merge_quadrature
)
out = out.to(dtype=torch.float32)

return out

Expand Down Expand Up @@ -239,7 +244,6 @@ def test_disco_convolution(
psi_dense = _precompute_convolution_tensor_dense(
out_shape,
in_shape,
kernel_shape,
filter_basis,
grid_in=grid_out,
grid_out=grid_in,
Expand All @@ -256,7 +260,6 @@ def test_disco_convolution(
psi_dense = _precompute_convolution_tensor_dense(
in_shape,
out_shape,
kernel_shape,
filter_basis,
grid_in=grid_in,
grid_out=grid_out,
Expand Down
23 changes: 15 additions & 8 deletions torch_harmonics/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def _precompute_convolution_tensor_s2(
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
theta_eps = 1e-3,
transpose_normalization=False,
basis_norm_mode="mean",
merge_quadrature=False,
Expand Down Expand Up @@ -164,20 +165,23 @@ def _precompute_convolution_tensor_s2(

# precompute input and output grids
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_in = torch.from_numpy(lats_in)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float()
lats_out = torch.from_numpy(lats_out)

# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]

# compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere.
if transpose_normalization:
quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0
quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
else:
quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0
quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0

# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff

out_idx = []
out_vals = []
Expand Down Expand Up @@ -207,7 +211,7 @@ def _precompute_convolution_tensor_s2(
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)

# find the indices where the rotated position falls into the support of the kernel
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)

# add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)
Expand All @@ -217,8 +221,8 @@ def _precompute_convolution_tensor_s2(
out_vals.append(vals)

# concatenate the indices and values
out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
out_idx = torch.cat(out_idx, dim=-1)
out_vals = torch.cat(out_vals, dim=-1)

out_vals = _normalize_convolution_tensor_s2(
out_idx,
Expand All @@ -232,6 +236,9 @@ def _precompute_convolution_tensor_s2(
merge_quadrature=merge_quadrature,
)

out_idx = out_idx.contiguous()
out_vals = out_vals.to(dtype=torch.float32).contiguous()

return out_idx, out_vals


Expand Down
24 changes: 16 additions & 8 deletions torch_harmonics/distributed/distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _precompute_distributed_convolution_tensor_s2(
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
theta_eps = 1e-3,
transpose_normalization=False,
basis_norm_mode="mean",
merge_quadrature=False,
Expand Down Expand Up @@ -103,21 +104,25 @@ def _precompute_distributed_convolution_tensor_s2(
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape

# precompute input and output grids
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_in = torch.from_numpy(lats_in)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float()
lats_out = torch.from_numpy(lats_out)

# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]

# compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere.
if transpose_normalization:
quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0
quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
else:
quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0
quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0

# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff

out_idx = []
out_vals = []
Expand Down Expand Up @@ -147,7 +152,7 @@ def _precompute_distributed_convolution_tensor_s2(
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)

# find the indices where the rotated position falls into the support of the kernel
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)

# add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)
Expand All @@ -157,8 +162,8 @@ def _precompute_distributed_convolution_tensor_s2(
out_vals.append(vals)

# concatenate the indices and values
out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
out_idx = torch.cat(out_idx, dim=-1)
out_vals = torch.cat(out_vals, dim=-1)

out_vals = _normalize_convolution_tensor_s2(
out_idx,
Expand Down Expand Up @@ -189,6 +194,9 @@ def _precompute_distributed_convolution_tensor_s2(
# for the indices we need to recompute them to refer to local indices of the input tenor
out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats] - start_idx) * nlon_in + lons[ilats]], dim=0)

out_idx = out_idx.contiguous()
out_vals = out_vals.to(dtype=torch.float32).contiguous()

return out_idx, out_vals


Expand Down

0 comments on commit 15d0750

Please sign in to comment.