Skip to content

Commit

Permalink
tg test
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Nov 29, 2024
1 parent 53c32c0 commit 2c42c56
Showing 1 changed file with 110 additions and 0 deletions.
110 changes: 110 additions & 0 deletions tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_nightly.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,113 @@ def test_line_all_gather_on_TG_cols_nightly(
num_all_gather_instances=replication_factor,
cluster_axis=0,
)


def tg(
mesh_device,
num_devices_per_line,
per_chip_output_shape,
tensor_memory_layout,
dim,
num_links,
input_dtype,
layout,
buffer_type: ttnn.BufferType,
use_program_cache,
function_level_defaults,
enable_async,
input_shard_spec: ttnn.ShardSpec = None,
num_all_gather_instances: int = 1,
num_iters: int = 1,
cluster_axis: int = 0,
tile=(32, 32),
trace_mode=False,
debug=False,
):
import ttnn
import torch
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc

# mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(8, 4), mesh_type=ttnn.MeshType.Ring)

# Construct test tensor of data; 8 chunks of 32x32
torch_tensor = torch.rand((1, 1, 32, 32 * 8 * 4), dtype=torch.bfloat16)

# Convert to ttnn.Tensor, tilize and move onto devices across mesh DRAM
mesh_tensor = ttnn.from_torch(
torch_tensor,
layout=ttnn.TILE_LAYOUT,
device=mesh_device,
mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3),
)

# Execute All-Gather on the tensor; `num_links=1` specifies the number of ethernet links to use
output_tensor = ttnn.all_gather(mesh_tensor, dim=3, num_links=1)
print(output_tensor.shape)
for i, t in enumerate(ttnn.get_device_tensors(output_tensor)):
tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
eq, output = comp_equal(tt_output_tensor, torch_tensor)
print(eq, output)
assert eq, f"{i} FAILED: {output}"


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
"num_devices, num_links, per_chip_output_shape, dim, layout",
[
(8, 4, [1, 1, 32, 32 * 8 * 4], 3, ttnn.TILE_LAYOUT),
# (8, 4, [8, 1, 32, 1280], 0, ttnn.TILE_LAYOUT),
# (8, 4, [1, 8, 32, 2048], 1, ttnn.TILE_LAYOUT),
# (8, 4, [1, 8, 32, 2304], 1, ttnn.TILE_LAYOUT),
# (8, 4, [1, 8, 32, 4096], 1, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
"input_dtype",
[
ttnn.bfloat16,
ttnn.bfloat8_b,
],
)
@pytest.mark.parametrize(
"buffer_type",
[
ttnn.BufferType.DRAM,
ttnn.BufferType.L1,
],
)
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("replication_factor", [4]) # 1, 4])
@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True)
def test_line_all_gather_on_TG_cols_nightly_check(
mesh_device,
num_devices,
per_chip_output_shape,
dim,
num_links,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async,
replication_factor,
num_iters=1,
):
tg(
mesh_device,
num_devices,
per_chip_output_shape,
ttnn.TensorMemoryLayout.INTERLEAVED,
dim,
num_links,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async=enable_async,
num_iters=num_iters,
num_all_gather_instances=replication_factor,
cluster_axis=0,
)

0 comments on commit 2c42c56

Please sign in to comment.