diff --git a/src/common/module_builder.cc b/src/common/module_builder.cc index 55cb112..17412c7 100644 --- a/src/common/module_builder.cc +++ b/src/common/module_builder.cc @@ -12,6 +12,10 @@ // loguru includes #include "loguru/loguru.hpp" +// llvm includes +#include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" + // llvm mlir includes #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" @@ -28,10 +32,12 @@ #include "stablehlo/dialect/Register.h" #include "stablehlo/dialect/Version.h" #include "stablehlo/transforms/Passes.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" // tt-mlir includes #define TTMLIR_ENABLE_STABLEHLO #include "tt/runtime/runtime.h" +#include "ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h" #include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h" #include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h" #include "ttmlir/Dialect/TTIR/Transforms/Passes.h" @@ -82,6 +88,7 @@ tt_pjrt_status ModuleBuilder::buildModule(const std::string_view &code, return m_status; } + collectInputShardings(mlir_module); collectOutputTypes(mlir_module); convertFromSHLOToTTIR(mlir_module); @@ -134,6 +141,60 @@ void ModuleBuilder::convertFromVHLOToSHLO( printModule(mlir_module); } +mlir::LogicalResult ModuleBuilder::fillMeshShardingFromGSPMDString( + mlir::StringAttr shardingStr, + mlir::tt::sharding_utils::MeshSharding &meshSharding) { + auto error = + meshSharding.convertGSPMDShardingToMeshSharding(shardingStr.getValue()); + if (auto e = error.takeError()) { + DLOG_F(ERROR, "Failed to convert sharding attribute to mesh sharding"); + return llvm::LogicalResult::failure(); + } + return llvm::LogicalResult::success(); +} + +void ModuleBuilder::collectInputShardings( + const mlir::OwningOpRef &module) { + DLOG_F(LOG_DEBUG, "ModuleBuilder::collectInputShardings"); + m_input_shardings.clear(); + + module.get().walk([&](mlir::Operation *op) { + mlir::func::FuncOp funcOp = mlir::dyn_cast(op); + mlir::ModuleOp moduleOp = mlir::dyn_cast(op); + + if (!funcOp) { + return; + } + if (!funcOp.isPublic()) { + return; + } + for (mlir::BlockArgument argument : funcOp.getArguments()) { + + mlir::tt::sharding_utils::MeshSharding meshSharding; + + for (unsigned i = 0; i < funcOp.getNumArguments(); ++i) { + mlir::BlockArgument argument = funcOp.getArgument(i); + + auto shardingAttr = llvm::dyn_cast_if_present( + funcOp.getArgAttr(i, mlir::tt::sharding_utils::kXlaShardingAttr)); + if (!shardingAttr) { + continue; + } + + mlir::LogicalResult conversionResult = + fillMeshShardingFromGSPMDString(shardingAttr, meshSharding); + + if (conversionResult.failed()) { + m_status = tt_pjrt_status::kInternal; + return; + } + break; + } + m_input_shardings.push_back(meshSharding); + } + }); +} + void ModuleBuilder::collectOutputTypes( const mlir::OwningOpRef &module) { DLOG_F(LOG_DEBUG, "ModuleBuilder::collectOutputTypes"); diff --git a/src/common/module_builder.h b/src/common/module_builder.h index 548cc32..99fe04e 100644 --- a/src/common/module_builder.h +++ b/src/common/module_builder.h @@ -18,6 +18,9 @@ // tt-mlir includes #include "tt/runtime/types.h" +#define TTMLIR_ENABLE_STABLEHLO +#include "ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h" + // tt-xla includes #include "status.h" @@ -36,6 +39,11 @@ class ModuleBuilder { return m_is_output_scalar; }; + const std::vector & + getInputShardings() const { + return m_input_shardings; + } + // This needs to return the number of addressable devices from the StableHLO // code. Currently hardcoded to one, as we only support one-chip execution. size_t getNumAddressableDevices() const { return 1; } @@ -52,6 +60,10 @@ class ModuleBuilder { // scalar or not. void collectOutputTypes(const mlir::OwningOpRef &module); + // Fills up the m_input_shardings array with information about the sharding of + // specifc inputs. + void collectInputShardings(const mlir::OwningOpRef &module); + // Converts StableHLO module to TTIR module. void convertFromSHLOToTTIR(mlir::OwningOpRef &mlir_module); @@ -68,6 +80,12 @@ class ModuleBuilder { // Checks if a particular type is scalar. bool isScalarType(mlir::Type type); + // Fills up a mlir::tt::sharding_utils::MeshSharding object with the + // information froma StringAttribute representing GSPMD sharding. + mlir::LogicalResult fillMeshShardingFromGSPMDString( + mlir::StringAttr shardingStr, + mlir::tt::sharding_utils::MeshSharding &meshSharding); + // MLIR context handle. std::unique_ptr m_context; @@ -79,6 +97,9 @@ class ModuleBuilder { // For every output, holds if the type is a scalar or not. std::vector m_is_output_scalar; + + // For every input, holds the sharding information. + std::vector m_input_shardings; }; } // namespace tt::pjrt