-
Notifications
You must be signed in to change notification settings - Fork 216
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add anchor detr model * add config * add converted model weights * refine README * refine README and modelzoo * refine README * refine README * refine config * add detrex pretrained model
- Loading branch information
Showing
15 changed files
with
1,426 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
## Anchor DETR: Query Design for Transformer-Based Object Detection | ||
|
||
Yingming Wang, Xiangyu Zhang, Tong Yang, Jian Sun | ||
|
||
[[`arXiv`](https://arxiv.org/abs/2109.07107)] [[`BibTeX`](#citing-anchor-detr)] | ||
|
||
<div align="center"> | ||
<img src="./assets/anchor_detr_arch.png"/> | ||
</div><br/> | ||
|
||
## Pretrained Weights | ||
Here's our pretrained Anchor-DETR weights based on detrex. | ||
|
||
<table><tbody> | ||
<!-- START TABLE --> | ||
<!-- TABLE HEADER --> | ||
<th valign="bottom">Name</th> | ||
<th valign="bottom">Backbone</th> | ||
<th valign="bottom">Pretrain</th> | ||
<th valign="bottom">Epochs</th> | ||
<th valign="bottom">box<br/>AP</th> | ||
<th valign="bottom">download</th> | ||
<!-- TABLE BODY --> | ||
<tr><td align="left"><a href="configs/anchor_detr_r50_50ep.py">Anchor-DETR-R50</a></td> | ||
<td align="center">R-50</td> | ||
<td align="center">IN1k</td> | ||
<td align="center">50</td> | ||
<td align="center">41.9</td> | ||
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/anchor_detr_r50_50ep.pth">model</a></td> | ||
</tr> | ||
</tbody></table> | ||
|
||
## Converted Weights | ||
<table><tbody> | ||
<!-- START TABLE --> | ||
<!-- TABLE HEADER --> | ||
<th valign="bottom">Name</th> | ||
<th valign="bottom">Backbone</th> | ||
<th valign="bottom">Pretrain</th> | ||
<th valign="bottom">Epochs</th> | ||
<th valign="bottom">box<br/>AP</th> | ||
<th valign="bottom">download</th> | ||
<!-- TABLE BODY --> | ||
<tr><td align="left"><a href="configs/anchor_detr_r50_50ep.py">Anchor-DETR-R50</a></td> | ||
<td align="center">R-50</td> | ||
<td align="center">IN1k</td> | ||
<td align="center">50</td> | ||
<td align="center">42.2</td> | ||
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/converted_anchor_detr_r50_50ep.pth">model</a></td> | ||
</tr> | ||
<tr><td align="left"><a href="configs/detr_r50_dc5_300ep.py">Anchor-DETR-R50-DC5</a></td> | ||
<td align="center">R-50</td> | ||
<td align="center">IN1k</td> | ||
<td align="center">50</td> | ||
<td align="center">44.2</td> | ||
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/converted_anchor_detr_r50_dc5_50ep.pth">model</a></td> | ||
</tr> | ||
<tr><td align="left"><a href="configs/detr_r101_300ep.py">Anchor-DETR-R101</a></td> | ||
<td align="center">R-101</td> | ||
<td align="center">IN1k</td> | ||
<td align="center">50</td> | ||
<td align="center">43.5</td> | ||
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/converted_anchor_detr_r101_50ep.pth">model</a></td> | ||
</tr> | ||
<tr><td align="left"><a href="configs/detr_r101_dc5_300ep.py">Anchor-DETR-R101-DC5</a></td> | ||
<td align="center">R-101</td> | ||
<td align="center">IN1k</td> | ||
<td align="center">50</td> | ||
<td align="center">45.1</td> | ||
<td align="center"> <a href="https://github.com/IDEA-Research/detrex-storage/releases/download/v0.3.0/converted_anchor_detr_r101_dc5_50ep.pth">model</a></td> | ||
</tr> | ||
</tbody></table> | ||
|
||
**Note:** Here we borrowed the pretrained weight from [Anchor-DETR](https://github.com/megvii-research/AnchorDETR) official repo. And our detrex training results will be released in the future version. | ||
|
||
## Training | ||
Training Anchor-DETR-R50 model: | ||
```bash | ||
cd detrex | ||
python tools/train_net.py --config-file projects/anchor_detr/configs/anchor_detr_r50_50ep.py --num-gpus 8 | ||
``` | ||
By default, we use 8 GPUs with total batch size as 64 for training. | ||
|
||
## Evaluation | ||
Model evaluation can be done as follows: | ||
```bash | ||
cd detrex | ||
python tools/train_net.py --config-file projects/anchor_detr/configs/path/to/config.py \ | ||
--eval-only train.init_checkpoint=/path/to/model_checkpoint | ||
``` | ||
|
||
|
||
## Citing Anchor-DETR | ||
```BibTex | ||
@inproceedings{wang2022anchor, | ||
title={Anchor detr: Query design for transformer-based detector}, | ||
author={Wang, Yingming and Zhang, Xiangyu and Yang, Tong and Sun, Jian}, | ||
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, | ||
volume={36}, | ||
number={3}, | ||
pages={2567--2575}, | ||
year={2022} | ||
} | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from .anchor_detr_r50_50ep import ( | ||
train, | ||
dataloader, | ||
optimizer, | ||
lr_multiplier, | ||
) | ||
from .models.anchor_detr_r50 import model | ||
|
||
# modify training config | ||
train.init_checkpoint = "https://download.pytorch.org/models/resnet101-63fe2227.pth" | ||
train.output_dir = "./output/anchor_detr_r101_50ep" | ||
|
||
# modify model | ||
model.backbone.name = "resnet101" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from .anchor_detr_r50_50ep import ( | ||
train, | ||
dataloader, | ||
optimizer, | ||
lr_multiplier, | ||
) | ||
from .models.anchor_detr_r50 import model | ||
|
||
# modify training config | ||
train.init_checkpoint = "https://download.pytorch.org/models/resnet101-63fe2227.pth" | ||
train.output_dir = "./output/anchor_detr_r101_dc5_50ep" | ||
|
||
# modify model | ||
model.backbone.name = "resnet101" | ||
model.backbone.dilation = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from detrex.config import get_config | ||
from .models.anchor_detr_r50 import model | ||
|
||
dataloader = get_config("common/data/coco_detr.py").dataloader | ||
optimizer = get_config("common/optim.py").AdamW | ||
lr_multiplier = get_config("common/coco_schedule.py").lr_multiplier_50ep | ||
train = get_config("common/train.py").train | ||
|
||
# modify training config | ||
train.init_checkpoint = "https://download.pytorch.org/models/resnet50-0676ba61.pth" | ||
train.output_dir = "./output/anchor_detr_r50_50ep" | ||
|
||
# max training iterations | ||
train.max_iter = 375000 | ||
|
||
# run evaluation every 5000 iters | ||
train.eval_period = 5000 | ||
|
||
# log training infomation every 20 iters | ||
train.log_period = 20 | ||
|
||
# save checkpoint every 5000 iters | ||
train.checkpointer.period = 5000 | ||
|
||
# gradient clipping for training | ||
train.clip_grad.enabled = True | ||
train.clip_grad.params.max_norm = 0.1 | ||
train.clip_grad.params.norm_type = 2 | ||
|
||
# set training devices | ||
train.device = "cuda" | ||
model.device = train.device | ||
|
||
# modify optimizer config | ||
optimizer.lr = 1e-4 | ||
optimizer.betas = (0.9, 0.999) | ||
optimizer.weight_decay = 1e-4 | ||
optimizer.params.lr_factor_func = lambda module_name: 0.1 if "backbone" in module_name else 1 | ||
|
||
# modify dataloader config | ||
dataloader.train.num_workers = 16 | ||
|
||
# please notice that this is total batch size. | ||
# surpose you're using 4 gpus for training and the batch size for | ||
# each gpu is 16/4 = 4 | ||
dataloader.train.total_batch_size = 16 | ||
|
||
# dump the testing results into output_dir for visualization | ||
dataloader.evaluator.output_dir = train.output_dir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from .anchor_detr_r50_50ep import ( | ||
train, | ||
dataloader, | ||
optimizer, | ||
lr_multiplier, | ||
) | ||
from .models.anchor_detr_r50 import model | ||
|
||
# modify training config | ||
train.init_checkpoint = "https://download.pytorch.org/models/resnet50-0676ba61.pth" | ||
train.output_dir = "./output/anchor_detr_r50_dc5_50ep" | ||
|
||
# modify model | ||
model.backbone.dilation = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from detectron2.config import LazyCall as L | ||
|
||
from detrex.modeling.matcher import HungarianMatcher | ||
from detrex.modeling.criterion import SetCriterion | ||
from detrex.modeling.backbone.torchvision_resnet import TorchvisionResNet | ||
|
||
from projects.anchor_detr.modeling import ( | ||
AnchorDETR, | ||
AnchorDETRTransformer, | ||
) | ||
|
||
|
||
model = L(AnchorDETR)( | ||
backbone=L(TorchvisionResNet)( | ||
name="resnet50", | ||
train_backbone=True, | ||
dilation=False, | ||
return_layers={"layer4": "res5"} | ||
), | ||
in_features=["res5"], # only use last level feature in Conditional-DETR | ||
in_channels=2048, | ||
embed_dim=256, | ||
transformer=L(AnchorDETRTransformer)( | ||
embed_dim=256, | ||
num_heads=8, | ||
num_encoder_layers=6, | ||
num_decoder_layers=6, | ||
dim_feedforward=1024, | ||
dropout=0., | ||
activation="relu", | ||
num_query_position=300, | ||
num_query_pattern=3, | ||
spatial_prior="learned", # choose from ["learned", "grid"] | ||
attention_type="RCDA", # choose from ["RCDA", "nn.MultiheadAttention"] | ||
num_classes=80, | ||
), | ||
criterion=L(SetCriterion)( | ||
num_classes=80, | ||
matcher=L(HungarianMatcher)( | ||
cost_class=2.0, | ||
cost_bbox=5.0, | ||
cost_giou=2.0, | ||
cost_class_type="focal_loss_cost", | ||
alpha=0.25, | ||
gamma=2.0, | ||
), | ||
weight_dict={ | ||
"loss_class": 2.0, | ||
"loss_bbox": 5.0, | ||
"loss_giou": 2.0, | ||
}, | ||
loss_class_type="focal_loss", | ||
alpha=0.25, | ||
gamma=2.0, | ||
), | ||
aux_loss=True, | ||
pixel_mean=[123.675, 116.280, 103.530], | ||
pixel_std=[58.395, 57.120, 57.375], | ||
select_box_nums_for_evaluation=100, | ||
device="cuda", | ||
) | ||
|
||
# set aux loss weight dict | ||
if model.aux_loss: | ||
weight_dict = model.criterion.weight_dict | ||
aux_weight_dict = {} | ||
for i in range(model.transformer.num_decoder_layers - 1): | ||
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) | ||
weight_dict.update(aux_weight_dict) | ||
model.criterion.weight_dict = weight_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .anchor_detr import AnchorDETR | ||
from .anchor_detr_transformer import AnchorDETRTransformer |
Oops, something went wrong.