From 8d6889b6295e86fd2570847aa3a296ad78e1d7df Mon Sep 17 00:00:00 2001 From: Alexander Tomlinson Date: Wed, 19 Feb 2020 16:48:50 -0800 Subject: [PATCH] #207 migrate cnn module to tf2 --- environment.yml | 1 + nems/db.py | 20 ++ nems/modelspec.py | 71 +++- nems/modules/levelshift.py | 2 +- nems/modules/nonlinearity.py | 4 +- nems/modules/stp.py | 188 +++++++--- nems/modules/weight_channels.py | 4 +- nems/plots/specgram.py | 15 +- nems/plots/timeseries.py | 2 +- nems/plugins/default_fitters.py | 96 +++++- nems/plugins/default_keywords.py | 46 ++- nems/tf/callbacks.py | 14 + nems/tf/cnn.py | 2 +- nems/tf/cnnlink.py | 254 +++++++++++++- nems/tf/cnnlink_new.py | 283 ++++++++++++++++ nems/tf/layers.py | 565 +++++++++++++++++++++++++++++++ nems/tf/loss_functions.py | 35 +- nems/tf/modelbuilder.py | 83 +++++ nems/xform_helper.py | 23 +- nems/xforms.py | 14 +- setup.py | 2 +- 21 files changed, 1631 insertions(+), 93 deletions(-) create mode 100644 nems/tf/callbacks.py create mode 100644 nems/tf/cnnlink_new.py create mode 100644 nems/tf/layers.py create mode 100644 nems/tf/modelbuilder.py diff --git a/environment.yml b/environment.yml index 6628365c..afe8c0de 100644 --- a/environment.yml +++ b/environment.yml @@ -15,3 +15,4 @@ dependencies: - pyqt - tensorflow - tensorboard + - tensorflow-probability diff --git a/nems/db.py b/nems/db.py index e207eadc..900118ce 100644 --- a/nems/db.py +++ b/nems/db.py @@ -601,6 +601,26 @@ def update_startdate(queueid=None): return r +def update_job_pid(pid, queueid=None): + """ + in tQueue, update the PID + """ + if queueid is None: + if 'QUEUEID' in os.environ: + queueid = os.environ['QUEUEID'] + else: + log.warning("queueid not specified or found in os.environ") + return 0 + + engine = Engine() + conn = engine.connect() + sql = ("UPDATE tQueue SET pid={} WHERE id={}" + .format(pid, queueid)) + r = conn.execute(sql) + conn.close() + return r + + def update_job_tick(queueid=None): """ update current machine's load in the cluster db and tick off a step diff --git a/nems/modelspec.py b/nems/modelspec.py index 07bc80cc..e9ff14e7 100644 --- a/nems/modelspec.py +++ b/nems/modelspec.py @@ -6,6 +6,7 @@ import logging import os import re +import typing from functools import partial import matplotlib.gridspec as gridspec @@ -541,6 +542,9 @@ def quickplot(self, rec=None, epoch=None, occurrence=None, fit_index=None, """ if rec is None: rec = self.recording + # rec = rec.copy() + + # rec['resp'] = rec['resp'].rasterize() if fit_index is not None: self.fit_index = fit_index @@ -595,18 +599,22 @@ def quickplot(self, rec=None, epoch=None, occurrence=None, fit_index=None, if epoch is not None: try: epoch_bounds = rec['resp'].get_epoch_bounds(epoch, mask=rec['mask']) + except ValueError: + # some signal types can't handle masks with epoch bounds + epoch_bounds = rec['resp'].get_epoch_bounds(epoch) except: log.warning(f'Quickplot: no valid epochs matching {epoch}. Will not subset data.') epoch = None if 'mask' in rec.signals.keys(): rec = rec.apply_mask() + rec_resp = rec['resp'] rec_pred = rec['pred'] rec_stim = rec['stim'] - if (epoch is None) or (len(epoch_bounds)==0): - epoch_bounds = np.array([[0, rec['resp'].shape[1]/rec['resp'].fs]]) + if epoch is None or len(epoch_bounds) == 0: + epoch_bounds = np.array([[0, rec['resp'].shape[1] / rec['resp'].fs]]) # figure out which occurrence # not_empty = [np.any(np.isfinite(x)) for x in extracted] # occurrences containing non inf/nan data @@ -675,7 +683,7 @@ def quickplot(self, rec=None, epoch=None, occurrence=None, fit_index=None, ] for s in sig_names: - if rec[s].shape[0]>1: + if rec[s].shape[0] > 1: plot_fn = _lookup_fn_at('nems.plots.api.spectrogram') title = s + ' spectrogram' else: @@ -978,7 +986,7 @@ def modelspec2tf(self, tps_per_stim=550, feat_dims=1, data_dims=1, state_dims=0, layer = nems.tf.cnnlink.map_layer(layer=layer, prev_layers=layers, fn=fn, idx=idx, modelspec=m, n_input_feats=n_input_feats, net_seed=net_seed, weight_scale=weight_scale, - use_modelspec_init=use_modelspec_init, distr=distr,) + use_modelspec_init=use_modelspec_init, fs=fs, distr=distr,) # necessary? layer['time_win_sec'] = layer['time_win_smp'] / fs @@ -987,6 +995,24 @@ def modelspec2tf(self, tps_per_stim=550, feat_dims=1, data_dims=1, state_dims=0, return layers + def modelspec2tf2(self, seed=0, use_modelspec_init=True, fs=100): + """New version + + TODO + """ + layers = [] + + for m in self: + try: + tf_layer = nems.utils.lookup_fn_at(m['tf_layer']) + except KeyError: + raise NotImplementedError(f'Layer "{m["fn"]}" does not have a tf equivalent.') + + layer = tf_layer.from_ms_layer(m, use_modelspec_init=use_modelspec_init, fs=fs) + layers.append(layer) + + return layers + def get_modelspec_metadata(modelspec): """Return a dict of the metadata for this modelspec. @@ -1215,6 +1241,43 @@ def fit_mode_off(modelspec): modelspec.fast_eval_off() +def eval_ms_layer(data: np.ndarray, + layer_spec: typing.Union[None, str] = None, + stop: typing.Union[None, int] = None, + modelspec: ModelSpec = None + ) -> np.ndarray: + """Takes in a numpy array and applies a single ms layer to it. + + :param data: The input data. Shape of (reps, time, channels). + :param layer_spec: A layer spec for layers of a modelspec. + :param stop: What layer to eval to. Non inclusive. If not passed, will evaluate the whole layer spec. + :param modelspec: Optionally use an existing modelspec. Takes precedence over layer_spec. + + :return: The processed data. + """ + if layer_spec is None and modelspec is None: + raise ValueError('Either of "layer_spec" or "modelspec" must be specified.') + + if modelspec is not None: + ms = modelspec + else: + ms = nems.initializers.from_keywords(layer_spec) + + sig = nems.signal.RasterizedSignal.from_3darray( + fs=100, + array=np.swapaxes(data, 1, 2), + name='temp', + recording='temp', + epoch_name='REFERENCE' + ) + + rec = nems.recording.Recording({'stim': sig}) + rec = ms.evaluate(rec=rec, stop=stop) + + pred = np.swapaxes(rec['pred'].extract_epoch('REFERENCE'), 1, 2) + return pred + + def evaluate(rec, modelspec, start=None, stop=None): """Given a recording object and a modelspec, return a prediction in a new recording. diff --git a/nems/modules/levelshift.py b/nems/modules/levelshift.py index 635b5c7f..d114fdec 100644 --- a/nems/modules/levelshift.py +++ b/nems/modules/levelshift.py @@ -1,4 +1,4 @@ -def levelshift(rec, i, o, level): +def levelshift(rec, i, o, level, **kwargs): ''' Parameters ---------- diff --git a/nems/modules/nonlinearity.py b/nems/modules/nonlinearity.py index ce1dd5a9..27a57b04 100644 --- a/nems/modules/nonlinearity.py +++ b/nems/modules/nonlinearity.py @@ -42,7 +42,7 @@ def _double_exponential(x, base, amplitude, shift, kappa): return base + amplitude * exp(-exp(np.array(-exp(kappa)) * (x - shift))) -def double_exponential(rec, i, o, base, amplitude, shift, kappa): +def double_exponential(rec, i, o, base, amplitude, shift, kappa, **kwargs): ''' A double exponential applied to all channels of a single signal. rec Recording object @@ -84,7 +84,7 @@ def _dlog(x, offset): return np.log((x + d) / d) -def dlog(rec, i, o, offset): +def dlog(rec, i, o, offset, **kwargs): """ Log compression with variable offset :param rec: recording object with signals i and o. diff --git a/nems/modules/stp.py b/nems/modules/stp.py index 2bf434b0..454cb518 100644 --- a/nems/modules/stp.py +++ b/nems/modules/stp.py @@ -6,8 +6,9 @@ log = logging.getLogger(__name__) + def short_term_plasticity(rec, i, o, u, tau, x0=None, crosstalk=0, - reset_signal=None, quick_eval=False): + reset_signal=None, quick_eval=False, **kwargs): ''' STP applied to each input channel. parameterized by Markram-Todyks model: @@ -20,10 +21,11 @@ def short_term_plasticity(rec, i, o, u, tau, x0=None, crosstalk=0, if reset_signal in rec.signals.keys(): r = rec[reset_signal].as_continuous() - fn = lambda x : _stp(x, u, tau, x0, crosstalk, rec[i].fs, r, quick_eval) + fn = lambda x: _stp(x, u, tau, x0, crosstalk, rec[i].fs, r, quick_eval) return [rec[i].transform(fn, o)] + def short_term_plasticity2(rec, i, o, u, u2, tau, tau2, urat=0.5, x0=None, crosstalk=0, reset_signal=None, quick_eval=False): ''' @@ -39,17 +41,22 @@ def short_term_plasticity2(rec, i, o, u, u2, tau, tau2, urat=0.5, x0=None, cross if reset_signal in rec.signals.keys(): r = rec[reset_signal].as_continuous() - fn = lambda x : _stp2(x, u, u2, tau, tau2, urat, x0, crosstalk, rec[i].fs, r, quick_eval) + fn = lambda x: _stp2(x, u, u2, tau, tau2, urat, x0, crosstalk, rec[i].fs, r, quick_eval) return [rec[i].transform(fn, o)] -def _stp(X, u, tau, x0=None, crosstalk=0, fs=1, reset_signal=None, quick_eval=False, dep_only=False): +def _stp(X, u, tau, x0=None, crosstalk=0, fs=1, reset_signal=None, quick_eval=False, dep_only=False, + chunksize=5): """ STP core function """ s = X.shape - tstim = X.copy() + # X = X.astype('float64') + u = u.astype('float64') + tau = tau.astype('float64') + + tstim = X.astype('float64') tstim[np.isnan(tstim)] = 0 if x0 is not None: tstim -= np.expand_dims(x0, axis=1) @@ -66,20 +73,29 @@ def _stp(X, u, tau, x0=None, crosstalk=0, fs=1, reset_signal=None, quick_eval=Fa else: ui = u.copy() - #ui[ui > 1] = 1 - #ui[ui < -0.4] = -0.4 + # force tau to have positive sign (probably better done with bounds on fitter) + taui = np.absolute(tau.copy()) + # taui[taui < 2/fs] = 2/fs - # convert tau units from sec to bins - taui = np.absolute(tau.copy()) * fs - taui[taui < 2] = 2 + # ui[ui > 1] = 1 + # ui[ui < -0.4] = -0.4 # avoid ringing if combination of strong depression and # rapid recovery is too large - rat = ui**2 / taui - ui[rat>0.1] = np.sqrt(0.1 * taui[rat>0.1]) - #taui[rat>0.08] = (ui[rat>0.08]**2) / 0.08 + # rat = ui**2 / taui + + # MK comment out + # ui[rat>0.1] = np.sqrt(0.1 * taui[rat>0.1]) + + # taui[rat>0.08] = (ui[rat>0.08]**2) / 0.08 + # print("rat: %s" % (ui**2 / taui)) - #print("rat: %s" % (ui**2 / taui)) + # convert u & tau units from sec to bins + taui = taui * fs + ui = ui / fs + + # convert chunksize from sec to bins + chunksize = int(chunksize * fs) # TODO : allow >1 STP channel per input? @@ -89,9 +105,90 @@ def _stp(X, u, tau, x0=None, crosstalk=0, fs=1, reset_signal=None, quick_eval=Fa if crosstalk: # assumes dim of u is 1 ! tstim = np.mean(tstim, axis=0, keepdims=True) - + if len(ui.shape) == 1: + ui = np.expand_dims(ui, axis=1) + taui = np.expand_dims(taui, axis=1) for i in range(0, len(u)): - if quick_eval and (reset_signal is not None): + if quick_eval: + + a = 1 / taui + x = ui * tstim + + if reset_signal is None: + reset_times = np.arange(0, s[1] + chunksize - 1, chunksize) + else: + reset_times = np.argwhere(reset_signal[0, :])[:, 0] + reset_times = np.append(reset_times, s[1]) + + td = np.ones_like(x) + x0, imu0 = 0., 0. + for j in range(len(reset_times) - 1): + si = slice(reset_times[j], reset_times[j + 1]) + xi = x[:, si] + + ix = cumtrapz(a + xi, dx=1, initial=0, axis=1) + a + (x0 + xi[:, :1]) / 2 + + mu = np.exp(ix) + imu = cumtrapz(mu * xi, dx=1, initial=0, axis=1) + (x0 + mu[:, :1] * xi[:, :1]) / 2 + imu0 + + ff = np.bitwise_and(mu > 0, imu > 0) + _td = np.ones_like(mu) + _td[ff] = 1 - np.exp(np.log(imu[ff]) - np.log(mu[ff])) + td[:, si] = _td + + x0 = xi[:, -1:] + imu0 = imu[:, -1:] / mu[:, -1:] + + stim_out = tstim * np.pad(td[:, :-1], ((0, 0), (1, 0)), 'constant', constant_values=(1, 1)) + + """ + a = 1 / taui[i] + x = ui[i] * tstim[i, :] # / fs + + #if reset_signal is None: + reset_times = np.arange(0, len(x) + chunksize - 1, chunksize) + #else: + # reset_times = np.argwhere(reset_signal[0, :])[:, 0] + # reset_times = np.append(reset_times, len(x)) + td = np.ones_like(x) + td0, mu0, imu0, x0 = 1., 1., 0., 0. + for j in range(len(reset_times) - 1): + si = slice(reset_times[j], reset_times[j + 1]) + #mu = np.zeros_like(x[si]) + #imu = np.zeros_like(x[si]) + + # ix = _cumtrapz(a + x[si]) + ix = cumtrapz(a + x[si], dx=1, initial=0) + ix += np.log(mu0) + a + (x0 + x[si][0]) / 2 + + mu = np.exp(ix) + # imu = _cumtrapz(mu*x[si]) + td0 + imu = cumtrapz(mu * x[si], dx=1, initial=0) # /mu0 + imu0/mu0 + imu += imu0 + (mu0 * x0 + mu[0] * x[si][0]) / 2 + + ff = np.bitwise_and(mu > 0, imu > 0) + _td = np.ones_like(mu) + # _td[mu>0] = 1 - imu[mu>0]/mu[mu>0] + _td[ff] = 1 - np.exp(np.log(imu[ff]) - np.log(mu[ff])) + td[si] = _td + + x0 = x[si][-1] + mu0 = mu[-1] + imu0 = imu[-1] + td0 = _td[-1] + mu0, imu0 = 1, imu0 / mu0 + + # if i==0: + # plt.figure(figsize=(16,3)) + # plt.plot(x[si]) + # plt.plot(_td) + if crosstalk: + stim_out *= np.expand_dims(td, 0) + else: + stim_out[i, :] *= td + """ + """ + if quick_eval and (reset_signal is not None): a = 1 / taui[i] x = ui[i] * tstim[i, :] / fs @@ -121,10 +218,12 @@ def _stp(X, u, tau, x0=None, crosstalk=0, fs=1, reset_signal=None, quick_eval=Fa stim_out *= np.expand_dims(td, 0) else: stim_out[i, :] *= td + + """ else: - a = 1/taui[i] - ustim = 1.0/taui[i] + ui[i] * tstim[i, :] + a = 1 / taui[i] + ustim = 1.0 / taui[i] + ui[i] * tstim[i, :] # ustim = ui[i] * tstim[i, :] td = np.ones_like(ustim) # initialize dep state vector @@ -139,14 +238,14 @@ def _stp(X, u, tau, x0=None, crosstalk=0, fs=1, reset_signal=None, quick_eval=Fa # delta = 1/taui[i] - td * (1/taui[i] - ui[i] * tstim[i, tt - 1]) # then a=1/taui[i] and ustim=1/taui[i] - ui[i] * tstim[i,:] delta = a - td[tt - 1] * ustim[tt - 1] - td[tt] = td[tt-1] + delta + td[tt] = td[tt - 1] + delta if td[tt] < 0: td[tt] = 0 else: # facilitation for tt in range(1, s[1]): - delta = a - td[tt-1] * ustim[tt - 1] - td[tt] = td[tt-1] + delta + delta = a - td[tt - 1] * ustim[tt - 1] + td[tt] = td[tt - 1] + delta if td[tt] > 5: td[tt] = 5 @@ -156,17 +255,22 @@ def _stp(X, u, tau, x0=None, crosstalk=0, fs=1, reset_signal=None, quick_eval=Fa stim_out[i, :] *= td if np.sum(np.isnan(stim_out)): - # import pdb - # pdb.set_trace() + # import pdb + # pdb.set_trace() log.info('nan value in stp stim_out') # print("(u,tau)=({0},{1})".format(ui,taui)) + try: + stim_out[np.isnan(X)] = np.nan + except: + import pdb + pdb.set_trace() - stim_out[np.isnan(X)] = np.nan return stim_out -def _stp2(X, u, u2, tau, tau2, urat=0.5, x0=None, crosstalk=0, fs=1, reset_signal=None, quick_eval=False, dep_only=False): +def _stp2(X, u, u2, tau, tau2, urat=0.5, x0=None, crosstalk=0, fs=1, reset_signal=None, quick_eval=False, + dep_only=False): """ STP core function """ @@ -194,16 +298,16 @@ def _stp2(X, u, u2, tau, tau2, urat=0.5, x0=None, crosstalk=0, fs=1, reset_signa taui = np.absolute(tau) * fs taui[taui < 2] = 2 taui2 = np.absolute(tau2) * fs - #taui2[taui2 < 2] = 2 + # taui2[taui2 < 2] = 2 # avoid ringing if combination of strong depression and # rapid recovery is too large - rat = ui**2 / taui - ui[rat>0.1] = np.sqrt(0.1 * taui[rat>0.1]) - rat = ui2**2 / taui2 - ui2[rat>0.1] = np.sqrt(0.1 * taui2[rat>0.1]) + rat = ui ** 2 / taui + ui[rat > 0.1] = np.sqrt(0.1 * taui[rat > 0.1]) + rat = ui2 ** 2 / taui2 + ui2[rat > 0.1] = np.sqrt(0.1 * taui2[rat > 0.1]) - #print("rat: %s" % (ui**2 / taui)) + # print("rat: %s" % (ui**2 / taui)) # TODO : allow >1 STP channel per input? @@ -216,8 +320,8 @@ def _stp2(X, u, u2, tau, tau2, urat=0.5, x0=None, crosstalk=0, fs=1, reset_signa for i in range(0, len(u)): - a = 1/taui[i] - ustim = 1.0/taui[i] + ui[i] * tstim[i, :] + a = 1 / taui[i] + ustim = 1.0 / taui[i] + ui[i] * tstim[i, :] # ustim = ui[i] * tstim[i, :] td = np.ones_like(ustim) # initialize dep state vector td2 = np.ones_like(ustim) # initialize dep state vector @@ -228,15 +332,15 @@ def _stp2(X, u, u2, tau, tau2, urat=0.5, x0=None, crosstalk=0, fs=1, reset_signa elif ui[i] > 0: # depression for tt in range(1, s[1]): - #td = di[i, tt - 1] # previous time bin depression - delta = (1 - td[tt-1]) / taui[i] - ui[i] * td[tt-1] * tstim[i, tt - 1] - delta2 = (1 - td2[tt-1]) / taui2[i] - ui2[i] * td2[tt-1] * tstim[i, tt - 1] + # td = di[i, tt - 1] # previous time bin depression + delta = (1 - td[tt - 1]) / taui[i] - ui[i] * td[tt - 1] * tstim[i, tt - 1] + delta2 = (1 - td2[tt - 1]) / taui2[i] - ui2[i] * td2[tt - 1] * tstim[i, tt - 1] # delta = 1/taui[i] - td * (1/taui[i] - ui[i] * tstim[i, tt - 1]) # then a=1/taui[i] and ustim=1/taui[i] - ui[i] * tstim[i,:] # delta = a - td[tt - 1] * ustim[tt - 1] - td[tt] = td[tt-1] + delta - td2[tt] = td2[tt-1] + delta2 + td[tt] = td[tt - 1] + delta + td2[tt] = td2[tt - 1] + delta2 if td[tt] < 0: td[tt] = 0 if td2[tt] < 0: @@ -244,15 +348,15 @@ def _stp2(X, u, u2, tau, tau2, urat=0.5, x0=None, crosstalk=0, fs=1, reset_signa else: # facilitation for tt in range(1, s[1]): - delta = a - td[tt-1] * ustim[tt - 1] - td[tt] = td[tt-1] + delta + delta = a - td[tt - 1] * ustim[tt - 1] + td[tt] = td[tt - 1] + delta if td[tt] > 5: td[tt] = 5 if crosstalk: stim_out *= np.expand_dims(td, 0) else: - stim_out[i, :] *= td*urat + td2*(1-urat) + stim_out[i, :] *= td * urat + td2 * (1 - urat) if np.sum(np.isnan(stim_out)): import pdb @@ -262,4 +366,4 @@ def _stp2(X, u, u2, tau, tau2, urat=0.5, x0=None, crosstalk=0, fs=1, reset_signa # print("(u,tau)=({0},{1})".format(ui,taui)) stim_out[np.isnan(X)] = np.nan - return stim_out + return stim_out \ No newline at end of file diff --git a/nems/modules/weight_channels.py b/nems/modules/weight_channels.py index 16704c9a..ef863448 100644 --- a/nems/modules/weight_channels.py +++ b/nems/modules/weight_channels.py @@ -20,7 +20,7 @@ def gaussian_coefficients(mean, sd, n_chan_in, **kwargs): # mean[mean < 0] = 0 # mean[mean > 1] = 1 sd = np.asanyarray(sd)[..., np.newaxis] - sd[sd < 0.00001] = 0.00001 + # sd[sd < 0.00001] = 0.00001 #coefficients = 1/(sd*(2*np.pi)**0.5) * np.exp(-0.5*((x-mean)/sd)**2) coefficients = np.exp(-0.5*((x-mean)/sd)**2) csum = np.sum(coefficients, axis=1, keepdims=True) @@ -32,7 +32,7 @@ def gaussian_coefficients(mean, sd, n_chan_in, **kwargs): #------------------------------------------------------------------------------- # Module functions #------------------------------------------------------------------------------- -def basic(rec, i, o, coefficients, normalize_coefs=False): +def basic(rec, i, o, coefficients, normalize_coefs=False, **kwargs): ''' Parameters ---------- diff --git a/nems/plots/specgram.py b/nems/plots/specgram.py index 9abc05f2..a4a67384 100644 --- a/nems/plots/specgram.py +++ b/nems/plots/specgram.py @@ -1,10 +1,13 @@ -import scipy.signal as sps -import scipy as scp -import numpy as np +import logging + import matplotlib.pyplot as plt -import matplotlib.gridspec as gridspec -from nems.plots.utils import ax_remove_box +import numpy as np + from nems.gui.decorators import scrollable +from nems.plots.utils import ax_remove_box + +log = logging.getLogger(__name__) + def plot_spectrogram(array, fs=None, ax=None, title=None, time_offset=0, cmap=None, clim=None, extent=True, time_range=None, **options): @@ -16,7 +19,7 @@ def plot_spectrogram(array, fs=None, ax=None, title=None, time_offset=0, if time_range is not None: if fs is not None: time_range = np.round(np.array(time_range)*fs).astype(int) - print('bin range: {}-{}'.format(time_range[0],time_range[1])) + log.debug('bin range: {}-{}'.format(time_range[0],time_range[1])) ax.imshow(array[:, np.arange(time_range[0],time_range[1])], origin='lower', interpolation='none', aspect='auto', cmap=cmap, clim=clim) diff --git a/nems/plots/timeseries.py b/nems/plots/timeseries.py index 3bba67b6..d209ee59 100644 --- a/nems/plots/timeseries.py +++ b/nems/plots/timeseries.py @@ -441,7 +441,7 @@ def pred_resp(rec=None, modelspec=None, ax=None, title=None, -------- ax : axis containing plot ''' - sig_list = ['pred', 'resp'] + sig_list = ['resp', 'pred'] sigs = [rec[s] for s in sig_list] ax = timeseries_from_signals(sigs, channels=channels, xlabel=xlabel, ylabel=ylabel, ax=ax, diff --git a/nems/plugins/default_fitters.py b/nems/plugins/default_fitters.py index 927d0111..10feec23 100644 --- a/nems/plugins/default_fitters.py +++ b/nems/plugins/default_fitters.py @@ -142,6 +142,94 @@ def nrc(fitkey): return xfspec +def newtf(fitkey): + """New tf fitter. + + TODO + """ + use_modelspec_init = False + optimizer = 'adam' + max_iter = 10_000 + cost_function = 'squared_error' + early_stopping_steps = 5 + early_stopping_tolerance = 5e-4 + learning_rate = 1e-4 + batch_size = None + # seed = 0 + + rand_count = 0 + pick_best = False + + options = _extract_options(fitkey) + + for op in options: + if op[:1] == 'i': + if len(op[1:]) == 0: + max_iter = int(1e10) # just a really large number + else: + max_iter = int(op[1:]) + elif op[:1] == 'n': + use_modelspec_init = True + if op == 'b': + pick_best = True + elif op.startswith('rb'): + pick_best = True + if len(op) == 2: + rand_count = 10 + else: + rand_count = int(op[2:]) + elif op.startswith('r'): + if len(op) == 1: + rand_count = 10 + else: + rand_count = int(op[1:]) + elif op.startswith('lr'): + learning_rate = op[2:] + if 'e' in learning_rate: + base, exponent = learning_rate.split('e') + learning_rate = int(base) * 10 ** -int(exponent) + else: + learning_rate = int(learning_rate) + elif op.startswith('et'): + early_stopping_tolerance = 1 * 10 ** -int(op[2:]) + elif op.startswith('bs'): + batch_size = int(op[2:]) + + xfspec = [] + if rand_count > 0: + xfspec.append(['nems.initializers.rand_phi', {'rand_count': rand_count}]) + + xfspec.append(['nems.xforms.fit_wrapper', + { + 'max_iter': max_iter, + 'use_modelspec_init': use_modelspec_init, + 'optimizer': optimizer, + 'cost_function': cost_function, + 'fit_function': 'nems.tf.cnnlink_new.fit_tf', + 'early_stopping_steps': early_stopping_steps, + 'early_stopping_tolerance': early_stopping_tolerance, + 'learning_rate': learning_rate, + 'batch_size': batch_size, + }]) + + if pick_best: + xfspec.append(['nems.analysis.test_prediction.pick_best_phi', {'criterion': 'mse_fit'}]) + + return xfspec + + +def tfinit(fitkey): + """New tf init. + + TODO + """ + xfspec = newtf(fitkey) + idx = [xf[0] for xf in xfspec].index('nems.xforms.fit_wrapper') + xfspec[idx][1]['fit_function'] = 'nems.tf.cnnlink_new.fit_tf_init' + + return xfspec + + def tf(fitkey): ''' Perform a Tensorflow fit, using Sam Norman-Haignere's CNN library @@ -399,7 +487,13 @@ def init(kw): elif op.startswith('ii'): sel_options['include_through'] = int(op[2:]) elif op.startswith('i'): - sel_options['include_idx'].append(int(op[1:])) + if not tf: + sel_options['include_idx'].append(int(op[1:])) + else: + if len(op[1:]) == 0: + max_iter = None + else: + max_iter = int(op[1:]) elif op.startswith('ffneg'): sel_options['freeze_after'] = -int(op[5:]) elif op.startswith('fneg'): diff --git a/nems/plugins/default_keywords.py b/nems/plugins/default_keywords.py index 096ba9ad..1959038a 100644 --- a/nems/plugins/default_keywords.py +++ b/nems/plugins/default_keywords.py @@ -31,6 +31,8 @@ import numpy as np +from nems.tf import layers + log = logging.getLogger(__name__) @@ -87,7 +89,12 @@ def wc(kw): # This is the default for wc, but options might overwrite it. fn = 'nems.modules.weight_channels.basic' - fn_kwargs = {'i': 'pred', 'o': 'pred', 'normalize_coefs': False} + fn_kwargs = {'i': 'pred', + 'o': 'pred', + 'normalize_coefs': False, + 'chans': n_outputs, + } + tf_layer = 'nems.tf.layers.WeightChannelsBasic' p_coefficients = {'mean': np.full((n_outputs, n_inputs), 0.01), 'sd': np.full((n_outputs, n_inputs), 0.2)} # add some variety across channels to help get the fitter started @@ -131,8 +138,13 @@ def wc(kw): # Generate evenly-spaced filter centers for the starting points fn = 'nems.modules.weight_channels.gaussian' - fn_kwargs = {'i': 'pred', 'o': 'pred', 'n_chan_in': n_inputs, - 'normalize_coefs': False} + fn_kwargs = {'i': 'pred', + 'o': 'pred', + 'n_chan_in': n_inputs, + 'normalize_coefs': False, + 'chans': n_outputs, + } + tf_layer = 'nems.tf.layers.WeightChannelsGaussian' coefs = 'nems.modules.weight_channels.gaussian_coefficients' mean = np.arange(n_outputs+1)/(n_outputs*2+2) + 0.25 mean = mean[1:] @@ -178,7 +190,8 @@ def wc(kw): 'nems.plots.api.spectrogram_output', 'nems.plots.api.weight_channels_heatmap'], 'plot_fn_idx': 2, - 'prior': prior + 'prior': prior, + 'tf_layer': tf_layer, } if bounds is not None: template['bounds'] = bounds @@ -318,7 +331,8 @@ def fir(kw): template = { 'fn': 'nems.modules.fir.basic', 'fn_kwargs': {'i': 'pred', 'o': 'pred', 'non_causal': non_causal, - 'cross_channels': cross_channels}, + 'cross_channels': cross_channels, 'chans': n_coefs}, + 'tf_layer': 'nems.tf.layers.FIR', 'plot_fns': ['nems.plots.api.mod_output', 'nems.plots.api.spectrogram_output', 'nems.plots.api.strf_heatmap', @@ -334,7 +348,8 @@ def fir(kw): template = { 'fn': 'nems.modules.fir.filter_bank', 'fn_kwargs': {'i': 'pred', 'o': 'pred', 'non_causal': non_causal, - 'bank_count': n_banks, 'cross_channels': cross_channels}, + 'bank_count': n_banks, 'cross_channels': cross_channels, + 'chans': n_coefs}, 'plot_fns': ['nems.plots.api.mod_output', 'nems.plots.api.spectrogram_output', 'nems.plots.api.strf_heatmap', @@ -689,13 +704,15 @@ def lvl(kw): template = { 'fn': 'nems.modules.levelshift.levelshift', - 'fn_kwargs': {'i': 'pred', 'o': 'pred'}, + 'fn_kwargs': {'i': 'pred', 'o': 'pred', 'chans': n_shifts}, + 'tf_layer': 'nems.tf.layers.Levelshift', 'plot_fns': ['nems.plots.api.mod_output', 'nems.plots.api.spectrogram_output', 'nems.plots.api.pred_resp'], 'plot_fn_idx': 2, 'prior': {'level': ('Normal', {'mean': np.zeros([n_shifts, 1]), 'sd': np.ones([n_shifts, 1])})} + } return template @@ -797,7 +814,8 @@ def stp(kw): template = { 'fn': 'nems.modules.stp.short_term_plasticity', 'fn_kwargs': {'i': 'pred', 'o': 'pred', 'crosstalk': crosstalk, - 'quick_eval': quick_eval, 'reset_signal': 'epoch_onsets'}, + 'quick_eval': quick_eval, 'reset_signal': 'epoch_onsets', + 'chans': n_synapse}, 'plot_fns': ['nems.plots.api.mod_output', 'nems.plots.api.spectrogram_output', 'nems.plots.api.before_and_after', @@ -817,6 +835,9 @@ def stp(kw): template['prior']['x0'] = ('Normal', {'mean': x0_mean, 'sd': u_sd}) template['bounds']['x0'] = (np.full_like(x0_mean, -np.inf), np.full_like(x0_mean, np.inf)) + if quick_eval: + template['tf_layer'] = 'nems.tf.layers.STPQuick' + return template @@ -968,7 +989,9 @@ def dexp(kw): template = { 'fn': 'nems.modules.nonlinearity.double_exponential', 'fn_kwargs': {'i': inout_name, - 'o': inout_name}, + 'o': inout_name, + 'chans': n_dims}, + 'tf_layer': 'nems.tf.layers.DoubleExponential', 'plot_fns': ['nems.plots.api.mod_output', 'nems.plots.api.spectrogram_output', 'nems.plots.api.pred_resp', @@ -1167,7 +1190,9 @@ def dlog(kw): template = { 'fn': 'nems.modules.nonlinearity.dlog', 'fn_kwargs': {'i': 'pred', - 'o': 'pred'}, + 'o': 'pred', + 'chans': chans,}, + 'tf_layer': 'nems.tf.layers.Dlog', 'plot_fns': ['nems.plots.api.mod_output', 'nems.plots.api.spectrogram_output', 'nems.plots.api.pred_resp', @@ -1228,6 +1253,7 @@ def relu(kw): 'fn': fname, 'fn_kwargs': {'i': 'pred', 'o': 'pred'}, + 'tf_layer': 'nems.tf.layers.Relu', 'plot_fns': ['nems.plots.api.mod_output', 'nems.plots.api.spectrogram_output', 'nems.plots.api.pred_resp', diff --git a/nems/tf/callbacks.py b/nems/tf/callbacks.py new file mode 100644 index 00000000..6955d1e4 --- /dev/null +++ b/nems/tf/callbacks.py @@ -0,0 +1,14 @@ +"""Custom tensorflow training callbacks.""" + +import tensorflow as tf + + +class DelayedStopper(tf.keras.callbacks.EarlyStopping): + """Early stopper that waits before kicking in.""" + def __init__(self, start_epoch=100, **kwargs): + super(DelayedStopper, self).__init__(**kwargs) + self.start_epoch = start_epoch + + def on_epoch_end(self, epoch, logs=None): + if epoch > self.start_epoch: + super().on_epoch_end(epoch, logs) diff --git a/nems/tf/cnn.py b/nems/tf/cnn.py index 6c09d2d2..e2da7e5e 100644 --- a/nems/tf/cnn.py +++ b/nems/tf/cnn.py @@ -307,7 +307,7 @@ def train(self, stim, resp, learning_rate=0.01, max_iter=300, eval_interval=30, # early stopping for > 5 non improving iterations # though partly redundant with tolerance early stopping, catches nans - if self.best_loss_index < len(self.train_loss) - early_stopping_steps: + if self.best_loss_index <= len(self.train_loss) - early_stopping_steps: log.info(f'Best epoch > {early_stopping_steps} iterations ago, stopping early!') break diff --git a/nems/tf/cnnlink.py b/nems/tf/cnnlink.py index 91c4585d..ac317632 100644 --- a/nems/tf/cnnlink.py +++ b/nems/tf/cnnlink.py @@ -14,11 +14,11 @@ import tensorflow as tf import nems +from nems import modelspec import nems.modelspec as ms import nems.tf.cnn as cnn import nems.utils from nems.initializers import init_static_nl -from nems.tf import initializers log = logging.getLogger(__name__) @@ -27,7 +27,7 @@ def map_layer(layer: dict, prev_layers: list, fn: str, idx: int, modelspec, n_input_feats, net_seed, weight_scale, use_modelspec_init: bool, - distr: str='norm',) -> dict: + fs: int, distr: str='norm',) -> dict: """Maps a module to a tensorflow layer. Available conversions: @@ -374,6 +374,204 @@ def map_layer(layer: dict, prev_layers: list, fn: str, idx: int, modelspec, layer['Sd'] = tf.nn.conv1d(prev_layers[0]['S'], layer['d'], stride=1, padding='SAME') layer['Y'] = layer['X'] * layer['Sg'] + layer['Sd'] + elif fn == 'nems.modules.stp.short_term_plasticity': + # Credit Menoua Keshisian for some stp code. + quick_eval = fn_kwargs['quick_eval'] + crosstalk = fn_kwargs['crosstalk'] + #reset_signal = fn_kwargs['reset_signal'] + reset_signal = None + chunksize = 5 + _zero = tf.constant(0.0, dtype='float32') + _nan = tf.constant(0.0, dtype='float32') + + if phi is not None: + # reshape appropriately + u = np.reshape(phi['u'].astype('float32'), (1, len(phi['u']))) + tau = np.reshape(phi['tau'].astype('float32'), (1, len(phi['tau']))) + if 'x0' in phi: + x0 = np.reshape(phi['x0'].astype('float32'), (1, len(phi['x0']))) + else: + x0 = None + + if use_modelspec_init: + layer['u'] = tf.Variable(u, constraint=lambda u: tf.abs(u)) + layer['tau'] = tf.Variable(tau, constraint=lambda tau: tf.maximum(tf.abs(tau), 2.001/fs)) + if x0 is not None: + layer['x0'] = tf.Variable(x0) + else: + layer['u'] = tf.Variable(tf.abs(tf.random.normal(u.shape, stddev=weight_scale, + seed=cnn.seed_to_randint(net_seed + idx))), + constraint=lambda u: tf.abs(u)) + layer['tau'] = tf.Variable(tf.abs(tf.random.normal(u.shape, stddev=weight_scale, + seed=cnn.seed_to_randint(net_seed + 20 + idx)))+5.0/fs, + constraint=lambda tau: tf.maximum(tf.abs(tau), 2.001/fs)) + if x0 is not None: + layer['x0'] = tf.Variable(tf.random.normal(u.shape, stddev=weight_scale, + seed=cnn.seed_to_randint(net_seed + 30 + idx))) + + else: + # u/tau/x0 values are fixed + u = np.reshape(fn_kwargs['u'].astype('float32'), (1, len(fn_kwargs['u']))) + tau = np.reshape(fn_kwargs['tau'].astype('float32'), (1, len(fn_kwargs['tau']))) + + layer['u'] = tf.constant(u) + layer['tau'] = tf.constant(tau) + + if 'x0' in fn_kwargs: + x0 = np.reshape(fn_kwargs['x0'].astype('float32'), (1,1,len(fn_kwargs['x0']))) + else: + x0 = None + + layer['n_kern'] = u.shape[-1] + s = layer['X'].shape + #X = layer['X'] + tstim = tf.where(tf.math.is_nan(layer['X']), _zero, layer['X']) + if x0 is not None: # x0 should be tf variable to avoid retraces + # TODO: is this expanding along the right dim? tstim dims: (None, time, chans) + tstim = tstim - tf.expand_dims(x0, axis=1) + + tstim = tf.where(tstim < 0.0, _zero, tstim) + + #tstim = tf.cast(tstim, 'float64') + #u = tf.cast(layer['u'], 'float64') + #tau = tf.cast(layer['tau'], 'float64') + u = layer['u'] + tau = layer['tau'] + + # for quick eval, assume input range is approx -1 to +1 + if quick_eval: + ui = tf.math.abs(u) + else: + ui = u + + # ui[ui > 1] = 1 + # ui[ui < -0.4] = -0.4 + + taui = tf.math.abs(tau) + # taui = tf.where(taui < 2.0/fs, 2.0/fs, taui) # Avoid hard thresholds, non differentiable, set update constraint on variable + + # avoid ringing if combination of strong depression and rapid recovery is too large + #rat = ui ** 2 / taui + # ui = tf.where(rat > 0.1, tf.math.sqrt(0.1 * taui), ui) + # taui = tf.where(rat > 0.08, ui**2 / 0.08, taui) + + # convert a & tau units from sec to bins + taui = taui * fs + ui = ui / fs + + # convert chunksize from sec to bins + chunksize = int(chunksize * fs) + + if crosstalk: + # assumes dim of u is 1 ! + tstim = tf.math.reduce_mean(tstim, axis=0, keepdims=True) + + ui = tf.expand_dims(ui, axis=0) + taui = tf.expand_dims(taui, axis=0) + + if quick_eval: + + @tf.function + def _cumtrapz(x, dx=1., initial=0.): + x = (x[:, :-1] + x[:, 1:]) / 2.0 + x = tf.pad(x, ((0, 0), (1, 0), (0, 0)), constant_values=initial) + return tf.cumsum(x, axis=1) * dx + + a = tf.cast(1.0 / taui, 'float64') + x = ui * tstim + + if reset_signal is None: + reset_times = tf.range(0, s[1]+chunksize-1, chunksize) + else: + reset_times = tf.where(reset_signal[0, :])[:, 0] + reset_times = tf.pad(reset_times, ((0, 1),), constant_values=s[1]) + + td = [] + x0, imu0 = 0.0, 0.0 + for j in range(reset_times.shape[0]-1): + xi = tf.cast(x[:, reset_times[j]:reset_times[j+1], :], 'float64') + ix = _cumtrapz(a + xi, dx=1, initial=0) + a + (x0 + xi[:, :1])/2.0 + + mu = tf.exp(ix) + imu = _cumtrapz(mu*xi, dx=1, initial=0) + (x0 + mu[:, :1]*xi[:, :1])/2.0 + imu0 + + valid = tf.logical_and(mu > 0.0, imu > 0.0) + mu = tf.where(valid, mu, 1.0) + imu = tf.where(valid, imu, 1.0) + _td = 1 - tf.exp(tf.math.log(imu) - tf.math.log(mu)) + _td = tf.where(valid, _td, 1.0) + + x0 = xi[:, -1:] + imu0 = imu[:, -1:] / mu[:, -1:] + td.append(tf.cast(_td, 'float32')) + td = tf.concat(td, axis=1) + + #layer['Y'] = tstim * td + layer['Y'] = tstim * tf.pad(td[:, :-1, :], ((0, 0), (1, 0), (0, 0)), constant_values=1.0) + layer['Y'] = tf.where(tf.math.is_nan(layer['X']), _nan, layer['Y']) + + else: + ustim = 1.0 / taui + ui * tstim + a = 1.0 / taui + + """ + tensor = tf.zeros(1, T) + indices = tf.range(T) + + for (start, end) in slices: + slice = tf.logical_and(indices > start, indices < end) + + xi = tf.where(slice, x, 0.0) + newval = some_function(xi) + + tensor = tf.where(slice, newval, tensor) + """ + """ + @tf.function + def stp_op(s, ustim, a, ui): + new_td = tf.ones_like(ustim[:, 0:1, :]) + new_td.set_shape((s[0], 1, s[2])) + td_list = [tf.squeeze(new_td)] + for tt in range(1, s[1]): + new_td = a + new_td * (1.0 - ustim[:, (tt-1):tt, :]) + new_td = tf.where(tf.math.logical_and(ui > 0.0, new_td < 0.0), 0.0, new_td) + new_td = tf.where(tf.math.logical_and(ui < 0.0, new_td > 5.0), 5.0, new_td) + td_list.append(tf.squeeze(new_td)) + #td = tf.concat((td[:, :tt, :], new_td, td[:, (tt+1):s[1], :]), axis=1) + td_list[tt].set_shape((s[0], s[2])) + td = tf.stack(td_list, axis=1) + return td + td = stp_op((s[0], s[1], s[2]), ustim, a, ui) + """ + + @tf.function + def stp_op(td, s, ustim, a, ui): + td = [] + for tt in tf.range(1, s[1]): + _td = a + td[:, (tt - 1):tt, :] * (1.0 - ustim[:, (tt - 1):tt, :]) + _td = tf.where(tf.math.logical_and(ui > 0.0, _td < 0.0), 0.0, _td) + _td = tf.where(tf.math.logical_and(ui < 0.0, _td > 5.0), 5.0, _td) + #td = tf.concat((td[:, :tt, :], new_td, td[:, (tt+1):s[1], :]), axis=1) + td.append(_td) + return tf.concat(td, axis=1) + + td = stp_op([], (s[0], s[1], s[2]), ustim, a, ui) + + log.debug('s: %s', s) + log.debug('ustim shape %s', ustim.get_shape()) + log.debug('a shape %s', a.get_shape()) + log.debug('ui shape %s', ui.get_shape()) + log.debug('td shape %s', td.get_shape()) + + layer['Y'] = tstim * td + layer['Y'] = tf.where(tf.math.is_nan(X), np.nan, layer['Y']) + + log.debug('Y shape %s', layer['Y'].get_shape()) + + # td = tf.cond(tf.math.reduce_any(ui != 0.0), + # true_fn=lambda: stp_op(td, (s[0], s[1], s[2]), ustim, a, ui), + # false_fn=lambda: td) + else: raise ValueError(f'Module "{fn}" not supported for mapping to cnn layer.') @@ -453,6 +651,13 @@ def tf2modelspec(net, modelspec): m['phi']['g'] = net_layer_vals[i]['g'][0, :, :].T m['phi']['d'] = net_layer_vals[i]['d'][0, :, :].T + elif m['fn'] in ['nems.modules.stp.short_term_plasticity']: + if 'phi' in m.keys(): + m['phi']['tau'] = net_layer_vals[i]['tau'].T + m['phi']['u'] = net_layer_vals[i]['u'].T + if 'x0' in m['phi']: + m['phi']['x0'] = net_layer_vals[i]['x0'].T + else: raise ValueError("NEMS module fn=%s not supported", m['fn']) @@ -822,3 +1027,48 @@ def eval_tf(modelspec, est, log_dir): #E = np.nanstd(new_est['pred'].as_continuous()-new_est['pred_nems'].as_continuous()) return new_est + + +def eval_tf_layer(data: np.ndarray, + layer_spec: str, + state_data: np.ndarray = None, + modelspec: modelspec.ModelSpec = None + ) -> np.ndarray: + """Takes in a numpy array and applies single a tf layer to it. + + :param data: The input data. Shape of (reps, time, channels). + :param layer_spec: A layer spec for a single layer of a modelspec. + :param state_data: State gain data, optional. Same shape as data. + + :return: The processed data. + """ + if layer_spec is None and modelspec is None: + raise ValueError('Either of "layer_spec" or "modelspec" must be specified.') + + if modelspec is not None: + ms = modelspec + else: + ms = nems.initializers.from_keywords(layer_spec) + + rep_dims, tps_per_stim, feat_dims = data.shape + if state_data is not None: + state_dims = state_data.shape[-1] + else: + state_dims = 0 + + tf_layers = ms.modelspec2tf( + tps_per_stim=tps_per_stim, + feat_dims=feat_dims, + state_dims=state_dims) + net = cnn.Net( + data_dims=(rep_dims, tps_per_stim, tf_layers[-1]['n_kern']), + n_feats=feat_dims, + sr_Hz=100, # arbitrary, unused + layers=tf_layers + ) + + feed_dict = net.feed_dict(stim=data, state=state_data, indices=np.arange(rep_dims)) + with net.sess.as_default(): + pred = net.layers[0]['Y'].eval(feed_dict) + + return pred \ No newline at end of file diff --git a/nems/tf/cnnlink_new.py b/nems/tf/cnnlink_new.py new file mode 100644 index 00000000..7bd5883d --- /dev/null +++ b/nems/tf/cnnlink_new.py @@ -0,0 +1,283 @@ +"""Helpers to build a Tensorflow Keras model from a modelspec.""" + +import logging +import os +import typing +from pathlib import Path + +import numpy as np +import tensorflow as tf +import tensorflow_probability as tfp +from matplotlib import pyplot as plt + +from nems import recording, get_setting, initializers, modelspec +from nems.tf import callbacks, loss_functions, modelbuilder +import nems.utils + +log = logging.getLogger(__name__) + +os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' +# tf.keras.backend.set_floatx('float64') + + +def tf2modelspec(model, modelspec): + """Populates the modelspec phis with the model layers. + + Does some checking on dims and names as well. + """ + log.info('Populating modelspec with model weights.') + + # first layer is the input shahpe, so skip it + for idx, layer in enumerate(model.layers[1:]): + ms = modelspec[idx] + + if layer.name != ms['fn']: + raise AssertionError('Model layers and modelspec layers do not match up!') + + phis = layer.weights_to_phi() + # check that phis/weights match up in both names and shapes + if phis.keys() != ms['phi'].keys(): + raise AssertionError(f'Model layer "{layer.name}" weights and modelspec phis do not have matching names!') + for model_weights, ms_phis in zip(phis.values(), ms['phi'].values()): + if model_weights.shape != ms_phis.shape: + if layer.name == 'nems.modules.nonlinearity.double_exponential': + continue # dexp has weird weight shapes due to basic init + raise AssertionError(f'Model layer "{layer.name}" weights and modelspec phis do not have matching ' + f'shapes!') + + ms['phi'] = phis + + return modelspec + + +def eval_tf(model: tf.keras.Model, stim: np.ndarray) -> np.ndarray: + """Evaluate the mode on the stim.""" + return model.predict(stim) + + +def compare_ms_tf(ms, model, rec_ms, stim_tf): + """Compares the evaluation between a modelspec and a keras model. + + For the modelspec, uses a recording object. For the tf model, uses the formatted data from fit_tf.""" + pred_ms = ms.evaluate(rec_ms)['pred']._data.flatten() + pred_tf = model.predict(stim_tf).flatten() + + return np.nanstd(pred_ms - pred_tf) + + +def fit_tf( + modelspec, + est: recording.Recording, + use_modelspec_init: bool = True, + optimizer: str = 'adam', + max_iter: int = 10000, + cost_function: str = 'squared_error', + early_stopping_steps: int = 5, + early_stopping_tolerance: float = 5e-4, + learning_rate: float = 1e-4, + batch_size: typing.Union[None, int] = None, + seed: int = 0, + filepath: typing.Union[str, Path] = None, + freeze_layers: typing.Union[None, list] = None, + **context + ) -> dict: + """TODO + + :param est: + :param modelspec: + :param use_modelspec_init: + :param optimizer: + :param max_iter: + :param cost_function: + :param early_stopping_steps: + :param early_stopping_tolerance: + :param learning_rate: + :param batch_size: + :param seed: + :param filepath: + :param freeze_layers: Indexes of layers to freeze prior to training. Indexes are modelspec indexes, so are offset + from model layer indexes. + :param context: + + :return: + """ + log.info('Building tensorflow keras model from modelspec.') + nems.utils.progress_fun() + + # figure out where to save model checkpoints + if filepath is None: + filepath = modelspec.meta['modelpath'] + + # if job is running on slurm, need to change model checkpoint dir + job_id = os.environ.get('SLURM_JOBID', None) + if job_id is not None: + # keep a record of the job id + modelspec.meta['slurm_jobid'] = job_id + + log_dir_root = Path('/mnt/scratch') + assert log_dir_root.exists() + log_dir_sub = Path('SLURM_JOBID' + job_id) / str(modelspec.meta['batch'])\ + / modelspec.meta.get('cellid', "NOCELL")\ + / modelspec.meta['modelname'] + filepath = log_dir_root / log_dir_sub + + filepath = Path(filepath) + if not filepath.exists(): + filepath.mkdir(exist_ok=True, parents=True) + + filepath = filepath / 'checkpoints' + + # update seed based on fit index + seed += modelspec.fit_index + + # need to get duration of stims in order to reshape data + epoch_name = 'REFERENCE' # TODO: this should not be hardcoded + # also grab the fs + fs = est['stim'].fs + + # extract out the raw data, and reshape to (batch, time, channel) + stim_train = np.transpose(est['stim'].extract_epoch(epoch=epoch_name, mask=est['mask']), [0, 2, 1]) + resp_train = np.transpose(est['resp'].extract_epoch(epoch=epoch_name, mask=est['mask']), [0, 2, 1]) + log.info(f'Feature dimensions: {stim_train.shape}; Data dimensions: {resp_train.shape}.') + + # correlation for monitoring + # TODO: tf.utils? + def pearson(y_true, y_pred): + return tfp.stats.correlation(y_true, y_pred, event_axis=None, sample_axis=None) + + # get the layers and build the model + cost_fn = loss_functions.get_loss_fn(cost_function) + model_layers = modelspec.modelspec2tf2(use_modelspec_init=use_modelspec_init, seed=seed, fs=fs) + model = modelbuilder.ModelBuilder( + name='Test-model', + layers=model_layers, + learning_rate=learning_rate, + loss_fn=cost_fn, + optimizer=optimizer, + metrics=[pearson], + ).build_model(input_shape=stim_train.shape) + + # create the callbacks + early_stopping = callbacks.DelayedStopper(monitor='loss', + patience=30 * early_stopping_steps, + min_delta=early_stopping_tolerance, + verbose=1, + restore_best_weights=False) + checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=str(filepath), + save_best_only=False, + save_weights_only=True, + save_freq=100 * stim_train.shape[0], + monitor='loss', + verbose=1) + nan_terminate = tf.keras.callbacks.TerminateOnNaN() + + # freeze layers + if freeze_layers is not None: + for freeze_index in freeze_layers: + model.layers[freeze_index + 1].trainable = False + log.info(f'Freezing layer "{model.layers[freeze_index + 1].name}".') + + log.info('Fitting model...') + history = model.fit( + stim_train, + resp_train, + # validation_split=0.2, + verbose=2, + epochs=max_iter, + batch_size=stim_train.shape[0] if batch_size == 0 else batch_size, + callbacks=[ + nan_terminate, + early_stopping, + checkpoint, + # TODO: tensorboard + ] + ) + + # did we terminate on a nan? Load checkpoint if so + if np.all(np.isnan(model.predict(stim_train))): + log.warning('Model terminated on nan, restoring saved weights.') + try: + # this can fail if it nans out before a single checkpoint gets saved + model.load_weights(str(filepath)) + except tf.errors.NotFoundError: + pass + + modelspec = tf2modelspec(model, modelspec) + # compare the predictions from the model and modelspec + error = compare_ms_tf(modelspec, model, est, stim_train) + log.info(f'Mean difference between NEMS and TF model prediction: {error}') + + # add in some relevant meta information + modelspec.meta['n_parms'] = len(modelspec.phi_vector) + try: + n_epochs = len(history.history['loss']) + except KeyError: + n_epochs = 0 + try: + max_iter = modelspec.meta['extra_results'] + modelspec.meta['extra_results'] = max(max_iter, n_epochs) + except KeyError: + modelspec.meta['extra_results'] = n_epochs + + nems.utils.progress_fun() + + return {'modelspec': modelspec} + + +def fit_tf_init( + modelspec, + est: recording.Recording, + **kwargs + ) -> dict: + """TODO""" + layers_to_freeze = [ + 'levelshift', + 'relu', + 'stp', + 'state_dc_gain', + 'state_gain', + 'rdt_gain', + ] + + frozen_idxes = [] + for idx, ms in enumerate(modelspec): + for layer_to_freeze in layers_to_freeze: + if layer_to_freeze in ms['fn']: + frozen_idxes.append(idx) + break # break out of inner loop + + return fit_tf( + modelspec, + est, + freeze_layers=frozen_idxes, + **kwargs, + ) + + +def eval_tf_layer(data: np.ndarray, + layer_spec: typing.Union[None, str] = None, + stop: typing.Union[None, int] = None, + modelspec: modelspec.ModelSpec = None + ) -> np.ndarray: + """Takes in a numpy array and applies a single tf layer to it. + + :param data: The input data. Shape of (reps, time, channels). + :param layer_spec: A layer spec for layers of a modelspec. + :param stop: What layer to eval to. Non inclusive. If not passed, will evaluate the whole layer spec. + :param modelspec: Optionally use an existing modelspec. Takes precedence over layer_spec. + + :return: The processed data. + """ + if layer_spec is None and modelspec is None: + raise ValueError('Either of "layer_spec" or "modelspec" must be specified.') + + if modelspec is not None: + ms = modelspec + else: + ms = initializers.from_keywords(layer_spec) + + layers = ms.modelspec2tf2(use_modelspec_init=True)[:stop] + model = modelbuilder.ModelBuilder(layers=layers).build_model(input_shape=data.shape) + + pred = model.predict(data) + return pred diff --git a/nems/tf/layers.py b/nems/tf/layers.py new file mode 100644 index 00000000..fbd508f6 --- /dev/null +++ b/nems/tf/layers.py @@ -0,0 +1,565 @@ +import logging + +import tensorflow as tf + +log = logging.getLogger(__name__) + + +class BaseLayer(tf.keras.layers.Layer): + """Base layer with parent methods for converting from modelspec layer to tf layers and back.""" + def __init__(self, *args, **kwargs): + """Catch args/kwargs that aren't allowed kwargs of keras.layers.Layers""" + self.fs = kwargs.pop('fs', None) + + super(BaseLayer, self).__init__(*args, **kwargs) + + @classmethod + def from_ms_layer(cls, + ms_layer, + use_modelspec_init: bool = True, + fs: int = 100 + ): + """Parses modelspec layer to generate layer class. + + :param ms_layer: A layer of a modelspec. + :param use_modelspec_init: Whether to use the modelspec's initialization or use a tf builtin. + :param fs: The sampling rate of the data. + """ + log.info(f'Building tf layer for "{ms_layer["fn"]}".') + + kwargs = { + 'name': ms_layer['fn'], + 'fs': fs, + } + + if use_modelspec_init: + # convert the phis to tf constants + kwargs['initializer'] = {k: tf.constant_initializer(v) + for k, v in ms_layer['phi'].items()} + + # TODO: clean this up, maybe separate kwargs/fn_kwargs, or method to split out valid tf kwargs from rest + if 'bounds' in ms_layer: + kwargs['bounds'] = ms_layer['bounds'] + if 'chans' in ms_layer['fn_kwargs']: + kwargs['units'] = ms_layer['fn_kwargs']['chans'] + if 'crosstalk' in ms_layer['fn_kwargs']: + kwargs['crosstalk'] = ms_layer['fn_kwargs']['crosstalk'] + if 'reset_signal' in ms_layer['fn_kwargs']: + # kwargs['reset_signal'] = ms_layer['fn_kwargs']['reset_signal'] + kwargs['reset_signal'] = None + + return cls(**kwargs) + + @property + def layer_values(self): + """Returns key value pairs of the weight names and their values.""" + weight_names = [ + # extract the actual layer name from the tf layer naming format + # ex: "nems.modules.nonlinearity.double_exponential/base:0" + weight.name.split('/')[1].split(':')[0] + for weight in self.weights + ] + + layer_values = {layer_name: weight.numpy() for layer_name, weight in zip(weight_names, self.weights)} + return layer_values + + def weights_to_phi(self): + """In subclass, use self.weight_dict to get a dict of weight_name: weights.""" + raise NotImplementedError + + +class Dlog(BaseLayer): + """Simple dlog nonlinearity.""" + def __init__(self, + units=None, + name='dlog', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(Dlog, self).__init__(name=name, *args, **kwargs) + + # try to infer the number of units if not specified + if units is None and initializer is None: + self.units = 1 + elif units is None: + self.units = initializer['offset'].value.shape[1] + else: + self.units = units + + self.initializer = {'offset': tf.random_normal_initializer(mean=0, stddev=0.5, seed=seed)} + if initializer is not None: + self.initializer.update(initializer) + + def build(self, input_shape): + self.offset = self.add_weight(name='offset', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['offset'], + trainable=True, + ) + + def call(self, inputs, training=True): + # clip bounds at ±2 to avoid huge compression/expansion + eb = tf.math.pow(tf.constant(10, dtype='float32'), tf.clip_by_value(self.offset, -2, 2)) + return tf.math.log((inputs + eb) / eb) + + def weights_to_phi(self): + layer_values = self.layer_values + layer_values['offset'] = layer_values['offset'].reshape((-1, 1)) + log.info(f'Converted {self.name} to modelspec phis.') + return layer_values + + +class Levelshift(BaseLayer): + """Simple levelshift nonlinearity.""" + def __init__(self, + units=None, + name='levelshift', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(Levelshift, self).__init__(name=name, *args, **kwargs) + + # try to infer the number of units if not specified + if units is None and initializer is None: + self.units = 1 + elif units is None: + self.units = initializer['level'].value.shape[1] + else: + self.units = units + + self.initializer = {'level': tf.random_normal_initializer(mean=0, stddev=1, seed=seed)} + if initializer is not None: + self.initializer.update(initializer) + + def build(self, input_shape): + self.level = self.add_weight(name='level', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['level'], + trainable=True, + ) + + def call(self, inputs, training=True): + return tf.identity(inputs + self.level) + + def weights_to_phi(self): + layer_values = self.layer_values + layer_values['level'] = layer_values['level'].reshape((-1, 1)) + log.info(f'Converted {self.name} to modelspec phis.') + return layer_values + + +class Relu(BaseLayer): + """Simple relu nonlinearity.""" + def __init__(self, + name='relu', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(Relu, self).__init__(name=name, *args, **kwargs) + + self.initializer = {'offset': tf.random_normal_initializer(mean=0, stddev=1, seed=seed)} + if initializer is not None: + self.initializer.update(initializer) + + def build(self, input_shape): + self.offset = self.add_weight(name='offset', + shape=(input_shape[-1],), + dtype='float32', + initializer=self.initializer['offset'], + trainable=True, + ) + + def call(self, inputs, training=True): + return tf.nn.relu(inputs + self.offset) + + def weights_to_phi(self): + layer_values = self.layer_values + layer_values['offset'] = layer_values['offset'].reshape((-1, 1)) + log.info(f'Converted {self.name} to modelspec phis.') + return layer_values + + +class DoubleExponential(BaseLayer): + """Basic double exponential nonlinearity.""" + def __init__(self, + units=None, + name='dexp', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(DoubleExponential, self).__init__(name=name, *args, **kwargs) + + # try to infer the number of units if not specified + if units is None and initializer is None: + self.units = 1 + elif units is None: + self.units = initializer['base'].value.shape[1] + else: + self.units = units + + self.initializer = { + 'base': tf.random_normal_initializer(mean=0, stddev=1, seed=seed), + 'amplitude': tf.random_normal_initializer(mean=1, stddev=0.5, seed=seed + 1), + 'shift': tf.random_normal_initializer(mean=0, stddev=1, seed=seed + 2), + 'kappa': tf.random_normal_initializer(mean=1, stddev=0.5, seed=seed + 3), + } + + if initializer is not None: + self.initializer.update(initializer) + + def build(self, input_shape): + self.base = self.add_weight(name='base', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['base'], + trainable=True, + ) + self.amplitude = self.add_weight(name='amplitude', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['amplitude'], + trainable=True, + ) + self.shift = self.add_weight(name='shift', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['shift'], + trainable=True, + ) + self.kappa = self.add_weight(name='kappa', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['kappa'], + trainable=True, + ) + + def call(self, inputs, training=True): + # formula: base + amp * e^(-e^(-e^kappa * (inputs - shift))) + return self.base + self.amplitude * tf.math.exp(-tf.math.exp(-tf.math.exp(self.kappa) * (inputs - self.shift))) + + def weights_to_phi(self): + layer_values = self.layer_values + + # if self.units == 1: + # shape = (1,) + # else: + # shape = (-1, 1) + shape = (-1, 1) + + layer_values['amplitude'] = layer_values['amplitude'].reshape(shape) + layer_values['base'] = layer_values['base'].reshape(shape) + layer_values['kappa'] = layer_values['kappa'].reshape(shape) + layer_values['shift'] = layer_values['shift'].reshape(shape) + + log.info(f'Converted {self.name} to modelspec phis.') + return layer_values + + +class WeightChannelsBasic(BaseLayer): + """Basic weight channels.""" + def __init__(self, + # kind='basic', + units=None, + name='weight_channels_basic', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(WeightChannelsBasic, self).__init__(name=name, *args, **kwargs) + + # try to infer the number of units if not specified + if units is None and initializer is None: + self.units = 1 + elif units is None: + self.units = initializer['coefficients'].value.shape[1] + else: + self.units = units + + self.initializer = {'coefficients': tf.random_normal_initializer(mean=0.01, stddev=0.2, seed=seed)} + if initializer is not None: + self.initializer.update(initializer) + + def build(self, input_shape): + self.coefficients = self.add_weight(name='coefficients', + shape=(self.units, input_shape[-1]), + dtype='float32', + initializer=self.initializer['coefficients'], + trainable=True, + ) + + def call(self, inputs, training=True): + tranposed = tf.transpose(self.coefficients) + return tf.nn.conv1d(inputs, tf.expand_dims(tranposed, 0), stride=1, padding='SAME') + + def weights_to_phi(self): + layer_values = self.layer_values + layer_values['coefficients'] = layer_values['coefficients'] + log.info(f'Converted {self.name} to modelspec phis.') + return layer_values + + +class WeightChannelsGaussian(BaseLayer): + """Basic weight channels.""" + # TODO: convert per https://stackoverflow.com/a/52012658/1510542 in order to handle banks + def __init__(self, + # kind='gaussian', + bounds, + units=None, + name='weight_channels_gaussian', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(WeightChannelsGaussian, self).__init__(name=name, *args, **kwargs) + + # try to infer the number of units if not specified + if units is None and initializer is None: + self.units = 1 + elif units is None: + self.units = initializer['mean'].value.shape[0] + else: + self.units = units + + self.initializer = { + 'mean': tf.random_normal_initializer(mean=0.5, stddev=0.4, seed=seed), + 'sd': tf.random_normal_initializer(mean=0, stddev=.4, seed=seed + 1), # this is halfnorm in NEMS + } + + if initializer is not None: + self.initializer.update(initializer) + + self.mean_constraint = tf.keras.constraints.MinMaxNorm( + min_value=bounds['mean'][0][0], + max_value=bounds['mean'][1][0]) + self.sd_constraint = tf.keras.constraints.MinMaxNorm( + min_value=bounds['sd'][0][0], + max_value=bounds['sd'][1][0]) + + def build(self, input_shape): + self.mean = self.add_weight(name='mean', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['mean'], + constraint=self.mean_constraint, + trainable=True, + ) + self.sd = self.add_weight(name='sd', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['sd'], + constraint=self.sd_constraint, + trainable=True, + ) + + def call(self, inputs, training=True): + input_features = tf.cast(tf.shape(inputs)[-1], dtype='float32') + temp = tf.range(input_features) / input_features + temp = (tf.reshape(temp, [1, input_features, 1]) - self.mean) / self.sd + temp = tf.math.exp(-0.5 * tf.math.square(temp)) + kernel = temp / tf.math.reduce_sum(temp, axis=1) + + return tf.nn.conv1d(inputs, kernel, stride=1, padding='SAME') + + def weights_to_phi(self): + layer_values = self.layer_values + # don't need to do any reshaping + log.info(f'Converted {self.name} to modelspec phis.') + return layer_values + + +class FIR(BaseLayer): + """Basic weight channels.""" + def __init__(self, + units=None, + name='fir', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(FIR, self).__init__(name=name, *args, **kwargs) + + # try to infer the number of units if not specified + if units is None and initializer is None: + self.units = 1 + elif units is None: + self.units = initializer['coefficients'].value.shape[1] + else: + self.units = units + + self.initializer = {'coefficients': tf.random_normal_initializer(mean=0, stddev=1, seed=seed)} + if initializer is not None: + self.initializer.update(initializer) + + def build(self, input_shape): + self.coefficients = self.add_weight(name='coefficients', + shape=(input_shape[-1], self.units), + dtype='float32', + initializer=self.initializer['coefficients'], + trainable=True, + ) + + def call(self, inputs, training=True): + # need to pad the front of the inputs + pad_size = self.units - 1 + padded_input = tf.pad(inputs, [[0, 0], [pad_size, 0], [0, 0]]) + transposed = tf.transpose(tf.reverse(self.coefficients, axis=[-1])) + return tf.nn.conv1d(padded_input, tf.expand_dims(transposed, -1), stride=1, padding='VALID') + + def weights_to_phi(self): + layer_values = self.layer_values + # don't need to do any reshaping + log.info(f'Converted {self.name} to modelspec phis.') + return layer_values + + +class STPQuick(BaseLayer): + """Quick version of STP.""" + def __init__(self, + bounds, + crosstalk=False, + reset_signal=None, + units=None, + name='dexp', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(STPQuick, self).__init__(name=name, *args, **kwargs) + + self.crosstalk = crosstalk + self.reset_signal = reset_signal + + # try to infer the number of units if not specified + if units is None and initializer is None: + self.units = 1 + elif units is None: + self.units = initializer['u'].value.shape[1] + else: + self.units = units + + self.initializer = { + 'u': tf.random_normal_initializer(mean=0.01, stddev=0.05, seed=seed), + 'tau': tf.random_normal_initializer(mean=0.05, stddev=0.01, seed=seed + 1), + 'x0': None + } + + if initializer is not None: + self.initializer.update(initializer) + + self.u_constraint = lambda u: tf.abs(u) + # max_tau = tf.maximum(tf.abs(self.initializer['tau'](self.units)), 2.001 / self.fs) + # self.tau_constraint = tf.keras.constraints.MaxNorm(max_value=2.001 / self.fs) + # self.tau_constraint = tf.keras.constraints.NonNeg() + self.tau_constraint = lambda tau: tf.maximum(tf.abs(tau), 2.001 / self.fs) + + def build(self, input_shape): + self.u = self.add_weight(name='u', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['u'], + constraint=self.u_constraint, + trainable=True, + ) + self.tau = self.add_weight(name='tau', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['tau'], + constraint=self.tau_constraint, + trainable=True, + ) + if self.initializer['x0'] is None: + self.x0 = None + else: + self.x0 = self.add_weight(name='x0', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['x0'], + trainable=True, + ) + + def call(self, inputs, trainig=True): + _zero = tf.constant(0.0, dtype='float32') + _nan = tf.constant(0.0, dtype='float32') + + s = inputs.shape + tstim = tf.where(tf.math.is_nan(inputs), _zero, inputs) + + if self.x0 is not None: # x0 should be tf variable to avoid retraces + # TODO: is this expanding along the right dim? tstim dims: (None, time, chans) + tstim = tstim - tf.expand_dims(self.x0, axis=1) + + # convert a & tau units from sec to bins + ui = tf.math.abs(tf.reshape(self.u, (1, -1))) / self.fs + taui = tf.math.abs(tf.reshape(self.tau, (1, -1))) * self.fs + + # convert chunksize from sec to bins + chunksize = 5 + chunksize = int(chunksize * self.fs) + + if self.crosstalk: + # assumes dim of u is 1 ! + tstim = tf.math.reduce_mean(tstim, axis=0, keepdims=True) + + ui = tf.expand_dims(ui, axis=0) + taui = tf.expand_dims(taui, axis=0) + + @tf.function + def _cumtrapz(x, dx=1., initial=0.): + x = (x[:, :-1] + x[:, 1:]) / 2.0 + x = tf.pad(x, ((0, 0), (1, 0), (0, 0)), constant_values=initial) + return tf.cumsum(x, axis=1) * dx + + a = tf.cast(1.0 / taui, 'float64') + x = ui * tstim + + if self.reset_signal is None: + reset_times = tf.range(0, s[1] + chunksize - 1, chunksize) + else: + reset_times = tf.where(self.reset_signal[0, :])[:, 0] + reset_times = tf.pad(reset_times, ((0, 1),), constant_values=s[1]) + + td = [] + x0, imu0 = 0.0, 0.0 + for j in range(reset_times.shape[0] - 1): + xi = tf.cast(x[:, reset_times[j]:reset_times[j + 1], :], 'float64') + ix = _cumtrapz(a + xi, dx=1, initial=0) + a + (x0 + xi[:, :1]) / 2.0 + + mu = tf.exp(ix) + imu = _cumtrapz(mu * xi, dx=1, initial=0) + (x0 + mu[:, :1] * xi[:, :1]) / 2.0 + imu0 + + valid = tf.logical_and(mu > 0.0, imu > 0.0) + mu = tf.where(valid, mu, 1.0) + imu = tf.where(valid, imu, 1.0) + _td = 1 - tf.exp(tf.math.log(imu) - tf.math.log(mu)) + _td = tf.where(valid, _td, 1.0) + + x0 = xi[:, -1:] + imu0 = imu[:, -1:] / mu[:, -1:] + td.append(tf.cast(_td, 'float32')) + td = tf.concat(td, axis=1) + + ret = tstim * tf.pad(td[:, :-1, :], ((0, 0), (1, 0), (0, 0)), constant_values=1.0) + ret = tf.where(tf.math.is_nan(inputs), _nan, ret) + + return ret + + def weights_to_phi(self): + layer_values = self.layer_values + # don't need to do any reshaping + log.info(f'Converted {self.name} to modelspec phis.') + return layer_values + diff --git a/nems/tf/loss_functions.py b/nems/tf/loss_functions.py index a0f7fba2..ac0906bf 100644 --- a/nems/tf/loss_functions.py +++ b/nems/tf/loss_functions.py @@ -8,24 +8,37 @@ log = logging.getLogger(__name__) +def get_loss_fn(loss_fn_str: str): + """Utility to get a function from a string name.""" + mapping = { + 'squared_error': loss_se, + # 'squared_error': loss_se_safe, + 'poisson': poisson, + 'nmse': loss_tf_nmse, + 'nmse_shrinkage': loss_tf_nmse_shrinkage, + } + + try: + return mapping[loss_fn_str] + except KeyError: + raise KeyError(f'Loss must be "squared_error", "poisson", "nmse",' + f' or "nmse_shrinkage", not {loss_fn_str}') + + def poisson(response, prediction): """Poisson loss.""" - return tf.reduce_mean(prediction - response * tf.log(prediction + 1e-5), name='poisson') + return tf.math.reduce_mean(prediction - response * tf.math.log(prediction + 1e-5), name='poisson') def loss_se(response, prediction): """Squared error loss.""" - # TODO : add percell option + return tf.math.reduce_mean(tf.math.square(response - prediction)) / tf.math.reduce_mean(tf.math.square(response)) - # SVD -- was trying to reduce nan outputs, but doesn't seem to make a different. more important to avoid - # having the fitter get there. Possibly useful in the future though? - #x = tf.square(response - prediction) - #y = tf.square(response) - #x = tf.reduce_mean(tf.boolean_mask(x, tf.is_finite(x))) + tf.reduce_mean(tf.dtypes.cast(tf.is_inf(x), tf.float32)) - #y = tf.reduce_mean(tf.boolean_mask(y, tf.is_finite(y))) + tf.reduce_mean(tf.dtypes.cast(tf.is_inf(y), tf.float32)) - #return x/y - return tf.reduce_mean(tf.square(response - prediction)) / tf.reduce_mean(tf.square(response)) +def loss_se_safe(response, prediction): + """Squared error loss, with checking for NaNs.""" + tf.debugging.check_numerics(prediction, message='Nan values in prediction.') + return tf.math.reduce_mean(tf.math.square(response - prediction)) / tf.math.reduce_mean(tf.math.square(response)) def loss_tf_nmse_shrinkage(response, prediction): @@ -37,7 +50,7 @@ def loss_tf_nmse(response, prediction, per_cell=True): """Normalized means squared error loss.""" mE, sE = tf_nmse(response, prediction, per_cell=per_cell) if per_cell: - return tf.reduce_mean(mE) + return tf.math.reduce_mean(mE) else: return mE diff --git a/nems/tf/modelbuilder.py b/nems/tf/modelbuilder.py new file mode 100644 index 00000000..6560241c --- /dev/null +++ b/nems/tf/modelbuilder.py @@ -0,0 +1,83 @@ +import tensorflow as tf +import logging + +from nems.tf import loss_functions + +log = logging.getLogger(__name__) + + +class ModelBuilder: + """Basic wrapper to build tensorflow keras models. + + Not an actual subclass of tf.keras.Model, but instead a set of helper + methods to prepare layers and then return a model when ready. + """ + def __init__(self, + name='model', + layers=None, + learning_rate=1e-3, + optimizer='adam', + loss_fn=loss_functions.loss_se, + metrics=None, + *args, + **kwargs + ): + """Layers don't need to be added at init, only prior to fitting. + + :param name: + :param args: + :param kwargs: + """ + self.optimizer_dict = { + 'adam': tf.keras.optimizers.Adam, + 'sgd': tf.keras.optimizers.SGD, + } + + if layers is not None: + self.model_layers = layers + else: + self.model_layers = [] + + self.name = name + self.loss_fn = loss_fn + self.metrics = metrics + + self.learning_rate = learning_rate + self.optimizer = optimizer + + def add_layer(self, layer, idx=None): + """Adds a layer to the list of layers that will be part of the model. + + :param layer: Layer to add. + :param idx: Insertion index. + """ + if idx is None: + self.model_layers.append(layer) + log.debug(f'Added "{layer.name}" to end of model.') + else: + self.model_layers.insert(idx, layer) + log.debug(f'Inserted "{layer.name}" to idx "{idx}".') + + def build_model(self, input_shape): + """Creates a tf.keras.Model instance. Simple single headed model. + + :param input_shape: Allows summary creation. Shape should be a tuple of (batch, time, channel). + :return: tf.keras.Model instance. + """ + inputs = tf.keras.Input(shape=input_shape[1:], name='input_node', batch_size=input_shape[0], dtype='float32') + outputs = inputs + for layer in self.model_layers: + outputs = layer(outputs) + + model = tf.keras.Model(inputs=inputs, outputs=outputs, name=self.name) + log.info('Built model, printing summary.') + model.summary(print_fn=log.info) + + # build the optimizer + try: + optimizer = self.optimizer_dict[self.optimizer](learning_rate=self.learning_rate) + except KeyError: + optimizer = self.optimizer + + model.compile(optimizer, loss=self.loss_fn, metrics=self.metrics) + return model diff --git a/nems/xform_helper.py b/nems/xform_helper.py index 0cd55b8a..334f4faf 100644 --- a/nems/xform_helper.py +++ b/nems/xform_helper.py @@ -193,19 +193,28 @@ def fit_model_xform(cellid, batch, modelname, autoPlot=True, saveInDB=False, else: cell_name = cellid - prefix = get_setting('NEMS_RESULTS_DIR') - destination = os.path.join(prefix, str(batch), cell_name, modelspec.get_longname()) + if 'modelpath' not in modelspec.meta: + prefix = get_setting('NEMS_RESULTS_DIR') + destination = os.path.join(prefix, str(batch), cell_name, modelspec.get_longname()) + + log.info(f'Setting modelpath to "{destination}"') + modelspec.meta['modelpath'] = destination + modelspec.meta['figurefile'] = os.path.join(destination, 'figure.0000.png') + else: + destination = modelspec.meta['modelpath'] # figure out URI for location to save results (either file or http, depending on USE_NEMS_BAPHY_API) if get_setting('USE_NEMS_BAPHY_API'): - prefix = 'http://'+get_setting('NEMS_BAPHY_API_HOST')+":"+str(get_setting('NEMS_BAPHY_API_PORT')) + '/results/' - save_destination = os.path.join(prefix, str(batch), cell_name, modelspec.get_longname()) + prefix = 'http://' + get_setting('NEMS_BAPHY_API_HOST') + ":" + str(get_setting('NEMS_BAPHY_API_PORT')) + \ + '/results' + save_loc = str(batch) + '/' + cell_name + '/' + modelspec.get_longname() + save_destination = prefix + '/' + save_loc + # set the modelspec meta save locations to be the filesystem and not baphy + modelspec.meta['modelpath'] = get_setting('NEMS_RESULTS_DIR') + '/' + save_loc + modelspec.meta['figurefile'] = modelspec.meta['modelpath'] + '/' + 'figure.0000.png' else: save_destination = destination - log.info(f'Setting modelpath to "{destination}"') - modelspec.meta['modelpath'] = destination - modelspec.meta['figurefile'] = os.path.join(destination, 'figure.0000.png') modelspec.meta['runtime'] = int(time.time() - startime) modelspec.meta.update(meta) diff --git a/nems/xforms.py b/nems/xforms.py index 8deebc8f..37ce3db2 100644 --- a/nems/xforms.py +++ b/nems/xforms.py @@ -1234,7 +1234,7 @@ def tree_path(recording, modelspecs, xfspec): def save_analysis(destination, recording, modelspec, xfspec, figures, - log, add_tree_path=False, update_meta=True): + log, add_tree_path=False, update_meta=True, save_rec=False): '''Save an analysis file collection to a particular destination.''' if add_tree_path: treepath = tree_path(recording, [modelspec], xfspec) @@ -1262,6 +1262,12 @@ def save_analysis(destination, recording, modelspec, xfspec, figures, save_resource(base_uri + 'figure.{:04d}.png'.format(number), data=figure) save_resource(base_uri + 'log.txt', data=log) save_resource(xfspec_uri, json=xfspec) + + if save_rec: + # TODO: copy actual recording file + rec_uri = base_uri + 'recording.tgz' + recording.save(rec_uri) + return {'savepath': base_uri} @@ -1279,7 +1285,7 @@ def load_analysis(filepath, eval_model=True, only=None): def _path_join(*args): if os.name == 'nt': # deal with problems on Windows OS - path = "/".join(*args) + path = "/".join(args) else: path = os.path.join(*args) return path @@ -1302,6 +1308,10 @@ def _path_join(*args): ctx['figures_to_load'] = figures_to_load ctx['log'] = logstring + if 'recording.tgz' in os.listdir(filepath): + # TODO: change rec path in modelspec + pass + if eval_model: ctx, log_xf = evaluate(xfspec, ctx) elif only is not None: diff --git a/setup.py b/setup.py index bcaf1ad8..a1932344 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ EXTRAS_REQUIRES = { 'docs': ['sphinx', 'sphinx_rtd_theme', 'pygments-enaml'], 'nwb': ['allensdk'], - 'tensorflow': ['tensorflow'] + 'tensorflow': ['tensorflow', 'tensorboard', 'tensorflow-probability'] } setup(