From 7ccee52668c3dc325342d3bb2316496d35b475f3 Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Fri, 22 Mar 2024 00:31:29 +0000 Subject: [PATCH 1/4] Restore bytedance implementation. --- llm_unlearn_ucl/unlearn_harm.py | 88 ++++++++++++++++----------------- 1 file changed, 42 insertions(+), 46 deletions(-) diff --git a/llm_unlearn_ucl/unlearn_harm.py b/llm_unlearn_ucl/unlearn_harm.py index 739cbe7..eaf7608 100644 --- a/llm_unlearn_ucl/unlearn_harm.py +++ b/llm_unlearn_ucl/unlearn_harm.py @@ -472,56 +472,52 @@ def main(args) -> None: else: # NOTE: Original ByteDance Unlearning. - train_bad_loader_gen = iter(train_bad_loaders[0]) - train_normal_loader_gen = iter(train_normal_loaders[0]) bad_loader_len = len(train_bad_loaders[0]) normal_loader_len = len(train_normal_loaders[0]) epoch_num = 0 - for idx in range(args.max_unlearn_steps): - try: - bad_batch = next(train_bad_loader_gen) - except StopIteration: - epoch_num += 1 - train_bad_loader_gen = iter(train_bad_loaders[0]) - bad_batch = next(train_bad_loader_gen) - try: - normal_batch = next(train_normal_loader_gen) - except StopIteration: - train_normal_loader_gen = iter(train_normal_loaders[0]) - normal_batch = next(train_normal_loader_gen) - loss, bad_loss = run_training_batch( - model=model, - pretrained_model=pretrained_model, - tokenizer=tokenizer, - device=device, - normal_ans=normal_ans, - bad_batch=bad_batch, - normal_batch=normal_batch, - idx=idx, - epoch=epoch_num, - bad_loader_size=bad_loader_len, - normal_loader_size=normal_loader_len, - question_prefix_str=question_prefix_str, - answer_prefix_str=answer_prefix_str, - ) - accelerator.backward(loss) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - final_model_tag = idx - if idx % args.save_every == 0: - model_tokenizer_save_dir = Path( - os.path.join(args.model_save_dir, f"idx_{idx}") + while idx < args.max_unlearn_steps: + for bad_batch, normal_batch in zip( + train_bad_loaders[0], train_normal_loaders[0] + ): + loss, bad_loss = run_training_batch( + model=model, + pretrained_model=pretrained_model, + tokenizer=tokenizer, + device=device, + normal_ans=normal_ans, + bad_batch=bad_batch, + normal_batch=normal_batch, + idx=idx, + epoch=epoch_num, + bad_loader_size=bad_loader_len, + normal_loader_size=normal_loader_len, + question_prefix_str=question_prefix_str, + answer_prefix_str=answer_prefix_str, ) - model_tokenizer_save_dir.mkdir(parents=True, exist_ok=True) - - model.save_pretrained(model_tokenizer_save_dir, from_pt=True) - tokenizer.save_pretrained(model_tokenizer_save_dir) - running_loss.append(bad_loss.item()) - while len(running_loss) > args.num_running_loss: - running_loss.popleft() - avg_loss = abs(np.mean(running_loss)) - if avg_loss > args.max_bad_loss: + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + idx += 1 + final_model_tag = idx + if idx % args.save_every == 0: + # Save model and tokenizer. + model_tokenizer_save_dir = Path( + os.path.join(args.model_save_dir, f"idx_{idx}") + ) + model_tokenizer_save_dir.mkdir(parents=True, exist_ok=True) + + model.save_pretrained(model_tokenizer_save_dir, from_pt=True) + tokenizer.save_pretrained(model_tokenizer_save_dir) + running_loss.append(bad_loss.item()) + while len(running_loss) > args.num_running_loss: + running_loss.popleft() + avg_loss = abs(np.mean(running_loss)) + if avg_loss > args.max_bad_loss: + break + + epoch_num += 1 + if abs(np.mean(running_loss)) > args.max_bad_loss: print( f"bad_loss {avg_loss} exceeding args.max_bad_loss {args.max_bad_loss}. Unlearning stopped." ) From 5dd90b4cf07bdfd1fdbd126ae92c3b1ba471bc6c Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Fri, 22 Mar 2024 01:01:14 +0000 Subject: [PATCH 2/4] fix: Bytedance training loop terminate on max_unlearn_steps --- llm_unlearn_ucl/unlearn_harm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llm_unlearn_ucl/unlearn_harm.py b/llm_unlearn_ucl/unlearn_harm.py index eaf7608..af812a6 100644 --- a/llm_unlearn_ucl/unlearn_harm.py +++ b/llm_unlearn_ucl/unlearn_harm.py @@ -513,9 +513,11 @@ def main(args) -> None: while len(running_loss) > args.num_running_loss: running_loss.popleft() avg_loss = abs(np.mean(running_loss)) - if avg_loss > args.max_bad_loss: + if avg_loss > args.max_bad_loss or idx >= args.max_unlearn_steps: break + if idx >= args.max_unlearn_steps: + print("max_unlearn_steps reached. Unlearning stopped.") epoch_num += 1 if abs(np.mean(running_loss)) > args.max_bad_loss: print( From 1a74ac82e94cd6b10b44e523bf62b976d08bb099 Mon Sep 17 00:00:00 2001 From: Andrzej Szablewski Date: Fri, 22 Mar 2024 01:10:31 +0000 Subject: [PATCH 3/4] Added a missing break statement. Signed-off-by: Andrzej Szablewski --- llm_unlearn_ucl/unlearn_harm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/llm_unlearn_ucl/unlearn_harm.py b/llm_unlearn_ucl/unlearn_harm.py index af812a6..f7e0e7d 100644 --- a/llm_unlearn_ucl/unlearn_harm.py +++ b/llm_unlearn_ucl/unlearn_harm.py @@ -512,14 +512,16 @@ def main(args) -> None: running_loss.append(bad_loss.item()) while len(running_loss) > args.num_running_loss: running_loss.popleft() - avg_loss = abs(np.mean(running_loss)) - if avg_loss > args.max_bad_loss or idx >= args.max_unlearn_steps: + + if abs(np.mean(running_loss)) > args.max_bad_loss or idx >= args.max_unlearn_steps: break + epoch_num += 1 + if idx >= args.max_unlearn_steps: print("max_unlearn_steps reached. Unlearning stopped.") - epoch_num += 1 - if abs(np.mean(running_loss)) > args.max_bad_loss: + break + if avg_loss := abs(np.mean(running_loss)) > args.max_bad_loss: print( f"bad_loss {avg_loss} exceeding args.max_bad_loss {args.max_bad_loss}. Unlearning stopped." ) From e1d558ec0857af6edfa9a202886ea88fec7f944a Mon Sep 17 00:00:00 2001 From: Andrzej Szablewski Date: Fri, 22 Mar 2024 01:11:25 +0000 Subject: [PATCH 4/4] Changed default wandb logging frequency and renamed the corresponding argument. Signed-off-by: Andrzej Szablewski --- llm_unlearn_ucl/parse_args.py | 4 ++-- llm_unlearn_ucl/unlearn_harm.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llm_unlearn_ucl/parse_args.py b/llm_unlearn_ucl/parse_args.py index df33d40..34b7857 100644 --- a/llm_unlearn_ucl/parse_args.py +++ b/llm_unlearn_ucl/parse_args.py @@ -120,9 +120,9 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument( - "--wandb_log_feq", + "--wandb_log_freq", type=int, - default=50, + default=1, help="The logging frequency for wandb to upload data", ) diff --git a/llm_unlearn_ucl/unlearn_harm.py b/llm_unlearn_ucl/unlearn_harm.py index f7e0e7d..fcbc672 100644 --- a/llm_unlearn_ucl/unlearn_harm.py +++ b/llm_unlearn_ucl/unlearn_harm.py @@ -180,7 +180,7 @@ def run_training_batch( # NOTE: backwardnd optimisation is done outside of this function in the # training loop for gradient accumulation compatibility. - if bool(args.wandb_log) and (idx % args.wandb_log_feq == 0): + if bool(args.wandb_log) and (idx % args.wandb_log_freq == 0): wandb.log( { "batch": idx,