From 5a3dcf8368e5eba7b40212bad212e61f4e8bf557 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 7 Feb 2025 23:07:55 -0800 Subject: [PATCH] Fix bindTensorMetaData for DID loop split --- csrc/evaluator_common.cpp | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index 34852d7d279..ecee4dc12cd 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -348,9 +349,11 @@ void PrecomputedValues::bindTensorMetaData( tensor.dim() == static_cast(logical_domain.size()), "Something went wrong configuring launch. Inputs do not match."); - for (const auto dim : c10::irange(logical_domain.size())) { + std::vector logical_sizes = unshardedSizes(tv, tensor.sizes()); + for (const auto dim : + c10::irange(static_cast(logical_domain.size()))) { IterDomain* id = logical_domain[dim]; - const auto dim_size = tensor.size(static_cast(dim)); + const auto dim_size = logical_sizes.at(dim); if (id->isBroadcast()) { // DIDs are ignored for broadcast. See MultideviceShardingTest.Broadcast // and .ExpandedBroadcast. @@ -359,13 +362,7 @@ void PrecomputedValues::bindTensorMetaData( bindValue(id->expandedExtent()->evaluatorIndex(), dim_size); } } else { - if (id->isDeviceDim()) { - bindValue( - id->extent()->evaluatorIndex(), - tv->getDeviceMesh().size(id->getParallelType())); - } else { - bindValue(id->extent()->evaluatorIndex(), dim_size); - } + bindValue(id->extent()->evaluatorIndex(), dim_size); } }