-
Notifications
You must be signed in to change notification settings - Fork 216
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add torchvision resnet backbone and convert detr official w…
…eights into detrex format (#208) * add torchvision resnet backbone and convert detr weights * add converted model links * refine dab-detr-r50-dc5 model with torchvision resnet * add pattern embedding in DAB-DETR * add dab-detr-3pattern * refine modelzoo and dab-3patterns dab-dc5 links * add dab-detr 3patterns configs * add dab-r101 and update modelzoo * refine modelzoo
- Loading branch information
Showing
15 changed files
with
353 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# coding=utf-8 | ||
# Copyright 2022 The IDEA Authors. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ------------------------------------------------------------------------------------------------ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# ------------------------------------------------------------------------------------------------ | ||
# Modified from: | ||
# https://github.com/facebookresearch/detr/blob/main/models/backbone.py | ||
# ------------------------------------------------------------------------------------------------ | ||
|
||
from collections import OrderedDict | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torchvision | ||
from torchvision.models._utils import IntermediateLayerGetter | ||
|
||
from detectron2.utils.comm import is_main_process | ||
|
||
|
||
class FrozenBatchNorm2d(torch.nn.Module): | ||
""" | ||
BatchNorm2d where the batch statistics and the affine parameters are fixed. | ||
Copy-paste from torchvision.misc.ops with added eps before rqsrt, | ||
without which any other models than torchvision.models.resnet[18,34,50,101] | ||
produce nans. | ||
""" | ||
|
||
def __init__(self, n): | ||
super(FrozenBatchNorm2d, self).__init__() | ||
self.register_buffer("weight", torch.ones(n)) | ||
self.register_buffer("bias", torch.zeros(n)) | ||
self.register_buffer("running_mean", torch.zeros(n)) | ||
self.register_buffer("running_var", torch.ones(n)) | ||
|
||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | ||
missing_keys, unexpected_keys, error_msgs): | ||
num_batches_tracked_key = prefix + 'num_batches_tracked' | ||
if num_batches_tracked_key in state_dict: | ||
del state_dict[num_batches_tracked_key] | ||
|
||
super(FrozenBatchNorm2d, self)._load_from_state_dict( | ||
state_dict, prefix, local_metadata, strict, | ||
missing_keys, unexpected_keys, error_msgs) | ||
|
||
def forward(self, x): | ||
# move reshapes to the beginning | ||
# to make it fuser-friendly | ||
w = self.weight.reshape(1, -1, 1, 1) | ||
b = self.bias.reshape(1, -1, 1, 1) | ||
rv = self.running_var.reshape(1, -1, 1, 1) | ||
rm = self.running_mean.reshape(1, -1, 1, 1) | ||
eps = 1e-5 | ||
scale = w * (rv + eps).rsqrt() | ||
bias = b - rm * scale | ||
return x * scale + bias | ||
|
||
|
||
class BackboneBase(nn.Module): | ||
def __init__( | ||
self, | ||
backbone: nn.Module, | ||
train_backbone: bool, | ||
num_channels: int, | ||
return_layers: dict, | ||
): | ||
super().__init__() | ||
for name, parameter in backbone.named_parameters(): | ||
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: | ||
parameter.requires_grad_(False) | ||
|
||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) | ||
self.num_channels = num_channels | ||
|
||
def forward(self, x): | ||
xs = self.body(x) | ||
out = {} | ||
for name, x in xs.items(): | ||
out[name] = x | ||
return out | ||
|
||
|
||
class TorchvisionResNet(BackboneBase): | ||
"""ResNet backbone with frozen BatchNorm.""" | ||
def __init__(self, | ||
name: str, | ||
train_backbone: bool, | ||
return_layers: dict = {"layer4": "res5"}, | ||
dilation: bool = False, | ||
): | ||
backbone = getattr(torchvision.models, name)( | ||
replace_stride_with_dilation=[False, False, dilation], | ||
pretrained=False, norm_layer=FrozenBatchNorm2d) | ||
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 | ||
super().__init__(backbone, train_backbone, num_channels, return_layers) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from .dab_detr_r50_50ep import ( | ||
train, | ||
dataloader, | ||
optimizer, | ||
lr_multiplier, | ||
) | ||
from .models.dab_detr_r50_dc5 import model | ||
|
||
# modify training config | ||
train.init_checkpoint = "https://download.pytorch.org/models/resnet101-63fe2227.pth" | ||
train.output_dir = "./output/dab_detr_r101_dc5_50ep" | ||
|
||
# modify model | ||
model.backbone.name = "resnet101" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from .dab_detr_r50_50ep import ( | ||
train, | ||
dataloader, | ||
optimizer, | ||
lr_multiplier, | ||
model, | ||
) | ||
|
||
# modify training config | ||
train.init_checkpoint = "detectron2://ImageNetPretrained/torchvision/R-50.pkl" | ||
train.output_dir = "./output/dab_detr_r50_3patterns_50ep" | ||
|
||
# using 3 pattern embeddings as in Anchor-DETR | ||
model.transformer.num_patterns = 3 |
17 changes: 17 additions & 0 deletions
17
projects/dab_detr/configs/dab_detr_r50_dc5_3patterns_50ep.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from .dab_detr_r50_dc5_50ep import ( | ||
train, | ||
dataloader, | ||
optimizer, | ||
lr_multiplier, | ||
model, | ||
) | ||
|
||
# modify training config | ||
train.init_checkpoint = "https://download.pytorch.org/models/resnet50-0676ba61.pth" | ||
train.output_dir = "./output/dab_detr_r50_dc5_3patterns_50ep" | ||
|
||
# using 3 pattern embeddings as in Anchor-DETR | ||
model.transformer.num_patterns = 3 | ||
|
||
# modify model | ||
model.position_embedding.temperature = 20 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .dab_detr_r50 import model | ||
|
||
|
||
# using 3 pattern embeddings as in Anchor-DETR | ||
model.transformer.num_patterns = 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,12 @@ | ||
from detectron2.config import LazyCall as L | ||
from detrex.modeling.backbone import ResNet, BasicStem, make_stage | ||
from detrex.modeling.backbone.torchvision_resnet import TorchvisionResNet | ||
|
||
from .dab_detr_r50 import model | ||
|
||
|
||
model.backbone = L(ResNet)( | ||
stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"), | ||
stages=L(make_stage)( | ||
depth=50, | ||
stride_in_1x1=False, | ||
norm="FrozenBN", | ||
res5_dilation=2, | ||
), | ||
out_features=["res2", "res3", "res4", "res5"], | ||
freeze_at=1, | ||
model.backbone=L(TorchvisionResNet)( | ||
name="resnet50", | ||
train_backbone=True, | ||
dilation=True, | ||
return_layers={"layer4": "res5"} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.