diff --git a/tests/sweep_framework/sweeps/line_all_gather.py b/tests/sweep_framework/sweeps/line_all_gather.py new file mode 100644 index 000000000000..2abd6fba3b40 --- /dev/null +++ b/tests/sweep_framework/sweeps/line_all_gather.py @@ -0,0 +1,178 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random +from loguru import logger +import tt_lib as ttl +import ttnn +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc +from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + + +def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout): + if layout == ttl.tensor.Layout.ROW_MAJOR and input_dtype == ttl.tensor.DataType.BFLOAT8_B: + return True, "Invalid combination" + + if num_devices < 2: + return True, "Requires multiple devices to run" + elif num_devices == 2 and num_links <= 2: + return True, "Not enough links to run" + + if input_shape[dim] % num_devices != 0 or (dim == 3 and input_shape[dim] // num_devices % 32 != 0): + return True, "Unsupported test case" + + ## Check that we can readback results + fast_dispatch_page_size_limit = 55 * 1024 + elem_size = 2 if input_dtype == ttl.tensor.DataType.BFLOAT16 else 1 + if layout == ttl.tensor.Layout.ROW_MAJOR and (input_shape[dim] * elem_size) > fast_dispatch_page_size_limit: + # Fast dispatch currently can't breakup readback of large pages into multiple smaller pages and is + # limited to ~55K pages. + return True, "Fast dispatch can't support reading back this page size in one shot" + + # Check that we can fit in L1 (if L1 config) + tensor_size_bytes = elem_size + for i in input_shape: + tensor_size_bytes *= i + num_l1_banks = 64 + if mem_config.buffer_type == ttl.tensor.BufferType.L1 and tensor_size_bytes > num_l1_banks * 50 * 1024: + return True, "L1 buffer can't support large tensor sizes" + + # Check that each chip has a non-zero amount of data available + min_sized_chunks_on_dim = input_shape[dim] + if dim == 3: + min_sized_chunks_on_dim //= 32 + if dim == 2: + if layout == ttl.tensor.Layout.TILE: + min_sized_chunks_on_dim //= 32 + if min_sized_chunks_on_dim < num_devices: + return ( + True, + f"Input shape {input_shape} incompatible with {num_devices} on dim {dim} because some chips will have no tensor", + ) + + if ( + input_shape == [8, 8, 256, 384] + and dim == 1 + and layout == ttl.tensor.Layout.TILE + and input_dtype == ttl.tensor.DataType.BFLOAT8_B + ): + return True, "Known failure" + + return False, "" + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "suite_1": { + "num_devices": [4, 8], + "num_links": [1, 2], + "input_shape": [ + [8, 1, 33, 256], + [8, 1, 256, 32], + [8, 8, 256, 384], + [8, 5, 13, 512], + [8, 5, 32, 512], + [1, 1, 32, 16384], + ], + "dim": [0, 1, 3], + "layout": [ttl.tensor.Layout.ROW_MAJOR, ttl.tensor.Layout.TILE], + "input_dtype": [ttl.tensor.DataType.BFLOAT16], + "mem_config": [ttl.tensor.MemoryConfig(buffer_type=ttl.tensor.BufferType.DRAM)], + "enable_async": [True, False], + }, +} + + +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["broadcast"] in {"w", "hw"} and test_vector["input_b_layout"] == ttnn.ROW_MAJOR_LAYOUT: + return True, "Broadcasting along width is not supported for row major layout" + return False, None + + +def skip( + *, all_devices, input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout, **_ +) -> Tuple[bool, Optional[str]]: + if len(all_devices) != 8: + return True, "Not T3000!" + (is_known_failure, message) = is_unsupported_case( + input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout + ) + if is_known_failure: + return True, f"Skipping unsupported case {message}." + return False, None + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +def run( + all_devices, + num_devices, + input_shape, + dim, + num_links, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + enable_async, + num_iters=1, +) -> list: + for device in all_devices: + device.enable_async(enable_async) + + logger.info(f"Input shape: {input_shape}") + logger.info(f"dim: {dim}") + + devices = get_devices_for_t3000(all_devices, num_devices) + # for device in devices: + # device.disable_and_clear_program_cache() + + logger.info(f"Input shape: {input_shape}") + logger.info(f"dim: {dim}") + + input_tensor = torch.rand(input_shape).bfloat16() + + input_tensors = torch.chunk(input_tensor, num_devices, dim) + tt_input_tensors = [] + for i, t in enumerate(input_tensors): + tt_input_tensors.append(ttl.tensor.Tensor(t, input_dtype).to(layout).to(devices[i], mem_config)) + + input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) + for i in range(num_iters): + start_time = start_measuring_time() + tt_out_tensor = ttnn.line_all_gather(input_tensor_mesh, dim, num_links=num_links, memory_config=mem_config) + e2e_perf = stop_measuring_time(start_time) + + for d in devices: + ttl.device.Synchronize(d) + logger.info(f"Done iteration {i}") + + for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): + tt_output_tensor = t.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + if input_dtype == ttl.tensor.DataType.BFLOAT16: + eq, output = comp_equal(tt_output_tensor, input_tensor) + else: + eq, output = comp_pcc(tt_output_tensor, input_tensor) + if not eq: + logger.error(f"output mismatch for tensor {i}") + # assert eq, f"{i} FAILED: {output}" + return [eq, output, e2e_perf]