diff --git a/models/experimental/yolov11/README.md b/models/experimental/yolov11/README.md new file mode 100644 index 00000000000..2875db7b45d --- /dev/null +++ b/models/experimental/yolov11/README.md @@ -0,0 +1,32 @@ +## YOLOv11n - Model + +### Introduction + +**YOLOv11** is the latest iteration in the YOLO series, offering improvements in accuracy, speed, and efficiency for real-time object detection. It features enhanced architecture and optimized training methods, suitable for various computer vision tasks. + +### Model Details + +* **Entry Point:** `models/experimental/yolov11/tt/ttnn_yolov11.py` +* **Weights:** `models/experimental/yolov11/reference/yolov11n.pt` + +### Batch Size + +* Default: 1 +* Recommended: 1 for optimal performance + +### Running YOLOv11 Demo + +* **Single Image (640x640x3 or 224x224x3):** `pytest models/experimental/yolov11/demo/demo.py` +* **Dataset Evaluation:** `pytest models/experimental/yolov11/demo/evaluate.py` + * Validation accuracy: 0.5616 on 250 images (coco-2017) + +### Input and Output Data + +* **Input Directory:** `models/experimental/yolov11/demo/images` +* **Output Directory:** `models/experimental/yolov11/demo/runs` + * Torch model output: `torch_model` + * TTNN model output: `tt_model` + +### Pending Issues + +* [#17385](https://github.com/tenstorrent/tt-metal/issues/17835) - Tracing fails in Yolov11n model diff --git a/models/experimental/yolov11/demo/demo.py b/models/experimental/yolov11/demo/demo.py new file mode 100644 index 00000000000..eac2092d02b --- /dev/null +++ b/models/experimental/yolov11/demo/demo.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +import os +import cv2 +import sys +import ttnn +import torch +import pytest +import torch.nn as nn +from loguru import logger +from datetime import datetime +from models.utility_functions import disable_persistent_kernel_cache +from models.experimental.yolov11.reference import yolov11 +from models.experimental.yolov11.reference.yolov11 import attempt_load +from models.experimental.yolov11.tt import ttnn_yolov11 +from models.experimental.yolov11.tt.model_preprocessing import ( + create_yolov11_input_tensors, + create_yolov11_model_parameters, +) +from models.experimental.yolov11.demo.demo_utils import LoadImages, preprocess, postprocess + +try: + sys.modules["ultralytics"] = yolov11 + sys.modules["ultralytics.nn.tasks"] = yolov11 + sys.modules["ultralytics.nn.modules.conv"] = yolov11 + sys.modules["ultralytics.nn.modules.block"] = yolov11 + sys.modules["ultralytics.nn.modules.head"] = yolov11 +except KeyError: + print("models.experimental.yolov11.reference.yolov11 not found.") + + +def save_yolo_predictions_by_model(result, save_dir, image_path, model_name): + model_save_dir = os.path.join(save_dir, model_name) + os.makedirs(model_save_dir, exist_ok=True) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + if model_name == "torch_model": + bounding_box_color, label_color = (0, 255, 0), (0, 255, 0) + else: + bounding_box_color, label_color = (255, 0, 0), (255, 0, 0) + + boxes = result["boxes"]["xyxy"] + scores = result["boxes"]["conf"] + classes = result["boxes"]["cls"] + names = result["names"] + + for box, score, cls in zip(boxes, scores, classes): + x1, y1, x2, y2 = map(int, box) + label = f"{names[int(cls)]} {score.item():.2f}" + cv2.rectangle(image, (x1, y1), (x2, y2), bounding_box_color, 3) + cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, label_color, 2) + + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_name = f"prediction_{timestamp}.jpg" + output_path = os.path.join(model_save_dir, output_name) + cv2.imwrite(output_path, image) + print(f"Predictions saved to {output_path}") + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@pytest.mark.parametrize( + "source, model_type,resolution", + [ + # 224*224 + # ("models/experimental/yolov11/demo/images/cycle_girl.jpg", "torch_model", [3, 224, 224]), + # ("models/experimental/yolov11/demo/images/cycle_girl.jpg", "tt_model", [3, 224, 224]), + # ("models/experimental/yolov11/demo/images/dog.jpg", "torch_model", [3, 224, 224]), + # ("models/experimental/yolov11/demo/images/dog.jpg", "tt_model", [3, 224, 224]), + # 640*640 + # ("models/experimental/yolov11/demo/images/cycle_girl.jpg", "torch_model", [3, 640, 640]), + ("models/experimental/yolov11/demo/images/cycle_girl.jpg", "tt_model", [3, 640, 640]), + # ("models/experimental/yolov11/demo/images/dog.jpg", "torch_model", [3, 640, 640]), + # ("models/experimental/yolov11/demo/images/dog.jpg", "tt_model", [3, 640, 640]), + ], +) +def test_demo(device, source, model_type, resolution): + disable_persistent_kernel_cache() + state_dict = attempt_load("yolov11n.pt", map_location="cpu").state_dict() + model = yolov11.YoloV11() + ds_state_dict = {k: v for k, v in state_dict.items()} + new_state_dict = {} + for (name1, parameter1), (name2, parameter2) in zip(model.state_dict().items(), ds_state_dict.items()): + if isinstance(parameter2, torch.FloatTensor): + new_state_dict[name1] = parameter2 + model.load_state_dict(new_state_dict) + if model_type == "torch_model": + model.eval() + logger.info("Inferencing using Torch Model") + else: + torch_input, ttnn_input = create_yolov11_input_tensors( + device, input_channels=resolution[0], input_height=resolution[1], input_width=resolution[2] + ) + parameters = create_yolov11_model_parameters(model, torch_input, device=device) + model = ttnn_yolov11.YoloV11(device, parameters) + logger.info("Inferencing using ttnn Model") + + save_dir = "models/experimental/yolov11/demo/runs" + dataset = LoadImages(path=source) + model_save_dir = os.path.join(save_dir, model_type) + os.makedirs(model_save_dir, exist_ok=True) + + names = { + 0: "person", + 1: "bicycle", + 2: "car", + 3: "motorcycle", + 4: "airplane", + 5: "bus", + 6: "train", + 7: "truck", + 8: "boat", + 9: "traffic light", + 10: "fire hydrant", + 11: "stop sign", + 12: "parking meter", + 13: "bench", + 14: "bird", + 15: "cat", + 16: "dog", + 17: "horse", + 18: "sheep", + 19: "cow", + 20: "elephant", + 21: "bear", + 22: "zebra", + 23: "giraffe", + 24: "backpack", + 25: "umbrella", + 26: "handbag", + 27: "tie", + 28: "suitcase", + 29: "frisbee", + 30: "skis", + 31: "snowboard", + 32: "sports ball", + 33: "kite", + 34: "baseball bat", + 35: "baseball glove", + 36: "skateboard", + 37: "surfboard", + 38: "tennis racket", + 39: "bottle", + 40: "wine glass", + 41: "cup", + 42: "fork", + 43: "knife", + 44: "spoon", + 45: "bowl", + 46: "banana", + 47: "apple", + 48: "sandwich", + 49: "orange", + 50: "broccoli", + 51: "carrot", + 52: "hot dog", + 53: "pizza", + 54: "donut", + 55: "cake", + 56: "chair", + 57: "couch", + 58: "potted plant", + 59: "bed", + 60: "dining table", + 61: "toilet", + 62: "TV", + 63: "laptop", + 64: "mouse", + 65: "remote", + 66: "keyboard", + 67: "cell phone", + 68: "microwave", + 69: "oven", + 70: "toaster", + 71: "sink", + 72: "refrigerator", + 73: "book", + 74: "clock", + 75: "vase", + 76: "scissors", + 77: "teddy bear", + 78: "hair drier", + 79: "toothbrush", + } + + for batch in dataset: + paths, im0s, s = batch + im = preprocess(im0s, resolution) + if model_type == "torch_model": + preds = model(im) + else: + img = torch.permute(im, (0, 2, 3, 1)) + img = img.reshape( + 1, + 1, + img.shape[0] * img.shape[1] * img.shape[2], + img.shape[3], + ) + ttnn_im = ttnn.from_torch(img, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat8_b) + preds = model(x=ttnn_im) + preds = ttnn.to_torch(preds, dtype=torch.float32) + results = postprocess(preds, im, im0s, batch, names)[0] + save_yolo_predictions_by_model(results, save_dir, source, model_type) + print("Inference done") diff --git a/models/experimental/yolov11/demo/demo_utils.py b/models/experimental/yolov11/demo/demo_utils.py new file mode 100644 index 00000000000..d603baae08f --- /dev/null +++ b/models/experimental/yolov11/demo/demo_utils.py @@ -0,0 +1,301 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os +import cv2 +import math +import time +import torch +import torchvision +import numpy as np +from pathlib import Path +from loguru import logger + + +def imread(filename: str, flags: int = cv2.IMREAD_COLOR): + return cv2.imdecode(np.fromfile(filename, np.uint8), flags) + + +IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm", "heic"} + + +class LoadImages: + def __init__(self, path, batch=1, vid_stride=1): + files = [] + for p in sorted(path) if isinstance(path, (list, tuple)) else [path]: + a = str(Path(p).absolute()) + if os.path.isfile(a): + files.append(a) + else: + raise FileNotFoundError(f"{p} does not exist") + + images = [] + for f in files: + suffix = f.split(".")[-1].lower() + if suffix in IMG_FORMATS: + images.append(f) + ni = len(images) + + self.files = images + self.nf = ni + self.ni = ni + self.mode = "image" + self.vid_stride = vid_stride + self.bs = batch + if self.nf == 0: + raise FileNotFoundError(f"No images or videos found in {p}") + + def __iter__(self): + self.count = 0 + return self + + def __next__(self): + paths, imgs, info = [], [], [] + while len(imgs) < self.bs: + if self.count >= self.nf: + if imgs: + return paths, imgs, info + else: + raise StopIteration + + path = self.files[self.count] + + self.mode = "image" + im0 = imread(path) + if im0 is None: + logger.warning(f"WARNING ⚠️ Image Read Error {path}") + else: + paths.append(path) + imgs.append(im0) + info.append(f"image {self.count + 1}/{self.nf} {path}: ") + self.count += 1 + if self.count >= self.ni: + break + + return paths, imgs, info + + def _new_video(self, path): + self.frame = 0 + self.cap = cv2.VideoCapture(path) + self.fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + if not self.cap.isOpened(): + raise FileNotFoundError(f"Failed to open video {path}") + self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) + + def __len__(self): + return math.ceil(self.nf / self.bs) + + +def LetterBox(img, new_shape=(224, 224), auto=False, scaleFill=False, scaleup=True, center=True, stride=32): + shape = img.shape[:2] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: + r = min(r, 1.0) + + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] + if auto: + dw, dh = np.mod(dw, stride), np.mod(dh, stride) + + if center: + dw /= 2 + dh /= 2 + + if shape[::-1] != new_unpad: + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)) if center else 0, int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)) if center else 0, int(round(dw + 0.1)) + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) + + return img + + +def pre_transform(im, LetterBox_shape=(224, 224)): + return [LetterBox(img=x, new_shape=LetterBox_shape) for x in im] + + +def preprocess(im, resolution): + device = "cpu" + not_tensor = not isinstance(im, torch.Tensor) + if not_tensor: + if resolution[1] == 224: + LetterBox_shape = (224, 224) + else: + LetterBox_shape = (640, 640) + im = np.stack(pre_transform(im, LetterBox_shape)) + im = im[..., ::-1].transpose((0, 3, 1, 2)) + im = np.ascontiguousarray(im) + im = torch.from_numpy(im) + + im = im.half() if device != "cpu" else im.float() + if not_tensor: + im /= 255 + return im + + +def empty_like(x): + return ( + torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32) + ) + + +def xywh2xyxy(x): + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) + xy = x[..., :2] + wh = x[..., 2:] / 2 + y[..., :2] = xy - wh + y[..., 2:] = xy + wh + return y + + +def non_max_suppression( + prediction, + conf_thres=0.25, + iou_thres=0.45, + classes=None, + agnostic=False, + multi_label=False, + labels=(), + max_det=300, + nc=0, + max_time_img=0.05, + max_nms=30000, + max_wh=7680, + in_place=True, + rotated=False, +): + assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" + assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" + + if isinstance(prediction, (list, tuple)): + prediction = prediction[0] + if classes is not None: + classes = torch.tensor(classes, device=prediction.device) + + if prediction.shape[-1] == 6: + output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction] + if classes is not None: + output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output] + return output + + bs = prediction.shape[0] + nc = nc or (prediction.shape[1] - 4) + nm = prediction.shape[1] - nc - 4 + mi = 4 + nc + xc = prediction[:, 4:mi].amax(1) > conf_thres + + time_limit = 2.0 + max_time_img * bs + multi_label &= nc > 1 + + prediction = prediction.transpose(-1, -2) + if not rotated: + if in_place: + prediction[..., :4] = xywh2xyxy(prediction[..., :4]) + else: + prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) + + t = time.time() + output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs + for xi, x in enumerate(prediction): + x = x[xc[xi]] + + if not x.shape[0]: + continue + + box, cls, mask = x.split((4, nc, nm), 1) + + if multi_label: + i, j = torch.where(cls > conf_thres) + x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1) + else: + conf, j = cls.max(1, keepdim=True) + x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] + + if classes is not None: + x = x[(x[:, 5:6] == classes).any(1)] + + n = x.shape[0] + if not n: + continue + if n > max_nms: + x = x[x[:, 4].argsort(descending=True)[:max_nms]] + + c = x[:, 5:6] * (0 if agnostic else max_wh) + scores = x[:, 4] + + boxes = x[:, :4] + c + i = torchvision.ops.nms(boxes, scores, iou_thres) + i = i[:max_det] + + output[xi] = x[i] + if (time.time() - t) > time_limit: + logger.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded") + break + + return output + + +def Boxes(data): + return {"xyxy": data[:, :4], "conf": data[:, -2], "cls": data[:, -1]} + + +def Results(orig_img, path, names, boxes): + return {"orig_img": orig_img, "path": path, "names": names, "boxes": Boxes(boxes)} + + +def clip_boxes(boxes, shape): + if isinstance(boxes, torch.Tensor): + boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) + boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) + boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) + boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) + else: + boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) + boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) + return boxes + + +def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False): + if ratio_pad is None: + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) + pad = ( + round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), + round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1), + ) + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + if padding: + boxes[..., 0] -= pad[0] + boxes[..., 1] -= pad[1] + if not xywh: + boxes[..., 2] -= pad[0] + boxes[..., 3] -= pad[1] + boxes[..., :4] /= gain + return clip_boxes(boxes, img0_shape) + + +def postprocess(preds, img, orig_imgs, batch, names): + args = {"conf": 0.5, "iou": 0.7, "agnostic_nms": False, "max_det": 300, "classes": None} + + preds = non_max_suppression( + preds, + args["conf"], + args["iou"], + agnostic=args["agnostic_nms"], + max_det=args["max_det"], + classes=args["classes"], + ) + + results = [] + for pred, orig_img, img_path in zip(preds, orig_imgs, batch[0]): + pred[:, :4] = scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + results.append(Results(orig_img, path=img_path, names=names, boxes=pred)) + + return results diff --git a/models/experimental/yolov11/demo/evaluate.py b/models/experimental/yolov11/demo/evaluate.py new file mode 100644 index 00000000000..90ec0f8d298 --- /dev/null +++ b/models/experimental/yolov11/demo/evaluate.py @@ -0,0 +1,436 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import fiftyone +import os +import json +import torch +import cv2 +from datetime import datetime +import ttnn +from functools import partial +from loguru import logger +import sys +from tqdm import tqdm +from models.utility_functions import disable_persistent_kernel_cache +import pytest +from models.experimental.yolov11.demo.demo_utils import LoadImages, preprocess, postprocess +from torch import nn +import numpy as np +import shutil +from sklearn.metrics import precision_recall_curve, average_precision_score +import warnings + +warnings.filterwarnings("ignore") + + +def iou(pred_box, gt_box): + """Calculate IoU (Intersection over Union) between two bounding boxes.""" + x1_pred, y1_pred, x2_pred, y2_pred = pred_box[:4] + x1_gt, y1_gt, x2_gt, y2_gt = gt_box + + # Calculate the intersection area + ix = max(0, min(x2_pred, x2_gt) - max(x1_pred, x1_gt)) + iy = max(0, min(y2_pred, y2_gt) - max(y1_pred, y1_gt)) + intersection = ix * iy + + # Calculate the union area + union = (x2_pred - x1_pred) * (y2_pred - y1_pred) + (x2_gt - x1_gt) * (y2_gt - y1_gt) - intersection + return intersection / union + + +def calculate_map(predictions, ground_truths, iou_threshold=0.5, num_classes=3): + """Calculate mAP for object detection.""" + ap_scores = [] + + # Iterate through each class + for class_id in range(num_classes): + y_true = [] + y_scores = [] + + for pred, gt in zip(predictions, ground_truths): + pred_boxes = [p for p in pred if p[5] == class_id] + gt_boxes = [g for g in gt if g[4] == class_id] + + for pred_box in pred_boxes: + best_iou = 0 + matched_gt = None + + for gt_box in gt_boxes: + iou_score = iou(pred_box[:4], gt_box[:4]) # Compare the [x1, y1, x2, y2] part of the box + if iou_score > best_iou: + best_iou = iou_score + matched_gt = gt_box + + # If IoU exceeds threshold, consider it a true positive + if best_iou >= iou_threshold: + y_true.append(1) # True Positive + y_scores.append(pred_box[4]) + gt_boxes.remove(matched_gt) # Remove matched ground truth + else: + y_true.append(0) # False Positive + y_scores.append(pred_box[4]) + + # Ground truth boxes that were not matched are false negatives + for gt_box in gt_boxes: + y_true.append(0) # False Negative + y_scores.append(0) # No detection + if len(y_true) == 0 or len(y_scores) == 0: + # print(f"No predictions or ground truth for class {class_id}") + continue + + # Calculate precision-recall and average precision for this class + precision, recall, _ = precision_recall_curve(y_true, y_scores) + ap = average_precision_score(y_true, y_scores) + ap_scores.append(ap) + + # Calculate mAP as the mean of the AP scores + mAP = np.mean(ap_scores) + return mAP + + +class Ensemble(nn.ModuleList): + def __init__(self): + super(Ensemble, self).__init__() + + def forward(self, x, augment=False): + y = [] + for module in self: + y.append(module(x, augment)[0]) + y = torch.cat(y, 1) + return y, None + + +def attempt_load(weights, model_path, map_location=None): + model = Ensemble() + for w in weights if isinstance(weights, list) else [weights]: + w = model_path # depends on model which we take + ckpt = torch.load(w, map_location=map_location) + model.append(ckpt["ema" if ckpt.get("ema") else "model"].float().eval()) + for m in model.modules(): + if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: + m.inplace = True + elif type(m) is nn.Upsample: + m.recompute_scale_factor = None + + if len(model) == 1: + return model[-1] + else: + for k in ["names", "stride"]: + setattr(model, k, getattr(model[-1], k)) + return model + + +def save_yolo_predictions_by_model(result, save_dir, image_path, model_name): + model_save_dir = os.path.join(save_dir, model_name) + os.makedirs(model_save_dir, exist_ok=True) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + if model_name == "torch_model": + bounding_box_color, label_color = (0, 255, 0), (0, 255, 0) + else: + bounding_box_color, label_color = (255, 0, 0), (255, 255, 0) + + boxes = result["boxes"]["xyxy"] + scores = result["boxes"]["conf"] + classes = result["boxes"]["cls"] + names = result["names"] + + for box, score, cls in zip(boxes, scores, classes): + x1, y1, x2, y2 = map(int, box) + label = f"{names[int(cls)]} {score.item():.2f}" + cv2.rectangle(image, (x1, y1), (x2, y2), bounding_box_color, 3) + cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, label_color, 2) + + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_name = f"prediction_{timestamp}.jpg" + output_path = os.path.join(model_save_dir, output_name) + + cv2.imwrite(output_path, image) + + print(f"Predictions saved to {output_path}") + + +def evaluation( + device, + res, + model_type, + model, + parameters, + input_dtype, + input_layout, + save_dir, + model_name=None, + additional_layer=None, +): + disable_persistent_kernel_cache() + + dataset = fiftyone.zoo.load_zoo_dataset( + "coco-2017", + split="validation", + max_samples=250, + ) + + source_list = [i["filepath"] for i in dataset] + data_set = LoadImages(path=[i["filepath"] for i in dataset]) + + with open("/home/ubuntu/fiftyone/coco-2017/info.json", "r") as file: + # Parse the JSON data + data = json.load(file) + classes = data["classes"] + + model_save_dir = os.path.join(save_dir, model_type) + os.makedirs(model_save_dir, exist_ok=True) + + index = 0 + predicted_bbox = [] + for batch in tqdm(data_set, desc="Processing dataset"): + sample = [] + paths, im0s, s = batch + if model_name == "YOLOv4": + sized = cv2.resize(im0s[0], (res[0], res[1])) + sized = cv2.cvtColor(sized, cv2.COLOR_BGR2RGB) + if type(sized) == np.ndarray and len(sized.shape) == 3: # cv2 image + img = torch.from_numpy(sized.transpose(2, 0, 1)).float().div(255.0).unsqueeze(0) + elif type(sized) == np.ndarray and len(sized.shape) == 4: + img = torch.from_numpy(sized.transpose(0, 3, 1, 2)).float().div(255.0) + else: + im = preprocess(im0s, resolution=res) + + if model_name == "YOLOv4": + input_shape = img.shape + input_tensor = torch.permute(img, (0, 2, 3, 1)) + # input_tensor = ttnn.from_torch(input_tensor, ttnn.bfloat16) + input_tensor = torch.permute(img, (0, 2, 3, 1)) # put channel at the end + input_tensor = torch.nn.functional.pad( + input_tensor, (0, 13, 0, 0, 0, 0, 0, 0) + ) # pad channel dim from 3 to 16 + N, H, W, C = input_tensor.shape + input_tensor = torch.reshape(input_tensor, (N, 1, H * W, C)) + + shard_grid = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(7, 7), + ), + } + ) + n_cores = 64 + shard_spec = ttnn.ShardSpec(shard_grid, [N * H * W // n_cores, C], ttnn.ShardOrientation.ROW_MAJOR) + input_mem_config = ttnn.MemoryConfig( + ttnn.types.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.types.BufferType.L1, shard_spec + ) + ttnn_im = ttnn.from_torch( + input_tensor, + dtype=ttnn.bfloat16, + layout=ttnn.ROW_MAJOR_LAYOUT, + device=device, + memory_config=input_mem_config, + ) + else: + ttnn_im = im.permute((0, 2, 3, 1)) + if model_name == "YOLOv11": # only for yolov11 + ttnn_im = ttnn_im.reshape( + 1, + 1, + ttnn_im.shape[0] * ttnn_im.shape[1] * ttnn_im.shape[2], + ttnn_im.shape[3], + ) + if model_name != "YOLOv4": + ttnn_im = ttnn.from_torch(ttnn_im, dtype=input_dtype, layout=input_layout, device=device) + + if model_type == "torch_model": + preds = model(im) + else: + preds = model(ttnn_im) + if model_name == "YOLOv11": + preds = ttnn.to_torch(preds, dtype=torch.float32) + elif model_name == "YOLOv4": + output_tensor1 = ttnn.to_torch(preds[0]) + output_tensor1 = output_tensor1.reshape(1, 40, 40, 255) + output_tensor1 = torch.permute(output_tensor1, (0, 3, 1, 2)) + + output_tensor2 = ttnn.to_torch(preds[1]) + output_tensor2 = output_tensor2.reshape(1, 20, 20, 255) + output_tensor2 = torch.permute(output_tensor2, (0, 3, 1, 2)) + + output_tensor3 = ttnn.to_torch(preds[2]) + output_tensor3 = output_tensor3.reshape(1, 10, 10, 255) + output_tensor3 = torch.permute(output_tensor3, (0, 3, 1, 2)) + + yolo1 = additional_layer( + anchor_mask=[0, 1, 2], + num_classes=len(classes), + anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401], + num_anchors=9, + stride=8, + ) + + yolo2 = additional_layer( + anchor_mask=[3, 4, 5], + num_classes=len(classes), + anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401], + num_anchors=9, + stride=16, + ) + + yolo3 = additional_layer( + anchor_mask=[6, 7, 8], + num_classes=len(classes), + anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401], + num_anchors=9, + stride=32, + ) + y1 = yolo1(output_tensor1) + y2 = yolo2(output_tensor2) + y3 = yolo3(output_tensor3) + from models.demos.yolov4.demo.demo import get_region_boxes + + output = get_region_boxes([y1, y2, y3]) + + else: + preds[0] = ttnn.to_torch(preds[0], dtype=torch.float32) + if model_name == "YOLOv4": + from models.demos.yolov4.demo.demo import post_processing + + results = post_processing(img, 0.3, 0.4, output) + else: + results = postprocess(preds, im, im0s, batch, classes)[0] + + if model_name == "YOLOv4": + predicted_temp = results[0] + for i in predicted_temp: + del i[5] + predicted_bbox.append(predicted_temp) + index += 1 + else: + pred = results["boxes"]["xyxy"].tolist() + h, w = results["orig_img"].shape[0], results["orig_img"].shape[1] + + for index_of_prediction, (conf, values) in enumerate( + zip(results["boxes"]["conf"].tolist(), results["boxes"]["cls"].tolist()) + ): + pred[index_of_prediction][0] /= w # normalizing the output since groundtruth values are normalized + pred[index_of_prediction][1] /= h # normalizing the output since groundtruth values are normalized + pred[index_of_prediction][2] /= w # normalizing the output since groundtruth values are normalized + pred[index_of_prediction][3] /= h # normalizing the output since groundtruth values are normalized + pred[index_of_prediction].append(conf) + pred[index_of_prediction].append(int(values)) + + predicted_bbox.append(pred) + save_yolo_predictions_by_model(results, save_dir, source_list[index], model_type) + index += 1 + + ground_truth = [] + for i in tqdm(dataset, desc="Processing dataset"): + sample = [] + for j in i["ground_truth"]["detections"]: + bb_temp = j["bounding_box"] + bb_temp[2] += bb_temp[0] + bb_temp[3] += bb_temp[1] + bb_temp.append(classes.index(j["label"])) + sample.append(bb_temp) + ground_truth.append(sample) + + class_indices = [box[5] for image in predicted_bbox for box in image] + num_classes = max(class_indices) + 1 + + iou_thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95] + mAPval_50_95 = [] + for iou_threshold in iou_thresholds: + # Calculate mAP + mAP = calculate_map(predicted_bbox, ground_truth, num_classes=num_classes, iou_threshold=iou_threshold) + print(f"Mean Average Precision (mAP): {mAP:.4f} for IOU Threshold: {iou_threshold:.4f}") + mAPval_50_95.append(mAP) + + print("mAPval_50_95", mAPval_50_95) + mAPval50_95_value = sum(mAPval_50_95) / len(mAPval_50_95) + + print(f"Mean Average Precision for val 50-95 (mAPval 50-95): {mAPval50_95_value:.4f}") + + +@pytest.mark.parametrize( + "model_type", + [ + ("tt_model"), + ], +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@pytest.mark.parametrize("res", [(224, 224)]) +def test_yolo11n(device, model_type, res, reset_seeds): + from models.experimental.yolov11.tt import ttnn_yolov11 # depends on model which we take + from models.experimental.yolov11.reference import yolov11 # depends on model which we take + from models.experimental.yolov11.tt.model_preprocessing import ( + create_yolov11_input_tensors, + create_yolov11_model_parameters, + ) + + try: + sys.modules["ultralytics"] = yolov11 + sys.modules["ultralytics.nn.tasks"] = yolov11 + sys.modules["ultralytics.nn.modules.conv"] = yolov11 + sys.modules["ultralytics.nn.modules.block"] = yolov11 + sys.modules["ultralytics.nn.modules.head"] = yolov11 + + except KeyError: + print("models.experimental.yolov11.reference.yolov11_utils not found.") + + if model_type == "torch_model": + model = attempt_load( + "yolo11n", model_path="models/experimental/yolov11/reference/yolov11n.pt", map_location="cpu" + ) + state_dict = model.state_dict() + model = yolov11.YoloV11() + ds_state_dict = {k: v for k, v in state_dict.items()} + new_state_dict = {} + for (name1, parameter1), (name2, parameter2) in zip(model.state_dict().items(), ds_state_dict.items()): + if isinstance(parameter2, torch.FloatTensor): + new_state_dict[name1] = parameter2 + model.load_state_dict(new_state_dict) + model.eval() + logger.info("Inferencing using Torch Model") + else: + torch_input, ttnn_input = create_yolov11_input_tensors( + device, input_channels=3, input_height=224, input_width=224 + ) + torch_model = attempt_load( + "yolo11n", model_path="models/experimental/yolov11/reference/yolov11n.pt", map_location="cpu" + ) + state_dict = torch_model.state_dict() + torch_model = yolov11.YoloV11() + ds_state_dict = {k: v for k, v in state_dict.items()} + new_state_dict = {} + for (name1, parameter1), (name2, parameter2) in zip(torch_model.state_dict().items(), ds_state_dict.items()): + if isinstance(parameter2, torch.FloatTensor): + new_state_dict[name1] = parameter2 + torch_model.load_state_dict(new_state_dict) + torch_model.eval() + parameters = create_yolov11_model_parameters(torch_model, torch_input, device=device) + model = ttnn_yolov11.YoloV11(device, parameters) + logger.info("Inferencing using ttnn Model") + + save_dir = "models/experimental/yolov11/demo/runs" + model_name = "YOLOv11" + + model_path = "models/experimental/yolov11/reference/yolov11n.pt" + + input_layout = ttnn.TILE_LAYOUT + input_dtype = ttnn.bfloat8_b + + evaluation( + device=device, + res=res, + model_type=model_type, + model=model, + parameters=None, + input_dtype=input_dtype, + input_layout=input_layout, + save_dir=save_dir, + model_name=model_name, + ) diff --git a/models/experimental/yolov11/demo/images/cycle_girl.jpg b/models/experimental/yolov11/demo/images/cycle_girl.jpg new file mode 100644 index 00000000000..d6ad93b55b9 Binary files /dev/null and b/models/experimental/yolov11/demo/images/cycle_girl.jpg differ diff --git a/models/experimental/yolov11/demo/images/dog.jpg b/models/experimental/yolov11/demo/images/dog.jpg new file mode 100644 index 00000000000..2c921b5bd9e Binary files /dev/null and b/models/experimental/yolov11/demo/images/dog.jpg differ diff --git a/models/experimental/yolov11/reference/yolov11.py b/models/experimental/yolov11/reference/yolov11.py new file mode 100644 index 00000000000..bba15a9c99e --- /dev/null +++ b/models/experimental/yolov11/reference/yolov11.py @@ -0,0 +1,1008 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +import os +import torch +import torch.nn as nn +import torch.nn.functional as f +import math +import torch +import time + + +def make_anchors(feats, strides, grid_cell_offset=0.5): + """Generate anchors from features.""" + anchor_points, stride_tensor = [], [] + assert feats is not None + dtype, device = feats[0].dtype, feats[0].device + for i, stride in enumerate(strides): + _, _, h, w = feats[i].shape + sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x + sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y + sy, sx = torch.meshgrid(sy, sx) + anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) + stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + return torch.cat(anchor_points), torch.cat(stride_tensor) + + +class DFL(nn.Module): + def __init__(self): + super(DFL, self).__init__() + self.conv = nn.Conv2d(16, 1, kernel_size=1, stride=1, bias=False) + + def forward(self, x): + return self.conv(x) + + +class Conv(nn.Module): + def __init__(self, in_channel, out_channel, kernel=1, stride=1, padding=0, dilation=1, groups=1, enable_act=True): + super().__init__() + self.enable_act = enable_act + if enable_act: + self.conv = nn.Conv2d( + in_channel, + out_channel, + kernel, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + ) + self.bn = nn.BatchNorm2d(out_channel, eps=0.001, momentum=0.03) + self.act = nn.SiLU(inplace=True) + else: + self.conv = nn.Conv2d( + in_channel, + out_channel, + kernel, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + ) + self.bn = nn.BatchNorm2d(out_channel, eps=0.001, momentum=0.03) + + def forward(self, x): + if self.enable_act: + x = self.conv(x) + x = self.bn(x) + x = self.act(x) + else: + x = self.conv(x) + x = self.bn(x) + return x + + +class Bottleneck(nn.Module): + def __init__( + self, in_channel, out_channel, kernel=[1, 1], stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=[1, 1] + ): + super().__init__() + self.cv1 = Conv( + in_channel[0], + out_channel[0], + kernel[0], + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + groups=groups[0], + ) + self.cv2 = Conv( + in_channel[1], + out_channel[1], + kernel[1], + stride=stride[1], + padding=padding[1], + dilation=dilation[1], + groups=groups[1], + ) + + def forward(self, x): + input = x + x = self.cv1(x) + x = self.cv2(x) + return input + x + + +class SPPF(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel=[1, 1], + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + groups=[1, 1], + m_kernel=5, + m_padding=2, + ): + super().__init__() + self.cv1 = Conv( + in_channel[0], + out_channel[0], + kernel[0], + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + groups=groups[0], + ) + self.cv2 = Conv( + in_channel[1], + out_channel[1], + kernel[1], + stride=stride[1], + padding=padding[1], + dilation=dilation[1], + groups=groups[1], + ) + self.m = nn.MaxPool2d(kernel_size=m_kernel, stride=1, padding=m_padding) + + def forward(self, x): + x = self.cv1(x) + x1 = x + m1 = self.m(x) + m2 = self.m(m1) + m3 = self.m(m2) + y = torch.cat((x1, m1, m2, m3), 1) + x = self.cv2(y) + return x + + +class C3k(nn.Module): + def __init__(self, in_channel, out_channel, kernel, stride, padding, dilation, groups): + super().__init__() + self.cv1 = Conv( + in_channel[0], + out_channel[0], + kernel[0], + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + groups=groups[0], + ) + self.cv2 = Conv( + in_channel[1], + out_channel[1], + kernel[1], + stride=stride[1], + padding=padding[1], + dilation=dilation[1], + groups=groups[1], + ) + self.cv3 = Conv( + in_channel[2], + out_channel[2], + kernel[2], + stride=stride[2], + padding=padding[2], + dilation=dilation[2], + groups=groups[2], + ) + self.m = nn.Sequential( + Bottleneck( + in_channel[3:5], + out_channel[3:5], + kernel[3:5], + stride=stride[3:5], + padding=padding[3:5], + dilation=dilation[3:5], + groups=groups[3:5], + ), + Bottleneck( + in_channel[5:7], + out_channel[5:7], + kernel[5:7], + stride=stride[5:7], + padding=padding[5:7], + dilation=dilation[5:7], + groups=groups[5:7], + ), + ) + + def forward(self, x): + x1 = self.cv1(x) + x2 = self.cv2(x) + x = self.m(x1) + x = torch.cat((x, x2), 1) + x = self.cv3(x) + return x + + +class C3k2(nn.Module): + def __init__(self, in_channel, out_channel, kernel, stride, padding, dilation, groups, is_bk_enabled=False): + super().__init__() + self.is_bk_enabled = is_bk_enabled + if is_bk_enabled: + self.cv1 = Conv( + in_channel[0], + out_channel[0], + kernel[0], + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + groups=groups[0], + ) + self.cv2 = Conv( + in_channel[1], + out_channel[1], + kernel[1], + stride=stride[1], + padding=padding[1], + dilation=dilation[1], + groups=groups[1], + ) + self.m = nn.ModuleList( + [ + Bottleneck( + in_channel[2:4], + out_channel[2:4], + kernel[2:4], + stride=stride[2:4], + padding=padding[2:4], + dilation=dilation[2:4], + groups=groups[2:4], + ), + ] + ) + + else: + self.cv1 = Conv( + in_channel[0], + out_channel[0], + kernel[0], + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + groups=groups[0], + ) + self.cv2 = Conv( + in_channel[1], + out_channel[1], + kernel[1], + stride=stride[1], + padding=padding[1], + dilation=dilation[1], + groups=groups[1], + ) + self.m = nn.ModuleList( + [ + C3k( + in_channel[2:9], + out_channel[2:9], + kernel[2:9], + stride[2:9], + padding[2:9], + dilation[2:9], + groups[2:9], + ), + ] + ) + + def forward(self, x): + if self.is_bk_enabled: + x = self.cv1(x) + y = list(x.chunk(2, 1)) + y.extend(self.m[0](y[-1])) + y[-1] = y[-1].unsqueeze(0) + x = torch.cat(y, 1) + x = self.cv2(x) + else: + x = self.cv1(x) + y = list(x.chunk(2, 1)) + y.extend(self.m[0](y[-1])) + y[-1] = y[-1].unsqueeze(0) + x = torch.cat(y, 1) + x = self.cv2(x) + return x + + +class Attention(nn.Module): + def __init__(self, in_channel, out_channel, kernel, stride, padding, dilation, groups): + super().__init__() + self.num_heads = 2 + self.key_dim = 32 + self.head_dim = 64 + self.scale = self.key_dim**-0.5 + + self.qkv = Conv( + in_channel[0], + out_channel[0], + kernel[0], + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + groups=groups[0], + enable_act=False, + ) + self.proj = Conv( + in_channel[1], + out_channel[1], + kernel[1], + stride=stride[1], + padding=padding[1], + dilation=dilation[1], + groups=groups[1], + enable_act=False, + ) + self.pe = Conv( + in_channel[2], + out_channel[2], + kernel[2], + stride=stride[2], + padding=padding[2], + dilation=dilation[2], + groups=groups[2], + enable_act=False, + ) + + def forward(self, x): + B, C, H, W = x.shape + N = H * W + qkv = self.qkv(x) + q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split( + [self.key_dim, self.key_dim, self.head_dim], dim=2 + ) + attn = (q.transpose(-2, -1) @ k) * self.scale + attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W)) + x = self.proj(x) + return x + + +class PSABlock(nn.Module): + def __init__(self, in_channel, out_channel, kernel, stride, padding, dilation, groups): + super().__init__() + self.attn = Attention( + in_channel[0:3], out_channel[0:3], kernel[0:3], stride[0:3], padding[0:3], dilation[0:3], groups[0:3] + ) + self.ffn = nn.Sequential( + Conv( + in_channel[3], + out_channel[3], + kernel[3], + stride=stride[3], + padding=padding[3], + dilation=dilation[3], + groups=groups[3], + ), + Conv( + in_channel[4], + out_channel[4], + kernel[4], + stride=stride[4], + padding=padding[4], + dilation=dilation[4], + groups=groups[4], + enable_act=False, + ), + ) + + def forward(self, x): + x1 = x + x = self.attn(x) + x = x1 + x + x1 = x + x = self.ffn(x) + return x + x1 + + +class C2PSA(nn.Module): + def __init__(self, in_channel, out_channel, kernel, stride, padding, dilation, groups): + super().__init__() + self.out_channel = out_channel + self.cv1 = Conv( + in_channel[0], + out_channel[0], + kernel[0], + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + groups=groups[0], + ) + self.cv2 = Conv( + in_channel[1], + out_channel[1], + kernel[1], + stride=stride[1], + padding=padding[1], + dilation=dilation[1], + groups=groups[1], + ) + self.m = nn.Sequential( + PSABlock( + in_channel[2:7], + out_channel[2:7], + kernel[2:7], + stride=stride[2:7], + padding=padding[2:7], + dilation=dilation[2:7], + groups=groups[2:7], + ) + ) + + def forward(self, x): + x = self.cv1(x) + a, b = x.split((int(self.out_channel[0] / 2), int(self.out_channel[0] / 2)), 1) + x = self.m(b) + x = self.cv2(torch.cat((a, x), 1)) + return x + + +class Detect(nn.Module): + def __init__(self, in_channel, out_channel, kernel, stride, padding, dilation, groups): + super().__init__() + self.out_channel = out_channel + self.in_channel = in_channel + self.cv2 = nn.ModuleList( + [ + nn.Sequential( + Conv( + in_channel[0], + out_channel[0], + kernel[0], + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + groups=groups[0], + ), + Conv( + in_channel[1], + out_channel[1], + kernel[1], + stride=stride[1], + padding=padding[1], + dilation=dilation[1], + groups=groups[1], + ), + nn.Conv2d( + in_channel[2], + out_channel[2], + kernel[2], + stride=stride[2], + padding=padding[2], + dilation=dilation[2], + groups=groups[2], + ), + ), + nn.Sequential( + Conv( + in_channel[3], + out_channel[3], + kernel[3], + stride=stride[3], + padding=padding[3], + dilation=dilation[3], + groups=groups[3], + ), + Conv( + in_channel[4], + out_channel[4], + kernel[4], + stride=stride[4], + padding=padding[4], + dilation=dilation[4], + groups=groups[4], + ), + nn.Conv2d( + in_channel[5], + out_channel[5], + kernel[5], + stride=stride[5], + padding=padding[5], + dilation=dilation[5], + groups=groups[5], + ), + ), + nn.Sequential( + Conv( + in_channel[6], + out_channel[6], + kernel[6], + stride=stride[6], + padding=padding[6], + dilation=dilation[6], + groups=groups[6], + ), + Conv( + in_channel[7], + out_channel[7], + kernel[7], + stride=stride[7], + padding=padding[7], + dilation=dilation[7], + groups=groups[7], + ), + nn.Conv2d( + in_channel[8], + out_channel[8], + kernel[8], + stride=stride[8], + padding=padding[8], + dilation=dilation[8], + groups=groups[8], + ), + ), + ] + ) + self.cv3 = nn.ModuleList( + [ + nn.Sequential( + nn.Sequential( + Conv( + in_channel[9], + out_channel[9], + kernel[9], + stride=stride[9], + padding=padding[9], + dilation=dilation[9], + groups=groups[9], + ), + Conv( + in_channel[10], + out_channel[10], + kernel[10], + stride=stride[10], + padding=padding[10], + dilation=dilation[10], + groups=groups[10], + ), + ), + nn.Sequential( + Conv( + in_channel[11], + out_channel[11], + kernel[11], + stride=stride[11], + padding=padding[11], + dilation=dilation[11], + groups=groups[11], + ), + Conv( + in_channel[12], + out_channel[12], + kernel[12], + stride=stride[12], + padding=padding[12], + dilation=dilation[12], + groups=groups[12], + ), + ), + nn.Conv2d( + in_channel[13], + out_channel[13], + kernel[13], + stride=stride[13], + padding=padding[13], + dilation=dilation[13], + groups=groups[13], + ), + ), + nn.Sequential( + nn.Sequential( + Conv( + in_channel[14], + out_channel[14], + kernel[14], + stride=stride[14], + padding=padding[14], + dilation=dilation[14], + groups=groups[14], + ), + Conv( + in_channel[15], + out_channel[15], + kernel[15], + stride=stride[15], + padding=padding[15], + dilation=dilation[15], + groups=groups[15], + ), + ), + nn.Sequential( + Conv( + in_channel[16], + out_channel[16], + kernel[16], + stride=stride[16], + padding=padding[16], + dilation=dilation[16], + groups=groups[16], + ), + Conv( + in_channel[17], + out_channel[17], + kernel[17], + stride=stride[17], + padding=padding[17], + dilation=dilation[17], + groups=groups[17], + ), + ), + nn.Conv2d( + in_channel[18], + out_channel[18], + kernel[18], + stride=stride[18], + padding=padding[18], + dilation=dilation[18], + groups=groups[18], + ), + ), + nn.Sequential( + nn.Sequential( + Conv( + in_channel[19], + out_channel[19], + kernel[19], + stride=stride[19], + padding=padding[19], + dilation=dilation[19], + groups=groups[19], + ), + Conv( + in_channel[20], + out_channel[20], + kernel[20], + stride=stride[20], + padding=padding[20], + dilation=dilation[20], + groups=groups[20], + ), + ), + nn.Sequential( + Conv( + in_channel[21], + out_channel[21], + kernel[21], + stride=stride[21], + padding=padding[21], + dilation=dilation[21], + groups=groups[21], + ), + Conv( + in_channel[22], + out_channel[22], + kernel[22], + stride=stride[22], + padding=padding[22], + dilation=dilation[22], + groups=groups[22], + ), + ), + nn.Conv2d( + in_channel[23], + out_channel[23], + kernel[23], + stride=stride[23], + padding=padding[23], + dilation=dilation[23], + groups=groups[23], + ), + ), + ] + ) + self.dfl = DFL() + + def forward(self, y1, y2, y3): + x1 = self.cv2[0](y1) + x2 = self.cv2[1](y2) + x3 = self.cv2[2](y3) + x4 = self.cv3[0](y1) + x5 = self.cv3[1](y2) + x6 = self.cv3[2](y3) + + y1 = torch.cat((x1, x4), 1) + y2 = torch.cat((x2, x5), 1) + y3 = torch.cat((x3, x6), 1) + y_all = [y1, y2, y3] + + y1 = torch.reshape(y1, (y1.shape[0], y1.shape[1], y1.shape[2] * y1.shape[3])) + y2 = torch.reshape(y2, (y2.shape[0], y2.shape[1], y2.shape[2] * y2.shape[3])) + y3 = torch.reshape(y3, (y3.shape[0], y3.shape[1], y3.shape[2] * y3.shape[3])) + + y = torch.cat((y1, y2, y3), 2) + + ya, yb = y.split((self.out_channel[0], self.out_channel[10]), 1) + + ya = torch.reshape(ya, (ya.shape[0], int(ya.shape[1] / self.in_channel[24]), self.in_channel[24], ya.shape[2])) + ya = torch.permute(ya, (0, 2, 1, 3)) + ya = f.softmax(ya, dim=1) + c = self.dfl(ya) + c1 = torch.reshape(c, (c.shape[0], c.shape[1] * c.shape[2], c.shape[3])) + c2 = c1 + c1 = c1[:, 0:2, :] + c2 = c2[:, 2:4, :] + anchor, strides = (y_all.transpose(0, 1) for y_all in make_anchors(y_all, [8, 16, 32], 0.5)) + anchor.unsqueeze(0) + c1 = anchor - c1 + c2 = anchor + c2 + z1 = c2 - c1 + z2 = c1 + c2 + z2 = z2 / 2 + z = torch.concat((z2, z1), 1) + z = z * strides + yb = torch.sigmoid(yb) + out = torch.concat((z, yb), 1) + return out + + +class YoloV11(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Sequential( + Conv(3, 16, kernel=3, stride=2, padding=1), # 0 + Conv(16, 32, kernel=3, stride=2, padding=1), # 1 + C3k2( # 2 + [32, 48, 16, 8], + [32, 64, 8, 16], + [1, 1, 3, 3], + [1, 1, 1, 1], + [0, 0, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + is_bk_enabled=True, + ), + Conv(64, 64, kernel=3, stride=2, padding=1), # 3 + C3k2( # 4 + [64, 96, 32, 16], + [64, 128, 16, 32], + [1, 1, 3, 3], + [1, 1, 1, 1], + [0, 0, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + is_bk_enabled=True, + ), + Conv(128, 128, kernel=3, stride=2, padding=1), # 5 + C3k2( + [128, 192, 64, 64, 64, 32, 32, 32, 32], # 6 + [128, 128, 32, 32, 64, 32, 32, 32, 32], + [1, 1, 1, 1, 1, 3, 3, 3, 3], + [1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1], + ), + Conv(128, 256, kernel=3, stride=2, padding=1), # 7 + C3k2( + [256, 384, 128, 128, 128, 64, 64, 64, 64], # 8 + [256, 256, 64, 64, 128, 64, 64, 64, 64], + [1, 1, 1, 1, 1, 3, 3, 3, 3], + [1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1], + ), + SPPF([256, 512], [128, 256], [1, 1], [1, 1]), # 9 + C2PSA( + [256, 256, 128, 128, 128, 128, 256], # 10 + [256, 256, 256, 128, 128, 256, 128], + [1, 1, 1, 1, 3, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 128, 1, 1], + ), + nn.Upsample(scale_factor=2.0, mode="nearest"), + Concat(), + C3k2( # 13 + [384, 192, 64, 32], + [128, 128, 32, 64], + [1, 1, 3, 3], + [1, 1, 1, 1], + [0, 0, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + is_bk_enabled=True, + ), + nn.Upsample(scale_factor=2.0, mode="nearest"), + Concat(), + C3k2( # 16 + [256, 96, 32, 16], + [64, 64, 16, 32], + [1, 1, 3, 3], + [1, 1, 1, 1], + [0, 0, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + is_bk_enabled=True, + ), + Conv(64, 64, kernel=3, stride=2, padding=1), # 17 + Concat(), + C3k2( # 19 + [192, 192, 64, 32], + [128, 128, 32, 64], + [1, 1, 3, 3], + [1, 1, 1, 1], + [0, 0, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + is_bk_enabled=True, + ), + Conv(128, 128, kernel=3, stride=2, padding=1), # 20 + Concat(), + C3k2( + [384, 384, 128, 128, 128, 64, 64, 64, 64], # 22 + [256, 256, 64, 64, 128, 64, 64, 64, 64], + [1, 1, 1, 1, 1, 3, 3, 3, 3], + [1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1], + ), + Detect( # 23 + [ + 64, + 64, + 64, + 128, + 64, + 64, + 256, + 64, + 64, + 64, + 64, + 80, + 80, + 80, + 128, + 128, + 80, + 80, + 80, + 256, + 256, + 80, + 80, + 80, + 16, + ], + [64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 80, 80, 80, 80, 128, 80, 80, 80, 80, 256, 80, 80, 80, 80, 1], + [3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 1, 3, 1, 1, 3, 1, 3, 1, 1, 3, 1, 3, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 64, 1, 80, 1, 1, 128, 1, 80, 1, 1, 256, 1, 80, 1, 1, 1], + ), + ) + + def forward(self, x): + x = self.model[0](x) # 0 + x = self.model[1](x) # 1 + x = self.model[2](x) # 2 + x = self.model[3](x) # 3 + x = self.model[4](x) # 4 + x4 = x + x = self.model[5](x) # 5 + x = self.model[6](x) # 6 + x6 = x + x = self.model[7](x) # 7 + x = self.model[8](x) # 8 + x = self.model[9](x) # 9 + x = self.model[10](x) # 10 + x10 = x + x = f.upsample(x, scale_factor=2.0) # 11 + x = torch.cat((x, x6), 1) # 12 + x = self.model[13](x) # 13 + x13 = x + x = f.upsample(x, scale_factor=2.0) # 14 + x = torch.cat((x, x4), 1) # 15 + x = self.model[16](x) # 16 + x16 = x + x = self.model[17](x) # 17 + x = torch.cat((x, x13), 1) # 18 + x = self.model[19](x) # 19 + x19 = x + x = self.model[20](x) # 20 + x = torch.cat((x, x10), 1) # 21 + x = self.model[22](x) # 22 + x22 = x + x = self.model[23](x16, x19, x22) # 23 + return x + + +class Concat(nn.Module): + def __init__(self, dimension=1): + super().__init__() + self.d = dimension + + def forward(self, x): + return torch.cat(x, self.d) + + +class DWConv(Conv): + """Depth-wise convolution.""" + + def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation + """Initialize Depth-wise convolution with given parameters.""" + super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act) + + +class BaseModel(nn.Module): + def forward(self, x, *args, **kwargs): + if isinstance(x, dict): + return self.loss(x, *args, **kwargs) + return self.predict(x, *args, **kwargs) + + def predict(self, x, profile=False, visualize=False, augment=False, embed=None): + return self._predict_once(x, profile, visualize, embed) + + def _predict_once(self, x, profile=False, visualize=False, embed=None): + y, dt, embeddings = [], [], [] + for m in self.model: + if m.f != -1: + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] + x = m(x) + y.append(x if m.i in self.save else None) + return x + + +class DetectionModel(BaseModel): + def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): + super().__init__() + + +class Ensemble(nn.ModuleList): + def __init__(self): + super(Ensemble, self).__init__() + + def forward(self, x, augment=False): + y = [] + for module in self: + y.append(module(x, augment)[0]) + y = torch.cat(y, 1) + return y, None + + +def attempt_download(file, repo="ultralytics/assets", key="reference"): + tests = Path(__file__).parent.parent / key + file_path = tests / Path(str(file).strip().replace("'", "").lower()) + if not file_path.exists(): + name = "yolov11n.pt" + msg = f"{file_path} missing, try downloading from https://github.com/{repo}/releases/" + + try: + url = f"https://github.com/{repo}/releases/download/v8.3.0/{name}" + + print(f"Downloading {url} to {file_path}...") + torch.hub.download_url_to_file(url, file_path) + assert file_path.exists() and file_path.stat().st_size > 1e6, f"Download failed for {name}" + + except Exception as e: + print(f"Error downloading from GitHub: {e}. Trying secondary source...") + url = f"https://storage.googleapis.com/{repo}/ckpt/{name}" + print(f"Downloading {url} to {file_path}...") + os.system(f"curl -L {url} -o {file_path}") + if not file_path.exists() or file_path.stat().st_size < 1e6: + file_path.unlink(missing_ok=True) + print(f"ERROR: Download failure for {msg}") + else: + print(f"Download succeeded from secondary source!") + return file_path + + +def attempt_load(weights, map_location=None): + model = Ensemble() + for w in weights if isinstance(weights, list) else [weights]: + weight_path = attempt_download(w) + ckpt = torch.load(weight_path, map_location=map_location) + model.append(ckpt["ema" if ckpt.get("ema") else "model"].float().eval()) + + for m in model.modules(): + if isinstance(m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU)): + m.inplace = True + elif isinstance(m, nn.Upsample): + m.recompute_scale_factor = None + + if len(model) == 1: + return model[-1] + else: + for k in ["names", "stride"]: + setattr(model, k, getattr(model[-1], k)) + return model diff --git a/models/experimental/yolov11/tests/test_yolov11.py b/models/experimental/yolov11/tests/test_yolov11.py new file mode 100644 index 00000000000..9cd51b40db8 --- /dev/null +++ b/models/experimental/yolov11/tests/test_yolov11.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import sys +import ttnn +import time +import torch +import pytest +import torch.nn as nn +from loguru import logger +from models.utility_functions import is_wormhole_b0 +from models.perf.perf_utils import prep_perf_report +from models.experimental.yolov11.tt import ttnn_yolov11 +from models.experimental.yolov11.reference import yolov11 +from models.experimental.yolov11.reference.yolov11 import attempt_load +from models.utility_functions import enable_persistent_kernel_cache, disable_persistent_kernel_cache +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report +from models.experimental.yolov11.tt.model_preprocessing import ( + create_yolov11_input_tensors, + create_yolov11_model_parameters, +) + +try: + sys.modules["ultralytics"] = yolov11 + sys.modules["ultralytics.nn.tasks"] = yolov11 + sys.modules["ultralytics.nn.modules.conv"] = yolov11 + sys.modules["ultralytics.nn.modules.block"] = yolov11 + sys.modules["ultralytics.nn.modules.head"] = yolov11 + +except KeyError: + print("models.experimental.yolov11.reference.yolov11 not found.") + + +def get_expected_times(name): + base = {"yolov11": (130.70, 0.594)} + return base[name] + + +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@pytest.mark.parametrize("batch_size", [(1)]) +@pytest.mark.parametrize("input_tensor", [torch.rand((1, 3, 640, 640))], ids=["input_tensor"]) +def test_yolov11(device, input_tensor, batch_size): + disable_persistent_kernel_cache() + torch_input, ttnn_input = create_yolov11_input_tensors( + device, + batch=input_tensor.shape[0], + input_channels=input_tensor.shape[1], + input_height=input_tensor.shape[2], + input_width=input_tensor.shape[3], + ) + torch_model = attempt_load("yolov11n.pt", map_location="cpu") + state_dict = torch_model.state_dict() + torch_model = yolov11.YoloV11() + ds_state_dict = {k: v for k, v in state_dict.items()} + new_state_dict = {} + for (name1, parameter1), (name2, parameter2) in zip(torch_model.state_dict().items(), ds_state_dict.items()): + if isinstance(parameter2, torch.FloatTensor): + new_state_dict[name1] = parameter2 + torch_model.load_state_dict(new_state_dict) + torch_model.eval() + parameters = create_yolov11_model_parameters(torch_model, torch_input, device=device) + model = ttnn_yolov11.YoloV11(device, parameters) + durations = [] + + for i in range(2): + start = time.time() + ttnn_model_output = model(ttnn_input) + end = time.time() + durations.append(end - start) + ttnn.deallocate(ttnn_model_output) + enable_persistent_kernel_cache() + + inference_and_compile_time, inference_time, *_ = durations + + expected_compile_time, expected_inference_time = get_expected_times("yolov11") + + prep_perf_report( + model_name="models/experimental/yolov11", + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") + + +@pytest.mark.parametrize( + "batch_size, expected_perf", + [ + [1, 81.94], + ], +) +@pytest.mark.models_device_performance_bare_metal +def test_perf_device_bare_metal_yolov11(batch_size, expected_perf): + subdir = "ttnn_yolov11" + num_iterations = 1 + margin = 0.03 + expected_perf = expected_perf if is_wormhole_b0() else 0 + + command = f"pytest models/experimental/yolov11/demo/demo.py::test_demo" + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) + + logger.info(f"{expected_results}") + + prep_device_perf_report( + model_name=f"ttnn_yolov11{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments="", + ) diff --git a/models/experimental/yolov11/tests/test_yolov11_perfomant.py b/models/experimental/yolov11/tests/test_yolov11_perfomant.py new file mode 100644 index 00000000000..a2d162931e3 --- /dev/null +++ b/models/experimental/yolov11/tests/test_yolov11_perfomant.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from tests.ttnn.utils_for_testing import assert_with_pcc +from models.utility_functions import run_for_wormhole_b0 +from models.experimental.yolov11.tests.yolov11_perfomant import ( + run_yolov11_trace_inference, + run_yolov11_trace_2cqs_inference, +) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "trace_region_size": 1843200}], indirect=True) +# @pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) +def test_run_yolov11_trace_inference( + device, + use_program_cache, + # enable_async_mode, + model_location_generator, +): + run_yolov11_trace_inference( + device, + model_location_generator, + ) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize( + "device_params", [{"l1_small_size": 24576, "trace_region_size": 3686400, "num_command_queues": 2}], indirect=True +) +def test_run_yolov11_trace_2cqs_inference( + device, + use_program_cache, + # enable_async_mode, + model_location_generator, +): + run_yolov11_trace_2cqs_inference( + device, + model_location_generator, + ) diff --git a/models/experimental/yolov11/tests/yolov11_perfomant.py b/models/experimental/yolov11/tests/yolov11_perfomant.py new file mode 100644 index 00000000000..5363cb62d9f --- /dev/null +++ b/models/experimental/yolov11/tests/yolov11_perfomant.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import ttnn +from models.utility_functions import is_wormhole_b0, profiler +from models.experimental.yolov11.tests.yolov11_test_infra import create_test_infra + +try: + from tracy import signpost + + use_signpost = True +except ModuleNotFoundError: + use_signpost = False + + +def buffer_address(tensor): + addr = [] + for ten in ttnn.get_device_tensors(tensor): + addr.append(ten.buffer_address()) + return addr + + +# TODO: Create ttnn apis for this +ttnn.buffer_address = buffer_address + + +def run_yolov11_trace_inference( + device, + model_location_generator, +): + test_infra = create_test_infra( + device, + model_location_generator=model_location_generator, + ) + tt_inputs_host, input_mem_config = test_infra.setup_l1_sharded_input(device) + + # First run configures convs JIT + test_infra.input_tensor = tt_inputs_host.to(device, input_mem_config) + spec = test_infra.input_tensor.spec + test_infra.run() + test_infra.validate() + test_infra.dealloc_output() + + # Optimized run + test_infra.input_tensor = tt_inputs_host.to(device, input_mem_config) + test_infra.run() + test_infra.validate() + test_infra.dealloc_output() + + # Capture + test_infra.input_tensor = tt_inputs_host.to(device, input_mem_config) + test_infra.dealloc_output() + trace_input_addr = ttnn.buffer_address(test_infra.input_tensor) + tid = ttnn.begin_trace_capture(device, cq_id=0) + test_infra.run() + tt_image_res = ttnn.allocate_tensor_on_device(spec, device) + ttnn.end_trace_capture(device, tid, cq_id=0) + assert trace_input_addr == ttnn.buffer_address(tt_image_res) + + # More optimized run with caching + if use_signpost: + signpost(header="start") + ttnn.copy_host_to_device_tensor(tt_inputs_host, tt_image_res, 0) + ttnn.execute_trace(device, tid, cq_id=0, blocking=True) + if use_signpost: + signpost(header="stop") + test_infra.validate() + + ttnn.release_trace(device, tid) + test_infra.dealloc_output() + + +def run_yolov11_trace_2cqs_inference( + device, + model_location_generator, +): + test_infra = create_test_infra( + device, + model_location_generator=model_location_generator, + ) + tt_inputs_host, sharded_mem_config_DRAM, input_mem_config = test_infra.setup_dram_sharded_input(device) + tt_image_res = tt_inputs_host.to(device, sharded_mem_config_DRAM) + op_event = ttnn.create_event(device) + write_event = ttnn.create_event(device) + # Initialize the op event so we can write + ttnn.record_event(0, op_event) + + # First run configures convs JIT + ttnn.wait_for_event(1, op_event) + ttnn.copy_host_to_device_tensor(tt_inputs_host, tt_image_res, 1) + ttnn.record_event(1, write_event) + ttnn.wait_for_event(0, write_event) + test_infra.input_tensor = ttnn.to_memory_config(tt_image_res, input_mem_config) + shape = test_infra.input_tensor.shape + dtype = test_infra.input_tensor.dtype + layout = test_infra.input_tensor.layout + ttnn.record_event(0, op_event) + test_infra.run() + test_infra.validate() + test_infra.dealloc_output() + + # Optimized run + ttnn.wait_for_event(1, op_event) + ttnn.copy_host_to_device_tensor(tt_inputs_host, tt_image_res, 1) + ttnn.record_event(1, write_event) + ttnn.wait_for_event(0, write_event) + test_infra.input_tensor = ttnn.to_memory_config(tt_image_res, input_mem_config) + ttnn.record_event(0, op_event) + test_infra.run() + test_infra.validate() + + # Capture + ttnn.wait_for_event(1, op_event) + ttnn.copy_host_to_device_tensor(tt_inputs_host, tt_image_res, 1) + ttnn.record_event(1, write_event) + ttnn.wait_for_event(0, write_event) + test_infra.input_tensor = ttnn.to_memory_config(tt_image_res, input_mem_config) + ttnn.record_event(0, op_event) + test_infra.dealloc_output() + trace_input_addr = ttnn.buffer_address(test_infra.input_tensor) + tid = ttnn.begin_trace_capture(device, cq_id=0) + test_infra.run() + input_tensor = ttnn.allocate_tensor_on_device( + shape, + dtype, + layout, + device, + input_mem_config, + ) + ttnn.end_trace_capture(device, tid, cq_id=0) + assert trace_input_addr == ttnn.buffer_address(input_tensor) + + # More optimized run with caching + if use_signpost: + signpost(header="start") + for iter in range(0, 2): + ttnn.wait_for_event(1, op_event) + ttnn.copy_host_to_device_tensor(tt_inputs_host, tt_image_res, 1) + ttnn.record_event(1, write_event) + ttnn.wait_for_event(0, write_event) + # TODO: Add in place support to ttnn to_memory_config + input_tensor = ttnn.reshard(tt_image_res, input_mem_config, input_tensor) + ttnn.record_event(0, op_event) + ttnn.execute_trace(device, tid, cq_id=0, blocking=False) + ttnn.synchronize_devices(device) + + if use_signpost: + signpost(header="stop") + + ttnn.release_trace(device, tid) diff --git a/models/experimental/yolov11/tests/yolov11_test_infra.py b/models/experimental/yolov11/tests/yolov11_test_infra.py new file mode 100644 index 00000000000..bfa1534b897 --- /dev/null +++ b/models/experimental/yolov11/tests/yolov11_test_infra.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from loguru import logger +import torch +from tests.ttnn.utils_for_testing import assert_with_pcc +import ttnn +from models.experimental.yolov11.reference import yolov11 +from models.experimental.yolov11.tt import ttnn_yolov11 +from models.experimental.yolov11.reference.yolov11 import attempt_load +import sys +from models.utility_functions import ( + is_wormhole_b0, + is_grayskull, + divup, +) +from models.experimental.yolov11.tt.model_preprocessing import ( + create_yolov11_input_tensors, + create_yolov11_model_parameters, +) +import torch.nn as nn + +try: + sys.modules["ultralytics"] = yolov11 + sys.modules["ultralytics.nn.tasks"] = yolov11 + sys.modules["ultralytics.nn.modules.conv"] = yolov11 + sys.modules["ultralytics.nn.modules.block"] = yolov11 + sys.modules["ultralytics.nn.modules.head"] = yolov11 + +except KeyError: + print("models.experimental.yolov11.reference.yolov11 not found.") + + +def load_yolov11_model(): + torch_model = attempt_load("yolov11n.pt", map_location="cpu") + state_dict = torch_model.state_dict() + torch_model = yolov11.YoloV11() + ds_state_dict = {k: v for k, v in state_dict.items()} + new_state_dict = {} + for (name1, parameter1), (name2, parameter2) in zip(torch_model.state_dict().items(), ds_state_dict.items()): + if isinstance(parameter2, torch.FloatTensor): + new_state_dict[name1] = parameter2 + torch_model.load_state_dict(new_state_dict) + torch_model.eval() + return torch_model + + +class Yolov11TestInfra: + def __init__( + self, + device, + model_location_generator=None, + ): + super().__init__() + torch.manual_seed(0) + self.pcc_passed = False + self.pcc_message = "Did you forget to call validate()?" + self.device = device + self.model_location_generator = model_location_generator + self.torch_input, self.ttnn_input = create_yolov11_input_tensors(device) + torch_model = load_yolov11_model() + parameters = create_yolov11_model_parameters(torch_model, self.torch_input, device=device) + self.torch_output = torch_model(self.torch_input) + self.ttnn_yolov11_model = ttnn_yolov11.YoloV11(device, parameters) + + def run(self): + self.output_tensor = self.ttnn_yolov11_model(self.input_tensor) + + def setup_l1_sharded_input(self, device, torch_input_tensor=None): + if is_wormhole_b0(): + core_grid = ttnn.CoreGrid(y=8, x=8) + else: + exit("Unsupported device") + num_devices = 1 if isinstance(device, ttnn.Device) else device.get_num_devices() + # torch tensor + torch_input_tensor = self.torch_input if self.torch_input is None else self.torch_input + + n, c, h, w = torch_input_tensor.shape + # sharded mem config for fold input + num_cores = core_grid.x * core_grid.y + shard_h = (n * w * h + num_cores - 1) // num_cores + grid_size = core_grid + grid_coord = ttnn.CoreCoord(grid_size.x - 1, grid_size.y - 1) + shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)}) + shard_spec = ttnn.ShardSpec(shard_grid, (shard_h, 16), ttnn.ShardOrientation.ROW_MAJOR) + input_mem_config = ttnn.MemoryConfig( + ttnn.types.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.types.BufferType.L1, shard_spec + ) + torch_input_tensor = torch_input_tensor.permute(0, 2, 3, 1) + torch_input_tensor = torch_input_tensor.reshape(1, 1, h * w * n, c) + tt_inputs_host = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT) + tt_inputs_host = ttnn.pad(tt_inputs_host, [1, 1, n * h * w, 16], [0, 0, 0, 0], 0) + return tt_inputs_host, input_mem_config + + def setup_dram_sharded_input(self, device, torch_input_tensor=None, mesh_mapper=None, mesh_composer=None): + tt_inputs_host, input_mem_config = self.setup_l1_sharded_input(device) + dram_grid_size = device.dram_grid_size() + dram_shard_spec = ttnn.ShardSpec( + ttnn.CoreRangeSet( + {ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(dram_grid_size.x - 1, dram_grid_size.y - 1))} + ), + [ + divup(tt_inputs_host.volume() // tt_inputs_host.shape[-1], (dram_grid_size.x * dram_grid_size.y)), + 16, + ], + ttnn.ShardOrientation.ROW_MAJOR, + ) + sharded_mem_config_DRAM = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.DRAM, dram_shard_spec + ) + + return tt_inputs_host, sharded_mem_config_DRAM, input_mem_config + + def validate(self, output_tensor=None): + output_tensor = self.output_tensor if output_tensor is None else output_tensor + output_tensor = ttnn.to_torch(self.output_tensor) + output_tensor = output_tensor.reshape((self.torch_output).shape) + + valid_pcc = 0.98 + self.pcc_passed, self.pcc_message = assert_with_pcc(self.torch_output, output_tensor, pcc=valid_pcc) + + logger.info(f"Yolov11, PCC={self.pcc_message}") + + def dealloc_output(self): + ttnn.deallocate(self.output_tensor) + + +def create_test_infra( + device, + model_location_generator=None, +): + return Yolov11TestInfra( + device, + model_location_generator, + ) diff --git a/models/experimental/yolov11/tt/model_preprocessing.py b/models/experimental/yolov11/tt/model_preprocessing.py new file mode 100644 index 00000000000..595a4bb3821 --- /dev/null +++ b/models/experimental/yolov11/tt/model_preprocessing.py @@ -0,0 +1,185 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from ttnn.model_preprocessing import infer_ttnn_module_args +from models.experimental.yolov11.reference.yolov11 import YoloV11 +import torch.nn as nn +from ttnn.model_preprocessing import preprocess_model_parameters, fold_batch_norm2d_into_conv2d +from models.experimental.yolov11.reference.yolov11 import Conv, DFL, Detect + + +def create_yolov11_input_tensors(device, batch=1, input_channels=3, input_height=224, input_width=224): + torch_input_tensor = torch.randn(batch, input_channels, input_height, input_width) + ttnn_input_tensor = torch.permute(torch_input_tensor, (0, 2, 3, 1)) + ttnn_input_tensor = ttnn_input_tensor.reshape( + 1, + 1, + ttnn_input_tensor.shape[0] * ttnn_input_tensor.shape[1] * ttnn_input_tensor.shape[2], + ttnn_input_tensor.shape[3], + ) + ttnn_input_tensor = ttnn.from_torch(ttnn_input_tensor, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat8_b) + return torch_input_tensor, ttnn_input_tensor + + +def make_anchors(device, feats, strides, grid_cell_offset=0.5): + anchor_points, stride_tensor = [], [] + assert feats is not None + for i, stride in enumerate(strides): + h, w = feats[i], feats[i] + sx = torch.arange(end=w) + grid_cell_offset + sy = torch.arange(end=h) + grid_cell_offset + sy, sx = torch.meshgrid(sy, sx) + anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) + stride_tensor.append(torch.full((h * w, 1), stride)) + + a = torch.cat(anchor_points).transpose(0, 1).unsqueeze(0) + b = torch.cat(stride_tensor).transpose(0, 1) + + return ( + ttnn.from_torch(a, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device), + ttnn.from_torch(b, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device), + ) + + +def preprocess_params(d, parameters, path=None, depth=0, max_depth=6): + if path is None: + path = [] # Initialize the path for the first call + + if isinstance(d, dict): + # If the dictionary has the 'conv' key, handle it + if "conv" in d: + weight_full_path = ".".join(path + ["conv", "weight"]) + bias_full_path = ".".join(path + ["conv", "bias"]) + weight, bias = preprocess(parameters, weight_full_path, bias_full_path) + d.conv.bias = None + d.conv.weight = weight + if bias is not None: + d.conv.bias = bias + + # Recurse deeper only if we haven't reached the max depth + if depth < max_depth: + for key, value in d.items(): + if isinstance(value, dict): # If the value is a dictionary, continue recursion + if depth == 0: + d[key] = preprocess_params(value, parameters, path + [key, "module"], depth + 1, max_depth) + else: + d[key] = preprocess_params(value, parameters, path + [key], depth + 1, max_depth) + + return d + + +class DotDict(dict): + """ + A dictionary subclass that allows attribute-style access to its keys. + """ + + def __getattr__(self, key): + try: + return self[key] + except KeyError: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'") + + def __setattr__(self, key, value): + self[key] = value + + def __delattr__(self, key): + try: + del self[key] + except KeyError: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'") + + +def preprocess(d: dict, weights_path: str, bias_path: str): + tt_bias = None + weight_keys = weights_path.split(".") + bias_keys = bias_path.split(".") + w_current = DotDict(d) # Convert the top-level dictionary to DotDict + b_current = DotDict(d) + + for key in weight_keys: + w_current = getattr(w_current, key) # Use getattr for dot notation access + for key in bias_keys: + b_current = getattr(b_current, key) # Use getattr for dot notation access + + tt_weight = ttnn.from_torch(w_current, dtype=ttnn.float32) + if b_current is not None: + b_current = torch.reshape(b_current, (1, 1, 1, -1)) + tt_bias = ttnn.from_torch(b_current, dtype=ttnn.float32) + + return tt_weight, tt_bias + + +def custom_preprocessor(model, name): + parameters = {} + if isinstance(model, nn.Conv2d): + parameters["weight"] = ttnn.from_torch(model.weight, dtype=ttnn.float32) + if model.bias is not None: + bias = model.bias.reshape((1, 1, 1, -1)) + parameters["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) + # else: + # parameters["bias"] = None + + if isinstance(model, Conv): + weight, bias = fold_batch_norm2d_into_conv2d(model.conv, model.bn) + parameters["conv"] = {} + parameters["conv"]["weight"] = ttnn.from_torch(weight, dtype=ttnn.float32) + bias = bias.reshape((1, 1, 1, -1)) + parameters["conv"]["bias"] = ttnn.from_torch(bias, dtype=ttnn.float32) + + return parameters + + +def create_yolov11_model_parameters(model: YoloV11, input_tensor: torch.Tensor, device): + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=custom_preprocessor, + device=device, + ) + parameters.conv_args = {} + parameters.conv_args = infer_ttnn_module_args(model=model, run_model=lambda model: model(input_tensor), device=None) + + parameters["model_args"] = model + + feats = [ + input_tensor.shape[3] // 8, + input_tensor.shape[3] // 16, + input_tensor.shape[3] // 32, + ] + strides = [8.0, 16.0, 32.0] + + anchors, strides = make_anchors(device, feats, strides) # Optimization: Processing make anchors outside model run + + if "model" in parameters: + parameters.model[23]["anchors"] = anchors + parameters.model[23]["strides"] = strides + + return parameters + + +def create_yolov11_model_parameters_detect( + model: YoloV11, input_tensor_1: torch.Tensor, input_tensor_2, input_tensor_3, device +): + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=custom_preprocessor, + device=device, + ) + parameters.conv_args = {} + parameters.conv_args = infer_ttnn_module_args( + model=model, run_model=lambda model: model(input_tensor_1, input_tensor_2, input_tensor_3), device=None + ) + + feats = [28, 14, 7] + strides = [8.0, 16.0, 32.0] + + anchors, strides = make_anchors(device, feats, strides) # Optimization: Processing make anchors outside model run + + parameters["anchors"] = anchors + parameters["strides"] = strides + + parameters["model"] = model + + return parameters diff --git a/models/experimental/yolov11/tt/ttnn_yolov11.py b/models/experimental/yolov11/tt/ttnn_yolov11.py new file mode 100644 index 00000000000..72ad8b6b944 --- /dev/null +++ b/models/experimental/yolov11/tt/ttnn_yolov11.py @@ -0,0 +1,690 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import math +from tt_lib.utils import ( + _nearest_y, +) +from tests.ttnn.ttnn_utility_fuction import get_shard_grid_from_num_cores + + +class Yolov11_Conv2D: + def __init__( + self, + conv, + conv_pth, + bn=None, + device=None, + cache={}, + activation="", + activation_dtype=ttnn.bfloat8_b, + weights_dtype=ttnn.bfloat8_b, + use_1d_systolic_array=True, + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + is_detect=False, + is_dfl=False, + ): + self.is_detect = is_detect + self.is_dfl = is_dfl + self.conv = conv + self.device = device + self.in_channels = conv.in_channels + self.out_channels = conv.out_channels + self.kernel_size = conv.kernel_size + self.padding = conv.padding + self.stride = conv.stride + self.groups = conv.groups + self.use_1d_systolic_array = use_1d_systolic_array + self.deallocate_activation = False + self.cache = cache + self.compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + fp32_dest_acc_en=False, + packer_l1_acc=True, + math_approx_mode=True, + ) + self.conv_config = ttnn.Conv2dConfig( + dtype=activation_dtype, + weights_dtype=weights_dtype, + shard_layout=shard_layout, + deallocate_activation=self.deallocate_activation, + enable_act_double_buffer=False, + enable_split_reader=False, + enable_subblock_padding=False, + reshard_if_not_optimal=True if self.use_1d_systolic_array else False, + activation=activation, + input_channels_alignment=32, + ) + config_override = None + if config_override and "act_block_h" in config_override: + self.conv_config.act_block_h_override = config_override["act_block_h"] + + if "bias" in conv_pth: + bias = ttnn.from_device(conv_pth.bias) + self.bias = bias + else: + self.bias = None + + weight = ttnn.from_device(conv_pth.weight) + self.weight = weight + + def __call__(self, x): + if self.is_detect: + input_height = int(math.sqrt(x.shape[2])) + input_width = int(math.sqrt(x.shape[2])) + batch_size = x.shape[0] + elif self.is_dfl: + input_height = x.shape[1] + input_width = x.shape[2] + batch_size = x.shape[0] + else: + batch_size = self.conv.batch_size + input_height = self.conv.input_height + input_width = self.conv.input_width + + [x, [output_height, output_width], [self.weight, self.bias]] = ttnn.conv2d( + input_tensor=x, + weight_tensor=self.weight, + bias_tensor=self.bias, + device=self.device, + in_channels=self.in_channels, + out_channels=self.out_channels, + input_height=input_height, + input_width=input_width, + batch_size=batch_size, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + conv_config=self.conv_config, + conv_op_cache=self.cache, + groups=self.groups, + compute_config=self.compute_config, + return_output_dim=True, + return_weights_and_bias=True, + ) + return x + + +def sharded_concat(input_tensors, num_cores=64, dim=3): # expected input tensors to be in fp16, RM, same (h*w) + shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}) + in_shard_width = input_tensors[0].shape[-1] + shard_height = (input_tensors[0].shape[2] + num_cores - 1) // num_cores + input_sharded_memory_config = ttnn.create_sharded_memory_config( + (shard_height, in_shard_width), + core_grid=shard_grid, + strategy=ttnn.ShardStrategy.HEIGHT, + use_height_and_width_as_shard_shape=True, + ) + out_shard_width = 0 + for i in range(len(input_tensors)): + out_shard_width += input_tensors[i].shape[-1] + input_tensors[i] = ttnn.to_memory_config(input_tensors[i], input_sharded_memory_config) + output_sharded_memory_config = ttnn.create_sharded_memory_config( + (shard_height, out_shard_width), + core_grid=shard_grid, + strategy=ttnn.ShardStrategy.HEIGHT, + use_height_and_width_as_shard_shape=True, + ) + output = ttnn.concat(input_tensors, dim, memory_config=output_sharded_memory_config) + output = ttnn.sharded_to_interleaved(output, memory_config=ttnn.L1_MEMORY_CONFIG) + + return output + + +class Conv: + def __init__(self, device, parameter, conv_pt, enable_act=True, is_detect=False): + self.enable_act = enable_act + self.conv = Yolov11_Conv2D(parameter.conv, conv_pt.conv, device=device, is_detect=is_detect) + + def __call__(self, device, x): + if self.enable_act: + x = self.conv(x) + if x.is_sharded(): + x = ttnn.sharded_to_interleaved(x, ttnn.L1_MEMORY_CONFIG) + x = ttnn.silu(x) + + else: + x = self.conv(x) + return x + + +class Bottleneck: + def __init__(self, device, parameter, conv_pt): + self.cv1 = Conv(device, parameter.cv1, conv_pt.cv1) + self.cv2 = Conv(device, parameter.cv2, conv_pt.cv2) + + def __call__(self, device, x): + input = x + x = self.cv1(device, x) + x = self.cv2(device, x) + return input + x + + +class SPPF: + def __init__(self, device, parameter, conv_pt): + self.parameter = parameter + self.cv1 = Conv(device, parameter.cv1, conv_pt.cv1) + self.cv2 = Conv(device, parameter.cv2, conv_pt.cv2) + + def __call__(self, device, x): + x = self.cv1(device, x) + if x.get_layout() != ttnn.ROW_MAJOR_LAYOUT: + x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) + x1 = x + m1 = ttnn.max_pool2d( + x, + batch_size=self.parameter.cv2.conv.batch_size, + input_h=self.parameter.cv2.conv.input_height, + input_w=self.parameter.cv2.conv.input_width, + channels=self.parameter.cv2.conv.in_channels, + kernel_size=[5, 5], + stride=[1, 1], + padding=[2, 2], + dilation=[1, 1], + ) + m2 = ttnn.max_pool2d( + m1, + batch_size=self.parameter.cv2.conv.batch_size, + input_h=self.parameter.cv2.conv.input_height, + input_w=self.parameter.cv2.conv.input_width, + channels=self.parameter.cv2.conv.in_channels, + kernel_size=[5, 5], + stride=[1, 1], + padding=[2, 2], + dilation=[1, 1], + ) + m3 = ttnn.max_pool2d( + m2, + batch_size=self.parameter.cv2.conv.batch_size, + input_h=self.parameter.cv2.conv.input_height, + input_w=self.parameter.cv2.conv.input_width, + channels=self.parameter.cv2.conv.in_channels, + kernel_size=[5, 5], + stride=[1, 1], + padding=[2, 2], + dilation=[1, 1], + ) + use_sharded_concat = True + if use_sharded_concat: + y = sharded_concat([x1, m1, m2, m3]) + else: + y = ttnn.concat([x1, m1, m2, m3], dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG) + x = self.cv2(device, y) + ttnn.deallocate(x1) + ttnn.deallocate(m1) + ttnn.deallocate(m2) + ttnn.deallocate(m3) + return x + + +class C3K: + def __init__(self, device, parameter, conv_pt): + self.cv1 = Conv(device, parameter.cv1, conv_pt.cv1) + self.cv2 = Conv(device, parameter.cv2, conv_pt.cv2) + self.cv3 = Conv(device, parameter.cv3, conv_pt.cv3) + self.k1 = Bottleneck(device, parameter.m[0], conv_pt.m[0]) + self.k2 = Bottleneck(device, parameter.m[1], conv_pt.m[1]) + + def __call__(self, device, x): + x1 = self.cv1(device, x) + x2 = self.cv2(device, x) + + k1 = self.k1(device, x1) + k2 = self.k2(device, k1) + use_shard_concat = False + if use_shard_concat: + x2 = ttnn.to_layout(x2, ttnn.ROW_MAJOR_LAYOUT) + x2 = ttnn.to_dtype(x2, ttnn.bfloat16) + k2 = ttnn.to_layout(k2, ttnn.ROW_MAJOR_LAYOUT) + k2 = ttnn.to_dtype(k2, ttnn.bfloat16) + x = sharded_concat([k2, x2]) + else: + x = ttnn.concat((k2, x2), 3, memory_config=ttnn.L1_MEMORY_CONFIG) + x = self.cv3(device, x) + ttnn.deallocate(x1) + ttnn.deallocate(x2) + ttnn.deallocate(k1) + ttnn.deallocate(k2) + return x + + +class C3k2: + def __init__(self, device, parameter, conv_pt, is_bk_enabled=False): + self.is_bk_enabled = is_bk_enabled + self.parameter = parameter + + if is_bk_enabled: + self.cv1 = Conv(device, parameter.cv1, conv_pt.cv1) + self.cv2 = Conv(device, parameter.cv2, conv_pt.cv2) + self.k = Bottleneck(device, parameter[0], conv_pt.m[0]) + else: + self.cv1 = Conv(device, parameter.cv1, conv_pt.cv1) + self.cv2 = Conv(device, parameter.cv2, conv_pt.cv2) + self.c3k = C3K(device, parameter[0], conv_pt.m[0]) + + def __call__(self, device, x): + x = self.cv1(device, x) + x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT) + y1 = x[:, :, :, : x.shape[-1] // 2] + y2 = x[:, :, :, x.shape[-1] // 2 : x.shape[-1]] + if self.is_bk_enabled: + y2 = ttnn.to_layout(y2, layout=ttnn.TILE_LAYOUT) + y3 = self.k(device, y2) + else: + y3 = self.c3k(device, y2) + + if y2.get_layout() != ttnn.ROW_MAJOR_LAYOUT: + y2 = ttnn.to_layout(y2, ttnn.ROW_MAJOR_LAYOUT) + if y3.get_layout() != ttnn.ROW_MAJOR_LAYOUT: + y3 = ttnn.to_layout(y3, ttnn.ROW_MAJOR_LAYOUT) + use_shard_concat = True + if use_shard_concat: + x = sharded_concat([y1, y2, y3]) + else: + x = ttnn.concat((y1, y2, y3), 3, memory_config=ttnn.L1_MEMORY_CONFIG) + x = self.cv2(device, x) + + ttnn.deallocate(y1) + ttnn.deallocate(y2) + ttnn.deallocate(y3) + return x + + +class Attention: + def __init__(self, device, parameter, conv_pt): + self.qkv = Conv(device, parameter.qkv, conv_pt.qkv, enable_act=False) + self.proj = Conv(device, parameter.proj, conv_pt.proj, enable_act=False) + self.pe = Conv(device, parameter.pe, conv_pt.pe, enable_act=False) + self.num_heads = 2 + self.key_dim = 32 + self.head_dim = 64 + self.scale = self.key_dim**-0.5 + + def __call__(self, device, x, batch_size=1): + qkv = self.qkv(device, x) + qkv = ttnn.sharded_to_interleaved(qkv, memory_config=ttnn.L1_MEMORY_CONFIG) + qkv = ttnn.permute(qkv, (0, 3, 1, 2)) + qkv = ttnn.to_layout(qkv, layout=ttnn.ROW_MAJOR_LAYOUT) + qkv = ttnn.to_dtype(qkv, ttnn.bfloat16) + qkv = ttnn.to_layout(qkv, layout=ttnn.TILE_LAYOUT) + qkv = ttnn.reshape(qkv, (batch_size, self.num_heads, self.key_dim * 2 + self.head_dim, qkv.shape[-1])) + q, k, v = ( + qkv[:, :, : self.key_dim, :], + qkv[:, :, self.key_dim : self.head_dim, :], + qkv[:, :, self.head_dim :, :], + ) + + q_permuted = ttnn.permute(q, (0, 1, 3, 2)) + attn = ttnn.matmul(q_permuted, k, memory_config=ttnn.L1_MEMORY_CONFIG) + attn = ttnn.multiply(attn, self.scale) + attn = ttnn.softmax(attn, dim=-1) + attn = ttnn.permute(attn, (0, 1, 3, 2)) + x1 = ttnn.matmul(v, attn, memory_config=ttnn.L1_MEMORY_CONFIG) + x1 = ttnn.reshape(x1, (1, 1, (x1.shape[0] * x1.shape[1] * x1.shape[2]), x1.shape[3])) + x1 = ttnn.permute(x1, (0, 1, 3, 2)) + v = ttnn.reshape(v, (1, 1, (v.shape[0] * v.shape[1] * v.shape[2]), v.shape[3])) + v = ttnn.permute(v, (0, 1, 3, 2)) + x2 = self.pe(device=device, x=v) + x = ttnn.add(x1, x2, memory_config=x2.memory_config()) + x = self.proj(device=device, x=x) + ttnn.deallocate(x1) + ttnn.deallocate(qkv) + ttnn.deallocate(q_permuted) + ttnn.deallocate(attn) + ttnn.deallocate(q) + ttnn.deallocate(k) + ttnn.deallocate(v) + ttnn.deallocate(x2) + return x + + +def determine_num_cores_for_upsample(nhw: int, width: int, max_cores=64) -> int: + gcd_nhw_width = math.gcd(nhw, width) + cores = nhw // gcd_nhw_width + if cores > max_cores: + for divisor in range(max_cores, 0, -1): + if nhw % divisor == 0 and (nhw // divisor) % width == 0: + cores = divisor + break + return cores + + +def get_core_grid_from_num_cores(num_cores: int, grid_rows: int = 8, grid_cols: int = 8): + rows = num_cores // grid_cols + assert rows <= grid_rows, "Not enough cores for specified core grid" + ranges = [] + if rows != 0: + ranges.append( + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(grid_rows - 1, rows - 1), + ) + ) + remainder = num_cores % grid_rows + if remainder != 0: + assert rows + 1 <= grid_rows, "Not enough cores for specified core grid" + ranges.append( + ttnn.CoreRange( + ttnn.CoreCoord(0, rows), + ttnn.CoreCoord(remainder - 1, rows), + ) + ) + return ttnn.CoreRangeSet({*ranges}) + + +class PSABlock: + def __init__(self, device, parameter, conv_pt): + self.attn = Attention(device=device, parameter=parameter.attn, conv_pt=conv_pt.attn) + self.ffn_conv1 = Conv(device, parameter.ffn[0], conv_pt.ffn[0]) + self.ffn_conv2 = Conv(device, parameter.ffn[1], conv_pt.ffn[1], enable_act=False) + + def __call__(self, device, x): + x1 = x + x = self.attn(device, x) + x = ttnn.add(x1, x, memory_config=x.memory_config()) + x1 = x + x = self.ffn_conv1(device, x) + x = self.ffn_conv2(device, x) + x = ttnn.add(x, x1, memory_config=x1.memory_config()) + return x + + +class C2PSA: + def __init__(self, device, parameter, conv_pt): + self.out_channel_0 = parameter.cv1.conv.out_channels + self.cv1 = Conv(device, parameter.cv1, conv_pt.cv1) + self.cv2 = Conv(device, parameter.cv2, conv_pt.cv2) + self.psablock = PSABlock(device, parameter.m[0], conv_pt.m[0]) + + def __call__(self, device, x): + x = self.cv1(device, x) + a, b = x[:, :, :, : int(self.out_channel_0 / 2)], x[:, :, :, int(self.out_channel_0 / 2) :] + x = self.psablock(device, b) + x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG) + x = ttnn.concat((a, x), dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG) + x = self.cv2(device, x) + ttnn.deallocate(a) + ttnn.deallocate(b) + return x + + +class Detect: + def __init__(self, device, parameter, conv_pt): + self.cv2_0_0 = Conv(device, parameter.cv2[0][0], conv_pt.cv2[0][0], is_detect=True) + self.cv2_0_1 = Conv(device, parameter.cv2[0][1], conv_pt.cv2[0][1], is_detect=True) + self.cv2_0_2 = Yolov11_Conv2D(parameter.cv2[0][2], conv_pt.cv2[0][2], device=device, is_detect=True) + + self.cv2_1_0 = Conv(device, parameter.cv2[1][0], conv_pt.cv2[1][0], is_detect=True) + self.cv2_1_1 = Conv(device, parameter.cv2[1][1], conv_pt.cv2[1][1], is_detect=True) + self.cv2_1_2 = Yolov11_Conv2D(parameter.cv2[1][2], conv_pt.cv2[1][2], device=device, is_detect=True) + + self.cv2_2_0 = Conv(device, parameter.cv2[2][0], conv_pt.cv2[2][0], is_detect=True) + self.cv2_2_1 = Conv(device, parameter.cv2[2][1], conv_pt.cv2[2][1], is_detect=True) + self.cv2_2_2 = Yolov11_Conv2D(parameter.cv2[2][2], conv_pt.cv2[2][2], device=device, is_detect=True) + + self.cv3_0_0_0 = Conv(device, parameter.cv3[0][0][0], conv_pt.cv3[0][0][0], is_detect=True) + self.cv3_0_0_1 = Conv(device, parameter.cv3[0][0][1], conv_pt.cv3[0][0][1], is_detect=True) + self.cv3_0_1_0 = Conv(device, parameter.cv3[0][1][0], conv_pt.cv3[0][1][0], is_detect=True) + self.cv3_0_1_1 = Conv(device, parameter.cv3[0][1][1], conv_pt.cv3[0][1][1], is_detect=True) + self.cv3_0_2_0 = Yolov11_Conv2D(parameter.cv3[0][2], conv_pt.cv3[0][2], device=device, is_detect=True) + + self.cv3_1_0_0 = Conv(device, parameter.cv3[1][0][0], conv_pt.cv3[1][0][0], is_detect=True) + self.cv3_1_0_1 = Conv(device, parameter.cv3[1][0][1], conv_pt.cv3[1][0][1], is_detect=True) + self.cv3_1_1_0 = Conv(device, parameter.cv3[1][1][0], conv_pt.cv3[1][1][0], is_detect=True) + self.cv3_1_1_1 = Conv(device, parameter.cv3[1][1][1], conv_pt.cv3[1][1][1], is_detect=True) + self.cv3_1_2_0 = Yolov11_Conv2D(parameter.cv3[1][2], conv_pt.cv3[1][2], device=device, is_detect=True) + + self.cv3_2_0_0 = Conv(device, parameter.cv3[2][0][0], conv_pt.cv3[2][0][0], is_detect=True) + self.cv3_2_0_1 = Conv(device, parameter.cv3[2][0][1], conv_pt.cv3[2][0][1], is_detect=True) + self.cv3_2_1_0 = Conv(device, parameter.cv3[2][1][0], conv_pt.cv3[2][1][0], is_detect=True) + self.cv3_2_1_1 = Conv(device, parameter.cv3[2][1][1], conv_pt.cv3[2][1][1], is_detect=True) + self.cv3_2_2_0 = Yolov11_Conv2D(parameter.cv3[2][2], conv_pt.cv3[2][2], device=device, is_detect=True) + + self.dfl = Yolov11_Conv2D(parameter.dfl.conv, conv_pt.dfl.conv, device=device, is_dfl=True) + self.anchors = conv_pt.anchors + self.strides = conv_pt.strides + + def __call__(self, device, y1, y2, y3): + x1 = self.cv2_0_0(device, y1) + x1 = self.cv2_0_1(device, x1) + x1 = self.cv2_0_2(x1) + x2 = self.cv2_1_0(device, y2) + x2 = self.cv2_1_1(device, x2) + x2 = self.cv2_1_2(x2) + + x3 = self.cv2_2_0(device, y3) + x3 = self.cv2_2_1(device, x3) + x3 = self.cv2_2_2(x3) + + x4 = self.cv3_0_0_0(device, y1) + x4 = self.cv3_0_0_1(device, x4) + x4 = self.cv3_0_1_0(device, x4) + x4 = self.cv3_0_1_1(device, x4) + x4 = self.cv3_0_2_0(x4) + + x5 = self.cv3_1_0_0(device, y2) + x5 = self.cv3_1_0_1(device, x5) + x5 = self.cv3_1_1_0(device, x5) + x5 = self.cv3_1_1_1(device, x5) + x5 = self.cv3_1_2_0(x5) + + x6 = self.cv3_2_0_0(device, y3) + x6 = self.cv3_2_0_1(device, x6) + x6 = self.cv3_2_1_0(device, x6) + x6 = self.cv3_2_1_1(device, x6) + x6 = self.cv3_2_2_0(x6) + + x1 = ttnn.sharded_to_interleaved(x1, memory_config=ttnn.L1_MEMORY_CONFIG) + x2 = ttnn.sharded_to_interleaved(x2, memory_config=ttnn.L1_MEMORY_CONFIG) + x3 = ttnn.sharded_to_interleaved(x3, memory_config=ttnn.L1_MEMORY_CONFIG) + x4 = ttnn.sharded_to_interleaved(x4, memory_config=ttnn.L1_MEMORY_CONFIG) + x5 = ttnn.sharded_to_interleaved(x5, memory_config=ttnn.L1_MEMORY_CONFIG) + x6 = ttnn.sharded_to_interleaved(x6, memory_config=ttnn.L1_MEMORY_CONFIG) + + y1 = ttnn.concat((x1, x4), -1, memory_config=ttnn.L1_MEMORY_CONFIG) + y2 = ttnn.concat((x2, x5), -1, memory_config=ttnn.L1_MEMORY_CONFIG) + y3 = ttnn.concat((x3, x6), -1, memory_config=ttnn.L1_MEMORY_CONFIG) + + y = ttnn.concat((y1, y2, y3), dim=2, memory_config=ttnn.L1_MEMORY_CONFIG) + y = ttnn.squeeze(y, dim=0) + ya, yb = y[:, :, :64], y[:, :, 64:144] + ttnn.deallocate(y1) + ttnn.deallocate(y2) + ttnn.deallocate(y3) + ttnn.deallocate(x1) + ttnn.deallocate(x2) + ttnn.deallocate(x3) + ttnn.deallocate(x4) + ttnn.deallocate(x5) + ttnn.deallocate(x6) + ttnn.deallocate(y) + ya = ttnn.reallocate(ya) + yb = ttnn.reallocate(yb) + ya = ttnn.reshape(ya, (ya.shape[0], y.shape[1], 4, 16)) + ya = ttnn.softmax(ya, dim=-1) + ya = ttnn.permute(ya, (0, 2, 1, 3)) + c = self.dfl(ya) + ttnn.deallocate(ya) + c = ttnn.sharded_to_interleaved(c, memory_config=ttnn.L1_MEMORY_CONFIG) + c = ttnn.to_layout(c, layout=ttnn.ROW_MAJOR_LAYOUT) + c = ttnn.permute(c, (0, 3, 1, 2)) + c = ttnn.reshape(c, (c.shape[0], 1, 4, int(c.shape[3] / 4))) + c = ttnn.reshape(c, (c.shape[0], c.shape[1] * c.shape[2], c.shape[3])) + c1, c2 = c[:, :2, :], c[:, 2:4, :] + + anchor, strides = self.anchors, self.strides + anchor = ttnn.to_memory_config(anchor, memory_config=ttnn.L1_MEMORY_CONFIG) + strides = ttnn.to_memory_config(strides, memory_config=ttnn.L1_MEMORY_CONFIG) + c1 = ttnn.to_layout(c1, layout=ttnn.TILE_LAYOUT) + c2 = ttnn.to_layout(c2, layout=ttnn.TILE_LAYOUT) + + c1 = anchor - c1 + c2 = anchor + c2 + + z1 = c2 - c1 + z2 = c1 + c2 + z2 = ttnn.div(z2, 2) + + z = ttnn.concat((z2, z1), dim=1, memory_config=ttnn.L1_MEMORY_CONFIG) + z = ttnn.multiply(z, strides) + yb = ttnn.permute(yb, (0, 2, 1)) + yb = ttnn.sigmoid(yb) + ttnn.deallocate(c) + ttnn.deallocate(z1) + ttnn.deallocate(z2) + ttnn.deallocate(c1) + ttnn.deallocate(c2) + ttnn.deallocate(anchor) + ttnn.deallocate(strides) + z = ttnn.reallocate(z) + yb = ttnn.reallocate(yb) + z = ttnn.to_layout(z, layout=ttnn.ROW_MAJOR_LAYOUT) + yb = ttnn.to_layout(yb, layout=ttnn.ROW_MAJOR_LAYOUT) + out = ttnn.concat((z, yb), dim=1, memory_config=ttnn.L1_MEMORY_CONFIG) + ttnn.deallocate(yb) + ttnn.deallocate(z) + return out + + +class YoloV11: + def __init__(self, device, parameters): + self.device = device + + self.conv1 = Conv(device, parameters.conv_args[0], parameters.model[0]) + self.conv2 = Conv(device, parameters.conv_args[1], parameters.model[1]) + self.c3k2_1 = C3k2(device, parameters.conv_args[2], parameters.model[2], is_bk_enabled=True) + self.conv3 = Conv(device, parameters.conv_args[3], parameters.model[3]) + self.c3k2_2 = C3k2(device, parameters.conv_args[4], parameters.model[4], is_bk_enabled=True) + self.conv5 = Conv(device, parameters.conv_args[5], parameters.model[5]) + self.c3k2_3 = C3k2(device, parameters.conv_args[6], parameters.model[6], is_bk_enabled=False) + self.conv6 = Conv(device, parameters.conv_args[7], parameters.model[7]) + self.c3k2_4 = C3k2(device, parameters.conv_args[8], parameters.model[8], is_bk_enabled=False) + self.sppf = SPPF(device, parameters.conv_args[9], parameters.model[9]) + self.c2psa = C2PSA(device, parameters.conv_args[10], parameters.model[10]) + self.c3k2_5 = C3k2( + device, + parameters.conv_args[13], + parameters.model[13], + is_bk_enabled=True, + ) + self.c3k2_6 = C3k2( + device, + parameters.conv_args[16], + parameters.model[16], + is_bk_enabled=True, + ) + self.conv7 = Conv(device, parameters.conv_args[17], parameters.model[17]) + self.c3k2_7 = C3k2( + device, + parameters.conv_args[19], + parameters.model[19], + is_bk_enabled=True, + ) + self.conv8 = Conv(device, parameters.conv_args[20], parameters.model[20]) + self.c3k2_8 = C3k2( + device, + parameters.conv_args[22], + parameters.model[22], + is_bk_enabled=False, + ) + self.detect = Detect(device, parameters.model_args.model[23], parameters.model[23]) + + def __call__(self, x): + x = self.conv1(self.device, x) + x = self.conv2(self.device, x) + x = self.c3k2_1(self.device, x) + x = self.conv3(self.device, x) + x = self.c3k2_2(self.device, x) + x4 = x + x = self.conv5(self.device, x) + x = self.c3k2_3(self.device, x) + x6 = x + x = self.conv6(self.device, x) + x = self.c3k2_4(self.device, x) + x = self.sppf(self.device, x) + x = self.c2psa(self.device, x) + x10 = x + x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT) + x = ttnn.reshape(x, (x.shape[0], int(math.sqrt(x.shape[2])), int(math.sqrt(x.shape[2])), x.shape[3])) + nhw = x.shape[0] * x.shape[1] * x.shape[2] + num_cores = determine_num_cores_for_upsample(nhw, x.shape[2]) + core_grid = get_core_grid_from_num_cores(num_cores) + shardspec = ttnn.create_sharded_memory_config_( + x.shape, core_grid, ttnn.ShardStrategy.HEIGHT, orientation=ttnn.ShardOrientation.ROW_MAJOR + ) + if x.is_sharded(): + x = ttnn.reshard(x, shardspec) + else: + x = ttnn.interleaved_to_sharded(x, shardspec) + x = ttnn.upsample(x, scale_factor=2, memory_config=x.memory_config()) # 11 + if x.is_sharded(): + x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG) + x = ttnn.reshape(x, (1, 1, x.shape[0] * x.shape[1] * x.shape[2], x.shape[3])) + x6 = ttnn.to_layout(x6, layout=ttnn.ROW_MAJOR_LAYOUT) + shard_height = (x[0].shape[2] + 64 - 1) // 64 + input_sharded_memory_config_1 = ttnn.create_sharded_memory_config( + (shard_height, x.shape[-1]), + core_grid=ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}), + strategy=ttnn.ShardStrategy.HEIGHT, + use_height_and_width_as_shard_shape=True, + ) + input_sharded_memory_config_2 = ttnn.create_sharded_memory_config( + (shard_height, x6.shape[-1]), + core_grid=ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}), + strategy=ttnn.ShardStrategy.HEIGHT, + use_height_and_width_as_shard_shape=True, + ) + x = ttnn.to_memory_config(x, input_sharded_memory_config_1) + x6 = ttnn.to_memory_config(x6, input_sharded_memory_config_2) + out_sharded_memory_config_ = ttnn.create_sharded_memory_config( + (shard_height, x.shape[-1] + x6.shape[-1]), + core_grid=ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}), + strategy=ttnn.ShardStrategy.HEIGHT, + use_height_and_width_as_shard_shape=True, + ) + x = ttnn.concat((x, x6), -1, memory_config=out_sharded_memory_config_) + + ttnn.deallocate(x6) + if x.shape[2] == 196: + x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG) + x = self.c3k2_5(self.device, x) # 13 + x13 = x + x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT) + x = ttnn.reshape(x, (x.shape[0], int(math.sqrt(x.shape[2])), int(math.sqrt(x.shape[2])), x.shape[3])) + nhw = x.shape[0] * x.shape[1] * x.shape[2] + num_cores = determine_num_cores_for_upsample(nhw, x.shape[2]) + core_grid = get_core_grid_from_num_cores(num_cores) + shardspec = ttnn.create_sharded_memory_config_( + x.shape, core_grid, ttnn.ShardStrategy.HEIGHT, orientation=ttnn.ShardOrientation.ROW_MAJOR + ) + if x.is_sharded(): + x = ttnn.reshard(x, shardspec) + else: + x = ttnn.interleaved_to_sharded(x, shardspec) + x = ttnn.upsample(x, scale_factor=2, memory_config=x.memory_config()) + if x.is_sharded(): + x = ttnn.sharded_to_interleaved(x, memory_config=ttnn.L1_MEMORY_CONFIG) + x = ttnn.reshape(x, (1, 1, x.shape[0] * x.shape[1] * x.shape[2], x.shape[3])) + x4 = ttnn.to_layout(x4, layout=ttnn.ROW_MAJOR_LAYOUT) + x = sharded_concat([x, x4]) + ttnn.deallocate(x4) + x = self.c3k2_6(self.device, x) # 16 + x16 = x + x = self.conv7(self.device, x) # 17 + x = ttnn.concat((x, x13), -1, memory_config=ttnn.L1_MEMORY_CONFIG) # 18 + ttnn.deallocate(x13) + x = self.c3k2_7(self.device, x) # 19 + x19 = x + x = self.conv8(self.device, x) + x = ttnn.concat((x, x10), -1, memory_config=ttnn.L1_MEMORY_CONFIG) # 21 + ttnn.deallocate(x10) + x = self.c3k2_8(self.device, x) # 22 + x22 = x + x = self.detect(self.device, x16, x19, x22) + ttnn.deallocate(x16) + ttnn.deallocate(x19) + ttnn.deallocate(x22) + return x diff --git a/tt_metal/python_env/requirements-dev.txt b/tt_metal/python_env/requirements-dev.txt index f1599339107..b24bc9cbde8 100644 --- a/tt_metal/python_env/requirements-dev.txt +++ b/tt_metal/python_env/requirements-dev.txt @@ -68,3 +68,4 @@ blobfile==2.1.1 # Required for llama3 numpy>=1.24.4,<2 huggingface-hub==0.25.2 pydantic==2.9.2 # Required for Superset benchmarking +fiftyone==0.25.2 # Required for Yolov11 Evaluation