Skip to content

Commit

Permalink
Add new model and change the clip size
Browse files Browse the repository at this point in the history
  • Loading branch information
kyloris0660 committed May 17, 2018
1 parent 6ed5f70 commit a0687f5
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .idea/sk-2x_tensorlayer.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 9 additions & 7 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# parameters for model and training
INPUT_SIZE = 28
INPUT_SIZE = 36
NUM_CHANNELS = 3
PATCH_SIZE = 80
SCALE_FACTOR = 2
LABEL_SIZE = SCALE_FACTOR * INPUT_SIZE
JPEG_NOISE_LEVEL = 1
GAUSSIAN_NOISE_STD = 0.01
BATCH_SIZE = 16
NUM_EPOCHS = 100000
NUM_EPOCHS = 500000
MODEL_NAME = 'vgg7'
# train dataset path and etc
TRAIN_PATH = 'E:/image_data/data_train/'
TEST_PATH = 'E:/image_data/data_test/'
TRAINING_SUMMARY_PATH = 'E:/image_data/Training_summary/'
CHECKPOINT_PATH = 'E:/image_data/checkpoint/'
INFERENCE_SAVE_PATH = 'C:/Users/kyloris/Desktop/inference'
TRAIN_PATH = '/Users/kyloris/Projects/image_data/data_train/'
TEST_PATH = '/Users/kyloris/Projects/image_data/data_test/'
TRAINING_SUMMARY_PATH = '/Users/kyloris/Projects/image_data/Training_summary/'
CHECKPOINT_PATH = '/Users/kyloris/Projects/image_data/checkpoint/'
INFERENCE_SAVE_PATH = '/Users/kyloris/Projects/image_data/inference/'
OUTPUT_SAVE_PATH = './save/'
70 changes: 69 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
from config import *


def create_model(patches):
def create_model(patches, model_name=MODEL_NAME):
if model_name == 'vgg7':
return vgg7(patches)
elif model_name == 'vgg12':
return vgg12(patches)


def vgg7(patches):
with tf.variable_scope('vgg7'):
net = tl.layers.InputLayer(patches, name='input_layer')
net = tl.layers.Conv2dLayer(net, act=tf.nn.leaky_relu,
Expand Down Expand Up @@ -44,6 +51,67 @@ def create_model(patches):
return net.outputs


def vgg12(patches):
with tf.variable_scope('vgg12'):
net = tl.layers.InputLayer(patches, name='input_layer')
net = tl.layers.Conv2dLayer(net, act=tf.nn.leaky_relu,
shape=(3, 3, 3, 16),
padding='VALID',
name='Conv1')
net = tl.layers.Conv2dLayer(net, act=tf.nn.leaky_relu,
shape=(3, 3, 16, 16),
padding='VALID',
name='Conv2')
net = tl.layers.Conv2dLayer(net, act=tf.nn.leaky_relu,
shape=(3, 3, 16, 16),
padding='VALID',
name='Conv3')
net = tl.layers.Conv2dLayer(net, act=tf.nn.leaky_relu,
shape=(3, 3, 16, 32),
padding='VALID',
name='Conv4')
net = tl.layers.Conv2dLayer(net, act=tf.nn.leaky_relu,
shape=(3, 3, 32, 64),
padding='VALID',
name='Conv5')
net = tl.layers.Conv2dLayer(net, act=tf.nn.leaky_relu,
shape=(3, 3, 64, 64),
padding='VALID',
name='Conv6')
net = tl.layers.Conv2dLayer(net, act=tf.nn.leaky_relu,
shape=(3, 3, 64, 64),
padding='VALID',
name='Conv7')
net = tl.layers.Conv2dLayer(net, act=tf.nn.leaky_relu,
shape=(3, 3, 64, 128),
padding='VALID',
name='Conv8')
net = tl.layers.Conv2dLayer(net, act=tf.nn.leaky_relu,
shape=(3, 3, 128, 128),
padding='VALID',
name='Conv9')
net = tl.layers.Conv2dLayer(net, act=tf.nn.leaky_relu,
shape=(3, 3, 128, 256),
padding='VALID',
name='Conv10')
net = tl.layers.Conv2dLayer(net, act=tf.nn.leaky_relu,
shape=(3, 3, 256, 256),
padding='VALID',
name='Conv11')

batch_size = int(net.outputs.shape[0])
rows = int(net.outputs.shape[1])
cows = int(net.outputs.shape[2])
channels = int(patches.get_shape()[3])

net = tl.layers.DeConv2dLayer(net,
shape=(4, 4, 3, 256),
output_shape=(batch_size, rows * 2, cows * 2, channels),
strides=(1, 2, 2, 1),
name='Deconv')
return net.outputs


def s_mse_loss(inference, ground_truth, name='mse_loss'):
with tf.name_scope(name):
slice_begin = (int(ground_truth.get_shape()[1]) - int(inference.get_shape()[1])) // 2
Expand Down
7 changes: 5 additions & 2 deletions sk-2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@


def main():
temp = input('Enter filename: ')
f_name = './' + temp + '.png'
f2x_name = OUTPUT_SAVE_PATH + temp + '_2x.png'
ckpt_state = tf.train.get_checkpoint_state(CHECKPOINT_PATH)
if not ckpt_state or not ckpt_state.model_checkpoint_path:
print('No check point files are found!')
Expand All @@ -28,7 +31,7 @@ def main():
saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, ckpt_files[-1]) # load the lateast model

low_res_img = cv.imread('./005.png')
low_res_img = cv.imread(f_name)
output_size = int(inferences.get_shape()[1])
input_size = INPUT_SIZE
available_size = output_size // SCALE_FACTOR
Expand Down Expand Up @@ -63,7 +66,7 @@ def main():
high_res_img = tf.image.convert_image_dtype(high_res_img, tf.uint8, True)

high_res_img = high_res_img[:SCALE_FACTOR * img_rows, :SCALE_FACTOR * img_cols, ...]
cv.imwrite('./005-2x.png', high_res_img.eval(session=sess))
cv.imwrite(f2x_name, high_res_img.eval(session=sess))

print('Enhance Finished!')

Expand Down

0 comments on commit a0687f5

Please sign in to comment.