diff --git a/.flake8 b/.flake8 index d067c43ce..636dc598f 100644 --- a/.flake8 +++ b/.flake8 @@ -1,7 +1,7 @@ [flake8] exclude = .git max-line-length = 120 -ignore = E203, E501, W503, W605, F821, E266 +ignore = E203, E501, W503, W605, F821, E266, E731 per-file-ignores = */__init__.py: F401 examples/*.py: E402 diff --git a/.github/workflows/run_entry_tests.yaml b/.github/workflows/run_entry_tests.yaml index c958e9bf2..dbde2dbd1 100644 --- a/.github/workflows/run_entry_tests.yaml +++ b/.github/workflows/run_entry_tests.yaml @@ -21,8 +21,6 @@ jobs: run: | python -m pip install --upgrade pip pip install flake8 pytest - # install haliax from source b/c it's changing in parallel with this repo - pip install git+https://github.com/stanford-crfm/haliax.git pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - name: Run entry tests with pytest run: | diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 46828d5b8..3af69bacf 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -21,9 +21,7 @@ jobs: run: | python -m pip install --upgrade pip pip install flake8 pytest - # install haliax from source b/c it's changing in parallel with this repo - pip install git+https://github.com/stanford-crfm/haliax.git pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - name: Test with pytest run: | - XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=tests:src:. pytest tests -m "not entry" + XLA_FLAGS=--xla_force_host_platform_device_count=8 PYTHONPATH=tests:src:. pytest tests -m "not entry and not slow" diff --git a/README.md b/README.md index b240f3560..13097d7dd 100644 --- a/README.md +++ b/README.md @@ -36,12 +36,13 @@ Haliax's documentation is available at [haliax.readthedocs.io](https://haliax.re * **Distributed Training**: We support distributed training on TPUs (and soon, GPUs), including FSDP and tensor parallelism. * **Compatibility**: Levanter supports importing and exporting models to/from the Hugging Face ecosystem, including tokenizers, datasets, and models via [SafeTensors](https://github.com/huggingface/safetensors). * **Performance**: Levanter's performance rivals commercially-backed frameworks like MosaicML's Composer or Google's MaxText. -* **Reproducibility**: Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption. * **Cached On-Demand Data Preprocessing**: We preprocess corpora online, but we cache the results of preprocessing so that resumes are much faster and so that subsequent runs are even faster. As soon as the first part of the cache is complete, Levanter will start training. -* **Logging**: Logging is done with [WandB](https://wandb.ai/), complete with a fancy online visualization of the validation set during training. +* **Optimization**: Levanter supports the new [Sophia](https://arxiv.org/abs/2305.14342) optimizer, which can be 2x as fast as Adam. We also support ses [Optax](https://github.com/deepmind/optax) for optimization with AdamW, etc. +* **Logging**: Levanter supports a few different logging backends, including [WandB](https://wandb.ai/site) and [TensorBoard](https://www.tensorflow.org/tensorboard). (Adding a new logging backend is easy!) Levanter even exposes the ability +to log inside of JAX `jit`-ted functions. +* **Reproducibility**: On TPU, Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption. * **Distributed Checkpointing**: Distributed checkpointing is supported via Google's [TensorStore](https://google.github.io/tensorstore/) library. Training can even be resumed on a different number of hosts, though this breaks reproducibility for now. -* **Optimization**: Levanter uses [Optax](https://github.com/deepmind/optax) for optimization. Our new optimizer, [Sophia](https://arxiv.org/abs/2305.14342), is available in the [dev branch](https://github.com/stanford-crfm/levanter/tree/dev). @@ -150,7 +151,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/backpack.yaml b/config/backpack.yaml index 5b6cef3cb..493be77a3 100644 --- a/config/backpack.yaml +++ b/config/backpack.yaml @@ -10,7 +10,7 @@ model: num_senses: 16 sense_intermediate_scale: 4 trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "backpack" ] diff --git a/config/doremi/doremi_nano.yaml b/config/doremi/doremi_nano.yaml new file mode 100644 index 000000000..397e91239 --- /dev/null +++ b/config/doremi/doremi_nano.yaml @@ -0,0 +1,28 @@ +data: + configs: + wikitext: + id: dlwh/wikitext_103_detokenized + w2: + id: dlwh/wikitext_103_detokenized + train_weights: + wikitext: 0.5 + w2: 0.5 +model: + type: gpt2 + hidden_dim: 32 + num_heads: 4 + num_layers: 2 +trainer: + mp: f32 + num_train_steps: 100 + + checkpointer: + keep: + - every: 50 + save_interval: 5m + + train_batch_size: 32 + + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" diff --git a/config/gpt2_1536.yaml b/config/gpt2_1536.yaml index 50ccbd882..a3633bf65 100644 --- a/config/gpt2_1536.yaml +++ b/config/gpt2_1536.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_20b.yaml b/config/gpt2_20b.yaml index 76bf6ba96..6f5f40e1b 100644 --- a/config/gpt2_20b.yaml +++ b/config/gpt2_20b.yaml @@ -12,7 +12,7 @@ model: use_bias: false fcm_prob: 0.15 trainer: - wandb: + tracker: project: "levanter" tags: ["pile", "gpt2"] diff --git a/config/gpt2_7b.yaml b/config/gpt2_7b.yaml index affb67aa5..36a3d4fd2 100644 --- a/config/gpt2_7b.yaml +++ b/config/gpt2_7b.yaml @@ -11,7 +11,7 @@ model: resid_pdrop: 0.0 fcm_prob: 0.15 trainer: - wandb: + tracker: project: "levanter" tags: ["pile", "gpt2"] diff --git a/config/gpt2_data_mix.yaml b/config/gpt2_data_mix.yaml deleted file mode 100644 index 073e3b46b..000000000 --- a/config/gpt2_data_mix.yaml +++ /dev/null @@ -1,22 +0,0 @@ -data: - configs: - owt: - train_urls: - - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" - validation_urls: - - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" - wikitext: - id: dlwh/wikitext_103_detokenized - train_weights: - owt: 0.6 - wikitext: 0.4 - tokenizer: gpt2 - cache_dir: "gs://levanter-data/tokenized/data_mix" -model: - type: gpt2 - hidden_dim: 32 - num_heads: 4 - num_layers: 2 -trainer: - num_train_steps: 100 - train_batch_size: 32 diff --git a/config/gpt2_large.yaml b/config/gpt2_large.yaml index 525a92c99..8a8aea8d7 100644 --- a/config/gpt2_large.yaml +++ b/config/gpt2_large.yaml @@ -8,13 +8,13 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 model_axis_size: 1 - per_device_parallelism: 16 + per_device_parallelism: -1 optimizer: learning_rate: 2E-4 weight_decay: 0.1 diff --git a/config/gpt2_large_sophia_g.yaml b/config/gpt2_large_sophia_g.yaml new file mode 100644 index 000000000..53a1d0806 --- /dev/null +++ b/config/gpt2_large_sophia_g.yaml @@ -0,0 +1,21 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 1280 + num_heads: 20 + num_layers: 36 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "sophia-g"] + + num_train_steps: 200000 + mp: p=f32,c=bfloat16 + +optimizer: + type: sophia-g + learning_rate: 2E-4 + weight_decay: 0.15 diff --git a/config/gpt2_large_sophia_h.yaml b/config/gpt2_large_sophia_h.yaml new file mode 100644 index 000000000..314801728 --- /dev/null +++ b/config/gpt2_large_sophia_h.yaml @@ -0,0 +1,21 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 1280 + num_heads: 20 + num_layers: 36 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "sophia-h"] + + num_train_steps: 200000 + mp: p=f32,c=bfloat16 + +optimizer: + type: sophia-h + learning_rate: 1.7E-4 + weight_decay: 0.2 diff --git a/config/gpt2_medium.yaml b/config/gpt2_medium.yaml index 9ea4408bc..47e21799c 100644 --- a/config/gpt2_medium.yaml +++ b/config/gpt2_medium.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_micro.yaml b/config/gpt2_micro.yaml index 274ecddaa..0a8283e78 100644 --- a/config/gpt2_micro.yaml +++ b/config/gpt2_micro.yaml @@ -6,7 +6,7 @@ model: num_heads: 8 num_layers: 4 trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_nano.yaml b/config/gpt2_nano.yaml index 993302670..5612fc104 100644 --- a/config/gpt2_nano.yaml +++ b/config/gpt2_nano.yaml @@ -14,8 +14,7 @@ trainer: - every: 50 save_interval: 5m - per_device_eval_parallelism: 1 - per_device_parallelism: 1 + per_device_parallelism: 16 train_batch_size: 32 tensor_parallel_axes: ["mlp", "heads"] diff --git a/config/gpt2_nano_tb.yaml b/config/gpt2_nano_tb.yaml new file mode 100644 index 000000000..9ada16aa3 --- /dev/null +++ b/config/gpt2_nano_tb.yaml @@ -0,0 +1,26 @@ +data: + id: dlwh/wikitext_103_detokenized +model: + type: gpt2 + hidden_dim: 32 + num_heads: 4 + num_layers: 2 +trainer: + mp: f32 + num_train_steps: 100 + + checkpointer: + keep: + - every: 50 + save_interval: 5m + + per_device_eval_parallelism: 1 + per_device_parallelism: 1 + train_batch_size: 32 + + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + tracker: + type: tensorboard + logdir: tb_logs/ diff --git a/config/gpt2_small.yaml b/config/gpt2_small.yaml index 74d0e031a..c657fe787 100644 --- a/config/gpt2_small.yaml +++ b/config/gpt2_small.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_small_fast.yaml b/config/gpt2_small_fast.yaml index 4c8434f38..6242a37bc 100644 --- a/config/gpt2_small_fast.yaml +++ b/config/gpt2_small_fast.yaml @@ -8,9 +8,10 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: - project: "levanter" - tags: [ "openwebtext", "gpt2", "itest"] + tracker: + - type: wandb + project: "levanter" + tags: [ "openwebtext", "gpt2", "itest"] mp: p=f32,c=bfloat16 model_axis_size: 1 diff --git a/config/gpt2_small_fast_mix.yaml b/config/gpt2_small_fast_mix.yaml index 0785e9103..ca9fa2ca6 100644 --- a/config/gpt2_small_fast_mix.yaml +++ b/config/gpt2_small_fast_mix.yaml @@ -21,7 +21,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext+wiki", "gpt2", "itest"] diff --git a/config/gpt2_small_fast_pile.yaml b/config/gpt2_small_fast_pile.yaml index f30743c1d..a0336da45 100644 --- a/config/gpt2_small_fast_pile.yaml +++ b/config/gpt2_small_fast_pile.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "pile", "gpt2", "itest"] diff --git a/config/gpt2_small_fast_sophia_g.yaml b/config/gpt2_small_fast_sophia_g.yaml new file mode 100644 index 000000000..0f86ac503 --- /dev/null +++ b/config/gpt2_small_fast_sophia_g.yaml @@ -0,0 +1,24 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "itest", "sophia-g"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: 8 + + train_batch_size: 256 + num_train_steps: 20000 +optimizer: + type: sophia-g + learning_rate: 1E-3 + weight_decay: 0.15 diff --git a/config/gpt2_small_fast_sophia_h.yaml b/config/gpt2_small_fast_sophia_h.yaml new file mode 100644 index 000000000..671acec8f --- /dev/null +++ b/config/gpt2_small_fast_sophia_h.yaml @@ -0,0 +1,24 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "itest", "sophia-h"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: 8 + + train_batch_size: 256 + num_train_steps: 20000 +optimizer: + type: sophia-h + learning_rate: .85E-3 + weight_decay: 0.2 diff --git a/config/gpt2_small_fast_sophiah.yaml b/config/gpt2_small_fast_sophiah.yaml new file mode 100644 index 000000000..71675312c --- /dev/null +++ b/config/gpt2_small_fast_sophiah.yaml @@ -0,0 +1,26 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + wandb: + project: "levanter" + tags: [ "openwebtext", "gpt2", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: -1 + + train_batch_size: 256 + num_train_steps: 20000 +optimizer: + type: sophia-h + learning_rate: 0.8E-3 + weight_decay: 0.1 + warmup: 0.01 + gamma: 0.005 diff --git a/config/gpt2_small_fast_wiki.yaml b/config/gpt2_small_fast_wiki.yaml index 407d8705b..a25736434 100644 --- a/config/gpt2_small_fast_wiki.yaml +++ b/config/gpt2_small_fast_wiki.yaml @@ -9,7 +9,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2", "itest"] diff --git a/config/gpt2_small_pile.yaml b/config/gpt2_small_pile.yaml new file mode 100644 index 000000000..19512c3dd --- /dev/null +++ b/config/gpt2_small_pile.yaml @@ -0,0 +1,23 @@ +data: !include data/pile_source_old.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 2048 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + tracker: + project: "levanter" + tags: [ "pile", "gpt2"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: 8 + + train_batch_size: 256 + num_train_steps: 50000 +optimizer: + learning_rate: 6e-4 + weight_decay: 0.1 diff --git a/config/gpt2_small_pile_mixture.yaml b/config/gpt2_small_pile_mixture.yaml index e02e4bd1f..a79ec8052 100644 --- a/config/gpt2_small_pile_mixture.yaml +++ b/config/gpt2_small_pile_mixture.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "pile", "gpt2"] diff --git a/config/gpt2_small_sophiah.yaml b/config/gpt2_small_sophiah.yaml new file mode 100644 index 000000000..fd82ab226 --- /dev/null +++ b/config/gpt2_small_sophiah.yaml @@ -0,0 +1,19 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + tracker: + project: "levanter" + tags: [ "openwebtext", "gpt2", "sophia-h"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + + train_batch_size: 512 +optimizer: !include optim/sophia-h_small.yaml diff --git a/config/gpt2_xl.yaml b/config/gpt2_xl.yaml index 8230b56a5..a58c7ceb0 100644 --- a/config/gpt2_xl.yaml +++ b/config/gpt2_xl.yaml @@ -8,11 +8,11 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 - per_device_parallelism: 1 + per_device_parallelism: -1 optimizer: learning_rate: 1E-4 weight_decay: 0.1 diff --git a/config/llama2_7b.yaml b/config/llama2_7b.yaml index 68931f3fa..b4ebe705f 100644 --- a/config/llama2_7b.yaml +++ b/config/llama2_7b.yaml @@ -11,7 +11,8 @@ model: # initialize_from_hf: "meta-llama/Llama-2-7b-hf" # use_hf_model_config: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["openwebtext", "llama"] diff --git a/config/llama2_7b_continued.yaml b/config/llama2_7b_continued.yaml index e03be7168..edb72a7e4 100644 --- a/config/llama2_7b_continued.yaml +++ b/config/llama2_7b_continued.yaml @@ -6,7 +6,8 @@ model: initialize_from_hf: true use_hf_model_config: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["pile", "llama2"] diff --git a/config/llama2_nano.yaml b/config/llama2_nano.yaml index c3ae4cdb8..58415022e 100644 --- a/config/llama2_nano.yaml +++ b/config/llama2_nano.yaml @@ -12,7 +12,7 @@ model: num_kv_heads: 4 num_layers: 2 trainer: - wandb: + tracker: project: "levanter" tags: ["openwebtext", "llama"] mp: p=f32 diff --git a/config/lora/mpt_biomed.yaml b/config/lora/mpt_biomed.yaml index f49267ca1..6b19d0ab5 100644 --- a/config/lora/mpt_biomed.yaml +++ b/config/lora/mpt_biomed.yaml @@ -11,7 +11,8 @@ lora: alpha: 32.0 target_modules: ["Wqkv"] trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["mpt", "lora", "pubmed"] diff --git a/config/mpt_7b_continued.yaml b/config/mpt_7b_continued.yaml deleted file mode 100644 index a7eaf800b..000000000 --- a/config/mpt_7b_continued.yaml +++ /dev/null @@ -1,22 +0,0 @@ -data: !include data/pile_source_old.yaml -model: - type: mpt -initialize_from_hf: true -use_hf_model_config: true -trainer: - wandb: - project: "levanter" - tags: ["pile", "mpt"] - - mp: p=f32,c=bfloat16 - - model_axis_size: 1 - per_device_parallelism: 4 - per_device_eval_parallelism: 4 - - train_batch_size: 1024 - num_train_steps: 10000 - steps_per_eval: 500 -optimizer: - learning_rate: 1.2e-4 - weight_decay: 0.1 diff --git a/config/mpt_7b_continued_biomedlm.yaml b/config/mpt_7b_continued_biomedlm.yaml deleted file mode 100644 index 44961df46..000000000 --- a/config/mpt_7b_continued_biomedlm.yaml +++ /dev/null @@ -1,27 +0,0 @@ -data: - train_urls: - - "gs://pubmed-mosaic/pubmed-sharded/pubmedRandomized_train.{1..128}-of-128.jsonl.gz" - validation_urls: - - "gs://pubmed-mosaic/pubmed-sharded/pubmedRandomized_val.{1..8}-of-8.jsonl.gz" - cache_dir: "gs://pubmed-mosaic/tokenized/pubmed-sharded-neox/" - tokenizer: "EleutherAI/gpt-neox-20b" -model: - type: mpt -initialize_from_hf: "mosaicml/mpt-7b@68e1a8e0ebb9b30f3c45c1ef6195980f29063ae2" -use_hf_model_config: true -trainer: - wandb: - project: "levanter" - tags: ["pubmed", "mpt", "continued"] - - mp: p=f32,c=bfloat16 - - model_axis_size: 1 - per_device_parallelism: 8 - - train_batch_size: 2048 - num_train_steps: 50000 - steps_per_eval: 1000 -optimizer: - learning_rate: 1.2e-5 - weight_decay: 0.1 diff --git a/config/optim/sophia-h_large.yaml b/config/optim/sophia-h_large.yaml new file mode 100644 index 000000000..6644f20b8 --- /dev/null +++ b/config/optim/sophia-h_large.yaml @@ -0,0 +1,7 @@ +type: sophia-h +learning_rate: 3E-4 +weight_decay: 0.2 +min_lr_ratio: 0.1 +gamma: 0.01 +# sophia needs a minimum amount of warmup or it doesn't do well +warmup: 2000 diff --git a/config/optim/sophia-h_medium.yaml b/config/optim/sophia-h_medium.yaml new file mode 100644 index 000000000..5c411f109 --- /dev/null +++ b/config/optim/sophia-h_medium.yaml @@ -0,0 +1,7 @@ +type: sophia-h +learning_rate: 4E-4 +weight_decay: 0.2 +min_lr_ratio: 0.1 +gamma: 0.01 +# sophia needs a minimum amount of warmup or it doesn't do well +warmup: 2000 diff --git a/config/optim/sophia-h_small.yaml b/config/optim/sophia-h_small.yaml new file mode 100644 index 000000000..0bb8ea2a7 --- /dev/null +++ b/config/optim/sophia-h_small.yaml @@ -0,0 +1,7 @@ +type: sophia-h +learning_rate: 6E-4 +weight_decay: 0.2 +min_lr_ratio: 0.1 +gamma: 0.01 +# sophia needs a minimum amount of warmup or it doesn't do well +warmup: 2000 diff --git a/config/optim/sophia-h_xl.yaml b/config/optim/sophia-h_xl.yaml new file mode 100644 index 000000000..fe2c868b3 --- /dev/null +++ b/config/optim/sophia-h_xl.yaml @@ -0,0 +1,7 @@ +type: sophia-h +learning_rate: 1.2E-4 +weight_decay: 0.2 +min_lr_ratio: 0.1 +gamma: 0.01 +# sophia needs a minimum amount of warmup or it doesn't do well +warmup: 2000 diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index 7927da154..8336b1eb8 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -35,7 +35,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] @@ -179,12 +180,34 @@ The default step-based checkpoint policy is to save a checkpoint every 10,000 st -## WandB +## Trackers and Logging -We mostly use wandb for logging, including using wandb for allocating the run id. We may change this. -These all live in a nested object `wandb` inside `trainer`. Most of these are the same as the corresponding `wandb.init` -parameters. +We mostly use [W&B](https://wandb.ai/site) for tracking values and other metadata about a run. However, we also support +Tensorboard and a few other trackers. You can also use multiple trackers at once, or even write your own. +See [Trackers](dev/Trackers.md) for more information. + +### W&B + +Wandb is the default tracker and is installed by default. To use it, you can configure it in your config file: + +```yaml +trainer: + tracker: + type: wandb + project: my-project + entity: my-entity +``` + +Because wandb is the default, you can also just do: + +```yaml +trainer: + tracker: + project: my-project + entity: my-entity +``` + | Parameter | Description | Default | @@ -206,6 +229,35 @@ of your main script. To use it, you must also set the right environment variables. Something like `XLA_FLAGS="--xla_dump_to=/tmp/output_folder/xla_dumps --xla_dump_hlo_pass_re=.*`. We will automatically parse out the env variable. +### Tensorboard + +Tensorboard is also supported. To use it, you can configure it in your config file: + +```yaml +trainer: + tracker: + type: tensorboard + logdir: logs +``` + +### Multiple Trackers + +In some cases, you may want to use multiple trackers at once. +For example, you may want to use both W&B and Tensorboard. + +To do this, you can use the [levanter.tracker.tracker.CompositeTracker][] class, or, if using a config file, you +can specify multiple trackers: + +```yaml +trainer: + tracker: + - type: wandb + project: my-project + entity: my-entity + - type: tensorboard + logdir: logs +``` + ## Ray Config Levanter will by default automatically start a Ray cluster with all @@ -213,11 +265,11 @@ the machines being used for training. This is useful for distributed preprocessing. You can disable this behavior using `auto_start_cluster: false`. -| Parameter | Description | Default | -|---------------------|-----------------------------------------------------------------------------|---------| -| `address` | The address of the Ray cluster to connect to. | `None` | -| `start_workers` | Whether to start Ray workers. If `False`, you must start them yourself. | `True` | -| `auto_start_cluster`| Whether to start a Ray cluster automatically. | `True` | +| Parameter | Description | Default | +|----------------------|-------------------------------------------------------------------------|---------| +| `address` | The address of the Ray cluster to connect to. | `None` | +| `start_workers` | Whether to start Ray workers. If `False`, you must start them yourself. | `True` | +| `auto_start_cluster` | Whether to start a Ray cluster automatically. | `True` | ## Distributed Config @@ -227,18 +279,18 @@ If you're not using SLURM or TPUs, you can specify the cluster manually using th **Don't use this on TPU, and possibly not on SLURM either.** -| Parameter | Description | Default | -|---------------------|-----------------------------------------------------------------------------|-------------------------| -| `coordinator_address`| The address of the coordinator. If `None`, we'll use the default address. | `None` | -| `num_processes` | The number of processes in the cluster. | `None` | -| `process_id` | The process id of this process. | `None` | -| `local_device_ids` | The local device ids of this process. | ${CUDA_VISIBLE_DEVICES} | +| Parameter | Description | Default | +|-----------------------|---------------------------------------------------------------------------|-------------------------| +| `coordinator_address` | The address of the coordinator. If `None`, we'll use the default address. | `None` | +| `num_processes` | The number of processes in the cluster. | `None` | +| `process_id` | The process id of this process. | `None` | +| `local_device_ids` | The local device ids of this process. | ${CUDA_VISIBLE_DEVICES} | ## Optimizer -[levanter.trainer.OptimizerConfig][] is a dataclass that specifies the optimizer configuration. It has the following fields: +[levanter.optim.OptimizerConfig][] is a dataclass that specifies the optimizer configuration. It has the following fields: | Parameter | Description | Default | |-----------------|-------------------------------------------------------------------|----------| @@ -277,8 +329,26 @@ We won't go into detail here. You can see the auto-generated docs below. ::: levanter.checkpoint.Checkpointer -### Wandb -::: levanter.logging.WandbConfig +### Trackers and Metrics + +See also [Trackers](dev/Trackers.md) for more information. Basic configuration is shown below. + +#### Single Tracker + +```yaml +trainer: + tracker: + type: wandb + project: my-project + entity: my-entity +``` + + + +::: levanter.tracker.wandb.WandbConfig + +::: levanter.tracker.tensorboard.TensorboardConfig + ### Distributed and Ray @@ -288,7 +358,7 @@ We won't go into detail here. You can see the auto-generated docs below. ### Optimizer -::: levanter.trainer.OptimizerConfig +::: levanter.optim.OptimizerConfig ### LM Model diff --git a/docs/Levanter-1.0-Release.md b/docs/Levanter-1.0-Release.md index 8fed293dd..05c66683a 100644 --- a/docs/Levanter-1.0-Release.md +++ b/docs/Levanter-1.0-Release.md @@ -539,7 +539,7 @@ learn differently from Transformers. ## A few other features * **Training**: Levanter uses [Optax](https://github.com/deepmind/optax) for optimization, - though our new optimizer, [Sofia](https://arxiv.org/abs/2305.14342), is coming to Levanter soon! + though our new optimizer, [Sophia](https://arxiv.org/abs/2305.14342), is coming to Levanter soon! * **Logging**: Logging is done with [WandB](https://wandb.ai/), complete with a fancy online visualization of the validation set during training. * **Checkpointing**: Distributed checkpointing is supported via Google's [TensorStore](https://google.github.io/tensorstore/) library. Training can even be resumed on a different number of hosts, though this breaks reproducibility for now. * **Export**: We also support exporting models to the Hugging Face Hub, with export compatible with Pytorch and Transformers via [SafeTensors](https://github.com/huggingface/safetensors). @@ -627,7 +627,7 @@ trained on the [Lakh MIDI](https://colinraffel.com/projects/lmd/) corpus. The la This is just the beginning for Levanter. In the future, look for: * more models on interesting problem domains, * scaled up versions of new architectures developed here at Stanford and elsewhere, -* new training techniques, including the newly released [Sofia](https://arxiv.org/abs/2305.14342) optimizer, +* new training techniques, including the newly released [Sophia](https://arxiv.org/abs/2305.14342) optimizer, * and larger models! Levanter is still a work in progress, but we are excited to share it with the community. We hope that Levanter will be diff --git a/docs/LoRA.md b/docs/LoRA.md index ddbb32358..4d8cdf099 100644 --- a/docs/LoRA.md +++ b/docs/LoRA.md @@ -107,6 +107,7 @@ parameters are sharded correctly, if you're using more than one device. @dataclass class TrainArgs: lora: LoraConfig = LoraConfig() + trainer: TrainerConfig = TrainerConfig() # ... some other stuff hf_save_path: Optional[str] = None # Path to save the HuggingFace checkpoint. @@ -120,7 +121,7 @@ class TrainArgs: def train(config: TrainArgs): ... - with config.trainer.device_mesh: + with Trainer(config.trainer, optimizer) as trainer: ... @hax.named_jit(axis_resources=parameter_axis_mapping, donate_args=(True)) @@ -143,15 +144,12 @@ using the `lora_trainable_params_filter` function, which takes a model and retur ```python def train(config: TrainArgs): ... - with config.trainer.device_mesh: + with Trainer(config.trainer, optimizer) as trainer: ... lora_param_filter = lora_trainable_params_filter(model) - def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() - - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) + state = trainer.initial_state(training_key, model=model, is_trainable=lora_param_filter) ``` ### 3. Serialize a PEFT-compatible checkpoint diff --git a/docs/Training-On-Your-Data.md b/docs/Training-On-Your-Data.md index edf33e0af..4c543b04f 100644 --- a/docs/Training-On-Your-Data.md +++ b/docs/Training-On-Your-Data.md @@ -214,7 +214,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" # TODO tags: ["gpt2"] diff --git a/docs/design/Trainer-Abstraction.md b/docs/design/Trainer-Abstraction.md new file mode 100644 index 000000000..f03ec286a --- /dev/null +++ b/docs/design/Trainer-Abstraction.md @@ -0,0 +1,157 @@ +# Trainer Abstraction Cleanup + +## Current Status (2024-01-23) + +### Trainer's current jobs + +Trainer currently has these jobs: + +* Handle registering and running callbacks +* Handle checkpointing (delegated to `Checkpointer`) +* Handling initialization, including loading from a checkpoint (partially delegated to `Checkpointer`) +* train_step/train_steps/training loop +* holding execution environment details (e.g. mixed precision policy, device mesh, etc) +* handles making data loaders (with the right sharding etc) +* sets up microbatching/grad accum (mostly factored out into nice pieces +* actually taking the step + +It would be nice if this were orthogonalized as much as possible. + +Hooks are already mostly delegated out to TrainerHooks so that's not too bad, and checkpoints are well encapsulated in the Checkpointer class, +other than the initialization/resume logic. + +Execution Environment is just work to abstract, and dovetails well with other work (i.e. just-in-time mixed precision). + +A lot of changes live in the doremi branch, because it needs an augmented trainer state to do its work + + + +### Other things that bother me + +* the cached_property loss_fn is smelly and actually behaves badly because jit(grad(jit(f))) doesn't work well +* I don't love the story for extending TrainerState + +### TrainerState extension + +We want TrainerState to be extensible, which means that: + +* it needs to inheritable +* methods like train_step need to be able to be overridden +* core "train_step" logic needs to be reusable (e.g. the logic for accumulating gradients, etc) in a way that + returns the right type (e.g. TrainerState or subclass) + +In Haliax we do initialization with a static/classmethod on modules, rather than the ctor. It's useful to have +a "plain old constructor" for various modules + +## Initialization/Resume + + +### Requirements + +There are 3 core cases to consider: + +1. No checkpoint, initialize from scratch +2. Checkpoint exists, load checkpoint and initialize "unserialized"/missing state +3. Partial checkpoint exists (e.g. only model), load checkpoint and initialize "unserialized"/missing state + +Typically, (3) is a full checkpoint, but we only want to load the model. This is useful for things like +fine-tuning a model, where we want to load the model but not the optimizer state. + +On top of that, we currently differentiate between passing in a model_init function and a model. This +complicates things a bit, but model_init is preferred because: + +1. it's more memory/time efficient when initializing from checkpoint +2. it makes it easier to get sharding and mixed precision right immediately. + +For (1), I think the time isn't a big deal, but we need a way of dealing +with the memory. One could maybe delete the passed in model (preserving only the shape) +once we determine the checkpoint exists? + +For (2), we also want to get the mixed precision and sharding set up correctly immediately. Passing in a model_init +allows us to wrap it in the right jit and partition magic to get that right. +We can and should expose (2) as a function... + + +Another complexity is `is_trainable`, which is a FilterSpec that allows you to specify which parts of the model +are trainable. This is useful for things like fine-tuning, where you want to freeze some layers. We use is_trainable in +4 ways: + +* only the is_trainable parts of a model get an optimizer_state associated with them +* we only serialize/deserialize the is_trainable parts of a model +* we only compute gradients for the is_trainable parts of a model +* We store the non-trainable parts of the model in compute precision, and the trainable parts in the param precision + +### Current Initialization w/o checkpoint + +This is conceptually what happens when there is no checkpointing: + +```python +@hax.named_jit(out_axis_resources=parameter_mapping) +def initialize(optimizer, model_init, is_trainable, mp): + model = model_init() + trainable_model = eqx.filter(model, is_trainable) + optimizer_state = optimizer.init(trainable_model) + + model = _cast_model_by_trainability(model, is_trainable, mp) + + state = TrainerState( + _step=0, + model=model, + optimizer_state=optimizer_state, + is_trainable=is_trainable, + ) + + state = hax.shard(state, parameter_mapping) + + return state + + +def _cast_model_by_trainability(model, is_trainable, mp): + trainable_model, non_trainable_model = eqx.partition(model, is_trainable) + non_trainable_model = mp.cast_to_compute(non_trainable_model) + trainable_model = mp.cast_to_param(trainable_model) + model = eqx.combine(trainable_model, non_trainable_model) + return model +``` + + + +### Current logic for initialization w/ checkpoint + +The logic for initial_state is pretty complex. There are 3 cases to consider: + +1. No checkpoint, initialize from scratch +2. Checkpoint exists, load checkpoint and initialize "unserialized"/missing state +3. Partial checkpoint exists (e.g. only model), load checkpoint and initialize "unserialized"/missing state + +At the moment the flow is: + +```python + +state_shape = eval_shape(_initialize_from_scratch(model_init_fn)) +if checkpoint_exists: + partial_state = load_checkpoint(state_shape, path) +elif partial_checkpoint_exists: + partial_checkpoint = load_checkpoint(state_shape.model, path, subpath="model") + partial_state = dataclasses.replace(partial_state, model=partial_checkpoint) + +state = jit(lambda s: combine(s, _initialize_from_scratch(model_init_fn)), partial_state) +``` + +I'd like to hoist this out so it's not dependent on the Trainer class, and so that it's easier to test. + +One of the things I was trying to accomplish was to define a checkpointed_or_initialize function that was just + +```python +state_shape = eval_shape(f) +if checkpoint_exists: + partial_state = load_checkpoint(state_shape, path) +else: + partial_state = eqx.filter(state_shape, lamba v: False) + +state = jit(lambda s: combine(s, f()), partial_state) + +``` + +But this doesn't actually compose well: you can't really do IO inside of eval_shape, so you can't really combine two +of those... or can you diff --git a/docs/Port-Models.md b/docs/dev/Port-Models.md similarity index 98% rename from docs/Port-Models.md rename to docs/dev/Port-Models.md index f75fa7534..41228c2a3 100644 --- a/docs/Port-Models.md +++ b/docs/dev/Port-Models.md @@ -287,7 +287,7 @@ model: num_layers: 2 ``` -For more details on the training configuration, please refer to [Configuration Guide](./Configuration-Guide.md). +For more details on the training configuration, please refer to [Configuration Guide](../Configuration-Guide.md). ### Launch Training Job Once you have your training configuration ready and your training environment set up, you can launch a training job with the following command: @@ -299,7 +299,7 @@ HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \ python levanter/src/levanter/main/train_lm.py --config_path $CONFIG_PATH ``` -Check out [Training on Your Own Data](./Training-On-Your-Data.md) for more detailed guide on how to spin off a training cluster and launch a training job. +Check out [Training on Your Own Data](../Training-On-Your-Data.md) for more detailed guide on how to spin off a training cluster and launch a training job. ### Profile Your Model If you are interested in profiling the training throughput of your model, good news is that it comes for free with automatic job monitoring in Levanter, powered through Weights & Biases. diff --git a/docs/dev/Trackers.md b/docs/dev/Trackers.md new file mode 100644 index 000000000..1f1677d52 --- /dev/null +++ b/docs/dev/Trackers.md @@ -0,0 +1,104 @@ +# Trackers and Metrics + +Logging values and other metadata about a run is a core requirement for any ML framework. +Until recently, Levanter had a hard dependency on [W&B](https://wandb.ai/site) for tracking such values. + +In the latest version, we introduce the [levanter.tracker.Tracker][] interface, which allows you to use any tracking backend you want. +The interface name is taken from the [HuggingFace Accelerate](https://github.com/huggingface/accelerate/blob/0f2686c8d3e6d949c4b7efa15d7f2dee44f7ce91/src/accelerate/tracking.py#L395) +framework. + +Given Levanter's historical dependency on W&B, the interface is designed to look similar to W&B's API. +The methods currently exposed are: + +* [levanter.tracker.current_tracker][]: returns the current tracker instance or sets it. +* [levanter.tracker.log_metrics][]: logs a dictionary of metrics for a given step. +* [levanter.tracker.log_summary][]: logs a dictionary of "summary" information, analogous to W&B's version. +* [levanter.tracker.get_tracker][]: returns a tracker with the given name. +* [levanter.tracker.jit_log_metrics][]: a version of [levanter.tracker.log_metrics][] that works inside JAX jit. + +A basic example of using the tracker interface is shown below: + +```python +import wandb +from levanter.tracker import current_tracker, log_metrics, log_summary +from levanter.tracker.wandb import WandbTracker + +with current_tracker(WandbTracker(wandb.init())): + for step in range(100): + log_metrics({"loss": 100 -0.01 * step}, step=step) + + log_summary({"best_loss": 0.0}) +``` + +A more typical example would be to use it in a config file, as we do with Trainer: + +```yaml +trainer: + tracker: + type: wandb + project: my-project + entity: my-entity +``` + +### Multiple Trackers + +In some cases, you may want to use multiple trackers at once. +For example, you may want to use both W&B and Tensorboard. + +To do this, you can use the [levanter.tracker.tracker.CompositeTracker][] class, or, if using a config file, you +can specify multiple trackers: + +```yaml +trainer: + tracker: + - type: wandb + project: my-project + entity: my-entity + - type: tensorboard + logdir: logs +``` + +## Adding your own tracker + +To add your own tracker, you need to implement the [levanter.tracker.Tracker][] interface. +You will also want to register your config with TrackerConfig as a "choice" in the choice type. +Follow the pattern for Tensorboard and W&B. + +TODO: expand this section. + + +## API Reference + +### Core Functions + +::: levanter.tracker.current_tracker + +::: levanter.tracker.log_metrics + +::: levanter.tracker.log_summary + +::: levanter.tracker.get_tracker + +::: levanter.tracker.jit_log_metrics + +### Trackers + +::: levanter.tracker.Tracker + +::: levanter.tracker.tracker.CompositeTracker + +::: levanter.tracker.tracker.NoopTracker + +::: levanter.tracker.tensorboard.TensorboardTracker + +::: levanter.tracker.wandb.WandbTracker + +### Tracker Config + +::: levanter.tracker.TrackerConfig + +::: levanter.tracker.tracker.NoopConfig + +::: levanter.tracker.tensorboard.TensorboardConfig + +::: levanter.tracker.wandb.WandbConfig diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index a4380a92b..c83beddaa 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -6,9 +6,9 @@ from dataclasses import dataclass from typing import Optional +import equinox as eqx import jax.random as jrandom import transformers -import wandb import haliax as hax @@ -21,7 +21,7 @@ save_merged_hf_checkpoint_callback, save_peft_checkpoint_callback, ) -from levanter.models.lm_model import LmExample, LmHeadModel +from levanter.models.lm_model import LmHeadModel from levanter.trainer import Trainer from levanter.utils.jax_utils import parameter_count from levanter.utils.py_utils import non_caching_cycle @@ -49,7 +49,7 @@ class TrainArgs(alpaca.TrainArgs): def train(config: TrainArgs): - config.trainer.initialize(config) + levanter.initialize(config) # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. @@ -80,7 +80,9 @@ def train(config: TrainArgs): optimizer = config.optimizer.build(config.trainer.num_train_steps) - with config.trainer.device_mesh: + # end major difference from Alpaca + + with Trainer(config.trainer, optimizer) as trainer: # how we shard parameters across devices parameter_axis_mapping = config.trainer.parameter_axis_mapping @@ -98,22 +100,20 @@ def loraize_hf_model(model): lora_param_filter = lora_trainable_params_filter(model) - def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() - - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) - - # end major difference from Alpaca - - trainer.add_default_hooks() - state = trainer.initial_state(training_key, model=model) + state = trainer.initial_state(training_key, model=model, is_trainable=lora_param_filter) # log some info about the model all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) + just_lora_params = parameter_count(eqx.filter(state.model, lora_param_filter)) + + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params logger.info(f"Total parameter count: {all_param_count}") logger.info(f"Trainable parameter count: {just_lora_params}") logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index e02b0738a..58d5bc17c 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -18,7 +18,8 @@ from levanter.data import Dataset from levanter.data.sharded_dataset import JsonDataset, JsonlDataset, WrappedHFDataset from levanter.models.lm_model import LmExample, LmHeadModel -from levanter.trainer import OptimizerConfig, Trainer, TrainerConfig +from levanter.optim import OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig from levanter.utils import fsspec_utils from levanter.utils.hf_utils import num_cpus_used_by_tokenizer from levanter.utils.py_utils import non_caching_cycle @@ -193,7 +194,7 @@ def get_prompts(prompt_path) -> dict: def train(config: TrainArgs): - config.trainer.initialize(config) + levanter.initialize(config) # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. @@ -226,12 +227,7 @@ def train(config: TrainArgs): optimizer = config.optimizer.build(config.trainer.num_train_steps) - def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() - - trainer = Trainer(config.trainer, optimizer, compute_loss) - - with trainer.device_mesh: + with Trainer(config.trainer, optimizer) as trainer: # how we shard parameters across devices parameter_axis_mapping = trainer.parameter_axis_mapping @@ -248,7 +244,6 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) loader = non_caching_cycle(loader) - trainer.add_default_hooks() state = trainer.initial_state(training_key, model=model) if state.step != 0: diff --git a/examples/gsm8k-lora/gsm8k_lora.py b/examples/gsm8k-lora/gsm8k_lora.py index 6b369bf77..0f4ba005f 100644 --- a/examples/gsm8k-lora/gsm8k_lora.py +++ b/examples/gsm8k-lora/gsm8k_lora.py @@ -5,11 +5,11 @@ from dataclasses import dataclass from typing import Optional, Union +import equinox as eqx import jax import jax.random as jrandom import numpy as np import transformers -import wandb import haliax as hax @@ -25,7 +25,8 @@ save_peft_checkpoint_callback, ) from levanter.models.lm_model import LmExample, LmHeadModel -from levanter.trainer import OptimizerConfig, Trainer, TrainerConfig +from levanter.optim import OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig from levanter.utils.hf_utils import num_cpus_used_by_tokenizer from levanter.utils.jax_utils import parameter_count from levanter.utils.py_utils import non_caching_cycle @@ -88,7 +89,7 @@ def __iter__(self): else: loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32) - yield LmExample(input_ids, loss_mask) + yield LmExample.causal(input_ids, loss_mask=loss_mask) def mk_dataset(config: TrainArgs, tokenizer: transformers.PreTrainedTokenizerBase): @@ -126,7 +127,7 @@ def format_output(ex): def train(config: TrainArgs): - config.trainer.initialize(config) + levanter.initialize(config) # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. @@ -147,7 +148,7 @@ def train(config: TrainArgs): optimizer = config.optimizer.build(config.trainer.num_train_steps) - with config.trainer.device_mesh: + with Trainer(config.trainer, optimizer) as trainer: # how we shard parameters across devices parameter_axis_mapping = config.trainer.parameter_axis_mapping @@ -165,22 +166,20 @@ def loraize_hf_model(model): lora_param_filter = lora_trainable_params_filter(model) - def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() - - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) - - # end major difference from Alpaca - - trainer.add_default_hooks() - state = trainer.initial_state(training_key, model=model) + state = trainer.initial_state(training_key, model=model, is_trainable=lora_param_filter) # log some info about the model all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) + just_lora_params = parameter_count(eqx.filter(state.model, lora_param_filter)) + + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params logger.info(f"Total parameter count: {all_param_count}") logger.info(f"Trainable parameter count: {just_lora_params}") logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") diff --git a/infra/helpers/setup-tpu-vm.sh b/infra/helpers/setup-tpu-vm.sh index e1facd8cb..d7fc87653 100755 --- a/infra/helpers/setup-tpu-vm.sh +++ b/infra/helpers/setup-tpu-vm.sh @@ -93,7 +93,7 @@ pip install -U wheel # jax and jaxlib # libtpu sometimes has issues installing for clinical (probably firewall?) #retry pip install -U "jax[tpu]==0.4.5" libtpu-nightly==0.1.dev20230216 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -retry pip install -U "jax[tpu]==0.4.21" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +retry pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # clone levanter git clone $REPO levanter diff --git a/mkdocs.yml b/mkdocs.yml index 35fdaf5c4..28fdb9849 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -97,7 +97,8 @@ nav: - "tutorials/Fine-Tuning-Semantic-Parsing.md" - "Hardware-Agnostic-Training.md" - 'Developer Guide': - - 'Port-Models.md' + - 'dev/Port-Models.md' + - 'dev/Trackers.md' - 'FAQ' : 'faq.md' - Other: - 'Levanter-1.0-Release.md' diff --git a/pyproject.toml b/pyproject.toml index 14f010c1b..657ae21dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,13 +25,13 @@ dependencies = [ # jax = {version = ">=0.4.10,<0.5.0"} # "haliax>=1.3,<2.0", # Haliax changes in step with levanter, so we'll just use the git version except for releases. - "haliax @ git+https://github.com/stanford-crfm/haliax.git", + "haliax @ git+https://github.com/stanford-crfm/haliax.git@jamp", "equinox>=0.10.7", "jaxtyping>=0.2.20", "transformers>=4.22.0", "optax", "wandb", - "draccus>=0.6", + "draccus>=0.7.1", "pyarrow>=11.0.0", "zstandard>=0.20.0", "datasets==2.16.1", diff --git a/scripts/launch_gpt2_small_fast_tpu.sh b/scripts/launch_gpt2_small_fast_tpu.sh index ed64660f7..bfb55fc23 100644 --- a/scripts/launch_gpt2_small_fast_tpu.sh +++ b/scripts/launch_gpt2_small_fast_tpu.sh @@ -12,7 +12,7 @@ fi echo "Launching GPT2 small fast on TPU with git branch $GIT_BRANCH" bash infra/babysit-tpu-vm.sh levanter-itest-32 -p -z us-east1-d -t v3-32 -b $GIT_BRANCH -- \ + HUGGING_FACE_HUB_TOKEN=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 \ XLA_FLAGS="--xla_dump_to=/tmp/output_folder/xla_dumps --xla_dump_hlo_pass_re=.*" \ WANDB_API_KEY=$WANDB_API_KEY levanter/infra/run.sh python levanter/src/levanter/main/train_lm.py \ - --config_path levanter/config/gpt2_small_fast.yaml \ --trainer.checkpointer.base_path gs://levanter-checkpoints/gpt-itest/ --trainer.checkpointer.save_interval 30m $* diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index 33bcd249d..ecabba8df 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -3,4 +3,10 @@ import levanter.data as data import levanter.distributed as distributed import levanter.logging as logging +import levanter.models as models +import levanter.optim as optim +import levanter.tracker as tracker +import levanter.trainer as trainer import levanter.visualization as visualization +from levanter.tracker import current_tracker, get_tracker +from levanter.trainer import initialize diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 2292c714a..0c462a997 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -1,5 +1,5 @@ import copy -import logging +import logging as pylogging import os import re import subprocess @@ -11,20 +11,24 @@ import humanfriendly import jax -import wandb from tqdm import tqdm -from levanter.logging import WandbConfig, log_optimizer_hyperparams, save_xla_dumps_to_wandb +import levanter.tracker +from levanter.logging import save_xla_dumps_to_wandb +from levanter.tracker.helpers import log_optimizer_hyperparams +from levanter.tracker.wandb import WandbConfig from levanter.trainer import StepInfo from levanter.utils.jax_utils import jnp_to_python from levanter.visualization import compute_and_visualize_log_probs as viz_probs -logger = logging.getLogger(__name__) +logger = pylogging.getLogger(__name__) def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None): total_loss = 0.0 + total_load_time = 0.0 + total_loss_time = 0.0 n = 0 if name is not None: @@ -33,10 +37,20 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n desc = "eval" pbar = tqdm(dataset, desc=desc, position=1, leave=False, total=max_batches) - for batch in pbar: + iter_ = iter(pbar) + while True: + time_in = time.time() + batch = next(iter_, None) + if batch is None: + break + load_time = time.time() - time_in + total_load_time += load_time loss = loss_fn(model, batch) total_loss += loss.item() n += 1 + loss_time = time.time() - time_in - load_time + total_loss_time += loss_time + pbar.set_postfix(loss=total_loss / n) if max_batches is not None and n >= max_batches: @@ -45,6 +59,9 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n if n > 0: total_loss /= n + logger.info(f"eval loading time: {total_load_time / n:.3f} s/ba") + logger.info(f"eval loss time: {total_loss_time / n:.3f} s/ba") + return total_loss @@ -57,11 +74,10 @@ def compute_validation_loss( def compute_loss(info: StepInfo): loss = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name) - if wandb.run is not None: - prefix = "eval" - if name: - prefix += "/" + name - wandb.log({f"{prefix}/loss": loss}, step=info.step) + prefix = "eval" + if name: + prefix += "/" + name + levanter.tracker.log_metrics({f"{prefix}/loss": loss}, step=info.step) if name: logger.info(f"{name} validation loss: {loss:.3f}") @@ -73,12 +89,14 @@ def compute_loss(info: StepInfo): return compute_loss -def log_to_wandb(step: StepInfo): - wandb.log({"train/loss": step.loss, "global_step": step.step}, step=step.step) +def log_step_info(step: StepInfo): + levanter.tracker.log_metrics({"train/loss": step.loss, "global_step": step.step}, step=step.step) log_optimizer_hyperparams(step.opt_state, step=step.step, prefix="optim") def wandb_xla_logger(config: WandbConfig): + import wandb + last_mtime = wandb.run and wandb.run.start_time or time.time() def log_xla_to_wandb(step: StepInfo): @@ -108,14 +126,14 @@ def log_performance_stats(step_info: StepInfo): # log these totals because it's useful for comparing different seqlens, batch sizes, etc total_tokens = tokens_per_example * batch_size * step_info.step - wandb.log({wrap_key("total_tokens"): total_tokens}, step=step_info.step) + levanter.tracker.log_metrics({wrap_key("total_tokens"): total_tokens}, step=step_info.step) if flops_per_example: total_flops = flops_per_example * batch_size * step_info.step - wandb.log({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) + levanter.tracker.log_metrics({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) if step_info.step_duration != 0.0: - wandb.log( + levanter.tracker.log_metrics( { wrap_key("examples_per_second"): float(batch_size) / step_info.step_duration, wrap_key("tokens_per_second"): float(tokens_per_example) / step_info.step_duration * batch_size, @@ -125,7 +143,7 @@ def log_performance_stats(step_info: StepInfo): ) if flops_per_example is not None: - wandb.log( + levanter.tracker.log_metrics( { wrap_key("gflops_per_second"): flops_per_example / 1e9 / step_info.step_duration * batch_size, }, @@ -152,7 +170,7 @@ def update_pbar(step: StepInfo): def log_memory_usage(sample_interval: float = 1.0, log_individual_devices: bool = False): """ - Logs memory usage to wandb. This runs a loop that samples memory usage every `sample_interval` seconds. + Logs memory usage. This runs a loop that samples memory usage every `sample_interval` seconds. We only log when hooks are invoked, so there's not much point in running this much more frequently than you invoke the hook. @@ -218,7 +236,7 @@ def log_memory_usage(step: StepInfo): match = regex.search(by_kind) if match: memory_usage = humanfriendly.parse_size(match.group(1)) - wandb.log({"memory/total": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({"memory/total": memory_usage / 1e6}, step=step.step) # this works for the "kind" and the individual devices regex = re.compile(r"([\d.]+[a-zA-Z]+) \(([\d.]+)%\): ([\w\d:_]+)") @@ -229,14 +247,14 @@ def log_memory_usage(step: StepInfo): for match in regex.finditer(per_device): memory_usage = humanfriendly.parse_size(match.group(1)) device_name = match.group(3) - wandb.log({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) # now, get the memory usage per kind. # same regex as above for match in regex.finditer(by_kind): memory_usage = match.group(1) memory_usage = humanfriendly.parse_size(memory_usage) - wandb.log({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) return log_memory_usage @@ -262,6 +280,9 @@ def compute_and_viz_log_probs(step: StepInfo): path = os.path.join(html_dir, f"step_{step}.html") viz_probs(path, model, tokenizer, log_prob_fn, test_data, max_docs=max_docs) + # TODO: convert to generic logging + import wandb + wandb.log({"log_probs": wandb.Html(path)}, step=step.step) return compute_and_viz_log_probs diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 15f16a203..d29df582f 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -7,7 +7,7 @@ import urllib.parse from dataclasses import dataclass from datetime import timedelta -from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import Callable, List, Optional, Sequence, TypeVar, Union import equinox import fsspec @@ -28,8 +28,7 @@ PathLike = Union[str, pathlib.Path] -M = TypeVar("M") -S = TypeVar("S") +M = TypeVar("M", bound=PyTree) @dataclass(frozen=True) @@ -102,19 +101,15 @@ def __init__( def load_checkpoint( self, - model: M, - training_state: S, + state: M, path: Optional[PathLike] = None, *, discover_latest: bool = True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[haliax.partitioning.Mesh] = None, - ) -> Optional[Tuple[M, S, int]]: + env: Optional[haliax.ResourceEnv] = None, + ) -> Optional[M]: if path is None: path = self.base_path - return load_checkpoint( - model, training_state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh - ) + return load_checkpoint(state, path, discover_latest=discover_latest, env=env) def load_model( self, @@ -122,18 +117,16 @@ def load_model( path: Optional[str] = None, *, discover_latest: bool = True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[haliax.partitioning.Mesh] = None, - ) -> Optional[Tuple[M, int]]: - if path is None: - path = self.base_path - ckpt = load_checkpoint( - model, None, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh - ) - if ckpt is None: + env: Optional[haliax.ResourceEnv] = None, + ) -> Optional[M]: + """ + Convenience method/holdover from previous API for loading checkpoints. + Loads just the model assuming the model is in the `model` subdir of the discovered checkpoint. + """ + ret_dict = self.load_checkpoint({"model": model}, path, discover_latest=discover_latest, env=env) + if ret_dict is None: return None - model, _, step = ckpt - return model, step + return ret_dict["model"] def on_step(self, info, force: bool = False): step = info.step @@ -212,17 +205,15 @@ def _rm_checkpoint(self, checkpoint): cp_path = os.path.join(plain_path, checkpoint) logger.info(f"Deleting checkpoint {checkpoint} from {cp_path}") fs.rm(cp_path, recursive=True) - # don't let this take down a run except Exception: # pylint: disable=broad-except logger.exception("Failed to delete checkpoint", exc_info=True) def save_checkpoint(self, info, destination: str): path = os.path.join(self.base_path, destination) logger.info(f"Saving checkpoint at step {info.step} to {path}") - model = equinox.filter(info.model, self.keep_params) + state = equinox.filter(info.state, info.state.is_trainable) save_checkpoint( - model=model, - training_state=(info.opt_state, info.next_key), + state, step=info.step, checkpoint_path=path, ) @@ -231,7 +222,7 @@ def save_checkpoint(self, info, destination: str): logger.info(f"Saved checkpoint at step {info.step} to {path}. Save time is {self._last_save_time}") -def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike): +def save_checkpoint(tree: M, step: int, checkpoint_path: PathLike): """ Save a checkpoint to a given path using TensorStore. If exist_ok is True, the checkpoint will be saved even if a checkpoint already exists at the given path. @@ -249,10 +240,7 @@ def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike) fs, plain_path = _get_fs_and_plain_path(checkpoint_path) fs.makedirs(plain_path, exist_ok=True) - tree_serialize_leaves_tensorstore(os.path.join(checkpoint_path, "model"), model) - if training_state is not None: - tree_serialize_leaves_tensorstore(os.path.join(checkpoint_path, "training_state"), training_state) - + tree_serialize_leaves_tensorstore(checkpoint_path, tree) save_metadata(checkpoint_path, fs, step) logger.info(f"Saved checkpoint for step {step}") @@ -271,22 +259,28 @@ def save_metadata(checkpoint_path, fs, step): def load_checkpoint( - model: M, - training_state: S, + tree: M, checkpoint_path: PathLike, + env: Optional[haliax.ResourceEnv] = None, *, + subpath: Optional[str] = None, discover_latest=True, - axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, - mesh: Optional[jax.sharding.Mesh] = None, -) -> Optional[Tuple[M, S, int]]: +) -> M: """ - Load a checkpoint from a given path. + Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint + in a subdirectory of the given path will be loaded. If subpath is not None, then the checkpoint + loads only that subpath of the checkpoint. This is useful for loading, e.g., just the model and not + the entire training state. + + Args: + tree: an exemplar of the tree to load. Can be a PyTree[ShapeDTypeStruct] instead of a PyTree[Any] + checkpoint_path: the path to load the checkpoint from + subpath: the subpath to load from the checkpoint + discover_latest: whether to discover the latest checkpoint in the given path + env: the resource env to use for loading the checkpoint. if None, the current resource env is used + Returns: + the loaded checkpoint, with the same structure as the exemplar tree - Returns the loaded model state, training state, and step. If discover_latest is True, - the latest checkpoint in the given path will be loaded. Otherwise, the checkpoint at - the given path will be loaded. If no checkpoint is found, returns None - - If training_state is None, no training state will be loaded. """ fs: AbstractFileSystem fs, _ = _get_fs_and_plain_path(checkpoint_path) @@ -297,28 +291,47 @@ def load_checkpoint( checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore if checkpoint_path is None or not fs.exists(checkpoint_path): - return None + raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}") logger.info(f"Loading checkpoint from {checkpoint_path}") metadata = load_metadata(checkpoint_path, fs) - model = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "model"), model, axis_mapping=axis_mapping, mesh=mesh - ) + if subpath: + checkpoint_path = os.path.join(checkpoint_path, subpath) - if training_state is None: - training_state = None - else: - training_state = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "training_state"), training_state, axis_mapping=axis_mapping, mesh=mesh - ) + try: + tree = tree_deserialize_leaves_tensorstore(checkpoint_path, tree, env) + return tree + except: # noqa + from levanter.trainer import TrainerState - return model, training_state, metadata["step"] + if not isinstance(tree, TrainerState): + raise + else: + logger.warning("Attempting to load old-style checkpoint") + model, training_state = tree.model, (tree.opt_state, tree.training_key) + + model = tree_deserialize_leaves_tensorstore(os.path.join(checkpoint_path, "model"), model, env) + + if training_state is None: + opt_state = None + key = None + else: + training_state = tree_deserialize_leaves_tensorstore( + os.path.join(checkpoint_path, "training_state"), training_state, env + ) + opt_state, key = training_state + + # TODO: pretty sure this is right, but should verify + step = metadata["step"] + new_state = dataclasses.replace( + tree, _step=step + 1, model=model, opt_state=opt_state, training_key=key # type: ignore + ) + return new_state def load_metadata(checkpoint_path, fs=None): if fs is None: - fs: AbstractFileSystem fs, _, _ = fsspec.get_fs_token_paths(str(checkpoint_path)) with fs.open(os.path.join(checkpoint_path, "metadata.json")) as metadata_in: metadata = json.load(metadata_in) @@ -381,13 +394,12 @@ class CheckpointerConfig: def expanded_path(self, run_id): return os.path.expanduser(os.path.join(self.base_path, run_id)) - def create(self, run_id, keep_params: PyTree[FilterSpec] = True) -> Checkpointer: + def create(self, run_id) -> Checkpointer: keeps = [CheckpointInterval(**k) for k in self.keep] return Checkpointer( base_path=self.expanded_path(run_id), save_interval=self.save_interval, step_policies=keeps, - keep_params=keep_params, ) def __post_init__(self): @@ -403,3 +415,6 @@ def __post_init__(self): interval["until"] is None or interval["until"] > prev_interval["until"] ), "Checkpoint intervals must be monotonic" prev_interval = interval + + +# TODO: add partial checkpoint loading diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index 25339d245..138b33200 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -519,10 +519,7 @@ def load_pretrained( f"Model vocab size ({Vocab.size}) does not match tokenizer vocab size ({tokenizer_Vocab.size})" ) - if axis_mapping is not None: - lev_model = haliax.shard_with_axis_mapping(lev_model, axis_mapping) - else: - lev_model = haliax.auto_sharded(lev_model) + lev_model = haliax.shard(lev_model, axis_mapping) # once more for good measure gc.collect() diff --git a/src/levanter/compat/torch_serialization.py b/src/levanter/compat/torch_serialization.py index 1c326911b..0df3be216 100644 --- a/src/levanter/compat/torch_serialization.py +++ b/src/levanter/compat/torch_serialization.py @@ -89,8 +89,8 @@ def jax_tree_from_state_dict(tree: PyTree, state_dict: StateDict, prefix: Option raise ValueError("Cannot extract a leaf value from a torch dict without a prefix") array = state_dict[prefix] - mesh = haliax.partitioning._get_mesh() - if mesh.devices.size > 1: # this happens with the default mesh + mesh = haliax.current_resource_env().mesh + if mesh is not None: # this happens with the default mesh pspec = haliax.partitioning.pspec_for_axis(tree.axes) sharding = jax.sharding.NamedSharding(mesh, pspec) array = jax.make_array_from_callback(tree.array.shape, sharding, lambda indices: array[indices]) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index d3710207e..a879cc06a 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -8,12 +8,11 @@ import jax.numpy as jnp import jax.tree_util as jtu from jax.experimental import multihost_utils -from jax.sharding import Mesh, PartitionSpec +from jax.sharding import PartitionSpec from jaxtyping import Array, PyTree import haliax as hax from haliax import NamedArray -from haliax.partitioning import ResourceMapping from haliax.util import is_named_array import levanter.mesh @@ -34,25 +33,23 @@ class BatchLoader(Iterable[Ex], abc.ABC): + """ + Args: + Batch: the batch size + resource_env: the resource environment, if None then use the current one + max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread + """ + Batch: hax.Axis - mesh: Mesh - axis_resources: Optional[ResourceMapping] - - def __init__(self, max_capacity: Optional[int], axis_resources: Optional[ResourceMapping]): - """ - :param max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread - :param axis_resources: - """ + + def __init__(self, Batch: hax.Axis, resource_env: hax.ResourceEnv, max_capacity: Optional[int]): self.max_capacity = max_capacity - self.axis_resources = axis_resources + self.resource_env = resource_env or hax.current_resource_env() + self.Batch = Batch def __iter__(self) -> Iterator[Ex]: - ax_resources = self.axis_resources - if ax_resources is None: - ax_resources = hax.partitioning.current_thread_local_mapping() - def produce_batches(): - with hax.axis_mapping(ax_resources): + with self.resource_env: for batch in self._produce_batches(): yield batch @@ -110,7 +107,7 @@ def get_local_data_for_leaf(indices: _TensorSliceIndex, leaf_index: int) -> Arra def make_global_array_for_leaf(leaf_index, item_leaf_shape: Union[ShapeSpec, NamedShapeSpec]): raw_array = jax.make_array_from_callback( to_raw_shape(item_leaf_shape), - jax.sharding.NamedSharding(self.mesh, self._pspec_for(item_leaf_shape)), + jax.sharding.NamedSharding(self.resource_env.mesh, self._pspec_for(item_leaf_shape)), lambda indices: get_local_data_for_leaf(indices, leaf_index), ) if isinstance(item_leaf_shape, NamedShapeSpec): @@ -131,10 +128,10 @@ def make_global_array_for_leaf(leaf_index, item_leaf_shape: Union[ShapeSpec, Nam def _pspec_for(self, shape_spec: Union[ShapeSpec, NamedShapeSpec]) -> PartitionSpec: if isinstance(shape_spec, ShapeSpec): # type: ignore - batch_name = hax.partitioning.physical_axis_name(self.Batch, self.axis_resources) + batch_name = hax.partitioning.physical_axis_name(self.Batch, self.resource_env) return PartitionSpec(batch_name, *((None,) * (len(shape_spec.shape) - 1))) else: - return hax.partitioning.pspec_for_axis(shape_spec.shape, self.axis_resources) # type: ignore + return hax.partitioning.pspec_for_axis(shape_spec.shape, self.resource_env) # type: ignore class ShardedBatchLoader(BatchLoader[Ex]): @@ -153,37 +150,40 @@ class ShardedBatchLoader(BatchLoader[Ex]): load, by determining which row(s) of the device mesh the process is responsible for. :arg local_dataset: a dataset that is shardable and can be iterated over - :arg mesh: the device mesh :arg Batch: the batch size + :arg env: the resource environment, if None then use the current one :param max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread """ def __init__( self, local_dataset: ShardableDataset[Ex], - mesh: Mesh, Batch: hax.Axis, - axis_resources: Optional[ResourceMapping] = None, + env: Optional[hax.ResourceEnv] = None, max_capacity: Optional[int] = 10, *, override_process_data_pos: Optional[int] = None, # for testing override_process_data_groups: Optional[int] = None, # for testing ): - self.mesh = mesh - self.Batch = Batch - - process_data_pos = override_process_data_pos or levanter.mesh.process_mesh_position(mesh)[0] - num_data_process_groups = override_process_data_groups or levanter.mesh.process_mesh_size(mesh)[0] + env = env or hax.current_resource_env() + # TODO: this could be better + mesh = env.mesh + if mesh is not None: + process_data_pos = override_process_data_pos or levanter.mesh.process_mesh_position(mesh)[0] + num_data_process_groups = override_process_data_groups or levanter.mesh.process_mesh_size(mesh)[0] + else: + process_data_pos = override_process_data_pos or 0 + num_data_process_groups = override_process_data_groups or 1 if not override_process_data_groups: assert num_data_process_groups <= jax.process_count() self.process_data_pos = process_data_pos self.num_data_process_groups = num_data_process_groups - assert self.Batch.size % num_data_process_groups == 0 + assert Batch.size % num_data_process_groups == 0 self.item_dataset = local_dataset.shard(process_data_pos, num_data_process_groups) - super().__init__(max_capacity, axis_resources) + super().__init__(Batch, env, max_capacity) def _produce_batches(self) -> Iterator[PyTree]: one_item_generator = non_caching_cycle(self.item_dataset) @@ -228,27 +228,24 @@ class ReplicatedBatchLoader(BatchLoader[Ex]): Note: this class discards the final batch if it is smaller than the batch size. - :arg item_dataset: a dataset that is shardable and can be iterated over - :arg mesh: the device mesh - :arg Batch: the batch size - :arg axis_resources: the resources for the batch axis - :param max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread + Args: + item_dataset: the dataset to load + Batch: the batch size + env: the resource environment + max_capacity: if not None, the maximum number of batches to keep in memory at once. If <0 then load in the main thread """ def __init__( self, item_dataset: Dataset[Ex], - mesh: Mesh, Batch: hax.Axis, - axis_resources: Optional[ResourceMapping] = None, + env: Optional[hax.ResourceEnv] = None, max_capacity: Optional[int] = 10, ): assert item_dataset is not None self.item_dataset = item_dataset - self.mesh = mesh - self.Batch = Batch - super().__init__(max_capacity, axis_resources) + super().__init__(Batch, env, max_capacity) def _produce_batches(self): for batch in _batched(self.item_dataset, self.Batch.size): diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 569bbe711..f5faa9b36 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -16,7 +16,6 @@ import pyarrow as pa import pyarrow.parquet as pq import ray -import wandb from dataclasses_json import dataclass_json from fsspec import AbstractFileSystem from ray.actor import ActorHandle @@ -31,6 +30,8 @@ TimeRemainingColumn, ) +import levanter.tracker + from .. import logging from ..utils.ray_utils import ExceptionInfo, RefBox, current_actor_handle, ser_exc_info from ._preprocessor import BatchProcessor, BatchResult, as_record_batch, dict_from_record_batch @@ -739,7 +740,7 @@ def __call__(self, metrics: InProgressCacheMetrics): self.last_metrics = metrics self.last_time = time.time() - wandb.log(to_log, commit=self.commit) + levanter.tracker.log_metrics(to_log, step=None, commit=self.commit) class LoggerMetricsMonitor(MetricsMonitor): diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py new file mode 100644 index 000000000..1924ebe29 --- /dev/null +++ b/src/levanter/doremi.py @@ -0,0 +1,305 @@ +import dataclasses +import logging +from typing import Callable, Iterator, Optional, Tuple, TypeVar + +import equinox as eqx +import jax.numpy as jnp +import jax.random as jrandom +import optax +from jaxtyping import PRNGKeyArray + +import haliax as hax +from haliax.types import IntScalar + +import levanter.tracker +from levanter.callbacks import eval_loss_loop +from levanter.data import ShardableDataset +from levanter.data.mixture import MixtureDataset +from levanter.logging import capture_time +from levanter.trainer import M, StepInfo, Trainer, TrainerConfig, TrainerState, take_opt_step +from levanter.types import ComputeLossFunction, ModuleComputeLoss +from levanter.utils.tree_utils import inference_mode + + +logger = logging.getLogger(__name__) + + +T = TypeVar("T") + + +# TODO: should we put ref in the state? If so, need to tell it to not serialize it +class DoremiState(TrainerState): + alpha: hax.NamedArray + average_alpha: hax.NamedArray + + def update_alpha(self, alpha): + average_alpha = self.average_alpha + (alpha - self.average_alpha) / (self._step + 1) + return dataclasses.replace(self, alpha=alpha, average_alpha=average_alpha) + + +class DoReMiTrainer(Trainer): + # we just use the DoReMi trainer for state management + + def __init__(self, trainer: TrainerConfig, optimizer: optax.GradientTransformation, initial_alpha: hax.NamedArray): + super().__init__(trainer, optimizer) + self.initial_alpha = initial_alpha + + # TODO: I'd like to not need to override trainer for this + def _initialize_state_from_scratch(self, model: Callable[[], M], training_key, is_trainable): + base_state = super()._initialize_state_from_scratch(model, training_key, is_trainable) + + return DoremiState(**base_state.__dict__, alpha=self.initial_alpha, average_alpha=self.initial_alpha) + + +@dataclasses.dataclass +class DoReMiConfig: + # This is designed to be used with estimate_mixture_weights + domain_weight_step_size: float = 1.0 + smoothing: float = 1e-3 + sampling_weights: Optional[dict[str, float]] = None + + +DEFAULT_DOREMI_TRAINER_CONFIG = TrainerConfig( + num_train_steps=10000, + train_batch_size=512, +) + + +def estimate_mixture_weights( + initial_proxy: M, + ref: M, + data_sources: dict[str, ShardableDataset[T]], + sampling_weights: Optional[dict[str, float]] = None, + *, + validation_sets: Optional[dict[str, ShardableDataset[T]]] = None, + trainer_config: TrainerConfig = DEFAULT_DOREMI_TRAINER_CONFIG, + optimizer: optax.GradientTransformation = optax.adamw(1e-3), + loss_fn: ComputeLossFunction[M, T] = ModuleComputeLoss(), + domain_weight_step_size: float = 1.0, + smoothing: float = 1e-3, + key: PRNGKeyArray, +) -> dict[str, float]: + """ + Estimate the mixture weights for the data sources using DoReMi. + https://arxiv.org/abs/2305.10429 + + Args: + trainer_config: Trainer config + initial_proxy: Initial proxy model + ref: Reference model + data_sources: Data sources to estimate the weights for + sampling_weights: Sampling weights for the data sources. If not provided, will use uniform sampling weights. + loss_fn: Loss function to use for the proxy and ref models. If not provided, will use the model's compute_loss + domain_weight_step_size: Step size for the domain weights + smoothing: Smoothing for the domain weights + key: PRNG key + """ + if len(data_sources) <= 1: + raise ValueError("Must have at least two data sources") + + training_key, data_key = jrandom.split(key) + domain_indices = list(data_sources.keys()) + domain_to_index = {domain: index for index, domain in enumerate(domain_indices)} + + # Initialize domain weights. + # TODO: should we initialize to the ref or to uniform? + Domain = hax.Axis("domain", len(domain_indices)) + initial_alpha = hax.ones(Domain) / Domain.size + + trainer = DoReMiTrainer(trainer_config, optimizer, initial_alpha) + with trainer: + ref = _prepare_ref_model(ref, trainer_config) + + if validation_sets is not None: + + @eqx.filter_jit + def eval_loss(model, *batch, **batch_kwargs): + model = inference_mode(model, True) + with trainer.compute_env: + return trainer.loss_fn(model, *batch, **batch_kwargs, key=None) + + for domain, dataset in validation_sets.items(): + loss = eval_loss_loop( + eval_loss, + ref, + trainer.replicated_loader(dataset, trainer.EvalBatch), + name=f"ref {domain}", + max_batches=trainer_config.max_eval_batches, + ) + print(f"Loss of ref model on domain {domain}: {loss:.3f}") + levanter.tracker.log_metrics({f"eval/ref/{domain}/loss": loss}, step=0, commit=False) + + if validation_sets is not None: + for domain, dataset in validation_sets.items(): + trainer.add_eval_hook(dataset, name=domain) + + if sampling_weights is not None: + assert set(sampling_weights.keys()) == set(data_sources.keys()) + sampling_weights = { + domain: weight / sum(sampling_weights.values()) for domain, weight in sampling_weights.items() + } + else: + sampling_weights = {domain: 1 / len(data_sources) for domain in data_sources.keys()} + + # Loss is \sum_d alpha_d * (proxy - ref) (basically the unclipped excess loss with the new alpha) + # Note that (\sum_d \alpha_d ref) is a constant in the model params, so we can ignore it for gradient computation + # (JAX would ignore it for us I think but it's nice to be explicit and lets us log better) + @hax.named_jit(axis_resources=trainer.parameter_axis_mapping, donate_args=(True,)) + def doremi_step(state: DoremiState, ref, batch, domains): + proxy = inference_mode(state.model, False) + with trainer.compute_env: + # calculate per-token losses for proxy and ref + proxy_losses, proxy_loss_bwd = eqx.filter_vjp(lambda p: loss_fn(p, batch, reduction_axis=()), proxy) + ref_losses = loss_fn(ref, batch, reduction_axis=()) + + # calculate excess losses, aggregate per-domain losses + excess_losses = proxy_losses - ref_losses + clipped_losses = hax.maximum(excess_losses, 0) + per_domain_losses = _compute_per_domain_losses(clipped_losses, Domain, domains) + + # Update domain weights + alpha = state.alpha * hax.exp(domain_weight_step_size * per_domain_losses) + alpha /= hax.sum(alpha) + alpha = (1 - smoothing) * alpha + initial_alpha * smoothing + + # Update proxy model weights θt for the objective L(θt−1, αt) (using Adam, Adafactor, etc.) + # Note DoReMi says to use the unclipped excess loss here. Confirmed with Michael + loss, grad_loss = eqx.filter_value_and_grad(_domain_weighted_loss)(excess_losses, Domain, domains, alpha) + grad = proxy_loss_bwd(grad_loss)[0] + + partial_loss = lambda model: loss_fn(model, *batch) + model, opt_state = take_opt_step( + optimizer, state.model, state.opt_state, grad, partial_loss, state.is_trainable + ) + + new_state = dataclasses.replace(state, model=model, opt_state=opt_state, _step=state._step + 1) + new_state = new_state.update_alpha(alpha) + + # log metrics + distance_from_uniform = hax.sum(hax.abs(alpha - initial_alpha)) + mean_excess_loss = hax.mean(excess_losses).scalar() + mean_proxy_loss = hax.mean(proxy_losses).scalar() + alpha_distance = hax.sum(hax.abs(new_state.average_alpha - state.average_alpha)) + alpha_dict = _decode_domain_array(Domain, new_state.average_alpha, domain_to_index) + per_domain_dict = _decode_domain_array(Domain, per_domain_losses, domain_to_index) + + levanter.tracker.jit_log_metrics( + { + "change_in_alpha": alpha_distance.scalar(), + "alpha_distance_from_uniform": distance_from_uniform.scalar(), + "train/mean_excess_loss": mean_excess_loss, + "train/mean_proxy_loss": mean_proxy_loss, + **{f"alpha/{domain}": weight for domain, weight in alpha_dict.items()}, + # just skip domains with no excess loss + # TODO: we need to skip logging things that are 0, but can't do that in jit, have to do it in python + **{f"train/{domain}/excess_loss": loss for domain, loss in per_domain_dict.items()}, + }, + step=state._step, + ) + + return loss, alpha_distance, new_state + + # we're not actually going to use the trainer for very much but it holds hooks and sets up contexts + with trainer: + tagged_mixture = domain_tagged_mixture(data_sources, sampling_weights, domain_to_index, key=data_key) + state: DoremiState = trainer.initial_state(training_key, model=initial_proxy) + del initial_proxy + train_loader = iter(trainer.sharded_loader(tagged_mixture, trainer.TrainBatch)) + + if state.step > 0: + # step is after the batch, so we need to seek to step + # TODO: implement iter_data.seek(resume_step +1) + import tqdm + + for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): + next(train_loader) + + while state.step < trainer.num_train_steps: + example, ex_domains = next(train_loader) + + with capture_time() as step_time: + loss, alpha_distance, state = doremi_step(state, ref, example, ex_domains) + loss = loss.item() # type: ignore + + new_info = StepInfo(state, loss, step_time()) + + trainer.run_hooks(new_info) + + trainer.run_hooks(new_info, force=True) + + alpha = state.average_alpha + final_weights = _decode_domain_array(Domain, alpha, domain_to_index) + + levanter.tracker.log_summary({"final_alpha": final_weights}) + + return {k: float(v) for k, v in final_weights.items()} + + +def _decode_domain_array(Domain, alpha, domain_name_to_index): + final_weights = {domain: alpha[Domain, index].scalar() for domain, index in domain_name_to_index.items()} + return final_weights + + +def _compute_per_domain_losses(losses, Domain, domains): + """Compute per-domain average losses from a batch of losses""" + # out[d] = E[losses | domain=d] + one_hot_domains = hax.nn.one_hot(domains, Domain) # Domain x Batch + per_domain_losses = hax.dot(one_hot_domains, losses, axis=losses.axes, out_axes=(Domain,)) + # count the number of losses for each domain + norm = hax.dot(one_hot_domains, losses != 0, axis=losses.axes, out_axes=(Domain,)) + norm = hax.maximum(norm, 1.0) # don't nan if there are no losses for a domain + return per_domain_losses / norm + + +def _domain_weighted_loss(losses, Domain, domains, alpha): + """Average loss weighted by domain weights""" + per_domain_losses = _compute_per_domain_losses(losses, Domain, domains) + return hax.dot(alpha, per_domain_losses, axis=Domain).scalar() + + +def _prepare_ref_model(ref, trainer): + return hax.named_jit( + lambda m: trainer.mp.cast_to_compute(inference_mode(m, True)), + axis_resources=trainer.parameter_axis_mapping, + donate_args=True, + )(ref) + + +def domain_tagged_mixture( + data_sources: dict[str, ShardableDataset[T]], + weights: dict[str, float], + domain_to_index: dict[str, int], + *, + key: PRNGKeyArray, +) -> MixtureDataset[Tuple[T, IntScalar]]: + """ + Domain tagged mixture dataset. This dataset will yield from the datasets according to the weights, + and will yield the domain index as a second element of the tuple. + """ + tagged_datasets = { + domain: DomainTaggedDataset(data_sources[domain], domain_index) + for domain, domain_index in domain_to_index.items() + } + + return MixtureDataset(tagged_datasets, weights, key=key) + + +class DomainTaggedDataset(ShardableDataset[Tuple[T, hax.NamedArray]]): # named array is a scalar int + def __init__( + self, + dataset: ShardableDataset[T], + domain_index: int | hax.NamedArray, + ): + self.dataset = dataset + + if isinstance(domain_index, int): + self.domain_index = hax.named(jnp.array(domain_index, dtype=int), ()) + else: + self.domain_index = domain_index + + def shard(self, shard_id: int, num_shards: int) -> "DomainTaggedDataset[T]": + return DomainTaggedDataset(self.dataset.shard(shard_id, num_shards), self.domain_index) + + def __iter__(self) -> Iterator[Tuple[T, hax.NamedArray]]: + for item in self.dataset: + yield item, self.domain_index diff --git a/src/levanter/logging.py b/src/levanter/logging.py index 7ffa90c91..78588669f 100644 --- a/src/levanter/logging.py +++ b/src/levanter/logging.py @@ -1,50 +1,27 @@ import contextlib -import dataclasses -import logging import logging as pylogging import os -import tempfile import time -import warnings -from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Union +from typing import List, Union -import draccus import jax -import wandb -from draccus import field -from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from optax import MultiStepsState -from levanter.utils import jax_utils -from levanter.utils.jax_utils import jnp_to_python +pylogger = pylogging.getLogger(__name__) -logger = pylogging.getLogger(__name__) - -def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): - if isinstance(opt_state, MultiStepsState): - opt_state = opt_state.inner_opt_state - - def wrap_key(key): - if prefix: - return f"{prefix}/{key}" - return key - - if hasattr(opt_state, "hyperparams"): - params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} - wandb.log(params, step=step) - - -def init_logger(path: Union[str, Path], level: int = pylogging.INFO) -> None: +def init_logging(log_dir: Union[str, Path], run_id: str, level: int = pylogging.INFO) -> None: """ Initialize logging.Logger with the appropriate name, console, and file handlers. :param path: Path for writing log file :param level: Default logging level """ + log_dir = Path(log_dir) + log_dir.mkdir(parents=True, exist_ok=True) + path = log_dir / f"{run_id}.log" + process_index = jax.process_index() log_format = f"%(asctime)s - {process_index} - %(name)s - %(filename)s:%(lineno)d - %(levelname)s :: %(message)s" # use ISO 8601 format for timestamps, except no TZ, because who cares @@ -62,13 +39,21 @@ def init_logger(path: Union[str, Path], level: int = pylogging.INFO) -> None: def save_xla_dumps_to_wandb(initial_time: float): import os + from levanter.tracker.wandb import is_wandb_available + + if not is_wandb_available(): + pylogger.warning("Wandb is not available, so we can't save XLA dumps") + return + + import wandb + # attempt to parse xla_flags to see if we're dumping assembly files flags = os.getenv("XLA_FLAGS", None) if flags is not None and "xla_dump_to" in flags: # parse the path # this isn't robust to quotes path = flags.split("xla_dump_to=")[1].split(" ")[0] - logger.info(f"Found xla_dump_to={path}, logging to wandb") + pylogger.info(f"Found xla_dump_to={path}, logging to wandb") if wandb.run: # only want to save the files that were generated during this run # XLA_FLAGS has to be set before the first jax call, so we can't just set it in the middle of the run @@ -80,7 +65,7 @@ def include_file(path: str): wandb.run.log_code(root=path, name="xla_dumps", include_fn=include_file) else: - logger.warning("XLA_FLAGS is not set to dump to a path, so we can't save the dumps to wandb") + pylogger.warning("XLA_FLAGS is not set to dump to a path, so we can't save the dumps to wandb") @contextlib.contextmanager @@ -98,23 +83,6 @@ def fn(): end = time.time() -@contextlib.contextmanager -def log_time_to_wandb(name: str, *, step=None): - with capture_time() as fn: - yield fn - wandb.log({name: fn()}, step=step) - - -def jittable_wandb_log(data, *, step=None): - """uses jax effect callback to log to wandb from the host""" - if is_wandb_available(): - jax.debug.callback(wandb.log, data, step=step) - - -def is_wandb_available(): - return wandb is not None and wandb.run is not None - - def silence_transformer_nag(): # this is a hack to silence the transformers' "None of PyTorch, TensorFlow 2.0 or Flax have been found..." thing # which is annoying and not useful @@ -123,172 +91,3 @@ def silence_transformer_nag(): os.environ["TRANSFORMERS_VERBOSITY"] = "error" import transformers # noqa: F401 - - -@dataclass -class WandbConfig: - """ - Configuration for wandb. - """ - - entity: Optional[str] = None # An entity is a username or team name where you send runs - project: Optional[str] = None # The name of the project where you are sending the enw run. - name: Optional[str] = None # A short display name for this run, which is how you'll identify this run in the UI. - tags: List[str] = field(default_factory=list) # Will populate the list of tags on this run in the UI. - id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project - group: Optional[str] = None # Specify a group to organize individual runs into a larger experiment. - mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be online. - resume: Optional[Union[bool, str]] = None # - """ - Set the resume behavior. Options: "allow", "must", "never", "auto" or None. - By default, if the new run has the same ID as a previous run, this run overwrites that data. - Please refer to [init](https://docs.wandb.ai/ref/python/init) and [resume](https://docs.wandb.ai/guides/runs/resuming) - document for more details. - """ - - save_code: Union[bool, str] = True - """If string, will save code from that directory. If True, will attempt to sniff out the main directory (since we - typically don't run from the root of the repo).""" - - save_xla_dumps: bool = False - """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" - - def init(self, run_id: Optional[str], hparams=None, **extra_hparams): - import wandb - - if run_id is not None and self.id is not None and run_id != self.id: - warnings.warn( - f"Both trainer's id {run_id} and WandB's id {self.id} are set. WandB will use the id set in its" - " config." - ) - - id = self.id - if id is None: - id = run_id - - if hparams is None: - hparams_to_save = {} - elif dataclasses.is_dataclass(hparams): - hparams_to_save = dataclasses.asdict(hparams) - else: - hparams_to_save = dict(hparams) - - if extra_hparams: - hparams_to_save.update(extra_hparams) - - # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled - # however, we do share information about the run id, so that we can link to it from the other workers - mode = self.mode - if jax.process_index() != 0: - mode = "disabled" - - if isinstance(self.save_code, str): - code_dir = self.save_code - elif self.save_code: - code_dir = WandbConfig._infer_experiment_git_root() or "." # type: ignore - else: - code_dir = None - - other_settings = dict() - if code_dir is not None: - logger.info(f"Setting wandb code_dir to {code_dir}") - other_settings["code_dir"] = code_dir - other_settings["git_root"] = code_dir - # for some reason, wandb isn't populating the git commit, so we do it here - try: - repo = Repo(code_dir) - other_settings["git_commit"] = repo.head.commit.hexsha - hparams_to_save["git_commit"] = repo.head.commit.hexsha - except (NoSuchPathError, InvalidGitRepositoryError): - logger.warning(f"Could not find git repo at {code_dir}") - pass - - r = wandb.init( - entity=self.entity, - project=self.project, - name=self.name, - tags=self.tags, - id=id, - group=self.group, - resume=self.resume, - mode=mode, - config=hparams_to_save, - settings=other_settings, - allow_val_change=True, - ) - - assert r is not None - - if jax.process_count() > 1: - # we need to share wandb run information across all hosts, because we use it for checkpoint paths and things - metadata_to_share = dict( - entity=r.entity, - project=r.project, - name=r.name, - tags=r.tags, - id=r.id, - group=r.group, - ) - metadata_to_share = jax_utils.multihost_broadcast_sync( - metadata_to_share, is_source=jax.process_index() == 0 - ) - - if jax.process_index() != 0: - assert r.mode == "disabled" - for k, v in metadata_to_share.items(): - setattr(r, k, v) - - logger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") - - if dataclasses.is_dataclass(hparams): - with tempfile.TemporaryDirectory() as tmpdir: - config_path = os.path.join(tmpdir, "config.yaml") - with open(config_path, "w") as f: - draccus.dump(hparams, f, encoding="utf-8") - if wandb.run is not None: - wandb.run.log_artifact(str(config_path), name="config.yaml", type="config") - - # generate a pip freeze - with tempfile.TemporaryDirectory() as tmpdir: - requirements_path = os.path.join(tmpdir, "requirements.txt") - requirements = _generate_pip_freeze() - with open(requirements_path, "w") as f: - f.write(requirements) - if wandb.run is not None: - wandb.run.log_artifact(str(requirements_path), name="requirements.txt", type="requirements") - - wandb.summary["num_devices"] = jax.device_count() - wandb.summary["num_hosts"] = jax.process_count() - wandb.summary["backend"] = jax.default_backend() - - @staticmethod - def _infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: - # sniff out the main directory (since we typically don't run from the root of the repo) - # we'll walk the stack and directories for the files in the stack the until we're at a git root - import os - import traceback - - stack = traceback.extract_stack() - # start from the top of the stack and work our way down since we want to hit the main file first - top_git_root = None - for frame in stack: - dirname = os.path.dirname(frame.filename) - # bit hacky but we want to skip anything that's in the python env - if any(x in dirname for x in ["site-packages", "dist-packages", "venv", "opt/homebrew", "conda", "pyenv"]): - continue - # see if it's under a git root - try: - repo = Repo(dirname, search_parent_directories=True) - top_git_root = repo.working_dir - break - except (NoSuchPathError, InvalidGitRepositoryError): - logger.debug(f"Skipping {dirname} since it's not a git root") - pass - return top_git_root - - -def _generate_pip_freeze(): - from importlib.metadata import distributions - - dists = distributions() - return "\n".join(f"{dist.name}=={dist.version}" for dist in dists) diff --git a/src/levanter/lora.py b/src/levanter/lora.py index 0af676ef5..3e0dee750 100644 --- a/src/levanter/lora.py +++ b/src/levanter/lora.py @@ -367,14 +367,14 @@ def save_peft_checkpoint_callback( If hf_repo is provided, this will upload the checkpoint to the huggingface hub, passing any additional kwargs to the huggingface_hub.upload_folder function. - Args - base_path: the base path to save the checkpoint to. `/step-` will be appended to this. base_path - may be a GCS bucket path, in which case the checkpoint will be uploaded to GCS after being written to a tmp - config: the LoRA config to use - base_model_name_or_path: the name or path of the base model - tokenizer: If provided, will save the tokenizer to the checkpoint - upload_to_hf: the repo to upload to. If a string, will be interpreted as a repo name + branch - hf_upload_kwargs: kwargs to pass to the upload function + Args: + base_path: the base path to save the checkpoint to. `/step-` will be appended to this. base_path + may be a GCS bucket path, in which case the checkpoint will be uploaded to GCS after being written to a tmp + config: the LoRA config to use + base_model_name_or_path: the name or path of the base model + tokenizer: If provided, will save the tokenizer to the checkpoint + upload_to_hf: the repo to upload to. If a string, will be interpreted as a repo name + branch + hf_upload_kwargs: kwargs to pass to the upload function """ def cb(step: StepInfo): diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 0b0636f4b..4dd46e63c 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -1,14 +1,13 @@ import logging import os -from dataclasses import dataclass - -import wandb +from dataclasses import dataclass, field import levanter from levanter.data.shard_cache import LoggingMetricsMonitor, RichMetricsMonitor, build_cache from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.distributed import RayConfig -from levanter.logging import init_logger +from levanter.logging import init_logging +from levanter.tracker import NoopConfig, TrackerConfig logger = logging.getLogger(__name__) @@ -16,19 +15,17 @@ @dataclass class RayCachedLMDatasetConfig(LMDatasetConfig, RayConfig): - pass + tracker: TrackerConfig = field(default_factory=NoopConfig) @levanter.config.main() def main(args: RayCachedLMDatasetConfig): """Caches two different kinds of datasets. It can cache a dataset from a list of urls, or a dataset from a hf dataset""" - init_logger("cache_dataset.log") + init_logging(".", "cache_dataset.log") args.initialize() tokenizer = args.the_tokenizer - wandb.init(mode="offline") - for split in ["train", "validation"]: print(f"Caching {split} to {args.cache_dir}.") # connect or start the actor @@ -49,6 +46,7 @@ def main(args: RayCachedLMDatasetConfig): rows_per_chunk=args.rows_per_chunk, await_finished=False, monitors=monitors, + batch_size=128, ) cache.await_finished() diff --git a/src/levanter/main/doremi_lm.py b/src/levanter/main/doremi_lm.py new file mode 100644 index 000000000..294d75c72 --- /dev/null +++ b/src/levanter/main/doremi_lm.py @@ -0,0 +1,144 @@ +import logging +from dataclasses import dataclass, field +from typing import Union + +import equinox as eqx +import jax.random as jrandom + +from haliax import Axis +from haliax.partitioning import named_jit, round_axis_for_partitioning + +import levanter +from levanter.compat.hf_checkpoints import HFCompatConfig +from levanter.data.text import CausalLmDataset, LMMixtureDatasetConfig +from levanter.doremi import DoReMiConfig, estimate_mixture_weights +from levanter.models.gpt2 import Gpt2Config +from levanter.models.lm_model import LmConfig +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.trainer import TrainerConfig +from levanter.utils.tree_utils import inference_mode + + +logger = logging.getLogger(__name__) + + +@dataclass +class TrainLmConfig: + ref_model_path: str + ref_model_from_hf: bool = False + + data: LMMixtureDatasetConfig = field(default_factory=LMMixtureDatasetConfig) + trainer: TrainerConfig = field(default_factory=TrainerConfig) + model: LmConfig = field(default_factory=Gpt2Config) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) + doremi: DoReMiConfig = field(default_factory=DoReMiConfig) + + # config related to continued pretraining + initialize_from_hf: Union[bool, str] = False + """if provided, this will override the model config in the config. if true, use the default hf checkpoint for this model class""" + use_hf_model_config: bool = False # if true, replace the model config with the hf config from the checkpoint + + # TODO: atm we don't support loading from a checkpoint that has a different tokenizer. this is a bit annoying + # TODO: atm you have to at least specify a levanter model config with the same type as the hf checkpoint + + +def main(config: TrainLmConfig): + levanter.initialize(config) + + tokenizer = config.data.the_tokenizer + + # this is some unpleasant code to allow us to initialize from a hf checkpoint. If this is your first read through, + # I recommend skipping it for now + if config.initialize_from_hf: + if config.trainer.initialize_from is not None: + raise ValueError("Cannot specify both initialize_from_hf and initialize_from") + + assert isinstance(config.model, HFCompatConfig) + converter = config.model.default_hf_checkpoint_converter + if hasattr(tokenizer, "vocab") and tokenizer.vocab != converter.tokenizer.vocab: + logger.warning("The tokenizers appear to be different. You may want to check this.") + + if isinstance(config.initialize_from_hf, str): + converter = converter.replaced(reference_checkpoint=config.initialize_from_hf, tokenizer=tokenizer) + else: + converter = converter.replaced(tokenizer=tokenizer) + + if config.use_hf_model_config: + # TODO: log diff of old and new config + # NB: gross mutability + config.model = converter.config_from_hf_config(converter.default_hf_config) + elif isinstance(config.model, HFCompatConfig): + converter = config.model.default_hf_checkpoint_converter + converter = converter.replaced(tokenizer=tokenizer) + else: + converter = None + + optimizer = config.optimizer.build(config.trainer.num_train_steps) + + parameter_axis_mapping = config.trainer.parameter_axis_mapping + + with config.trainer.device_mesh: + vocab_size = len(tokenizer) + Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping) + if vocab_size != Vocab.size: + logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") + + # initialize the ref model + if config.ref_model_from_hf: + assert converter is not None + ref_model = converter.load_pretrained(type(config.model)) + else: + ref_model_shape = eqx.filter_eval_shape(config.model.build, Vocab, key=jrandom.PRNGKey(0)) + ref_model = levanter.checkpoint.load_checkpoint( + ref_model_shape, config.ref_model_path, env=config.trainer.param_env, subpath="model" + ) + + ref_model = inference_mode(ref_model, True) + + training_key, model_key = jrandom.split(jrandom.PRNGKey(config.trainer.seed), 2) + + @named_jit(axis_resources=parameter_axis_mapping) + def init_proxy_model(): + return config.model.build(Vocab, key=model_key) + + proxy_model = init_proxy_model() + + train_datasets = config.data.training_sets(ref_model.Pos.size) + valid_datasets = config.data.validation_sets(ref_model.Pos.size) + + train_datasets = { + k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id) + for k, v in train_datasets.items() + } + valid_datasets = { + k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id) + for k, v in valid_datasets.items() + } + + mixture_weights = estimate_mixture_weights( + proxy_model, + ref=ref_model, + data_sources=train_datasets, + trainer_config=config.trainer, + optimizer=optimizer, + domain_weight_step_size=config.doremi.domain_weight_step_size, + sampling_weights=config.doremi.sampling_weights, + validation_sets=valid_datasets, + key=training_key, + ) + + print(mixture_weights) + + # dump to a yaml file + weights_path = "mixture_weights.yaml" + with open(weights_path, "w") as f: + import yaml + + yaml.dump(mixture_weights, f) + + # log as an artifact + levanter.tracker.current_tracker().log_artifact(weights_path, name="mixture_weights.yaml") + + +if __name__ == "__main__": + levanter.config.main(main)() diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index 6262eb428..24be4b11d 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -41,7 +41,7 @@ class EvalLmConfig: def main(config: EvalLmConfig): - config.trainer.initialize(config) + levanter.initialize(config) tokenizer = config.data.the_tokenizer Batch = Axis("batch", config.trainer.eval_batch_size) @@ -51,13 +51,17 @@ def main(config: EvalLmConfig): if config.eval_on_train: raw_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) else: - raw_dataset = CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos) # type: ignore + validation_set = config.data.validation_set(Pos.size) + if validation_set is None: + raise ValueError("Can't eval on validation_set b/c there isn't one!") + + raw_dataset = CausalLmDataset(validation_set, Pos, KeyPos) # type: ignore - eval_loader = ReplicatedBatchLoader(raw_dataset, config.trainer.device_mesh, Batch) compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping - with config.trainer.device_mesh, hax.axis_mapping(parameter_axis_mapping): + with config.trainer.param_env: + eval_loader = ReplicatedBatchLoader(raw_dataset, Batch) key = jax.random.PRNGKey(0) vocab_size = len(tokenizer) @@ -81,14 +85,10 @@ def compute_loss(model: LmHeadModel, example: LmExample): with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(model, None, config.checkpoint_path) - - assert ckpt is not None - model, _, _ = ckpt + model = load_checkpoint(model, config.checkpoint_path, subpath="model") - model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) + model = hax.shard(model, config.trainer.param_env) - # TODO: switch to throwing instead of returning None loss = callbacks.eval_loss_loop(compute_loss, model, eval_loader, max_batches=total) del model diff --git a/src/levanter/main/export_lm_to_hf.py b/src/levanter/main/export_lm_to_hf.py index 50a8e4b92..7fd4d073d 100644 --- a/src/levanter/main/export_lm_to_hf.py +++ b/src/levanter/main/export_lm_to_hf.py @@ -51,10 +51,9 @@ def main(config: ConvertLmConfig): model: LmHeadModel = eqx.filter_eval_shape(config.model.build, Vocab, key=key) trainable, non_trainable = eqx.partition(model, is_inexact_arrayish) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(trainable, None, config.checkpoint_path) + trainable = load_checkpoint(trainable, config.checkpoint_path, subpath="model") - assert ckpt is not None - trainable, _, _ = ckpt + assert trainable is not None model = eqx.combine(trainable, non_trainable) if config.override_vocab_size: diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index d19c80943..862aff5d5 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -3,15 +3,15 @@ from dataclasses import dataclass, field from typing import Optional +import equinox as eqx import jax.random as jrandom -import wandb import haliax.random import levanter from levanter import callbacks from levanter.compat.hf_checkpoints import HFCheckpointConverter -from levanter.data.text import CausalLmDataset, LMDatasetConfig, LmExample +from levanter.data.text import CausalLmDataset, LMDatasetConfig from levanter.lora import ( LoraConfig, lora_trainable_params_filter, @@ -19,7 +19,8 @@ save_merged_hf_checkpoint_callback, save_peft_checkpoint_callback, ) -from levanter.trainer import OptimizerConfig, Trainer, TrainerConfig +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig from levanter.utils.jax_utils import parameter_count from levanter.utils.py_utils import non_caching_cycle @@ -33,7 +34,7 @@ class LoraLmConfig: lora: LoraConfig = field(default_factory=LoraConfig) data: LMDatasetConfig = field(default_factory=LMDatasetConfig) trainer: TrainerConfig = field(default_factory=TrainerConfig) - optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) peft_save_path: Optional[str] = None # path to save peft-compatible checkpoints peft_hf_upload: Optional[str] = None @@ -46,6 +47,7 @@ class LoraLmConfig: def main(config: LoraLmConfig): + levanter.initialize(config) tokenizer = config.data.the_tokenizer converter = HFCheckpointConverter.from_hf(config.initialize_from_hf, trust_remote_code=config.trust_remote_code) @@ -54,7 +56,6 @@ def main(config: LoraLmConfig): converter = converter.replaced(tokenizer=tokenizer) - config.trainer.initialize(config) model_config = converter.default_config # randomness in jax is tightly controlled by "keys" which are the states of the random number generators @@ -67,7 +68,9 @@ def main(config: LoraLmConfig): Pos = model_config.Pos KeyPos = model_config.KeyPos - with config.trainer.device_mesh: + optimizer = config.optimizer.build(config.trainer.num_train_steps) + + with Trainer(config.trainer, optimizer) as trainer: # how we shard parameters across devices parameter_axis_mapping = config.trainer.parameter_axis_mapping @@ -83,32 +86,38 @@ def loraize_hf_model(model): lora_param_filter = lora_trainable_params_filter(model) - def compute_loss(model, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() - - optimizer = config.optimizer.build(config.trainer.num_train_steps) - # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) - state = trainer.initial_state(training_key, model=model) + eval_datasets = config.data.validation_sets(Pos.size) + + state = trainer.initial_state(training_key, model=model, is_trainable=lora_param_filter) all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) + just_lora_params = parameter_count(eqx.filter(state.model, lora_param_filter)) + + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params logger.info(f"Total parameter count: {all_param_count}") logger.info(f"Trainable parameter count: {just_lora_params}") logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") # data loaders - eval_dataset = CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos) # type: ignore + if len(eval_datasets) == 0: + logger.warning("No evaluation datasets provided.") + + for name, eval_dataset in eval_datasets.items(): + eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos) + trainer.add_eval_hook(eval_dataset, name=name) train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) train_loader = trainer.sharded_loader(train_dataset, Batch) # boilerplate hooks and such - trainer.add_default_hooks(eval_dataset) trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) if config.peft_save_path is not None: full_save_path = os.path.join(config.peft_save_path, trainer.run_id) @@ -134,7 +143,7 @@ def compute_loss(model, example: LmExample, key=None): # TODO: implement iter_data.seek(resume_step +1) import tqdm - for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): + for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): next(iter_data) ## OK, actually run training! diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 982f72358..3c82e2fe2 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -5,7 +5,6 @@ from typing import Optional, Union import jax.random as jrandom -import wandb import haliax as hax from haliax import Axis @@ -16,8 +15,9 @@ from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig from levanter.models.gpt2 import Gpt2Config -from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel -from levanter.trainer import OptimizerConfig, Trainer, TrainerConfig +from levanter.models.lm_model import LmConfig, LmExample +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig from levanter.utils.jax_utils import parameter_count @@ -29,7 +29,7 @@ class TrainLmConfig: data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig) trainer: TrainerConfig = field(default_factory=TrainerConfig) model: LmConfig = field(default_factory=Gpt2Config) - optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) # config related to continued pretraining initialize_from_hf: Union[bool, str] = False @@ -45,8 +45,12 @@ class TrainLmConfig: hf_upload: Optional[str] = None hf_save_steps: int = 10000 + update_hessian_steps: int = 10 + def main(config: TrainLmConfig): + levanter.initialize(config) + tokenizer = config.data.the_tokenizer # this is some unpleasant code to allow us to initialize from a hf checkpoint. If this is your first read through, @@ -75,39 +79,35 @@ def main(config: TrainLmConfig): else: converter = None - # initialize training config *after* we've done the hf stuff b/c we might have changed the model config - config.trainer.initialize(config) - - # randomness in jax is tightly controlled by "keys" which are the states of the random number generators - # this makes deterministic training pretty easy - seed = config.trainer.seed - data_key, loader_key, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4) - - # some axes we need - Batch = config.trainer.TrainBatch - EvalBatch = config.trainer.EvalBatch - Pos = config.model.Pos - KeyPos = config.model.KeyPos - - # We have two axis_mappings: one for storing the model and optimizer states, and one for compute - # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh - compute_axis_mapping = config.trainer.compute_axis_mapping - parameter_axis_mapping = config.trainer.parameter_axis_mapping - - def compute_loss(model: LmHeadModel, example: LmExample, key=None): - return model.compute_loss(example, key=key).scalar() - optimizer = config.optimizer.build(config.trainer.num_train_steps) # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp - trainer = Trainer(config.trainer, optimizer, compute_loss) - - eval_datasets = config.data.validation_sets(Pos.size) - train_dataset = CausalLmDataset( - config.data.train_set(Pos.size), Pos, KeyPos, ignore_index=config.data.ignore_token_id - ) + # Using the trainer as a context manager does 3 things: + # 1. Sets the device mesh + # 2. Sets the axis mapping (for fsdp) + # 3. Sets the global metrics tracker + with Trainer(config.trainer, optimizer) as trainer: + # randomness in jax is tightly controlled by "keys" which are the states of the random number generators + # this makes deterministic training pretty easy + seed = config.trainer.seed + data_key, loader_key, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4) + + # We have two axis_mappings: one for storing the model and optimizer states, and one for compute + # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh + compute_axis_mapping = trainer.compute_axis_mapping + parameter_axis_mapping = trainer.parameter_axis_mapping + + # some axes we need + Batch = config.trainer.TrainBatch + EvalBatch = config.trainer.EvalBatch + Pos = config.model.Pos + KeyPos = config.model.KeyPos + + eval_datasets = config.data.validation_sets(Pos.size) + train_dataset = CausalLmDataset( + config.data.train_set(Pos.size), Pos, KeyPos, ignore_index=config.data.ignore_token_id + ) - with trainer.device_mesh: # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of # tokens: gpt-2 has 50257, for example. So we round up. @@ -134,10 +134,11 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): else: logger.info("No checkpoint found. Starting from scratch.") - wandb.summary["parameter_count"] = parameter_count(state.model) - - # boilerplate hooks and such - trainer.add_default_hooks() + levanter.tracker.log_summary( + { + "parameter_count": parameter_count(state.model), + } + ) if len(eval_datasets) == 0: logger.warning("No evaluation datasets provided.") @@ -146,6 +147,7 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id) trainer.add_eval_hook(eval_dataset, name=name) + # Register hooks trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) if config.hf_save_path is not None: full_save_path = os.path.join(config.hf_save_path, trainer.run_id) @@ -183,7 +185,7 @@ def compute_log_probs(model, example: LmExample): # TODO: implement iter_data.seek(resume_step +1) import tqdm - for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): + for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): next(train_loader) ## OK, actually run training! diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index 370b20d59..c43525cd1 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -36,33 +36,28 @@ class VizGpt2Config: def main(config: VizGpt2Config): - config.trainer.initialize(config) + levanter.initialize(config) tokenizer = config.data.the_tokenizer - EvalBatch = Axis("batch", config.trainer.eval_batch_size) - # some axes we use outside the model proper + EvalBatch = config.trainer.EvalBatch Pos = config.model.Pos KeyPos = config.model.KeyPos + validation_set = config.data.validation_set(Pos.size) + assert validation_set is not None eval_loader = ReplicatedBatchLoader( - CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos), # type: ignore - config.trainer.device_mesh, - EvalBatch, + CausalLmDataset(validation_set, Pos, KeyPos), EvalBatch, config.trainer.compute_env ) - # some axes we use outside the model proper - Pos = config.model.Pos - KeyPos = config.model.KeyPos - compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping - with config.trainer.device_mesh, hax.axis_mapping(parameter_axis_mapping): + with config.trainer.param_env: key = jax.random.PRNGKey(0) vocab_size = len(tokenizer) - Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), compute_axis_mapping) + Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size)) if vocab_size != Vocab.size: logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") @@ -83,12 +78,11 @@ def compute_log_probs(model: LmHeadModel, example: LmExample): with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(model, None, config.checkpoint_path) + model = load_checkpoint(model, config.checkpoint_path, subpath="model") - assert ckpt is not None - model, _, _ = ckpt + assert model is not None - model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) + model = hax.shard(model) compute_and_visualize_log_probs( path=config.path, diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py index a6f27c7a5..4531e44d8 100644 --- a/src/levanter/models/gpt2.py +++ b/src/levanter/models/gpt2.py @@ -320,12 +320,8 @@ class Gpt2Embeddings(StateDictSerializationMixin, eqx.Module): def init(Vocab: Axis, config: Gpt2Config, *, key) -> "Gpt2Embeddings": k_wte, k_wpe, k_out = jrandom.split(key, 3) - token_embeddings = hnn.Embedding.init( - Vocab, config.Embed, key=k_wte, initializer_range=config.initializer_range - ) - position_embeddings = hnn.Embedding.init( - config.Pos, config.Embed, key=k_wpe, initializer_range=config.initializer_range / 2 - ) + token_embeddings = hnn.Embedding.init(Vocab, config.Embed, config.initializer_range, key=k_wte) + position_embeddings = hnn.Embedding.init(config.Pos, config.Embed, config.initializer_range / 2, key=k_wpe) dropout = hnn.Dropout(pdrop=config.embed_pdrop) return Gpt2Embeddings(Vocab, config, token_embeddings, position_embeddings, dropout) diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 0d0a7d70a..b68a33f2b 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -117,19 +117,23 @@ def compute_loss( key=None, reduction: Optional[hax.ReductionFunction] = hax.mean, reduction_axis: Optional[hax.AxisSelection] = None, - ) -> NamedArray: + ) -> jnp.ndarray | NamedArray: """ Computes the cross-entropy loss for a language modeling example. If reduction is not None, the loss is reduced across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not reduced, and the result is a named array with axes (*batch axes, sequence_length). """ logits = self(example.tokens, example.attn_mask, key=key) + # TODO: would be nice if we made the dtype configurable + logits = logits.astype(jnp.float32) targets = hax.roll(example.tokens, -1, axis=self.Pos.name) target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype) - return cross_entropy_loss( + loss = cross_entropy_loss( logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask ) + return loss + @property def vocab_size(self) -> int: return self.Vocab.size diff --git a/src/levanter/models/mpt.py b/src/levanter/models/mpt.py index 9c31e63b6..e01c43c0c 100644 --- a/src/levanter/models/mpt.py +++ b/src/levanter/models/mpt.py @@ -475,8 +475,7 @@ def from_hf_pretrained( lev_model = eqx.filter_eval_shape(MptLmHeadModel.init, Vocab, lev_config, key=PRNGKey(0)) lev_model = lev_model.from_state_dict(state_dict) - if axis_mapping is not None: - lev_model = haliax.shard_with_axis_mapping(lev_model, axis_mapping) + lev_model = haliax.shard(lev_model, axis_mapping) return lev_model diff --git a/src/levanter/optim/__init__.py b/src/levanter/optim/__init__.py new file mode 100644 index 000000000..319ddf84d --- /dev/null +++ b/src/levanter/optim/__init__.py @@ -0,0 +1,16 @@ +from .config import AdamConfig, OptimizerConfig +from .second_order import ( + AnySecondOrderTransformation, + HessianUpdateFn, + SecondOrderTransformation, + chain_second_order, + inject_hyperparams, +) +from .sophia import ( + ScaleBySophiaState, + SophiaGConfig, + SophiaGObjective, + SophiaHConfig, + scale_by_sophia_g, + scale_by_sophia_h, +) diff --git a/src/levanter/optim/config.py b/src/levanter/optim/config.py new file mode 100644 index 000000000..ad8708c6a --- /dev/null +++ b/src/levanter/optim/config.py @@ -0,0 +1,160 @@ +import abc +import re +import warnings +from dataclasses import dataclass +from typing import Optional + +import draccus +import equinox as eqx +import jax +import optax +from jax import numpy as jnp + +from levanter.utils.jax_utils import leaf_key_paths + + +@dataclass +class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): + learning_rate: float = 6e-4 + weight_decay: float = 0.0 + + min_lr_ratio: float = 0.1 + warmup_ratio: Optional[float] = None # Deprecated. fraction of training steps to use as warmup + warmup: float = 0.01 + """fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup""" + cooldown: float = 0.0 + """fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown""" + lr_schedule: str = "cosine" # constant, cosine, linear + weight_decay_modules: Optional[list[str] | str] = None + """A regex or a list of strings to identify where to mask weight. + For nano-GPT, this field can be set as `r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings"`""" + + @classmethod + def default_choice_name(cls) -> Optional[str]: + return "adam" + + @abc.abstractmethod + def build(self, num_train_steps: int): + raise NotImplementedError + + def build_weight_decay_mask(self): + if self.weight_decay_modules is None: + return None + else: + # mask based on regex or module path + def _apply_on(x, key_path): + if isinstance(self.weight_decay_modules, str): + compiled_regex = re.compile(self.weight_decay_modules) + return compiled_regex.match(key_path) is not None + else: + return any(key_path.__contains__(target) for target in self.weight_decay_modules) + + def mask_fn(model): + return jax.tree_util.tree_map( + _apply_on, + model, + leaf_key_paths(model, is_leaf=eqx.is_array), + is_leaf=eqx.is_array, + ) + + return mask_fn + + def lr_scheduler(self, num_train_steps): + warmup_steps = self._convert_warmup(num_train_steps) + cooldown_steps = _convert_ratio_or_steps(self.cooldown, num_train_steps) + lr_decay_steps = num_train_steps - warmup_steps - cooldown_steps + min_lr = self.learning_rate * self.min_lr_ratio + + match self.lr_schedule: + case "constant": + schedule = optax.constant_schedule(self.learning_rate) + case "cosine": + schedule = optax.cosine_decay_schedule(self.learning_rate, lr_decay_steps, self.min_lr_ratio) + case "linear": + schedule = optax.linear_schedule(self.learning_rate, min_lr, lr_decay_steps - warmup_steps) + case "inv_sqrt": + schedule = _inv_sqrt_decay_schedule(self.learning_rate, min_lr, warmup_steps, 10000) + case _: + raise ValueError(f"Unknown lr_schedule: {self.lr_schedule}") + + schedules = [] + boundaries = [] + + if warmup_steps != 0: + warmup = optax.linear_schedule(0.0, self.learning_rate, warmup_steps) + schedules.append(warmup) + boundaries.append(warmup_steps) + + schedules.append(schedule) + + if cooldown_steps != 0: + final_main_lr = schedule(lr_decay_steps) + cooldown = optax.linear_schedule(final_main_lr, min_lr, cooldown_steps) + schedules.append(cooldown) + boundaries.append(num_train_steps - cooldown_steps) + + if len(schedules) > 1: + schedule = optax.join_schedules(schedules, boundaries) + + return schedule + + def _convert_warmup(self, num_train_steps: int): + if self.warmup_ratio is not None: + warnings.warn("warmup_ratio is deprecated. Use warmup instead") + return int(self.warmup_ratio * num_train_steps) + else: + return _convert_ratio_or_steps(self.warmup, num_train_steps) + + +def _inv_sqrt_decay_schedule(lr: float, min_lr: float, warmup_steps: int, timescale: float = 10000): + def schedule(count): + decay = jnp.minimum(1.0, 1.0 / jnp.sqrt(jnp.maximum(count + warmup_steps, 1) / timescale)) + return jnp.maximum(lr * decay, min_lr) + + return schedule + + +def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int): + if ratio_or_steps < 1.0: + return int(ratio_or_steps * num_train_steps) + else: + return int(ratio_or_steps) + + +@dataclass +class HessianOptConfig(OptimizerConfig, abc.ABC): + update_interval: int = 10 + """How often to update the hessian approximation.""" + + +@OptimizerConfig.register_subclass("adam") +@dataclass +class AdamConfig(OptimizerConfig): + weight_decay: float = 0.1 + beta1: float = 0.9 + beta2: float = 0.999 + epsilon: float = 1e-8 + max_grad_norm: Optional[float] = 1.0 + + def build(self, num_train_steps): + """Creates the optimizer""" + # indirection makes it work with optax.inject_hyperparams so we can log the learning rate + def _optimizer(learning_rate): + components = [] + + if self.max_grad_norm: + components.append(optax.clip_by_global_norm(self.max_grad_norm)) + + components.append(optax.scale_by_adam(self.beta1, self.beta2, self.epsilon)) + + if self.weight_decay > 0: + components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) + + # - learning rate for descent + components.append(optax.scale(-learning_rate)) + + optimizer = optax.chain(*components) + + return optimizer + + return optax.inject_hyperparams(_optimizer)(learning_rate=self.lr_scheduler(num_train_steps)) diff --git a/src/levanter/optim/second_order.py b/src/levanter/optim/second_order.py new file mode 100644 index 000000000..036c6e157 --- /dev/null +++ b/src/levanter/optim/second_order.py @@ -0,0 +1,233 @@ +import functools +import inspect +import typing +from typing import Callable, Iterable, List, NamedTuple, Optional, Union + +import chex +import jax +import optax +from jax import numpy as jnp +from optax import InjectHyperparamsState + + +class HessianUpdateFn(typing.Protocol): + """A callable type for the""" + + def __call__( + self, + state, + fn, + model, + *batch, + **batch_kwargs, + ) -> optax.OptState: + """Returns the updated `state` given the `hessian` and `state`.""" + pass + + +class SecondOrderTransformation(NamedTuple): + """A triple of pure functions that together define a second-order optimizer.""" + + init: optax.TransformInitFn + update: optax.TransformUpdateFn + update_hessian: HessianUpdateFn + + +AnySecondOrderTransformation = Union[SecondOrderTransformation, optax.GradientTransformation] +"""A type that can be used to represent either a first or second order transformation.""" + + +def chain_second_order(*args: AnySecondOrderTransformation) -> SecondOrderTransformation: + """Applies a list of chainable update transformations. Analogous to optax.chain, + but for second order transformations. + """ + + init_fns = [] + update_fns = [] + update_hessian_fns: List[Optional[HessianUpdateFn]] = [] + + for arg in args: + if isinstance(arg, SecondOrderTransformation): + init_fns.append(arg.init) + update_fns.append(arg.update) + update_hessian_fns.append(arg.update_hessian) + else: + init_fns.append(arg.init) + update_fns.append(arg.update) + update_hessian_fns.append(None) + + def init_fn(params): + return tuple(fn(params) for fn in init_fns) + + def update_fn(updates, state, params=None): + if len(update_fns) != len(state): + raise ValueError( + "The number of updates and states has to be the same in chain! Make sure you have called init first!" + ) + + new_state = [] + for s, fn in zip(state, update_fns): + updates, new_s = fn(updates, s, params) + new_state.append(new_s) + return updates, tuple(new_state) + + def update_hessian_fn(state, fn, model, *batch, **batch_kwargs): + if len(update_hessian_fns) != len(state): + raise ValueError( + "The number of updates and states has to be the same in chain! Make sure you have called init first!" + ) + + new_state = [] + for s, update_fn in zip(state, update_hessian_fns): + if update_fn is None: + new_state.append(s) + else: + new_s = update_fn(s, fn, model, *batch, **batch_kwargs) + new_state.append(new_s) + return tuple(new_state) + + return SecondOrderTransformation(init_fn, update_fn, update_hessian_fn) + + +def inject_hyperparams( + inner_factory: Callable[..., SecondOrderTransformation], + static_args: Union[str, Iterable[str]] = (), + hyperparam_dtype: Optional[jnp.dtype] = None, +) -> Callable[..., SecondOrderTransformation]: + """ + Second Order version of optax.inject_hyperparams. + + Original docstring: + + Wrapper that injects hyperparameters into the inner GradientTransformation. + + This wrapper allows you to pass schedules (i.e. a function that returns a + numeric value given a step count) instead of constants for + hyperparameters. You may only schedule numeric hyperparameters (i.e. boolean + flags cannot be scheduled). + + For example, to use ``scale_by_adam`` with a piecewise linear + schedule for beta_1 and constant for beta_2:: + + scheduled_adam = optax.inject_hyperparams(optax.scale_by_adam)( + b1=optax.piecewise_linear_schedule(...), + b2=0.99) + + You may manually change numeric hyperparameters that were not scheduled + through the ``hyperparams`` dict in the ``InjectHyperparamState``:: + + state = scheduled_adam.init(params) + updates, state = scheduled_adam.update(grads, state) + state.hyperparams['b2'] = 0.95 + updates, state = scheduled_adam.update(updates, state) # uses b2 = 0.95 + + Manually overriding scheduled hyperparameters will have no effect (e.g. + in the code sample above, you cannot manually adjust ``b1``). + + Args: + inner_factory: a function that returns the inner + ``optax.GradientTransformation`` given the hyperparameters. + static_args: a string or iterable of strings specifying which + callable parameters are not schedules. inject_hyperparams treats all + callables as schedules by default, so if a hyperparameter is a + non-schedule callable, you must specify that using this argument. + hyperparam_dtype: Optional datatype override. If specified, all float + hyperparameters will be cast to this type. + + Returns: + A callable that returns a ``optax.GradientTransformation``. This callable + accepts the same arguments as ``inner_factory``, except you may provide + schedules in place of the constant arguments. + """ + static_args = {static_args} if isinstance(static_args, str) else set(static_args) + inner_signature = inspect.signature(inner_factory) + + if not static_args.issubset(inner_signature.parameters): + raise ValueError( + "`static_args` must specify a subset of `inner_factory`'s parameters. " + f"Given `static_args`: {static_args}. `inner_factory` parameters: " + f"{set(inner_signature.parameters.keys())}" + ) + + @functools.wraps(inner_factory) + def wrapped_transform(*args, **kwargs) -> SecondOrderTransformation: + bound_arguments = inner_signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + + sched_hps, numeric_hps, other_hps = {}, {}, {} + for name, value in bound_arguments.arguments.items(): + if name in static_args or isinstance(value, bool): + other_hps[name] = value + elif callable(value): + sched_hps[name] = value + elif isinstance(value, (int, float, chex.Array)): + numeric_hps[name] = value + else: + other_hps[name] = value + + def schedule_fn(count, dtype): + return {k: _convert_floats(f(count), dtype) for k, f in sched_hps.items()} + + def init_fn(params): + count = jnp.zeros([], jnp.int32) + if hyperparam_dtype is None: + dtype = _find_first_floating_dtype(numeric_hps) + else: + dtype = hyperparam_dtype + hparams = {k: jnp.asarray(_convert_floats(v, dtype)) for k, v in numeric_hps.items()} + hparams.update(schedule_fn(count, dtype)) + return InjectHyperparamsState( # pylint:disable=too-many-function-args + count, hparams, inner_factory(**other_hps, **hparams).init(params) + ) + + def update_fn(updates, state, params=None): + if hyperparam_dtype is None: + dtype = _find_first_floating_dtype(updates) + else: + dtype = hyperparam_dtype + hparams = {k: _convert_floats(v, dtype) for k, v in state.hyperparams.items()} + hparams.update(schedule_fn(state.count, dtype)) + updates, inner_state = inner_factory(**other_hps, **hparams).update(updates, state.inner_state, params) + # pylint:disable=too-many-function-args + return updates, InjectHyperparamsState(state.count + 1, hparams, inner_state) + # pylint:enable=too-many-function-args + + def _find_first_floating_dtype(updates): + dtype = jnp.float32 + for v in jax.tree_util.tree_leaves(updates): + if isinstance(v, jnp.ndarray): + if isinstance(v.dtype, jnp.floating): + dtype = v.dtype + break + return dtype + + def update_hessian(state, fn, model, *batch, **batch_kwargs): + if hyperparam_dtype is None: + dtype = _find_first_floating_dtype(batch) + else: + dtype = hyperparam_dtype + hparams = {k: _convert_floats(v, dtype) for k, v in state.hyperparams.items()} + hparams.update(schedule_fn(state.count, dtype)) + new_inner_state = inner_factory(**other_hps, **hparams).update_hessian( + state.inner_state, + fn, + model, + *batch, + **batch_kwargs, + ) + + # pylint:disable=too-many-function-args + return InjectHyperparamsState(state.count, hparams, new_inner_state) + # pylint:enable=too-many-function-args + + return SecondOrderTransformation(init_fn, update_fn, update_hessian) + + return wrapped_transform + + +# Cribbed from optax._src.schedule, which recently deleted this function. +def _convert_floats(x, dtype): + """Convert float-like inputs to dtype, rest pass through.""" + if jax.dtypes.scalar_type_of(x) == float: + return jnp.asarray(x, dtype=dtype) + return x diff --git a/src/levanter/optim/sophia.py b/src/levanter/optim/sophia.py new file mode 100644 index 000000000..9df275c29 --- /dev/null +++ b/src/levanter/optim/sophia.py @@ -0,0 +1,428 @@ +import abc +import functools +import typing +from dataclasses import dataclass +from typing import Any, NamedTuple, Optional, TypeVar, runtime_checkable + +import equinox as eqx +import jax +import jaxtyping +import optax +from jax import numpy as jnp +from jax.random import PRNGKey +from jaxtyping import PRNGKeyArray + +import levanter.tracker +from levanter.optim.config import HessianOptConfig, OptimizerConfig +from levanter.optim.second_order import SecondOrderTransformation, chain_second_order, inject_hyperparams +from levanter.optim.util import hvp, tree_gaussian_like +from levanter.utils.jax_utils import parameter_count, tree_filter_like + + +M = TypeVar("M") +Ex = TypeVar("Ex") + +GAMMA_SOPHIA_G = 0.05 +GAMMA_SOPHIA_H = 0.01 + + +class ScaleBySophiaState(NamedTuple): + """State for Sophia and similar.""" + + count: jaxtyping.Array # shape=(), dtype=jnp.int32. + hessian_count: jaxtyping.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates # momentum + h: optax.Updates # EMA of hessian diagonal + hess_key: PRNGKey + + +@runtime_checkable +class SophiaGObjective(typing.Protocol): + """ + Class for objective functions that can be used with Sophia-G + + Sophia-G is a second order optimizer that uses the Gauss-Newton-Bartlett approximation to the Hessian + to compute the second order update. This requires the objective function be of the form loss(logits(x)) + where logits(x) is the activation of the model for the given example x. This is the case for most models + that are trained with "typical" losses. + """ + + def logits(self, parameters: M, example: Ex, *args, **kwargs) -> Any: + """ + Returns the logits/activations of the model for the given example, + or just sufficient statistics for the example for non-categorical models. + """ + ... + + def sample(self, logits, example: Ex, *, key: PRNGKey) -> Ex: + """ + Samples a new example with the same shape as the original example, but with + the "labels" replaced with some sampled values + """ + ... + + def loss(self, logits, example: Ex): + """ + Just computes the loss, e.g. cross entropy. + + Should return the mean loss over the batch, not the sum. + + TODO: should we reconsider this? + """ + ... + + def __call__(self, parameters: M, example: Ex, *args, **kwargs): + """ + Just a convenience method for invoking the objective for "normal" training w/o sophia-g + """ + logits = self.logits(parameters, example, *args, **kwargs) + return self.loss(logits, example) + + def num_data_points(self, example: Ex) -> int: + """ + Returns the number of data points in the example. This should take into account the loss mask + or any other masking that might be applied to the example. + + By default, we just return 1, and you can just pull the term into the hyperparams of Sophia if you want. + + Returns: + The number of data points in the example + """ + return 1 + + +@dataclass +class BaseSophiaConfig(HessianOptConfig): + """Base class for sophia variants. Doesn't implement the state update""" + + weight_decay: float = 0.1 + beta1: float = 0.96 + beta2: float = 0.99 + + epsilon: float = 1e-12 + clip_threshold: Optional[float] = 1.0 + rng_seed: int = 0 + + @abc.abstractmethod + def compute_hessian( + self, + fn, + model, + *batch, + hess_key: PRNGKey, + **batch_kwargs, + ): + raise NotImplementedError + + def build(self, num_train_steps: int): + def _optimizer(learning_rate, gamma) -> SecondOrderTransformation: + components = [] + key = jax.random.PRNGKey(self.rng_seed) + + components.append( + _sophia_gradient_transform( + sophia_hess_fn=self.compute_hessian, + update_interval=self.update_interval, + b1=self.beta1, + b2=self.beta2, + eps=self.epsilon, + gamma=gamma, + initial_key=key, + clip_threshold=self.clip_threshold, + ) + ) + + # Algorithm 3, step 11 (Note, this comes after clipping b/c it's not supposed to be clipped) + # In the paper, it comes as a prior step, but doesn't get clipped + if self.weight_decay > 0: + components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) + + # - learning rate for descent + components.append(optax.scale(-learning_rate)) + + optimizer = chain_second_order(*components) + + return optimizer + + # Hong suggested using cosine decay for gamma + # gamma_decay_schedule = optax.cosine_decay_schedule(self.gamma, num_train_steps // 2, 0) # type: ignore + constant_gamma_schedule = optax.constant_schedule(self.gamma) # type: ignore + # gamma_schedule = optax.join_schedules([constant_gamma_schedule, gamma_decay_schedule], [num_train_steps // 2]) + + return inject_hyperparams(_optimizer)( + learning_rate=self.lr_scheduler(num_train_steps), gamma=constant_gamma_schedule + ) + + +@OptimizerConfig.register_subclass("sophia-g") +@dataclass +class SophiaGConfig(BaseSophiaConfig): + gamma: float = GAMMA_SOPHIA_G + + def compute_hessian(self, fn, model, *batch, hess_key: PRNGKey, **batch_kwargs): + return stochastic_diag_gauss_newton(fn, model, *batch, **batch_kwargs, hess_key=hess_key) + + +@OptimizerConfig.register_subclass("sophia-h") +@dataclass +class SophiaHConfig(BaseSophiaConfig): + gamma: float = GAMMA_SOPHIA_H + + def compute_hessian(self, fn, model, *batch, hess_key: PRNGKey, **batch_kwargs): + return stochastic_hessian_diagonal(fn, model, *batch, **batch_kwargs, hess_key=hess_key) + + +def sophia_h( + lr: float = 0.85e-3, + *, + b1: float = 0.965, + b2: float = 0.99, + eps: float = 1e-8, + gamma: float = GAMMA_SOPHIA_H, + weight_decay: float = 0.0, + clip_threshold: Optional[float] = 1.0, + update_interval: int = 10, + key: PRNGKey, +) -> SecondOrderTransformation: + """Sophia-H: https://arxiv.org/pdf/2305.14342.pdf Algorithm 1&3""" + components = [] + + components.append(scale_by_sophia_h(b1, b2, eps, gamma, clip_threshold, update_interval, key=key)) + + if weight_decay > 0: + components.append(optax.add_decayed_weights(weight_decay)) + + components.append(optax.scale(-lr)) + + return chain_second_order(*components) + + +def scale_by_sophia_h( + b1=0.965, + b2=0.99, + eps=1e-8, + gamma=GAMMA_SOPHIA_H, + clip_threshold: Optional[float] = 1.0, + update_interval=10, + *, + key: PRNGKey, +): + + return _sophia_gradient_transform( + sophia_hess_fn=stochastic_hessian_diagonal, + update_interval=update_interval, + b1=b1, + b2=b2, + eps=eps, + gamma=gamma, + clip_threshold=clip_threshold, + initial_key=key, + ) + + +def sophia_g( + lr: float = 1e-3, + *, + b1: float = 0.99, + b2: float = 0.99, + eps: float = 1e-8, + gamma: float = GAMMA_SOPHIA_G, + weight_decay: float = 0.0, + clip_threshold: Optional[float] = 1.0, + update_interval: int = 10, + key: PRNGKey, +) -> SecondOrderTransformation: + """Sophia-G: https://arxiv.org/pdf/2305.14342.pdf Algorithm 2&3""" + components = [] + + components.append(scale_by_sophia_g(b1, b2, eps, gamma, clip_threshold, update_interval, key=key)) + + if weight_decay > 0: + components.append(optax.add_decayed_weights(weight_decay)) + + components.append(optax.scale(-lr)) + + return chain_second_order(*components) + + +def scale_by_sophia_g( + b1: float = 0.99, + b2: float = 0.99, + eps: float = 1e-8, + gamma: float = GAMMA_SOPHIA_G, + clip_threshold: Optional[float] = 1.0, + update_interval=10, + *, + key: PRNGKeyArray, +): + + return _sophia_gradient_transform( + sophia_hess_fn=stochastic_diag_gauss_newton, + update_interval=update_interval, + b1=b1, + b2=b2, + eps=eps, + gamma=gamma, + clip_threshold=clip_threshold, + initial_key=key, + ) + + +def _sophia_gradient_transform( + sophia_hess_fn, + update_interval: int, + b1: float, + b2: float, + eps: float, + gamma: float, + clip_threshold: Optional[float], + initial_key: PRNGKeyArray, + mu_dtype: Optional[Any] = None, +) -> SecondOrderTransformation: + mu_dtype = jax.canonicalize_dtype(mu_dtype) if mu_dtype is not None else None + + def init_fn(params): + mu = jax.tree_util.tree_map(lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) # First moment + h = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment + return ScaleBySophiaState( + count=jnp.zeros([], jnp.int32), hessian_count=jnp.zeros([], jnp.int32), mu=mu, h=h, hess_key=initial_key + ) + + def update_fn(updates, state, params=None): + mu = update_moment(updates, state.mu, b1, 1) + # nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) + mu_hat = bias_correction(mu, b1, state.count + 1) + h_hat = state.h + # track how often hessian is used + mu_leaves = jax.tree_util.tree_leaves(mu_hat) + h_leaves = jax.tree_util.tree_leaves(h_hat) + + stats: dict[str, Any] = { + "optim/param_norm": jnp.sqrt(sum(jnp.sum(p**2) for p in jax.tree_util.tree_leaves(params))), + "optim/momentum_norm": jnp.sqrt(sum(jnp.sum(m**2) for m in mu_leaves)), + "optim/hessian_norm": jnp.sqrt(sum(jnp.sum(h**2) for h in h_leaves)), + } + + # with sophia-g the max(h, 0) is not needed but no harm + updates = jax.tree_util.tree_map( + # lambda m, v: m / jnp.maximum(jnp.maximum(jnp.abs(m), gamma * jnp.maximum(v, 0)), eps), mu_hat, h_hat + lambda m, h: m / jnp.maximum(gamma * h, eps), + mu_hat, + h_hat, + ) + + if clip_threshold is not None: + unclipped_count = sum(jnp.sum(jnp.abs(u) < clip_threshold) for u in jax.tree_util.tree_leaves(updates)) + updates = jax.tree_util.tree_map(lambda u: jnp.clip(u, -clip_threshold, clip_threshold), updates) + stats["optim/unclipped_fraction"] = unclipped_count / parameter_count(updates) + + # this doesn't work well on CPU, so skip if cpu + if jax.lib.xla_bridge.get_backend().platform != "cpu": + levanter.tracker.jit_log_metrics(stats, step=state.count) + + if mu_dtype is not None: + mu = jax.tree_util.tree_map(lambda t: t.astype(mu_dtype), mu) + + return updates, ScaleBySophiaState( + count=state.count + 1, hessian_count=state.hessian_count, mu=mu, h=h_hat, hess_key=state.hess_key + ) + + def update_hessian(state, fn, model, *batch, **batch_kwargs): + def _do_update(): + key, next_key = jax.random.split(state.hess_key) + new_hess = sophia_hess_fn(fn, model, *batch, hess_key=key, **batch_kwargs) + + new_hess = tree_filter_like(state.h, new_hess) + + # EMAs of hessian + nu = update_moment(new_hess, state.h, b2, 1) + return ScaleBySophiaState( + count=state.count, hessian_count=state.hessian_count + 1, mu=state.mu, h=nu, hess_key=next_key + ) + + def _dont_update(): + return state + + return jax.lax.cond( + jnp.equal(state.count % update_interval, 0), + lambda _: _do_update(), + lambda _: _dont_update(), + state.count, + ) + + return SecondOrderTransformation(init_fn, update_fn, update_hessian) + + +# use this for Sophia-G +def stochastic_diag_gauss_newton(fn: SophiaGObjective, model, example, *args, hess_key: PRNGKey, **kwargs): + """ + + Approximate the diagonal of the Hessian using an approximation to the Gauss Newton matrix. + This is Algorithm 2 of https://arxiv.org/pdf/2305.14342.pdf + + Args: + fn (SophiaGObjective): objective function + model: model whose Hessian to compute + hess_key: key for sampling + *args, **kwargs: passed to fn's logits + """ + if not isinstance(fn, SophiaGObjective): + raise ValueError("objective must be a SophiaGObjective") + + # Step 3 + logits, model_backward = eqx.filter_vjp(lambda model: fn.logits(model, example, *args, **kwargs), model) + + # Step 4 + y_hat = fn.sample(logits, example, key=hess_key) + + # Step 5 + grad_loss_logits = eqx.filter_grad(fn.loss)(logits, y_hat) + pseudo_g = model_backward(grad_loss_logits)[0] + + # Step 6 + bs = fn.num_data_points(example) + h = jax.tree_util.tree_map(lambda x: x**2 * bs, pseudo_g) + + return h + + +# Use this for Sophia-H +def stochastic_hessian_diagonal(fn, model, *args, hess_key: PRNGKey, **kwargs): + """Compute the diagonal of the Hessian of a function using a normal distribution. + + https://arxiv.org/pdf/2305.14342.pdf Algorithm 1 + + Args: + fn: function to compute the Hessian of + model: model to compute the Hessian of + hess_key: key for the normal distribution + """ + # cf https://arxiv.org/pdf/2006.00719.pdf eqn 9 + # https://www-users.cse.umn.edu/~saad/PDF/umsi-2005-082.pdf + # https://arxiv.org/pdf/2208.03268.pdf + g = tree_gaussian_like(hess_key, model) + # TODO: consider allowing for n > 1 gaussians? + product = hvp(lambda m: fn(m, *args, **kwargs), model, g) + hessian = jax.tree_util.tree_map(lambda grad, gaussian: grad * gaussian, product, g) + + return hessian + + +# Cribbed from optax._src.transform +def update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order`-th moment.""" + return jax.tree_util.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +@functools.partial(jax.jit, inline=True) +def bias_correction(moment, decay, count): + """Performs bias correction. It becomes a no-op as count goes to infinity.""" + # The conversion to the data type of the moment ensures that bfloat16 remains + # bfloat16 in the optimizer state. This conversion has to be done after + # `bias_correction_` is calculated as calculating `decay**count` in low + # precision can result in it being rounded to 1 and subsequently a + # "division by zero" error. + bias_correction_ = 1 - decay**count + + # Perform division in the original precision. + return jax.tree_util.tree_map(lambda t: t / bias_correction_.astype(t.dtype), moment) diff --git a/src/levanter/optim/util.py b/src/levanter/optim/util.py new file mode 100644 index 000000000..7fd3a41df --- /dev/null +++ b/src/levanter/optim/util.py @@ -0,0 +1,23 @@ +import equinox as eqx +import jax + +from levanter.utils.jax_utils import is_inexact_arrayish + + +def hvp(f, x, v): + """Compute the Hessian-vector product of a function.""" + return eqx.filter_jvp(eqx.filter_grad(f), (x,), (v,))[1] + + +def tree_gaussian_like(key, tree): + """ + Samples a tree of gaussian noise with the same structure as `tree`, except for leaves which are not inexact arrays, + for which it returns None + """ + leaves, structure = jax.tree_util.tree_flatten(tree) + keys = jax.random.split(key, len(leaves)) + rand_n = lambda x, key: jax.random.normal(key, x.shape) if is_inexact_arrayish(x) else None + g = jax.tree_util.tree_map(rand_n, leaves, list(keys)) + g = jax.tree_util.tree_unflatten(structure, g) + + return g diff --git a/src/levanter/tensorstore_serialization.py b/src/levanter/tensorstore_serialization.py index 9809665f3..25f7fa594 100644 --- a/src/levanter/tensorstore_serialization.py +++ b/src/levanter/tensorstore_serialization.py @@ -13,11 +13,10 @@ import jax.tree_util as jtu import numpy as np import tensorstore -from jax.sharding import Mesh from tensorstore import TensorStore import haliax as hax -from haliax.partitioning import ResourceMapping +import haliax.tree_util as htu from haliax.util import is_named_array from levanter.utils import jax_utils @@ -26,15 +25,17 @@ logger = logging.getLogger(__name__) +def _is_named_or_none(x): + return x is None or is_named_array(x) + + def tree_serialize_leaves_tensorstore(checkpoint_dir, pytree): - leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=is_named_array) - specs = jtu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths, is_leaf=is_named_array) + leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=_is_named_or_none) + specs = jtu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths, is_leaf=_is_named_or_none) # TODO: jax array_ser has a fancy async manager thing to checkpoint while training, would be good but not right now. - # array_ser only supports saving sharded arrays, so we can't use its top-level function run_serialization. - # however we're inspired by its implementation, meaning we'll make a tree of futures and wait on them. async def _do_serialize(): - futures = jtu.tree_map(_serialize_one_leaf, pytree, specs, is_leaf=is_named_array) + futures = jtu.tree_map(_serialize_one_leaf, pytree, specs, is_leaf=_is_named_or_none) return await asyncio.gather(*jtu.tree_leaves(futures)) asyncio.run(_do_serialize()) @@ -86,9 +87,9 @@ async def load_array_from_tensorstore(spec): return await t.read("C") -async def _deserialize_one_leaf(like, spec, axis_mapping, mesh): +async def _deserialize_one_leaf(like, spec, env): if is_named_array(like): - return await _deserialize_named_array(like, spec, axis_mapping, mesh) + return await _deserialize_named_array(like, spec, env) elif isinstance(like, jax.Array): if not like.is_fully_addressable: return await array_ser.async_deserialize(like.sharding, spec, global_shape=like.shape, dtype=like.dtype) @@ -105,22 +106,20 @@ async def _deserialize_one_leaf(like, spec, axis_mapping, mesh): raise TypeError(f"Can't deserialize {type(like)}") -async def _deserialize_named_array(like, spec, axis_mapping, mesh): +async def _deserialize_named_array(like, spec, env): # the main thing we're worried about is deserialized NamedArrays that are not yet arrays but are ShapedDtypeStructs. # These don't (currently) have sharding info, but we can infer it from the axes if isinstance(like.array, jax.ShapeDtypeStruct): - sharding = hax.partitioning.sharding_for_axis(like.axes, axis_mapping, mesh) + sharding = hax.partitioning.sharding_for_axis(like.axes, env) array = await array_ser.async_deserialize(sharding, spec, global_shape=like.array.shape, dtype=like.dtype) assert sharding.is_equivalent_to(array.sharding, len(like.array.shape)) return hax.NamedArray(array, like.axes) else: - array = await _deserialize_one_leaf(like.array, spec, axis_mapping, mesh) + array = await _deserialize_one_leaf(like.array, spec, env) return hax.NamedArray(array, like.axes) -def tree_deserialize_leaves_tensorstore( - checkpoint_dir, pytree, axis_mapping: Optional[ResourceMapping] = None, mesh: Optional[Mesh] = None -): +def tree_deserialize_leaves_tensorstore(checkpoint_dir, pytree, env: Optional[hax.ResourceEnv] = None): """ Deserializes a PyTree of Arrays and NamedArrays from a Tensorstore checkpoint, returning a pytree with the same shape as the one provided. This method is capable of deserializing NamedArrays that are the result of an eval_shape call @@ -135,10 +134,11 @@ def tree_deserialize_leaves_tensorstore( :return: a pytree with the same shape as the exemplar pytree, but with the arrays deserialized from the checkpoint """ # TODO: support ShapeDtypeStructs that are not NamedArrays + env = env or hax.current_resource_env() leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=is_named_array) - specs = jtu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths, is_leaf=is_named_array) + specs = htu.tree_map(partial(_tensorstore_spec_for, checkpoint_dir), leaf_key_paths) - deser_partial = functools.partial(_deserialize_one_leaf, axis_mapping=axis_mapping, mesh=mesh) + deser_partial = functools.partial(_deserialize_one_leaf, env=env) async def _do_deserialize(): futures = jtu.tree_map(deser_partial, pytree, specs, is_leaf=is_named_array) diff --git a/src/levanter/tracker/__init__.py b/src/levanter/tracker/__init__.py new file mode 100644 index 000000000..69156c6a6 --- /dev/null +++ b/src/levanter/tracker/__init__.py @@ -0,0 +1,29 @@ +from levanter.tracker.helpers import log_optimizer_hyperparams +from levanter.tracker.tracker import CompositeTracker, NoopConfig, NoopTracker, Tracker, TrackerConfig +from levanter.tracker.tracker_fns import ( + current_tracker, + get_tracker, + jit_log_metrics, + log_configuration, + log_hyperparameters, + log_metrics, + log_summary, + set_global_tracker, +) + + +__all__ = [ + "Tracker", + "TrackerConfig", + "CompositeTracker", + "log_optimizer_hyperparams", + "NoopTracker", + "current_tracker", + "get_tracker", + "jit_log_metrics", + "log_configuration", + "log_metrics", + "log_summary", + "log_hyperparameters", + "set_global_tracker", +] diff --git a/src/levanter/tracker/helpers.py b/src/levanter/tracker/helpers.py new file mode 100644 index 000000000..1091840c5 --- /dev/null +++ b/src/levanter/tracker/helpers.py @@ -0,0 +1,75 @@ +import dataclasses +import logging +import os +from typing import Optional + +from git import InvalidGitRepositoryError, NoSuchPathError, Repo + +import levanter.tracker +from levanter.utils.jax_utils import jnp_to_python + + +logger = logging.getLogger(__name__) + + +def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): + try: + from optax._src.wrappers import MultiStepsState + + if isinstance(opt_state, MultiStepsState): + opt_state = opt_state.inner_opt_state + except ImportError: + pass + + def wrap_key(key): + if prefix: + return f"{prefix}/{key}" + return key + + if hasattr(opt_state, "hyperparams"): + params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} + levanter.tracker.log_metrics(params, step=step) + + +def hparams_to_dict(hparams, **extra_hparams): + if hparams is None: + hparams_to_save = {} + elif dataclasses.is_dataclass(hparams): + hparams_to_save = dataclasses.asdict(hparams) + else: + hparams_to_save = dict(hparams) + if extra_hparams: + hparams_to_save.update(extra_hparams) + return hparams_to_save + + +def infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: + # sniff out the main directory (since we typically don't run from the root of the repo) + # we'll walk the stack and directories for the files in the stack the until we're at a git root + import os + import traceback + + stack = traceback.extract_stack() + # start from the top of the stack and work our way down since we want to hit the main file first + top_git_root = None + for frame in stack: + dirname = os.path.dirname(frame.filename) + # bit hacky but we want to skip anything that's in the python env + if any(x in dirname for x in ["site-packages", "dist-packages", "venv", "opt/homebrew", "conda", "pyenv"]): + continue + # see if it's under a git root + try: + repo = Repo(dirname, search_parent_directories=True) + top_git_root = repo.working_dir + break + except (NoSuchPathError, InvalidGitRepositoryError): + logger.debug(f"Skipping {dirname} since it's not a git root") + pass + return top_git_root + + +def generate_pip_freeze(): + from importlib.metadata import distributions + + dists = distributions() + return "\n".join(f"{dist.name}=={dist.version}" for dist in dists) diff --git a/src/levanter/tracker/tensorboard.py b/src/levanter/tracker/tensorboard.py new file mode 100644 index 000000000..bd3ee70ba --- /dev/null +++ b/src/levanter/tracker/tensorboard.py @@ -0,0 +1,81 @@ +import logging +import os +import typing +from dataclasses import dataclass +from typing import Any, Optional + +from levanter.tracker import Tracker, TrackerConfig + + +pylogger = logging.getLogger(__name__) + +if typing.TYPE_CHECKING: + from tensorboardX import SummaryWriter # noqa: F401 + + +class TensorboardTracker(Tracker): + name: str = "tensorboard" + + def __init__(self, writer: "SummaryWriter"): + self.writer = writer + + def log_hyperparameters(self, hparams: dict[str, Any]): + self.writer.add_hparams(hparams, {"dummy": 0}) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + del commit + for k, v in metrics.items(): + self.writer.add_scalar(k, v, step) + + def log_summary(self, metrics: dict[str, Any]): + for k, v in metrics.items(): + self.writer.add_scalar(k, v, global_step=None) + + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + pylogger.error("TensorboardLogger does not support logging artifacts yet") + pass + + +@TrackerConfig.register_subclass("tensorboard") +@dataclass +class TensorboardConfig(TrackerConfig): + logdir: str = "tblogs" + comment: Optional[str] = "" + purge_step: Optional[int] = None + max_queue: Optional[int] = 10 + flush_secs: Optional[int] = 120 + filename_suffix: Optional[str] = "" + write_to_disk: Optional[bool] = True + + def init(self, run_id: Optional[str]) -> TensorboardTracker: + dir_to_write = self.logdir + if run_id is not None: + dir_to_write = os.path.join(dir_to_write, run_id) + + pylogger.info(f"Writing Tensorboard logs to {dir_to_write}") + + from tensorboardX import SummaryWriter # noqa: F811 + + writer = SummaryWriter( + dir_to_write, + comment=self.comment, + purge_step=self.purge_step, + max_queue=self.max_queue, + flush_secs=self.flush_secs, + filename_suffix=self.filename_suffix, + write_to_disk=self.write_to_disk, + ) + + return TensorboardTracker(writer) + + +def _flatten_nested_dict(d): + def items(): + for key, value in d.items(): + if isinstance(value, dict): + for subkey, subvalue in _flatten_nested_dict(value).items(): + yield key + "/" + subkey, subvalue + else: + yield key, value + + return dict(items()) diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py new file mode 100644 index 000000000..8b6816f17 --- /dev/null +++ b/src/levanter/tracker/tracker.py @@ -0,0 +1,117 @@ +import abc +import dataclasses +import typing +from typing import Any, List, Optional + +import draccus + + +class Tracker(abc.ABC): + """ + A tracker is responsible for logging metrics, hyperparameters, and artifacts. + Meant to be used with the [levanter.tracker.current_tracker][] context manager, but can also be used directly. + + The name is borrowed from HF Accelerate. + + Examples: + >>> from levanter.tracker import current_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> with current_tracker(WandbTracker()): + ... log_metrics({"foo": 1}, step=0) + """ + + name: str + + @abc.abstractmethod + def log_hyperparameters(self, hparams: dict[str, Any]): + pass + + @abc.abstractmethod + def log(self, metrics: dict[str, typing.Any], *, step: Optional[int], commit: Optional[bool] = None): + """ + Log metrics to the tracker. Step is always required. + + Args: + metrics: Metrics to log + step: Step to log at + commit: Whether to commit the metrics. If None, uses the default for the tracker. + """ + pass + + @abc.abstractmethod + def log_summary(self, metrics: dict[str, Any]): + pass + + @abc.abstractmethod + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + pass + + def __enter__(self): + import levanter.tracker.tracker_fns as tracker_fns + + if hasattr(self, "_tracker_cm"): + raise RuntimeError("This tracker is already set as the global tracker") + setattr(self, "_tracker_cm", tracker_fns.current_tracker(self)) + self._tracker_cm.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + if not hasattr(self, "_tracker_cm"): + raise RuntimeError("This tracker is not set as the global tracker") + self._tracker_cm.__exit__(exc_type, exc_val, exc_tb) + delattr(self, "_tracker_cm") + + +class CompositeTracker(Tracker): + def __init__(self, loggers: List[Tracker]): + self.loggers = loggers + + def log_hyperparameters(self, hparams: dict[str, Any]): + for tracker in self.loggers: + tracker.log_hyperparameters(hparams) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + for tracker in self.loggers: + tracker.log(metrics, step=step, commit=commit) + + def log_summary(self, metrics: dict[str, Any]): + for tracker in self.loggers: + tracker.log_summary(metrics) + + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + for tracker in self.loggers: + tracker.log_artifact(artifact_path, name=name, type=type) + + +class TrackerConfig(draccus.PluginRegistry, abc.ABC): + discover_packages_path = "levanter.tracker" + + @abc.abstractmethod + def init(self, run_id: Optional[str]) -> Tracker: + raise NotImplementedError + + @classmethod + def default_choice_name(cls) -> Optional[str]: + return "wandb" + + +class NoopTracker(Tracker): + name: str = "noop" + + def log_hyperparameters(self, hparams: dict[str, Any]): + pass + + def log(self, metrics: dict[str, Any], *, step, commit: Optional[bool] = None): + pass + + def log_summary(self, metrics: dict[str, Any]): + pass + + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + pass + + +@TrackerConfig.register_subclass("noop") +@dataclasses.dataclass +class NoopConfig(TrackerConfig): + def init(self, run_id: Optional[str]) -> Tracker: + return NoopTracker() diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py new file mode 100644 index 000000000..2ed2b1928 --- /dev/null +++ b/src/levanter/tracker/tracker_fns.py @@ -0,0 +1,235 @@ +import dataclasses +import logging +import os +import tempfile +import typing +import warnings +from contextlib import AbstractContextManager +from typing import Any, Literal, Optional + +import draccus +import jax + +from levanter.tracker import CompositeTracker, Tracker +from levanter.tracker.helpers import hparams_to_dict +from levanter.tracker.tensorboard import TensorboardTracker +from levanter.tracker.wandb import WandbTracker +from levanter.utils.jax_utils import is_inside_jit + + +logger = logging.getLogger(__name__) + + +_global_tracker: Optional["Tracker"] = None + + +def log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optional[bool] = None): + """ + Log metrics to the global tracker. + + Args: + metrics: Metrics to log + step: Step to log at + commit: Whether to commit the metrics. If None, uses the default for the tracker. + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + + if is_inside_jit(): + # we're inside a jit, so we need to log from the host + if commit: + raise ValueError("Cannot commit from inside jit") + jit_log_metrics(metrics, step=step) + else: + # TODO: do we need to coerce to np here? + _global_tracker.log(metrics, step=step) + + +def _no_throw_log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optional[bool] = None): + try: + if _global_tracker is None: + raise RuntimeError("No global tracker set") + _global_tracker.log(metrics, step=step) + except Exception: + logger.exception("Error logging metrics") + + +def jit_log_metrics(metrics, *, step=None): + """uses jax effect callback to log to wandb from the host""" + jax.debug.callback(_no_throw_log_metrics, metrics, step=step) + + +def log_summary(metrics: dict[str, Any]): + """ + Log summary metrics to the global tracker. + + Args: + metrics: Metrics to log + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + _global_tracker.log_summary(metrics) + + +def log_hyperparameters(hparams: dict[str, Any]): + """ + Log hyperparameters to the global tracker. + + Args: + hparams: Hyperparameters to log + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + + _global_tracker.log_hyperparameters(hparams) + + +def log_configuration(hparams: Any, config_name: Optional[str] = None): + """ + Logs a configuration object to the global tracker. If the configuration object is a dataclass, + it is dumped to a yaml file and logged as an artifact. + + Args: + hparams: Hyperparameters to log + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + + hparams_dict = hparams_to_dict(hparams) + _global_tracker.log_hyperparameters(hparams_dict) + + if dataclasses.is_dataclass(hparams): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, "config.yaml") + with open(config_path, "w") as f: + draccus.dump(hparams, f, encoding="utf-8") + name = config_name or "config.yaml" + _global_tracker.log_artifact(config_path, name=name, type="config") + + +def set_global_tracker(tracker: Tracker): + """ + Set the global tracker. Note that setting the global tracker is not thread-safe, + and using a tracker from multiple threads is only supported if the tracker itself is thread-safe. + + In general, it's preferred to use the context manager returned by `current_tracker` instead of this function + except for once at the beginning of the program. + + Args: + tracker: The tracker to set as the global tracker + force: Whether to force setting the global tracker even if it is already set + + Examples: + >>> from levanter.tracker import set_global_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> set_global_tracker(WandbTracker()) + >>> log_metrics({"foo": 1}, step=0) + """ + global _global_tracker + if _global_tracker is not None: + warnings.warn("Global tracker is already set. Overwriting it.") + _global_tracker = tracker + + +@typing.overload +def current_tracker() -> "Tracker": + ... + + +@typing.overload +def current_tracker(tracker: "Tracker") -> typing.ContextManager: + """Returns a context manager for setting the global tracker""" + ... + + +def current_tracker( + tracker: Optional[Tracker] = None, +) -> Tracker | typing.ContextManager: + """ + Get or set the global tracker. Note that setting the global tracker is not thread-safe, + and using a tracker from multiple threads is only supported if the tracker itself is thread-safe. + + Args: + tracker: If provided, returns a context manager that sets the global tracker to the provided tracker when used. + + Returns: + If no tracker is provided, returns the current global tracker. + If a tracker is provided, returns a context manager that sets the global tracker to the provided tracker when used. + + Examples: + >>> from levanter.tracker import current_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> with current_tracker(WandbTracker()): + ... log_metrics({"foo": 1}, step=0) + ... current_tracker().log({"foo": 2}, step=1) + """ + global _global_tracker + if tracker is None: + if _global_tracker is None: + raise RuntimeError("No global tracker set") + return _global_tracker + else: + return _GlobalLoggerContextManager(tracker) + + +@typing.overload +def get_tracker(name: Literal["wandb"]) -> WandbTracker: + ... + + +@typing.overload +def get_tracker(name: Literal["tensorboard"]) -> TensorboardTracker: + ... + + +@typing.overload +def get_tracker(name: str) -> Tracker: + ... + + +def get_tracker(name: str) -> Tracker: + """ + Lookup a tracker in the current global tracker with the provided name. + + Args: + name: Name of the tracker to lookup + + Returns: + The tracker with the provided name + + Examples: + >>> from levanter.tracker import get_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> with current_tracker(WandbTracker()): + ... log_metrics({"foo": 1}, step=0) + ... get_tracker("wandb").log_metrics({"foo": 2}, step=1) + """ + tracker = current_tracker() + if isinstance(tracker, CompositeTracker): + for t in tracker.loggers: + if t.name == name: + return t + elif tracker.name == name: + return tracker + + raise KeyError(f"Tracker with name {name} not found") + + +class _GlobalLoggerContextManager(AbstractContextManager): + def __init__(self, tracker: "Tracker"): + self.tracker = tracker + + def __enter__(self): + global _global_tracker + self.old_tracker = _global_tracker + _global_tracker = self.tracker + + return self.tracker + + def __exit__(self, exc_type, exc_val, exc_tb): + global _global_tracker + _global_tracker = self.old_tracker diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py new file mode 100644 index 000000000..d217ab000 --- /dev/null +++ b/src/levanter/tracker/wandb.py @@ -0,0 +1,199 @@ +import logging +import os +import tempfile +import typing +import warnings +from dataclasses import dataclass +from typing import Any, List, Optional, Union + +import jax +from draccus import field +from git import InvalidGitRepositoryError, NoSuchPathError, Repo + +from levanter.tracker import Tracker +from levanter.tracker.helpers import generate_pip_freeze, infer_experiment_git_root +from levanter.tracker.tracker import TrackerConfig +from levanter.utils import jax_utils + + +if typing.TYPE_CHECKING: + import wandb + import wandb.sdk.lib.disabled + + +logger = logging.getLogger(__name__) + +WandbRun = Union["wandb.sdk.wandb_run.Run", "wandb.sdk.lib.disabled.RunDisabled"] + + +class WandbTracker(Tracker): + name: str = "wandb" + run: WandbRun + + def __init__(self, run: Optional[WandbRun]): + import wandb + + if run is None: + if wandb.run is None: + logger.warning("Wandb run is not initialized. Initializing a new run.") + runx = wandb.init() + if runx is None: + raise RuntimeError("Wandb run is not initialized.") + self.run = runx + else: + self.run = wandb.run + else: + self.run = run + + def log_hyperparameters(self, hparams: dict[str, Any]): + self.run.config.update(hparams, allow_val_change=True) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + if step is None and not commit: + step = self.run.step + + self.run.log(metrics, step=step, commit=commit) + + def log_summary(self, metrics: dict[str, Any]): + self.run.summary.update(metrics) + + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + self.run.log_artifact(artifact_path, name=name, type=type) + + +def is_wandb_available(): + try: + import wandb + except ImportError: + return False + return wandb is not None and wandb.run is not None + + +@TrackerConfig.register_subclass("wandb") +@dataclass +class WandbConfig(TrackerConfig): + """ + Configuration for wandb. + """ + + entity: Optional[str] = None # An entity is a username or team name where you send runs + project: Optional[str] = None # The name of the project where you are sending the enw run. + name: Optional[str] = None # A short display name for this run, which is how you'll identify this run in the UI. + tags: List[str] = field(default_factory=list) # Will populate the list of tags on this run in the UI. + id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project + group: Optional[str] = None # Specify a group to organize individual runs into a larger experiment. + mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be whatever W&B decides. + resume: Optional[Union[bool, str]] = None + """ + Set the resume behavior. Options: "allow", "must", "never", "auto" or None. + By default, if the new run has the same ID as a previous run, this run overwrites that data. + Please refer to [init](https://docs.wandb.ai/ref/python/init) and [resume](https://docs.wandb.ai/guides/runs/resuming) + document for more details. + """ + + save_code: Union[bool, str] = True + """If string, will save code from that directory. If True, will attempt to sniff out the main directory (since we + typically don't run from the root of the repo).""" + + save_xla_dumps: bool = False + """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" + + def init(self, run_id: Optional[str]) -> WandbTracker: + import wandb + + if run_id is not None and self.id is not None and run_id != self.id: + warnings.warn( + f"Both trainer's id {run_id} and WandB's id {self.id} are set. WandB will use the id set in its" + " config." + ) + + id = self.id + if id is None: + id = run_id + + hparams_to_save = {} + + # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled + # however, we do share information about the run id, so that we can link to it from the other workers + if jax.process_index() == 0: + mode = self.mode + else: + mode = "disabled" + + git_settings = self._git_settings() + + if "git_commit" in git_settings: + hparams_to_save["git_commit"] = git_settings["git_commit"] + + r = wandb.init( + entity=self.entity, + project=self.project, + name=self.name, + tags=self.tags, + id=id, + group=self.group, + resume=self.resume, + mode=mode, + config=hparams_to_save, + settings=git_settings, + allow_val_change=True, + ) + + assert r is not None + + if jax.process_count() > 1: + # we need to share wandb run information across all hosts, because we use it for checkpoint paths and things + metadata_to_share = dict( + entity=r.entity, + project=r.project, + name=r.name, + tags=r.tags, + id=r.id, + group=r.group, + ) + metadata_to_share = jax_utils.multihost_broadcast_sync( + metadata_to_share, is_source=jax.process_index() == 0 + ) + + if jax.process_index() != 0: + assert r.mode == "disabled" + for k, v in metadata_to_share.items(): + setattr(r, k, v) + + logger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") + + # generate a pip freeze + with tempfile.TemporaryDirectory() as tmpdir: + requirements_path = os.path.join(tmpdir, "requirements.txt") + requirements = generate_pip_freeze() + with open(requirements_path, "w") as f: + f.write(requirements) + if wandb.run is not None: + wandb.run.log_artifact(str(requirements_path), name="requirements.txt", type="requirements") + + wandb.summary["num_devices"] = jax.device_count() + wandb.summary["num_hosts"] = jax.process_count() + wandb.summary["backend"] = jax.default_backend() + + return WandbTracker(r) + + def _git_settings(self): + other_settings = dict() + if isinstance(self.save_code, str): + code_dir = self.save_code + elif self.save_code: + code_dir = infer_experiment_git_root() or "." # type: ignore + else: + code_dir = None + if code_dir is not None: + logger.info(f"Setting wandb code_dir to {code_dir}") + other_settings["code_dir"] = code_dir + other_settings["git_root"] = code_dir + # for some reason, wandb isn't populating the git commit, so we do it here + try: + repo = Repo(code_dir) + other_settings["git_commit"] = repo.head.commit.hexsha + except (NoSuchPathError, InvalidGitRepositoryError): + logger.warning(f"Could not find git repo at {code_dir}") + pass + return other_settings diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 5fe6e0302..8c0c7e663 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -1,26 +1,35 @@ import atexit import copy -import functools +import dataclasses import logging as pylogging import os -import re import sys import typing import warnings from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Any, Callable, Dict, Generic, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + List, + Mapping, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, + Union, +) import equinox as eqx import jax -import jax.numpy as jnp import jmp import numpy as np -import optax -import wandb from draccus import field -from jax import ShapeDtypeStruct from jax.experimental import multihost_utils from jax.sharding import Mesh from jaxtyping import PRNGKeyArray, PyTree @@ -29,18 +38,24 @@ import haliax as hax from haliax import Axis from haliax.partitioning import ResourceAxis, ResourceMapping, named_jit -from haliax.types import Scalar +from haliax.types import IntScalar, Scalar +import levanter.checkpoint import levanter.logging -from levanter.checkpoint import CheckpointerConfig +import levanter.tracker +import levanter.tracker.wandb +from levanter import tracker +from levanter.checkpoint import CheckpointerConfig, load_checkpoint from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import microbatched -from levanter.logging import WandbConfig, capture_time -from levanter.types import FilterSpec +from levanter.logging import capture_time +from levanter.optim import SecondOrderTransformation +from levanter.tracker import TrackerConfig +from levanter.types import ComputeLossFunction, FilterSpec, ModuleComputeLoss from levanter.utils import cloud_utils -from levanter.utils.jax_utils import is_inexact_arrayish, leaf_key_paths +from levanter.utils.jax_utils import as_arrayish, is_inexact_arrayish from levanter.utils.tree_utils import inference_mode @@ -48,26 +63,43 @@ X = TypeVar("X") # Input M = TypeVar("M", bound=PyTree) -S = TypeVar("S", bound=PyTree) DEFAULT_JAX_CONFIG = { "jax_threefry_partitionable": True, "jax_softmax_custom_jvp": True, } -# A note on the semantics of "step" vs "next_step": -# The "step" of a TrainerState is the state after `step` steps have been taken. -# A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. +class TrainerState(eqx.Module, Generic[M]): + """ + This is the state of the trainer. It contains the model, optimizer state, and random key. + It is an equinox Module becaues it is a PyTree that gets passed to the core `train_step` method + of the Trainer. This unfortunately means that `step` is an Array and not an int, hence the IntScalar. -@dataclass -class TrainerState(Generic[M]): - step: int + It's designed to be extended by subclasses. + """ + + _step: IntScalar = eqx.field(converter=lambda x: as_arrayish(x)) model: M opt_state: OptState training_key: PRNGKeyArray + is_trainable: PyTree[FilterSpec] # = eqx.field(static=True) + + @cached_property + def step(self) -> int: + return int(self._step) + + @property + def trainable_model(self) -> M: + return eqx.filter(self.model, self.is_trainable) +S = TypeVar("S", bound=TrainerState) + + +# A note on the semantics of "step" vs "next_step": +# The "step" of a TrainerState is the state after `step` steps have been taken. +# A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. @dataclass class StepInfo(Generic[M]): state: TrainerState[M] @@ -76,7 +108,6 @@ class StepInfo(Generic[M]): model = property(lambda self: self.state.model) opt_state = property(lambda self: self.state.opt_state) - next_key = property(lambda self: self.state.training_key) step = property(lambda self: self.state.step - 1) """ @@ -112,52 +143,54 @@ def decorator(fn: Callable[[StepInfo], None]): return decorator(fn) +# A note on extending Trainer: +# First, consider whether you can do what you want with hooks. Hooks can cover a lot of use cases. +# Sometimes, however, you need to do something more complicated. In that case, you can extend Trainer. +# In order to do that, you need to: +# * Extend TrainerState to add your additional state +# * Override `_train_step` to add your additional logic +# * Override `initial_state` or `_initialize_state_from_scratch` to initialize your additional state. (The latter is +# simpler and means you don't need to handle the checkpointing logic yourself.) +# * You might also need to override `training_steps` if you want to make the type checker happy. + + class Trainer: config: "TrainerConfig" optimizer: GradientTransformation hooks: TrainerHooks - is_trainable_param: Optional[PyTree[FilterSpec]] + tracker: levanter.tracker.Tracker _raw_loss_function: Callable + _cmanagers: List[typing.ContextManager] = [] def __init__( self, config: "TrainerConfig", optimizer: GradientTransformation, - loss_fn: Callable, + loss_fn: Optional[ComputeLossFunction] = None, *, - is_trainable: PyTree[FilterSpec] = True, + add_default_hooks: bool = True, ): """ Args: config: the trainer config - optimizer: the optimizer, e.g. `optax.adam(1e-3)` or produced by [levanter.trainer.OptimizerConfig][] + optimizer: the optimizer, e.g. `optax.adam(1e-3)` or produced by [levanter.optim.OptimizerConfig][] loss_fn (Callable): the loss function. This should be a function that takes a model and some inputs and returns a scalar loss. It should be jit-able and should not have any side effects. - is_trainable: optional filter spec for the trainable parameters. This is used to filter out non-trainable - parameters for the optimizer state and for computing gradients. Non-trainable parameters are also - not checkpointed. If you don't specify this, all parameters are assumed to be trainable. """ self.hooks = TrainerHooks() self.config = config - self._raw_loss_function = loss_fn self.optimizer = optimizer - self.is_trainable_param = is_trainable - - @cached_property - def loss_fn(self): - """ - Wrapped loss function that casts the model to compute precision and sets the context axis mapping to compute - """ + self.loss_fn = loss_fn or ModuleComputeLoss() + if isinstance(config.tracker, Sequence): + self.tracker = levanter.tracker.CompositeTracker([c.init(self.run_id) for c in config.tracker]) + else: + self.tracker = config.tracker.init(self.run_id) - @named_jit(in_axis_resources=self.parameter_axis_mapping, axis_resources=self.compute_axis_mapping) - @functools.wraps(self._raw_loss_function) - def fn(model, *batch, **batch_kwargs): - with hax.axis_mapping(self.compute_axis_mapping): - model = self.mp.cast_to_compute(model) - return self._raw_loss_function(model, *batch, **batch_kwargs) + self._cmanagers = [] - return fn + if add_default_hooks: + self._add_default_hooks() @property def run_id(self) -> str: @@ -170,6 +203,10 @@ def mp(self) -> jmp.Policy: """Returns the mixed precision policy""" return self.config.mp + @property + def num_train_steps(self) -> int: + return self.config.num_train_steps + @typing.overload def add_hook(self, fn: Callable[[StepInfo], Any], *, every: int = 1): ... @@ -192,6 +229,14 @@ def parameter_axis_mapping(self) -> ResourceMapping: def compute_axis_mapping(self) -> ResourceMapping: return self.config.compute_axis_mapping + @property + def param_env(self) -> hax.ResourceEnv: + return self.config.param_env + + @property + def compute_env(self) -> hax.ResourceEnv: + return self.config.compute_env + @property def device_mesh(self) -> Mesh: return self.config.device_mesh @@ -204,121 +249,155 @@ def TrainBatch(self): def EvalBatch(self): return self.config.EvalBatch + def __enter__(self): + this_managers = [ + levanter.current_tracker(self.tracker), + self.param_env, + ] + self._cmanagers.append(this_managers) + + for cmanager in this_managers: + cmanager.__enter__() + + return self + + def __exit__(self, *args): + assert len(self._cmanagers) > 0, "Trainer.__exit__ called without corresponding Trainer.__enter__" + cur_managers = self._cmanagers.pop() + problems = [] + for cmanager in reversed(cur_managers): + try: + cmanager.__exit__(*args) + except Exception as e: + problems.append(e) + + if len(problems) > 0: + raise RuntimeError("Exception(s) occurred while exiting trainer", problems) from problems[0] + def initial_state( - self, training_key: PRNGKeyArray, model: Optional[M] = None, model_init: Optional[Callable[[], M]] = None - ) -> TrainerState: + self, + training_key: PRNGKeyArray, + model: Optional[M] = None, + model_init: Optional[Callable[[], M]] = None, + *, + is_trainable: PyTree[FilterSpec] = True, + ) -> TrainerState[M]: """ - Initializes the model, optimizer state, and random key. Also handles loading a checkpoint if needed. + Either loads a checkpoint or initializes a fresh trainer state. This is the recommended way to initialize + a trainer state. + + This method is smart enough to handle subclasses of TrainerState. If you want to extend TrainerState, you + can override _initialize_state_from_scratch + + Args + is_trainable: optional filter spec for the trainable parameters. This is used to filter out non-trainable + parameters for the optimizer state and for computing gradients. Non-trainable parameters are also + not checkpointed. If you don't specify this, all parameters are assumed to be trainable. Returns: - model, opt_state, key, resume_step + TrainerState: the initial state, """ - if model is not None and model_init is not None: raise ValueError("only one of model and model_init should be specified") elif model is None and model_init is None: raise ValueError("one of model and model_init must be specified") if model is not None: - # we can't just use `lambda: model` because JAX jit can't see captures, but it can see partials - # We can't use plain partials because they aren't pytrees + # we can't just use `lambda: model` because JAX jit can't see captures, but it can see jax partials model_init = jax.tree_util.Partial(lambda m: m, model) + del model assert model_init is not None - model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init) + # first try to load a full trainer state checkpoint + checkpoint_path = self.config.load_checkpoint_path + if checkpoint_path is None: + checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) - # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones - trainable_model_shape = self.trainable_params_only(model_shape) + do_load_checkpoint = self.config.load_checkpoint + initial_model_path = self.config.initialize_from - ckpt = self.maybe_load_checkpoint( - trainable_model_shape, - (opt_state_shape, training_key), - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, - ) + # we don't save the full trainer state, so we need to filter out the non-trainable parameters - if ckpt is not None: - trainable_model, (opt_state, training_key), completed_step = ckpt - if model is not None: - model = eqx.combine(trainable_model, model) - else: - model = eqx.combine(trainable_model, model_shape) - - if any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(model)): - # if we're resuming, we need to re-initialize the non-trainable parameters to their original values - non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)(model_init) - model = eqx.combine(trainable_model, non_trainable) - - step = completed_step + 1 - elif self.config.initialize_from is not None: - # initialize from a levanter checkpoint - logger.info(f"Initializing model from checkpoint {self.config.initialize_from}") - match levanter.checkpoint.load_checkpoint( - model_shape, - None, - self.config.initialize_from, - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, - ): - # new_model is probably only the trainable parameters, so we init the rest - case base_model, _, loaded_step: - logger.info(f"Initialized from step {loaded_step} of {self.config.initialize_from}") - old_model_init = model_init - - model_init = jax.tree_util.Partial(lambda m: eqx.combine(m, old_model_init()), base_model) - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)( - model_init - ) - - step = 0 - case None: - raise ValueError(f"Could not load model from checkpoint {self.config.initialize_from}") - else: - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init) - step = 0 + def init_state_and_model(model_init, training_key, is_trainable): + model = model_init() + state = self._initialize_state_from_scratch(model, training_key, is_trainable) + return state + + trainer_state_shape = eqx.filter_eval_shape(init_state_and_model, model_init, training_key, is_trainable) + saveable_state_shape = _make_saveable_trainer_state(trainer_state_shape, is_trainable) + + if do_load_checkpoint is not False: + try: + state = load_checkpoint(saveable_state_shape, checkpoint_path, self.param_env) + except FileNotFoundError: + if do_load_checkpoint: + raise + else: + state = None + + # if that fails, try to load just a model from a checkpoint for initialization + if state is None and initial_model_path is not None: + logger.info(f"Initializing from {initial_model_path}") + # todo: we are potentially holding two models in memory at once here, if we pass in a model + # instead of a model_init and we use initialize_from. We could avoid this by deleting + # any to-be-loaded parameters from the model before loading, but that's a bit more complicated + loaded_model = load_checkpoint( + saveable_state_shape.model, + initial_model_path, + env=self.param_env, + subpath="model", + ) + + # we don't necessarily load the full model, so we need to combine it with the model init + model_init = jax.tree_util.Partial(lambda m, f: eqx.combine(m, f()), loaded_model, model_init) + + # now we initialize a fresh trainer state, possibly just to finish any missing fields + @named_jit(axis_resources=self.param_env, donate_args=(True, True, True, False)) + def init_state(partial_state, model_init, training_key, is_trainable): + model = model_init() + fresh_state = self._initialize_state_from_scratch(model, training_key, is_trainable) + return eqx.combine(partial_state, fresh_state) - return TrainerState(step, model, opt_state, training_key) + state = init_state(state, model_init, training_key, is_trainable) + + return state def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepInfo[M]: """ Performs a single training step. """ with capture_time() as step_time: - key, new_key = jax.random.split(state.training_key) - loss, new_model, new_optstate = self._train_step_fn( - state.model, state.opt_state, *batch, **batch_kwargs, key=key - ) + loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs) # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?) loss = loss.item() # type: ignore - return StepInfo(TrainerState(state.step + 1, new_model, new_optstate, new_key), loss, step_time()) + return StepInfo(new_state, loss, step_time()) def training_steps( self, state: TrainerState[M], train_loader, run_hooks: bool = True - ) -> typing.Iterator[StepInfo]: + ) -> typing.Iterator[StepInfo[M]]: """ Generator that yields training steps and runs hooks. """ iter_data = iter(train_loader) + with levanter.current_tracker(self.tracker): + while state.step < self.num_train_steps: + with capture_time() as loading_time: + example = next(iter_data) - while state.step < self.config.num_train_steps: - with capture_time() as loading_time: - example = next(iter_data) + levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) - # TODO: refactor logging - wandb.log({"throughput/loading_time": loading_time()}, step=state.step) + info = self.train_step(state, example) - info = self.train_step(state, example) - state = info.state + if run_hooks: + with capture_time() as hook_time: + self.run_hooks(info) - if run_hooks: - with capture_time() as hook_time: - self.run_hooks(info) + levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) - wandb.log({"throughput/hook_time": hook_time()}, step=state.step) - - yield info + state = info.state + yield info def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bool = True) -> StepInfo[M]: """ @@ -333,16 +412,13 @@ def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bo return info - def add_default_hooks(self, eval_dataset: Optional[Iterable[X]] = None): + def _add_default_hooks(self): from levanter import callbacks self.add_hook(callbacks.pbar_logger(total=self.config.num_train_steps), every=1) - self.add_hook(callbacks.log_to_wandb, every=1) - if eval_dataset is not None: - self.add_eval_hook(eval_dataset) - self.add_hook(callbacks.wandb_xla_logger(self.config.wandb), every=self.config.steps_per_eval) + self.add_hook(callbacks.log_step_info, every=1) # engine.add_hook(callbacks.log_memory_usage(), every=1) - checkpointer = self.config.checkpointer.create(self.run_id, self.is_trainable_param) + checkpointer = self.config.checkpointer.create(self.run_id) self.add_hook(checkpointer.on_step, every=1) # checkpointer manages its own frequency def add_eval_hook(self, eval_dataset, name: Optional[str] = None): @@ -352,10 +428,12 @@ def add_eval_hook(self, eval_dataset, name: Optional[str] = None): if eval_loader and (self.config.max_eval_batches is None or self.config.max_eval_batches > 0): - @eqx.filter_jit + @named_jit(axis_resources=self.param_env, donate_args=(False)) def eval_loss(model, *batch, **batch_kwargs): model = inference_mode(model, True) - return self.loss_fn(model, *batch, **batch_kwargs, key=None) + # TODO: should we do in full precision? + with self.compute_env: + return self.loss_fn(model, *batch, **batch_kwargs, key=None) self.add_hook( callbacks.compute_validation_loss( @@ -375,7 +453,7 @@ def replicated_loader(self, dataset: Dataset[X], batch_axis: Axis) -> Replicated Returns: ReplicatedBatchLoader: the batch loader """ - return ReplicatedBatchLoader(dataset, self.device_mesh, batch_axis, self.compute_axis_mapping) + return ReplicatedBatchLoader(dataset, batch_axis, self.config.compute_env) def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> ShardedBatchLoader[X]: """Creates a sharded batch loader for the given dataset. Generally you should use this @@ -388,33 +466,34 @@ def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> Shar Returns: ShardedBatchLoader: the batch loader """ - return ShardedBatchLoader(dataset, self.device_mesh, batch_axis, self.compute_axis_mapping) + return ShardedBatchLoader(dataset, batch_axis, self.config.compute_env) @cached_property - def _train_step_fn(self): - @named_jit( - axis_resources=self.parameter_axis_mapping, - out_axis_resources=self.parameter_axis_mapping, - donate_args=(True, True), - ) - def train_step(model, opt_state, *batch, **batch_kwargs): - model = inference_mode(model, False) + def _jit_train_step_fn(self): + return named_jit(self._train_step, axis_resources=self.parameter_axis_mapping, donate_args=(True,)) - # we do this so that we only take the gradients of the trainable parameters - trainable_model, rest_model = self.partition_trainable_params(model) + def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scalar, TrainerState]: + key, new_key = jax.random.split(state.training_key) + model = inference_mode(state.model, False) - def split_loss_fn(trainable_model, *batch, **batch_kwargs): - model = eqx.combine(trainable_model, rest_model) - return self.loss_fn(model, *batch, **batch_kwargs) + def loss_fn(model, *batch, **batch_kwargs): + with self.compute_env: + return self.loss_fn(model, *batch, **batch_kwargs).scalar() - loss, grads = self._compute_gradients_microbatched(split_loss_fn, trainable_model, batch, **batch_kwargs) + # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us + loss, grads = self._compute_gradients_microbatched(loss_fn, model, batch, **batch_kwargs, key=key) - updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) - model = eqx.apply_updates(model, updates) + with self.param_env: + partial_loss = lambda model: loss_fn(model, *batch, **batch_kwargs) + model, opt_state = take_opt_step( + self.optimizer, model, state.opt_state, grads, partial_loss, state.is_trainable + ) + + new_state = dataclasses.replace(state, model=model, opt_state=opt_state) - return loss, model, opt_state + new_state = dataclasses.replace(new_state, _step=state._step + 1, training_key=new_key) - return train_step + return loss, new_state def _compute_gradients_microbatched(self, loss_fn, model: M, batch, **batch_kwargs) -> tuple[Scalar, M]: grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) @@ -427,75 +506,51 @@ def _compute_gradients_microbatched(self, loss_fn, model: M, batch, **batch_kwar ) return grad_fn(model, *batch, **batch_kwargs) - def _init_model_and_opt_state(self, model_init): - model = model_init() - # only force trainable params to param precision. Other params are cast to compute precision - trainable, non_trainable = self.partition_trainable_params(model) - trainable = self.mp.cast_to_param(trainable) - non_trainable = self.mp.cast_to_compute(non_trainable) - model = eqx.combine(trainable, non_trainable) - opt_state = self.optimizer.init(trainable) - return model, opt_state - - def _init_non_trainable_params(self, model_init): - model = model_init() + def _initialize_state_from_scratch(self, model, training_key, is_trainable): # only force trainable params to param precision. Other params are cast to compute precision - trainable, non_trainable = self.partition_trainable_params(model) - non_trainable = self.mp.cast_to_compute(non_trainable) - return non_trainable + model = cast_params_by_trainability(model, self.mp, is_trainable) + opt_state = init_optimizer_for_trainables(self.optimizer, model, is_trainable) - def trainable_params_only(self, model: M) -> M: - """ - Filters out non-trainable parameters from the model. This is used internally to - for the optimizer state and to compute gradients, but you can also use it to filter out - params for logging or something. - """ - return self.partition_trainable_params(model)[0] + return TrainerState(0, model, opt_state, training_key, is_trainable) - def partition_trainable_params(self, model): - """ - Partitions the model into trainable and non-trainable parameters. This is used internally - for the gradient calculation and checkpointing, but you can also use it to filter out params for logging - or something. - Returns: - trainable, non-trainable - """ +def init_optimizer_for_trainables(optimizer, model, is_trainable): + trainable, _ = _partition_trainable_params(model, is_trainable) + opt_state = optimizer.init(trainable) + return opt_state - def trainable_and_diffable(pred): - if callable(pred): - return lambda x: pred(x) and is_inexact_arrayish(x) - elif pred is True: - return is_inexact_arrayish - else: - return pred - - combined_mask = jax.tree_util.tree_map(trainable_and_diffable, self.is_trainable_param) - return eqx.partition(model, combined_mask) - - def maybe_load_checkpoint( - self, model: M, training_state: S, *, axis_mapping=None, mesh=None - ) -> Optional[Tuple[M, S, int]]: - """Loads a checkpoint if one exists and we're supposed to load it, - otherwise returns the model and training state as is""" - if self.config.load_checkpoint is not False: - # TODO: don't remake the checkpointer every time - checkpointer = self.config.checkpointer.create(self.run_id) - load_checkpoint_path = self.config.load_checkpoint_path - - if load_checkpoint_path is None: - load_checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) - - ckpt = checkpointer.load_checkpoint( - model, training_state, load_checkpoint_path, axis_mapping=axis_mapping, mesh=mesh - ) - if ckpt is None and self.config.load_checkpoint is True: - raise ValueError(f"Could not load checkpoint from {load_checkpoint_path}") +def cast_params_by_trainability(model, mp, is_trainable): + """ + Casts the parameters of a model to the appropriate precision based on the is_trainable filter spec. + Trainable parameters are cast to param precision, non-trainable parameters are cast to compute precision. + """ - return ckpt - else: - return None + trainable, non_trainable = _partition_trainable_params(model, is_trainable) + trainable = mp.cast_to_param(trainable) + non_trainable = mp.cast_to_compute(non_trainable) + model = eqx.combine(trainable, non_trainable) + return model + + +def _make_saveable_trainer_state(trainer_state: S, is_trainable) -> S: + """ + Returns the shape of the trainer state that we save to a checkpoint. This is used to load a checkpoint. + You can override if you really need custom checkpointing logic. By default everything in the trainer state + is saved (except for non-trainable model parameters) + """ + saveable_model = eqx.filter(trainer_state.model, is_trainable) + saveable_state = dataclasses.replace(trainer_state, model=saveable_model) + return saveable_state + + +def _initialize_global_tracker(config, run_id): + if isinstance(config, Sequence): + tracker = levanter.tracker.CompositeTracker([c.init(run_id) for c in config]) + else: + tracker = config.init(run_id) + + levanter.tracker.set_global_tracker(tracker) @dataclass @@ -503,18 +558,19 @@ class TrainerConfig: seed: int = 0 # random seed mp: jmp.Policy = jmp.get_policy("f32") # mixed precision policy - wandb: WandbConfig = field(default_factory=WandbConfig) + wandb: Optional[tracker.wandb.WandbConfig] = None log_dir: Path = Path("logs/") run_base_dir: Path = Path("runs/") id: Optional[str] = None # run id. if None, will be set to a random string + tracker: TrackerConfig | Tuple[TrackerConfig, ...] = field(default_factory=tracker.wandb.WandbConfig) + # config related to partitioning batch_axis: Optional[str] = "batch" # Batch axis for data parallel. fsdp_axis: Optional[Union[str, List[str]]] = "embed" # Axis/Axes to use for FSDP tensor_parallel_axes: Optional[List[str]] = None # Axes, if any, to use for tensor parallelism - # TODO: in theory we can support tuples of physical axis names, but I don't think anyone actually uses that. axis_resources: Mapping[str, str] = field(default_factory=dict) """mapping from logical axis to physical axis. batch_axis, fsdp_axis, and tensor_parallel_axes are preferred""" parameter_axis_resources: Mapping[str, str] = field(default_factory=dict) # overrides axis_mapping for parameter @@ -555,15 +611,6 @@ class TrainerConfig: # whether or not to shutdown the tpu at exit. If a float, shutdown after that many seconds. True = 5 minutes shutdown_at_exit: Union[bool, float] = False - @property - def run_name(self) -> str: - try: - import wandb - - return wandb.run and (wandb.run.name or wandb.run.id) or "unnamed" - except ImportError: - return "unnamed" - @property def TrainBatch(self): return Axis("batch", self.train_batch_size) @@ -576,23 +623,31 @@ def EvalBatch(self): def microbatch_size(self): return self.per_device_parallelism * self.data_axis_size - def initialize(self, all_config): + def __post_init__(self): + if self.wandb is not None: + warnings.warn("wandb is deprecated. use tracker with type wandb instead", DeprecationWarning) + self.tracker = self.wandb + + def initialize(self): """Initializes jax, wandb, logging, setting the run name/id in the process""" + self._initialize_jax_config() # Can't do full logging setup until we've initialized jax b/c we use jax for rank id pylogging.basicConfig(level=pylogging.INFO) self.distributed.initialize() - self._maybe_set_id() - self._initialize_logging() - self.ray.initialize() - self._initialize_jax_config() self._validate_and_set_defaults() - self.wandb.init(self.id, all_config) + + id = self._maybe_set_id() + levanter.logging.init_logging(self.log_dir, f"{id}.log") + _initialize_global_tracker(self.tracker, id) + + self.ray.initialize() if self.require_accelerator is None: self.require_accelerator = not sys.platform.startswith("darwin") if self.require_accelerator: - assert jax.default_backend() != "cpu", "Accelerator required but not found" + if jax.default_backend() == "cpu": + raise RuntimeError("No accelerator found. Please run on a TPU or GPU.") if self.shutdown_at_exit is not False: if isinstance(self.shutdown_at_exit, bool): @@ -616,6 +671,14 @@ def data_axis_size(self): assert jax.device_count() % self.model_axis_size == 0 return jax.device_count() // self.model_axis_size + @cached_property + def compute_env(self) -> hax.ResourceEnv: + return hax.ResourceEnv(self.compute_axis_mapping, self.mp, self.device_mesh) + + @cached_property + def param_env(self) -> hax.ResourceEnv: + return hax.ResourceEnv(self.parameter_axis_mapping, self.mp, self.device_mesh) + @cached_property def compute_axis_mapping(self) -> ResourceMapping: """Mapping from logical axis to physical axis for compute.""" @@ -651,10 +714,6 @@ def _initialize_jax_config(self): for key, value in self.jax_config.items(): jax.config.update(key, value) - def _initialize_logging(self): - self.log_dir.mkdir(parents=True, exist_ok=True) - levanter.logging.init_logger(self.log_dir / f"{self.id}.log") - def _maybe_set_id(self): # always do this so we don't get weird hangs if the id isn't set right # for random ids, we want to ensure that all hosts have the same id @@ -667,7 +726,7 @@ def _maybe_set_id(self): # TODO: this doesn't work with wandb sweeps. need to reconcile when we merge if "RUN_ID" in os.environ: self.id = os.environ["RUN_ID"] - elif self.wandb.id is not None: + elif self.wandb is not None and self.wandb.id is not None: self.id = self.wandb.id else: # wandb run ids are 8 characters [a-z0-9], which we'll emulate here @@ -678,6 +737,8 @@ def _maybe_set_id(self): logger.info(f"Setting run id to {self.id}") + return self.id + # we can't do this in post_init because we don't want to call jax.device_count before calling distributed.initialize def _validate_and_set_defaults(self): if jax.device_count() % self.model_axis_size != 0: @@ -691,9 +752,14 @@ def _validate_and_set_defaults(self): ): raise ValueError("either model_axis_size or local_device_count must be divisible by the other") + assert self.train_batch_size != -1 or self.per_device_parallelism != -1 + if self.per_device_parallelism == -1: self.per_device_parallelism = self.train_batch_size // self.data_axis_size + if self.train_batch_size == -1: + self.train_batch_size = self.per_device_parallelism * self.data_axis_size + # validate size of per_device_parallelism if self.train_batch_size % (self.per_device_parallelism * self.data_axis_size) != 0: raise ValueError( @@ -705,135 +771,64 @@ def _validate_and_set_defaults(self): self.per_device_eval_parallelism = self.per_device_parallelism -@dataclass -class OptimizerConfig: - # Config related to optimizer (always adam for now) - learning_rate: float = 6e-4 - weight_decay: float = 0.0 - beta1: float = 0.9 - beta2: float = 0.999 - epsilon: float = 1e-8 - max_grad_norm: Optional[float] = 1.0 - - min_lr_ratio: float = 0.1 - warmup_ratio: Optional[float] = None # Deprecated. fraction of training steps to use as warmup - warmup: float = 0.01 - """fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup""" - cooldown: float = 0.0 - """fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown""" - lr_schedule: str = "cosine" # constant, cosine, linear - """a regex or a list of strings to identify where to mask weight. """ - """For nano-GPT, this field can be set as - `r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings"`""" - weight_decay_modules: Optional[Union[List[str], str]] = None - - def build(self, num_train_steps: int) -> GradientTransformation: - """Creates the optimizer""" - - # indirection makes it work with optax.inject_hyperparams so we can log the learning rate - def _optimizer(learning_rate): - components = [] - - if self.max_grad_norm: - components.append(optax.clip_by_global_norm(self.max_grad_norm)) - - components.append(optax.scale_by_adam(self.beta1, self.beta2, self.epsilon)) - - if self.weight_decay > 0: - components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) - - # - learning rate for descent - components.append(optax.scale(-learning_rate)) - - optimizer = optax.chain(*components) - - return optimizer - - return optax.inject_hyperparams(_optimizer)(learning_rate=self.lr_scheduler(num_train_steps)) - - def build_weight_decay_mask(self): - if self.weight_decay_modules is None: - return None - else: - # mask based on regex or module path - def _apply_on(x, key_path): - if isinstance(self.weight_decay_modules, str): - compiled_regex = re.compile(self.weight_decay_modules) - return compiled_regex.match(key_path) is not None - else: - return any(key_path.__contains__(target) for target in self.weight_decay_modules) - - def mask_fn(model): - return jax.tree_util.tree_map( - _apply_on, - model, - leaf_key_paths(model, is_leaf=eqx.is_array), - is_leaf=eqx.is_array, - ) - - return mask_fn - - def lr_scheduler(self, num_train_steps): - warmup_steps = self._convert_warmup(num_train_steps) - cooldown_steps = _convert_ratio_or_steps(self.cooldown, num_train_steps) - lr_decay_steps = num_train_steps - warmup_steps - cooldown_steps - min_lr = self.learning_rate * self.min_lr_ratio - - match self.lr_schedule: - case "constant": - schedule = optax.constant_schedule(self.learning_rate) - case "cosine": - schedule = optax.cosine_decay_schedule(self.learning_rate, lr_decay_steps, self.min_lr_ratio) - case "linear": - schedule = optax.linear_schedule(self.learning_rate, min_lr, lr_decay_steps - warmup_steps) - case "inv_sqrt": - schedule = _inv_sqrt_decay_schedule(self.learning_rate, min_lr, warmup_steps, 10000) - case _: - raise ValueError(f"Unknown lr_schedule: {self.lr_schedule}") - - schedules = [] - boundaries = [] - - if warmup_steps != 0: - warmup = optax.linear_schedule(0.0, self.learning_rate, warmup_steps) - schedules.append(warmup) - boundaries.append(warmup_steps) - - schedules.append(schedule) - - if cooldown_steps != 0: - final_main_lr = schedule(lr_decay_steps) - cooldown = optax.linear_schedule(final_main_lr, min_lr, cooldown_steps) - schedules.append(cooldown) - boundaries.append(num_train_steps - cooldown_steps) - - if len(schedules) > 1: - schedule = optax.join_schedules(schedules, boundaries) - - return schedule - - def _convert_warmup(self, num_train_steps: int): - if self.warmup_ratio is not None: - warnings.warn("warmup_ratio is deprecated. Use warmup instead") - return int(self.warmup_ratio * num_train_steps) - else: - return _convert_ratio_or_steps(self.warmup, num_train_steps) +class AllConfig(Protocol): + trainer: TrainerConfig -def _inv_sqrt_decay_schedule(lr: float, min_lr: float, warmup_steps: int, timescale: float = 10000): - def schedule(count): - decay = jnp.minimum(1.0, 1.0 / jnp.sqrt(jnp.maximum(count + warmup_steps, 1) / timescale)) - return jnp.maximum(lr * decay, min_lr) +def initialize(config: TrainerConfig | AllConfig): + """Initializes jax, logging, setting the run name/id in the process. Also initializes tracking and saves config + as hyperparameters and as an artifact""" + if isinstance(config, TrainerConfig): + trainer_config = config + else: + trainer_config = config.trainer - return schedule + trainer_config.initialize() + levanter.tracker.log_configuration(config) -def _params_only(t): - return eqx.filter(t, is_inexact_arrayish) +def _partition_trainable_params(model, filter): + """ + Partitions the model into trainable and non-trainable parameters. This is used internally + for the gradient calculation and checkpointing, but you can also use it to filter out params for logging + or something. + + Returns: + trainable, non-trainable + """ + + def trainable_and_diffable(pred): + if callable(pred): + return lambda x: pred(x) and is_inexact_arrayish(x) + elif pred is True: + return is_inexact_arrayish + else: + return pred + + combined_mask = jax.tree_util.tree_map(trainable_and_diffable, filter) + return eqx.partition(model, combined_mask) -def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int): - if ratio_or_steps < 1.0: - return int(ratio_or_steps * num_train_steps) +def _ensure_scalar(x: hax.types.Scalar | hax.NamedArray) -> hax.types.Scalar: + if isinstance(x, hax.NamedArray): + return x.scalar() else: - return int(ratio_or_steps) + return x + + +def take_opt_step( + optimizer, + model: M, + opt_state: OptState, + grads: M, + obj_fn: Optional[Callable[[M], Scalar]] = None, + is_trainable: PyTree[FilterSpec] = True, +) -> tuple[M, OptState]: + train_grads = eqx.filter(grads, is_trainable) + trainable_model = eqx.filter(model, is_trainable) + updates, opt_state = optimizer.update(train_grads, opt_state, params=trainable_model) + # Sophia, e.g. + if isinstance(optimizer, SecondOrderTransformation): + opt_state = optimizer.update_hessian(opt_state, obj_fn, model) + model = eqx.apply_updates(model, updates) + return model, opt_state diff --git a/src/levanter/types.py b/src/levanter/types.py index 954578d27..60d7b82a0 100644 --- a/src/levanter/types.py +++ b/src/levanter/types.py @@ -1,17 +1,21 @@ -from typing import Any, Callable, Protocol, Tuple, TypeVar, Union +from typing import Any, Callable, Optional, Protocol, Tuple, TypeVar, Union + +import haliax as hax +from haliax.types import Scalar M = TypeVar("M") # Model +M_con = TypeVar("M_con", contravariant=True) # Model X = TypeVar("X", contravariant=True) # Input class ValAndGradFn(Protocol[M, X]): - def __call__(self, model: M, *inputs: X, **input_kwargs) -> Tuple[float, M]: + def __call__(self, model: M, *inputs: X, **input_kwargs) -> Tuple[Scalar, M]: ... -class ValFn(Protocol[M, X]): - def __call__(self, model: M, *inputs: X, **input_kwargs) -> Tuple[float, M]: +class ValFn(Protocol[M_con, X]): + def __call__(self, model: M_con, *inputs: X, **input_kwargs) -> Scalar: ... @@ -21,3 +25,36 @@ def __call__(self, model: M, *inputs: X, **input_kwargs) -> Tuple[float, M]: treated as-is, while callables are called on each element of the pytree. If the callable returns True, the element is kept, otherwise it is filtered out. """ + + +class ComputeLossFunction(Protocol[M_con, X]): + """ + Function signature for "compute_loss" functions in Levanter: these + couple the computation of the logits and the evaluation of the loss + """ + + def __call__( + self, + model: M_con, + *inputs: X, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + **kwargs, + ) -> Scalar | hax.NamedArray: + ... + + +class ModuleComputeLoss(ComputeLossFunction[M, X]): + """ + Loss that just delegates to the model's compute_loss method. + """ + + def __call__( + self, + model, + *inputs: X, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + **kwargs, + ) -> Scalar | hax.NamedArray: + return model.compute_loss(*inputs, reduction=reduction, reduction_axis=reduction_axis, **kwargs) diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index ff9fdb7af..879162a65 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -18,8 +18,9 @@ def num_cpus_used_by_tokenizer(tokenizer) -> int: else: # This is a bit hacky, but HF's fast tokenizers are parallelized under the hood. # we reserve a couple of cores just so Ray has somewhere to run the coordinator. - # Empirically I never see it get past 10 (usually more like 5-8), so we'll say 8 - return min(max(1, logical_cpu_core_count() - 2), 8) + # Really it's dependent on the number of docs, but that's not something we + # can easily know here. + return min(max(1, logical_cpu_core_count() - 2), 16) else: return 1 diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 258318497..92af33648 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -1,11 +1,14 @@ import contextlib import json +import warnings from dataclasses import fields from typing import Any, Callable, Optional, TypeVar import equinox as eqx import jax +import numpy as np from jax import numpy as jnp +from jax.sharding import Mesh from jaxtyping import PRNGKeyArray, PyTree from haliax.jax_utils import is_jax_array_like @@ -26,8 +29,24 @@ def jnp_to_python(a: jnp.ndarray): @contextlib.contextmanager def use_cpu_device(): """Temporarily sets the default device to CPU""" - with jax.default_device(jax.local_devices(backend="cpu")[0]): - yield + # If we have a mesh, we need to make a new version of that mesh + from haliax import current_resource_env + + mesh = current_resource_env().mesh + cpu = jax.local_devices(backend="cpu")[0] + if mesh is None: + with jax.default_device(cpu): + yield + else: + mesh_axis_names = mesh.axis_names + new_mesh = Mesh(np.array([cpu]).reshape((1,) * len(mesh_axis_names)), axis_names=mesh_axis_names) + with jax.default_device(cpu), new_mesh: + yield + + +def is_inside_jit(): + """Returns True if we're currently inside a jit""" + return isinstance(jnp.zeros(()), jax.core.Tracer) def flops_estimate(fn, *args, **kwargs): @@ -135,7 +154,12 @@ def leaf_key_paths( rec_value = rec(field, field_name) rec_values.append(rec_value) - return eqx.tree_at(lambda m: [getattr(m, name) for name in names], pytree, rec_values) + + _, tree_def = eqx.tree_flatten_one_level(pytree) + out = jax.tree_util.tree_unflatten(tree_def, rec_values) + return out + # this doesn't work reliably because tree_at doesn't like none values + # return eqx.tree_at(lambda m: [getattr(m, name) for name in names], pytree, rec_values, is_leaf=lambda x: x is None) else: leaves, treedef = jax.tree_util.tree_flatten(pytree, is_leaf=is_leaf) if len(leaves) == 1: @@ -150,7 +174,9 @@ def join_key(prefix, k): return f"{prefix}.{k}" if prefix else k -def key_iterator(key: PRNGKeyArray): +def key_iterator(key: PRNGKeyArray | int): + if isinstance(key, int): + key = jax.random.PRNGKey(key) while True: key, subkey = jax.random.split(key) yield subkey @@ -167,3 +193,28 @@ def is_inexact_arrayish(x): return jnp.issubdtype(x.dtype, jnp.inexact) else: return False + + +def tree_filter_like(template: X, tree: X) -> X: + """ + Filters a tree to only include the leaves that are not None in the template. + + This is useful for filtering out nontrainable parameters from a tree. + """ + + def match_like(templ_leaf, tree_leaf): + if templ_leaf is None: + return None + else: + if tree_leaf is None: + warnings.warn(f"Template has a non-None value where tree is None. Template value: {templ_leaf}") + return tree_leaf + + return jax.tree_util.tree_map(match_like, template, tree, is_leaf=lambda x: x is None) + + +def as_arrayish(x): + if hasattr(x, "shape") and hasattr(x, "dtype"): + return x + else: + return jnp.asarray(x) diff --git a/src/levanter/utils/py_utils.py b/src/levanter/utils/py_utils.py index 38ecfc49c..afc11c051 100644 --- a/src/levanter/utils/py_utils.py +++ b/src/levanter/utils/py_utils.py @@ -1,4 +1,5 @@ import os +import sys from dataclasses import dataclass from typing import Callable, TypeVar @@ -142,3 +143,25 @@ def cached_classproperty(func: Callable[..., PropReturn]) -> PropReturn: cached_classproperty.__doc__ = _CachedClassProperty.__doc__ + + +def actual_sizeof(obj): + """similar to sys.getsizeof, but recurses into dicts and lists and other objects""" + seen = set() + size = 0 + objects = [obj] + while objects: + need_to_see = [] + for obj in objects: + if id(obj) in seen: + continue + seen.add(id(obj)) + size += sys.getsizeof(obj) + if isinstance(obj, dict): + need_to_see.extend(obj.values()) + elif hasattr(obj, "__dict__"): + need_to_see.extend(obj.__dict__.values()) + elif isinstance(obj, (list, tuple, set, frozenset)): + need_to_see.extend(obj) + objects = need_to_see + return size diff --git a/tests/data/hero_data.npy b/tests/data/hero_data.npy new file mode 100644 index 000000000..f39678d79 Binary files /dev/null and b/tests/data/hero_data.npy differ diff --git a/tests/test_backpack.py b/tests/test_backpack.py index 0b1af96e1..f6b841f5a 100644 --- a/tests/test_backpack.py +++ b/tests/test_backpack.py @@ -8,7 +8,6 @@ import haliax import haliax as hax from haliax import Axis -from haliax.partitioning import round_axis_for_partitioning from levanter.models.backpack import BackpackConfig, BackpackLMHeadModel from levanter.trainer import TrainerConfig @@ -22,7 +21,7 @@ def test_backpack_predict(): trainer_config = TrainerConfig() - Vocab = round_axis_for_partitioning(Axis("vocab", VOCAB_SIZE), trainer_config.compute_axis_mapping) + Vocab = Axis("vocab", VOCAB_SIZE) model_config = BackpackConfig() model_key = PRNGKey(0) model = BackpackLMHeadModel.init(Vocab, model_config, key=model_key) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index c22525fd6..db54b2569 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -1,3 +1,4 @@ +import dataclasses import datetime import pathlib import tempfile @@ -26,10 +27,11 @@ def _dummy_step_info(step): return StepInfo( state=TrainerState( # + 1 b/c step here is next step - step=step + 1, + _step=step + 1, model=None, opt_state=(), training_key=(), + is_trainable=True, ), loss=0.0, step_duration=0.0, @@ -139,42 +141,41 @@ def advance_time(delta_seconds): assert _get_checkpoint_steps(tmpdir) == [2, 4, 6, 8, 10, 15, 20, 30, 40, 49] # 49 is last temporary checkpoint +def _make_state(step, key): + model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) + optim = optax.adam(1e-4) + opt_state = optim.init(arrays_only(model)) + + return TrainerState(step, model, opt_state, key, True) + + def test_checkpoint_simple(): key0 = jax.random.PRNGKey(0) key1 = jax.random.PRNGKey(1) - def make_state(key): - model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) - optim = optax.adam(1e-4) - opt_state = optim.init(arrays_only(model)) - - return model, opt_state, key - - initial_model, initial_opt_state, initial_key = make_state(key0) - rep_model, rep_state, rep_key = make_state(key1) + initial_state = _make_state(10, key0) + rep_state = _make_state(2, key1) - assert_trees_not_close(initial_model, rep_model) + assert_trees_not_close(initial_state.model, rep_state.model) with tempfile.TemporaryDirectory() as tmpdir: save_checkpoint( - initial_model, - (initial_opt_state, initial_key), - step=10, + initial_state, + step=initial_state.step, checkpoint_path=tmpdir, ) - restored_model, (restored_optstate, rkey), step = load_checkpoint( - rep_model, - (rep_state, rep_key), + restored_state = load_checkpoint( + rep_state, checkpoint_path=tmpdir, discover_latest=False, ) assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_model)), - jax.tree_util.tree_leaves(arrays_only(initial_model)), + jax.tree_util.tree_leaves(arrays_only(restored_state.model)), + jax.tree_util.tree_leaves(arrays_only(initial_state.model)), ) - assert all(np.isclose(rkey, initial_key)) - assert step == 10 + assert all(np.isclose(restored_state.training_key, initial_state.training_key)) + assert restored_state.step == initial_state.step def test_checkpoint_steps(): @@ -183,13 +184,7 @@ def test_checkpoint_steps(): optim = optax.adam(1e-4) - def make_state(key): - model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) - opt_state = optim.init(arrays_only(model)) - - return model, opt_state, key - - initial_model, initial_opt_state, initial_key = make_state(key0) + initial_state = _make_state(10, key0) data = jax.random.uniform(key0, (2, 2)) @eqx.filter_grad @@ -197,41 +192,33 @@ def loss_fn(model, data): m = jax.vmap(model) return jnp.mean(jnp.square(m(data))) - model, state = initial_model, initial_opt_state + state = initial_state for i in range(3): - grad = loss_fn(model, data) - updates, state = optim.update(grad, state) - model = eqx.apply_updates(model, updates) + grad = loss_fn(state.model, data) + updates, new_state = optim.update(grad, state.opt_state) + model = eqx.apply_updates(state.model, updates) + state = dataclasses.replace(state, _step=state.step + 1, model=model, opt_state=new_state) - assert_trees_not_close(model, initial_model) - assert_trees_not_close(state, initial_opt_state) + assert_trees_not_close(state, initial_state) - rep_model, rep_state, rep_key = make_state(key1) - assert_trees_not_close(model, rep_model) + rep_state = _make_state(42, key1) assert_trees_not_close(state, rep_state) with tempfile.TemporaryDirectory() as tmpdir: - save_checkpoint(model, state, step=3, checkpoint_path=tmpdir) - restored_model, restored_optstate, step = load_checkpoint( - rep_model, rep_state, checkpoint_path=tmpdir, discover_latest=False - ) + save_checkpoint(state, step=3, checkpoint_path=tmpdir) + restored_state = load_checkpoint(rep_state, checkpoint_path=tmpdir, discover_latest=False) assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_model)), - jax.tree_util.tree_leaves(arrays_only(model)), - ) - assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_optstate)), + jax.tree_util.tree_leaves(arrays_only(restored_state)), jax.tree_util.tree_leaves(arrays_only(state)), ) - assert step == 3 def test_checkpoint_discovery(): with tempfile.TemporaryDirectory() as tempdir: - save_checkpoint(model=1, training_state=2, step=10, checkpoint_path=f"{tempdir}/step-10") - save_checkpoint(model=3, training_state=4, step=20, checkpoint_path=f"{tempdir}/step-20") - save_checkpoint(model=5, training_state=6, step=30, checkpoint_path=f"{tempdir}/step-30") + save_checkpoint(dict(model=1, training_state=2), step=10, checkpoint_path=f"{tempdir}/step-10") + save_checkpoint(dict(model=3, training_state=4), step=20, checkpoint_path=f"{tempdir}/step-20") + save_checkpoint(dict(model=5, training_state=6), step=30, checkpoint_path=f"{tempdir}/step-30") latest = discover_latest_checkpoint(tempdir) assert latest == f"{tempdir}/step-30" diff --git a/tests/test_doremi.py b/tests/test_doremi.py new file mode 100644 index 000000000..c6ac76a47 --- /dev/null +++ b/tests/test_doremi.py @@ -0,0 +1,155 @@ +import equinox +import jax +import jax.random +import optax +import pytest + +import haliax as hax + +from levanter.callbacks import eval_loss_loop +from levanter.data.dataset import ShardableDataset +from levanter.data.mixture import MixtureDataset +from levanter.trainer import Trainer, TrainerConfig +from levanter.utils.jax_utils import key_iterator +from levanter.utils.py_utils import non_caching_cycle + + +class Example(equinox.Module): + x: hax.NamedArray + y: hax.NamedArray + + +Block = hax.Axis("Block", 1024) + + +class LogitDataset(ShardableDataset[Example]): + def __init__(self, W, noise, x_mask, x_bias, *, key): + self.W = W + self.noise = noise + self.x_mask = x_mask + self.x_bias = x_bias + self.key = key + + def __iter__(self): + key_iter = key_iterator(self.key) + Dim = self.W.axes[0] + while True: + x_block = hax.random.normal(next(key_iter), (Block, Dim)) * self.x_mask + self.x_bias + noise = hax.random.normal(next(key_iter), (Block,)) * self.noise + y_block = (hax.nn.sigmoid(hax.dot(x_block, self.W, axis=Dim) + noise) > 0.5).astype(float) + for i in range(Block.size): + yield Example(x=x_block[Block, i], y=y_block[Block, i]) + + def shard(self, shard_id: int, num_shards: int): + return LogitDataset(self.W, self.noise, self.x_mask, self.x_bias, key=jax.random.fold_in(self.key, shard_id)) + + +@pytest.mark.slow +def test_estimate_mixture_weights(): + # we create 3 simple logistic regression datasets + # 1. x is moderately predictive of y (y ~ [0, 0.5, 0.5] x + N(0, noise^2) > 0.5) + # 2. x is not predictive of y at all, y is highly random (y ~ N(0, 1)) + # 3. x is highly predictive of y, but it's very easy (y = sigmoid([1, 0, 0] x > 0.5) + + Dim = hax.Axis("Dim", 5) + Batch = hax.Axis("Batch", 32) + + keys = key_iterator(0) + + # W = hax.random.normal(next(keys), (Dim,)) + W1 = hax.named([0.0, 0.5, 0.5, 0.0, 0.0], (Dim,)) + x1_mask = hax.named([0.0, 1.0, 1.0, 0.0, 0.0], (Dim,)) + W2 = hax.named([0.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) + x2_mask = hax.named([0.0, 0.0, 0.0, 1.0, 1.0], (Dim,)) + W3 = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) + x3_mask = hax.named([1.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) + x3_bias = hax.named([4.0, 0.0, 0.0, 0.0, 0.0], (Dim,)) + + # y = sigmoid(Wx + b + N(0, noise^2)) > 0.5 + ds1 = LogitDataset(W1, 0.1, x1_mask, 0.0, key=next(keys)) + ds2 = LogitDataset(W2, 2.0, x2_mask, 0.0, key=next(keys)) + ds3 = LogitDataset(W3, 0.0, x3_mask, x3_bias, key=next(keys)) + + # TODO: remove key as a requirement for models + def compute_loss_fn(model, example, reduction=hax.mean, reduction_axis=None, key=None): + del key + y_pred = model(example.x) + return hax.nn.binary_cross_entropy_loss(y_pred, example.y, reduction=reduction, reduction_axis=reduction_axis) + + tiny_trainer_config = TrainerConfig( + num_train_steps=600, + train_batch_size=Batch.size, + tracker=(), + id="kmaklfmaf", + per_device_parallelism=Batch.size // len(jax.devices()), + ) + + optimizer = optax.adam(1e-2) + + trainer = Trainer(tiny_trainer_config, optimizer, compute_loss_fn) + + def fit_to_dataset(dataset): + initial_model = init_model() + with trainer: + state = trainer.initial_state(next(keys), model=initial_model) + loader = trainer.replicated_loader(dataset, Batch) + loader = non_caching_cycle(loader) + + loss = 0.0 + + # state = trainer.train(state, loader, run_hooks=False) + for state in trainer.training_steps(state, loader, run_hooks=False): + if state.step >= 200: + loss += state.loss + + return state.model, (loss / (state.step - 200)) + + def init_model(): + return hax.nn.Linear.init( + Dim, + (), + use_bias=True, + key=next(keys), + ) + + m1, loss1 = fit_to_dataset(ds1) + m2, loss2 = fit_to_dataset(ds2) + m3, loss3 = fit_to_dataset(ds3) + + assert loss3 < loss1 < loss2 + + datasets = {"d1": ds1, "d2": ds2, "d3": ds3} + + ref_model, ref_loss = fit_to_dataset(MixtureDataset(datasets, weights={k: 1 / 3.0 for k in datasets.keys()})) + + # let's see the loss on each dataset + l1_ref = eval_loss_loop( + compute_loss_fn, ref_model, trainer.replicated_loader(ds1, Batch), max_batches=10, name="d1" + ) + l2_ref = eval_loss_loop( + compute_loss_fn, ref_model, trainer.replicated_loader(ds2, Batch), max_batches=10, name="d2" + ) + l3_ref = eval_loss_loop( + compute_loss_fn, ref_model, trainer.replicated_loader(ds3, Batch), max_batches=10, name="d3" + ) + + assert l3_ref < l1_ref < l2_ref + + from levanter.doremi import estimate_mixture_weights + + w = estimate_mixture_weights( + initial_proxy=init_model(), + ref=ref_model, + data_sources=datasets, + trainer_config=tiny_trainer_config, + key=next(keys), + loss_fn=compute_loss_fn, + ) + + w1 = w["d1"] + w2 = w["d2"] + w3 = w["d3"] + + assert w1 > w3 > w2 + assert abs(w1 + w2 + w3 - 1.0) < 1e-3 + assert w2 < 0.05 # the noise distribution should get a very low weight diff --git a/tests/test_eval_lm.py b/tests/test_eval_lm.py index f1193f4f4..a6bf3c8d9 100644 --- a/tests/test_eval_lm.py +++ b/tests/test_eval_lm.py @@ -11,8 +11,9 @@ import tiny_test_corpus from levanter.checkpoint import save_checkpoint from levanter.distributed import RayConfig -from levanter.logging import WandbConfig from levanter.models.gpt2 import Gpt2LMHeadModel +from levanter.tracker.wandb import WandbConfig +from levanter.trainer import TrainerState from levanter.utils.py_utils import logical_cpu_core_count @@ -43,7 +44,9 @@ def test_eval_lm(): Vocab = haliax.Axis("vocab", len(tok)) model = Gpt2LMHeadModel.init(Vocab, model_config, key=jax.random.PRNGKey(0)) - save_checkpoint(model, None, 0, f"{f}/ckpt") + state = TrainerState(0, model, model, jax.random.PRNGKey(0), True) + + save_checkpoint(state, 0, f"{f}/ckpt") config = eval_lm.EvalLmConfig( data=data_config, diff --git a/tests/test_export_to_hf.py b/tests/test_export_to_hf.py index b50bde9cb..ed6a0d4c0 100644 --- a/tests/test_export_to_hf.py +++ b/tests/test_export_to_hf.py @@ -34,7 +34,7 @@ def test_export_lm_to_hf(): # in our trainer, we only export the trainable params trainable, non_trainable = eqx.partition(model, is_inexact_arrayish) - save_checkpoint(trainable, None, 0, f"{tmpdir}/ckpt") + save_checkpoint({"model": trainable}, 0, f"{tmpdir}/ckpt") try: config = export_lm_to_hf.ConvertLmConfig( @@ -50,8 +50,7 @@ def test_export_lm_to_hf(): export_lm_to_hf.main(config) if has_torch(): - m = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/output") - print(m) + AutoModelForCausalLM.from_pretrained(f"{tmpdir}/output") finally: try: diff --git a/tests/test_hf_checkpoints.py b/tests/test_hf_checkpoints.py index 4f5b58b1d..169bb3999 100644 --- a/tests/test_hf_checkpoints.py +++ b/tests/test_hf_checkpoints.py @@ -50,7 +50,7 @@ def test_save_backpack_model_with_code(): with tempfile.TemporaryDirectory() as tmpdir: converter._save_pretrained_local( - lev_model, tmpdir, save_reference_code=True, save_tokenizer=True, max_shard_size=100000 + lev_model, tmpdir, save_tokenizer=True, save_reference_code=True, max_shard_size=1e8 ) new_converter = converter.replaced(reference_checkpoint=tmpdir, trust_remote_code=True) diff --git a/tests/test_hf_gpt2_serialize.py b/tests/test_hf_gpt2_serialize.py index c9c2bbc61..34c4bb941 100644 --- a/tests/test_hf_gpt2_serialize.py +++ b/tests/test_hf_gpt2_serialize.py @@ -17,7 +17,7 @@ from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel from levanter.models.loss import next_token_loss -from levanter.trainer import OptimizerConfig +from levanter.optim import AdamConfig from levanter.utils.tree_utils import inference_mode from test_utils import skip_if_no_torch @@ -142,7 +142,7 @@ def compute_loss(model, input_ids): assert onp.isclose(jax_g, torch_g.detach().cpu().numpy(), rtol=1e-2, atol=1e-2).all(), f"{jax_g} != {torch_g}" # now we also want to check that the optimizers do similar things - optimizer_config = OptimizerConfig(weight_decay=0.0, learning_rate=1e-3, warmup_ratio=0.0, lr_schedule="constant") + optimizer_config = AdamConfig(weight_decay=0.0, learning_rate=1e-3, warmup_ratio=0.0, lr_schedule="constant") if optimizer_config.max_grad_norm is not None: torch.nn.utils.clip_grad_norm_(torch_model.parameters(), optimizer_config.max_grad_norm) diff --git a/tests/test_levanter_hf_consistency.py b/tests/test_levanter_hf_consistency.py index 9a0aadb61..5fafe9791 100644 --- a/tests/test_levanter_hf_consistency.py +++ b/tests/test_levanter_hf_consistency.py @@ -5,7 +5,6 @@ import haliax as hax from haliax import Axis -from haliax.partitioning import round_axis_for_partitioning from levanter.checkpoint import load_checkpoint from levanter.models.backpack import BackpackLMHeadModel @@ -34,8 +33,7 @@ def test_hf_backpack_consistency(): model_config: BackpackConfig = BackpackConfig.from_hf_config(hf_model_config) trainer_config = TrainerConfig() - vocab_size = hf_model_config.vocab_size - Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), trainer_config.compute_axis_mapping) + Vocab = Axis("vocab", hf_model_config.vocab_size) model_key = PRNGKey(0) model_levanter = BackpackLMHeadModel.init(Vocab, model_config, key=model_key) model_levanter, (_, _), _ = load_checkpoint( @@ -59,18 +57,17 @@ def test_hf_gpt2_consistency(): from levanter.models.gpt2 import Gpt2Config - model_config: GPT2Config = Gpt2Config.from_hf_config(hf_model_config) + model_config = Gpt2Config.from_hf_config(hf_model_config) trainer_config = TrainerConfig() - vocab_size = hf_model_config.vocab_size - Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), trainer_config.compute_axis_mapping) + Vocab = Axis("vocab", hf_model_config.vocab_size) model_key = PRNGKey(0) model_levanter = Gpt2LMHeadModel.init(Vocab, model_config, key=model_key) - model_levanter, (_, _), _ = load_checkpoint( + model_levanter = load_checkpoint( model_levanter, - (None, None), checkpoint_path=LEVANTER_GPT2_CHECKPOINT, discover_latest=True, + subpath="model", ) mp = trainer_config.mp model_levanter = mp.cast_to_param(model_levanter) diff --git a/tests/test_logging.py b/tests/test_logging.py index dc74c78ed..ab7cc35f2 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -3,7 +3,7 @@ import pytest from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from levanter.logging import WandbConfig +from levanter.tracker.helpers import infer_experiment_git_root def test_infer_experiment_git_root(): @@ -13,12 +13,11 @@ def test_infer_experiment_git_root(): except (InvalidGitRepositoryError, NoSuchPathError): pytest.skip("test not running in a git repo") - root = WandbConfig._infer_experiment_git_root() + root = infer_experiment_git_root() # ensure that 1) this is a git root and 2) this source file is underneath assert root is not None assert pathlib.Path(root).exists() repo = Repo(root) assert repo.working_dir == root - print(root, __file__) assert pathlib.Path(__file__).is_relative_to(root), f"{__file__} is not relative to {root}" diff --git a/tests/test_lora.py b/tests/test_lora.py index 5ba011bce..46cc2e0f8 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -198,6 +198,7 @@ def test_lora_load_in_peft(): @skip_if_no_torch def test_lora_merged_load_in_hf(): + jax.config.update("jax_traceback_filtering", "off") import torch converter: HFCheckpointConverter = Gpt2Config.default_hf_checkpoint_converter @@ -212,7 +213,7 @@ def test_lora_merged_load_in_hf(): causal_mask = hax.nn.attention.causal_mask(model.Pos, config.KeyPos) - with (tempfile.TemporaryDirectory() as tmpdir): + with tempfile.TemporaryDirectory() as tmpdir: converter.save_pretrained(model, f"{tmpdir}/model") lora_config = LoraConfig(r=8, target_modules=["c_attn"]) diff --git a/tests/test_mpt.py b/tests/test_mpt.py index 8b1fd2bb9..8f3384c3d 100644 --- a/tests/test_mpt.py +++ b/tests/test_mpt.py @@ -10,7 +10,7 @@ from levanter.models.mpt import MptConfig, MptLmHeadModel from levanter.utils.tree_utils import inference_mode -from test_utils import check_load_config, check_model_works_with_seqlen, parameterize_with_configs, skip_if_no_torch +from test_utils import check_model_works_with_seqlen, skip_if_no_torch @pytest.mark.skip(reason="MPT is broken in the latest version of transformers") @@ -104,15 +104,6 @@ def test_mpt_nano_compare(attn_impl): # lev_model = MptLmHeadModel.from_hf_pretrained("mosaicml/mpt-7b") -@parameterize_with_configs("mpt*.yaml") -def test_mpt_configs(config_file): - from levanter.main.train_lm import TrainLmConfig - - config_class = TrainLmConfig - - check_load_config(config_class, config_file) - - def test_pass_different_length_seq(): config = MptConfig( max_seq_len=32, diff --git a/tests/test_py_utils.py b/tests/test_py_utils.py new file mode 100644 index 000000000..50b3461ee --- /dev/null +++ b/tests/test_py_utils.py @@ -0,0 +1,8 @@ +from levanter.utils.py_utils import actual_sizeof + + +def test_actual_sizeof(): + d1 = {"a": 1, "b": 2} + d2 = {"a": "this is a string", "b": "this is another string"} + + assert actual_sizeof(d1) < actual_sizeof(d2) diff --git a/tests/test_replicated_loader.py b/tests/test_replicated_loader.py index 431a1c0bb..347c153d5 100644 --- a/tests/test_replicated_loader.py +++ b/tests/test_replicated_loader.py @@ -40,12 +40,12 @@ def test_local_batched_data_loading_model_axis_2(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): + with haliax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 cache = _small_dataset(seq_len) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(cache, mesh, Batch) + loader = ReplicatedBatchLoader(cache, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -60,12 +60,12 @@ def test_local_batched_data_loading_model_axis_1(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): + with haliax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 cache = _small_dataset(seq_len) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(cache, mesh, Batch) + loader = ReplicatedBatchLoader(cache, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -105,11 +105,11 @@ def test_structured_batches_model_axis_1(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): + with haliax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 dataset = StructuredDataset(seq_len, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = ReplicatedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -125,11 +125,11 @@ def test_structured_batches_model_axis_2(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): + with haliax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 dataset = StructuredDataset(seq_len, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = ReplicatedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -180,12 +180,12 @@ def test_structured_batches_model_axis_1_with_names(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): + with haliax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): Height = Axis("Height", 16) Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = ReplicatedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -203,12 +203,12 @@ def test_structured_batches_model_axis_2_with_names(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA}): + with haliax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): Height = Axis("Height", 16) Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = ReplicatedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -227,10 +227,10 @@ def test_structured_batches_model_axis_2_subsharded(): ) Height = Axis("Height", 16) Width = Axis("Width", 16) - with mesh, haliax.axis_mapping({"batch": ResourceAxis.DATA, Height.name: ResourceAxis.MODEL}): + with haliax.resource_env({"batch": ResourceAxis.DATA, Height.name: ResourceAxis.MODEL}, mesh=mesh): dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ReplicatedBatchLoader(dataset, mesh, Batch) + loader = ReplicatedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: diff --git a/tests/test_shard_cache.py b/tests/test_shard_cache.py index c1b55ca1e..6b54970c5 100644 --- a/tests/test_shard_cache.py +++ b/tests/test_shard_cache.py @@ -13,10 +13,12 @@ def setup_module(module): + print("setting up") ray.init("local", num_cpus=2 * logical_cpu_core_count()) # 2x cpu count is faster on my m1 def teardown_module(module): + print("shutting down") ray.shutdown() @@ -175,9 +177,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[List[int]]: ) # now block until the cache is done - print("at wait") cache.await_finished(timeout=10) - print("done waiting") # now check that the chunks are in the right order # TODO: this is a bit gross @@ -233,7 +233,7 @@ def back_to_py(batch: pa.RecordBatch): assert [list(x) for x in chunk] == [[i] * 10 for i in range(10)] with pytest.raises(TimeoutError): - cache.get_chunk(1, timeout=0.1) + cache.get_chunk(1, timeout=0.5) ray.get(blocker_to_wait_on_test.unblock.remote()) diff --git a/tests/test_sharded_loader.py b/tests/test_sharded_loader.py index 19f72bcfe..e83a21e02 100644 --- a/tests/test_sharded_loader.py +++ b/tests/test_sharded_loader.py @@ -44,11 +44,11 @@ def test_sharded_data_loading_model_axis_2(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 cache = _small_dataset(seq_len) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(cache, mesh, Batch) + loader = ShardedBatchLoader(cache, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -63,11 +63,11 @@ def test_sharded_data_loading_model_axis_1(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 cache = _small_dataset(seq_len) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(cache, mesh, Batch) + loader = ShardedBatchLoader(cache, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -107,11 +107,11 @@ def test_structured_batches_model_axis_1(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 dataset = StructuredDataset(seq_len, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) + loader = ShardedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -141,10 +141,10 @@ def test_can_batch_named_scalars(): model_axis_size = 1 mesh = Mesh(np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL)) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): dataset = ScalarDataset(0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) + loader = ShardedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -160,11 +160,11 @@ def test_structured_batches_model_axis_2(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): seq_len = 128 dataset = StructuredDataset(seq_len, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) + loader = ShardedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -216,12 +216,12 @@ def test_structured_batches_model_axis_1_with_names(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): Height = Axis("Height", 16) Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) + loader = ShardedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -237,12 +237,12 @@ def test_structured_batches_model_axis_2_with_names(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): Height = Axis("Height", 16) Width = Axis("Width", 16) dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) + loader = ShardedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -261,10 +261,10 @@ def test_structured_batches_model_axis_2_subsharded(): ) Height = Axis("Height", 16) Width = Axis("Width", 16) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA, Height.name: ResourceAxis.MODEL}): + with hax.resource_env({"batch": ResourceAxis.DATA, Height.name: ResourceAxis.MODEL}, mesh=mesh): dataset = StructuredDatasetWithNames(Height, Width, 0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) + loader = ShardedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) for batch in batches: @@ -279,10 +279,10 @@ def test_sharded_loader_doesnt_throw_away_data(): np.array(devices).reshape(-1, model_axis_size), (ResourceAxis.DATA, ResourceAxis.MODEL), ) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + with hax.resource_env({"batch": ResourceAxis.DATA}, mesh=mesh): dataset = ScalarDataset(0, 256, 1) Batch = Axis("batch", len(devices)) - loader = ShardedBatchLoader(dataset, mesh, Batch) + loader = ShardedBatchLoader(dataset, Batch) batches = list(itertools.islice(loader, 10)) dataset_examples = list(itertools.islice(dataset, 10 * Batch.size)) diff --git a/tests/test_sophia.py b/tests/test_sophia.py new file mode 100644 index 000000000..7e759c330 --- /dev/null +++ b/tests/test_sophia.py @@ -0,0 +1,56 @@ +import os + +import equinox as eqx +import equinox.nn as nn +import jax +import jax.numpy as jnp +import numpy as np + +import levanter +import levanter.optim.sophia + + +def test_sophia_h(): + key = jax.random.PRNGKey(0) + model = nn.Linear(4, 4, use_bias=False, key=key) + data = np.load(f"{os.path.dirname(__file__)}/data/hero_data.npy").astype("float32") + optimizer = levanter.optim.sophia.sophia_h( + lr=1, b1=0, b2=0.99, gamma=2, weight_decay=0.0, clip_threshold=1, key=key + ) + model = jax.tree_util.tree_map(lambda x: jnp.ones_like(x), model) + + opt_state = optimizer.init(model) + + def loss_fn(model, data): + out = eqx.filter_vmap(model)(data) + return jnp.mean(out**2) * 4 + + jit_update = eqx.filter_jit(optimizer.update_hessian) + + for i in range(1000): + opt_state = jit_update(opt_state, loss_fn, model, data) + + # print('Test-estimated hessian: most coordinates should be approximately 2') + # print('Estimated hessian:', opt_state[0].h.weight) + assert jnp.allclose(opt_state[0].h.weight, 2, rtol=0.2, atol=0.3) # this is very approximate + + grad_loss_fn = eqx.filter_jit(eqx.filter_value_and_grad(loss_fn)) + + loss, grad = grad_loss_fn(model, data) + model_updates, opt_state = optimizer.update(grad, opt_state) + model = eqx.apply_updates(model, model_updates) + + # loss should be 15.74834156036377 + assert jnp.allclose(loss, 15.74834156036377) + + # print("Test-model param after 1 step: most coordinates should be very loosely 0.5") + assert jnp.allclose(model.weight, 0.5, rtol=0.2, atol=0.1) # this is very approximate + + # print("Test-loss: loss should shrink by approximately 75% after each iteration") + for i in range(10): + loss, grad = grad_loss_fn(model, data) + model_updates, opt_state = optimizer.update(grad, opt_state) + model = eqx.apply_updates(model, model_updates) + + # print('Step:', i , "Loss:", loss.item()) + assert loss < 15.74834156036377 * 0.75 ** (i + 1) diff --git a/tests/test_tracker.py b/tests/test_tracker.py new file mode 100644 index 000000000..15485b83e --- /dev/null +++ b/tests/test_tracker.py @@ -0,0 +1,80 @@ +# NOTE: Do not explicitly import wandb/other trackers here, as this will cause the tests to trivially pass. +import dataclasses +from typing import Tuple + +import pytest +import yaml + +import levanter.tracker +from levanter.tracker import CompositeTracker, TrackerConfig + + +def test_tracker_plugin_stuff_works(): + assert TrackerConfig.get_choice_class("wandb") is not None + with pytest.raises(KeyError): + TrackerConfig.get_choice_class("foo") + + +def test_tracker_plugin_default_works(): + config = """ + tracker: + entity: foo + """ + parsed = yaml.safe_load(config) + + @dataclasses.dataclass + class ConfigHolder: + tracker: TrackerConfig + + import draccus + + tconfig = draccus.decode(ConfigHolder, parsed).tracker + + assert isinstance(tconfig, TrackerConfig.get_choice_class("wandb")) + + assert tconfig.entity == "foo" # type: ignore + + +def test_tracker_plugin_multi_parsing_work(): + config = """ + tracker: + type: noop + """ + parsed = yaml.safe_load(config) + + @dataclasses.dataclass + class ConfigHolder: + tracker: TrackerConfig | Tuple[TrackerConfig, ...] + + import draccus + + from levanter.tracker.tracker import NoopConfig + + assert isinstance(draccus.decode(ConfigHolder, parsed).tracker, NoopConfig) + + config = """ + tracker: + - type: noop + - type: wandb + """ + parsed = yaml.safe_load(config) + decoded = draccus.decode(ConfigHolder, parsed).tracker + assert decoded == (NoopConfig(), TrackerConfig.get_choice_class("wandb")()) + + +def test_get_tracker_by_name(): + wandb_config = TrackerConfig.get_choice_class("wandb") + if wandb_config is None: + pytest.skip("wandb not installed") + + from levanter.tracker import NoopTracker + + wandb1 = wandb_config(mode="disabled").init(None) + tracker = CompositeTracker([wandb1, NoopTracker()]) + + with tracker: + assert levanter.tracker.get_tracker("wandb") is wandb1 + assert levanter.tracker.get_tracker("noop") is not None + + with pytest.raises(KeyError): + levanter.tracker.get_tracker("foo") diff --git a/tests/test_train_lm.py b/tests/test_train_lm.py index 3cd762d8b..f95b27efb 100644 --- a/tests/test_train_lm.py +++ b/tests/test_train_lm.py @@ -8,7 +8,7 @@ import levanter.main.train_lm as train_lm import tiny_test_corpus from levanter.distributed import RayConfig -from levanter.logging import WandbConfig +from levanter.tracker.wandb import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count diff --git a/tests/test_utils.py b/tests/test_utils.py index b2b060c28..08df42f69 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -139,21 +139,21 @@ def try_load_path(path): else: return True - return pytest.mark.skipif(not try_load_path(path), reason="Checkpoint not accessible")(lambda x: x) + return pytest.mark.skipif(not try_load_path(path), reason="Checkpoint not accessible") def skip_if_hf_model_not_accessible(model_id: str): def try_load_hf(model_id): try: - from transformers import AutoModelForCausalLM + from transformers import AutoConfig - AutoModelForCausalLM.from_pretrained(model_id) + AutoConfig.from_pretrained(model_id) except Exception: return False else: return True - return pytest.mark.skipif(not try_load_hf(model_id), reason="HuggingFace model not accessible")(lambda x: x) + return pytest.mark.skipif(not try_load_hf(model_id), reason="HuggingFace model not accessible") class IdentityProcessor(BatchProcessor[BatchEncoding]): diff --git a/tests/test_viz_lm.py b/tests/test_viz_lm.py index 665c98772..71d117055 100644 --- a/tests/test_viz_lm.py +++ b/tests/test_viz_lm.py @@ -11,14 +11,18 @@ import tiny_test_corpus from levanter.checkpoint import save_checkpoint from levanter.distributed import RayConfig -from levanter.logging import WandbConfig from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel +from levanter.tracker.wandb import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count def setup_module(module): ray_designated_cores = max(1, logical_cpu_core_count()) - ray.init("local", num_cpus=ray_designated_cores) + try: + ray.init("local", num_cpus=ray_designated_cores) + except AssertionError: + # don't get upset if ray is already running + pass def teardown_module(module): @@ -43,7 +47,7 @@ def test_viz_lm(): Vocab = haliax.Axis("vocab", len(tok)) model = Gpt2LMHeadModel.init(Vocab, model_config, key=jax.random.PRNGKey(0)) - save_checkpoint(model, None, 0, f"{f}/ckpt") + save_checkpoint({"model": model}, 0, f"{f}/ckpt") config = viz_logprobs.VizGpt2Config( data=data_config, diff --git a/tests/test_weight_decay_mask.py b/tests/test_weight_decay_mask.py index 52834e679..cc94c5749 100644 --- a/tests/test_weight_decay_mask.py +++ b/tests/test_weight_decay_mask.py @@ -5,7 +5,7 @@ import haliax as hax from levanter.models.gpt2 import Gpt2Config -from levanter.trainer import OptimizerConfig +from levanter.optim import AdamConfig def test_weight_decay_masking(): @@ -43,7 +43,7 @@ def apply_weight_decay(tree): gpt_config = Gpt2Config() Vocab = hax.Axis("vocab", 100) model = gpt_config.build(Vocab, key=jrandom.PRNGKey(0)) - string_list_config = OptimizerConfig( + string_list_config = AdamConfig( weight_decay_modules=[ "attn.c_attn.weight", "attn.c_proj.weight", @@ -53,7 +53,7 @@ def apply_weight_decay(tree): "position_embeddings.weight", ] ) - regex_config = OptimizerConfig( + regex_config = AdamConfig( weight_decay_modules=r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings", ) # masking using `equinox.tree_at`