Skip to content

Commit

Permalink
feat: 3D decomposed relative positional embeddings
Browse files Browse the repository at this point in the history
Signed-off-by: vgrau98 <[email protected]>
  • Loading branch information
vgrau98 committed Jan 1, 2024
1 parent a46b374 commit 57fc23b
Showing 1 changed file with 53 additions and 39 deletions.
92 changes: 53 additions & 39 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,18 @@ def __init__(
dropout_rate: float = 0.0,
qkv_bias: bool = False,
save_attn: bool = False,
use_rel_pos: bool = False,
input_size: Optional[Tuple[int, int]] = None,
use_rel_pos: Optional[str] = None,
input_size: Optional[Tuple] = None,
) -> None:
"""
Args:
hidden_size (int): dimension of hidden layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
rel_pos (bool): If True, add relative positional embeddings to the attention map. Only support 2D inputs.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
rel_pos (str, optional): Add relative positional embeddings to the attention map.
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
Expand Down Expand Up @@ -74,12 +75,11 @@ def __init__(
self.use_rel_pos = use_rel_pos
self.input_size = input_size

if self.use_rel_pos:
if self.use_rel_pos == "decomposed":
assert input_size is not None, "Input size must be provided if using relative positional encoding."
assert len(input_size) == 2, "Relative positional embedding is only supported for 2D"
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim))
self.rel_pos_arr = nn.ParameterList(
[nn.Parameter(torch.zeros(2 * dim_input_size - 1, self.head_dim)) for dim_input_size in input_size]
)

def forward(self, x: torch.Tensor):
"""
Expand All @@ -93,18 +93,18 @@ def forward(self, x: torch.Tensor):
q, k, v = output[0], output[1], output[2]
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

if self.use_rel_pos:
if self.use_rel_pos == "decomposed":
batch = x.shape[0]
h, w = self.input_size if self.input_size is not None else (0, 0)
h, w = self.input_size[:2] if self.input_size is not None else (0, 0)
d = self.input_size[2] if self.input_size is not None and len(self.input_size) > 2 else 1
att_mat = add_decomposed_rel_pos(
att_mat.view(batch * self.num_heads, h * w, h * w),
q.view(batch * self.num_heads, h * w, -1),
self.rel_pos_h,
self.rel_pos_w,
(h, w),
(h, w),
att_mat.view(batch * self.num_heads, h * w * d, h * w * d),
q.view(batch * self.num_heads, h * w * d, -1),
self.rel_pos_arr,
(h, w) if d == 1 else (h, w, d),
(h, w) if d == 1 else (h, w, d),
)
att_mat = att_mat.reshape(batch, self.num_heads, h * w, h * w)
att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d)

att_mat = att_mat.softmax(dim=-1)

Expand Down Expand Up @@ -154,39 +154,53 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor


def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Only 2D and 3D are supported.
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C).
rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis.
q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n).
k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
rh = get_rel_pos(q_h, k_h, rel_pos_h)
rw = get_rel_pos(q_w, k_w, rel_pos_w)
rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0])
rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1])

batch, _, dim = q.shape
r_q = q.reshape(batch, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw)

attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
batch, q_h * q_w, k_h * k_w
)
if len(rel_pos_lst) == 2:
q_h, q_w = q_size[:2]
k_h, k_w = k_size[:2]
r_q = q.reshape(batch, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw)

attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
batch, q_h * q_w, k_h * k_w
)
elif len(rel_pos_lst) == 3:
q_h, q_w, q_d = q_size[:3]
k_h, k_w, k_d = k_size[:3]

rd = get_rel_pos(q_d, k_d, rel_pos_lst[2])

r_q = q.reshape(batch, q_h, q_w, q_d, dim)
rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh)
rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw)
rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd)

attn = (
attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d)
+ rel_h[:, :, :, :, None, None]
+ rel_w[:, :, :, None, :, None]
+ rel_d[:, :, :, None, None, :]
).view(batch, q_h * q_w * q_d, k_h * k_w * k_d)

return attn

0 comments on commit 57fc23b

Please sign in to comment.