Skip to content

Commit

Permalink
#13404: All Ops in TT-NN & on-device for Lenet model
Browse files Browse the repository at this point in the history
  • Loading branch information
sabira-mcw committed Dec 10, 2024
1 parent 516ecb1 commit 030e5f8
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 16 deletions.
8 changes: 4 additions & 4 deletions models/demos/lenet/tests/test_perf_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
def get_expected_times(tt_lenet):
if is_grayskull():
return {
tt_lenet: (5.94, 0.63291),
tt_lenet: (7.2, 0.05),
}[tt_lenet]
elif is_wormhole_b0():
return {
tt_lenet: (8.14, 0.8243),
tt_lenet: (10.1557, 0.045),
}[tt_lenet]


Expand Down Expand Up @@ -106,9 +106,9 @@ def test_perf_device_bare_metal(batch_size, reset_seeds):
num_iterations = 1
margin = 0.03
if is_grayskull():
expected_perf = 193314.92814121
expected_perf = 110955.849
elif is_wormhole_b0():
expected_perf = 113208.6151
expected_perf = 60971.775

command = f"pytest tests/ttnn/integration_tests/lenet/test_lenet.py"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]
Expand Down
26 changes: 15 additions & 11 deletions models/demos/lenet/tt/tt_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# SPDX-License-Identifier: Apache-2.0

import ttnn
import torch.nn as nn


def conv(device, input_tensor, batch_size, parameters):
Expand Down Expand Up @@ -48,20 +47,26 @@ def conv(device, input_tensor, batch_size, parameters):
def lenet(input_tensor, batch_size, device, parameters):
conv_1, out_height, out_width = conv(device, input_tensor, batch_size, parameters.layer1)
conv_1 = ttnn.sharded_to_interleaved(conv_1, ttnn.L1_MEMORY_CONFIG)
conv_1 = ttnn.reshape(conv_1, (batch_size, out_height, out_width, conv_1.shape[-1]))
conv_1 = ttnn.permute(conv_1, (0, 3, 1, 2))
conv_1 = ttnn.to_torch(conv_1)
conv_1 = ttnn.to_layout(conv_1, layout=ttnn.ROW_MAJOR_LAYOUT)
conv_1 = ttnn.pad(conv_1, [(0, 10)], value=0.0)

max = nn.MaxPool2d(kernel_size=2, stride=2)
maxpool_1 = max(conv_1)
maxpool_1 = ttnn.from_torch(
maxpool_1, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG
maxpool_1 = ttnn.max_pool2d(
input_tensor=conv_1,
batch_size=batch_size,
input_h=out_height,
input_w=out_width,
channels=conv_1.shape[3],
kernel_size=[2, 2],
stride=[2, 2],
padding=[0, 0],
dilation=[1, 1],
)
maxpool_1 = ttnn.permute(maxpool_1, (0, 2, 3, 1))

maxpool_1 = ttnn.sharded_to_interleaved(maxpool_1, ttnn.L1_MEMORY_CONFIG)
maxpool_1 = ttnn.reshape(maxpool_1, (batch_size, 14, 14, maxpool_1.shape[3]))
conv_2, out_height, out_width = conv(device, maxpool_1, batch_size, parameters.layer2)

conv_2 = ttnn.to_layout(conv_2, layout=ttnn.ROW_MAJOR_LAYOUT)

maxpool_2 = ttnn.max_pool2d(
input_tensor=conv_2,
batch_size=batch_size,
Expand All @@ -73,7 +78,6 @@ def lenet(input_tensor, batch_size, device, parameters):
padding=[0, 0],
dilation=[1, 1],
)

maxpool_2 = ttnn.sharded_to_interleaved(maxpool_2, ttnn.L1_MEMORY_CONFIG)
maxpool_2 = ttnn.to_layout(maxpool_2, layout=ttnn.TILE_LAYOUT)
maxpool_2 = ttnn.reshape(maxpool_2, (batch_size, 5, 5, maxpool_2.shape[3]))
Expand Down
4 changes: 4 additions & 0 deletions tests/scripts/run_performance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ run_perf_models_other() {

env pytest -n auto models/demos/squeezebert/tests/test_performance.py -m $test_marker

env pytest -n auto models/demos/lenet/tests -m $test_marker

## Merge all the generated reports
env python models/perf/merge_perf_results.py
}
Expand Down Expand Up @@ -97,6 +99,8 @@ run_device_perf_models() {

env pytest models/demos/squeezebert/tests -m $test_marker

env pytest models/demos/lenet/tests -m $test_marker

if [ "$tt_arch" == "grayskull" ]; then
#TODO(MO): Until #6560 is fixed, GS device profiler test are grouped with
#Model Device perf regression tests to make sure thy run on no-soft-reset BMs
Expand Down
3 changes: 3 additions & 0 deletions tests/scripts/single_card/run_single_card_demo_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ run_common_func_tests() {
# Mnist
pytest --disable-warnings models/demos/mnist/demo/demo.py --timeout 600; fail+=$?

# Lenet
pytest --disable-warnings models/demos/lenet/demo/demo.py --timeout 600; fail+=$?

# SqueezeBERT
pytest --disable-warnings models/demos/squeezebert/demo/demo.py --timeout 600; fail+=$?

Expand Down
3 changes: 2 additions & 1 deletion tests/ttnn/integration_tests/lenet/test_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from tests.ttnn.utils_for_testing import assert_with_pcc
from ttnn.model_preprocessing import preprocess_model_parameters
from models.utility_functions import is_grayskull
from models.demos.lenet.tt import tt_lenet
from models.demos.lenet import lenet_utils

Expand Down Expand Up @@ -37,4 +38,4 @@ def test_lenet(device, batch_size, model_location_generator, reset_seeds):

tt_output = ttnn.to_torch(tt_output)

assert_with_pcc(torch_output, tt_output, 0.9993) # 0.9993022969312866
assert_with_pcc(torch_output, tt_output, 0.997 if is_grayskull() else 0.9993)

0 comments on commit 030e5f8

Please sign in to comment.