Skip to content

Commit

Permalink
unet_gan
Browse files Browse the repository at this point in the history
  • Loading branch information
Ksuriuri committed Jul 21, 2022
0 parents commit d3d9c68
Show file tree
Hide file tree
Showing 3 changed files with 556 additions and 0 deletions.
Binary file added unet_gan_fitS0/__pycache__/ops.cpython-36.pyc
Binary file not shown.
190 changes: 190 additions & 0 deletions unet_gan_fitS0/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import torch
import torch.nn as nn
from torch.nn import Module, ModuleList, Linear, Dropout, LayerNorm, Identity, Parameter, init
import torch.nn.functional as F


class DownsampleLayer(nn.Module):
def __init__(self, in_ch, out_ch):
super(DownsampleLayer, self).__init__()
self.Conv_BN_ReLU_2 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
)
self.downsample=nn.Sequential(
nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
)

def forward(self, x):
out = self.Conv_BN_ReLU_2(x)
out_2 = self.downsample(out)
return out, out_2


class UpSampleLayer(nn.Module):
def __init__(self, in_ch, out_ch):
super(UpSampleLayer, self).__init__()
self.Conv_BN_ReLU_2 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=out_ch*2, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_ch*2),
nn.ReLU(),
nn.Conv2d(in_channels=out_ch*2, out_channels=out_ch*2, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_ch*2),
nn.ReLU()
)
self.upsample=nn.Sequential(
nn.ConvTranspose2d(in_channels=out_ch*2, out_channels=out_ch, kernel_size=3, stride=2,
padding=1, output_padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
)

def forward(self, x, out):
x_out = self.Conv_BN_ReLU_2(x)
x_out = self.upsample(x_out)
cat_out = torch.cat((x_out, out), dim=1)
return cat_out


class TransformerEncoderLayer(Module):
"""
Inspired by torch.nn.TransformerEncoderLayer and timm.
"""

def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1, out_dim=None):
super(TransformerEncoderLayer, self).__init__()
if out_dim is None:
out_dim = d_model
self.q_norm = LayerNorm(d_model)
self.kv_norm = LayerNorm(d_model)
self.self_attn = Attention(dim=d_model, num_heads=nhead,
attention_dropout=attention_dropout, projection_dropout=dropout)

self.linear1 = Linear(d_model, dim_feedforward)
self.dropout1 = Dropout(dropout)
self.norm1 = LayerNorm(d_model)
self.linear2 = Linear(dim_feedforward, out_dim)
self.dropout2 = Dropout(dropout)

self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity()

self.activation = F.gelu

def forward(self, src_q: torch.Tensor, src_kv, *args, **kwargs) -> torch.Tensor:
src = src_q + self.drop_path(self.self_attn(self.q_norm(src_q), self.kv_norm(src_kv)))
src = self.norm1(src)
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
src = src + self.drop_path(self.dropout2(src2))
return src


class TransformerClassifier(Module):
def __init__(self,
seq_pool=True,
embedding_dim=768,
num_layers=12,
num_heads=12,
mlp_ratio=4.0,
num_classes=1000,
dropout=0.1,
attention_dropout=0.1,
stochastic_depth=0.1,
positional_embedding='learnable',
sequence_length=None):
super().__init__()
positional_embedding = positional_embedding if \
positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
dim_feedforward = int(embedding_dim * mlp_ratio)
self.embedding_dim = embedding_dim
self.sequence_length = sequence_length
self.seq_pool = seq_pool
self.num_tokens = 0

assert sequence_length is not None or positional_embedding == 'none', \
f"Positional embedding is set to {positional_embedding} and" \
f" the sequence length was not specified."

# print("seq ", seq_pool, 'emb_dim ', embedding_dim, 'pos_emb ', positional_embedding)

if not seq_pool:
sequence_length += 1
self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim),
requires_grad=True)
self.num_tokens = 1
else:
self.attention_pool = Linear(self.embedding_dim, 1)

if positional_embedding != 'none':
if positional_embedding == 'learnable':
self.positional_emb = Parameter(torch.zeros(1, sequence_length, embedding_dim),
requires_grad=True)
init.trunc_normal_(self.positional_emb, std=0.2)
else:
self.positional_emb = Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
requires_grad=False)
else:
self.positional_emb = None

self.dropout = Dropout(p=dropout)
dpr = [x.item() for x in torch.linspace(0, stochastic_depth, num_layers)]
self.blocks = ModuleList([
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
dim_feedforward=dim_feedforward, dropout=dropout,
attention_dropout=attention_dropout, drop_path_rate=dpr[i])
for i in range(num_layers)])
self.norm = LayerNorm(embedding_dim)

self.fc = Linear(embedding_dim, num_classes)
self.apply(self.init_weight)

def forward(self, x):
# if self.positional_emb is None and x.size(1) < self.sequence_length:
# x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)

# if not self.seq_pool:
# cls_token = self.class_emb.expand(x.shape[0], -1, -1)
# x = torch.cat((cls_token, x), dim=1)

if self.positional_emb is not None:
x = x + self.positional_emb

x = self.dropout(x)

for blk in self.blocks:
x = blk(x)
x_seq = self.norm(x)

# print('before seq pool ', x.shape)
# if self.seq_pool:
x = torch.matmul(F.softmax(self.attention_pool(x_seq), dim=1).transpose(-1, -2), x).squeeze(-2)
# else:
# x = x_seq[:, 0]
# print('aften seq pool ', x.shape)

x = self.fc(x)
return x, x_seq

@staticmethod
def init_weight(m):
if isinstance(m, Linear):
init.trunc_normal_(m.weight, std=.02)
if isinstance(m, Linear) and m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, LayerNorm):
init.constant_(m.bias, 0)
init.constant_(m.weight, 1.0)

@staticmethod
def sinusoidal_embedding(n_channels, dim):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)])
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
return pe.unsqueeze(0)
Loading

0 comments on commit d3d9c68

Please sign in to comment.