Skip to content

Commit

Permalink
added m_num_addressable_devices, changed device_id to match chip_id
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjukicTT committed Feb 20, 2025
1 parent 7082b6e commit cb58764
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 26 deletions.
8 changes: 1 addition & 7 deletions inc/common/pjrt_implementation/executable_image.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ class ExecutableImage {

public:
ExecutableImage(const tt::runtime::Binary &binary, std::string code,
const std::vector<bool> &is_output_scalar,
size_t num_addressable_devices);
const std::vector<bool> &is_output_scalar);

operator PJRT_Executable *() {
return reinterpret_cast<PJRT_Executable *>(this);
Expand Down Expand Up @@ -54,10 +53,6 @@ class ExecutableImage {
// Checks if the output on the i-th index is a scalar.
bool isOutputScalar(size_t index) const;

const size_t get_num_addressable_devices() const {
return m_num_addressable_devices;
}

private:
// The reference count. Must be disposed when reaching zero.
std::atomic<int> m_ref_count;
Expand All @@ -70,7 +65,6 @@ class ExecutableImage {

size_t m_arg_count;
size_t m_result_count;
size_t m_num_addressable_devices;

// For every output, holds if the type is a scalar or not.
std::vector<bool> m_is_output_scalar;
Expand Down
11 changes: 9 additions & 2 deletions inc/common/pjrt_implementation/loaded_executable_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ class LoadedExecutableInstance {
public:
LoadedExecutableInstance(
ClientInstance &client, ExecutableImage *image,
const std::vector<DeviceInstance *> &addressable_devices)
const std::vector<DeviceInstance *> &addressable_devices,
size_t num_addressable_devices)
: client_(client), image_(image),
addressable_devices_(addressable_devices) {}
addressable_devices_(addressable_devices),
num_addressable_devices_(num_addressable_devices) {}
~LoadedExecutableInstance();

operator PJRT_LoadedExecutable *() {
Expand All @@ -46,6 +48,10 @@ class LoadedExecutableInstance {
return addressable_devices_;
}

const size_t get_num_addressable_devices() const {
return num_addressable_devices_;
}

// Loads all executables to addressable devices.
tt_pjrt_status LoadAll();

Expand All @@ -59,6 +65,7 @@ class LoadedExecutableInstance {
ClientInstance &client_;
ExecutableImage *image_; // Ref-counted semantics.
std::vector<DeviceInstance *> addressable_devices_;
size_t num_addressable_devices_;
std::vector<ResidentExecutable> resident_executables_;
};

Expand Down
15 changes: 15 additions & 0 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ tt_pjrt_status ModuleBuilder::buildModule(const std::string_view &code,
return m_status;
}

collectNumAddressableDevices(mlir_module);

convertFromVHLOToSHLO(mlir_module);
if (!tt_pjrt_status_is_ok(m_status)) {
return m_status;
Expand Down Expand Up @@ -118,6 +120,19 @@ ModuleBuilder::createVHLOModule(const std::string_view &code) {
return vhlo_module;
}

void ModuleBuilder::collectNumAddressableDevices(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
if (auto attr = mlir_module->getOperation()->getAttrOfType<mlir::IntegerAttr>(
"mhlo.num_partitions")) {
m_num_addressable_devices = attr.getInt();
} else {
m_num_addressable_devices = 1;
DLOG_F(
WARNING,
"mhlo.num_partitions not found, using default number of devices: 1.");
}
}

void ModuleBuilder::convertFromVHLOToSHLO(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
mlir::PassManager vhlo_to_shlo_pm(mlir_module.get()->getName());
Expand Down
12 changes: 9 additions & 3 deletions src/common/module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,18 @@ class ModuleBuilder {
return m_is_output_scalar;
};

// This needs to return the number of addressable devices from the StableHLO
// code. Currently hardcoded to one, as we only support one-chip execution.
size_t getNumAddressableDevices() const { return 1; }
size_t getNumAddressableDevices() const { return m_num_addressable_devices; }

private:
// Creates VHLO module from the input program code.
mlir::OwningOpRef<mlir::ModuleOp>
createVHLOModule(const std::string_view &code);

// Sets m_num_addressable_devices to the number of devices from the VHLO
// module.
void
collectNumAddressableDevices(mlir::OwningOpRef<mlir::ModuleOp> &mlir_module);

// Converts VHLO module to StableHLO module.
void convertFromVHLOToSHLO(mlir::OwningOpRef<mlir::ModuleOp> &mlir_module);

Expand Down Expand Up @@ -79,6 +82,9 @@ class ModuleBuilder {

// For every output, holds if the type is a scalar or not.
std::vector<bool> m_is_output_scalar;

// Number of devices the binary is intended to run on.
size_t m_num_addressable_devices;
};

} // namespace tt::pjrt
Expand Down
9 changes: 4 additions & 5 deletions src/common/pjrt_implementation/client_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ tt_pjrt_status ClientInstance::PopulateDevices() {

devices_.resize(devices_count);
for (size_t i = 0; i < devices_count; ++i) {
devices_[i] =
new DeviceInstance(i, *this, system_desc->chip_descs()->Get(i)->arch());
devices_[i] = new DeviceInstance(chip_ids[i], *this,
system_desc->chip_descs()->Get(i)->arch());
}

// For now, just make all devices addressable.
Expand All @@ -191,9 +191,8 @@ PJRT_Error *ClientInstance::Compile(const PJRT_Program *program,
*this,
new ExecutableImage(module_builder_->getBinary(),
std::string(program->code, program->code_size),
module_builder_->getIsOutputScalar(),
module_builder_->getNumAddressableDevices()),
addressable_devices_);
module_builder_->getIsOutputScalar()),
addressable_devices_, module_builder_->getNumAddressableDevices());
*out_executable = executable.release();

return nullptr;
Expand Down
6 changes: 2 additions & 4 deletions src/common/pjrt_implementation/executable_image.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ const std::string_view kMlirFormat = "mlir";

ExecutableImage::ExecutableImage(const tt::runtime::Binary &binary,
std::string code,
const std::vector<bool> &is_output_scalar,
size_t num_addressable_devices)
const std::vector<bool> &is_output_scalar)
: m_ref_count(1), m_binary(binary), m_code(code),
m_arg_count(binary.getProgramInputs(0).size()),
m_result_count(binary.getProgramOutputs(0).size()),
m_is_output_scalar(is_output_scalar),
m_num_addressable_devices(num_addressable_devices) {
m_is_output_scalar(is_output_scalar) {
if (m_result_count != m_is_output_scalar.size()) {
// TODO: We should throw error instead, otherwise execution will continue
// and crash later.
Expand Down
9 changes: 4 additions & 5 deletions src/common/pjrt_implementation/loaded_executable_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void LoadedExecutableInstance::BindApi(PJRT_Api *api) {
const std::vector<DeviceInstance *> &addressable_devices =
loaded_executable->addressable_devices();
int num_addressable_devices =
loaded_executable->image_->get_num_addressable_devices();
loaded_executable->get_num_addressable_devices();
args->addressable_devices = const_cast<PJRT_Device **>(
reinterpret_cast<PJRT_Device *const *>(addressable_devices.data()));
args->num_addressable_devices = num_addressable_devices;
Expand Down Expand Up @@ -70,8 +70,6 @@ tt_pjrt_status
LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
DLOG_F(LOG_DEBUG, "LoadedExecutableInstance::Execute");

auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc();

// Sanity check, as we only support execution on one chip currently.
assert(args->num_devices == 1);

Expand All @@ -89,7 +87,7 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
rt_inputs.emplace_back(buffer->getTensor());
int64_t buffer_device_id =
buffer->device().device_description()->getDeviceId();
device_ids.insert(chip_ids[buffer_device_id]);
device_ids.insert(buffer_device_id);
DLOG_F(INFO, "Runtime input id: %d", buffer->unique_id());
}

Expand All @@ -99,7 +97,8 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
// TODO: Now we will run only on the first one, but this should be somehow
// explicit.
if (device_ids.size() == 0) {
device_ids_vector.push_back(chip_ids[0]);
device_ids_vector.push_back(
addressable_devices_[0]->device_description()->getDeviceId());
}

assert(device_ids_vector.size() == 1);
Expand Down

0 comments on commit cb58764

Please sign in to comment.