Skip to content

Commit

Permalink
move load custom func to base
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Nov 21, 2024
1 parent 9e1d76f commit bc25a35
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 127 deletions.
125 changes: 1 addition & 124 deletions src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,16 @@
import math
import os
import time
from collections import defaultdict
from copy import deepcopy
from datetime import datetime
from functools import partial
from itertools import chain
from math import ceil
from typing import (
Any,
DefaultDict,
Dict,
Hashable,
Iterable,
List,
Optional,
Tuple,
Union,
)

import numpy as np
Expand All @@ -28,15 +22,14 @@
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import LambdaLR
from torch.profiler import ProfilerActivity, profile, tensorboard_trace_handler
from typing_extensions import TypeAlias

from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import Config, DatasetStageArgs, LRSchedulerArgs, OptimizerArgs, ParallelismArgs
from nanotron.distributed import ProcessGroup
from nanotron.logging import LogItem, log_rank
from nanotron.models.base import NanotronModel
from nanotron.optim.base import BaseOptimizer, Optimizer
from nanotron.optim.base import BaseOptimizer, Optimizer, custom_load_state_dict
from nanotron.optim.gradient_accumulator import (
FP32GradBucketManager,
FP32GradientAccumulator,
Expand All @@ -58,11 +51,6 @@
from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod
from nanotron.serialize.metadata import TrainingMetadata

Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -310,117 +298,6 @@ def merge_named_param_groups(
return named_param_groups


# Modified from torch.optim.Optimizer._process_value_according_to_param_policy
@staticmethod
def _process_value_according_to_param_policy(
param: torch.Tensor,
value: torch.Tensor,
param_id: int,
param_groups: List[Dict[Any, Any]],
map_location: Optional[Union[str, torch.device]],
key: Hashable = None,
) -> torch.Tensor:
# If map_location is specified, use it instead of param.device
target_device = map_location if map_location is not None else param.device

fused = False
capturable = False
assert param_groups is not None
for pg in param_groups:
if param_id in pg["params"]:
fused = pg["fused"] if "fused" in pg else False
capturable = pg["capturable"] if "capturable" in pg else False
break

if key == "step":
if capturable or fused:
return value.to(dtype=torch.float32, device=target_device)
else:
return value
else:
if param.is_floating_point():
return value.to(dtype=param.dtype, device=target_device)
else:
return value.to(device=target_device)


# Modified from torch.optim.Optimizer.load_state_dict
@torch._disable_dynamo
def custom_load_state_dict(
self, state_dict: StateDict, map_location: Optional[Union[str, torch.device]] = "cpu"
) -> None:
r"""Loads the optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
map_location (str or torch.device, optional): Device where to load the optimizer states.
If None, states will be loaded to the same device as their corresponding parameters.
Default: None
"""

# shallow copy, to be consistent with module API
state_dict = state_dict.copy()

for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
hook_result = pre_hook(self, state_dict)
if hook_result is not None:
state_dict = hook_result

# Validate the state_dict
groups = self.param_groups

# Deepcopy as we write into saved_groups later to update state
saved_groups = deepcopy(state_dict["param_groups"])

if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of " "parameter groups")
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError(
"loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
)

# Update the state
id_map = dict(
zip(chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups))
)

def _cast(param, value, param_id=None, param_groups=None, key=None):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
return _process_value_according_to_param_policy(param, value, param_id, param_groups, map_location, key)
elif isinstance(value, dict):
return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()}
elif isinstance(value, Iterable):
return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value)
else:
return value

# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"])
else:
state[k] = v

# Update parameter groups, setting their 'params' value
def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:
new_group["params"] = group["params"]
return new_group

param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({"state": state, "param_groups": param_groups})

for post_hook in self._optimizer_load_state_dict_post_hooks.values():
post_hook(self)


def init_optimizer_and_grad_accumulator(
parametrization_method: ParametrizationMethod,
model: nn.Module,
Expand Down
134 changes: 133 additions & 1 deletion src/nanotron/optim/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,28 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union
from collections import defaultdict
from copy import deepcopy
from itertools import chain
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Hashable,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)

import torch
from typing_extensions import TypeAlias

Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]


class BaseOptimizer(ABC):
Expand Down Expand Up @@ -46,3 +67,114 @@ def inherit_from(self, cls) -> bool:


Optimizer = TypeVar("Optimizer", BaseOptimizer, torch.optim.Optimizer)


# Modified from torch.optim.Optimizer._process_value_according_to_param_policy
@staticmethod
def _process_value_according_to_param_policy(
param: torch.Tensor,
value: torch.Tensor,
param_id: int,
param_groups: List[Dict[Any, Any]],
map_location: Optional[Union[str, torch.device]],
key: Hashable = None,
) -> torch.Tensor:
# If map_location is specified, use it instead of param.device
target_device = map_location if map_location is not None else param.device

fused = False
capturable = False
assert param_groups is not None
for pg in param_groups:
if param_id in pg["params"]:
fused = pg["fused"] if "fused" in pg else False
capturable = pg["capturable"] if "capturable" in pg else False
break

if key == "step":
if capturable or fused:
return value.to(dtype=torch.float32, device=target_device)
else:
return value
else:
if param.is_floating_point():
return value.to(dtype=param.dtype, device=target_device)
else:
return value.to(device=target_device)


# Modified from torch.optim.Optimizer.load_state_dict
@torch._disable_dynamo
def custom_load_state_dict(
self, state_dict: StateDict, map_location: Optional[Union[str, torch.device]] = "cpu"
) -> None:
r"""Loads the optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
map_location (str or torch.device, optional): Device where to load the optimizer states.
If None, states will be loaded to the same device as their corresponding parameters.
Default: None
"""

# shallow copy, to be consistent with module API
state_dict = state_dict.copy()

for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
hook_result = pre_hook(self, state_dict)
if hook_result is not None:
state_dict = hook_result

# Validate the state_dict
groups = self.param_groups

# Deepcopy as we write into saved_groups later to update state
saved_groups = deepcopy(state_dict["param_groups"])

if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of " "parameter groups")
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError(
"loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
)

# Update the state
id_map = dict(
zip(chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups))
)

def _cast(param, value, param_id=None, param_groups=None, key=None):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
return _process_value_according_to_param_policy(param, value, param_id, param_groups, map_location, key)
elif isinstance(value, dict):
return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()}
elif isinstance(value, Iterable):
return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value)
else:
return value

# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"])
else:
state[k] = v

# Update parameter groups, setting their 'params' value
def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:
new_group["params"] = group["params"]
return new_group

param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({"state": state, "param_groups": param_groups})

for post_hook in self._optimizer_load_state_dict_post_hooks.values():
post_hook(self)
3 changes: 1 addition & 2 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import gc
import json
import os
import shutil
Expand Down Expand Up @@ -437,8 +438,6 @@ def train(

prof = get_profiler(config=self.config)
# free memory
import gc

gc.collect()
torch.cuda.empty_cache()
with prof:
Expand Down

0 comments on commit bc25a35

Please sign in to comment.