Skip to content

Commit

Permalink
Fix bindTensorMetaData for DID loop split
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Feb 8, 2025
1 parent cf1fac4 commit 5a3dcf8
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions csrc/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <expr_evaluator.h>
#include <instrumentation.h>
#include <ir/utils.h>
#include <multidevice/utils.h>
#include <runtime/executor_kernel_arg.h>
#include <tensor_metadata.h>

Expand Down Expand Up @@ -348,9 +349,11 @@ void PrecomputedValues::bindTensorMetaData(
tensor.dim() == static_cast<int64_t>(logical_domain.size()),
"Something went wrong configuring launch. Inputs do not match.");

for (const auto dim : c10::irange(logical_domain.size())) {
std::vector<int64_t> logical_sizes = unshardedSizes(tv, tensor.sizes());
for (const auto dim :
c10::irange(static_cast<int64_t>(logical_domain.size()))) {
IterDomain* id = logical_domain[dim];
const auto dim_size = tensor.size(static_cast<int64_t>(dim));
const auto dim_size = logical_sizes.at(dim);
if (id->isBroadcast()) {
// DIDs are ignored for broadcast. See MultideviceShardingTest.Broadcast
// and .ExpandedBroadcast.
Expand All @@ -359,13 +362,7 @@ void PrecomputedValues::bindTensorMetaData(
bindValue(id->expandedExtent()->evaluatorIndex(), dim_size);
}
} else {
if (id->isDeviceDim()) {
bindValue(
id->extent()->evaluatorIndex(),
tv->getDeviceMesh().size(id->getParallelType()));
} else {
bindValue(id->extent()->evaluatorIndex(), dim_size);
}
bindValue(id->extent()->evaluatorIndex(), dim_size);
}
}

Expand Down

0 comments on commit 5a3dcf8

Please sign in to comment.