Skip to content

Commit

Permalink
Split Precision.init_context (#18734)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Oct 9, 2023
1 parent 0e04760 commit 3775340
Show file tree
Hide file tree
Showing 28 changed files with 88 additions and 87 deletions.
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for meta-device initialization with `Fabric.init_module(empty_init=True)` in FSDP ([#18122](https://github.com/Lightning-AI/lightning/pull/18122))


- Added `lightning.fabric.plugins.Precision.init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
- Added `lightning.fabric.plugins.Precision.module_init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))


- `lightning.fabric.strategies.Strategy.tensor_init_context()` context manager to instantiate tensors efficiently directly on device and dtype ([#17607](https://github.com/Lightning-AI/lightning/pull/17607))
Expand Down
7 changes: 5 additions & 2 deletions src/lightning/fabric/plugins/precision/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
m.compute_type_is_set = False
return module

def init_context(self) -> ContextManager:
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self.dtype)

def module_init_context(self) -> ContextManager:
if self.ignore_modules:
# cannot patch the Linear class if the user wants to skip some submodules
raise RuntimeError(
Expand All @@ -125,7 +128,7 @@ def init_context(self) -> ContextManager:
" may initialize the layers on-device, defeating the purpose of quantization. You can remove"
" `ignore_modules` or remove the `init_module` context manager."
)
dtype_ctx = _DtypeContextManager(self.dtype)
dtype_ctx = self.tensor_init_context()
# TODO: this could also support replacing `Embedding` and `Conv1D`
context_manager = _ClassReplacementContextManager({"torch.nn.Linear": self._linear_cls})
stack = ExitStack()
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/fabric/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,14 @@ def convert_module(self, module: Module) -> Module:
return module.to(dtype=self._desired_dtype)
return module

def init_context(self) -> ContextManager:
def tensor_init_context(self) -> ContextManager:
if "true" not in self.precision:
return nullcontext()
return _DtypeContextManager(self._desired_dtype)

def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)

Expand Down
7 changes: 5 additions & 2 deletions src/lightning/fabric/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ class DoublePrecision(Precision):
def convert_module(self, module: Module) -> Module:
return module.double()

def init_context(self) -> ContextManager:
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(torch.double)

def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

def forward_context(self) -> ContextManager:
return _DtypeContextManager(torch.double)
return self.tensor_init_context()

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.double)
Expand Down
7 changes: 5 additions & 2 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,16 @@ def mixed_precision_config(self) -> "TorchMixedPrecision":
buffer_dtype=buffer_dtype,
)

def init_context(self) -> ContextManager:
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self._desired_input_dtype)

def module_init_context(self) -> ContextManager:
return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32)

def forward_context(self) -> ContextManager:
if "mixed" in self.precision:
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
return _DtypeContextManager(self._desired_input_dtype)
return self.tensor_init_context()

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
Expand Down
7 changes: 5 additions & 2 deletions src/lightning/fabric/plugins/precision/half.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> No
def convert_module(self, module: Module) -> Module:
return module.to(dtype=self._desired_input_dtype)

def init_context(self) -> ContextManager:
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self._desired_input_dtype)

def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

def forward_context(self) -> ContextManager:
return _DtypeContextManager(self._desired_input_dtype)
return self.tensor_init_context()

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/fabric/plugins/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def convert_module(self, module: Module) -> Module:
"""
return module

def init_context(self) -> ContextManager:
def tensor_init_context(self) -> ContextManager:
"""Controls how tensors get created (device, dtype)."""
return nullcontext()

def module_init_context(self) -> ContextManager:
"""Instantiate module parameters or tensors in the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
Expand Down
7 changes: 5 additions & 2 deletions src/lightning/fabric/plugins/precision/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
module = module.to(dtype=self.dtype)
return module

def init_context(self) -> ContextManager:
dtype_ctx = _DtypeContextManager(self.dtype)
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self.dtype)

