Skip to content

Commit

Permalink
#10244: Move some utils/test infra for resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed Jul 25, 2024
1 parent 33fe2d6 commit 8b8160d
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 267 deletions.
2 changes: 1 addition & 1 deletion models/demos/resnet/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from models.demos.resnet.tt.metalResnetBlock50 import ResNet, Bottleneck

resnet_model_config = {
"MATH_FIDELITY": tt_lib.tensor.MathFidelity.HiFi2,
"MATH_FIDELITY": tt_lib.tensor.MathFidelity.LoFi,
"WEIGHTS_DTYPE": tt_lib.tensor.DataType.BFLOAT8_B,
"ACTIVATIONS_DTYPE": tt_lib.tensor.DataType.BFLOAT8_B,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
divup,
run_for_wormhole_b0,
)
from tests.ttnn.integration_tests.resnet.test_ttnn_functional_resnet50_new import create_test_infra
from models.demos.ttnn_resnet.tests.ttnn_resnet_test_infra import create_test_infra

try:
from tracy import signpost
Expand Down
4 changes: 2 additions & 2 deletions models/demos/ttnn_resnet/tests/test_perf_ttnn_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
setup_l1_sharded_input,
setup_dram_sharded_input,
)
from models.demos.ttnn_resnet.tt.custom_preprocessing import custom_preprocessor
from models.demos.ttnn_resnet.tt.custom_preprocessing import create_custom_mesh_preprocessor
from models.demos.ttnn_resnet.tt.ttnn_functional_resnet50_new_conv_api import resnet50

try:
Expand Down Expand Up @@ -308,7 +308,7 @@ def run_perf_resnet(
torch_resnet50.eval()

parameters = preprocess_model_parameters(
initialize_model=lambda: torch_resnet50, custom_preprocessor=custom_preprocessor, device=None
initialize_model=lambda: torch_resnet50, custom_preprocessor=create_custom_mesh_preprocessor(None), device=None
)
torch_resnet50.to(torch.bfloat16)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
divup,
skip_for_grayskull,
)
from tests.ttnn.integration_tests.resnet.test_ttnn_functional_resnet50_new import create_test_infra
from models.demos.ttnn_resnet.tests.ttnn_resnet_test_infra import create_test_infra

try:
from tracy import signpost
Expand Down
259 changes: 259 additions & 0 deletions models/demos/ttnn_resnet/tests/ttnn_resnet_test_infra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from loguru import logger
import torch
import torchvision

import ttnn
from ttnn.model_preprocessing import (
preprocess_model_parameters,
)

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.demos.ttnn_resnet.tt.custom_preprocessing import create_custom_mesh_preprocessor
from models.demos.ttnn_resnet.tt.ttnn_functional_resnet50_new_conv_api import resnet50


