Skip to content

Commit

Permalink
Fix the normalization scheduler to accept DID loop split.
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Feb 8, 2025
1 parent e332322 commit cf1fac4
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 14 deletions.
25 changes: 15 additions & 10 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,6 @@ NVF_API bool distributedEnabled() {

namespace {

std::unordered_set<IterDomain*> getShardedIterDomains(TensorView* tv) {
std::unordered_set<IterDomain*> sharded_ids;
std::copy_if(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
std::inserter(sharded_ids, sharded_ids.begin()),
[](auto id) { return id->isDeviceDim(); });
return sharded_ids;
}

// Returns the position where an axis is allocated in a tv, skipping trivial
// dimensions (i.e. DID, reduction and broadcast). Returns -1 if id is not in
// tv's loop domain WAR: today we assume that the loop domain match with the
Expand Down Expand Up @@ -230,6 +220,21 @@ int64_t getShardedLogicalAxis(
return logical_id_to_axis.at(id);
}

int64_t getShardedLoopAxis(
const TensorView* tv,
const ParallelType parallel_type) {
NVF_ERROR(
isParallelTypeDeviceDim(parallel_type),
"Expect a DID but found: ",
parallel_type);
for (int64_t i : c10::irange(tv->nDims())) {
if (tv->getLoopDomain()[i]->isDeviceDim()) {
return i;
}
}
return -1;
}

at::Tensor shardTensor(
at::Tensor tensor,
const int64_t axis,
Expand Down
4 changes: 4 additions & 0 deletions csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ void unshard(TensorView*);
// extent if that IterDomain is sharded.
int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type);

// Returns the index of the loop axis that's parallelized on `parallel_type`.
// If it's not found, returns -1.
int64_t getShardedLoopAxis(const TensorView* tv, ParallelType parallel_type);

// Shards the input tensor along `axis`. How the tensor gets sliced along `axis`
// is determined by `mesh` and `device_id`. Returns the sharded tensor.
at::Tensor shardTensor(
Expand Down
15 changes: 11 additions & 4 deletions csrc/scheduler/reduction_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,25 @@ TensorView* scheduleReductionTV(
// Inner here though is only relative to the other axis. When
// rparams->fastest_dim == false, the reduction axis is logically outside the
// iteration axis.
//
// Multidevice scheduling: we assume only the outermost domain can be
// parallelized with DIDx at this point and in that case this reduction
// scheduler only schedules the remaining domains while leaving the DIDx
// domain unchanged.
const bool has_outermost_dim_sharded = isSharded(reduction_tv);
int64_t sharded_axis = getShardedLoopAxis(reduction_tv, ParallelType::DIDx);
if (sharded_axis >= 0) {
NVF_ERROR(
sharded_axis == 0,
"Expect 1D mesh and DIDx only appear outermost in loop, but found: ",
reduction_tv->getLoopDomain());
}
NVF_ERROR(
!has_outermost_dim_sharded || !rparams->schedule_3D,
sharded_axis == -1 || !rparams->schedule_3D,
"Mixing interdevice and 3D schedule is not supported");
const int iter_axis = has_outermost_dim_sharded ? 1 : 0;
const int iter_axis = (sharded_axis >= 0) ? 1 : 0;
const int outer_reduce_axis = rparams->schedule_3D ? 1 : 0;
const int inner_reduce_axis =
rparams->schedule_3D ? 2 : has_outermost_dim_sharded + has_iter_axis;
rparams->schedule_3D ? 2 : (sharded_axis >= 0) + has_iter_axis;

const bool is_outer_grid_persistence = rparams->persistent_kernel &&
rparams->cross_grid_inner_reduction && !rparams->fastest_dim;
Expand Down
42 changes: 42 additions & 0 deletions tests/cpp/test_multidevice_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,48 @@ TEST_F(MultiDeviceTest, BackpropMeshes) {
<< "be sharded in the same way as x.";
}

TEST_F(MultiDeviceTest, DivideBySum) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

const int64_t d = communicator_->size();

// [b, h, s, s]
TensorView* x = makeContigTensor(4);
TensorView* sum_x = sum(x, {-1});
TensorView* sum_x_broadcasted = broadcast(sum_x, {false, false, false, true});
TensorView* y = div(x, sum_x_broadcasted);
fusion->addInput(x);
fusion->addOutput(y);

auto mesh = DeviceMesh::createForNumDevices(d);
for (auto* tv : {x, sum_x, sum_x_broadcasted, y}) {
tv->setDeviceMesh(mesh);
tv->split(1, d, /*inner_split=*/false);
tv->axis(1)->parallelize(ParallelType::DIDx);
tv->reorder({{1, 0}});
}
for (auto* tv : {x, y}) {
tv->setAllocationDomain(tv->getLoopDomain(), true);
}

const int64_t b = 2;
const int64_t h = d * 3;
const int64_t s = 5;
at::Tensor unsharded_x_tensor = at::randint(5, {b, h, s, s}, tensor_options);
at::Tensor x_tensor = shardTensor(unsharded_x_tensor, x);

FusionExecutorCache executor_cache(std::move(fusion));
at::Tensor y_tensor = executor_cache.runFusionWithInputs({x_tensor})[0];
testValidate(
executor_cache.fusion(),
{y_tensor},
{x_tensor},
{x_tensor / x_tensor.sum(-1, true)},
__LINE__,
__FILE__);
}

TEST_F(MultiDeviceTest, LayerNorm) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
Expand Down

0 comments on commit cf1fac4

Please sign in to comment.