-
Notifications
You must be signed in to change notification settings - Fork 443
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MMDet MaskRCNN ResNet50/SwinTransformer Decouple (#3281)
* migrate mmdet maskrcnn modules * ignore mypy, ruff errors * skip mypy error * add maskrcnn * add cross-entropy loss * style changes * mypy changes and style changes * update style * remove box structures * modify resnet * add annotation * fix all mypy issues * fix mypy issues * style changes * remove unused losses * remove focal_loss_pb * fix all rull and mypy issues * remove duplicates * remove mmdet mask structures * remove duplicates * style changes * add new test * test style change * chagne device for unit test * add deployment files * remove deployment from inst-seg * update deployment * add mmdeploy maskrcnn opset * replace mmcv.cnn module * remove upsample building * remove upsample building * use batch_nms from otx * add swintransformer * add transformers * add swin transformer * update instance_segmentation/maskrcnn.py * update nms * change rotate detection recipe
- Loading branch information
1 parent
4d80c49
commit fffa2eb
Showing
73 changed files
with
6,627 additions
and
432 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
"""Functions for mmdeploy adapters.""" | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
|
||
import importlib | ||
|
||
|
||
def is_mmdeploy_enabled() -> bool: | ||
"""Checks if the 'mmdeploy' Python module is installed and available for use. | ||
Returns: | ||
bool: True if 'mmdeploy' is installed, False otherwise. | ||
Example: | ||
>>> is_mmdeploy_enabled() | ||
True | ||
""" | ||
return importlib.util.find_spec("mmdeploy") is not None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# This class and its supporting functions are adapted from the mmdet. | ||
# Please refer to https://github.com/open-mmlab/mmdetection/ | ||
|
||
"""MMDet Accuracy.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
from torch import Tensor | ||
|
||
|
||
def accuracy( | ||
pred: Tensor, | ||
target: Tensor, | ||
topk: int | tuple[int] = 1, | ||
thresh: float | None = None, | ||
) -> list[Tensor] | Tensor: | ||
"""Calculate accuracy according to the prediction and target. | ||
Args: | ||
pred (torch.Tensor): The model prediction, shape (N, num_class) | ||
target (torch.Tensor): The target of each prediction, shape (N, ) | ||
topk (int | tuple[int], optional): If the predictions in ``topk`` | ||
matches the target, the predictions will be regarded as | ||
correct ones. Defaults to 1. | ||
thresh (float, optional): If not None, predictions with scores under | ||
this threshold are considered incorrect. Default to None. | ||
Returns: | ||
float | tuple[float]: If the input ``topk`` is a single integer, | ||
the function will return a single float as accuracy. If | ||
``topk`` is a tuple containing multiple integers, the | ||
function will return a tuple containing accuracies of | ||
each ``topk`` number. | ||
""" | ||
if not isinstance(topk, (int, tuple)): | ||
msg = f"topk must be int or tuple of int, got {type(topk)}" | ||
raise TypeError(msg) | ||
if isinstance(topk, int): | ||
topk = (topk,) | ||
return_single = True | ||
else: | ||
return_single = False | ||
|
||
maxk = max(topk) | ||
if pred.size(0) == 0: | ||
accu = [pred.new_tensor(0.0) for i in range(len(topk))] | ||
return accu[0] if return_single else accu | ||
if pred.ndim != 2 or target.ndim != 1: | ||
msg = "Input tensors must have 2 dims for pred and 1 dim for target" | ||
raise ValueError(msg) | ||
if pred.size(0) != target.size(0): | ||
msg = "Input tensors must have the same size along the 0th dim" | ||
raise ValueError(msg) | ||
if maxk > pred.size(1): | ||
msg = f"maxk {maxk} exceeds pred dimension {pred.size(1)}" | ||
raise ValueError(msg) | ||
pred_value, pred_label = pred.topk(maxk, dim=1) | ||
pred_label = pred_label.t() # transpose to shape (maxk, N) | ||
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label)) | ||
if thresh is not None: | ||
# Only prediction values larger than thresh are counted as correct | ||
correct = correct & (pred_value > thresh).t() | ||
res = [] | ||
for k in topk: | ||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) | ||
res.append(correct_k.mul_(100.0 / pred.size(0))) | ||
return res[0] if return_single else res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.