## copied from ttlib version test:
# golden pcc is ordered fidelity, weight dtype, activation dtype
golden_pcc = {
8: {
(
ttnn.MathFidelity.HiFi4,
ttnn.bfloat16,
ttnn.bfloat16,
): 0.983301, # PCC: 0.9833017469734239 TODO: NEED DEBUGGING WHY THIS IS SLIGHTLY LOWER THAN TTLIB
# ): 0.990804, # Max ATOL Delta: 1.607335090637207, Max RTOL Delta: 115.62200164794922, PCC: 0.9908042840544742
(
ttnn.MathFidelity.HiFi4,
ttnn.bfloat8_b,
ttnn.bfloat16,
): 0.986301, # Max ATOL Delta: 1.5697126388549805, Max RTOL Delta: 21.3042049407959, PCC: 0.9863013351442654
(
ttnn.MathFidelity.HiFi4,
ttnn.bfloat16,
ttnn.bfloat8_b,
): 0.973763, # Max ATOL Delta: 2.455164909362793, Max RTOL Delta: inf, PCC: 0.9737631427307492
(
ttnn.MathFidelity.HiFi4,
ttnn.bfloat8_b,
ttnn.bfloat8_b,
): 0.978099, # Max ATOL Delta: 1.955164909362793, Max RTOL Delta: inf, PCC: 0.9780993165966628
(
ttnn.MathFidelity.HiFi2,
ttnn.bfloat16,
ttnn.bfloat16,
): 0.983400, # Max ATOL Delta: 1.7310011386871338, Max RTOL Delta: 369.5689392089844, PCC: 0.9834004200555363
(
ttnn.MathFidelity.HiFi2,
ttnn.bfloat8_b,
ttnn.bfloat16,
): 0.984828, # Max ATOL Delta: 1.6054553985595703, Max RTOL Delta: 59.124324798583984, PCC: 0.9848281996919587
(
ttnn.MathFidelity.HiFi2,
ttnn.bfloat16,
ttnn.bfloat8_b,
): 0.934073, # Max ATOL Delta: 4.330164909362793, Max RTOL Delta: inf, PCC: 0.9340735819578696
(
ttnn.MathFidelity.HiFi2,
ttnn.bfloat8_b,
ttnn.bfloat8_b,
): 0.944435, # Max ATOL Delta: 4.705164909362793, Max RTOL Delta: inf, PCC: 0.9444350983635019
(
ttnn.MathFidelity.LoFi,
ttnn.bfloat16,
ttnn.bfloat16,
): 0.938909, # Max ATOL Delta: 3.861414909362793, Max RTOL Delta: 240.63145446777344, PCC: 0.9389092547575272
(
ttnn.MathFidelity.LoFi,
ttnn.bfloat8_b,
ttnn.bfloat16,
): 0.959609, # Max ATOL Delta: 3.205164909362793, Max RTOL Delta: 141.7057342529297, PCC: 0.9596095155046113
(
ttnn.MathFidelity.LoFi,
ttnn.bfloat16,
ttnn.bfloat8_b,
): 0.854903, # Max ATOL Delta: 7.830164909362793, Max RTOL Delta: inf, PCC: 0.8549035869182201
(
ttnn.MathFidelity.LoFi,
ttnn.bfloat8_b,
ttnn.bfloat8_b,
): 0.884609, # Max ATOL Delta: 6.455164909362793, Max RTOL Delta: inf, PCC: 0.8846098380419433
},
16: {
(
ttnn.MathFidelity.HiFi4,
ttnn.bfloat8_b,
ttnn.bfloat8_b,
): 0.978099, # Max ATOL Delta: 1.955164909362793, Max RTOL Delta: inf, PCC: 0.9780993165966632
(
ttnn.MathFidelity.HiFi2,
ttnn.bfloat8_b,
ttnn.bfloat8_b,
): 0.941, # PCC: 0.9414369437627494 TODO: NEED DEBUGGING WHY THIS IS SLIGHTLY LOWER THAN TTLIB
# ): 0.944435, # Max ATOL Delta: 4.705164909362793, Max RTOL Delta: inf, PCC: 0.9444350983635021
(
ttnn.MathFidelity.LoFi,
ttnn.bfloat8_b,
ttnn.bfloat8_b,
): 0.884609, # Max ATOL Delta: 6.455164909362793, Max RTOL Delta: inf, PCC: 0.8846098380419435
},
20: {
(
ttnn.MathFidelity.HiFi4,
ttnn.bfloat8_b,
ttnn.bfloat8_b,
): 0.978099, # Max ATOL Delta: 1.955164909362793, Max RTOL Delta: inf, PCC: 0.9780993165966628
(
ttnn.MathFidelity.HiFi2,
ttnn.bfloat8_b,
ttnn.bfloat8_b,
): 0.941, # PCC: 0.9419975597174123 TODO: NEED DEBUGGING WHY THIS IS SLIGHTLY LOWER THAN TTLIB
# ): 0.944435, # Max ATOL Delta: 4.705164909362793, Max RTOL Delta: inf, PCC: 0.9444350983635021
(
ttnn.MathFidelity.LoFi,
ttnn.bfloat8_b,
ttnn.bfloat8_b,
): 0.884609, # Max ATOL Delta: 6.455164909362793, Max RTOL Delta: inf, PCC: 0.8846098380419433
},
}


class ResNet50TestInfra:
def __init__(
self,
device,
batch_size,
act_dtype,
weight_dtype,
math_fidelity,
use_pretrained_weight,
dealloc_input,
final_output_mem_config,
inputs_mesh_mapper=None,
weights_mesh_mapper=None,
output_mesh_composer=None,
):
super().__init__()
torch.manual_seed(0)
self.pcc_passed = False
self.pcc_message = "Did you forget to call validate()?"
self.device = device
self.batch_size = batch_size
self.act_dtype = act_dtype
self.weight_dtype = weight_dtype
self.math_fidelity = math_fidelity
self.dealloc_input = dealloc_input
self.final_output_mem_config = final_output_mem_config
self.inputs_mesh_mapper = inputs_mesh_mapper
self.weights_mesh_mapper = weights_mesh_mapper
self.output_mesh_composer = output_mesh_composer

torch_model = (
torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1).eval()
if use_pretrained_weight
else torchvision.models.resnet50().eval()
)

