Skip to content

Commit

Permalink
peft_module_casting_to_bf16 util method, append_concat_token flag…
Browse files Browse the repository at this point in the history
…, remove callback `PeftSavingCallback` (#1110)

* SFT Trainer enhancements

* remove the callback `PeftSavingCallback`

* bump the version of transformers to `4.31.0`

* remove `PeftSavingCallback` from all places.
  • Loading branch information
pacman100 authored Dec 19, 2023
1 parent d708ec2 commit f100ca3
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 52 deletions.
27 changes: 0 additions & 27 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -279,33 +279,6 @@ trainer = SFTTrainer(
trainer.train()
```

Note that in case of training adapters, we manually add a saving callback to automatically save the adapters only:
```python
class PeftSavingCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
kwargs["model"].save_pretrained(checkpoint_path)

if "pytorch_model.bin" in os.listdir(checkpoint_path):
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))
```
If you want to add more callbacks, make sure to add this one as well to properly save the adapters only during training.
```python
...

callbacks = [YourCustomCallback(), PeftSavingCallback()]

trainer = SFTTrainer(
"EleutherAI/gpt-neo-125m",
train_dataset=dataset,
dataset_text_field="text",
peft_config=peft_config,
callbacks=callbacks
)

trainer.train()
```

You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed.

### Training adapters with base 8 bit models
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

REQUIRED_PKGS = [
"torch>=1.4.0",
"transformers>=4.18.0",
"transformers>=4.31.0",
"numpy>=1.18.2",
"accelerate",
"datasets",
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DataCollatorForCompletionOnlyLM,
RunningMoments,
disable_dropout_in_model,
peft_module_casting_to_bf16,
)

# isort: on
Expand Down
8 changes: 1 addition & 7 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from ..import_utils import is_peft_available
from .training_configs import RewardConfig
from .utils import PeftSavingCallback, RewardDataCollatorWithPadding, compute_accuracy
from .utils import RewardDataCollatorWithPadding, compute_accuracy


if is_peft_available():
Expand Down Expand Up @@ -147,12 +147,6 @@ def __init__(

model = get_peft_model(model, peft_config)

if is_peft_available() and isinstance(model, PeftModel):
if callbacks is None:
callbacks = [PeftSavingCallback()]
else:
callbacks += [PeftSavingCallback()]

if compute_metrics is None:
compute_metrics = compute_accuracy

Expand Down
7 changes: 3 additions & 4 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
from .utils import (
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
PeftSavingCallback,
neftune_post_forward_hook,
peft_module_casting_to_bf16,
)


Expand Down Expand Up @@ -201,9 +201,8 @@ def make_inputs_require_grad(module, input, output):
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

model = get_peft_model(model, peft_config)

if callbacks is None:
callbacks = [PeftSavingCallback]
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(model)

if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
Expand Down
37 changes: 24 additions & 13 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import warnings
from collections import deque
Expand All @@ -22,7 +21,7 @@
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import IterableDataset
from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerBase, TrainerCallback
from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerBase


class AdaptiveKLController:
Expand Down Expand Up @@ -204,6 +203,7 @@ class RewardDataCollatorWithPadding:
return_tensors (`str`, `optional`, defaults to `"pt"`):
The tensor type to use.
"""

tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str] = True
max_length: Optional[int] = None
Expand Down Expand Up @@ -281,6 +281,7 @@ class DPODataCollatorWithPadding:
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
Whether or not you model has an encoder_decoder architecture.
"""

pad_token_id: int = 0
label_pad_token_id: int = -100
is_encoder_decoder: Optional[bool] = False
Expand Down Expand Up @@ -358,6 +359,8 @@ class ConstantLengthDataset(IterableDataset):
Id of the end of sequence token if the passed tokenizer does not have an EOS token.
shuffle ('bool', *optional*, defaults to True)
Shuffle the examples before they are returned
append_concat_token ('bool', *optional*, defaults to True)
If true, appends `eos_token_id` at the end of each sample being packed.
"""

def __init__(
Expand All @@ -372,6 +375,7 @@ def __init__(
chars_per_token=3.6,
eos_token_id=0,
shuffle=True,
append_concat_token=True,
):
self.tokenizer = tokenizer

Expand All @@ -388,6 +392,7 @@ def __init__(
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
self.shuffle = shuffle
self.append_concat_token = append_concat_token
if formatting_func is None:
self.formatting_func = lambda x: x[dataset_text_field]
else:
Expand Down Expand Up @@ -424,7 +429,9 @@ def __iter__(self):
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
if self.append_concat_token:
tokenized_input = tokenized_input + [self.concat_token_id]
all_token_ids.extend(tokenized_input)
examples = []
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
Expand All @@ -440,16 +447,6 @@ def __iter__(self):
}


class PeftSavingCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
if args.should_save:
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
kwargs["model"].save_pretrained(checkpoint_path)

if "pytorch_model.bin" in os.listdir(checkpoint_path):
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))


class RunningMoments:
def __init__(self, accelerator):
"""
Expand Down Expand Up @@ -620,3 +617,17 @@ def neftune_post_forward_hook(module, input, output):
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output


def peft_module_casting_to_bf16(model):
from peft.tuners.tuners_utils import BaseTunerLayer

for name, module in model.named_modules():
if isinstance(module, BaseTunerLayer):
module = module.to(torch.bfloat16)
if "norm" in name:
module = module.to(torch.float32)
if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
if hasattr(module, "weight"):
if module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)

0 comments on commit f100ca3

Please sign in to comment.