Skip to content

Commit

Permalink
doc
Browse files Browse the repository at this point in the history
Signed-off-by: vgrau98 <[email protected]>
  • Loading branch information
vgrau98 committed Jan 15, 2024
1 parent 6d61b22 commit 6fb51f5
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
6 changes: 6 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,12 @@ Blocks
.. autoclass:: monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock
:members:

`Attention utilities`
~~~~~~~~~~~~~~~~~~~~~
.. automodule:: monai.networks.blocks.attention_utils
.. autofunction:: monai.networks.blocks.attention_utils.get_rel_pos
.. autofunction:: monai.networks.blocks.attention_utils.add_decomposed_rel_pos

N-Dim Fourier Transform
~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: monai.networks.blocks.fft_utils_t
Expand Down
37 changes: 32 additions & 5 deletions monai/networks/blocks/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
Expand Down Expand Up @@ -51,10 +52,36 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
def add_decomposed_rel_pos(
attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
r"""
Calculate decomposed Relative Positional Embeddings from mvitv2 implementation:
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Only 2D and 3D are supported.
Encoding the relative position of tokens in the attention matrix: tokens spaced a distance
`d` apart will have the same embedding value (unlike absolute positional embedding).
.. math::
Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale
where
.. math::
E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)}
with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`,
respectively spatial positions of element :math:`i` and :math:`j`
When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow:
.. math::
R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)}
with :math:`n = 1...dim`
Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to
:math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding.
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C).
Expand All @@ -63,7 +90,7 @@ def add_decomposed_rel_pos(
k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
attn (Tensor): attention logits with added relative positional embeddings.
"""
rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0])
rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1])
Expand Down

0 comments on commit 6fb51f5

Please sign in to comment.