Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the parallel conversion from ZeRO checkpoints to FP32/FP16/BF16 param weight #6655

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
101 changes: 63 additions & 38 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@

# DeepSpeed Team

from functools import partial
from itertools import chain
import argparse
import glob
import itertools
import math
from concurrent.futures import ProcessPoolExecutor
import os
import re
import shutil
from collections import OrderedDict
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from itertools import chain

import torch
import tqdm
#from pprint import pprint
Expand Down Expand Up @@ -109,7 +111,7 @@ def _save_checkpoint(file_path, chkpt_sd):
torch.save(chkpt_sd, file_path)


def extract_zero_shards(dir, ds_checkpoint, indices_3D):
def extract_zero_shards(dir, ds_checkpoint, weight_only, data_type, indices_3D):
pp_index, tp_index, dp_index = indices_3D
sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index)

Expand All @@ -121,19 +123,20 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):
pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETER_PATTERNS, [])
# print(f'{pipeline_replicated_params=}')

# dict
state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"]
# list
fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS]
param_groups_cnt = len(state_groups)

for param_group_id in range(param_groups_cnt):
param_state = OrderedDict()

flat_state = dict(
exp_avg=state_groups[param_group_id]["exp_avg"],
exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"],
fp32=fp32_groups[param_group_id],
)
for param_group_id in range(len(state_groups)):
if weight_only:
flat_state = dict(fp32=fp32_groups[param_group_id].detach(), )
else:
flat_state = dict(
exp_avg=state_groups[param_group_id]["exp_avg"],
exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"],
fp32=fp32_groups[param_group_id],
)

if "step" in state_groups[param_group_id]:
flat_state["step"] = state_groups[param_group_id]["step"]
Expand All @@ -145,29 +148,38 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):

# pprint(f"dpt{dp_index}{pp_index}{tp_index} {param_group_id} {name} => {fragment_mapping.start}:{fragment_mapping.numel}")
for state_key in flat_state.keys():
dump_param_fragment(dir, tp_index, dp_index, state_key, flat_state[state_key], name,
fragment_mapping.start, fragment_mapping.numel)
dump_param_fragment(param_state, dir, tp_index, dp_index, state_key, flat_state[state_key], name,
fragment_mapping.start, fragment_mapping.numel, data_type, weight_only)

return dp_index, param_state


def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index):
def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, weight_only, data_type, dp_index):
state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False)

flat_state = dict(
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"],
exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg_sq"],
fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0],
)
param_state = OrderedDict()

if weight_only:
flat_state = dict(fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0].detach(), )
else:
flat_state = dict(
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"],
exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg_sq"],
fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0],
)

offset = 0
for name, shape in param_shapes.items():
unpartitioned_numel = shape.numel()
partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree)
padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel))
for state_key in flat_state.keys():
dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset,
padding_free_numel)
dump_param_fragment(param_state, temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset,
padding_free_numel, data_type, weight_only)
offset += partitioned_numel

return dp_index, param_state


cnt = 0

Expand All @@ -176,23 +188,29 @@ def dp_index_to_str(dp_index):
return f"{dp_index:0>2d}"


def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel):
def dump_param_fragment(param_state, dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel,
data_type, weight_only):

global cnt # temp hack

param_base_path = os.path.join(dir, param_name, str(tp_index))
os.makedirs(param_base_path, exist_ok=True)

cnt += 1

path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}")

#print(f"{param_name}: {offset}: {numel} => {path}")

# State might be a python int or a tensor
if state_name != "step" and torch.is_tensor(state_flat_tensor):
state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone()
_save_checkpoint(path, state_flat_tensor)

if data_type == "FP16":
state_flat_tensor = state_flat_tensor.to(torch.float16)
elif data_type == "BF16":
state_flat_tensor = state_flat_tensor.to(torch.bfloat16)

if weight_only:
param_state[param_name] = state_flat_tensor
else:
param_base_path = os.path.join(dir, param_name, str(tp_index))
os.makedirs(param_base_path, exist_ok=True)
path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}")
_save_checkpoint(path, state_flat_tensor)


def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None):
Expand Down Expand Up @@ -360,19 +378,26 @@ def _do_parallel_work(do_work, work_chunks, num_workers):
return results


def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
def _extract_zero_shard_files(args, ds_checkpoint, temp_dir, weight_only=False, data_type="FP32"):
_3d_range_list = list(
itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree),
range(ds_checkpoint.dp_degree)))
#pprint(f'{_3d_range_list=}')

do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint)
_do_parallel_work(do_work, _3d_range_list, args.num_extract_workers)
do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint, weight_only, data_type)
return _do_parallel_work(do_work, _3d_range_list, args.num_extract_workers)


def _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir):
do_work = partial(extract_zero_shards_stage3, optim_files, param_shapes, dp_degree, temp_dir)
_do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers)
def _extract_zero_shard_files_stage3(args,
optim_files,
param_shapes,
dp_degree,
temp_dir,
weight_only=False,
data_type="FP32"):
do_work = partial(extract_zero_shards_stage3, optim_files, param_shapes, dp_degree, temp_dir, weight_only,
data_type)
return _do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers)


def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
Expand Down
92 changes: 76 additions & 16 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)

