diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 8fe76dd674..f0ab08f6c1 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -16,6 +16,7 @@ import random import warnings from collections import defaultdict +from copy import deepcopy from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch @@ -345,7 +346,8 @@ def make_inputs_require_grad(module, input, output): def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 deepspeed_plugin = self.accelerator.state.deepspeed_plugin - config_kwargs = deepspeed_plugin.deepspeed_config + config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) + if model is not None: if hasattr(model, "config"): hidden_size = (