Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rolling ppl with sliding window #2553

Merged
merged 7 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 44 additions & 102 deletions eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,141 +1,83 @@
import copy
import json
import numpy as np
import os
import pyonmttok
import time
from onmt.constants import CorpusTask, DefaultTokens
from onmt.inference_engine import InferenceEnginePY
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
import onmt.opts as opts
from onmt.utils.logging import init_logger
from onmt.utils.parse import ArgumentParser
from onmt.utils.misc import use_gpu, set_random_seed
from onmt.transforms import get_transforms_cls


def compute_file_ppl(output_filename):
with open(output_filename, "r") as f:
run_results = json.load(f)
nlls = []
lengths = []
for i, _res in enumerate(run_results["scored_results"]):
print(_res)
nlls.append(_res[0])
lengths.append(_res[1])
file_ppl = np.exp(-np.sum(nlls) / np.sum(lengths))
print("wikitext-2 ppl: %.4f" % file_ppl)


def tokenize_dataset(opt, context_length):
print("Tokenization...")

# Prepare the dataset
# Clean and Concat the dataset
x = open(opt.src, "r").readlines()
x = [_x.rstrip("\n") for _x in x]
y = DefaultTokens.SEP.join(x)

with open(opt.src + ".temp", "w") as writer:
writer.write(y)

# ########################## #
# Build the dataset iterator #
# ########################## #

# Build the vocab
vocab_path_in = "/nas-labs/LM/big_llms/llama/7B/llama.vocab"
voc = []
with open(vocab_path_in, "r", encoding="utf-8") as reader:
for line in reader:
line = line.strip("\n")
voc.append(line)
vocabs = {}
src_vocab = pyonmttok.build_vocab_from_tokens(voc)
vocabs["src"] = src_vocab
vocabs["tgt"] = src_vocab
vocabs["data_task"] = "lm"
vocabs["decoder_start_token"] = "<s>"

transforms_cls = get_transforms_cls(opt._all_transform)

new_opt = opt
new_opt.gpu = -1
new_opt.parallel_mode = "data_parallel"
new_opt.src = opt.src + ".temp"

dataset_iter = build_dynamic_dataset_iter(
new_opt, transforms_cls, vocabs, task=CorpusTask.INFER, device_id=-1
)

input_tokens = []
for batch, i in dataset_iter:
for i in range(batch["src"].size()[0]):
start_ids = batch["src"][i, :, 0].cpu().numpy().tolist()
input_tokens += [
vocabs["src"].lookup_index(id)
for id in start_ids
if id != vocabs["src"].lookup_token(DefaultTokens.PAD)
]

def make_chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]

# #################### #
# Tokenize the dataset #
# ################### #
with open(opt.src + f".tokenized.context_{context_length}", "w") as writer:
for _chunk in make_chunks(input_tokens, context_length - 1):
writer.write(" ".join(_chunk) + "\n")
print(len(_chunk))
xx = [_x for _x in x if _x != " \n"]
from onmt.transforms.tokenize import SentencePieceTransform

tokenizer = SentencePieceTransform(opt)
tokenizer.warm_up()
tokens = tokenizer._tokenize(xx)
print("Done !")

z = open(opt.src + f".tokenized.context_{context_length}", "r").readlines()
print(len(z[0].split(" ")))
return tokens


def evaluate(opt):
"""Score the wikitext2 testset"""
"""Score the wikitext2 testset

The perplexity of the file is calculated with a window size of max_seq_length = 4096 tokens.
At each step, the window shifts by 512 tokens, and its first max_seq_length - stride
tokens are considered as context tokens. This means that their logits are not
taken into account, allowing this rolling perplexity to be calculated without overlap."""

ArgumentParser.validate_translate_opts(opt)
ArgumentParser._get_all_transform_translate(opt)
ArgumentParser._validate_transforms_opts(opt)
ArgumentParser.validate_translate_opts_dynamic(opt)
logger = init_logger(opt.log_file)
set_random_seed(opt.seed, use_gpu(opt))

run_results = {}
dir_name = os.path.dirname(opt.models[0])
base_name = os.path.basename(opt.models[0])

output_filename = os.path.join(
dir_name, "wikitext-2_benchmark_%s.json" % base_name[:-3]
)
# Tokenize the dataset.
opt.src = "wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
tokens = tokenize_dataset(opt, context_length=512)

