Skip to content

Commit

Permalink
Update README (#275)
Browse files Browse the repository at this point in the history
* update readme

* add annotations to Deform.DETR

* refine README

* update README
  • Loading branch information
rentainhe authored Jul 5, 2023
1 parent 301d6c8 commit 1e645c6
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 4 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
<!-- <a href="https://arxiv.org/abs/2306.07265">📚Read detrex Benchmarking Paper</a> <sup><i><font size="3" color="#FF0000">New</font></i></sup> |
<a href="https://rentainhe.github.io/projects/detrex/">🏠Project Page</a> <sup><i><font size="3" color="#FF0000">New</font></i></sup> | [🏷️Cite detrex](#citation) -->

[📚Read detrex Benchmarking Paper](https://arxiv.org/abs/2306.07265) | [🏠Project Page](https://rentainhe.github.io/projects/detrex/) | [🏷️Cite detrex](#citation)
[📚Read detrex Benchmarking Paper](https://arxiv.org/abs/2306.07265) | [🏠Project Page](https://rentainhe.github.io/projects/detrex/) | [🏷️Cite detrex](#citation) | [🚢DeepDataSpace](https://github.com/IDEA-Research/deepdataspace)

</div>

Expand Down Expand Up @@ -59,7 +59,7 @@ detrex is an open-source toolbox that provides state-of-the-art Transformer-base

- **Modular Design.** detrex decomposes the Transformer-based detection framework into various components which help users easily build their own customized models.

- **State-of-the-art Methods.** detrex provides a series of Transformer-based detection algorithms, including [DINO](https://arxiv.org/abs/2203.03605) which reached the SOTA of DETR-like models with **63.3AP**!
- **Strong Baselines.** detrex provides a series of strong baselines for Transformer-based detection models. We have further boosted the model performance from **0.2 AP** to **1.1 AP** through optimizing hyper-parameters among most of the supported algorithms.

- **Easy to Use.** detrex is designed to be **light-weight** and easy for users to use:
- [LazyConfig System](https://detectron2.readthedocs.io/en/latest/tutorials/lazyconfigs.html) for more flexible syntax and cleaner config files.
Expand Down
12 changes: 12 additions & 0 deletions detrex/layers/multi_scale_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def __init__(
self.num_heads = num_heads
self.num_levels = num_levels
self.num_points = num_points
# n_heads * n_points and n_levels for multi-level feature inputs
self.sampling_offsets = nn.Linear(embed_dim, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
Expand Down Expand Up @@ -284,13 +285,18 @@ def forward(

assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value

# value projection
value = self.value_proj(value)
# fill "0" for the padding part
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], float(0))
# [bs, all hw, 256] -> [bs, all hw, 8, 32]
value = value.view(bs, num_value, self.num_heads, -1)
# [bs, all hw, 8, 4, 4, 2]: 8 heads, 4 level features, 4 sampling points, 2 offsets
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
)
# [bs, all hw, 8, 16]: 4 level 4 sampling points: 16 features total
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points
)
Expand All @@ -305,6 +311,12 @@ def forward(

# bs, num_query, num_heads, num_levels, num_points, 2
if reference_points.shape[-1] == 2:

# reference_points [bs, all hw, 4, 2] -> [bs, all hw, 1, 4, 1, 2]
# sampling_offsets [bs, all hw, 8, 4, 4, 2]
# offset_normalizer [4, 2] -> [1, 1, 1, 4, 1, 2]
# references_points + sampling_offsets

offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = (
reference_points[:, :, None, :, None, :]
Expand Down
1 change: 1 addition & 0 deletions projects/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ Here are projects that are built on detrex which show you use detrex as a librar
- [DETRs with Hybrid Matching (CVPR'2023)](./h_deformable_detr/)
- [Mask DINO: Towards A Unified Transformer-based Framework for Object Detection and Segmentation (CVPR'2023)](./maskdino/)
- [NMS strikes back (ArXiv'2022)](./deta/)
- [CO-MOT: Bridging the Gap Between End-to-end and Non-End-to-end Multi-Object Tracking (ArXiv'2023)](./co_mot/)
- [Enhanced Training of Query-Based Object Detection via Selective Query Recollection (CVPR'2023)](./sqr_detr/)
1 change: 0 additions & 1 deletion projects/co_mot/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# CO-MOT: Bridging the Gap Between End-to-end and Non-End-to-end Multi-Object Tracking


[![arXiv]](https://arxiv.org/abs/2305.12724)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bridging-the-gap-between-end-to-end-and-non/multi-object-tracking-on-dancetrack)](https://paperswithcode.com/sota/multi-object-tracking-on-dancetrack?p=bridging-the-gap-between-end-to-end-and-non)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bridging-the-gap-between-end-to-end-and-non/multi-object-tracking-on-bdd100k)](https://paperswithcode.com/sota/multi-object-tracking-on-bdd100k?p=bridging-the-gap-between-end-to-end-and-non)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bridging-the-gap-between-end-to-end-and-non/multi-object-tracking-on-mot17)](https://paperswithcode.com/sota/multi-object-tracking-on-mot17?p=bridging-the-gap-between-end-to-end-and-non)
Expand Down
17 changes: 16 additions & 1 deletion projects/deformable_detr/modeling/deformable_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,14 @@ def init_weights(self):
nn.init.normal_(self.level_embeds)

def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
# S here is all the sum of h*w
N, S, C = memory.shape
proposals = []
_cur = 0
for lvl, (H, W) in enumerate(spatial_shapes):
# memory_padding_mask is the mask for all levels of features: [bs, all H*W]
# get the padding mask for each level of feature: [bs, H_, W_, 1]
# and count the valid height and width for each level of feature
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].view(N, H, W, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
Expand All @@ -282,18 +286,23 @@ def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shap
_cur += H * W

output_proposals = torch.cat(proposals, 1)
# Ensure that the proposals are not positioned too closely to the boundaries.
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(
-1, keepdim=True
)
# inverse sigmoid
output_proposals = torch.log(output_proposals / (1 - output_proposals))
# fill "inf" to the masked region
output_proposals = output_proposals.masked_fill(
memory_padding_mask.unsqueeze(-1), float("inf")
)
# fill "inf" to the proposals near the boundary
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))

output_memory = memory
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
# fill "0" to the input memory features according to the proposals
output_memory = self.enc_output_norm(self.enc_output(output_memory))
return output_memory, output_proposals

Expand All @@ -316,16 +325,22 @@ def get_reference_points(spatial_shapes, valid_ratios, device):
"""
reference_points_list = []
for lvl, (H, W) in enumerate(spatial_shapes):
# TODO check this 0.5
# generate reference points
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device),
torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device),
)

# normalize reference points
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)

# concat all the reference points together: [bs, all hw, 2]
reference_points = torch.cat(reference_points_list, 1)
# [bs, all h*w, 4, 2]
# 每个reference points都在每个level feature上有个初始的采样点
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points

Expand Down

0 comments on commit 1e645c6

Please sign in to comment.