Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifications to "get_working_device" function in Pytorch. Added the capability to set a Pytorch device using the "set_device" function #891

Merged
merged 5 commits into from
Dec 31, 2023

Conversation

lior-dikstein
Copy link
Collaborator

@lior-dikstein lior-dikstein commented Dec 25, 2023

Pull Request Description:

Refine PyTorch device management using DeviceManager singleton class.

  • Adopted the Singleton pattern in DeviceManager for consistent and centralized management of PyTorch devices.
  • Modified the get_working_device function in PyTorch for improved retrieval of the current device setting using the singleton DeviceManager.
  • Implemented the set_device function using DeviceManager, enabling the capability to set the PyTorch device across various project parts dynamically.
  • Incorporated flexibility in DeviceManager to handle specific CUDA device indices ('cuda:x') and the default CUDA device ('cuda').

Checklist before requesting a review:

  • I set the appropriate labels on the pull request.
  • I have added/updated the release note draft (if necessary).
  • I have updated the documentation to reflect my changes (if necessary).
  • All function and files are well documented.
  • All function and classes have type hints.
  • There is a licenses in all file.
  • The function and variable names are informative.
  • I have checked for code duplications.
  • I have added new unittest (if necessary).

@@ -25,4 +25,4 @@
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig, MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.core.keras.kpi_data_facade import keras_kpi_data, keras_kpi_data_experimental
from model_compression_toolkit.core.pytorch.kpi_data_facade import pytorch_kpi_data, pytorch_kpi_data_experimental
from model_compression_toolkit.core.pytorch.kpi_data_facade import pytorch_kpi_data, pytorch_kpi_data_experimental
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unneeded changes

Args:
device_name (str): The name of the device, e.g., 'cuda:0' or 'cpu'.

If the specified device is not valid or available, it prints an error message without changing the current device.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should alert or raise an error? If just a warning - fix the comment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a warning. Changed the comment.

else:
logger.Logger.warning(message)

def get_device(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add type hint

return self.DEVICE

@staticmethod
def is_valid_device(device_name: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returned type is missing


return True, "Valid device"

return True, "Valid device"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if it's 'cpu'? I think it should be checked here before True is returned

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't a check for CPU because it is always a valid choice. This checks that if you choose a GPU it exists.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to verify that it's 'cpu' and not unrelated string, but it's up to you

device_manager = DeviceManager()
device_manager.set_device(device_name)

def get_working_device():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type hints

liord added 4 commits December 31, 2023 12:07
…capability to set a Pytorch device using the "set_device" function. Implemented using DeviceManager with Singleton Pattern.
…capability to set a Pytorch device using the "set_device" function. Implemented using DeviceManager with Singleton Pattern.
@lior-dikstein lior-dikstein merged commit 7db1ae7 into main Dec 31, 2023
22 of 24 checks passed
@lior-dikstein lior-dikstein deleted the pytorch_device branch December 31, 2023 13:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Does not work on system with multiple CUDA devices
2 participants