Skip to content

Commit

Permalink
#207 migrate cnn module to tf2
Browse files Browse the repository at this point in the history
  • Loading branch information
arrrobase committed Apr 2, 2020
1 parent 5a04664 commit b1d03d7
Show file tree
Hide file tree
Showing 13 changed files with 781 additions and 27 deletions.
20 changes: 20 additions & 0 deletions nems/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,26 @@ def update_startdate(queueid=None):
return r


def update_job_pid(pid, queueid=None):
"""
in tQueue, update the PID
"""
if queueid is None:
if 'QUEUEID' in os.environ:
queueid = os.environ['QUEUEID']
else:
log.warning("queueid not specified or found in os.environ")
return 0

engine = Engine()
conn = engine.connect()
sql = ("UPDATE tQueue SET pid={} WHERE id={}"
.format(pid, queueid))
r = conn.execute(sql)
conn.close()
return r


def update_job_tick(queueid=None):
"""
update current machine's load in the cluster db and tick off a step
Expand Down
29 changes: 26 additions & 3 deletions nems/modelspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,9 @@ def quickplot(self, rec=None, epoch=None, occurrence=None, fit_index=None,
"""
if rec is None:
rec = self.recording
# rec = rec.copy()

# rec['resp'] = rec['resp'].rasterize()

