Skip to content

Commit

Permalink
Add support for Resnet50 batch 32.
Browse files Browse the repository at this point in the history
  • Loading branch information
mywoodstock committed Mar 2, 2025
1 parent a019371 commit 8150c70
Show file tree
Hide file tree
Showing 18 changed files with 167 additions and 86 deletions.
22 changes: 21 additions & 1 deletion models/demos/ttnn_resnet/tests/resnet50_test_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
preprocess_model_parameters,
)
from models.utility_functions import (
is_blackhole,
is_wormhole_b0,
is_grayskull,
divup,
Expand Down Expand Up @@ -141,6 +142,23 @@ def load_resnet50_model(model_location_generator):
ttnn.bfloat8_b,
): 0.884609, # Max ATOL Delta: 6.455164909362793, Max RTOL Delta: inf, PCC: 0.8846098380419433
},
32: {
(
ttnn.MathFidelity.HiFi4,
ttnn.bfloat8_b,
ttnn.bfloat8_b,
): 0.97,
(
ttnn.MathFidelity.HiFi2,
ttnn.bfloat8_b,
ttnn.bfloat8_b,
): 0.95,
(
ttnn.MathFidelity.LoFi,
ttnn.bfloat8_b,
ttnn.bfloat8_b,
): 0.88,
},
}

golden_pcc = {
Expand Down Expand Up @@ -255,10 +273,12 @@ def setup_l1_sharded_input(self, device, torch_input_tensor=None):
if self.batch_size == 16:
core_grid = ttnn.CoreGrid(y=8, x=6)
elif self.batch_size == 20:
if is_grayskull():
if is_grayskull() or is_blackhole():
core_grid = ttnn.CoreGrid(y=8, x=10)
elif is_wormhole_b0():
core_grid = ttnn.CoreGrid(y=5, x=6) # untested due to unsupported batch20 on WH
elif self.batch_size == 32:
core_grid = ttnn.CoreGrid(y=10, x=13)
num_devices = 1 if isinstance(device, ttnn.Device) else device.get_num_devices()
# torch tensor
torch_input_tensor = self.torch_input_tensor if torch_input_tensor is None else torch_input_tensor
Expand Down
125 changes: 90 additions & 35 deletions models/demos/ttnn_resnet/tt/ttnn_functional_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,19 @@
mcast_in0=True,
),
20: ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig(
compute_with_storage_grid_size=(8, 4),
in0_block_w=2,
compute_with_storage_grid_size=(8, 8),
in0_block_w=1,
out_subblock_h=1,
out_subblock_w=1,
per_core_M=1,
per_core_N=1,
fuse_batch=True,
fused_activation=None,
mcast_in0=True,
),
32: ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig(
compute_with_storage_grid_size=(8, 8),
in0_block_w=1,
out_subblock_h=1,
out_subblock_w=1,
per_core_M=1,
Expand Down Expand Up @@ -197,8 +208,6 @@ def run_downsample_if_req(
enable_subblock_padding=enable_subblock_padding,
),
}
if is_blackhole():
conv_kwargs["conv_config"].enable_split_reader = False

if not ttnn.is_tensor_storage_on_device(self.ds_conv_weight_tensor):
self.ds_conv_weight_tensor = ttnn.prepare_conv_weights(
Expand Down Expand Up @@ -292,8 +301,6 @@ def __call__(
transpose_shards=transpose_shards,
),
}
if is_blackhole():
conv_kwargs_1["conv_config"].enable_split_reader = False

