diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 24cc342e78d1..e69ecd9acb5a 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -10,7 +10,10 @@ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any # application. # -# example: python zero_to_fp32.py . pytorch_model.bin +# example: +# python zero_to_fp32.py . output_dir/ +# or +# python zero_to_fp32.py . output_dir/ --safe_serialization import argparse import torch @@ -18,6 +21,8 @@ import math import os import re +import json +from tqdm import tqdm from collections import OrderedDict from dataclasses import dataclass @@ -139,7 +144,6 @@ def parse_model_states(files): def parse_optim_states(files, ds_checkpoint_dir): - total_files = len(files) state_dicts = [] for f in files: @@ -420,12 +424,10 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero offset = 0 total_numel = 0 total_params = 0 - for name, shape in param_shapes.items(): - + for name, shape in tqdm(param_shapes.items(), desc='Gathering Sharded Weights'): unpartitioned_numel = shape.numel() total_numel += unpartitioned_numel total_params += 1 - partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) if debug: @@ -521,21 +523,75 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) -def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, + output_dir, + max_shard_size="5GB", + safe_serialization=False, + tag=None, + exclude_frozen_parameters=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. Args: - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) - - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``output_dir``: directory to the pytorch fp32 state_dict output files + - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB + - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` - ``exclude_frozen_parameters``: exclude frozen parameters """ - + # Dependency pre-check + if safe_serialization: + try: + from safetensors.torch import save_file + except ImportError: + print('If you want to use `safe_serialization`, please `pip install safetensors`') + raise + if max_shard_size is not None: + try: + from huggingface_hub import split_torch_state_dict_into_shards + except ImportError: + print('If you want to use `max_shard_size`, please `pip install huggingface_hub`') + raise + + # Convert zero checkpoint to state_dict state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) - print(f"Saving fp32 state dict to {output_file}") - torch.save(state_dict, output_file) + + # Shard the model if it is too big. + weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin" + if max_shard_size is not None: + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") + state_dict_split = split_torch_state_dict_into_shards(state_dict, + filename_pattern=filename_pattern, + max_shard_size=max_shard_size) + else: + from collections import namedtuple + StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"]) + state_dict_split = StateDictSplit(is_sharded=False, + filename_to_tensors={weights_name: list(state_dict.keys())}) + + # Save the model + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"): + shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} + output_path = os.path.join(output_dir, shard_file) + if safe_serialization: + save_file(shard, output_path, metadata={"format": "pt"}) + else: + torch.save(shard, output_path) + + # Save index if sharded + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json" + save_index_file = os.path.join(output_dir, save_index_file) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): @@ -578,15 +634,27 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument("checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument("output_dir", + type=str, + help="directory to the pytorch fp32 state_dict output files" + "(e.g. path/checkpoint-12-output/)") parser.add_argument( - "output_file", + "--max_shard_size", type=str, - help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + default="5GB", + help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size" + "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`" + "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances" + "without CPU OOM issues.") + parser.add_argument( + "--safe_serialization", + default=False, + action='store_true', + help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).") parser.add_argument("-t", "--tag", type=str, @@ -599,6 +667,8 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): debug = args.debug convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, - args.output_file, + args.output_dir, + max_shard_size=args.max_shard_size, + safe_serialization=args.safe_serialization, tag=args.tag, exclude_frozen_parameters=args.exclude_frozen_parameters)