Skip to content

Commit

Permalink
[misc] fit torch api upgradation and remove legecy import (#6093)
Browse files Browse the repository at this point in the history
* [amp] fit torch's new api

* [amp] fix api call

* [amp] fix api call

* [misc] fit torch pytree api upgrade

* [misc] remove legacy import

* [misc] fit torch amp api

* [misc] fit torch amp api
  • Loading branch information
ver217 authored Oct 18, 2024
1 parent 5ddad48 commit 58d8b8a
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion colossalai/accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,4 +279,4 @@ def autocast(
"""
Return autocast function
"""
return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
return torch.amp.autocast(device_type="cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
2 changes: 1 addition & 1 deletion colossalai/kernel/jit/option.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch

from colossalai.accelerator import get_accelerator
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear

from .bias_dropout_add import bias_dropout_add_fused_train
from .bias_gelu import bias_gelu_impl
Expand Down Expand Up @@ -45,6 +44,7 @@ def warmup_jit_fusion(
dtype: torch.dtype = torch.float32,
):
"""Compile JIT functions before the main training steps"""
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear

embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())
Expand Down
10 changes: 8 additions & 2 deletions colossalai/pipeline/schedule/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import torch
import torch.cuda
from packaging.version import Version
from torch.nn import Module
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, tree_flatten, tree_map, tree_unflatten


# this register are for torch under version 1.13.1, maybe removed in the future
Expand All @@ -16,7 +17,12 @@ def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]"
return OrderedDict((key, value) for key, value in zip(context, values))


_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
if Version(torch.__version__) <= Version("1.13.1"):
try:
from torch.utils._pytree import register_pytree_node as _register_pytree_node
except ImportError:
from torch.utils._pytree import _register_pytree_node
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)


def tree_map_hf(fn: Any, pytree: Any):
Expand Down
11 changes: 6 additions & 5 deletions colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import torch.nn

from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float

Expand All @@ -27,6 +22,12 @@ class RuntimeMemTracer:

def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
super().__init__()
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)

self.module = module
self.dtype = dtype
self._gradstat = GradMemStats()
Expand Down
3 changes: 2 additions & 1 deletion colossalai/zero/gemini/placement_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.distributed as dist

from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import Chunk

from .chunk import Chunk, ChunkManager
Expand Down Expand Up @@ -172,6 +171,8 @@ def evict_tensors(
Returns:
int: the volume of memory that is evicted
"""
from colossalai.legacy.utils.memory import colo_device_memory_capacity

start = time()
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan)
AMP stands for automatic mixed precision training.
In Colossal-AI, we have incorporated different implementations of mixed precision training:

1. torch.cuda.amp
1. torch.amp
2. apex.amp
3. naive amp

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
AMP 代表自动混合精度训练。
在 Colossal-AI 中, 我们结合了混合精度训练的不同实现:

1. torch.cuda.amp
1. torch.amp
2. apex.amp
3. naive amp

Expand Down

0 comments on commit 58d8b8a

Please sign in to comment.