if not ttnn.is_tensor_storage_on_device(self.conv1_weight_tensor):
self.conv1_weight_tensor = ttnn.prepare_conv_weights(
Expand Down Expand Up @@ -415,7 +422,30 @@ def __call__(
if is_blackhole():
conv_kwargs_2["conv_config"].act_block_h_override = 2 * 32
conv_kwargs_2["conv_config"].enable_subblock_padding = False
conv_kwargs_2["conv_config"].enable_split_reader = False
if (
batch_size == 32
and layer_module
and (
layer_module == "layer1_module2"
or layer_module == "layer1_module3"
or layer_module == "layer2_module2"
or layer_module == "layer2_module3"
or layer_module == "layer2_module4"
)
):
conv_kwargs_2["conv_config"].act_block_h_override = 0
elif (
batch_size == 20
and layer_module
and (layer_module == "layer4_module2" or layer_module == "layer4_module3")
):
conv_kwargs_2["conv_config"].act_block_h_override = 0
elif (
batch_size == 16
and layer_module
and (layer_module == "layer1_module2" or layer_module == "layer1_module3")
):
conv_kwargs_2["conv_config"].act_block_h_override = 0

if not ttnn.is_tensor_storage_on_device(self.conv2_weight_tensor):
self.conv2_weight_tensor = ttnn.prepare_conv_weights(
Expand Down Expand Up @@ -449,6 +479,7 @@ def __call__(
return_output_dim=True,
return_weights_and_bias=False,
)

if layer_module and layer_module == "layer4_module1":
if ops_parallel_config and "layer4_module1_input" not in ops_parallel_config:
x_memory_config = ttnn.get_memory_config(out)
Expand Down Expand Up @@ -485,8 +516,6 @@ def __call__(
transpose_shards=transpose_shards,
),
}
if is_blackhole():
conv_kwargs_3["conv_config"].enable_split_reader = False

if not ttnn.is_tensor_storage_on_device(self.conv3_weight_tensor):
self.conv3_weight_tensor = ttnn.prepare_conv_weights(
Expand Down Expand Up @@ -598,9 +627,6 @@ def __init__(
self.conv1_output_channels = self.conv1_weight_tensor.shape[0]
assert self.conv1_weight_tensor.shape[2] == 4

self.max_pool_reader_patterns_cache = {}
max_pool_parallel_config_override = {}

self.layer1 = self._make_layer(
parameters=parameters.layer1,
planes=64,
Expand Down Expand Up @@ -666,12 +692,18 @@ def __init__(
) # num_classes = 1000

self.transpose_shards = True
act_block_h_override = 1568

if is_wormhole_b0() or is_blackhole():
act_block_h_override = 0

if is_wormhole_b0():
self.transpose_shards = False
act_block_h_override = 1568

if is_blackhole() and self.batch_size < 20:
self.transpose_shards = False
else:
act_block_h_override = 0

if is_blackhole() and self.batch_size == 32:
act_block_h_override = 49 * 32

input_channels_alignment = 16
self.conv1_config = ttnn.Conv2dConfig(
Expand Down Expand Up @@ -699,10 +731,6 @@ def __init__(
self.conv1_config.act_block_h_override = 64
else:
self.conv1_config.act_block_h_override = 49 * 32
if is_blackhole():
# self.conv1_config.act_block_h_override = 7 * 32
# self.conv1_config.act_block_h_override = 2 * 32
self.conv1_config.enable_split_reader = False

self.conv1_kernel_size = (4, 4)
self.conv1_stride = (1, 1)
Expand Down Expand Up @@ -738,6 +766,9 @@ def __init__(
if self.batch_size == 16:
num_cores_x = 8
num_cores_y = 8
self.fold_compute_grid_size = ttnn.CoreRangeSet(
{ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(num_cores_x - 1, num_cores_y - 1))}
)
elif self.batch_size == 20:
if is_grayskull():
num_cores_x = 10
Expand All @@ -746,9 +777,19 @@ def __init__(
num_cores_x = 8
num_cores_y = 5
elif is_blackhole():
num_cores_x = 8
num_cores_y = 10
self.fold_compute_grid_size = (num_cores_x, num_cores_y)
num_cores_x = 10
num_cores_y = 8
self.fold_compute_grid_size = ttnn.CoreRangeSet(
{ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(num_cores_x - 1, num_cores_y - 1))}
)
elif self.batch_size == 32:
core_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(12, 8)),
ttnn.CoreRange(ttnn.CoreCoord(0, 9), ttnn.CoreCoord(10, 9)),
}
)
self.fold_compute_grid_size = core_grid

conv_dummy_tensor = torch.rand((self.fold_output_shape), dtype=torch.bfloat16)
conv_dummy_tensor = ttnn.from_torch(conv_dummy_tensor, layout=ttnn.ROW_MAJOR_LAYOUT)
Expand All @@ -763,7 +804,7 @@ def __init__(
self.conv1_output_width,
device.compute_with_storage_grid_size(),
self.conv1_config.input_channels_alignment,
is_grayskull(),
is_grayskull() or is_blackhole(),
)

def __del__(self):
Expand Down Expand Up @@ -871,6 +912,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
)
self.conv1_weight_tensor = ttnn.to_device(self.conv1_weight_tensor, device)
self.conv1_bias_tensor = ttnn.to_device(self.conv1_bias_tensor, device)

x, [x_height, x_width] = ttnn.conv2d(
input_tensor=fold_output_tensor,
weight_tensor=self.conv1_weight_tensor,
Expand All @@ -881,6 +923,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
return_output_dim=True,
return_weights_and_bias=False,
)

# Relu is fused with conv1
if self.batch_size == 20:
x = ttnn.reallocate(x)
Expand All @@ -902,6 +945,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
x = ttnn.reshape(x, (1, 1, x_height * x_width * self.batch_size, 64))

if is_blackhole():
## 112
core_range_set = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
Expand All @@ -927,15 +971,16 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
)
x = ttnn.to_memory_config(x, mem_config)

x = ttnn.to_layout(x, ttnn.TILE_LAYOUT, dtype=self.model_config["ACTIVATIONS_DTYPE"])

if self.batch_size == 20 and is_grayskull():
x = ttnn.reallocate(x)

if not is_blackhole():
x = ttnn.to_layout(x, ttnn.TILE_LAYOUT, dtype=self.model_config["ACTIVATIONS_DTYPE"])

logger.debug(f"==== Running layer 1 module 1")
layer1_module1_input_shape = ttnn.Shape(x.padded_shape)

reshard = False
reshard = is_blackhole()
height_shard = True

x, x_height, x_width = self.layer1_module1(
Expand Down Expand Up @@ -975,6 +1020,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
enable_act_double_buffer=False,
enable_split_reader=True,
enable_subblock_padding=not is_grayskull(),
layer_module="layer1_module2",
)

logger.debug(f"==== Running layer 1 module 3")
Expand All @@ -989,14 +1035,15 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
enable_act_double_buffer=False,
enable_split_reader=True,
enable_subblock_padding=not is_grayskull(),
layer_module="layer1_module3",
)

layer2_module1_input_shape = ttnn.Shape(x.padded_shape)

reshard = not (is_wormhole_b0() or is_grayskull())
reshard = is_blackhole() or not (is_wormhole_b0() or is_grayskull())
height_shard = True

if is_blackhole():
if is_blackhole() and self.batch_size < 20:
## 98
core_range_set = ttnn.CoreRangeSet(
{
Expand Down Expand Up @@ -1031,8 +1078,9 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
height_sharding=height_shard,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True,
enable_split_reader=False,
enable_split_reader=True,
enable_subblock_padding=False,
layer_module="layer2_module1",
)

if is_first_run:
Expand All @@ -1055,8 +1103,9 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
conv_op_cache,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True,
enable_split_reader=False,
enable_split_reader=True,
enable_subblock_padding=False,
layer_module="layer2_module2",
)

logger.debug(f"==== Running layer 2 module 3")
Expand All @@ -1069,8 +1118,9 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
conv_op_cache,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True,
enable_split_reader=False,
enable_split_reader=True,
enable_subblock_padding=False,
layer_module="layer2_module3",
)

