-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Getting number of addressable devices from the mlir module instead of hardcoding #262
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,9 +29,11 @@ class LoadedExecutableInstance { | |
public: | ||
LoadedExecutableInstance( | ||
ClientInstance &client, ExecutableImage *image, | ||
const std::vector<DeviceInstance *> &addressable_devices) | ||
const std::vector<DeviceInstance *> &addressable_devices, | ||
size_t num_addressable_devices) | ||
: client_(client), image_(image), | ||
addressable_devices_(addressable_devices) {} | ||
addressable_devices_(addressable_devices), | ||
num_addressable_devices_(num_addressable_devices) {} | ||
~LoadedExecutableInstance(); | ||
|
||
operator PJRT_LoadedExecutable *() { | ||
|
@@ -46,6 +48,10 @@ class LoadedExecutableInstance { | |
return addressable_devices_; | ||
} | ||
|
||
const size_t get_num_addressable_devices() const { | ||
return num_addressable_devices_; | ||
} | ||
|
||
// Loads all executables to addressable devices. | ||
tt_pjrt_status LoadAll(); | ||
|
||
|
@@ -59,6 +65,7 @@ class LoadedExecutableInstance { | |
ClientInstance &client_; | ||
ExecutableImage *image_; // Ref-counted semantics. | ||
std::vector<DeviceInstance *> addressable_devices_; | ||
size_t num_addressable_devices_; | ||
Comment on lines
67
to
+68
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They aren't equal. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we call them There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I had the similar pending comment, to rename from
This means that we shouldn't just pass the full There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kmitrovicTT I agree that the naming is a bit weird since we had just adopted what xla and PJRT use. @mrakitaTT yes, I had this exact dilemma at some point, but the devices for compilation are passed through extra compilation options (protobuf), and without it, we have no idea what exact devices the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed offline, we can't get the exact list of devices yet from the compile options but since we are currently returning all devices as addressable devices this shouldn't be a problem. Let's then just rename occurrences of |
||
std::vector<ResidentExecutable> resident_executables_; | ||
}; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -77,6 +77,8 @@ tt_pjrt_status ModuleBuilder::buildModule(const std::string_view &code, | |
return m_status; | ||
} | ||
|
||
collectNumAddressableDevices(mlir_module); | ||
|
||
convertFromVHLOToSHLO(mlir_module); | ||
if (!tt_pjrt_status_is_ok(m_status)) { | ||
return m_status; | ||
|
@@ -118,6 +120,19 @@ ModuleBuilder::createVHLOModule(const std::string_view &code) { | |
return vhlo_module; | ||
} | ||
|
||
void ModuleBuilder::collectNumAddressableDevices( | ||
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) { | ||
if (auto attr = mlir_module->getOperation()->getAttrOfType<mlir::IntegerAttr>( | ||
"mhlo.num_partitions")) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ajakovljevicTT do we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are correct, thanks for catching this! In our examples, |
||
m_num_addressable_devices = attr.getInt(); | ||
} else { | ||
m_num_addressable_devices = 1; | ||
DLOG_F( | ||
WARNING, | ||
mrakitaTT marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"mhlo.num_partitions not found, using default number of devices: 1."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for dot at the end of the log messages |
||
} | ||
} | ||
|
||
void ModuleBuilder::convertFromVHLOToSHLO( | ||
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) { | ||
mlir::PassManager vhlo_to_shlo_pm(mlir_module.get()->getName()); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -163,8 +163,8 @@ tt_pjrt_status ClientInstance::PopulateDevices() { | |
|
||
devices_.resize(devices_count); | ||
for (size_t i = 0; i < devices_count; ++i) { | ||
devices_[i] = | ||
new DeviceInstance(i, *this, system_desc->chip_descs()->Get(i)->arch()); | ||
devices_[i] = new DeviceInstance(chip_ids[i], *this, | ||
system_desc->chip_descs()->Get(i)->arch()); | ||
Comment on lines
+166
to
+167
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like a weird line break. |
||
} | ||
|
||
// For now, just make all devices addressable. | ||
|
@@ -191,9 +191,8 @@ PJRT_Error *ClientInstance::Compile(const PJRT_Program *program, | |
*this, | ||
new ExecutableImage(module_builder_->getBinary(), | ||
std::string(program->code, program->code_size), | ||
module_builder_->getIsOutputScalar(), | ||
module_builder_->getNumAddressableDevices()), | ||
addressable_devices_); | ||
module_builder_->getIsOutputScalar()), | ||
addressable_devices_, module_builder_->getNumAddressableDevices()); | ||
Comment on lines
+194
to
+195
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also. |
||
*out_executable = executable.release(); | ||
|
||
return nullptr; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,7 @@ void LoadedExecutableInstance::BindApi(PJRT_Api *api) { | |
const std::vector<DeviceInstance *> &addressable_devices = | ||
loaded_executable->addressable_devices(); | ||
int num_addressable_devices = | ||
loaded_executable->image_->get_num_addressable_devices(); | ||
loaded_executable->get_num_addressable_devices(); | ||
args->addressable_devices = const_cast<PJRT_Device **>( | ||
reinterpret_cast<PJRT_Device *const *>(addressable_devices.data())); | ||
args->num_addressable_devices = num_addressable_devices; | ||
|
@@ -70,10 +70,8 @@ tt_pjrt_status | |
LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { | ||
DLOG_F(LOG_DEBUG, "LoadedExecutableInstance::Execute"); | ||
|
||
auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); | ||
|
||
// Sanity check, as we only support execution on one chip currently. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment is now outdated, should edit to explain that we are making sure that the number of devices matches the number of devices we counted from the SHLO module. |
||
assert(args->num_devices == 1); | ||
assert(args->num_devices == num_addressable_devices_); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems to me that we need more changes than this in this function to make it work for multichip, for example we still have a hardcoded There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at the comments in the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mrakitaTT This change is scoped so that it only adressess the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, let's just not forget to add some TODO to support in the future the scenario when |
||
|
||
int dev_index = 0; | ||
const tt::runtime::Binary &binary = image_->get_binary(); | ||
|
@@ -89,7 +87,7 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { | |
rt_inputs.emplace_back(buffer->getTensor()); | ||
int64_t buffer_device_id = | ||
buffer->device().device_description()->getDeviceId(); | ||
device_ids.insert(chip_ids[buffer_device_id]); | ||
device_ids.insert(buffer_device_id); | ||
DLOG_F(INFO, "Runtime input id: %d", buffer->unique_id()); | ||
} | ||
|
||
|
@@ -99,11 +97,10 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { | |
// TODO: Now we will run only on the first one, but this should be somehow | ||
// explicit. | ||
if (device_ids.size() == 0) { | ||
device_ids_vector.push_back(chip_ids[0]); | ||
device_ids_vector.push_back( | ||
addressable_devices_[0]->device_description()->getDeviceId()); | ||
} | ||
|
||
assert(device_ids_vector.size() == 1); | ||
|
||
tt::runtime::Device device = tt::runtime::openDevice(device_ids_vector); | ||
|
||
std::vector<tt::runtime::Tensor> rt_outputs = | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const not needed before
size_t