-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Eleuther LM-Eval-Harness in Levanter (#675)
Adds Eleuther's LM Eval Harness as a callback in Levanter. It's much slower than it needs to be because I'm not doing any sequence packing, but it gets the job done. Scores on Llama 3 seem reasonable, so I think this is right. Closes #564 --------- Co-authored-by: Jason Wang <[email protected]>
- Loading branch information
1 parent
586225e
commit d125206
Showing
14 changed files
with
1,070 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
eval_harness: | ||
task_spec: ["piqa", "hellaswag"] | ||
max_examples: 32 | ||
eval_harness_steps: 50 | ||
data: | ||
id: dlwh/wikitext_103_detokenized | ||
model: | ||
type: gpt2 | ||
hidden_dim: 32 | ||
num_heads: 4 | ||
num_layers: 2 | ||
trainer: | ||
mp: f32 | ||
num_train_steps: 100 | ||
|
||
checkpointer: | ||
keep: | ||
- every: 50 | ||
save_interval: 5m | ||
|
||
per_device_parallelism: -1 | ||
train_batch_size: 4 | ||
|
||
tensor_parallel_axes: ["mlp", "heads"] | ||
fsdp_axis: "embed" | ||
batch_axis: "batch" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
eval_harness: | ||
task_spec: | ||
- task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios | ||
num_fewshot: 10 | ||
- task: agieval_lsat_ar # 3-shot tests in legal domain | ||
num_fewshot: 3 | ||
- task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science | ||
num_fewshot: 10 | ||
- task: arc_challenge # a (harder) version of arc_easy | ||
num_fewshot: 10 | ||
- task: boolq # answer yes/no questions based on a passage | ||
num_fewshot: 10 | ||
- task: copa # use causal reasoning to predict the correct outcome of a given scenario | ||
num_fewshot: 0 | ||
- task: hellaswag # 4-way multiple choice commonsense reasoning dataset | ||
num_fewshot: 0 | ||
task_alias: hellaswag_0shot | ||
- task: hellaswag # 4-way multiple choice commonsense reasoning dataset | ||
num_fewshot: 10 | ||
task_alias: hellaswag_10shot | ||
- task: lambada # predict the endings of text passages | ||
num_fewshot: 0 | ||
- task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning | ||
num_fewshot: 0 | ||
- task: piqa # answer questions based on a passage | ||
num_fewshot: 10 | ||
- task: wsc273 # Winograd Schema Challenge | ||
num_fewshot: 0 | ||
- task: winogrande # Winograd challenge, extended to more domains | ||
num_fewshot: 0 | ||
# requires generation | ||
## - task: squadv2 # reading comprehension benchmark | ||
# num_fewshot: 10 | ||
max_eval_length: 4096 | ||
tokenizer: meta-llama/Meta-Llama-3-8B | ||
model: | ||
type: llama | ||
#checkpoint_path: gs://marin-us-central2/checkpoints/dclm_baseline_1b_1x_replication_nov12_3404462497seed-b68241/hf/step-54930 | ||
checkpoint_path: meta-llama/Meta-Llama-3-8B | ||
checkpoint_is_hf: true | ||
trainer: | ||
mp: f32 | ||
profiler: true | ||
|
||
per_device_parallelism: -1 | ||
train_batch_size: 512 | ||
|
||
tensor_parallel_axes: ["mlp", "heads"] | ||
fsdp_axis: "embed" | ||
batch_axis: "batch" | ||
ray: | ||
auto_start_cluster: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
eval_harness: | ||
task_spec: ["hellaswag"] | ||
tokenizer: "gpt2" | ||
model: | ||
type: gpt2 | ||
hidden_dim: 32 | ||
num_heads: 4 | ||
num_layers: 2 | ||
trainer: | ||
mp: f32 | ||
num_train_steps: 100 | ||
profiler: true | ||
|
||
checkpointer: | ||
keep: | ||
- every: 50 | ||
save_interval: 5m | ||
|
||
per_device_parallelism: -1 | ||
train_batch_size: 32 | ||
|
||
tensor_parallel_axes: ["mlp", "heads"] | ||
fsdp_axis: "embed" | ||
batch_axis: "batch" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
#data: !include data/dolma_olmo_paloma.yaml | ||
data: | ||
cache_dir: "gs://marin-data/tokenized/OLMo-1B/dolma-v1.7" | ||
tokenizer: "allenai/OLMo-1B" # requires `pip install ai2-olmo` | ||
# tokenizer: "meta-llama/Llama-2-7b-hf" | ||
stop_strategy: restart | ||
shuffle_buffer_size: 100000 | ||
configs: | ||
dolma-algebraic-stack: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/algebraic-stack-train-{0000..0015}.json.gz | ||
dolma-arxiv: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/arxiv-{0000..0099}.json.gz | ||
dolma-gutenberg: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/books-{0000..0002}.json.gz | ||
dolma-c4: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/c4-{0000..0170}.json.gz | ||
dolma-cc: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_head-{0000..0274}.json.gz | ||
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0000..0238}.json.gz # 239 is missing | ||
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0240..0379}.json.gz | ||
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0000..0152}.json.gz # 153 is missing | ||
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0154..0444}.json.gz | ||
dolma-cc-news: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_head-{0000..0004}.json.gz | ||
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_middle-{0000..0002}.json.gz | ||
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_tail-0000.json.gz | ||
dolma-falcon: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/falcon-{0000..0499}.json.gz | ||
dolma-megawika: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/megawika-{0000..0261}.json.gz | ||
dolma-owmath: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/open-web-math-train-{0000..0012}.json.gz | ||
dolma-pes2o: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/pes2o-{0000..0025}.json.gz | ||
dolma-reddit: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/reddit-{0000..0077}.json.gz | ||
dolma-stackexchange: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/stackexchange-{0000..0025}.json.gz | ||
dolma-starcoder: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/starcoder-{0000..0048}.json.gz | ||
dolma-flan: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/tulu_flan-{0000..0065}.json.gz | ||
dolma-wiki: | ||
train_urls: | ||
- gs://marin-data/raw/dolma/dolma-v1.7/wiki-{0000..0001}.json.gz | ||
# these are just for eval | ||
"paloma/4chan": | ||
validation_urls: | ||
- gs://levanter-data/paloma/4chan_meta_sep/val/val*.jsonl.gz | ||
"paloma/c4_100_domains": | ||
validation_urls: | ||
- gs://levanter-data/paloma/c4_100_domains/val/val*.jsonl.gz | ||
"paloma/c4_en": | ||
validation_urls: | ||
- gs://levanter-data/paloma/c4_en/val/val*.jsonl.gz | ||
"paloma/dolma-v1_5": | ||
validation_urls: | ||
- gs://levanter-data/paloma/dolma-v1_5/val/val*.jsonl.gz | ||
"paloma/dolma_100_programing_languages": | ||
validation_urls: | ||
- gs://levanter-data/paloma/dolma_100_programing_languages/val/val*.jsonl.gz | ||
"paloma/dolma_100_subreddits": | ||
validation_urls: | ||
- gs://levanter-data/paloma/dolma_100_subreddits/val/val*.jsonl.gz | ||
"paloma/falcon-refinedweb": | ||
validation_urls: | ||
- gs://levanter-data/paloma/falcon-refinedweb/val/val*.jsonl.gz | ||
"paloma/gab": | ||
validation_urls: | ||
- gs://levanter-data/paloma/gab/val/val*.jsonl.gz | ||
"paloma/m2d2_s2orc_unsplit": | ||
validation_urls: | ||
- gs://levanter-data/paloma/m2d2_s2orc_unsplit/val/val*.jsonl.gz | ||
"paloma/m2d2_wikipedia_unsplit": | ||
validation_urls: | ||
- gs://levanter-data/paloma/m2d2_wikipedia_unsplit/val/val*.jsonl.gz | ||
"paloma/manosphere_meta_sep": | ||
validation_urls: | ||
- gs://levanter-data/paloma/manosphere_meta_sep/val/val*.jsonl.gz | ||
"paloma/mc4": | ||
validation_urls: | ||
- gs://levanter-data/paloma/mc4/val/val*.jsonl.gz | ||
"paloma/ptb": | ||
validation_urls: | ||
- gs://levanter-data/paloma/ptb/val/val*.jsonl.gz | ||
"paloma/redpajama": | ||
validation_urls: | ||
- gs://levanter-data/paloma/redpajama/val/val*.jsonl.gz | ||
"paloma/twitterAAE_HELM_fixed": | ||
validation_urls: | ||
- gs://levanter-data/paloma/twitterAAE_HELM_fixed/val/val*.jsonl.gz | ||
"paloma/wikitext_103": | ||
validation_urls: | ||
- gs://levanter-data/paloma/wikitext_103/val/val*.jsonl.gz | ||
train_weights: | ||
# sampling proportion comes from https://huggingface.co/datasets/allenai/dolma | ||
dolma-algebraic-stack: 12.6 # 12.6 * 1.0 | ||
dolma-arxiv: 28.0 # 28.0 * 1.0 | ||
dolma-gutenberg: 5.3 # 5.3 * 1.0 | ||
dolma-c4: 69.2 # 138.4 * 0.5 | ||
dolma-cc: 597.75 # 1,195.5 * 0.5 | ||
dolma-cc-news: 14.3 # 1.0 | ||
dolma-falcon: 456.4 # 1.0, refined web | ||
dolma-megawika: 4.6 # 1.0 | ||
dolma-owmath: 12.6 # 1.0 | ||
dolma-pes2o: 57.2 # 1.0 | ||
dolma-reddit: 79.9 # 1.0 | ||
dolma-stackexchange: 19.6 # 1.0 | ||
dolma-starcoder: 263.8 # 1.0 | ||
dolma-flan: 16.5 # 6.5 * 1.0 | ||
dolma-wiki: 7.4 # 3.7 * 2.0 | ||
paloma/4chan: 0.0 | ||
paloma/c4_100_domains: 0.0 | ||
paloma/c4_en: 0.0 | ||
paloma/dolma-v1_5: 0.0 | ||
paloma/dolma_100_programing_languages: 0.0 | ||
paloma/dolma_100_subreddits: 0.0 | ||
paloma/falcon-refinedweb: 0.0 | ||
paloma/gab: 0.0 | ||
paloma/m2d2_s2orc_unsplit: 0.0 | ||
paloma/m2d2_wikipedia_unsplit: 0.0 | ||
paloma/manosphere_meta_sep: 0.0 | ||
paloma/mc4: 0.0 | ||
paloma/ptb: 0.0 | ||
paloma/redpajama: 0.0 | ||
paloma/twitterAAE_HELM_fixed: 0.0 | ||
paloma/wikitext_103: 0.0 | ||
model: # 7B class model | ||
type: llama | ||
seq_len: 2048 | ||
hidden_dim: 4096 | ||
intermediate_dim: 11008 | ||
num_layers: 32 | ||
num_heads: 32 | ||
num_kv_heads: 32 | ||
use_flash_attention: True | ||
# flash_attention_block_size: 1024 | ||
|
||
use_bias: false | ||
use_layer_norm_weight: false | ||
trainer: | ||
tracker: | ||
type: wandb | ||
project: "marin" | ||
tags: ["dolma", "olmo", "llama"] | ||
|
||
mp: p=f32,c=bfloat16 | ||
train_batch_size: 2048 # olmo actually uses 2160 table 5 of https://arxiv.org/pdf/2402.00838 | ||
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 | ||
steps_per_eval: 1000 | ||
tensor_parallel_axes: ["mlp", "heads"] | ||
fsdp_axis: "embed" | ||
batch_axis: "batch" | ||
replica_dcn_axis_size: 2 | ||
optimizer: | ||
learning_rate: 3E-4 | ||
weight_decay: 0.1 | ||
min_lr_ratio: 0.1 | ||
beta1: 0.9 | ||
beta2: 0.95 | ||
warmup: 2000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,11 @@ name = "levanter" | |
version = "1.2" | ||
authors = [ | ||
{ name = "David Hall", email = "[email protected]" }, | ||
{ name = "Jason Wang"}, | ||
{ name = "Ahmed Ahmed"}, | ||
{ name = "Ivan Zhou", email = "[email protected]" }, | ||
{ name = "Will Held"}, | ||
{ name = "Virginia Adams"} | ||
] | ||
description = "Scalable Training for Foundation Models with Named Tensors and JAX" | ||
readme = "README.md" | ||
|
@@ -47,10 +51,11 @@ dependencies = [ | |
"pydantic<3", | ||
"rich~=13.0", | ||
"filelock~=3.13", | ||
# "ai2-olmo", | ||
"async-lru~=2.0", | ||
"tqdm-loggable>=0.2", | ||
"deepdiff" | ||
"deepdiff", | ||
# "lm-eval==0.4.2", | ||
"lm-eval @ git+https://github.com/dlwh/lm-evaluation-harness.git@no_torch" | ||
] | ||
|
||
[project.urls] | ||
|
@@ -100,11 +105,11 @@ markers = [ | |
[project.optional-dependencies] | ||
test = [ | ||
"pytest", | ||
"pytest-forked", | ||
"pytest-asyncio", | ||
"flake8", | ||
"soundfile", | ||
"librosa", | ||
"pytest-forked", | ||
"pytest-asyncio", | ||
] | ||
|
||
[tool.setuptools.packages.find] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.