diff --git a/docs/source/networks.rst b/docs/source/networks.rst index f9375f1e97..556bf12d50 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -248,6 +248,12 @@ Blocks .. autoclass:: monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock :members: +`Attention utilities` +~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: monai.networks.blocks.attention_utils +.. autofunction:: monai.networks.blocks.attention_utils.get_rel_pos +.. autofunction:: monai.networks.blocks.attention_utils.add_decomposed_rel_pos + N-Dim Fourier Transform ~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: monai.networks.blocks.fft_utils_t diff --git a/monai/networks/blocks/attention_utils.py b/monai/networks/blocks/attention_utils.py index a9fd6c89ae..8c9002a16e 100644 --- a/monai/networks/blocks/attention_utils.py +++ b/monai/networks/blocks/attention_utils.py @@ -19,7 +19,8 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: """ Get relative positional embeddings according to the relative positions of - query and key sizes. + query and key sizes. + Args: q_size (int): size of query q. k_size (int): size of key k. @@ -51,10 +52,36 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor def add_decomposed_rel_pos( attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple ) -> torch.Tensor: - """ - Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. - https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + r""" + Calculate decomposed Relative Positional Embeddings from mvitv2 implementation: + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + Only 2D and 3D are supported. + + Encoding the relative position of tokens in the attention matrix: tokens spaced a distance + `d` apart will have the same embedding value (unlike absolute positional embedding). + + .. math:: + Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale + + where + + .. math:: + E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)} + + with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`, + respectively spatial positions of element :math:`i` and :math:`j` + + When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow: + + .. math:: + R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)} + + with :math:`n = 1...dim` + + Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to + :math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding. + Args: attn (Tensor): attention map. q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C). @@ -63,7 +90,7 @@ def add_decomposed_rel_pos( k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n). Returns: - attn (Tensor): attention map with added relative positional embeddings. + attn (Tensor): attention logits with added relative positional embeddings. """ rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0]) rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1])