Skip to content

Commit

Permalink
[Feature] Support PnP-DETR (#237)
Browse files Browse the repository at this point in the history
* support PnP-DETR

* refine README
  • Loading branch information
rentainhe authored Mar 18, 2023
1 parent 69b50b8 commit 0f7e6a8
Show file tree
Hide file tree
Showing 10 changed files with 756 additions and 11 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ Results and models are available in [model zoo](https://detrex.readthedocs.io/en

- [x] [DETR (ECCV'2020)](./projects/detr/)
- [x] [Deformable-DETR (ICLR'2021 Oral)](./projects/deformable_detr/)
- [x] [PnP-DETR (ICCV'2021)](./projects/pnp_detr/)
- [x] [Conditional-DETR (ICCV'2021)](./projects/conditional_detr/)
- [x] [Anchor-DETR (AAAI 2022)](./projects/anchor_detr/)
- [x] [DAB-DETR (ICLR'2022)](./projects/dab_detr/)
Expand Down
23 changes: 12 additions & 11 deletions projects/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
Here are projects that are built on detrex which show you use detrex as a library and make your own projects more maintainable.

## Projects by detrex
- [End-to-End Object Detection with Transformers](./detr)
- [Deformable DETR: Deformable Transformers for End-to-End Object Detection](./deformable_detr/)
- [Conditional DETR for Fast Training Convergence](./conditional_detr/)
- [Anchor DETR: Query Design for Transformer-Based Detector](./anchor_detr/)
- [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR](./dab_detr/)
- [DN-DETR: Accelerate DETR Training by Introducing Query DeNoising](./dn_detr/)
- [DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection](./dino)
- [Group DETR: Fast DETR Training with Group-Wise One-to-Many Assignment](./group_detr/)
- [DETRs with Hybrid Matching](./h_deformable_detr/)
- [Mask DINO: Towards A Unified Transformer-based Framework for Object Detection and Segmentation](./maskdino/)
- [NMS strikes back](./deta/)
- [End-to-End Object Detection with Transformers (ECCV'2020)](./detr)
- [Deformable DETR: Deformable Transformers for End-to-End Object Detection (ICLR'2021 Oral)](./deformable_detr/)
- [PnP-DETR: Towards Efficient Visual Analysis with Transformers (ICCV'2021)](./pnp_detr/)
- [Conditional DETR for Fast Training Convergence (ICCV'2021)](./conditional_detr/)
- [Anchor DETR: Query Design for Transformer-Based Detector (AAAI'2022)](./anchor_detr/)
- [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR (ICLR'2022)](./dab_detr/)
- [DN-DETR: Accelerate DETR Training by Introducing Query DeNoising (CVPR'2022 Oral)](./dn_detr/)
- [DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection (ICLR'2023)](./dino)
- [Group DETR: Fast DETR Training with Group-Wise One-to-Many Assignment (ArXiv'2022)](./group_detr/)
- [DETRs with Hybrid Matching (CVPR'2023)](./h_deformable_detr/)
- [Mask DINO: Towards A Unified Transformer-based Framework for Object Detection and Segmentation (CVPR'2023)](./maskdino/)
- [NMS strikes back (ArXiv'2022)](./deta/)
38 changes: 38 additions & 0 deletions projects/pnp_detr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
## PnP-DETR: Towards Efficient Visual Analysis with Transformers

Tao Wang, Li Yuan, Yunpeng Chen, Jiashi Feng, Shuicheng Yan

[[`arXiv`](https://arxiv.org/abs/2109.07036)] [[`BibTeX`](#citing-pnp-detr)]

<div align="center">
<img src="./assets/PnP-DETR.png"/>
</div><br/>


## Training
Training PnP-DETR model for 300 epochs:
```bash
cd detrex
python tools/train_net.py --config-file projects/pnp_detr/configs/pnp_detr_r50_300ep.py --num-gpus 8
```
By default, we use 8 GPUs with total batch size as 64 for training.

## Evaluation
Model evaluation can be done as follows:
```bash
cd detrex
python tools/train_net.py --config-file projects/pnp_detr/configs/path/to/config.py \
--eval-only train.init_checkpoint=/path/to/model_checkpoint
```


## Citing PnP-DETR
```BibTex
@inproceedings{wang2021pnp,
title={PnP-DETR: Towards Efficient Visual Analysis with Transformers},
author={Wang, Tao and Yuan, Li and Chen, Yunpeng and Feng, Jiashi and Yan, Shuicheng},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={4661--4670},
year={2021}
}
```
Binary file added projects/pnp_detr/assets/PnP-DETR.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
93 changes: 93 additions & 0 deletions projects/pnp_detr/configs/models/pnp_detr_r50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from detectron2.config import LazyCall as L

from detrex.modeling.backbone import ResNet, BasicStem
from detrex.modeling.matcher import HungarianMatcher
from detrex.modeling.criterion.criterion import SetCriterion
from detrex.layers.position_embedding import PositionEmbeddingSine

from projects.pnp_detr.modeling import (
PnPDETR,
PnPDetrTransformer,
PnPDetrTransformerEncoder,
PnPDetrTransformerDecoder,
)

model = L(PnPDETR)(
backbone=L(ResNet)(
stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"),
stages=L(ResNet.make_default_stages)(
depth=50,
stride_in_1x1=False,
norm="FrozenBN",
),
out_features=["res2", "res3", "res4", "res5"],
freeze_at=1,
),
in_features=["res5"],
in_channels=2048,
position_embedding=L(PositionEmbeddingSine)(
num_pos_feats=128,
temperature=10000,
normalize=True,
),
transformer=L(PnPDetrTransformer)(
encoder=L(PnPDetrTransformerEncoder)(
embed_dim=256,
num_heads=8,
attn_dropout=0.1,
feedforward_dim=2048,
ffn_dropout=0.1,
num_layers=6,
post_norm=False,
),
decoder=L(PnPDetrTransformerDecoder)(
embed_dim=256,
num_heads=8,
attn_dropout=0.1,
feedforward_dim=2048,
ffn_dropout=0.1,
num_layers=6,
return_intermediate=True,
post_norm=True,
),
sample_topk_ratio=1/3.,
score_pred_net="2layer-fc-256", # choose from ["2layer-fc-256", "2layer-fc-16", "1layer-fc"]
kproj_net="1layer-fc",
unsample_abstract_number=30,

),
embed_dim=256,
num_classes=80,
num_queries=100,
test_time_sample_ratio=0.5, # default to 0.5, should be set to a float number between 0 and 1
criterion=L(SetCriterion)(
num_classes=80,
matcher=L(HungarianMatcher)(
cost_class=1,
cost_bbox=5.0,
cost_giou=2.0,
cost_class_type="ce_cost",
),
weight_dict={
"loss_class": 1,
"loss_bbox": 5.0,
"loss_giou": 2.0,
},
loss_class_type="ce_loss",
eos_coef=0.1,
),
aux_loss=True,
pixel_mean=[123.675, 116.280, 103.530],
pixel_std=[58.395, 57.120, 57.375],
device="cuda",
)

# set aux loss weight dict
if model.aux_loss:
weight_dict = model.criterion.weight_dict
aux_weight_dict = {}
for i in range(model.transformer.decoder.num_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
aux_weight_dict["sample_reg_loss"] = 1e-4
weight_dict.update(aux_weight_dict)
model.criterion.weight_dict = weight_dict
8 changes: 8 additions & 0 deletions projects/pnp_detr/configs/pnp_detr_r101_300ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .pnp_detr_r50_300ep import train, dataloader, optimizer, lr_multiplier, model

# modify model config
model.backbone.stages.depth = 101

# modify training config
train.init_checkpoint = "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
train.output_dir = "./output/detr_r101_300ep"
23 changes: 23 additions & 0 deletions projects/pnp_detr/configs/pnp_detr_r50_300ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from detrex.config import get_config
from .models.pnp_detr_r50 import model

dataloader = get_config("common/data/coco_detr.py").dataloader
lr_multiplier = get_config("common/coco_schedule.py").lr_multiplier_50ep
optimizer = get_config("common/optim.py").AdamW
train = get_config("common/train.py").train

# modify training config
train.init_checkpoint = "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
train.output_dir = "./output/detr_r50_300ep"
train.max_iter = 554400

# modify lr_multiplier
lr_multiplier.scheduler.milestones = [369600, 554400]

# modify optimizer config
optimizer.weight_decay = 1e-4
optimizer.params.lr_factor_func = lambda module_name: 0.1 if "backbone" in module_name else 1

# modify dataloader config
dataloader.train.num_workers = 16
dataloader.train.total_batch_size = 64
6 changes: 6 additions & 0 deletions projects/pnp_detr/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .detr import PnPDETR
from .transformer import (
PnPDetrTransformerEncoder,
PnPDetrTransformerDecoder,
PnPDetrTransformer,
)
Loading

0 comments on commit 0f7e6a8

Please sign in to comment.