Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preprocess Conv2D weights on Device #18272

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tests/sweep_framework/sweep_utils/conv2d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,17 @@ def run_conv2d_short_sweep(
dtype=output_dtype,
weights_dtype=weights_dtype,
output_layout=output_layout,
preprocess_weights_on_device=True,
)
else:
tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, ttnn.bfloat16)
if has_bias:
tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat16)

tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16, device=device)
conv_config = ttnn.Conv2dConfig()
conv_config = ttnn.Conv2dConfig(
preprocess_weights_on_device=True,
)

start_time = start_measuring_time()
[tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d(
Expand Down
39 changes: 32 additions & 7 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def run_conv(
config_override,
dilation=1,
use_shallow_conv_variant=False,
transpose_shards=True, # https://github.com/tenstorrent/tt-metal/issues/17897
fp32_accum=False,
packer_l1_acc=False,
output_layout=ttnn.TILE_LAYOUT,
Expand All @@ -73,6 +74,7 @@ def run_conv(
output_mesh_composer=None,
enable_split_reader=False,
activation="",
preprocess_weights_on_device=True,
):
if isinstance(device, ttnn.MeshDevice):
assert input_mesh_mapper is not None, "Expected mesh mapper for input tensor when using device mesh"
Expand All @@ -92,7 +94,7 @@ def run_conv(
torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1))

torch_weight_tensor = randomize_torch_tensor(torch_tensor_map, conv_weight_shape)
torch_bias_tensor = randomize_torch_tensor(torch_tensor_map, conv_bias_shape) if has_bias else None
torch_bias_tensor = randomize_torch_tensor(torch_tensor_map, conv_bias_shape) * 10 if has_bias else None

torch_out_golden_tensor = torch.nn.functional.conv2d(
torch_input_tensor_nchw,
Expand Down Expand Up @@ -138,6 +140,9 @@ def run_conv(
enable_subblock_padding=False,
output_layout=output_layout,
activation=activation,
transpose_shards=transpose_shards,
preprocess_weights_on_device=preprocess_weights_on_device,
always_preprocess_weights=True,
)
compute_config = ttnn.init_device_compute_kernel_config(
device.arch(),
Expand All @@ -157,7 +162,7 @@ def run_conv(
conv_config.override_sharding_config = True
print("Setting num_cores_nhw to 98")

[tt_output_tensor_on_device, [out_height, out_width]] = ttnn.conv2d(
[tt_output_tensor_on_device, [out_height, out_width], [d_w, d_b]] = ttnn.conv2d(
input_tensor=tt_input_tensor,
weight_tensor=tt_weight_tensor,
in_channels=input_channels,
Expand All @@ -178,8 +183,8 @@ def run_conv(
groups=groups,
memory_config=memory_config,
return_output_dim=True,
return_weights_and_bias=True,
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
torch_output_tensor = ttnn.to_torch(tt_output_tensor, mesh_composer=output_mesh_composer)

Expand All @@ -195,6 +200,8 @@ def run_conv(

if not fp32_accum:
pcc = 0.985
if input_channels * filter_height * filter_width > 10000:
pcc = 0.97
elif math_fidelity == ttnn.MathFidelity.LoFi and activations_dtype == ttnn.bfloat8_b:
pcc = 0.996
else:
Expand Down Expand Up @@ -388,6 +395,9 @@ def test_conv_features(
if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b:
pytest.skip("Row major layout not compatible with bfloat8_b")

if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat16 and packer_l1_acc and fp32_accum:
pytest.skip("skipping due to pack_untilize_dst issue!")

run_conv(
device,
torch_tensor_map,
Expand All @@ -411,6 +421,7 @@ def test_conv_features(
has_bias=True,
fp32_accum=fp32_accum,
packer_l1_acc=packer_l1_acc,
preprocess_weights_on_device=True,
)


Expand Down Expand Up @@ -782,7 +793,7 @@ def test_conv_for_segformer_512x512(
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat16, ttnn.bfloat8_b],
[ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
Expand Down Expand Up @@ -965,6 +976,7 @@ def test_resnet50_conv_wh(
pad_w,
config_override=config_override,
use_shallow_conv_variant=use_shallow_conv_variant,
transpose_shards=True, ## use RM (transpose_mcast=False) with 2D on WH
packer_l1_acc=packer_l1_acc,
fp32_accum=False,
has_bias=has_bias,
Expand Down Expand Up @@ -1026,6 +1038,7 @@ def test_conv_mem_config_wh(
shard_layout=shard_layout,
config_override=config_override,
use_shallow_conv_variant=use_shallow_conv_variant,
transpose_shards=True, ## use RM (transpose_mcast=False) with 2D on WH
packer_l1_acc=True,
fp32_accum=False,
has_bias=True,
Expand Down Expand Up @@ -1211,7 +1224,7 @@ def test_resnet50_conv_wh_fp32(
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat8_b],
[ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
Expand Down Expand Up @@ -1353,7 +1366,7 @@ def test_sd_conv(
)
@pytest.mark.parametrize(
"activations_dtype",
[ttnn.bfloat16, ttnn.bfloat8_b],
[ttnn.bfloat16],
)
@pytest.mark.parametrize(
"fp32_accum",
Expand Down Expand Up @@ -1494,7 +1507,7 @@ def test_sd_conv_wh(
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat8_b],
[ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
Expand Down Expand Up @@ -1646,6 +1659,7 @@ def test_unet_conv_wh(
config_override,
shard_layout=shard_layout,
use_shallow_conv_variant=use_shallow_conv_variant,
transpose_shards=True, ## use RM (transpose_mcast=False) with 2D on WH
output_layout=output_layout,
auto_shard=auto_shard,
)
Expand Down Expand Up @@ -1744,6 +1758,7 @@ def test_unet_conv_groups_2_wh(
config_override,
shard_layout=shard_layout,
use_shallow_conv_variant=use_shallow_conv_variant,
transpose_shards=True, ## use RM (transpose_mcast=False) with 2D on WH
output_layout=output_layout,
auto_shard=auto_shard,
groups=groups,
Expand Down Expand Up @@ -1841,6 +1856,7 @@ def test_unet_conv_groups_4_6_wh(
config_override,
shard_layout=shard_layout,
use_shallow_conv_variant=use_shallow_conv_variant,
transpose_shards=True, ## use RM (transpose_mcast=False) with 2D on WH
output_layout=output_layout,
groups=groups,
)
Expand Down Expand Up @@ -1939,12 +1955,14 @@ def test_unet_conv_groups_8_wh(
config_override,
shard_layout=shard_layout,
use_shallow_conv_variant=use_shallow_conv_variant,
transpose_shards=True, ## use RM (transpose_mcast=False) with 2D on WH
output_layout=output_layout,
auto_shard=auto_shard,
groups=groups,
)


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, config_override",
Expand Down Expand Up @@ -2006,6 +2024,7 @@ def test_halo_reshard_conv(
)


@skip_for_grayskull()
@pytest.mark.skip("New API needs to be tested")
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -2247,6 +2266,7 @@ def test_conv_groups(
)


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override, use_shallow_conv_variant, groups",
Expand Down Expand Up @@ -2367,6 +2387,7 @@ def test_yolov4_conv_groups_larger_than_one(
)


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
" output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override, use_shallow_conv_variant, groups",
Expand Down Expand Up @@ -2655,6 +2676,7 @@ def test_shallow_conv_with_tiled_input(device):

# Tests running conv2d which maps to matmul w/o sharding the input tensor.
# Output tensor is in DRAM.
@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize("tiled_input", [True, False])
@pytest.mark.parametrize("input_on_device", [True, False])
Expand Down Expand Up @@ -2780,6 +2802,9 @@ def test_small_in_large_out_channels_auto_shard(device, torch_tensor_map):
padding = (0, 0)
height = 128
width = 128
if device.core_grid.y != 8 and is_wormhole_b0():
pytest.skip("Needs 8x8 grid for wormhole_b0")

run_conv(
device,
torch_tensor_map,
Expand Down
130 changes: 0 additions & 130 deletions tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,133 +196,3 @@ def test_prepare_conv_weights(
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=pcc)
logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}")
assert passing


@skip_for_grayskull()
@skip_for_blackhole()
# @skip_for_wormhole_b0()
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override",
(
# rn50 layer1
(8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None),
(16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None),
(20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None),
),
)
@pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"])
@pytest.mark.parametrize("has_bias", [True, False], ids=["has_bias", "no_bias"])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 2**15}], indirect=True)
def test_prepare_bias(
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
packer_l1_acc,
config_override,
has_bias,
device,
):
if device.core_grid.y == 7:
pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range")

if batch_size == 20 and (
output_channels == 64 or (stride_h == 2 and (output_channels == 256 or output_channels == 128))
):
pytest.skip("Skipping test because it won't fit in L1!")

inp_shape = (batch_size, input_channels, input_height, input_width)
conv_weight_shape = (output_channels, input_channels, filter_height, filter_width)
torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16)
torch_input_tensor = torch.randn(inp_shape, dtype=torch.bfloat16)
torch_bias_tensor = torch.randn((1, 1, 1, output_channels), dtype=torch.bfloat16) if has_bias else None

torch_out_golden_tensor = torch.nn.functional.conv2d(
torch_input_tensor,
torch_weight_tensor,
bias=torch_bias_tensor.reshape(-1) if has_bias else None,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(1, 1),
groups=1,
).permute(0, 2, 3, 1)

tt_input_tensor = ttnn.from_torch(torch_input_tensor.transpose(-3, -2).transpose(-2, -1), ttnn.bfloat16)
tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, ttnn.bfloat16)
tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat16) if has_bias else None

conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat16,
weights_dtype=ttnn.bfloat16,
input_channels_alignment=(16 if input_channels == 16 and input_height == 115 else 32),
enable_act_double_buffer=False,
enable_split_reader=False,
enable_subblock_padding=False,
)
compute_config = ttnn.init_device_compute_kernel_config(device.arch(), packer_l1_acc=packer_l1_acc)
if config_override and "act_block_h" in config_override:
conv_config.act_block_h_override = config_override["act_block_h"]

if config_override and "act_block_w_div" in config_override:
conv_config.act_block_w_div = config_override["act_block_w_div"]

if config_override and "num_cores_nhw" in config_override:
if config_override["num_cores_nhw"] == 98:
conv_config.core_grid = ttnn.CoreRangeSet({ttnn.CoreRange((0, 0), (11, 7)), ttnn.CoreRange((0, 8), (1, 8))})
conv_config.override_sharding_config = True
print("Setting num_cores_nhw to 98")

conv_kwargs = {
"input_layout": ttnn.ROW_MAJOR_LAYOUT,
"in_channels": input_channels,
"out_channels": output_channels,
"batch_size": batch_size,
"input_height": input_height,
"input_width": input_width,
"kernel_size": (filter_height, filter_width),
"stride": (stride_h, stride_w),
"padding": (pad_h, pad_w),
"dilation": (1, 1),
"groups": 1,
"device": device,
"conv_config": conv_config,
}

tt_input_tensor = ttnn.to_device(tt_input_tensor, device)

tt_bias_tensor_formatted = (
ttnn.prepare_conv_bias(
bias_tensor=tt_bias_tensor, input_memory_config=tt_input_tensor.memory_config(), **conv_kwargs
)
if has_bias
else None
)

tt_bias_tensor_formatted = ttnn.to_device(tt_bias_tensor_formatted, device) if has_bias else None
(k := next(iter(conv_kwargs)), conv_kwargs.pop(k)) ##removing 1st element from dict
tt_output_tensor_on_device = ttnn.conv2d(
input_tensor=tt_input_tensor,
weight_tensor=tt_weight_tensor,
bias_tensor=tt_bias_tensor_formatted,
**conv_kwargs,
compute_config=compute_config,
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)

torch_output_tensor = torch_output_tensor[:, :, :, :output_channels]
torch_output_tensor = torch_output_tensor.reshape(torch_out_golden_tensor.shape)

pcc = 0.99
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=pcc)
logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}")
assert passing
Loading
Loading