From 1f59eeb9bb4298aa6857754284d49cecc483ec36 Mon Sep 17 00:00:00 2001 From: Philipp Schmid <32632186+philschmid@users.noreply.github.com> Date: Thu, 18 Jan 2024 16:47:25 +0100 Subject: [PATCH] Fix chatml template (#1248) * first draft * 64 * sourabs suggestion * wip tests * make style happy * add check * docstring * fix docstring * Update tests/test_model_utils.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * move tests * add todo for abstract class * make style happy * add slow tests and imports * add documentation * sft_trainer.mdx aktualisieren Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * fix template & add test --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- tests/test_dataset_formatting.py | 17 +++++++++++++++++ trl/models/utils.py | 4 ++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py index 63198f32c2..84d34e393e 100644 --- a/tests/test_dataset_formatting.py +++ b/tests/test_dataset_formatting.py @@ -147,3 +147,20 @@ def test_setup_chat_format(self): self.assertTrue(len(modified_tokenizer) == original_tokenizer_len + 2) self.assertTrue(self.model.get_input_embeddings().weight.shape[0] % 64 == 0) self.assertTrue(self.model.get_input_embeddings().weight.shape[0] == original_tokenizer_len + 64) + + def test_example_with_setup_model(self): + modified_model, modified_tokenizer = setup_chat_format( + self.model, + self.tokenizer, + ) + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how can I help you?"}, + ] + prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False) + + self.assertEqual( + prompt, + "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n", + ) diff --git a/trl/models/utils.py b/trl/models/utils.py index f667ba9d77..2ec88845c3 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -29,10 +29,10 @@ def assistant(self): def chat_template(self): return ( "{% for message in messages %}" - f"{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + eos_token + '\n'}}" + f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}" "{% endfor %}" "{% if add_generation_prompt %}" - f"{{ '{self.assistant}\n' }}" + f"{{{{ '{self.assistant}\n' }}}}" "{% endif %}" )