forked from huggingface/nanotron
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfix_checkpoint_bad_naming.py
51 lines (39 loc) · 1.97 KB
/
fix_checkpoint_bad_naming.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""Fixes the problem where '{type.value}_{suffix_name}.safetensors' was duplicated in checkpoint files
For example this script will change the following:
```
checkpoints/10/model/model/decoder/0/pp_block/attn/o_proj/model_model_weight.safetensors_pp-rank-0-of-1_tp-rank-0-of-2.safetensors
to
checkpoints/10/model/model/decoder/0/pp_block/attn/o_proj/model_weight_pp-rank-0-of-1_tp-rank-0-of-2.safetensors
```
Example Usage:
python scripts/fix_checkpoint_bad_naming.py /fsx/nouamane/projects/nanotron/checkpoints/10
"""
import argparse
import os
import re
from pathlib import Path
def update_checkpoint(checkpoint_dir: str):
print(f"Updating checkpoint in {checkpoint_dir}")
for root, _, files in os.walk(checkpoint_dir):
for file in files:
if file.endswith(".safetensors"):
# r'(?<=model)_(model)' means match the string '_model' that is preceded by 'model'
if len(re.findall(r"(?<=model)_(model)", file)) == 0:
continue
# we remove second _model
new_file = re.sub(r"(?<=model)_(model)", "", file)
# we would have "model_weight.safetensors_pp-rank-0-of-1_tp-rank-0-of-2.safetensors"
# let's assert we have two matches of ".safetensors"
assert len(re.findall(r".safetensors", new_file)) == 2
# then we remove first match
new_file = re.sub(r".safetensors", "", new_file, count=1)
# so that we get "model_weight_pp-rank-0-of-1_tp-rank-0-of-2.safetensors"
print(f"Renaming {file} to {new_file}")
os.rename(os.path.join(root, file), os.path.join(root, new_file))
def main():
parser = argparse.ArgumentParser(description="Update checkpoint from 1.3 to 1.4")
parser.add_argument("checkpoint_dir", type=Path, help="Path to the checkpoint directory")
args = parser.parse_args()
update_checkpoint(args.checkpoint_dir)
if __name__ == "__main__":
main()