Skip to content

Commit

Permalink
#13398: Dataparallel suppport for MNIST model
Browse files Browse the repository at this point in the history
  • Loading branch information
sabira-mcw committed Oct 8, 2024
1 parent 34f1b62 commit b22b3de
Show file tree
Hide file tree
Showing 8 changed files with 394 additions and 0 deletions.
15 changes: 15 additions & 0 deletions models/demos/wormhole/mnist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
## INTRODUCTION
The MNIST model uses only fully connected linear layers to classify handwritten digits from the MNIST dataset. Despite the absence of convolutional layers, the model efficiently processes the 28x28 pixel images by flattening them into a 1D vector and passing them through multiple linear layers to predict the corresponding digit (0-9). This approach demonstrates how even simpler architectures can be applied for image classification tasks.

## How to Run

To run the demo for digit classification using the MNIST model, follow these instructions:

- Use the following command to run the MNIST model.
```
pytest models/demos/wormhole/mnist/demo/demo.py::test_demo_dataset
```

# Additional Information

The input tensor for reshape op is in the host.
102 changes: 102 additions & 0 deletions models/demos/wormhole/mnist/demo/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import ttnn

from torchvision import transforms, datasets
from loguru import logger

from torch.utils.data import DataLoader
from models.demos.wormhole.mnist.reference.mnist import MnistModel
from models.demos.wormhole.mnist.tt import tt_mnist

from ttnn.model_preprocessing import preprocess_model_parameters
from models.utility_functions import is_wormhole_b0, skip_for_grayskull


def run_demo_dataset(
batch_size,
iterations,
model_location_generator,
):
# Data preprocessing/loading
transform = transforms.Compose([transforms.ToTensor()])
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

mesh_device = None
inputs_mesh_mapper = None
weights_mesh_mapper = None
output_mesh_composer = None
parameters = None
mesh_device_flag = False

state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist"))
model = MnistModel(state_dict)
model = model.eval()
if is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2:
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 2))
inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0)
weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device)
output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0)
mesh_device_flag = True
with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)):
parameters = preprocess_model_parameters(
initialize_model=lambda: model,
device=mesh_device,
)
else:
if is_wormhole_b0():
mesh_device = ttnn.open_device(device_id=0)
parameters = preprocess_model_parameters(
initialize_model=lambda: model,
device=mesh_device,
)

correct = 0
for iters in range(iterations):
dataloader = DataLoader(test_dataset, batch_size=batch_size)
x, labels = next(iter(dataloader))
dataset_predictions = []
ttnn_predictions = []
dataset_ttnn_correct = 0
x = ttnn.from_torch(x, dtype=ttnn.bfloat16, mesh_mapper=inputs_mesh_mapper, device=mesh_device)
tt_output = tt_mnist.mnist(mesh_device, batch_size, x, parameters)
tt_output = ttnn.to_torch(tt_output, mesh_composer=output_mesh_composer)
predicted_probabilities = torch.nn.functional.softmax(tt_output, dim=1)
_, predicted_label = torch.max(predicted_probabilities, 1)
tt_output = tt_output
for i in range(batch_size):
dataset_predictions.append(labels[i])
ttnn_predictions.append(predicted_label[i])
logger.info(f"Iter: {iters} Sample {i}:")
logger.info(f"Expected Label: {dataset_predictions[i]}")
logger.info(f"Predicted Label: {ttnn_predictions[i]}")

if dataset_predictions[i] == ttnn_predictions[i]:
dataset_ttnn_correct += 1
correct += 1
dataset_ttnn_accuracy = dataset_ttnn_correct / (batch_size)
logger.info(
f"ImageNet Inference Accuracy for iter {iters} of {batch_size} input samples : {dataset_ttnn_accuracy}"
)

accuracy = correct / (batch_size * iterations)
logger.info(f"ImageNet Inference Accuracy for {batch_size}x{iterations} Samples : {accuracy}")


