Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add special case for saving model when running with ZERO-3 optimisation. #149

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 71 additions & 7 deletions llm_unlearn_ucl/unlearn_harm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,14 @@ def main(args) -> None:
) % args.batch_size == 0, "samples in each 'sequence' (--samples_count / --sequential) should be a multiple of batch_size."

if args.wandb_log:
accelerator = Accelerator(log_with="wandb")
accelerator: Accelerator = Accelerator(log_with="wandb")
accelerator.init_trackers(
project_name=args.wandb_project_name,
config=vars(args),
init_kwargs={"wandb": {"name": args.wandb_run_name}},
)
else:
accelerator = Accelerator()
accelerator: Accelerator = Accelerator()
device = accelerator.device

# setup logging
Expand Down Expand Up @@ -365,8 +365,28 @@ def main(args) -> None:
optimizer.zero_grad()

# NOTE: This only handles deepspeed zero and zero2, zero3 will require change
if accelerator.is_local_main_process:
if args.sequential == 1 and epoch_num % args.save_every == 0:
if args.sequential == 1 and epoch_num % args.save_every == 0:
# NOTE: special case for zero 3
if (
accelerator.deepspeed_config is not None
and accelerator.deepspeed_config["zero_optimization"]["stage"]
== 3
):
print("Zero 3 optim: Saving model shards from all GPUs!")
model_tokenizer_save_dir = Path(
os.path.join(args.model_save_dir, f"idx_{epoch_num}")
)
model_tokenizer_save_dir.mkdir(parents=True, exist_ok=True)
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
model_tokenizer_save_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
tokenizer.save_pretrained(model_tokenizer_save_dir)
print(f"Saved zero-3 model at step {epoch_num}.")
elif accelerator.is_local_main_process:
accelerator.wait_for_everyone() # for model saving
# NOTE: Batch unlearning, save for every epoch
model_tokenizer_save_dir = Path(
Expand Down Expand Up @@ -424,8 +444,29 @@ def main(args) -> None:
optimizer.zero_grad()
idx += 1
final_model_tag = idx
if accelerator.is_local_main_process:
if idx % args.save_every == 0:
if idx % args.save_every == 0:
# NOTE: special case for zero 3
if (
accelerator.deepspeed_config is not None
and accelerator.deepspeed_config["zero_optimization"]["stage"]
== 3
):
print("Zero 3 optim: Saving model shards from all GPUs!")
model_tokenizer_save_dir = Path(
os.path.join(args.model_save_dir, f"idx_{epoch_num}")
)
model_tokenizer_save_dir.mkdir(parents=True, exist_ok=True)
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
model_tokenizer_save_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
tokenizer.save_pretrained(model_tokenizer_save_dir)
print(f"Saved zero-3 model at step {epoch_num}.")
elif accelerator.is_local_main_process:
# If not using zero 2, just save the entire model on the main process (its not sharded)
accelerator.wait_for_everyone() # for model saving
# Save model and tokenizer.
model_tokenizer_save_dir = Path(
Expand Down Expand Up @@ -468,7 +509,30 @@ def main(args) -> None:
model = model.merge_and_unload()

# Save final model.
if accelerator.is_local_main_process:
# NOTE: special case for zero 3
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we not handling model saving in the main process? I wonder

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Adamliu1 good q, I basedd this off of: https://huggingface.co/docs/accelerate/usage_guides/deepspeed#saving-and-loading , and I belive its because each process (on each of the 4 GPUs) has a portion of the model that needs to be synced across - hence cannot just unwrap it on a single process.

if (
accelerator.deepspeed_config is not None
and accelerator.deepspeed_config["zero_optimization"]["stage"] == 3
):
print("Zero 3 optim: Saving model shards from all GPUs!")
model_tokenizer_save_dir = Path(
os.path.join(args.model_save_dir, f"idx_{epoch_num}")
)
model_tokenizer_save_dir.mkdir(parents=True, exist_ok=True)
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
model_tokenizer_save_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
tokenizer.save_pretrained(model_tokenizer_save_dir)
print(f"Saved final zero-3 model at step {epoch_num}.")
print("Unlearning finished")
logger.info("Unlearning finished")
if bool(args.wandb_log):
accelerator.end_training()
elif accelerator.is_local_main_process:
model_tokenizer_save_dir = Path(
os.path.join(args.model_save_dir, f"idx_{final_model_tag}")
)
Expand Down
Loading