Skip to content

Commit

Permalink
Add setup_chat_format for adding new special tokens to model for tr…
Browse files Browse the repository at this point in the history
…aining chat models (#1242)

* first draft

* 64

* sourabs suggestion

* wip tests

* make style happy

* add check

* docstring

* fix docstring

* Update tests/test_model_utils.py

Co-authored-by: Younes Belkada <[email protected]>

* move tests

* add todo for abstract class

* make style happy

* add slow tests and imports

* add documentation

* sft_trainer.mdx aktualisieren

Co-authored-by: Younes Belkada <[email protected]>

---------

Co-authored-by: Younes Belkada <[email protected]>
  • Loading branch information
philschmid and younesbelkada authored Jan 18, 2024
1 parent 3319993 commit 928d144
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 1 deletion.
23 changes: 23 additions & 0 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,29 @@ response_template_ids = tokenizer.encode(response_template_with_context, add_spe
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
```

### Add Special Tokens for Chat Format

Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
- Resizes the model’s embedding layer to accommodate the new tokens.
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g. 64. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

# Set up the chat format with default 'chatml' format
model, tokenizer = setup_chat_format(model, tokenizer)

```

With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.

### Dataset format support

The [`SFTTrainer`] supports popular dataset formats. This allows you to pass the dataset to the trainer without any pre-processing directly. The following formats are supported:
Expand Down
44 changes: 44 additions & 0 deletions tests/slow/test_sft_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments

from trl import SFTTrainer, is_peft_available
from trl.models.utils import setup_chat_format

from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu, require_torch_multi_gpu
from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS
Expand Down Expand Up @@ -344,3 +345,46 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr
trainer.train()

release_memory(model, trainer)

@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
@require_peft
@require_bitsandbytes
def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
"""
Simply tests if using setup_chat_format with a transformers model + peft + bnb config to `SFTTrainer` loads and runs the trainer
as expected.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
train_dataset = load_dataset("trl-internal-testing/dolly-chatml-sft", split="train")

args = TrainingArguments(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
fp16=True,
)

quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)

model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model, tokenizer = setup_chat_format(model, tokenizer)

trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
packing=packing,
max_seq_length=self.max_seq_length,
peft_config=self.peft_config,
)

self.assertTrue(isinstance(trainer.model, PeftModel))

trainer.train()

release_memory(model, trainer)
27 changes: 26 additions & 1 deletion tests/test_dataset_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import Callable

from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer

from trl.extras.dataset_formatting import get_formatting_func_from_dataset
from trl.models.utils import ChatMlSpecialTokens, setup_chat_format


class DatasetFormattingTestCase(unittest.TestCase):
Expand Down Expand Up @@ -122,3 +123,27 @@ def test_get_formatting_func_from_dataset_with_unknown_format(self):
dataset = Dataset.from_dict({"text": "test"})
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
self.assertIsNone(formatting_func)


class SetupChatFormatTestCase(unittest.TestCase):
def setUp(self):
self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")

def test_setup_chat_format(self):
original_tokenizer_len = len(self.tokenizer)
modified_model, modified_tokenizer = setup_chat_format(
self.model, self.tokenizer, format="chatml", resize_to_multiple_of=64
)

_chatml = ChatMlSpecialTokens()
# Check if special tokens are correctly set
self.assertTrue(modified_tokenizer.eos_token == "<|im_end|>")
self.assertTrue(modified_tokenizer.pad_token == "<|im_end|>")
self.assertTrue(modified_tokenizer.bos_token == "<|im_start|>")
self.assertTrue(modified_tokenizer.eos_token == _chatml.eos_token)
self.assertTrue(modified_tokenizer.pad_token == _chatml.pad_token)
self.assertTrue(modified_tokenizer.bos_token == _chatml.bos_token)
self.assertTrue(len(modified_tokenizer) == original_tokenizer_len + 2)
self.assertTrue(self.model.get_input_embeddings().weight.shape[0] % 64 == 0)
self.assertTrue(self.model.get_input_embeddings().weight.shape[0] == original_tokenizer_len + 64)
1 change: 1 addition & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AutoModelForSeq2SeqLMWithValueHead,
PreTrainedModelWrapper,
create_reference_model,
setup_chat_format,
)
from .trainer import (
DataCollatorForCompletionOnlyLM,
Expand Down
1 change: 1 addition & 0 deletions trl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
from .modeling_base import PreTrainedModelWrapper, create_reference_model
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import setup_chat_format


SUPPORTED_ARCHITECTURES = (
Expand Down
80 changes: 80 additions & 0 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from dataclasses import dataclass
from typing import Literal, Optional, Tuple

from transformers import PreTrainedModel, PreTrainedTokenizer


# TODO: Add Abstract Base Class if more formats are added
@dataclass
class ChatMlSpecialTokens:
"""Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens."""

bos_token: str = "<|im_start|>"
eos_token: str = "<|im_end|>"
pad_token: str = "<|im_end|>"

@property
def system(self):
return f"{self.bos_token}system"

@property
def user(self):
return f"{self.bos_token}user"

@property
def assistant(self):
return f"{self.bos_token}assistant"

@property
def chat_template(self):
return (
"{% for message in messages %}"
f"{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + eos_token + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
f"{{ '{self.assistant}\n' }}"
"{% endif %}"
)


FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens}


def setup_chat_format(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
format: Optional[Literal["chatml"]] = "chatml",
resize_to_multiple_of: Optional[int] = None,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens.
Args:
model (`~transformers.PreTrainedModel`): The model to be modified.
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None.
Returns:
model (`~transformers.PreTrainedModel`): The modified model.
tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer.
"""
# check if format available and retrieve
if format not in FORMAT_MAPPING:
raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}")

chat_format = FORMAT_MAPPING[format]()

# set special tokens and them
tokenizer.eos_token = chat_format.eos_token
tokenizer.pad_token = chat_format.pad_token
tokenizer.bos_token = chat_format.bos_token
tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]})
# set chat format for tokenizer
tokenizer.chat_template = chat_format.chat_template

# resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377
model.resize_token_embeddings(
len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None
)

return model, tokenizer

0 comments on commit 928d144

Please sign in to comment.