@skip_for_grayskull()
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("iterations", [1])
def test_demo_dataset(
batch_size,
iterations,
model_location_generator,
):
return run_demo_dataset(
batch_size=batch_size,
iterations=iterations,
model_location_generator=model_location_generator,
)
30 changes: 30 additions & 0 deletions models/demos/wormhole/mnist/reference/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch


class MnistModel(torch.nn.Module):
def __init__(self, state_dict):
super().__init__()

self.fc1 = torch.nn.Linear(784, 120)
self.fc2 = torch.nn.Linear(120, 84)
self.fc3 = torch.nn.Linear(84, 10)

self.load_state_dict(state_dict)

def forward(self, x):
x = x.view(x.shape[0], -1)

x = self.fc1(x)
x = torch.nn.functional.relu(x)

x = self.fc2(x)
x = torch.nn.functional.relu(x)

x = self.fc3(x)
x = torch.nn.functional.relu(x)

return torch.nn.functional.softmax(x)
151 changes: 151 additions & 0 deletions models/demos/wormhole/mnist/tests/test_perf_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn
import time
import pytest
import torch
from loguru import logger
from torchvision import transforms, datasets
from models.perf.perf_utils import prep_perf_report
from models.demos.wormhole.mnist.tt import tt_mnist
from ttnn.model_preprocessing import preprocess_model_parameters
from models.demos.wormhole.mnist.reference.mnist import MnistModel
from models.utility_functions import is_grayskull, is_wormhole_b0
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from models.utility_functions import (
enable_persistent_kernel_cache,
disable_persistent_kernel_cache,
)
from models.utility_functions import is_grayskull, is_wormhole_b0, skip_for_grayskull
from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report

transform = transforms.Compose([transforms.ToTensor()])
test_dataset = datasets.MNIST(root="./data", train=False, transform=None, download=True)


@skip_for_grayskull()
def get_expected_times(tt_mnist):
if is_grayskull():
return {
tt_mnist: (2.3, 0.0041),
}[tt_mnist]
elif is_wormhole_b0():
return {
tt_mnist: (6.43, 0.0073),
}[tt_mnist]


@pytest.mark.models_performance_bare_metal
@pytest.mark.models_performance_virtual_machine
@pytest.mark.parametrize(
"batch_size",
[4],
)
@pytest.mark.parametrize(
"tt_mnist",
[tt_mnist],
)
def test_performance_mnist(batch_size, tt_mnist, model_location_generator):
mesh_device = None
inputs_mesh_mapper = None
weights_mesh_mapper = None
output_mesh_composer = None
parameters = None
mesh_device_flag = False

state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist"))
model = MnistModel(state_dict)
model = model.eval()
disable_persistent_kernel_cache()
transform = transforms.Compose([transforms.ToTensor()])
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
dataloader = DataLoader(test_dataset, batch_size=batch_size)
x, labels = next(iter(dataloader))

if is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2:
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, 2))
inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0)
weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device)
output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0)
mesh_device_flag = True
with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)):
parameters = preprocess_model_parameters(
initialize_model=lambda: model,
device=mesh_device,
)
else:
if is_wormhole_b0():
mesh_device = ttnn.open_device(device_id=0)
parameters = preprocess_model_parameters(
initialize_model=lambda: model,
device=mesh_device,
)

x = ttnn.from_torch(x, dtype=ttnn.bfloat16, mesh_mapper=inputs_mesh_mapper, device=mesh_device)
durations = []

for _ in range(2):
start = time.time()

ttnn_output = tt_mnist.mnist(mesh_device, batch_size, x, parameters)
end = time.time()
durations.append(end - start)
# enable_persistent_kernel_cache()

inference_and_compile_time, *inference_times = durations
average_inference_time = sum(inference_times) / len(inference_times)
expected_compile_time, expected_inference_time = get_expected_times(tt_mnist)

