Skip to content

Commit

Permalink
Support Eleuther LM-Eval-Harness in Levanter (#675)
Browse files Browse the repository at this point in the history
Adds Eleuther's LM Eval Harness as a callback in Levanter. It's much
slower than it needs to be because I'm not doing any sequence packing,
but it gets the job done. Scores on Llama 3 seem reasonable, so I think
this is right.

Closes #564

---------

Co-authored-by: Jason Wang <[email protected]>
  • Loading branch information
dlwh and blahBlahhhJ authored Dec 4, 2024
1 parent 586225e commit d125206
Show file tree
Hide file tree
Showing 14 changed files with 1,070 additions and 28 deletions.
26 changes: 26 additions & 0 deletions config/gpt2_nano_harness.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
eval_harness:
task_spec: ["piqa", "hellaswag"]
max_examples: 32
eval_harness_steps: 50
data:
id: dlwh/wikitext_103_detokenized
model:
type: gpt2
hidden_dim: 32
num_heads: 4
num_layers: 2
trainer:
mp: f32
num_train_steps: 100

checkpointer:
keep:
- every: 50
save_interval: 5m

per_device_parallelism: -1
train_batch_size: 4

tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
52 changes: 52 additions & 0 deletions config/harness/eval_llama3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
eval_harness:
task_spec:
- task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios
num_fewshot: 10
- task: agieval_lsat_ar # 3-shot tests in legal domain
num_fewshot: 3
- task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science
num_fewshot: 10
- task: arc_challenge # a (harder) version of arc_easy
num_fewshot: 10
- task: boolq # answer yes/no questions based on a passage
num_fewshot: 10
- task: copa # use causal reasoning to predict the correct outcome of a given scenario
num_fewshot: 0
- task: hellaswag # 4-way multiple choice commonsense reasoning dataset
num_fewshot: 0
task_alias: hellaswag_0shot
- task: hellaswag # 4-way multiple choice commonsense reasoning dataset
num_fewshot: 10
task_alias: hellaswag_10shot
- task: lambada # predict the endings of text passages
num_fewshot: 0
- task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning
num_fewshot: 0
- task: piqa # answer questions based on a passage
num_fewshot: 10
- task: wsc273 # Winograd Schema Challenge
num_fewshot: 0
- task: winogrande # Winograd challenge, extended to more domains
num_fewshot: 0
# requires generation
## - task: squadv2 # reading comprehension benchmark
# num_fewshot: 10
max_eval_length: 4096
tokenizer: meta-llama/Meta-Llama-3-8B
model:
type: llama
#checkpoint_path: gs://marin-us-central2/checkpoints/dclm_baseline_1b_1x_replication_nov12_3404462497seed-b68241/hf/step-54930
checkpoint_path: meta-llama/Meta-Llama-3-8B
checkpoint_is_hf: true
trainer:
mp: f32
profiler: true

per_device_parallelism: -1
train_batch_size: 512

tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
ray:
auto_start_cluster: false
24 changes: 24 additions & 0 deletions config/harness/harness_nano.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
eval_harness:
task_spec: ["hellaswag"]
tokenizer: "gpt2"
model:
type: gpt2
hidden_dim: 32
num_heads: 4
num_layers: 2
trainer:
mp: f32
num_train_steps: 100
profiler: true

checkpointer:
keep:
- every: 50
save_interval: 5m

per_device_parallelism: -1
train_batch_size: 32

tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
175 changes: 175 additions & 0 deletions config/olmo/olmo_7b_repro.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#data: !include data/dolma_olmo_paloma.yaml
data:
cache_dir: "gs://marin-data/tokenized/OLMo-1B/dolma-v1.7"
tokenizer: "allenai/OLMo-1B" # requires `pip install ai2-olmo`
# tokenizer: "meta-llama/Llama-2-7b-hf"
stop_strategy: restart
shuffle_buffer_size: 100000
configs:
dolma-algebraic-stack:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/algebraic-stack-train-{0000..0015}.json.gz
dolma-arxiv:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/arxiv-{0000..0099}.json.gz
dolma-gutenberg:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/books-{0000..0002}.json.gz
dolma-c4:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/c4-{0000..0170}.json.gz
dolma-cc:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_head-{0000..0274}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0000..0238}.json.gz # 239 is missing
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0240..0379}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0000..0152}.json.gz # 153 is missing
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0154..0444}.json.gz
dolma-cc-news:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_head-{0000..0004}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_middle-{0000..0002}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_tail-0000.json.gz
dolma-falcon:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/falcon-{0000..0499}.json.gz
dolma-megawika:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/megawika-{0000..0261}.json.gz
dolma-owmath:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/open-web-math-train-{0000..0012}.json.gz
dolma-pes2o:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/pes2o-{0000..0025}.json.gz
dolma-reddit:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/reddit-{0000..0077}.json.gz
dolma-stackexchange:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/stackexchange-{0000..0025}.json.gz
dolma-starcoder:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/starcoder-{0000..0048}.json.gz
dolma-flan:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/tulu_flan-{0000..0065}.json.gz
dolma-wiki:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/wiki-{0000..0001}.json.gz
# these are just for eval
"paloma/4chan":
validation_urls:
- gs://levanter-data/paloma/4chan_meta_sep/val/val*.jsonl.gz
"paloma/c4_100_domains":
validation_urls:
- gs://levanter-data/paloma/c4_100_domains/val/val*.jsonl.gz
"paloma/c4_en":
validation_urls:
- gs://levanter-data/paloma/c4_en/val/val*.jsonl.gz
"paloma/dolma-v1_5":
validation_urls:
- gs://levanter-data/paloma/dolma-v1_5/val/val*.jsonl.gz
"paloma/dolma_100_programing_languages":
validation_urls:
- gs://levanter-data/paloma/dolma_100_programing_languages/val/val*.jsonl.gz
"paloma/dolma_100_subreddits":
validation_urls:
- gs://levanter-data/paloma/dolma_100_subreddits/val/val*.jsonl.gz
"paloma/falcon-refinedweb":
validation_urls:
- gs://levanter-data/paloma/falcon-refinedweb/val/val*.jsonl.gz
"paloma/gab":
validation_urls:
- gs://levanter-data/paloma/gab/val/val*.jsonl.gz
"paloma/m2d2_s2orc_unsplit":
validation_urls:
- gs://levanter-data/paloma/m2d2_s2orc_unsplit/val/val*.jsonl.gz
"paloma/m2d2_wikipedia_unsplit":
validation_urls:
- gs://levanter-data/paloma/m2d2_wikipedia_unsplit/val/val*.jsonl.gz
"paloma/manosphere_meta_sep":
validation_urls:
- gs://levanter-data/paloma/manosphere_meta_sep/val/val*.jsonl.gz
"paloma/mc4":
validation_urls:
- gs://levanter-data/paloma/mc4/val/val*.jsonl.gz
"paloma/ptb":
validation_urls:
- gs://levanter-data/paloma/ptb/val/val*.jsonl.gz
"paloma/redpajama":
validation_urls:
- gs://levanter-data/paloma/redpajama/val/val*.jsonl.gz
"paloma/twitterAAE_HELM_fixed":
validation_urls:
- gs://levanter-data/paloma/twitterAAE_HELM_fixed/val/val*.jsonl.gz
"paloma/wikitext_103":
validation_urls:
- gs://levanter-data/paloma/wikitext_103/val/val*.jsonl.gz
train_weights:
# sampling proportion comes from https://huggingface.co/datasets/allenai/dolma
dolma-algebraic-stack: 12.6 # 12.6 * 1.0
dolma-arxiv: 28.0 # 28.0 * 1.0
dolma-gutenberg: 5.3 # 5.3 * 1.0
dolma-c4: 69.2 # 138.4 * 0.5
dolma-cc: 597.75 # 1,195.5 * 0.5
dolma-cc-news: 14.3 # 1.0
dolma-falcon: 456.4 # 1.0, refined web
dolma-megawika: 4.6 # 1.0
dolma-owmath: 12.6 # 1.0
dolma-pes2o: 57.2 # 1.0
dolma-reddit: 79.9 # 1.0
dolma-stackexchange: 19.6 # 1.0
dolma-starcoder: 263.8 # 1.0
dolma-flan: 16.5 # 6.5 * 1.0
dolma-wiki: 7.4 # 3.7 * 2.0
paloma/4chan: 0.0
paloma/c4_100_domains: 0.0
paloma/c4_en: 0.0
paloma/dolma-v1_5: 0.0
paloma/dolma_100_programing_languages: 0.0
paloma/dolma_100_subreddits: 0.0
paloma/falcon-refinedweb: 0.0
paloma/gab: 0.0
paloma/m2d2_s2orc_unsplit: 0.0
paloma/m2d2_wikipedia_unsplit: 0.0
paloma/manosphere_meta_sep: 0.0
paloma/mc4: 0.0
paloma/ptb: 0.0
paloma/redpajama: 0.0
paloma/twitterAAE_HELM_fixed: 0.0
paloma/wikitext_103: 0.0
model: # 7B class model
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: True
# flash_attention_block_size: 1024

