Skip to content

Commit

Permalink
AMP hotfix (#47)
Browse files Browse the repository at this point in the history
* AMP hotfix

* Bumping up version to 0.7.1
  • Loading branch information
bonevbs authored Aug 27, 2024
1 parent 1bfda53 commit 5d7e9b0
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 46 deletions.
4 changes: 4 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Versioning

### v0.7.1

* Hotfix to AMP in SFNO example

### v0.7.0

* CUDA-accelerated DISCO convolutions
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,23 +209,23 @@ Detailed usage of torch-harmonics, alongside helpful analysis provided in a seri

## Remarks on automatic mixed precision (AMP) support

Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.cuda.amp.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically:
Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically:

```python
import torch
import torch_harmonics as th

sht = th.RealSHT(512, 1024, grid="equiangular").cuda()

with torch.cuda.amp.autocast(enabled = True):
with torch.autocast(device_type="cuda", enabled = True):
# do some AMP converted math here
x = some_math(x)
# convert tensor to float32
x = x.to(torch.float32)
# now disable autocast specifically for the transform,
# making sure that the tensors are not converted
# back to reduced precision internally
with torch.cuda.amp.autocast(enabled = False):
with torch.autocast(device_type="cuda", enabled = False):
xt = sht(x)

# continue operating on the transformed tensor
Expand Down
21 changes: 10 additions & 11 deletions examples/train_sfno.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
Expand Down Expand Up @@ -38,7 +38,6 @@
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.cuda import amp

import numpy as np
import pandas as pd
Expand All @@ -55,7 +54,7 @@ def l2loss_sphere(solver, prd, tar, relative=False, squared=True):
loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1)
if relative:
loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)

if not squared:
loss = torch.sqrt(loss)
loss = loss.mean()
Expand Down Expand Up @@ -124,7 +123,7 @@ def h1loss_sphere(solver, prd, tar, relative=False, squared=True):
h1_loss = torch.sum(h1_norm2, dim=(-1,-2))
l2_loss = torch.sum(l2_norm2, dim=(-1,-2))

# strictly speaking this is not exactly h1 loss
# strictly speaking this is not exactly h1 loss
if not squared:
loss = torch.sqrt(h1_loss) + torch.sqrt(l2_loss)
else:
Expand All @@ -143,7 +142,7 @@ def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0):
fluct = solver.integrate_grid((tar - inp)**2, dimensionless=True, polar_opt=polar_opt)
weight = fluct / torch.sum(fluct, dim=-1, keepdim=True)
# weight = weight.reshape(*weight.shape, 1, 1)

loss = weight * solver.integrate_grid((prd - tar)**2, dimensionless=True, polar_opt=polar_opt)
if relative:
loss = loss / (weight * solver.integrate_grid(tar**2, dimensionless=True, polar_opt=polar_opt))
Expand Down Expand Up @@ -194,7 +193,7 @@ def autoregressive_inference(model,
# classical model
start_time = time.time()
for i in range(1, autoreg_steps+1):

# advance classical model
uspec = dataset.solver.timestep(uspec, nsteps)

Expand All @@ -212,7 +211,7 @@ def autoregressive_inference(model,
ref = dataset.solver.spec2grid(uspec)
prd = prd * torch.sqrt(inp_var) + inp_mean
losses[iic] = l2loss_sphere(dataset.solver, prd, ref, relative=True).item()


return losses, fno_times, nwp_times

Expand Down Expand Up @@ -267,8 +266,8 @@ def train_model(model,
model.train()

for inp, tar in dataloader:
with amp.autocast(enabled=enable_amp):

with torch.autocast(device_type="cuda", enabled=enable_amp):

prd = model(inp)
for _ in range(nfuture):
Expand Down Expand Up @@ -357,7 +356,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
dt_solver = 150
nsteps = dt//dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(256, 512), device=device, normalize=True)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)

Expand Down Expand Up @@ -418,7 +417,7 @@ def count_parameters(model):
# optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=3E-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp)
gscaler = torch.GradScaler("cuda", enabled=enable_amp)

start_time = time.time()

Expand Down
2 changes: 1 addition & 1 deletion torch_harmonics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

__version__ = "0.7.0"
__version__ = "0.7.1"

from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
Expand Down
61 changes: 30 additions & 31 deletions torch_harmonics/examples/sfno/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
Expand Down Expand Up @@ -33,7 +33,6 @@
import torch.nn as nn
import torch.fft
from torch.utils.checkpoint import checkpoint
from torch.cuda import amp
import math

from torch_harmonics import *
Expand All @@ -53,36 +52,36 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.

if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)

with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)

# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)

# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()

# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)

# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
Expand Down Expand Up @@ -128,9 +127,9 @@ class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
return drop_path(x, self.drop_prob, self.training)

class MLP(nn.Module):
def __init__(self,
Expand Down Expand Up @@ -171,11 +170,11 @@ def __init__(self,
self.fwd = nn.Sequential(fc1, act, drop, fc2, drop)
else:
self.fwd = nn.Sequential(fc1, act, fc2)

@torch.jit.ignore
def checkpoint_forward(self, x):
return checkpoint(self.fwd, x)

def forward(self, x):
if self.checkpointing:
return self.checkpoint_forward(x)
Expand Down Expand Up @@ -221,14 +220,14 @@ def __init__(self,

def forward(self, x):
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")

class SpectralConvS2(nn.Module):
"""
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
domain via the RealFFT2 and InverseRealFFT2 wrappers.
"""

def __init__(self,
forward_transform,
inverse_transform,
Expand Down Expand Up @@ -277,18 +276,18 @@ def __init__(self,

# get the right contraction function
self._contract = _contract

if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))


def forward(self, x):

dtype = x.dtype
x = x.float()
residual = x

with amp.autocast(enabled=False):
with torch.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
Expand All @@ -298,20 +297,20 @@ def forward(self, x):
x = self._contract(x, self.weight)
x = torch.view_as_complex(x)

with amp.autocast(enabled=False):
with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x)

if hasattr(self, "bias"):
x = x + self.bias
x = x.type(dtype)

return x, residual

class FactorizedSpectralConvS2(nn.Module):
"""
Factorized version of SpectralConvS2. Uses tensorly-torch to keep the weights factorized
"""

def __init__(self,
forward_transform,
inverse_transform,
Expand Down Expand Up @@ -366,39 +365,39 @@ def __init__(self,
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")

# form weight tensors
self.weight = FactorizedTensor.new(weight_shape, rank=self.rank, factorization=factorization,
self.weight = FactorizedTensor.new(weight_shape, rank=self.rank, factorization=factorization,
fixed_rank_modes=False, **decomposition_kwargs)

# initialization of weights
scale = math.sqrt(gain / in_channels)
self.weight.normal_(0, scale)

# get the right contraction function
from .factorizations import get_contract_fun
self._contract = get_contract_fun(self.weight, implementation=implementation, separable=separable)

if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))


def forward(self, x):

dtype = x.dtype
x = x.float()
residual = x

with amp.autocast(enabled=False):
with torch.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)

x = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type)

with amp.autocast(enabled=False):
with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x)

if hasattr(self, "bias"):
x = x + self.bias
x = x.type(dtype)

return x, residual

0 comments on commit 5d7e9b0

Please sign in to comment.