Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return custom num addressable devices in PJRT_LoadedExecutable_AddressableDevices #229

Open
ajakovljevicTT opened this issue Feb 4, 2025 · 0 comments
Labels
multichip Multichip issues

Comments

@ajakovljevicTT
Copy link
Contributor

ajakovljevicTT commented Feb 4, 2025

Currently, we are hardcoding the number of addressable devices to be returned when PJRT_LoadedExecutable_AddressableDevices, to be able to support multichip, this has to be custom, with this information inferred from the StableHLO code, possibly via the mhlo.num_partitions = 2 attribute. This will require changes to the ExecutableImage::get_num_addressable_devices() function. It also might necessitate changing this part of the code that returns concrete addressable devices, but is currently not tested:

args->addressable_devices = const_cast<PJRT_Device **>(
        reinterpret_cast<PJRT_Device *const *>(addressable_devices.data()));
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
multichip Multichip issues
Projects
None yet
Development

No branches or pull requests

2 participants