diff --git a/training/src/utils/ddp_zero1.py b/training/src/utils/ddp_zero1.py index da07c7bfc..479725e35 100644 --- a/training/src/utils/ddp_zero1.py +++ b/training/src/utils/ddp_zero1.py @@ -9,8 +9,13 @@ from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.utilities.types import _PATH -# from lightning_lite.utilities.types import _PATH +try: # pytorch_lightning <= 1.7 + from pytorch_lightning.utilities.types import _PATH +except ImportError: # pytorch_lightning >= 1.8 + try: + from lightning_lite.utilities.types import _PATH + except ImportError: # pytorch_lightning >= 1.9 + from lightning_fabric.utilities.types import _PATH # Copied from Pytorch's ZeroRedundancyOptimizer's state_dict method, but we only get diff --git a/training/src/utils/ddp_zero2.py b/training/src/utils/ddp_zero2.py index e526abfc3..f6003380c 100644 --- a/training/src/utils/ddp_zero2.py +++ b/training/src/utils/ddp_zero2.py @@ -13,9 +13,14 @@ from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.plugins.precision import PrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.utilities.types import _PATH -# from lightning_lite.utilities.types import _PATH from pytorch_lightning.utilities.exceptions import MisconfigurationException +try: # pytorch_lightning <= 1.7 + from pytorch_lightning.utilities.types import _PATH +except ImportError: # pytorch_lightning >= 1.8 + try: + from lightning_lite.utilities.types import _PATH + except ImportError: # pytorch_lightning >= 1.9 + from lightning_fabric.utilities.types import _PATH class DistAdamNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):