Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RSO (Statistical Rejection Sampling Improves Preference Optimization) #902

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions examples/research_projects/rso/generate_from_sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Optional

from accelerate import Accelerator
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser
)

from trl.trainer.utils import generate


@dataclass
class ScriptArguments:
# model parameters
model_name_or_path: Optional[str] = field(default=None, metadata={"help": "the model name"})
mixed_precision: Optional[str] = field(default="fp16", metadata={"help": "the model dtype"})
# data parameters
dataset_name: Optional[str] = field(default="Dahoas/full-hh-rlhf", metadata={"help": "the HF data path"})
split: Optional[str] = field(default="train", metadata={"help": "the dataset split to use for generation"})
batch_size: Optional[int] = field(default=8, metadata={"help": "the generation batch size"})
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
save_dataset_path: Optional[str] = field(default="sft_gen_dataset", metadata={"help": "the path for saving the generated dataset"})
# generation parameters
max_new_tokens: Optional[int] = field(
default=128, metadata={"help": "the maximum number of tokens generated per sample"}
)
temperature: Optional[float] = field(default=1.0, metadata={"help": "the sampling temperature"})
top_p: Optional[float] = field(default=1.0, metadata={"help": "top_p sampling argument"})
top_k: Optional[float] = field(default=50, metadata={"help": "top_k sampling argument"})
num_return_sequences: Optional[int] = field(default=64, metadata={"help": "the number of return sequences"})
# instrumentation
sanity_check: Optional[bool] = field(default=False)


if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

accelerator = Accelerator(
mixed_precision=script_args.mixed_precision
)

# load sft policy
model = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)

if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# for generation
tokenizer.padding_side = "left"

# define gen_kwargs
generation_kwargs = {
"top_k": script_args.top_k,
"top_p": script_args.top_p,
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
"temperature": script_args.temperature,
"max_new_tokens": script_args.max_new_tokens,
"num_return_sequences": script_args.num_return_sequences,
}

# load and preprocess the dataset
dataset = load_dataset(script_args.dataset_name)[script_args.split]

if script_args.sanity_check:
dataset = dataset.select(range(min(len(dataset), 100)))

def tokenize_fn(samples):
model_inputs = tokenizer(samples["prompt"])

return {
**model_inputs,
}

dataset = dataset.map(tokenize_fn, batched=True, remove_columns=list(dataset.features))
dataset = dataset.filter(lambda x: len(x["input_ids"])<script_args.max_prompt_length)

data_collator = DataCollatorForSeq2Seq(tokenizer, max_length=script_args.max_prompt_length, pad_to_multiple_of=8)

dataloader = DataLoader(dataset, batch_size=script_args.batch_size, shuffle=False, collate_fn=data_collator)

model, dataloader = accelerator.prepare(model, dataloader)

# generate responses from sft policy
prompts, responses = generate(model, dataloader, tokenizer, accelerator, **generation_kwargs)

generated_dataset = Dataset.from_dict({"prompt": prompts, "response": responses})

# save the generated dataset
generated_dataset.save_to_disk(script_args.save_dataset_path)
187 changes: 187 additions & 0 deletions examples/research_projects/rso/rso.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import warnings
from dataclasses import dataclass, field
from typing import List, Optional, Tuple

from accelerate import Accelerator
from datasets import Dataset, load_from_disk
from torch.utils.data import DataLoader
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
HfArgumentParser
)

from trl.trainer.utils import conduct_rejection_sampling, compute_reward_score

@dataclass
class ScriptArguments:
reward_model_name_or_path: Optional[str] = field(default=None, metadata={"help": "the model name"})
mixed_precision: Optional[str] = field(default="fp16", metadata={"help": "the model dtype"})
# data parameters
dataset_name: Optional[str] = field(default=None, metadata={"help": "the generated dataset path"})
batch_size: Optional[int] = field(default=32, metadata={"help": "the scoring batch size"})
save_dataset_path: Optional[str] = field(default="sft_gen_dataset_ranked", metadata={"help": "the path for saving the dataset"})

# rejection sampling
num_samples: Optional[int] = field(default=8, metadata={"help": "the number of samples to keep after rejection sampling"})
beta: Optional[int] = field(default=.5, metadata={"help": "TO DO"})
ranking_method: Optional[str] = field(default="first_round", metadata={"help": " or tournament TO DO"})

