Skip to content

Commit

Permalink
Fix PrecomputedValues::bindTensorMetaData for DID loop split (#3854)
Browse files Browse the repository at this point in the history
in the same way as ExpressionEvaluator::bindTensorDomain and several
other places. Caveat: having to fix multiple places in the same way
probably indicates a pre-existing duplication of logic.

Fixes #3817
  • Loading branch information
wujingyue authored Feb 10, 2025
1 parent 97765cc commit 0510726
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
15 changes: 6 additions & 9 deletions csrc/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <expr_evaluator.h>
#include <instrumentation.h>
#include <ir/utils.h>
#include <multidevice/utils.h>
#include <runtime/executor_kernel_arg.h>
#include <tensor_metadata.h>

Expand Down Expand Up @@ -348,9 +349,11 @@ void PrecomputedValues::bindTensorMetaData(
tensor.dim() == static_cast<int64_t>(logical_domain.size()),
"Something went wrong configuring launch. Inputs do not match.");

for (const auto dim : c10::irange(logical_domain.size())) {
std::vector<int64_t> logical_sizes = unshardedSizes(tv, tensor.sizes());
for (const auto dim :
c10::irange(static_cast<int64_t>(logical_domain.size()))) {
IterDomain* id = logical_domain[dim];
const auto dim_size = tensor.size(static_cast<int64_t>(dim));
const auto dim_size = logical_sizes.at(dim);
if (id->isBroadcast()) {
// DIDs are ignored for broadcast. See MultideviceShardingTest.Broadcast
// and .ExpandedBroadcast.
Expand All @@ -359,13 +362,7 @@ void PrecomputedValues::bindTensorMetaData(
bindValue(id->expandedExtent()->evaluatorIndex(), dim_size);
}
} else {
if (id->isDeviceDim()) {
bindValue(
id->extent()->evaluatorIndex(),
tv->getDeviceMesh().size(id->getParallelType()));
} else {
bindValue(id->extent()->evaluatorIndex(), dim_size);
}
bindValue(id->extent()->evaluatorIndex(), dim_size);
}
}

Expand Down
11 changes: 6 additions & 5 deletions tests/python/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,10 +456,7 @@ def assert_close(actual: nvfuser.DistributedTensor, expected: torch.Tensor):
@pytest.mark.parametrize("qkv_format", [QkvFormat.BHSE, QkvFormat.BSHE])
@pytest.mark.mpi
def test_sdpa_loop_split(multidevice_test, qkv_format: QkvFormat):
d, b, s, h, e = multidevice_test.size, 2, 1024, 12, 768

if h % d != 0:
pytest.skip(f"We only support even split, so {h} has to be divisible by {d}.")
d = multidevice_test.size
mesh = nvfuser.DeviceMesh(range(d))

class Model(FusionDefinition):
Expand All @@ -476,7 +473,7 @@ def definition(self) -> None:

self.q, self.k, self.v, self.out_grad = [
self.define_tensor(
shape=[b, h, s, e // h],
shape=[-1, -1, -1, -1],
dtype=DataType.BFloat16,
stride_order=stride_order,
)
Expand Down Expand Up @@ -542,6 +539,10 @@ def multidevice_schedule(self) -> None:
for t in output_tvs:
self.sched.set_allocation_as_loop(t)

b, s, h, e = 2, 1024, 12, 768
if h % d != 0:
pytest.skip(f"We only support even split, so {h} has to be divisible by {d}.")

torch.cuda.set_device(multidevice_test.local_rank)

def make_unsharded_tensor() -> torch.Tensor:
Expand Down

0 comments on commit 0510726

Please sign in to comment.