from deepspeed.checkpoint import DeepSpeedCheckpoint

from deepspeed.checkpoint.ds_to_universal import _inject_missing_state, _extract_zero_shard_files, _extract_zero_shard_files_stage3, _get_model_state_files, _parse_model_states_stage3, _get_optim_files


@dataclass
class zero_model_state:
Expand Down Expand Up @@ -101,6 +105,7 @@ def get_model_state_files(checkpoint_dir):

def parse_model_states(files):
zero_model_states = []
zero_stage = None
for file in files:
state_dict = torch.load(file, map_location=device, weights_only=False)

Expand Down Expand Up @@ -142,7 +147,9 @@ def parse_model_states(files):
frozen_param_fragments=frozen_param_fragments)
zero_model_states.append(z_model_state)

return zero_model_states
if zero_stage is None:
zero_stage = state_dict['ds_config']['zero_optimization']['stage']
return zero_stage, zero_model_states


def parse_optim_states(files, ds_checkpoint_dir):
Expand Down Expand Up @@ -195,20 +202,15 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_
"""
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")

optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")

model_files = get_model_state_files(ds_checkpoint_dir)

zero_model_states = parse_model_states(model_files)
zero_stage, zero_model_states = parse_model_states(model_files)
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')

if zero_stage <= 2:
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
return _get_fp32_state_dict_from_zero2_checkpoint(ds_checkpoint_dir, zero_model_states,
exclude_frozen_parameters)
elif zero_stage == 3:
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
return _get_fp32_state_dict_from_zero3_checkpoint(ds_checkpoint_dir, zero_model_states,
exclude_frozen_parameters)


Expand Down Expand Up @@ -322,10 +324,22 @@ def zero2_align(x):
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")


def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
exclude_frozen_parameters):
def _consolidate_ucp_checkpoints(args, state_dict, slice_shapes):
zero_output_folder = os.path.join(args.output_dir, "zero")

for param in slice_shapes.keys():
ucp_checkpoint_path = os.path.join(zero_output_folder, param, "fp32.pt")
weight = torch.load(ucp_checkpoint_path, map_location=device)
state_dict[param] = weight['param']


def _get_fp32_state_dict_from_zero2_checkpoint(ds_checkpoint_dir, zero_model_states, exclude_frozen_parameters):

state_dict = OrderedDict()

ds_checkpoint = DeepSpeedCheckpoint(ds_checkpoint_dir)
_inject_missing_state(ds_checkpoint)

# buffers
buffers = zero_model_states[0].buffers
state_dict.update(buffers)
Expand All @@ -335,7 +349,20 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer
if not exclude_frozen_parameters:
_zero2_merge_frozen_params(state_dict, zero_model_states)

_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
param_shards = _extract_zero_shard_files(args,
ds_checkpoint,
temp_dir=None,
weight_only=True,
data_type=args.data_type)

param_shards.sort(key=lambda x: x[0])

for _, param in param_shards:
for key, value in param.items():
if key in state_dict:
state_dict[key] = torch.cat((state_dict[key], value), 0)
else:
state_dict[key] = value

# recover shared parameters
for pair in zero_model_states[0].shared_params:
Expand Down Expand Up @@ -487,10 +514,15 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")


def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
exclude_frozen_parameters):
def _get_fp32_state_dict_from_zero3_checkpoint(ds_checkpoint_dir, zero_model_states, exclude_frozen_parameters):
state_dict = OrderedDict()

model_files = _get_model_state_files(ds_checkpoint_dir)
optim_files = _get_optim_files(ds_checkpoint_dir)
param_shapes = _parse_model_states_stage3(model_files)
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
world_size = len(model_files)

# buffers
buffers = zero_model_states[0].buffers
state_dict.update(buffers)
Expand All @@ -500,7 +532,22 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer
if not exclude_frozen_parameters:
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)

_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
param_shards = _extract_zero_shard_files_stage3(args,
optim_files,
param_shapes,
world_size,
temp_dir=None,
weight_only=True,
data_type=args.data_type)

param_shards.sort(key=lambda x: x[0])

for _, param in param_shards:
for key, value in param.items():
if key in state_dict:
state_dict[key] = torch.cat((state_dict[key], value), 0)
else:
state_dict[key] = value

# recover shared parameters
for pair in zero_model_states[0].shared_params:
Expand Down Expand Up @@ -535,7 +582,7 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
exclude_frozen_parameters=False,
lazy_mode=False):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
Convert ZeRO 2 or 3 checkpoint into a single fp32/fp16/bf16 consolidated state_dict that can be loaded with
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
via a model hub.

Expand Down Expand Up @@ -748,6 +795,19 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
parser.add_argument('--num_extract_workers',
default=4,
type=int,
help='How many parallel processes to extract zero shards')
parser.add_argument('--no_strict',
dest='strict',
action='store_false',
help='Do not perform validity checks on converted checkpoint.')
parser.add_argument(
'--data_type',
default='FP32',
choices=['FP32', 'FP16', 'BF16'],
help="Specify the output tensor data type format (FP32, FP16, BF16, FP8, BF8). Default is FP32.")
args = parser.parse_args()

debug = args.debug
Expand Down
Loading