diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index fa8b59513d804..7017c67003ca3 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -251,7 +251,7 @@ def set_world_ranks(self) -> None: pass def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: - self._pod_progress_bar_force_stdout() + _pod_progress_bar_force_stdout(self.global_rank) def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None @@ -315,13 +315,18 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: description=cls.__name__, ) - def _pod_progress_bar_force_stdout(self) -> None: - # Why is it required? The way `pytorch_xla.distributed` streams logs - # from different vms to the main worker doesn't work well with tqdm - # Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140 - # The print statement seems to force tqdm to flush stdout. - import torch_xla.core.xla_env_vars as xenv - from torch_xla.utils.utils import getenv_as - if self.global_rank == 0 and getenv_as(xenv.TPUVM_MODE, int, 0) == 1: - print() +def _pod_progress_bar_force_stdout(global_rank: int) -> None: + if _using_pjrt(): + # this was removed in https://github.com/pytorch/xla/pull/5240 + return + + # Why is it required? The way `pytorch_xla.distributed` streams logs + # from different vms to the main worker doesn't work well with tqdm + # Ref: https://github.com/pytorch/xla/blob/v2.0.0/torch_xla/distributed/xla_dist.py#L227 + # The print statement seems to force tqdm to flush stdout. + import torch_xla.core.xla_env_vars as xenv + from torch_xla.utils.utils import getenv_as + + if global_rank == 0 and getenv_as(xenv.TPUVM_MODE, int, 0) == 1: + print()