def module_init_context(self) -> ContextManager:
dtype_ctx = self.tensor_init_context()
stack = ExitStack()
if self.replace_layers:
import transformer_engine.pytorch as te
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,17 @@ def module_to_device(self, module: Module) -> None:
pass

def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
precision_init_ctx = self.precision.init_context()
precision_init_ctx = self.precision.module_init_context()
module_sharded_ctx = self.module_sharded_context()
empty_ctx = _EmptyInit(enabled=bool(empty_init))
stack = ExitStack()
if _TORCH_GREATER_EQUAL_2_1 and empty_init:
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
stack.enter_context(torch.device("meta"))
elif _TORCH_GREATER_EQUAL_1_13:
stack.enter_context(_EmptyInit(enabled=bool(empty_init)))
stack.enter_context(empty_ctx)
stack.enter_context(precision_init_ctx)
stack.enter_context(module_sharded_ctx)
return stack
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader:

def tensor_init_context(self) -> ContextManager:
"""Controls how tensors get created (device, dtype)."""
precision_init_ctx = self.precision.init_context()
precision_init_ctx = self.precision.tensor_init_context()
stack = ExitStack()
if _TORCH_GREATER_EQUAL_2_0:
stack.enter_context(self.root_device)
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def module_to_device(self, module: Module) -> None:
pass

