Skip to content

Commit

Permalink
Merge pull request #62 from Kevin2/same_img_size
Browse files Browse the repository at this point in the history
Use the same image size in filter_deep
  • Loading branch information
FeGeyer authored Oct 22, 2020
2 parents 4037271 + 5756293 commit 3caddc7
Show file tree
Hide file tree
Showing 30 changed files with 1,121 additions and 285 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ __pycache__/
*.jpg
*.pdf
*.gif
*.png

# make
*_done
Expand Down
6 changes: 6 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@ language: python
python:
- "3.7"

# directories:
# - "/tmp/texlive"
# - "$HOME/.texlive"

before_install:
- sudo apt-get -y install libgeos-dev
- sudo apt-get -y install libproj-dev
# - travis_wait 45 bash ./utilities/travis_setup.sh
# - export PATH="/tmp/texlive/bin/x86_64-linux:$PATH"

install:
- pip install .
Expand Down
38 changes: 19 additions & 19 deletions dl_framework/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,10 +557,10 @@ def __init__(self, img_size):
)

self.conv4_amp = nn.Sequential(
*conv_amp(1, 4, (5, 5), 1, 3, 2)
*conv_amp(1, 4, (5, 5), 1, 4, 2)
)
self.conv4_phase = nn.Sequential(
*conv_phase(1, 4, (5, 5), 1, 3, 2, add=1-pi)
*conv_phase(1, 4, (5, 5), 1, 4, 2, add=1-pi)
)
self.conv5_amp = nn.Sequential(
*conv_amp(4, 8, (5, 5), 1, 2, 1)
Expand All @@ -569,10 +569,10 @@ def __init__(self, img_size):
*conv_phase(4, 8, (5, 5), 1, 2, 1, add=1-pi)
)
self.conv6_amp = nn.Sequential(
*conv_amp(8, 12, (3, 3), 1, 3, 2)
*conv_amp(8, 12, (3, 3), 1, 2, 2)
)
self.conv6_phase = nn.Sequential(
*conv_phase(8, 12, (3, 3), 1, 3, 2, add=1-pi)
*conv_phase(8, 12, (3, 3), 1, 2, 2, add=1-pi)
)
self.conv7_amp = nn.Sequential(
*conv_amp(12, 16, (3, 3), 1, 1, 1)
Expand Down Expand Up @@ -710,13 +710,13 @@ def __init__(self, img_size):
)

self.conv4_amp = nn.Sequential(
*conv_amp(1, 4, (5, 5), 1, 3, 2)
*conv_amp(1, 4, (5, 5), 1, 4, 2)
)
self.conv5_amp = nn.Sequential(
*conv_amp(4, 8, (5, 5), 1, 2, 1)
)
self.conv6_amp = nn.Sequential(
*conv_amp(8, 12, (3, 3), 1, 3, 2)
*conv_amp(8, 12, (3, 3), 1, 2, 2)
)
self.conv7_amp = nn.Sequential(
*conv_amp(12, 16, (3, 3), 1, 1, 1)
Expand Down Expand Up @@ -787,51 +787,51 @@ class filter_deep_phase(nn.Module):
def __init__(self, img_size):
super().__init__()
self.conv1_phase = nn.Sequential(
*conv_phase(1, 4, (23, 23), 1, 11, 1, add=-2.1415)
*conv_phase(1, 4, (23, 23), 1, 11, 1, add=1-pi)
)
self.conv2_phase = nn.Sequential(
*conv_phase(4, 8, (21, 21), 1, 10, 1, add=-2.1415)
*conv_phase(4, 8, (21, 21), 1, 10, 1, add=1-pi)
)
self.conv3_phase = nn.Sequential(
*conv_phase(8, 12, (17, 17), 1, 8, 1, add=-2.1415)
*conv_phase(8, 12, (17, 17), 1, 8, 1, add=1-pi)
)
self.conv_con1_phase = nn.Sequential(
LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False),
nn.BatchNorm2d(1),
GeneralELU(-2.1415),
GeneralELU(1-pi),
)

self.conv4_phase = nn.Sequential(
*conv_phase(1, 4, (5, 5), 1, 3, 2, add=-2.1415)
*conv_phase(1, 4, (5, 5), 1, 4, 2, add=1-pi)
)
self.conv5_phase = nn.Sequential(
*conv_phase(4, 8, (5, 5), 1, 2, 1, add=-2.1415)
*conv_phase(4, 8, (5, 5), 1, 2, 1, add=1-pi)
)
self.conv6_phase = nn.Sequential(
*conv_phase(8, 12, (3, 3), 1, 3, 2, add=-2.1415)
*conv_phase(8, 12, (3, 3), 1, 2, 2, add=1-pi)
)
self.conv7_phase = nn.Sequential(
*conv_phase(12, 16, (3, 3), 1, 1, 1, add=-2.1415)
*conv_phase(12, 16, (3, 3), 1, 1, 1, add=1-pi)
)
self.conv_con2_phase = nn.Sequential(
LocallyConnected2d(16, 1, img_size, 1, stride=1, bias=False),
nn.BatchNorm2d(1),
GeneralELU(-2.1415),
GeneralELU(1-pi),
)

