Skip to content

Commit

Permalink
[TPU] Do not force stdout with PJRT (#18765)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Oct 10, 2023
1 parent 2f21670 commit 929eea1
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/lightning/pytorch/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def set_world_ranks(self) -> None:
pass

def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
self._pod_progress_bar_force_stdout()
_pod_progress_bar_force_stdout(self.global_rank)

def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
Expand Down Expand Up @@ -315,13 +315,18 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
description=cls.__name__,
)

def _pod_progress_bar_force_stdout(self) -> None:
# Why is it required? The way `pytorch_xla.distributed` streams logs
# from different vms to the main worker doesn't work well with tqdm
# Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140
# The print statement seems to force tqdm to flush stdout.
import torch_xla.core.xla_env_vars as xenv
from torch_xla.utils.utils import getenv_as

if self.global_rank == 0 and getenv_as(xenv.TPUVM_MODE, int, 0) == 1:
print()
def _pod_progress_bar_force_stdout(global_rank: int) -> None:
if _using_pjrt():
# this was removed in https://github.com/pytorch/xla/pull/5240
return

# Why is it required? The way `pytorch_xla.distributed` streams logs
# from different vms to the main worker doesn't work well with tqdm
# Ref: https://github.com/pytorch/xla/blob/v2.0.0/torch_xla/distributed/xla_dist.py#L227
# The print statement seems to force tqdm to flush stdout.
import torch_xla.core.xla_env_vars as xenv
from torch_xla.utils.utils import getenv_as

if global_rank == 0 and getenv_as(xenv.TPUVM_MODE, int, 0) == 1:
print()

0 comments on commit 929eea1

Please sign in to comment.