diff --git a/models/demos/convnet_mnist/tt/convnet_mnist.py b/models/demos/convnet_mnist/tt/convnet_mnist.py index a38aa60a770c..c323baf28930 100644 --- a/models/demos/convnet_mnist/tt/convnet_mnist.py +++ b/models/demos/convnet_mnist/tt/convnet_mnist.py @@ -21,7 +21,6 @@ def convnet_mnist( weights_dtype=ttnn.bfloat16, math_fidelity=ttnn.MathFidelity.LoFi, activation="", - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, math_approx_mode_enabled=True, fp32_dest_acc_enabled=False, packer_l1_accum_enabled=False, diff --git a/models/demos/wormhole/convnet_mnist/README.md b/models/demos/wormhole/convnet_mnist/README.md new file mode 100644 index 000000000000..faa26e49d5a6 --- /dev/null +++ b/models/demos/wormhole/convnet_mnist/README.md @@ -0,0 +1,24 @@ +# Introduction + +Convnet Mnist implements a Convolutions to classify handwritten digits from the MNIST dataset. The MNIST dataset contains grayscale images of handwritten digits (0-9), each of size 32x32 pixels. + +# Platforms: + WH N300 + +## How to Run + +To run the demo for digit classification using the MNIST model, follow these instructions: + +- Use the following command to run the MNIST model. + +``` +pytest models/demos/wormhole/convnet_mnist/demo/demo.py +``` + +Maxpool and Softmax are used in torch inside the model. +ISSUES: + #12664 - [softmax](https://github.com/tenstorrent/tt-metal/issues/12664) + #12642 - [maxpool](https://github.com/tenstorrent/tt-metal/issues/12642) + + +### Owner: [vigneshkumarkeerthivasan](https://github.com/vigneshkeerthivasanx) diff --git a/models/demos/wormhole/convnet_mnist/convnet_mnist_preprocessing.py b/models/demos/wormhole/convnet_mnist/convnet_mnist_preprocessing.py new file mode 100644 index 000000000000..99681afebe95 --- /dev/null +++ b/models/demos/wormhole/convnet_mnist/convnet_mnist_preprocessing.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn + + +def custom_preprocessor(parameters, device): + parameters.conv1.bias = ttnn.to_device(parameters.conv1.bias, device) + parameters.conv1.bias = ttnn.to_device(parameters.conv1.bias, device) + + parameters.fc1.weight = ttnn.to_device(parameters.fc1.weight, device) + parameters.fc1.bias = ttnn.to_device(parameters.fc1.bias, device) + parameters.fc2.weight = ttnn.to_device(parameters.fc2.weight, device) + parameters.fc2.bias = ttnn.to_device(parameters.fc2.bias, device) + + return parameters diff --git a/models/demos/wormhole/convnet_mnist/convnet_mnist_utils.py b/models/demos/wormhole/convnet_mnist/convnet_mnist_utils.py new file mode 100644 index 000000000000..74755f817b3c --- /dev/null +++ b/models/demos/wormhole/convnet_mnist/convnet_mnist_utils.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torchvision +import torchvision.transforms as transforms + + +def get_test_data(batch_size=64): + transform = transforms.Compose( + [ + transforms.Resize((32, 32)), + transforms.ToTensor(), + transforms.Normalize(mean=(0.05,), std=(0.05,)), + ] + ) + + test_dataset = torchvision.datasets.MNIST( + root="./data", + train=False, + download=True, + ) + + batch = [] + images = [] + outputs = [] + + for i in range(batch_size): + img, output = test_dataset[i] + tensor = transform(img).unsqueeze(0) + batch.append(tensor) + images.append(img) + outputs.append(output) + + batch = torch.cat(batch) + return batch, images, outputs diff --git a/models/demos/wormhole/convnet_mnist/demo/demo.py b/models/demos/wormhole/convnet_mnist/demo/demo.py new file mode 100644 index 000000000000..e62a0cd87d86 --- /dev/null +++ b/models/demos/wormhole/convnet_mnist/demo/demo.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +import pytest + +from pathlib import Path +from loguru import logger + +from models.demos.wormhole.convnet_mnist.tt.convnet_mnist import ( + convnet_mnist, + custom_preprocessor, +) +from models.demos.wormhole.convnet_mnist import convnet_mnist_preprocessing +from models.demos.wormhole.convnet_mnist.convnet_mnist_utils import get_test_data +from models.experimental.convnet_mnist.reference.convnet import ConvNet +from ttnn.model_preprocessing import preprocess_model_parameters +from models.utility_functions import is_wormhole_b0, skip_for_grayskull + + +def model_location_generator(rel_path): + internal_weka_path = Path("/mnt/MLPerf") + has_internal_weka = (internal_weka_path / "bit_error_tests").exists() + + if has_internal_weka: + return Path("/mnt/MLPerf") / rel_path + else: + return Path("/opt/tt-metal-models") / rel_path + + +@skip_for_grayskull() +@pytest.mark.parametrize( + "batch_size", + ((16),), +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +def test_convnet_mnist(mesh_device, batch_size, reset_seeds): + model_path = model_location_generator("tt_dnn-models/ConvNetMNIST/") + state_dict = str(model_path / "convnet_mnist.pt") + state_dict = torch.load(state_dict) + + test_input, images, output = get_test_data(batch_size) + + model = ConvNet() + model.load_state_dict(state_dict) + model.eval() + torch_output = model(test_input) + batch_size = len(test_input) + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: model, convert_to_ttnn=lambda *_: True, custom_preprocessor=custom_preprocessor + ) + parameters = convnet_mnist_preprocessing.custom_preprocessor(parameters, device=mesh_device) + + ttnn_input = torch.permute(test_input, (0, 2, 3, 1)) + ttnn_input = ttnn.from_torch( + ttnn_input, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=inputs_mesh_mapper + ) + + ttnn_output = convnet_mnist( + input_tensor=ttnn_input, + device=mesh_device, + parameters=parameters, + mesh_mapper=inputs_mesh_mapper, + mesh_composer=output_mesh_composer, + ) + + ttnn_output = ttnn.to_torch(ttnn_output, mesh_composer=output_mesh_composer) + + _, torch_predicted = torch.max(torch_output.data, -1) + _, ttnn_predicted = torch.max(ttnn_output.data, -1) + + correct = 0 + for i in range(batch_size): + if output[i] == ttnn_predicted[i]: + correct += 1 + accuracy = correct / (batch_size) + + logger.info(f" Accuracy for {batch_size} Samples : {accuracy}") + logger.info(f"torch_predicted {torch_predicted.squeeze()}") + logger.info(f"ttnn_predicted {ttnn_predicted.squeeze()}") diff --git a/models/demos/wormhole/convnet_mnist/tests/test_performance.py b/models/demos/wormhole/convnet_mnist/tests/test_performance.py new file mode 100644 index 000000000000..5faefd098be5 --- /dev/null +++ b/models/demos/wormhole/convnet_mnist/tests/test_performance.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn +import time +from pathlib import Path + +from loguru import logger +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters +from models.utility_functions import ( + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, +) +from models.perf.perf_utils import prep_perf_report +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report +from models.demos.wormhole.convnet_mnist.tt.convnet_mnist import ( + convnet_mnist, + custom_preprocessor, +) +from models.demos.wormhole.convnet_mnist import convnet_mnist_preprocessing +from models.experimental.convnet_mnist.reference.convnet import ConvNet +from models.utility_functions import is_wormhole_b0, skip_for_grayskull + + +def get_expected_times(convnet_mnist): + return (15.0, 9.2) + + +def model_location_generator(rel_path): + internal_weka_path = Path("/mnt/MLPerf") + has_internal_weka = (internal_weka_path / "bit_error_tests").exists() + + if has_internal_weka: + return Path("/mnt/MLPerf") / rel_path + else: + return Path("/opt/tt-metal-models") / rel_path + + +@skip_for_grayskull() +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "input_shape", + [ + (2, 1, 32, 32), + ], +) +def test_convnet_mnist( + mesh_device, + input_shape, + reset_seeds, +): + disable_persistent_kernel_cache() + + model_path = model_location_generator("tt_dnn-models/ConvNetMNIST/") + state_dict = str(model_path / "convnet_mnist.pt") + state_dict = torch.load(state_dict) + + input_tensor = torch.randn(input_shape, dtype=torch.bfloat16) + batch_size = input_tensor.shape[0] + input_tensor = torch.permute(input_tensor, (0, 2, 3, 1)) + + model = ConvNet() + model.load_state_dict(state_dict) + model.eval() + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: model, convert_to_ttnn=lambda *_: True, custom_preprocessor=custom_preprocessor + ) + parameters = convnet_mnist_preprocessing.custom_preprocessor(parameters, device=mesh_device) + + durations = [] + for i in range(2): + start = time.time() + ttnn_input = ttnn.from_torch( + input_tensor, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=inputs_mesh_mapper + ) + + ttnn_output = convnet_mnist( + input_tensor=ttnn_input, + device=mesh_device, + parameters=parameters, + mesh_mapper=inputs_mesh_mapper, + mesh_composer=output_mesh_composer, + ) + output = ttnn.from_device(ttnn_output) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) + end = time.time() + durations.append(end - start) + enable_persistent_kernel_cache() + + inference_and_compile_time, inference_time, *_ = durations + + expected_compile_time, expected_inference_time = get_expected_times("convnet_mnist") + prep_perf_report( + model_name="convnet_mnist", + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") + + +@skip_for_grayskull() +@pytest.mark.parametrize( + "batch_size, expected_perf", + [ + [1, 2885], + ], +) +@pytest.mark.models_device_performance_bare_metal +def test_perf_device_bare_metal_convnet_mnist(batch_size, expected_perf): + subdir = "ttnn_convnet_mnist" + num_iterations = 1 + margin = 0.03 + + command = f"pytest tests/ttnn/integration_tests/convnet_mnist/test_convnet_mnist_wh.py" + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) + prep_device_perf_report( + model_name=f"ttnn_convnet_mnist_wh_{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments="", + ) diff --git a/models/demos/wormhole/convnet_mnist/tt/convnet_mnist.py b/models/demos/wormhole/convnet_mnist/tt/convnet_mnist.py new file mode 100644 index 000000000000..8f9738b7dcb1 --- /dev/null +++ b/models/demos/wormhole/convnet_mnist/tt/convnet_mnist.py @@ -0,0 +1,169 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +import torch.nn.functional as F +from torch import nn + + +def convnet_mnist( + input_tensor, + parameters, + device, + mesh_mapper, + mesh_composer, +): + batch_size = input_tensor.shape[0] + torch_maxpool = True + + conv_config = ttnn.Conv2dConfig( + dtype=ttnn.bfloat16, + weights_dtype=ttnn.bfloat16, + math_fidelity=ttnn.MathFidelity.LoFi, + activation="", + math_approx_mode_enabled=True, + fp32_dest_acc_enabled=False, + packer_l1_accum_enabled=False, + input_channels_alignment=32, + transpose_shards=False, + reshard_if_not_optimal=True, + deallocate_activation=True, + reallocate_halo_output=True, + ) + + [x, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + input_tensor=input_tensor, + weight_tensor=parameters.conv1.weight, + in_channels=1, + out_channels=32, + device=device, + bias_tensor=parameters.conv1.bias, + kernel_size=(3, 3), + stride=(1, 1), + padding=(0, 0), + batch_size=batch_size, + input_height=input_tensor.shape[1], + input_width=input_tensor.shape[2], + conv_config=conv_config, + conv_op_cache={}, + debug=True, + groups=1, + ) + + x = ttnn.relu(x) + + if torch_maxpool: # Can be removed once issue #12642 is resolved + x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT) + x = ttnn.reshape(x, (batch_size, 30, 30, 32)) + x = ttnn.to_torch(x, mesh_composer=mesh_composer) + x = torch.permute(x, (0, 3, 1, 2)) + x = F.max_pool2d(x, 2) + x = torch.permute(x, (0, 2, 3, 1)) + x = ttnn.from_torch( + x, device=device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=mesh_mapper + ) + + else: + x = ttnn.sharded_to_interleaved(x, ttnn.L1_MEMORY_CONFIG) + x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) + x = ttnn.max_pool2d( + input_tensor=x, + batch_size=batch_size, + input_h=30, + input_w=30, + channels=32, + kernel_size=[2, 2], + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + device=device, + ) + + [x, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + input_tensor=x, + weight_tensor=parameters.conv2.weight, + in_channels=32, + out_channels=64, + device=device, + bias_tensor=parameters.conv2.bias, + kernel_size=(3, 3), + stride=(1, 1), + padding=(0, 0), + batch_size=batch_size, + input_height=15, + input_width=15, + conv_config=conv_config, + conv_op_cache={}, + debug=False, + groups=1, + ) + + x = ttnn.relu(x) + + if torch_maxpool: # Can be removed once issue #12642 is resolved + x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT) + x = ttnn.reshape(x, (batch_size, 13, 13, 64)) + x = ttnn.to_torch(x, mesh_composer=mesh_composer) + x = torch.permute(x, (0, 3, 1, 2)) + x = F.max_pool2d(x, 2) + x = ttnn.from_torch( + x, device=device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=mesh_mapper + ) + + else: + x = ttnn.sharded_to_interleaved(x, ttnn.DRAM_MEMORY_CONFIG) + x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) + x = ttnn.max_pool2d( + input_tensor=x, + batch_size=batch_size, + input_h=out_height, + input_w=out_width, + channels=x.shape[-1], + kernel_size=[2, 2], + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + device=device, + ) + x = ttnn.from_device(x) + x = ttnn.reshape(x, (x.shape[0], -1)) + x = ttnn.to_device(x, device) + x = ttnn.to_layout(x, ttnn.TILE_LAYOUT) + x = ttnn.linear( + x, + parameters.fc1.weight, + bias=parameters.fc1.bias, + activation="relu", + ) + + x = ttnn.linear( + x, + parameters.fc2.weight, + bias=parameters.fc2.bias, + ) + + output = torch.softmax(ttnn.to_torch(x, mesh_composer=mesh_composer), dim=-1) + output = ttnn.from_torch(output, device=device, dtype=ttnn.bfloat16, mesh_mapper=mesh_mapper) + return output + + +def preprocess_conv_parameter(parameter, *, dtype): + parameter = ttnn.from_torch(parameter, dtype=dtype) + return parameter + + +def custom_preprocessor(model, device): + parameters = {} + if isinstance(model, nn.Conv2d): + weight = model.weight + bias = model.bias + while weight.dim() < 4: + weight = weight.unsqueeze(0) + while bias.dim() < 4: + bias = bias.unsqueeze(0) + parameters["weight"] = preprocess_conv_parameter(weight, dtype=ttnn.bfloat16) + parameters["bias"] = preprocess_conv_parameter(bias, dtype=ttnn.bfloat16) + + return parameters diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index c251fa4ccb3b..df331b41b521 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -33,6 +33,8 @@ run_perf_models_other() { env pytest -n auto models/demos/convnet_mnist/tests -m $test_marker + env pytest -n auto models/demos/wormhole/convnet_mnist/tests -m $test_marker + ## Merge all the generated reports env python models/perf/merge_perf_results.py } @@ -95,6 +97,8 @@ run_device_perf_models() { env pytest models/demos/convnet_mnist/tests/ -m $test_marker + env pytest models/demos/wormhole/convnet_mnist/tests -m $test_marker + if [ "$tt_arch" == "grayskull" ]; then #TODO(MO): Until #6560 is fixed, GS device profiler test are grouped with #Model Device perf regression tests to make sure thy run on no-soft-reset BMs diff --git a/tests/scripts/single_card/run_single_card_demo_tests.sh b/tests/scripts/single_card/run_single_card_demo_tests.sh index 1091798b6f99..9fe7402b81e6 100755 --- a/tests/scripts/single_card/run_single_card_demo_tests.sh +++ b/tests/scripts/single_card/run_single_card_demo_tests.sh @@ -42,6 +42,8 @@ run_common_func_tests() { # ConvNet Mnist pytest --disable-warnings models/demos/convnet_mnist/demo/demo.py --timeout 600; fail+=$? + pytest --disable-warnings models/demos/wormhole/convnet_mnist/demo/demo.py --timeout 600; fail+=$? + return $fail } diff --git a/tests/ttnn/integration_tests/convnet_mnist/test_convnet_mnist_wh.py b/tests/ttnn/integration_tests/convnet_mnist/test_convnet_mnist_wh.py new file mode 100644 index 000000000000..e1f61741997b --- /dev/null +++ b/tests/ttnn/integration_tests/convnet_mnist/test_convnet_mnist_wh.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +import pytest + +from pathlib import Path + +from models.demos.wormhole.convnet_mnist.tt.convnet_mnist import ( + convnet_mnist, + custom_preprocessor, +) +from models.demos.wormhole.convnet_mnist import convnet_mnist_preprocessing +from models.demos.wormhole.convnet_mnist.convnet_mnist_utils import get_test_data +from models.experimental.convnet_mnist.reference.convnet import ConvNet +from ttnn.model_preprocessing import preprocess_model_parameters +from tests.ttnn.utils_for_testing import assert_with_pcc +from models.utility_functions import is_wormhole_b0, skip_for_grayskull + + +def model_location_generator(rel_path): + internal_weka_path = Path("/mnt/MLPerf") + has_internal_weka = (internal_weka_path / "bit_error_tests").exists() + + if has_internal_weka: + return Path("/mnt/MLPerf") / rel_path + else: + return Path("/opt/tt-metal-models") / rel_path + + +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +def test_convnet_mnist(mesh_device, reset_seeds): + model_path = model_location_generator("tt_dnn-models/ConvNetMNIST/") + state_dict = str(model_path / "convnet_mnist.pt") + state_dict = torch.load(state_dict) + + test_input, images, output = get_test_data(16) + + model = ConvNet() + model.load_state_dict(state_dict) + model.eval() + + torch_output = model(test_input) + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: model, convert_to_ttnn=lambda *_: True, custom_preprocessor=custom_preprocessor + ) + parameters = convnet_mnist_preprocessing.custom_preprocessor(parameters, device=mesh_device) + + ttnn_input = torch.permute(test_input, (0, 2, 3, 1)) + ttnn_input = ttnn.from_torch( + ttnn_input, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=inputs_mesh_mapper + ) + + ttnn_output = convnet_mnist( + input_tensor=ttnn_input, + device=mesh_device, + parameters=parameters, + mesh_mapper=inputs_mesh_mapper, + mesh_composer=output_mesh_composer, + ) + ttnn_output = ttnn.to_torch(ttnn_output, mesh_composer=output_mesh_composer) + + assert_with_pcc(torch_output, ttnn_output, 0.99)