Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting number of addressable devices from the mlir module instead of hardcoding #262

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

sdjukicTT
Copy link
Contributor

@sdjukicTT sdjukicTT commented Feb 18, 2025

Ticket

#229

Problem description

Number of addressable devices is hardcoded to 1, which is a problem for multichip.

What's changed

  • Now we get the number of addressable devices from the mhlo.num_partitions attribute in VHLO module.
  • Changed device_id to be the same as chip_id from getCurrentSystemDesc, for less convoluted code.
  • Moved num_addressable_devices from ExecutableImage to LoadedExecutableInstance because it is related to addressable_devices_ of the LoadedExecutableInstance.

Checklist

  • New/Existing tests provide coverage for changes

Copy link

github-actions bot commented Feb 18, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-XLA Tests599 ran431 passed168 skipped0 failed
TestResult
No test annotations available

@codecov-commenter
Copy link

codecov-commenter commented Feb 18, 2025

Codecov Report

Attention: Patch coverage is 94.44444% with 1 line in your changes missing coverage. Please review.

Project coverage is 78.07%. Comparing base (72560da) to head (ea88e1e).
Report is 1 commits behind head on main.

✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
src/common/module_builder.cc 85.71% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #262      +/-   ##
==========================================
- Coverage   78.25%   78.07%   -0.19%     
==========================================
  Files          21       21              
  Lines        1044     1049       +5     
==========================================
+ Hits          817      819       +2     
- Misses        227      230       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@sdjukicTT sdjukicTT force-pushed the sdjukic/num-addressable-devices branch from ea88e1e to cb58764 Compare February 20, 2025 17:18
Comment on lines 67 to +68
std::vector<DeviceInstance *> addressable_devices_;
size_t num_addressable_devices_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If num_addressable_devices_ must be equal to addressable_devices_.size() at all times, do we really need it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They aren't equal. addressable_devices_ represents total chips, and num_addressable_devices_ tells us how many of them we are using in the particular loaded executable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call them num_devices_ or num_used_devices_? I get that addressable means the same as used in this context, but just to avoid this confusion with addressable_devices_.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I had the similar pending comment, to rename from num_addressable_devices_ into m_num_devices_to_utilize, but then I started wondering how JAX and PJRT actually use the addressable devices. This lead me to a realisation that we have two different groups of "addressable devices":

  1. PJRT_Client_AddressableDevices - "Returns a list of devices that are addressable from the client. Addressable devices are those that the client can issue commands to. All devices are addressable in a single-process environment."
  2. PJRT_LoadedExecutable_AddressableDevices - "Returns a list of devices this executable will run on."

This means that we shouldn't just pass the full addressable_devices list from client to the LoadedExecutableInstance, instead we should get the exact subset of devices that the executable needs to run on from the StableHLO. Then we won't need num_addressable_devices anymore, the size of the vector will be the num. FYI @ajakovljevicTT

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kmitrovicTT I agree that the naming is a bit weird since we had just adopted what xla and PJRT use.

@mrakitaTT yes, I had this exact dilemma at some point, but the devices for compilation are passed through extra compilation options (protobuf), and without it, we have no idea what exact devices the LoadedExecutableInstance will run on. Since we can only rely on the things we can see from the StableHLO code, we can get the number of devices and later get the exact devices for execution by checking what the are the devices of individual BufferInstances.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline, we can't get the exact list of devices yet from the compile options but since we are currently returning all devices as addressable devices this shouldn't be a problem. Let's then just rename occurrences of num_addressable_devices to num_devices_to_utilize in ModuleBuilder and LoadedExecutableInstance to make it more obvious what that number actually counts.

Comment on lines +166 to +167
devices_[i] = new DeviceInstance(chip_ids[i], *this,
system_desc->chip_descs()->Get(i)->arch());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a weird line break.

Comment on lines +194 to +195
module_builder_->getIsOutputScalar()),
addressable_devices_, module_builder_->getNumAddressableDevices());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also.

Copy link
Contributor

@kmitrovicTT kmitrovicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix comments and you're good to go as far as I am concerned.

@@ -46,6 +48,10 @@ class LoadedExecutableInstance {
return addressable_devices_;
}

const size_t get_num_addressable_devices() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const not needed before size_t

m_num_addressable_devices = 1;
DLOG_F(
WARNING,
"mhlo.num_partitions not found, using default number of devices: 1.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for dot at the end of the log messages

Comment on lines 67 to +68
std::vector<DeviceInstance *> addressable_devices_;
size_t num_addressable_devices_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I had the similar pending comment, to rename from num_addressable_devices_ into m_num_devices_to_utilize, but then I started wondering how JAX and PJRT actually use the addressable devices. This lead me to a realisation that we have two different groups of "addressable devices":

  1. PJRT_Client_AddressableDevices - "Returns a list of devices that are addressable from the client. Addressable devices are those that the client can issue commands to. All devices are addressable in a single-process environment."
  2. PJRT_LoadedExecutable_AddressableDevices - "Returns a list of devices this executable will run on."

This means that we shouldn't just pass the full addressable_devices list from client to the LoadedExecutableInstance, instead we should get the exact subset of devices that the executable needs to run on from the StableHLO. Then we won't need num_addressable_devices anymore, the size of the vector will be the num. FYI @ajakovljevicTT

void ModuleBuilder::collectNumAddressableDevices(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
if (auto attr = mlir_module->getOperation()->getAttrOfType<mlir::IntegerAttr>(
"mhlo.num_partitions")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ajakovljevicTT do we need mhlo.num_replicas * mhlo.num_partitions here? Take a look at this and this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct, thanks for catching this! In our examples, mhlo.num_replicas was always 1, so it was not a problem, but yes. @sdjukicTT please change.

// Sanity check, as we only support execution on one chip currently.
assert(args->num_devices == 1);
assert(args->num_devices == num_addressable_devices_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me that we need more changes than this in this function to make it work for multichip, for example we still have a hardcoded dev_index = 0. Let's sync offline (@ajakovljevicTT)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the comments in the PJRT_LoadedExecutable_Execute_Args it seems like we should either:

  1. Run on the addressable devices passed to the LoadedExecutable from the compiler, where num_devices should match the size of the addressable devices list
  2. Or run on the execute_device specified in PJRT_LoadedExecutable_Execute_Args

Copy link
Contributor

@ajakovljevicTT ajakovljevicTT Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrakitaTT This change is scoped so that it only adressess the addressable_devices number. The PR that I will put out soon (that relies on this one and some of the other PRs by @sdjukicTT), will execute on devices that are read from the list of PJRT_Buffers that the xla gives us in the Execute() function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's just not forget to add some TODO to support in the future the scenario when execute_device is specified.

@@ -70,10 +70,8 @@ tt_pjrt_status
LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
DLOG_F(LOG_DEBUG, "LoadedExecutableInstance::Execute");

auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc();

// Sanity check, as we only support execution on one chip currently.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is now outdated, should edit to explain that we are making sure that the number of devices matches the number of devices we counted from the SHLO module.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants