Skip to content

Commit

Permalink
Use batched operations for PowerSGD
Browse files Browse the repository at this point in the history
This implements method proposed in pytorch#74907

Pull Request resolved: pytorch#75157
Approved by: https://github.com/wayi1, https://github.com/rohan-varma
  • Loading branch information
MagiaSN authored and pytorchmergebot committed Apr 18, 2022
1 parent 3a38f17 commit 5654e63
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 36 deletions.
5 changes: 4 additions & 1 deletion test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,12 +2014,15 @@ def _test_powerSGD_ddp_comm_hook_nccl(self, gradient_as_bucket_view=False):

# Get GPU model with the hook registered.
# Test the hook with different algorithmic configs.
for use_error_feedback, warm_start in product([True, False], [True, False]):
for use_error_feedback, warm_start, batch_tensors_with_same_shape in product(
[True, False], [True, False], [True, False],
):
state = powerSGD.PowerSGDState(
process_group=process_group,
matrix_approximation_rank=1,
use_error_feedback=use_error_feedback,
warm_start=warm_start,
batch_tensors_with_same_shape=batch_tensors_with_same_shape,
)
for hook in [powerSGD.powerSGD_hook, powerSGD.batched_powerSGD_hook]:
gpu_model = self._gpu_model_with_ddp_comm_hook(
Expand Down
118 changes: 83 additions & 35 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections import defaultdict
import logging
import math
from typing import Dict

import numpy as np
import torch
Expand All @@ -10,55 +12,57 @@
logger = logging.getLogger(__name__)


def _orthogonalize(matrix, epsilon=0):
def _orthogonalize(matrices, epsilon=0):
"""
Decide between Gram-Schmidt or QR factorization to orthogonalize the matrix.
Decide between Gram-Schmidt or QR factorization to orthogonalize a batch of matrices.
QR factorization doesn't work with half-precision, but it is usually faster with a rank > 2.
"""
assert len(matrix.shape) == 2 and matrix.shape[1] <= matrix.shape[0]
assert len(matrices.shape) == 3 and matrices.shape[2] <= matrices.shape[1]

rank = matrix.shape[1]
dtype = matrix.dtype
num_matrices = matrices.shape[0]
rank = matrices.shape[2]
dtype = matrices.dtype
if rank <= 2 or dtype in [torch.float16, torch.bfloat16]:
_orthogonalize_gram_schmidt(matrix, epsilon=epsilon)
_orthogonalize_gram_schmidt(matrices, epsilon=epsilon)
else:
torch.linalg.qr(
matrix,
matrices,
out=(
matrix,
torch.empty(rank, rank, device=matrix.device, dtype=dtype)
matrices,
torch.empty(num_matrices, rank, rank, device=matrices.device, dtype=dtype)
)
)

def _orthogonalize_gram_schmidt(matrix, epsilon=0):
def _orthogonalize_gram_schmidt(matrices, epsilon=0):
"""
Applies Gram-Schmidt procedure to orthogonalize a given 2D tensor.
If epsilon is 0, this is equivalent to `torch.qr(matrix, out=(matrix, _))`,
Applies Gram-Schmidt procedure to orthogonalize a batch of matrices.
If epsilon is 0, this is equivalent to `torch.qr(matrices, out=(matrices, _))`,
"""
num_cols = matrix.shape[1]
num_cols = matrices.shape[2]
for i in range(num_cols):
# Normalize the i'th column.
col = matrix[:, i : i + 1]
col = matrices[:, :, i : i + 1]
# If no epsilon is added here, division by zero may be caused by vanishing gradients.
# This epsilon is not needed if the input matrix covers the gradients of at least one entire layer in the neural network.
# This epsilon is not needed if the input batch of matrices covers the gradients of at least one entire layer
# in the neural network.
if epsilon == 0:
# Note that col ** 2 can underflow/overflow if we use FP16.
# May need to consider multiplying a scaling factor and dividing it later, or using bfloat16 instead.
try:
col /= torch.norm(col)
col /= torch.norm(col, dim=1, keepdim=True)
except ZeroDivisionError:
logger.error(
"The matrix to be orthogonalized has at least a column of all 0s. Please set a small value such as 1e-8 "
"The matrices to be orthogonalized has at least a column of all 0s. Please set a small value such as 1e-8 "
"as `orthogonalization_epsilon` in PowerSGD state."
)
# Recover the values from NaNs to 0s.
col.fill_(0.0)
else:
col /= torch.norm(col) + epsilon
col /= torch.norm(col, dim=1, keepdim=True) + epsilon
# Project it on the rest and remove it.
if i + 1 < num_cols:
rest = matrix[:, i + 1 :]
rest -= torch.sum(col * rest, dim=0) * col
rest = matrices[:, :, i + 1 :]
rest -= torch.sum(col * rest, dim=1, keepdim=True) * col


def _should_compress(
Expand Down Expand Up @@ -128,6 +132,8 @@ class PowerSGDState(object):
4. ``orthogonalization_epsilon`` can be a very small value (e.g., 1e-8) added to every normalized matrix column in orthogonalization step, to prevent div-by-zero error if any column has all 0s. If this can already be prevented (e.g., by batch normalization), an epsilon of 0 is recommended for accuracy.
5. ``batch_tensors_with_same_shape`` controls whether to compress and decompress tensors with same shape in a batched operation to achieve higher parallelism. Note that you should also increase the bucket size (i.e., ``bucket_cap_mb`` arg in DDP constructor) to make more same-shaped tensors appear in the same bucket, however this may reduce the overlap between computation and communication, and increase the memory footprint due to stacking the tensors of the same shape. Set to ``True`` if the compression / decompression computation is a bottleneck.
.. warning ::
If error feedback or warm-up is enabled, the minimum value of ``start_powerSGD_iter`` allowed in DDP is 2.
This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP,
Expand All @@ -145,6 +151,7 @@ class PowerSGDState(object):
# The fields below are the binary hyperparameters recommended to be turned on for performance and accuracy.
"use_error_feedback",
"warm_start",
"batch_tensors_with_same_shape",
# The fields below are internal state.
"rng",
"error_dict",
Expand All @@ -169,11 +176,12 @@ def __init__(
orthogonalization_epsilon=0,
random_seed=0,
compression_stats_logging_frequency=10_000,
batch_tensors_with_same_shape: bool = False,
):
logger.info(
"PowerSGD config: matrix_approximation_rank = {}; start_powerSGD_iter = {}; "
"min_compression_rate = {}; orthogonalization_epsilon = {}; use_error_feedback = {}; warm_start = {}; "
"random_seed = {}; compression_stats_logging_frequency = {}".format(
"random_seed = {}; compression_stats_logging_frequency = {}; batch_tensors_with_same_shape = {}".format(
matrix_approximation_rank,
start_powerSGD_iter,
min_compression_rate,
Expand All @@ -182,6 +190,7 @@ def __init__(
warm_start,
random_seed,
compression_stats_logging_frequency,
batch_tensors_with_same_shape,
)
)

Expand Down Expand Up @@ -230,9 +239,9 @@ def __init__(
self.rng = np.random.RandomState(random_seed)
# Since there is only a single state instance for all the input buckets,
# need to maintain a dictionary that maps each bucket index to the local error.
self.error_dict = {}
self.p_memory_dict = {}
self.q_memory_dict = {}
self.error_dict: Dict[int, torch.Tensor] = {}
self.p_memory_dict: Dict[int, torch.Tensor] = {}
self.q_memory_dict: Dict[int, torch.Tensor] = {}
# Iteration/step in the training loop.
self.iter = 0
# Compression stats accumulators
Expand All @@ -244,6 +253,12 @@ def __init__(
1, compression_stats_logging_frequency
)
self.next_stats_report = 0
# Batching tensors with same shape can increase parallelism in compressiom / decompression computation.
# This requires a larger bucket size to make more same-shaped tensor to appear in one bucket, however
# this may reduce the overlap between computation and communication, and increase the memory footprint
# due to stacking tensors.
# Turn on if compression / decompression computation is a bottleneck.
self.batch_tensors_with_same_shape = batch_tensors_with_same_shape

def maybe_increase_iter(self, bucket):
# Since bucket 0 is the last bucket to allreduce in an iteration.
Expand Down Expand Up @@ -431,26 +446,48 @@ def powerSGD_hook(
total_Qs_size, device=device, dtype=dtype
)

# Batch tensors to compress by shape.
shape_to_tensors = defaultdict(list)
for tensor in tensors_to_compress:
shape_to_tensors[tensor.shape].append(tensor)

# This function decides whether to batch tensors with same shape or not according to the argument,
# so the following process could share the same code.
def maybe_batched_tensors_to_compress():
for tensors in shape_to_tensors.values():
if state.batch_tensors_with_same_shape:
batch_size = len(tensors)
if batch_size == 1:
# Use the original tensor to avoid copy.
yield tensors[0].unsqueeze(0)
else:
yield torch.stack(tensors)
else:
for tensor in tensors:
yield tensor.unsqueeze(0)

# Create Ps and Qs that point to the allocated memory.
tensors_to_compress = []
ps = []
qs = []
p_idx = 0
q_idx = 0
for tensor in tensors_to_compress:
n, m = tensor.shape
for tensor in maybe_batched_tensors_to_compress():
batch_size, n, m = tensor.shape
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
tensors_to_compress.append(tensor)
ps.append(
state.p_memory_dict[bucket_index][
p_idx : p_idx + n * matrix_approximation_rank
].view(n, matrix_approximation_rank)
p_idx : p_idx + batch_size * n * matrix_approximation_rank
].view(batch_size, n, matrix_approximation_rank)
)
qs.append(
state.q_memory_dict[bucket_index][
q_idx : q_idx + m * matrix_approximation_rank
].view(m, matrix_approximation_rank)
q_idx : q_idx + batch_size * m * matrix_approximation_rank
].view(batch_size, m, matrix_approximation_rank)
)
p_idx += n * matrix_approximation_rank
q_idx += m * matrix_approximation_rank
p_idx += batch_size * n * matrix_approximation_rank
q_idx += batch_size * m * matrix_approximation_rank

# If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values.
# The exception is the first iteration when PowerSGD is applied.
Expand All @@ -477,7 +514,7 @@ def powerSGD_hook(

# Compute Ps.
for tensor, q, p in zip(tensors_to_compress, qs, ps):
torch.matmul(tensor, q, out=p)
torch.bmm(tensor, q, out=p)

# This allreduce is only applied to uncompressed tensors,
# so it should have been kicked off before the above computation on the compressed tensors to hide more communication costs.
Expand Down Expand Up @@ -511,7 +548,7 @@ def compute_qs(fut):

# Compute Qs.
for tensor, p, q in zip(tensors_to_compress, ps, qs):
torch.matmul(tensor.t(), p, out=q)
torch.bmm(tensor.transpose(1, 2), p, out=q)

# TODO: The above procedure does two matmul+allreduce steps per iteration --
# one left multiplication and one right multiplication.
Expand All @@ -530,7 +567,18 @@ def decompress(fut):
state.q_memory_dict[bucket_index] = fut.value().div_(world_size)

for p, q, tensor in zip(ps, qs, tensors_to_compress):
torch.matmul(p, q.t(), out=tensor)
torch.bmm(p, q.transpose(1, 2), out=tensor)

# Copy batched tensors back to original buffer.
if state.batch_tensors_with_same_shape:
for tensor in tensors_to_compress:
if tensor.shape[0] == 1:
# Skip tensor with batch_size == 1 since itself is the original tensor.
continue
original_tensors = shape_to_tensors[tensor.shape[1:]]
for i, original_tensor in enumerate(original_tensors):
original_tensor.copy_(tensor[i])

if torch.cuda.is_available():
torch.cuda.synchronize(device)

Expand Down

0 comments on commit 5654e63

Please sign in to comment.