-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrewards.py
46 lines (35 loc) · 1.64 KB
/
rewards.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def preplexity_eval(model, tokenizer, sentence, device="cpu"):
input_ids = tokenizer.encode(sentence, return_tensors="pt")
max_length = model.config.n_positions
stride = 512
seq_len = input_ids.size(1)
nlls = []
prev_end_loc = 0
for begin_loc in range(0, seq_len, stride):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
input_ids = input_ids[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
# loss is calculated using CrossEntropyLoss which averages over input tokens.
# Multiply it with trg_len to get the summation instead of average.
# We will take average over all the tokens to get the true average
# in the last step of this example.
neg_log_likelihood = outputs.loss * trg_len
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
return ppl
def get_toxicity_score(text, model, tokenizer):
inputs = tokenizer.encode(text, return_tensors="pt", truncation=True, padding=True)
outputs = model(inputs)
probabilities = torch.softmax(outputs.logits, dim=-1)
toxic_class_index = 1 # Class index for the toxic label
toxic_score = float(probabilities[:, toxic_class_index])
return toxic_score