From 24196a3e8a21b90d939f95bcf97d15aa65aa8012 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 20 Nov 2024 10:59:22 -0800 Subject: [PATCH] allow for qk norm to be turned off for na vit nested tensor --- setup.py | 2 +- vit_pytorch/na_vit_nested_tensor.py | 13 +++++++------ vit_pytorch/na_vit_nested_tensor_3d.py | 13 +++++++------ 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/setup.py b/setup.py index c511aee..f6da4c3 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.8.7', + version = '1.8.8', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description=long_description, diff --git a/vit_pytorch/na_vit_nested_tensor.py b/vit_pytorch/na_vit_nested_tensor.py index 04882c8..48b5988 100644 --- a/vit_pytorch/na_vit_nested_tensor.py +++ b/vit_pytorch/na_vit_nested_tensor.py @@ -41,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.): ) class Attention(Module): - def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True): super().__init__() self.norm = nn.LayerNorm(dim, bias = False) @@ -56,8 +56,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): # in the paper, they employ qk rmsnorm, a way to stabilize attention # will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors - self.query_norm = nn.LayerNorm(dim_head, bias = False) - self.key_norm = nn.LayerNorm(dim_head, bias = False) + self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity() + self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity() self.dropout = dropout @@ -111,13 +111,13 @@ def transpose_head_seq(t): return self.to_out(out) class Transformer(Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True): super().__init__() self.layers = ModuleList([]) for _ in range(depth): self.layers.append(ModuleList([ - Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm), FeedForward(dim, mlp_dim, dropout = dropout) ])) @@ -146,6 +146,7 @@ def __init__( dim_head = 64, dropout = 0., emb_dropout = 0., + qk_rmsnorm = True, token_dropout_prob: float | None = None ): super().__init__() @@ -184,7 +185,7 @@ def __init__( self.dropout = nn.Dropout(emb_dropout) - self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm) # final attention pooling queries diff --git a/vit_pytorch/na_vit_nested_tensor_3d.py b/vit_pytorch/na_vit_nested_tensor_3d.py index 1f6ab59..e160bc7 100644 --- a/vit_pytorch/na_vit_nested_tensor_3d.py +++ b/vit_pytorch/na_vit_nested_tensor_3d.py @@ -41,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.): ) class Attention(Module): - def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True): super().__init__() self.norm = nn.LayerNorm(dim, bias = False) @@ -56,8 +56,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): # in the paper, they employ qk rmsnorm, a way to stabilize attention # will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors - self.query_norm = nn.LayerNorm(dim_head, bias = False) - self.key_norm = nn.LayerNorm(dim_head, bias = False) + self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity() + self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity() self.dropout = dropout @@ -123,13 +123,13 @@ def transpose_head_seq(t): return self.to_out(out) class Transformer(Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True): super().__init__() self.layers = ModuleList([]) for _ in range(depth): self.layers.append(ModuleList([ - Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm), FeedForward(dim, mlp_dim, dropout = dropout) ])) @@ -161,6 +161,7 @@ def __init__( dropout = 0., emb_dropout = 0., num_registers = 4, + qk_rmsnorm = True, token_dropout_prob: float | None = None ): super().__init__() @@ -209,7 +210,7 @@ def __init__( self.dropout = nn.Dropout(emb_dropout) - self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm) # final attention pooling queries