Skip to content

Commit

Permalink
Add cases for TG
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Jan 29, 2025
1 parent 834e10f commit 493a4f3
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 16 deletions.
79 changes: 79 additions & 0 deletions tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_async_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from tests.ttnn.unit_tests.operations.ccl.test_new_all_gather import (
run_all_gather_impl,
)
from tests.ttnn.unit_tests.operations.ccl.test_all_gather_TG_post_commit import (
run_line_all_gather_on_TG_with_mesh_tensor_along_rows,
)


@skip_for_grayskull("Requires eth connected devices to run")
Expand All @@ -18,6 +21,12 @@
# (4, 1, [1, 1, 32, 32768], 3, ttnn.TILE_LAYOUT),
# (4, 1, [1, 1, 1024, 1024], 3, ttnn.TILE_LAYOUT),
# (4, 1, [1, 1, 2048, 16384], 3, ttnn.TILE_LAYOUT),
(4, 1, [1, 1, 32, 1280], 0, ttnn.TILE_LAYOUT),
(4, 1, [1, 1, 32, 7168], 0, ttnn.TILE_LAYOUT),
(8, 1, [1, 1, 32, 2048], 0, ttnn.TILE_LAYOUT),
(4, 1, [1, 1, 32, 3584], 0, ttnn.TILE_LAYOUT),
(4, 1, [1, 1, 32, 32], 0, ttnn.TILE_LAYOUT),
# (4, 1, [1, 1, 8, 32], 2, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -67,3 +76,73 @@ def test_all_gather_async_t3000(
mem_config=mem_config,
trace_mode=True,
)


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
"num_devices, num_links, per_chip_output_shape, dim, layout",
[
(8, 1, [1, 8, 32, 1280], 1, ttnn.TILE_LAYOUT),
(8, 1, [8, 1, 32, 1280], 0, ttnn.TILE_LAYOUT),
(8, 1, [1, 8, 32, 2048], 1, ttnn.TILE_LAYOUT),
(8, 1, [1, 8, 32, 2304], 1, ttnn.TILE_LAYOUT),
(8, 1, [1, 8, 32, 4096], 1, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
"input_dtype",
[
ttnn.bfloat16,
ttnn.bfloat8_b,
],
)
@pytest.mark.parametrize(
"buffer_type",
[
ttnn.BufferType.DRAM,
ttnn.BufferType.L1,
],
)
@pytest.mark.parametrize("replication_factor", [4])
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True)
@pytest.mark.parametrize("device_params", [{"trace_region_size": 1824800}], indirect=True)
def test_all_gather_async_tg(
mesh_device,
num_devices,
per_chip_output_shape,
dim,
num_links,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async,
replication_factor,
num_iters=1,
):
if len(mesh_device.get_devices()) != 32:
pytest.skip("Not TG!")
run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
mesh_device,
num_devices,
per_chip_output_shape,
ttnn.TensorMemoryLayout.INTERLEAVED,
dim,
num_links,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async=enable_async,
num_iters=num_iters,
num_all_gather_instances=replication_factor,
cluster_axis=0,
use_all_gather_async=True,
enable_persistent_fabric=True,
create_persistent_fabric=True,
teardown_persistent_fabric=True,
trace_mode=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -57,28 +57,30 @@ def run_with_trace(
num_links,
cluster_axis,
output_mem_config,
ccl_semaphore_handles,
worker_sub_device_id,
enable_persistent_fabric,
n_worker=None,
n_buffer=None,
num_iter=20,
use_all_gather_async=False,
):
# Compile Run
logger.info("Compiling model")
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)

# Capture trace
logger.info("Capturing trace")
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(num_iter):
if use_all_gather_async:
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
else:
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
Expand All @@ -88,6 +90,37 @@ def run_with_trace(
memory_config=output_mem_config,
topology=all_gather_topology,
)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)

# Capture trace
logger.info("Capturing trace")
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(num_iter):
if use_all_gather_async:
logger.info("Running all-gather async")
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
else:
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)
Expand Down Expand Up @@ -224,8 +257,12 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
mesh_device=mesh_device,
num_links=num_links,
output_mem_config=output_mem_config,
ccl_semaphore_handles=ccl_semaphore_handles,
worker_sub_device_id=worker_sub_device_id,
enable_persistent_fabric=enable_persistent_fabric,
all_gather_topology=ttnn.Topology.Linear,
num_iter=num_iters,
use_all_gather_async=use_all_gather_async,
)
else:
for _ in range(num_iters):
Expand Down

0 comments on commit 493a4f3

Please sign in to comment.