def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
precision_init_ctx = self.precision.init_context()
precision_init_ctx = self.precision.module_init_context()
module_sharded_ctx = self.module_sharded_context()
stack = ExitStack()
if _TORCH_GREATER_EQUAL_1_13:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for meta-device initialization with `Trainer.init_module(empty_init=True)` in FSDP ([#18385](https://github.com/Lightning-AI/lightning/pull/18385))


- Added `lightning.pytorch.plugins.PrecisionPlugin.init_context()` and `lightning.pytorch.strategies.Strategy.tensor_init_context()` context managers to control model and tensor instantiation ([#18004](https://github.com/Lightning-AI/lightning/pull/18004))
- Added `lightning.pytorch.plugins.PrecisionPlugin.module_init_context()` and `lightning.pytorch.strategies.Strategy.tensor_init_context()` context managers to control model and tensor instantiation ([#18004](https://github.com/Lightning-AI/lightning/pull/18004))


- Automatically call `xla_model.mark_step()` before saving checkpoints with XLA ([#17882](https://github.com/Lightning-AI/lightning/pull/17882))
Expand Down
23 changes: 9 additions & 14 deletions src/lightning/pytorch/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, Union
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union

import torch
from lightning_utilities import apply_to_collection
Expand All @@ -23,7 +23,7 @@

import lightning.pytorch as pl
from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.fabric.utilities.types import Steppable
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.utilities import GradClipAlgorithmType
Expand Down Expand Up @@ -77,18 +77,13 @@ def convert_module(self, module: Module) -> Module:
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)

@contextmanager
def init_context(self) -> Generator[None, None, None]:
def tensor_init_context(self) -> ContextManager:
if "true" not in self.precision:
yield
return

default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_dtype)
try:
yield
finally:
torch.set_default_dtype(default_dtype)
return nullcontext()
return _DtypeContextManager(self._desired_dtype)

def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

def backward( # type: ignore[override]
self,
Expand Down
26 changes: 7 additions & 19 deletions src/lightning/pytorch/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Generator, Literal
from typing import Any, ContextManager, Generator, Literal

import torch
import torch.nn as nn
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor

import lightning.pytorch as pl
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation
Expand All @@ -34,19 +34,11 @@ class DoublePrecisionPlugin(PrecisionPlugin):
def convert_module(self, module: nn.Module) -> nn.Module:
return module.double()

@contextmanager
def init_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when initializing module parameters or tensors.
See: :func:`torch.set_default_dtype`
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(torch.float64)

"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float64)
try:
yield
finally:
torch.set_default_dtype(default_dtype)
def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
Expand All @@ -55,12 +47,8 @@ def forward_context(self) -> Generator[None, None, None]:
See: :func:`torch.set_default_dtype`
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float64)
try:
with self.tensor_init_context():
yield
finally:
torch.set_default_dtype(default_dtype)

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.double)
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ def mixed_precision_config(self) -> "TorchMixedPrecision":
buffer_dtype=buffer_dtype,
)

def init_context(self) -> ContextManager:
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self._desired_input_dtype)

def module_init_context(self) -> ContextManager:
return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32)

def forward_context(self) -> ContextManager:
Expand Down
20 changes: 6 additions & 14 deletions src/lightning/pytorch/plugins/precision/half.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Generator, Literal
from typing import Any, ContextManager, Generator, Literal

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import Module

from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin


Expand All @@ -40,19 +40,11 @@ def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> No
def convert_module(self, module: Module) -> Module:
return module.to(dtype=self._desired_input_dtype)

@contextmanager
def init_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when initializing module parameters or tensors.
See: :func:`torch.set_default_dtype`
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self._desired_input_dtype)

"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_input_dtype)
try:
yield
finally:
torch.set_default_dtype(default_dtype)
def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[No
empty_init_context = _EmptyInit(enabled=bool(empty_init))
else:
empty_init_context = nullcontext()
with empty_init_context, self.precision_plugin.init_context():
with empty_init_context, self.precision_plugin.tensor_init_context():
yield

@contextmanager
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[No
"""
device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext()
empty_init_context = _EmptyInit(enabled=bool(empty_init)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext()
with empty_init_context, device_context, self.precision_plugin.init_context():
with empty_init_context, device_context, self.precision_plugin.tensor_init_context():
yield

@contextmanager
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/trainer/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _call_configure_model(trainer: "pl.Trainer") -> None:
# we don't normally check for this before calling the hook. it is done here to avoid instantiating the context
# managers
if is_overridden("configure_model", trainer.lightning_module):
with trainer.strategy.tensor_init_context(), trainer.strategy.model_sharded_context():
with trainer.strategy.tensor_init_context(), trainer.strategy.model_sharded_context(), trainer.precision_plugin.module_init_context(): # noqa: E501
_call_lightning_module_hook(trainer, "configure_model")


Expand Down
4 changes: 2 additions & 2 deletions tests/tests_fabric/plugins/precision/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def test_default_dtype_is_restored(precision):
precision = FSDPPrecision("16-true")

contexts = (
(precision.init_context, precision.forward_context)
(precision.module_init_context, precision.forward_context)
if not isinstance(precision, DeepSpeedPrecision)
else (precision.init_context,)
else (precision.module_init_context,)
)
for context in contexts:
assert torch.get_default_dtype() is torch.float32
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_fabric/plugins/precision/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, in_features, out_features, bias=True, *_, **__):

# same logic as in `test_default_dtype_is_restored`
assert torch.get_default_dtype() is torch.float32
with pytest.raises(RuntimeError, match="foo"), precision.init_context():
with pytest.raises(RuntimeError, match="foo"), precision.module_init_context():
assert torch.get_default_dtype() is not torch.float32
raise RuntimeError("foo")
assert torch.get_default_dtype() is torch.float32
Expand All @@ -65,7 +65,7 @@ def __init__(self):
_NF4Linear = vars(module)["_NF4Linear"]
_NF4Linear._quantize_weight = Mock()

with precision.init_context():
with precision.module_init_context():
assert torch.get_default_dtype() == torch.float16
model = MyModule()
assert isinstance(model.l1, _NF4Linear)
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_fabric/plugins/precision/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def test_selected_dtype(precision, expected_dtype):
("16-true", torch.float16),
],
)
def test_init_context(precision, expected_dtype):
def test_module_init_context(precision, expected_dtype):
plugin = DeepSpeedPrecision(precision=precision)
with plugin.init_context():
with plugin.module_init_context():
model = torch.nn.Linear(2, 2)
assert torch.get_default_dtype() == expected_dtype
assert model.weight.dtype == expected_dtype
Expand Down
Loading

0 comments on commit 3775340

Please sign in to comment.