logger.debug(f"==== Running layer 2 module 4")
Expand All @@ -1083,8 +1133,9 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
conv_op_cache,
transpose_shards=self.transpose_shards,
enable_act_double_buffer=True,
enable_split_reader=False,
enable_split_reader=True,
enable_subblock_padding=False,
layer_module="layer2_module4",
)

layer3_module1_input_shape = ttnn.Shape(x.padded_shape)
Expand Down Expand Up @@ -1211,11 +1262,11 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
enable_subblock_padding=False,
)

reshard = is_grayskull()
reshard = is_grayskull() or (is_blackhole() and self.batch_size == 20)
height_shard = False

layer4_module1_input_shape = ttnn.Shape(x.padded_shape)
if is_blackhole():
if is_blackhole() and self.batch_size != 20:
# 104
grid_size = (13, 8)
core_range_set = ttnn.CoreRangeSet(
Expand Down Expand Up @@ -1275,6 +1326,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
enable_act_double_buffer=True,
enable_split_reader=False,
enable_subblock_padding=False,
layer_module="layer4_module2",
)

logger.debug(f"==== Running layer 4 module 3")
Expand All @@ -1289,9 +1341,12 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
enable_act_double_buffer=True,
enable_split_reader=False,
enable_subblock_padding=False,
layer_module="layer4_module3",
)

grid_size = (8, 4)
if self.batch_size > 16:
grid_size = (8, 8)
shard_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def test_perf_trace(
)


@run_for_wormhole_b0()
@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768, "num_command_queues": 2}], indirect=True)
@pytest.mark.parametrize(
Expand All @@ -95,10 +94,9 @@ def test_perf_2cqs(
)


@run_for_wormhole_b0()
@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize(
"device_params", [{"l1_small_size": 32768, "num_command_queues": 2, "trace_region_size": 1332224}], indirect=True
"device_params", [{"l1_small_size": 32768, "num_command_queues": 2, "trace_region_size": 1470464}], indirect=True
)
@pytest.mark.parametrize(
"batch_size, expected_inference_time, expected_compile_time",
Expand Down
Loading

0 comments on commit 8150c70

Please sign in to comment.