-
Notifications
You must be signed in to change notification settings - Fork 113
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
178 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Optional, Tuple | ||
|
||
import torch | ||
|
||
import ttnn | ||
|
||
from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time | ||
from models.utility_functions import torch_random | ||
from loguru import logger | ||
import tt_lib as ttl | ||
import ttnn | ||
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc | ||
from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 | ||
|
||
# Override the default timeout in seconds for hang detection. | ||
TIMEOUT = 30 | ||
|
||
|
||
def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout): | ||
if layout == ttl.tensor.Layout.ROW_MAJOR and input_dtype == ttl.tensor.DataType.BFLOAT8_B: | ||
return True, "Invalid combination" | ||
|
||
if num_devices < 2: | ||
return True, "Requires multiple devices to run" | ||
elif num_devices == 2 and num_links <= 2: | ||
return True, "Not enough links to run" | ||
|
||
if input_shape[dim] % num_devices != 0 or (dim == 3 and input_shape[dim] // num_devices % 32 != 0): | ||
return True, "Unsupported test case" | ||
|
||
## Check that we can readback results | ||
fast_dispatch_page_size_limit = 55 * 1024 | ||
elem_size = 2 if input_dtype == ttl.tensor.DataType.BFLOAT16 else 1 | ||
if layout == ttl.tensor.Layout.ROW_MAJOR and (input_shape[dim] * elem_size) > fast_dispatch_page_size_limit: | ||
# Fast dispatch currently can't breakup readback of large pages into multiple smaller pages and is | ||
# limited to ~55K pages. | ||
return True, "Fast dispatch can't support reading back this page size in one shot" | ||
|
||
# Check that we can fit in L1 (if L1 config) | ||
tensor_size_bytes = elem_size | ||
for i in input_shape: | ||
tensor_size_bytes *= i | ||
num_l1_banks = 64 | ||
if mem_config.buffer_type == ttl.tensor.BufferType.L1 and tensor_size_bytes > num_l1_banks * 50 * 1024: | ||
return True, "L1 buffer can't support large tensor sizes" | ||
|
||
# Check that each chip has a non-zero amount of data available | ||
min_sized_chunks_on_dim = input_shape[dim] | ||
if dim == 3: | ||
min_sized_chunks_on_dim //= 32 | ||
if dim == 2: | ||
if layout == ttl.tensor.Layout.TILE: | ||
min_sized_chunks_on_dim //= 32 | ||
if min_sized_chunks_on_dim < num_devices: | ||
return ( | ||
True, | ||
f"Input shape {input_shape} incompatible with {num_devices} on dim {dim} because some chips will have no tensor", | ||
) | ||
|
||
if ( | ||
input_shape == [8, 8, 256, 384] | ||
and dim == 1 | ||
and layout == ttl.tensor.Layout.TILE | ||
and input_dtype == ttl.tensor.DataType.BFLOAT8_B | ||
): | ||
return True, "Known failure" | ||
|
||
return False, "" | ||
|
||
|
||
# Parameters provided to the test vector generator are defined here. | ||
# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. | ||
# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. | ||
# Developers can create their own generator functions and pass them to the parameters as inputs. | ||
parameters = { | ||
"suite_1": { | ||
"num_devices": [4, 8], | ||
"num_links": [1, 2], | ||
"input_shape": [ | ||
[8, 1, 33, 256], | ||
[8, 1, 256, 32], | ||
[8, 8, 256, 384], | ||
[8, 5, 13, 512], | ||
[8, 5, 32, 512], | ||
[1, 1, 32, 16384], | ||
], | ||
"dim": [0, 1, 3], | ||
"layout": [ttl.tensor.Layout.ROW_MAJOR, ttl.tensor.Layout.TILE], | ||
"input_dtype": [ttl.tensor.DataType.BFLOAT16], | ||
"mem_config": [ttl.tensor.MemoryConfig(buffer_type=ttl.tensor.BufferType.DRAM)], | ||
"enable_async": [True, False], | ||
}, | ||
} | ||
|
||
|
||
# Invalidate vector is called during the generation phase where each vector will be passed in. | ||
# If invalidated, the vector will still be stored but will be skipped. | ||
# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. | ||
def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: | ||
if test_vector["broadcast"] in {"w", "hw"} and test_vector["input_b_layout"] == ttnn.ROW_MAJOR_LAYOUT: | ||
return True, "Broadcasting along width is not supported for row major layout" | ||
return False, None | ||
|
||
|
||
def skip( | ||
*, all_devices, input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout, **_ | ||
) -> Tuple[bool, Optional[str]]: | ||
if len(all_devices) != 8: | ||
return True, "Not T3000!" | ||
(is_known_failure, message) = is_unsupported_case( | ||
input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout | ||
) | ||
if is_known_failure: | ||
return True, f"Skipping unsupported case {message}." | ||
return False, None | ||
|
||
|
||
# This is the run instructions for the test, defined by the developer. | ||
# The run function must take the above-defined parameters as inputs. | ||
# The runner will call this run function with each test vector, and the returned results from this function will be stored. | ||
def run( | ||
all_devices, | ||
num_devices, | ||
input_shape, | ||
dim, | ||
num_links, | ||
input_dtype, | ||
layout, | ||
mem_config, | ||
use_program_cache, | ||
function_level_defaults, | ||
enable_async, | ||
num_iters=1, | ||
) -> list: | ||
for device in all_devices: | ||
device.enable_async(enable_async) | ||
|
||
logger.info(f"Input shape: {input_shape}") | ||
logger.info(f"dim: {dim}") | ||
|
||
devices = get_devices_for_t3000(all_devices, num_devices) | ||
# for device in devices: | ||
# device.disable_and_clear_program_cache() | ||
|
||
logger.info(f"Input shape: {input_shape}") | ||
logger.info(f"dim: {dim}") | ||
|
||
input_tensor = torch.rand(input_shape).bfloat16() | ||
|
||
input_tensors = torch.chunk(input_tensor, num_devices, dim) | ||
tt_input_tensors = [] | ||
for i, t in enumerate(input_tensors): | ||
tt_input_tensors.append(ttl.tensor.Tensor(t, input_dtype).to(layout).to(devices[i], mem_config)) | ||
|
||
input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) | ||
for i in range(num_iters): | ||
start_time = start_measuring_time() | ||
tt_out_tensor = ttnn.line_all_gather(input_tensor_mesh, dim, num_links=num_links, memory_config=mem_config) | ||
e2e_perf = stop_measuring_time(start_time) | ||
|
||
for d in devices: | ||
ttl.device.Synchronize(d) | ||
logger.info(f"Done iteration {i}") | ||
|
||
for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): | ||
tt_output_tensor = t.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() | ||
if input_dtype == ttl.tensor.DataType.BFLOAT16: | ||
eq, output = comp_equal(tt_output_tensor, input_tensor) | ||
else: | ||
eq, output = comp_pcc(tt_output_tensor, input_tensor) | ||
if not eq: | ||
logger.error(f"output mismatch for tensor {i}") | ||
# assert eq, f"{i} FAILED: {output}" | ||
return [eq, output, e2e_perf] |