From b67e492049fc7cfe066146bc58eca037946cb368 Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Thu, 18 Jan 2024 10:01:37 +0100 Subject: [PATCH] fix "\n" tokenization + phi-2 new layer names (#2552) --- eval_llm/MMLU/run_mmlu_opennmt.py | 6 +- .../WIKITEXT2/run_wikitext-2_benchmark.py | 5 +- onmt/decoders/ensemble.py | 5 +- onmt/inputters/text_corpus.py | 12 +-- onmt/inputters/text_utils.py | 20 ++--- onmt/transforms/fuzzymatch.py | 2 +- onmt/transforms/inlinetags.py | 14 ++-- onmt/transforms/misc.py | 16 ++-- onmt/transforms/normalize.py | 4 +- onmt/transforms/terminology.py | 16 ++-- onmt/transforms/tokenize.py | 15 +++- onmt/transforms/uppercase.py | 4 +- onmt/translate/translation_server.py | 8 +- onmt/utils/alignment.py | 6 +- tools/convert_HF.py | 78 +++++++++++++------ 15 files changed, 126 insertions(+), 85 deletions(-) diff --git a/eval_llm/MMLU/run_mmlu_opennmt.py b/eval_llm/MMLU/run_mmlu_opennmt.py index bede7475e7..c6be92d900 100644 --- a/eval_llm/MMLU/run_mmlu_opennmt.py +++ b/eval_llm/MMLU/run_mmlu_opennmt.py @@ -167,12 +167,12 @@ def evaluate(opt): prompt_end = format_example(test_df, i, include_answer=False) train_prompt = gen_prompt(dev_df, task, k) prompt = train_prompt + prompt_end - """ - while len(prompt.split()) > 768: + + while len(prompt.split(" ")) > 768: prompt_split = prompt.split("\n\n") prompt_split.pop(1) prompt = "\n\n".join(prompt_split) - """ + label = test_df.iloc[i, test_df.shape[1] - 1] records.append({"prompt": prompt, "answer": label}) src.append(prompt.replace("\n", "⦅newline⦆")) diff --git a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py index 94dca750cb..34663a6e34 100644 --- a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py +++ b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py @@ -119,7 +119,7 @@ def evaluate(opt): engine = InferenceEnginePY(engine_opt) # Tokenize the dataset. - opt.src = "wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw" + opt.src = "eval_llm/WIKITEXT2/wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw" tokenize_dataset(opt, context_length=512) # Score the tokeznized dataset @@ -140,8 +140,7 @@ def evaluate(opt): def _get_parser(): parser = ArgumentParser(description="run_wikitext-2_benchmark.py") - opts.config_opts(parser) - opts.translate_opts(parser, dynamic=True) + opts.translate_opts(parser) return parser diff --git a/onmt/decoders/ensemble.py b/onmt/decoders/ensemble.py index 3ed4a90f24..55ae134da5 100644 --- a/onmt/decoders/ensemble.py +++ b/onmt/decoders/ensemble.py @@ -65,7 +65,10 @@ def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs): for i, model_decoder in enumerate(self.model_decoders) ] ) - mean_attns = self.combine_attns(attns) + if attns[0]["std"] is not None: + mean_attns = self.combine_attns(attns) + else: + mean_attns = attns return EnsembleDecoderOutput(dec_outs), mean_attns def combine_attns(self, attns): diff --git a/onmt/inputters/text_corpus.py b/onmt/inputters/text_corpus.py index af0b63b7dc..e38bdfee12 100644 --- a/onmt/inputters/text_corpus.py +++ b/onmt/inputters/text_corpus.py @@ -174,20 +174,20 @@ def __init__( def _process(self, stream): for i, example in enumerate(stream): - example["src"] = example["src"].strip("\n").split() - example["src_original"] = example["src_original"].strip("\n").split() + example["src"] = example["src"].strip().split(" ") + example["src_original"] = example["src_original"].strip().split(" ") if "src_feats" in example: example["src_feats"] = [ - feat.strip("\n").split() for feat in example["src_feats"] + feat.strip().split(" ") for feat in example["src_feats"] ] line_number = i * self.stride + self.offset example["cid_line_number"] = line_number example["cid"] = self.cid if "align" in example: - example["align"] = example["align"].strip("\n").split() + example["align"] = example["align"].strip().split(" ") if example["tgt"] is not None: - example["tgt"] = example["tgt"].strip("\n").split() - example["tgt_original"] = example["tgt_original"].strip("\n").split() + example["tgt"] = example["tgt"].strip().split(" ") + example["tgt_original"] = example["tgt_original"].strip().split(" ") if ( len(example["src"]) == 0 or len(example["tgt"]) == 0 diff --git a/onmt/inputters/text_utils.py b/onmt/inputters/text_utils.py index 42ebb3c1c7..83da07cc62 100644 --- a/onmt/inputters/text_utils.py +++ b/onmt/inputters/text_utils.py @@ -121,31 +121,33 @@ def numericalize(vocabs, example): numeric = example numeric["src"]["src_ids"] = [] if vocabs["data_task"] == ModelTask.SEQ2SEQ: - src_text = example["src"]["src"].split() + src_text = example["src"]["src"].split(" ") numeric["src"]["src_ids"] = vocabs["src"](src_text) if example["tgt"] is not None: numeric["tgt"]["tgt_ids"] = [] - tgt_text = example["tgt"]["tgt"].split() + tgt_text = example["tgt"]["tgt"].split(" ") numeric["tgt"]["tgt_ids"] = vocabs["tgt"]( [decoder_start_token] + tgt_text + [DefaultTokens.EOS] ) elif vocabs["data_task"] == ModelTask.LANGUAGE_MODEL: - src_text = example["src"]["src"].split() + src_text = example["src"]["src"].split(" ") if decoder_start_token != "": src_text = [decoder_start_token] + src_text numeric["src"]["src_ids"] = vocabs["src"](src_text) if example["tgt"] is not None: numeric["tgt"]["tgt_ids"] = [] - tgt_text = example["tgt"]["tgt"].split() + tgt_text = example["tgt"]["tgt"].split(" ") numeric["tgt"]["tgt_ids"] = vocabs["tgt"](tgt_text + [DefaultTokens.EOS]) + if decoder_start_token == "": + numeric["tgt"]["tgt_ids"] = numeric["tgt"]["tgt_ids"][1:] else: raise ValueError(f"Something went wrong with task {vocabs['data_task']}") if "feats" in example["src"]: numeric_feats = [] for fv, feat in zip(vocabs["src_feats"], example["src"]["feats"]): - numeric_feats.append(fv(feat.split())) + numeric_feats.append(fv(feat.split(" "))) numeric["src"]["feats"] = numeric_feats return numeric @@ -329,7 +331,7 @@ def textbatch_to_tensor(vocabs, batch, device, is_train=False): infer_iter = [] for i, ex in enumerate(batch): # Keep it consistent with dynamic data - ex["srclen"] = len(ex["src"]["src"].split()) + ex["srclen"] = len(ex["src"]["src"].split(" ")) ex["in_in_bucket"] = i ex["cid"] = "text" ex["cid_line_number"] = i @@ -354,7 +356,7 @@ def _addcopykeys(vocabs, example): Returns: ``example``, changed as described. """ - src = example["src"]["src"].split() + src = example["src"]["src"].split(" ") src_ex_vocab = pyonmttok.build_vocab_from_tokens( Counter(src), maximum_size=0, @@ -377,10 +379,10 @@ def _addcopykeys(vocabs, example): if vocabs["data_task"] == ModelTask.SEQ2SEQ: tgt = ( [DefaultTokens.UNK] - + example["tgt"]["tgt"].split() + + example["tgt"]["tgt"].split(" ") + [DefaultTokens.UNK] ) elif vocabs["data_task"] == ModelTask.LANGUAGE_MODEL: - tgt = example["tgt"]["tgt"].split() + [DefaultTokens.UNK] + tgt = example["tgt"]["tgt"].split(" ") + [DefaultTokens.UNK] example["alignment"] = src_ex_vocab(tgt) return example diff --git a/onmt/transforms/fuzzymatch.py b/onmt/transforms/fuzzymatch.py index 6341156785..3d787f23f1 100644 --- a/onmt/transforms/fuzzymatch.py +++ b/onmt/transforms/fuzzymatch.py @@ -216,6 +216,6 @@ def batch_apply(self, batch, is_train=False, stats=None, **kwargs): assert len(src_segments) == len(fuzzied_src) for idx, (example, _, _) in enumerate(batch): if fuzzied_src[idx] != "": - example["src"] = fuzzied_src[idx].split() + example["src"] = fuzzied_src[idx].split(" ") return batch diff --git a/onmt/transforms/inlinetags.py b/onmt/transforms/inlinetags.py index 223f0045a8..f34792510b 100644 --- a/onmt/transforms/inlinetags.py +++ b/onmt/transforms/inlinetags.py @@ -73,8 +73,8 @@ def _tagged_src_tgt(self, src_example, tgt_example) -> tuple: maybe_augmented[1].strip() if len(maybe_augmented) > 1 else None ) - tokenized_source_string = source_only.split() - tokenized_target_string = tgt_example.split() + tokenized_source_string = source_only.split(" ") + tokenized_target_string = tgt_example.split(" ") src_offset, tgt_offset = 0, 0 src_with_tags, tgt_with_tags = list(), list() @@ -140,12 +140,12 @@ def _tagged_src_tgt(self, src_example, tgt_example) -> tuple: src_term = " ".join( tokenized_source_string[ - source_index : source_index + len(pair[0].split()) + source_index : source_index + len(pair[0].split(" ")) ] ) tgt_term = " ".join( tokenized_target_string[ - target_index : target_index + len(pair[1].split()) + target_index : target_index + len(pair[1].split(" ")) ] ) @@ -210,11 +210,11 @@ def _tagged_src_tgt(self, src_example, tgt_example) -> tuple: tgt_with_tags.append(tgt_example[tgt_offset:]) return ( - "".join(src_with_tags).replace("∥", " ").split(), - "".join(tgt_with_tags).replace("∥", " ").split(), + "".join(src_with_tags).replace("∥", " ").split(" "), + "".join(tgt_with_tags).replace("∥", " ").split(" "), ), is_match else: - return (src_example.split(), tgt_example.split()), is_match + return (src_example.split(" "), tgt_example.split(" ")), is_match @register_transform(name="inlinetags") diff --git a/onmt/transforms/misc.py b/onmt/transforms/misc.py index d526f5f4ac..2614ce758f 100644 --- a/onmt/transforms/misc.py +++ b/onmt/transforms/misc.py @@ -136,8 +136,8 @@ def get_specials(cls, opts): prefix_dict = cls.get_prefix_dict(opts) src_specials, tgt_specials = set(), set() for _, prefix in prefix_dict.items(): - src_specials.update(prefix["src"].split()) - tgt_specials.update(prefix["tgt"].split()) + src_specials.update(prefix["src"].split(" ")) + tgt_specials.update(prefix["tgt"].split(" ")) return (src_specials, tgt_specials) def warm_up(self, vocabs=None): @@ -149,9 +149,9 @@ def _prepend(self, example, prefix): """Prepend `prefix` to `tokens`.""" for side, side_prefix in prefix.items(): if example.get(side) is not None: - example[side] = side_prefix.split() + example[side] + example[side] = side_prefix.split(" ") + example[side] elif len(side_prefix) > 0: - example[side] = side_prefix.split() + example[side] = side_prefix.split(" ") return example def apply(self, example, is_train=False, stats=None, **kwargs): @@ -250,8 +250,8 @@ def get_specials(cls, opts): suffix_dict = cls.get_suffix_dict(opts) src_specials, tgt_specials = set(), set() for _, suffix in suffix_dict.items(): - src_specials.update(suffix["src"].split()) - tgt_specials.update(suffix["tgt"].split()) + src_specials.update(suffix["src"].split(" ")) + tgt_specials.update(suffix["tgt"].split(" ")) return (src_specials, tgt_specials) def warm_up(self, vocabs=None): @@ -263,9 +263,9 @@ def _append(self, example, suffix): """Prepend `suffix` to `tokens`.""" for side, side_suffix in suffix.items(): if example.get(side) is not None: - example[side] = example[side] + side_suffix.split() + example[side] = example[side] + side_suffix.split(" ") elif len(side_suffix) > 0: - example[side] = side_suffix.split() + example[side] = side_suffix.split(" ") return example def apply(self, example, is_train=False, stats=None, **kwargs): diff --git a/onmt/transforms/normalize.py b/onmt/transforms/normalize.py index 3cf11975c5..30ecd244c1 100644 --- a/onmt/transforms/normalize.py +++ b/onmt/transforms/normalize.py @@ -329,7 +329,7 @@ def apply(self, example, is_train=False, stats=None, **kwargs): self.pre_dict[corpus_name], self.post_dict[corpus_name], ) - example["src"] = src_str.split() + example["src"] = src_str.split(" ") if example["tgt"] is not None: tgt_str = self.mpn.normalize( @@ -341,6 +341,6 @@ def apply(self, example, is_train=False, stats=None, **kwargs): self.pre_dict[corpus_name], self.post_dict[corpus_name], ) - example["tgt"] = tgt_str.split() + example["tgt"] = tgt_str.split(" ") return example diff --git a/onmt/transforms/terminology.py b/onmt/transforms/terminology.py index 7d775ec8ad..e0ddaf6c5e 100644 --- a/onmt/transforms/terminology.py +++ b/onmt/transforms/terminology.py @@ -57,7 +57,7 @@ def _create_internal_termbase(self, termbase_path): for pair in pairs: src_term, tgt_term = map(str, pair.split("\t")) src_lemma = " ".join( - "∥".join(tok.lemma_.split()) for tok in self.src_nlp(src_term) + "∥".join(tok.lemma_.split(" ")) for tok in self.src_nlp(src_term) ).strip() tgt_lemma = " ".join( tok.lemma_ for tok in self.tgt_nlp(tgt_term) @@ -93,7 +93,7 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple: # Perform tokenization with spacy for consistency. tokenized_source = [tok.text for tok in doc_src] - lemmatized_source = ["∥".join(tok.lemma_.lower().split()) for tok in doc_src] + lemmatized_source = ["∥".join(tok.lemma_.lower().split(" ")) for tok in doc_src] lemmatized_target = [tok.lemma_.lower() for tok in doc_tgt] lemmatized_source_string = " ".join(lemmatized_source) @@ -143,7 +143,7 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple: lemma_list_index += len(w) + 1 # We need to know if the term is multiword - num_words_in_src_term = len(src_entry.split()) + num_words_in_src_term = len(src_entry.split(" ")) src_term = " ".join( tokenized_source[ lemma_list_index : lemma_list_index + num_words_in_src_term @@ -164,7 +164,7 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple: if is_match: source_with_terms.append(lemmatized_source_string[offset:]) - tokenized_source_with_terms = "".join(source_with_terms).split() + tokenized_source_with_terms = "".join(source_with_terms).split(" ") if not ( len(tokenized_source) @@ -173,7 +173,7 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple: ): final_string = " ".join(tokenized_source) fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string) - return fixed_punct.split(), not is_match + return fixed_punct.split(" "), not is_match # Construct the final source from the lemmatized list # that contains the terms. We compare the tokens in the @@ -195,17 +195,17 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple: final_string = " ".join( completed_tokenized_source + [self.delimiter] - + augmented_part.split() + + augmented_part.split(" ") ) else: final_string = " ".join(completed_tokenized_source) fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string) - return fixed_punct.split(), is_match + return fixed_punct.split(" "), is_match else: final_string = " ".join(tokenized_source) fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string) - return fixed_punct.split(), not is_match + return fixed_punct.split(" "), not is_match @register_transform(name="terminology") diff --git a/onmt/transforms/tokenize.py b/onmt/transforms/tokenize.py index 54ad0a88df..17424a94bb 100644 --- a/onmt/transforms/tokenize.py +++ b/onmt/transforms/tokenize.py @@ -283,7 +283,7 @@ def apply_reverse(self, translated): if isinstance(translated, list): return self._detokenize(translated, "tgt") else: - return self._detokenize(translated.split(), "tgt") + return self._detokenize(translated.split(" "), "tgt") def _repr_args(self): """Return str represent key arguments for class.""" @@ -353,7 +353,7 @@ def apply_reverse(self, translated): if isinstance(translated, list): return self._detokenize(translated, "tgt") else: - return self._detokenize(translated.split(), "tgt") + return self._detokenize(translated.split(" "), "tgt") @register_transform(name="onmt_tokenize") @@ -550,7 +550,14 @@ def tokenize_string(self, sentence, side="src", is_train=False): self.maptable[b] for b in sentence.replace(DefaultTokens.SEP, "\n").encode("utf-8") ) - segmented = tokenizer(sentence) + segmented1 = tokenizer(sentence) + segmented = [] + # ugly patch to make sure "\n\n" is split in two items + for s in segmented1: + if s == "ĊĊ": + segmented.extend(["Ċ", "Ċ"]) + else: + segmented.append(s) else: segmented = tokenizer(sentence) return segmented @@ -572,7 +579,7 @@ def apply_reverse(self, translated): if isinstance(translated, list): return self._detokenize(translated, "tgt") else: - return self._detokenize(translated.split(), "tgt") + return self._detokenize(translated.split(" "), "tgt") def _repr_args(self): """Return str represent key arguments for class.""" diff --git a/onmt/transforms/uppercase.py b/onmt/transforms/uppercase.py index fb6a96b249..8027efa4e7 100644 --- a/onmt/transforms/uppercase.py +++ b/onmt/transforms/uppercase.py @@ -47,7 +47,7 @@ def apply(self, example, is_train=False, stats=None, **kwargs): for c in unicodedata.normalize("NFD", src_str.upper()) if unicodedata.category(c) != "Mn" ) - example["src"] = src_str.split() + example["src"] = src_str.split(" ") if example["tgt"] is not None: tgt_str = " ".join(example["tgt"]) @@ -56,6 +56,6 @@ def apply(self, example, is_train=False, stats=None, **kwargs): for c in unicodedata.normalize("NFD", tgt_str.upper()) if unicodedata.category(c) != "Mn" ) - example["tgt"] = tgt_str.split() + example["tgt"] = tgt_str.split(" ") return example diff --git a/onmt/translate/translation_server.py b/onmt/translate/translation_server.py index 22b5e2b062..6dbb269153 100644 --- a/onmt/translate/translation_server.py +++ b/onmt/translate/translation_server.py @@ -937,7 +937,7 @@ def maybe_detokenize(self, sequence, side="tgt"): """De-tokenize the sequence (or not) Same args/returns as :func:``tokenize()``""" - if self.tokenizers_opt is not None and "".join(sequence.split()) != "": + if self.tokenizers_opt is not None and "".join(sequence.split(" ")) != "": return self.detokenize(sequence, side) return sequence @@ -950,9 +950,9 @@ def detokenize(self, sequence, side="tgt"): raise ValueError("No tokenizer loaded") if self.tokenizers_opt[side]["type"] == "sentencepiece": - detok = self.tokenizers[side].DecodePieces(sequence.split()) + detok = self.tokenizers[side].DecodePieces(sequence.split(" ")) elif self.tokenizers_opt[side]["type"] == "pyonmttok": - detok = self.tokenizers[side].detokenize(sequence.split()) + detok = self.tokenizers[side].detokenize(sequence.split(" ")) return detok @@ -976,7 +976,7 @@ def maybe_convert_align(self, src, tgt, align, align_scores): "To get decoded alignment, joiner/spacer " "should be used in both side's tokenizer." ) - elif "".join(tgt.split()) != "": + elif "".join(tgt.split(" ")) != "": align = to_word_align( src, tgt, align, align_scores, src_marker, tgt_marker ) diff --git a/onmt/utils/alignment.py b/onmt/utils/alignment.py index 22d8a801eb..5ab92265bb 100644 --- a/onmt/utils/alignment.py +++ b/onmt/utils/alignment.py @@ -115,14 +115,14 @@ def to_word_align( assert m_src in ["joiner", "spacer"], "Invalid value for argument m_src!" assert m_tgt in ["joiner", "spacer"], "Invalid value for argument m_tgt!" - src, tgt = src.strip().split(), tgt.strip().split() + src, tgt = src.strip().split(" "), tgt.strip().split(" ") subword_align = { - (int(a), int(b)) for a, b in (x.split("-") for x in subword_align.split()) + (int(a), int(b)) for a, b in (x.split("-") for x in subword_align.split(" ")) } subword_align_scores = dict( (int(a), float(b)) - for a, b in (x.split("-") for x in subword_align_scores.split()) + for a, b in (x.split("-") for x in subword_align_scores.split(" ")) ) src_map = ( diff --git a/tools/convert_HF.py b/tools/convert_HF.py index 6269b9d16c..81973603ba 100755 --- a/tools/convert_HF.py +++ b/tools/convert_HF.py @@ -74,29 +74,20 @@ ".feed_forward.experts.7.layer_norm.weight": ".post_attention_layernorm.weight", } key_maps["PhiForCausalLM"] = { - "layer_prefix": "transformer.h.", - "decoder.embeddings.make_embedding.emb_luts.0.weight": "transformer.embd.wte.weight", - "decoder.layer_norm.weight": "lm_head.ln.weight", - "decoder.layer_norm.bias": "lm_head.ln.bias", - "generator.weight": "lm_head.linear.weight", - "generator.bias": "lm_head.linear.bias", - ".self_attn.linear_query.": ( - ".mixer.Wqkv.", - "[:hidden_size]", # noqa E501 - ), - ".self_attn.linear_keys.": ( - ".mixer.Wqkv.", - "[hidden_size:2*hidden_size]", # noqa E501 - ), - ".self_attn.linear_values.": ( - ".mixer.Wqkv.", - "[-hidden_size:]", # noqa E501 - ), - ".self_attn.final_linear.": ".mixer.out_proj.", + "layer_prefix": "model.layers.", + "decoder.embeddings.make_embedding.emb_luts.0.weight": "model.embed_tokens.weight", + "decoder.layer_norm.weight": "model.final_layernorm.weight", + "decoder.layer_norm.bias": "model.final_layernorm.bias", + "generator.weight": "lm_head.weight", + "generator.bias": "lm_head.bias", + ".self_attn.linear_query.": ".self_attn.q_proj.", + ".self_attn.linear_keys.": ".self_attn.k_proj.", + ".self_attn.linear_values.": ".self_attn.v_proj.", + ".self_attn.final_linear.": ".self_attn.dense.", ".feed_forward.w_1.": ".mlp.fc1.", ".feed_forward.w_2.": ".mlp.fc2.", - ".layer_norm_1.weight": (".ln.weight", ""), - ".layer_norm_1.bias": (".ln.bias", ""), + ".layer_norm_1.weight": (".input_layernorm.weight", ""), + ".layer_norm_1.bias": (".input_layernorm.bias", ""), } ln_table = { "LlamaForCausalLM": "rms", @@ -190,6 +181,10 @@ def __init__(self, model_path: str): "You used a local directory but tokenizer.model", " and/or tokenizer.json are missing", ) + if os.path.exists(os.path.join(opt.model_dir, "tokenizer_config.json")): + tokenizer_config_json = os.path.join(opt.model_dir, "tokenizer_config.json") + else: + tokenizer_config_json = None else: directory_path, _ = os.path.split(opt.output) os.makedirs(directory_path, exist_ok=True) @@ -224,6 +219,17 @@ def __init__(self, model_path: str): raise huggingface_hub.utils.EntryNotFoundError( "Something went wrong the repo does not contain any config.json file" ) + try: + tokenizer_config_json = huggingface_hub.hf_hub_download( + repo_id=opt.model_dir, + filename="tokenizer_config.json", + local_dir=directory_path, + token=opt.token, + ) + except huggingface_hub.utils.EntryNotFoundError: + raise huggingface_hub.utils.EntryNotFoundError( + "Something went wrong the repo does not contain any tokenizer_config.json file" + ) try: wmap_path = huggingface_hub.hf_hub_download( repo_id=opt.model_dir, @@ -325,6 +331,8 @@ def __init__(self, model_path: str): norm_eps = config["rms_norm_eps"] elif "layer_norm_epsilon" in config.keys(): norm_eps = config["layer_norm_epsilon"] + elif "layer_norm_eps" in config.keys(): + norm_eps = config["layer_norm_eps"] else: norm_eps = 1e-6 if "rope_theta" in config.keys(): @@ -333,6 +341,8 @@ def __init__(self, model_path: str): rope_theta = 1e4 if "rotary_dim" in config.keys(): rotary_dim = config["rotary_dim"] + elif "partial_rotary_factor" in config.keys(): + rotary_dim = int(config["partial_rotary_factor"] * (hidden_size // heads)) else: rotary_dim = 0 if "sliding_window" in config.keys(): @@ -404,7 +414,7 @@ def __init__(self, model_path: str): params = ["weight", "bias"] add_qkvbias = False - aff_ffnbias = False + add_ffnbias = False rotary_interleave = False if arch == "PhiForCausalLM": parallel_residual = True @@ -689,11 +699,28 @@ def get_weight(checkpoint, tensor_name): directory_path, _ = os.path.split(opt.output) os.makedirs(directory_path, exist_ok=True) + if tokenizer_config_json is not None: + with open(tokenizer_config_json, encoding="utf-8") as f: + data = json.load(f) + if "add_bos_token" in data.keys(): + add_bos_token = data["add_bos_token"] + else: + add_bos_token = False + else: + add_bos_token = True vocabs = {} if tokenizer_model is not None: tokenizer = Tokenizer(model_path=tokenizer_model) vocab = tokenizer.vocab - vocab[3] = DefaultTokens.PAD + if "<|startoftext|>" in vocab: + index = vocab.index("<|startoftext|>") + vocab[index] = DefaultTokens.BOS + if "<|endoftext|>" in vocab: + index = vocab.index("<|endoftext|>") + vocab[index] = DefaultTokens.EOS + if "<0x00>" in vocab: + index = vocab.index("<0x00>") + vocab[index] = DefaultTokens.PAD src_vocab = pyonmttok.build_vocab_from_tokens( vocab, maximum_size=tokenizer.n_words, @@ -722,7 +749,10 @@ def get_weight(checkpoint, tensor_name): vocabs["src"] = src_vocab vocabs["tgt"] = src_vocab vocabs["data_task"] = "lm" - vocabs["decoder_start_token"] = decoder_start_table[arch] + if add_bos_token: + vocabs["decoder_start_token"] = decoder_start_table[arch] + else: + vocabs["decoder_start_token"] = "" onmt_cp["vocab"] = {} onmt_cp["vocab"] = vocabs_to_dict(vocabs)