Skip to content

Commit

Permalink
formating changes to resample module
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Jan 14, 2025
1 parent 4d8755b commit 8680e02
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 22 deletions.
15 changes: 4 additions & 11 deletions torch_harmonics/distributed/distributed_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,12 @@ def _upscale_longitudes(self, x: torch.Tensor):
else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lon_weights) * omega)/somega, (1.-self.lon_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lon_weights * omega)/somega, self.lon_weights)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lon_weights) * omega) / somega, (1.0 - self.lon_weights))
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lon_weights * omega) / somega, self.lon_weights)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]

return x

# old deprecated method with repeat_interleave
# def _upscale_longitudes(self, x: torch.Tensor):
# # for artifact-free upsampling in the longitudinal direction
# x = torch.repeat_interleave(x, self.lon_scale_factor, dim=-1)
# x = torch.roll(x, - self.lon_shift, dims=-1)
# return x

def _expand_poles(self, x: torch.Tensor):
repeats = [1 for _ in x.shape]
repeats[-1] = x.shape[-1]
Expand All @@ -171,8 +164,8 @@ def _upscale_latitudes(self, x: torch.Tensor):
else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lat_weights) * omega)/somega, (1.-self.lat_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lat_weights * omega)/somega, self.lat_weights)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lat_weights) * omega) / somega, (1.0 - self.lat_weights))
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lat_weights * omega) / somega, self.lat_weights)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]

return x
Expand Down
15 changes: 4 additions & 11 deletions torch_harmonics/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,19 +128,12 @@ def _upscale_longitudes(self, x: torch.Tensor):
else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lon_weights) * omega)/somega, (1.-self.lon_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lon_weights * omega)/somega, self.lon_weights)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lon_weights) * omega) / somega, (1.0 - self.lon_weights))
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lon_weights * omega) / somega, self.lon_weights)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]

return x

# old deprecated method with repeat_interleave
# def _upscale_longitudes(self, x: torch.Tensor):
# # for artifact-free upsampling in the longitudinal direction
# x = torch.repeat_interleave(x, self.lon_scale_factor, dim=-1)
# x = torch.roll(x, - self.lon_shift, dims=-1)
# return x

def _expand_poles(self, x: torch.Tensor):
repeats = [1 for _ in x.shape]
repeats[-1] = x.shape[-1]
Expand All @@ -156,8 +149,8 @@ def _upscale_latitudes(self, x: torch.Tensor):
else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lat_weights) * omega)/somega, (1.-self.lat_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lat_weights * omega)/somega, self.lat_weights)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lat_weights) * omega) / somega, (1.0 - self.lat_weights))
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lat_weights * omega) / somega, self.lat_weights)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]

return x
Expand Down

0 comments on commit 8680e02

Please sign in to comment.