use_bias: false
use_layer_norm_weight: false
trainer:
tracker:
type: wandb
project: "marin"
tags: ["dolma", "olmo", "llama"]

mp: p=f32,c=bfloat16
train_batch_size: 2048 # olmo actually uses 2160 table 5 of https://arxiv.org/pdf/2402.00838
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
replica_dcn_axis_size: 2
optimizer:
learning_rate: 3E-4
weight_decay: 0.1
min_lr_ratio: 0.1
beta1: 0.9
beta2: 0.95
warmup: 2000
13 changes: 9 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ name = "levanter"
version = "1.2"
authors = [
{ name = "David Hall", email = "[email protected]" },
{ name = "Jason Wang"},
{ name = "Ahmed Ahmed"},
{ name = "Ivan Zhou", email = "[email protected]" },
{ name = "Will Held"},
{ name = "Virginia Adams"}
]
description = "Scalable Training for Foundation Models with Named Tensors and JAX"
readme = "README.md"
Expand Down Expand Up @@ -47,10 +51,11 @@ dependencies = [
"pydantic<3",
"rich~=13.0",
"filelock~=3.13",
# "ai2-olmo",
"async-lru~=2.0",
"tqdm-loggable>=0.2",
"deepdiff"
"deepdiff",
# "lm-eval==0.4.2",
"lm-eval @ git+https://github.com/dlwh/lm-evaluation-harness.git@no_torch"
]

[project.urls]
Expand Down Expand Up @@ -100,11 +105,11 @@ markers = [
[project.optional-dependencies]
test = [
"pytest",
"pytest-forked",
"pytest-asyncio",
"flake8",
"soundfile",
"librosa",
"pytest-forked",
"pytest-asyncio",
]

[tool.setuptools.packages.find]
Expand Down
8 changes: 6 additions & 2 deletions src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,15 @@ def load_checkpoint(
logger.warning("Loading checkpoint in jit. This is not recommended and probably won't work.")

if discover_latest:
checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore
discovered_checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore
else:
discovered_checkpoint_path = checkpoint_path

if checkpoint_path is None or not fs.exists(checkpoint_path):
if discovered_checkpoint_path is None or not fs.exists(discovered_checkpoint_path):
raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}")

checkpoint_path = discovered_checkpoint_path

logger.info(f"Loading checkpoint from {checkpoint_path}")
metadata = load_metadata(checkpoint_path, fs)

Expand Down
Loading

0 comments on commit d125206

Please sign in to comment.