Skip to content

Commit

Permalink
Make DistributedDataParallel use new reducer (pytorch#18953)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#18953

This removes Python side bucketing code from DistributedDataParallel
and replaces it with calls to the new C++ based bucketing and reducing
code. To confirm this is working well, we ran a test with both the
previous implementation and the new implementation, and confirmed they
are numerically equivalent.

Performance is improved by a couple percent or more, including the
single machine multiple GPU runs.

Closes pytorch#13273.

Reviewed By: mrshenli

Differential Revision: D14580911

fbshipit-source-id: 44e76f8b0b7e58dd6c91644e3df4660ca2ee4ae2
  • Loading branch information
pietern authored and facebook-github-bot committed Apr 15, 2019
1 parent 6ed57e0 commit a0263ec
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 180 deletions.
46 changes: 46 additions & 0 deletions test/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,6 +1934,52 @@ def test_forward_backward_optimizer(self):
optimizer.step()


class ComputeBucketAssignmentTest(TestCase):
def test_single_limit_single_dtype(self):
tensors = [
torch.empty([100], dtype=torch.float),
torch.empty([200], dtype=torch.float),
torch.empty([100], dtype=torch.float),
torch.empty([50], dtype=torch.float),
]
result = dist._compute_bucket_assignment_by_size(tensors, [400])
self.assertEqual([[0], [1], [2], [3]], result)

def test_single_limit_multi_dtype(self):
tensors = [
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
]
result = dist._compute_bucket_assignment_by_size(tensors, [400])
self.assertEqual([[0, 2], [1, 3], [4], [5]], result)

def test_multi_limit_single_dtype(self):
tensors = [
torch.empty([10], dtype=torch.float),
torch.empty([10], dtype=torch.float),
torch.empty([10], dtype=torch.float),
torch.empty([10], dtype=torch.float),
]
result = dist._compute_bucket_assignment_by_size(tensors, [40, 80])
self.assertEqual([[0], [1, 2], [3]], result)

def test_multi_limit_multi_dtype(self):
tensors = [
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
]
result = dist._compute_bucket_assignment_by_size(tensors, [200, 400])
self.assertEqual([[0], [1], [2, 4], [3, 5]], result)


if __name__ == '__main__':
assert not torch.cuda._initialized, "test_distributed must not have initialized CUDA context on main process"

Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,13 @@ They are used in specifying strategies for reduction collectives, e.g.,
py::call_guard<py::gil_scoped_release>());
#endif

module.def(
"_compute_bucket_assignment_by_size",
&::c10d::compute_bucket_assignment_by_size,
py::arg("tensors"),
py::arg("bucket_size"),
py::call_guard<py::gil_scoped_release>());

Py_RETURN_TRUE;
}

Expand Down
110 changes: 110 additions & 0 deletions torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/utils/hash.h>
#include <torch/csrc/utils/memory.h>

namespace c10d {
Expand Down Expand Up @@ -361,6 +362,13 @@ void Reducer::prepare_for_backward(
bucket.pending = bucket.replicas.size();
}

// If no outputs are specified, we assume that autograd hooks for ALL
// variables will be called, and we don't have to search the autograd graph
// for presence of these hooks.
if (outputs.empty()) {
return;
}

// Seed queue with the grad functions of all outputs.
for (const auto& output : outputs) {
const auto& grad_fn = output.grad_fn();
Expand Down Expand Up @@ -433,4 +441,106 @@ void Reducer::finalize_backward() {
}
}

namespace {

// Tensors may be coalesced into buckets. Buckets must contain tensors of
// the same type, on the same device, so a bucket can identified by a
// composite key of a tensor's type identifier and its device.
struct BucketKey {
BucketKey(c10::ScalarType type, c10::Device device)
: type(std::move(type)), device(std::move(device)) {}

const c10::ScalarType type;
const c10::Device device;

// See torch/csrc/utils/hash.h for dispatch code.
static size_t hash(const BucketKey& key) {
return torch::get_hash(key.type, key.device);
}
};

inline bool operator==(const BucketKey& lhs, const BucketKey& rhs) {
return lhs.type == rhs.type && lhs.device == rhs.device;
}

} // namespace

// This is equivalent to take_tensors but returns indices into the
// tensor list argument for bucket assignment. Also, it is aware
// of device placement and will not allow buckets to span devices.
std::vector<std::vector<size_t>> compute_bucket_assignment_by_size(
const std::vector<at::Tensor>& tensors,
std::vector<size_t> bucket_size_limits) {
std::vector<std::vector<size_t>> result;
result.reserve(tensors.size());

// Keep iterator into the size_limit vector by tensor type and device.
// This is done so that we can use the consecutive bucket limits per type.
std::unordered_map<
BucketKey,
std::vector<size_t>::iterator,
torch::hash<BucketKey>>
bucket_size_limit_iterators;

// Local accumulator type for a single bucket.
struct BucketAccumulator {
std::vector<size_t> indices;
size_t size = 0;
};

// Keep vector of indices and size accumulator by tensor type and device.
std::unordered_map<BucketKey, BucketAccumulator, torch::hash<BucketKey>>
buckets;

for (size_t i = 0; i < tensors.size(); i++) {
const auto& tensor = tensors[i];
AT_ASSERTM(!tensor.is_sparse(), "No support for sparse tensors.");
auto key = BucketKey(tensor.scalar_type(), tensor.device());
auto& bucket = buckets[key];
bucket.indices.push_back(i);
bucket.size += tensor.numel() * tensor.element_size();

// Initialize bucket size limit iterator if necessary.
if (bucket_size_limit_iterators.count(key) == 0) {
bucket_size_limit_iterators[key] = bucket_size_limits.begin();
}

auto& bucket_size_limit_iterator = bucket_size_limit_iterators[key];
const auto bucket_size_limit = *bucket_size_limit_iterator;
if (bucket.size >= bucket_size_limit) {
result.emplace_back(std::move(bucket.indices));
bucket = BucketAccumulator();

// Advance to the next bucket size limit for this type/device.
auto next = bucket_size_limit_iterator + 1;
if (next != bucket_size_limits.end()) {
bucket_size_limit_iterator = next;
}
}
}

// Add remaining buckets.
for (auto& it : buckets) {
auto& bucket = it.second;
if (!bucket.indices.empty()) {
result.emplace_back(std::move(bucket.indices));
}
}

// Sort resulting buckets by the minimum tensor index they include.
// We assume that the order of the tensors is the order in which they are
// used (or the reverse order in which their gradients are produced).
// This sorting step ensures that the buckets are ready in consecutive order.
std::sort(
result.begin(),
result.end(),
[](const std::vector<size_t>& a, const std::vector<size_t>& b) {
const auto amin = std::min_element(a.begin(), a.end());
const auto bmin = std::min_element(b.begin(), b.end());
return *amin < *bmin;
});

return result;
}

} // namespace c10d
4 changes: 4 additions & 0 deletions torch/csrc/distributed/c10d/reducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,8 @@ class Reducer {
std::vector<std::vector<int64_t>> backward_stats_;
};

std::vector<std::vector<size_t>> compute_bucket_assignment_by_size(
const std::vector<at::Tensor>& tensors,
std::vector<size_t> bucket_size);

} // namespace c10d
Loading

0 comments on commit a0263ec

Please sign in to comment.