-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
76 lines (58 loc) · 2.47 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Dict, Optional
import torch.optim as optim
from torchvision import datasets,transforms
import torchvision
import tqdm
class Sampling(nn.Module):
# Sampling from a random normal distribution
def __init__(self, means: torch.Tensor, logvars: torch.Tensor):
super().__init__()
self.mean = means
self.logvar = logvars
def forward(self):
return torch.randn_like(self.logvar) * torch.exp(self.logvar/2) + self.mean
class Encoder(nn.Module):
#Implementing the encoder part
def __init__(self, codings_size : int = 20, inp_shape : tuple = (28,28))->None:
super().__init__()
self.codings_size = codings_size
self.inp_shape = inp_shape
self.flatten = nn.Flatten()
self.linear_1 = nn.Linear(in_features=self.inp_shape[0]*self.inp_shape[1],out_features=150)
self.linear_2 = nn.Linear(in_features=150,out_features=100)
self.coding_mean = nn.Linear(in_features=100,out_features=self.codings_size)
self.coding_logvar = nn.Linear(in_features=100,out_features=self.codings_size)
def forward(self,x):
x = self.flatten(x)
x = F.selu(self.linear_1(x))
x = F.selu(self.linear_2(x))
coding_mean = self.coding_mean(x)
coding_logvar = self.coding_logvar(x)
codings = Sampling(means=coding_mean,logvars=coding_logvar)()
return coding_mean,coding_logvar,codings
class Decoder(nn.Module):
#Implementing the decoder part
def __init__(self, codings_size : int = 20, inp_shape : tuple = (28,28))->None:
super().__init__()
self.codings_size = codings_size
self.inp_shape = inp_shape
self.linear_decoder_1 = nn.Linear(in_features=self.codings_size,out_features=100)
self.linear_decoder_2 = nn.Linear(in_features=100,out_features=150)
self.linear_decoder_3 = nn.Linear(in_features=150,out_features=self.inp_shape[0]*self.inp_shape[1])
def forward(self,x):
x = F.selu(self.linear_decoder_1(x))
x = F.selu(self.linear_decoder_2(x))
x = torch.sigmoid(self.linear_decoder_3(x))
return x
class VAE_Dense(nn.Module):
def __init__(self, Encoder,Decoder)->None:
super().__init__()
self.Encoder = Encoder
self.Decoder = Decoder
def forward(self,x):
mean,logvar,codings = self.Encoder(x)
x = self.Decoder(codings)
return x,mean,logvar