diff --git a/llm_unlearn_ucl/deepspeed_4gpu.yaml b/llm_unlearn_ucl/deepspeed_4gpu.yaml new file mode 100644 index 0000000..018beb9 --- /dev/null +++ b/llm_unlearn_ucl/deepspeed_4gpu.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/llm_unlearn_ucl/unlearn_harm.py b/llm_unlearn_ucl/unlearn_harm.py index 95ddd28..b8169f0 100644 --- a/llm_unlearn_ucl/unlearn_harm.py +++ b/llm_unlearn_ucl/unlearn_harm.py @@ -132,14 +132,14 @@ def main(args) -> None: ) % args.batch_size == 0, "samples in each 'sequence' (--samples_count / --sequential) should be a multiple of batch_size." if args.wandb_log: - accelerator = Accelerator(log_with="wandb") + accelerator: Accelerator = Accelerator(log_with="wandb") accelerator.init_trackers( project_name=args.wandb_project_name, config=vars(args), init_kwargs={"wandb": {"name": args.wandb_run_name}}, ) else: - accelerator = Accelerator() + accelerator: Accelerator = Accelerator() device = accelerator.device # setup logging @@ -365,8 +365,28 @@ def main(args) -> None: optimizer.zero_grad() # NOTE: This only handles deepspeed zero and zero2, zero3 will require change - if accelerator.is_local_main_process: - if args.sequential == 1 and epoch_num % args.save_every == 0: + if args.sequential == 1 and epoch_num % args.save_every == 0: + # NOTE: special case for zero 3 + if ( + accelerator.deepspeed_config is not None + and accelerator.deepspeed_config["zero_optimization"]["stage"] + == 3 + ): + print("Zero 3 optim: Saving model shards from all GPUs!") + model_tokenizer_save_dir = Path( + os.path.join(args.model_save_dir, f"idx_{epoch_num}") + ) + model_tokenizer_save_dir.mkdir(parents=True, exist_ok=True) + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + model_tokenizer_save_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) + tokenizer.save_pretrained(model_tokenizer_save_dir) + print(f"Saved zero-3 model at step {epoch_num}.") + elif accelerator.is_local_main_process: accelerator.wait_for_everyone() # for model saving # NOTE: Batch unlearning, save for every epoch model_tokenizer_save_dir = Path( @@ -424,8 +444,29 @@ def main(args) -> None: optimizer.zero_grad() idx += 1 final_model_tag = idx - if accelerator.is_local_main_process: - if idx % args.save_every == 0: + if idx % args.save_every == 0: + # NOTE: special case for zero 3 + if ( + accelerator.deepspeed_config is not None + and accelerator.deepspeed_config["zero_optimization"]["stage"] + == 3 + ): + print("Zero 3 optim: Saving model shards from all GPUs!") + model_tokenizer_save_dir = Path( + os.path.join(args.model_save_dir, f"idx_{epoch_num}") + ) + model_tokenizer_save_dir.mkdir(parents=True, exist_ok=True) + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + model_tokenizer_save_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) + tokenizer.save_pretrained(model_tokenizer_save_dir) + print(f"Saved zero-3 model at step {epoch_num}.") + elif accelerator.is_local_main_process: + # If not using zero 2, just save the entire model on the main process (its not sharded) accelerator.wait_for_everyone() # for model saving # Save model and tokenizer. model_tokenizer_save_dir = Path( @@ -468,7 +509,30 @@ def main(args) -> None: model = model.merge_and_unload() # Save final model. - if accelerator.is_local_main_process: + # NOTE: special case for zero 3 + if ( + accelerator.deepspeed_config is not None + and accelerator.deepspeed_config["zero_optimization"]["stage"] == 3 + ): + print("Zero 3 optim: Saving model shards from all GPUs!") + model_tokenizer_save_dir = Path( + os.path.join(args.model_save_dir, f"idx_{epoch_num}") + ) + model_tokenizer_save_dir.mkdir(parents=True, exist_ok=True) + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + model_tokenizer_save_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) + tokenizer.save_pretrained(model_tokenizer_save_dir) + print(f"Saved final zero-3 model at step {epoch_num}.") + print("Unlearning finished") + logger.info("Unlearning finished") + if bool(args.wandb_log): + accelerator.end_training() + elif accelerator.is_local_main_process: model_tokenizer_save_dir = Path( os.path.join(args.model_save_dir, f"idx_{final_model_tag}") )