Skip to content

Commit

Permalink
more tests for DPO
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Jan 12, 2024
1 parent 4582f4f commit d82983e
Showing 1 changed file with 74 additions and 1 deletion.
75 changes: 74 additions & 1 deletion tests/slow/test_dpo_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from trl import DPOTrainer, is_peft_available

from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu
from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu, require_torch_multi_gpu
from .testing_constants import (
DPO_GEN_DURING_EVAL,
DPO_LOSS_TYPES,
Expand Down Expand Up @@ -228,3 +228,76 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra
trainer.save_model()

release_memory(model, trainer)


@require_torch_multi_gpu
class DPOTrainerSlowTesterMultiGPU(DPOTrainerSlowTester):
@parameterized.expand(
list(
itertools.product(
MODELS_TO_TEST,
DPO_LOSS_TYPES,
DPO_PRECOMPUTE_LOGITS,
)
)
)
@require_bitsandbytes
@require_peft
def test_dpo_peft_model_qlora_multi(self, model_id, loss_type, pre_compute_logits):
"""
A test that tests the simple usage of `DPOTrainer` using QLoRA + different scenarios of gradient checkpointing.
"""
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)

# Currently accelerate does not let you let's say the reference and active model
# on different devices. The canonical way to do so is either load them on the same device
# or just use a single model with PEFT.
model = AutoModelForCausalLM.from_pretrained(
model_id, quantization_config=quantization_config, device_map={"": 0}
)
ref_model = AutoModelForCausalLM.from_pretrained(
model_id, quantization_config=quantization_config, device_map={"": 0}
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=2,
remove_unused_columns=False,
gradient_accumulation_steps=2,
learning_rate=9e-1,
evaluation_strategy="steps",
fp16=True,
logging_strategy="no",
report_to="none",
gradient_checkpointing=True,
)

# dpo train lora model
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=self.dataset,
eval_dataset=self.dataset,
generate_during_eval=False,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute_logits,
peft_config=self.peft_config,
max_length=self.max_length,
)

self.assertTrue(isinstance(trainer.model, PeftModel))

# train the model
trainer.train()

# save trained model or adapter
trainer.save_model()

release_memory(model, trainer)

0 comments on commit d82983e

Please sign in to comment.