forked from DLLXW/baby-llama2-chinese
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcal_ppl.py
90 lines (80 loc) · 3.56 KB
/
cal_ppl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import torch
import json
from tqdm import tqdm
from typing import Any, Dict, Literal, Optional, Sequence
from argparse import ArgumentParser
from src.model_runner import init_model
from src.utils import read_config
from src.data.dataset_pretrain import PretrainDataset
from src.data.dataset_sft import SFTDataset
# https://github.com/hiyouga/LLaMA-Factory/blob/main/scripts/cal_ppl.py
def cal_ppl(
model_name_or_path: str,
train_file: str = "config/train.yaml",
# stage: Literal["pt", "sft", "rm"] = "sft",
dataset: str = "data/sft_data.csv",
ddp : bool = False,
device: str = 'cuda',
):
r"""
Calculates the ppl on the dataset of the pre-trained models.
Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
"""
config_file = os.path.join(model_name_or_path, "config.yaml")
model_config = read_config(config_file)
model, tokenizer = init_model(model_config, flag='train')
model.to(device)
train_config = read_config(train_file)
# if stage == 'pt':
# train_ds = PretrainDataset(train_config['train_data_path'],
# max_length=model_config['max_seq_len'],
# memmap=True)
# else:
train_ds = SFTDataset(dataset,
max_length=model_config['max_seq_len'],
tokenizer=tokenizer)
if ddp:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_ds,
batch_size=train_config['batch_size'],
pin_memory=False,
drop_last=False,
shuffle=False,
num_workers=0 if os.name == 'nt' else 4,
sampler=train_sampler)
criterion = torch.nn.CrossEntropyLoss(reduction="none")
total_ppl = 0
perplexities = []
batch: Dict[str, "torch.Tensor"]
with torch.no_grad():
for X, Y, loss_mask in tqdm(train_loader):
X = X.to(device)
Y = Y.to(device)
outputs = model(X,Y)
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
shift_labels: "torch.Tensor" = Y[..., 1:]
loss_mask = shift_labels != -100
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
flatten_labels = shift_labels.contiguous().view(-1)
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
total_ppl += sentence_logps.exp().sum().item()
perplexities.extend(sentence_logps.exp().tolist())
save_name = os.path.join(model_name_or_path, "ppl.json")
with open(save_name, "w", encoding="utf-8") as f:
json.dump(perplexities, f, indent=2)
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
print("Perplexities have been saved at {}.".format(save_name))
# I/O
if __name__=="__main__":
parser = ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default='./out/pretrain_layer12_dim768_seq768', help="path to config")
parser.add_argument("--train_file", type=str, default="config/train.yaml", help="path to config")
parser.add_argument("--dataset", type=str, default="data/sft_data.csv", help="path to config")
args = parser.parse_args()
cal_ppl(args.model_name_or_path, args.train_file, args.dataset)