# Build the translator (along with the model.
engine_opt = copy.copy(opt)
engine_opt._all_transform = []
engine = InferenceEnginePY(engine_opt)

# Tokenize the dataset.
opt.src = "eval_llm/WIKITEXT2/wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
tokenize_dataset(opt, context_length=512)

# Score the tokeznized dataset
engine.opt.src = opt.src + f".tokenized.context_{512}"
start_time = time.time()
scored_results = engine.score_file()
engine.terminate()
run_results["scored_results"] = scored_results
# Score the dataset.
stride = 512
max_seq_length = 4096

with open(output_filename, "w") as f:
json.dump(run_results, f, ensure_ascii=False, indent=2)
seq_len = len(tokens)
src = []
for begin_loc in range(0, seq_len, stride):
end_loc = min(begin_loc + max_seq_length, seq_len)
src.append(" ".join(tokens[begin_loc:end_loc]))

compute_file_ppl(output_filename)
start_time = time.time()
engine.translator.return_gold_log_probs = True
score_results = engine.score_list(src=src)
nlls = []
lengths = []
for _, log_probs, _ in score_results:
lengths.append(stride)
# zero out the context tokens
nlls += [
log_probs[i][0]
for i, _ in enumerate(log_probs)
if i > (max_seq_length - stride)
]
ppl = np.exp(-np.sum(nlls) / np.sum(lengths))

engine.terminate()
end_time = time.time()
logger.info("total run time %.2f" % (end_time - start_time))
logger.info(
"wikitext-2 perplexity with rolling likelihood and sliding window size 1000 and stride 512 %.2f" # noqa: E501
% (ppl)
)


def _get_parser():
Expand Down
1 change: 1 addition & 0 deletions onmt/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def _translate(self, infer_iter):

def _score(self, infer_iter):
self.translator.with_scores = True
self.return_gold_log_probs = True
return self.translator._score(infer_iter)

