From a06b947894e81ffec4c6a7c9bb305f66164153f2 Mon Sep 17 00:00:00 2001 From: Linfang-mumu <52233121+Linfang-mumu@users.noreply.github.com> Date: Mon, 14 Jun 2021 17:13:45 +0800 Subject: [PATCH] Add files via upload --- CNN_train.py | 149 +++++++ CNN_train_2d.py | 155 +++++++ data_prepare.py | 380 ++++++++++++++++ data_prepare_2d.py | 365 +++++++++++++++ losses.py | 156 +++++++ model2_cpx.py | 840 +++++++++++++++++++++++++++++++++++ test_even_odd_ms_torchfft.py | 218 +++++++++ test_even_odd_ss_torchfft.py | 222 +++++++++ 8 files changed, 2485 insertions(+) create mode 100644 CNN_train.py create mode 100644 CNN_train_2d.py create mode 100644 data_prepare.py create mode 100644 data_prepare_2d.py create mode 100644 losses.py create mode 100644 model2_cpx.py create mode 100644 test_even_odd_ms_torchfft.py create mode 100644 test_even_odd_ss_torchfft.py diff --git a/CNN_train.py b/CNN_train.py new file mode 100644 index 0000000..953f24f --- /dev/null +++ b/CNN_train.py @@ -0,0 +1,149 @@ +#------------------------- python ------------------------------------------ +# jupyter notebook +import numpy as np +import pandas as pd +import datetime +from model2_cpx import Net_cpx, Net_MS_cpx +import time + +import torch.optim as optim +from scipy import io +import argparse +import os # nn.BatchNorm2d(2,affine=False), +import torch +# from SSIM import SSIM +from losses import SSIMLoss2D_MC + +from torch import nn +from torch.utils.data import Dataset, DataLoader +import h5py +import matplotlib.pyplot as plt +import h5py +import matplotlib +from PIL import Image +import math +from sklearn.metrics import confusion_matrix +import pylab as pl +import matplotlib.pyplot as plt +import numpy as np +import itertools + +os.environ["CUDA_VISIBLE_DEVICES"]="0" #USE gpu 1, gp0 cannot be used for some reason +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = torch.device("cpu") +print(device) +epoch_num = 50 #itration number +num_workers = 0 + +current_data= '//media/bisp/New Volume/Linfang/PF_CC398_218_170_218/PF55/MS/' +current_data_file = current_data + 'CC_brain/' + +os.makedirs(current_data+'/ssim_64_16_cpx'+'/', exist_ok=True) +model_save_path = current_data + '/ssim_64_16_cpx'+'/' +class prepareData(Dataset): + def __init__(self, train_or_test): + + self.files = os.listdir(current_data_file+train_or_test) + self.train_or_test= train_or_test + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + c=current_data_file+self.train_or_test+'/'+self.files[idx] + + data = torch.load(current_data_file+self.train_or_test+'/'+self.files[idx]) + return data['k-space'], data['label'] + +trainset = prepareData('train') +trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,shuffle=True, num_workers=num_workers) + +validationset = prepareData('validation') +validationloader = torch.utils.data.DataLoader(validationset, batch_size=1,shuffle=True, num_workers=num_workers) + +testset = prepareData('test') +testloader = torch.utils.data.DataLoader(testset, batch_size=1,shuffle=False, num_workers=num_workers) + +# model = Net_cpx().to(device) +model = Net_MS_cpx().to(device) +print(model) + +criterion1 = nn.L1Loss() + +lr = 0.0002 +nx = 218 +ny = 170 +# nx = 256 +# ny = 256 +nc = 6 +weight_decay = 0.000 +optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + +ssim = SSIMLoss2D_MC(in_chan=2) +loss_train_list = [] +loss_validation_list = [] + +for epoch in range(epoch_num): #set to 0 for no running the training + model.train() + + loss_batch = [] + time_start=time.time() + for i, data in enumerate(trainloader, 0): + inputs = data[0].reshape(-1,nc,ny,nx).to(device) ##single slice + label = data[1].reshape(-1,2,ny,nx).to(device) + if nc == 6: + labels= label + labels[:,0,:,:]= label[:,0,:,:] +inputs[:,0,:,:] + labels[:,1,:,:]= label[:,1,:,:] +inputs[:,3,:,:] + else: + labels = inputs + label + + outs = model(inputs) + loss = criterion1(outs, labels) + # loss = ssim(outs, labels,1) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + loss_batch.append(loss.item()) + if (i)%10==0: + print('epoch:%d - %d, loss:%.10f'%(epoch+1,i+1,loss.item())) + # break + # h=0 + loss_train_list.append(round(sum(loss_batch) / len(loss_batch),10)) + print(loss_train_list) + time_end=time.time() + print('time cost for training',time_end-time_start,'s') + model.eval() # evaluation + loss_batch = [] + print('\n testing...') + time_start=time.time() + for i, data in enumerate(validationloader, 0): + inputs = data[0].reshape(-1,nc,ny,nx).to(device) + label = data[1].reshape(-1,2,ny,nx).to(device) + if nc ==6: + labels= label + labels[:,0,:,:]= label[:,0,:,:] +inputs[:,0,:,:] + labels[:,1,:,:]= label[:,1,:,:] +inputs[:,3,:,:] + else: + labels = inputs + label + + with torch.no_grad(): + outs = model(inputs) + loss = criterion1(outs, labels) + # loss = ssim(outs, labels,1) ####using the L1loss for several epochs, then using ssim to train the whole model + loss_batch.append(loss.item()) + + time_end=time.time() + print('time cost for testing',time_end-time_start,'s') + loss_validation_list.append(round(sum(loss_batch) / len(loss_batch),10)) + print(loss_validation_list) + + torch.save(model, os.path.join(model_save_path, 'epoch-%d-%.10f.pth' % (epoch+1, loss.item()))) + + # if (epoch+1) % 2 == 0: + # lr = max(5e-5,lr*0.8) + # optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + +print('Finished Training') \ No newline at end of file diff --git a/CNN_train_2d.py b/CNN_train_2d.py new file mode 100644 index 0000000..8876770 --- /dev/null +++ b/CNN_train_2d.py @@ -0,0 +1,155 @@ +#------------------------- python ------------------------------------------ +# jupyter notebook +import numpy as np +import pandas as pd +import datetime +from model2_cpx import Net_cpx_2D +import time +import torch.optim as optim +from scipy import io +import argparse +import os # nn.BatchNorm2d(2,affine=False), +import torch +# from SSIM import SSIM +from losses import SSIMLoss2D_MC + +from torch import nn +from torch.utils.data import Dataset, DataLoader +import h5py +import matplotlib.pyplot as plt +import h5py +import matplotlib +from PIL import Image +import math +from sklearn.metrics import confusion_matrix +import pylab as pl +import matplotlib.pyplot as plt +import numpy as np +import itertools + +os.environ["CUDA_VISIBLE_DEVICES"]="0" #USE gpu 1, gp0 cannot be used for some reason +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = torch.device("cpu") +print(device) +epoch_num = 50 #itration number +num_workers = 0 + +current_data= '//media/bisp/New Volume/Linfang/PF_CC398_218_170_218/PF60/SS//' +current_data_file = current_data + 'CC_brain_2D/' + + +os.makedirs(current_data+'/ssim_64_16_cpx'+'/', exist_ok=True) +model_save_path = current_data + '/ssim_64_16_cpx'+'/' +class prepareData(Dataset): + def __init__(self, train_or_test): + + self.files = os.listdir(current_data_file+train_or_test) + self.train_or_test= train_or_test + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + c=current_data_file+self.train_or_test+'/'+self.files[idx] + + data = torch.load(current_data_file+self.train_or_test+'/'+self.files[idx]) + return data['k-space'], data['label'] + +trainset = prepareData('train') +trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,shuffle=True, num_workers=num_workers) + +validationset = prepareData('validation') +validationloader = torch.utils.data.DataLoader(validationset, batch_size=1,shuffle=True, num_workers=num_workers) + +testset = prepareData('test') +testloader = torch.utils.data.DataLoader(testset, batch_size=1,shuffle=False, num_workers=num_workers) + + +model = Net_cpx_2D().to(device) +# model = torch.load(current_data +'/real_L1_64_16_cpx'+'/epoch-35-0.0229811855.pth')# repeat once +print(model) + +criterion1 = nn.L1Loss() + +lr = 0.0002 +nx = 218 +ny = 170 +nc = 2 +weight_decay = 0.000 +optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + +# ssim=SSIM(channels=2) +ssim = SSIMLoss2D_MC(in_chan=2) +loss_train_list = [] +loss_validation_list = [] + +for epoch in range(epoch_num): #set to 0 for no running the training + model.train() + loss_batch = [] + + time_start=time.time() + + for i, data in enumerate(trainloader, 0): + # break + inputs = data[0].reshape(-1,nc,ny,nx).to(device) ##single slice + label = data[1].reshape(-1,2,ny,nx).to(device) + if nc == 6: + labels= label + labels[:,0,:,:]= label[:,0,:,:] +inputs[:,0,:,:] + labels[:,1,:,:]= label[:,1,:,:] +inputs[:,3,:,:] + else: + labels = inputs + label + + outs = model(inputs) + loss = criterion1(outs, labels) + # loss = ssim(outs, labels,1) ####using the L1loss for several epochs, then using ssim to train the whole model + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + loss_batch.append(loss.item()) + + if (i)%10==0: + print('epoch:%d - %d, loss:%.10f'%(epoch+1,i+1,loss.item())) + # break + # h=0 + loss_train_list.append(round(sum(loss_batch) / len(loss_batch),10)) + print(loss_train_list) + + time_end=time.time() + print('time cost for training',time_end-time_start,'s') + + model.eval() # evaluation + loss_batch = [] + print('\n testing...') + + time_start=time.time() + for i, data in enumerate(validationloader, 0): + inputs = data[0].reshape(-1,nc,ny,nx).to(device) + label = data[1].reshape(-1,2,ny,nx).to(device) + if nc ==6: + labels= label + labels[:,0,:,:]= label[:,0,:,:] +inputs[:,0,:,:] + labels[:,1,:,:]= label[:,1,:,:] +inputs[:,3,:,:] + else: + labels = inputs + label + + with torch.no_grad(): + outs = model(inputs) + loss = criterion1(outs, labels) + # loss = ssim(outs, labels,1) ######using the L1loss for several epochs, then using ssim to train the whole model + loss_batch.append(loss.item()) + + time_end=time.time() + print('time cost for testing',time_end-time_start,'s') + loss_validation_list.append(round(sum(loss_batch) / len(loss_batch),10)) + print(loss_validation_list) + + torch.save(model, os.path.join(model_save_path, 'epoch-%d-%.10f.pth' % (epoch+1, loss.item()))) + + # if (epoch+1) % 2 == 0: + # lr = max(5e-5,lr*0.8) + # optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + +print('Finished Training') \ No newline at end of file diff --git a/data_prepare.py b/data_prepare.py new file mode 100644 index 0000000..8756ae0 --- /dev/null +++ b/data_prepare.py @@ -0,0 +1,380 @@ +import numpy as np +import pandas as pd +import datetime +import torch.optim as optim +from scipy import io +import argparse +import os +import torch +from torch import nn +from torch.utils.data import Dataset, DataLoader +import h5py +import matplotlib.pyplot as plt +import matplotlib +# print(matplotlib.__version__) +matplotlib.use('Agg') +import h5py + +from PIL import Image +import math +from sklearn.metrics import confusion_matrix +import pylab as pl +import numpy as np +import itertools +import torchvision +# from ._conv import register_converters as _register_converters +# current_file_data_save = '//nfs/bisp_data_server/Linfang/PF_recon/PF55/SS/NYU_brain_new/' +current_file_data_save = '//media/bisp/New Volume/Linfang/PF_CC398_218_170_218/PF55/EMS/CC_brain_testing_code/' +data_file = '/nfs/bisp_data_server/Linfang/CC359_dataset/multi_channel/train_val_12_channel/CC398_218_170_128/' +os.makedirs(current_file_data_save+'/test/', exist_ok=True) +os.makedirs(current_file_data_save+'/validation/', exist_ok=True) +os.makedirs(current_file_data_save+'/train/', exist_ok=True) +matrix_size = int(218) + +pf_line = int(np.floor(matrix_size*0.45)) +pf_line_com = matrix_size-pf_line +SS_flag =0 +MS_flag =0 +EMS_flag =1 +# for validation +# filePath ='///nfs/bisp_data_server/Linfang/PF_recon/singlecoil//validation/' +filePath =data_file +'/test/' +filename = os.listdir(filePath) +length = len(filename) +print('test') +for idx in range(0,length): + f = h5py.File(filePath+ filename[idx],'r') + # print(f.keys()) + # data = f['kspace'].value + + data_I = f['kspace_I'].value + data_R = f['kspace_R'].value + + data = data_I*1j +data_R + sz = data.shape + + freq= np.fft.ifftshift(data ,axes=0) + freq = np.fft.ifft(freq,axis =0) + data1 = np.sqrt(sz[0])*np.fft.fftshift(freq,axes=0) + + # data1 = data[(sz[0]//2-8):(sz[0]//2+8),(sz[1]//2-matix_crop):(sz[1]//2+matix_crop),(sz[2]//2-matix_crop):(sz[2]//2+matix_crop)] + data2 = np.reshape(data1,(-1,1,sz[1],sz[2])) + # sz1 = data1.shape + + + # plt.figure(1) + # plt.imshow(np.log(abs(data2[150,0, :, :])-1e-10), cmap='gray') + # plt.savefig('k_or') + + # # freq= np.fft.ifft2(data2 ,axes=(2,3)) + # # img = np.fft.ifftshift(freq,axes=(2,3)) + freq= np.fft.ifftshift(data2 ,axes=(2,3)) + freq = np.fft.ifft2(freq,axes=(2,3)) + img = np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + + +#################################################### + # plt.figure(1) + # plt.imshow((abs(img[90,0, :, :])-1e-10), cmap='gray') + # plt.savefig('img_data_or') + + # plt.figure(1) + # plt.imshow((np.angle(img[90,0, :, :])-1e-10), cmap='gray') + # plt.savefig('img_angle_or') + # freq= np.fft.ifftshift(img ,axes=(2,3)) + # freq= np.fft.fft2(freq ,axes=(2,3)) + # freq = 1/matrix_size*np.fft.fftshift(freq,axes=(2,3)) + # plt.figure(1) + # plt.imshow(np.log(abs(freq[2,0, :, :])-1e-10), cmap='gray') + # plt.savefig('k_or') +############################################################## + scale_ref = round(np.max(np.absolute(img)),15) + # imgR = np.real(img)/scale_ref + # imgI = np.imag(img)/scale_ref + img = img/scale_ref + + + + + + ##single slice + if SS_flag ==1: + img_ref = img[1:sz[0]-1,:,:,:]*0 + for k in range(1,sz[0]-1): + img_ref[k-1,:,:,:] = img[k,:,:,:] + else: + img_ref = np.concatenate((img[1:sz[0]-1,:,:,:]*0,img[1:sz[0]-1,:,:,:]*0,img[1:sz[0]-1,:,:,:]*0),axis = 1) + for k in range(1,sz[0]-1): + img_ref[k-1,:,:,:] = np.concatenate((img[k,:,:,:], img[k-1,:,:,:], img[k+1,:,:,:]), axis =0) + # plt.figure(2) + # plt.imshow((abs(img[91,0, :, :])), cmap='gray') + # plt.savefig('img_data_or') + + ##multi slice + freq= np.fft.ifftshift(img_ref ,axes=(2,3)) + freq= np.fft.fft2(freq ,axes=(2,3)) + test_k= 1/np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + test_data = np.copy(test_k) + test_data[:,0,:,:pf_line] = 0 +##############complementry sampling pattern + if EMS_flag ==1: + test_data[:,1:3,:,pf_line_com:] = 0 + + if MS_flag==1: + test_data[:,1:3,:,:pf_line] = 0 +########################### + + freq = np.fft.ifftshift(test_data,axes=(2,3)) + freq = np.fft.ifft2(freq ,axes=(2,3)) + img_data = np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + + # plt.figure(2) + # plt.imshow((abs(img_data[90,0, :, :])), cmap='gray') + # plt.savefig('img_data_zero') + + # plt.figure(2) + # plt.imshow((np.angle(img_data[90,0, :, :])), cmap='gray') + # plt.savefig('angle_data_zero') + + # img_label = img_ref*0 + img_label =img_ref[:,0,:,:] -img_data[:,0,:,:] + img_label =img_label.reshape(sz[0]-2,-1,sz[1],sz[2]) + + imgdataR = np.copy(np.real(img_data)) + imgdataI = np.copy(np.imag(img_data)) + + imgdata = np.concatenate((imgdataR,imgdataI), axis=1) + + + + imglabelR = np.copy(np.real(img_label)) + imglabelI = np.copy(np.imag(img_label)) + imglabel = np.concatenate((imglabelR,imglabelI), axis=1) + + + imgfull=img_label[:,0,:,:] + img_data[:,0,:,:] + imgfull =imgfull.reshape(sz[0]-2,-1,sz[1],sz[2]) + +######################################################### + k_nor = np.fft.ifftshift(imgfull,axes=(2,3)) + k_nor= np.fft.fft2(k_nor ,axes=(2,3)) + k_nor = 1/np.sqrt(sz[1]*sz[2])*np.fft.fftshift(k_nor,axes=(2,3)) + + k_zero = np.fft.ifftshift(img_data,axes=(2,3)) + k_zero= np.fft.fft2(k_zero ,axes=(2,3)) + k_zero = 1/np.sqrt(sz[1]*sz[2])*np.fft.fftshift(k_zero,axes=(2,3)) + + + k_label = np.fft.ifftshift(img_label,axes=(2,3)) + k_label= np.fft.fft2(k_label ,axes=(2,3)) + k_label = 1/np.sqrt(sz[1]*sz[2])*np.fft.fftshift(k_label,axes=(2,3)) + + nc =0 + nslice = 90 + plt.figure(1) + plt.imshow(np.log(np.abs(k_nor[nslice,0,:,:])),'gray') + plt.savefig('0_original_k.png') + + plt.figure(2) + plt.imshow(abs(imgfull[nslice,0,:,:]),'gray') + plt.savefig('1_original_ref.png') + + plt.figure(3) + plt.imshow(np.log(np.abs(k_zero[nslice,nc,:,:])),'gray') + plt.savefig('2_zero_k.png') + + plt.figure(4) + plt.imshow(abs(img_data[nslice,nc,:,:]),'gray') + plt.savefig('3_zero_img.png') + + plt.figure(6) + plt.imshow(np.log(abs(k_label[nslice,0,:,:])),'gray') + plt.savefig('4_label_k.png') + + plt.figure(5) + plt.imshow(abs(img_label[nslice,0,:,:]),'gray') + plt.savefig('5_label_img.png') +# #################################################################### + plt.figure(3) + plt.imshow(np.log(np.abs(k_zero[nslice,1,:,:])),'gray') + plt.savefig('2_zero_k_1.png') + + plt.figure(4) + plt.imshow(abs(img_data[nslice,1,:,:]),'gray') + plt.savefig('3_zero_img_1.png') + + plt.figure(3) + plt.imshow(np.log(np.abs(k_zero[nslice,2,:,:])),'gray') + plt.savefig('2_zero_k_2.png') + + plt.figure(4) + plt.imshow(abs(img_data[nslice,2,:,:]),'gray') + plt.savefig('3_zero_img_2.png') +##################################################################### + + D = torch.from_numpy(imgdata).float() + L = torch.from_numpy(imglabel).float() + data = {'k-space':D,'label':L} + torch.save(data,current_file_data_save +'/test/'+str(idx)+'.pth') + f.close() + + + +# # for test +print('validation') +filePath =data_file +'/validation/' +filename = os.listdir(filePath) +length = len(filename) + +for idx in range(0,length): + f = h5py.File(filePath+ filename[idx],'r') + + data_I = f['kspace_I'].value + data_R = f['kspace_R'].value + + data = data_I*1j +data_R + sz = data.shape + + freq= np.fft.ifftshift(data ,axes=0) + freq = np.fft.ifft(freq,axis =0) + data1 = np.sqrt(sz[0])*np.fft.fftshift(freq,axes=0) + + data2 = np.reshape(data1,(-1,1,sz[1],sz[2])) + + freq= np.fft.ifftshift(data2 ,axes=(2,3)) + freq = np.fft.ifft2(freq,axes=(2,3)) + img = np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + + scale_ref = round(np.max(np.absolute(img)),15) + img = img/scale_ref + + ##single slice + if SS_flag ==1: + img_ref = img[1:sz[0]-1,:,:,:]*0 + for k in range(1,sz[0]-1): + img_ref[k-1,:,:,:] = img[k,:,:,:] + else: + img_ref = np.concatenate((img[1:sz[0]-1,:,:,:]*0,img[1:sz[0]-1,:,:,:]*0,img[1:sz[0]-1,:,:,:]*0),axis = 1) + for k in range(1,sz[0]-1): + img_ref[k-1,:,:,:] = np.concatenate((img[k,:,:,:], img[k-1,:,:,:], img[k+1,:,:,:]), axis =0) + + ##multi slice + freq= np.fft.ifftshift(img_ref ,axes=(2,3)) + freq= np.fft.fft2(freq ,axes=(2,3)) + test_k= 1/np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + test_data = np.copy(test_k) + test_data[:,0,:,:pf_line] = 0 +##############complementry sampling pattern + if EMS_flag ==1: + test_data[:,1:3,:,pf_line_com:] = 0 + + if MS_flag==1: + test_data[:,1:3,:,:pf_line] = 0 +########################### + + freq = np.fft.ifftshift(test_data,axes=(2,3)) + freq = np.fft.ifft2(freq ,axes=(2,3)) + img_data = np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + + img_label =img_ref[:,0,:,:] -img_data[:,0,:,:] + img_label =img_label.reshape(sz[0]-2,-1,sz[1],sz[2]) + + imgdataR = np.copy(np.real(img_data)) + imgdataI = np.copy(np.imag(img_data)) + + imgdata = np.concatenate((imgdataR,imgdataI), axis=1) + + imglabelR = np.copy(np.real(img_label)) + imglabelI = np.copy(np.imag(img_label)) + imglabel = np.concatenate((imglabelR,imglabelI), axis=1) + + + imgfull=img_label[:,0,:,:] + img_data[:,0,:,:] + imgfull =imgfull.reshape(sz[0]-2,-1,sz[1],sz[2]) + + for i in range(16,int(sz[0]-20)): + D = torch.from_numpy(imgdata[i,:,:,:]).float() + L = torch.from_numpy(imglabel[i,:,:,:]).float() + data = {'k-space':D,'label':L} + torch.save(data,current_file_data_save +'/validation/'+str(idx)+'_'+str(i)+'.pth') + f.close() + + +# # for train +print('train') +# filePath ='///nfs/bisp_data_server/Linfang/PF_recon/singlecoil//train/' +filePath =data_file +'/train/' +filename = os.listdir(filePath) +length = len(filename) + +for idx in range(0,length): + f = h5py.File(filePath+ filename[idx],'r') + data_I = f['kspace_I'].value + data_R = f['kspace_R'].value + + data = data_I*1j +data_R + sz = data.shape + + freq= np.fft.ifftshift(data ,axes=0) + freq = np.fft.ifft(freq,axis =0) + data1 = np.sqrt(sz[0])*np.fft.fftshift(freq,axes=0) + data2 = np.reshape(data1,(-1,1,sz[1],sz[2])) + freq= np.fft.ifftshift(data2 ,axes=(2,3)) + freq = np.fft.ifft2(freq,axes=(2,3)) + img = np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + + scale_ref = round(np.max(np.absolute(img)),15) + img = img/scale_ref + + ##single slice + if SS_flag ==1: + img_ref = img[1:sz[0]-1,:,:,:]*0 + for k in range(1,sz[0]-1): + img_ref[k-1,:,:,:] = img[k,:,:,:] + else: + img_ref = np.concatenate((img[1:sz[0]-1,:,:,:]*0,img[1:sz[0]-1,:,:,:]*0,img[1:sz[0]-1,:,:,:]*0),axis = 1) + for k in range(1,sz[0]-1): + img_ref[k-1,:,:,:] = np.concatenate((img[k,:,:,:], img[k-1,:,:,:], img[k+1,:,:,:]), axis =0) + + ##multi slice + freq= np.fft.ifftshift(img_ref ,axes=(2,3)) + freq= np.fft.fft2(freq ,axes=(2,3)) + test_k= 1/np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + test_data = np.copy(test_k) + test_data[:,0,:,:pf_line] = 0 +##############complementry sampling pattern + if EMS_flag ==1: + test_data[:,1:3,:,pf_line_com:] = 0 + + if MS_flag==1: + test_data[:,1:3,:,:pf_line] = 0 +########################### + + freq = np.fft.ifftshift(test_data,axes=(2,3)) + freq = np.fft.ifft2(freq ,axes=(2,3)) + img_data = np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + + img_label =img_ref[:,0,:,:] -img_data[:,0,:,:] + img_label =img_label.reshape(sz[0]-2,-1,sz[1],sz[2]) + + imgdataR = np.copy(np.real(img_data)) + imgdataI = np.copy(np.imag(img_data)) + + imgdata = np.concatenate((imgdataR,imgdataI), axis=1) + + imglabelR = np.copy(np.real(img_label)) + imglabelI = np.copy(np.imag(img_label)) + imglabel = np.concatenate((imglabelR,imglabelI), axis=1) + + + imgfull=img_label[:,0,:,:] + img_data[:,0,:,:] + imgfull =imgfull.reshape(sz[0]-2,-1,sz[1],sz[2]) + + for i in range(16,int(sz[0]-20)): + D = torch.from_numpy(imgdata[i,:,:,:]).float() + L = torch.from_numpy(imglabel[i,:,:,:]).float() + data = {'k-space':D,'label':L} + torch.save(data,current_file_data_save +'/train/'+str(idx)+'_'+str(i)+'.pth') + f.close() + +os.system('python CNN_train.py') \ No newline at end of file diff --git a/data_prepare_2d.py b/data_prepare_2d.py new file mode 100644 index 0000000..125a857 --- /dev/null +++ b/data_prepare_2d.py @@ -0,0 +1,365 @@ +import numpy as np +import pandas as pd +import datetime +import torch.optim as optim +from scipy import io +import argparse +import os +import torch +from torch import nn +from torch.utils.data import Dataset, DataLoader +import h5py +import matplotlib.pyplot as plt +import matplotlib +# print(matplotlib.__version__) +matplotlib.use('Agg') +import h5py + +from PIL import Image +import math +from sklearn.metrics import confusion_matrix +import pylab as pl +import numpy as np +import itertools +import torchvision +# from ._conv import register_converters as _register_converters + +current_file_data_save = '//media/bisp/New Volume/Linfang/PF_CC398_218_170_218/PF55/SS/CC_brain_2D' +data_file = '/nfs/bisp_data_server/Linfang/CC359_dataset/multi_channel/train_val_12_channel/CC398_218_170_128/' +os.makedirs(current_file_data_save+'/test/', exist_ok=True) +os.makedirs(current_file_data_save+'/validation/', exist_ok=True) +os.makedirs(current_file_data_save+'/train/', exist_ok=True) +matrix_size = int(218) + +pf_line = int(np.floor(matrix_size*0.45)) +pf_line_com = matrix_size-pf_line +pf_line_o =int(np.floor(170*0.45)) +SS_flag =1 +MS_flag =0 +EMS_flag =0 +# for test +filePath =data_file +'/test/' +filename = os.listdir(filePath) +length = len(filename) +print('test') +for idx in range(0,length): + f = h5py.File(filePath+ filename[idx],'r') + # print(f.keys()) + # data = f['kspace'].value + + data_I = f['kspace_I'].value + data_R = f['kspace_R'].value + + data = data_I*1j +data_R + sz = data.shape + + freq= np.fft.ifftshift(data ,axes=0) + freq = np.fft.ifft(freq,axis =0) + data1 = np.sqrt(sz[0])*np.fft.fftshift(freq,axes=0) + + # data1 = data[(sz[0]//2-8):(sz[0]//2+8),(sz[1]//2-matix_crop):(sz[1]//2+matix_crop),(sz[2]//2-matix_crop):(sz[2]//2+matix_crop)] + data2 = np.reshape(data1,(-1,1,sz[1],sz[2])) + # sz1 = data1.shape + # plt.figure(1) + # plt.imshow(np.log(abs(data2[150,0, :, :])-1e-10), cmap='gray') + # plt.savefig('k_or') + # # freq= np.fft.ifft2(data2 ,axes=(2,3)) + # # img = np.fft.ifftshift(freq,axes=(2,3)) + freq= np.fft.ifftshift(data2 ,axes=(2,3)) + freq = np.fft.ifft2(freq,axes=(2,3)) + img = np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) +#################################################### + # plt.figure(1) + # plt.imshow((abs(img[90,0, :, :])-1e-10), cmap='gray') + # plt.savefig('img_data_or') + + # plt.figure(1) + # plt.imshow((np.angle(img[90,0, :, :])-1e-10), cmap='gray') + # plt.savefig('img_angle_or') + # freq= np.fft.ifftshift(img ,axes=(2,3)) + # freq= np.fft.fft2(freq ,axes=(2,3)) + # freq = 1/matrix_size*np.fft.fftshift(freq,axes=(2,3)) + # plt.figure(1) + # plt.imshow(np.log(abs(freq[2,0, :, :])-1e-10), cmap='gray') + # plt.savefig('k_or') +############################################################## + scale_ref = round(np.max(np.absolute(img)),15) + img = img/scale_ref + ##single slice + if SS_flag ==1: + img_ref = img[1:sz[0]-1,:,:,:]*0 + for k in range(1,sz[0]-1): + img_ref[k-1,:,:,:] = img[k,:,:,:] + else: + img_ref = np.concatenate((img[1:sz[0]-1,:,:,:]*0,img[1:sz[0]-1,:,:,:]*0,img[1:sz[0]-1,:,:,:]*0),axis = 1) + for k in range(1,sz[0]-1): + img_ref[k-1,:,:,:] = np.concatenate((img[k,:,:,:], img[k-1,:,:,:], img[k+1,:,:,:]), axis =0) + ##multi slice + freq= np.fft.ifftshift(img_ref ,axes=(2,3)) + freq= np.fft.fft2(freq ,axes=(2,3)) + test_k= 1/np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + test_data = np.copy(test_k) + test_data[:,0,:,:pf_line] = 0 + test_data[:,0,:pf_line_o,:] = 0 +##############complementry sampling pattern + if EMS_flag ==1: + test_data[:,1:3,:,pf_line_com:] = 0 + + if MS_flag==1: + test_data[:,1:3,:,:pf_line] = 0 +########################### + + freq = np.fft.ifftshift(test_data,axes=(2,3)) + freq = np.fft.ifft2(freq ,axes=(2,3)) + img_data = np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + + img_label =img_ref[:,0,:,:] -img_data[:,0,:,:] + img_label =img_label.reshape(sz[0]-2,-1,sz[1],sz[2]) + + imgdataR = np.copy(np.real(img_data)) + imgdataI = np.copy(np.imag(img_data)) + + imgdata = np.concatenate((imgdataR,imgdataI), axis=1) + + imglabelR = np.copy(np.real(img_label)) + imglabelI = np.copy(np.imag(img_label)) + imglabel = np.concatenate((imglabelR,imglabelI), axis=1) + + imgfull=img_label[:,0,:,:] + img_data[:,0,:,:] + imgfull =imgfull.reshape(sz[0]-2,-1,sz[1],sz[2]) + +######################################################### + k_nor = np.fft.ifftshift(imgfull,axes=(2,3)) + k_nor= np.fft.fft2(k_nor ,axes=(2,3)) + k_nor = 1/np.sqrt(sz[1]*sz[2])*np.fft.fftshift(k_nor,axes=(2,3)) + + k_zero = np.fft.ifftshift(img_data,axes=(2,3)) + k_zero= np.fft.fft2(k_zero ,axes=(2,3)) + k_zero = 1/np.sqrt(sz[1]*sz[2])*np.fft.fftshift(k_zero,axes=(2,3)) + + + k_label = np.fft.ifftshift(img_label,axes=(2,3)) + k_label= np.fft.fft2(k_label ,axes=(2,3)) + k_label = 1/np.sqrt(sz[1]*sz[2])*np.fft.fftshift(k_label,axes=(2,3)) + + nc =0 + nslice = 90 + plt.figure(1) + plt.imshow(np.log(np.abs(k_nor[nslice,0,:,:])),'gray') + plt.savefig('0_original_k.png') + + plt.figure(2) + plt.imshow(abs(imgfull[nslice,0,:,:]),'gray') + plt.savefig('1_original_ref.png') + + plt.figure(3) + plt.imshow(np.log(np.abs(k_zero[nslice,nc,:,:])),'gray') + plt.savefig('2_zero_k.png') + + plt.figure(4) + plt.imshow(abs(img_data[nslice,nc,:,:]),'gray') + plt.savefig('3_zero_img.png') + + plt.figure(6) + plt.imshow(np.log(abs(k_label[nslice,0,:,:])),'gray') + plt.savefig('4_label_k.png') + + plt.figure(5) + plt.imshow(abs(img_label[nslice,0,:,:]),'gray') + plt.savefig('5_label_img.png') +# #################################################################### +# plt.figure(3) +# plt.imshow(np.log(np.abs(k_zero[nslice,1,:,:])),'gray') +# plt.savefig('2_zero_k_1.png') + +# plt.figure(4) +# plt.imshow(abs(img_data[nslice,1,:,:]),'gray') +# plt.savefig('3_zero_img_1.png') + +# plt.figure(3) +# plt.imshow(np.log(np.abs(k_zero[nslice,2,:,:])),'gray') +# plt.savefig('2_zero_k_2.png') + +# plt.figure(4) +# plt.imshow(abs(img_data[nslice,2,:,:]),'gray') +# plt.savefig('3_zero_img_2.png') +##################################################################### + + D = torch.from_numpy(imgdata).float() + L = torch.from_numpy(imglabel).float() + data = {'k-space':D,'label':L} + torch.save(data,current_file_data_save +'/test/'+str(idx)+'.pth') + f.close() + + + +# # for validation +print('validation') +filePath =data_file +'/validation/' +filename = os.listdir(filePath) +length = len(filename) + +for idx in range(0,length): + f = h5py.File(filePath+ filename[idx],'r') + + data_I = f['kspace_I'].value + data_R = f['kspace_R'].value + + data = data_I*1j +data_R + sz = data.shape + + freq= np.fft.ifftshift(data ,axes=0) + freq = np.fft.ifft(freq,axis =0) + data1 = np.sqrt(sz[0])*np.fft.fftshift(freq,axes=0) + + data2 = np.reshape(data1,(-1,1,sz[1],sz[2])) + + freq= np.fft.ifftshift(data2 ,axes=(2,3)) + freq = np.fft.ifft2(freq,axes=(2,3)) + img = np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + + scale_ref = round(np.max(np.absolute(img)),15) + img = img/scale_ref + + ##single slice + if SS_flag ==1: + img_ref = img[1:sz[0]-1,:,:,:]*0 + for k in range(1,sz[0]-1): + img_ref[k-1,:,:,:] = img[k,:,:,:] + else: + img_ref = np.concatenate((img[1:sz[0]-1,:,:,:]*0,img[1:sz[0]-1,:,:,:]*0,img[1:sz[0]-1,:,:,:]*0),axis = 1) + for k in range(1,sz[0]-1): + img_ref[k-1,:,:,:] = np.concatenate((img[k,:,:,:], img[k-1,:,:,:], img[k+1,:,:,:]), axis =0) + + ##multi slice + freq= np.fft.ifftshift(img_ref ,axes=(2,3)) + freq= np.fft.fft2(freq ,axes=(2,3)) + test_k= 1/np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + test_data = np.copy(test_k) + test_data[:,0,:,:pf_line] = 0 + + test_data[:,0,:pf_line_o,:] = 0 +##############complementry sampling pattern + if EMS_flag ==1: + test_data[:,1:3,:,pf_line_com:] = 0 + + if MS_flag==1: + test_data[:,1:3,:,:pf_line] = 0 +########################### + + freq = np.fft.ifftshift(test_data,axes=(2,3)) + freq = np.fft.ifft2(freq ,axes=(2,3)) + img_data = np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + + img_label =img_ref[:,0,:,:] -img_data[:,0,:,:] + img_label =img_label.reshape(sz[0]-2,-1,sz[1],sz[2]) + + imgdataR = np.copy(np.real(img_data)) + imgdataI = np.copy(np.imag(img_data)) + + imgdata = np.concatenate((imgdataR,imgdataI), axis=1) + + + + imglabelR = np.copy(np.real(img_label)) + imglabelI = np.copy(np.imag(img_label)) + imglabel = np.concatenate((imglabelR,imglabelI), axis=1) + + + imgfull=img_label[:,0,:,:] + img_data[:,0,:,:] + imgfull =imgfull.reshape(sz[0]-2,-1,sz[1],sz[2]) + + for i in range(16,int(sz[0]-20)): + D = torch.from_numpy(imgdata[i,:,:,:]).float() + L = torch.from_numpy(imglabel[i,:,:,:]).float() + data = {'k-space':D,'label':L} + torch.save(data,current_file_data_save +'/validation/'+str(idx)+'_'+str(i)+'.pth') + f.close() + + +# # for train +print('train') +filePath =data_file +'/train/' +filename = os.listdir(filePath) +length = len(filename) + +for idx in range(0,length): + f = h5py.File(filePath+ filename[idx],'r') + data_I = f['kspace_I'].value + data_R = f['kspace_R'].value + + data = data_I*1j +data_R + sz = data.shape + + freq= np.fft.ifftshift(data ,axes=0) + freq = np.fft.ifft(freq,axis =0) + data1 = np.sqrt(sz[0])*np.fft.fftshift(freq,axes=0) + data2 = np.reshape(data1,(-1,1,sz[1],sz[2])) + + freq= np.fft.ifftshift(data2 ,axes=(2,3)) + freq = np.fft.ifft2(freq,axes=(2,3)) + img = np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + + scale_ref = round(np.max(np.absolute(img)),15) + img = img/scale_ref + + + + + + ##single slice + if SS_flag ==1: + img_ref = img[1:sz[0]-1,:,:,:]*0 + for k in range(1,sz[0]-1): + img_ref[k-1,:,:,:] = img[k,:,:,:] + else: + img_ref = np.concatenate((img[1:sz[0]-1,:,:,:]*0,img[1:sz[0]-1,:,:,:]*0,img[1:sz[0]-1,:,:,:]*0),axis = 1) + for k in range(1,sz[0]-1): + img_ref[k-1,:,:,:] = np.concatenate((img[k,:,:,:], img[k-1,:,:,:], img[k+1,:,:,:]), axis =0) + + ##multi slice + freq= np.fft.ifftshift(img_ref ,axes=(2,3)) + freq= np.fft.fft2(freq ,axes=(2,3)) + test_k= 1/np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + test_data = np.copy(test_k) + test_data[:,0,:,:pf_line] = 0 + + test_data[:,0,:pf_line_o,:] = 0 +##############complementry sampling pattern + if EMS_flag ==1: + test_data[:,1:3,:,pf_line_com:] = 0 + + if MS_flag==1: + test_data[:,1:3,:,:pf_line] = 0 +########################### + + freq = np.fft.ifftshift(test_data,axes=(2,3)) + freq = np.fft.ifft2(freq ,axes=(2,3)) + img_data = np.sqrt(sz[1]*sz[2])*np.fft.fftshift(freq,axes=(2,3)) + + img_label =img_ref[:,0,:,:] -img_data[:,0,:,:] + img_label =img_label.reshape(sz[0]-2,-1,sz[1],sz[2]) + + imgdataR = np.copy(np.real(img_data)) + imgdataI = np.copy(np.imag(img_data)) + + imgdata = np.concatenate((imgdataR,imgdataI), axis=1) + + + + imglabelR = np.copy(np.real(img_label)) + imglabelI = np.copy(np.imag(img_label)) + imglabel = np.concatenate((imglabelR,imglabelI), axis=1) + + + imgfull=img_label[:,0,:,:] + img_data[:,0,:,:] + imgfull =imgfull.reshape(sz[0]-2,-1,sz[1],sz[2]) + + for i in range(16,int(sz[0]-20)): + D = torch.from_numpy(imgdata[i,:,:,:]).float() + L = torch.from_numpy(imglabel[i,:,:,:]).float() + data = {'k-space':D,'label':L} + torch.save(data,current_file_data_save +'/train/'+str(idx)+'_'+str(i)+'.pth') + f.close() + +os.system('python CNN_train_2d.py') \ No newline at end of file diff --git a/losses.py b/losses.py new file mode 100644 index 0000000..64159c6 --- /dev/null +++ b/losses.py @@ -0,0 +1,156 @@ +import torch +from torch import nn +from torch.nn import functional as F + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +class SSIMLoss2D(nn.Module): + """ + 2D SSIM loss module. + + """ + + def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03): + """ + Args: + win_size: Window size for SSIM calculation. + k1: k1 parameter for SSIM calculation. + k2: k2 parameter for SSIM calculation. + """ + super().__init__() + self.win_size = win_size + self.k1, self.k2 = k1, k2 + self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size ** 2) + NP = win_size ** 2 + self.cov_norm = NP / (NP - 1) + + def forward(self, X: torch.Tensor, Y: torch.Tensor, data_range: torch.Tensor): + #assert isinstance(self.w, torch.Tensor) + + C1 = (self.k1 * data_range) ** 2 + C2 = (self.k2 * data_range) ** 2 + ux = F.conv2d(X, self.w) # typing: ignore + uy = F.conv2d(Y, self.w) # + uxx = F.conv2d(X * X, self.w) + uyy = F.conv2d(Y * Y, self.w) + uxy = F.conv2d(X * Y, self.w) + vx = self.cov_norm * (uxx - ux * ux) + vy = self.cov_norm * (uyy - uy * uy) + vxy = self.cov_norm * (uxy - ux * uy) + A1, A2, B1, B2 = ( + 2 * ux * uy + C1, + 2 * vxy + C2, + ux ** 2 + uy ** 2 + C1, + vx + vy + C2, + ) + D = B1 * B2 + S = (A1 * A2) / (D + 1e-8) + + return 1 - S.mean() + + +class SSIMLoss2D_MC(nn.Module): + """ + 2D multichannel SSIM loss module. + + """ + + def __init__(self, win_size: int=7, k1: float=0.01, k2: float=0.03, in_chan: int=1): + """ + Args: + win_size: Window size for SSIM calculation. + k1: k1 parameter for SSIM calculation. + k2: k2 parameter for SSIM calculation. + in_chan: number of input channels + """ + super().__init__() + self.win_size = win_size + self.k1, self.k2 = k1, k2 + self.in_chan = in_chan + self.register_buffer("w", torch.ones(in_chan, 1, win_size, win_size) / win_size ** 2) + NP = win_size ** 2 + self.cov_norm = NP / (NP - 1) + + def forward(self, X: torch.Tensor, Y: torch.Tensor, data_range: torch.Tensor): + #assert isinstance(self.w, torch.Tensor) + + + C1 = (self.k1 * data_range) ** 2 + C2 = (self.k2 * data_range) ** 2 + ux = F.conv2d(X, self.w.cuda(), groups=self.in_chan) # per-channel convolution + uy = F.conv2d(Y, self.w.cuda(), groups=self.in_chan) # + uxx = F.conv2d(X * X, self.w.cuda(), groups=self.in_chan) + uyy = F.conv2d(Y * Y, self.w.cuda(), groups=self.in_chan) + uxy = F.conv2d(X * Y, self.w.cuda(), groups=self.in_chan) + vx = self.cov_norm * (uxx - ux * ux) + vy = self.cov_norm * (uyy - uy * uy) + vxy = self.cov_norm * (uxy - ux * uy) + A1, A2, B1, B2 = ( + 2 * ux * uy + C1, + 2 * vxy + C2, + ux ** 2 + uy ** 2 + C1, + vx + vy + C2, + ) + D = B1 * B2 + S = (A1 * A2) / (D + 1e-8) + + return 1 - S.mean() + + +class SSIMLoss3D(nn.Module): + """ + 3D SSIM loss module + + Square window with uniform weights (non-Gaussian) following the implementation + from Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). + Image quality assessment: From error visibility to structural similarity. + IEEE Transactions on Image Processing, 13, 600-612. + https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, + :DOI:`10.1109/TIP.2003.819861` with slight modification for efficiency + + by Vick Lau 2021 + """ + + def __init__(self, win_size: int=7, k1: float=0.01, k2: float=0.03): + """Initialise 3D SSIM loss + + Args + ------- + win_size (int): Window size + k1 (float): k1 parameter + k2 (float): k2 parameter + Returns + ------- + torch.Tensor: 3D SSIM Loss + """ + super().__init__() + self.win_size = win_size + self.k1, self.k2 = k1, k2 + self.register_buffer("w", torch.ones(1, 1, win_size, win_size, win_size) / win_size**3) + NP = win_size ** 3 + self.cov_norm = NP / (NP - 1) # sample covariance instead of population + + def forward(self, X: torch.Tensor, Y: torch.Tensor, data_range: torch.Tensor): + #assert isinstance(self.w, torch.Tensor) + + C1 = (self.k1 * data_range) ** 2 + C2 = (self.k2 * data_range) ** 2 + + # compute variances and covariances + ux = F.conv3d(X, self.w) + uy = F.conv3d(Y, self.w) + uxx = F.conv3d(X * X, self.w) + uyy = F.conv3d(Y * Y, self.w) + uxy = F.conv3d(X * Y, self.w) + vx = self.cov_norm * (uxx - ux*ux) + vy = self.cov_norm * (uyy - uy*uy) + vxy = self.cov_norm * (uxy - ux*uy) + A1, A2, B1, B2 = ( + 2*ux*uy + C1, + 2*vxy + C2, + ux**2 + uy**2 + C1, + vx + vy + C2, + ) + D = B1*B2 + S = (A1*A2) / (D + 1e-8) # eps for floating point stability + + return 1 - S.mean() # compute 1 - mean of SSIM diff --git a/model2_cpx.py b/model2_cpx.py new file mode 100644 index 0000000..328a6d8 --- /dev/null +++ b/model2_cpx.py @@ -0,0 +1,840 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +import torchvision + +RB =0 +class ComplexConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=True): + super(ComplexConv2d, self).__init__() + padding = kernel_size // 2 + self.conv_r = nn.Conv2d(in_channels, out_channels, + kernel_size, stride, padding, bias) + self.conv_i = nn.Conv2d(in_channels, out_channels, + kernel_size, stride, padding, bias) + + def forward(self, x): + + input_r, input_i = torch.split(x, x.shape[1]//2, dim=1) + + y1 = self.conv_r(input_r)-self.conv_i(input_i) + y2 = self.conv_r(input_i)+self.conv_i(input_r) + + return torch.cat((y1, y2), dim=1) + +class MeanShift(nn.Conv2d): + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): + + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std + for p in self.parameters(): + p.requires_grad = False + + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + super(Upsampler, self).__init__(*m) + + +class Upsampler2(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act='relu', bias=True): + + m = [] + for _ in range(int(math.log(scale, 2))): + m.append(nn.UpsamplingNearest2d(scale_factor=2)) + m.append(conv(n_feats, n_feats, 3, bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + super(Upsampler2, self).__init__(*m) + + +class BasicBlock(nn.Sequential): + def __init__( + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, + bn=True, act=nn.ReLU(True)): + + m = [conv(in_channels, out_channels, kernel_size, bias=bias)] + if bn: + m.append(nn.BatchNorm2d(out_channels)) + if act is not None: + m.append(act) + + super(BasicBlock, self).__init__(*m) + +class ResBlock(nn.Module): + def __init__( + self, n_feats, kernel_size, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(ComplexConv2d(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + return res + +class Net_cpx(nn.Module): + def __init__(self): + super(Net_cpx, self).__init__() + + n_resblocks = 16 + n_feats = 32 + + kernel_size = 3 + self.scale = 1 + n_colors = 1 + + # act = ComplexReLU() + # act = nn.ReLU(True) + # self.sub_mean = MeanShift(rgb_range) + # self.add_mean = MeanShift(rgb_range, sign=1) + + # define head module + # m_head = ComplexConv2d(n_colors, n_feats, kernel_size) + + # define body module + m_body = [ + ResBlock(n_feats, kernel_size, res_scale=1 + ) for _ in range(n_resblocks) + ] + m_body.append(ComplexConv2d(n_feats, n_feats, kernel_size)) + + # define tail module + + # m_tail = [ + # ComplexConv2d(n_feats, n_colors, kernel_size) + # ] + + self.head = ComplexConv2d(n_colors, n_feats, kernel_size) + self.body = nn.Sequential(*m_body) + self.tail = ComplexConv2d(n_feats, n_colors, kernel_size) + + # self.cc = ComplexConv2d(1,64) + + self.register_parameter("t", nn.Parameter(-2*torch.ones(1))) + + def forward(self, x): + or_im = x + nsize =x.size() + pf = math.floor(nsize[3]*0.45) + pf_com =nsize[3] -pf + # print(nsize[3]) + # pf_ratio =torch.int(torch.floor(nsize[3]*0.45)) + or_k = torch.complex(or_im[:, 0, :, :], or_im[:, 1, :, :]) + or_k = torch.fft.ifftshift(or_k, dim=(1, 2)) + or_k = torch.fft.fft2(or_k, dim=(1, 2)) + or_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(or_k, dim=(1, 2)) + +###################################################### + # or_k_ = or_k + # or_k_ = or_k_.cpu() + # plt.figure(6) + # plt.imshow(np.log(abs(or_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('input_k_or') + + # or_im_ = torch.complex(or_im[:,0,:,:],or_im[:,1,:,:]) + # or_im_ = or_im_.cpu() + # plt.figure(7) + # plt.imshow((abs(or_im_.detach_().numpy()[2, :, :])), cmap='gray') + # plt.savefig('input_img_or') +################################################# + + y = x + for i in range(2): + x = y + new_k = torch.complex(x[:, 0, :, :], x[:, 1, :, :]) + +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(70) + # plt.imshow((abs(new_k_.detach_().numpy()[2, :, :])), cmap='gray') + # plt.savefig('iter_'+str(i)+'_img_before_DC') +################################### + new_k = torch.fft.ifftshift(new_k, dim=(1,2)) + new_k = torch.fft.fft2(new_k, dim=(1, 2)) + new_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(new_k, dim=(1,2)) + +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(80) + # plt.imshow(np.log(abs(new_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('iter_'+str(i)+'_kspace_before_DC') + +################################### + new_k[:, :, :pf] = or_k[:, :, :pf] # only keep the measured data + new_k = torch.fft.ifftshift(new_k, dim=(1, 2)) + new_k = torch.fft.ifft2(new_k, dim=(1, 2)) + new_k = math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(new_k, dim=(1, 2)) + new_k = torch.stack((torch.real(new_k), torch.imag(new_k)), dim=1) + x = x + self.t*(new_k - or_im) # t learnable parameter ## only keep the measured data + +# ################################### + # test_im = torch.complex(x[:, 0, :, :], x[:, 1, :, :]) + # # test_k = test_im + # test_k = torch.fft.ifftshift(test_im, dim=(1, 2)) + # test_k = torch.fft.fft2(test_k, dim=(1, 2)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(1, 2)) + + # test_k_ = test_k + # test_k_ = test_k_.cpu() + # plt.figure(10) + # plt.imshow(np.log(abs(test_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('iter_'+str(i)+'_kspace_after_soft_DC') + + # x_ = torch.complex(x[:, 0, :, :], x[:, 1, :, :]) + # x_ = x_.cpu() + # plt.figure(61) + # plt.imshow(abs(x_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('iter_'+str(i)+'_img_after_soft_DC_input') + +################################### +# + res = self.head(x) +################################### + # test_im = res + # test_im = torch.complex(res[:, :64, :, :], res[:, 64:, :, :]) + # # test_k = test_im + # test_k = torch.fft.fftshift(test_im, dim=(2, 3)) + # test_k = torch.fft.fft2(test_k, dim=(2, 3)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(2, 3)) + # res_dis = test_im.permute(1,0,2,3) + # slice1 = torch.reshape(res_dis[:,2,:,:],(64,1,256,256)) + # torchvision.utils.save_image(abs(slice1),'iter_'+str(i)+'_img_head_output.png',normalize=True) + + # res_dis = test_k.permute(1,0,2,3) + # slice1 = torch.reshape(res_dis[:,2,:,:],(64,1,256,256)) + # torchvision.utils.save_image(torch.log(abs(slice1)),'iter_'+str(i)+'_k_head_output.png',normalize=True) +################################### + + + + + res = self.body(res) + +# ################################### + # test_im = torch.complex(res[:, :64, :, :], res[:, 64:, :, :]) + # # test_k = test_im + # test_k = torch.fft.ifftshift(test_im, dim=(2, 3)) + # test_k = torch.fft.fft2(test_k, dim=(2, 3)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(2, 3)) + # res_dis = test_im.permute(1,0,2,3) + # slice1 = torch.reshape(res_dis[:,2,:,:],(64,1,256,256)) + # torchvision.utils.save_image(abs(slice1),'iter_'+str(i)+'_img_body_output.png',normalize=True,scale_each = True) + + # res_dis = test_k.permute(1,0,2,3) + # slice1 = torch.reshape(res_dis[:,2,:,:],(64,1,256,256)) + # torchvision.utils.save_image(torch.log(abs(slice1)),'iter_'+str(i)+'_k_body_output.png',normalize=True,scale_each = True) + + + # y = self.tail(res) + # test_im = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + # # test_k =test_im + # test_k = torch.fft.fftshift(test_im, dim=(1, 2)) + # test_k = torch.fft.fft2(test_k, dim=(1, 2)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(1, 2)) + + # test_k_ = test_k + # test_k_ = test_k_.cpu() + # plt.figure(10) + # plt.imshow( + # np.log(abs(test_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('iter_'+str(i)+'_kspace_tail_output_residual') + + + # y_ = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + # y_ = y_.cpu() + # plt.figure(71) + # plt.imshow(abs(y_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('iter_'+str(i)+'_img_tail_output_residual') +################################### + y = self.tail(res) + x +##################################### + + # test_im = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + # test_k = test_im + # test_k = torch.fft.ifftshift(test_im, dim=(1, 2)) + # test_k = torch.fft.fft2(test_k, dim=(1, 2)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(1, 2)) + # test_k_ = test_k + # test_k_ = test_k_.cpu() + # plt.figure(10) + # plt.imshow(np.log(abs(test_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('iter_'+str(i)+'_kspace_tail_output') + + + # y_ = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + # y_ = y_.cpu() + # plt.figure(71) + # plt.imshow(abs(y_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('iter_'+str(i)+'_img_tail_output') + + # plt.figure(62) + # plt.imshow(abs(y_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('img_before_hard_DC') +################################### + new_k = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + new_k = torch.fft.ifftshift(new_k, dim=(1, 2)) + new_k = torch.fft.fft2(new_k, dim=(1, 2)) + new_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(new_k, dim=(1, 2)) +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(62) + # plt.imshow(np.log(abs(new_k_.detach_().numpy()[2, :, :])), cmap='gray') + # plt.savefig('k_before_hard_DC') +################################### + new_k[:, :, pf_com:] = or_k[:, :, pf_com:] +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(62) + # plt.imshow(np.log(abs(new_k_.detach_().numpy()[2, :, :])), cmap='gray') + # plt.savefig('k_after_hard_DC') +################################### + new_k = torch.fft.ifftshift(new_k, dim=(1, 2)) + new_k = torch.fft.ifft2(new_k, dim=(1, 2)) + new_k = math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(new_k, dim=(1, 2)) +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(6) + # plt.imshow(abs(new_k_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('img_after_hard_DC') +################################### + y = torch.stack((torch.real(new_k), torch.imag(new_k)), dim=1) + return y + + + +class Net_MS_cpx(nn.Module): + def __init__(self): + super(Net_MS_cpx, self).__init__() + + n_resblocks = 16 + n_feats = 32 + + kernel_size = 3 + self.scale = 1 + n_colors = 1 + + # act = ComplexReLU() + # act = nn.ReLU(True) + # self.sub_mean = MeanShift(rgb_range) + # self.add_mean = MeanShift(rgb_range, sign=1) + + # define head module + # m_head = ComplexConv2d(n_colors, n_feats, kernel_size) + + # define body module + m_body = [ + ResBlock(n_feats, kernel_size, res_scale=1 + ) for _ in range(n_resblocks) + ] + m_body.append(ComplexConv2d(n_feats, n_feats, kernel_size)) + + # define tail module + + # m_tail = [ + # ComplexConv2d(n_feats, n_colors, kernel_size) + # ] + + self.head = ComplexConv2d(3, n_feats, kernel_size) + self.body = nn.Sequential(*m_body) + self.tail = ComplexConv2d(n_feats, n_colors, kernel_size) + + # self.cc = ComplexConv2d(1,64) + + self.register_parameter("t", nn.Parameter(-2*torch.ones(1))) + + def forward(self, x): + # or_im = x + nsize =x.size() + pf = math.floor(nsize[3]*0.40) + pf_com =nsize[3] -pf + # print(nsize[3]) + # pf_ratio =torch.int(torch.floor(nsize[3]*0.45)) + or_im = torch.complex(x[:, 0, :, :], x[:, 3, :, :]) #central slice + # nsize =or_im.size() + or_k = torch.fft.ifftshift(or_im, dim=(1, 2)) + or_k = torch.fft.fft2(or_k, dim=(1, 2)) + or_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(or_k, dim=(1, 2)) + +###################################################### + # or_k_ = or_k + # or_k_ = or_k_.cpu() + # plt.figure(6) + # plt.imshow(np.log(abs(or_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('input_k_or') + + # or_im_ = torch.complex(or_im[:,0,:,:],or_im[:,1,:,:]) + # or_im_ = or_im_.cpu() + # plt.figure(7) + # plt.imshow((abs(or_im_.detach_().numpy()[2, :, :])), cmap='gray') + # plt.savefig('input_img_or') +################################################# + + + x_or = x + y = x[:,0:2,:,:] + y[:,1,:,:] = x[:,3,:,:] #the central slice + or_im = y + + for i in range(2): + x = x_or + x[:,0,:,:].data = y[:,0,:,:] + x[:,3,:,:].data = y[:,1,:,:].data + + new_k = torch.complex(x[:, 0, :, :], x[:, 3, :, :]) +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(70) + # plt.imshow((abs(new_k_.detach_().numpy()[2, :, :])), cmap='gray') + # plt.savefig('iter_'+str(i)+'_img_before_DC') +################################### + new_k = torch.fft.ifftshift(new_k, dim=(1,2)) + new_k = torch.fft.fft2(new_k, dim=(1, 2)) + new_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(new_k, dim=(1,2)) + +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(80) + # plt.imshow(np.log(abs(new_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('iter_'+str(i)+'_kspace_before_DC') + +################################### + new_k[:, :, :pf] = or_k[:, :, :pf] # only keep the measured data + new_k = torch.fft.ifftshift(new_k, dim=(1, 2)) + new_k = torch.fft.ifft2(new_k, dim=(1, 2)) + new_k = math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(new_k, dim=(1, 2)) + new_k = torch.stack((torch.real(new_k), torch.imag(new_k)), dim=1) + y1 = y + self.t*(new_k - or_im) # t learnable parameter ## only keep the measured data + x[:,0,:,:].data = y1[:,0,:,:] + x[:,3,:,:].data = y1[:,1,:,:] + +# ################################### + # test_im = torch.complex(x[:, 0, :, :], x[:, 1, :, :]) + # # test_k = test_im + # test_k = torch.fft.ifftshift(test_im, dim=(1, 2)) + # test_k = torch.fft.fft2(test_k, dim=(1, 2)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(1, 2)) + + # test_k_ = test_k + # test_k_ = test_k_.cpu() + # plt.figure(10) + # plt.imshow(np.log(abs(test_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('iter_'+str(i)+'_kspace_after_soft_DC') + + # x_ = torch.complex(x[:, 0, :, :], x[:, 1, :, :]) + # x_ = x_.cpu() + # plt.figure(61) + # plt.imshow(abs(x_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('iter_'+str(i)+'_img_after_soft_DC_input') + +################################### +# + res = self.head(x) +################################### + # test_im = res + # test_im = torch.complex(res[:, :64, :, :], res[:, 64:, :, :]) + # # test_k = test_im + # test_k = torch.fft.fftshift(test_im, dim=(2, 3)) + # test_k = torch.fft.fft2(test_k, dim=(2, 3)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(2, 3)) + # res_dis = test_im.permute(1,0,2,3) + # slice1 = torch.reshape(res_dis[:,2,:,:],(64,1,256,256)) + # torchvision.utils.save_image(abs(slice1),'iter_'+str(i)+'_img_head_output.png',normalize=True) + + # res_dis = test_k.permute(1,0,2,3) + # slice1 = torch.reshape(res_dis[:,2,:,:],(64,1,256,256)) + # torchvision.utils.save_image(torch.log(abs(slice1)),'iter_'+str(i)+'_k_head_output.png',normalize=True) +################################### + + + + + res = self.body(res) + +# ################################### + # test_im = torch.complex(res[:, :64, :, :], res[:, 64:, :, :]) + # # test_k = test_im + # test_k = torch.fft.ifftshift(test_im, dim=(2, 3)) + # test_k = torch.fft.fft2(test_k, dim=(2, 3)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(2, 3)) + # res_dis = test_im.permute(1,0,2,3) + # slice1 = torch.reshape(res_dis[:,2,:,:],(64,1,256,256)) + # torchvision.utils.save_image(abs(slice1),'iter_'+str(i)+'_img_body_output.png',normalize=True,scale_each = True) + + # res_dis = test_k.permute(1,0,2,3) + # slice1 = torch.reshape(res_dis[:,2,:,:],(64,1,256,256)) + # torchvision.utils.save_image(torch.log(abs(slice1)),'iter_'+str(i)+'_k_body_output.png',normalize=True,scale_each = True) + + + # y = self.tail(res) + # test_im = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + # # test_k =test_im + # test_k = torch.fft.fftshift(test_im, dim=(1, 2)) + # test_k = torch.fft.fft2(test_k, dim=(1, 2)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(1, 2)) + + # test_k_ = test_k + # test_k_ = test_k_.cpu() + # plt.figure(10) + # plt.imshow( + # np.log(abs(test_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('iter_'+str(i)+'_kspace_tail_output_residual') + + + # y_ = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + # y_ = y_.cpu() + # plt.figure(71) + # plt.imshow(abs(y_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('iter_'+str(i)+'_img_tail_output_residual') +################################### + y = self.tail(res) + y +##################################### + + # test_im = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + # test_k = test_im + # test_k = torch.fft.ifftshift(test_im, dim=(1, 2)) + # test_k = torch.fft.fft2(test_k, dim=(1, 2)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(1, 2)) + # test_k_ = test_k + # test_k_ = test_k_.cpu() + # plt.figure(10) + # plt.imshow(np.log(abs(test_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('iter_'+str(i)+'_kspace_tail_output') + + + # y_ = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + # y_ = y_.cpu() + # plt.figure(71) + # plt.imshow(abs(y_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('iter_'+str(i)+'_img_tail_output') + + # plt.figure(62) + # plt.imshow(abs(y_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('img_before_hard_DC') +################################### + new_k = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + new_k = torch.fft.ifftshift(new_k, dim=(1, 2)) + new_k = torch.fft.fft2(new_k, dim=(1, 2)) + new_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(new_k, dim=(1, 2)) +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(62) + # plt.imshow(np.log(abs(new_k_.detach_().numpy()[2, :, :])), cmap='gray') + # plt.savefig('k_before_hard_DC') +################################### + new_k[:, :, pf_com:] = or_k[:, :, pf_com:] +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(62) + # plt.imshow(np.log(abs(new_k_.detach_().numpy()[2, :, :])), cmap='gray') + # plt.savefig('k_after_hard_DC') +################################### + new_k = torch.fft.ifftshift(new_k, dim=(1, 2)) + new_k = torch.fft.ifft2(new_k, dim=(1, 2)) + new_k = math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(new_k, dim=(1, 2)) +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(6) + # plt.imshow(abs(new_k_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('img_after_hard_DC') +################################### + y = torch.stack((torch.real(new_k), torch.imag(new_k)), dim=1) + return y + +class Net_cpx_2D(nn.Module): + def __init__(self): + super(Net_cpx_2D, self).__init__() + + n_resblocks = 16 + n_feats = 32 + + kernel_size = 3 + self.scale = 1 + n_colors = 1 + + # act = ComplexReLU() + # act = nn.ReLU(True) + # self.sub_mean = MeanShift(rgb_range) + # self.add_mean = MeanShift(rgb_range, sign=1) + + # define head module + # m_head = ComplexConv2d(n_colors, n_feats, kernel_size) + + # define body module + m_body = [ + ResBlock(n_feats, kernel_size, res_scale=1 + ) for _ in range(n_resblocks) + ] + m_body.append(ComplexConv2d(n_feats, n_feats, kernel_size)) + + # define tail module + + # m_tail = [ + # ComplexConv2d(n_feats, n_colors, kernel_size) + # ] + + self.head = ComplexConv2d(n_colors, n_feats, kernel_size) + self.body = nn.Sequential(*m_body) + self.tail = ComplexConv2d(n_feats, n_colors, kernel_size) + + # self.cc = ComplexConv2d(1,64) + + self.register_parameter("t", nn.Parameter(-2*torch.ones(1))) + + def forward(self, x): + or_im = x + nsize =x.size() + pf_1= math.floor(nsize[3]*0.40) + pf_com_1 =nsize[3] -pf_1 + pf_0= math.floor(nsize[2]*0.40) + pf_com_0 =nsize[2] -pf_0 + + # print(nsize[3]) + # pf_ratio =torch.int(torch.floor(nsize[3]*0.45)) + or_k = torch.complex(or_im[:, 0, :, :], or_im[:, 1, :, :]) + or_k = torch.fft.ifftshift(or_k, dim=(1, 2)) + or_k = torch.fft.fft2(or_k, dim=(1, 2)) + or_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(or_k, dim=(1, 2)) + +###################################################### + # or_k_ = or_k + # or_k_ = or_k_.cpu() + # plt.figure(6) + # plt.imshow(np.log(abs(or_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('input_k_or') + + # or_im_ = torch.complex(or_im[:,0,:,:],or_im[:,1,:,:]) + # or_im_ = or_im_.cpu() + # plt.figure(7) + # plt.imshow((abs(or_im_.detach_().numpy()[2, :, :])), cmap='gray') + # plt.savefig('input_img_or') +################################################# + + y = x + for i in range(2): + x = y + new_k = torch.complex(x[:, 0, :, :], x[:, 1, :, :]) + +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(70) + # plt.imshow((abs(new_k_.detach_().numpy()[2, :, :])), cmap='gray') + # plt.savefig('iter_'+str(i)+'_img_before_DC') +################################### + new_k = torch.fft.ifftshift(new_k, dim=(1,2)) + new_k = torch.fft.fft2(new_k, dim=(1, 2)) + new_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(new_k, dim=(1,2)) + +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(80) + # plt.imshow(np.log(abs(new_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('iter_'+str(i)+'_kspace_before_DC') + +################################### + new_k[:, :pf_0, :] = or_k[:, :pf_0, :] # only keep the measured data + new_k[:, :, :pf_1] = or_k[:, :, :pf_1] # only keep the measured data + + new_k = torch.fft.ifftshift(new_k, dim=(1, 2)) + new_k = torch.fft.ifft2(new_k, dim=(1, 2)) + new_k = math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(new_k, dim=(1, 2)) + new_k = torch.stack((torch.real(new_k), torch.imag(new_k)), dim=1) + x = x + self.t*(new_k - or_im) # t learnable parameter ## only keep the measured data + +# ################################### + # test_im = torch.complex(x[:, 0, :, :], x[:, 1, :, :]) + # # test_k = test_im + # test_k = torch.fft.ifftshift(test_im, dim=(1, 2)) + # test_k = torch.fft.fft2(test_k, dim=(1, 2)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(1, 2)) + + # test_k_ = test_k + # test_k_ = test_k_.cpu() + # plt.figure(10) + # plt.imshow(np.log(abs(test_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('iter_'+str(i)+'_kspace_after_soft_DC') + + # x_ = torch.complex(x[:, 0, :, :], x[:, 1, :, :]) + # x_ = x_.cpu() + # plt.figure(61) + # plt.imshow(abs(x_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('iter_'+str(i)+'_img_after_soft_DC_input') + +################################### +# + res = self.head(x) +################################### + # test_im = res + # test_im = torch.complex(res[:, :64, :, :], res[:, 64:, :, :]) + # # test_k = test_im + # test_k = torch.fft.fftshift(test_im, dim=(2, 3)) + # test_k = torch.fft.fft2(test_k, dim=(2, 3)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(2, 3)) + # res_dis = test_im.permute(1,0,2,3) + # slice1 = torch.reshape(res_dis[:,2,:,:],(64,1,256,256)) + # torchvision.utils.save_image(abs(slice1),'iter_'+str(i)+'_img_head_output.png',normalize=True) + + # res_dis = test_k.permute(1,0,2,3) + # slice1 = torch.reshape(res_dis[:,2,:,:],(64,1,256,256)) + # torchvision.utils.save_image(torch.log(abs(slice1)),'iter_'+str(i)+'_k_head_output.png',normalize=True) +################################### + + + + + res = self.body(res) + +# ################################### + # test_im = torch.complex(res[:, :64, :, :], res[:, 64:, :, :]) + # # test_k = test_im + # test_k = torch.fft.ifftshift(test_im, dim=(2, 3)) + # test_k = torch.fft.fft2(test_k, dim=(2, 3)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(2, 3)) + # res_dis = test_im.permute(1,0,2,3) + # slice1 = torch.reshape(res_dis[:,2,:,:],(64,1,256,256)) + # torchvision.utils.save_image(abs(slice1),'iter_'+str(i)+'_img_body_output.png',normalize=True,scale_each = True) + + # res_dis = test_k.permute(1,0,2,3) + # slice1 = torch.reshape(res_dis[:,2,:,:],(64,1,256,256)) + # torchvision.utils.save_image(torch.log(abs(slice1)),'iter_'+str(i)+'_k_body_output.png',normalize=True,scale_each = True) + + + # y = self.tail(res) + # test_im = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + # # test_k =test_im + # test_k = torch.fft.fftshift(test_im, dim=(1, 2)) + # test_k = torch.fft.fft2(test_k, dim=(1, 2)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(1, 2)) + + # test_k_ = test_k + # test_k_ = test_k_.cpu() + # plt.figure(10) + # plt.imshow( + # np.log(abs(test_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('iter_'+str(i)+'_kspace_tail_output_residual') + + + # y_ = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + # y_ = y_.cpu() + # plt.figure(71) + # plt.imshow(abs(y_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('iter_'+str(i)+'_img_tail_output_residual') +################################### + y = self.tail(res) + x +##################################### + + # test_im = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + # test_k = test_im + # test_k = torch.fft.ifftshift(test_im, dim=(1, 2)) + # test_k = torch.fft.fft2(test_k, dim=(1, 2)) + # test_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(test_k, dim=(1, 2)) + # test_k_ = test_k + # test_k_ = test_k_.cpu() + # plt.figure(10) + # plt.imshow(np.log(abs(test_k_.detach_().numpy()[2, :, :])-1e-10), cmap='gray') + # plt.savefig('iter_'+str(i)+'_kspace_tail_output') + + + # y_ = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + # y_ = y_.cpu() + # plt.figure(71) + # plt.imshow(abs(y_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('iter_'+str(i)+'_img_tail_output') + + # plt.figure(62) + # plt.imshow(abs(y_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('img_before_hard_DC') +################################### + new_k = torch.complex(y[:, 0, :, :], y[:, 1, :, :]) + new_k = torch.fft.ifftshift(new_k, dim=(1, 2)) + new_k = torch.fft.fft2(new_k, dim=(1, 2)) + new_k = 1/math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(new_k, dim=(1, 2)) +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(62) + # plt.imshow(np.log(abs(new_k_.detach_().numpy()[2, :, :])), cmap='gray') + # plt.savefig('k_before_hard_DC') +################################### + or_k_ = or_k + or_k_[:, :pf_0, :] = new_k[:, :pf_0, :] # only keep the measured data + or_k_[:, :, :pf_1] = new_k[:, :, :pf_1] # only keep the measured data + new_k = or_k_ +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(62) + # plt.imshow(np.log(abs(new_k_.detach_().numpy()[2, :, :])), cmap='gray') + # plt.savefig('k_after_hard_DC') +################################### + new_k = torch.fft.ifftshift(new_k, dim=(1, 2)) + new_k = torch.fft.ifft2(new_k, dim=(1, 2)) + new_k = math.sqrt(nsize[2]*nsize[3])*torch.fft.fftshift(new_k, dim=(1, 2)) +################################### + # new_k_ = new_k + # new_k_ = new_k_.cpu() + # plt.figure(6) + # plt.imshow(abs(new_k_.detach_().numpy()[2, :, :]), cmap='gray') + # plt.savefig('img_after_hard_DC') +################################### + y = torch.stack((torch.real(new_k), torch.imag(new_k)), dim=1) + return y diff --git a/test_even_odd_ms_torchfft.py b/test_even_odd_ms_torchfft.py new file mode 100644 index 0000000..91c8b20 --- /dev/null +++ b/test_even_odd_ms_torchfft.py @@ -0,0 +1,218 @@ +import numpy as np +import pandas as pd +import datetime +# from model2 import Net1, Net2 +from model2_real import Net_real +# from model2_cpx import Net_cpx +import torch.optim as optim +from scipy import io +import argparse +import os +import torch +from torch import nn +from torch.utils.data import Dataset, DataLoader +import h5py +import matplotlib.pyplot as plt +import h5py +import matplotlib +from PIL import Image +import math +from sklearn.metrics import confusion_matrix +import pylab as pl +import matplotlib.pyplot as plt +import numpy as np +import itertools +import math +from losses import SSIMLoss2D_MC + +os.environ["CUDA_VISIBLE_DEVICES"]="0" #USE gpu 1, gp0 cannot be used for some reason +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = torch.device("cpu") +num_workers = 0 +# current_file = '//nfs/bisp_data_server/Linfang/PF_recon/PF55/SS/' +current_file ='//media/bisp/New Volume/Linfang/PF_DL_paper_cpx_ssim/PF55/EMS/' +current_file_name = current_file + '/CC_brain/' +# nslice = 75 + + +class prepareData(Dataset): + def __init__(self, train_or_test): + + self.files = os.listdir(current_file_name+train_or_test) + self.train_or_test= train_or_test + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + + data = torch.load(current_file_name+self.train_or_test+'/'+self.files[idx]) + return data['k-space'], data['label'] + + +testset = prepareData('test') + +# testset = prepareData('test_'+flag) +testloader = torch.utils.data.DataLoader(testset, batch_size=1,shuffle=False, num_workers=num_workers) +# model = torch.load(current_file+'/real_64_16'+'/epoch-33-0.0128679229.pth')# ss +# model = torch.load(current_file+'/cpx_64_16'+'/epoch-33-0.0146769155.pth')# ss +model = torch.load(current_file +'/real_ssim_64_16'+'/epoch-30-0.0734601021.pth')# repeat once +# model = torch.load(current_file+'/real_ssim_64_16'+'/epoch-34-0.0832024813.pth')# ss +# model = torch.load(current_file+'/model_'+flag+'/epoch-32-0.0069624754.pth')# ss +nx=218 +ny=170 +nc =6 +print(model) +save_file = '/real_64_16_ssim_results' +'/' +# save_file = '/real_64_16' +'/' +criterion1 = nn.L1Loss() +ssim = SSIMLoss2D_MC(in_chan=2) +model.eval() +loss_validation_list = [] +loss_batch = [] +loss = [] +print('\n testing...') +for i, data in enumerate(testloader, 0): + inputs = data[0].reshape(-1,nc,ny,nx).to(device) + label = data[1].reshape(-1,2,ny,nx).to(device) + # labels = inputs +label + labels= label + labels[:,0,:,:]= label[:,0,:,:] +inputs[:,0,:,:] + labels[:,1,:,:]= label[:,1,:,:] +inputs[:,3,:,:] + + + os.makedirs(current_file_name+save_file+str(i), exist_ok=True) + inpin = torch.complex(inputs[:,0,:,:],inputs[:,3,:,:]) + inpin = inpin.cpu() + sz = inpin.size() + nslice = sz[0]//2 + plt.figure(2) + plt.imshow(abs(inpin[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/8_inpin') + plt.figure(2) + plt.imshow(np.angle(inpin[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/18_inpin_angle') + + ref = torch.complex(labels[:,0,:,:],labels[:,1,:,:]) + ref = ref.cpu() + plt.figure(1) + plt.imshow(np.abs(ref[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/8_ref') + plt.figure(1) + plt.imshow(np.angle(ref[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/18_ref_angle') + + + labelin =torch.complex(labels[:,0,:,:],labels[:,1,:,:])-torch.complex(inputs[:,0,:,:],inputs[:,3,:,:]) + labelin = labelin.cpu() + plt.figure(3) + plt.imshow(np.abs(labelin[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/6_label') + + + ref_ksp = torch.fft.ifftshift(ref,dim=(1, 2)) + ref_k= torch.fft.fft2(ref_ksp ,dim=(1, 2)) + ref_k =1/math.sqrt(sz[1]*sz[2])* torch.fft.fftshift(ref_k,dim=(1, 2)) +## for 3d k-sapce + # ref_k= np.fft.fftshift(ref_k,axes=0) + # ref_k = np.fft.fft(ref_k,axis=0) + + plt.figure(4) + plt.imshow(torch.log(abs(ref_k[nslice,:,:])),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/7_ref_k') + + # img to k-space + inp_ksp = torch.fft.ifftshift(inpin,dim=(1, 2)) + inp_k= torch.fft.fft2(inp_ksp ,dim=(1, 2)) + inp_k = 1/math.sqrt(sz[1]*sz[2])*torch.fft.fftshift(inp_k,dim=(1, 2)) +## for 3d k-sapce + # inp_k= np.fft.fftshift(inp_k,axes=0) + # inp_k = np.fft.fft(inp_k,axis=0) + + plt.figure(5) + plt.imshow(torch.log(torch.abs(inp_k[nslice,:,:])),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/7_inp_k') + + # print(label.shape) + # plt.imshow(np.log(np.abs(label[1,:,:])),cmap='gray') + # freq= np.fft.ifft2(label,axes=(2,3)) + # img= np.fft.ifftshift(freq,axes=(2,3)) + # plt.figure(1) + # plt.imshow(np.abs(img[12,0,:,:]),cmap='gray') + # plt.show() + +# Lable_tumor.append(labels) + with torch.no_grad(): + outs = model(inputs) + + + + + + # la_out = outs.cuda().data.cpu() + + # la_out= la_out.numpy() + + output = torch.complex(outs[:,0,:,:],outs[:,1,:,:]) + output = output.cpu() + plt.figure(6) + plt.imshow(abs(output[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/8_output_ref') + plt.figure(6) + plt.imshow(np.angle(output[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/18_output_ref_angle') + + label_out =torch.complex(outs[:,0,:,:],outs[:,1,:,:])-torch.complex(inputs[:,0,:,:],inputs[:,3,:,:]) + label_out = label_out.cpu() + plt.figure(7) + plt.imshow(abs(label_out[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/6_output_label') + + # img to k-space + outp_ksp = torch.fft.ifftshift(output,dim=(1,2)) + outp_k= torch.fft.fft2(outp_ksp ,dim=(1,2)) + outp_k= 1/math.sqrt(sz[1]*sz[2])*torch.fft.fftshift(outp_k,dim=(1,2)) +## for 3d k-sapce + # outp_k= np.fft.fftshift(outp_k,axes=0) + # outp_k = np.fft.fft(outp_k,axis=0) + + plt.figure(8) + plt.imshow(torch.log(abs(outp_k[nslice,:,:])),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/7_outp_k') + + plt.figure(9) + residual = output-ref + plt.imshow(abs(residual[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/9_residual') + + # loss = criterion1(outs, labels) + loss = ssim(outs, labels,1) + loss_batch.append(loss.item()) + loss_validation_list.append(round(sum(loss_batch) / len(loss_batch),10)) + print(loss_validation_list) + # output = outs.cuda().data.cpu() + # labelo =labels.cuda().data.cpu() + # inputo = inputs.cuda().data.cpu() + # outp_k = outp_k[1:-1,:,:] + # ref_k = ref_k[1:-1,:,:] + # inp_k = inp_k[1:-1,:,:] + os.makedirs(current_file_name+save_file+'/outputs/', exist_ok=True) + os.makedirs(current_file_name+save_file+'/reference/', exist_ok=True) + os.makedirs(current_file_name+save_file+'/inputs/', exist_ok=True) + f = h5py.File(current_file_name+save_file+'/outputs/'+str(i)+'.h5','w') + f['k-space'] = outp_k + g = h5py.File(current_file_name+save_file+'/reference/'+str(i)+'.h5','w') + g['k-space'] = ref_k + k = h5py.File(current_file_name+save_file+'/inputs/'+str(i)+'.h5','w') + k['k-space'] = inp_k + # f = h5py.File(current_file_name+save_file+'/outputs/'+str(i)+'.h5','w') + # f['k-space'] = output + # g = h5py.File(current_file_name+save_file+'/reference/'+str(i)+'.h5','w') + # g['k-space'] = ref + # k = h5py.File(current_file_name+save_file+'/inputs/'+str(i)+'.h5','w') + # k['k-space'] = inpin + f.close() + g.close() + k.close() + + \ No newline at end of file diff --git a/test_even_odd_ss_torchfft.py b/test_even_odd_ss_torchfft.py new file mode 100644 index 0000000..25322bc --- /dev/null +++ b/test_even_odd_ss_torchfft.py @@ -0,0 +1,222 @@ +import numpy as np +import pandas as pd +import datetime +# from model2 import Net1, Net2 +from model2_real_2D import Net_real +# from model2_cpx import Net_cpx +import torch.optim as optim +from scipy import io +import argparse +import os +import torch +from torch import nn +from torch.utils.data import Dataset, DataLoader +import h5py +import matplotlib.pyplot as plt +import h5py +import matplotlib +from PIL import Image +import math +from sklearn.metrics import confusion_matrix +import pylab as pl +import matplotlib.pyplot as plt +import numpy as np +import itertools +import math +from losses import SSIMLoss2D_MC +import time + +os.environ["CUDA_VISIBLE_DEVICES"]="0" #USE gpu 1, gp0 cannot be used for some reason +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = torch.device("cpu") +num_workers = 0 + +current_file = '//media/bisp/New Volume/Linfang/PF_CC398_218_170_218/PF60/SS//' +# current_file ='//media/bisp/New Volume/Linfang/PF_DL_paper_FSE/PF55/SS/' +# current_file_name = current_file + 'test_NYU_brain_0316_new/' +# current_file_name = current_file + 'NYU_brain_T1_complex/' +current_file_name = current_file + '/CC_brain_2D/' +# current_file_name = current_file + '/NYU_knee/' + + +class prepareData(Dataset): + def __init__(self, train_or_test): + + self.files = os.listdir(current_file_name+train_or_test) + self.train_or_test= train_or_test + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + + data = torch.load(current_file_name+self.train_or_test+'/'+self.files[idx]) + return data['k-space'], data['label'] + + +testset = prepareData('test') + +# testset = prepareData('test_'+flag) +testloader = torch.utils.data.DataLoader(testset, batch_size=1,shuffle=False, num_workers=num_workers) +# model = torch.load(current_file+'/real_64_16'+'/epoch-33-0.0128679229.pth')# ss +# model = torch.load(current_file+'/cpx_64_16'+'/epoch-33-0.0146769155.pth')# ss +# model = torch.load('/media/bisp/New Volume/Linfang/PF_DL_paper_cpx_ssim/PF55/SS/real_ssim_64_16'+'/epoch-9-0.0866945386.pth')# repeat once +# model = torch.load(current_file+'/real_ssim_64_16_swi'+'/epoch-10-0.0677850842.pth')# ss +# model = torch.load(current_file+'/real_ssim_64_16'+'/epoch-16-0.0836557150.pth')# ss60 +# model = torch.load(current_file+'/real_ssim_64_16'+'/epoch-18-0.0791845322.pth')# ss55 +model = torch.load(current_file+'/real_L1_64_16_cpx'+'/epoch-15-0.0184234809.pth')# ss51 +nx=218 +ny=170 +nc =2 +print(model) +save_file = '/real_64_16_ssim_brain' +'/' +# save_file = '/real_64_16' +'/' +criterion1 = nn.L1Loss() +ssim = SSIMLoss2D_MC(in_chan=2) +model.eval() +loss_validation_list = [] +loss_batch = [] +loss = [] +print('\n testing...') +time_start=time.time() +for i, data in enumerate(testloader, 0): + inputs = data[0].reshape(-1,nc,ny,nx).to(device) + label = data[1].reshape(-1,2,ny,nx).to(device) + labels = inputs +label + + os.makedirs(current_file_name+save_file+str(i), exist_ok=True) + inpin = torch.complex(inputs[:,0,:,:],inputs[:,1,:,:]) + inpin = inpin.cpu() + sz = inpin.size() + nslice = sz[0]//2 + plt.figure(2) + plt.imshow(abs(inpin[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/8_inpin') + plt.figure(2) + plt.imshow(np.angle(inpin[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/18_inpin_angle') + + ref = torch.complex(labels[:,0,:,:],labels[:,1,:,:]) + ref = ref.cpu() + plt.figure(1) + plt.imshow(np.abs(ref[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/8_ref') + plt.figure(1) + plt.imshow(np.angle(ref[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/18_ref_angle') + + + labelin =torch.complex(labels[:,0,:,:],labels[:,1,:,:])-torch.complex(inputs[:,0,:,:],inputs[:,1,:,:]) + labelin = labelin.cpu() + plt.figure(3) + plt.imshow(np.abs(labelin[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/6_label') + + + ref_ksp = torch.fft.ifftshift(ref,dim=(1, 2)) + ref_k= torch.fft.fft2(ref_ksp ,dim=(1, 2)) + ref_k =1/math.sqrt(sz[1]*sz[2])* torch.fft.fftshift(ref_k,dim=(1, 2)) +## for 3d k-sapce + # ref_k= np.fft.fftshift(ref_k,axes=0) + # ref_k = np.fft.fft(ref_k,axis=0) + + plt.figure(4) + plt.imshow(torch.log(abs(ref_k[nslice,:,:])),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/7_ref_k') + + # img to k-space + inp_ksp = torch.fft.ifftshift(inpin,dim=(1, 2)) + inp_k= torch.fft.fft2(inp_ksp ,dim=(1, 2)) + inp_k = 1/math.sqrt(sz[1]*sz[2])*torch.fft.fftshift(inp_k,dim=(1, 2)) +## for 3d k-sapce + # inp_k= np.fft.fftshift(inp_k,axes=0) + # inp_k = np.fft.fft(inp_k,axis=0) + + plt.figure(5) + plt.imshow(torch.log(torch.abs(inp_k[nslice,:,:])),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/7_inp_k') + + # print(label.shape) + # plt.imshow(np.log(np.abs(label[1,:,:])),cmap='gray') + # freq= np.fft.ifft2(label,axes=(2,3)) + # img= np.fft.ifftshift(freq,axes=(2,3)) + # plt.figure(1) + # plt.imshow(np.abs(img[12,0,:,:]),cmap='gray') + # plt.show() + +# Lable_tumor.append(labels) + with torch.no_grad(): + outs = model(inputs) + + + + + + # la_out = outs.cuda().data.cpu() + + # la_out= la_out.numpy() + + output = torch.complex(outs[:,0,:,:],outs[:,1,:,:]) + output = output.cpu() + plt.figure(6) + plt.imshow(abs(output[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/8_output_ref') + plt.figure(6) + plt.imshow(np.angle(output[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/18_output_ref_angle') + + label_out =torch.complex(outs[:,0,:,:],outs[:,1,:,:])-torch.complex(inputs[:,0,:,:],inputs[:,1,:,:]) + label_out = label_out.cpu() + plt.figure(7) + plt.imshow(abs(label_out[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/6_output_label') + + # img to k-space + outp_ksp = torch.fft.ifftshift(output,dim=(1,2)) + outp_k= torch.fft.fft2(outp_ksp ,dim=(1,2)) + outp_k= 1/math.sqrt(sz[1]*sz[2])*torch.fft.fftshift(outp_k,dim=(1,2)) +## for 3d k-sapce + # outp_k= np.fft.fftshift(outp_k,axes=0) + # outp_k = np.fft.fft(outp_k,axis=0) + + plt.figure(8) + plt.imshow(torch.log(abs(outp_k[nslice,:,:])),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/7_outp_k') + + plt.figure(9) + residual = output-ref + plt.imshow(abs(residual[nslice,:,:]),cmap='gray') + plt.savefig(current_file_name+save_file+str(i)+'/9_residual') + + # loss = criterion1(outs, labels) + loss = ssim(outs, labels,1) + loss_batch.append(loss.item()) + loss_validation_list.append(round(sum(loss_batch) / len(loss_batch),10)) + print(loss_validation_list) + # output = outs.cuda().data.cpu() + # labelo =labels.cuda().data.cpu() + # inputo = inputs.cuda().data.cpu() + # outp_k = outp_k[1:-1,:,:] + # ref_k = ref_k[1:-1,:,:] + # inp_k = inp_k[1:-1,:,:] + os.makedirs(current_file_name+save_file+'/outputs/', exist_ok=True) + os.makedirs(current_file_name+save_file+'/reference/', exist_ok=True) + os.makedirs(current_file_name+save_file+'/inputs/', exist_ok=True) + f = h5py.File(current_file_name+save_file+'/outputs/'+str(i)+'.h5','w') + f['k-space'] = outp_k + g = h5py.File(current_file_name+save_file+'/reference/'+str(i)+'.h5','w') + g['k-space'] = ref_k + k = h5py.File(current_file_name+save_file+'/inputs/'+str(i)+'.h5','w') + k['k-space'] = inp_k + # f = h5py.File(current_file_name+save_file+'/outputs/'+str(i)+'.h5','w') + # f['k-space'] = output + # g = h5py.File(current_file_name+save_file+'/reference/'+str(i)+'.h5','w') + # g['k-space'] = ref + # k = h5py.File(current_file_name+save_file+'/inputs/'+str(i)+'.h5','w') + # k['k-space'] = inpin + f.close() + g.close() + k.close() + time_end=time.time() + print('time cost for testing',time_end-time_start,'s') + \ No newline at end of file