-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Getting number of addressable devices from the mlir module instead of hardcoding #262
base: main
Are you sure you want to change the base?
Conversation
|
Codecov ReportAttention: Patch coverage is
✅ All tests successful. No failed tests found.
Additional details and impacted files@@ Coverage Diff @@
## main #262 +/- ##
==========================================
- Coverage 78.25% 78.07% -0.19%
==========================================
Files 21 21
Lines 1044 1049 +5
==========================================
+ Hits 817 819 +2
- Misses 227 230 +3 ☔ View full report in Codecov by Sentry. |
ea88e1e
to
cb58764
Compare
std::vector<DeviceInstance *> addressable_devices_; | ||
size_t num_addressable_devices_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If num_addressable_devices_
must be equal to addressable_devices_.size()
at all times, do we really need it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They aren't equal. addressable_devices_
represents total chips, and num_addressable_devices_
tells us how many of them we are using in the particular loaded executable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we call them num_devices_
or num_used_devices_
? I get that addressable
means the same as used
in this context, but just to avoid this confusion with addressable_devices_
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I had the similar pending comment, to rename from num_addressable_devices_
into m_num_devices_to_utilize
, but then I started wondering how JAX and PJRT actually use the addressable devices. This lead me to a realisation that we have two different groups of "addressable devices":
PJRT_Client_AddressableDevices
- "Returns a list of devices that are addressable from the client. Addressable devices are those that the client can issue commands to. All devices are addressable in a single-process environment."PJRT_LoadedExecutable_AddressableDevices
- "Returns a list of devices this executable will run on."
This means that we shouldn't just pass the full addressable_devices
list from client to the LoadedExecutableInstance
, instead we should get the exact subset of devices that the executable needs to run on from the StableHLO. Then we won't need num_addressable_devices
anymore, the size of the vector will be the num. FYI @ajakovljevicTT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kmitrovicTT I agree that the naming is a bit weird since we had just adopted what xla and PJRT use.
@mrakitaTT yes, I had this exact dilemma at some point, but the devices for compilation are passed through extra compilation options (protobuf), and without it, we have no idea what exact devices the LoadedExecutableInstance
will run on. Since we can only rely on the things we can see from the StableHLO code, we can get the number of devices and later get the exact devices for execution by checking what the are the devices of individual BufferInstance
s.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline, we can't get the exact list of devices yet from the compile options but since we are currently returning all devices as addressable devices this shouldn't be a problem. Let's then just rename occurrences of num_addressable_devices
to num_devices_to_utilize
in ModuleBuilder
and LoadedExecutableInstance
to make it more obvious what that number actually counts.
devices_[i] = new DeviceInstance(chip_ids[i], *this, | ||
system_desc->chip_descs()->Get(i)->arch()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a weird line break.
module_builder_->getIsOutputScalar()), | ||
addressable_devices_, module_builder_->getNumAddressableDevices()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix comments and you're good to go as far as I am concerned.
@@ -46,6 +48,10 @@ class LoadedExecutableInstance { | |||
return addressable_devices_; | |||
} | |||
|
|||
const size_t get_num_addressable_devices() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const not needed before size_t
m_num_addressable_devices = 1; | ||
DLOG_F( | ||
WARNING, | ||
"mhlo.num_partitions not found, using default number of devices: 1."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for dot at the end of the log messages
std::vector<DeviceInstance *> addressable_devices_; | ||
size_t num_addressable_devices_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I had the similar pending comment, to rename from num_addressable_devices_
into m_num_devices_to_utilize
, but then I started wondering how JAX and PJRT actually use the addressable devices. This lead me to a realisation that we have two different groups of "addressable devices":
PJRT_Client_AddressableDevices
- "Returns a list of devices that are addressable from the client. Addressable devices are those that the client can issue commands to. All devices are addressable in a single-process environment."PJRT_LoadedExecutable_AddressableDevices
- "Returns a list of devices this executable will run on."
This means that we shouldn't just pass the full addressable_devices
list from client to the LoadedExecutableInstance
, instead we should get the exact subset of devices that the executable needs to run on from the StableHLO. Then we won't need num_addressable_devices
anymore, the size of the vector will be the num. FYI @ajakovljevicTT
void ModuleBuilder::collectNumAddressableDevices( | ||
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) { | ||
if (auto attr = mlir_module->getOperation()->getAttrOfType<mlir::IntegerAttr>( | ||
"mhlo.num_partitions")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ajakovljevicTT do we need mhlo.num_replicas * mhlo.num_partitions
here? Take a look at this and this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are correct, thanks for catching this! In our examples, mhlo.num_replicas
was always 1, so it was not a problem, but yes. @sdjukicTT please change.
// Sanity check, as we only support execution on one chip currently. | ||
assert(args->num_devices == 1); | ||
assert(args->num_devices == num_addressable_devices_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to me that we need more changes than this in this function to make it work for multichip, for example we still have a hardcoded dev_index = 0
. Let's sync offline (@ajakovljevicTT)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the comments in the PJRT_LoadedExecutable_Execute_Args
it seems like we should either:
- Run on the addressable devices passed to the
LoadedExecutable
from the compiler, wherenum_devices
should match the size of the addressable devices list - Or run on the
execute_device
specified inPJRT_LoadedExecutable_Execute_Args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrakitaTT This change is scoped so that it only adressess the addressable_devices
number. The PR that I will put out soon (that relies on this one and some of the other PRs by @sdjukicTT), will execute on devices that are read from the list of PJRT_Buffer
s that the xla gives us in the Execute()
function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, let's just not forget to add some TODO to support in the future the scenario when execute_device
is specified.
@@ -70,10 +70,8 @@ tt_pjrt_status | |||
LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { | |||
DLOG_F(LOG_DEBUG, "LoadedExecutableInstance::Execute"); | |||
|
|||
auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); | |||
|
|||
// Sanity check, as we only support execution on one chip currently. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is now outdated, should edit to explain that we are making sure that the number of devices matches the number of devices we counted from the SHLO module.
Ticket
#229
Problem description
Number of addressable devices is hardcoded to 1, which is a problem for multichip.
What's changed
mhlo.num_partitions
attribute in VHLO module.getCurrentSystemDesc
, for less convoluted code.num_addressable_devices
fromExecutableImage
toLoadedExecutableInstance
because it is related toaddressable_devices_
of theLoadedExecutableInstance
.Checklist