if fit_index is not None:
self.fit_index = fit_index
Expand Down Expand Up @@ -595,18 +598,22 @@ def quickplot(self, rec=None, epoch=None, occurrence=None, fit_index=None,
if epoch is not None:
try:
epoch_bounds = rec['resp'].get_epoch_bounds(epoch, mask=rec['mask'])
except ValueError:
# some signal types can't handle masks with epoch bounds
epoch_bounds = rec['resp'].get_epoch_bounds(epoch)
except:
log.warning(f'Quickplot: no valid epochs matching {epoch}. Will not subset data.')
epoch = None

if 'mask' in rec.signals.keys():
rec = rec.apply_mask()

rec_resp = rec['resp']
rec_pred = rec['pred']
rec_stim = rec['stim']

if (epoch is None) or (len(epoch_bounds)==0):
epoch_bounds = np.array([[0, rec['resp'].shape[1]/rec['resp'].fs]])
if epoch is None or len(epoch_bounds) == 0:
epoch_bounds = np.array([[0, rec['resp'].shape[1] / rec['resp'].fs]])

# figure out which occurrence
# not_empty = [np.any(np.isfinite(x)) for x in extracted] # occurrences containing non inf/nan data
Expand Down Expand Up @@ -675,7 +682,7 @@ def quickplot(self, rec=None, epoch=None, occurrence=None, fit_index=None,
]

for s in sig_names:
if rec[s].shape[0]>1:
if rec[s].shape[0] > 1:
plot_fn = _lookup_fn_at('nems.plots.api.spectrogram')
title = s + ' spectrogram'
else:
Expand Down Expand Up @@ -987,6 +994,22 @@ def modelspec2tf(self, tps_per_stim=550, feat_dims=1, data_dims=1, state_dims=0,

return layers

def modelspec2tf2(self, seed=0, use_modelspec_init=True):
"""New version
TODO
"""
layers = []

for m in self:
layer = nems.tf.cnnlink_new.map_layer(layer=m,
use_modelspec_init=use_modelspec_init,
# seed=seed,
)
layers.append(layer)

return layers


def get_modelspec_metadata(modelspec):
"""Return a dict of the metadata for this modelspec.
Expand Down
2 changes: 1 addition & 1 deletion nems/modules/levelshift.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def levelshift(rec, i, o, level):
def levelshift(rec, i, o, level, **kwargs):
'''
Parameters
----------
Expand Down
4 changes: 2 additions & 2 deletions nems/modules/nonlinearity.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _double_exponential(x, base, amplitude, shift, kappa):
return base + amplitude * exp(-exp(np.array(-exp(kappa)) * (x - shift)))


def double_exponential(rec, i, o, base, amplitude, shift, kappa):
def double_exponential(rec, i, o, base, amplitude, shift, kappa, **kwargs):
'''
A double exponential applied to all channels of a single signal.
rec Recording object
Expand Down Expand Up @@ -84,7 +84,7 @@ def _dlog(x, offset):
return np.log((x + d) / d)


def dlog(rec, i, o, offset):
def dlog(rec, i, o, offset, **kwargs):
"""
Log compression with variable offset
:param rec: recording object with signals i and o.
Expand Down
2 changes: 1 addition & 1 deletion nems/plots/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def pred_resp(rec=None, modelspec=None, ax=None, title=None,
--------
ax : axis containing plot
'''
sig_list = ['pred', 'resp']
sig_list = ['resp', 'pred']
sigs = [rec[s] for s in sig_list]
ax = timeseries_from_signals(sigs, channels=channels,
xlabel=xlabel, ylabel=ylabel, ax=ax,
Expand Down
62 changes: 61 additions & 1 deletion nems/plugins/default_fitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,60 @@ def nrc(fitkey):
return xfspec


def newtf(fitkey):
"""New tf fitter.
TODO
:param fitkey:
:return:
"""
use_modelspec_init = True
optimizer = 'adam'
# max_iter = 10000
max_iter = 10_000
cost_function = 'squared_error'
early_stopping_steps = 5
early_stopping_tolerance = 5e-4
learning_rate = 1e-4
# seed = 0

rand_count = 0
pick_best = False

options = _extract_options(fitkey)

for op in options:
if op == 'b':
pick_best = True
elif op.startswith('rb'):
if len(op) == 2:
rand_count = 10
else:
rand_count = int(op[2:])
pick_best = True

xfspec = []
if rand_count > 0:
xfspec.append(['nems.initializers.rand_phi', {'rand_count': rand_count}])

xfspec.append(['nems.xforms.fit_wrapper',
{
'max_iter': max_iter,
'use_modelspec_init': use_modelspec_init,
'optimizer': optimizer,
'cost_function': cost_function,
'fit_function': 'nems.tf.cnnlink_new.fit_tf',
'early_stopping_steps': early_stopping_steps,
'early_stopping_tolerance': early_stopping_tolerance,
'learning_rate': learning_rate,
}])

if pick_best:
xfspec.append(['nems.analysis.test_prediction.pick_best_phi', {'criterion': 'mse_fit'}])

return xfspec


def tf(fitkey):
'''
Perform a Tensorflow fit, using Sam Norman-Haignere's CNN library
Expand Down Expand Up @@ -399,7 +453,13 @@ def init(kw):
elif op.startswith('ii'):
sel_options['include_through'] = int(op[2:])
elif op.startswith('i'):
sel_options['include_idx'].append(int(op[1:]))
if not tf:
sel_options['include_idx'].append(int(op[1:]))
else:
if len(op[1:]) == 0:
max_iter = None
else:
max_iter = int(op[1:])
elif op.startswith('ffneg'):
sel_options['freeze_after'] = -int(op[5:])
elif op.startswith('fneg'):
Expand Down
18 changes: 11 additions & 7 deletions nems/plugins/default_keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def wc(kw):

# This is the default for wc, but options might overwrite it.
fn = 'nems.modules.weight_channels.basic'
fn_kwargs = {'i': 'pred', 'o': 'pred', 'normalize_coefs': False}
fn_kwargs = {'i': 'pred', 'o': 'pred', 'normalize_coefs': False, 'chans': n_outputs}
p_coefficients = {'mean': np.full((n_outputs, n_inputs), 0.01),
'sd': np.full((n_outputs, n_inputs), 0.2)}
# add some variety across channels to help get the fitter started
Expand Down Expand Up @@ -132,7 +132,7 @@ def wc(kw):
# Generate evenly-spaced filter centers for the starting points
fn = 'nems.modules.weight_channels.gaussian'
fn_kwargs = {'i': 'pred', 'o': 'pred', 'n_chan_in': n_inputs,
'normalize_coefs': False}
'normalize_coefs': False, 'chans': n_outputs}
coefs = 'nems.modules.weight_channels.gaussian_coefficients'
mean = np.arange(n_outputs+1)/(n_outputs*2+2) + 0.25
mean = mean[1:]
Expand Down Expand Up @@ -318,7 +318,7 @@ def fir(kw):
template = {
'fn': 'nems.modules.fir.basic',
'fn_kwargs': {'i': 'pred', 'o': 'pred', 'non_causal': non_causal,
'cross_channels': cross_channels},
'cross_channels': cross_channels, 'chans': n_coefs},
'plot_fns': ['nems.plots.api.mod_output',
'nems.plots.api.spectrogram_output',
'nems.plots.api.strf_heatmap',
Expand All @@ -334,7 +334,8 @@ def fir(kw):
template = {
'fn': 'nems.modules.fir.filter_bank',
'fn_kwargs': {'i': 'pred', 'o': 'pred', 'non_causal': non_causal,
'bank_count': n_banks, 'cross_channels': cross_channels},
'bank_count': n_banks, 'cross_channels': cross_channels,
'chans': n_coefs},
'plot_fns': ['nems.plots.api.mod_output',
'nems.plots.api.spectrogram_output',
'nems.plots.api.strf_heatmap',
Expand Down Expand Up @@ -689,13 +690,14 @@ def lvl(kw):

template = {
'fn': 'nems.modules.levelshift.levelshift',
'fn_kwargs': {'i': 'pred', 'o': 'pred'},
'fn_kwargs': {'i': 'pred', 'o': 'pred', 'chans': n_shifts},
'plot_fns': ['nems.plots.api.mod_output',
'nems.plots.api.spectrogram_output',
'nems.plots.api.pred_resp'],
'plot_fn_idx': 2,
'prior': {'level': ('Normal', {'mean': np.zeros([n_shifts, 1]),
'sd': np.ones([n_shifts, 1])})}

}

return template
Expand Down Expand Up @@ -968,7 +970,8 @@ def dexp(kw):
template = {
'fn': 'nems.modules.nonlinearity.double_exponential',
'fn_kwargs': {'i': inout_name,
'o': inout_name},
'o': inout_name,
'chans': n_dims},
'plot_fns': ['nems.plots.api.mod_output',
'nems.plots.api.spectrogram_output',
'nems.plots.api.pred_resp',
Expand Down Expand Up @@ -1167,7 +1170,8 @@ def dlog(kw):
template = {
'fn': 'nems.modules.nonlinearity.dlog',
'fn_kwargs': {'i': 'pred',
'o': 'pred'},
'o': 'pred',
'chans': chans,},
'plot_fns': ['nems.plots.api.mod_output',
'nems.plots.api.spectrogram_output',
'nems.plots.api.pred_resp',
Expand Down
14 changes: 14 additions & 0 deletions nems/tf/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Custom tensorflow training callbacks."""

import tensorflow as tf


class DelayedStopper(tf.keras.callbacks.EarlyStopping):
"""Early stopper that waits before kicking in."""
def __init__(self, start_epoch=100, **kwargs):
super(DelayedStopper, self).__init__(**kwargs)
self.start_epoch = start_epoch

def on_epoch_end(self, epoch, logs=None):
if epoch > self.start_epoch:
super().on_epoch_end(epoch, logs)
2 changes: 1 addition & 1 deletion nems/tf/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def train(self, stim, resp, learning_rate=0.01, max_iter=300, eval_interval=30,

# early stopping for > 5 non improving iterations
# though partly redundant with tolerance early stopping, catches nans
if self.best_loss_index < len(self.train_loss) - early_stopping_steps:
if self.best_loss_index <= len(self.train_loss) - early_stopping_steps:
log.info(f'Best epoch > {early_stopping_steps} iterations ago, stopping early!')
break

Expand Down
Loading

0 comments on commit b1d03d7

Please sign in to comment.