-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit d3d9c68
Showing
3 changed files
with
556 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.nn import Module, ModuleList, Linear, Dropout, LayerNorm, Identity, Parameter, init | ||
import torch.nn.functional as F | ||
|
||
|
||
class DownsampleLayer(nn.Module): | ||
def __init__(self, in_ch, out_ch): | ||
super(DownsampleLayer, self).__init__() | ||
self.Conv_BN_ReLU_2 = nn.Sequential( | ||
nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1), | ||
nn.BatchNorm2d(out_ch), | ||
nn.ReLU(), | ||
nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1), | ||
nn.BatchNorm2d(out_ch), | ||
nn.ReLU() | ||
) | ||
self.downsample=nn.Sequential( | ||
nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=2, padding=1), | ||
nn.BatchNorm2d(out_ch), | ||
nn.ReLU() | ||
) | ||
|
||
def forward(self, x): | ||
out = self.Conv_BN_ReLU_2(x) | ||
out_2 = self.downsample(out) | ||
return out, out_2 | ||
|
||
|
||
class UpSampleLayer(nn.Module): | ||
def __init__(self, in_ch, out_ch): | ||
super(UpSampleLayer, self).__init__() | ||
self.Conv_BN_ReLU_2 = nn.Sequential( | ||
nn.Conv2d(in_channels=in_ch, out_channels=out_ch*2, kernel_size=3, stride=1, padding=1), | ||
nn.BatchNorm2d(out_ch*2), | ||
nn.ReLU(), | ||
nn.Conv2d(in_channels=out_ch*2, out_channels=out_ch*2, kernel_size=3, stride=1, padding=1), | ||
nn.BatchNorm2d(out_ch*2), | ||
nn.ReLU() | ||
) | ||
self.upsample=nn.Sequential( | ||
nn.ConvTranspose2d(in_channels=out_ch*2, out_channels=out_ch, kernel_size=3, stride=2, | ||
padding=1, output_padding=1), | ||
nn.BatchNorm2d(out_ch), | ||
nn.ReLU() | ||
) | ||
|
||
def forward(self, x, out): | ||
x_out = self.Conv_BN_ReLU_2(x) | ||
x_out = self.upsample(x_out) | ||
cat_out = torch.cat((x_out, out), dim=1) | ||
return cat_out | ||
|
||
|
||
class TransformerEncoderLayer(Module): | ||
""" | ||
Inspired by torch.nn.TransformerEncoderLayer and timm. | ||
""" | ||
|
||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, | ||
attention_dropout=0.1, drop_path_rate=0.1, out_dim=None): | ||
super(TransformerEncoderLayer, self).__init__() | ||
if out_dim is None: | ||
out_dim = d_model | ||
self.q_norm = LayerNorm(d_model) | ||
self.kv_norm = LayerNorm(d_model) | ||
self.self_attn = Attention(dim=d_model, num_heads=nhead, | ||
attention_dropout=attention_dropout, projection_dropout=dropout) | ||
|
||
self.linear1 = Linear(d_model, dim_feedforward) | ||
self.dropout1 = Dropout(dropout) | ||
self.norm1 = LayerNorm(d_model) | ||
self.linear2 = Linear(dim_feedforward, out_dim) | ||
self.dropout2 = Dropout(dropout) | ||
|
||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity() | ||
|
||
self.activation = F.gelu | ||
|
||
def forward(self, src_q: torch.Tensor, src_kv, *args, **kwargs) -> torch.Tensor: | ||
src = src_q + self.drop_path(self.self_attn(self.q_norm(src_q), self.kv_norm(src_kv))) | ||
src = self.norm1(src) | ||
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)))) | ||
src = src + self.drop_path(self.dropout2(src2)) | ||
return src | ||
|
||
|
||
class TransformerClassifier(Module): | ||
def __init__(self, | ||
seq_pool=True, | ||
embedding_dim=768, | ||
num_layers=12, | ||
num_heads=12, | ||
mlp_ratio=4.0, | ||
num_classes=1000, | ||
dropout=0.1, | ||
attention_dropout=0.1, | ||
stochastic_depth=0.1, | ||
positional_embedding='learnable', | ||
sequence_length=None): | ||
super().__init__() | ||
positional_embedding = positional_embedding if \ | ||
positional_embedding in ['sine', 'learnable', 'none'] else 'sine' | ||
dim_feedforward = int(embedding_dim * mlp_ratio) | ||
self.embedding_dim = embedding_dim | ||
self.sequence_length = sequence_length | ||
self.seq_pool = seq_pool | ||
self.num_tokens = 0 | ||
|
||
assert sequence_length is not None or positional_embedding == 'none', \ | ||
f"Positional embedding is set to {positional_embedding} and" \ | ||
f" the sequence length was not specified." | ||
|
||
# print("seq ", seq_pool, 'emb_dim ', embedding_dim, 'pos_emb ', positional_embedding) | ||
|
||
if not seq_pool: | ||
sequence_length += 1 | ||
self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim), | ||
requires_grad=True) | ||
self.num_tokens = 1 | ||
else: | ||
self.attention_pool = Linear(self.embedding_dim, 1) | ||
|
||
if positional_embedding != 'none': | ||
if positional_embedding == 'learnable': | ||
self.positional_emb = Parameter(torch.zeros(1, sequence_length, embedding_dim), | ||
requires_grad=True) | ||
init.trunc_normal_(self.positional_emb, std=0.2) | ||
else: | ||
self.positional_emb = Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim), | ||
requires_grad=False) | ||
else: | ||
self.positional_emb = None | ||
|
||
self.dropout = Dropout(p=dropout) | ||
dpr = [x.item() for x in torch.linspace(0, stochastic_depth, num_layers)] | ||
self.blocks = ModuleList([ | ||
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, | ||
dim_feedforward=dim_feedforward, dropout=dropout, | ||
attention_dropout=attention_dropout, drop_path_rate=dpr[i]) | ||
for i in range(num_layers)]) | ||
self.norm = LayerNorm(embedding_dim) | ||
|
||
self.fc = Linear(embedding_dim, num_classes) | ||
self.apply(self.init_weight) | ||
|
||
def forward(self, x): | ||
# if self.positional_emb is None and x.size(1) < self.sequence_length: | ||
# x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0) | ||
|
||
# if not self.seq_pool: | ||
# cls_token = self.class_emb.expand(x.shape[0], -1, -1) | ||
# x = torch.cat((cls_token, x), dim=1) | ||
|
||
if self.positional_emb is not None: | ||
x = x + self.positional_emb | ||
|
||
x = self.dropout(x) | ||
|
||
for blk in self.blocks: | ||
x = blk(x) | ||
x_seq = self.norm(x) | ||
|
||
# print('before seq pool ', x.shape) | ||
# if self.seq_pool: | ||
x = torch.matmul(F.softmax(self.attention_pool(x_seq), dim=1).transpose(-1, -2), x).squeeze(-2) | ||
# else: | ||
# x = x_seq[:, 0] | ||
# print('aften seq pool ', x.shape) | ||
|
||
x = self.fc(x) | ||
return x, x_seq | ||
|
||
@staticmethod | ||
def init_weight(m): | ||
if isinstance(m, Linear): | ||
init.trunc_normal_(m.weight, std=.02) | ||
if isinstance(m, Linear) and m.bias is not None: | ||
init.constant_(m.bias, 0) | ||
elif isinstance(m, LayerNorm): | ||
init.constant_(m.bias, 0) | ||
init.constant_(m.weight, 1.0) | ||
|
||
@staticmethod | ||
def sinusoidal_embedding(n_channels, dim): | ||
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] | ||
for p in range(n_channels)]) | ||
pe[:, 0::2] = torch.sin(pe[:, 0::2]) | ||
pe[:, 1::2] = torch.cos(pe[:, 1::2]) | ||
return pe.unsqueeze(0) |
Oops, something went wrong.