From a6ee7c2044e6ddf7d19ae3ad663149e51d6f89e7 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Mon, 23 Dec 2024 19:22:27 +0000 Subject: [PATCH] Move `parallel_state.py` to the `distributed` folder --- optimum/habana/accelerate/accelerator.py | 2 +- optimum/habana/accelerate/data_loader.py | 2 +- optimum/habana/accelerate/state.py | 2 +- optimum/habana/distributed/contextparallel.py | 2 +- optimum/habana/{ => distributed}/parallel_state.py | 0 optimum/habana/transformers/models/llama/modeling_llama.py | 3 ++- 6 files changed, 6 insertions(+), 5 deletions(-) rename optimum/habana/{ => distributed}/parallel_state.py (100%) diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py index e5fb539a5b..8926866cc3 100644 --- a/optimum/habana/accelerate/accelerator.py +++ b/optimum/habana/accelerate/accelerator.py @@ -59,7 +59,7 @@ from accelerate.utils.other import is_compiled_module from torch.optim.lr_scheduler import LRScheduler -from .. import parallel_state +from ..distributed import parallel_state if is_deepspeed_available(): diff --git a/optimum/habana/accelerate/data_loader.py b/optimum/habana/accelerate/data_loader.py index afe8fc1cd8..bab31edb0e 100644 --- a/optimum/habana/accelerate/data_loader.py +++ b/optimum/habana/accelerate/data_loader.py @@ -22,7 +22,7 @@ ) from torch.utils.data import BatchSampler, DataLoader, IterableDataset -from .. import parallel_state +from ..distributed import parallel_state from .state import GaudiAcceleratorState from .utils.operations import ( broadcast, diff --git a/optimum/habana/accelerate/state.py b/optimum/habana/accelerate/state.py index b9c4e794f7..6e507acc2c 100644 --- a/optimum/habana/accelerate/state.py +++ b/optimum/habana/accelerate/state.py @@ -21,7 +21,7 @@ from optimum.utils import logging -from .. import parallel_state +from ..distributed import parallel_state from .utils import GaudiDistributedType diff --git a/optimum/habana/distributed/contextparallel.py b/optimum/habana/distributed/contextparallel.py index 0b48465542..2020b6a84e 100644 --- a/optimum/habana/distributed/contextparallel.py +++ b/optimum/habana/distributed/contextparallel.py @@ -1,6 +1,6 @@ import torch -from ..parallel_state import ( +from .parallel_state import ( get_sequence_parallel_group, get_sequence_parallel_rank, get_sequence_parallel_world_size, diff --git a/optimum/habana/parallel_state.py b/optimum/habana/distributed/parallel_state.py similarity index 100% rename from optimum/habana/parallel_state.py rename to optimum/habana/distributed/parallel_state.py diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 55d4475a87..6cd3de0a72 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -21,7 +21,8 @@ ) from transformers.utils import is_torchdynamo_compiling -from .... import distributed, parallel_state +from .... import distributed +from ....distributed import parallel_state from ....distributed.strategy import DistributedStrategy, NoOpStrategy from ....distributed.tensorparallel import ( reduce_from_tensor_model_parallel_region,