Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting number of addressable devices from the mlir module instead of hardcoding #262

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions inc/common/pjrt_implementation/executable_image.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ class ExecutableImage {

public:
ExecutableImage(const tt::runtime::Binary &binary, std::string code,
const std::vector<bool> &is_output_scalar,
size_t num_addressable_devices);
const std::vector<bool> &is_output_scalar);

operator PJRT_Executable *() {
return reinterpret_cast<PJRT_Executable *>(this);
Expand Down Expand Up @@ -54,10 +53,6 @@ class ExecutableImage {
// Checks if the output on the i-th index is a scalar.
bool isOutputScalar(size_t index) const;

const size_t get_num_addressable_devices() const {
return m_num_addressable_devices;
}

private:
// The reference count. Must be disposed when reaching zero.
std::atomic<int> m_ref_count;
Expand All @@ -70,7 +65,6 @@ class ExecutableImage {

size_t m_arg_count;
size_t m_result_count;
size_t m_num_addressable_devices;

// For every output, holds if the type is a scalar or not.
std::vector<bool> m_is_output_scalar;
Expand Down
11 changes: 9 additions & 2 deletions inc/common/pjrt_implementation/loaded_executable_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ class LoadedExecutableInstance {
public:
LoadedExecutableInstance(
ClientInstance &client, ExecutableImage *image,
const std::vector<DeviceInstance *> &addressable_devices)
const std::vector<DeviceInstance *> &addressable_devices,
size_t num_addressable_devices)
: client_(client), image_(image),
addressable_devices_(addressable_devices) {}
addressable_devices_(addressable_devices),
num_addressable_devices_(num_addressable_devices) {}
~LoadedExecutableInstance();

operator PJRT_LoadedExecutable *() {
Expand All @@ -46,6 +48,10 @@ class LoadedExecutableInstance {
return addressable_devices_;
}

const size_t get_num_addressable_devices() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const not needed before size_t

return num_addressable_devices_;
}

// Loads all executables to addressable devices.
tt_pjrt_status LoadAll();

Expand All @@ -59,6 +65,7 @@ class LoadedExecutableInstance {
ClientInstance &client_;
ExecutableImage *image_; // Ref-counted semantics.
std::vector<DeviceInstance *> addressable_devices_;
size_t num_addressable_devices_;
Comment on lines 67 to +68
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If num_addressable_devices_ must be equal to addressable_devices_.size() at all times, do we really need it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They aren't equal. addressable_devices_ represents total chips, and num_addressable_devices_ tells us how many of them we are using in the particular loaded executable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call them num_devices_ or num_used_devices_? I get that addressable means the same as used in this context, but just to avoid this confusion with addressable_devices_.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I had the similar pending comment, to rename from num_addressable_devices_ into m_num_devices_to_utilize, but then I started wondering how JAX and PJRT actually use the addressable devices. This lead me to a realisation that we have two different groups of "addressable devices":

  1. PJRT_Client_AddressableDevices - "Returns a list of devices that are addressable from the client. Addressable devices are those that the client can issue commands to. All devices are addressable in a single-process environment."
  2. PJRT_LoadedExecutable_AddressableDevices - "Returns a list of devices this executable will run on."

This means that we shouldn't just pass the full addressable_devices list from client to the LoadedExecutableInstance, instead we should get the exact subset of devices that the executable needs to run on from the StableHLO. Then we won't need num_addressable_devices anymore, the size of the vector will be the num. FYI @ajakovljevicTT

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kmitrovicTT I agree that the naming is a bit weird since we had just adopted what xla and PJRT use.

@mrakitaTT yes, I had this exact dilemma at some point, but the devices for compilation are passed through extra compilation options (protobuf), and without it, we have no idea what exact devices the LoadedExecutableInstance will run on. Since we can only rely on the things we can see from the StableHLO code, we can get the number of devices and later get the exact devices for execution by checking what the are the devices of individual BufferInstances.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline, we can't get the exact list of devices yet from the compile options but since we are currently returning all devices as addressable devices this shouldn't be a problem. Let's then just rename occurrences of num_addressable_devices to num_devices_to_utilize in ModuleBuilder and LoadedExecutableInstance to make it more obvious what that number actually counts.

std::vector<ResidentExecutable> resident_executables_;
};

Expand Down
15 changes: 15 additions & 0 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ tt_pjrt_status ModuleBuilder::buildModule(const std::string_view &code,
return m_status;
}

collectNumAddressableDevices(mlir_module);

convertFromVHLOToSHLO(mlir_module);
if (!tt_pjrt_status_is_ok(m_status)) {
return m_status;
Expand Down Expand Up @@ -118,6 +120,19 @@ ModuleBuilder::createVHLOModule(const std::string_view &code) {
return vhlo_module;
}

void ModuleBuilder::collectNumAddressableDevices(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
if (auto attr = mlir_module->getOperation()->getAttrOfType<mlir::IntegerAttr>(
"mhlo.num_partitions")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ajakovljevicTT do we need mhlo.num_replicas * mhlo.num_partitions here? Take a look at this and this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct, thanks for catching this! In our examples, mhlo.num_replicas was always 1, so it was not a problem, but yes. @sdjukicTT please change.

m_num_addressable_devices = attr.getInt();
} else {
m_num_addressable_devices = 1;
DLOG_F(
WARNING,
"mhlo.num_partitions not found, using default number of devices: 1.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for dot at the end of the log messages

}
}

void ModuleBuilder::convertFromVHLOToSHLO(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
mlir::PassManager vhlo_to_shlo_pm(mlir_module.get()->getName());
Expand Down
12 changes: 9 additions & 3 deletions src/common/module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,18 @@ class ModuleBuilder {
return m_is_output_scalar;
};

// This needs to return the number of addressable devices from the StableHLO
// code. Currently hardcoded to one, as we only support one-chip execution.
size_t getNumAddressableDevices() const { return 1; }
size_t getNumAddressableDevices() const { return m_num_addressable_devices; }

private:
// Creates VHLO module from the input program code.
mlir::OwningOpRef<mlir::ModuleOp>
createVHLOModule(const std::string_view &code);

// Sets m_num_addressable_devices to the number of devices from the VHLO
// module.
void
collectNumAddressableDevices(mlir::OwningOpRef<mlir::ModuleOp> &mlir_module);

// Converts VHLO module to StableHLO module.
void convertFromVHLOToSHLO(mlir::OwningOpRef<mlir::ModuleOp> &mlir_module);

Expand Down Expand Up @@ -79,6 +82,9 @@ class ModuleBuilder {

// For every output, holds if the type is a scalar or not.
std::vector<bool> m_is_output_scalar;

// Number of devices the binary is intended to run on.
size_t m_num_addressable_devices;
};

} // namespace tt::pjrt
Expand Down
9 changes: 4 additions & 5 deletions src/common/pjrt_implementation/client_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ tt_pjrt_status ClientInstance::PopulateDevices() {

devices_.resize(devices_count);
for (size_t i = 0; i < devices_count; ++i) {
devices_[i] =
new DeviceInstance(i, *this, system_desc->chip_descs()->Get(i)->arch());
devices_[i] = new DeviceInstance(chip_ids[i], *this,
system_desc->chip_descs()->Get(i)->arch());
Comment on lines +166 to +167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a weird line break.

}

// For now, just make all devices addressable.
Expand All @@ -191,9 +191,8 @@ PJRT_Error *ClientInstance::Compile(const PJRT_Program *program,
*this,
new ExecutableImage(module_builder_->getBinary(),
std::string(program->code, program->code_size),
module_builder_->getIsOutputScalar(),
module_builder_->getNumAddressableDevices()),
addressable_devices_);
module_builder_->getIsOutputScalar()),
addressable_devices_, module_builder_->getNumAddressableDevices());
Comment on lines +194 to +195
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also.

*out_executable = executable.release();

return nullptr;
Expand Down
6 changes: 2 additions & 4 deletions src/common/pjrt_implementation/executable_image.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ const std::string_view kMlirFormat = "mlir";

ExecutableImage::ExecutableImage(const tt::runtime::Binary &binary,
std::string code,
const std::vector<bool> &is_output_scalar,
size_t num_addressable_devices)
const std::vector<bool> &is_output_scalar)
: m_ref_count(1), m_binary(binary), m_code(code),
m_arg_count(binary.getProgramInputs(0).size()),
m_result_count(binary.getProgramOutputs(0).size()),
m_is_output_scalar(is_output_scalar),
m_num_addressable_devices(num_addressable_devices) {
m_is_output_scalar(is_output_scalar) {
if (m_result_count != m_is_output_scalar.size()) {
// TODO: We should throw error instead, otherwise execution will continue
// and crash later.
Expand Down
13 changes: 5 additions & 8 deletions src/common/pjrt_implementation/loaded_executable_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void LoadedExecutableInstance::BindApi(PJRT_Api *api) {
const std::vector<DeviceInstance *> &addressable_devices =
loaded_executable->addressable_devices();
int num_addressable_devices =
loaded_executable->image_->get_num_addressable_devices();
loaded_executable->get_num_addressable_devices();
args->addressable_devices = const_cast<PJRT_Device **>(
reinterpret_cast<PJRT_Device *const *>(addressable_devices.data()));
args->num_addressable_devices = num_addressable_devices;
Expand Down Expand Up @@ -70,10 +70,8 @@ tt_pjrt_status
LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
DLOG_F(LOG_DEBUG, "LoadedExecutableInstance::Execute");

auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc();

// Sanity check, as we only support execution on one chip currently.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is now outdated, should edit to explain that we are making sure that the number of devices matches the number of devices we counted from the SHLO module.

assert(args->num_devices == 1);
assert(args->num_devices == num_addressable_devices_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me that we need more changes than this in this function to make it work for multichip, for example we still have a hardcoded dev_index = 0. Let's sync offline (@ajakovljevicTT)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the comments in the PJRT_LoadedExecutable_Execute_Args it seems like we should either:

  1. Run on the addressable devices passed to the LoadedExecutable from the compiler, where num_devices should match the size of the addressable devices list
  2. Or run on the execute_device specified in PJRT_LoadedExecutable_Execute_Args

Copy link
Contributor

@ajakovljevicTT ajakovljevicTT Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrakitaTT This change is scoped so that it only adressess the addressable_devices number. The PR that I will put out soon (that relies on this one and some of the other PRs by @sdjukicTT), will execute on devices that are read from the list of PJRT_Buffers that the xla gives us in the Execute() function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's just not forget to add some TODO to support in the future the scenario when execute_device is specified.


int dev_index = 0;
const tt::runtime::Binary &binary = image_->get_binary();
Expand All @@ -89,7 +87,7 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
rt_inputs.emplace_back(buffer->getTensor());
int64_t buffer_device_id =
buffer->device().device_description()->getDeviceId();
device_ids.insert(chip_ids[buffer_device_id]);
device_ids.insert(buffer_device_id);
DLOG_F(INFO, "Runtime input id: %d", buffer->unique_id());
}

Expand All @@ -99,11 +97,10 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
// TODO: Now we will run only on the first one, but this should be somehow
// explicit.
if (device_ids.size() == 0) {
device_ids_vector.push_back(chip_ids[0]);
device_ids_vector.push_back(
addressable_devices_[0]->device_description()->getDeviceId());
}

assert(device_ids_vector.size() == 1);

tt::runtime::Device device = tt::runtime::openDevice(device_ids_vector);

std::vector<tt::runtime::Tensor> rt_outputs =
Expand Down
Loading