Skip to content

Commit

Permalink
fix: ignore update net1 in train2
Browse files Browse the repository at this point in the history
  • Loading branch information
andabi committed Mar 14, 2018
1 parent 61e3b84 commit 4eb0d8f
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 8 deletions.
2 changes: 1 addition & 1 deletion hparams/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ train2:
lr_cyclic_steps: 5000
clip_value_max: 3.
clip_value_min: -3.
clip_norm: 100
clip_norm: 10
mol_step: 0.001
num_epochs: 10000
steps_per_epoch: 100
Expand Down
13 changes: 8 additions & 5 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils import (
summary, get_current_tower_context, optimizer, gradproc)
import re

import tensorpack_extension


class Net1(ModelDesc):
Expand Down Expand Up @@ -107,11 +110,11 @@ def _build_graph(self, inputs):
def _get_optimizer(self):
lr = tf.get_variable('learning_rate', initializer=hp.train2.lr, trainable=False)
opt = tf.train.AdamOptimizer(learning_rate=lr)
gradprocs = [gradproc.MapGradient(lambda grad: grad, regex='.*net2.*'), # apply only gradients of net2
gradproc.GlobalNormClip(hp.train2.clip_norm),
# gradproc.PrintGradient()]
]

gradprocs = [
tensorpack_extension.FilterGradientVariables('.*net2.*', verbose=False),
gradproc.GlobalNormClip(hp.train2.clip_norm),
gradproc.PrintGradient(),
]
return optimizer.apply_grad_processors(opt, gradprocs)

@auto_reuse_variable_scope
Expand Down
35 changes: 35 additions & 0 deletions tensorpack_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
#!/usr/bin/env python

import re
from tensorpack.utils import logger
from tensorpack.tfutils.gradproc import GradientProcessor


class FilterGradientVariables(GradientProcessor):
"""
Skip the update of certain variables and print a warning.
"""

def __init__(self, var_regex='.*', verbose=True):
"""
Args:
var_regex (string): regular expression to match variable to update.
verbose (bool): whether to print warning about None gradients.
"""
super(FilterGradientVariables, self).__init__()
self._regex = var_regex
self._verbose = verbose

def _process(self, grads):
g = []
to_print = []
for grad, var in grads:
if re.match(self._regex, var.op.name):
g.append((grad, var))
else:
to_print.append(var.op.name)
if self._verbose and len(to_print):
message = ', '.join(to_print)
logger.warn("No gradient w.r.t these trainable variables: {}".format(message))
return g
2 changes: 1 addition & 1 deletion train1.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def train(args, logdir):
steps_per_epoch=hp.train1.steps_per_epoch,
# session_config=session_conf
)
ckpt = args.ckpt if args.ckpt else tf.train.latest_checkpoint(logdir)
ckpt = '{}/{}'.format(logdir, args.ckpt) if args.ckpt else tf.train.latest_checkpoint(logdir)
if ckpt:
train_conf.session_init = SaverRestore(ckpt)

Expand Down
2 changes: 1 addition & 1 deletion train2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def train(args, logdir1, logdir2):
# )

session_inits = []
ckpt2 = args.ckpt if args.ckpt else tf.train.latest_checkpoint(logdir2)
ckpt2 = '{}/{}'.format(logdir2, args.ckpt) if args.ckpt else tf.train.latest_checkpoint(logdir2)
if ckpt2:
session_inits.append(SaverRestore(ckpt2))
ckpt1 = tf.train.latest_checkpoint(logdir1)
Expand Down

0 comments on commit 4eb0d8f

Please sign in to comment.