-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding input sharding collection in ModuleBuilder #289
base: main
Are you sure you want to change the base?
Conversation
|
src/common/module_builder.cc
Outdated
auto shardingAttr = llvm::dyn_cast_if_present<mlir::StringAttr>( | ||
funcOp.getArgAttr(i, mlir::tt::sharding_utils::kXlaShardingAttr)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: can we avoid casting to stringattr and directly convert from the fetched attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left it like this so I can use the already made functions in tt-mlir
. See: lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
in tt-mlir
const std::vector<mlir::tt::sharding_utils::MeshSharding> & | ||
getInputShardings() const { | ||
return m_input_shardings; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a getter there for future use, do you think I should erase it?
93a81fd
to
9ce0256
Compare
9ce0256
to
260fb4d
Compare
As part of adding support for multichip, adding collection of the information about the device sharding of the inputs of the the StableHLO graph during compilation. That way, we will know how to create MultiDevice input tensors when executing.