Skip to content

Commit

Permalink
[Multi-Adapter PPO] Fix and Refactor reward model adapter (#982)
Browse files Browse the repository at this point in the history
* reward adapter loaded as part of init

more flexible, clearer args

* fixed script for multi gpu

unwrap model since it is DDP
downside, with reward adapter it seems we need to use
find_unused_parameters=True

* remove gradient from reward score calculation

* change supported_args back to None
  • Loading branch information
mnoukhov authored Nov 21, 2023
1 parent aea1da8 commit b307faf
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 83 deletions.
7 changes: 4 additions & 3 deletions examples/scripts/ppo_multi_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import BitsAndBytesConfig, HfArgumentParser, LlamaTokenizer
from transformers import AutoTokenizer, BitsAndBytesConfig, HfArgumentParser

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, is_xpu_available
from trl.core import LengthSampler
Expand Down Expand Up @@ -88,7 +88,7 @@ def tokenize(example):
reward_adapter=script_args.rm_adapter,
use_safetensors=script_args.use_safetensors,
)
tokenizer = LlamaTokenizer.from_pretrained(script_args.model_name)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)

tokenizer.pad_token = tokenizer.eos_token

Expand Down Expand Up @@ -127,6 +127,7 @@ def collator(data):
"top_p": 0.9,
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
"max_new_tokens": 32,
}

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
Expand All @@ -142,7 +143,7 @@ def collator(data):
# Compute reward score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(ppo_trainer.accelerator.device)
raw_rewards = ppo_trainer.model.compute_reward_score(**inputs)
raw_rewards = ppo_trainer.accelerator.unwrap_model(ppo_trainer.model).compute_reward_score(**inputs)
rewards = [raw_rewards[i, -1, 1] for i in range(len(raw_rewards))] # take last token

# Run PPO step
Expand Down
174 changes: 96 additions & 78 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

if is_peft_available():
from peft import (
LoraConfig,
PeftConfig,
PeftModel,
PeftModelForCausalLM,
Expand All @@ -38,7 +37,6 @@
get_peft_model,
prepare_model_for_kbit_training,
)
from peft.peft_model import set_peft_model_state_dict

if is_transformers_greater_than("4.33.0"):
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
Expand Down Expand Up @@ -77,7 +75,9 @@ class PreTrainedModelWrapper(nn.Module):
else (PreTrainedModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM)
)

def __init__(self, pretrained_model=None, **kwargs):
def __init__(
self, pretrained_model=None, score_module=None, supports_rm_adapter=False, rm_adapter_name=None, **kwargs
):
super().__init__()
self.pretrained_model = pretrained_model

Expand All @@ -93,6 +93,12 @@ def __init__(self, pretrained_model=None, **kwargs):
if hasattr(pretrained_model, "gradient_checkpointing_enable"):
self.gradient_checkpointing_enable = pretrained_model.gradient_checkpointing_enable

self.supports_rm_adapter = supports_rm_adapter
self.rm_adapter_name = rm_adapter_name
self.policy_adapter_name = "default"
if score_module is not None:
self.score = score_module

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Expand Down Expand Up @@ -120,6 +126,7 @@ class and the arguments that are specific to trl models. The kwargs
if kwargs is not None:
peft_config = kwargs.pop("peft_config", None)
reward_adapter = kwargs.pop("reward_adapter", None)
reward_adapter_name = kwargs.pop("reward_adapter_name", "reward_adapter")
is_trainable = kwargs.pop("is_trainable", False)
trl_model_args, pretrained_kwargs, peft_quantization_kwargs = cls._split_kwargs(kwargs)
token = pretrained_kwargs.get("token", None)
Expand Down Expand Up @@ -242,8 +249,24 @@ class and the arguments that are specific to trl models. The kwargs
pretrained_model.active_peft_config, PromptLearningConfig
):
raise ValueError("PromptLearningConfig is not supported for PPO training.")

# Add reward modeling adapter if specified
if not is_peft_model and reward_adapter is not None:
raise ValueError("reward_adapter can only be used with a PeftModel. ")
elif is_peft_model and reward_adapter is not None:
score_module = cls.add_and_load_reward_modeling_adapter(
pretrained_model, reward_adapter, reward_adapter_name, token=token
)
multi_adapter_args = {
"score_module": score_module,
"supports_rm_adapter": True,
"rm_adapter_name": reward_adapter_name,
}
else:
multi_adapter_args = {"supports_rm_adapter": False}

# Then, create the full model by instantiating the wrapper class
model = cls(pretrained_model, **trl_model_args)
model = cls(pretrained_model, **multi_adapter_args, **trl_model_args)

# if resume_training, load the state_dict again - this is ok since the
# state_dict is removed from the model after loading it.
Expand Down Expand Up @@ -306,14 +329,6 @@ class and the arguments that are specific to trl models. The kwargs
if is_resuming_training:
model.post_init(state_dict=state_dict)

if not is_peft_model and reward_adapter is not None:
raise ValueError("reward_adapter can only be used with a PeftModel. ")
elif is_peft_model and reward_adapter is not None:
model.add_and_load_reward_modeling_adapter(reward_adapter, token=token)
model.supports_rm_adapter = True
else:
model.supports_rm_adapter = False

return model

@classmethod
Expand Down Expand Up @@ -415,6 +430,62 @@ def _split_kwargs(cls, kwargs):

return supported_kwargs, unsupported_kwargs, peft_kwargs