model_config = {
"MATH_FIDELITY": math_fidelity,
"WEIGHTS_DTYPE": weight_dtype,
"ACTIVATIONS_DTYPE": act_dtype,
}

num_devices = 1 if isinstance(device, ttnn.Device) else device.get_num_devices()
input_shape = (batch_size * num_devices, 3, 224, 224)

self.torch_input_tensor = torch.rand(input_shape, dtype=torch.float32)

parameters = preprocess_model_parameters(
initialize_model=lambda: torch_model,
custom_preprocessor=create_custom_mesh_preprocessor(self.weights_mesh_mapper),
device=None,
)

torch_model.to(torch.bfloat16)
self.torch_input_tensor = self.torch_input_tensor.to(torch.bfloat16)

## golden

self.torch_output_tensor = torch_model(self.torch_input_tensor)

## ttnn

self.ttnn_resnet50_model = resnet50(
device=device,
parameters=parameters,
batch_size=batch_size,
model_config=model_config,
dealloc_input=dealloc_input,
final_output_mem_config=final_output_mem_config,
mesh_mapper=weights_mesh_mapper,
)
self.ops_parallel_config = {}

def preprocess_torch_input(self, torch_input_tensor=None):
torch_input_tensor = self.torch_input_tensor if torch_input_tensor is None else torch_input_tensor
self.input_tensor = self.ttnn_resnet50_model.preprocessing(torch_input_tensor, self.inputs_mesh_mapper)

def run(self, torch_input_tensor=None):
# Note: currently not including the time to flip from torch to ttnn tensors.
# self.preprocess_torch_input(torch_input_tensor)
self.output_tensor = self.ttnn_resnet50_model(self.input_tensor, self.device, self.ops_parallel_config)
return self.output_tensor

def validate(self, output_tensor=None):
output_tensor = self.output_tensor if output_tensor is None else output_tensor
output_tensor = ttnn.to_torch(output_tensor, device=self.device, mesh_composer=self.output_mesh_composer)
output_tensor = torch.reshape(output_tensor, (output_tensor.shape[0], 1000))

valid_pcc = 1.0
if self.batch_size >= 8:
valid_pcc = golden_pcc[self.batch_size][(self.math_fidelity, self.weight_dtype, self.act_dtype)]
else:
if self.act_dtype == ttnn.bfloat8_b:
if self.math_fidelity == ttnn.MathFidelity.LoFi:
valid_pcc = 0.87
else:
valid_pcc = 0.94
else:
if self.math_fidelity == ttnn.MathFidelity.LoFi:
valid_pcc = 0.93
else:
valid_pcc = 0.982
print(valid_pcc)
self.pcc_passed, self.pcc_message = assert_with_pcc(self.torch_output_tensor, output_tensor, pcc=valid_pcc)

logger.info(
f"ResNet50 batch_size={self.batch_size}, act_dtype={self.act_dtype}, weight_dtype={self.weight_dtype}, math_fidelity={self.math_fidelity}, PCC={self.pcc_message}"
)


def create_test_infra(
device,
batch_size,
act_dtype,
weight_dtype,
math_fidelity,
use_pretrained_weight=True,
dealloc_input=True,
final_output_mem_config=ttnn.L1_MEMORY_CONFIG,
inputs_mesh_mapper=None,
weights_mesh_mapper=None,
output_mesh_composer=None,
):
return ResNet50TestInfra(
device,
batch_size,
act_dtype,
weight_dtype,
math_fidelity,
use_pretrained_weight,
dealloc_input,
final_output_mem_config,
inputs_mesh_mapper,
weights_mesh_mapper,
output_mesh_composer,
)
9 changes: 9 additions & 0 deletions models/demos/ttnn_resnet/tt/custom_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,12 @@ def custom_preprocessor(
ttnn_module_args=ttnn_module_args,
)
return parameters


def create_custom_mesh_preprocessor(mesh_mapper=None):
def custom_mesh_preprocessor(model, name, ttnn_module_args, convert_to_ttnn):
return custom_preprocessor(
model, name, ttnn_module_args, convert_to_ttnn, custom_mesh_preprocessor, mesh_mapper
)

return custom_mesh_preprocessor
2 changes: 1 addition & 1 deletion tests/ttnn/integration_tests/resnet/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import pytest

from tests.ttnn.integration_tests.resnet.test_ttnn_functional_resnet50 import create_test_infra
from models.demos.ttnn_resnet.tests.ttnn_resnet_test_infra import create_test_infra
from models.utility_functions import skip_for_wormhole_b0, skip_for_grayskull
import ttnn
from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report
Expand Down
Loading

0 comments on commit 8b8160d

Please sign in to comment.