From 377534072bc1d13af179ae0d2965d21e34a242a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 9 Oct 2023 09:34:30 -0700 Subject: [PATCH] Split `Precision.init_context` (#18734) --- src/lightning/fabric/CHANGELOG.md | 2 +- .../fabric/plugins/precision/bitsandbytes.py | 7 +++-- .../fabric/plugins/precision/deepspeed.py | 5 +++- .../fabric/plugins/precision/double.py | 7 +++-- .../fabric/plugins/precision/fsdp.py | 7 +++-- .../fabric/plugins/precision/half.py | 7 +++-- .../fabric/plugins/precision/precision.py | 6 ++++- .../plugins/precision/transformer_engine.py | 7 +++-- src/lightning/fabric/strategies/fsdp.py | 5 ++-- src/lightning/fabric/strategies/strategy.py | 2 +- src/lightning/fabric/strategies/xla_fsdp.py | 2 +- src/lightning/pytorch/CHANGELOG.md | 2 +- .../pytorch/plugins/precision/deepspeed.py | 23 +++++++--------- .../pytorch/plugins/precision/double.py | 26 +++++-------------- .../pytorch/plugins/precision/fsdp.py | 5 +++- .../pytorch/plugins/precision/half.py | 20 +++++--------- src/lightning/pytorch/strategies/fsdp.py | 2 +- src/lightning/pytorch/strategies/strategy.py | 2 +- src/lightning/pytorch/trainer/call.py | 2 +- .../plugins/precision/test_all.py | 4 +-- .../plugins/precision/test_bitsandbytes.py | 4 +-- .../plugins/precision/test_deepspeed.py | 4 +-- .../plugins/precision/test_half.py | 4 +-- .../precision/test_transformer_engine.py | 6 ++--- .../plugins/precision/test_all.py | 4 +-- .../precision/test_deepspeed_precision.py | 2 +- .../plugins/precision/test_double.py | 4 +-- .../plugins/precision/test_half.py | 4 +-- 28 files changed, 88 insertions(+), 87 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index d47a281668a9b..d539e251d2ba5 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -49,7 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for meta-device initialization with `Fabric.init_module(empty_init=True)` in FSDP ([#18122](https://github.com/Lightning-AI/lightning/pull/18122)) -- Added `lightning.fabric.plugins.Precision.init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462)) +- Added `lightning.fabric.plugins.Precision.module_init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462)) - `lightning.fabric.strategies.Strategy.tensor_init_context()` context manager to instantiate tensors efficiently directly on device and dtype ([#17607](https://github.com/Lightning-AI/lightning/pull/17607)) diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index 85801efb4e6dc..a1e1d6bc20b6a 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -116,7 +116,10 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: m.compute_type_is_set = False return module - def init_context(self) -> ContextManager: + def tensor_init_context(self) -> ContextManager: + return _DtypeContextManager(self.dtype) + + def module_init_context(self) -> ContextManager: if self.ignore_modules: # cannot patch the Linear class if the user wants to skip some submodules raise RuntimeError( @@ -125,7 +128,7 @@ def init_context(self) -> ContextManager: " may initialize the layers on-device, defeating the purpose of quantization. You can remove" " `ignore_modules` or remove the `init_module` context manager." ) - dtype_ctx = _DtypeContextManager(self.dtype) + dtype_ctx = self.tensor_init_context() # TODO: this could also support replacing `Embedding` and `Conv1D` context_manager = _ClassReplacementContextManager({"torch.nn.Linear": self._linear_cls}) stack = ExitStack() diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py index 32cde97f235c0..7f5e95bc15976 100644 --- a/src/lightning/fabric/plugins/precision/deepspeed.py +++ b/src/lightning/fabric/plugins/precision/deepspeed.py @@ -66,11 +66,14 @@ def convert_module(self, module: Module) -> Module: return module.to(dtype=self._desired_dtype) return module - def init_context(self) -> ContextManager: + def tensor_init_context(self) -> ContextManager: if "true" not in self.precision: return nullcontext() return _DtypeContextManager(self._desired_dtype) + def module_init_context(self) -> ContextManager: + return self.tensor_init_context() + def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype) diff --git a/src/lightning/fabric/plugins/precision/double.py b/src/lightning/fabric/plugins/precision/double.py index 03395626ea991..3e38ccce67be1 100644 --- a/src/lightning/fabric/plugins/precision/double.py +++ b/src/lightning/fabric/plugins/precision/double.py @@ -30,11 +30,14 @@ class DoublePrecision(Precision): def convert_module(self, module: Module) -> Module: return module.double() - def init_context(self) -> ContextManager: + def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(torch.double) + def module_init_context(self) -> ContextManager: + return self.tensor_init_context() + def forward_context(self) -> ContextManager: - return _DtypeContextManager(torch.double) + return self.tensor_init_context() def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.double) diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 3358c20b40364..054aa23c64314 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -103,13 +103,16 @@ def mixed_precision_config(self) -> "TorchMixedPrecision": buffer_dtype=buffer_dtype, ) - def init_context(self) -> ContextManager: + def tensor_init_context(self) -> ContextManager: + return _DtypeContextManager(self._desired_input_dtype) + + def module_init_context(self) -> ContextManager: return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32) def forward_context(self) -> ContextManager: if "mixed" in self.precision: return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) - return _DtypeContextManager(self._desired_input_dtype) + return self.tensor_init_context() def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype) diff --git a/src/lightning/fabric/plugins/precision/half.py b/src/lightning/fabric/plugins/precision/half.py index 2244fa9d1cc5a..77d02d0c000c2 100644 --- a/src/lightning/fabric/plugins/precision/half.py +++ b/src/lightning/fabric/plugins/precision/half.py @@ -39,11 +39,14 @@ def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> No def convert_module(self, module: Module) -> Module: return module.to(dtype=self._desired_input_dtype) - def init_context(self) -> ContextManager: + def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(self._desired_input_dtype) + def module_init_context(self) -> ContextManager: + return self.tensor_init_context() + def forward_context(self) -> ContextManager: - return _DtypeContextManager(self._desired_input_dtype) + return self.tensor_init_context() def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype) diff --git a/src/lightning/fabric/plugins/precision/precision.py b/src/lightning/fabric/plugins/precision/precision.py index 5278b59b7f966..fbff54f8e3595 100644 --- a/src/lightning/fabric/plugins/precision/precision.py +++ b/src/lightning/fabric/plugins/precision/precision.py @@ -53,7 +53,11 @@ def convert_module(self, module: Module) -> Module: """ return module - def init_context(self) -> ContextManager: + def tensor_init_context(self) -> ContextManager: + """Controls how tensors get created (device, dtype).""" + return nullcontext() + + def module_init_context(self) -> ContextManager: """Instantiate module parameters or tensors in the precision type this plugin handles. This is optional and depends on the precision limitations during optimization. diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index a9486de0b92c4..1aa4f66a27e8a 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -93,8 +93,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: module = module.to(dtype=self.dtype) return module - def init_context(self) -> ContextManager: - dtype_ctx = _DtypeContextManager(self.dtype) + def tensor_init_context(self) -> ContextManager: + return _DtypeContextManager(self.dtype) + + def module_init_context(self) -> ContextManager: + dtype_ctx = self.tensor_init_context() stack = ExitStack() if self.replace_layers: import transformer_engine.pytorch as te diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 51226250c2ba1..d2438678ecb59 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -329,8 +329,9 @@ def module_to_device(self, module: Module) -> None: pass def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: - precision_init_ctx = self.precision.init_context() + precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.module_sharded_context() + empty_ctx = _EmptyInit(enabled=bool(empty_init)) stack = ExitStack() if _TORCH_GREATER_EQUAL_2_1 and empty_init: # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: @@ -338,7 +339,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag # These operations are applied to each submodule 'bottom up' in the module hierarchy. stack.enter_context(torch.device("meta")) elif _TORCH_GREATER_EQUAL_1_13: - stack.enter_context(_EmptyInit(enabled=bool(empty_init))) + stack.enter_context(empty_ctx) stack.enter_context(precision_init_ctx) stack.enter_context(module_sharded_ctx) return stack diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 9747835690321..8f72683976189 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -120,7 +120,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader: def tensor_init_context(self) -> ContextManager: """Controls how tensors get created (device, dtype).""" - precision_init_ctx = self.precision.init_context() + precision_init_ctx = self.precision.tensor_init_context() stack = ExitStack() if _TORCH_GREATER_EQUAL_2_0: stack.enter_context(self.root_device) diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index 30ebd5c1b5c5a..348fe1894c488 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -194,7 +194,7 @@ def module_to_device(self, module: Module) -> None: pass def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: - precision_init_ctx = self.precision.init_context() + precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.module_sharded_context() stack = ExitStack() if _TORCH_GREATER_EQUAL_1_13: diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index b2651d1e772ab..79bd3438f95d0 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -83,7 +83,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for meta-device initialization with `Trainer.init_module(empty_init=True)` in FSDP ([#18385](https://github.com/Lightning-AI/lightning/pull/18385)) -- Added `lightning.pytorch.plugins.PrecisionPlugin.init_context()` and `lightning.pytorch.strategies.Strategy.tensor_init_context()` context managers to control model and tensor instantiation ([#18004](https://github.com/Lightning-AI/lightning/pull/18004)) +- Added `lightning.pytorch.plugins.PrecisionPlugin.module_init_context()` and `lightning.pytorch.strategies.Strategy.tensor_init_context()` context managers to control model and tensor instantiation ([#18004](https://github.com/Lightning-AI/lightning/pull/18004)) - Automatically call `xla_model.mark_step()` before saving checkpoints with XLA ([#17882](https://github.com/Lightning-AI/lightning/pull/17882)) diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index 288086b390023..7fa409f1d6b5f 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, Union +from contextlib import nullcontext +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -23,7 +23,7 @@ import lightning.pytorch as pl from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT -from lightning.fabric.plugins.precision.utils import _convert_fp_tensor +from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager from lightning.fabric.utilities.types import Steppable from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin from lightning.pytorch.utilities import GradClipAlgorithmType @@ -77,18 +77,13 @@ def convert_module(self, module: Module) -> Module: def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype) - @contextmanager - def init_context(self) -> Generator[None, None, None]: + def tensor_init_context(self) -> ContextManager: if "true" not in self.precision: - yield - return - - default_dtype = torch.get_default_dtype() - torch.set_default_dtype(self._desired_dtype) - try: - yield - finally: - torch.set_default_dtype(default_dtype) + return nullcontext() + return _DtypeContextManager(self._desired_dtype) + + def module_init_context(self) -> ContextManager: + return self.tensor_init_context() def backward( # type: ignore[override] self, diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index 73a49c61b268b..72b5d0d6da2bf 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Generator, Literal +from typing import Any, ContextManager, Generator, Literal import torch import torch.nn as nn @@ -20,7 +20,7 @@ from torch import Tensor import lightning.pytorch as pl -from lightning.fabric.plugins.precision.utils import _convert_fp_tensor +from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation @@ -34,19 +34,11 @@ class DoublePrecisionPlugin(PrecisionPlugin): def convert_module(self, module: nn.Module) -> nn.Module: return module.double() - @contextmanager - def init_context(self) -> Generator[None, None, None]: - """A context manager to change the default tensor type when initializing module parameters or tensors. - - See: :func:`torch.set_default_dtype` + def tensor_init_context(self) -> ContextManager: + return _DtypeContextManager(torch.float64) - """ - default_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.float64) - try: - yield - finally: - torch.set_default_dtype(default_dtype) + def module_init_context(self) -> ContextManager: + return self.tensor_init_context() @contextmanager def forward_context(self) -> Generator[None, None, None]: @@ -55,12 +47,8 @@ def forward_context(self) -> Generator[None, None, None]: See: :func:`torch.set_default_dtype` """ - default_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.float64) - try: + with self.tensor_init_context(): yield - finally: - torch.set_default_dtype(default_dtype) def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.double) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index ff47cb0415e5a..5a124ab6b676d 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -112,7 +112,10 @@ def mixed_precision_config(self) -> "TorchMixedPrecision": buffer_dtype=buffer_dtype, ) - def init_context(self) -> ContextManager: + def tensor_init_context(self) -> ContextManager: + return _DtypeContextManager(self._desired_input_dtype) + + def module_init_context(self) -> ContextManager: return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32) def forward_context(self) -> ContextManager: diff --git a/src/lightning/pytorch/plugins/precision/half.py b/src/lightning/pytorch/plugins/precision/half.py index 1c38869c4ecb6..a7ef8c82afe86 100644 --- a/src/lightning/pytorch/plugins/precision/half.py +++ b/src/lightning/pytorch/plugins/precision/half.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Generator, Literal +from typing import Any, ContextManager, Generator, Literal import torch from lightning_utilities import apply_to_collection from torch import Tensor from torch.nn import Module -from lightning.fabric.plugins.precision.utils import _convert_fp_tensor +from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin @@ -40,19 +40,11 @@ def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> No def convert_module(self, module: Module) -> Module: return module.to(dtype=self._desired_input_dtype) - @contextmanager - def init_context(self) -> Generator[None, None, None]: - """A context manager to change the default tensor type when initializing module parameters or tensors. - - See: :func:`torch.set_default_dtype` + def tensor_init_context(self) -> ContextManager: + return _DtypeContextManager(self._desired_input_dtype) - """ - default_dtype = torch.get_default_dtype() - torch.set_default_dtype(self._desired_input_dtype) - try: - yield - finally: - torch.set_default_dtype(default_dtype) + def module_init_context(self) -> ContextManager: + return self.tensor_init_context() @contextmanager def forward_context(self) -> Generator[None, None, None]: diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index c40f36df970d7..88404a6184014 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -360,7 +360,7 @@ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[No empty_init_context = _EmptyInit(enabled=bool(empty_init)) else: empty_init_context = nullcontext() - with empty_init_context, self.precision_plugin.init_context(): + with empty_init_context, self.precision_plugin.tensor_init_context(): yield @contextmanager diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index fd6b011633e0d..3b6467d5bd886 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -502,7 +502,7 @@ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[No """ device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext() empty_init_context = _EmptyInit(enabled=bool(empty_init)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext() - with empty_init_context, device_context, self.precision_plugin.init_context(): + with empty_init_context, device_context, self.precision_plugin.tensor_init_context(): yield @contextmanager diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index 2eab1bac09c0f..b9c270b620c1e 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -105,7 +105,7 @@ def _call_configure_model(trainer: "pl.Trainer") -> None: # we don't normally check for this before calling the hook. it is done here to avoid instantiating the context # managers if is_overridden("configure_model", trainer.lightning_module): - with trainer.strategy.tensor_init_context(), trainer.strategy.model_sharded_context(): + with trainer.strategy.tensor_init_context(), trainer.strategy.model_sharded_context(), trainer.precision_plugin.module_init_context(): # noqa: E501 _call_lightning_module_hook(trainer, "configure_model") diff --git a/tests/tests_fabric/plugins/precision/test_all.py b/tests/tests_fabric/plugins/precision/test_all.py index 136e49214c57f..5e86a35647489 100644 --- a/tests/tests_fabric/plugins/precision/test_all.py +++ b/tests/tests_fabric/plugins/precision/test_all.py @@ -17,9 +17,9 @@ def test_default_dtype_is_restored(precision): precision = FSDPPrecision("16-true") contexts = ( - (precision.init_context, precision.forward_context) + (precision.module_init_context, precision.forward_context) if not isinstance(precision, DeepSpeedPrecision) - else (precision.init_context,) + else (precision.module_init_context,) ) for context in contexts: assert torch.get_default_dtype() is torch.float32 diff --git a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py index 82683b7d1e59a..e4d8763da35f1 100644 --- a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py @@ -46,7 +46,7 @@ def __init__(self, in_features, out_features, bias=True, *_, **__): # same logic as in `test_default_dtype_is_restored` assert torch.get_default_dtype() is torch.float32 - with pytest.raises(RuntimeError, match="foo"), precision.init_context(): + with pytest.raises(RuntimeError, match="foo"), precision.module_init_context(): assert torch.get_default_dtype() is not torch.float32 raise RuntimeError("foo") assert torch.get_default_dtype() is torch.float32 @@ -65,7 +65,7 @@ def __init__(self): _NF4Linear = vars(module)["_NF4Linear"] _NF4Linear._quantize_weight = Mock() - with precision.init_context(): + with precision.module_init_context(): assert torch.get_default_dtype() == torch.float16 model = MyModule() assert isinstance(model.l1, _NF4Linear) diff --git a/tests/tests_fabric/plugins/precision/test_deepspeed.py b/tests/tests_fabric/plugins/precision/test_deepspeed.py index 4279c14bc226f..248f616646842 100644 --- a/tests/tests_fabric/plugins/precision/test_deepspeed.py +++ b/tests/tests_fabric/plugins/precision/test_deepspeed.py @@ -79,9 +79,9 @@ def test_selected_dtype(precision, expected_dtype): ("16-true", torch.float16), ], ) -def test_init_context(precision, expected_dtype): +def test_module_init_context(precision, expected_dtype): plugin = DeepSpeedPrecision(precision=precision) - with plugin.init_context(): + with plugin.module_init_context(): model = torch.nn.Linear(2, 2) assert torch.get_default_dtype() == expected_dtype assert model.weight.dtype == expected_dtype diff --git a/tests/tests_fabric/plugins/precision/test_half.py b/tests/tests_fabric/plugins/precision/test_half.py index ed92766706aa4..4037feebbd178 100644 --- a/tests/tests_fabric/plugins/precision/test_half.py +++ b/tests/tests_fabric/plugins/precision/test_half.py @@ -36,9 +36,9 @@ def test_selected_dtype(precision, expected_dtype): ("16-true", torch.half), ], ) -def test_init_context(precision, expected_dtype): +def test_module_init_context(precision, expected_dtype): plugin = HalfPrecision(precision=precision) - with plugin.init_context(): + with plugin.module_init_context(): model = torch.nn.Linear(2, 2) assert torch.get_default_dtype() == expected_dtype assert model.weight.dtype == expected_dtype diff --git a/tests/tests_fabric/plugins/precision/test_transformer_engine.py b/tests/tests_fabric/plugins/precision/test_transformer_engine.py index b721f71a3b568..b44df5233b453 100644 --- a/tests/tests_fabric/plugins/precision/test_transformer_engine.py +++ b/tests/tests_fabric/plugins/precision/test_transformer_engine.py @@ -56,7 +56,7 @@ def test_transformer_engine_plugin(monkeypatch): # same logic as in `test_default_dtype_is_restored` assert torch.get_default_dtype() is torch.float32 - with pytest.raises(RuntimeError, match="foo"), precision.init_context(): + with pytest.raises(RuntimeError, match="foo"), precision.module_init_context(): assert torch.get_default_dtype() is not torch.float32 raise RuntimeError("foo") assert torch.get_default_dtype() is torch.float32 @@ -95,7 +95,7 @@ def __init__(self): assert mock_calls[1][1][1]._extract_mock_name() == "mock.pytorch.LayerNorm()" precision.replace_layers = False - with precision.init_context(): + with precision.module_init_context(): model = MyModule() assert isinstance(model.l1, torch.nn.Linear) assert isinstance(model.l2, torch.nn.LayerNorm) @@ -110,7 +110,7 @@ class TELayerNormMock(Mock): transformer_engine_mock.pytorch.Linear = TELinearMock transformer_engine_mock.pytorch.LayerNorm = TELayerNormMock precision.replace_layers = True - with precision.init_context(): + with precision.module_init_context(): assert torch.get_default_dtype() == torch.float16 model = MyModule() assert isinstance(model.l1, TELinearMock) diff --git a/tests/tests_pytorch/plugins/precision/test_all.py b/tests/tests_pytorch/plugins/precision/test_all.py index 3c380544dbafe..8b58ae9b0eebd 100644 --- a/tests/tests_pytorch/plugins/precision/test_all.py +++ b/tests/tests_pytorch/plugins/precision/test_all.py @@ -22,9 +22,9 @@ def test_default_dtype_is_restored(precision): precision = FSDPPrecisionPlugin("16-true") contexts = ( - (precision.init_context, precision.forward_context) + (precision.module_init_context, precision.forward_context) if not isinstance(precision, DeepSpeedPrecisionPlugin) - else (precision.init_context,) + else (precision.module_init_context,) ) for context in contexts: assert torch.get_default_dtype() is torch.float32 diff --git a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py index 9598d11813587..8c4d9a6b198e9 100644 --- a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py @@ -50,7 +50,7 @@ def test_selected_dtype(precision, expected_dtype): ) def test_module_init_context(precision, expected_dtype): plugin = DeepSpeedPrecisionPlugin(precision=precision) - with plugin.init_context(): + with plugin.module_init_context(): model = torch.nn.Linear(2, 2) assert torch.get_default_dtype() == expected_dtype assert model.weight.dtype == expected_dtype diff --git a/tests/tests_pytorch/plugins/precision/test_double.py b/tests/tests_pytorch/plugins/precision/test_double.py index 8c0f7717f209c..f295bd5601f45 100644 --- a/tests/tests_pytorch/plugins/precision/test_double.py +++ b/tests/tests_pytorch/plugins/precision/test_double.py @@ -184,9 +184,9 @@ def test_convert_module(): assert model.layer.weight.dtype == model.layer.bias.dtype == torch.float64 -def test_init_context(): +def test_module_init_context(): plugin = DoublePrecisionPlugin() - with plugin.init_context(): + with plugin.module_init_context(): model = torch.nn.Linear(2, 2) assert torch.get_default_dtype() == torch.double assert model.weight.dtype == torch.double diff --git a/tests/tests_pytorch/plugins/precision/test_half.py b/tests/tests_pytorch/plugins/precision/test_half.py index 26bbc3e524481..89a7cddf13b2c 100644 --- a/tests/tests_pytorch/plugins/precision/test_half.py +++ b/tests/tests_pytorch/plugins/precision/test_half.py @@ -37,9 +37,9 @@ def test_selected_dtype(precision, expected_dtype): ("16-true", torch.half), ], ) -def test_init_context(precision, expected_dtype): +def test_module_init_context(precision, expected_dtype): plugin = HalfPrecisionPlugin(precision=precision) - with plugin.init_context(): + with plugin.module_init_context(): model = torch.nn.Linear(2, 2) assert torch.get_default_dtype() == expected_dtype assert model.weight.dtype == expected_dtype