Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][Example] Fix FSDP example for use_orig_params=False #3298

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion docs/source/usage_guides/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,29 @@ accelerate merge-weights pytorch_model_fsdp_0/ output_path

## Mapping between FSDP sharding strategies and DeepSpeed ZeRO Stages
* `FULL_SHARD` maps to the DeepSpeed `ZeRO Stage-3`. Shards optimizer states, gradients and parameters.
* `SHARD_GRAD_OP` maps to the DeepSpeed `ZeRO Stage-2`. Shards optimizer states and gradients.
* `SHARD_GRAD_OP` maps to DeepSpeed `ZeRO Stage-2`. Shards optimizer states and gradients. A key difference from `ZeRO Stage-2` is that `SHARD_GRAD_OP` also shards the model parameters outside of computation (forward/backward passes).
* `NO_SHARD` maps to `ZeRO Stage-0`. No sharding wherein each GPU has full copy of model, optimizer states and gradients.
* `HYBRID_SHARD` maps to `ZeRO++ Stage-3` wherein `zero_hpz_partition_size=<num_gpus_per_node>`. Here, this will shard optimizer states, gradients and parameters within each node while each node has full copy.

## A few caveats to be aware of

- PyTorch FSDP auto wraps sub-modules. With `use_orig_params=False`, it flattens the parameters in each sub-module and shards them in place.
Due to this, any optimizer created before model wrapping gets broken and occupies more memory. Further, you might also observe correctness issues during training.
Hence, it is highly recommended and efficient to prepare the model before creating the optimizer. Example:
```diff
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True)
+ model = accelerator.prepare(model)

optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)

- model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
- model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
- )

+ optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
+ optimizer, train_dataloader, eval_dataloader, lr_scheduler
+ )
```
- In case of multiple models, pass the optimizers to the prepare call in the same order as corresponding models else `accelerator.save_state()` and `accelerator.load_state()` will result in wrong/unexpected behaviour.
- This feature is incompatible with `--predict_with_generate` in the `run_translation.py` script of `Transformers` library.

Expand Down
7 changes: 5 additions & 2 deletions examples/by_feature/fsdp_with_peak_mem_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ def collate_fn(examples):
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name_or_path, return_dict=True, low_cpu_mem_usage=True
)
# In FSDP, with `use_orig_params` as False, we need to `.prepare` the model before instantiating the optimizer.
# We prepare the model beforehand in all cases for simplicity.
model = accelerator.prepare(model)

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
Expand All @@ -267,8 +270,8 @@ def collate_fn(examples):
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
)

model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

overall_step = 0
Expand Down
Loading