Skip to content

Commit

Permalink
#0: resolving PR issues
Browse files Browse the repository at this point in the history
  • Loading branch information
mbahnasTT committed Feb 19, 2025
1 parent f34a56d commit 7847c35
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 94 deletions.
4 changes: 2 additions & 2 deletions models/demos/yolov4/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

- Use the following command to run the yolov4 with a giraffe image:
```
pytest models/demos/yolov4/demo/test_ttnn_yolov4.py
pytest models/demos/yolov4/demo/demo.py
```

- Use the following command to run the yolov4 with different input image:
```
pytest --disable-warnings --input-path=<PATH_TO_INPUT_IMAGE> models/demos/yolov4/demo/test_ttnn_yolov4.py
pytest --disable-warnings --input-path=<PATH_TO_INPUT_IMAGE> models/demos/yolov4/demo/demo.py
```

Once you run the command, The output file named `ttnn_prediction_demo.jpg` will be generated.
110 changes: 110 additions & 0 deletions models/demos/yolov4/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,3 +554,113 @@ def do_detect(model, img, conf_thresh, nms_thresh, n_classes, device=None, class
class_names = load_class_names(class_name)
img = cv2.imread(imgfile)
plot_boxes_cv2(img, boxes[0], "torch_prediction_demo.jpg", class_names)


def gen_yolov4_boxes_confs(output):
n_classes = 80

yolo1 = YoloLayer(
anchor_mask=[0, 1, 2],
num_classes=n_classes,
anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
num_anchors=9,
stride=8,
)

yolo2 = YoloLayer(
anchor_mask=[3, 4, 5],
num_classes=n_classes,
anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
num_anchors=9,
stride=16,
)

yolo3 = YoloLayer(
anchor_mask=[6, 7, 8],
num_classes=n_classes,
anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
num_anchors=9,
stride=32,
)

y1 = yolo1(output[0])
y2 = yolo2(output[1])
y3 = yolo3(output[2])

return y1, y2, y3


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_yolov4(device, reset_seeds, model_location_generator):
torch.manual_seed(0)
model_path = model_location_generator("models", model_subdir="Yolo")

if model_path == "models":
if not os.path.exists("tests/ttnn/integration_tests/yolov4/yolov4.pth"): # check if yolov4.th is availble
os.system(
"tests/ttnn/integration_tests/yolov4/yolov4_weights_download.sh"
) # execute the yolov4_weights_download.sh file

weights_pth = "tests/ttnn/integration_tests/yolov4/yolov4.pth"
else:
weights_pth = str(model_path / "yolov4.pth")

ttnn_model = TtYOLOv4(weights_pth, device)

imgfile = "models/demos/yolov4/demo/giraffe_320.jpg"
width = 320
height = 320
img = cv2.imread(imgfile)
img = cv2.resize(img, (width, height))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if type(img) == np.ndarray and len(img.shape) == 3: # cv2 image
img = torch.from_numpy(img.transpose(2, 0, 1)).float().div(255.0).unsqueeze(0)
elif type(img) == np.ndarray and len(img.shape) == 4:
img = torch.from_numpy(img.transpose(0, 3, 1, 2)).float().div(255.0)
torch_input = torch.autograd.Variable(img)

input_tensor = torch.permute(torch_input, (0, 2, 3, 1))
ttnn_input = ttnn.from_torch(input_tensor, ttnn.bfloat16)

torch_model = Yolov4()

new_state_dict = {}
ds_state_dict = {k: v for k, v in ttnn_model.torch_model.items()}

keys = [name for name, parameter in torch_model.state_dict().items()]
values = [parameter for name, parameter in ds_state_dict.items()]

for i in range(len(keys)):
new_state_dict[keys[i]] = values[i]

torch_model.load_state_dict(new_state_dict)
torch_model.eval()

torch_output_tensor = torch_model(torch_input)

ref1, ref2, ref3 = gen_yolov4_boxes_confs(torch_output_tensor)
ref_boxes, ref_confs = get_region_boxes([ref1, ref2, ref3])

ttnn_output_tensor = ttnn_model(ttnn_input)
result_boxes_padded = ttnn.to_torch(ttnn_output_tensor[0])
result_confs = ttnn.to_torch(ttnn_output_tensor[1])

result_boxes_padded = result_boxes_padded.permute(0, 2, 1, 3)
result_boxes_list = []
# Unpadding
result_boxes_list.append(result_boxes_padded[:, 0:6100])
result_boxes_list.append(result_boxes_padded[:, 6128:6228])
result_boxes_list.append(result_boxes_padded[:, 6256:6356])
result_boxes = torch.cat(result_boxes_list, dim=1)

## Giraffe image detection
conf_thresh = 0.3
nms_thresh = 0.4
output = [result_boxes.to(torch.float16), result_confs.to(torch.float16)]

boxes = post_processing(img, conf_thresh, nms_thresh, output)
namesfile = "models/demos/yolov4/demo/coco.names"
class_names = load_class_names(namesfile)
img = cv2.imread(imgfile)
plot_boxes_cv2(img, boxes[0], "ttnn_yolov4_320_prediction_demo.jpg", class_names)
4 changes: 2 additions & 2 deletions models/demos/yolov4/tests/test_perf_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_yolov4(

logger.info(f"Compiling model with warmup run")
profiler.start(f"inference_and_compile_time")
ttnn_output_tensor = ttnn_model(device, ttnn_input)
ttnn_output_tensor = ttnn_model(ttnn_input)

profiler.end(f"inference_and_compile_time")

Expand All @@ -80,7 +80,7 @@ def test_yolov4(
for idx in range(iterations):
profiler.start("inference_time")
profiler.start(f"inference_time_{idx}")
ttnn_output_tensor = ttnn_model(device, ttnn_input)
ttnn_output_tensor = ttnn_model(ttnn_input)

profiler.end(f"inference_time_{idx}")
profiler.end("inference_time")
Expand Down
7 changes: 4 additions & 3 deletions models/demos/yolov4/ttnn/yolov4.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, path, device) -> None:
self.boxes_confs_2 = TtGenBoxes(device)

self.downs = [] # [self.down1]
self.device = device

def __call__(self, input_tensor):
d1 = self.down1(input_tensor)
Expand All @@ -61,9 +62,9 @@ def __call__(self, input_tensor):
if orig:
return x4, x5, x6
else:
x4_boxes_confs = self.boxes_confs_0(device, x4)
x5_boxes_confs = self.boxes_confs_1(device, x5)
x6_boxes_confs = self.boxes_confs_2(device, x6)
x4_boxes_confs = self.boxes_confs_0(self.device, x4)
x5_boxes_confs = self.boxes_confs_1(self.device, x5)
x6_boxes_confs = self.boxes_confs_2(self.device, x6)

confs_1 = ttnn.to_layout(x4_boxes_confs[1], ttnn.ROW_MAJOR_LAYOUT)
confs_2 = ttnn.to_layout(x5_boxes_confs[1], ttnn.ROW_MAJOR_LAYOUT)
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/run_python_model_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ run_python_model_tests_wormhole_b0() {
# higher sequence lengths and different formats trigger memory issues
pytest models/demos/falcon7b_common/tests/unit_tests/test_falcon_matmuls_and_bmms_with_mixed_precision.py -k "seq_len_128 and in0_BFLOAT16-in1_BFLOAT8_B-out_BFLOAT16-weights_DRAM"
pytest tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50.py -k "pretrained_weight_false"
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest tests/ttnn/integration_tests/yolov4/test_ttnn_yolov4.py -k "pretrained_weight_false"
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/yolov4/demo/demo.py -k "pretrained_weight_false"

# Unet Shallow
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -svv models/experimental/functional_unet/tests/test_unet_model.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ def test_down3(device, reset_seeds, model_location_generator):
ref = torch_model(torch_input)
ref = ref.permute(0, 2, 3, 1)
result = result.reshape(ref.shape)
assert_with_pcc(result, ref, 0.95) # PCC 0.95 - The PCC will improve once #3612 is resolved.
assert_with_pcc(result, ref, 0.96) # PCC 0.96 - The PCC will improve once #3612 is resolved.
2 changes: 1 addition & 1 deletion tests/ttnn/integration_tests/yolov4/test_ttnn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ttnn
from models.demos.yolov4.reference.head import Head
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_grayskull, skip_for_wormhole_b0
from models.utility_functions import skip_for_grayskull
import pytest
import time
from models.demos.yolov4.ttnn.head import TtHead
Expand Down
1 change: 1 addition & 0 deletions tests/ttnn/integration_tests/yolov4/test_ttnn_neck.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_neck(device, reset_seeds, model_location_generator):
torch.manual_seed(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,6 @@ def test_yolov4_post_processing(device, reset_seeds, model_location_generator):
result_2_conf = ttnn.to_torch(result_2[1])
result_3_conf = ttnn.to_torch(result_3[1])

# result_1_conf = result_1_conf.permute(0, 1, 3, 2)
# result_2_conf = result_2_conf.permute(0, 1, 3, 2)
# result_3_conf= result_3_conf.permute(0, 1, 3, 2)

# result_1_conf = result_1_conf.reshape(1, 4800, 80)
# result_2_conf = result_2_conf.reshape(1, 1200, 80)
# result_3_conf= result_3_conf.reshape(1, 300, 80)

assert_with_pcc(ref1[0], result_1_bb, 0.99)
assert_with_pcc(ref2[0], result_2_bb, 0.99)
assert_with_pcc(ref3[0], result_3_bb, 0.99)
Expand Down
6 changes: 4 additions & 2 deletions tests/ttnn/integration_tests/yolov4/test_ttnn_yolov4.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ttnn
from models.demos.yolov4.reference.yolov4 import Yolov4
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_grayskull, skip_for_wormhole_b0
from models.utility_functions import skip_for_grayskull
from models.demos.yolov4.ttnn.yolov4 import TtYOLOv4
from models.demos.yolov4.demo.demo import YoloLayer, get_region_boxes, post_processing, plot_boxes_cv2, load_class_names
import cv2
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_yolov4(device, reset_seeds, model_location_generator):
ref1, ref2, ref3 = gen_yolov4_boxes_confs(torch_output_tensor)
ref_boxes, ref_confs = get_region_boxes([ref1, ref2, ref3])

ttnn_output_tensor = ttnn_model(device, ttnn_input)
ttnn_output_tensor = ttnn_model(ttnn_input)
result_boxes_padded = ttnn.to_torch(ttnn_output_tensor[0])
result_confs = ttnn.to_torch(ttnn_output_tensor[1])

Expand All @@ -117,6 +117,7 @@ def test_yolov4(device, reset_seeds, model_location_generator):
assert_with_pcc(ref_boxes, result_boxes, 0.99)
assert_with_pcc(ref_confs, result_confs, 0.71)

"""
## Giraffe image detection
conf_thresh = 0.3
nms_thresh = 0.4
Expand All @@ -127,3 +128,4 @@ def test_yolov4(device, reset_seeds, model_location_generator):
class_names = load_class_names(namesfile)
img = cv2.imread(imgfile)
plot_boxes_cv2(img, boxes[0], "ttnn_prediction_demo.jpg", class_names)
"""
74 changes: 0 additions & 74 deletions tests/ttnn/integration_tests/yolov4/test_ttnn_yolov4_orig.py

This file was deleted.

0 comments on commit 7847c35

Please sign in to comment.