-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modifications to "get_working_device" function in Pytorch. Added the capability to set a Pytorch device using the "set_device" function #891
Conversation
@@ -25,4 +25,4 @@ | |||
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI | |||
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig, MixedPrecisionQuantizationConfigV2 | |||
from model_compression_toolkit.core.keras.kpi_data_facade import keras_kpi_data, keras_kpi_data_experimental | |||
from model_compression_toolkit.core.pytorch.kpi_data_facade import pytorch_kpi_data, pytorch_kpi_data_experimental | |||
from model_compression_toolkit.core.pytorch.kpi_data_facade import pytorch_kpi_data, pytorch_kpi_data_experimental |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unneeded changes
Args: | ||
device_name (str): The name of the device, e.g., 'cuda:0' or 'cpu'. | ||
|
||
If the specified device is not valid or available, it prints an error message without changing the current device. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we should alert or raise an error? If just a warning - fix the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a warning. Changed the comment.
else: | ||
logger.Logger.warning(message) | ||
|
||
def get_device(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add type hint
return self.DEVICE | ||
|
||
@staticmethod | ||
def is_valid_device(device_name: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
returned type is missing
|
||
return True, "Valid device" | ||
|
||
return True, "Valid device" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if it's 'cpu'? I think it should be checked here before True is returned
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There isn't a check for CPU because it is always a valid choice. This checks that if you choose a GPU it exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to verify that it's 'cpu' and not unrelated string, but it's up to you
device_manager = DeviceManager() | ||
device_manager.set_device(device_name) | ||
|
||
def get_working_device(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type hints
…capability to set a Pytorch device using the "set_device" function. Implemented using DeviceManager with Singleton Pattern.
…capability to set a Pytorch device using the "set_device" function. Implemented using DeviceManager with Singleton Pattern.
9c73d22
to
7c848d3
Compare
Pull Request Description:
Refine PyTorch device management using DeviceManager singleton class.
DeviceManager
for consistent and centralized management of PyTorch devices.get_working_device
function in PyTorch for improved retrieval of the current device setting using the singletonDeviceManager
.set_device
function usingDeviceManager
, enabling the capability to set the PyTorch device across various project parts dynamically.DeviceManager
to handle specific CUDA device indices ('cuda:x') and the default CUDA device ('cuda').Checklist before requesting a review: