diff --git a/models/demos/wormhole/mnist/README.md b/models/demos/wormhole/mnist/README.md new file mode 100644 index 00000000000..6a296871b77 --- /dev/null +++ b/models/demos/wormhole/mnist/README.md @@ -0,0 +1,15 @@ +## INTRODUCTION +The MNIST model uses only fully connected linear layers to classify handwritten digits from the MNIST dataset. Despite the absence of convolutional layers, the model efficiently processes the 28x28 pixel images by flattening them into a 1D vector and passing them through multiple linear layers to predict the corresponding digit (0-9). This approach demonstrates how even simpler architectures can be applied for image classification tasks. + +## How to Run + +To run the demo for digit classification using the MNIST model, follow these instructions: + +- Use the following command to run the MNIST model. + ``` + pytest models/demos/wormhole/mnist/demo/demo.py::test_demo_dataset + ``` + +# Additional Information + +The input tensor for reshape op is in the host. diff --git a/models/demos/wormhole/mnist/demo/demo.py b/models/demos/wormhole/mnist/demo/demo.py new file mode 100644 index 00000000000..134089cd9ec --- /dev/null +++ b/models/demos/wormhole/mnist/demo/demo.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import ttnn + +from torchvision import transforms, datasets +from loguru import logger + +from torch.utils.data import DataLoader +from models.demos.wormhole.mnist.reference.mnist import MnistModel +from models.demos.wormhole.mnist.tt import tt_mnist + +from ttnn.model_preprocessing import preprocess_model_parameters +from models.utility_functions import is_wormhole_b0, skip_for_grayskull + + +def run_demo_dataset( + batch_size, + iterations, + model_location_generator, +): + # Data preprocessing/loading + transform = transforms.Compose([transforms.ToTensor()]) + test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True) + + mesh_device = None + inputs_mesh_mapper = None + weights_mesh_mapper = None + output_mesh_composer = None + parameters = None + mesh_device_flag = False + + state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist")) + model = MnistModel(state_dict) + model = model.eval() + if is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2: + mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 2)) + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + mesh_device_flag = True + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=mesh_device, + ) + else: + if is_wormhole_b0(): + mesh_device = ttnn.open_device(device_id=0) + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=mesh_device, + ) + + correct = 0 + for iters in range(iterations): + dataloader = DataLoader(test_dataset, batch_size=batch_size) + x, labels = next(iter(dataloader)) + dataset_predictions = [] + ttnn_predictions = [] + dataset_ttnn_correct = 0 + x = ttnn.from_torch(x, dtype=ttnn.bfloat16, mesh_mapper=inputs_mesh_mapper, device=mesh_device) + tt_output = tt_mnist.mnist(mesh_device, batch_size, x, parameters) + tt_output = ttnn.to_torch(tt_output, mesh_composer=output_mesh_composer) + predicted_probabilities = torch.nn.functional.softmax(tt_output, dim=1) + _, predicted_label = torch.max(predicted_probabilities, 1) + tt_output = tt_output + for i in range(batch_size): + dataset_predictions.append(labels[i]) + ttnn_predictions.append(predicted_label[i]) + logger.info(f"Iter: {iters} Sample {i}:") + logger.info(f"Expected Label: {dataset_predictions[i]}") + logger.info(f"Predicted Label: {ttnn_predictions[i]}") + + if dataset_predictions[i] == ttnn_predictions[i]: + dataset_ttnn_correct += 1 + correct += 1 + dataset_ttnn_accuracy = dataset_ttnn_correct / (batch_size) + logger.info( + f"ImageNet Inference Accuracy for iter {iters} of {batch_size} input samples : {dataset_ttnn_accuracy}" + ) + + accuracy = correct / (batch_size * iterations) + logger.info(f"ImageNet Inference Accuracy for {batch_size}x{iterations} Samples : {accuracy}") + + +@skip_for_grayskull() +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("iterations", [1]) +def test_demo_dataset( + batch_size, + iterations, + model_location_generator, +): + return run_demo_dataset( + batch_size=batch_size, + iterations=iterations, + model_location_generator=model_location_generator, + ) diff --git a/models/demos/wormhole/mnist/reference/mnist.py b/models/demos/wormhole/mnist/reference/mnist.py new file mode 100644 index 00000000000..b9cfb7365f4 --- /dev/null +++ b/models/demos/wormhole/mnist/reference/mnist.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +class MnistModel(torch.nn.Module): + def __init__(self, state_dict): + super().__init__() + + self.fc1 = torch.nn.Linear(784, 120) + self.fc2 = torch.nn.Linear(120, 84) + self.fc3 = torch.nn.Linear(84, 10) + + self.load_state_dict(state_dict) + + def forward(self, x): + x = x.view(x.shape[0], -1) + + x = self.fc1(x) + x = torch.nn.functional.relu(x) + + x = self.fc2(x) + x = torch.nn.functional.relu(x) + + x = self.fc3(x) + x = torch.nn.functional.relu(x) + + return torch.nn.functional.softmax(x) diff --git a/models/demos/wormhole/mnist/tests/test_perf_mnist.py b/models/demos/wormhole/mnist/tests/test_perf_mnist.py new file mode 100644 index 00000000000..e438d3e1440 --- /dev/null +++ b/models/demos/wormhole/mnist/tests/test_perf_mnist.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import time +import pytest +import torch +from loguru import logger +from torchvision import transforms, datasets +from models.perf.perf_utils import prep_perf_report +from models.demos.wormhole.mnist.tt import tt_mnist +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.wormhole.mnist.reference.mnist import MnistModel +from models.utility_functions import is_grayskull, is_wormhole_b0 +from torch.utils.data import DataLoader +from torchvision import transforms, datasets +from models.utility_functions import ( + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, +) +from models.utility_functions import is_grayskull, is_wormhole_b0, skip_for_grayskull +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report + +transform = transforms.Compose([transforms.ToTensor()]) +test_dataset = datasets.MNIST(root="./data", train=False, transform=None, download=True) + + +@skip_for_grayskull() +def get_expected_times(tt_mnist): + if is_grayskull(): + return { + tt_mnist: (2.3, 0.0041), + }[tt_mnist] + elif is_wormhole_b0(): + return { + tt_mnist: (6.43, 0.0073), + }[tt_mnist] + + +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize( + "batch_size", + [4], +) +@pytest.mark.parametrize( + "tt_mnist", + [tt_mnist], +) +def test_performance_mnist(batch_size, tt_mnist, model_location_generator): + mesh_device = None + inputs_mesh_mapper = None + weights_mesh_mapper = None + output_mesh_composer = None + parameters = None + mesh_device_flag = False + + state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist")) + model = MnistModel(state_dict) + model = model.eval() + disable_persistent_kernel_cache() + transform = transforms.Compose([transforms.ToTensor()]) + test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True) + dataloader = DataLoader(test_dataset, batch_size=batch_size) + x, labels = next(iter(dataloader)) + + if is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2: + mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 2)) + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + mesh_device_flag = True + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=mesh_device, + ) + else: + if is_wormhole_b0(): + mesh_device = ttnn.open_device(device_id=0) + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=mesh_device, + ) + + x = ttnn.from_torch(x, dtype=ttnn.bfloat16, mesh_mapper=inputs_mesh_mapper, device=mesh_device) + durations = [] + + for _ in range(2): + start = time.time() + + ttnn_output = tt_mnist.mnist(mesh_device, batch_size, x, parameters) + end = time.time() + durations.append(end - start) + # enable_persistent_kernel_cache() + + inference_and_compile_time, *inference_times = durations + average_inference_time = sum(inference_times) / len(inference_times) + expected_compile_time, expected_inference_time = get_expected_times(tt_mnist) + + prep_perf_report( + model_name="MNIST", + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=average_inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + if mesh_device_flag: + ttnn.close_mesh_device(mesh_device) + else: + ttnn.close_device(mesh_device) + + logger.info(f"Compile time: {inference_and_compile_time - average_inference_time}") + logger.info(f"Inference time: {average_inference_time}") + logger.info(f"Inference times: {inference_times}") + logger.info(f"Sample(s) per second: {1 / average_inference_time * batch_size}") + + +@skip_for_grayskull() +@pytest.mark.parametrize( + "batch_size, expected_perf", + [ + [8, 28670.34], + ], +) +@pytest.mark.models_device_performance_bare_metal +def test_perf_device_bare_metal(batch_size, expected_perf): + subdir = "ttnn_mnist" + num_iterations = 1 + margin = 0.03 + + command = f"pytest tests/ttnn/integration_tests/mnist/test_mnist.py" + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) + prep_device_perf_report( + model_name=f"tt_mnist{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments="", + ) diff --git a/models/demos/wormhole/mnist/tt/tt_mnist.py b/models/demos/wormhole/mnist/tt/tt_mnist.py new file mode 100644 index 00000000000..3bcd9fece60 --- /dev/null +++ b/models/demos/wormhole/mnist/tt/tt_mnist.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch + + +def mnist(mesh_device, batch_size, x, parameters): + x = ttnn.from_device(x) + x = ttnn.reshape(x, (x.shape[0], -1)) + + x = ttnn.to_device(x, device=mesh_device, memory_config=ttnn.L1_MEMORY_CONFIG) + x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT) + + x = ttnn.linear( + x, parameters.fc1.weight, bias=parameters.fc1.bias, memory_config=ttnn.L1_MEMORY_CONFIG, activation="relu" + ) + + x = ttnn.linear( + x, parameters.fc2.weight, bias=parameters.fc2.bias, memory_config=ttnn.L1_MEMORY_CONFIG, activation="relu" + ) + + x = ttnn.linear( + x, parameters.fc3.weight, bias=parameters.fc3.bias, memory_config=ttnn.L1_MEMORY_CONFIG, activation="relu" + ) + + x = ttnn.softmax(x) + return x diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index 8c15b65ecf9..bca4e300476 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -17,6 +17,8 @@ run_perf_models_other() { if [ "$tt_arch" == "wormhole_b0" ]; then env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/resnet50/tests/test_perf_e2e_resnet50.py -m $test_marker + + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/mnist/tests/test_perf_mnist.py -m $test_marker fi env pytest -n auto tests/ttnn/integration_tests/bert/test_performance.py -m $test_marker @@ -106,6 +108,8 @@ run_device_perf_models() { env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/metal_BERT_large_11/tests -m $test_marker env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/falcon7b_common/tests -m $test_marker + + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yam pytets models/demos/wormhole/mnist/tests/test_perf_mnist.py::test_performance_mnist -m $test_marker fi ## Merge all the generated reports diff --git a/tests/scripts/single_card/run_single_card_demo_tests.sh b/tests/scripts/single_card/run_single_card_demo_tests.sh index 957d161a627..84e46144ffd 100755 --- a/tests/scripts/single_card/run_single_card_demo_tests.sh +++ b/tests/scripts/single_card/run_single_card_demo_tests.sh @@ -28,6 +28,8 @@ run_common_func_tests() { # ConvNet Mnist pytest --disable-warnings models/demos/convnet_mnist/demo/demo.py --timeout 600; fail+=$? + #MNIST + pytest --disable-warnings models/demos/wormhole/mnist/demo/demo.py --timeout 600; fail+=$? return $fail } @@ -65,6 +67,8 @@ run_n300_func_tests() { run_common_func_tests; fail+=$? + WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto --disable-warnings models/experimental/functional_mnist/demo/demo.py; fail+=$? + if [[ $fail -ne 0 ]]; then exit 1 fi diff --git a/tests/ttnn/integration_tests/mnist/test_mnist.py b/tests/ttnn/integration_tests/mnist/test_mnist.py new file mode 100644 index 00000000000..8abd8ba1e3c --- /dev/null +++ b/tests/ttnn/integration_tests/mnist/test_mnist.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 +import torch +import ttnn +import pytest +from tests.ttnn.utils_for_testing import assert_with_pcc +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.wormhole.mnist.reference.mnist import MnistModel +from models.utility_functions import is_wormhole_b0, skip_for_grayskull +from models.demos.wormhole.mnist.tt import tt_mnist +from torch.utils.data import DataLoader +from torchvision import transforms, datasets + + +@skip_for_grayskull() +@pytest.mark.parametrize( + "batch_size", + [4], +) +def test_mnist(mesh_device, reset_seeds, batch_size, model_location_generator): + inputs_mesh_mapper = None + weights_mesh_mapper = None + output_mesh_composer = None + parameters = None + mesh_device_flag = False + + state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist")) + model = MnistModel(state_dict) + model = model.eval() + transform = transforms.Compose([transforms.ToTensor()]) + test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True) + dataloader = DataLoader(test_dataset, batch_size=batch_size) + x, labels = next(iter(dataloader)) + torch_output = model(x) + if is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2: + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + mesh_device_flag = True + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=mesh_device, + ) + else: + if is_wormhole_b0(): + mesh_device = ttnn.open_device(device_id=0) + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=mesh_device, + ) + x = ttnn.from_torch(x, dtype=ttnn.bfloat16, mesh_mapper=inputs_mesh_mapper, device=mesh_device) + tt_output = tt_mnist.mnist(mesh_device, batch_size, x, parameters) + tt_output = ttnn.to_torch(tt_output, mesh_composer=output_mesh_composer) + # if mesh_device_flag: + # ttnn.close_mesh_device(mesh_device) + # else: + # ttnn.close_device(mesh_device) + assert_with_pcc(torch_output, tt_output, 0.99)