Skip to content

Commit

Permalink
Add Mixtral Perplexity and Top-1/5 tests to CI (#10850)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtairum authored Jul 30, 2024
1 parent 1df00cf commit 69633e1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
10 changes: 5 additions & 5 deletions models/demos/t3000/mixtral8x7b/tests/test_mixtral_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,9 @@ def run_test_perplexity(
# ("prefill", 1024, 64, -, -, -),
# ("prefill", 2048, 64, -, -, -),
# ("prefill", 4096, 64, -, -, -),
("decode", 128, 64, 8.70, 0.52, 0.75),
("decode", 1024, 64, 4.90, 0.62, 0.83),
("decode", 2048, 64, 4.23, 0.64, 0.85),
("decode", 128, 64, 8.80, 0.52, 0.75),
# ("decode", 1024, 64, 5.10, 0.62, 0.83),
# ("decode", 2048, 64, 4.23, 0.64, 0.85),
# ("decode", 4096, 32, 10.59, 0.49, 0.73),
),
ids=[
Expand All @@ -266,8 +266,8 @@ def run_test_perplexity(
# "prefill_2048",
# "prefill_4096",
"decode_128",
"decode_1024",
"decode_2048",
# "decode_1024",
# "decode_2048",
# "decode_4096",
],
)
Expand Down
12 changes: 8 additions & 4 deletions models/demos/t3000/mixtral8x7b/tests/test_mixtral_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ def forward(self, x):
@pytest.mark.parametrize(
"iterations, expected_top1, expected_top5",
(
(64, 0.92, 0.99),
(128, 0.92, 0.99),
(256, 0.92, 0.99),
(64, 0.93, 0.99),
# (128, 0.92, 0.99),
# (256, 0.92, 0.99),
),
ids=(
"64seqlen",
# "128seqlen",
# "256seqlen"
),
ids=("64seqlen", "128seqlen", "256seqlen"),
)
def test_mixtral_model_inference(
t3k_device_mesh, use_program_cache, reset_seeds, iterations, expected_top1, expected_top5
Expand Down
4 changes: 4 additions & 0 deletions tests/scripts/t3000/run_t3000_perplexity_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ run_t3000_perplexity_tests() {
# Llama-70B perplexity tests
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto models/demos/t3000/llama2_70b/demo/eval_t3000.py --timeout=7200 ; fail+=$?

# Mixtral8x7B perplexity tests
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto models/demos/t3000/mixtral8x7b/tests/test_mixtral_perplexity.py --timeout=3600 ; fail+=$?
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto models/demos/t3000/mixtral8x7b/tests/test_mixtral_topk.py --timeout=3600 ; fail+=$?

# Record the end time
end_time=$(date +%s)
duration=$((end_time - start_time))
Expand Down

0 comments on commit 69633e1

Please sign in to comment.