Skip to content

Commit

Permalink
Decouple mmpretrain: Efficientnet-V2-S (#3353)
Browse files Browse the repository at this point in the history
* Initial commit for decouple mmpretrain

* Remove mv3 with mmlab

* Fix XAI explain functions

* Revisit ignored_label for multi-label & h-label head

* Fix some pipeline

* Fix wrong config value

* Fix XAI with H-label

* Fix H-label & export with explain

* Resolve some conflicts

* Add TODO comments

* Rename mv_cfgs to backbone_configs

* Add timm.py

* Decouple Efficientnet-v2 with Timm

* Revisit forward_tracing

* Replace recipes

* Remove comments

* Remove mmconfigs for effnet-V2
  • Loading branch information
harimkang authored Apr 22, 2024
1 parent 0388471 commit cd521af
Show file tree
Hide file tree
Showing 18 changed files with 666 additions and 389 deletions.
4 changes: 2 additions & 2 deletions src/otx/algo/classification/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Backbone modules for OTX custom model."""

from .otx_efficientnet import OTXEfficientNet
from .otx_efficientnet_v2 import OTXEfficientNetV2
from .timm import TimmBackbone
from .mobilenet_v3 import OTXMobileNetV3

__all__ = ["OTXEfficientNet", "OTXEfficientNetV2", "OTXMobileNetV3"]
__all__ = ["OTXEfficientNet", "TimmBackbone", "OTXMobileNetV3"]
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

"""EfficientNetV2 model.
Original papers:
Expand All @@ -11,18 +11,19 @@

import os

import torch
import timm
from mmengine.runner import load_checkpoint
from mmpretrain.registry import MODELS
from otx.algo.utils.mmengine_utils import load_from_http, load_checkpoint_to_model
from torch import nn
from typing import Literal

PRETRAINED_ROOT = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/"
pretrained_urls = {
"efficientnetv2_s_21k": PRETRAINED_ROOT + "tf_efficientnetv2_s_21k-6337ad01.pth",
"efficientnetv2_s_1k": PRETRAINED_ROOT + "tf_efficientnetv2_s_21ft1k-d7dafa41.pth",
}

NAME_DICT = {
TIMM_MODEL_NAME_DICT = {
"mobilenetv3_large_21k": "mobilenetv3_large_100_miil_in21k",
"mobilenetv3_large_1k": "mobilenetv3_large_100_miil",
"tresnet": "tresnet_m",
Expand All @@ -33,22 +34,34 @@
"efficientnetv2_b0": "tf_efficientnetv2_b0",
}


class TimmModelsWrapper(nn.Module):
"""Timm model wrapper."""

def __init__(self, model_name, pretrained=False, pooling_type="avg", **kwargs):
TimmModelType = Literal[
"mobilenetv3_large_21k",
"mobilenetv3_large_1k",
"tresnet",
"efficientnetv2_s_21k",
"efficientnetv2_s_1k",
"efficientnetv2_m_21k",
"efficientnetv2_m_1k",
"efficientnetv2_b0",
]


class TimmBackbone(nn.Module):
def __init__(
self,
backbone: TimmModelType,
pretrained=False,
pooling_type="avg",
**kwargs,
):
super().__init__(**kwargs)
self.model_name = model_name
self.backbone = backbone
self.pretrained = pretrained
if model_name in ["mobilenetv3_large_100_miil_in21k", "mobilenetv3_large_100_miil"]:
self.is_mobilenet = True
else:
self.is_mobilenet = False
self.is_mobilenet = backbone.startswith("mobilenet")

self.model = timm.create_model(NAME_DICT[self.model_name], pretrained=pretrained, num_classes=1000)
self.model = timm.create_model(TIMM_MODEL_NAME_DICT[self.backbone], pretrained=pretrained, num_classes=1000)
if self.pretrained:
print(f"init weight - {pretrained_urls[self.model_name]}")
print(f"init weight - {pretrained_urls[self.backbone]}")
self.model.classifier = None # Detach classifier. Only use 'backbone' part in otx.
self.num_head_features = self.model.num_features
self.num_features = self.model.conv_head.in_channels if self.is_mobilenet else self.model.num_features
Expand Down Expand Up @@ -85,20 +98,14 @@ def get_config_optim(self, lrs):

return parameters


@MODELS.register_module()
class OTXEfficientNetV2(TimmModelsWrapper):
"""EfficientNetV2 for OTX."""

def __init__(self, version="s_21k", **kwargs):
self.model_name = "efficientnetv2_" + version
super().__init__(model_name=self.model_name, **kwargs)

def init_weights(self, pretrained=None):
def init_weights(self, pretrained: str | bool | None = None):
"""Initialize weights."""
checkpoint = None
if isinstance(pretrained, str) and os.path.exists(pretrained):
load_checkpoint(self, pretrained)
checkpoint = torch.load(pretrained, None)
print(f"init weight - {pretrained}")
elif pretrained is not None:
load_checkpoint(self, pretrained_urls[self.model_name])
print(f"init weight - {pretrained_urls[self.model_name]}")
checkpoint = load_from_http(pretrained_urls[self.key])
print(f"init weight - {pretrained_urls[self.key]}")
if checkpoint is not None:
load_checkpoint_to_model(self, checkpoint)
Loading

0 comments on commit cd521af

Please sign in to comment.