diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py index 63198f32c2..84d34e393e 100644 --- a/tests/test_dataset_formatting.py +++ b/tests/test_dataset_formatting.py @@ -147,3 +147,20 @@ def test_setup_chat_format(self): self.assertTrue(len(modified_tokenizer) == original_tokenizer_len + 2) self.assertTrue(self.model.get_input_embeddings().weight.shape[0] % 64 == 0) self.assertTrue(self.model.get_input_embeddings().weight.shape[0] == original_tokenizer_len + 64) + + def test_example_with_setup_model(self): + modified_model, modified_tokenizer = setup_chat_format( + self.model, + self.tokenizer, + ) + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how can I help you?"}, + ] + prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False) + + self.assertEqual( + prompt, + "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n", + ) diff --git a/trl/models/utils.py b/trl/models/utils.py index f667ba9d77..2ec88845c3 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -29,10 +29,10 @@ def assistant(self): def chat_template(self): return ( "{% for message in messages %}" - f"{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + eos_token + '\n'}}" + f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}" "{% endfor %}" "{% if add_generation_prompt %}" - f"{{ '{self.assistant}\n' }}" + f"{{{{ '{self.assistant}\n' }}}}" "{% endif %}" )