diff --git a/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_async_perf.py b/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_async_perf.py index bd5abea6b7b..7a482172d91 100644 --- a/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_async_perf.py +++ b/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_async_perf.py @@ -8,6 +8,9 @@ from tests.ttnn.unit_tests.operations.ccl.test_new_all_gather import ( run_all_gather_impl, ) +from tests.ttnn.unit_tests.operations.ccl.test_all_gather_TG_post_commit import ( + run_line_all_gather_on_TG_with_mesh_tensor_along_rows, +) @skip_for_grayskull("Requires eth connected devices to run") @@ -18,6 +21,12 @@ # (4, 1, [1, 1, 32, 32768], 3, ttnn.TILE_LAYOUT), # (4, 1, [1, 1, 1024, 1024], 3, ttnn.TILE_LAYOUT), # (4, 1, [1, 1, 2048, 16384], 3, ttnn.TILE_LAYOUT), + (4, 1, [1, 1, 32, 1280], 0, ttnn.TILE_LAYOUT), + (4, 1, [1, 1, 32, 7168], 0, ttnn.TILE_LAYOUT), + (8, 1, [1, 1, 32, 2048], 0, ttnn.TILE_LAYOUT), + (4, 1, [1, 1, 32, 3584], 0, ttnn.TILE_LAYOUT), + (4, 1, [1, 1, 32, 32], 0, ttnn.TILE_LAYOUT), + # (4, 1, [1, 1, 8, 32], 2, ttnn.TILE_LAYOUT), ], ) @pytest.mark.parametrize( @@ -67,3 +76,73 @@ def test_all_gather_async_t3000( mem_config=mem_config, trace_mode=True, ) + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "num_devices, num_links, per_chip_output_shape, dim, layout", + [ + (8, 1, [1, 8, 32, 1280], 1, ttnn.TILE_LAYOUT), + (8, 1, [8, 1, 32, 1280], 0, ttnn.TILE_LAYOUT), + (8, 1, [1, 8, 32, 2048], 1, ttnn.TILE_LAYOUT), + (8, 1, [1, 8, 32, 2304], 1, ttnn.TILE_LAYOUT), + (8, 1, [1, 8, 32, 4096], 1, ttnn.TILE_LAYOUT), + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ttnn.bfloat8_b, + ], +) +@pytest.mark.parametrize( + "buffer_type", + [ + ttnn.BufferType.DRAM, + ttnn.BufferType.L1, + ], +) +@pytest.mark.parametrize("replication_factor", [4]) +@pytest.mark.parametrize("enable_async", [True]) +@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) +@pytest.mark.parametrize("device_params", [{"trace_region_size": 1824800}], indirect=True) +def test_all_gather_async_tg( + mesh_device, + num_devices, + per_chip_output_shape, + dim, + num_links, + input_dtype, + layout, + buffer_type, + use_program_cache, + function_level_defaults, + enable_async, + replication_factor, + num_iters=1, +): + if len(mesh_device.get_devices()) != 32: + pytest.skip("Not TG!") + run_line_all_gather_on_TG_with_mesh_tensor_along_rows( + mesh_device, + num_devices, + per_chip_output_shape, + ttnn.TensorMemoryLayout.INTERLEAVED, + dim, + num_links, + input_dtype, + layout, + buffer_type, + use_program_cache, + function_level_defaults, + enable_async=enable_async, + num_iters=num_iters, + num_all_gather_instances=replication_factor, + cluster_axis=0, + use_all_gather_async=True, + enable_persistent_fabric=True, + create_persistent_fabric=True, + teardown_persistent_fabric=True, + trace_mode=True, + ) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py index 7534038d205..0e080a3e219 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py @@ -57,28 +57,30 @@ def run_with_trace( num_links, cluster_axis, output_mem_config, + ccl_semaphore_handles, + worker_sub_device_id, + enable_persistent_fabric, n_worker=None, n_buffer=None, num_iter=20, + use_all_gather_async=False, ): # Compile Run logger.info("Compiling model") - tt_out_tensor = ttnn.all_gather( - input_tensor, - dim=dim, - cluster_axis=cluster_axis, - mesh_device=mesh_device, - num_links=num_links, - memory_config=output_mem_config, - topology=all_gather_topology, - ) - for d in mesh_device.get_devices(): - ttnn.synchronize_device(d) - - # Capture trace - logger.info("Capturing trace") - trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) - for i in range(num_iter): + if use_all_gather_async: + tt_out_tensor = ttnn.experimental.all_gather_async( + input_tensor, + dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + topology=ttnn.Topology.Linear, + multi_device_global_semaphore=ccl_semaphore_handles, + num_links=num_links, + memory_config=output_mem_config, + subdevice_id=worker_sub_device_id, + enable_persistent_fabric_mode=enable_persistent_fabric, + ) + else: tt_out_tensor = ttnn.all_gather( input_tensor, dim=dim, @@ -88,6 +90,37 @@ def run_with_trace( memory_config=output_mem_config, topology=all_gather_topology, ) + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + + # Capture trace + logger.info("Capturing trace") + trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + for i in range(num_iter): + if use_all_gather_async: + logger.info("Running all-gather async") + tt_out_tensor = ttnn.experimental.all_gather_async( + input_tensor, + dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + topology=ttnn.Topology.Linear, + multi_device_global_semaphore=ccl_semaphore_handles, + num_links=num_links, + memory_config=output_mem_config, + subdevice_id=worker_sub_device_id, + enable_persistent_fabric_mode=enable_persistent_fabric, + ) + else: + tt_out_tensor = ttnn.all_gather( + input_tensor, + dim=dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + num_links=num_links, + memory_config=output_mem_config, + topology=all_gather_topology, + ) ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) for d in mesh_device.get_devices(): ttnn.synchronize_device(d) @@ -224,8 +257,12 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( mesh_device=mesh_device, num_links=num_links, output_mem_config=output_mem_config, + ccl_semaphore_handles=ccl_semaphore_handles, + worker_sub_device_id=worker_sub_device_id, + enable_persistent_fabric=enable_persistent_fabric, all_gather_topology=ttnn.Topology.Linear, num_iter=num_iters, + use_all_gather_async=use_all_gather_async, ) else: for _ in range(num_iters):