def score_list_parallel(self, src):
Expand Down
2 changes: 1 addition & 1 deletion onmt/tests/pull_request_chk.sh
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ ${PYTHON} translate.py -model ${TEST_DIR}/test_model_lm.pt \
-ban_unk_token \
-length_penalty none \
-out $TMP_OUT_DIR/gen_sampling >> ${LOG_FILE} 2>&1
diff ${DATA_DIR}/data_lm/gen-nucleus-sampling-sol$(python -c "import torch; print(torch.__version__[0])").txt $TMP_OUT_DIR/gen_sampling
diff ${DATA_DIR}/data_lm/gen-nucleus-sampling-sol$(${PYTHON} -c "import torch; print(torch.__version__[0])").txt $TMP_OUT_DIR/gen_sampling
[ "$?" -eq 0 ] || error_exit
echo "Succeeded" | tee -a ${LOG_FILE}
rm $TMP_OUT_DIR/gen_sampling
Expand Down
73 changes: 41 additions & 32 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
logger=None,
seed=-1,
with_score=False,
return_gold_log_probs=False,
):
self.model = model
self.vocabs = vocabs
Expand Down Expand Up @@ -205,6 +206,8 @@ def __init__(
set_random_seed(seed, self._use_cuda)
self.with_score = with_score

self.return_gold_log_probs = return_gold_log_probs

@classmethod
def from_opt(
cls,
Expand Down Expand Up @@ -280,26 +283,17 @@ def _log(self, msg):
print(msg)

def _gold_score(
self,
batch,
enc_out,
src_len,
use_src_map,
enc_final_hs,
batch_size,
src,
self, batch, enc_out, src_len, use_src_map, enc_final_hs, batch_size, src
):
if "tgt" in batch.keys() and not self.tgt_file_prefix:
gs = self._score_target(
batch,
enc_out,
src_len,
batch["src_map"] if use_src_map else None,
gs, glp = self._score_target(
batch, enc_out, src_len, batch["src_map"] if use_src_map else None
)
self.model.decoder.init_state(src, enc_out, enc_final_hs)
else:
gs = [0] * batch_size
return gs
glp = None
return gs, glp

def _translate(
self,
Expand Down Expand Up @@ -584,12 +578,25 @@ def _score(self, infer_iter):
self.with_scores = True
scored_bucket = {}
for batch, bucket_idx in infer_iter:
batch_data = self.translate_batch(batch, attn_debug=False)
batch_data = self.translate_batch(batch, attn_debug=False, scoring=True)
batch_gold_scores = batch_data["gold_score"].cpu().numpy().tolist()
if self.return_gold_log_probs:
batch_gold_log_probs = (
batch_data["gold_log_probs"].cpu().numpy().tolist()
)
else:
batch_gold_log_probs = None
batch_tgt_lengths = batch["tgtlen"].cpu().numpy().tolist()
batch_inds_in_bucket = batch["ind_in_bucket"]
for i, _score in enumerate(batch_gold_scores):
scored_bucket[batch_inds_in_bucket[i]] = (_score, batch_tgt_lengths[i])
log_probs = (
batch_gold_log_probs[i] if self.return_gold_log_probs else None
)
scored_bucket[batch_inds_in_bucket[i]] = (
_score,
log_probs,
batch_tgt_lengths[i],
)
score_results = [scored_bucket[i] for i in range(len(scored_bucket))]
return score_results

Expand Down Expand Up @@ -720,6 +727,7 @@ def _score_target(self, batch, enc_out, src_len, src_map):
def report_results(
self,
gold_score,
gold_log_probs,
batch,
batch_size,
decode_strategy,
Expand All @@ -730,6 +738,7 @@ def report_results(
"attention": None,
"batch": batch,
"gold_score": gold_score,
"gold_log_probs": gold_log_probs,
}

results["scores"] = decode_strategy.scores
Expand Down Expand Up @@ -900,7 +909,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy):

self.model.decoder.init_state(src, enc_out, enc_final_hs)

gold_score = self._gold_score(
gold_score, gold_log_probs = self._gold_score(
batch,
enc_out,
src_len,
Expand Down Expand Up @@ -961,6 +970,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy):

return self.report_results(
gold_score,
gold_log_probs,
batch,
batch_size,
decode_strategy,
Expand All @@ -982,7 +992,7 @@ def _score_target(self, batch, enc_out, src_len, src_map):
gold = tgt[:, 1:, :]
gold_scores = log_probs.gather(2, gold)
gold_scores = gold_scores.sum(dim=1).view(-1)
return gold_scores
return gold_scores, None


class GeneratorLM(Inference):
Expand All @@ -1001,8 +1011,9 @@ def _align_forward(self, batch, predictions):
"""
raise NotImplementedError

def translate_batch(self, batch, attn_debug):
def translate_batch(self, batch, attn_debug, scoring=False):
"""Translate a batch of sentences."""
max_length = 0 if scoring else self.max_length
with torch.no_grad():
if self.sample_from_topk != 0 or self.sample_from_topp != 0:
decode_strategy = GreedySearchLM(
Expand All @@ -1015,7 +1026,7 @@ def translate_batch(self, batch, attn_debug):
batch_size=len(batch["srclen"]),
global_scorer=self.global_scorer,
min_length=self.min_length,
max_length=self.max_length,
max_length=max_length,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
return_attention=attn_debug or self.replace_unk,
Expand All @@ -1039,7 +1050,7 @@ def translate_batch(self, batch, attn_debug):
n_best=self.n_best,
global_scorer=self.global_scorer,
min_length=self.min_length,
max_length=self.max_length,
max_length=max_length,
return_attention=attn_debug or self.replace_unk,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
Expand Down Expand Up @@ -1095,14 +1106,8 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True):

# (2) init decoder
self.model.decoder.init_state(src, None, None)
gold_score = self._gold_score(
batch,
None,
src_len,
use_src_map,
None,
batch_size,
src,
gold_score, gold_log_probs = self._gold_score(
batch, None, src_len, use_src_map, None, batch_size, src
)

# (3) prep decode_strategy. Possibly repeat src objects.
Expand Down Expand Up @@ -1158,6 +1163,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True):

return self.report_results(
gold_score,
gold_log_probs,
batch,
batch_size,
decode_strategy,
Expand All @@ -1177,7 +1183,10 @@ def _score_target(self, batch, enc_out, src_len, src_map):
)

log_probs[:, :, self._tgt_pad_idx] = 0
gold_scores = log_probs.gather(2, tgt)
gold_scores = gold_scores.sum(dim=1).view(-1)
gold_log_probs = log_probs.gather(2, tgt)
gold_scores = gold_log_probs.sum(dim=1).view(-1)

if self.return_gold_log_probs:
return gold_scores, gold_log_probs

return gold_scores
return gold_scores, None
Loading