From 10d9cb06e7665aaeb0f862a05a47da4a1cca8c12 Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Thu, 18 Jan 2024 18:33:35 +0100 Subject: [PATCH 1/7] rolling ppl with window size 1000 and stride 512 --- .../WIKITEXT2/run_wikitext-2_benchmark.py | 132 ++++-------------- 1 file changed, 31 insertions(+), 101 deletions(-) mode change 100644 => 100755 eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py diff --git a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py old mode 100644 new mode 100755 index 34663a6e34..8a8afa4d5a --- a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py +++ b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py @@ -1,99 +1,28 @@ import copy -import json import numpy as np -import os -import pyonmttok import time -from onmt.constants import CorpusTask, DefaultTokens from onmt.inference_engine import InferenceEnginePY -from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter import onmt.opts as opts from onmt.utils.logging import init_logger from onmt.utils.parse import ArgumentParser from onmt.utils.misc import use_gpu, set_random_seed -from onmt.transforms import get_transforms_cls - - -def compute_file_ppl(output_filename): - with open(output_filename, "r") as f: - run_results = json.load(f) - nlls = [] - lengths = [] - for i, _res in enumerate(run_results["scored_results"]): - print(_res) - nlls.append(_res[0]) - lengths.append(_res[1]) - file_ppl = np.exp(-np.sum(nlls) / np.sum(lengths)) - print("wikitext-2 ppl: %.4f" % file_ppl) def tokenize_dataset(opt, context_length): print("Tokenization...") - # Prepare the dataset + # Clean and Concat the dataset x = open(opt.src, "r").readlines() - x = [_x.rstrip("\n") for _x in x] - y = DefaultTokens.SEP.join(x) - - with open(opt.src + ".temp", "w") as writer: - writer.write(y) - - # ########################## # - # Build the dataset iterator # - # ########################## # - - # Build the vocab - vocab_path_in = "/nas-labs/LM/big_llms/llama/7B/llama.vocab" - voc = [] - with open(vocab_path_in, "r", encoding="utf-8") as reader: - for line in reader: - line = line.strip("\n") - voc.append(line) - vocabs = {} - src_vocab = pyonmttok.build_vocab_from_tokens(voc) - vocabs["src"] = src_vocab - vocabs["tgt"] = src_vocab - vocabs["data_task"] = "lm" - vocabs["decoder_start_token"] = "" - - transforms_cls = get_transforms_cls(opt._all_transform) - - new_opt = opt - new_opt.gpu = -1 - new_opt.parallel_mode = "data_parallel" - new_opt.src = opt.src + ".temp" - - dataset_iter = build_dynamic_dataset_iter( - new_opt, transforms_cls, vocabs, task=CorpusTask.INFER, device_id=-1 - ) - - input_tokens = [] - for batch, i in dataset_iter: - for i in range(batch["src"].size()[0]): - start_ids = batch["src"][i, :, 0].cpu().numpy().tolist() - input_tokens += [ - vocabs["src"].lookup_index(id) - for id in start_ids - if id != vocabs["src"].lookup_token(DefaultTokens.PAD) - ] - - def make_chunks(lst, n): - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] - - # #################### # - # Tokenize the dataset # - # ################### # - with open(opt.src + f".tokenized.context_{context_length}", "w") as writer: - for _chunk in make_chunks(input_tokens, context_length - 1): - writer.write(" ".join(_chunk) + "\n") - print(len(_chunk)) - + xx = [_x for _x in x if _x != ' \n'] + print(xx[:2]) + from onmt.transforms.tokenize import SentencePieceTransform + tokenizer = SentencePieceTransform(opt) + tokenizer.warm_up() + tokens = tokenizer._tokenize(xx) print("Done !") - - z = open(opt.src + f".tokenized.context_{context_length}", "r").readlines() - print(len(z[0].split(" "))) + print(len(tokens)) + print(tokens[:100]) + return tokens def evaluate(opt): @@ -105,37 +34,38 @@ def evaluate(opt): logger = init_logger(opt.log_file) set_random_seed(opt.seed, use_gpu(opt)) - run_results = {} - dir_name = os.path.dirname(opt.models[0]) - base_name = os.path.basename(opt.models[0]) - - output_filename = os.path.join( - dir_name, "wikitext-2_benchmark_%s.json" % base_name[:-3] - ) + # Tokenize the dataset. + opt.src = "wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw" + tokens = tokenize_dataset(opt, context_length=512) # Build the translator (along with the model. engine_opt = copy.copy(opt) engine_opt._all_transform = [] engine = InferenceEnginePY(engine_opt) - # Tokenize the dataset. - opt.src = "eval_llm/WIKITEXT2/wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw" - tokenize_dataset(opt, context_length=512) + # Score the dataset. + stride = 512 + max_seq_length = 4096 + max_seq_length = 1000 + seq_len = len(tokens) + print('seq_len: ', seq_len) + score_results = [] + nlls = [] + src = [] + for begin_loc in range(0, seq_len, stride): + end_loc = min(begin_loc + max_seq_length - 1, seq_len) + src.append(' '.join(tokens[begin_loc:end_loc])) - # Score the tokeznized dataset - engine.opt.src = opt.src + f".tokenized.context_{512}" start_time = time.time() - scored_results = engine.score_file() + score_results = engine.score_list(src=src) + nlls = [_score for (_score, _length) in score_results] + lengths = [_length for (_score, _length) in score_results] + ppl = np.exp(-np.sum(nlls) / np.sum(lengths)) + print(ppl) engine.terminate() - run_results["scored_results"] = scored_results - - with open(output_filename, "w") as f: - json.dump(run_results, f, ensure_ascii=False, indent=2) - - compute_file_ppl(output_filename) - end_time = time.time() logger.info("total run time %.2f" % (end_time - start_time)) + logger.info("wikitext-2 perplexity with rolling likelihood and sliding window size 1000 and stride 512 %.2f" % (ppl)) # noqa: E501 def _get_parser(): From 06bcac68aff57d6613baedf6c5eeb9fa9306f1fd Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Fri, 19 Jan 2024 08:53:39 +0100 Subject: [PATCH 2/7] applied black --- eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py index 8a8afa4d5a..898f612688 100755 --- a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py +++ b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py @@ -13,9 +13,10 @@ def tokenize_dataset(opt, context_length): # Clean and Concat the dataset x = open(opt.src, "r").readlines() - xx = [_x for _x in x if _x != ' \n'] + xx = [_x for _x in x if _x != " \n"] print(xx[:2]) from onmt.transforms.tokenize import SentencePieceTransform + tokenizer = SentencePieceTransform(opt) tokenizer.warm_up() tokens = tokenizer._tokenize(xx) @@ -48,13 +49,13 @@ def evaluate(opt): max_seq_length = 4096 max_seq_length = 1000 seq_len = len(tokens) - print('seq_len: ', seq_len) + print("seq_len: ", seq_len) score_results = [] nlls = [] src = [] for begin_loc in range(0, seq_len, stride): end_loc = min(begin_loc + max_seq_length - 1, seq_len) - src.append(' '.join(tokens[begin_loc:end_loc])) + src.append(" ".join(tokens[begin_loc:end_loc])) start_time = time.time() score_results = engine.score_list(src=src) @@ -65,7 +66,10 @@ def evaluate(opt): engine.terminate() end_time = time.time() logger.info("total run time %.2f" % (end_time - start_time)) - logger.info("wikitext-2 perplexity with rolling likelihood and sliding window size 1000 and stride 512 %.2f" % (ppl)) # noqa: E501 + logger.info( + "wikitext-2 perplexity with rolling likelihood and sliding window size 1000 and stride 512 %.2f" + % (ppl) + ) # noqa: E501 def _get_parser(): From de2a558737b15b02c539e1867637e521160b7d9c Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Fri, 19 Jan 2024 08:59:43 +0100 Subject: [PATCH 3/7] fixed flake error --- eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py index 898f612688..415c5479b8 100755 --- a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py +++ b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py @@ -47,7 +47,7 @@ def evaluate(opt): # Score the dataset. stride = 512 max_seq_length = 4096 - max_seq_length = 1000 + max_seq_length = 2048 seq_len = len(tokens) print("seq_len: ", seq_len) score_results = [] @@ -67,9 +67,9 @@ def evaluate(opt): end_time = time.time() logger.info("total run time %.2f" % (end_time - start_time)) logger.info( - "wikitext-2 perplexity with rolling likelihood and sliding window size 1000 and stride 512 %.2f" + "wikitext-2 perplexity with rolling likelihood and sliding window size 1000 and stride 512 %.2f" # noqa: E501 % (ppl) - ) # noqa: E501 + ) def _get_parser(): From 68c22af0cd742d90910154d049e0945e021de80b Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Mon, 22 Jan 2024 16:48:12 +0100 Subject: [PATCH 4/7] fixed tokenization --- eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py | 12 +++--------- onmt/translate/translator.py | 9 +++++---- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py index 415c5479b8..e278e81858 100755 --- a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py +++ b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py @@ -10,19 +10,15 @@ def tokenize_dataset(opt, context_length): print("Tokenization...") - # Clean and Concat the dataset x = open(opt.src, "r").readlines() xx = [_x for _x in x if _x != " \n"] - print(xx[:2]) from onmt.transforms.tokenize import SentencePieceTransform tokenizer = SentencePieceTransform(opt) tokenizer.warm_up() tokens = tokenizer._tokenize(xx) print("Done !") - print(len(tokens)) - print(tokens[:100]) return tokens @@ -46,23 +42,21 @@ def evaluate(opt): # Score the dataset. stride = 512 - max_seq_length = 4096 max_seq_length = 2048 + engine_opt.batch_type = "sents" + engine_opt.batch_size = 1 seq_len = len(tokens) - print("seq_len: ", seq_len) score_results = [] nlls = [] src = [] for begin_loc in range(0, seq_len, stride): - end_loc = min(begin_loc + max_seq_length - 1, seq_len) + end_loc = min(begin_loc + max_seq_length, seq_len) src.append(" ".join(tokens[begin_loc:end_loc])) - start_time = time.time() score_results = engine.score_list(src=src) nlls = [_score for (_score, _length) in score_results] lengths = [_length for (_score, _length) in score_results] ppl = np.exp(-np.sum(nlls) / np.sum(lengths)) - print(ppl) engine.terminate() end_time = time.time() logger.info("total run time %.2f" % (end_time - start_time)) diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index b8c9f57203..b86ad0315d 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -584,7 +584,7 @@ def _score(self, infer_iter): self.with_scores = True scored_bucket = {} for batch, bucket_idx in infer_iter: - batch_data = self.translate_batch(batch, attn_debug=False) + batch_data = self.translate_batch(batch, attn_debug=False, scoring=True) batch_gold_scores = batch_data["gold_score"].cpu().numpy().tolist() batch_tgt_lengths = batch["tgtlen"].cpu().numpy().tolist() batch_inds_in_bucket = batch["ind_in_bucket"] @@ -1001,8 +1001,9 @@ def _align_forward(self, batch, predictions): """ raise NotImplementedError - def translate_batch(self, batch, attn_debug): + def translate_batch(self, batch, attn_debug, scoring=False): """Translate a batch of sentences.""" + max_length = 0 if scoring else self.max_length with torch.no_grad(): if self.sample_from_topk != 0 or self.sample_from_topp != 0: decode_strategy = GreedySearchLM( @@ -1015,7 +1016,7 @@ def translate_batch(self, batch, attn_debug): batch_size=len(batch["srclen"]), global_scorer=self.global_scorer, min_length=self.min_length, - max_length=self.max_length, + max_length=max_length, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, @@ -1039,7 +1040,7 @@ def translate_batch(self, batch, attn_debug): n_best=self.n_best, global_scorer=self.global_scorer, min_length=self.min_length, - max_length=self.max_length, + max_length=max_length, return_attention=attn_debug or self.replace_unk, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, From 94da04cb74e4c79b008cfe43381d9c9ebe3f68ef Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Tue, 23 Jan 2024 12:27:01 +0100 Subject: [PATCH 5/7] zero out the context tokens --- .../WIKITEXT2/run_wikitext-2_benchmark.py | 28 ++++++--- onmt/inference_engine.py | 1 + onmt/translate/translator.py | 63 ++++++++++--------- 3 files changed, 57 insertions(+), 35 deletions(-) diff --git a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py index e278e81858..4438d55b57 100755 --- a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py +++ b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py @@ -23,7 +23,13 @@ def tokenize_dataset(opt, context_length): def evaluate(opt): - """Score the wikitext2 testset""" + """Score the wikitext2 testset + + The perplexity of the file is calculated with a window size of max_seq_length = 2048 tokens. + At each step, the window shifts by 512 tokens, and its first max_seq_length - stride + tokens are considered context tokens. This means that their logits are not + taken into account, allowing this rolling perplexity to be calculated without overlap.""" + ArgumentParser.validate_translate_opts(opt) ArgumentParser._get_all_transform_translate(opt) ArgumentParser._validate_transforms_opts(opt) @@ -43,20 +49,28 @@ def evaluate(opt): # Score the dataset. stride = 512 max_seq_length = 2048 - engine_opt.batch_type = "sents" - engine_opt.batch_size = 1 + seq_len = len(tokens) - score_results = [] - nlls = [] src = [] for begin_loc in range(0, seq_len, stride): end_loc = min(begin_loc + max_seq_length, seq_len) src.append(" ".join(tokens[begin_loc:end_loc])) + start_time = time.time() + engine.translator.return_gold_log_probs = True score_results = engine.score_list(src=src) - nlls = [_score for (_score, _length) in score_results] - lengths = [_length for (_score, _length) in score_results] + nlls = [] + lengths = [] + for _, log_probs, _ in score_results: + lengths.append(512) + # zero out the context tokens + nlls += [ + log_probs[i][0] + for i, _ in enumerate(log_probs) + if i > (max_seq_length - stride) + ] ppl = np.exp(-np.sum(nlls) / np.sum(lengths)) + engine.terminate() end_time = time.time() logger.info("total run time %.2f" % (end_time - start_time)) diff --git a/onmt/inference_engine.py b/onmt/inference_engine.py index a87b4ae76a..b088f497ef 100755 --- a/onmt/inference_engine.py +++ b/onmt/inference_engine.py @@ -163,6 +163,7 @@ def _translate(self, infer_iter): def _score(self, infer_iter): self.translator.with_scores = True + self.return_gold_log_probs = True return self.translator._score(infer_iter) def score_list_parallel(self, src): diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index b86ad0315d..145ea1eab5 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -133,6 +133,7 @@ def __init__( logger=None, seed=-1, with_score=False, + return_gold_log_probs=False, ): self.model = model self.vocabs = vocabs @@ -205,6 +206,8 @@ def __init__( set_random_seed(seed, self._use_cuda) self.with_score = with_score + self.return_gold_log_probs = return_gold_log_probs + @classmethod def from_opt( cls, @@ -280,26 +283,16 @@ def _log(self, msg): print(msg) def _gold_score( - self, - batch, - enc_out, - src_len, - use_src_map, - enc_final_hs, - batch_size, - src, + self, batch, enc_out, src_len, use_src_map, enc_final_hs, batch_size, src ): if "tgt" in batch.keys() and not self.tgt_file_prefix: - gs = self._score_target( - batch, - enc_out, - src_len, - batch["src_map"] if use_src_map else None, + gs, glp = self._score_target( + batch, enc_out, src_len, batch["src_map"] if use_src_map else None ) self.model.decoder.init_state(src, enc_out, enc_final_hs) else: gs = [0] * batch_size - return gs + return gs, glp def _translate( self, @@ -586,10 +579,23 @@ def _score(self, infer_iter): for batch, bucket_idx in infer_iter: batch_data = self.translate_batch(batch, attn_debug=False, scoring=True) batch_gold_scores = batch_data["gold_score"].cpu().numpy().tolist() + if self.return_gold_log_probs: + batch_gold_log_probs = ( + batch_data["gold_log_probs"].cpu().numpy().tolist() + ) + else: + batch_gold_log_probs = None batch_tgt_lengths = batch["tgtlen"].cpu().numpy().tolist() batch_inds_in_bucket = batch["ind_in_bucket"] for i, _score in enumerate(batch_gold_scores): - scored_bucket[batch_inds_in_bucket[i]] = (_score, batch_tgt_lengths[i]) + log_probs = ( + batch_gold_log_probs[i] if self.return_gold_log_probs else None + ) + scored_bucket[batch_inds_in_bucket[i]] = ( + _score, + log_probs, + batch_tgt_lengths[i], + ) score_results = [scored_bucket[i] for i in range(len(scored_bucket))] return score_results @@ -720,6 +726,7 @@ def _score_target(self, batch, enc_out, src_len, src_map): def report_results( self, gold_score, + gold_log_probs, batch, batch_size, decode_strategy, @@ -730,6 +737,7 @@ def report_results( "attention": None, "batch": batch, "gold_score": gold_score, + "gold_log_probs": gold_log_probs, } results["scores"] = decode_strategy.scores @@ -900,7 +908,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): self.model.decoder.init_state(src, enc_out, enc_final_hs) - gold_score = self._gold_score( + gold_score, gold_log_probs = self._gold_score( batch, enc_out, src_len, @@ -961,6 +969,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): return self.report_results( gold_score, + gold_log_probs, batch, batch_size, decode_strategy, @@ -982,7 +991,7 @@ def _score_target(self, batch, enc_out, src_len, src_map): gold = tgt[:, 1:, :] gold_scores = log_probs.gather(2, gold) gold_scores = gold_scores.sum(dim=1).view(-1) - return gold_scores + return gold_scores, None class GeneratorLM(Inference): @@ -1096,14 +1105,8 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True): # (2) init decoder self.model.decoder.init_state(src, None, None) - gold_score = self._gold_score( - batch, - None, - src_len, - use_src_map, - None, - batch_size, - src, + gold_score, gold_log_probs = self._gold_score( + batch, None, src_len, use_src_map, None, batch_size, src ) # (3) prep decode_strategy. Possibly repeat src objects. @@ -1159,6 +1162,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True): return self.report_results( gold_score, + gold_log_probs, batch, batch_size, decode_strategy, @@ -1178,7 +1182,10 @@ def _score_target(self, batch, enc_out, src_len, src_map): ) log_probs[:, :, self._tgt_pad_idx] = 0 - gold_scores = log_probs.gather(2, tgt) - gold_scores = gold_scores.sum(dim=1).view(-1) + gold_log_probs = log_probs.gather(2, tgt) + gold_scores = gold_log_probs.sum(dim=1).view(-1) + + if self.return_gold_log_probs: + return gold_scores, gold_log_probs - return gold_scores + return gold_scores, None From e7207bb8f1a97ed353e9b18eb812d608fc3ae057 Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Tue, 23 Jan 2024 14:31:35 +0100 Subject: [PATCH 6/7] fixed unit test error --- eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py | 4 ++-- onmt/tests/pull_request_chk.sh | 2 +- onmt/translate/translator.py | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py index 4438d55b57..86ed6183b9 100755 --- a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py +++ b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py @@ -25,7 +25,7 @@ def tokenize_dataset(opt, context_length): def evaluate(opt): """Score the wikitext2 testset - The perplexity of the file is calculated with a window size of max_seq_length = 2048 tokens. + The perplexity of the file is calculated with a window size of max_seq_length = 4096 tokens. At each step, the window shifts by 512 tokens, and its first max_seq_length - stride tokens are considered context tokens. This means that their logits are not taken into account, allowing this rolling perplexity to be calculated without overlap.""" @@ -48,7 +48,7 @@ def evaluate(opt): # Score the dataset. stride = 512 - max_seq_length = 2048 + max_seq_length = 4096 seq_len = len(tokens) src = [] diff --git a/onmt/tests/pull_request_chk.sh b/onmt/tests/pull_request_chk.sh index 99bfc81680..29293b925c 100755 --- a/onmt/tests/pull_request_chk.sh +++ b/onmt/tests/pull_request_chk.sh @@ -558,7 +558,7 @@ ${PYTHON} translate.py -model ${TEST_DIR}/test_model_lm.pt \ -ban_unk_token \ -length_penalty none \ -out $TMP_OUT_DIR/gen_sampling >> ${LOG_FILE} 2>&1 -diff ${DATA_DIR}/data_lm/gen-nucleus-sampling-sol$(python -c "import torch; print(torch.__version__[0])").txt $TMP_OUT_DIR/gen_sampling +diff ${DATA_DIR}/data_lm/gen-nucleus-sampling-sol$(${PYTHON} -c "import torch; print(torch.__version__[0])").txt $TMP_OUT_DIR/gen_sampling [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} rm $TMP_OUT_DIR/gen_sampling diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 145ea1eab5..85a8dc1ad9 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -292,6 +292,7 @@ def _gold_score( self.model.decoder.init_state(src, enc_out, enc_final_hs) else: gs = [0] * batch_size + glp = None return gs, glp def _translate( From e5829dbfab8f01dc48930c4746344d09166b3a34 Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Tue, 23 Jan 2024 16:11:03 +0100 Subject: [PATCH 7/7] some code cleaning --- eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py index 86ed6183b9..2ff4d28d2e 100755 --- a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py +++ b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py @@ -27,7 +27,7 @@ def evaluate(opt): The perplexity of the file is calculated with a window size of max_seq_length = 4096 tokens. At each step, the window shifts by 512 tokens, and its first max_seq_length - stride - tokens are considered context tokens. This means that their logits are not + tokens are considered as context tokens. This means that their logits are not taken into account, allowing this rolling perplexity to be calculated without overlap.""" ArgumentParser.validate_translate_opts(opt) @@ -62,7 +62,7 @@ def evaluate(opt): nlls = [] lengths = [] for _, log_probs, _ in score_results: - lengths.append(512) + lengths.append(stride) # zero out the context tokens nlls += [ log_probs[i][0]