diff --git a/llm_unlearn_ucl/unlearn_harm.py b/llm_unlearn_ucl/unlearn_harm.py index 46961f7..f88bd59 100644 --- a/llm_unlearn_ucl/unlearn_harm.py +++ b/llm_unlearn_ucl/unlearn_harm.py @@ -400,7 +400,7 @@ def main(args) -> None: question_prefix_str = "### Question:" answer_prefix_str = "### Answer:" - elif args.unlearning_dataset == "math_qa": + elif args.unlearning_dataset == "allenai/math_qa": full_bad_dataset = load_dataset("math_qa", split="train") if args.shuffle_seed: # shuffle the dataset with a given seed for reproducibility