Skip to content

Commit

Permalink
[Feature] Add torchvision resnet backbone and convert detr official w…
Browse files Browse the repository at this point in the history
…eights into detrex format (#208)

* add torchvision resnet backbone and convert detr weights

* add converted model links

* refine dab-detr-r50-dc5 model with torchvision resnet

* add pattern embedding in DAB-DETR

* add dab-detr-3pattern

* refine modelzoo and dab-3patterns dab-dc5 links

* add dab-detr 3patterns configs

* add dab-r101 and update modelzoo

* refine modelzoo
  • Loading branch information
rentainhe authored Feb 15, 2023
1 parent e495888 commit 629f9cc
Show file tree
Hide file tree
Showing 15 changed files with 353 additions and 18 deletions.
106 changes: 106 additions & 0 deletions detrex/modeling/backbone/torchvision_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates.
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/facebookresearch/detr/blob/main/models/backbone.py
# ------------------------------------------------------------------------------------------------

from collections import OrderedDict

import torch
import torch.nn as nn
import torchvision
from torchvision.models._utils import IntermediateLayerGetter

from detectron2.utils.comm import is_main_process


class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other models than torchvision.models.resnet[18,34,50,101]
produce nans.
"""

def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]

super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)

def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias


class BackboneBase(nn.Module):
def __init__(
self,
backbone: nn.Module,
train_backbone: bool,
num_channels: int,
return_layers: dict,
):
super().__init__()
for name, parameter in backbone.named_parameters():
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
parameter.requires_grad_(False)

self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels

def forward(self, x):
xs = self.body(x)
out = {}
for name, x in xs.items():
out[name] = x
return out


class TorchvisionResNet(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self,
name: str,
train_backbone: bool,
return_layers: dict = {"layer4": "res5"},
dilation: bool = False,
):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=False, norm_layer=FrozenBatchNorm2d)
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_layers)
74 changes: 73 additions & 1 deletion docs/source/tutorials/Model_Zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,51 @@


## COCO Object Detection Baselines
Here we provides our pretrained baselines with **detrex**. And more pretrained weights will be released in the future version.
Here we provides our pretrained baselines with **detrex**. And more pretrained weights will be released in the future version. We also provide our converted pretrained for the users which will be marked as `(converted)`.

### DETR
<table class="docutils"><tbody>
<!-- START TABLE -->
<!-- TABLE HEADER -->
<th valign="bottom">Name</th>
<th valign="bottom">Backbone</th>
<th valign="bottom">Pretrained</th>
<th valign="bottom">Epochs</th>
<th valign="bottom">box<br/>AP</th>
<th valign="bottom">Download</th>
<!-- TABLE BODY -->
<!-- ROW: detr_r50 -->
<tr><td align="left"><a href="https://github.com/IDEA-Research/detrex/blob/main/projects/detr/configs/detr_r50_300ep.py">DETR-R50 (converted)</a></td>
<td align="center">R-50</td>
<td align="center">IN1k</td>
<td align="center">500</td>
<td align="center">42.0</td>
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.1.0/converted_detr_r50_500ep.pth">model</a></td>
<!-- ROW: detr_r50_dc5 -->
<tr><td align="left"><a href="https://github.com/IDEA-Research/detrex/blob/main/projects/detr/configs/detr_r50_dc5_300ep.py">DETR-R50-DC5 (converted)</a></td>
<td align="center">R-50</td>
<td align="center">IN1k</td>
<td align="center">500</td>
<td align="center">43.4</td>
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/converted_detr_r50_dc5.pth">model</a></td>
</tr>
<!-- ROW: detr_r101 -->
<tr><td align="left"><a href="https://github.com/IDEA-Research/detrex/blob/main/projects/detr/configs/detr_r101_300ep.py">DETR-R101 (converted)</a></td>
<td align="center">R-101</td>
<td align="center">IN1k</td>
<td align="center">500</td>
<td align="center">43.5</td>
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.1.0/converted_detr_r101_500ep.pth">model</a></td>
</tr>
<!-- ROW: detr_r101_dc5 -->
<tr><td align="left"><a href="https://github.com/IDEA-Research/detrex/blob/main/projects/detr/configs/detr_r101_dc5_300ep.py">DETR-R101-DC5 (converted)</a></td>
<td align="center">R-101</td>
<td align="center">IN1k</td>
<td align="center">500</td>
<td align="center">44.9</td>
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/converted_detr_r101_dc5.pth">model</a></td>
</tr>
</tbody></table>

### Deformable-DETR
<table class="docutils"><tbody>
Expand Down Expand Up @@ -74,13 +118,41 @@ Here we provides our pretrained baselines with **detrex**. And more pretrained w
<td align="center">50</td>
<td align="center">43.3</td>
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.1.0/dab_detr_r50_50ep.pth"> model </a></td>
</tr>
<tr><td align="left"><a href="https://github.com/IDEA-Research/detrex/blob/main/projects/dab_detr/configs/dab_detr_r50_3patterns_50ep.py">DAB-DETR-R50-3patterns (converted)</a></td>
<td align="center">R-50</td>
<td align="center">IN1k</td>
<td align="center">50</td>
<td align="center">42.8</td>
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/converted_dab_detr_r50_3patterns.pth">model</a></td>
</tr>
<tr><td align="left"><a href="https://github.com/IDEA-Research/detrex/blob/main/projects/dab_detr/configs/dab_detr_r50_dc5_50ep.py">DAB-DETR-R50-DC5 (converted)</a></td>
<td align="center">R-50</td>
<td align="center">IN1k</td>
<td align="center">50</td>
<td align="center">44.6</td>
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/converted_dab_detr_r50_dc5.pth">model</a></td>
</tr>
<tr><td align="left"><a href="https://github.com/IDEA-Research/detrex/blob/main/projects/dab_detr/configs/dab_detr_r50_dc5_3patterns_50ep.py">DAB-DETR-R50-DC5-3patterns (converted)</a></td>
<td align="center">R-50</td>
<td align="center">IN1k</td>
<td align="center">50</td>
<td align="center">45.7</td>
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/converted_dab_detr_r50_dc5_3patterns.pth">model</a></td>
</tr>
<tr><td align="left"> <a href="https://github.com/IDEA-Research/detrex/blob/main/projects/dab_detr/configs/dab_detr_r101_50ep.py"> DAB-DETR-R101 </a> </td>
<td align="center">R101</td>
<td align="center">IN1k</td>
<td align="center">50</td>
<td align="center">44.0</td>
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.1.0/dab_detr_r101_50ep.pth"> model </a></td>
</tr>
<tr><td align="left"><a href="https://github.com/IDEA-Research/detrex/blob/main/projects/dab_detr/configs/dab_detr_r50_dc5_3patterns_50ep.py">DAB-DETR-R101-DC5 (converted)</a></td>
<td align="center">R-101</td>
<td align="center">IN1k</td>
<td align="center">50</td>
<td align="center">45.7</td>
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/converted_detr_r101_dc5.pth">model</a></td>
</tr>
<tr><td align="left"> <a href="https://github.com/IDEA-Research/detrex/blob/main/projects/dab_detr/configs/dab_detr_swin_t_in1k_50ep.py"> DAB-DETR-Swin-T </a> </td>
<td align="center">Swin-Tiny-224</td>
Expand Down
44 changes: 44 additions & 0 deletions projects/dab_detr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,50 @@ Here we provide the pretrained `DAB-DETR` weights based on detrex.
</tbody></table>


## Converted Models
Here are the converted the pretrained weights from [DAB-DETR](https://github.com/IDEA-Research/DAB-DETR) official repo.
<table><tbody>
<!-- START TABLE -->
<!-- TABLE HEADER -->
<th valign="bottom">Name</th>
<th valign="bottom">Backbone</th>
<th valign="bottom">Pretrain</th>
<th valign="bottom">Epochs</th>
<th valign="bottom">box<br/>AP</th>
<th valign="bottom">download</th>
<!-- TABLE BODY -->
<!-- ROW: dab_detr_r50_3patterns_50ep -->
<tr><td align="left"><a href="configs/dab_detr_r50_3patterns_50ep.py">DAB-DETR-R50-3patterns</a></td>
<td align="center">R-50</td>
<td align="center">IN1k</td>
<td align="center">50</td>
<td align="center">42.8</td>
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/converted_dab_detr_r50_3patterns.pth">model</a></td>
</tr>
<!-- ROW: dab_detr_r50_dc5_50ep -->
<tr><td align="left"><a href="configs/dab_detr_r50_dc5_50ep.py">DAB-DETR-R50-DC5</a></td>
<td align="center">R-50</td>
<td align="center">IN1k</td>
<td align="center">50</td>
<td align="center">44.6</td>
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/converted_dab_detr_r50_dc5.pth">model</a></td>
</tr>
<tr><td align="left"><a href="configs/dab_detr_r50_dc5_3patterns_50ep.py">DAB-DETR-R50-DC5-3patterns</a></td>
<td align="center">R-50</td>
<td align="center">IN1k</td>
<td align="center">50</td>
<td align="center">45.7</td>
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/converted_dab_detr_r50_dc5_3patterns.pth">model</a></td>
</tr>
<tr><td align="left"><a href="configs/dab_detr_r50_dc5_3patterns_50ep.py">DAB-DETR-R101-DC5</a></td>
<td align="center">R-101</td>
<td align="center">IN1k</td>
<td align="center">50</td>
<td align="center">45.7</td>
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/converted_detr_r101_dc5.pth">model</a></td>
</tr>
</tbody></table>

## Training
All configs can be trained with:
```bash
Expand Down
14 changes: 14 additions & 0 deletions projects/dab_detr/configs/dab_detr_r101_dc5_50ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .dab_detr_r50_50ep import (
train,
dataloader,
optimizer,
lr_multiplier,
)
from .models.dab_detr_r50_dc5 import model

# modify training config
train.init_checkpoint = "https://download.pytorch.org/models/resnet101-63fe2227.pth"
train.output_dir = "./output/dab_detr_r101_dc5_50ep"

# modify model
model.backbone.name = "resnet101"
14 changes: 14 additions & 0 deletions projects/dab_detr/configs/dab_detr_r50_3patterns_50ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .dab_detr_r50_50ep import (
train,
dataloader,
optimizer,
lr_multiplier,
model,
)

# modify training config
train.init_checkpoint = "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
train.output_dir = "./output/dab_detr_r50_3patterns_50ep"

# using 3 pattern embeddings as in Anchor-DETR
model.transformer.num_patterns = 3
17 changes: 17 additions & 0 deletions projects/dab_detr/configs/dab_detr_r50_dc5_3patterns_50ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .dab_detr_r50_dc5_50ep import (
train,
dataloader,
optimizer,
lr_multiplier,
model,
)

# modify training config
train.init_checkpoint = "https://download.pytorch.org/models/resnet50-0676ba61.pth"
train.output_dir = "./output/dab_detr_r50_dc5_3patterns_50ep"

# using 3 pattern embeddings as in Anchor-DETR
model.transformer.num_patterns = 3

# modify model
model.position_embedding.temperature = 20
5 changes: 4 additions & 1 deletion projects/dab_detr/configs/dab_detr_r50_dc5_50ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,8 @@
from .models.dab_detr_r50_dc5 import model

# modify training config
train.init_checkpoint = "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
train.init_checkpoint = "https://download.pytorch.org/models/resnet50-0676ba61.pth"
train.output_dir = "./output/dab_detr_r50_dc5_50ep"

# modify model
model.position_embedding.temperature = 10
1 change: 1 addition & 0 deletions projects/dab_detr/configs/models/dab_detr_r50.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
num_layers=6,
modulate_hw_attn=True,
),
num_patterns=0, # pattern embedding as in Anchor-DETR
),
embed_dim=256,
num_classes=80,
Expand Down
5 changes: 5 additions & 0 deletions projects/dab_detr/configs/models/dab_detr_r50_3patterns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .dab_detr_r50 import model


# using 3 pattern embeddings as in Anchor-DETR
model.transformer.num_patterns = 3
17 changes: 6 additions & 11 deletions projects/dab_detr/configs/models/dab_detr_r50_dc5.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from detectron2.config import LazyCall as L
from detrex.modeling.backbone import ResNet, BasicStem, make_stage
from detrex.modeling.backbone.torchvision_resnet import TorchvisionResNet

from .dab_detr_r50 import model


model.backbone = L(ResNet)(
stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"),
stages=L(make_stage)(
depth=50,
stride_in_1x1=False,
norm="FrozenBN",
res5_dilation=2,
),
out_features=["res2", "res3", "res4", "res5"],
freeze_at=1,
model.backbone=L(TorchvisionResNet)(
name="resnet50",
train_backbone=True,
dilation=True,
return_layers={"layer4": "res5"}
)
16 changes: 14 additions & 2 deletions projects/dab_detr/modeling/dab_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,18 @@ def forward(


class DabDetrTransformer(nn.Module):
def __init__(self, encoder=None, decoder=None):
def __init__(self, encoder=None, decoder=None, num_patterns=0):
super(DabDetrTransformer, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.embed_dim = self.encoder.embed_dim

# using patterns designed as AnchorDETR
assert isinstance(num_patterns, int), "num_patterns should be int but got {}".format(type(num_patterns))
self.num_patterns = num_patterns
if self.num_patterns > 0:
self.patterns = nn.Embedding(self.num_patterns, self.embed_dim)

self.init_weights()

def init_weights(self):
Expand All @@ -285,7 +291,13 @@ def forward(self, x, mask, anchor_box_embed, pos_embed):
query_key_padding_mask=mask,
)
num_queries = anchor_box_embed.shape[0]
target = torch.zeros(num_queries, bs, self.embed_dim, device=anchor_box_embed.device)


if self.num_patterns == 0:
target = torch.zeros(num_queries, bs, self.embed_dim, device=anchor_box_embed.device)
else:
target = self.patterns.weight[:, None, None, :].repeat(1, num_queries, bs, 1).flatten(0, 1)
anchor_box_embed = anchor_box_embed.repeat(self.num_patterns, 1, 1)

hidden_state, reference_boxes = self.decoder(
query=target,
Expand Down
Loading

0 comments on commit 629f9cc

Please sign in to comment.