prep_perf_report(
model_name="MNIST",
batch_size=batch_size,
inference_and_compile_time=inference_and_compile_time,
inference_time=average_inference_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments="",
inference_time_cpu=0.0,
)

if mesh_device_flag:
ttnn.close_mesh_device(mesh_device)
else:
ttnn.close_device(mesh_device)

logger.info(f"Compile time: {inference_and_compile_time - average_inference_time}")
logger.info(f"Inference time: {average_inference_time}")
logger.info(f"Inference times: {inference_times}")
logger.info(f"Sample(s) per second: {1 / average_inference_time * batch_size}")


@skip_for_grayskull()
@pytest.mark.parametrize(
"batch_size, expected_perf",
[
[8, 28670.34],
],
)
@pytest.mark.models_device_performance_bare_metal
def test_perf_device_bare_metal(batch_size, expected_perf):
subdir = "ttnn_mnist"
num_iterations = 1
margin = 0.03

command = f"pytest tests/ttnn/integration_tests/mnist/test_mnist.py"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]

inference_time_key = "AVG DEVICE KERNEL SAMPLES/S"
expected_perf_cols = {inference_time_key: expected_perf}

post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size)
expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols)
prep_device_perf_report(
model_name=f"tt_mnist{batch_size}",
batch_size=batch_size,
post_processed_results=post_processed_results,
expected_results=expected_results,
comments="",
)
29 changes: 29 additions & 0 deletions models/demos/wormhole/mnist/tt/tt_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn
import torch


def mnist(mesh_device, batch_size, x, parameters):
x = ttnn.from_device(x)
x = ttnn.reshape(x, (x.shape[0], -1))

x = ttnn.to_device(x, device=mesh_device, memory_config=ttnn.L1_MEMORY_CONFIG)
x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT)

x = ttnn.linear(
x, parameters.fc1.weight, bias=parameters.fc1.bias, memory_config=ttnn.L1_MEMORY_CONFIG, activation="relu"
)

x = ttnn.linear(
x, parameters.fc2.weight, bias=parameters.fc2.bias, memory_config=ttnn.L1_MEMORY_CONFIG, activation="relu"
)

x = ttnn.linear(
x, parameters.fc3.weight, bias=parameters.fc3.bias, memory_config=ttnn.L1_MEMORY_CONFIG, activation="relu"
)

x = ttnn.softmax(x)
return x
4 changes: 4 additions & 0 deletions tests/scripts/run_performance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ run_perf_models_other() {

if [ "$tt_arch" == "wormhole_b0" ]; then
env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/resnet50/tests/test_perf_e2e_resnet50.py -m $test_marker

env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/mnist/tests/test_perf_mnist.py -m $test_marker
fi

env pytest -n auto tests/ttnn/integration_tests/bert/test_performance.py -m $test_marker
Expand Down Expand Up @@ -106,6 +108,8 @@ run_device_perf_models() {
env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/metal_BERT_large_11/tests -m $test_marker

env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/falcon7b_common/tests -m $test_marker

env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yam pytets models/demos/wormhole/mnist/tests/test_perf_mnist.py::test_performance_mnist -m $test_marker
fi

## Merge all the generated reports
Expand Down
4 changes: 4 additions & 0 deletions tests/scripts/single_card/run_single_card_demo_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ run_common_func_tests() {
# ConvNet Mnist
pytest --disable-warnings models/demos/convnet_mnist/demo/demo.py --timeout 600; fail+=$?

#MNIST
pytest --disable-warnings models/demos/wormhole/mnist/demo/demo.py --timeout 600; fail+=$?
return $fail
}

Expand Down Expand Up @@ -65,6 +67,8 @@ run_n300_func_tests() {

run_common_func_tests; fail+=$?

WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto --disable-warnings models/experimental/functional_mnist/demo/demo.py; fail+=$?

if [[ $fail -ne 0 ]]; then
exit 1
fi
Expand Down
Loading

0 comments on commit b22b3de

Please sign in to comment.