Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing MuP #1061

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
28 changes: 26 additions & 2 deletions README-MUP.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,36 @@
"mup-rp-embedding-mult": 1.0,
```

## Install package

```
cd mup
pip install -e .
```

## Generate base shapes

1. Set use-mup to true
2. Set save-base-shapes to true
3. Run once. gpt-neox will instantiate a base model and a delta model, then save one file per rank named <base-shapes-file>.<rank>. gpt-neox will exit immediately.
4. Set save-base-shapes to false

## Generate coord check plots (optional)
## Testing the implementation

The most simple test is to use the coordinate check:
1. Keep use-mup true
2. Set coord-check to true
3. Run once. gpt-neox will output jpg images similar to https://github.com/microsoft/mutransformers/blob/main/README.md#coord-check. gpt-neox will exit immediately
3. Run once. gpt-neox will output jpg images similar to those below and exit immediately
4. Set coord-check to false
What you are gonna get is some stastistics of pre-activations for models only differing by the width. If done correctly these should be approximately horizontal
![](mup/figures/coord_check_up.0.jpg)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This image is missing a text alternative. This is a problem for people using screen readers.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This image is missing a text alternative. This is a problem for people using screen readers.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This image is missing a text alternative. This is a problem for people using screen readers.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This image is missing a text alternative. This is a problem for people using screen readers.

<font size="1"> *Healthy coordinate check*</font>
![](mup/figures/coord_check_sp.0.jpg)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This image is missing a text alternative. This is a problem for people using screen readers.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This image is missing a text alternative. This is a problem for people using screen readers.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This image is missing a text alternative. This is a problem for people using screen readers.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This image is missing a text alternative. This is a problem for people using screen readers.

<font size="1"> *Something's wrong*</font>

A second kind of test is to pick any configuration and learning rate (that doesn't lead to diverging training) and simply run a few different experiments fixing everything except for the width. Since with mup wider is always better the results should look like the figure below
![](mup/figures/width_check.png)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This image is missing a text alternative. This is a problem for people using screen readers.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This image is missing a text alternative. This is a problem for people using screen readers.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This image is missing a text alternative. This is a problem for people using screen readers.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This image is missing a text alternative. This is a problem for people using screen readers.

<font size="1"> *Healthy training*</font>

## Tune mup hyperparameters and LR

Expand All @@ -47,3 +64,10 @@ The values under `mup hp search` were added and correspond to appendix F.4 from
## Transfer

With the best LR set and the best mup HPs set, revert the value of hidden-size in the scaled-up config and run again.

## Usage under distributed setting

The code is setup so that each individual rank takes care of its own piece of model and dumps a different shape file to be picked up for training. The easiest way to do the right thing is to generate the base shapes with the same number of devices and same tensor/pipe parallelism that should be used later on. Consider also the following
- Data parallelism: nothing changes for mup, you can copy paste a base_shape N times for each data-parallel rank
- Pipe parallelism: still nothing changes but different ranks need to deal with different layers so check above
- **Tensor parallelism: has a huge effect on mup**. Column parallel layers get chopped on the input dimension changing the actual width of the parameter. Think carefully about what you are doing if you are not sticking to what's written above
10 changes: 8 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,14 @@ def __init__(
coeff = max(1, self.layer_number)
self.norm_factor *= coeff

if neox_args.use_mup:
self.norm_factor = self.hidden_size_per_attention_head
# TODO
#right now there's no way to correctly set use_mup here, possible options:
#- refactor model init (hard)
#- do this via another config argument, e.g. "mup_norm_factor" (probably easy)
#- ignore, this never changed anything in my experiments
#
#if neox_args.use_mup:
# self.norm_factor = self.hidden_size_per_attention_head

self.rpe = rpe

Expand Down
10 changes: 6 additions & 4 deletions megatron/model/word_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
self.use_mup = neox_args.use_mup
self.use_mup = neox_args.use_mup # TODO: as of now this will always be false
self.mup_embedding_mult = neox_args.mup_embedding_mult
self.mup_rp_embedding_mult = neox_args.mup_rp_embedding_mult

Expand Down Expand Up @@ -155,9 +155,11 @@ def forward(self, input_ids, position_ids, tokentype_ids=None):
# Dropout.
embeddings = self.embedding_dropout(embeddings)

if self.use_mup:
with torch.no_grad():
embeddings.mul_(self.mup_embedding_mult)
# TODO:
# not only this always false because of the way the model is initialized, but this also throws an error
# if self.use_mup:
# with torch.no_grad():
# embeddings.mul_(self.mup_embedding_mult)

return embeddings

Expand Down
13 changes: 8 additions & 5 deletions megatron/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def __init__(
self.init_method = init_method
self.stride = stride
self.mup_rescale_parameters = mup_rescale_parameters
self.use_mup = neox_args.use_mup
self.use_mup = False

# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
Expand Down Expand Up @@ -539,6 +539,7 @@ def mup_reinitialize_weights(self, neox_args):
partition_dim=0,
stride=self.stride,
)
self.use_mup = True

def set_parallel_output(self, value: bool):
assert isinstance(value, bool)
Expand All @@ -547,8 +548,9 @@ def set_parallel_output(self, value: bool):
) # if gather_output is True, parallel output is False, so we set the opposite

def forward(self, input_):
if self.use_mup and self.mup_rescale_parameters:
input_ /= self.width_mult()
if self.mup_rescale_parameters:
if hasattr(self.weight, "infshape"):
input_ /= self.weight.infshape.width_mult()
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply.
Expand Down Expand Up @@ -623,7 +625,7 @@ def __init__(
self.stride = stride
self.keep_master_weight_for_test = keep_master_weight_for_test
self.mup_rescale_parameters = mup_rescale_parameters
self.use_mup = neox_args.use_mup
self.use_mup = False

# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
Expand Down Expand Up @@ -728,13 +730,14 @@ def mup_reinitialize_weights(self, neox_args):
partition_dim=1,
stride=self.stride,
)
self.use_mup = True

def set_parallel_output(self, parallel_output: bool):
assert isinstance(parallel_output, bool)
self.parallel_output = parallel_output

def forward(self, input_):
if self.use_mup and self.mup_rescale_parameters:
if self.mup_rescale_parameters:
input_ /= self.width_mult()
# Set up backprop all-reduce.
if self.input_is_parallel:
Expand Down
15 changes: 10 additions & 5 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,18 @@ def mup_weights_reinit(neox_args, model):
def has_method(o, name):
return callable(getattr(o, name, None))

# HACK: it uses the mother class name to avoid re-initializing the output layer, highly prone to future bugs
# HACK: only works with non-tied input-output layers

previous = ""
for layer in model.modules():
# This normally would happen in set_base_shapes if we actually were able to use the MuReadout class
if hasattr(layer, "mup_rescale_parameters") and layer.mup_rescale_parameters:
layer._rescale_parameters()

if has_method(layer, "mup_reinitialize_weights"):
layer.mup_reinitialize_weights(neox_args)
if previous != "ParallelLinearPipe":
if has_method(layer, "mup_reinitialize_weights"):
layer.mup_reinitialize_weights(neox_args)
previous = layer.__class__.__name__


def save_base_shapes(neox_args, base_shapes, use_cache):
Expand Down Expand Up @@ -530,9 +535,9 @@ def get_optimizer(model, neox_args):
# Use Adam
if neox_args.use_mup:
try:
from mup import MuAdam
from mup import MuAdamW # TODO: was there any particular reason for not using MuAdamW?

adam_optimizer = MuAdam
adam_optimizer = MuAdamW
except ModuleNotFoundError:
print("Please install mup https://github.com/microsoft/mup")
raise Exception
Expand Down
Binary file added test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading