From 928d14445e31b3586ce8b73ca70ecb02dc603369 Mon Sep 17 00:00:00 2001 From: Philipp Schmid <32632186+philschmid@users.noreply.github.com> Date: Thu, 18 Jan 2024 11:05:32 +0100 Subject: [PATCH] Add `setup_chat_format` for adding new special tokens to model for training chat models (#1242) * first draft * 64 * sourabs suggestion * wip tests * make style happy * add check * docstring * fix docstring * Update tests/test_model_utils.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * move tests * add todo for abstract class * make style happy * add slow tests and imports * add documentation * sft_trainer.mdx aktualisieren Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- docs/source/sft_trainer.mdx | 23 +++++++++ tests/slow/test_sft_slow.py | 44 ++++++++++++++++++ tests/test_dataset_formatting.py | 27 ++++++++++- trl/__init__.py | 1 + trl/models/__init__.py | 1 + trl/models/utils.py | 80 ++++++++++++++++++++++++++++++++ 6 files changed, 175 insertions(+), 1 deletion(-) create mode 100644 trl/models/utils.py diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx index 9c4df57677..7469661534 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.mdx @@ -154,6 +154,29 @@ response_template_ids = tokenizer.encode(response_template_with_context, add_spe data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer) ``` +### Add Special Tokens for Chat Format + +Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment. +The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function: +- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation. +- Resizes the model’s embedding layer to accommodate the new tokens. +- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI. +- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g. 64. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl) + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Load model and tokenizer +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + +# Set up the chat format with default 'chatml' format +model, tokenizer = setup_chat_format(model, tokenizer) + +``` + +With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning. + ### Dataset format support The [`SFTTrainer`] supports popular dataset formats. This allows you to pass the dataset to the trainer without any pre-processing directly. The following formats are supported: diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py index 68299dbe34..1b88bf9130 100644 --- a/tests/slow/test_sft_slow.py +++ b/tests/slow/test_sft_slow.py @@ -23,6 +23,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments from trl import SFTTrainer, is_peft_available +from trl.models.utils import setup_chat_format from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu, require_torch_multi_gpu from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS @@ -344,3 +345,46 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr trainer.train() release_memory(model, trainer) + + @parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS))) + @require_peft + @require_bitsandbytes + def test_sft_trainer_with_chat_format_qlora(self, model_name, packing): + """ + Simply tests if using setup_chat_format with a transformers model + peft + bnb config to `SFTTrainer` loads and runs the trainer + as expected. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + train_dataset = load_dataset("trl-internal-testing/dolly-chatml-sft", split="train") + + args = TrainingArguments( + output_dir=tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=10, + fp16=True, + ) + + quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) + + model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + model, tokenizer = setup_chat_format(model, tokenizer) + + trainer = SFTTrainer( + model, + args=args, + tokenizer=tokenizer, + train_dataset=train_dataset, + packing=packing, + max_seq_length=self.max_seq_length, + peft_config=self.peft_config, + ) + + self.assertTrue(isinstance(trainer.model, PeftModel)) + + trainer.train() + + release_memory(model, trainer) diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py index 156afdef09..63198f32c2 100644 --- a/tests/test_dataset_formatting.py +++ b/tests/test_dataset_formatting.py @@ -2,9 +2,10 @@ from typing import Callable from datasets import Dataset, load_dataset -from transformers import AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer from trl.extras.dataset_formatting import get_formatting_func_from_dataset +from trl.models.utils import ChatMlSpecialTokens, setup_chat_format class DatasetFormattingTestCase(unittest.TestCase): @@ -122,3 +123,27 @@ def test_get_formatting_func_from_dataset_with_unknown_format(self): dataset = Dataset.from_dict({"text": "test"}) formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) self.assertIsNone(formatting_func) + + +class SetupChatFormatTestCase(unittest.TestCase): + def setUp(self): + self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + + def test_setup_chat_format(self): + original_tokenizer_len = len(self.tokenizer) + modified_model, modified_tokenizer = setup_chat_format( + self.model, self.tokenizer, format="chatml", resize_to_multiple_of=64 + ) + + _chatml = ChatMlSpecialTokens() + # Check if special tokens are correctly set + self.assertTrue(modified_tokenizer.eos_token == "<|im_end|>") + self.assertTrue(modified_tokenizer.pad_token == "<|im_end|>") + self.assertTrue(modified_tokenizer.bos_token == "<|im_start|>") + self.assertTrue(modified_tokenizer.eos_token == _chatml.eos_token) + self.assertTrue(modified_tokenizer.pad_token == _chatml.pad_token) + self.assertTrue(modified_tokenizer.bos_token == _chatml.bos_token) + self.assertTrue(len(modified_tokenizer) == original_tokenizer_len + 2) + self.assertTrue(self.model.get_input_embeddings().weight.shape[0] % 64 == 0) + self.assertTrue(self.model.get_input_embeddings().weight.shape[0] == original_tokenizer_len + 64) diff --git a/trl/__init__.py b/trl/__init__.py index 8a6b789be0..6567d44d30 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -18,6 +18,7 @@ AutoModelForSeq2SeqLMWithValueHead, PreTrainedModelWrapper, create_reference_model, + setup_chat_format, ) from .trainer import ( DataCollatorForCompletionOnlyLM, diff --git a/trl/models/__init__.py b/trl/models/__init__.py index 6ccce25e5e..ec20345533 100644 --- a/trl/models/__init__.py +++ b/trl/models/__init__.py @@ -15,6 +15,7 @@ # limitations under the License. from .modeling_base import PreTrainedModelWrapper, create_reference_model from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead +from .utils import setup_chat_format SUPPORTED_ARCHITECTURES = ( diff --git a/trl/models/utils.py b/trl/models/utils.py new file mode 100644 index 0000000000..f667ba9d77 --- /dev/null +++ b/trl/models/utils.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass +from typing import Literal, Optional, Tuple + +from transformers import PreTrainedModel, PreTrainedTokenizer + + +# TODO: Add Abstract Base Class if more formats are added +@dataclass +class ChatMlSpecialTokens: + """Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens.""" + + bos_token: str = "<|im_start|>" + eos_token: str = "<|im_end|>" + pad_token: str = "<|im_end|>" + + @property + def system(self): + return f"{self.bos_token}system" + + @property + def user(self): + return f"{self.bos_token}user" + + @property + def assistant(self): + return f"{self.bos_token}assistant" + + @property + def chat_template(self): + return ( + "{% for message in messages %}" + f"{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + eos_token + '\n'}}" + "{% endfor %}" + "{% if add_generation_prompt %}" + f"{{ '{self.assistant}\n' }}" + "{% endif %}" + ) + + +FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens} + + +def setup_chat_format( + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + format: Optional[Literal["chatml"]] = "chatml", + resize_to_multiple_of: Optional[int] = None, +) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: + """ + Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. + + Args: + model (`~transformers.PreTrainedModel`): The model to be modified. + tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. + format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml". + resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None. + Returns: + model (`~transformers.PreTrainedModel`): The modified model. + tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer. + """ + # check if format available and retrieve + if format not in FORMAT_MAPPING: + raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}") + + chat_format = FORMAT_MAPPING[format]() + + # set special tokens and them + tokenizer.eos_token = chat_format.eos_token + tokenizer.pad_token = chat_format.pad_token + tokenizer.bos_token = chat_format.bos_token + tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]}) + # set chat format for tokenizer + tokenizer.chat_template = chat_format.chat_template + + # resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377 + model.resize_token_embeddings( + len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None + ) + + return model, tokenizer