-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Storing system descriptor #265
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ ClientInstance::ClientInstance(std::unique_ptr<Platform> platform) | |
|
||
ClientInstance::~ClientInstance() { | ||
DLOG_F(LOG_DEBUG, "ClientInstance::~ClientInstance"); | ||
std::remove(ModuleBuilder::system_desc_path.data()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is ClientInstance destructor doing this? Why even do this in the first place? I don't think I have ever seen a string being cleaned up. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} | ||
|
||
PJRT_Error *ClientInstance::Initialize() { | ||
|
@@ -164,6 +165,7 @@ void ClientInstance::BindApi(PJRT_Api *api) { | |
tt_pjrt_status ClientInstance::PopulateDevices() { | ||
DLOG_F(LOG_DEBUG, "ClientInstance::PopulateDevices"); | ||
auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); | ||
system_desc.store(ModuleBuilder::system_desc_path.data()); | ||
int devices_count = chip_ids.size(); | ||
|
||
devices_.resize(devices_count); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any particular reason why this is not just a plain string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this can be a plain string, there is no benefit to keep it as string_view. See more info here regarding string vs string_view (std string_view is implementing the same thing as abseil's).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we would store
tt::runtime::SystemDesc
object inside theClientInstance
and pass it through theCompile
function to theModuleBuilder
which would then pass it to thecreateTTIRToTTNNBackendPipeline
. Since tt-mlir currently only supports passing a path to the descriptor and is not easy to change to support passing already parsed object, as a temporary solution we need to save it on the disk as pass the path to the compiler.I propose these changes:
tt::runtime::SystemDesc m_system_descriptor
inClientInstance
which will be set in theClientInstance::PopulateDevices()
function from thesystem_desc
variable.std::string m_cached_system_descriptor_path
which should be set inside theClientInstance
constructor to the combination ofstd::filesystem::temp_directory_path()
directory and some file name that should be unique from other programs, for examplett_pjrt_system_descriptor
plus maybe name of the device architecture, and maybe even some client id if there is some in pjrt structures.ClientInstance::PopulateDevices()
initializem_system_descriptor
withsystem_desc
, and then store it intom_cached_system_descriptor_path
with a TODO comment to remove that once the support in tt-mlir is done to pass the system descriptor object. Check ifstore
method checks for errors, if not we should check.ClientInstance::Compile
function pass them_cached_system_descriptor_path
to themodule_builder_->buildModule
call.ClientInstance::~ClientInstance()
remove the cached system descriptor, as you already do.