@classmethod
def add_and_load_reward_modeling_adapter(
cls, pretrained_model, adapter_model_id, adapter_name="reward_model_adapter", token=None
):
r"""
Add and load a reward modeling adapter. This method can only be used if the
model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id`
argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the
score head in order to produce the reward.
"""
pretrained_model.load_adapter(adapter_model_id, adapter_name, is_trainable=False)
pretrained_model.train()

filename = os.path.join(adapter_model_id, "adapter_model.bin")
if not os.path.exists(filename):
try:
local_filename = hf_hub_download(
adapter_model_id,
"adapter_model.bin",
token=token,
)
except: # noqa
raise ValueError(
"Could not find adapter model in the Hub, make sure you have the correct adapter model id."
)
else:
local_filename = filename

adapter_state_dict = torch.load(local_filename, map_location="cpu")

for score_name_candidate in cls.supported_rm_modules:
if any([score_name_candidate in name for name in adapter_state_dict.keys()]):
score_name = score_name_candidate
# we have found the correct head name and can break
break

score_dict = {}

for name, param in adapter_state_dict.items():
if score_name in name:
key_name = ".".join(name.split(".")[-1:])
score_dict[key_name] = param.to(cls._get_current_device())

num_labels, hidden_dim = score_dict["weight"].shape
has_bias = any(["bias" in name for name in adapter_state_dict.keys()])

score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(
device=cls._get_current_device(),
dtype=pretrained_model.dtype,
)
score.load_state_dict(score_dict)
for param in score.parameters():
param.requires_grad = False

return score

def push_to_hub(self, *args, **kwargs):
r"""
Push the pretrained model to the hub. This method is a wrapper around
Expand Down Expand Up @@ -474,61 +545,7 @@ def post_init(self, *args, **kwargs):
"""
raise NotImplementedError

def add_and_load_reward_modeling_adapter(self, adapter_model_id, adapter_name="reward_model_adapter", token=None):
r"""
Add and load a reward modeling adapter. This method can only be used if the
model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id`
argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the
score head in order to produce the reward.
"""
filename = os.path.join(adapter_model_id, "adapter_model.bin")
if not os.path.exists(filename):
try:
local_filename = hf_hub_download(
adapter_model_id,
"adapter_model.bin",
token=token,
)
except: # noqa
raise ValueError(
"Could not find adapter model in the Hub, make sure you have the correct adapter model id."
)
else:
local_filename = filename

adapter_state_dict = torch.load(local_filename, map_location="cpu")
rm_adapter_peft_config = LoraConfig.from_pretrained(adapter_model_id)

for score_name_candidate in self.supported_rm_modules:
if any([score_name_candidate in name for name in adapter_state_dict.keys()]):
score_name = score_name_candidate
# we have found the correct head name and can break
break

score_dict = {}
copy_adapter_state_dict = adapter_state_dict.copy()

for name, _ in copy_adapter_state_dict.items():
if score_name in name:
key_name = ".".join(name.split(".")[-1:])
score_dict[key_name] = adapter_state_dict.pop(name).to(self._get_current_device())

self.pretrained_model.add_adapter(adapter_name, rm_adapter_peft_config)
self.rm_adapter_name = adapter_name

num_labels, hidden_dim = score_dict["weight"].shape
has_bias = any(["bias" in name for name in adapter_state_dict.keys()])

self.score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(
device=self._get_current_device(),
dtype=self.pretrained_model.dtype,
)
self.score.load_state_dict(score_dict)

# load the adapter to the model
set_peft_model_state_dict(self.pretrained_model, adapter_state_dict, adapter_name=adapter_name)

def compute_reward_score(self, input_ids, attention_mask=None, ppo_adapter_name="default", **kwargs):
def compute_reward_score(self, input_ids, attention_mask=None, **kwargs):
r"""
Computes the reward score for a given input. The method has first to enable the adapter
and then compute the reward score. After that the model disables the reward modeling
Expand All @@ -541,19 +558,20 @@ def compute_reward_score(self, input_ids, attention_mask=None, ppo_adapter_name=
self.pretrained_model.set_adapter(self.rm_adapter_name)
self.pretrained_model.eval()

base_model_output = self.pretrained_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
**kwargs,
)
with torch.no_grad():
base_model_output = self.pretrained_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
**kwargs,
)

last_hidden_states = base_model_output.hidden_states[-1]
scores = self.score(last_hidden_states)
last_hidden_states = base_model_output.hidden_states[-1]
scores = self.score(last_hidden_states)

self.pretrained_model.set_adapter(ppo_adapter_name)
self.pretrained_model.train()
self.pretrained_model.set_adapter(self.policy_adapter_name)
self.pretrained_model.eval()

return scores

Expand Down
4 changes: 2 additions & 2 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self, pretrained_model, **kwargs):
kwargs (`dict`, `optional`):
Additional keyword arguments, that are passed to the `ValueHead` class.
"""
super().__init__(pretrained_model)
super().__init__(pretrained_model, **kwargs)
v_head_kwargs, _, _ = self._split_kwargs(kwargs)

if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
Expand Down Expand Up @@ -285,7 +285,7 @@ class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
)

def __init__(self, pretrained_model, **kwargs):
super().__init__(pretrained_model)
super().__init__(pretrained_model, **kwargs)
v_head_kwargs, _, _ = self._split_kwargs(kwargs)
self.is_encoder_decoder = True

Expand Down

0 comments on commit b307faf

Please sign in to comment.