From 009a3e71ece11b2ec20629883f7d66212db52c7d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Mar 2023 01:43:39 -0700 Subject: [PATCH] [Training] Fix lightning _PATH import --- training/src/utils/ddp_zero1.py | 9 +++++++-- training/src/utils/ddp_zero2.py | 9 +++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/training/src/utils/ddp_zero1.py b/training/src/utils/ddp_zero1.py index da07c7bfc..479725e35 100644 --- a/training/src/utils/ddp_zero1.py +++ b/training/src/utils/ddp_zero1.py @@ -9,8 +9,13 @@ from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.utilities.types import _PATH -# from lightning_lite.utilities.types import _PATH +try: # pytorch_lightning <= 1.7 + from pytorch_lightning.utilities.types import _PATH +except ImportError: # pytorch_lightning >= 1.8 + try: + from lightning_lite.utilities.types import _PATH + except ImportError: # pytorch_lightning >= 1.9 + from lightning_fabric.utilities.types import _PATH # Copied from Pytorch's ZeroRedundancyOptimizer's state_dict method, but we only get diff --git a/training/src/utils/ddp_zero2.py b/training/src/utils/ddp_zero2.py index e526abfc3..f6003380c 100644 --- a/training/src/utils/ddp_zero2.py +++ b/training/src/utils/ddp_zero2.py @@ -13,9 +13,14 @@ from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.plugins.precision import PrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.utilities.types import _PATH -# from lightning_lite.utilities.types import _PATH from pytorch_lightning.utilities.exceptions import MisconfigurationException +try: # pytorch_lightning <= 1.7 + from pytorch_lightning.utilities.types import _PATH +except ImportError: # pytorch_lightning >= 1.8 + try: + from lightning_lite.utilities.types import _PATH + except ImportError: # pytorch_lightning >= 1.9 + from lightning_fabric.utilities.types import _PATH class DistAdamNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):