From c1454134cd1cef407ba67dcc438340516aaf03d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 23 Nov 2020 13:07:25 +0100 Subject: [PATCH 01/55] remove splitted_L1 --- radionets/dl_framework/learner.py | 3 --- radionets/dl_framework/loss_functions.py | 22 +--------------------- 2 files changed, 1 insertion(+), 24 deletions(-) diff --git a/radionets/dl_framework/learner.py b/radionets/dl_framework/learner.py index b5c1285c..3d20c4fb 100644 --- a/radionets/dl_framework/learner.py +++ b/radionets/dl_framework/learner.py @@ -15,7 +15,6 @@ likelihood, likelihood_phase, spe, - splitted_L1, ) from radionets.dl_framework.callbacks import ( normalize_tfm, @@ -120,8 +119,6 @@ def define_learner( loss_func = loss_msssim_amp elif loss_func == "spe": loss_func = spe - elif loss_func == "splitted_L1": - loss_func = splitted_L1 else: print("\n No matching loss function or architecture! Exiting. \n") sys.exit(1) diff --git a/radionets/dl_framework/loss_functions.py b/radionets/dl_framework/loss_functions.py index f96f6b42..558c1561 100644 --- a/radionets/dl_framework/loss_functions.py +++ b/radionets/dl_framework/loss_functions.py @@ -452,24 +452,4 @@ def spe_(x, y): value = 0 k = sum(loss) loss = k / len(x) - return loss - -def splitted_L1(x, y): - inp_real = x[:, 0, :] - inp_imag = x[:, 1, :] - - tar_real = y[:, 0, :] - tar_imag = y[:, 1, :] - - loss_real = ( - torch.sum(1 / inp_real.shape[1] * torch.sum(torch.abs(inp_real - tar_real), 1)) - * 1 - / inp_real.shape[0] - ) - loss_imag = ( - torch.sum(1 / inp_imag.shape[1] * torch.sum(torch.abs(inp_imag - tar_imag), 1)) - * 1 - / inp_real.shape[0] - ) - - return loss_real + loss_imag \ No newline at end of file + return loss \ No newline at end of file From 4dfcdccca0e406d017f97d0ec8b83717a697f8d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 23 Nov 2020 13:07:59 +0100 Subject: [PATCH 02/55] add psnr and ssim --- radionets/dl_framework/inspection.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/radionets/dl_framework/inspection.py b/radionets/dl_framework/inspection.py index fbe9fc66..1193d0a6 100644 --- a/radionets/dl_framework/inspection.py +++ b/radionets/dl_framework/inspection.py @@ -128,8 +128,8 @@ def fft_pred(pred, truth, amp_phase=True): a = pred[:, 0, :, :] b = pred[:, 1, :, :] - a_true = truth[0, :, :] - b_true = truth[1, :, :] + a_true = truth[:, 0, :, :] + b_true = truth[:, 1, :, :] if amp_phase: amp_pred_rescaled = (10 ** (10 * a) - 1) / 10 ** 10 @@ -147,7 +147,7 @@ def fft_pred(pred, truth, amp_phase=True): ifft_pred = np.fft.ifft2(compl_pred) ifft_true = np.fft.ifft2(compl_true) - return np.absolute(ifft_pred)[0], np.absolute(ifft_true) + return np.absolute(ifft_pred), np.absolute(ifft_true) def reshape_2d(array): @@ -339,3 +339,22 @@ def make_axes_nice(fig, ax, im, title): cbar.set_label("Intensity / a.u.") # cbar.formatter.set_powerlimits((0, 0)) # cbar.update_ticks() + +def psnr(pred, truth): + L = [np.amax(t)**2 for t in truth] + psnr = [10*np.log10(L[i]/(np.mean((truth[i]-pred[i])**2))) for i in range(len(pred))] + return np.mean(psnr) + +def ssim(pred,true): + mean_pred = [np.mean(p) for p in pred] + std_pred = [np.std(p) for p in pred] + mean_true = [np.mean(t) for t in true] + std_true = [np.std(t) for t in true] + cov = [1/(len(pred[i])**2-1)*np.sum((pred[i]-mean_pred[i])*(true[i]-mean_true[i])) for i in range(len(pred))] + c1 = [(0.01*np.amax(t))**2 for t in true] + c2 = [(0.03*np.amax(t))**2 for t in true] + c3 = [c/2 for c in c2] + l = [(2*mean_pred[i]*mean_true[i]+c1[i])/(mean_pred[i]**2+mean_true[i]**2+c1[i]) for i in range(len(pred))] + c = [(2*std_pred[i]*std_true[i]+c2[i])/(std_pred[i]**2+std_true[i]**2+c2[i]) for i in range(len(pred))] + s = [(cov[i]+c3[i])/(std_pred[i]*std_true[i]+c3[i]) for i in range(len(pred))] + return np.mean([l[i]*c[i]*s[i] for i in range(len(l))]) \ No newline at end of file From c344996bbdd4d78f28dec21fbd32d576c0600da5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 23 Nov 2020 13:09:57 +0100 Subject: [PATCH 03/55] add EDSR and RDNet --- .../dl_framework/architectures/superRes.py | 92 +++++++++++++++++++ radionets/dl_framework/model.py | 60 ++++++++++++ radionets/dl_training/utils.py | 2 + 3 files changed, 154 insertions(+) diff --git a/radionets/dl_framework/architectures/superRes.py b/radionets/dl_framework/architectures/superRes.py index 92d93696..26353b4f 100644 --- a/radionets/dl_framework/architectures/superRes.py +++ b/radionets/dl_framework/architectures/superRes.py @@ -7,6 +7,8 @@ ResBlock_amp, ResBlock_phase, SRBlock, + EDSRBaseBlock, + RDB, ) @@ -403,3 +405,93 @@ def forward(self, x): x = self.final(x) return x + +class EDSRBase(nn.Module): + def __init__(self, img_size): + super().__init__() + # torch.cuda.set_device(1) + self.img_size = img_size + + self.preBlock = nn.Sequential( + nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2) + ) + + # ResBlock 16 + self.blocks = nn.Sequential( + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + EDSRBaseBlock(64, 64), + ) + + self.postBlock = nn.Sequential( + nn.Conv2d(64, 64, 3, stride=1, padding=1) + ) + + self.final = nn.Sequential( + nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2) + ) + + def forward(self, x): + x = self.preBlock(x) + + x = x + self.postBlock(self.blocks(x)) + + x = self.final(x) + + return x + + +class RDNet(nn.Module): + def __init__(self, img_size): + super().__init__() + # torch.cuda.set_device(1) + self.img_size = img_size + + self.preBlock = nn.Sequential( + nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2, bias=False) + ) + + # ResBlock 6 + self.block1 = RDB(64, 32) + self.block2 = RDB(64, 32) + self.block3 = RDB(64, 32) + self.block4 = RDB(64, 32) + self.block5 = RDB(64, 32) + self.block6 = RDB(64, 32) + + + self.postBlock = nn.Sequential( + nn.Conv2d(6*64, 64, 1, stride=1, padding=0, bias=False), + nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False) + ) + + self.final = nn.Sequential( + nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2, bias=False) + ) + + def forward(self, x): + x = self.preBlock(x) + + x1 = self.block1(x) + x2 = self.block2(x1) + x3 = self.block3(x2) + x4 = self.block4(x3) + x5 = self.block5(x4) + x6 = self.block6(x5) + + x = x + self.postBlock(torch.cat((x1,x2,x3,x4,x5,x6), dim=1)) + x = self.final(x) + return x \ No newline at end of file diff --git a/radionets/dl_framework/model.py b/radionets/dl_framework/model.py index 8f26aa12..24ed20ab 100644 --- a/radionets/dl_framework/model.py +++ b/radionets/dl_framework/model.py @@ -664,3 +664,63 @@ def _conv_block(self, ni, nf, stride): nn.Conv2d(nf, nf, 3, stride=1, padding=1), nn.BatchNorm2d(nf) ) + +class EDSRBaseBlock(nn.Module): + def __init__(self, ni, nf, stride=1): + super().__init__() + self.convs = self._conv_block(ni,nf,stride) + self.idconv = nn.Identity() if ni == nf else nn.Conv2d(ni, nf, 1) + self.pool = nn.Identity() if stride == 1 else nn.AvgPool2d(2, ceil_mode=True)#nn.AvgPool2d(8, 2, ceil_mode=True) + + def forward(self, x): + return self.convs(x) + self.idconv(self.pool(x)) + + def _conv_block(self, ni, nf, stride): + return nn.Sequential( + nn.Conv2d(ni, nf, 3, stride=stride, padding=1), + nn.PReLU(), + nn.Conv2d(nf, nf, 3, stride=1, padding=1) + ) + +class RDB(nn.Module): + def __init__(self, ni, nf, stride=1): + super().__init__() + self.conv1 = self._conv_block(ni,nf,stride) + self.conv2 = self._conv_block(ni+nf,nf,stride) + self.conv3 = self._conv_block(ni+2*nf,nf,stride) + self.conv4 = self._conv_block(ni+3*nf,nf,stride) + self.conv5 = self._conv_block(ni+4*nf,nf,stride) + self.conv6 = self._conv_block(ni+5*nf,nf,stride) + + self.convFusion = nn.Conv2d(ni+6*nf, ni, 1, stride=1, padding=0, groups=2, bias=False) + + def forward(self, x): + x1_c = self.conv1(x) + cat = self._cat_split(x, x1_c) + x2_c = self.conv2(cat) + cat = self._cat_split(cat, x2_c) + x3_c = self.conv3(cat) + cat = self._cat_split(cat, x3_c) + x4_c = self.conv4(cat) + cat = self._cat_split(cat, x4_c) + x5_c = self.conv5(cat) + cat = self._cat_split(cat, x5_c) + x6_c = self.conv6(cat) + cat = self._cat_split(cat, x6_c) + + + return self.convFusion(cat) + x + + def _conv_block(self, ni, nf, stride): + return nn.Sequential( + nn.Conv2d(ni, nf, 3, stride=stride, padding=1, bias=False), + nn.PReLU() + ) + + def _cat_split(self, x, y): + x1, x2 = torch.chunk(x,2, dim=1) + y1, y2 = torch.chunk(y,2, dim=1) + return torch.cat((x1,y1,x2,y2), dim=1) + + + diff --git a/radionets/dl_training/utils.py b/radionets/dl_training/utils.py index 7f3735ff..0284bbc1 100644 --- a/radionets/dl_training/utils.py +++ b/radionets/dl_training/utils.py @@ -75,6 +75,8 @@ def define_arch(arch_name, img_size): or arch_name == "superRes_res34" or arch_name == "SRResNet" or arch_name == "SRResNet_corr" + or arch_name == "EDSRBase" + or arch_name == "RDNet" ): arch = getattr(architecture, arch_name)(img_size) else: From 4e96c469cf4cb50f1e8a48f5f1847970898a144d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 10 Dec 2020 13:04:41 +0100 Subject: [PATCH 04/55] add Feedback Arch and SRResNet_sym with transformation of input --- .../dl_framework/architectures/superRes.py | 109 +++++++++++++++- radionets/dl_framework/model.py | 121 ++++++++++++++++++ 2 files changed, 227 insertions(+), 3 deletions(-) diff --git a/radionets/dl_framework/architectures/superRes.py b/radionets/dl_framework/architectures/superRes.py index 26353b4f..57e25580 100644 --- a/radionets/dl_framework/architectures/superRes.py +++ b/radionets/dl_framework/architectures/superRes.py @@ -9,6 +9,11 @@ SRBlock, EDSRBaseBlock, RDB, + FBB, + Lambda, + better_symmetry, + tf_shift, + btf_shift ) @@ -362,7 +367,7 @@ def forward(self, x): class SRResNet_corr(nn.Module): def __init__(self, img_size): super().__init__() - torch.cuda.set_device(1) + # torch.cuda.set_device(1) self.img_size = img_size self.preBlock = nn.Sequential( @@ -397,6 +402,10 @@ def __init__(self, img_size): nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2), ) + #new symmetry + + self.symmetry = Lambda(better_symmetry) + def forward(self, x): x = self.preBlock(x) @@ -404,7 +413,59 @@ def forward(self, x): x = self.final(x) - return x + return self.symmetry(x) + +class SRResNet_sym(nn.Module): + def __init__(self, img_size): + super().__init__() + # torch.cuda.set_device(1) + self.tf = Lambda(tf_shift) + + self.preBlock = nn.Sequential( + nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2), nn.PReLU() + ) + + # ResBlock 16 + self.blocks = nn.Sequential( + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + ) + + self.postBlock = nn.Sequential( + nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64) + ) + + self.final = nn.Sequential( + nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2), + ) + + #new symmetry + + self.btf = Lambda(btf_shift) + + def forward(self, x): + x = self.tf(x) + x = self.preBlock(x) + + x = x + self.postBlock(self.blocks(x)) + + x = self.final(x) + + return self.btf(x) class EDSRBase(nn.Module): def __init__(self, img_size): @@ -457,7 +518,7 @@ def forward(self, x): class RDNet(nn.Module): def __init__(self, img_size): super().__init__() - # torch.cuda.set_device(1) + torch.cuda.set_device(1) self.img_size = img_size self.preBlock = nn.Sequential( @@ -494,4 +555,46 @@ def forward(self, x): x = x + self.postBlock(torch.cat((x1,x2,x3,x4,x5,x6), dim=1)) x = self.final(x) + return x + + +class SRFBNet(nn.Module): + def __init__(self, img_size): + super().__init__() + # torch.cuda.set_device(1) + self.img_size = img_size + + self.preBlock = nn.Sequential( + nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2, bias=False) + ) + + # ResBlock 6 + # self.block1 = FBB(64, 32, first=True) + # self.block2 = FBB(64, 32) + self.block1 = FBB(64, 32, first=True) + + self.postBlock = nn.Sequential( + nn.Conv2d(64, 64, 1, stride=1, padding=0, bias=False), + nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False) + ) + + self.final = nn.Sequential( + nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2, bias=False) + ) + + def forward(self, x): + x = self.preBlock(x) + + + + x1 = torch.zeros(x.shape).cuda() + for i in range(4): + x1 = self.block1(torch.cat((x,x1), dim=1)) + if i == 0: + block = x1 + else: + block = torch.cat((block,x1), dim=0) + + x = torch.cat((x,x,x,x), dim=0) + self.postBlock(block) + x = self.final(x) return x \ No newline at end of file diff --git a/radionets/dl_framework/model.py b/radionets/dl_framework/model.py index 095a86a1..e2796344 100644 --- a/radionets/dl_framework/model.py +++ b/radionets/dl_framework/model.py @@ -5,6 +5,7 @@ from pathlib import Path from math import sqrt, pi from fastcore.foundation import L +from torch.nn.common_types import _size_4_t class Lambda(nn.Module): @@ -234,6 +235,39 @@ def symmetry(x, mode="real"): x[:, i, j] = -x[:, u, v] return torch.rot90(x, 3, dims=(1, 2)) +def better_symmetry(x): + # rotation + x = torch.flip(x, [3]) + + # indices of upper and lower triangle + triu = torch.triu_indices(x.shape[2], x.shape[3], 1) + tril = torch.tril_indices(x.shape[2], x.shape[3], -1) + triu = torch.flip(triu, [1]) + + # sym amp and phase + x[:,0,tril[0], tril[1]] = x[:,0, triu[0], triu[1]] + x[:,1,tril[0], tril[1]] = -x[:,1, triu[0], triu[1]] + + # rotation + x = torch.flip(x, [3]) + + return x + +def tf_shift(x): + triu = torch.triu_indices(x.shape[2], x.shape[2], 0) + tf = torch.flip(x, [3])[:,:,triu[0], triu[1]].reshape(x.shape[0],x.shape[1],x.shape[2],int(x.shape[3]/2)+1) + + return tf + +def btf_shift(x): + btf = torch.zeros((x.shape[0],x.shape[1],x.shape[2], x.shape[3]*2-1)).cuda() + triu = torch.triu_indices(x.shape[2], x.shape[2], 0) + + btf[:,:,triu[0], triu[1]] = x[:,:].reshape(x.shape[0], x.shape[1], -1) + btf = torch.flip(btf, [3]) + + btf = better_symmetry(btf) + return btf def phase_range(phase): # if isinstance(phase, float): @@ -728,4 +762,91 @@ def _cat_split(self, x, y): return torch.cat((x1,y1,x2,y2), dim=1) +class FBB(nn.Module): + def __init__(self, ni, nf, stride=1, first=False): + super().__init__() + self.first = first + if first: + self.convCat = nn.Conv2d(ni*2, ni, 1, stride=1, padding=0, groups=2, bias=False) + else: + self.convCat = nn.Identity() + self.conv1 = self._conv_block(ni,nf,stride) + self.conv2 = self._conv_block(ni+nf,nf,stride) + self.conv3 = self._conv_block(ni+2*nf,nf,stride) + self.conv4 = self._conv_block(ni+3*nf,nf,stride) + self.conv5 = self._conv_block(ni+4*nf,nf,stride) + self.conv6 = self._conv_block(ni+5*nf,nf,stride) + + self.convFusion = nn.Conv2d(ni+6*nf, ni, 1, stride=1, padding=0, groups=2, bias=False) + + def forward(self, x): + # if self.first: + # comb = torch.chunk(x,2, dim=1) + # skip = comb[0] + # x = self._cat_split(comb[0], comb[1]) + # # x = self._cat_split(x, comb[2]) + # # x = self._cat_split(x, comb[3]) + # else: + # skip = x + + x_cc = self.convCat(x) + x1_c = self.conv1(x_cc) + cat = self._cat_split(x_cc, x1_c) + x2_c = self.conv2(cat) + cat = self._cat_split(cat, x2_c) + x3_c = self.conv3(cat) + cat = self._cat_split(cat, x3_c) + x4_c = self.conv4(cat) + cat = self._cat_split(cat, x4_c) + x5_c = self.conv5(cat) + cat = self._cat_split(cat, x5_c) + x6_c = self.conv6(cat) + cat = self._cat_split(cat, x6_c) + + + return self.convFusion(cat) + x_cc + + def _conv_block(self, ni, nf, stride): + return nn.Sequential( + nn.Conv2d(ni, nf, 3, stride=stride, padding=1, bias=False), + nn.PReLU() + ) + + def _cat_split(self, x, y): + x1, x2 = torch.chunk(x,2, dim=1) + y1, y2 = torch.chunk(y,2, dim=1) + return torch.cat((x1,y1,x2,y2), dim=1) + + +class _CirculationPadNd(nn.Module): + __constants__ = ['padding'] + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.pad(input, self.padding, 'circular') + + def extra_repr(self) -> str: + return '{}'.format(self.padding) + +class CirculationPad2d(_CirculationPadNd): + padding: _size_4_t + + def __init__(self, padding: _size_4_t) -> None: + super(CirculationPad2d, self).__init__() + self.padding = _pair(padding) + +class CirculationShiftPad(nn.Module): + padding: _size_4_t + + def __init__(self, padding: _size_4_t) -> None: + super(_CirculationShiftPad, self).__init__() + self.padding = _pair(padding) + def forward(self, input: torch.Tensor) -> torch.Tensor: + x = F.pad(input, self.padding, 'circular') + x[...,0,:] = 0 + x[...,-1,:] = 0 + x[...,:,0] = torch.roll(x[...,:,0],1) + x[...,:,-1] = torch.roll(x[...,:,-1],-1) + x[...,0,:] = 0 + x[...,-1,:] = 0 + return x \ No newline at end of file From d4e4a06caed954e51ad21bfcfaa03aec4d1576f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 10 Dec 2020 13:05:49 +0100 Subject: [PATCH 05/55] add lr scheduler as comment --- radionets/dl_framework/learner.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/radionets/dl_framework/learner.py b/radionets/dl_framework/learner.py index 464254db..576d448c 100644 --- a/radionets/dl_framework/learner.py +++ b/radionets/dl_framework/learner.py @@ -54,6 +54,13 @@ def define_learner( train_conf["lr_stop"], ) } + # lr_max = train_conf["lr_max"] + # div = 25. + # div_final = 1e5 + # pct_start = 0.25 + # moms = (0.95, 0.85) + # sched = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final), + # 'mom': combined_cos(pct_start, moms[0], moms[1], moms[0])} cbfs.extend([ParamScheduler(sched)]) if train_conf["gpu"]: cbfs.extend( From 2b4f07d0218a23cbad13ba22d753eafc40a14600 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 10 Dec 2020 13:07:15 +0100 Subject: [PATCH 06/55] add splitted_l1 --- radionets/dl_framework/loss_functions.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/radionets/dl_framework/loss_functions.py b/radionets/dl_framework/loss_functions.py index df43de81..2b66f86e 100644 --- a/radionets/dl_framework/loss_functions.py +++ b/radionets/dl_framework/loss_functions.py @@ -132,6 +132,20 @@ def l1(x, y): loss = l1(x, y) return loss +def l1_rnn(x, y): + l1 = nn.L1Loss() + x = torch.chunk(x, 4, dim=0) + + l = 0 + for i in range(4): + l += l1(x[i], y) + return l/4 + +def splitted_l1(x, y): + l1 = nn.L1Loss() + l = 0.5*(10*l1(x[:,0], y[:,0]) + l1(x[:,1], y[:,1])) + return l + def splitted_mse(x, y): inp_real = x[:, 0, :] From 2bd6dc8229e5e93eeb30a7dc6bda6480446a02c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 10 Dec 2020 13:08:23 +0100 Subject: [PATCH 07/55] fix for HDF5 lock --- radionets/dl_training/scripts/start_training.py | 2 ++ radionets/simulations/scripts/simulate_images.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/radionets/dl_training/scripts/start_training.py b/radionets/dl_training/scripts/start_training.py index 7aa465c9..5813507b 100644 --- a/radionets/dl_training/scripts/start_training.py +++ b/radionets/dl_training/scripts/start_training.py @@ -18,6 +18,8 @@ ) from radionets.evaluation.train_inspection import create_inspection_plots from pathlib import Path +import os +os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" @click.command() diff --git a/radionets/simulations/scripts/simulate_images.py b/radionets/simulations/scripts/simulate_images.py index ec622154..8cf9ea77 100644 --- a/radionets/simulations/scripts/simulate_images.py +++ b/radionets/simulations/scripts/simulate_images.py @@ -2,6 +2,8 @@ import toml from radionets.simulations.simulate import create_fft_images, sample_fft_images from radionets.simulations.utils import check_outpath, read_config, calc_norm +import os +os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" @click.command() From e646cfe7c33d7edb6d307214e810efe594b3e367 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 10 Dec 2020 13:09:46 +0100 Subject: [PATCH 08/55] add Net suffix for define_arch --- radionets/dl_training/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/radionets/dl_training/utils.py b/radionets/dl_training/utils.py index c40ffe10..2e20f974 100644 --- a/radionets/dl_training/utils.py +++ b/radionets/dl_training/utils.py @@ -69,7 +69,7 @@ def check_outpath(model_path, train_conf): def define_arch(arch_name, img_size): - if "filter_deep" in arch_name or "resnet" in arch_name: + if "filter_deep" in arch_name or "resnet" in arch_name or "Net" in arch_name: arch = getattr(architecture, arch_name)(img_size) else: arch = getattr(architecture, arch_name)() From 8bbf24318c6ad5eda451b3bb998df13467bd1812 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 11 Jan 2021 13:28:05 +0100 Subject: [PATCH 09/55] SRResNet_sym_pad and VGG19 --- .../dl_framework/architectures/superRes.py | 115 +++++++++++++++++- 1 file changed, 113 insertions(+), 2 deletions(-) diff --git a/radionets/dl_framework/architectures/superRes.py b/radionets/dl_framework/architectures/superRes.py index 57e25580..21b8f8df 100644 --- a/radionets/dl_framework/architectures/superRes.py +++ b/radionets/dl_framework/architectures/superRes.py @@ -2,6 +2,7 @@ from torch import nn from math import pi import numpy as np +from fastai.vision import models from radionets.dl_framework.model import ( GeneralELU, ResBlock_amp, @@ -13,7 +14,10 @@ Lambda, better_symmetry, tf_shift, - btf_shift + btf_shift, + CirculationShiftPad, + SRBlockPad, + BetterShiftPad ) @@ -467,6 +471,61 @@ def forward(self, x): return self.btf(x) +class SRResNet_sym_pad(nn.Module): + def __init__(self, img_size): + super().__init__() + # torch.cuda.set_device(1) + self.tf = Lambda(tf_shift) + + self.preBlock = nn.Sequential( + BetterShiftPad((4,4,4,4)), + nn.Conv2d(2, 64, 9, stride=1, padding=0, groups=2), nn.PReLU() + ) + + # ResBlock 16 + self.blocks = nn.Sequential( + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + ) + + self.postBlock = nn.Sequential( + BetterShiftPad((1,1,1,1)), + nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.BatchNorm2d(64) + ) + + self.final = nn.Sequential( + BetterShiftPad((4,4,4,4)), + nn.Conv2d(64, 2, 9, stride=1, padding=0, groups=2), + ) + + #new symmetry + + self.btf = Lambda(btf_shift) + + def forward(self, x): + x = self.tf(x) + x = self.preBlock(x) + + x = x + self.postBlock(self.blocks(x)) + + x = self.final(x) + + return self.btf(x) + class EDSRBase(nn.Module): def __init__(self, img_size): super().__init__() @@ -597,4 +656,56 @@ def forward(self, x): x = torch.cat((x,x,x,x), dim=0) + self.postBlock(block) x = self.final(x) - return x \ No newline at end of file + return x + + +# SRGAN + +class VGG19(nn.Module): + def __init__(self): + super().__init__() + + self.block = nn.Sequential(#63*63 + nn.Conv2d(2,64,3, stride=1), + nn.LeakyReLU(0.2), + + nn.Conv2d(64,64,3, stride=2), + nn.BatchNorm2d(64), + nn.LeakyReLU(0.2), + + nn.Conv2d(64,128,3, stride=1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + + nn.Conv2d(128,128,3, stride=2), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + + nn.Conv2d(128,256,3, stride=1), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2), + + nn.Conv2d(256,256,3, stride=2), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2), + + nn.Conv2d(256,512,3, stride=1), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2), + + nn.Conv2d(512,512,3, stride=2), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2), + ) + self.fc = nn.Sequential( + nn.Flatten(), + nn.Linear(512,1024), + nn.LeakyReLU(0.2), + nn.Linear(1024,1), + nn.Sigmoid() + ) + + + def forward(self, x): + x = self.block(x) + return self.fc(x) From 7c2c15f3badf1688395ce3cdd3fd21a0b92ec8e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 11 Jan 2021 13:29:07 +0100 Subject: [PATCH 10/55] SRGAN --- radionets/dl_framework/callbacks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/radionets/dl_framework/callbacks.py b/radionets/dl_framework/callbacks.py index da60e627..a48e1192 100644 --- a/radionets/dl_framework/callbacks.py +++ b/radionets/dl_framework/callbacks.py @@ -165,13 +165,14 @@ def begin_batch(self): class SaveTempCallback(Callback): _order = 95 - def __init__(self, model_path): + def __init__(self, model_path, gan): self.model_path = model_path + self.gan = gan def after_epoch(self): p = Path(self.model_path).parent p.mkdir(parents=True, exist_ok=True) if (self.epoch + 1) % 10 == 0: out = p / f"temp_{self.epoch + 1}.model" - save_model(self, out) + save_model(self, out, self.gan) print(f"\nFinished Epoch {self.epoch + 1}, model saved.\n") From 63f83ae4506639a64d1a65ad48a99fa2bcc36e07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 11 Jan 2021 13:30:28 +0100 Subject: [PATCH 11/55] pytorch prediction --- radionets/dl_framework/inspection.py | 60 +++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/radionets/dl_framework/inspection.py b/radionets/dl_framework/inspection.py index b00140c7..13e9182c 100644 --- a/radionets/dl_framework/inspection.py +++ b/radionets/dl_framework/inspection.py @@ -1,4 +1,5 @@ import torch +import torch.fft as fft import numpy as np import pandas as pd from tqdm import tqdm @@ -148,6 +149,49 @@ def fft_pred(pred, truth, amp_phase=True): return np.absolute(ifft_pred), np.absolute(ifft_true) +def fft_pred_torch(pred, truth, amp_phase=True): + """ + Transform predicted image and true image to local domain. + + Parameters + ---------- + pred: 4D array [1, channel, height, width] + prediction from eval_model + truth: 3D array [channel, height, width] + true image + amp_phase: Bool + trained on Amp/Phase or Re/Im + + Returns + ------- + ifft_pred, ifft_true: two 2D arrays [height, width] + predicted and true image in local domain + """ + a = pred[:, 0, :, :] + b = pred[:, 1, :, :] + + a_true = truth[:, 0, :, :] + b_true = truth[:, 1, :, :] + + if amp_phase: + amp_pred_rescaled = (10 ** (10 * a) - 1) / 10 ** 10 + phase_pred = b + + amp_true_rescaled = (10 ** (10 * a_true) - 1) / 10 ** 10 + phase_true = b_true + + compl_pred = amp_pred_rescaled * (torch.cos(phase_pred) + 1j * torch.sin(phase_pred))#torch.exp(1j * phase_pred) + compl_true = amp_true_rescaled * (torch.cos(phase_true) + 1j * torch.sin(phase_true))#torch.exp(1j * phase_true) + else: + compl_pred = a + 1j * b + compl_true = a_true + 1j * b_true + + ifft_pred = fft.ifftn(compl_pred) + ifft_true = fft.ifftn(compl_true,2) + + return torch.absolute(ifft_pred), torch.absolute(ifft_true) + + def reshape_2d(array): """ Reshape 1d arrays into 2d ones. @@ -359,4 +403,18 @@ def ssim(pred,true): l = [(2*mean_pred[i]*mean_true[i]+c1[i])/(mean_pred[i]**2+mean_true[i]**2+c1[i]) for i in range(len(pred))] c = [(2*std_pred[i]*std_true[i]+c2[i])/(std_pred[i]**2+std_true[i]**2+c2[i]) for i in range(len(pred))] s = [(cov[i]+c3[i])/(std_pred[i]*std_true[i]+c3[i]) for i in range(len(pred))] - return np.mean([l[i]*c[i]*s[i] for i in range(len(l))]) \ No newline at end of file + return np.mean([l[i]*c[i]*s[i] for i in range(len(l))]) + +def ssim_torch(pred,true): + mean_pred = [torch.mean(p) for p in pred] + std_pred = [torch.std(p) for p in pred] + mean_true = [torch.mean(t) for t in true] + std_true = [torch.std(t) for t in true] + cov = [1/(len(pred[i])**2-1)*torch.sum((pred[i]-mean_pred[i])*(true[i]-mean_true[i])) for i in range(len(pred))] + c1 = [(0.01*torch.amax(t))**2 for t in true] + c2 = [(0.03*torch.amax(t))**2 for t in true] + c3 = [c/2 for c in c2] + l = [(2*mean_pred[i]*mean_true[i]+c1[i])/(mean_pred[i]**2+mean_true[i]**2+c1[i]) for i in range(len(pred))] + c = [(2*std_pred[i]*std_true[i]+c2[i])/(std_pred[i]**2+std_true[i]**2+c2[i]) for i in range(len(pred))] + s = [(cov[i]+c3[i])/(std_pred[i]*std_true[i]+c3[i]) for i in range(len(pred))] + return torch.mean([l[i]*c[i]*s[i] for i in range(len(l))]) \ No newline at end of file From d1ca026732169c929903ab3f311acfdce0de1751 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 11 Jan 2021 13:32:10 +0100 Subject: [PATCH 12/55] SRGAN --- radionets/dl_framework/learner.py | 32 +++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/radionets/dl_framework/learner.py b/radionets/dl_framework/learner.py index 576d448c..f1ae596c 100644 --- a/radionets/dl_framework/learner.py +++ b/radionets/dl_framework/learner.py @@ -13,6 +13,8 @@ from fastai.callback.data import CudaCallback from fastai.callback.schedule import ParamScheduler, combined_cos import radionets.dl_framework.loss_functions as loss_functions +from fastai.vision import gan, models +import torchvision def get_learner( @@ -26,6 +28,18 @@ def get_learner( return Learner(dls, arch, loss_func, lr=lr, cbs=cb_funcs, opt_func=opt_func) +def get_learner_gan( + data, generator, discriminator, lr, gen_loss_func, crit_loss_func, cb_funcs=None, opt_func=Adam, **kwargs +): + init_cnn(generator) + init_cnn(discriminator) + dls = DataLoaders.from_dsets( + data.train_ds, + data.valid_ds, + ) + return gan.GANLearner(dls, generator, discriminator, gen_loss_func, crit_loss_func, lr=lr, cbs=cb_funcs, opt_func=opt_func) + + def define_learner( data, arch, @@ -71,7 +85,7 @@ def define_learner( if not test: cbfs.extend( [ - SaveTempCallback(model_path=model_path), + SaveTempCallback(model_path=model_path, gan=train_conf["gan"]), AvgLossCallback, # DataAug, ] @@ -90,7 +104,17 @@ def define_learner( loss_func = getattr(loss_functions, train_conf["loss_func"]) # Combine model and data in learner - learn = get_learner( - data, arch, lr=lr, opt_func=opt_func, cb_funcs=cbfs, loss_func=loss_func - ) + if train_conf["gan"]: + gl = getattr(loss_functions, "gen_loss") + dl = getattr(loss_functions, "disc_loss") + # vgg = torchvision.models.vgg19_bn() + # vgg.to(device) + # print(arch) + learn = get_learner_gan( + data, arch[0], arch[1], lr=lr, gen_loss_func=gl, crit_loss_func=dl, opt_func=opt_func, cb_funcs=cbfs + ) + else: + learn = get_learner( + data, arch, lr=lr, opt_func=opt_func, cb_funcs=cbfs, loss_func=loss_func + ) return learn From 1246069dbd9d3794c6fe9b5e340f68caeef24a71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 11 Jan 2021 13:32:22 +0100 Subject: [PATCH 13/55] SRGAN --- radionets/dl_framework/loss_functions.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/radionets/dl_framework/loss_functions.py b/radionets/dl_framework/loss_functions.py index 2b66f86e..3d89443d 100644 --- a/radionets/dl_framework/loss_functions.py +++ b/radionets/dl_framework/loss_functions.py @@ -11,6 +11,8 @@ rot, calc_spec, ) +from fastai.vision import gan +import radionets.dl_framework.inspection as inspec class FeatureLoss(nn.Module): @@ -143,9 +145,17 @@ def l1_rnn(x, y): def splitted_l1(x, y): l1 = nn.L1Loss() - l = 0.5*(10*l1(x[:,0], y[:,0]) + l1(x[:,1], y[:,1])) + l = (10*l1(x[:,0], y[:,0]) + l1(x[:,1], y[:,1]))/2 return l +def l1_ssim(x,y): + fft_x, fft_y = inspec.fft_pred_torch(x,y) + l1 = nn.L1Loss() + print(inspec.ssim_torch(fft_x, fft_y).shape) + l = (l1(fft_x, fft_y) + (1-inspec.ssim_torch(fft_x, fft_y)))/2 + return l + + def splitted_mse(x, y): inp_real = x[:, 0, :] @@ -486,3 +496,14 @@ def spe_(x, y): k = sum(loss) loss = k / len(x) return loss + + +#SRGAN +def gen_loss(x,y,z): + l = gan.gan_loss_from_func(nn.L1Loss(), nn.L1Loss(), weights_gen=(1e-3,1))[0] + return l(x,y,z) + + +def disc_loss(x,y): + l = gan.gan_loss_from_func(nn.L1Loss(), nn.L1Loss(), weights_gen=(1e-3,1))[1] + return l(x,y) \ No newline at end of file From 894dbd6d9ac2545685391587be5c8d26d4651d9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 11 Jan 2021 13:35:14 +0100 Subject: [PATCH 14/55] save SRGAN + new transformation --- radionets/dl_framework/model.py | 128 ++++++++++++++++++++++++++------ 1 file changed, 105 insertions(+), 23 deletions(-) diff --git a/radionets/dl_framework/model.py b/radionets/dl_framework/model.py index e2796344..39d66d1d 100644 --- a/radionets/dl_framework/model.py +++ b/radionets/dl_framework/model.py @@ -508,22 +508,40 @@ def load_pre_model(learn, pre_path, visualize=False): learn.recorder.lrs = checkpoint["recorder_lrs"] -def save_model(learn, model_path): - torch.save( - { - "model": learn.model.state_dict(), - "opt": learn.opt.state_dict(), - "epoch": learn.epoch, - "loss": learn.loss, - "iters": learn.recorder.iters, - "vals": learn.recorder.values, - "recorder_train_loss": L(learn.recorder.values[0:]).itemgot(0), - "recorder_valid_loss": L(learn.recorder.values[0:]).itemgot(1), - "recorder_losses": learn.recorder.losses, - "recorder_lrs": learn.recorder.lrs, - }, - model_path, - ) +def save_model(learn, model_path, gan): + # print(learn.model.generator) + if gan: + torch.save( + { + "model": learn.model.generator.state_dict(), + "opt": learn.opt.state_dict(), + "epoch": learn.epoch, + "loss": learn.loss, + "iters": learn.recorder.iters, + "vals": learn.recorder.values, + "recorder_train_loss": L(learn.recorder.values[0:]).itemgot(0), + "recorder_valid_loss": L(learn.recorder.values[0:]).itemgot(1), + "recorder_losses": learn.recorder.losses, + "recorder_lrs": learn.recorder.lrs, + }, + model_path, + ) + else: + torch.save( + { + "model": learn.model.state_dict(), + "opt": learn.opt.state_dict(), + "epoch": learn.epoch, + "loss": learn.loss, + "iters": learn.recorder.iters, + "vals": learn.recorder.values, + "recorder_train_loss": L(learn.recorder.values[0:]).itemgot(0), + "recorder_valid_loss": L(learn.recorder.values[0:]).itemgot(1), + "recorder_losses": learn.recorder.losses, + "recorder_lrs": learn.recorder.lrs, + }, + model_path, + ) class LocallyConnected2d(nn.Module): @@ -704,6 +722,29 @@ def _conv_block(self, ni, nf, stride): nn.BatchNorm2d(nf), ) +class SRBlockPad(nn.Module): + def __init__(self, ni, nf, stride=1): + super().__init__() + self.convs = self._conv_block(ni, nf, stride) + self.idconv = nn.Identity() if ni == nf else nn.Conv2d(ni, nf, 1) + self.pool = ( + nn.Identity() if stride == 1 else nn.AvgPool2d(2, ceil_mode=True) + ) # nn.AvgPool2d(8, 2, ceil_mode=True) + + def forward(self, x): + return self.convs(x) + self.idconv(self.pool(x)) + + def _conv_block(self, ni, nf, stride): + return nn.Sequential( + BetterShiftPad((1,1,1,1)), + nn.Conv2d(ni, nf, 3, stride=stride, padding=0), + nn.BatchNorm2d(nf), + nn.PReLU(), + BetterShiftPad((1,1,1,1)), + nn.Conv2d(nf, nf, 3, stride=1, padding=0), + nn.BatchNorm2d(nf), + ) + class EDSRBaseBlock(nn.Module): def __init__(self, ni, nf, stride=1): super().__init__() @@ -838,15 +879,56 @@ class CirculationShiftPad(nn.Module): padding: _size_4_t def __init__(self, padding: _size_4_t) -> None: - super(_CirculationShiftPad, self).__init__() + super(CirculationShiftPad, self).__init__() self.padding = _pair(padding) def forward(self, input: torch.Tensor) -> torch.Tensor: x = F.pad(input, self.padding, 'circular') - x[...,0,:] = 0 - x[...,-1,:] = 0 - x[...,:,0] = torch.roll(x[...,:,0],1) - x[...,:,-1] = torch.roll(x[...,:,-1],-1) - x[...,0,:] = 0 - x[...,-1,:] = 0 + x[...,:self.padding[2],:] = 0 + x[...,-self.padding[3]:,:] = 0 + x[...,:,:self.padding[0]] = torch.roll(x[...,:,:self.padding[0]],1,2) + x[...,:,-self.padding[1]:] = torch.roll(x[...,:,-self.padding[1]:],-1,2) + x[...,:self.padding[2],:] = 0 + x[...,-self.padding[3]:,:] = 0 + return x + +def better_padding(input, padding): + in_shape = input.shape + paddable_shape = in_shape[2:] + + out_shape = in_shape[:2] + for idx, size in enumerate(paddable_shape): + out_shape += (size + padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)],) + + # fill empty tensor of new shape with input + out = torch.zeros(out_shape, dtype=input.dtype, layout=input.layout, + device=input.device) + + out[..., padding[-2]:(out_shape[2]-padding[-1]), padding[-4]:(out_shape[3]-padding[-3])] = input + + # pad left + i0 = out_shape[3] - padding[-4] - padding[-3] + i1 = out_shape[3] -padding[-3] + o0 = 0 + o1 = padding[-4] + out[:, :, padding[-2]:out_shape[2]-padding[-1], o0:o1] = out[:, :, padding[-2]-1:out_shape[2]-padding[-1]-1, i0:i1] + + # pad right + i0 = padding[-4] + i1 = padding[-4] + padding[-3] + o0 = out_shape[3] - padding[-3] + o1 = out_shape[3] + out[:, :, padding[-2]:out_shape[2]-padding[-1], o0:o1] = out[:, :, padding[-2]+1:out_shape[2]-padding[-1]+1, i0:i1] + + return out + +class BetterShiftPad(nn.Module): + padding: _size_4_t + + def __init__(self, padding: _size_4_t) -> None: + super(BetterShiftPad, self).__init__() + self.padding = _pair(padding) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + x = better_padding(input, self.padding) return x \ No newline at end of file From 9359440302046d979193a5274a5df15c5fc356df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 11 Jan 2021 13:36:22 +0100 Subject: [PATCH 15/55] SRGAN --- radionets/dl_training/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/radionets/dl_training/utils.py b/radionets/dl_training/utils.py index 2e20f974..4e4387da 100644 --- a/radionets/dl_training/utils.py +++ b/radionets/dl_training/utils.py @@ -41,6 +41,7 @@ def read_config(config): train_conf["inspection"] = config["general"]["inspection"] train_conf["separate"] = False train_conf["format"] = config["general"]["output_format"] + train_conf["gan"] = config["general"]["gan"] train_conf["param_scheduling"] = config["param_scheduling"]["use"] train_conf["lr_start"] = config["param_scheduling"]["lr_start"] @@ -71,6 +72,8 @@ def check_outpath(model_path, train_conf): def define_arch(arch_name, img_size): if "filter_deep" in arch_name or "resnet" in arch_name or "Net" in arch_name: arch = getattr(architecture, arch_name)(img_size) + elif "SRGAN" in arch_name: + arch = [getattr(architecture, "SRResNet_sym_pad")(img_size), getattr(architecture, "VGG19")()] else: arch = getattr(architecture, arch_name)() return arch @@ -81,7 +84,7 @@ def pop_interrupt(learn, train_conf): model_path = train_conf["model_path"] # save model print("Saving the model after epoch {}".format(learn.epoch)) - save_model(learn, model_path) + save_model(learn, model_path, train_conf["gan"]) # plot loss plot_loss(learn, model_path) @@ -96,7 +99,7 @@ def pop_interrupt(learn, train_conf): def end_training(learn, train_conf): # Save model - save_model(learn, Path(train_conf["model_path"])) + save_model(learn, Path(train_conf["model_path"]), train_conf["gan"]) # Plot loss plot_loss(learn, Path(train_conf["model_path"])) From 7cbd036a176f4bb31c9ee18394a7e533e81eb6dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 11 Jan 2021 13:37:15 +0100 Subject: [PATCH 16/55] Process VLBI Data --- radionets/simulations/process_vlbi.py | 127 ++++++++++++++++++++++++++ radionets/simulations/utils.py | 17 ++++ 2 files changed, 144 insertions(+) create mode 100644 radionets/simulations/process_vlbi.py diff --git a/radionets/simulations/process_vlbi.py b/radionets/simulations/process_vlbi.py new file mode 100644 index 00000000..9a74abee --- /dev/null +++ b/radionets/simulations/process_vlbi.py @@ -0,0 +1,127 @@ +import os +from tqdm import tqdm +from numpy import savez_compressed +from radionets.simulations.utils import ( + get_fft_bundle_paths, + get_real_bundle_paths, + prepare_fft_images, + interpol, +) +from radionets.dl_framework.data import ( + open_fft_bundle, + save_fft_pair, + save_fft_pair_list, +) +from radionets.simulations.uv_simulations import sample_freqs +import h5py +import numpy as np +from astropy.io import fits +from PIL import Image +import cv2 + + +def process_data( + data_path, + # amp_phase, + # real_imag, + # fourier, + # compressed, + # interpolation, + # specific_mask, + # antenna_config, + # lon=None, + # lat=None, + # steps=None, +): + + print(f"\n Loading VLBI data set.\n") + bundle_paths = get_real_bundle_paths(data_path) + size = len(bundle_paths[0]) + img = np.zeros((size,256,256)) + samps = np.zeros((size,4,21000)) + for i in tqdm(range(size)): + sampled = bundle_paths[0][i] + target = bundle_paths[1][i] + + with fits.open(target) as hdul: + img[i] = hdul[0].data + + with fits.open(sampled) as hdul: + data = hdul[4].data + samps[i] = [np.append(data['UCOORD']/hdul[1].data['EFF_WAVE'],-data['UCOORD']/hdul[1].data['EFF_WAVE']),np.append(data['VCOORD']/hdul[1].data['EFF_WAVE'],-data['VCOORD']/hdul[1].data['EFF_WAVE']),np.append(data['VISAMP'],data['VISAMP']),np.append(data['VISPHI'],-data['VISPHI'])] + + print(f"\n Gridding VLBI data set.\n") + + # Generate Mask + u_0 = samps[0][0] + v_0 = samps[0][1] + N = 127 + mask = np.zeros((N,N,21000)) + umax = max(u_0) + delta_u = 2*umax/N + for i in range(N): + for j in range(N): + u_cell = (j-N/2)*delta_u + v_cell = (i-N/2)*delta_u + mask[i,j] = ((u_cell <= u_0) & (u_0 <= u_cell+delta_u)) & ((v_cell <= v_0) & (v_0 <= v_cell+delta_u)) + + mask = np.flip(mask, [0]) + points = np.sum(mask, 2) + points[points==0] = 1 + samp_img = np.zeros((size,2,N,N)) + img_resized = np.zeros((size,N,N)) + for i in tqdm(range(samps.shape[0])): + samp_img[i][0] = np.matmul(mask, samps[i][2].T)/points + samp_img[i][0] = (np.log10(samp_img[i][0] + 1e-10) / 10) + 1 + samp_img[i][1] = np.deg2rad(np.matmul(mask, samps[i][3].T)/points) + img_resized[i] = cv2.resize(img[i], (N,N)) + + truth_fft = np.array([np.fft.fftshift(np.fft.fft2(im)) for im in img_resized]) + fft_scaled_truth = prepare_fft_images(truth_fft, True, False) + + out = data_path + "/samp_train0.h5" + save_fft_pair(out, samp_img[:100], fft_scaled_truth[:100]) + out = data_path + "/samp_valid0.h5" + save_fft_pair(out, samp_img[100:], fft_scaled_truth[100:]) + + # return samp_img, fft_scaled_truth + # f = h5py.File(path, "r") + # z = np.array(f["z"]) + # size = fft.shape[-1] + + # fft_scaled = prepare_fft_images(fft.copy(), amp_phase, real_imag) + # truth_fft = np.array([np.fft.fftshift(np.fft.fft2(img)) for img in truth]) + # fft_scaled_truth = prepare_fft_images(truth_fft, amp_phase, real_imag) + + # if specific_mask is True: + # fft_samp = sample_freqs( + # fft_scaled.copy(), + # antenna_config, + # size, + # lon, + # lat, + # steps, + # plot=False, + # test=False, + # ) + # else: + # fft_samp = sample_freqs( + # fft_scaled.copy(), + # antenna_config, + # size=size, + # specific_mask=False, + # ) + # if interpolation: + # for i in range(len(fft_samp[:, 0, 0, 0])): + # fft_samp[i] = interpol(fft_samp[i]) + + # out = data_path + "/samp_" + path.name.split("_")[-1] + + # if fourier: + # if compressed: + # savez_compressed(out, x=fft_samp, y=fft_scaled) + # os.remove(path) + # else: + # save_fft_pair(out, fft_samp, fft_scaled_truth) + # else: + # save_fft_pair_list(out, fft_samp, truth, z) diff --git a/radionets/simulations/utils.py b/radionets/simulations/utils.py index 3c83a334..cc1585df 100644 --- a/radionets/simulations/utils.py +++ b/radionets/simulations/utils.py @@ -10,6 +10,7 @@ from tqdm import tqdm from skimage.transform import resize from scipy import interpolate +from natsort import natsorted from radionets.dl_framework.data import ( save_fft_pair, open_fft_pair, @@ -231,6 +232,22 @@ def get_fft_bundle_paths(data_path, ftype, mode): return bundle_paths +def get_real_bundle_paths(data_path): + bundles = get_bundles(data_path) + bundles_target = get_bundles(bundles[3]) + bundles_input = get_bundles(bundles[1]) + bundle_paths_target = [ + path for path in bundles_target if re.findall(f"[0-9].fits", path.name) + ] + bundle_paths_input = [ + path for path in bundles_input if re.findall(f"[0-9].oifits", path.name) + ] + bundle_paths_input = natsorted(bundle_paths_input) + bundle_paths_target = natsorted(bundle_paths_target) + + return [bundle_paths_input, bundle_paths_target] + + def prepare_fft_images(fft_images, amp_phase, real_imag): if amp_phase: amp, phase = split_amp_phase(fft_images) From dad3edcd8e0b10d3b3188bc6ffe71182d04cce07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 21 Jan 2021 13:42:39 +0100 Subject: [PATCH 17/55] env + fine_tune --- environment.yml | 7 ++++--- radionets/dl_training/scripts/start_training.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/environment.yml b/environment.yml index 6892cbc4..55972ba1 100644 --- a/environment.yml +++ b/environment.yml @@ -5,11 +5,12 @@ channels: - pytorch - defaults dependencies: - - python>=3.6 - - pytorch>=1.6 + - python + - pytorch - torchvision - - cudatoolkit=10.2 + - cudatoolkit=11.0 - cartopy + - natsort - pip - pip: - -e . diff --git a/radionets/dl_training/scripts/start_training.py b/radionets/dl_training/scripts/start_training.py index 5813507b..4fcf3c78 100644 --- a/radionets/dl_training/scripts/start_training.py +++ b/radionets/dl_training/scripts/start_training.py @@ -93,6 +93,7 @@ def main(configuration_path, mode): # Train the model, except interrupt try: + # learn.fine_tune(train_conf["num_epochs"]) learn.fit(train_conf["num_epochs"]) except KeyboardInterrupt: pop_interrupt(learn, train_conf) From 39fdb79656062c496f6caf4dc70bba827bae1689 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 22 Apr 2021 09:51:02 +0200 Subject: [PATCH 18/55] delete gan --- radionets/dl_framework/callbacks.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/radionets/dl_framework/callbacks.py b/radionets/dl_framework/callbacks.py index a48e1192..da60e627 100644 --- a/radionets/dl_framework/callbacks.py +++ b/radionets/dl_framework/callbacks.py @@ -165,14 +165,13 @@ def begin_batch(self): class SaveTempCallback(Callback): _order = 95 - def __init__(self, model_path, gan): + def __init__(self, model_path): self.model_path = model_path - self.gan = gan def after_epoch(self): p = Path(self.model_path).parent p.mkdir(parents=True, exist_ok=True) if (self.epoch + 1) % 10 == 0: out = p / f"temp_{self.epoch + 1}.model" - save_model(self, out, self.gan) + save_model(self, out) print(f"\nFinished Epoch {self.epoch + 1}, model saved.\n") From e81b149e376f73e67407f251ec730501b5301296 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 22 Apr 2021 09:54:47 +0200 Subject: [PATCH 19/55] delete gan --- radionets/dl_framework/learner.py | 34 ++++++------------------------- 1 file changed, 6 insertions(+), 28 deletions(-) diff --git a/radionets/dl_framework/learner.py b/radionets/dl_framework/learner.py index f1ae596c..3ecaae76 100644 --- a/radionets/dl_framework/learner.py +++ b/radionets/dl_framework/learner.py @@ -13,7 +13,7 @@ from fastai.callback.data import CudaCallback from fastai.callback.schedule import ParamScheduler, combined_cos import radionets.dl_framework.loss_functions as loss_functions -from fastai.vision import gan, models +from fastai.vision import models import torchvision @@ -28,18 +28,6 @@ def get_learner( return Learner(dls, arch, loss_func, lr=lr, cbs=cb_funcs, opt_func=opt_func) -def get_learner_gan( - data, generator, discriminator, lr, gen_loss_func, crit_loss_func, cb_funcs=None, opt_func=Adam, **kwargs -): - init_cnn(generator) - init_cnn(discriminator) - dls = DataLoaders.from_dsets( - data.train_ds, - data.valid_ds, - ) - return gan.GANLearner(dls, generator, discriminator, gen_loss_func, crit_loss_func, lr=lr, cbs=cb_funcs, opt_func=opt_func) - - def define_learner( data, arch, @@ -85,9 +73,9 @@ def define_learner( if not test: cbfs.extend( [ - SaveTempCallback(model_path=model_path, gan=train_conf["gan"]), + SaveTempCallback(model_path=model_path), AvgLossCallback, - # DataAug, + DataAug, ] ) if train_conf["telegram_logger"]: @@ -104,17 +92,7 @@ def define_learner( loss_func = getattr(loss_functions, train_conf["loss_func"]) # Combine model and data in learner - if train_conf["gan"]: - gl = getattr(loss_functions, "gen_loss") - dl = getattr(loss_functions, "disc_loss") - # vgg = torchvision.models.vgg19_bn() - # vgg.to(device) - # print(arch) - learn = get_learner_gan( - data, arch[0], arch[1], lr=lr, gen_loss_func=gl, crit_loss_func=dl, opt_func=opt_func, cb_funcs=cbfs - ) - else: - learn = get_learner( - data, arch, lr=lr, opt_func=opt_func, cb_funcs=cbfs, loss_func=loss_func - ) + learn = get_learner( + data, arch, lr=lr, opt_func=opt_func, cb_funcs=cbfs, loss_func=loss_func + ) return learn From c5499c5a16d35e6c2bdeea009287748fe3629691 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 22 Apr 2021 09:55:38 +0200 Subject: [PATCH 20/55] delete gan --- radionets/dl_framework/loss_functions.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/radionets/dl_framework/loss_functions.py b/radionets/dl_framework/loss_functions.py index 3d89443d..ff025dec 100644 --- a/radionets/dl_framework/loss_functions.py +++ b/radionets/dl_framework/loss_functions.py @@ -11,7 +11,6 @@ rot, calc_spec, ) -from fastai.vision import gan import radionets.dl_framework.inspection as inspec @@ -496,14 +495,3 @@ def spe_(x, y): k = sum(loss) loss = k / len(x) return loss - - -#SRGAN -def gen_loss(x,y,z): - l = gan.gan_loss_from_func(nn.L1Loss(), nn.L1Loss(), weights_gen=(1e-3,1))[0] - return l(x,y,z) - - -def disc_loss(x,y): - l = gan.gan_loss_from_func(nn.L1Loss(), nn.L1Loss(), weights_gen=(1e-3,1))[1] - return l(x,y) \ No newline at end of file From 8dfda4f03df82a825e9ab8d852d4d82d16531e9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 22 Apr 2021 09:57:17 +0200 Subject: [PATCH 21/55] delete gan --- radionets/dl_framework/model.py | 53 ++++++++++++--------------------- 1 file changed, 19 insertions(+), 34 deletions(-) diff --git a/radionets/dl_framework/model.py b/radionets/dl_framework/model.py index 39d66d1d..72f40ca7 100644 --- a/radionets/dl_framework/model.py +++ b/radionets/dl_framework/model.py @@ -6,6 +6,8 @@ from math import sqrt, pi from fastcore.foundation import L from torch.nn.common_types import _size_4_t +import numpy as np +import radionets.simulations.utils as utils class Lambda(nn.Module): @@ -508,40 +510,23 @@ def load_pre_model(learn, pre_path, visualize=False): learn.recorder.lrs = checkpoint["recorder_lrs"] -def save_model(learn, model_path, gan): +def save_model(learn, model_path): # print(learn.model.generator) - if gan: - torch.save( - { - "model": learn.model.generator.state_dict(), - "opt": learn.opt.state_dict(), - "epoch": learn.epoch, - "loss": learn.loss, - "iters": learn.recorder.iters, - "vals": learn.recorder.values, - "recorder_train_loss": L(learn.recorder.values[0:]).itemgot(0), - "recorder_valid_loss": L(learn.recorder.values[0:]).itemgot(1), - "recorder_losses": learn.recorder.losses, - "recorder_lrs": learn.recorder.lrs, - }, - model_path, - ) - else: - torch.save( - { - "model": learn.model.state_dict(), - "opt": learn.opt.state_dict(), - "epoch": learn.epoch, - "loss": learn.loss, - "iters": learn.recorder.iters, - "vals": learn.recorder.values, - "recorder_train_loss": L(learn.recorder.values[0:]).itemgot(0), - "recorder_valid_loss": L(learn.recorder.values[0:]).itemgot(1), - "recorder_losses": learn.recorder.losses, - "recorder_lrs": learn.recorder.lrs, - }, - model_path, - ) + torch.save( + { + "model": learn.model.state_dict(), + "opt": learn.opt.state_dict(), + "epoch": learn.epoch, + "loss": learn.loss, + "iters": learn.recorder.iters, + "vals": learn.recorder.values, + "recorder_train_loss": L(learn.recorder.values[0:]).itemgot(0), + "recorder_valid_loss": L(learn.recorder.values[0:]).itemgot(1), + "recorder_losses": learn.recorder.losses, + "recorder_lrs": learn.recorder.lrs, + }, + model_path, + ) class LocallyConnected2d(nn.Module): @@ -908,7 +893,7 @@ def better_padding(input, padding): # pad left i0 = out_shape[3] - padding[-4] - padding[-3] - i1 = out_shape[3] -padding[-3] + i1 = out_shape[3] - padding[-3] o0 = 0 o1 = padding[-4] out[:, :, padding[-2]:out_shape[2]-padding[-1], o0:o1] = out[:, :, padding[-2]-1:out_shape[2]-padding[-1]-1, i0:i1] From bfe8dfc147f284cba12f9bfe95b23eaf79e21ac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 22 Apr 2021 10:00:16 +0200 Subject: [PATCH 22/55] delete gan --- .../dl_framework/architectures/superRes.py | 54 +------------------ 1 file changed, 1 insertion(+), 53 deletions(-) diff --git a/radionets/dl_framework/architectures/superRes.py b/radionets/dl_framework/architectures/superRes.py index 21b8f8df..10ce4e78 100644 --- a/radionets/dl_framework/architectures/superRes.py +++ b/radionets/dl_framework/architectures/superRes.py @@ -656,56 +656,4 @@ def forward(self, x): x = torch.cat((x,x,x,x), dim=0) + self.postBlock(block) x = self.final(x) - return x - - -# SRGAN - -class VGG19(nn.Module): - def __init__(self): - super().__init__() - - self.block = nn.Sequential(#63*63 - nn.Conv2d(2,64,3, stride=1), - nn.LeakyReLU(0.2), - - nn.Conv2d(64,64,3, stride=2), - nn.BatchNorm2d(64), - nn.LeakyReLU(0.2), - - nn.Conv2d(64,128,3, stride=1), - nn.BatchNorm2d(128), - nn.LeakyReLU(0.2), - - nn.Conv2d(128,128,3, stride=2), - nn.BatchNorm2d(128), - nn.LeakyReLU(0.2), - - nn.Conv2d(128,256,3, stride=1), - nn.BatchNorm2d(256), - nn.LeakyReLU(0.2), - - nn.Conv2d(256,256,3, stride=2), - nn.BatchNorm2d(256), - nn.LeakyReLU(0.2), - - nn.Conv2d(256,512,3, stride=1), - nn.BatchNorm2d(512), - nn.LeakyReLU(0.2), - - nn.Conv2d(512,512,3, stride=2), - nn.BatchNorm2d(512), - nn.LeakyReLU(0.2), - ) - self.fc = nn.Sequential( - nn.Flatten(), - nn.Linear(512,1024), - nn.LeakyReLU(0.2), - nn.Linear(1024,1), - nn.Sigmoid() - ) - - - def forward(self, x): - x = self.block(x) - return self.fc(x) + return x \ No newline at end of file From 9dcb8dd601a82f04845e2198a9e74e6affeffa73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 22 Apr 2021 10:01:00 +0200 Subject: [PATCH 23/55] delete gan --- radionets/dl_training/utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/radionets/dl_training/utils.py b/radionets/dl_training/utils.py index 4e4387da..2e20f974 100644 --- a/radionets/dl_training/utils.py +++ b/radionets/dl_training/utils.py @@ -41,7 +41,6 @@ def read_config(config): train_conf["inspection"] = config["general"]["inspection"] train_conf["separate"] = False train_conf["format"] = config["general"]["output_format"] - train_conf["gan"] = config["general"]["gan"] train_conf["param_scheduling"] = config["param_scheduling"]["use"] train_conf["lr_start"] = config["param_scheduling"]["lr_start"] @@ -72,8 +71,6 @@ def check_outpath(model_path, train_conf): def define_arch(arch_name, img_size): if "filter_deep" in arch_name or "resnet" in arch_name or "Net" in arch_name: arch = getattr(architecture, arch_name)(img_size) - elif "SRGAN" in arch_name: - arch = [getattr(architecture, "SRResNet_sym_pad")(img_size), getattr(architecture, "VGG19")()] else: arch = getattr(architecture, arch_name)() return arch @@ -84,7 +81,7 @@ def pop_interrupt(learn, train_conf): model_path = train_conf["model_path"] # save model print("Saving the model after epoch {}".format(learn.epoch)) - save_model(learn, model_path, train_conf["gan"]) + save_model(learn, model_path) # plot loss plot_loss(learn, model_path) @@ -99,7 +96,7 @@ def pop_interrupt(learn, train_conf): def end_training(learn, train_conf): # Save model - save_model(learn, Path(train_conf["model_path"]), train_conf["gan"]) + save_model(learn, Path(train_conf["model_path"])) # Plot loss plot_loss(learn, Path(train_conf["model_path"])) From fce638da06db413ba8adfb82710fef0c9b485521 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 22 Apr 2021 10:02:15 +0200 Subject: [PATCH 24/55] changed for vipy/eht simulations; VERY HARD CODE --- radionets/simulations/process_vlbi.py | 99 +++++++++++---------------- 1 file changed, 39 insertions(+), 60 deletions(-) diff --git a/radionets/simulations/process_vlbi.py b/radionets/simulations/process_vlbi.py index 9a74abee..9b6923ca 100644 --- a/radionets/simulations/process_vlbi.py +++ b/radionets/simulations/process_vlbi.py @@ -18,6 +18,10 @@ from astropy.io import fits from PIL import Image import cv2 +import radionets.dl_framework.data as dt +import re +from natsort import natsorted +from PIL import Image def process_data( @@ -35,30 +39,45 @@ def process_data( ): print(f"\n Loading VLBI data set.\n") - bundle_paths = get_real_bundle_paths(data_path) - size = len(bundle_paths[0]) + bundles = dt.get_bundles('/net/big-tank/POOL/users/sfroese/vipy/eht/m87/') + freq = 227297*10**6 # hard code #eht 227297 + bundles_target = dt.get_bundles(bundles[0]) + bundles_input = dt.get_bundles(bundles[1]) + bundle_paths_target = natsorted(bundles_target) + bundle_paths_input = natsorted(bundles_input) + size = len(bundle_paths_target) img = np.zeros((size,256,256)) - samps = np.zeros((size,4,21000)) + samps = np.zeros((size,4,21036)) # hard code for i in tqdm(range(size)): - sampled = bundle_paths[0][i] - target = bundle_paths[1][i] + sampled = bundle_paths_input[i] + target = bundle_paths_target[i] - with fits.open(target) as hdul: - img[i] = hdul[0].data - + img[i] = np.asarray(Image.open(str(target))) + # img[i] = img[i]/np.sum(img[i]) + with fits.open(sampled) as hdul: - data = hdul[4].data - samps[i] = [np.append(data['UCOORD']/hdul[1].data['EFF_WAVE'],-data['UCOORD']/hdul[1].data['EFF_WAVE']),np.append(data['VCOORD']/hdul[1].data['EFF_WAVE'],-data['VCOORD']/hdul[1].data['EFF_WAVE']),np.append(data['VISAMP'],data['VISAMP']),np.append(data['VISPHI'],-data['VISPHI'])] + data = hdul[0].data + cmplx = data['DATA'] + x = cmplx[...,0,0] + y = cmplx[...,0,1] + w = cmplx[...,0,2] + x = np.squeeze(x) + y = np.squeeze(y) + w = np.squeeze(w) + ap = np.sqrt(x**2+y**2) + ph = np.angle(x+1j*y) + samps[i] = [np.append(data['UU--']*freq,-data['UU--']*freq),np.append(data['VV--']*freq,-data['VV--']*freq),np.append(ap,ap),np.append(ph,-ph)] print(f"\n Gridding VLBI data set.\n") # Generate Mask u_0 = samps[0][0] v_0 = samps[0][1] - N = 127 - mask = np.zeros((N,N,21000)) - umax = max(u_0) - delta_u = 2*umax/N + N = 63 # hard code + mask = np.zeros((N,N,u_0.shape[0])) + fov = 0.00018382*np.pi/(3600*180) # hard code #default 0.00018382 + # delta_u = 1/(fov*N/256) # hard code + delta_u = 1/(fov) for i in range(N): for j in range(N): u_cell = (j-N/2)*delta_u @@ -73,55 +92,15 @@ def process_data( for i in tqdm(range(samps.shape[0])): samp_img[i][0] = np.matmul(mask, samps[i][2].T)/points samp_img[i][0] = (np.log10(samp_img[i][0] + 1e-10) / 10) + 1 - samp_img[i][1] = np.deg2rad(np.matmul(mask, samps[i][3].T)/points) + samp_img[i][1] = np.matmul(mask, samps[i][3].T)/points img_resized[i] = cv2.resize(img[i], (N,N)) + img_resized[i] = img_resized[i]/np.sum(img_resized[i]) - truth_fft = np.array([np.fft.fftshift(np.fft.fft2(im)) for im in img_resized]) + # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) + truth_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(img_resized, axes=(1,2)), axes=(1,2)), axes=(1,2)) fft_scaled_truth = prepare_fft_images(truth_fft, True, False) out = data_path + "/samp_train0.h5" - save_fft_pair(out, samp_img[:100], fft_scaled_truth[:100]) + save_fft_pair(out, samp_img[:500], fft_scaled_truth[:500]) out = data_path + "/samp_valid0.h5" - save_fft_pair(out, samp_img[100:], fft_scaled_truth[100:]) - - # return samp_img, fft_scaled_truth - # f = h5py.File(path, "r") - # z = np.array(f["z"]) - # size = fft.shape[-1] - - # fft_scaled = prepare_fft_images(fft.copy(), amp_phase, real_imag) - # truth_fft = np.array([np.fft.fftshift(np.fft.fft2(img)) for img in truth]) - # fft_scaled_truth = prepare_fft_images(truth_fft, amp_phase, real_imag) - - # if specific_mask is True: - # fft_samp = sample_freqs( - # fft_scaled.copy(), - # antenna_config, - # size, - # lon, - # lat, - # steps, - # plot=False, - # test=False, - # ) - # else: - # fft_samp = sample_freqs( - # fft_scaled.copy(), - # antenna_config, - # size=size, - # specific_mask=False, - # ) - # if interpolation: - # for i in range(len(fft_samp[:, 0, 0, 0])): - # fft_samp[i] = interpol(fft_samp[i]) - - # out = data_path + "/samp_" + path.name.split("_")[-1] - - # if fourier: - # if compressed: - # savez_compressed(out, x=fft_samp, y=fft_scaled) - # os.remove(path) - # else: - # save_fft_pair(out, fft_samp, fft_scaled_truth) - # else: - # save_fft_pair_list(out, fft_samp, truth, z) + save_fft_pair(out, samp_img[500:], fft_scaled_truth[500:]) From 61e9450781fb385aabaad33b8e332724e1bd1f92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Wed, 19 May 2021 18:36:35 +0200 Subject: [PATCH 25/55] gridding for model loss --- radionets/dl_framework/data.py | 21 +++- radionets/simulations/process_vlbi.py | 143 ++++++++++++++++++++++++-- 2 files changed, 153 insertions(+), 11 deletions(-) diff --git a/radionets/dl_framework/data.py b/radionets/dl_framework/data.py index 25673a18..42c46203 100644 --- a/radionets/dl_framework/data.py +++ b/radionets/dl_framework/data.py @@ -32,7 +32,7 @@ def do_normalisation(x, norm): class h5_dataset: - def __init__(self, bundle_paths, tar_fourier, amp_phase=None, source_list=False): + def __init__(self, bundle_paths, tar_fourier, amp_phase=None, source_list=False, vgg=False): """ Save the bundle paths and the number of bundles in one file. """ @@ -41,6 +41,7 @@ def __init__(self, bundle_paths, tar_fourier, amp_phase=None, source_list=False) self.tar_fourier = tar_fourier self.amp_phase = amp_phase self.source_list = source_list + self.vgg = vgg def __call__(self): return print("This is the h5_dataset class.") @@ -99,6 +100,8 @@ def open_image(self, var, i): else: if self.source_list: data_channel = data + elif self.vgg: + data_channel = data else: if data.shape[1] == 2: raise ValueError( @@ -202,6 +205,18 @@ def save_fft_pair(path, x, y, z=None, name_x="x", name_y="y", name_z="z"): hf.create_dataset(name_z, data=z) hf.close() +def save_fft_pair_with_response(path, x, y, base_mask, A, name_x="x", name_y="y", name_base_mask="base_mask", name_A='A'): + """ + write fft_pairs created in second analysis step to h5 file + write response matrices & baselines + """ + with h5py.File(path, "w") as hf: + hf.create_dataset(name_x, data=x) + hf.create_dataset(name_y, data=y) + hf.create_dataset(name_base_mask, data=base_mask) + hf.create_dataset(name_A, data=A) + hf.close() + def open_fft_pair(path): """ @@ -217,7 +232,7 @@ def mean_and_std(array): return array.mean(), array.std() -def load_data(data_path, mode, fourier=False, source_list=False): +def load_data(data_path, mode, fourier=False, source_list=False, vgg=False): """ Load data set from a directory and return it as h5_dataset. @@ -237,5 +252,5 @@ def load_data(data_path, mode, fourier=False, source_list=False): """ bundle_paths = get_bundles(data_path) data = [path for path in bundle_paths if re.findall("samp_" + mode, path.name)] - ds = h5_dataset(data, tar_fourier=fourier, source_list=source_list) + ds = h5_dataset(data, tar_fourier=fourier, source_list=source_list, vgg=vgg) return ds diff --git a/radionets/simulations/process_vlbi.py b/radionets/simulations/process_vlbi.py index 9b6923ca..daa3548f 100644 --- a/radionets/simulations/process_vlbi.py +++ b/radionets/simulations/process_vlbi.py @@ -10,7 +10,7 @@ from radionets.dl_framework.data import ( open_fft_bundle, save_fft_pair, - save_fft_pair_list, + save_fft_pair_with_response, ) from radionets.simulations.uv_simulations import sample_freqs import h5py @@ -20,8 +20,23 @@ import cv2 import radionets.dl_framework.data as dt import re -from natsort import natsorted +from natsort import natsorted, ns from PIL import Image +import os +import vipy.simulation.utils as ut +import vipy.layouts.layouts as layouts +import astropy.constants as const +from astropy import units as un +import vipy.simulation.scan as scan + +# set env flags to catch BLAS used for scipy/numpy +# to only use 1 cpu, n_cpus will be totally controlled by csky +# flags from mirco +os.environ['MKL_NUM_THREADS'] = "12" +os.environ['NUMEXPR_NUM_THREADS'] = "12" +os.environ['OMP_NUM_THREADS'] = "12" +os.environ['OPENBLAS_NUM_THREADS'] = "12" +os.environ['VECLIB_MAXIMUM_THREADS'] = "12" def process_data( @@ -39,15 +54,15 @@ def process_data( ): print(f"\n Loading VLBI data set.\n") - bundles = dt.get_bundles('/net/big-tank/POOL/users/sfroese/vipy/eht/m87/') + bundles = dt.get_bundles('/net/big-tank/POOL/users/sfroese/vipy/eht/m87/blackhole/') freq = 227297*10**6 # hard code #eht 227297 - bundles_target = dt.get_bundles(bundles[0]) - bundles_input = dt.get_bundles(bundles[1]) + bundles_target = dt.get_bundles(bundles[1]) + bundles_input = dt.get_bundles(bundles[0]) bundle_paths_target = natsorted(bundles_target) bundle_paths_input = natsorted(bundles_input) size = len(bundle_paths_target) img = np.zeros((size,256,256)) - samps = np.zeros((size,4,21036)) # hard code + samps = np.zeros((size,4,21000)) # hard code for i in tqdm(range(size)): sampled = bundle_paths_input[i] target = bundle_paths_target[i] @@ -101,6 +116,118 @@ def process_data( fft_scaled_truth = prepare_fft_images(truth_fft, True, False) out = data_path + "/samp_train0.h5" - save_fft_pair(out, samp_img[:500], fft_scaled_truth[:500]) + save_fft_pair(out, samp_img[:2300], fft_scaled_truth[:2300]) out = data_path + "/samp_valid0.h5" - save_fft_pair(out, samp_img[500:], fft_scaled_truth[500:]) + save_fft_pair(out, samp_img[2300:], fft_scaled_truth[2300:]) + + + +def process_data_dirty_model(data_path, freq, n_positions, fov_asec): + + print(f"\n Loading VLBI data set.\n") + bundles = dt.get_bundles(data_path) + freq = freq*10**6 # hard code #eht 227297 + uvfits = dt.get_bundles(bundles[3]) + imgs = dt.get_bundles(bundles[2]) + configs = dt.get_bundles(bundles[0]) + uv_srt = natsorted(uvfits, alg=ns.PATH) + img_srt = natsorted(imgs, alg=ns.PATH) + size = 1000 + for p in tqdm(range(n_positions)): + N = 63 # hard code + with fits.open(uv_srt[p*1000]) as hdul: + n_sampled = hdul[0].data.shape[0] #number of sampled points + baselines = hdul[0].data['Baseline'] + baselines = np.append(baselines,baselines) + unique_telescopes = hdul[3].data.shape[0] + unique_baselines = (unique_telescopes**2 - unique_telescopes)/2 + + # response matrices + A = response(configs[p], N, unique_telescopes, unique_baselines) + + img = np.zeros((size,256,256)) + samps = np.zeros((size,4,n_sampled*2)) + print(f"\n Load subset.\n") + for i in np.arange(p*1000, p*1000+1000): + sampled = uv_srt[i] + target = img_srt[i] + + img[i-p*1000] = np.asarray(Image.open(str(target))) + + with fits.open(sampled) as hdul: + data = hdul[0].data + cmplx = data['DATA'] + x = cmplx[...,0,0] + y = cmplx[...,0,1] + w = cmplx[...,0,2] + x = np.squeeze(x) + y = np.squeeze(y) + w = np.squeeze(w) + ap = np.sqrt(x**2+y**2) + ph = np.angle(x+1j*y) + samps[i-p*1000] = [np.append(data['UU--']*freq,-data['UU--']*freq),np.append(data['VV--']*freq,-data['VV--']*freq),np.append(ap,ap),np.append(ph,-ph)] + + print(f"\n Gridding VLBI data set.\n") + + # Generate Mask + u_0 = samps[0][0] + v_0 = samps[0][1] + mask = np.zeros((N,N,u_0.shape[0])) + + base_mask = np.zeros((N,N,int(unique_baselines))) + + fov = fov_asec*np.pi/(3600*180) # hard code #default 0.00018382 + # delta_u = 1/(fov*N/256) # hard code + delta_u = 1/(fov) # with a set N this is the same as zooming in since N*delta_u can be smaller than u_max + delta_u = (2*max(np.max(u_0),np.max(v_0))/N) # test gridding pixel size + # print(delta_u) + for i in range(N): + for j in range(N): + u_cell = (j-N/2)*delta_u + v_cell = (i-N/2)*delta_u + mask[i,j] = ((u_cell <= u_0) & (u_0 <= u_cell+delta_u)) & ((v_cell <= v_0) & (v_0 <= v_cell+delta_u)) + + base = np.unique(baselines[mask[i,j].astype(bool)]) + base_mask[i,j,:base.shape[0]] = base + + mask = np.flip(mask, [0]) + points = np.sum(mask, 2) + points[points==0] = 1 + samp_img = np.zeros((size,2,N,N)) + img_resized = np.zeros((size,N,N)) + for i in range(samps.shape[0]): + samp_img[i][0] = np.matmul(mask, samps[i][2].T)/points + samp_img[i][0] = (np.log10(samp_img[i][0] + 1e-10) / 10) + 1 + samp_img[i][1] = np.matmul(mask, samps[i][3].T)/points + img_resized[i] = cv2.resize(img[i], (N,N)) + img_resized[i] = img_resized[i]/np.sum(img_resized[i]) + + # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) + truth_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(img_resized, axes=(1,2)), axes=(1,2)), axes=(1,2)) + fft_scaled_truth = prepare_fft_images(truth_fft, True, False) + + out = data_path + "/h5/samp_train"+ str(p) + ".h5" + save_fft_pair_with_response(out, samp_img[:800], fft_scaled_truth[:800], base_mask, A) + out = data_path + "/h5/samp_valid"+ str(p) + ".h5" + save_fft_pair_with_response(out, samp_img[800:], fft_scaled_truth[800:], base_mask, A) + + +def response(config, N, unique_telescopes, unique_baselines): + rc = ut.read_config(config) + array_layout = layouts.get_array_layout('vlba') + src_crd = rc['src_coord'] + + wave = const.c/((float(rc['channel'].split(':')[0])/2)*10**6/un.second)/un.meter + rd = scan.rd_grid(rc['fov_size']*np.pi/(3600*180),N, src_crd) + E = scan.getE(rd, array_layout, wave, src_crd) + A = np.zeros((N,N,int(unique_baselines))) + counter = 0 + for i in range(int(unique_telescopes)): + for j in range(int(unique_telescopes)): + if i == j or j < i: + continue + A[:,:,counter] = E[:,:,i]*E[:,:,j] + counter += 1 + + return A + From ecd68c92ce26cb0f1557ca5f76268f5c03ae89e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Wed, 19 May 2021 19:20:18 +0200 Subject: [PATCH 26/55] vgg19 feature loss & training; gan implementation --- .../dl_framework/architectures/superRes.py | 214 +++++++++++++++++- radionets/dl_framework/callbacks.py | 13 +- radionets/dl_framework/learner.py | 34 ++- radionets/dl_framework/loss_functions.py | 91 ++++++++ radionets/dl_framework/model.py | 59 +++-- .../dl_training/scripts/start_training.py | 35 ++- radionets/dl_training/utils.py | 7 +- radionets/evaluation/utils.py | 12 +- 8 files changed, 428 insertions(+), 37 deletions(-) diff --git a/radionets/dl_framework/architectures/superRes.py b/radionets/dl_framework/architectures/superRes.py index 7f1ad9d7..dd84565a 100644 --- a/radionets/dl_framework/architectures/superRes.py +++ b/radionets/dl_framework/architectures/superRes.py @@ -15,11 +15,14 @@ btf_shift, CirculationShiftPad, SRBlockPad, - BetterShiftPad + BetterShiftPad, Lambda, symmetry, ) from functools import partial +import torchvision +import radionets.evaluation.utils as ut +import numpy as np class superRes_simple(nn.Module): @@ -372,7 +375,7 @@ def forward(self, x): class SRResNet_corr(nn.Module): def __init__(self, img_size): super().__init__() - # torch.cuda.set_device(1) + torch.cuda.set_device(1) self.img_size = img_size self.preBlock = nn.Sequential( @@ -411,15 +414,22 @@ def __init__(self, img_size): self.symmetry = Lambda(better_symmetry) - def forward(self, x): - x = x[:, 0].unsqueeze(1) + #pi layer + self.pi = nn.Tanh() + def forward(self, x): x = self.preBlock(x) x = x + self.postBlock(self.blocks(x)) x = self.final(x) + # x[:,0][x[:,0]<0] = 0 + # x[:,0][x[:,0]>2] = 2 + # x[:,1] = np.pi*self.pi(x[:,1]) + + + return self.symmetry(x) class SRResNet_sym(nn.Module): @@ -659,4 +669,200 @@ def forward(self, x): x = torch.cat((x,x,x,x), dim=0) + self.postBlock(block) x = self.final(x) + return x + + +class vgg19_feature_maps(nn.Module): + + def __init__(self, i, j): + super().__init__() + # load pretrained vgg19 + # vgg19 = torchvision.models.vgg19(pretrained=True) + # model = ut.load_pretrained_model(arch_name='vgg19_blackhole_group2', model_path='/net/big-tank/POOL/users/sfroese/vipy/eht/m87/blackhole/models/vgg19_group2.model') + # model = ut.load_pretrained_model(arch_name='vgg19_blackhole_group2_prelu', model_path='/net/big-tank/POOL/users/sfroese/vipy/eht/m87/blackhole/models/vgg19_groups2_prelu.model') + model = ut.load_pretrained_model(arch_name='vgg19_blackhole_fft', model_path='/net/big-tank/POOL/users/sfroese/vipy/eht/m87/blackhole/models/vgg19_fft.model') + + vgg19 = model.vgg + + conv_counter = 0 + maxpool_counter = 0 + truncate_at = 0 + for layer in vgg19.features: + truncate_at += 1 + + if isinstance(layer, nn.MaxPool2d): + maxpool_counter += 1 + conv_counter = 0 + if isinstance(layer, nn.Conv2d): + conv_counter += 1 + + if maxpool_counter == i - 1 and conv_counter == j: + break + + self.truncated_vgg19 = nn.Sequential(*list(vgg19.features)[:truncate_at + 1]) + + def forward(self, x): + amp_rescaled = (10 ** (10 * x[:,0]) - 1) / 10 ** 10 + phase = x[:,1] + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + img = torch.absolute(ifft) + shift = torch.fft.fftshift(img) + with torch.no_grad(): + feature = self.truncated_vgg19(img.unsqueeze(1)) + + return feature + + + +class vgg19_blackhole(nn.Module): + def __init__(self): + super().__init__() + torch.cuda.set_device(1) + vgg19 = torchvision.models.vgg19(pretrained=False) + + # customize vgg19 + vgg19.features[0] = nn.Conv2d(2, 64, 3, stride=1, padding=1) + vgg19.classifier[6] = nn.Sequential(nn.Linear(in_features=4096, out_features=6, bias=True)) + + # for i, layer in enumerate(vgg19.features): + # if isinstance(layer, nn.Conv2d): + # vgg19.features[i] = nn.Conv2d(layer.in_channels, layer.out_channels, layer.kernel_size, layer.stride, layer.padding, groups=2) + # if isinstance(layer, nn.ReLU): + # vgg19.features[i] = nn.PReLU() + + self.vgg = vgg19 + + def forward(self, x): + return self.vgg(x) + + +class vgg19_blackhole_group2(nn.Module): + def __init__(self): + super().__init__() + # torch.cuda.set_device(1) + vgg19 = torchvision.models.vgg19(pretrained=False) + + # customize vgg19 + + vgg19.features[0] = nn.Conv2d(2, 64, 3, stride=1, padding=1) + vgg19.classifier[6] = nn.Sequential(nn.Linear(in_features=4096, out_features=6, bias=True)) + + for i, layer in enumerate(vgg19.features): + if isinstance(layer, nn.Conv2d): + vgg19.features[i] = nn.Conv2d(layer.in_channels, layer.out_channels, layer.kernel_size, layer.stride, layer.padding, groups=2) + # if isinstance(layer, nn.ReLU): + # vgg19.features[i] = nn.PReLU() + + self.vgg = vgg19 + + def forward(self, x): + #ifft + return self.vgg(x) + +class vgg19_blackhole_group2_prelu(nn.Module): + def __init__(self): + super().__init__() + # torch.cuda.set_device(1) + vgg19 = torchvision.models.vgg19(pretrained=False) + + # customize vgg19 + vgg19.features[0] = nn.Conv2d(2, 64, 3, stride=1, padding=1) + vgg19.classifier[6] = nn.Sequential(nn.Linear(in_features=4096, out_features=6, bias=True)) + + for i, layer in enumerate(vgg19.features): + if isinstance(layer, nn.Conv2d): + vgg19.features[i] = nn.Conv2d(layer.in_channels, layer.out_channels, layer.kernel_size, layer.stride, layer.padding, groups=2) + if isinstance(layer, nn.ReLU): + vgg19.features[i] = nn.PReLU() + + self.vgg = vgg19 + + def forward(self, x): + return self.vgg(x) + +class vgg19_blackhole_fft(nn.Module): + def __init__(self): + super().__init__() + # torch.cuda.set_device(1) + vgg19 = torchvision.models.vgg19(pretrained=False) + + # customize vgg19 + + vgg19.features[0] = nn.Conv2d(1, 64, 3, stride=1, padding=1) + vgg19.classifier[6] = nn.Sequential(nn.Linear(in_features=4096, out_features=6, bias=True)) + + # for i, layer in enumerate(vgg19.features): + # if isinstance(layer, nn.Conv2d): + # vgg19.features[i] = nn.Conv2d(layer.in_channels, layer.out_channels, layer.kernel_size, layer.stride, layer.padding, groups=2) + # # if isinstance(layer, nn.ReLU): + # vgg19.features[i] = nn.PReLU() + + self.vgg = vgg19 + + def forward(self, x): + #ifft + amp_rescaled = (10 ** (10 * x[:,0]) - 1) / 10 ** 10 + phase = x[:,1] + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + img = torch.absolute(ifft) + shift = torch.fft.fftshift(img) + return self.vgg(img.unsqueeze(1)) + +class discriminator(nn.Module): + def __init__(self): + super().__init__() + + self.preBlock = nn.Sequential(nn.Conv2d(1, 64, 3, stride=1, padding=1), nn.LeakyReLU(0.2)) + + self.block1 = nn.Sequential(nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2)) + self.block2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) + self.block3 = nn.Sequential(nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) + self.block4 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) + self.block5 = nn.Sequential(nn.Conv2d(256, 256, 3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) + self.block6 = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) + self.block7 = nn.Sequential(nn.Conv2d(512, 512, 3, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) + + self.main = nn.Sequential(self.block1, self.block2, self.block3, self.block4, self.block5, self.block6, self.block7) + + self.postBlock = nn.Sequential(nn.Linear(512*4*4, 1024), nn.LeakyReLU(0.2), nn.Linear(1024,1), nn.Sigmoid()) + + def forward(self, x): + if x.shape[1] == 2: + amp_x = (10 ** (10 * x[:,0]) - 1) / 10 ** 10 + phase_x = x[:,1] + compl_x = amp_x * torch.exp(1j * phase_x) + ifft_x = torch.fft.ifft2(compl_x) + img_x = torch.absolute(ifft_x) + shift_x = torch.fft.fftshift(img_x).unsqueeze(1) + x = shift_x + x = self.preBlock(x) + x = self.main(x) + x = torch.flatten(x, 1) + x = self.postBlock(x) + + + return x + + +class automap(nn.Module): + def __init__(self): + super().__init__() + torch.cuda.set_device(1) + self.fcs = nn.Sequential(nn.Linear(63*63*2,63*63),nn.Tanh(),nn.Linear(63*63,63*63),nn.Tanh()) + + self.convs = nn.Sequential(nn.Conv2d(1,64,5, stride=1, padding=2),nn.ReLU(),nn.Conv2d(64,64,5,stride=1, padding=2),nn.ReLU(),nn.Conv2d(64,1,7,stride=1, padding=3)) + + + def forward(self, x): + amp_x = (10 ** (10 * x[:,0]) - 1) / 10 ** 10 + phase_x = x[:,1] + compl_x = amp_x * torch.exp(1j * phase_x) + x[:,0] = compl_x.real + x[:,1] = compl_x.imag + x = torch.flatten(x, 1) + x = self.fcs(x) + x = x.reshape((x.shape[0],1,63,63)) + x = self.convs(x) return x \ No newline at end of file diff --git a/radionets/dl_framework/callbacks.py b/radionets/dl_framework/callbacks.py index 9d3f8a65..4344ec95 100644 --- a/radionets/dl_framework/callbacks.py +++ b/radionets/dl_framework/callbacks.py @@ -112,6 +112,9 @@ def normalize_tfm(self): class DataAug(Callback): _order = 3 + def __init__(self, vgg): + self.vgg = vgg + def before_batch(self): x = self.xb[0].clone() y = self.yb[0].clone() @@ -119,8 +122,9 @@ def before_batch(self): for i in range(x.shape[0]): x[i, 0] = torch.rot90(x[i, 0], int(randint[i])) x[i, 1] = torch.rot90(x[i, 1], int(randint[i])) - y[i, 0] = torch.rot90(y[i, 0], int(randint[i])) - y[i, 1] = torch.rot90(y[i, 1], int(randint[i])) + if not self.vgg: + y[i, 0] = torch.rot90(y[i, 0], int(randint[i])) + y[i, 1] = torch.rot90(y[i, 1], int(randint[i])) self.learn.xb = [x] self.learn.yb = [y] @@ -128,13 +132,14 @@ def before_batch(self): class SaveTempCallback(Callback): _order = 95 - def __init__(self, model_path): + def __init__(self, model_path, gan=False): self.model_path = model_path + self.gan = gan def after_epoch(self): p = Path(self.model_path).parent p.mkdir(parents=True, exist_ok=True) if (self.epoch + 1) % 10 == 0: out = p / f"temp_{self.epoch + 1}.model" - save_model(self, out) + save_model(self, out, self.gan) print(f"\nFinished Epoch {self.epoch + 1}, model saved.\n") diff --git a/radionets/dl_framework/learner.py b/radionets/dl_framework/learner.py index c0c43c07..9f044813 100644 --- a/radionets/dl_framework/learner.py +++ b/radionets/dl_framework/learner.py @@ -14,7 +14,10 @@ from fastai.callback.schedule import ParamScheduler, combined_cos import radionets.dl_framework.loss_functions as loss_functions from fastai.vision import models +# from radionets.dl_framework.architectures import superRes import torchvision +from radionets.dl_training.utils import define_arch +from fastai.vision.gan import GANLearner def get_learner( @@ -36,6 +39,7 @@ def define_learner( cbfs=[], test=False, lr_find=False, + gan=False, ): model_path = train_conf["model_path"] model_name = ( @@ -72,12 +76,20 @@ def define_learner( CudaCallback, ] ) - if not test: + if not test and not gan: cbfs.extend( [ SaveTempCallback(model_path=model_path), AvgLossCallback, - DataAug, + DataAug(vgg=train_conf["vgg"]), + ] + ) + if gan: + cbfs.extend( + [ + SaveTempCallback(model_path=model_path, gan=gan), + AvgLossCallback, + DataAug(vgg=train_conf["vgg"]), ] ) if train_conf["telegram_logger"] and not lr_find: @@ -93,6 +105,24 @@ def define_learner( else: loss_func = getattr(loss_functions, train_conf["loss_func"]) + if gan: + gen_loss_func = getattr(loss_functions, 'gen_loss_func') + crit_loss_func = getattr(loss_functions, 'crit_loss_func') + + generator = arch + critic = define_arch( + arch_name='discriminator', img_size=train_conf["image_size"] + ) + init_cnn(generator) + init_cnn(critic) + dls = DataLoaders.from_dsets( + data.train_ds, + data.valid_ds, + bs=data.train_dl.batch_size, + ) + learn = GANLearner(dls, generator, critic, gen_loss_func, crit_loss_func, lr=lr, cbs=cbfs, opt_func=opt_func) + return learn + # Combine model and data in learner learn = get_learner( data, arch, lr=lr, opt_func=opt_func, cb_funcs=cbfs, loss_func=loss_func diff --git a/radionets/dl_framework/loss_functions.py b/radionets/dl_framework/loss_functions.py index f8aefd05..5d38a311 100644 --- a/radionets/dl_framework/loss_functions.py +++ b/radionets/dl_framework/loss_functions.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from pytorch_msssim import MS_SSIM from scipy.optimize import linear_sum_assignment +from radionets.dl_framework.architectures import superRes class FeatureLoss(nn.Module): @@ -120,6 +121,96 @@ def l1(x, y): loss = l1(x, y) return loss +# vgg19 = superRes.vgg19_feature_maps(5,4).eval().to('cuda:1') +def vgg19_feature_loss(x, y): + print() + if 'vgg19_feature_model_12' not in globals(): + global vgg19_feature_model_12 + # global vgg19_feature_model_22 + # global vgg19_feature_model_34 + # global vgg19_feature_model_44 + # global vgg19_feature_model_54 + vgg19_feature_model_12 = superRes.vgg19_feature_maps(1,2).eval().to('cuda:1') + # vgg19_feature_model_22 = superRes.vgg19_feature_maps(2,2).eval().to('cuda:1') + # vgg19_feature_model_34 = superRes.vgg19_feature_maps(3,4).eval().to('cuda:1') + # vgg19_feature_model_44 = superRes.vgg19_feature_maps(4,4).eval().to('cuda:1') + # vgg19_feature_model_54 = superRes.vgg19_feature_maps(5,4).eval().to('cuda:1') + + + mse = nn.MSELoss() + l1 = nn.L1Loss() + + # up1 = nn.Upsample(size=7, mode='nearest').to('cuda:1') + # up2 = nn.Upsample(size=15, mode='nearest').to('cuda:1') + # up3 = nn.Upsample(size=31, mode='nearest').to('cuda:1') + # up4 = nn.Upsample(size=63, mode='nearest').to('cuda:1') + # c34 = nn.Conv2d(256, 512, 1).to('cuda:1') + # c22 = nn.Conv2d(128, 512, 1).to('cuda:1') + # c12 = nn.Conv2d(64, 512, 1).to('cuda:1') + + # upx1 = up1(vgg19_feature_model_54(x)) + # upy1 = up1(vgg19_feature_model_54(y)) + + # upx2 = up2(vgg19_feature_model_44(x) + upx1) + # upy2 = up2(vgg19_feature_model_44(y) + upy1) + + # upx3 = up3(c34(vgg19_feature_model_34(x)) + upx2) + # upy3 = up3(c34(vgg19_feature_model_34(y)) + upy2) + + # upx4 = up4(c22(vgg19_feature_model_22(x)) + upx3) + # upy4 = up4(c22(vgg19_feature_model_22(y)) + upy3) + + # upx5 = (c12(vgg19_feature_model_12(x)) + upx4) + # upy5 = (c12(vgg19_feature_model_12(y)) + upy4) + + + # mix_x = (0.5*x[:,0]+0.5*x[:,1]).unsqueeze(1) + # mix_y = (0.5*y[:,0]+0.5*y[:,1]).unsqueeze(1) + # ones = torch.ones((x.shape[0], 1, x.shape[2], x.shape[3])).to('cuda:1') + # x_3c = torch.cat((x, ones), dim=1) + # y_3c = torch.cat((y, ones), dim=1) + + # loss = l1(vgg19_feature_model_22(x), vgg19_feature_model_22(y))# + l1(vgg19_feature_model_12(x), vgg19_feature_model_12(y)) + l1(vgg19_feature_model_34(x), vgg19_feature_model_34(y)) + l1(vgg19_feature_model_44(x), vgg19_feature_model_44(y)) + l1(vgg19_feature_model_54(x), vgg19_feature_model_54(y)) + loss = l1(vgg19_feature_model_12(x), vgg19_feature_model_12(y)) + return loss + +def gen_loss_func(fake_pred, x, y): + l1 = nn.L1Loss() + bce = nn.BCELoss() + + # mask = torch.zeros(x.shape).to('cuda:1') + # mask[:,:,31-10:31+10,31-10:31+10]=1 + # xm = torch.einsum('bcij,bcjk->bcik',x,mask) + # ym = torch.einsum('bcij,bcjk->bcik',y,mask) + + # content_loss = l1(x,y) + content_loss = automap_l2(x,y) + # content_loss = vgg19_feature_loss(x,y) + adv_loss = bce(fake_pred, torch.ones_like(fake_pred)) + + return content_loss + 1e-3*adv_loss +def automap_l2(x,y): + amp_y = (10 ** (10 * y[:,0]) - 1) / 10 ** 10 + phase_y = y[:,1] + compl_y = amp_y * torch.exp(1j * phase_y) + ifft_y = torch.fft.ifft2(compl_y) + img_y = torch.absolute(ifft_y) + shift_y = torch.fft.fftshift(img_y).unsqueeze(1) + mse = nn.MSELoss() + return mse(x,shift_y) + + +def crit_loss_func(real_pred, fake_pred): + bce = nn.BCELoss() + loss = bce(real_pred, torch.ones_like(real_pred)) + bce(fake_pred, torch.zeros_like(fake_pred)) + + return loss + +def cross_entropy(x,y): + loss = nn.CrossEntropyLoss() + return loss(x, y.squeeze().long()) + + def l1_rnn(x, y): l1 = nn.L1Loss() x = torch.chunk(x, 4, dim=0) diff --git a/radionets/dl_framework/model.py b/radionets/dl_framework/model.py index a4d92869..b9d8ff70 100644 --- a/radionets/dl_framework/model.py +++ b/radionets/dl_framework/model.py @@ -320,26 +320,47 @@ def load_pre_model(learn, pre_path, visualize=False): learn.recorder.lrs = checkpoint["recorder_lrs"] -def save_model(learn, model_path): +def save_model(learn, model_path, gan=False): # print(learn.model.generator) - torch.save( - { - "model": learn.model.state_dict(), - "opt": learn.opt.state_dict(), - "epoch": learn.epoch, - "loss": learn.loss, - "iters": learn.recorder.iters, - "vals": learn.recorder.values, - "train_loss": learn.avg_loss.loss_train, - "valid_loss": learn.avg_loss.loss_valid, - "lrs": learn.avg_loss.lrs, - "recorder_train_loss": L(learn.recorder.values[0:]).itemgot(0), - "recorder_valid_loss": L(learn.recorder.values[0:]).itemgot(1), - "recorder_losses": learn.recorder.losses, - "recorder_lrs": learn.recorder.lrs, - }, - model_path, - ) + if not gan: + torch.save( + { + "model": learn.model.state_dict(), + "opt": learn.opt.state_dict(), + "epoch": learn.epoch, + "loss": learn.loss, + "iters": learn.recorder.iters, + "vals": learn.recorder.values, + "train_loss": learn.avg_loss.loss_train, + "valid_loss": learn.avg_loss.loss_valid, + "lrs": learn.avg_loss.lrs, + "recorder_train_loss": L(learn.recorder.values[0:]).itemgot(0), + "recorder_valid_loss": L(learn.recorder.values[0:]).itemgot(1), + "recorder_losses": learn.recorder.losses, + "recorder_lrs": learn.recorder.lrs, + }, + model_path, + ) + else: + torch.save( + { + "model": learn.model.generator.state_dict(), + "opt": learn.opt.state_dict(), + "epoch": learn.epoch, + "loss": learn.loss, + "iters": learn.recorder.iters, + "vals": learn.recorder.values, + "train_loss": learn.avg_loss.loss_train, + "valid_loss": learn.avg_loss.loss_valid, + "lrs": learn.avg_loss.lrs, + "recorder_train_loss": L(learn.recorder.values[0:]).itemgot(0), + "recorder_valid_loss": L(learn.recorder.values[0:]).itemgot(1), + "recorder_losses": learn.recorder.losses, + "recorder_lrs": learn.recorder.lrs, + }, + model_path, + ) + class LocallyConnected2d(nn.Module): diff --git a/radionets/dl_training/scripts/start_training.py b/radionets/dl_training/scripts/start_training.py index d9ca653b..befdb75c 100644 --- a/radionets/dl_training/scripts/start_training.py +++ b/radionets/dl_training/scripts/start_training.py @@ -31,6 +31,7 @@ "train", "lr_find", "plot_loss", + "gan", ], case_sensitive=False, ), @@ -63,6 +64,7 @@ def main(configuration_path, mode): fourier=train_conf["fourier"], batch_size=train_conf["bs"], source_list=train_conf["source_list"], + vgg=train_conf["vgg"], ) # get image size @@ -94,7 +96,38 @@ def main(configuration_path, mode): # Train the model, except interrupt try: # learn.fine_tune(train_conf["num_epochs"]) - learn.fit(train_conf["num_epochs"]) + learn.fit(train_conf["num_epochs"]) + except KeyboardInterrupt: + pop_interrupt(learn, train_conf) + + end_training(learn, train_conf) + + if train_conf["inspection"]: + create_inspection_plots(train_conf, rand=True) + + if mode == "gan": + # check out path and look for existing model files + check_outpath(train_conf["model_path"], train_conf) + + click.echo("Start training of the GAN model.\n") + + # define_learner + learn = define_learner( + data, + arch, + train_conf, + gan=True + ) + + # load pretrained model + if train_conf["pre_model"] != "none": + learn.create_opt() + load_pre_model(learn, train_conf["pre_model"]) + + # Train the model, except interrupt + try: + # learn.fine_tune(train_conf["num_epochs"]) + learn.fit(train_conf["num_epochs"]) except KeyboardInterrupt: pop_interrupt(learn, train_conf) diff --git a/radionets/dl_training/utils.py b/radionets/dl_training/utils.py index d406f066..3745746f 100644 --- a/radionets/dl_training/utils.py +++ b/radionets/dl_training/utils.py @@ -8,10 +8,10 @@ from radionets.evaluation.train_inspection import create_inspection_plots -def create_databunch(data_path, fourier, source_list, batch_size): +def create_databunch(data_path, fourier, source_list, batch_size, vgg): # Load data sets - train_ds = load_data(data_path, "train", source_list=source_list, fourier=fourier) - valid_ds = load_data(data_path, "valid", source_list=source_list, fourier=fourier) + train_ds = load_data(data_path, "train", source_list=source_list, fourier=fourier, vgg=vgg) + valid_ds = load_data(data_path, "valid", source_list=source_list, fourier=fourier, vgg=vgg) # Create databunch with defined batchsize bs = batch_size @@ -34,6 +34,7 @@ def read_config(config): train_conf["lr"] = config["hypers"]["lr"] train_conf["fourier"] = config["general"]["fourier"] + train_conf["vgg"] = config["general"]["vgg"] train_conf["amp_phase"] = config["general"]["amp_phase"] train_conf["arch_name"] = config["general"]["arch_name"] train_conf["loss_func"] = config["general"]["loss_func"] diff --git a/radionets/evaluation/utils.py b/radionets/evaluation/utils.py index 4e75e5d3..17c2133e 100644 --- a/radionets/evaluation/utils.py +++ b/radionets/evaluation/utils.py @@ -209,7 +209,11 @@ def load_pretrained_model(arch_name, model_path, img_size=63): arch: architecture object architecture with pretrained weigths """ - if "filter_deep" in arch_name or "resnet" in arch_name: + if 'vgg19' in arch_name: + arch = getattr(architecture, arch_name)() + elif 'automap' in arch_name: + arch = getattr(architecture, arch_name)() + elif "filter_deep" in arch_name or "resnet" or "Res" in arch_name: arch = getattr(architecture, arch_name)(img_size) else: arch = getattr(architecture, arch_name)() @@ -320,8 +324,8 @@ def fft_pred(pred, truth, amp_phase=True): a = pred[:, 0, :, :] b = pred[:, 1, :, :] - a_true = truth[0, :, :] - b_true = truth[1, :, :] + a_true = truth[:, 0, :, :] + b_true = truth[:, 1, :, :] if amp_phase: amp_pred_rescaled = (10 ** (10 * a) - 1) / 10 ** 10 @@ -339,4 +343,4 @@ def fft_pred(pred, truth, amp_phase=True): ifft_pred = np.fft.ifft2(compl_pred) ifft_true = np.fft.ifft2(compl_true) - return np.absolute(ifft_pred)[0], np.absolute(ifft_true) + return np.absolute(ifft_pred), np.absolute(ifft_true) From be58f68559a791e5461da76437d57732ec762459 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 20 May 2021 19:59:41 +0200 Subject: [PATCH 27/55] dirty model .h5 file from process_vlbi.py are now readable --- radionets/dl_framework/callbacks.py | 18 ++++++++++--- radionets/dl_framework/data.py | 25 ++++++++++++++++--- radionets/dl_framework/learner.py | 4 +-- radionets/dl_framework/loss_functions.py | 10 ++++++++ .../dl_training/scripts/start_training.py | 1 + radionets/dl_training/utils.py | 7 +++--- radionets/simulations/process_vlbi.py | 4 +-- 7 files changed, 55 insertions(+), 14 deletions(-) diff --git a/radionets/dl_framework/callbacks.py b/radionets/dl_framework/callbacks.py index 4344ec95..c708e440 100644 --- a/radionets/dl_framework/callbacks.py +++ b/radionets/dl_framework/callbacks.py @@ -112,12 +112,18 @@ def normalize_tfm(self): class DataAug(Callback): _order = 3 - def __init__(self, vgg): + def __init__(self, vgg, physics_informed): self.vgg = vgg + self.physics_informed = physics_informed def before_batch(self): x = self.xb[0].clone() - y = self.yb[0].clone() + if self.physics_informed: + y = self.yb[0][0].clone() + base_mask = self.yb[0][1].clone() + A = self.yb[0][2].clone() + else: + y = self.yb[0].clone() randint = np.random.randint(0, 4, x.shape[0]) for i in range(x.shape[0]): x[i, 0] = torch.rot90(x[i, 0], int(randint[i])) @@ -125,8 +131,14 @@ def before_batch(self): if not self.vgg: y[i, 0] = torch.rot90(y[i, 0], int(randint[i])) y[i, 1] = torch.rot90(y[i, 1], int(randint[i])) + if self.physics_informed: + base_mask[i] = torch.rot90(base_mask[i], int(randint[i]), dims=[0,1]) + A[i] = torch.rot90(A[i], int(randint[i]), dims=[0,1]) self.learn.xb = [x] - self.learn.yb = [y] + if self.physics_informed: + self.learn.yb = [(y, base_mask, A)] + else: + self.learn.yb = [y] class SaveTempCallback(Callback): diff --git a/radionets/dl_framework/data.py b/radionets/dl_framework/data.py index 42c46203..f4c0ce86 100644 --- a/radionets/dl_framework/data.py +++ b/radionets/dl_framework/data.py @@ -32,7 +32,7 @@ def do_normalisation(x, norm): class h5_dataset: - def __init__(self, bundle_paths, tar_fourier, amp_phase=None, source_list=False, vgg=False): + def __init__(self, bundle_paths, tar_fourier, amp_phase=None, source_list=False, vgg=False, physics_informed=False): """ Save the bundle paths and the number of bundles in one file. """ @@ -42,6 +42,7 @@ def __init__(self, bundle_paths, tar_fourier, amp_phase=None, source_list=False, self.amp_phase = amp_phase self.source_list = source_list self.vgg = vgg + self.physics_informed = physics_informed def __call__(self): return print("This is the h5_dataset class.") @@ -56,6 +57,12 @@ def __getitem__(self, i): if self.source_list: x = self.open_image("x", i) y = self.open_image("z", i) + elif self.physics_informed: + x = self.open_image("x", i) + y = self.open_image("y", i) + base_mask = self.open_image("base_mask", i) + A = self.open_image("A", i) + return x, (y, base_mask.squeeze(0), A.squeeze(0)) else: x = self.open_image("x", i) y = self.open_image("y", i) @@ -79,6 +86,12 @@ def open_image(self, var, i): h5py.File(self.bundles[bundle], "r") for bundle in bundle_unique ] bundle_paths_str = list(map(str, bundle_paths)) + + if var == 'base_mask' or var == 'A': # Baselines and response matrices are the same for every single src_position/bundle + image[image != 0] = 0 + + + data = torch.tensor( [ bund[var][img] @@ -88,7 +101,7 @@ def open_image(self, var, i): ] ] ) - if var == "x" or self.tar_fourier is True: + if var == "x" or var == 'y': if len(i) == 1: data_amp, data_phase = data[:, 0], data[:, 1] @@ -102,6 +115,8 @@ def open_image(self, var, i): data_channel = data elif self.vgg: data_channel = data + elif self.physics_informed: + data_channel = data else: if data.shape[1] == 2: raise ValueError( @@ -112,6 +127,8 @@ def open_image(self, var, i): data_channel = data.reshape(data.shape[-1] ** 2) else: data_channel = data.reshape(-1, data.shape[-1] ** 2) + # if var == 'base_mask' or var == 'A': + # print(data_channel.shape) return data_channel.float() @@ -232,7 +249,7 @@ def mean_and_std(array): return array.mean(), array.std() -def load_data(data_path, mode, fourier=False, source_list=False, vgg=False): +def load_data(data_path, mode, fourier=False, source_list=False, vgg=False, physics_informed=False): """ Load data set from a directory and return it as h5_dataset. @@ -252,5 +269,5 @@ def load_data(data_path, mode, fourier=False, source_list=False, vgg=False): """ bundle_paths = get_bundles(data_path) data = [path for path in bundle_paths if re.findall("samp_" + mode, path.name)] - ds = h5_dataset(data, tar_fourier=fourier, source_list=source_list, vgg=vgg) + ds = h5_dataset(data, tar_fourier=fourier, source_list=source_list, vgg=vgg, physics_informed=physics_informed) return ds diff --git a/radionets/dl_framework/learner.py b/radionets/dl_framework/learner.py index 9f044813..7b64d553 100644 --- a/radionets/dl_framework/learner.py +++ b/radionets/dl_framework/learner.py @@ -81,7 +81,7 @@ def define_learner( [ SaveTempCallback(model_path=model_path), AvgLossCallback, - DataAug(vgg=train_conf["vgg"]), + DataAug(vgg=train_conf["vgg"], physics_informed=train_conf["physics_informed"]), ] ) if gan: @@ -89,7 +89,7 @@ def define_learner( [ SaveTempCallback(model_path=model_path, gan=gan), AvgLossCallback, - DataAug(vgg=train_conf["vgg"]), + DataAug(vgg=train_conf["vgg"], physics_informed=train_conf["physics_informed"]), ] ) if train_conf["telegram_logger"] and not lr_find: diff --git a/radionets/dl_framework/loss_functions.py b/radionets/dl_framework/loss_functions.py index 5d38a311..6ed9fd06 100644 --- a/radionets/dl_framework/loss_functions.py +++ b/radionets/dl_framework/loss_functions.py @@ -121,6 +121,16 @@ def l1(x, y): loss = l1(x, y) return loss +def dirty_model(x, y): + print(x.shape) + print(y[0].shape) + print(y[1].shape) + print(y[2].shape) + + l1 = nn.L1Loss() + loss = l1(x, y[0]) + return loss + # vgg19 = superRes.vgg19_feature_maps(5,4).eval().to('cuda:1') def vgg19_feature_loss(x, y): print() diff --git a/radionets/dl_training/scripts/start_training.py b/radionets/dl_training/scripts/start_training.py index befdb75c..4e6fda44 100644 --- a/radionets/dl_training/scripts/start_training.py +++ b/radionets/dl_training/scripts/start_training.py @@ -65,6 +65,7 @@ def main(configuration_path, mode): batch_size=train_conf["bs"], source_list=train_conf["source_list"], vgg=train_conf["vgg"], + physics_informed=train_conf["physics_informed"] ) # get image size diff --git a/radionets/dl_training/utils.py b/radionets/dl_training/utils.py index 3745746f..0006061e 100644 --- a/radionets/dl_training/utils.py +++ b/radionets/dl_training/utils.py @@ -8,10 +8,10 @@ from radionets.evaluation.train_inspection import create_inspection_plots -def create_databunch(data_path, fourier, source_list, batch_size, vgg): +def create_databunch(data_path, fourier, source_list, batch_size, vgg, physics_informed): # Load data sets - train_ds = load_data(data_path, "train", source_list=source_list, fourier=fourier, vgg=vgg) - valid_ds = load_data(data_path, "valid", source_list=source_list, fourier=fourier, vgg=vgg) + train_ds = load_data(data_path, "train", source_list=source_list, fourier=fourier, vgg=vgg, physics_informed=physics_informed) + valid_ds = load_data(data_path, "valid", source_list=source_list, fourier=fourier, vgg=vgg, physics_informed=physics_informed) # Create databunch with defined batchsize bs = batch_size @@ -35,6 +35,7 @@ def read_config(config): train_conf["fourier"] = config["general"]["fourier"] train_conf["vgg"] = config["general"]["vgg"] + train_conf["physics_informed"] = config["general"]["physics_informed"] train_conf["amp_phase"] = config["general"]["amp_phase"] train_conf["arch_name"] = config["general"]["arch_name"] train_conf["loss_func"] = config["general"]["loss_func"] diff --git a/radionets/simulations/process_vlbi.py b/radionets/simulations/process_vlbi.py index daa3548f..b0e919ec 100644 --- a/radionets/simulations/process_vlbi.py +++ b/radionets/simulations/process_vlbi.py @@ -207,9 +207,9 @@ def process_data_dirty_model(data_path, freq, n_positions, fov_asec): fft_scaled_truth = prepare_fft_images(truth_fft, True, False) out = data_path + "/h5/samp_train"+ str(p) + ".h5" - save_fft_pair_with_response(out, samp_img[:800], fft_scaled_truth[:800], base_mask, A) + save_fft_pair_with_response(out, samp_img[:800], fft_scaled_truth[:800], np.expand_dims(base_mask,0), np.expand_dims(A,0)) out = data_path + "/h5/samp_valid"+ str(p) + ".h5" - save_fft_pair_with_response(out, samp_img[800:], fft_scaled_truth[800:], base_mask, A) + save_fft_pair_with_response(out, samp_img[800:], fft_scaled_truth[800:], np.expand_dims(base_mask,0), np.expand_dims(A,0)) def response(config, N, unique_telescopes, unique_baselines): From 8a1daf305614887709110ce6a963b8b85c9a8809 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 14 Jun 2021 09:07:32 +0200 Subject: [PATCH 28/55] GANCS + CLEANNN --- .../dl_framework/architectures/superRes.py | 500 ++++++++++++++++-- 1 file changed, 466 insertions(+), 34 deletions(-) diff --git a/radionets/dl_framework/architectures/superRes.py b/radionets/dl_framework/architectures/superRes.py index dd84565a..585e37db 100644 --- a/radionets/dl_framework/architectures/superRes.py +++ b/radionets/dl_framework/architectures/superRes.py @@ -11,6 +11,7 @@ FBB, Lambda, better_symmetry, + fft, tf_shift, btf_shift, CirculationShiftPad, @@ -18,12 +19,18 @@ BetterShiftPad, Lambda, symmetry, + SRBlock_noBias, + HardDC, + SoftDC, + calc_DirtyBeam, + gauss, ) from functools import partial import torchvision import radionets.evaluation.utils as ut import numpy as np - +import matplotlib.pyplot as plt +import numpy as np class superRes_simple(nn.Module): def __init__(self, img_size): @@ -375,7 +382,7 @@ def forward(self, x): class SRResNet_corr(nn.Module): def __init__(self, img_size): super().__init__() - torch.cuda.set_device(1) + # torch.cuda.set_device(1) self.img_size = img_size self.preBlock = nn.Sequential( @@ -415,7 +422,7 @@ def __init__(self, img_size): self.symmetry = Lambda(better_symmetry) #pi layer - self.pi = nn.Tanh() + # self.pi = nn.Tanh() def forward(self, x): x = self.preBlock(x) @@ -676,11 +683,13 @@ class vgg19_feature_maps(nn.Module): def __init__(self, i, j): super().__init__() + # torch.cuda.set_device(1) # load pretrained vgg19 # vgg19 = torchvision.models.vgg19(pretrained=True) # model = ut.load_pretrained_model(arch_name='vgg19_blackhole_group2', model_path='/net/big-tank/POOL/users/sfroese/vipy/eht/m87/blackhole/models/vgg19_group2.model') # model = ut.load_pretrained_model(arch_name='vgg19_blackhole_group2_prelu', model_path='/net/big-tank/POOL/users/sfroese/vipy/eht/m87/blackhole/models/vgg19_groups2_prelu.model') - model = ut.load_pretrained_model(arch_name='vgg19_blackhole_fft', model_path='/net/big-tank/POOL/users/sfroese/vipy/eht/m87/blackhole/models/vgg19_fft.model') + # model = ut.load_pretrained_model(arch_name='vgg19_blackhole_fft', model_path='/net/big-tank/POOL/users/sfroese/vipy/eht/m87/blackhole/models/vgg19_fft.model') + model = ut.load_pretrained_model(arch_name='vgg19_one_channel', model_path='/net/big-tank/POOL/projects/radio/simulations/jets/260521/model/temp_30.model') vgg19 = model.vgg @@ -700,17 +709,21 @@ def __init__(self, i, j): break self.truncated_vgg19 = nn.Sequential(*list(vgg19.features)[:truncate_at + 1]) + for param in self.truncated_vgg19.parameters(): + param.requires_grad = False def forward(self, x): - amp_rescaled = (10 ** (10 * x[:,0]) - 1) / 10 ** 10 - phase = x[:,1] - compl = amp_rescaled * torch.exp(1j * phase) - ifft = torch.fft.ifft2(compl) - img = torch.absolute(ifft) - shift = torch.fft.fftshift(img) - with torch.no_grad(): - feature = self.truncated_vgg19(img.unsqueeze(1)) - + if x.shape[1] == 2: + amp_rescaled = (10 ** (10 * x[:,0]) - 1) / 10 ** 10 + phase = x[:,1] + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + img = torch.absolute(ifft) + shift = torch.fft.fftshift(img) + with torch.no_grad(): + feature = self.truncated_vgg19(img.unsqueeze(1)) + else: + feature = self.truncated_vgg19(x) return feature @@ -810,9 +823,27 @@ def forward(self, x): shift = torch.fft.fftshift(img) return self.vgg(img.unsqueeze(1)) +class vgg19_one_channel(nn.Module): + def __init__(self): + super().__init__() + torch.cuda.set_device(1) + vgg19 = torchvision.models.vgg19(pretrained=False) + + # customize vgg19 + + vgg19.features[0] = nn.Conv2d(1, 64, 3, stride=1, padding=1) + vgg19.classifier[6] = nn.Sequential(nn.Linear(in_features=4096, out_features=2, bias=True)) + + self.vgg = vgg19 + + def forward(self, x): + #ifft + return self.vgg(x) + class discriminator(nn.Module): def __init__(self): super().__init__() + torch.cuda.set_device(0) self.preBlock = nn.Sequential(nn.Conv2d(1, 64, 3, stride=1, padding=1), nn.LeakyReLU(0.2)) @@ -826,43 +857,444 @@ def __init__(self): self.main = nn.Sequential(self.block1, self.block2, self.block3, self.block4, self.block5, self.block6, self.block7) - self.postBlock = nn.Sequential(nn.Linear(512*4*4, 1024), nn.LeakyReLU(0.2), nn.Linear(1024,1), nn.Sigmoid()) + # self.postBlock = nn.Sequential(nn.Linear(512*4*4, 1024), nn.LeakyReLU(0.2), nn.Linear(1024,1), nn.Sigmoid()) #GAN + self.postBlock = nn.Sequential(nn.Linear(512*4*4, 1024), nn.LeakyReLU(0.2), nn.Linear(1024,1)) #WGAN def forward(self, x): + if isinstance(x, tuple) or isinstance(x, list): + if len(x) == 2: + x = x[1] + else: + x = x[0] + if x.shape[1] == 2: amp_x = (10 ** (10 * x[:,0]) - 1) / 10 ** 10 phase_x = x[:,1] compl_x = amp_x * torch.exp(1j * phase_x) ifft_x = torch.fft.ifft2(compl_x) img_x = torch.absolute(ifft_x) - shift_x = torch.fft.fftshift(img_x).unsqueeze(1) - x = shift_x - x = self.preBlock(x) - x = self.main(x) - x = torch.flatten(x, 1) - x = self.postBlock(x) + shift_x = torch.fft.ifftshift(img_x).unsqueeze(1) + else: + shift_x = x + # shift_x[torch.isnan(shift_x)] = 0 + pred = self.preBlock(shift_x) + pred = self.main(pred) + pred = torch.flatten(pred, 1) + pred = self.postBlock(pred) + return pred + +class SRResNet_dirtyModel_pretrainedL1(nn.Module): + def __init__(self, img_size): + super().__init__() + self.model = ut.load_pretrained_model(arch_name='SRResNet_dirtyModel', model_path='/net/big-tank/POOL/users/sfroese/vipy/jets/models/l1_symmetry.model') + def forward(self, x): + return self.model(x) + + +class SRResNet_dirtyModel(nn.Module): + def __init__(self, img_size): + super().__init__() + torch.cuda.set_device(1) + self.img_size = img_size + + self.preBlock = nn.Sequential( + nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=1), nn.PReLU() + ) + + # ResBlock 16 + self.blocks = nn.Sequential( + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + ) + + self.postBlock = nn.Sequential( + nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64) + ) + # self.upscale = nn.Sequential( + # nn.Conv2d(64, 256, 3, stride=1, padding = 1), + # nn.PixelShuffle(2), + # nn.PReLU() + # ) + self.final = nn.Sequential( + nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=1), + ) + + self.relu = nn.Hardtanh(0,1.1) + self.pi = nn.Hardtanh(-np.pi,np.pi) + + self.symmetry = Lambda(better_symmetry) + + + def forward(self, x): + + amp = x[:,0].clone().detach() + phase = x[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + dirty = torch.fft.ifftshift(torch.absolute(ifft)) + dirty = dirty.unsqueeze(1) + + + pred = self.preBlock(x) + + pred = pred + self.postBlock(self.blocks(pred)) + # pred = self.postBlock(self.blocks(pred)) + + + # pred = self.upscale(pred) + + pred = self.final(pred) + + pred[:,0] = self.relu(pred[:,0].clone()) + pred[:,1] = self.pi(pred[:,1].clone()) + + pred = self.symmetry(pred) + + + # pred = self.relu(pred) + # pred = nn.functional.interpolate(pred, scale_factor=0.5) + + return dirty, pred + + +class GANCS_generator_test(nn.Module): + def __init__(self): + super().__init__() + torch.cuda.set_device(1) + self.blocks = nn.Sequential( + SRBlock(2, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + ) + + self.post = nn.Sequential( + nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(), + nn.Conv2d(64, 64, 1, stride=1, padding=0), nn.ReLU(), + nn.Conv2d(64, 2, 1, stride=1, padding=0) + ) + + self.DC = HardDC(45, 10) + + def forward(self, x): + ap = x[0] + base_mask = x[1] + A = x[2] - return x + amp = ap[:,0].clone().detach() + phase = ap[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) #k measured + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft) + # change to two channels real/imag + input = torch.zeros(ap.shape).to('cuda') + input[:,0] = spatial.real + input[:,1] = spatial.imag + # dirty = input.clone().detach() + + + pred = self.blocks(input) + + pred = self.post(pred) + + pred = self.DC(pred, compl.unsqueeze(1), A, base_mask) + -class automap(nn.Module): + return pred + +class GANCS_generator(nn.Module): def __init__(self): super().__init__() torch.cuda.set_device(1) - self.fcs = nn.Sequential(nn.Linear(63*63*2,63*63),nn.Tanh(),nn.Linear(63*63,63*63),nn.Tanh()) + self.blocks = nn.Sequential( + SRBlock(2, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + ) + + self.post = nn.Sequential( + nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(), + nn.Conv2d(64, 64, 1, stride=1, padding=0), nn.ReLU(), + nn.Conv2d(64, 2, 1, stride=1, padding=0) + ) + + self.DC = HardDC(45, 10) + + def forward(self, x): + ap = x[0] + base_mask = x[1] + A = x[2] + + + amp = ap[:,0].clone().detach() + phase = ap[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) #k measured + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft) + # change to two channels real/imag + input = torch.zeros(ap.shape).to('cuda') + input[:,0] = spatial.real + input[:,1] = spatial.imag + # dirty = input.clone().detach() + + + pred = self.blocks(input) + + pred = self.post(pred) - self.convs = nn.Sequential(nn.Conv2d(1,64,5, stride=1, padding=2),nn.ReLU(),nn.Conv2d(64,64,5,stride=1, padding=2),nn.ReLU(),nn.Conv2d(64,1,7,stride=1, padding=3)) + pred = self.DC(pred, compl.unsqueeze(1), A, base_mask) + + + return pred + + +class GANCS_critic(nn.Module): + def __init__(self): + super().__init__() + # self.blocks = nn.Sequential( + # nn.Conv2d(2, 4, 3, stride=2, padding=1), + # nn.LeakyReLU(0.2), + # nn.Conv2d(4, 8, 3, stride=2, padding=1), + # nn.LeakyReLU(0.2), + # nn.Conv2d(8, 16, 3, stride=2, padding=1), + # nn.LeakyReLU(0.2), + # nn.Conv2d(16, 32, 3, stride=2, padding=1), + # nn.LeakyReLU(0.2), + # nn.Conv2d(32, 32, 3, stride=1, padding=1), + # nn.LeakyReLU(0.2), + # nn.Conv2d(32, 32, 1, stride=1, padding=0), + # nn.LeakyReLU(0.2), + # nn.Conv2d(32, 1, 1, stride=1, padding=0), + # nn.AdaptiveAvgPool2d(1) + # ) + self.block1 = nn.Sequential(nn.Conv2d(2, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2)) + self.block2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) + self.block3 = nn.Sequential(nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) + self.block4 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) + self.block5 = nn.Sequential(nn.Conv2d(256, 256, 3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) + self.block6 = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) + self.block7 = nn.Sequential(nn.Conv2d(512, 512, 3, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) + self.blocks = nn.Sequential(self.block1, self.block2, self.block3, self.block4, self.block5, self.block6, self.block7, nn.AdaptiveAvgPool2d(1)) + + def forward(self, x): + if x.shape[1] == 2: + amp = x[:,0].clone().detach() + phase = x[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + x = torch.fft.ifftshift(ifft).unsqueeze(1) + input = torch.zeros((x.shape[0],2,x.shape[2], x.shape[3])).to('cuda') + input[:,0] = x.real.squeeze(1) + input[:,1] = x.imag.squeeze(1) + return self.blocks(input) + + +class GANCS_unrolled(nn.Module): + def __init__(self): + super().__init__() + torch.cuda.set_device(0) + self.block1 = nn.Sequential( + SRBlock(2, 64), + SRBlock(64, 64), + nn.Conv2d(64, 2, 3, stride=1, padding=1), + ) + self.DC1 = SoftDC(45, 10) + self.block2 = nn.Sequential( + SRBlock(2, 64), + SRBlock(64, 64), + nn.Conv2d(64, 2, 3, stride=1, padding=1), + ) + self.DC2 = SoftDC(45, 10) + self.block3 = nn.Sequential( + SRBlock(2, 64), + SRBlock(64, 64), + nn.Conv2d(64, 2, 3, stride=1, padding=1), + ) + self.DC3 = HardDC(45, 10) + # self.block4 = nn.Sequential( + # SRBlock(2, 64), + # SRBlock(64, 64), + # nn.Conv2d(64, 2, 3, stride=1, padding=1), + # ) + # self.DC4 = SoftDC(45, 10) + # self.block5 = nn.Sequential( + # SRBlock(2, 64), + # SRBlock(64, 64), + # nn.Conv2d(64, 2, 3, stride=1, padding=1), + # ) + # self.DC5 = SoftDC(45, 10) def forward(self, x): - amp_x = (10 ** (10 * x[:,0]) - 1) / 10 ** 10 - phase_x = x[:,1] - compl_x = amp_x * torch.exp(1j * phase_x) - x[:,0] = compl_x.real - x[:,1] = compl_x.imag - x = torch.flatten(x, 1) - x = self.fcs(x) - x = x.reshape((x.shape[0],1,63,63)) - x = self.convs(x) - return x \ No newline at end of file + ap = x[0] + base_mask = x[1] + A = x[2] + + + amp = ap[:,0].clone().detach() + phase = ap[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) #k measured + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft) + # change to two channels real/imag + input = torch.zeros(ap.shape).to('cuda') + input[:,0] = spatial.real + input[:,1] = spatial.imag + measured = input.clone().detach() + # dirty = input.clone().detach() + + + pred = self.block1(input) + dc1 = self.DC1(pred, measured, A, base_mask) + # pred[:,0] = dc1.real.squeeze(1) + # pred[:,1] = dc1.imag.squeeze(1) + pred = self.block2(pred) + dc2 = self.DC2(pred, measured, A, base_mask) + # pred[:,0] = dc2.real.squeeze(1) + # pred[:,1] = dc2.imag.squeeze(1) + pred = self.block3(pred) + dc3 = self.DC3(pred, compl.unsqueeze(1), A, base_mask) + pred[:,0] = dc3.real.squeeze(1) + pred[:,1] = dc3.imag.squeeze(1) + # pred = self.block4(pred) + # pred = self.DC4(pred, measured, A, base_mask) + # pred = self.block5(pred) + # pred = self.DC5(pred, measured, A, base_mask) + + #pred = self.post(pred) + + #pred = self.DC(pred, compl.unsqueeze(1), A, base_mask) + + return (pred[:,0]+1j*pred[:,1]).unsqueeze(1) + + +class CLEANNN(nn.Module): + def __init__(self): + super().__init__() + torch.cuda.set_device(0) + + self.blocks = nn.Sequential( + SRBlock(2, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64) + ) + + self.post = nn.Sequential( + nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(), + nn.Conv2d(64, 64, 1, stride=1, padding=0), nn.ReLU(), + nn.Conv2d(64, 2, 1, stride=1, padding=0) + ) + + self.lamb = nn.Parameter(torch.tensor(1).float()) + + # self.conv = nn.Conv2d(2, 2, 3, stride=1, padding=1, bias=False) + + # self.DC = HardDC(45, 10) + # self.beamBlock = nn.Sequential( + # SRBlock(2, 64), + # SRBlock(64, 64), + # SRBlock(64, 64), + # SRBlock(64, 64), + # SRBlock(64, 64), + # nn.Conv2d(64, 2, 1, stride=1, padding=0) + # ) + # self.block1 = nn.Sequential(nn.Conv2d(2, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2)) + # self.block2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) + # self.block3 = nn.Sequential(nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) + # self.block4 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) + # self.block5 = nn.Sequential(nn.Conv2d(256, 256, 3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) + # self.block6 = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) + # self.block7 = nn.Sequential(nn.Conv2d(512, 512, 3, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) + + # self.beamBlock = nn.Sequential(self.block1, self.block2, self.block3, self.block4, self.block5, self.block6, self.block7, nn.Conv2d(512, 1, 3, stride=1, padding=1), nn.AdaptiveAvgPool2d(1), nn.Hardtanh(1,3)) + + + def forward(self, x): + # print(len(x)) + ap = x[0] + base_mask = x[1] + A = x[2] + M = x[3] + + amp = ap[:,0].clone().detach() + phase = ap[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) #k measured + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft) + # change to two channels real/imag + input = torch.zeros(ap.shape).to('cuda') + input[:,0] = spatial.real + input[:,1] = spatial.imag + # measured = input.clone().detach() + + #calculate Dirty Beam + beam = calc_DirtyBeam(base_mask) + # beam_copy = beam.clone().detach() + + # M = torch.zeros(input.shape).to('cuda') + + + # for i in range(5): + out_b = self.blocks(input) + out_p = self.post(out_b) + + + # residual = input - self.lamb*torch.einsum('bclm,bclm->bclm', out_p, beam) + residual = input - self.lamb*out_p + + M = M + self.lamb*out_p + + # if i == 4: + # break + + # return (input[:,0]+1j*input[:,1]).unsqueeze(1) + # return (M[:,0]+1j*M[:,1]).unsqueeze(1) + + # gauss_params = self.beamBlock(beam) + + # clean_beam = torch.fft.ifft2(torch.fft.fft2(torch.cat([gauss(63,s) for s in gauss_params]).reshape(-1,M.shape[2],M.shape[2]))) + + # M = M + input + # return clean_beam + # M_compl = (M[:,0]+1j*M[:,1]) + + + # M_conv = torch.einsum('blm,blm->blm', torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(M_compl))), torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(clean_beam)))) + + # M_conv = torch.fft.fftshift(torch.fft.ifft2(torch.fft.fftshift(M_conv))) + fft_residual = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(residual[:,0]+1j*residual[:,1]))) + + res_amp = torch.absolute(fft_residual) + res_phase = torch.angle(fft_residual) + + residual[:,0] = ((torch.log10(res_amp + 1e-10) / 10) + 1) + residual[:,1] = res_phase + + return residual, M \ No newline at end of file From 43153e06f1e25b02e1c498ab2d3885d0b053b547 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 14 Jun 2021 09:07:56 +0200 Subject: [PATCH 29/55] RNN Callback --- radionets/dl_framework/callbacks.py | 47 ++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/radionets/dl_framework/callbacks.py b/radionets/dl_framework/callbacks.py index c708e440..52d49c39 100644 --- a/radionets/dl_framework/callbacks.py +++ b/radionets/dl_framework/callbacks.py @@ -91,6 +91,8 @@ def plot_lrs(self): plt.tight_layout() + + class NormCallback(Callback): _order = 2 @@ -117,26 +119,37 @@ def __init__(self, vgg, physics_informed): self.physics_informed = physics_informed def before_batch(self): - x = self.xb[0].clone() + # x = self.xb[0].clone() + y = self.yb[0].clone() if self.physics_informed: - y = self.yb[0][0].clone() - base_mask = self.yb[0][1].clone() - A = self.yb[0][2].clone() + # y = self.yb[0][0].clone() + # base_mask = self.yb[0][1].clone() + # A = self.yb[0][2].clone() + x = self.xb[0][0].clone() + base_mask = self.xb[0][1].clone() + A = self.xb[0][2].clone() else: y = self.yb[0].clone() + randint = np.random.randint(0, 4, x.shape[0]) for i in range(x.shape[0]): - x[i, 0] = torch.rot90(x[i, 0], int(randint[i])) - x[i, 1] = torch.rot90(x[i, 1], int(randint[i])) + if x.shape[1] == 2: + x[i, 0] = torch.rot90(x[i, 0], int(randint[i])) + x[i, 1] = torch.rot90(x[i, 1], int(randint[i])) + else: + x[i, 0] = torch.rot90(x[i, 0], int(randint[i])) + if not self.vgg: y[i, 0] = torch.rot90(y[i, 0], int(randint[i])) y[i, 1] = torch.rot90(y[i, 1], int(randint[i])) if self.physics_informed: base_mask[i] = torch.rot90(base_mask[i], int(randint[i]), dims=[0,1]) A[i] = torch.rot90(A[i], int(randint[i]), dims=[0,1]) - self.learn.xb = [x] + # self.learn.xb = [x] + self.learn.yb = [y] if self.physics_informed: - self.learn.yb = [(y, base_mask, A)] + # self.learn.yb = [(y, base_mask, A)] + self.learn.xb = [(x, base_mask, A)] else: self.learn.yb = [y] @@ -155,3 +168,21 @@ def after_epoch(self): out = p / f"temp_{self.epoch + 1}.model" save_model(self, out, self.gan) print(f"\nFinished Epoch {self.epoch + 1}, model saved.\n") + +# Best callback ever +class OverwriteOneBatch_CLEAN(Callback): + _order = 4 + def __init__(self, n_iter): + self.n_iter = n_iter + + def before_batch(self): + input = self.xb[0] + M = torch.zeros(input[0].shape).to('cuda') + self.learn.xb = [(self.xb[0][0],self.xb[0][1],self.xb[0][2],M)] + + for i in range(self.n_iter-1): + # self.model.zero_grad() + self._do_one_batch() + self.learn.xb = [(self.pred[0].clone().detach(),self.xb[0][1],self.xb[0][2],self.pred[1].clone().detach())] + + \ No newline at end of file From 39786a6fb3dc395961965d9582f6b8a8b2e48cd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 14 Jun 2021 09:08:58 +0200 Subject: [PATCH 30/55] tuple input h5 --- radionets/dl_framework/data.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/radionets/dl_framework/data.py b/radionets/dl_framework/data.py index f4c0ce86..bfff7364 100644 --- a/radionets/dl_framework/data.py +++ b/radionets/dl_framework/data.py @@ -62,7 +62,7 @@ def __getitem__(self, i): y = self.open_image("y", i) base_mask = self.open_image("base_mask", i) A = self.open_image("A", i) - return x, (y, base_mask.squeeze(0), A.squeeze(0)) + return (x, base_mask.squeeze(0), A.squeeze(0)), y #x, (y, base_mask.squeeze(0), A.squeeze(0)) else: x = self.open_image("x", i) y = self.open_image("y", i) @@ -102,7 +102,10 @@ def open_image(self, var, i): ] ) if var == "x" or var == 'y': - if len(i) == 1: + if data.shape[1] == 1: + + data_channel = data + elif len(i) == 1: data_amp, data_phase = data[:, 0], data[:, 1] data_channel = torch.cat([data_amp, data_phase], dim=0) From 46e231de5f16351ecdbe3d228bea1aa99291aab2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 14 Jun 2021 09:09:14 +0200 Subject: [PATCH 31/55] GANLearner --- radionets/dl_framework/learner.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/radionets/dl_framework/learner.py b/radionets/dl_framework/learner.py index 7b64d553..418c36b6 100644 --- a/radionets/dl_framework/learner.py +++ b/radionets/dl_framework/learner.py @@ -6,8 +6,9 @@ TelegramLoggerCallback, DataAug, AvgLossCallback, + OverwriteOneBatch_CLEAN, ) -from fastai.optimizer import Adam +from fastai.optimizer import Adam, RMSProp from fastai.learner import Learner from fastai.data.core import DataLoaders from fastai.callback.data import CudaCallback @@ -17,7 +18,7 @@ # from radionets.dl_framework.architectures import superRes import torchvision from radionets.dl_training.utils import define_arch -from fastai.vision.gan import GANLearner +from fastai.vision.gan import GANLearner, FixedGANSwitcher, _tk_diff, GANDiscriminativeLR def get_learner( @@ -82,6 +83,8 @@ def define_learner( SaveTempCallback(model_path=model_path), AvgLossCallback, DataAug(vgg=train_conf["vgg"], physics_informed=train_conf["physics_informed"]), + # OverwriteOneBatch_CLEAN(5), + OverwriteOneBatch_CLEAN(10), ] ) if gan: @@ -90,6 +93,8 @@ def define_learner( SaveTempCallback(model_path=model_path, gan=gan), AvgLossCallback, DataAug(vgg=train_conf["vgg"], physics_informed=train_conf["physics_informed"]), + # WGANL1Callback, + # GANDiscriminativeLR, ] ) if train_conf["telegram_logger"] and not lr_find: @@ -106,21 +111,25 @@ def define_learner( loss_func = getattr(loss_functions, train_conf["loss_func"]) if gan: - gen_loss_func = getattr(loss_functions, 'gen_loss_func') + # gen_loss_func = getattr(loss_functions, 'gen_loss_func') #non physics informed + gen_loss_func = getattr(loss_functions, 'l1_wgan_GANCS') crit_loss_func = getattr(loss_functions, 'crit_loss_func') generator = arch critic = define_arch( - arch_name='discriminator', img_size=train_conf["image_size"] + arch_name='GANCS_critic', img_size=train_conf["image_size"] ) - init_cnn(generator) + # init_cnn(generator) init_cnn(critic) dls = DataLoaders.from_dsets( data.train_ds, data.valid_ds, bs=data.train_dl.batch_size, ) - learn = GANLearner(dls, generator, critic, gen_loss_func, crit_loss_func, lr=lr, cbs=cbfs, opt_func=opt_func) + switcher = FixedGANSwitcher(n_crit=1, n_gen=1) #GAN + # learn = GANLearner(dls, generator, critic, gen_loss_func, crit_loss_func, lr=lr, cbs=cbfs, opt_func=opt_func, switcher=switcher) #GAN + # learn = GANLearner.wgan(dls, generator, critic, lr=lr, cbs=cbfs, opt_func=RMSProp) #WGAN + learn = GANLearner(dls, generator, critic, gen_loss_func, _tk_diff, clip=0.01, switch_eval=False, lr=lr, cbs=cbfs, opt_func=RMSProp) #WGAN-l1 return learn # Combine model and data in learner From c2ea72cc45f0c44997b6f0556950680019019a8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 14 Jun 2021 09:09:43 +0200 Subject: [PATCH 32/55] GANCS Loss + physics informed loss --- radionets/dl_framework/loss_functions.py | 173 ++++++++++++++++++++--- 1 file changed, 155 insertions(+), 18 deletions(-) diff --git a/radionets/dl_framework/loss_functions.py b/radionets/dl_framework/loss_functions.py index 6ed9fd06..3de6d212 100644 --- a/radionets/dl_framework/loss_functions.py +++ b/radionets/dl_framework/loss_functions.py @@ -6,6 +6,8 @@ from pytorch_msssim import MS_SSIM from scipy.optimize import linear_sum_assignment from radionets.dl_framework.architectures import superRes +from fastai.vision.gan import _tk_mean +import matplotlib.pyplot as plt class FeatureLoss(nn.Module): @@ -121,14 +123,101 @@ def l1(x, y): loss = l1(x, y) return loss +def l1_phyinfo(x, y): + l1 = nn.L1Loss() + return l1(x[1],y[0]) + +def l1_GANCS(x,y): + amp = y[:,0].clone().detach() + phase = y[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft).unsqueeze(1) + # change to two channels real/imag + # input = torch.zeros(y.shape) + # input[:,0] = spatial.real + # input[:,1] = spatial.imag + + l1 = nn.L1Loss() + + return l1(x,spatial) + +def l1_CLEAN(x,y): + amp = y[:,0].clone().detach() + phase = y[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft).unsqueeze(1) + + + l1 = nn.L1Loss() + + return l1((x[1][:,0]+1j*x[1][:,1]).unsqueeze(1),spatial) + +def l1_wgan_GANCS(fake_pred,x,y): + amp = y[:,0].clone().detach() + phase = y[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft).unsqueeze(1) + + l1 = nn.L1Loss() + lamb = 1e-5 + + return l1(x,spatial)+lamb*_tk_mean(fake_pred, x, spatial) + def dirty_model(x, y): - print(x.shape) - print(y[0].shape) - print(y[1].shape) - print(y[2].shape) + amp = x[1][:,0] + phase = x[1][:,1] + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + pred = torch.fft.ifftshift(torch.absolute(ifft)).unsqueeze(1) + + # amp_t = y[0][:,0] + # phase_t = y[0][:,1] + # amp_rescaled_t = (10 ** (10 * amp_t) - 1) / 10 ** 10 + # compl_t = amp_rescaled_t * torch.exp(1j * phase_t) + # ifft_t = torch.fft.ifft2(compl_t) + # true = torch.fft.ifftshift(torch.absolute(ifft_t)).unsqueeze(1) + + + base_nums = torch.zeros(45) #hard code + n_tel = 10 #hardcode + c = 0 + for i in range(n_tel): + for j in range(n_tel): + if j<=i: + continue + base_nums[c] = 256 * (i + 1) + j + 1 + c += 1 + + base_mask = y[1] + A = y[2] + MD = torch.zeros(pred.shape, dtype=torch.complex64).to('cuda') + + + for idx, bn in enumerate(base_nums): + s_uv = torch.sum((base_mask == bn),3) + if not (base_mask == bn).any(): + continue + AI = torch.einsum('blm,bclm->bclm',A[...,idx],pred) + MD += torch.einsum('blm,bclm->bclm',s_uv,torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(AI)))) #spatial + + points = base_mask.clone() + points[points != 0] = 1 + points = torch.sum(points,3) + points[points == 0] = 1 + + MD = torch.fft.ifftshift(torch.absolute(torch.fft.ifft2(MD/points.unsqueeze(1)))) l1 = nn.L1Loss() - loss = l1(x, y[0]) + mse = nn.MSELoss() + loss = l1(pred, x[0]) + # loss = vgg19_feature_loss(MD,x[0]) return loss # vgg19 = superRes.vgg19_feature_maps(5,4).eval().to('cuda:1') @@ -181,7 +270,7 @@ def vgg19_feature_loss(x, y): # y_3c = torch.cat((y, ones), dim=1) # loss = l1(vgg19_feature_model_22(x), vgg19_feature_model_22(y))# + l1(vgg19_feature_model_12(x), vgg19_feature_model_12(y)) + l1(vgg19_feature_model_34(x), vgg19_feature_model_34(y)) + l1(vgg19_feature_model_44(x), vgg19_feature_model_44(y)) + l1(vgg19_feature_model_54(x), vgg19_feature_model_54(y)) - loss = l1(vgg19_feature_model_12(x), vgg19_feature_model_12(y)) + loss = l1(vgg19_feature_model_22(x), vgg19_feature_model_22(y)) return loss def gen_loss_func(fake_pred, x, y): @@ -193,27 +282,75 @@ def gen_loss_func(fake_pred, x, y): # xm = torch.einsum('bcij,bcjk->bcik',x,mask) # ym = torch.einsum('bcij,bcjk->bcik',y,mask) - # content_loss = l1(x,y) - content_loss = automap_l2(x,y) + content_loss = l1(x,y) + # content_loss = automap_l2(x,y) # content_loss = vgg19_feature_loss(x,y) adv_loss = bce(fake_pred, torch.ones_like(fake_pred)) return content_loss + 1e-3*adv_loss -def automap_l2(x,y): - amp_y = (10 ** (10 * y[:,0]) - 1) / 10 ** 10 - phase_y = y[:,1] - compl_y = amp_y * torch.exp(1j * phase_y) - ifft_y = torch.fft.ifft2(compl_y) - img_y = torch.absolute(ifft_y) - shift_y = torch.fft.fftshift(img_y).unsqueeze(1) - mse = nn.MSELoss() - return mse(x,shift_y) + +def gen_loss_wgan_l1(fake_pred, x, y): + l1 = nn.L1Loss() + content_loss = l1(x[1], y[0]) + + + adv_loss = _tk_mean(fake_pred, x, y) + lamb = 1.5 # first:0.9 + + return lamb*content_loss + (1-lamb)*adv_loss + + +def gen_loss_func_physics_informed(fake_pred, x, y): + + + + bce = nn.BCELoss() + l1 = nn.L1Loss() + ######### physics informed stuff + # base_nums = torch.zeros(45) #hard code + # n_tel = 10 #hardcode + # c = 0 + # for i in range(n_tel): + # for j in range(n_tel): + # if j<=i: + # continue + # base_nums[c] = 256 * (i + 1) + j + 1 + # c += 1 + + # base_mask = y[1] + # A = y[2] + # MD = torch.zeros(x[1].shape, dtype=torch.complex64).to('cuda') + + + + + # for idx, bn in enumerate(base_nums): + # s_uv = torch.sum((base_mask == bn),3) + # if not (base_mask == bn).any(): + # continue + # AI = torch.einsum('blm,bclm->bclm',A[...,idx],x[1]) + # MD += torch.einsum('blm,bclm->bclm',s_uv,torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(AI)))) #spatial + + # points = base_mask.clone() + # points[points != 0] = 1 + # points = torch.sum(points,3) + # points[points == 0] = 1 + + # MD = torch.fft.ifftshift(torch.absolute(torch.fft.ifft2(MD/points.unsqueeze(1)))) + + + content_loss = l1(x[1], y[0]) + # print(fake_pred.requires_grad) + adv_loss = bce(fake_pred, torch.ones_like(fake_pred)) + + return 1e-3*adv_loss +content_loss + def crit_loss_func(real_pred, fake_pred): bce = nn.BCELoss() loss = bce(real_pred, torch.ones_like(real_pred)) + bce(fake_pred, torch.zeros_like(fake_pred)) - + # print(fake_pred.requires_grad) return loss def cross_entropy(x,y): From 2982a6d375d20570bc1017d655da487530a804a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 14 Jun 2021 09:10:35 +0200 Subject: [PATCH 33/55] HardDC + SoftDC + gauss --- radionets/dl_framework/model.py | 156 +++++++++++++++++++++++++++++++- 1 file changed, 153 insertions(+), 3 deletions(-) diff --git a/radionets/dl_framework/model.py b/radionets/dl_framework/model.py index b9d8ff70..d2ce900d 100644 --- a/radionets/dl_framework/model.py +++ b/radionets/dl_framework/model.py @@ -291,7 +291,7 @@ def deconv(ni, nc, ks, stride, padding, out_padding): return layers -def load_pre_model(learn, pre_path, visualize=False): +def load_pre_model(learn, pre_path, visualize=False, gan=False): """ :param learn: object of type learner :param pre_path: string wich contains the path of the model @@ -303,7 +303,7 @@ def load_pre_model(learn, pre_path, visualize=False): if visualize: learn.load_state_dict(checkpoint["model"]) - + else: learn.model.load_state_dict(checkpoint["model"]) learn.opt.load_state_dict(checkpoint["opt"]) @@ -465,6 +465,27 @@ def _conv_block(self, ni, nf, stride): nn.BatchNorm2d(nf), ) +class SRBlock_noBias(nn.Module): + def __init__(self, ni, nf, stride=1): + super().__init__() + self.convs = self._conv_block(ni, nf, stride) + self.idconv = nn.Identity() if ni == nf else nn.Conv2d(ni, nf, 1,bias=False) + self.pool = ( + nn.Identity() if stride == 1 else nn.AvgPool2d(2, ceil_mode=True) + ) # nn.AvgPool2d(8, 2, ceil_mode=True) + + def forward(self, x): + return self.convs(x) + self.idconv(self.pool(x)) + + def _conv_block(self, ni, nf, stride): + return nn.Sequential( + nn.Conv2d(ni, nf, 3, stride=stride, padding=1,bias=False), + nn.BatchNorm2d(nf), + nn.PReLU(), + nn.Conv2d(nf, nf, 3, stride=1, padding=1,bias=False), + nn.BatchNorm2d(nf), + ) + class SRBlockPad(nn.Module): def __init__(self, ni, nf, stride=1): super().__init__() @@ -674,4 +695,133 @@ def __init__(self, padding: _size_4_t) -> None: def forward(self, input: torch.Tensor) -> torch.Tensor: x = better_padding(input, self.padding) - return x \ No newline at end of file + return x + +class HardDC(nn.Module): + def __init__(self, base_nums, n_tel): + super().__init__() + self.base_nums = torch.zeros(base_nums) + self.n_tel = n_tel + self.weights = nn.Parameter(torch.tensor(1).float()) + + def forward(self, x, input, A, base_mask): + c = 0 + for i in range(self.n_tel): + for j in range(self.n_tel): + if j<=i: + continue + self.base_nums[c] = 256 * (i + 1) + j + 1 + c += 1 + + + pred = torch.zeros((x.shape[0],1,x.shape[2],x.shape[3]), dtype=torch.complex64).to('cuda') + c = 0 + for idx, bn in enumerate(self.base_nums): + s_uv = torch.sum((base_mask == bn),3) + if not (base_mask == bn).any(): + continue + + xA = torch.einsum('bclm,blm->bclm',x,A[...,idx]) + x_prime = xA[:,0] + 1j*xA[:,1] #from 2 channels to complex for fft + k_prime = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(x_prime))) + y_prime = torch.einsum('blm,bclm->bclm',(1-s_uv),k_prime.unsqueeze(1)) + + Y = torch.einsum('blm,bclm->bclm', s_uv, input) + + full_k_space = Y + self.weights*y_prime + + pred += full_k_space # maybe a conj(A) missing, see paper 1910.07048 + c += 1 + + points = base_mask.clone() + points[points != 0] = 1 + points = torch.sum(points,3) + points[points == 0] = 1 + + + return torch.fft.fftshift(torch.fft.ifft2(torch.fft.fftshift(pred/c))) #divide by c because we summed c fully sampled maps in pred??? + +class SoftDC(nn.Module): + def __init__(self, base_nums, n_tel): + super().__init__() + self.base_nums = torch.zeros(base_nums) + self.n_tel = n_tel + self.weights = nn.Parameter(torch.tensor(1).float()) + + def forward(self, x, measured, A, base_mask): + c = 0 + for i in range(self.n_tel): + for j in range(self.n_tel): + if j<=i: + continue + self.base_nums[c] = 256 * (i + 1) + j + 1 + c += 1 + + + sum = torch.zeros((x.shape[0],1,x.shape[2],x.shape[3]), dtype=torch.complex64).to('cuda') + c = 0 + for idx, bn in enumerate(self.base_nums): + s_uv = torch.sum((base_mask == bn),3) + if not (base_mask == bn).any(): + continue + + xA = torch.einsum('bclm,blm->bclm',x,A[...,idx]) + x_prime = xA[:,0] + 1j*xA[:,1] #from 2 channels to complex for fft + k_prime = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(x_prime))) + y_prime = torch.einsum('blm,bclm->bclm',s_uv,k_prime.unsqueeze(1)) + + # Y = torch.einsum('blm,bclm->bclm', s_uv, input) + + diff = torch.fft.fftshift(torch.fft.ifft2(torch.fft.fftshift(y_prime))) + + sum += diff # maybe a conj(A) missing, see paper 1910.07048 + c += 1 + + sum = sum/c + + pred = torch.zeros(x.shape).to('cuda') + + pred[:,0] = sum.real.squeeze(1) + pred[:,1] = sum.imag.squeeze(1) + + + + + + return x + self.weights*(pred-measured) #divide by c because we summed c fully sampled maps in pred??? + + +def calc_DirtyBeam(base_mask): + s_uv = torch.sum(base_mask,3) + s_uv[s_uv != 0] = 1 + + b = torch.fft.fftshift(torch.fft.ifft2(torch.fft.fftshift(s_uv))) + beam = torch.zeros((b.shape[0],2, b.shape[1], b.shape[2])).to('cuda') + beam[:,0] = b.real.squeeze(1) + beam[:,1] = b.imag.squeeze(1) + return beam + + +def gauss(kernel_size, sigma): + # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) + x_cord = torch.arange(kernel_size).to('cuda') + x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size) + y_grid = x_grid.t() + xy_grid = torch.stack([x_grid, y_grid], dim=-1) + + mean = (kernel_size - 1)/2. + variance = sigma**2. + + # Calculate the 2-dimensional gaussian kernel which is + # the product of two gaussian distributions for two different + # variables (in this case called x and y) + gaussian_kernel = (1./(2.*np.pi*variance)) *\ + torch.exp( + -torch.sum((xy_grid - mean)**2., dim=-1) /\ + (2*variance) + ) + # Make sure sum of values in gaussian kernel equals 1. + gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) + return gaussian_kernel + + \ No newline at end of file From 5a338fa7226a9b094952f217e2c72c37d315dc7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 14 Jun 2021 09:11:17 +0200 Subject: [PATCH 34/55] gan save --- radionets/dl_training/scripts/start_training.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/radionets/dl_training/scripts/start_training.py b/radionets/dl_training/scripts/start_training.py index 4e6fda44..eaa52f41 100644 --- a/radionets/dl_training/scripts/start_training.py +++ b/radionets/dl_training/scripts/start_training.py @@ -92,7 +92,7 @@ def main(configuration_path, mode): # load pretrained model if train_conf["pre_model"] != "none": learn.create_opt() - load_pre_model(learn, train_conf["pre_model"]) + load_pre_model(learn, train_conf["pre_model"], gan=False) # Train the model, except interrupt try: @@ -130,9 +130,9 @@ def main(configuration_path, mode): # learn.fine_tune(train_conf["num_epochs"]) learn.fit(train_conf["num_epochs"]) except KeyboardInterrupt: - pop_interrupt(learn, train_conf) + pop_interrupt(learn, train_conf, True) - end_training(learn, train_conf) + end_training(learn, train_conf, True) if train_conf["inspection"]: create_inspection_plots(train_conf, rand=True) From cf597d881903a0071c484de6d995a9c14b09d165 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 14 Jun 2021 09:11:43 +0200 Subject: [PATCH 35/55] gan save only generator --- radionets/dl_training/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/radionets/dl_training/utils.py b/radionets/dl_training/utils.py index 0006061e..8a0bd853 100644 --- a/radionets/dl_training/utils.py +++ b/radionets/dl_training/utils.py @@ -78,12 +78,12 @@ def define_arch(arch_name, img_size): return arch -def pop_interrupt(learn, train_conf): +def pop_interrupt(learn, train_conf, gan=False): if click.confirm("KeyboardInterrupt, do you want to save the model?", abort=False): - model_path = train_conf["model_path"] + model_path = Path(train_conf["model_path"]) # save model print("Saving the model after epoch {}".format(learn.epoch)) - save_model(learn, model_path) + save_model(learn, model_path, gan=False) # plot loss plot_loss(learn, model_path) @@ -96,9 +96,9 @@ def pop_interrupt(learn, train_conf): sys.exit(1) -def end_training(learn, train_conf): +def end_training(learn, train_conf, gan=False): # Save model - save_model(learn, Path(train_conf["model_path"])) + save_model(learn, Path(train_conf["model_path"]), gan=False) # Plot loss plot_loss(learn, Path(train_conf["model_path"])) From 87b85a2a164e3caa220f9dbfb218aea71ef07714 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 14 Jun 2021 09:13:13 +0200 Subject: [PATCH 36/55] gridding + response + baselines-mask --- radionets/simulations/process_vlbi.py | 95 +++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 5 deletions(-) diff --git a/radionets/simulations/process_vlbi.py b/radionets/simulations/process_vlbi.py index b0e919ec..88b717e9 100644 --- a/radionets/simulations/process_vlbi.py +++ b/radionets/simulations/process_vlbi.py @@ -126,7 +126,7 @@ def process_data_dirty_model(data_path, freq, n_positions, fov_asec): print(f"\n Loading VLBI data set.\n") bundles = dt.get_bundles(data_path) - freq = freq*10**6 # hard code #eht 227297 + freq = freq*10**6 # mhz hard code #eht 227297 uvfits = dt.get_bundles(bundles[3]) imgs = dt.get_bundles(bundles[2]) configs = dt.get_bundles(bundles[0]) @@ -150,7 +150,7 @@ def process_data_dirty_model(data_path, freq, n_positions, fov_asec): print(f"\n Load subset.\n") for i in np.arange(p*1000, p*1000+1000): sampled = uv_srt[i] - target = img_srt[i] + target = img_srt[i] # +1000 because I had to only grid images from 1000-1999 img[i-p*1000] = np.asarray(Image.open(str(target))) @@ -179,7 +179,11 @@ def process_data_dirty_model(data_path, freq, n_positions, fov_asec): fov = fov_asec*np.pi/(3600*180) # hard code #default 0.00018382 # delta_u = 1/(fov*N/256) # hard code delta_u = 1/(fov) # with a set N this is the same as zooming in since N*delta_u can be smaller than u_max - delta_u = (2*max(np.max(u_0),np.max(v_0))/N) # test gridding pixel size + # delta_u = (2*max(np.max(u_0),np.max(v_0))/N) # test gridding pixel size + # biggest_baselines = 8611*1e3 + # wave = const.c/(freq/un.second)/un.meter + # uv_max = biggest_baselines/wave + # delta_u = uv_max/N # print(delta_u) for i in range(N): for j in range(N): @@ -206,7 +210,7 @@ def process_data_dirty_model(data_path, freq, n_positions, fov_asec): truth_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(img_resized, axes=(1,2)), axes=(1,2)), axes=(1,2)) fft_scaled_truth = prepare_fft_images(truth_fft, True, False) - out = data_path + "/h5/samp_train"+ str(p) + ".h5" + out = data_path + "/h5/samp_train"+ str(p) +".h5" save_fft_pair_with_response(out, samp_img[:800], fft_scaled_truth[:800], np.expand_dims(base_mask,0), np.expand_dims(A,0)) out = data_path + "/h5/samp_valid"+ str(p) + ".h5" save_fft_pair_with_response(out, samp_img[800:], fft_scaled_truth[800:], np.expand_dims(base_mask,0), np.expand_dims(A,0)) @@ -217,7 +221,7 @@ def response(config, N, unique_telescopes, unique_baselines): array_layout = layouts.get_array_layout('vlba') src_crd = rc['src_coord'] - wave = const.c/((float(rc['channel'].split(':')[0])/2)*10**6/un.second)/un.meter + wave = const.c/((float(rc['channel'].split(':')[0]))*10**6/un.second)/un.meter rd = scan.rd_grid(rc['fov_size']*np.pi/(3600*180),N, src_crd) E = scan.getE(rd, array_layout, wave, src_crd) A = np.zeros((N,N,int(unique_baselines))) @@ -230,4 +234,85 @@ def response(config, N, unique_telescopes, unique_baselines): counter += 1 return A + + +def process_measurement(data_path, file, config, fov_asec): + + print(f"\n Loading VLBI data set.\n") + configs = config + size = 1 + N=63 + with fits.open(file) as hdul: + n_sampled = hdul[0].data.shape[0] #number of sampled points + baselines = hdul[0].data['Baseline'] + unique_telescopes = hdul[3].data.shape[0] + unique_baselines = (unique_telescopes**2 - unique_telescopes)/2 + freq = hdul[0].header[37] + offset = hdul[2].data['IF FREQ'] + for o in offset[0][1:]: + baselines = np.append(baselines,hdul[0].data['Baseline']) + baselines = np.append(baselines,baselines) + # response matrices + A = response(configs, N, unique_telescopes, unique_baselines) + + samps = np.zeros((size,4,n_sampled*2)) + print(f"\n Load subset.\n") + + + with fits.open(file) as hdul: + data = hdul[0].data + cmplx = data['DATA'] + x = cmplx[...,0,0] + y = cmplx[...,0,1] + w = cmplx[...,0,2] + x = np.squeeze(x) + y = np.squeeze(y) + w = np.squeeze(w) + ap = np.sqrt(x**2+y**2) + ph = np.angle(x+1j*y) + u = np.array([]) + v = np.array([]) + for f in offset[0]: + u = np.append(u,data['UU--']*(freq+f)) + v = np.append(v,data['VV--']*(freq+f)) + samps = [np.append(u,-u),np.append(v,-v),np.append(ap,ap),np.append(ph,-ph)] + + print(f"\n Gridding VLBI data set.\n") + + # Generate Mask + u_0 = samps[0] + v_0 = samps[1] + mask = np.zeros((N,N,u_0.shape[0])) + + base_mask = np.zeros((N,N,int(unique_baselines))) + + fov = fov_asec*np.pi/(3600*180) # hard code #default 0.00018382 + delta_u = 1/(fov) # with a set N this is the same as zooming in since N*delta_u can be smaller than u_max + + for i in range(N): + for j in range(N): + u_cell = (j-N/2)*delta_u + v_cell = (i-N/2)*delta_u + mask[i,j] = ((u_cell <= u_0) & (u_0 <= u_cell+delta_u)) & ((v_cell <= v_0) & (v_0 <= v_cell+delta_u)) + + base = np.unique(baselines[mask[i,j].astype(bool)]) + base_mask[i,j,:base.shape[0]] = base + + mask = np.flip(mask, [0]) + points = np.sum(mask, 2) + points[points==0] = 1 + img_resized = np.zeros((size,N,N)) + + samp_img = np.zeros((size,2,N,N)) + samp_img[0,0] = np.matmul(mask, samps[2].T)/points + samp_img[0,0] = (np.log10(samp_img[0,0] + 1e-10) / 10) + 1 + samp_img[0,1] = np.matmul(mask, samps[3].T)/points + + + # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) + + out = data_path + "/h5/samp_meas2.h5" + save_fft_pair_with_response(out, samp_img, samp_img, np.expand_dims(base_mask,0), np.expand_dims(A,0)) + + From 7f05e98ab93f26c471e0722242b6e617d0766fc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 14 Jun 2021 11:02:57 +0200 Subject: [PATCH 37/55] ConvGRUCell --- radionets/dl_framework/architectures/superRes.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/radionets/dl_framework/architectures/superRes.py b/radionets/dl_framework/architectures/superRes.py index 585e37db..2d9f41c1 100644 --- a/radionets/dl_framework/architectures/superRes.py +++ b/radionets/dl_framework/architectures/superRes.py @@ -1297,4 +1297,16 @@ def forward(self, x): residual[:,0] = ((torch.log10(res_amp + 1e-10) / 10) + 1) residual[:,1] = res_phase - return residual, M \ No newline at end of file + return residual, M + +class CLEANR(nn.Module): + def __init__(self): + super().__init__() + self.block1 = nn.Sequential( + nn.Conv2d(2, 64, stride=2, padding=0), + nn.Tanh(), + ) + self.block2 = nn.GRU(64, 256) + + def forward(self, x): + return 0 \ No newline at end of file From 4a02a54e51a8cfeace2ea73847e7d3503323f24f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Mon, 14 Jun 2021 11:05:49 +0200 Subject: [PATCH 38/55] ConvGRUCell --- radionets/dl_framework/model.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/radionets/dl_framework/model.py b/radionets/dl_framework/model.py index d2ce900d..9c624c63 100644 --- a/radionets/dl_framework/model.py +++ b/radionets/dl_framework/model.py @@ -824,4 +824,30 @@ def gauss(kernel_size, sigma): gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) return gaussian_kernel - \ No newline at end of file + +class ConvGRUCell(nn.Module): + def __init__(self, input_size, hidden_size, kernel_size, dilation=1, bias=True): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.kernel_size = kernel_size + self.dilation = dilation + self.bias = bias + + self.Wih = nn.Conv2d(input_size, 3*hidden_size, kernel_size, dilation=dilation, bias=bias, padding=dilation*(kernel_size-1)/2) + self.Whh = nn.Conv2d(input_size, 3*hidden_size, kernel_size, dilation=dilation, bias=bias, padding=dilation*(kernel_size-1)/2) + + def forward(self, x, hx=None): + if hx is None: + hx = torch.zeros((x.size(0), self.hidden_size) + x.size()[2:], requires_grad=False) + + ih = self.Wih(x).chunk(3, dim=1) + hh = self.Whh(hx).chunk(3, dim=1) + + z = torch.sigmoid(ih[0] + hh[0]) + r = torch.sigmoid(ih[1] + hh[1]) + n = torch.tanh(ih[2]+ r*hh[2]) + + hx = (1-z)*hx + z*n + + return hx \ No newline at end of file From d7b76c537a8d682569cb57417546bb0c92a6efb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Tue, 15 Jun 2021 08:51:37 +0200 Subject: [PATCH 39/55] ConvGRUCell + gradient Function --- radionets/dl_framework/model.py | 66 +++++++++++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 4 deletions(-) diff --git a/radionets/dl_framework/model.py b/radionets/dl_framework/model.py index 9c624c63..d9140c6e 100644 --- a/radionets/dl_framework/model.py +++ b/radionets/dl_framework/model.py @@ -1,3 +1,4 @@ +from numpy.lib.function_base import diff import torch from torch import nn import torch.nn.functional as F @@ -834,12 +835,12 @@ def __init__(self, input_size, hidden_size, kernel_size, dilation=1, bias=True): self.dilation = dilation self.bias = bias - self.Wih = nn.Conv2d(input_size, 3*hidden_size, kernel_size, dilation=dilation, bias=bias, padding=dilation*(kernel_size-1)/2) - self.Whh = nn.Conv2d(input_size, 3*hidden_size, kernel_size, dilation=dilation, bias=bias, padding=dilation*(kernel_size-1)/2) + self.Wih = nn.Conv2d(input_size, 3*hidden_size, kernel_size, dilation=dilation, bias=bias, padding=dilation*(kernel_size-1)//2) + self.Whh = nn.Conv2d(input_size, 3*hidden_size, kernel_size, dilation=dilation, bias=bias, padding=dilation*(kernel_size-1)//2) def forward(self, x, hx=None): if hx is None: - hx = torch.zeros((x.size(0), self.hidden_size) + x.size()[2:], requires_grad=False) + hx = torch.zeros((x.size(0), self.hidden_size) + x.size()[2:], requires_grad=False).to('cuda') ih = self.Wih(x).chunk(3, dim=1) hh = self.Whh(hx).chunk(3, dim=1) @@ -850,4 +851,61 @@ def forward(self, x, hx=None): hx = (1-z)*hx + z*n - return hx \ No newline at end of file + return hx + +def gradFunc(x, y, A, base_mask, n_tel, base_nums): + # does_require_grad = x.requires_grad + # with torch.enable_grad(): + # x.requires_grad_(True) + # print(y.shape) + # base_nums = torch.zeros(base_nums) + # c = 0 + # for i in range(n_tel): + # for j in range(n_tel): + # if j<=i: + # continue + # base_nums[c] = 256 * (i + 1) + j + 1 + # c += 1 + # difference = torch.zeros((x.shape[0],1,x.shape[2],x.shape[3]), dtype=torch.complex64).to('cuda') + # c = 0 + # for idx, bn in enumerate(base_nums): + # s_uv = torch.sum((base_mask == bn),3) + # if not (base_mask == bn).any(): + # continue + + # xA = torch.einsum('bclm,blm->bclm',x,A[...,idx]) + # x_prime = xA[:,0] + 1j*xA[:,1] #from 2 channels to complex for fft + # k_prime = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(x_prime))) + # y_prime = torch.einsum('blm,bclm->bclm',s_uv,k_prime.unsqueeze(1)) + + # # Y = torch.einsum('blm,bclm->bclm', s_uv, y.unsqueeze(1)) + + # # d = Y - y_prime + + # difference += y_prime + # c += 1 + # print(difference.shape) + # print(y.shape) + # print((difference-y).shape) + # print((difference-y.unsqueeze(1)).shape) + # points = base_mask.clone() + # points[points != 0] = 1 + # points = torch.sum(points,3) + # points[points == 0] = 1 + mask = torch.sum(base_mask, 3) + mask[mask != 0] = 1 + + fx = torch.fft.fft2(torch.fft.fftshift(x[:,0]+1j*x[:,1])) + pfx = torch.einsum('blm,blm->blm', mask, fx) + + diff = pfx-y + + ift = torch.fft.fftshift(torch.fft.ifft2(diff)) + + grad = torch.zeros((ift.size(0), 2) + ift.size()[1:]).to('cuda') + grad[:,0] = ift.real.squeeze(1) + grad[:,1] = ift.imag.squeeze(1) + # grad_x = torch.autograd.grad(spatial, inputs=x, retain_graph=does_require_grad, create_graph=does_require_grad)[0] + # grad_2c = torch.zeros((grad_x, 1) + grad_x()[2:], dtype=torch.complex64).to('cuda') + # x.requires_grad_(does_require_grad) + return grad \ No newline at end of file From ca03c071981208714762f2d3707e00011c1e7b72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Tue, 15 Jun 2021 08:52:36 +0200 Subject: [PATCH 40/55] deactivate OverwriteOneBatch_CLEAN --- radionets/dl_framework/learner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/radionets/dl_framework/learner.py b/radionets/dl_framework/learner.py index 418c36b6..ed67ed62 100644 --- a/radionets/dl_framework/learner.py +++ b/radionets/dl_framework/learner.py @@ -84,7 +84,7 @@ def define_learner( AvgLossCallback, DataAug(vgg=train_conf["vgg"], physics_informed=train_conf["physics_informed"]), # OverwriteOneBatch_CLEAN(5), - OverwriteOneBatch_CLEAN(10), + # OverwriteOneBatch_CLEAN(10), ] ) if gan: From f77c9bbbff40d4770f560f12fd9b640558f9e16c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Tue, 15 Jun 2021 08:53:11 +0200 Subject: [PATCH 41/55] ConvRNN + RIM --- .../dl_framework/architectures/superRes.py | 72 +++++++++++++++++-- 1 file changed, 65 insertions(+), 7 deletions(-) diff --git a/radionets/dl_framework/architectures/superRes.py b/radionets/dl_framework/architectures/superRes.py index 2d9f41c1..0174c3f8 100644 --- a/radionets/dl_framework/architectures/superRes.py +++ b/radionets/dl_framework/architectures/superRes.py @@ -24,6 +24,8 @@ SoftDC, calc_DirtyBeam, gauss, + ConvGRUCell, + gradFunc, ) from functools import partial import torchvision @@ -1299,14 +1301,70 @@ def forward(self, x): return residual, M -class CLEANR(nn.Module): +class ConvRNN(nn.Module): def __init__(self): super().__init__() - self.block1 = nn.Sequential( - nn.Conv2d(2, 64, stride=2, padding=0), - nn.Tanh(), + self.conv1 = nn.Sequential( + nn.Conv2d(4, 64, 5, stride=1, dilation=1, padding=2), #padding = dilation * (ks-1) // 2 + nn.ReLU(), + ) + self.GRU1 = ConvGRUCell(64, 64, 1) + self.conv2 = nn.Sequential( + nn.Conv2d(64, 64, 3, stride=1, dilation=2, padding=2), + nn.ReLU(), + ) + self.GRU2 = ConvGRUCell(64, 64, 1) + self.conv3 = nn.Sequential( + nn.Conv2d(64, 2, 3, stride=1, dilation=1, padding=1, bias=False) ) - self.block2 = nn.GRU(64, 256) - def forward(self, x): - return 0 \ No newline at end of file + def forward(self, x, hx=None): + if not hx: + hx = [None]*2 + + c1 = self.conv1(x) + g1 = self.GRU1(c1, hx[0]) + c2 = self.conv2(g1) + g2 = self.GRU2(c2, hx[1]) + c3 = self.conv3(g2) + + + + return c3, [g1, g2] + +class RIM(nn.Module): + def __init__(self, n_steps=20): + super().__init__() + torch.cuda.set_device(1) + self.n_steps = n_steps + self.cRNN = ConvRNN() + + def forward(self, x, hx=None): + ap = x[0] + amp = ap[:,0] + phase = ap[:,1] + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) #k measured + data = compl.clone() + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft) + eta = torch.zeros(ap.shape).to('cuda') + eta[:,0] = spatial.real + eta[:,1] = spatial.imag + + + etas = [] + + for i in range(self.n_steps): + grad = gradFunc(eta, data, x[2], x[1], 8, 45) + input = torch.cat((eta,grad), dim=1) + + delta, hx = self.cRNN(input, hx) + eta = eta + delta + # plt.imshow(torch.absolute(eta[0,0]+1j*eta[0,1]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + # print(i) + etas.append(eta) + + return etas \ No newline at end of file From 5b37b310398686fb2656773b73d059c518a2f26b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Tue, 15 Jun 2021 08:53:30 +0200 Subject: [PATCH 42/55] RNN Loss function --- radionets/dl_framework/loss_functions.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/radionets/dl_framework/loss_functions.py b/radionets/dl_framework/loss_functions.py index 3de6d212..63ae488b 100644 --- a/radionets/dl_framework/loss_functions.py +++ b/radionets/dl_framework/loss_functions.py @@ -156,6 +156,23 @@ def l1_CLEAN(x,y): return l1((x[1][:,0]+1j*x[1][:,1]).unsqueeze(1),spatial) +def l1_RIM(x,y): + amp = y[:,0].clone().detach() + phase = y[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft).unsqueeze(1) + + + l1 = nn.L1Loss() + loss = 0 + for eta in x: + loss += l1((eta[:,0]+1j*eta[:,1]).unsqueeze(1),spatial) + + loss = loss/len(x) + return loss + def l1_wgan_GANCS(fake_pred,x,y): amp = y[:,0].clone().detach() phase = y[:,1].clone().detach() From db0d8caf27d8a61e70994d2ca8442a6ccb3e33df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Tue, 15 Jun 2021 08:54:22 +0200 Subject: [PATCH 43/55] cleanup gradfunc --- radionets/dl_framework/model.py | 42 +-------------------------------- 1 file changed, 1 insertion(+), 41 deletions(-) diff --git a/radionets/dl_framework/model.py b/radionets/dl_framework/model.py index d9140c6e..3c144913 100644 --- a/radionets/dl_framework/model.py +++ b/radionets/dl_framework/model.py @@ -854,44 +854,6 @@ def forward(self, x, hx=None): return hx def gradFunc(x, y, A, base_mask, n_tel, base_nums): - # does_require_grad = x.requires_grad - # with torch.enable_grad(): - # x.requires_grad_(True) - # print(y.shape) - # base_nums = torch.zeros(base_nums) - # c = 0 - # for i in range(n_tel): - # for j in range(n_tel): - # if j<=i: - # continue - # base_nums[c] = 256 * (i + 1) + j + 1 - # c += 1 - # difference = torch.zeros((x.shape[0],1,x.shape[2],x.shape[3]), dtype=torch.complex64).to('cuda') - # c = 0 - # for idx, bn in enumerate(base_nums): - # s_uv = torch.sum((base_mask == bn),3) - # if not (base_mask == bn).any(): - # continue - - # xA = torch.einsum('bclm,blm->bclm',x,A[...,idx]) - # x_prime = xA[:,0] + 1j*xA[:,1] #from 2 channels to complex for fft - # k_prime = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(x_prime))) - # y_prime = torch.einsum('blm,bclm->bclm',s_uv,k_prime.unsqueeze(1)) - - # # Y = torch.einsum('blm,bclm->bclm', s_uv, y.unsqueeze(1)) - - # # d = Y - y_prime - - # difference += y_prime - # c += 1 - # print(difference.shape) - # print(y.shape) - # print((difference-y).shape) - # print((difference-y.unsqueeze(1)).shape) - # points = base_mask.clone() - # points[points != 0] = 1 - # points = torch.sum(points,3) - # points[points == 0] = 1 mask = torch.sum(base_mask, 3) mask[mask != 0] = 1 @@ -905,7 +867,5 @@ def gradFunc(x, y, A, base_mask, n_tel, base_nums): grad = torch.zeros((ift.size(0), 2) + ift.size()[1:]).to('cuda') grad[:,0] = ift.real.squeeze(1) grad[:,1] = ift.imag.squeeze(1) - # grad_x = torch.autograd.grad(spatial, inputs=x, retain_graph=does_require_grad, create_graph=does_require_grad)[0] - # grad_2c = torch.zeros((grad_x, 1) + grad_x()[2:], dtype=torch.complex64).to('cuda') - # x.requires_grad_(does_require_grad) + return grad \ No newline at end of file From d2d5e9f32c6db36ef530e68273137f773e214f0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Tue, 15 Jun 2021 08:56:40 +0200 Subject: [PATCH 44/55] pred change to tuple --- radionets/evaluation/utils.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/radionets/evaluation/utils.py b/radionets/evaluation/utils.py index 17c2133e..1fa79318 100644 --- a/radionets/evaluation/utils.py +++ b/radionets/evaluation/utils.py @@ -211,6 +211,12 @@ def load_pretrained_model(arch_name, model_path, img_size=63): """ if 'vgg19' in arch_name: arch = getattr(architecture, arch_name)() + elif 'GANCS' in arch_name: + arch = getattr(architecture, arch_name)() + elif 'CLEANNN' in arch_name: + arch = getattr(architecture, arch_name)() + elif 'RIM' in arch_name: + arch = getattr(architecture, arch_name)() elif 'automap' in arch_name: arch = getattr(architecture, arch_name)() elif "filter_deep" in arch_name or "resnet" or "Res" in arch_name: @@ -269,13 +275,20 @@ def eval_model(img, model): pred: n 1d arrays predicted images """ - if len(img.shape) == (3): - img = img.unsqueeze(0) model.eval() model.cuda() - with torch.no_grad(): - pred = model(img.float().cuda()) - return pred.cpu() + if isinstance(img, tuple): + img = (img[0].unsqueeze(0).float().cuda(), img[1].float().cuda(), img[2].float().cuda()) + with torch.no_grad(): + pred = model(img) + else: + if len(img.shape) == (3): + img = img.unsqueeze(0) + with torch.no_grad(): + pred = model(img.float().cuda()) + if isinstance(pred, tuple): + return pred + return pred def get_ifft(array, amp_phase=False): From 0da2c3b0763c712368db75ee8c066787ca28e4d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Thu, 15 Jul 2021 10:56:10 +0200 Subject: [PATCH 45/55] RIM --- .../dl_framework/architectures/superRes.py | 250 +++++++++++++++++- radionets/dl_framework/learner.py | 2 + radionets/dl_framework/loss_functions.py | 25 ++ radionets/dl_framework/model.py | 175 +++++++++++- radionets/evaluation/utils.py | 4 + radionets/simulations/process_vlbi.py | 232 ++++++++++++++-- 6 files changed, 655 insertions(+), 33 deletions(-) diff --git a/radionets/dl_framework/architectures/superRes.py b/radionets/dl_framework/architectures/superRes.py index 0174c3f8..8aef9bf7 100644 --- a/radionets/dl_framework/architectures/superRes.py +++ b/radionets/dl_framework/architectures/superRes.py @@ -12,6 +12,8 @@ Lambda, better_symmetry, fft, + gradFunc2, + manual_grad, tf_shift, btf_shift, CirculationShiftPad, @@ -26,6 +28,9 @@ gauss, ConvGRUCell, gradFunc, + gradFunc2, + gradFunc_putzky, + rnd_dirty_noise, ) from functools import partial import torchvision @@ -33,6 +38,7 @@ import numpy as np import matplotlib.pyplot as plt import numpy as np +# import irim.rim as rim class superRes_simple(nn.Module): def __init__(self, img_size): @@ -1330,7 +1336,7 @@ def forward(self, x, hx=None): - return c3, [g1, g2] + return c3, [g1.detach(), g2.detach()] class RIM(nn.Module): def __init__(self, n_steps=20): @@ -1338,6 +1344,8 @@ def __init__(self, n_steps=20): torch.cuda.set_device(1) self.n_steps = n_steps self.cRNN = ConvRNN() + # self.bn = nn.BatchNorm2d(2) + def forward(self, x, hx=None): ap = x[0] @@ -1345,26 +1353,248 @@ def forward(self, x, hx=None): phase = ap[:,1] amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 compl = amp_rescaled * torch.exp(1j * phase) #k measured - data = compl.clone() - ifft = torch.fft.ifft2(compl) - spatial = torch.fft.ifftshift(ifft) + data = compl.clone().detach() + compl_shift = torch.fft.fftshift(compl) # shift low freq to corner + ifft = torch.fft.ifft2(compl_shift) + ifft_shift = torch.fft.ifftshift(ifft) # shift low freq to center eta = torch.zeros(ap.shape).to('cuda') - eta[:,0] = spatial.real - eta[:,1] = spatial.imag + eta[:,0] = ifft_shift.real + eta[:,1] = ifft_shift.imag + etas = [] for i in range(self.n_steps): - grad = gradFunc(eta, data, x[2], x[1], 8, 45) + + grad = gradFunc(eta, data, x[2], x[1], 8, 45).detach() + # bn = self.bn(grad) input = torch.cat((eta,grad), dim=1) delta, hx = self.cRNN(input, hx) eta = eta + delta - # plt.imshow(torch.absolute(eta[0,0]+1j*eta[0,1]).cpu().detach().numpy()) + # plt.imshow(torch.absolute(bn[0,0]+1j*bn[0,1]).cpu().detach().numpy()) # plt.colorbar() # plt.show() - # print(i) etas.append(eta) - return etas \ No newline at end of file + return etas + + +class ConvRNN_deepClean(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(4, 64, 11, stride=3, dilation=1, padding=2), # use stride=4 for 63 px images + nn.Tanh(), + ) + self.GRU1 = ConvGRUCell(64, 64, 11) + self.conv2 = nn.Sequential( + nn.ConvTranspose2d(64, 64, 11, stride=3, dilation=1, padding=2), + nn.Tanh(), + ) + self.GRU2 = ConvGRUCell(64, 64, 11) + self.conv3 = nn.Sequential( + nn.Conv2d(64, 2, 11, stride=1, dilation=1, padding=5, bias=False) + ) + # self.weight = nn.Parameter(torch.tensor([0.25])) + + def forward(self, x, hx=None): + if not hx: + hx = [None]*2 + + + complex2channels = torch.cat((x[:,0].real.unsqueeze(1),x[:,0].imag.unsqueeze(1),x[:,1].real.unsqueeze(1),x[:,1].imag.unsqueeze(1)), dim=1) + # print(complex2channels.shape) + # print(complex2channels.dtype) + + c1 = self.conv1(complex2channels) + g1 = self.GRU1(c1, hx[0]) + c2 = self.conv2(g1) + g2 = self.GRU2(c2, hx[1]) + c3 = self.conv3(g2) + + # plt.imshow(torch.absolute(c2[0,0]+1j*c2[0,1]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + channels2complex = (c3[:,0]+1j*c3[:,1]).unsqueeze(1) + + return channels2complex, [g1.detach(), g2.detach()] # ??? detach() ??? + +class RIM_DC(nn.Module): + def __init__(self, n_steps=10): + super().__init__() + torch.cuda.set_device(0) + # torch.set_default_dtype(torch.float64) ## this is really important since we do a lot of ffts. otherwise torch.zeros is float32 and we can't save complex128 into it! + self.n_steps = n_steps + + self.cRNN = ConvRNN_deepClean() + # self.type(torch.complex64) + # torch.backends.cudnn.enabled = False + + + def forward(self, x, hx=None): + ap = x[0] + amp = ap[:,0] + phase = ap[:,1] + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) #k measured + + + data = compl.clone().detach().unsqueeze(1) + compl_shift = torch.fft.fftshift(compl) # shift low freq to corner + ifft = torch.fft.ifft2(compl_shift, norm="forward") + eta = torch.fft.ifftshift(ifft).unsqueeze(1) # shift low freq to center + # print(eta.shape) + # eta = torch.zeros(ap.shape, dtype=torch.float64).to('cuda') + # eta[:,0] = ifft_shift.real + # eta[:,1] = ifft_shift.imag + + + + etas = [] + # plt.imshow(torch.abs(eta[0,0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + for i in range(self.n_steps): + + grad = gradFunc_putzky(eta.detach(), [data, x[1]]).detach() + # grad = manual_grad(eta.detach(), [data, x[1]]).detach() + # plt.imshow(torch.absolute(grad[0,0]+1j*grad[0,1]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + # break + + input = torch.cat((eta.detach(),grad), dim=1) + + + delta, hx = self.cRNN(input, hx) + # plt.imshow(torch.abs(grad[0,0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + # plt.imshow(torch.absolute(grad[0,0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + # print(hx[0].requires_grad) + # plt.imshow(torch.abs(delta[0,0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + eta = eta.detach() + delta + # plt.imshow(torch.abs(eta[0,0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + etas.append(eta) + + + + return [eta/eta.shape[2]**2 for eta in etas] + + +class ConvRNN_SRBlock(nn.Module): + def __init__(self): + super().__init__() + self.pre = nn.Sequential( + SRBlock(4, 64), + SRBlock(64, 64), + ) + self.GRU1 = ConvGRUCell(64, 64, 3) + self.conv2 = SRBlock(64, 64) + self.GRU2 = ConvGRUCell(64, 64, 3) + self.conv3 = SRBlock(64, 2) + + def forward(self, x, hx=None): + if not hx: + hx = [None]*2 + + c1 = self.pre(x) + g1 = self.GRU1(c1, hx[0]) + c2 = self.conv2(g1) + g2 = self.GRU2(c2, hx[1]) + c3 = self.conv3(g2) + # plt.imshow(torch.absolute(c2[0,0]+1j*c2[0,1]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + + return c3, [g1.detach(), g2.detach()] # ??? detach() ??? + + +class RIM_SR(nn.Module): + def __init__(self, n_steps=2): + super().__init__() + torch.cuda.set_device(0) + # torch.set_default_dtype(torch.float64) ## this is really important since we do a lot of ffts. otherwise torch.zeros is float32 and we can't save complex128 into it! + self.n_steps = n_steps + self.cRNN = ConvRNN_SRBlock() + + + def forward(self, x, hx=None): + ap = x[0] + amp = ap[:,0] + phase = ap[:,1] + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) #k measured + + + data = compl.clone().detach() + compl_shift = torch.fft.fftshift(compl) # shift low freq to corner + ifft = torch.fft.ifft2(compl_shift) + ifft_shift = torch.fft.ifftshift(ifft) # shift low freq to center + eta = torch.zeros(ap.shape, dtype=torch.float64).to('cuda') + eta[:,0] = ifft_shift.real + eta[:,1] = ifft_shift.imag + + + + etas = [] + + for i in range(self.n_steps): + + grad = gradFunc(eta, data, x[2], x[1], 8, 45).detach() + # plt.imshow(torch.absolute(grad[0,0]+1j*grad[0,1]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + # break + + input = torch.cat((eta.float(),grad.float()), dim=1) + + + delta, hx = self.cRNN(input, hx) + + # print(hx[0].requires_grad) + eta = eta + delta + # print(delta.requires_grad) + etas.append(eta) + + # plt.imshow(torch.absolute(delta[0,0]+1j*delta[0,1]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + return etas + +class putzky(nn.Module): + def __init__(self): + super().__init__() + torch.cuda.set_device(0) + convrnn = rim.conv_rnn.ConvRNN(4) + self.rim = rim.rim.RIM(convrnn, gradFunc_putzky) + + def forward(self, x): + ap = x[0] + amp = ap[:,0] + phase = ap[:,1] + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) #k measured + + + data = compl.clone().detach() + compl_shift = torch.fft.fftshift(compl) # shift low freq to corner + ifft = torch.fft.ifft2(compl_shift) + ifft_shift = torch.fft.ifftshift(ifft) # shift low freq to center + eta = torch.zeros(ap.shape, dtype=torch.float64).to('cuda').float() + eta[:,0] = ifft_shift.real + eta[:,1] = ifft_shift.imag + + etas, hx = self.rim.forward(eta, [data,x[1]], n_steps=8, accumulate_eta=True) + return etas diff --git a/radionets/dl_framework/learner.py b/radionets/dl_framework/learner.py index ed67ed62..6e978991 100644 --- a/radionets/dl_framework/learner.py +++ b/radionets/dl_framework/learner.py @@ -19,6 +19,7 @@ import torchvision from radionets.dl_training.utils import define_arch from fastai.vision.gan import GANLearner, FixedGANSwitcher, _tk_diff, GANDiscriminativeLR +from fastai.callback.mixup import MixUp def get_learner( @@ -85,6 +86,7 @@ def define_learner( DataAug(vgg=train_conf["vgg"], physics_informed=train_conf["physics_informed"]), # OverwriteOneBatch_CLEAN(5), # OverwriteOneBatch_CLEAN(10), + MixUp(), ] ) if gan: diff --git a/radionets/dl_framework/loss_functions.py b/radionets/dl_framework/loss_functions.py index 63ae488b..0818ef40 100644 --- a/radionets/dl_framework/loss_functions.py +++ b/radionets/dl_framework/loss_functions.py @@ -173,6 +173,31 @@ def l1_RIM(x,y): loss = loss/len(x) return loss +def mse_RIM(x,y): + amp = y[:,0].clone().detach() + phase = y[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + compl_shift = torch.fft.fftshift(compl) + ifft = torch.fft.ifft2(compl_shift, norm="forward") + true = torch.fft.ifftshift(ifft).unsqueeze(1) + # spatial = torch.zeros((ifft.size(0),2) + ifft.size()[1:]).to('cuda') + # spatial[:,0] = torch.fft.ifftshift(ifft).real.squeeze(1) + # spatial[:,1] = torch.fft.ifftshift(ifft).imag.squeeze(1) + + complex2channels_y = torch.cat((true.real,true.imag), dim=1) + + mse = nn.MSELoss() + loss = 0 + for eta in x: + complex2channels_x = torch.cat((eta.real,eta.imag), dim=1) + # print(complex2channels_x.shape) + # print(complex2channels_y.shape) + loss += mse(complex2channels_x*eta.shape[2]**2,complex2channels_y) + + loss = loss/len(x) + return loss + def l1_wgan_GANCS(fake_pred,x,y): amp = y[:,0].clone().detach() phase = y[:,1].clone().detach() diff --git a/radionets/dl_framework/model.py b/radionets/dl_framework/model.py index 3c144913..369bc885 100644 --- a/radionets/dl_framework/model.py +++ b/radionets/dl_framework/model.py @@ -849,23 +849,182 @@ def forward(self, x, hx=None): r = torch.sigmoid(ih[1] + hh[1]) n = torch.tanh(ih[2]+ r*hh[2]) + # import matplotlib.pyplot as plt + # plt.imshow(torch.abs(hx[0,0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() hx = (1-z)*hx + z*n return hx def gradFunc(x, y, A, base_mask, n_tel, base_nums): + does_require_grad = x.requires_grad + with torch.enable_grad(): + x.requires_grad_(True) + + mask = torch.sum(base_mask, 3) + mask[mask != 0] = 1 + + fx = torch.fft.fft2(torch.fft.fftshift(x[:,0]+1j*x[:,1])) # shift x low freq to corner & fft + pfx = torch.einsum('blm,blm->blm', mask, torch.fft.ifftshift(fx)) # shift low freq to center + py = torch.einsum('blm,blm->blm', mask, y) + + # import matplotlib.pyplot as plt + # plt.imshow(torch.absolute(py[0]-pfx[0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + diff = (py-pfx)**2 + # import matplotlib.pyplot as plt + # plt.imshow(torch.absolute(py[0]-pfx[0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + # diff_shift = torch.fft.fftshift(diff) # shift low freq to corner + + # error = torch.sum(torch.fft.ifftshift(torch.fft.ifft2(diff_shift))) # ifft & shift low freq to center + + # grad = torch.zeros((ift.size(0), 2) + ift.size()[1:]).to('cuda') + # grad[:,0] = ift.real.squeeze(1) + # grad[:,1] = ift.imag.squeeze(1) + + grad_x = torch.autograd.grad(torch.sum(diff), inputs=x, retain_graph=does_require_grad, + create_graph=does_require_grad)[0] + + + # import matplotlib.pyplot as plt + # plt.imshow(np.absolute((grad_x[:,0]+1j*grad_x[:,1])[0].cpu().detach().numpy())) + # plt.colorbar() + # plt.show() + # import matplotlib.pyplot as plt + # plt.imshow(np.absolute(torch.fft.ifftshift(torch.fft.ifft(torch.fft.fftshift(grad_x)))[0].cpu().detach().numpy())) + # plt.colorbar() + # plt.show() + + + + # import matplotlib.pyplot as plt + # # print(grad_x.shape) + # plt.imshow(torch.absolute(diff[0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + x.requires_grad_(does_require_grad) + + return grad_x + + +def gradFunc2(x, y, A, base_mask, n_tel, base_nums): + + mask = torch.sum(base_mask, 3) mask[mask != 0] = 1 - fx = torch.fft.fft2(torch.fft.fftshift(x[:,0]+1j*x[:,1])) - pfx = torch.einsum('blm,blm->blm', mask, fx) + fx = torch.fft.fft2(torch.fft.fftshift(x[:,0]+1j*x[:,1])) # shift x low freq to corner & fft + pfx = torch.einsum('blm,blm->blm', mask, torch.fft.ifftshift(fx)) # shift low freq to center + py = torch.einsum('blm,blm->blm', mask, y) # mask y otherwise diff is not zero if x=y since we do a lot of ffts + + diff = pfx-py + diff_shift = torch.fft.fftshift(diff) # shift low freq to corner + error = torch.fft.ifftshift(torch.fft.ifft2(diff_shift)) + - diff = pfx-y + grad = torch.zeros((error.size(0), 2) + error.size()[1:]).to('cuda') + grad[:,0] = error.real.squeeze(1) + grad[:,1] = error.imag.squeeze(1) - ift = torch.fft.fftshift(torch.fft.ifft2(diff)) + # import matplotlib.pyplot as plt + # # # print(grad_x.shape) + # plt.imshow(torch.absolute(error[0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() - grad = torch.zeros((ift.size(0), 2) + ift.size()[1:]).to('cuda') - grad[:,0] = ift.real.squeeze(1) - grad[:,1] = ift.imag.squeeze(1) + return grad - return grad \ No newline at end of file +def gradFunc_putzky(x, y): + base_mask = y[1] + data = y[0] + does_require_grad = x.requires_grad + with torch.enable_grad(): + x.requires_grad_(True) + + mask = torch.sum(base_mask, 3) + mask[mask != 0] = 1 + + fx = torch.fft.fft2(torch.fft.fftshift(x), norm="forward") + + # import matplotlib.pyplot as plt + # plt.imshow((torch.abs(torch.fft.ifftshift(torch.fft.ifft2(fx))))[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + pfx = torch.einsum('blm,bclm->bclm', torch.flip(mask, [1]), torch.fft.ifftshift(fx)) # shift low freq to center + # py = torch.einsum('blmforward,bclm->bclm', mask, data) + # plt.imshow(torch.abs(pfx-data)[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + difference = pfx-data + + # import matplotlib.pyplot as plt + # plt.imshow((torch.abs(torch.fft.ifftshift(torch.fft.ifft2(data))-torch.fft.ifftshift(torch.fft.ifft2(pfx))))[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + + + chi2 = torch.sum(torch.square(torch.abs(difference))) + + + grad_x = torch.autograd.grad(chi2, inputs=x, retain_graph=does_require_grad, + create_graph=does_require_grad)[0] + + # import matplotlib.pyplot as plt + # # # print(grad_x.shape) + # # test[test==0] = 1 + # # plt.figure(figsize=(12,8)) + # plt.imshow((torch.abs(grad_x))[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + x.requires_grad_(does_require_grad) + + return grad_x + +def manual_grad(x, y): + base_mask = y[1] + data = y[0] + + mask = torch.sum(base_mask, 3) + mask[mask != 0] = 1 + + fx = torch.fft.fft2(torch.fft.fftshift(x), norm="forward") + + # import matplotlib.pyplot as plt + # plt.imshow((torch.abs(torch.fft.ifftshift(torch.fft.ifft2(fx))))[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + pfx = torch.einsum('blm,bclm->bclm', torch.flip(mask, [1]), torch.fft.ifftshift(fx)) # shift low freq to center + # py = torch.einsum('blmforward,bclm->bclm', mask, data) + # plt.imshow(torch.abs(pfx-data)[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + difference = pfx-data + + # import matplotlib.pyplot as plt + # plt.imshow((torch.abs(torch.fft.ifftshift(torch.fft.ifft2(data))-torch.fft.ifftshift(torch.fft.ifft2(pfx))))[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + grad_x = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.fftshift(difference), norm='forward')) + + # import matplotlib.pyplot as plt + # # # print(grad_x.shape) + # # test[test==0] = 1 + # # plt.figure(figsize=(12,8)) + # plt.imshow((torch.abs(grad_x))[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + return grad_x + +def rnd_dirty_noise(x, basemask): + noise = torch.random.normal(size=x.shape) + ft_noise = torch.fft.ifftshift(torch.fft.fft2(torch.fft.fftshift(noise))) + diff --git a/radionets/evaluation/utils.py b/radionets/evaluation/utils.py index 1fa79318..8994ff11 100644 --- a/radionets/evaluation/utils.py +++ b/radionets/evaluation/utils.py @@ -217,6 +217,10 @@ def load_pretrained_model(arch_name, model_path, img_size=63): arch = getattr(architecture, arch_name)() elif 'RIM' in arch_name: arch = getattr(architecture, arch_name)() + elif 'RIM_SR' in arch_name: + arch = getattr(architecture, arch_name)() + elif 'putzky' in arch_name: + arch = getattr(architecture, arch_name)() elif 'automap' in arch_name: arch = getattr(architecture, arch_name)() elif "filter_deep" in arch_name or "resnet" or "Res" in arch_name: diff --git a/radionets/simulations/process_vlbi.py b/radionets/simulations/process_vlbi.py index 88b717e9..c0162397 100644 --- a/radionets/simulations/process_vlbi.py +++ b/radionets/simulations/process_vlbi.py @@ -122,19 +122,19 @@ def process_data( -def process_data_dirty_model(data_path, freq, n_positions, fov_asec): +def process_data_dirty_model(data_path, freq, n_positions, fov_asec, layout): print(f"\n Loading VLBI data set.\n") bundles = dt.get_bundles(data_path) freq = freq*10**6 # mhz hard code #eht 227297 - uvfits = dt.get_bundles(bundles[3]) - imgs = dt.get_bundles(bundles[2]) + uvfits = dt.get_bundles(bundles[2]) + imgs = dt.get_bundles(bundles[1]) configs = dt.get_bundles(bundles[0]) uv_srt = natsorted(uvfits, alg=ns.PATH) img_srt = natsorted(imgs, alg=ns.PATH) size = 1000 for p in tqdm(range(n_positions)): - N = 63 # hard code + N = 64 # hard code with fits.open(uv_srt[p*1000]) as hdul: n_sampled = hdul[0].data.shape[0] #number of sampled points baselines = hdul[0].data['Baseline'] @@ -143,12 +143,13 @@ def process_data_dirty_model(data_path, freq, n_positions, fov_asec): unique_baselines = (unique_telescopes**2 - unique_telescopes)/2 # response matrices - A = response(configs[p], N, unique_telescopes, unique_baselines) + A = response(configs[p], N, unique_telescopes, unique_baselines, layout) - img = np.zeros((size,256,256)) + img = np.zeros((size,128,128)) samps = np.zeros((size,4,n_sampled*2)) print(f"\n Load subset.\n") for i in np.arange(p*1000, p*1000+1000): + # print(i) sampled = uv_srt[i] target = img_srt[i] # +1000 because I had to only grid images from 1000-1999 @@ -216,9 +217,9 @@ def process_data_dirty_model(data_path, freq, n_positions, fov_asec): save_fft_pair_with_response(out, samp_img[800:], fft_scaled_truth[800:], np.expand_dims(base_mask,0), np.expand_dims(A,0)) -def response(config, N, unique_telescopes, unique_baselines): +def response(config, N, unique_telescopes, unique_baselines, layout='vlba'): rc = ut.read_config(config) - array_layout = layouts.get_array_layout('vlba') + array_layout = layouts.get_array_layout(layout) src_crd = rc['src_coord'] wave = const.c/((float(rc['channel'].split(':')[0]))*10**6/un.second)/un.meter @@ -241,7 +242,7 @@ def process_measurement(data_path, file, config, fov_asec): print(f"\n Loading VLBI data set.\n") configs = config size = 1 - N=63 + N=64 with fits.open(file) as hdul: n_sampled = hdul[0].data.shape[0] #number of sampled points baselines = hdul[0].data['Baseline'] @@ -251,10 +252,11 @@ def process_measurement(data_path, file, config, fov_asec): freq = hdul[0].header[37] offset = hdul[2].data['IF FREQ'] for o in offset[0][1:]: + break baselines = np.append(baselines,hdul[0].data['Baseline']) baselines = np.append(baselines,baselines) # response matrices - A = response(configs, N, unique_telescopes, unique_baselines) + A = response(configs, N, unique_telescopes, unique_baselines, 'vlba') samps = np.zeros((size,4,n_sampled*2)) print(f"\n Load subset.\n") @@ -266,9 +268,9 @@ def process_measurement(data_path, file, config, fov_asec): x = cmplx[...,0,0] y = cmplx[...,0,1] w = cmplx[...,0,2] - x = np.squeeze(x) - y = np.squeeze(y) - w = np.squeeze(w) + x = np.squeeze(x)[:,0] + y = np.squeeze(y)[:,0] + w = np.squeeze(w)[:,0] ap = np.sqrt(x**2+y**2) ph = np.angle(x+1j*y) u = np.array([]) @@ -276,8 +278,11 @@ def process_measurement(data_path, file, config, fov_asec): for f in offset[0]: u = np.append(u,data['UU--']*(freq+f)) v = np.append(v,data['VV--']*(freq+f)) + break samps = [np.append(u,-u),np.append(v,-v),np.append(ap,ap),np.append(ph,-ph)] - + import matplotlib.pyplot as plt + plt.plot(samps[0], samps[1], 'x') + plt.show() print(f"\n Gridding VLBI data set.\n") # Generate Mask @@ -305,6 +310,8 @@ def process_measurement(data_path, file, config, fov_asec): img_resized = np.zeros((size,N,N)) samp_img = np.zeros((size,2,N,N)) + print(mask.shape) + print(samps[2].shape) samp_img[0,0] = np.matmul(mask, samps[2].T)/points samp_img[0,0] = (np.log10(samp_img[0,0] + 1e-10) / 10) + 1 samp_img[0,1] = np.matmul(mask, samps[3].T)/points @@ -312,7 +319,202 @@ def process_measurement(data_path, file, config, fov_asec): # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) - out = data_path + "/h5/samp_meas2.h5" + out = data_path + "/h5/samp_meas.h5" save_fft_pair_with_response(out, samp_img, samp_img, np.expand_dims(base_mask,0), np.expand_dims(A,0)) +def process_eht(data_path, file, config, fov_asec): + + print(f"\n Loading VLBI data set.\n") + configs = config + size = 1 + N=64 + with fits.open(file) as hdul: + n_sampled = hdul[0].data.shape[0] #number of sampled points + baselines = hdul[0].data['Baseline'] + + unique_telescopes = 8 + unique_baselines = (unique_telescopes**2 - unique_telescopes)/2 + freq = 227297e6#hdul[0].header[37] + offset = hdul[2].data['IF FREQ'] + baselines = np.append(baselines,baselines) + # response matrices + A = response(configs, N, unique_telescopes, unique_baselines, 'eht') + + # samps = np.zeros((size,4,n_sampled*2)) + print(f"\n Load subset.\n") + + + with fits.open(file) as hdul: + data = hdul[0].data + cmplx = data['DATA'] + x = cmplx[...,0,0] + y = cmplx[...,0,1] + w = cmplx[...,0,2] + x = np.squeeze(x) + y = np.squeeze(y) + w = np.squeeze(w) + ap = np.sqrt(x**2+y**2) + ph = np.angle(x+1j*y) + u = np.array([]) + v = np.array([]) + u = np.append(u,data['UU---SIN']*(freq)) + v = np.append(v,data['VV---SIN']*(freq)) + samps = [np.append(u,-u),np.append(v,-v),np.append(ap,ap),np.append(ph,-ph)] + + # plt.plot(samps[0], samps[1], 'x') + # plt.show() + print(f"\n Gridding VLBI data set.\n") + + # Generate Mask + u_0 = samps[0] + v_0 = samps[1] + mask = np.zeros((N,N,u_0.shape[0])) + + base_mask = np.zeros((N,N,int(unique_baselines))) + + fov = fov_asec*np.pi/(3600*180) # hard code #default 0.00018382 + delta_u = 1/(fov) # with a set N this is the same as zooming in since N*delta_u can be smaller than u_max + + for i in range(N): + for j in range(N): + u_cell = (j-N/2)*delta_u + v_cell = (i-N/2)*delta_u + mask[i,j] = ((u_cell <= u_0) & (u_0 <= u_cell+delta_u)) & ((v_cell <= v_0) & (v_0 <= v_cell+delta_u)) + + base = np.unique(baselines[mask[i,j].astype(bool)]) + base_mask[i,j,:base.shape[0]] = base + + mask = np.flip(mask, [0]) + import matplotlib.pyplot as plt + # plt.imshow(np.sum(mask, 2)) + # plt.show() + points = np.sum(mask, 2) + points[points==0] = 1 + + samp_img = np.zeros((size,2,N,N)) + print(mask.shape) + print(samps[2].shape) + samp_img[0,0] = np.matmul(mask, samps[2].T)/points + samp_img[0,0] = (np.log10(samp_img[0,0] + 1e-10) / 10) + 1 + samp_img[0,1] = np.matmul(mask, samps[3].T)/points + plt.imshow(samp_img[0,0]) + plt.colorbar() + plt.show() + plt.imshow(samp_img[0,1]) + plt.colorbar() + plt.show() + + + # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) + + out = data_path + "/h5/eht_hi.h5" + save_fft_pair_with_response(out, samp_img, samp_img, np.expand_dims(base_mask,0), np.expand_dims(A,0)) + + +def process_data_dirty_model_noisy(data_path, freq, n_positions, fov_asec, layout): + + print(f"\n Loading VLBI data set.\n") + bundles = dt.get_bundles(data_path) + freq = freq*10**6 # mhz hard code #eht 227297 + uvfits = dt.get_bundles(bundles[2]) + imgs = dt.get_bundles(bundles[1]) + configs = dt.get_bundles(bundles[0]) + uv_srt = natsorted(uvfits, alg=ns.PATH) + img_srt = natsorted(imgs, alg=ns.PATH) + size = 1000 + for p in tqdm(range(n_positions)): + N = 64 # hard code + with fits.open(uv_srt[p*1000]) as hdul: + n_sampled = hdul[0].data.shape[0] #number of sampled points + baselines = hdul[0].data['Baseline'] + baselines = np.append(baselines,baselines) + unique_telescopes = hdul[3].data.shape[0] + unique_baselines = (unique_telescopes**2 - unique_telescopes)/2 + + # response matrices + A = response(configs[p], N, unique_telescopes, unique_baselines, layout) + + img = np.zeros((size,128,128)) + samps = np.zeros((size,4,n_sampled*2)) + print(f"\n Load subset.\n") + for i in np.arange(p*1000, p*1000+1000): + # print(i) + sampled = uv_srt[i] + target = img_srt[i] # +1000 because I had to only grid images from 1000-1999 + + img[i-p*1000] = np.asarray(Image.open(str(target))) + + with fits.open(sampled) as hdul: + data = hdul[0].data + cmplx = data['DATA'] + x = cmplx[...,0,0] + y = cmplx[...,0,1] + w = cmplx[...,0,2] + x = np.squeeze(x) + y = np.squeeze(y) + w = np.squeeze(w) + ap = np.sqrt(x**2+y**2) + ph = np.angle(x+1j*y) + samps[i-p*1000] = [np.append(data['UU--']*freq,-data['UU--']*freq),np.append(data['VV--']*freq,-data['VV--']*freq),np.append(ap,ap),np.append(ph,-ph)] + + print(f"\n Gridding VLBI data set.\n") + + # Generate Mask + u_0 = samps[0][0] + v_0 = samps[0][1] + mask = np.zeros((N,N,u_0.shape[0])) + + base_mask = np.zeros((N,N,int(unique_baselines))) + + fov = fov_asec*np.pi/(3600*180) # hard code #default 0.00018382 + # delta_u = 1/(fov*N/256) # hard code + delta_u = 1/(fov) # with a set N this is the same as zooming in since N*delta_u can be smaller than u_max + # delta_u = (2*max(np.max(u_0),np.max(v_0))/N) # test gridding pixel size + # biggest_baselines = 8611*1e3 + # wave = const.c/(freq/un.second)/un.meter + # uv_max = biggest_baselines/wave + # delta_u = uv_max/N + # print(delta_u) + for i in range(N): + for j in range(N): + u_cell = (j-N/2)*delta_u + v_cell = (i-N/2)*delta_u + mask[i,j] = ((u_cell <= u_0) & (u_0 <= u_cell+delta_u)) & ((v_cell <= v_0) & (v_0 <= v_cell+delta_u)) + + base = np.unique(baselines[mask[i,j].astype(bool)]) + base_mask[i,j,:base.shape[0]] = base + + mask = np.flip(mask, [0]) + points = np.sum(mask, 2) + points[points==0] = 1 + samp_img = np.zeros((size,2,N,N)) + img_resized = np.zeros((size,N,N)) + for i in range(samps.shape[0]): + samp_img[i][0] = np.matmul(mask, samps[i][2].T)/points + samp_img[i][0] = (np.log10(samp_img[i][0] + 1e-10) / 10) + 1 + samp_img[i][1] = np.matmul(mask, samps[i][3].T)/points + img_resized[i] = cv2.resize(img[i], (N,N)) + img_resized[i] = img_resized[i]/np.sum(img_resized[i]) + + ### nooiiiiiseeeee + np.random.seed(42) + noise = np.random.normal(size=(size, N, N)) + m = np.zeros((1000,64,64)) + m[:] = np.sum(mask, 2) + m[m != 0] = 1 + ft_noise = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(noise))) + ft_noise[m == 0] = 0 + noise_dirty = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(ft_noise))) + + + + + # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) + truth_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(img_resized, axes=(1,2)), axes=(1,2)), axes=(1,2)) + fft_scaled_truth = prepare_fft_images(truth_fft, True, False) + + out = data_path + "/h5/samp_train"+ str(p) +".h5" + save_fft_pair_with_response(out, samp_img[:800], fft_scaled_truth[:800], np.expand_dims(base_mask,0), np.expand_dims(A,0)) + out = data_path + "/h5/samp_valid"+ str(p) + ".h5" + save_fft_pair_with_response(out, samp_img[800:], fft_scaled_truth[800:], np.expand_dims(base_mask,0), np.expand_dims(A,0)) From 8db63fd9ed4656964bbbedba057336ac30f2d976 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Tue, 30 Nov 2021 10:27:02 +0100 Subject: [PATCH 46/55] rim update --- radionets/dl_framework/architecture.py | 13 +- .../dl_framework/architectures/superRes.py | 183 ++++++++++++------ radionets/dl_framework/model.py | 58 +++++- 3 files changed, 181 insertions(+), 73 deletions(-) diff --git a/radionets/dl_framework/architecture.py b/radionets/dl_framework/architecture.py index 2129e6e0..9ba1a31c 100644 --- a/radionets/dl_framework/architecture.py +++ b/radionets/dl_framework/architecture.py @@ -1,6 +1,9 @@ -from radionets.dl_framework.architectures.basics import * -from radionets.dl_framework.architectures.unet import * -from radionets.dl_framework.architectures.filter_deep import * +#from radionets.dl_framework.architectures.basics import * +#from radionets.dl_framework.architectures.unet import * +# from radionets.dl_framework.architectures.filter_deep import * from radionets.dl_framework.architectures.superRes import * -from radionets.dl_framework.architectures.res_exp import * -from radionets.dl_framework.architectures.lists import * +from radionets.dl_framework.architectures.superRes import SRResNet_dirtyModel, SRResNet_dirtyModel_pretrainedL1, GANCS_generator, GANCS_generator_test, RIM, RIM_DC +from radionets.dl_framework.architectures.superRes import RIM_DC_noDetach +#from radionets.dl_framework.architectures.superRes import automap +# from radionets.dl_framework.architectures.res_exp import * +#from radionets.dl_framework.architectures.lists import * diff --git a/radionets/dl_framework/architectures/superRes.py b/radionets/dl_framework/architectures/superRes.py index 8aef9bf7..536f7a50 100644 --- a/radionets/dl_framework/architectures/superRes.py +++ b/radionets/dl_framework/architectures/superRes.py @@ -30,7 +30,8 @@ gradFunc, gradFunc2, gradFunc_putzky, - rnd_dirty_noise, + fft_conv, + ConvGRUCellBN, ) from functools import partial import torchvision @@ -1403,7 +1404,6 @@ def forward(self, x, hx=None): if not hx: hx = [None]*2 - complex2channels = torch.cat((x[:,0].real.unsqueeze(1),x[:,0].imag.unsqueeze(1),x[:,1].real.unsqueeze(1),x[:,1].imag.unsqueeze(1)), dim=1) # print(complex2channels.shape) # print(complex2channels.dtype) @@ -1466,8 +1466,6 @@ def forward(self, x, hx=None): # break input = torch.cat((eta.detach(),grad), dim=1) - - delta, hx = self.cRNN(input, hx) # plt.imshow(torch.abs(grad[0,0]).cpu().detach().numpy()) # plt.colorbar() @@ -1491,110 +1489,167 @@ def forward(self, x, hx=None): return [eta/eta.shape[2]**2 for eta in etas] -class ConvRNN_SRBlock(nn.Module): +class ConvRNN_deepClean_noDetach(nn.Module): def __init__(self): super().__init__() - self.pre = nn.Sequential( - SRBlock(4, 64), - SRBlock(64, 64), + self.conv1 = nn.Sequential( + nn.Conv2d(4, 64, 11, stride=4, dilation=1, padding=2), # use stride=4 for 63 px images # for blackhole model use 4, 64, 11, stride=3, dilation=1, padding=2 + nn.Tanh(), + ) + self.GRU1 = ConvGRUCell(64, 64, 11) + self.conv2 = nn.Sequential( + nn.ConvTranspose2d(64, 64, 11, stride=4, dilation=1, padding=2), + nn.Tanh(), + ) + self.GRU2 = ConvGRUCell(64, 64, 11) + self.conv3 = nn.Sequential( + nn.Conv2d(64, 2, 11, stride=1, dilation=1, padding=5, bias=False) ) - self.GRU1 = ConvGRUCell(64, 64, 3) - self.conv2 = SRBlock(64, 64) - self.GRU2 = ConvGRUCell(64, 64, 3) - self.conv3 = SRBlock(64, 2) + # self.weight = nn.Parameter(torch.tensor([0.25])) def forward(self, x, hx=None): if not hx: hx = [None]*2 - c1 = self.pre(x) + complex2channels = torch.cat((x[:,0].real.unsqueeze(1),x[:,0].imag.unsqueeze(1),x[:,1].real.unsqueeze(1),x[:,1].imag.unsqueeze(1)), dim=1) + # print(complex2channels.shape) + # print(complex2channels.dtype) + + c1 = self.conv1(complex2channels) g1 = self.GRU1(c1, hx[0]) c2 = self.conv2(g1) g2 = self.GRU2(c2, hx[1]) c3 = self.conv3(g2) + # plt.imshow(torch.absolute(c2[0,0]+1j*c2[0,1]).cpu().detach().numpy()) # plt.colorbar() # plt.show() + channels2complex = (c3[:,0]+1j*c3[:,1]).unsqueeze(1) + return channels2complex, [g1, g2] # ??? detach() ??? - return c3, [g1.detach(), g2.detach()] # ??? detach() ??? - - -class RIM_SR(nn.Module): - def __init__(self, n_steps=2): +class RIM_DC_noDetach(nn.Module): + def __init__(self, n_steps=10): super().__init__() - torch.cuda.set_device(0) + torch.cuda.set_device(1) # torch.set_default_dtype(torch.float64) ## this is really important since we do a lot of ffts. otherwise torch.zeros is float32 and we can't save complex128 into it! self.n_steps = n_steps - self.cRNN = ConvRNN_SRBlock() + + # self.cRNN = ConvRNN_deepClean_noDetach() + self.cRNN = ConvRNN_deepClean_noDetach_smallKernel() + # self.type(torch.complex64) + # torch.backends.cudnn.enabled = False - def forward(self, x, hx=None): + def forward(self, x, hx=None, factor=1): ap = x[0] amp = ap[:,0] + uv_cov = amp.unsqueeze(1).clone().detach() phase = ap[:,1] amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 compl = amp_rescaled * torch.exp(1j * phase) #k measured - data = compl.clone().detach() + data = compl.clone().detach().unsqueeze(1) compl_shift = torch.fft.fftshift(compl) # shift low freq to corner - ifft = torch.fft.ifft2(compl_shift) - ifft_shift = torch.fft.ifftshift(ifft) # shift low freq to center - eta = torch.zeros(ap.shape, dtype=torch.float64).to('cuda') - eta[:,0] = ifft_shift.real - eta[:,1] = ifft_shift.imag - + ifft = torch.fft.ifft2(compl_shift, norm="forward") + eta = torch.fft.ifftshift(ifft).unsqueeze(1)*factor # shift low freq to center + #calc beam + # uv_cov[uv_cov!=0] = 1 + # beam = abs(torch.fft.ifftshift(torch.fft.fft2(torch.fft.fftshift(uv_cov)))) + + # beam = beam/torch.max(torch.max(beam,2)[0],2)[0][:,:,None,None] + # plt.imshow(beam[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + etas = [] for i in range(self.n_steps): - grad = gradFunc(eta, data, x[2], x[1], 8, 45).detach() - # plt.imshow(torch.absolute(grad[0,0]+1j*grad[0,1]).cpu().detach().numpy()) - # plt.colorbar() - # plt.show() - # break - - input = torch.cat((eta.float(),grad.float()), dim=1) - + grad = gradFunc_putzky(eta, [data, x[1]]) + # if i == 0: + # eta = fft_conv(eta,beam) + input = torch.cat((eta,grad), dim=1) delta, hx = self.cRNN(input, hx) - - # print(hx[0].requires_grad) eta = eta + delta - # print(delta.requires_grad) - etas.append(eta) - - # plt.imshow(torch.absolute(delta[0,0]+1j*delta[0,1]).cpu().detach().numpy()) + # plt.imshow(abs(eta[0,0].cpu().detach().numpy())/(64**2), cmap='hot') # plt.colorbar() # plt.show() - - return etas + # plt.imshow(abs(grad[0,0].cpu().detach().numpy()), cmap='hot') + # plt.colorbar() + # plt.show() + etas.append(eta) + + # plt.imshow(abs(fft_conv(eta,beam)[0,0].cpu().detach().numpy())/(64**2), cmap='hot') + # plt.colorbar() + # plt.show() + # return [fft_conv(eta,beam)/eta.shape[2]**2 for eta in etas] + return [eta/eta.shape[2]**2 for eta in etas] + -class putzky(nn.Module): + +class ConvRNN_deepClean_noDetach_smallKernel(nn.Module): def __init__(self): super().__init__() - torch.cuda.set_device(0) - convrnn = rim.conv_rnn.ConvRNN(4) - self.rim = rim.rim.RIM(convrnn, gradFunc_putzky) + self.conv1 = nn.Sequential( + nn.Conv2d(4, 64, 3, stride=1, dilation=2, padding=2), + nn.Tanh(), + ) + self.conv1b = nn.Sequential( + nn.Conv2d(4, 64, 1, stride=1, dilation=2, padding=0), + nn.Tanh(), + ) + self.conv1c = nn.Sequential( + nn.Conv2d(4, 64, 5, stride=1, dilation=2, padding=4), + nn.Tanh(), + ) + self.GRU1 = ConvGRUCell(64*3, 64*3, 3) + self.conv2 = nn.Sequential( + nn.ConvTranspose2d(64, 64, 3, stride=1, dilation=2, padding=2), + nn.Tanh(), + ) + self.conv2b = nn.Sequential( + nn.ConvTranspose2d(64, 64, 1, stride=1, dilation=2, padding=0), + nn.Tanh(), + ) + self.conv2c = nn.Sequential( + nn.ConvTranspose2d(64, 64, 5, stride=1, dilation=2, padding=4), + nn.Tanh(), + ) + self.GRU2 = ConvGRUCell(64*3, 64*3, 3) + self.conv3 = nn.Sequential( + nn.Conv2d(64*3, 2, 3, stride=1, dilation=2, padding=2, bias=False) + ) + # self.weight = nn.Parameter(torch.tensor([0.25])) - def forward(self, x): - ap = x[0] - amp = ap[:,0] - phase = ap[:,1] - amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 - compl = amp_rescaled * torch.exp(1j * phase) #k measured - + def forward(self, x, hx=None): + if not hx: + hx = [None]*2 - data = compl.clone().detach() - compl_shift = torch.fft.fftshift(compl) # shift low freq to corner - ifft = torch.fft.ifft2(compl_shift) - ifft_shift = torch.fft.ifftshift(ifft) # shift low freq to center - eta = torch.zeros(ap.shape, dtype=torch.float64).to('cuda').float() - eta[:,0] = ifft_shift.real - eta[:,1] = ifft_shift.imag + complex2channels = torch.cat((x[:,0].real.unsqueeze(1),x[:,0].imag.unsqueeze(1),x[:,1].real.unsqueeze(1),x[:,1].imag.unsqueeze(1)), dim=1) + # print(complex2channels.shape) + # print(complex2channels.dtype) - etas, hx = self.rim.forward(eta, [data,x[1]], n_steps=8, accumulate_eta=True) - return etas + c1 = self.conv1(complex2channels) + c1b = self.conv1b(complex2channels) + c1c = self.conv1c(complex2channels) + comb = torch.cat((c1,c1b,c1c),dim=1) + g1 = self.GRU1(comb, hx[0]) + g1abc = torch.split(g1,64,dim=1) + c2 = self.conv2(g1abc[0]) + c2b = self.conv2(g1abc[1]) + c2c = self.conv2(g1abc[2]) + comb2 = torch.cat((c2,c2b,c2c),dim=1) + g2 = self.GRU2(comb2, hx[1]) + c3 = self.conv3(g2) + + # plt.imshow(torch.absolute(c2[0,0]+1j*c2[0,1]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + channels2complex = (c3[:,0]+1j*c3[:,1]).unsqueeze(1) + + return channels2complex, [g1, g2] # ??? detach() ??? diff --git a/radionets/dl_framework/model.py b/radionets/dl_framework/model.py index 369bc885..6966f064 100644 --- a/radionets/dl_framework/model.py +++ b/radionets/dl_framework/model.py @@ -857,6 +857,40 @@ def forward(self, x, hx=None): return hx +class ConvGRUCellBN(nn.Module): + def __init__(self, input_size, hidden_size, kernel_size, dilation=1, bias=True): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.kernel_size = kernel_size + self.dilation = dilation + self.bias = bias + + self.Wih = nn.Conv2d(input_size, 3*hidden_size, kernel_size, dilation=dilation, bias=bias, padding=dilation*(kernel_size-1)//2) + self.Whh = nn.Conv2d(input_size, 3*hidden_size, kernel_size, dilation=dilation, bias=bias, padding=dilation*(kernel_size-1)//2) + + self.bn1 = nn.BatchNorm2d(3*hidden_size) + self.bn2 = nn.BatchNorm2d(3*hidden_size) + + def forward(self, x, hx=None): + if hx is None: + hx = torch.zeros((x.size(0), self.hidden_size) + x.size()[2:], requires_grad=False).to('cuda') + + ih = self.bn1(self.Wih(x)).chunk(3, dim=1) + hh = self.bn2(self.Whh(hx)).chunk(3, dim=1) + + z = torch.sigmoid(ih[0] + hh[0]) + r = torch.sigmoid(ih[1] + hh[1]) + n = torch.tanh(ih[2]+ r*hh[2]) + + # import matplotlib.pyplot as plt + # plt.imshow(torch.abs(hx[0,0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + hx = (1-z)*hx + z*n + + return hx + def gradFunc(x, y, A, base_mask, n_tel, base_nums): does_require_grad = x.requires_grad with torch.enable_grad(): @@ -961,12 +995,24 @@ def gradFunc_putzky(x, y): # plt.colorbar() # plt.show() difference = pfx-data + # import matplotlib.pyplot as plt + # plt.imshow(abs(difference[0,0].cpu().detach().numpy())) + # plt.colorbar() + # plt.show() # import matplotlib.pyplot as plt # plt.imshow((torch.abs(torch.fft.ifftshift(torch.fft.ifft2(data))-torch.fft.ifftshift(torch.fft.ifft2(pfx))))[0,0].cpu().detach().numpy()) # plt.colorbar() # plt.show() + #import matplotlib.pyplot as plt + #plt.imshow((torch.abs(torch.fft.ifftshift(torch.fft.ifft2(data))))[0,0].cpu().detach().numpy()) + #plt.colorbar() + #plt.show() + #import matplotlib.pyplot as plt + #plt.imshow((torch.abs(torch.fft.ifftshift(torch.fft.ifft2(fx))))[0,0].cpu().detach().numpy()) + #plt.colorbar() + #plt.show() chi2 = torch.sum(torch.square(torch.abs(difference))) @@ -1024,7 +1070,11 @@ def manual_grad(x, y): return grad_x -def rnd_dirty_noise(x, basemask): - noise = torch.random.normal(size=x.shape) - ft_noise = torch.fft.ifftshift(torch.fft.fft2(torch.fft.fftshift(noise))) - +def fft_conv(a,b): + multiply = (torch.fft.fft2(torch.fft.fftshift(a))*torch.fft.fft2(torch.fft.fftshift(b), norm="ortho")) + ifft =torch.fft.ifftshift(torch.fft.ifft2(multiply)) + import matplotlib.pyplot as plt + # plt.imshow(abs(ifft[0,0].cpu().detach().numpy())) + # plt.colorbar() + # plt.show() + return ifft From a44d71d530264360922c992bcb60e8e91b57deec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Tue, 30 Nov 2021 10:29:16 +0100 Subject: [PATCH 47/55] loss --- radionets/dl_framework/learner.py | 2 +- radionets/dl_framework/loss_functions.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/radionets/dl_framework/learner.py b/radionets/dl_framework/learner.py index 6e978991..3033df3f 100644 --- a/radionets/dl_framework/learner.py +++ b/radionets/dl_framework/learner.py @@ -86,7 +86,7 @@ def define_learner( DataAug(vgg=train_conf["vgg"], physics_informed=train_conf["physics_informed"]), # OverwriteOneBatch_CLEAN(5), # OverwriteOneBatch_CLEAN(10), - MixUp(), + # MixUp(), ] ) if gan: diff --git a/radionets/dl_framework/loss_functions.py b/radionets/dl_framework/loss_functions.py index 0818ef40..85e5823f 100644 --- a/radionets/dl_framework/loss_functions.py +++ b/radionets/dl_framework/loss_functions.py @@ -181,9 +181,7 @@ def mse_RIM(x,y): compl_shift = torch.fft.fftshift(compl) ifft = torch.fft.ifft2(compl_shift, norm="forward") true = torch.fft.ifftshift(ifft).unsqueeze(1) - # spatial = torch.zeros((ifft.size(0),2) + ifft.size()[1:]).to('cuda') - # spatial[:,0] = torch.fft.ifftshift(ifft).real.squeeze(1) - # spatial[:,1] = torch.fft.ifftshift(ifft).imag.squeeze(1) + complex2channels_y = torch.cat((true.real,true.imag), dim=1) @@ -191,8 +189,6 @@ def mse_RIM(x,y): loss = 0 for eta in x: complex2channels_x = torch.cat((eta.real,eta.imag), dim=1) - # print(complex2channels_x.shape) - # print(complex2channels_y.shape) loss += mse(complex2channels_x*eta.shape[2]**2,complex2channels_y) loss = loss/len(x) From 71a4e8a7aa549b4143145a6215da0b8273e198ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Tue, 30 Nov 2021 10:31:24 +0100 Subject: [PATCH 48/55] minor changes gpu/cpu --- radionets/evaluation/plotting.py | 3 ++- radionets/evaluation/train_inspection.py | 33 ++++++++++++++++-------- radionets/evaluation/utils.py | 20 +++++++++----- 3 files changed, 37 insertions(+), 19 deletions(-) diff --git a/radionets/evaluation/plotting.py b/radionets/evaluation/plotting.py index 0c6153a8..98be804d 100644 --- a/radionets/evaluation/plotting.py +++ b/radionets/evaluation/plotting.py @@ -195,6 +195,7 @@ def visualize_with_fourier( out_path: str which contains the output path """ # reshaping and splitting in real and imaginary part if necessary + img_pred = img_pred.cpu() inp_real, inp_imag = img_input[0], img_input[1] real_pred, imag_pred = img_pred[0], img_pred[1] real_truth, imag_truth = img_truth[0], img_truth[1] @@ -353,7 +354,7 @@ def visualize_source_reconstruction( ): m_truth, n_truth, alpha_truth = calc_jet_angle(ifft_truth) m_pred, n_pred, alpha_pred = calc_jet_angle(ifft_pred) - x_space = torch.arange(0, 511, 1) + x_space = torch.arange(0, 63, 1) # plt.style.use("./paper_large_3.rc") fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 10), sharey=True) diff --git a/radionets/evaluation/train_inspection.py b/radionets/evaluation/train_inspection.py index f4f849cc..6c06de82 100644 --- a/radionets/evaluation/train_inspection.py +++ b/radionets/evaluation/train_inspection.py @@ -156,7 +156,7 @@ def create_source_plots(conf, num_images=3, rand=False): if not conf["fourier"]: click.echo("\n This is not a fourier dataset.\n") - pred = pred.numpy() + pred = pred.cpu().numpy() # inverse fourier transformation for prediction ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) @@ -191,7 +191,7 @@ def create_contour_plots(conf, num_images=3, rand=False): if not conf["fourier"]: click.echo("\n This is not a fourier dataset.\n") - pred = pred.numpy() + pred = pred.cpu().numpy() # inverse fourier transformation for prediction ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) @@ -233,7 +233,7 @@ def evaluate_viewing_angle(conf): pred = torch.cat((pred, pred_2), dim=1) ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + ifft_pred = get_ifft(pred.cpu(), amp_phase=conf["amp_phase"]) m_truth, n_truth, alpha_truth = calc_jet_angle(torch.tensor(ifft_truth)) m_pred, n_pred, alpha_pred = calc_jet_angle(torch.tensor(ifft_pred)) @@ -278,7 +278,7 @@ def evaluate_dynamic_range(conf): pred = torch.cat((pred, pred_2), dim=1) ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + ifft_pred = get_ifft(pred.cpu(), amp_phase=conf["amp_phase"]) dr_truth, dr_pred, _, _ = calc_dr(ifft_truth, ifft_pred) dr_truths = np.append(dr_truths, dr_truth) @@ -332,7 +332,7 @@ def evaluate_ms_ssim(conf): pred = torch.cat((pred, pred_2), dim=1) ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + ifft_pred = get_ifft(pred.cpu(), amp_phase=conf["amp_phase"]) if img_size < 160: ifft_truth = pad_unsqueeze(torch.tensor(ifft_truth)) @@ -357,7 +357,7 @@ def evaluate_ms_ssim(conf): def evaluate_mean_diff(conf): # create DataLoader loader = create_databunch( - conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"] + conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"], conf["rim"], ) model_path = conf["model_path"] out_path = Path(model_path).parent / "evaluation" @@ -374,20 +374,31 @@ def evaluate_mean_diff(conf): # iterate trough DataLoader for i, (img_test, img_true) in enumerate(tqdm(loader)): + + if conf["rim"]: + pred = eval_model(img_test, model)[9] + else: + pred = eval_model(img_test, model) - pred = eval_model(img_test, model) if conf["model_path_2"] != "none": pred_2 = eval_model(img_test, model_2) pred = torch.cat((pred, pred_2), dim=1) - ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) - + ifft_truth = np.fft.ifftshift(get_ifft(img_true, amp_phase=conf["amp_phase"])) + if conf["rim"]: + ifft_pred = abs(pred.cpu()) + else: + ifft_pred = get_ifft(pred.cpu(), amp_phase=conf["amp_phase"]) + # import matplotlib.pyplot as plt + # plt.imshow(ifft_truth[0]) + # plt.colorbar() + # plt.show() for pred, truth in zip(ifft_pred, ifft_truth): blobs_pred, blobs_truth = calc_blobs(pred, truth) flux_pred, flux_truth = crop_first_component( pred, truth, blobs_truth[0], out_path ) + # print(blobs_truth) vals.extend([1 - flux_truth.mean() / flux_pred.mean()]) click.echo("\nCreating mean_diff histogram.\n") @@ -426,7 +437,7 @@ def evaluate_area(conf): pred = torch.cat((pred, pred_2), dim=1) ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + ifft_pred = get_ifft(pred.cpu(), amp_phase=conf["amp_phase"]) for pred, truth in zip(ifft_pred, ifft_truth): val = area_of_contour(pred, truth) diff --git a/radionets/evaluation/utils.py b/radionets/evaluation/utils.py index 8994ff11..e22318fa 100644 --- a/radionets/evaluation/utils.py +++ b/radionets/evaluation/utils.py @@ -8,13 +8,18 @@ from torch.utils.data import DataLoader -def create_databunch(data_path, fourier, source_list, batch_size): +def create_databunch(data_path, fourier, source_list, batch_size, rim): # Load data sets + if rim: + mode = "valid" + else: + mode = "test" test_ds = load_data( data_path, - mode="test", + mode=mode, fourier=fourier, source_list=source_list, + physics_informed=rim ) # Create databunch with defined batchsize @@ -39,6 +44,7 @@ def read_config(config): eval_conf["source_list"] = config["general"]["source_list"] eval_conf["arch_name_2"] = config["general"]["arch_name_2"] eval_conf["diff"] = config["general"]["diff"] + eval_conf["rim"] = config["general"]["rim"] eval_conf["vis_pred"] = config["inspection"]["visualize_prediction"] eval_conf["vis_source"] = config["inspection"]["visualize_source_reconstruction"] @@ -281,8 +287,8 @@ def eval_model(img, model): """ model.eval() model.cuda() - if isinstance(img, tuple): - img = (img[0].unsqueeze(0).float().cuda(), img[1].float().cuda(), img[2].float().cuda()) + if isinstance(img, tuple) or isinstance(img, list): + img = (img[0].float().cuda(), img[1].float().cuda(), img[2].float().cuda()) #img[0].unsqueeze(0) with torch.no_grad(): pred = model(img) else: @@ -357,7 +363,7 @@ def fft_pred(pred, truth, amp_phase=True): compl_pred = a + 1j * b compl_true = a_true + 1j * b_true - ifft_pred = np.fft.ifft2(compl_pred) - ifft_true = np.fft.ifft2(compl_true) - + ifft_pred = np.fft.ifft2(np.fft.fftshift(compl_pred)) + ifft_true = np.fft.ifft2(np.fft.fftshift(compl_true)) + # return ifft_pred.real, ifft_true.real return np.absolute(ifft_pred), np.absolute(ifft_true) From eac4a420b2572daed58893a4844d71eeb469b7fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Tue, 30 Nov 2021 10:32:35 +0100 Subject: [PATCH 49/55] gridding --- radionets/simulations/process_vlbi.py | 181 ++++++++++++++++++++++++-- 1 file changed, 167 insertions(+), 14 deletions(-) diff --git a/radionets/simulations/process_vlbi.py b/radionets/simulations/process_vlbi.py index c0162397..5a30f9b8 100644 --- a/radionets/simulations/process_vlbi.py +++ b/radionets/simulations/process_vlbi.py @@ -128,8 +128,8 @@ def process_data_dirty_model(data_path, freq, n_positions, fov_asec, layout): bundles = dt.get_bundles(data_path) freq = freq*10**6 # mhz hard code #eht 227297 uvfits = dt.get_bundles(bundles[2]) - imgs = dt.get_bundles(bundles[1]) - configs = dt.get_bundles(bundles[0]) + imgs = dt.get_bundles(bundles[4]) + configs = dt.get_bundles(bundles[1]) uv_srt = natsorted(uvfits, alg=ns.PATH) img_srt = natsorted(imgs, alg=ns.PATH) size = 1000 @@ -330,13 +330,11 @@ def process_eht(data_path, file, config, fov_asec): size = 1 N=64 with fits.open(file) as hdul: - n_sampled = hdul[0].data.shape[0] #number of sampled points baselines = hdul[0].data['Baseline'] unique_telescopes = 8 unique_baselines = (unique_telescopes**2 - unique_telescopes)/2 - freq = 227297e6#hdul[0].header[37] - offset = hdul[2].data['IF FREQ'] + freq = 229071e6#hdul[0].header[37] baselines = np.append(baselines,baselines) # response matrices A = response(configs, N, unique_telescopes, unique_baselines, 'eht') @@ -354,12 +352,12 @@ def process_eht(data_path, file, config, fov_asec): x = np.squeeze(x) y = np.squeeze(y) w = np.squeeze(w) - ap = np.sqrt(x**2+y**2) - ph = np.angle(x+1j*y) + ap = np.sqrt((x*w)**2+(y*w)**2) + ph = np.angle(x*w+1j*y*w) u = np.array([]) v = np.array([]) - u = np.append(u,data['UU---SIN']*(freq)) - v = np.append(v,data['VV---SIN']*(freq)) + u = np.append(u,data['UU']*(freq)) + v = np.append(v,data['VV']*(freq)) samps = [np.append(u,-u),np.append(v,-v),np.append(ap,ap),np.append(ph,-ph)] # plt.plot(samps[0], samps[1], 'x') @@ -408,7 +406,7 @@ def process_eht(data_path, file, config, fov_asec): # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) - out = data_path + "/h5/eht_hi.h5" + out = data_path + "/eht_hi__startmod_weights.h5" save_fft_pair_with_response(out, samp_img, samp_img, np.expand_dims(base_mask,0), np.expand_dims(A,0)) @@ -418,8 +416,8 @@ def process_data_dirty_model_noisy(data_path, freq, n_positions, fov_asec, layou bundles = dt.get_bundles(data_path) freq = freq*10**6 # mhz hard code #eht 227297 uvfits = dt.get_bundles(bundles[2]) - imgs = dt.get_bundles(bundles[1]) - configs = dt.get_bundles(bundles[0]) + imgs = dt.get_bundles(bundles[4]) + configs = dt.get_bundles(bundles[1]) uv_srt = natsorted(uvfits, alg=ns.PATH) img_srt = natsorted(imgs, alg=ns.PATH) size = 1000 @@ -492,7 +490,7 @@ def process_data_dirty_model_noisy(data_path, freq, n_positions, fov_asec, layou img_resized = np.zeros((size,N,N)) for i in range(samps.shape[0]): samp_img[i][0] = np.matmul(mask, samps[i][2].T)/points - samp_img[i][0] = (np.log10(samp_img[i][0] + 1e-10) / 10) + 1 + # samp_img[i][0] = (np.log10(samp_img[i][0] + 1e-10) / 10) + 1 samp_img[i][1] = np.matmul(mask, samps[i][3].T)/points img_resized[i] = cv2.resize(img[i], (N,N)) img_resized[i] = img_resized[i]/np.sum(img_resized[i]) @@ -506,11 +504,166 @@ def process_data_dirty_model_noisy(data_path, freq, n_positions, fov_asec, layou ft_noise = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(noise))) ft_noise[m == 0] = 0 noise_dirty = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(ft_noise))) + + compl = samp_img[:,0]*np.exp(1j*samp_img[:,1]) + dirty_img = abs(np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(compl)))) + for idx, d in enumerate(dirty_img): + max = np.max(d) + std = np.std(noise_dirty[idx]) + snr = np.random.uniform(2,10) + alpha = max/(std*snr) + dirty_img[idx] = dirty_img[idx] + abs(noise_dirty[idx]*alpha) + + measured = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(dirty_img))) + mask = np.sum(mask, 2) + for idx, m in enumerate(measured): + samp_img[idx][0] = (np.log10(np.abs(m) + 1e-10) / 10) + 1 + samp_img[idx][0][mask == 0] = 0 + samp_img[idx][1] = np.angle(m) + samp_img[idx][1][mask == 0] = 0 + import matplotlib.pyplot as plt + plt.imshow(samp_img[0][0]) + plt.colorbar() + plt.show() + + + # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) + truth_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(img_resized, axes=(1,2)), axes=(1,2)), axes=(1,2)) + fft_scaled_truth = prepare_fft_images(truth_fft, True, False) + + out = data_path + "/h5/samp_train"+ str(p) +".h5" + save_fft_pair_with_response(out, samp_img[:800], fft_scaled_truth[:800], np.expand_dims(base_mask,0), np.expand_dims(A,0)) + out = data_path + "/h5/samp_valid"+ str(p) + ".h5" + save_fft_pair_with_response(out, samp_img[800:], fft_scaled_truth[800:], np.expand_dims(base_mask,0), np.expand_dims(A,0)) + +def process_data_dirty_model_noisy_pointSource(data_path, freq, n_positions, fov_asec, layout): + + print(f"\n Loading VLBI data set.\n") + bundles = dt.get_bundles(data_path) + print(bundles) + freq = freq*10**6 # mhz hard code #eht 227297 + uvfits = dt.get_bundles(bundles[2]) + imgs = dt.get_bundles(bundles[4]) + configs = dt.get_bundles(bundles[1]) + uv_srt = natsorted(uvfits, alg=ns.PATH) + img_srt = natsorted(imgs, alg=ns.PATH) + size = 1000 + for p in tqdm(range(n_positions)): + N = 64 # hard code + with fits.open(uv_srt[p*1000]) as hdul: + n_sampled = hdul[0].data.shape[0] #number of sampled points + baselines = hdul[0].data['Baseline'] + baselines = np.append(baselines,baselines) + unique_telescopes = hdul[3].data.shape[0] + unique_baselines = (unique_telescopes**2 - unique_telescopes)/2 + + # response matrices + A = response(configs[p], N, unique_telescopes, unique_baselines, layout) + + img = np.zeros((size,128,128)) + samps = np.zeros((size,4,n_sampled*2)) + print(f"\n Load subset.\n") + for i in np.arange(p*1000, p*1000+1000): + # print(i) + sampled = uv_srt[i] + target = img_srt[i] # +1000 because I had to only grid images from 1000-1999 + + img[i-p*1000] = np.asarray(Image.open(str(target))) + with fits.open(sampled) as hdul: + data = hdul[0].data + cmplx = data['DATA'] + x = cmplx[...,0,0] + y = cmplx[...,0,1] + w = cmplx[...,0,2] + x = np.squeeze(x) + y = np.squeeze(y) + w = np.squeeze(w) + ap = np.sqrt(x**2+y**2) + ph = np.angle(x+1j*y) + samps[i-p*1000] = [np.append(data['UU--']*freq,-data['UU--']*freq),np.append(data['VV--']*freq,-data['VV--']*freq),np.append(ap,ap),np.append(ph,-ph)] + + print(f"\n Gridding VLBI data set.\n") + + # Generate Mask + u_0 = samps[0][0] + v_0 = samps[0][1] + mask = np.zeros((N,N,u_0.shape[0])) + base_mask = np.zeros((N,N,int(unique_baselines))) + + fov = fov_asec*np.pi/(3600*180) # hard code #default 0.00018382 + # delta_u = 1/(fov*N/256) # hard code + delta_u = 1/(fov) # with a set N this is the same as zooming in since N*delta_u can be smaller than u_max + # delta_u = (2*max(np.max(u_0),np.max(v_0))/N) # test gridding pixel size + # biggest_baselines = 8611*1e3 + # wave = const.c/(freq/un.second)/un.meter + # uv_max = biggest_baselines/wave + # delta_u = uv_max/N + # print(delta_u) + for i in range(N): + for j in range(N): + u_cell = (j-N/2)*delta_u + v_cell = (i-N/2)*delta_u + mask[i,j] = ((u_cell <= u_0) & (u_0 <= u_cell+delta_u)) & ((v_cell <= v_0) & (v_0 <= v_cell+delta_u)) + + base = np.unique(baselines[mask[i,j].astype(bool)]) + base_mask[i,j,:base.shape[0]] = base + + mask = np.flip(mask, [0]) + points = np.sum(mask, 2) + points[points==0] = 1 + samp_img = np.zeros((size,2,N,N)) + img_resized = np.zeros((size,N,N)) + for i in range(samps.shape[0]): + samp_img[i][0] = np.matmul(mask, samps[i][2].T)/points + # samp_img[i][0] = (np.log10(samp_img[i][0] + 1e-10) / 10) + 1 + samp_img[i][1] = np.matmul(mask, samps[i][3].T)/points + img_resized[i] = cv2.resize(img[i], (N,N)) + img_resized[i] = img_resized[i]/np.sum(img_resized[i]) + + ### nooiiiiiseeeee + np.random.seed(42) + noise = np.random.normal(size=(size, N, N)) + m = np.zeros((1000,64,64)) + m[:] = np.sum(mask, 2) + m[m != 0] = 1 + ft_noise = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(noise))) + ft_noise[m == 0] = 0 + noise_dirty = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(ft_noise))) + + compl = samp_img[:,0]*np.exp(1j*samp_img[:,1]) + dirty_img = abs(np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(compl)))) + for idx, d in enumerate(dirty_img): + max = np.max(d) + std = np.std(noise_dirty[idx]) + snr = np.random.uniform(2,10) + alpha = max/(std*snr) + dirty_img[idx] = dirty_img[idx] + abs(noise_dirty[idx]*alpha) + + measured = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(dirty_img))) + mask = np.sum(mask, 2) + for idx, m in enumerate(measured): + samp_img[idx][0] = (np.log10(np.abs(m) + 1e-10) / 10) + 1 + samp_img[idx][0][mask == 0] = 0 + samp_img[idx][1] = np.angle(m) + samp_img[idx][1][mask == 0] = 0 + import matplotlib.pyplot as plt + plt.imshow(samp_img[0][0]) + plt.colorbar() + plt.show() + + + #point source label + position = np.zeros((size,N,N)) + result = np.array([np.unravel_index(np.argmax(r), r.shape) for r in img_resized]) + for i in range(2): + position[i,result[i][0],result[i][1]] = 1 + plt.imshow(position[0]) + plt.colorbar() + plt.show() - # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) truth_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(img_resized, axes=(1,2)), axes=(1,2)), axes=(1,2)) fft_scaled_truth = prepare_fft_images(truth_fft, True, False) From e53caf22b98e11a6bfd1225cfb19436148bfa1df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Tue, 30 Nov 2021 10:33:07 +0100 Subject: [PATCH 50/55] eht layout --- radionets/simulations/layouts/eht.txt | 8 ++++++++ radionets/simulations/layouts/eht_spt.txt | 9 +++++++++ 2 files changed, 17 insertions(+) create mode 100644 radionets/simulations/layouts/eht.txt create mode 100644 radionets/simulations/layouts/eht_spt.txt diff --git a/radionets/simulations/layouts/eht.txt b/radionets/simulations/layouts/eht.txt new file mode 100644 index 00000000..ccf0da8a --- /dev/null +++ b/radionets/simulations/layouts/eht.txt @@ -0,0 +1,8 @@ +#station_name X Y Z dish_dia el_low el_high SEFD altitude +ALMA50 2225037.1851 -5441199.162 -2479303.4629 84.7 15 85 110 5030 +SMTO -1828796.2 -5054406.8 3427865.2 10 15 85 11900 3185 +LMT -768713.9637 -5988541.7982 2063275.9472 50 15 85 560 4640 +Hawaii8 -5464523.4 -2493147.08 2150611.75 20.8 15 85 4900 4205 +PV 5088967.9 -301681.6 3825015.8 30 15 85 2900 2850 +PdBI 4523998.4 468045.24 4460309.76 7 15 85 1600 2550 +GLT 1500692 -1191735 6066409 12 15 85 4744 3210 \ No newline at end of file diff --git a/radionets/simulations/layouts/eht_spt.txt b/radionets/simulations/layouts/eht_spt.txt new file mode 100644 index 00000000..b05373c8 --- /dev/null +++ b/radionets/simulations/layouts/eht_spt.txt @@ -0,0 +1,9 @@ +station_name X Y Z dish_dia el_low el_high SEFD altitude +ALMA50 2225037.1851 -5441199.162 -2479303.4629 84.7 15 85 110 5030 +SMTO -1828796.2 -5054406.8 3427865.2 10 15 85 11900 3185 +LMT -768713.9637 -5988541.7982 2063275.9472 50 15 85 560 4640 +Hawaii8 -5464523.4 -2493147.08 2150611.75 20.8 15 85 4900 4205 +PV 5088967.9 -301681.6 3825015.8 30 15 85 2900 2850 +PdBI 4523998.4 468045.24 4460309.76 7 15 85 1600 2550 +SPT 0 0 -6359587.3 12 15 85 7300 2800 +GLT 1500692 -1191735 6066409 12 15 85 4744 3210 \ No newline at end of file From bdf49f028c5983b1b434f7b1ce391eea56a8242e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Tue, 30 Nov 2021 10:33:27 +0100 Subject: [PATCH 51/55] eht layout --- radionets/simulations/layouts/layouts.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/radionets/simulations/layouts/layouts.py b/radionets/simulations/layouts/layouts.py index e6a7e6a9..52f0b112 100644 --- a/radionets/simulations/layouts/layouts.py +++ b/radionets/simulations/layouts/layouts.py @@ -9,3 +9,9 @@ def vlba(): x, y, z, _, _ = np.genfromtxt(file_dir / "vlba.txt", unpack=True) ant_pos = np.array([x, y, z]) return ant_pos + +def eht(): + _, x, y, z, _, _, _, _, _ = np.genfromtxt(file_dir / "eht.txt", unpack=True) + print(x) + ant_pos = np.array([x, y, z]) + return ant_pos From bf75c6e8259c58071af23da2a80d3ce5ddd658af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Tue, 30 Nov 2021 10:34:11 +0100 Subject: [PATCH 52/55] markersize antenna distribution changed --- radionets/simulations/uv_plots.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/radionets/simulations/uv_plots.py b/radionets/simulations/uv_plots.py index 20990cfc..909a690d 100644 --- a/radionets/simulations/uv_plots.py +++ b/radionets/simulations/uv_plots.py @@ -100,7 +100,7 @@ class object with antenna positions and baselines between telescopes x_enu_ant, y_enu_ant, marker="o", - markersize=6, + markersize=15, color="#1f77b4", linestyle="none", label="Antenna positions", @@ -111,7 +111,7 @@ class object with antenna positions and baselines between telescopes marker="*", linestyle="none", color="#ff7f0e", - markersize=15, + markersize=20, transform=ccrs.Geodetic(), zorder=10, label="Projected source", @@ -120,8 +120,9 @@ class object with antenna positions and baselines between telescopes if baselines is True: plot_baselines(antenna) - plt.legend(fontsize=16, markerscale=1.5) + plt.legend(fontsize=16, markerscale=1.5,loc='lower center',bbox_to_anchor=(-0.2, 0)) plt.tight_layout() + return plt.gcf() def animate_baselines(source, antenna, filename, fps=5): From 77f74257c00fffa85a5922a38e8bea7dc82d2e9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Wed, 13 Apr 2022 10:56:37 +0200 Subject: [PATCH 53/55] gridding fix --- radionets/simulations/process_vlbi.py | 120 +++++++++++++++++++++++--- 1 file changed, 109 insertions(+), 11 deletions(-) diff --git a/radionets/simulations/process_vlbi.py b/radionets/simulations/process_vlbi.py index 5a30f9b8..8c37ac01 100644 --- a/radionets/simulations/process_vlbi.py +++ b/radionets/simulations/process_vlbi.py @@ -356,8 +356,8 @@ def process_eht(data_path, file, config, fov_asec): ph = np.angle(x*w+1j*y*w) u = np.array([]) v = np.array([]) - u = np.append(u,data['UU']*(freq)) - v = np.append(v,data['VV']*(freq)) + u = np.append(u,data['UU---SIN']*(freq)) + v = np.append(v,data['VV---SIN']*(freq)) samps = [np.append(u,-u),np.append(v,-v),np.append(ap,ap),np.append(ph,-ph)] # plt.plot(samps[0], samps[1], 'x') @@ -406,20 +406,118 @@ def process_eht(data_path, file, config, fov_asec): # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) - out = data_path + "/eht_hi__startmod_weights.h5" + out = data_path + "/eht_hi_test_DPG_startmod.h5" save_fft_pair_with_response(out, samp_img, samp_img, np.expand_dims(base_mask,0), np.expand_dims(A,0)) +def process_eht_hist(data_path, file, config, fov_asec): + + print(f"\n Loading VLBI data set.\n") + configs = config + size = 1 + N=64 + with fits.open(file) as hdul: + baselines = hdul[0].data['Baseline'] + + unique_telescopes = 8 + unique_baselines = (unique_telescopes**2 - unique_telescopes)/2 + freq = 229071e6#hdul[0].header[37] + baselines = np.append(baselines,baselines) + # response matrices + A = response(configs, N, unique_telescopes, unique_baselines, 'eht') + + # samps = np.zeros((size,4,n_sampled*2)) + print(f"\n Load subset.\n") + + + with fits.open(file) as hdul: + data = hdul[0].data + cmplx = data['DATA'] + x = cmplx[...,0,0] + y = cmplx[...,0,1] + w = cmplx[...,0,2] + x = np.squeeze(x) + y = np.squeeze(y) + w = np.squeeze(w) + ap = np.sqrt(x**2+y**2) + ph = np.angle(x+1j*y) + u = np.array([]) + v = np.array([]) + u = np.append(u,data['UU---SIN']*(freq)) + v = np.append(v,data['VV---SIN']*(freq)) + samps = [np.append(u,-u),np.append(v,-v),np.append(ap,ap),np.append(ph,-ph)] + + # plt.plot(samps[0], samps[1], 'x') + # plt.show() + print(f"\n Gridding VLBI data set.\n") + + # Generate Mask + u_0 = samps[0] + v_0 = samps[1] + mask = np.zeros((N,N,u_0.shape[0])) + + base_mask = np.zeros((N,N,int(unique_baselines))) + + fov = fov_asec*np.pi/(3600*180) # hard code #default 0.00018382 + delta_u = 1/(fov) # with a set N this is the same as zooming in since N*delta_u can be smaller than u_max + binpos = np.arange(N//2+1)*delta_u + binsneg = -np.flip(np.arange(N//2+1)*delta_u) + bins = np.append(binsneg,binpos) + bins = np.unique(bins) + for i in range(N): + for j in range(N): + u_cell = (j-N/2)*delta_u + v_cell = (i-N/2)*delta_u + mask[i,j] = ((u_cell <= u_0) & (u_0 <= u_cell+delta_u)) & ((v_cell <= v_0) & (v_0 <= v_cell+delta_u)) + + base = np.unique(baselines[mask[i,j].astype(bool)]) + base_mask[i,j,:base.shape[0]] = base + + import matplotlib.pyplot as plt + # plt.imshow(np.sum(mask, 2)) + # plt.show() + amp_cal,_,_,_ = plt.hist2d(samps[1],samps[0],bins=bins, weights=np.append(ap,ap)) + phase_cal,_,_,_ = plt.hist2d(samps[1],samps[0],bins=bins, weights=np.append(ph,-ph)) + points_cal,_,_,_ = plt.hist2d(samps[1],samps[0],bins=bins) + points_cal[points_cal==0]=1 + amp_cal = amp_cal/points_cal + phase_cal = phase_cal/points_cal + + samp_img = np.zeros((size,2,N,N)) + print(mask.shape) + print(samps[2].shape) + samp_img[0,0] = amp_cal + samp_img[0,0] = (np.log10(samp_img[0,0] + 1e-10) / 10) + 1 + samp_img[0,1] = phase_cal + plt.imshow(points_cal) + plt.colorbar() + plt.show() + plt.imshow(samp_img[0,0]) + plt.colorbar() + plt.show() + plt.imshow(samp_img[0,1]) + plt.colorbar() + plt.show() + + + # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) + + out = data_path + "/eht_hi_test_DPG.h5" + save_fft_pair_with_response(out, samp_img, samp_img, np.expand_dims(base_mask,0), np.expand_dims(A,0)) + + + def process_data_dirty_model_noisy(data_path, freq, n_positions, fov_asec, layout): print(f"\n Loading VLBI data set.\n") bundles = dt.get_bundles(data_path) freq = freq*10**6 # mhz hard code #eht 227297 - uvfits = dt.get_bundles(bundles[2]) - imgs = dt.get_bundles(bundles[4]) - configs = dt.get_bundles(bundles[1]) - uv_srt = natsorted(uvfits, alg=ns.PATH) - img_srt = natsorted(imgs, alg=ns.PATH) + uvfits = dt.get_bundles(bundles[3]) + imgs = dt.get_bundles(bundles[2]) + configs = dt.get_bundles(bundles[0]) + uv_srt = natsorted(uvfits, alg=ns.PATH)[50000:] + img_srt = natsorted(imgs, alg=ns.PATH)[50000:] + configs = natsorted(configs, alg=ns.PATH)[50:] size = 1000 for p in tqdm(range(n_positions)): N = 64 # hard code @@ -433,7 +531,7 @@ def process_data_dirty_model_noisy(data_path, freq, n_positions, fov_asec, layou # response matrices A = response(configs[p], N, unique_telescopes, unique_baselines, layout) - img = np.zeros((size,128,128)) + img = np.zeros((size,256,256)) samps = np.zeros((size,4,n_sampled*2)) print(f"\n Load subset.\n") for i in np.arange(p*1000, p*1000+1000): @@ -531,9 +629,9 @@ def process_data_dirty_model_noisy(data_path, freq, n_positions, fov_asec, layou truth_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(img_resized, axes=(1,2)), axes=(1,2)), axes=(1,2)) fft_scaled_truth = prepare_fft_images(truth_fft, True, False) - out = data_path + "/h5/samp_train"+ str(p) +".h5" + out = data_path + "/h5/bh/samp_train"+ str(p) +".h5" save_fft_pair_with_response(out, samp_img[:800], fft_scaled_truth[:800], np.expand_dims(base_mask,0), np.expand_dims(A,0)) - out = data_path + "/h5/samp_valid"+ str(p) + ".h5" + out = data_path + "/h5/bh/samp_valid"+ str(p) + ".h5" save_fft_pair_with_response(out, samp_img[800:], fft_scaled_truth[800:], np.expand_dims(base_mask,0), np.expand_dims(A,0)) def process_data_dirty_model_noisy_pointSource(data_path, freq, n_positions, fov_asec, layout): From 903672cf9b5cac36200bc115546388e5e74359e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Wed, 13 Apr 2022 12:12:25 +0200 Subject: [PATCH 54/55] delete unused architectures --- .../dl_framework/architectures/superRes.py | 1110 ++--------------- 1 file changed, 70 insertions(+), 1040 deletions(-) diff --git a/radionets/dl_framework/architectures/superRes.py b/radionets/dl_framework/architectures/superRes.py index 536f7a50..8aef9da4 100644 --- a/radionets/dl_framework/architectures/superRes.py +++ b/radionets/dl_framework/architectures/superRes.py @@ -39,268 +39,6 @@ import numpy as np import matplotlib.pyplot as plt import numpy as np -# import irim.rim as rim - -class superRes_simple(nn.Module): - def __init__(self, img_size): - super().__init__() - self.img_size = img_size - self.conv1_amp = nn.Sequential( - nn.Conv2d(1, 4, stride=2, kernel_size=3, padding=3 // 2), nn.ReLU() - ) - self.conv1_phase = nn.Sequential( - nn.Conv2d(1, 4, stride=2, kernel_size=3, padding=3 // 2), GeneralELU(1 - pi) - ) - self.conv2_amp = nn.Sequential( - nn.Conv2d(4, 8, stride=2, kernel_size=3, padding=3 // 2), nn.ReLU() - ) - self.conv2_phase = nn.Sequential( - nn.Conv2d(4, 8, stride=2, kernel_size=3, padding=3 // 2), GeneralELU(1 - pi) - ) - self.conv3_amp = nn.Sequential( - nn.Conv2d(8, 16, stride=2, kernel_size=3, padding=3 // 2), nn.ReLU() - ) - self.conv3_phase = nn.Sequential( - nn.Conv2d(8, 16, stride=2, kernel_size=3, padding=3 // 2), - GeneralELU(1 - pi), - ) - self.conv4_amp = nn.Sequential( - nn.Conv2d(16, 32, stride=2, kernel_size=3, padding=3 // 2), nn.ReLU() - ) - self.conv4_phase = nn.Sequential( - nn.Conv2d(16, 32, stride=2, kernel_size=3, padding=3 // 2), - GeneralELU(1 - pi), - ) - self.conv5_amp = nn.Sequential( - nn.Conv2d(32, 64, stride=2, kernel_size=3, padding=3 // 2), nn.ReLU() - ) - self.conv5_phase = nn.Sequential( - nn.Conv2d(32, 64, stride=2, kernel_size=3, padding=3 // 2), - GeneralELU(1 - pi), - ) - self.final_amp = nn.Sequential( - nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64, img_size ** 2) - ) - self.final_phase = nn.Sequential( - nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64, img_size ** 2) - ) - - def forward(self, x): - amp = x[:, 0, :].unsqueeze(1) - phase = x[:, 1, :].unsqueeze(1) - - amp = self.conv1_amp(amp) - phase = self.conv1_phase(phase) - - amp = self.conv2_amp(amp) - phase = self.conv2_phase(phase) - - amp = self.conv3_amp(amp) - phase = self.conv3_phase(phase) - - amp = self.conv4_amp(amp) - phase = self.conv4_phase(phase) - - amp = self.conv5_amp(amp) - phase = self.conv5_phase(phase) - - amp = self.final_amp(amp) - phase = self.final_phase(phase) - - amp = amp.reshape(-1, 1, self.img_size, self.img_size) - phase = phase.reshape(-1, 1, self.img_size, self.img_size) - - comb = torch.cat([amp, phase], dim=1) - return comb - - -class superRes_res18(nn.Module): - def __init__(self, img_size): - super().__init__() - torch.cuda.set_device(1) - self.img_size = img_size - - self.preBlock_amp = nn.Sequential( - nn.Conv2d(1, 64, 7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU() - ) - self.preBlock_phase = nn.Sequential( - nn.Conv2d(1, 64, 7, stride=2, padding=3), - nn.BatchNorm2d(64), - GeneralELU(1 - pi), - ) - - self.maxpool_amp = nn.MaxPool2d(3, 2, 1) - self.maxpool_phase = nn.MaxPool2d(3, 2, 1) - - # first block - self.layer1_amp = nn.Sequential(ResBlock_amp(64, 64), ResBlock_amp(64, 64)) - self.layer1_phase = nn.Sequential( - ResBlock_phase(64, 64), ResBlock_phase(64, 64) - ) - - self.layer2_amp = nn.Sequential( - ResBlock_amp(64, 128, stride=2), ResBlock_amp(128, 128) - ) - self.layer2_phase = nn.Sequential( - ResBlock_phase(64, 128, stride=2), ResBlock_phase(128, 128) - ) - - self.layer3_amp = nn.Sequential( - ResBlock_amp(128, 256, stride=2), ResBlock_amp(256, 256) - ) - self.layer3_phase = nn.Sequential( - ResBlock_phase(128, 256, stride=2), ResBlock_phase(256, 256) - ) - - self.layer4_amp = nn.Sequential( - ResBlock_amp(256, 512, stride=2), ResBlock_amp(512, 512) - ) - self.layer4_phase = nn.Sequential( - ResBlock_phase(256, 512, stride=2), ResBlock_phase(512, 512) - ) - - self.final_amp = nn.Sequential( - nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(512, img_size ** 2) - ) - self.final_phase = nn.Sequential( - nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(512, img_size ** 2) - ) - - def forward(self, x): - amp = x[:, 0, :].unsqueeze(1) - phase = x[:, 1, :].unsqueeze(1) - - amp = self.preBlock_amp(amp) - phase = self.preBlock_phase(phase) - - amp = self.maxpool_amp(amp) - phase = self.maxpool_phase(phase) - - amp = self.layer1_amp(amp) - phase = self.layer1_phase(phase) - - amp = self.layer2_amp(amp) - phase = self.layer2_phase(phase) - - amp = self.layer3_amp(amp) - phase = self.layer3_phase(phase) - - amp = self.layer4_amp(amp) - phase = self.layer4_phase(phase) - - amp = self.final_amp(amp) - phase = self.final_phase(phase) - - amp = amp.reshape(-1, 1, self.img_size, self.img_size) - phase = phase.reshape(-1, 1, self.img_size, self.img_size) - - comb = torch.cat([amp, phase], dim=1) - return comb - - -class superRes_res34(nn.Module): - def __init__(self, img_size): - super().__init__() - self.img_size = img_size - - self.preBlock_amp = nn.Sequential( - nn.Conv2d(1, 64, 7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU() - ) - self.preBlock_phase = nn.Sequential( - nn.Conv2d(1, 64, 7, stride=2, padding=3), - nn.BatchNorm2d(64), - GeneralELU(1 - pi), - ) - - self.maxpool_amp = nn.MaxPool2d(3, 2, 1) - self.maxpool_phase = nn.MaxPool2d(3, 2, 1) - - # first block - self.layer1_amp = nn.Sequential( - ResBlock_amp(64, 64), ResBlock_amp(64, 64), ResBlock_amp(64, 64) - ) - self.layer1_phase = nn.Sequential( - ResBlock_phase(64, 64), ResBlock_phase(64, 64), ResBlock_phase(64, 64) - ) - - self.layer2_amp = nn.Sequential( - ResBlock_amp(64, 128, stride=2), - ResBlock_amp(128, 128), - ResBlock_amp(128, 128), - ResBlock_amp(128, 128), - ) - self.layer2_phase = nn.Sequential( - ResBlock_phase(64, 128, stride=2), - ResBlock_phase(128, 128), - ResBlock_phase(128, 128), - ResBlock_phase(128, 128), - ) - - self.layer3_amp = nn.Sequential( - ResBlock_amp(128, 256, stride=2), - ResBlock_amp(256, 256), - ResBlock_amp(256, 256), - ResBlock_amp(256, 256), - ResBlock_amp(256, 256), - ResBlock_amp(256, 256), - ) - self.layer3_phase = nn.Sequential( - ResBlock_phase(128, 256, stride=2), - ResBlock_phase(256, 256), - ResBlock_phase(256, 256), - ResBlock_phase(256, 256), - ResBlock_phase(256, 256), - ResBlock_phase(256, 256), - ) - - self.layer4_amp = nn.Sequential( - ResBlock_amp(256, 512, stride=2), - ResBlock_amp(512, 512), - ResBlock_amp(512, 512), - ) - self.layer4_phase = nn.Sequential( - ResBlock_phase(256, 512, stride=2), - ResBlock_phase(512, 512), - ResBlock_phase(512, 512), - ) - - self.final_amp = nn.Sequential( - nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(512, img_size ** 2) - ) - self.final_phase = nn.Sequential( - nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(512, img_size ** 2) - ) - - def forward(self, x): - amp = x[:, 0, :].unsqueeze(1) - phase = x[:, 1, :].unsqueeze(1) - - amp = self.preBlock_amp(amp) - phase = self.preBlock_phase(phase) - - amp = self.maxpool_amp(amp) - phase = self.maxpool_phase(phase) - - amp = self.layer1_amp(amp) - phase = self.layer1_phase(phase) - - amp = self.layer2_amp(amp) - phase = self.layer2_phase(phase) - - amp = self.layer3_amp(amp) - phase = self.layer3_phase(phase) - - amp = self.layer4_amp(amp) - phase = self.layer4_phase(phase) - - amp = self.final_amp(amp) - phase = self.final_phase(phase) - - amp = amp.reshape(-1, 1, self.img_size, self.img_size) - phase = phase.reshape(-1, 1, self.img_size, self.img_size) - - comb = torch.cat([amp, phase], dim=1) - return comb class SRResNet(nn.Module): @@ -555,549 +293,102 @@ def forward(self, x): return self.btf(x) -class EDSRBase(nn.Module): - def __init__(self, img_size): + +class discriminator(nn.Module): + def __init__(self): super().__init__() - # torch.cuda.set_device(1) - self.img_size = img_size + torch.cuda.set_device(0) - self.preBlock = nn.Sequential( - nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2) - ) + self.preBlock = nn.Sequential(nn.Conv2d(1, 64, 3, stride=1, padding=1), nn.LeakyReLU(0.2)) - # ResBlock 16 - self.blocks = nn.Sequential( - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - EDSRBaseBlock(64, 64), - ) + self.block1 = nn.Sequential(nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2)) + self.block2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) + self.block3 = nn.Sequential(nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) + self.block4 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) + self.block5 = nn.Sequential(nn.Conv2d(256, 256, 3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) + self.block6 = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) + self.block7 = nn.Sequential(nn.Conv2d(512, 512, 3, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) - self.postBlock = nn.Sequential( - nn.Conv2d(64, 64, 3, stride=1, padding=1) - ) + self.main = nn.Sequential(self.block1, self.block2, self.block3, self.block4, self.block5, self.block6, self.block7) - self.final = nn.Sequential( - nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2) - ) + # self.postBlock = nn.Sequential(nn.Linear(512*4*4, 1024), nn.LeakyReLU(0.2), nn.Linear(1024,1), nn.Sigmoid()) #GAN + self.postBlock = nn.Sequential(nn.Linear(512*4*4, 1024), nn.LeakyReLU(0.2), nn.Linear(1024,1)) #WGAN def forward(self, x): - x = self.preBlock(x) - - x = x + self.postBlock(self.blocks(x)) - - x = self.final(x) - - return x - + if isinstance(x, tuple) or isinstance(x, list): + if len(x) == 2: + x = x[1] + else: + x = x[0] + + if x.shape[1] == 2: + amp_x = (10 ** (10 * x[:,0]) - 1) / 10 ** 10 + phase_x = x[:,1] + compl_x = amp_x * torch.exp(1j * phase_x) + ifft_x = torch.fft.ifft2(compl_x) + img_x = torch.absolute(ifft_x) + shift_x = torch.fft.ifftshift(img_x).unsqueeze(1) + else: + shift_x = x + # shift_x[torch.isnan(shift_x)] = 0 + pred = self.preBlock(shift_x) + pred = self.main(pred) + pred = torch.flatten(pred, 1) + pred = self.postBlock(pred) + return pred -class RDNet(nn.Module): - def __init__(self, img_size): +class GANCS_generator(nn.Module): + def __init__(self): super().__init__() torch.cuda.set_device(1) - self.img_size = img_size - - self.preBlock = nn.Sequential( - nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2, bias=False) - ) - - # ResBlock 6 - self.block1 = RDB(64, 32) - self.block2 = RDB(64, 32) - self.block3 = RDB(64, 32) - self.block4 = RDB(64, 32) - self.block5 = RDB(64, 32) - self.block6 = RDB(64, 32) - - - self.postBlock = nn.Sequential( - nn.Conv2d(6*64, 64, 1, stride=1, padding=0, bias=False), - nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False) + self.blocks = nn.Sequential( + SRBlock(2, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), ) - self.final = nn.Sequential( - nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2, bias=False) + self.post = nn.Sequential( + nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(), + nn.Conv2d(64, 64, 1, stride=1, padding=0), nn.ReLU(), + nn.Conv2d(64, 2, 1, stride=1, padding=0) ) + self.DC = HardDC(45, 10) + def forward(self, x): - x = self.preBlock(x) - - x1 = self.block1(x) - x2 = self.block2(x1) - x3 = self.block3(x2) - x4 = self.block4(x3) - x5 = self.block5(x4) - x6 = self.block6(x5) - - x = x + self.postBlock(torch.cat((x1,x2,x3,x4,x5,x6), dim=1)) - x = self.final(x) - return x - - -class SRFBNet(nn.Module): - def __init__(self, img_size): - super().__init__() - # torch.cuda.set_device(1) - self.img_size = img_size - - self.preBlock = nn.Sequential( - nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2, bias=False) - ) + ap = x[0] + base_mask = x[1] + A = x[2] - # ResBlock 6 - # self.block1 = FBB(64, 32, first=True) - # self.block2 = FBB(64, 32) - self.block1 = FBB(64, 32, first=True) - self.postBlock = nn.Sequential( - nn.Conv2d(64, 64, 1, stride=1, padding=0, bias=False), - nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False) - ) + amp = ap[:,0].clone().detach() + phase = ap[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) #k measured + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft) + # change to two channels real/imag + input = torch.zeros(ap.shape).to('cuda') + input[:,0] = spatial.real + input[:,1] = spatial.imag + # dirty = input.clone().detach() - self.final = nn.Sequential( - nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2, bias=False) - ) - def forward(self, x): - x = self.preBlock(x) + pred = self.blocks(input) + pred = self.post(pred) + pred = self.DC(pred, compl.unsqueeze(1), A, base_mask) - x1 = torch.zeros(x.shape).cuda() - for i in range(4): - x1 = self.block1(torch.cat((x,x1), dim=1)) - if i == 0: - block = x1 - else: - block = torch.cat((block,x1), dim=0) - - x = torch.cat((x,x,x,x), dim=0) + self.postBlock(block) - x = self.final(x) - return x + return pred -class vgg19_feature_maps(nn.Module): - def __init__(self, i, j): +class GANCS_critic(nn.Module): + def __init__(self): super().__init__() - # torch.cuda.set_device(1) - # load pretrained vgg19 - # vgg19 = torchvision.models.vgg19(pretrained=True) - # model = ut.load_pretrained_model(arch_name='vgg19_blackhole_group2', model_path='/net/big-tank/POOL/users/sfroese/vipy/eht/m87/blackhole/models/vgg19_group2.model') - # model = ut.load_pretrained_model(arch_name='vgg19_blackhole_group2_prelu', model_path='/net/big-tank/POOL/users/sfroese/vipy/eht/m87/blackhole/models/vgg19_groups2_prelu.model') - # model = ut.load_pretrained_model(arch_name='vgg19_blackhole_fft', model_path='/net/big-tank/POOL/users/sfroese/vipy/eht/m87/blackhole/models/vgg19_fft.model') - model = ut.load_pretrained_model(arch_name='vgg19_one_channel', model_path='/net/big-tank/POOL/projects/radio/simulations/jets/260521/model/temp_30.model') - - vgg19 = model.vgg - - conv_counter = 0 - maxpool_counter = 0 - truncate_at = 0 - for layer in vgg19.features: - truncate_at += 1 - - if isinstance(layer, nn.MaxPool2d): - maxpool_counter += 1 - conv_counter = 0 - if isinstance(layer, nn.Conv2d): - conv_counter += 1 - - if maxpool_counter == i - 1 and conv_counter == j: - break - - self.truncated_vgg19 = nn.Sequential(*list(vgg19.features)[:truncate_at + 1]) - for param in self.truncated_vgg19.parameters(): - param.requires_grad = False - - def forward(self, x): - if x.shape[1] == 2: - amp_rescaled = (10 ** (10 * x[:,0]) - 1) / 10 ** 10 - phase = x[:,1] - compl = amp_rescaled * torch.exp(1j * phase) - ifft = torch.fft.ifft2(compl) - img = torch.absolute(ifft) - shift = torch.fft.fftshift(img) - with torch.no_grad(): - feature = self.truncated_vgg19(img.unsqueeze(1)) - else: - feature = self.truncated_vgg19(x) - return feature - - - -class vgg19_blackhole(nn.Module): - def __init__(self): - super().__init__() - torch.cuda.set_device(1) - vgg19 = torchvision.models.vgg19(pretrained=False) - - # customize vgg19 - vgg19.features[0] = nn.Conv2d(2, 64, 3, stride=1, padding=1) - vgg19.classifier[6] = nn.Sequential(nn.Linear(in_features=4096, out_features=6, bias=True)) - - # for i, layer in enumerate(vgg19.features): - # if isinstance(layer, nn.Conv2d): - # vgg19.features[i] = nn.Conv2d(layer.in_channels, layer.out_channels, layer.kernel_size, layer.stride, layer.padding, groups=2) - # if isinstance(layer, nn.ReLU): - # vgg19.features[i] = nn.PReLU() - - self.vgg = vgg19 - - def forward(self, x): - return self.vgg(x) - - -class vgg19_blackhole_group2(nn.Module): - def __init__(self): - super().__init__() - # torch.cuda.set_device(1) - vgg19 = torchvision.models.vgg19(pretrained=False) - - # customize vgg19 - - vgg19.features[0] = nn.Conv2d(2, 64, 3, stride=1, padding=1) - vgg19.classifier[6] = nn.Sequential(nn.Linear(in_features=4096, out_features=6, bias=True)) - - for i, layer in enumerate(vgg19.features): - if isinstance(layer, nn.Conv2d): - vgg19.features[i] = nn.Conv2d(layer.in_channels, layer.out_channels, layer.kernel_size, layer.stride, layer.padding, groups=2) - # if isinstance(layer, nn.ReLU): - # vgg19.features[i] = nn.PReLU() - - self.vgg = vgg19 - - def forward(self, x): - #ifft - return self.vgg(x) - -class vgg19_blackhole_group2_prelu(nn.Module): - def __init__(self): - super().__init__() - # torch.cuda.set_device(1) - vgg19 = torchvision.models.vgg19(pretrained=False) - - # customize vgg19 - vgg19.features[0] = nn.Conv2d(2, 64, 3, stride=1, padding=1) - vgg19.classifier[6] = nn.Sequential(nn.Linear(in_features=4096, out_features=6, bias=True)) - - for i, layer in enumerate(vgg19.features): - if isinstance(layer, nn.Conv2d): - vgg19.features[i] = nn.Conv2d(layer.in_channels, layer.out_channels, layer.kernel_size, layer.stride, layer.padding, groups=2) - if isinstance(layer, nn.ReLU): - vgg19.features[i] = nn.PReLU() - - self.vgg = vgg19 - - def forward(self, x): - return self.vgg(x) - -class vgg19_blackhole_fft(nn.Module): - def __init__(self): - super().__init__() - # torch.cuda.set_device(1) - vgg19 = torchvision.models.vgg19(pretrained=False) - - # customize vgg19 - - vgg19.features[0] = nn.Conv2d(1, 64, 3, stride=1, padding=1) - vgg19.classifier[6] = nn.Sequential(nn.Linear(in_features=4096, out_features=6, bias=True)) - - # for i, layer in enumerate(vgg19.features): - # if isinstance(layer, nn.Conv2d): - # vgg19.features[i] = nn.Conv2d(layer.in_channels, layer.out_channels, layer.kernel_size, layer.stride, layer.padding, groups=2) - # # if isinstance(layer, nn.ReLU): - # vgg19.features[i] = nn.PReLU() - - self.vgg = vgg19 - - def forward(self, x): - #ifft - amp_rescaled = (10 ** (10 * x[:,0]) - 1) / 10 ** 10 - phase = x[:,1] - compl = amp_rescaled * torch.exp(1j * phase) - ifft = torch.fft.ifft2(compl) - img = torch.absolute(ifft) - shift = torch.fft.fftshift(img) - return self.vgg(img.unsqueeze(1)) - -class vgg19_one_channel(nn.Module): - def __init__(self): - super().__init__() - torch.cuda.set_device(1) - vgg19 = torchvision.models.vgg19(pretrained=False) - - # customize vgg19 - - vgg19.features[0] = nn.Conv2d(1, 64, 3, stride=1, padding=1) - vgg19.classifier[6] = nn.Sequential(nn.Linear(in_features=4096, out_features=2, bias=True)) - - self.vgg = vgg19 - - def forward(self, x): - #ifft - return self.vgg(x) - -class discriminator(nn.Module): - def __init__(self): - super().__init__() - torch.cuda.set_device(0) - - self.preBlock = nn.Sequential(nn.Conv2d(1, 64, 3, stride=1, padding=1), nn.LeakyReLU(0.2)) - - self.block1 = nn.Sequential(nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2)) - self.block2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) - self.block3 = nn.Sequential(nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) - self.block4 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) - self.block5 = nn.Sequential(nn.Conv2d(256, 256, 3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) - self.block6 = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) - self.block7 = nn.Sequential(nn.Conv2d(512, 512, 3, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) - - self.main = nn.Sequential(self.block1, self.block2, self.block3, self.block4, self.block5, self.block6, self.block7) - - # self.postBlock = nn.Sequential(nn.Linear(512*4*4, 1024), nn.LeakyReLU(0.2), nn.Linear(1024,1), nn.Sigmoid()) #GAN - self.postBlock = nn.Sequential(nn.Linear(512*4*4, 1024), nn.LeakyReLU(0.2), nn.Linear(1024,1)) #WGAN - - def forward(self, x): - if isinstance(x, tuple) or isinstance(x, list): - if len(x) == 2: - x = x[1] - else: - x = x[0] - - if x.shape[1] == 2: - amp_x = (10 ** (10 * x[:,0]) - 1) / 10 ** 10 - phase_x = x[:,1] - compl_x = amp_x * torch.exp(1j * phase_x) - ifft_x = torch.fft.ifft2(compl_x) - img_x = torch.absolute(ifft_x) - shift_x = torch.fft.ifftshift(img_x).unsqueeze(1) - else: - shift_x = x - # shift_x[torch.isnan(shift_x)] = 0 - pred = self.preBlock(shift_x) - pred = self.main(pred) - pred = torch.flatten(pred, 1) - pred = self.postBlock(pred) - return pred - -class SRResNet_dirtyModel_pretrainedL1(nn.Module): - def __init__(self, img_size): - super().__init__() - self.model = ut.load_pretrained_model(arch_name='SRResNet_dirtyModel', model_path='/net/big-tank/POOL/users/sfroese/vipy/jets/models/l1_symmetry.model') - - def forward(self, x): - return self.model(x) - - -class SRResNet_dirtyModel(nn.Module): - def __init__(self, img_size): - super().__init__() - torch.cuda.set_device(1) - self.img_size = img_size - - self.preBlock = nn.Sequential( - nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=1), nn.PReLU() - ) - - # ResBlock 16 - self.blocks = nn.Sequential( - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - ) - - self.postBlock = nn.Sequential( - nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64) - ) - # self.upscale = nn.Sequential( - # nn.Conv2d(64, 256, 3, stride=1, padding = 1), - # nn.PixelShuffle(2), - # nn.PReLU() - # ) - self.final = nn.Sequential( - nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=1), - ) - - self.relu = nn.Hardtanh(0,1.1) - self.pi = nn.Hardtanh(-np.pi,np.pi) - - self.symmetry = Lambda(better_symmetry) - - - def forward(self, x): - - amp = x[:,0].clone().detach() - phase = x[:,1].clone().detach() - amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 - compl = amp_rescaled * torch.exp(1j * phase) - ifft = torch.fft.ifft2(compl) - dirty = torch.fft.ifftshift(torch.absolute(ifft)) - dirty = dirty.unsqueeze(1) - - - pred = self.preBlock(x) - - pred = pred + self.postBlock(self.blocks(pred)) - # pred = self.postBlock(self.blocks(pred)) - - - # pred = self.upscale(pred) - - pred = self.final(pred) - - pred[:,0] = self.relu(pred[:,0].clone()) - pred[:,1] = self.pi(pred[:,1].clone()) - - pred = self.symmetry(pred) - - - # pred = self.relu(pred) - # pred = nn.functional.interpolate(pred, scale_factor=0.5) - - return dirty, pred - - -class GANCS_generator_test(nn.Module): - def __init__(self): - super().__init__() - torch.cuda.set_device(1) - self.blocks = nn.Sequential( - SRBlock(2, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - ) - - self.post = nn.Sequential( - nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(), - nn.Conv2d(64, 64, 1, stride=1, padding=0), nn.ReLU(), - nn.Conv2d(64, 2, 1, stride=1, padding=0) - ) - - self.DC = HardDC(45, 10) - - def forward(self, x): - ap = x[0] - base_mask = x[1] - A = x[2] - - - amp = ap[:,0].clone().detach() - phase = ap[:,1].clone().detach() - amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 - compl = amp_rescaled * torch.exp(1j * phase) #k measured - ifft = torch.fft.ifft2(compl) - spatial = torch.fft.ifftshift(ifft) - # change to two channels real/imag - input = torch.zeros(ap.shape).to('cuda') - input[:,0] = spatial.real - input[:,1] = spatial.imag - # dirty = input.clone().detach() - - - pred = self.blocks(input) - - pred = self.post(pred) - - pred = self.DC(pred, compl.unsqueeze(1), A, base_mask) - - - return pred - -class GANCS_generator(nn.Module): - def __init__(self): - super().__init__() - torch.cuda.set_device(1) - self.blocks = nn.Sequential( - SRBlock(2, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - ) - - self.post = nn.Sequential( - nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(), - nn.Conv2d(64, 64, 1, stride=1, padding=0), nn.ReLU(), - nn.Conv2d(64, 2, 1, stride=1, padding=0) - ) - - self.DC = HardDC(45, 10) - - def forward(self, x): - ap = x[0] - base_mask = x[1] - A = x[2] - - - amp = ap[:,0].clone().detach() - phase = ap[:,1].clone().detach() - amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 - compl = amp_rescaled * torch.exp(1j * phase) #k measured - ifft = torch.fft.ifft2(compl) - spatial = torch.fft.ifftshift(ifft) - # change to two channels real/imag - input = torch.zeros(ap.shape).to('cuda') - input[:,0] = spatial.real - input[:,1] = spatial.imag - # dirty = input.clone().detach() - - - pred = self.blocks(input) - - pred = self.post(pred) - - pred = self.DC(pred, compl.unsqueeze(1), A, base_mask) - - - return pred - - -class GANCS_critic(nn.Module): - def __init__(self): - super().__init__() - # self.blocks = nn.Sequential( - # nn.Conv2d(2, 4, 3, stride=2, padding=1), - # nn.LeakyReLU(0.2), - # nn.Conv2d(4, 8, 3, stride=2, padding=1), - # nn.LeakyReLU(0.2), - # nn.Conv2d(8, 16, 3, stride=2, padding=1), - # nn.LeakyReLU(0.2), - # nn.Conv2d(16, 32, 3, stride=2, padding=1), - # nn.LeakyReLU(0.2), - # nn.Conv2d(32, 32, 3, stride=1, padding=1), - # nn.LeakyReLU(0.2), - # nn.Conv2d(32, 32, 1, stride=1, padding=0), - # nn.LeakyReLU(0.2), - # nn.Conv2d(32, 1, 1, stride=1, padding=0), - # nn.AdaptiveAvgPool2d(1) - # ) self.block1 = nn.Sequential(nn.Conv2d(2, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2)) self.block2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) self.block3 = nn.Sequential(nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) @@ -1121,267 +412,6 @@ def forward(self, x): input[:,1] = x.imag.squeeze(1) return self.blocks(input) - -class GANCS_unrolled(nn.Module): - def __init__(self): - super().__init__() - torch.cuda.set_device(0) - self.block1 = nn.Sequential( - SRBlock(2, 64), - SRBlock(64, 64), - nn.Conv2d(64, 2, 3, stride=1, padding=1), - ) - self.DC1 = SoftDC(45, 10) - self.block2 = nn.Sequential( - SRBlock(2, 64), - SRBlock(64, 64), - nn.Conv2d(64, 2, 3, stride=1, padding=1), - ) - self.DC2 = SoftDC(45, 10) - self.block3 = nn.Sequential( - SRBlock(2, 64), - SRBlock(64, 64), - nn.Conv2d(64, 2, 3, stride=1, padding=1), - ) - self.DC3 = HardDC(45, 10) - # self.block4 = nn.Sequential( - # SRBlock(2, 64), - # SRBlock(64, 64), - # nn.Conv2d(64, 2, 3, stride=1, padding=1), - # ) - # self.DC4 = SoftDC(45, 10) - # self.block5 = nn.Sequential( - # SRBlock(2, 64), - # SRBlock(64, 64), - # nn.Conv2d(64, 2, 3, stride=1, padding=1), - # ) - # self.DC5 = SoftDC(45, 10) - - def forward(self, x): - ap = x[0] - base_mask = x[1] - A = x[2] - - - amp = ap[:,0].clone().detach() - phase = ap[:,1].clone().detach() - amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 - compl = amp_rescaled * torch.exp(1j * phase) #k measured - ifft = torch.fft.ifft2(compl) - spatial = torch.fft.ifftshift(ifft) - # change to two channels real/imag - input = torch.zeros(ap.shape).to('cuda') - input[:,0] = spatial.real - input[:,1] = spatial.imag - measured = input.clone().detach() - # dirty = input.clone().detach() - - - pred = self.block1(input) - dc1 = self.DC1(pred, measured, A, base_mask) - # pred[:,0] = dc1.real.squeeze(1) - # pred[:,1] = dc1.imag.squeeze(1) - pred = self.block2(pred) - dc2 = self.DC2(pred, measured, A, base_mask) - # pred[:,0] = dc2.real.squeeze(1) - # pred[:,1] = dc2.imag.squeeze(1) - pred = self.block3(pred) - dc3 = self.DC3(pred, compl.unsqueeze(1), A, base_mask) - pred[:,0] = dc3.real.squeeze(1) - pred[:,1] = dc3.imag.squeeze(1) - # pred = self.block4(pred) - # pred = self.DC4(pred, measured, A, base_mask) - # pred = self.block5(pred) - # pred = self.DC5(pred, measured, A, base_mask) - - #pred = self.post(pred) - - #pred = self.DC(pred, compl.unsqueeze(1), A, base_mask) - - return (pred[:,0]+1j*pred[:,1]).unsqueeze(1) - - -class CLEANNN(nn.Module): - def __init__(self): - super().__init__() - torch.cuda.set_device(0) - - self.blocks = nn.Sequential( - SRBlock(2, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64), - SRBlock(64, 64) - ) - - self.post = nn.Sequential( - nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(), - nn.Conv2d(64, 64, 1, stride=1, padding=0), nn.ReLU(), - nn.Conv2d(64, 2, 1, stride=1, padding=0) - ) - - self.lamb = nn.Parameter(torch.tensor(1).float()) - - # self.conv = nn.Conv2d(2, 2, 3, stride=1, padding=1, bias=False) - - # self.DC = HardDC(45, 10) - # self.beamBlock = nn.Sequential( - # SRBlock(2, 64), - # SRBlock(64, 64), - # SRBlock(64, 64), - # SRBlock(64, 64), - # SRBlock(64, 64), - # nn.Conv2d(64, 2, 1, stride=1, padding=0) - # ) - # self.block1 = nn.Sequential(nn.Conv2d(2, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2)) - # self.block2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) - # self.block3 = nn.Sequential(nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) - # self.block4 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) - # self.block5 = nn.Sequential(nn.Conv2d(256, 256, 3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) - # self.block6 = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) - # self.block7 = nn.Sequential(nn.Conv2d(512, 512, 3, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) - - # self.beamBlock = nn.Sequential(self.block1, self.block2, self.block3, self.block4, self.block5, self.block6, self.block7, nn.Conv2d(512, 1, 3, stride=1, padding=1), nn.AdaptiveAvgPool2d(1), nn.Hardtanh(1,3)) - - - def forward(self, x): - # print(len(x)) - ap = x[0] - base_mask = x[1] - A = x[2] - M = x[3] - - amp = ap[:,0].clone().detach() - phase = ap[:,1].clone().detach() - amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 - compl = amp_rescaled * torch.exp(1j * phase) #k measured - ifft = torch.fft.ifft2(compl) - spatial = torch.fft.ifftshift(ifft) - # change to two channels real/imag - input = torch.zeros(ap.shape).to('cuda') - input[:,0] = spatial.real - input[:,1] = spatial.imag - # measured = input.clone().detach() - - #calculate Dirty Beam - beam = calc_DirtyBeam(base_mask) - # beam_copy = beam.clone().detach() - - # M = torch.zeros(input.shape).to('cuda') - - - # for i in range(5): - out_b = self.blocks(input) - out_p = self.post(out_b) - - - # residual = input - self.lamb*torch.einsum('bclm,bclm->bclm', out_p, beam) - residual = input - self.lamb*out_p - - M = M + self.lamb*out_p - - # if i == 4: - # break - - # return (input[:,0]+1j*input[:,1]).unsqueeze(1) - # return (M[:,0]+1j*M[:,1]).unsqueeze(1) - - # gauss_params = self.beamBlock(beam) - - # clean_beam = torch.fft.ifft2(torch.fft.fft2(torch.cat([gauss(63,s) for s in gauss_params]).reshape(-1,M.shape[2],M.shape[2]))) - - # M = M + input - # return clean_beam - # M_compl = (M[:,0]+1j*M[:,1]) - - - # M_conv = torch.einsum('blm,blm->blm', torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(M_compl))), torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(clean_beam)))) - - # M_conv = torch.fft.fftshift(torch.fft.ifft2(torch.fft.fftshift(M_conv))) - fft_residual = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(residual[:,0]+1j*residual[:,1]))) - - res_amp = torch.absolute(fft_residual) - res_phase = torch.angle(fft_residual) - - residual[:,0] = ((torch.log10(res_amp + 1e-10) / 10) + 1) - residual[:,1] = res_phase - - return residual, M - -class ConvRNN(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Sequential( - nn.Conv2d(4, 64, 5, stride=1, dilation=1, padding=2), #padding = dilation * (ks-1) // 2 - nn.ReLU(), - ) - self.GRU1 = ConvGRUCell(64, 64, 1) - self.conv2 = nn.Sequential( - nn.Conv2d(64, 64, 3, stride=1, dilation=2, padding=2), - nn.ReLU(), - ) - self.GRU2 = ConvGRUCell(64, 64, 1) - self.conv3 = nn.Sequential( - nn.Conv2d(64, 2, 3, stride=1, dilation=1, padding=1, bias=False) - ) - - def forward(self, x, hx=None): - if not hx: - hx = [None]*2 - - c1 = self.conv1(x) - g1 = self.GRU1(c1, hx[0]) - c2 = self.conv2(g1) - g2 = self.GRU2(c2, hx[1]) - c3 = self.conv3(g2) - - - - return c3, [g1.detach(), g2.detach()] - -class RIM(nn.Module): - def __init__(self, n_steps=20): - super().__init__() - torch.cuda.set_device(1) - self.n_steps = n_steps - self.cRNN = ConvRNN() - # self.bn = nn.BatchNorm2d(2) - - - def forward(self, x, hx=None): - ap = x[0] - amp = ap[:,0] - phase = ap[:,1] - amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 - compl = amp_rescaled * torch.exp(1j * phase) #k measured - data = compl.clone().detach() - compl_shift = torch.fft.fftshift(compl) # shift low freq to corner - ifft = torch.fft.ifft2(compl_shift) - ifft_shift = torch.fft.ifftshift(ifft) # shift low freq to center - eta = torch.zeros(ap.shape).to('cuda') - eta[:,0] = ifft_shift.real - eta[:,1] = ifft_shift.imag - - - - etas = [] - - for i in range(self.n_steps): - - grad = gradFunc(eta, data, x[2], x[1], 8, 45).detach() - # bn = self.bn(grad) - input = torch.cat((eta,grad), dim=1) - - delta, hx = self.cRNN(input, hx) - eta = eta + delta - # plt.imshow(torch.absolute(bn[0,0]+1j*bn[0,1]).cpu().detach().numpy()) - # plt.colorbar() - # plt.show() - etas.append(eta) - - return etas - - class ConvRNN_deepClean(nn.Module): def __init__(self): super().__init__() From 347a7a3a2404a6006ed9a3edb24f48d4775e9b85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Fr=C3=B6se?= Date: Wed, 13 Apr 2022 12:32:16 +0200 Subject: [PATCH 55/55] add rim to evaluation.toml --- tests/evaluate.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/evaluate.toml b/tests/evaluate.toml index 16b95091..853c6e41 100644 --- a/tests/evaluate.toml +++ b/tests/evaluate.toml @@ -20,6 +20,7 @@ arch_name = "SRResNet_bigger_no_symmetry" arch_name_2 = "none" output_format = "png" diff = true +rim = false [inspection] visualize_prediction = true @@ -38,4 +39,4 @@ evaluate_dynamic_range = false evaluate_ms_ssim = false evaluate_mean_diff = false evaluate_area = false -evaluate_point = false \ No newline at end of file +evaluate_point = false