Skip to content

Commit

Permalink
#10874: Use t3k_device_mesh fixture for concurrent instances
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Aug 7, 2024
1 parent 0eb3fe1 commit 172371d
Showing 1 changed file with 18 additions and 23 deletions.
41 changes: 18 additions & 23 deletions tests/ttnn/unit_tests/operations/test_line_all_gather_nightly.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_line_all_gather_on_t3000_post_commit(


def run_line_all_gather_instances(
all_devices,
t3k_device_mesh,
num_devices,
num_instances,
input_shape,
Expand All @@ -149,10 +149,10 @@ def run_line_all_gather_instances(
enable_async,
num_iters=1,
):
if len(all_devices) != 8:
if t3k_device_mesh.get_num_devices() != 8:
pytest.skip("Not T3000!")

for device in all_devices:
for device in t3k_device_mesh.get_devices():
device.enable_async(enable_async)

logger.info(f"Input shape: {input_shape}")
Expand All @@ -167,45 +167,40 @@ def run_line_all_gather_instances(
# devices = get_devices_for_t3000(all_devices, num_devices)
# for device in devices:
# device.disable_and_clear_program_cache()
t3k_device = []

for device in t3k_device_mesh.get_devices():
t3k_device.append(device)

t3000_device_rows = [
[all_devices[4], all_devices[0], all_devices[3], all_devices[7]],
[all_devices[5], all_devices[1], all_devices[2], all_devices[6]],
[t3k_device[4], t3k_device[0], t3k_device[3], t3k_device[7]],
[t3k_device[5], t3k_device[1], t3k_device[2], t3k_device[6]],
]
logger.info(f"Input shape: {input_shape}")
logger.info(f"dim: {dim}")

input_mesh_tensors = []
input_tensor_to_compare = []
for devices in t3000_device_rows:
input_tensor = torch.rand(input_shape).bfloat16()
input_tensor_to_compare.append(input_tensor)

input_tensors = torch.chunk(input_tensor, len(devices), dim)
tt_input_tensors = []
for i, t in enumerate(input_tensors):
tt_input_tensors.append(ttl.tensor.Tensor(t, input_dtype).to(layout).to(devices[i], mem_config))
input_tensor = torch.rand(input_shape).bfloat16()

input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors)
input_mesh_tensors.append(input_tensor_mesh)
ttnn_tensor = ttnn.from_torch(input_tensor, mesh_mapper=ShardTensorToMesh(t3k_device_mesh, dim=dim))
input_tensor_mesh = ttnn.to_device(ttnn_tensor, t3k_device_mesh)

result_mesh_tensors = []
for i, devices in enumerate(t3000_device_rows):
tt_out_tensor = ttnn.line_all_gather(input_mesh_tensors[i], dim, num_links=num_links, memory_config=mem_config)
tt_out_tensor = ttnn.line_all_gather(input_tensor_mesh, dim, num_links=num_links, memory_config=mem_config)
result_mesh_tensors.append(tt_out_tensor)

## Wait for completion
for i, devices in enumerate(t3000_device_rows):
for d in devices:
ttl.device.Synchronize(d)

for count, tt_out_tensor in enumerate(result_mesh_tensors):
for tt_out_tensor in result_mesh_tensors:
for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)):
tt_output_tensor = t.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
if input_dtype == ttl.tensor.DataType.BFLOAT16:
eq, output = comp_equal(tt_output_tensor, input_tensor_to_compare[count])
eq, output = comp_equal(tt_output_tensor, input_tensor)
else:
eq, output = comp_pcc(tt_output_tensor, input_tensor_to_compare[count])
eq, output = comp_pcc(tt_output_tensor, input_tensor)
if not eq:
logger.error(f"output mismatch for tensor {i}")
assert eq, f"{i} FAILED: {output}"
Expand Down Expand Up @@ -247,7 +242,7 @@ def run_line_all_gather_instances(
)
@pytest.mark.parametrize("enable_async", [True, False])
def test_line_all_gather_on_t3000_post_commit_instances(
all_devices,
t3k_device_mesh,
num_devices,
num_instances,
input_shape,
Expand All @@ -262,7 +257,7 @@ def test_line_all_gather_on_t3000_post_commit_instances(
num_iters=1,
):
run_line_all_gather_instances(
all_devices,
t3k_device_mesh,
num_devices,
num_instances,
input_shape,
Expand Down

0 comments on commit 172371d

Please sign in to comment.