-
Notifications
You must be signed in to change notification settings - Fork 116
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#10244: Move some utils/test infra for resnet
- Loading branch information
Showing
8 changed files
with
275 additions
and
267 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
259 changes: 259 additions & 0 deletions
259
models/demos/ttnn_resnet/tests/ttnn_resnet_test_infra.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from loguru import logger | ||
import torch | ||
import torchvision | ||
|
||
import ttnn | ||
from ttnn.model_preprocessing import ( | ||
preprocess_model_parameters, | ||
) | ||
|
||
from tests.ttnn.utils_for_testing import assert_with_pcc | ||
from models.demos.ttnn_resnet.tt.custom_preprocessing import create_custom_mesh_preprocessor | ||
from models.demos.ttnn_resnet.tt.ttnn_functional_resnet50_new_conv_api import resnet50 | ||
|
||
|
||
## copied from ttlib version test: | ||
# golden pcc is ordered fidelity, weight dtype, activation dtype | ||
golden_pcc = { | ||
8: { | ||
( | ||
ttnn.MathFidelity.HiFi4, | ||
ttnn.bfloat16, | ||
ttnn.bfloat16, | ||
): 0.983301, # PCC: 0.9833017469734239 TODO: NEED DEBUGGING WHY THIS IS SLIGHTLY LOWER THAN TTLIB | ||
# ): 0.990804, # Max ATOL Delta: 1.607335090637207, Max RTOL Delta: 115.62200164794922, PCC: 0.9908042840544742 | ||
( | ||
ttnn.MathFidelity.HiFi4, | ||
ttnn.bfloat8_b, | ||
ttnn.bfloat16, | ||
): 0.986301, # Max ATOL Delta: 1.5697126388549805, Max RTOL Delta: 21.3042049407959, PCC: 0.9863013351442654 | ||
( | ||
ttnn.MathFidelity.HiFi4, | ||
ttnn.bfloat16, | ||
ttnn.bfloat8_b, | ||
): 0.973763, # Max ATOL Delta: 2.455164909362793, Max RTOL Delta: inf, PCC: 0.9737631427307492 | ||
( | ||
ttnn.MathFidelity.HiFi4, | ||
ttnn.bfloat8_b, | ||
ttnn.bfloat8_b, | ||
): 0.978099, # Max ATOL Delta: 1.955164909362793, Max RTOL Delta: inf, PCC: 0.9780993165966628 | ||
( | ||
ttnn.MathFidelity.HiFi2, | ||
ttnn.bfloat16, | ||
ttnn.bfloat16, | ||
): 0.983400, # Max ATOL Delta: 1.7310011386871338, Max RTOL Delta: 369.5689392089844, PCC: 0.9834004200555363 | ||
( | ||
ttnn.MathFidelity.HiFi2, | ||
ttnn.bfloat8_b, | ||
ttnn.bfloat16, | ||
): 0.984828, # Max ATOL Delta: 1.6054553985595703, Max RTOL Delta: 59.124324798583984, PCC: 0.9848281996919587 | ||
( | ||
ttnn.MathFidelity.HiFi2, | ||
ttnn.bfloat16, | ||
ttnn.bfloat8_b, | ||
): 0.934073, # Max ATOL Delta: 4.330164909362793, Max RTOL Delta: inf, PCC: 0.9340735819578696 | ||
( | ||
ttnn.MathFidelity.HiFi2, | ||
ttnn.bfloat8_b, | ||
ttnn.bfloat8_b, | ||
): 0.944435, # Max ATOL Delta: 4.705164909362793, Max RTOL Delta: inf, PCC: 0.9444350983635019 | ||
( | ||
ttnn.MathFidelity.LoFi, | ||
ttnn.bfloat16, | ||
ttnn.bfloat16, | ||
): 0.938909, # Max ATOL Delta: 3.861414909362793, Max RTOL Delta: 240.63145446777344, PCC: 0.9389092547575272 | ||
( | ||
ttnn.MathFidelity.LoFi, | ||
ttnn.bfloat8_b, | ||
ttnn.bfloat16, | ||
): 0.959609, # Max ATOL Delta: 3.205164909362793, Max RTOL Delta: 141.7057342529297, PCC: 0.9596095155046113 | ||
( | ||
ttnn.MathFidelity.LoFi, | ||
ttnn.bfloat16, | ||
ttnn.bfloat8_b, | ||
): 0.854903, # Max ATOL Delta: 7.830164909362793, Max RTOL Delta: inf, PCC: 0.8549035869182201 | ||
( | ||
ttnn.MathFidelity.LoFi, | ||
ttnn.bfloat8_b, | ||
ttnn.bfloat8_b, | ||
): 0.884609, # Max ATOL Delta: 6.455164909362793, Max RTOL Delta: inf, PCC: 0.8846098380419433 | ||
}, | ||
16: { | ||
( | ||
ttnn.MathFidelity.HiFi4, | ||
ttnn.bfloat8_b, | ||
ttnn.bfloat8_b, | ||
): 0.978099, # Max ATOL Delta: 1.955164909362793, Max RTOL Delta: inf, PCC: 0.9780993165966632 | ||
( | ||
ttnn.MathFidelity.HiFi2, | ||
ttnn.bfloat8_b, | ||
ttnn.bfloat8_b, | ||
): 0.941, # PCC: 0.9414369437627494 TODO: NEED DEBUGGING WHY THIS IS SLIGHTLY LOWER THAN TTLIB | ||
# ): 0.944435, # Max ATOL Delta: 4.705164909362793, Max RTOL Delta: inf, PCC: 0.9444350983635021 | ||
( | ||
ttnn.MathFidelity.LoFi, | ||
ttnn.bfloat8_b, | ||
ttnn.bfloat8_b, | ||
): 0.884609, # Max ATOL Delta: 6.455164909362793, Max RTOL Delta: inf, PCC: 0.8846098380419435 | ||
}, | ||
20: { | ||
( | ||
ttnn.MathFidelity.HiFi4, | ||
ttnn.bfloat8_b, | ||
ttnn.bfloat8_b, | ||
): 0.978099, # Max ATOL Delta: 1.955164909362793, Max RTOL Delta: inf, PCC: 0.9780993165966628 | ||
( | ||
ttnn.MathFidelity.HiFi2, | ||
ttnn.bfloat8_b, | ||
ttnn.bfloat8_b, | ||
): 0.941, # PCC: 0.9419975597174123 TODO: NEED DEBUGGING WHY THIS IS SLIGHTLY LOWER THAN TTLIB | ||
# ): 0.944435, # Max ATOL Delta: 4.705164909362793, Max RTOL Delta: inf, PCC: 0.9444350983635021 | ||
( | ||
ttnn.MathFidelity.LoFi, | ||
ttnn.bfloat8_b, | ||
ttnn.bfloat8_b, | ||
): 0.884609, # Max ATOL Delta: 6.455164909362793, Max RTOL Delta: inf, PCC: 0.8846098380419433 | ||
}, | ||
} | ||
|
||
|
||
class ResNet50TestInfra: | ||
def __init__( | ||
self, | ||
device, | ||
batch_size, | ||
act_dtype, | ||
weight_dtype, | ||
math_fidelity, | ||
use_pretrained_weight, | ||
dealloc_input, | ||
final_output_mem_config, | ||
inputs_mesh_mapper=None, | ||
weights_mesh_mapper=None, | ||
output_mesh_composer=None, | ||
): | ||
super().__init__() | ||
torch.manual_seed(0) | ||
self.pcc_passed = False | ||
self.pcc_message = "Did you forget to call validate()?" | ||
self.device = device | ||
self.batch_size = batch_size | ||
self.act_dtype = act_dtype | ||
self.weight_dtype = weight_dtype | ||
self.math_fidelity = math_fidelity | ||
self.dealloc_input = dealloc_input | ||
self.final_output_mem_config = final_output_mem_config | ||
self.inputs_mesh_mapper = inputs_mesh_mapper | ||
self.weights_mesh_mapper = weights_mesh_mapper | ||
self.output_mesh_composer = output_mesh_composer | ||
|
||
torch_model = ( | ||
torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1).eval() | ||
if use_pretrained_weight | ||
else torchvision.models.resnet50().eval() | ||
) | ||
|
||
model_config = { | ||
"MATH_FIDELITY": math_fidelity, | ||
"WEIGHTS_DTYPE": weight_dtype, | ||
"ACTIVATIONS_DTYPE": act_dtype, | ||
} | ||
|
||
num_devices = 1 if isinstance(device, ttnn.Device) else device.get_num_devices() | ||
input_shape = (batch_size * num_devices, 3, 224, 224) | ||
|
||
self.torch_input_tensor = torch.rand(input_shape, dtype=torch.float32) | ||
|
||
parameters = preprocess_model_parameters( | ||
initialize_model=lambda: torch_model, | ||
custom_preprocessor=create_custom_mesh_preprocessor(self.weights_mesh_mapper), | ||
device=None, | ||
) | ||
|
||
torch_model.to(torch.bfloat16) | ||
self.torch_input_tensor = self.torch_input_tensor.to(torch.bfloat16) | ||
|
||
## golden | ||
|
||
self.torch_output_tensor = torch_model(self.torch_input_tensor) | ||
|
||
## ttnn | ||
|
||
self.ttnn_resnet50_model = resnet50( | ||
device=device, | ||
parameters=parameters, | ||
batch_size=batch_size, | ||
model_config=model_config, | ||
dealloc_input=dealloc_input, | ||
final_output_mem_config=final_output_mem_config, | ||
mesh_mapper=weights_mesh_mapper, | ||
) | ||
self.ops_parallel_config = {} | ||
|
||
def preprocess_torch_input(self, torch_input_tensor=None): | ||
torch_input_tensor = self.torch_input_tensor if torch_input_tensor is None else torch_input_tensor | ||
self.input_tensor = self.ttnn_resnet50_model.preprocessing(torch_input_tensor, self.inputs_mesh_mapper) | ||
|
||
def run(self, torch_input_tensor=None): | ||
# Note: currently not including the time to flip from torch to ttnn tensors. | ||
# self.preprocess_torch_input(torch_input_tensor) | ||
self.output_tensor = self.ttnn_resnet50_model(self.input_tensor, self.device, self.ops_parallel_config) | ||
return self.output_tensor | ||
|
||
def validate(self, output_tensor=None): | ||
output_tensor = self.output_tensor if output_tensor is None else output_tensor | ||
output_tensor = ttnn.to_torch(output_tensor, device=self.device, mesh_composer=self.output_mesh_composer) | ||
output_tensor = torch.reshape(output_tensor, (output_tensor.shape[0], 1000)) | ||
|
||
valid_pcc = 1.0 | ||
if self.batch_size >= 8: | ||
valid_pcc = golden_pcc[self.batch_size][(self.math_fidelity, self.weight_dtype, self.act_dtype)] | ||
else: | ||
if self.act_dtype == ttnn.bfloat8_b: | ||
if self.math_fidelity == ttnn.MathFidelity.LoFi: | ||
valid_pcc = 0.87 | ||
else: | ||
valid_pcc = 0.94 | ||
else: | ||
if self.math_fidelity == ttnn.MathFidelity.LoFi: | ||
valid_pcc = 0.93 | ||
else: | ||
valid_pcc = 0.982 | ||
print(valid_pcc) | ||
self.pcc_passed, self.pcc_message = assert_with_pcc(self.torch_output_tensor, output_tensor, pcc=valid_pcc) | ||
|
||
logger.info( | ||
f"ResNet50 batch_size={self.batch_size}, act_dtype={self.act_dtype}, weight_dtype={self.weight_dtype}, math_fidelity={self.math_fidelity}, PCC={self.pcc_message}" | ||
) | ||
|
||
|
||
def create_test_infra( | ||
device, | ||
batch_size, | ||
act_dtype, | ||
weight_dtype, | ||
math_fidelity, | ||
use_pretrained_weight=True, | ||
dealloc_input=True, | ||
final_output_mem_config=ttnn.L1_MEMORY_CONFIG, | ||
inputs_mesh_mapper=None, | ||
weights_mesh_mapper=None, | ||
output_mesh_composer=None, | ||
): | ||
return ResNet50TestInfra( | ||
device, | ||
batch_size, | ||
act_dtype, | ||
weight_dtype, | ||
math_fidelity, | ||
use_pretrained_weight, | ||
dealloc_input, | ||
final_output_mem_config, | ||
inputs_mesh_mapper, | ||
weights_mesh_mapper, | ||
output_mesh_composer, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.