self.conv8_phase = nn.Sequential(
*conv_phase(1, 4, (3, 3), 1, 1, 1, add=-2.1415)
*conv_phase(1, 4, (3, 3), 1, 1, 1, add=1-pi)
)
self.conv9_phase = nn.Sequential(
*conv_phase(4, 8, (3, 3), 1, 1, 1, add=-2.1415)
*conv_phase(4, 8, (3, 3), 1, 1, 1, add=1-pi)
)
self.conv10_phase = nn.Sequential(
*conv_phase(8, 12, (3, 3), 1, 2, 2, add=-2.1415)
*conv_phase(8, 12, (3, 3), 1, 2, 2, add=1-pi)
)
self.conv_con3_phase = nn.Sequential(
LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False),
nn.BatchNorm2d(1),
GeneralELU(-2.1415),
GeneralELU(1-pi),
)
self.symmetry_imag = Lambda(partial(symmetry, mode='imag'))

Expand Down
18 changes: 16 additions & 2 deletions dl_framework/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,22 @@ def plot_lr(self):
plt.tight_layout()

def plot_loss(self, log=True):
plt.plot(self.train_losses, label="train loss")
plt.plot(self.valid_losses, label="valid loss")
import matplotlib as mpl

# make nice Latex friendly plots
# mpl.use("pgf")
# mpl.rcParams.update(
# {
# "font.size": 12,
# "font.family": "sans-serif",
# "text.usetex": True,
# "pgf.rcfonts": False,
# "pgf.texsystem": "lualatex",
# }
# )

plt.plot(self.train_losses, label="training loss")
plt.plot(self.valid_losses, label="validation loss")
if log:
plt.yscale("log")
plt.xlabel(r"Number of Epochs")
Expand Down
2 changes: 2 additions & 0 deletions dl_framework/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def open_image(self, var, i):

data_channel = torch.cat([data_amp, data_phase], dim=1)
else:
if data.shape[1] == 2:
raise ValueError("Two channeled data is used despite Fourier being False. Set Fourier to True!")
if len(i) == 1:
data_channel = data.reshape(data.shape[-1] ** 2)
else:
Expand Down
14 changes: 13 additions & 1 deletion dl_framework/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
import dl_framework.architectures as architecture
from dl_framework.model import load_pre_model

# make nice Latex friendly plots
# mpl.use("pgf")
# mpl.rcParams.update(
# {
# "font.size": 12,
# "font.family": "sans-serif",
# "text.usetex": True,
# "pgf.rcfonts": False,
# "pgf.texsystem": "lualatex",
# }
# )


def load_pretrained_model(arch_name, model_path):
"""
Expand Down Expand Up @@ -121,7 +133,7 @@ def plot_loss(learn, model_path):
save_path = model_path.split(".model")[0]
print("\nPlotting Loss for: {}\n".format(name_model))
learn.recorder.plot_loss()
plt.title(r"{}".format(name_model))
plt.title(r"{}".format(name_model.replace("_", " ")))
plt.savefig("{}_loss.pdf".format(save_path), bbox_inches="tight", pad_inches=0.01)
plt.clf()
mpl.rcParams.update(mpl.rcParamsDefault)
Expand Down
26 changes: 23 additions & 3 deletions dl_framework/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@
from tqdm import tqdm
import sys
from functools import partial
from dl_framework.loss_functions import init_feature_loss, loss_amp, loss_phase
from dl_framework.loss_functions import (
init_feature_loss,
loss_amp,
loss_phase,
loss_msssim,
loss_mse_msssim,
loss_mse_msssim_phase,
loss_mse_msssim_amp,
loss_msssim_amp,
)
from dl_framework.callbacks import (
AvgStatsCallback,
BatchTransformXCallback,
Expand Down Expand Up @@ -188,6 +197,7 @@ def define_learner(
opt_func=torch.optim.Adam,
):
cbfs.extend([
# commented out because of normed and limited input values
# partial(BatchTransformXCallback, norm),
])
if not test:
Expand All @@ -202,7 +212,7 @@ def define_learner(
])
if not test and not lr_find:
cbfs.extend([
partial(LoggerCallback, model_name=model_name),
partial(LoggerCallback, model_name=model_name),
data_aug,
])

Expand All @@ -216,8 +226,18 @@ def define_learner(
loss_func = loss_amp
elif loss_func == "loss_phase":
loss_func = loss_phase
elif loss_func == "msssim":
loss_func = loss_msssim
elif loss_func == "mse_msssim":
loss_func = loss_mse_msssim
elif loss_func == "mse_msssim_phase":
loss_func = loss_mse_msssim_phase
elif loss_func == "mse_msssim_amp":
loss_func = loss_mse_msssim_amp
elif loss_func == "msssim_amp":
loss_func = loss_msssim_amp
else:
print("\n No matching loss function! Exiting. \n")
print("\n No matching loss function or architecture! Exiting. \n")
sys.exit(1)

# Combine model and data in learner
Expand Down
Loading

0 comments on commit 3caddc7

Please sign in to comment.