Skip to content

Commit

Permalink
🧳 Move zen generation script and fix tests (#2393)
Browse files Browse the repository at this point in the history
* Move zen

* step -> stepwise_supervision

* Fix train_test_split shuffle issue

* Fix tests

* Update tests/test_sft_trainer.py

Co-authored-by: Kashif Rasul <[email protected]>

* Fix typo in key name

---------

Co-authored-by: Kashif Rasul <[email protected]>
  • Loading branch information
qgallouedec and kashif authored Nov 26, 2024
1 parent baee06f commit 43df3a4
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 35 deletions.
36 changes: 18 additions & 18 deletions examples/datasets/zen.py → scripts/generate_zen_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ class ScriptArguments:
Fraction of the dataset to include in the test split.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/zen"`):
repo_id (`str`, *optional*, defaults to `"trl-internal-testing/zen"`):
Hugging Face repository ID to push the dataset to.
"""

test_size: float = 0.1
push_to_hub: bool = False
repo_id: str = "trl-lib/zen"
repo_id: str = "trl-internal-testing/zen"


def main(test_size, push_to_hub, repo_id):
Expand Down Expand Up @@ -62,7 +62,7 @@ def main(test_size, push_to_hub, repo_id):
"Namespaces are one honking great idea -- let's do more of those!",
],
})
standard_language_modeling_dataset = standard_language_modeling_dataset.train_test_split(test_size=test_size)
standard_language_modeling_dataset = standard_language_modeling_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
standard_language_modeling_dataset.push_to_hub(repo_id, config_name="standard_language_modeling")

Expand All @@ -89,7 +89,7 @@ def main(test_size, push_to_hub, repo_id):
"Namespaces are one honking great",
],
})
standard_prompt_only_dataset = standard_prompt_only_dataset.train_test_split(test_size=test_size)
standard_prompt_only_dataset = standard_prompt_only_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
standard_prompt_only_dataset.push_to_hub(repo_id, config_name="standard_prompt_only")

Expand Down Expand Up @@ -137,7 +137,7 @@ def main(test_size, push_to_hub, repo_id):
" idea -- let's do more of those!",
],
})
standard_prompt_completion_dataset = standard_prompt_completion_dataset.train_test_split(test_size=test_size)
standard_prompt_completion_dataset = standard_prompt_completion_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
standard_prompt_completion_dataset.push_to_hub(repo_id, config_name="standard_prompt_completion")

Expand Down Expand Up @@ -206,7 +206,7 @@ def main(test_size, push_to_hub, repo_id):
" watermelon -- let's plant some!",
],
})
standard_preference_dataset = standard_preference_dataset.train_test_split(test_size=test_size)
standard_preference_dataset = standard_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
standard_preference_dataset.push_to_hub(repo_id, config_name="standard_preference")

Expand Down Expand Up @@ -254,7 +254,7 @@ def main(test_size, push_to_hub, repo_id):
"Namespaces are one honking great watermelon -- let's plant some!",
],
})
standard_implicit_prompt_preference_dataset = standard_implicit_prompt_preference_dataset.train_test_split(test_size=test_size)
standard_implicit_prompt_preference_dataset = standard_implicit_prompt_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
standard_implicit_prompt_preference_dataset.push_to_hub(repo_id, config_name="standard_implicit_prompt_preference")

Expand Down Expand Up @@ -303,11 +303,11 @@ def main(test_size, push_to_hub, repo_id):
],
"label": [True, False, False, True, True, False, True, False, True, True, False, True, True, False, True, False, True, False, False],
})
standard_unpaired_preference_dataset = standard_unpaired_preference_dataset.train_test_split(test_size=test_size)
standard_unpaired_preference_dataset = standard_unpaired_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
standard_unpaired_preference_dataset.push_to_hub(repo_id, config_name="standard_unpaired_preference")

standard_step_dataset = Dataset.from_dict({
standard_stepwise_supervision_dataset = Dataset.from_dict({
"prompt": [
"Beautiful is better than",
"Explicit is better than",
Expand Down Expand Up @@ -350,7 +350,7 @@ def main(test_size, push_to_hub, repo_id):
[" of those great ideas,", " that solve many problems."],
[" the code should still aim for balance."],
],
"label": [
"labels": [
[False, True],
[False, True, False],
[False, True],
Expand All @@ -371,9 +371,9 @@ def main(test_size, push_to_hub, repo_id):
[False]
]
})
standard_step_dataset = standard_step_dataset.train_test_split(test_size=test_size)
standard_stepwise_supervision_dataset = standard_stepwise_supervision_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
standard_step_dataset.push_to_hub(repo_id, config_name="standard_step")
standard_stepwise_supervision_dataset.push_to_hub(repo_id, config_name="standard_stepwise_supervision")

conversational_language_modeling_dataset = Dataset.from_dict({
"messages": [
Expand All @@ -398,7 +398,7 @@ def main(test_size, push_to_hub, repo_id):
[{"role": "user", "content": "Any great ideas?"}, {"role": "assistant", "content": "Namespaces are one honking great idea."}],
],
})
conversational_language_modeling_dataset = conversational_language_modeling_dataset.train_test_split(test_size=test_size)
conversational_language_modeling_dataset = conversational_language_modeling_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
conversational_language_modeling_dataset.push_to_hub(repo_id, config_name="conversational_language_modeling")

Expand All @@ -425,7 +425,7 @@ def main(test_size, push_to_hub, repo_id):
[{"role": "user", "content": "Any great ideas?"}],
],
})
conversational_prompt_only_dataset = conversational_prompt_only_dataset.train_test_split(test_size=test_size)
conversational_prompt_only_dataset = conversational_prompt_only_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
conversational_prompt_only_dataset.push_to_hub(repo_id, config_name="conversational_prompt_only")

Expand Down Expand Up @@ -473,7 +473,7 @@ def main(test_size, push_to_hub, repo_id):
[{"role": "assistant", "content": "Namespaces are one honking great idea."}],
],
})
conversational_prompt_completion_dataset = conversational_prompt_completion_dataset.train_test_split(test_size=test_size)
conversational_prompt_completion_dataset = conversational_prompt_completion_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
conversational_prompt_completion_dataset.push_to_hub(repo_id, config_name="conversational_prompt_completion")

Expand Down Expand Up @@ -542,7 +542,7 @@ def main(test_size, push_to_hub, repo_id):
[{"role": "assistant", "content": "Recursion."}],
],
})
conversational_preference_dataset = conversational_preference_dataset.train_test_split(test_size=test_size)
conversational_preference_dataset = conversational_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
conversational_preference_dataset.push_to_hub(repo_id, config_name="conversational_preference")

Expand Down Expand Up @@ -590,7 +590,7 @@ def main(test_size, push_to_hub, repo_id):
[{"role": "user", "content": "Any great ideas?"}, {"role": "assistant", "content": "Recursion."}],
],
})
conversational_implicit_prompt_preference_dataset = conversational_implicit_prompt_preference_dataset.train_test_split(test_size=test_size)
conversational_implicit_prompt_preference_dataset = conversational_implicit_prompt_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
conversational_implicit_prompt_preference_dataset.push_to_hub(repo_id, config_name="conversational_implicit_prompt_preference")

Expand Down Expand Up @@ -639,7 +639,7 @@ def main(test_size, push_to_hub, repo_id):
],
"label": [True, True, True, False, True, True, True, False, True, False, True, False, True, False, False, True, True, True, True],
})
conversational_unpaired_preference_dataset = conversational_unpaired_preference_dataset.train_test_split(test_size=test_size)
conversational_unpaired_preference_dataset = conversational_unpaired_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
conversational_unpaired_preference_dataset.push_to_hub(repo_id, config_name="conversational_unpaired_preference")
# fmt: on
Expand Down
18 changes: 10 additions & 8 deletions tests/test_bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ def test_tokenize_and_process_tokens(self):
self.assertListEqual(tokenized_dataset["prompt"], train_dataset["prompt"])
self.assertListEqual(tokenized_dataset["completion"], train_dataset["completion"])
self.assertListEqual(tokenized_dataset["label"], train_dataset["label"])
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [31137])
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1])
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [374, 2664, 1091, 16965, 13])
self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1, 1, 1, 1])
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13])
self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1])

fn_kwargs = {
"prefix": "",
Expand All @@ -178,13 +178,15 @@ def test_tokenize_and_process_tokens(self):
self.assertListEqual(processed_dataset["prompt"], train_dataset["prompt"])
self.assertListEqual(processed_dataset["completion"], train_dataset["completion"])
self.assertListEqual(processed_dataset["label"], train_dataset["label"])
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [31137])
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1])
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
self.assertListEqual(
processed_dataset["completion_input_ids"][0], [31137, 374, 2664, 1091, 16965, 13, 151645]
processed_dataset["completion_input_ids"][0], [46518, 374, 2664, 1091, 27261, 13, 151645]
)
self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1])
self.assertListEqual(processed_dataset["completion_labels"][0], [-100, 374, 2664, 1091, 16965, 13, 151645])
self.assertListEqual(
processed_dataset["completion_labels"][0], [-100, -100, -100, -100, 27261, 13, 151645]
)

@require_sklearn
def test_bco_trainer_without_providing_ref_model(self):
Expand Down
18 changes: 10 additions & 8 deletions tests/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,10 @@ def test_tokenize_and_process_tokens(self):
self.assertListEqual(tokenized_dataset["prompt"], train_dataset["prompt"])
self.assertListEqual(tokenized_dataset["completion"], train_dataset["completion"])
self.assertListEqual(tokenized_dataset["label"], train_dataset["label"])
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [31137])
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1])
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [374, 2664, 1091, 16965, 13])
self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1, 1, 1, 1])
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13])
self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1])

# Test corruption of (prompt, completion) pairs for KL dataset
for batch_size in [2, 3]:
Expand Down Expand Up @@ -196,13 +196,15 @@ def test_tokenize_and_process_tokens(self):
self.assertListEqual(processed_dataset["prompt"], train_dataset["prompt"])
self.assertListEqual(processed_dataset["completion"], train_dataset["completion"])
self.assertListEqual(processed_dataset["label"], train_dataset["label"])
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [31137])
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1])
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
self.assertListEqual(
processed_dataset["completion_input_ids"][0], [31137, 374, 2664, 1091, 16965, 13, 151645]
processed_dataset["completion_input_ids"][0], [46518, 374, 2664, 1091, 27261, 13, 151645]
)
self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1])
self.assertListEqual(processed_dataset["completion_labels"][0], [-100, 374, 2664, 1091, 16965, 13, 151645])
self.assertListEqual(
processed_dataset["completion_labels"][0], [-100, -100, -100, -100, 27261, 13, 151645]
)

def test_kto_trainer_without_providing_ref_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,7 @@ def test_sft_trainer_eval_packing(self):
)

self.assertEqual(len(trainer.train_dataset["input_ids"]), 46) # w/ this dataset, we end up with 46 seqs
self.assertEqual(len(trainer.eval_dataset["input_ids"]), 5) # w/ this dataset, we end up with 5 seqs
self.assertEqual(len(trainer.eval_dataset["input_ids"]), 6) # w/ this dataset, we end up with 6 seqs

def test_sft_trainer_no_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down

0 comments on commit 43df3a4

Please sign in to comment.