Skip to content

Commit

Permalink
All gather async llama ci (#17746)
Browse files Browse the repository at this point in the history
### What's changed
Added Llama shape ccl async test to CI and added (e2e) perf measurement

### Checklist
- [x] All post commit:
https://github.com/tenstorrent/tt-metal/actions/runs/13246317576
  • Loading branch information
caixunshiren authored Feb 10, 2025
1 parent 0d5c997 commit 2d6c93d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
1 change: 1 addition & 0 deletions tests/nightly/tg/ccl/test_ccl_async_TG_llama_nightly.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
teardown_fabric_interface,
create_global_semaphore_with_same_address,
)
from models.perf.benchmarking_utils import BenchmarkProfiler


def report_mismatches(golden, actual, max_printable=None):
Expand Down Expand Up @@ -64,6 +65,7 @@ def run_with_trace(
n_buffer=None,
num_iter=20,
use_all_gather_async=False,
profiler=BenchmarkProfiler(),
):
# Compile Run
logger.info("Compiling model")
Expand Down Expand Up @@ -131,10 +133,15 @@ def run_with_trace(

# Run the op
logger.info("Starting Trace perf test...")
profiler.start("all-gather-async-trace")
ttnn.execute_trace(mesh_device, trace_id, blocking=False)
ttnn.release_trace(mesh_device, trace_id)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)
profiler.end("all-gather-async-trace")
logger.info(f"Time taken: {profiler.get_duration('all-gather-async-trace')} s")
logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter} s")
logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter * 1e6} us")

return tt_out_tensor

Expand All @@ -160,6 +167,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
tile=(32, 32),
trace_mode=False,
debug=False,
profiler=BenchmarkProfiler(),
# New all-gather-async and persistent fabric params
use_all_gather_async=False,
enable_persistent_fabric=False,
Expand Down Expand Up @@ -270,6 +278,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
all_gather_topology=ttnn.Topology.Linear,
num_iter=num_iters,
use_all_gather_async=use_all_gather_async,
profiler=profiler,
)

else:
Expand Down
30 changes: 23 additions & 7 deletions tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tests.ttnn.unit_tests.operations.ccl.test_all_reduce_async import (
run_all_reduce_with_mesh_tensor_along_row,
)
from models.perf.benchmarking_utils import BenchmarkProfiler


PREFETCHER_NOC1_RING = [
Expand Down Expand Up @@ -79,22 +80,25 @@ def get_core_range_set(output_core_grid):
"num_devices, num_links",
[
(4, 3),
(4, 2),
(4, 1),
],
)
@pytest.mark.parametrize(
"input_dtype",
[
ttnn.bfloat16,
ttnn.bfloat8_b,
],
)
@pytest.mark.parametrize(
"num_iters",
[
5000,
],
)
@pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR])
@pytest.mark.parametrize(
"tensor_mem_layout, output_shape, dim, input_shard_shape,input_shard_grid,output_shard_shape, output_shard_grid, layout",
"tensor_mem_layout, output_shape, dim, input_shard_shape,input_shard_grid,output_shard_shape, output_shard_grid, layout, perf_target_us",
(
( # AllGather after SDPA (~160 us)
( # AllGather after SDPA
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
(1, 32, 32, 128),
1,
Expand All @@ -108,8 +112,9 @@ def get_core_range_set(output_core_grid):
}
),
ttnn.TILE_LAYOUT,
32,
),
( # AllGather after Binary Mult+Silu (~160 us)
( # AllGather after Binary Mult+Silu
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
(1, 1, 32, 3840),
3,
Expand All @@ -118,6 +123,7 @@ def get_core_range_set(output_core_grid):
(32, 160),
get_core_range_set(PREFETCHER_NOC1_RING),
ttnn.TILE_LAYOUT,
25,
),
),
)
Expand All @@ -143,7 +149,8 @@ def test_line_all_gather_sharded_on_TG_rows_llama(
function_level_defaults,
enable_async,
replication_factor,
num_iters=100,
num_iters,
perf_target_us,
):
if len(mesh_device.get_devices()) != 32:
pytest.skip("Not TG!")
Expand All @@ -162,6 +169,8 @@ def test_line_all_gather_sharded_on_TG_rows_llama(
else:
output_shard_spec = None

profiler = BenchmarkProfiler()

run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
mesh_device,
num_devices,
Expand All @@ -180,13 +189,20 @@ def test_line_all_gather_sharded_on_TG_rows_llama(
output_shard_spec=output_shard_spec,
num_all_gather_instances=replication_factor,
cluster_axis=1,
profiler=profiler,
trace_mode=True,
use_all_gather_async=True,
enable_persistent_fabric=True,
create_persistent_fabric=True,
teardown_persistent_fabric=True,
)

latency_us = profiler.get_duration("all-gather-async-trace") / num_iters * 1e6
if perf_target_us is not None:
assert (
latency_us < perf_target_us
), f"Measured latency {latency_us} us is greater than target {perf_target_us} us"


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
Expand Down

0 comments on commit 2d6c93d

Please sign in to comment.