diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 38b9476..d551ce1 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -84,11 +84,11 @@ def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalizati def _precompute_convolution_tensor_dense( in_shape, out_shape, - kernel_shape, filter_basis, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, + theta_eps=1e-3, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, @@ -106,21 +106,25 @@ def _precompute_convolution_tensor_dense( nlat_out, nlon_out = out_shape lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in) - lats_in = torch.from_numpy(lats_in).float() + lats_in = torch.from_numpy(lats_in) lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out) - lats_out = torch.from_numpy(lats_out).float() # array for accumulating non-zero indices + lats_out = torch.from_numpy(lats_out) # compute the phi differences. We need to make the linspace exclusive to not double the last point - lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] - lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1)[:-1] + lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1] + lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1, dtype=torch.float64)[:-1] + + # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles) + theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff # compute quadrature weights that will be merged into the Psi tensor if transpose_normalization: - quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0 + quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0 else: - quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0 + quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0 - out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in) + # array for accumulating non-zero indices + out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64) for t in range(nlat_out): for p in range(nlon_out): @@ -147,13 +151,14 @@ def _precompute_convolution_tensor_dense( phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi) # find the indices where the rotated position falls into the support of the kernel - iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff) + iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff) out[iidx[:, 0], t, p, iidx[:, 1], iidx[:, 2]] = vals - # take care of normalization + # take care of normalization and cast to float out = _normalize_convolution_tensor_dense( out, quad_weights=quad_weights, transpose_normalization=transpose_normalization, basis_norm_mode=basis_norm_mode, merge_quadrature=merge_quadrature ) + out = out.to(dtype=torch.float32) return out @@ -239,7 +244,6 @@ def test_disco_convolution( psi_dense = _precompute_convolution_tensor_dense( out_shape, in_shape, - kernel_shape, filter_basis, grid_in=grid_out, grid_out=grid_in, @@ -256,7 +260,6 @@ def test_disco_convolution( psi_dense = _precompute_convolution_tensor_dense( in_shape, out_shape, - kernel_shape, filter_basis, grid_in=grid_in, grid_out=grid_out, diff --git a/torch_harmonics/convolution.py b/torch_harmonics/convolution.py index 452fd9c..c9a5d7e 100644 --- a/torch_harmonics/convolution.py +++ b/torch_harmonics/convolution.py @@ -134,6 +134,7 @@ def _precompute_convolution_tensor_s2( grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, + theta_eps = 1e-3, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, @@ -164,20 +165,23 @@ def _precompute_convolution_tensor_s2( # precompute input and output grids lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in) - lats_in = torch.from_numpy(lats_in).float() + lats_in = torch.from_numpy(lats_in) lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out) - lats_out = torch.from_numpy(lats_out).float() + lats_out = torch.from_numpy(lats_out) # compute the phi differences # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 - lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] + lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1] # compute quadrature weights and merge them into the convolution tensor. # These quadrature integrate to 1 over the sphere. if transpose_normalization: - quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0 + quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0 else: - quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0 + quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0 + + # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles) + theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff out_idx = [] out_vals = [] @@ -207,7 +211,7 @@ def _precompute_convolution_tensor_s2( phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi) # find the indices where the rotated position falls into the support of the kernel - iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff) + iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff) # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in) idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0) @@ -217,8 +221,8 @@ def _precompute_convolution_tensor_s2( out_vals.append(vals) # concatenate the indices and values - out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous() - out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous() + out_idx = torch.cat(out_idx, dim=-1) + out_vals = torch.cat(out_vals, dim=-1) out_vals = _normalize_convolution_tensor_s2( out_idx, @@ -232,6 +236,9 @@ def _precompute_convolution_tensor_s2( merge_quadrature=merge_quadrature, ) + out_idx = out_idx.contiguous() + out_vals = out_vals.to(dtype=torch.float32).contiguous() + return out_idx, out_vals diff --git a/torch_harmonics/distributed/distributed_convolution.py b/torch_harmonics/distributed/distributed_convolution.py index f052165..05fec97 100644 --- a/torch_harmonics/distributed/distributed_convolution.py +++ b/torch_harmonics/distributed/distributed_convolution.py @@ -75,6 +75,7 @@ def _precompute_distributed_convolution_tensor_s2( grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, + theta_eps = 1e-3, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, @@ -103,21 +104,25 @@ def _precompute_distributed_convolution_tensor_s2( nlat_in, nlon_in = in_shape nlat_out, nlon_out = out_shape + # precompute input and output grids lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in) - lats_in = torch.from_numpy(lats_in).float() + lats_in = torch.from_numpy(lats_in) lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out) - lats_out = torch.from_numpy(lats_out).float() + lats_out = torch.from_numpy(lats_out) # compute the phi differences # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 - lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] + lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1] # compute quadrature weights and merge them into the convolution tensor. # These quadrature integrate to 1 over the sphere. if transpose_normalization: - quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0 + quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0 else: - quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0 + quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0 + + # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles) + theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff out_idx = [] out_vals = [] @@ -147,7 +152,7 @@ def _precompute_distributed_convolution_tensor_s2( phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi) # find the indices where the rotated position falls into the support of the kernel - iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff) + iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff) # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in) idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0) @@ -157,8 +162,8 @@ def _precompute_distributed_convolution_tensor_s2( out_vals.append(vals) # concatenate the indices and values - out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous() - out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous() + out_idx = torch.cat(out_idx, dim=-1) + out_vals = torch.cat(out_vals, dim=-1) out_vals = _normalize_convolution_tensor_s2( out_idx, @@ -189,6 +194,9 @@ def _precompute_distributed_convolution_tensor_s2( # for the indices we need to recompute them to refer to local indices of the input tenor out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats] - start_idx) * nlon_in + lons[ilats]], dim=0) + out_idx = out_idx.contiguous() + out_vals = out_vals.to(dtype=torch.float32).contiguous() + return out_idx, out_vals