# instrumentation
sanity_check: Optional[bool] = field(default=False)


def first_round_ranking(responses: List[str], rewards: List[float]) -> Tuple[List[str], List[str]]:
"""Conducts first round ranking. Starts from n responses and construct n/2 pairs to be assigned
to chosen or rejected based on there rewards.

Args:
responses: accecpted candidates from rejection sampling
rewards: response rewards.

Returns:
chosen: chosen samples.
rejected: rejected samples.
"""

chosen = []
rejected = []

def pick(responses):
selected = random.randrange(len(responses))
return responses.pop(selected)

responses = [(response, reward) for response, reward in zip(responses,rewards)]
while responses:
selected1 = pick(responses)
selected2 = pick(responses)
if selected1[1]>selected2[1]:
chosen.append(selected1[0])
rejected.append(selected2[0])
else:
chosen.append(selected2[0])
rejected.append(selected1[0])

return chosen, rejected


def tournament_ranking(responses: List[str], rewards: List[float]):
"""Conducts tournament ranking. Starts from n responses and construct n-1 pairs to be assigned
to chosen or rejected based on there rewards.

Args:
responses: accecpted candidates from rejection sampling.
rewards: response rewards.

Returns:
chosen: chosen samples.
rejected: rejected samples.
"""
sorted_responses = [response for _, response in sorted(zip(rewards, responses), reverse=True)]

chosen = [sorted_responses[i] for i in range(0, len(responses), 2)]
rejected =[sorted_responses[i] for i in range(1, len(responses), 2)]

return chosen, rejected


if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]


if script_args.num_samples%2!=0:
warnings.warn(
"Creating pairs requires an even number for num_samples."
f"Setting num_samples to {script_args.num_samples+1} instead of {script_args.num_samples}"
)
script_args.num_samples += 1

accelerator = Accelerator(
mixed_precision=script_args.mixed_precision
)

# load reward model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained(script_args.reward_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(script_args.reward_model_name_or_path)

# load and preprocess the dataset
dataset = load_from_disk(script_args.dataset_name)

if script_args.sanity_check:
dataset = dataset.dataset(range(min(len(dataset), 500)))

def tokenize_fn(samples):
# create the text column first
text = [prompt + " " + response for prompt, response in zip(samples["prompt"], samples["response"])]
model_inputs = tokenizer(text)

return {
**model_inputs,
}

reward_dataset = dataset.map(tokenize_fn, batched=True, remove_columns=list(dataset.features))

data_collator = DataCollatorWithPadding(tokenizer)

dataloader = DataLoader(reward_dataset, batch_size=script_args.batch_size, shuffle=False, collate_fn=data_collator)

model, dataloader = accelerator.prepare(model, dataloader)

rewards = compute_reward_score(model, dataloader, accelerator)

rewards = rewards[: len(dataset)]

dataset = dataset.add_column("rewards", rewards)

# perform rejection sampling
df = dataset.to_pandas()
df = df.groupby("prompt").agg({"response":lambda x: list(x), "rewards":lambda x: list(x)}).reset_index()

# conduct rejected sampling algorithm as in https://arxiv.org/pdf/2309.06657.pdf
df["accepted"], df["rewards"] = zip(*df.apply(
lambda x: conduct_rejection_sampling(
x["response"],
x["rewards"],
script_args.num_samples,
script_args.beta
),
axis=1
)
)

# perform ranking
ranking_fn = tournament_ranking if "tournament" in script_args.ranking_method else first_round_ranking

df["chosen"], df["rejected"] = zip(*df.apply(lambda x: ranking_fn(x["accepted"], x["rewards"]), axis=1))
df = df.filter(["prompt", "chosen", "rejected"])
df = df.explode(["chosen", "rejected"])

dataset = Dataset.from_pandas(df)

# save the dataset for later finetuning with DPO
dataset.save_to_disk(script_args.save_dataset_path)







3 changes: 3 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
DataCollatorForCompletionOnlyLM,
RunningMoments,
disable_dropout_in_model,
generate,
compute_reward_score,
conduct_rejection_sampling,
)

# isort: on
Expand Down
Loading