Skip to content

Commit

Permalink
Adding input sharding support
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Mar 3, 2025
1 parent 5f758e5 commit 9ce0256
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
61 changes: 61 additions & 0 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
// loguru includes
#include "loguru/loguru.hpp"

// llvm includes
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"

// llvm mlir includes
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
Expand All @@ -28,10 +32,12 @@
#include "stablehlo/dialect/Register.h"
#include "stablehlo/dialect/Version.h"
#include "stablehlo/transforms/Passes.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"

// tt-mlir includes
#define TTMLIR_ENABLE_STABLEHLO
#include "tt/runtime/runtime.h"
#include "ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h"
#include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h"
#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"
Expand Down Expand Up @@ -82,6 +88,7 @@ tt_pjrt_status ModuleBuilder::buildModule(const std::string_view &code,
return m_status;
}

collectInputShardings(mlir_module);
collectOutputTypes(mlir_module);

convertFromSHLOToTTIR(mlir_module);
Expand Down Expand Up @@ -134,6 +141,60 @@ void ModuleBuilder::convertFromVHLOToSHLO(
printModule(mlir_module);
}

mlir::LogicalResult ModuleBuilder::fillMeshShardingFromGSPMDString(
mlir::StringAttr shardingStr,
mlir::tt::sharding_utils::MeshSharding &meshSharding) {
auto error =
meshSharding.convertGSPMDShardingToMeshSharding(shardingStr.getValue());
if (auto e = error.takeError()) {
DLOG_F(ERROR, "Failed to convert sharding attribute to mesh sharding");
return llvm::LogicalResult::failure();
}
return llvm::LogicalResult::success();
}

void ModuleBuilder::collectInputShardings(
const mlir::OwningOpRef<mlir::ModuleOp> &module) {
DLOG_F(LOG_DEBUG, "ModuleBuilder::collectInputShardings");
m_input_shardings.clear();

module.get().walk([&](mlir::Operation *op) {
mlir::func::FuncOp funcOp = mlir::dyn_cast<mlir::func::FuncOp>(op);
mlir::ModuleOp moduleOp = mlir::dyn_cast<mlir::ModuleOp>(op);

if (!funcOp) {
return;
}
if (!funcOp.isPublic()) {
return;
}
for (mlir::BlockArgument argument : funcOp.getArguments()) {

mlir::tt::sharding_utils::MeshSharding meshSharding;

for (unsigned i = 0; i < funcOp.getNumArguments(); ++i) {
mlir::BlockArgument argument = funcOp.getArgument(i);

auto shardingAttr = llvm::dyn_cast_if_present<mlir::StringAttr>(
funcOp.getArgAttr(i, mlir::tt::sharding_utils::kXlaShardingAttr));
if (!shardingAttr) {
continue;
}

mlir::LogicalResult conversionResult =
fillMeshShardingFromGSPMDString(shardingAttr, meshSharding);

if (conversionResult.failed()) {
m_status = tt_pjrt_status::kInternal;
return;
}
break;
}
m_input_shardings.push_back(meshSharding);
}
});
}

void ModuleBuilder::collectOutputTypes(
const mlir::OwningOpRef<mlir::ModuleOp> &module) {
DLOG_F(LOG_DEBUG, "ModuleBuilder::collectOutputTypes");
Expand Down
21 changes: 21 additions & 0 deletions src/common/module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
// tt-mlir includes
#include "tt/runtime/types.h"

#define TTMLIR_ENABLE_STABLEHLO
#include "ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h"

// tt-xla includes
#include "status.h"

Expand All @@ -36,6 +39,11 @@ class ModuleBuilder {
return m_is_output_scalar;
};

const std::vector<mlir::tt::sharding_utils::MeshSharding> &
getInputShardings() const {
return m_input_shardings;
}

// This needs to return the number of addressable devices from the StableHLO
// code. Currently hardcoded to one, as we only support one-chip execution.
size_t getNumAddressableDevices() const { return 1; }
Expand All @@ -52,6 +60,10 @@ class ModuleBuilder {
// scalar or not.
void collectOutputTypes(const mlir::OwningOpRef<mlir::ModuleOp> &module);

// Fills up the m_input_shardings array with information about the sharding of
// specifc inputs.
void collectInputShardings(const mlir::OwningOpRef<mlir::ModuleOp> &module);

// Converts StableHLO module to TTIR module.
void convertFromSHLOToTTIR(mlir::OwningOpRef<mlir::ModuleOp> &mlir_module);

Expand All @@ -68,6 +80,12 @@ class ModuleBuilder {
// Checks if a particular type is scalar.
bool isScalarType(mlir::Type type);

// Fills up a mlir::tt::sharding_utils::MeshSharding object with the
// information froma StringAttribute representing GSPMD sharding.
mlir::LogicalResult fillMeshShardingFromGSPMDString(
mlir::StringAttr shardingStr,
mlir::tt::sharding_utils::MeshSharding &meshSharding);

// MLIR context handle.
std::unique_ptr<mlir::MLIRContext> m_context;

Expand All @@ -79,6 +97,9 @@ class ModuleBuilder {

// For every output, holds if the type is a scalar or not.
std::vector<bool> m_is_output_scalar;

// For every input, holds the sharding information.
std::vector<mlir::tt::sharding_utils::MeshSharding> m_input_shardings;
};

} // namespace tt::pjrt
Expand Down

0 comments on commit 9ce0256

Please sign in to comment.