diff --git a/nems/db.py b/nems/db.py index e207eadc..900118ce 100644 --- a/nems/db.py +++ b/nems/db.py @@ -601,6 +601,26 @@ def update_startdate(queueid=None): return r +def update_job_pid(pid, queueid=None): + """ + in tQueue, update the PID + """ + if queueid is None: + if 'QUEUEID' in os.environ: + queueid = os.environ['QUEUEID'] + else: + log.warning("queueid not specified or found in os.environ") + return 0 + + engine = Engine() + conn = engine.connect() + sql = ("UPDATE tQueue SET pid={} WHERE id={}" + .format(pid, queueid)) + r = conn.execute(sql) + conn.close() + return r + + def update_job_tick(queueid=None): """ update current machine's load in the cluster db and tick off a step diff --git a/nems/modelspec.py b/nems/modelspec.py index 07bc80cc..2b813fe6 100644 --- a/nems/modelspec.py +++ b/nems/modelspec.py @@ -541,6 +541,9 @@ def quickplot(self, rec=None, epoch=None, occurrence=None, fit_index=None, """ if rec is None: rec = self.recording + # rec = rec.copy() + + # rec['resp'] = rec['resp'].rasterize() if fit_index is not None: self.fit_index = fit_index @@ -595,18 +598,22 @@ def quickplot(self, rec=None, epoch=None, occurrence=None, fit_index=None, if epoch is not None: try: epoch_bounds = rec['resp'].get_epoch_bounds(epoch, mask=rec['mask']) + except ValueError: + # some signal types can't handle masks with epoch bounds + epoch_bounds = rec['resp'].get_epoch_bounds(epoch) except: log.warning(f'Quickplot: no valid epochs matching {epoch}. Will not subset data.') epoch = None if 'mask' in rec.signals.keys(): rec = rec.apply_mask() + rec_resp = rec['resp'] rec_pred = rec['pred'] rec_stim = rec['stim'] - if (epoch is None) or (len(epoch_bounds)==0): - epoch_bounds = np.array([[0, rec['resp'].shape[1]/rec['resp'].fs]]) + if epoch is None or len(epoch_bounds) == 0: + epoch_bounds = np.array([[0, rec['resp'].shape[1] / rec['resp'].fs]]) # figure out which occurrence # not_empty = [np.any(np.isfinite(x)) for x in extracted] # occurrences containing non inf/nan data @@ -675,7 +682,7 @@ def quickplot(self, rec=None, epoch=None, occurrence=None, fit_index=None, ] for s in sig_names: - if rec[s].shape[0]>1: + if rec[s].shape[0] > 1: plot_fn = _lookup_fn_at('nems.plots.api.spectrogram') title = s + ' spectrogram' else: @@ -987,6 +994,22 @@ def modelspec2tf(self, tps_per_stim=550, feat_dims=1, data_dims=1, state_dims=0, return layers + def modelspec2tf2(self, seed=0, use_modelspec_init=True): + """New version + + TODO + """ + layers = [] + + for m in self: + layer = nems.tf.cnnlink_new.map_layer(layer=m, + use_modelspec_init=use_modelspec_init, + # seed=seed, + ) + layers.append(layer) + + return layers + def get_modelspec_metadata(modelspec): """Return a dict of the metadata for this modelspec. diff --git a/nems/modules/levelshift.py b/nems/modules/levelshift.py index 635b5c7f..d114fdec 100644 --- a/nems/modules/levelshift.py +++ b/nems/modules/levelshift.py @@ -1,4 +1,4 @@ -def levelshift(rec, i, o, level): +def levelshift(rec, i, o, level, **kwargs): ''' Parameters ---------- diff --git a/nems/modules/nonlinearity.py b/nems/modules/nonlinearity.py index ce1dd5a9..27a57b04 100644 --- a/nems/modules/nonlinearity.py +++ b/nems/modules/nonlinearity.py @@ -42,7 +42,7 @@ def _double_exponential(x, base, amplitude, shift, kappa): return base + amplitude * exp(-exp(np.array(-exp(kappa)) * (x - shift))) -def double_exponential(rec, i, o, base, amplitude, shift, kappa): +def double_exponential(rec, i, o, base, amplitude, shift, kappa, **kwargs): ''' A double exponential applied to all channels of a single signal. rec Recording object @@ -84,7 +84,7 @@ def _dlog(x, offset): return np.log((x + d) / d) -def dlog(rec, i, o, offset): +def dlog(rec, i, o, offset, **kwargs): """ Log compression with variable offset :param rec: recording object with signals i and o. diff --git a/nems/plots/timeseries.py b/nems/plots/timeseries.py index 3bba67b6..d209ee59 100644 --- a/nems/plots/timeseries.py +++ b/nems/plots/timeseries.py @@ -441,7 +441,7 @@ def pred_resp(rec=None, modelspec=None, ax=None, title=None, -------- ax : axis containing plot ''' - sig_list = ['pred', 'resp'] + sig_list = ['resp', 'pred'] sigs = [rec[s] for s in sig_list] ax = timeseries_from_signals(sigs, channels=channels, xlabel=xlabel, ylabel=ylabel, ax=ax, diff --git a/nems/plugins/default_fitters.py b/nems/plugins/default_fitters.py index 927d0111..1ee39f06 100644 --- a/nems/plugins/default_fitters.py +++ b/nems/plugins/default_fitters.py @@ -142,6 +142,60 @@ def nrc(fitkey): return xfspec +def newtf(fitkey): + """New tf fitter. + + TODO + :param fitkey: + :return: + """ + use_modelspec_init = True + optimizer = 'adam' + # max_iter = 10000 + max_iter = 10_000 + cost_function = 'squared_error' + early_stopping_steps = 5 + early_stopping_tolerance = 5e-4 + learning_rate = 1e-4 + # seed = 0 + + rand_count = 0 + pick_best = False + + options = _extract_options(fitkey) + + for op in options: + if op == 'b': + pick_best = True + elif op.startswith('rb'): + if len(op) == 2: + rand_count = 10 + else: + rand_count = int(op[2:]) + pick_best = True + + xfspec = [] + if rand_count > 0: + xfspec.append(['nems.initializers.rand_phi', {'rand_count': rand_count}]) + + xfspec.append(['nems.xforms.fit_wrapper', + { + 'max_iter': max_iter, + 'use_modelspec_init': use_modelspec_init, + 'optimizer': optimizer, + 'cost_function': cost_function, + 'fit_function': 'nems.tf.cnnlink_new.fit_tf', + 'early_stopping_steps': early_stopping_steps, + 'early_stopping_tolerance': early_stopping_tolerance, + 'learning_rate': learning_rate, + }]) + + if pick_best: + xfspec.append(['nems.analysis.test_prediction.pick_best_phi', {'criterion': 'mse_fit'}]) + + return xfspec + + def tf(fitkey): ''' Perform a Tensorflow fit, using Sam Norman-Haignere's CNN library @@ -399,7 +453,13 @@ def init(kw): elif op.startswith('ii'): sel_options['include_through'] = int(op[2:]) elif op.startswith('i'): - sel_options['include_idx'].append(int(op[1:])) + if not tf: + sel_options['include_idx'].append(int(op[1:])) + else: + if len(op[1:]) == 0: + max_iter = None + else: + max_iter = int(op[1:]) elif op.startswith('ffneg'): sel_options['freeze_after'] = -int(op[5:]) elif op.startswith('fneg'): diff --git a/nems/plugins/default_keywords.py b/nems/plugins/default_keywords.py index 096ba9ad..c69dec73 100644 --- a/nems/plugins/default_keywords.py +++ b/nems/plugins/default_keywords.py @@ -87,7 +87,7 @@ def wc(kw): # This is the default for wc, but options might overwrite it. fn = 'nems.modules.weight_channels.basic' - fn_kwargs = {'i': 'pred', 'o': 'pred', 'normalize_coefs': False} + fn_kwargs = {'i': 'pred', 'o': 'pred', 'normalize_coefs': False, 'chans': n_outputs} p_coefficients = {'mean': np.full((n_outputs, n_inputs), 0.01), 'sd': np.full((n_outputs, n_inputs), 0.2)} # add some variety across channels to help get the fitter started @@ -132,7 +132,7 @@ def wc(kw): # Generate evenly-spaced filter centers for the starting points fn = 'nems.modules.weight_channels.gaussian' fn_kwargs = {'i': 'pred', 'o': 'pred', 'n_chan_in': n_inputs, - 'normalize_coefs': False} + 'normalize_coefs': False, 'chans': n_outputs} coefs = 'nems.modules.weight_channels.gaussian_coefficients' mean = np.arange(n_outputs+1)/(n_outputs*2+2) + 0.25 mean = mean[1:] @@ -318,7 +318,7 @@ def fir(kw): template = { 'fn': 'nems.modules.fir.basic', 'fn_kwargs': {'i': 'pred', 'o': 'pred', 'non_causal': non_causal, - 'cross_channels': cross_channels}, + 'cross_channels': cross_channels, 'chans': n_coefs}, 'plot_fns': ['nems.plots.api.mod_output', 'nems.plots.api.spectrogram_output', 'nems.plots.api.strf_heatmap', @@ -334,7 +334,8 @@ def fir(kw): template = { 'fn': 'nems.modules.fir.filter_bank', 'fn_kwargs': {'i': 'pred', 'o': 'pred', 'non_causal': non_causal, - 'bank_count': n_banks, 'cross_channels': cross_channels}, + 'bank_count': n_banks, 'cross_channels': cross_channels, + 'chans': n_coefs}, 'plot_fns': ['nems.plots.api.mod_output', 'nems.plots.api.spectrogram_output', 'nems.plots.api.strf_heatmap', @@ -689,13 +690,14 @@ def lvl(kw): template = { 'fn': 'nems.modules.levelshift.levelshift', - 'fn_kwargs': {'i': 'pred', 'o': 'pred'}, + 'fn_kwargs': {'i': 'pred', 'o': 'pred', 'chans': n_shifts}, 'plot_fns': ['nems.plots.api.mod_output', 'nems.plots.api.spectrogram_output', 'nems.plots.api.pred_resp'], 'plot_fn_idx': 2, 'prior': {'level': ('Normal', {'mean': np.zeros([n_shifts, 1]), 'sd': np.ones([n_shifts, 1])})} + } return template @@ -968,7 +970,8 @@ def dexp(kw): template = { 'fn': 'nems.modules.nonlinearity.double_exponential', 'fn_kwargs': {'i': inout_name, - 'o': inout_name}, + 'o': inout_name, + 'chans': n_dims}, 'plot_fns': ['nems.plots.api.mod_output', 'nems.plots.api.spectrogram_output', 'nems.plots.api.pred_resp', @@ -1167,7 +1170,8 @@ def dlog(kw): template = { 'fn': 'nems.modules.nonlinearity.dlog', 'fn_kwargs': {'i': 'pred', - 'o': 'pred'}, + 'o': 'pred', + 'chans': chans,}, 'plot_fns': ['nems.plots.api.mod_output', 'nems.plots.api.spectrogram_output', 'nems.plots.api.pred_resp', diff --git a/nems/tf/callbacks.py b/nems/tf/callbacks.py new file mode 100644 index 00000000..6955d1e4 --- /dev/null +++ b/nems/tf/callbacks.py @@ -0,0 +1,14 @@ +"""Custom tensorflow training callbacks.""" + +import tensorflow as tf + + +class DelayedStopper(tf.keras.callbacks.EarlyStopping): + """Early stopper that waits before kicking in.""" + def __init__(self, start_epoch=100, **kwargs): + super(DelayedStopper, self).__init__(**kwargs) + self.start_epoch = start_epoch + + def on_epoch_end(self, epoch, logs=None): + if epoch > self.start_epoch: + super().on_epoch_end(epoch, logs) diff --git a/nems/tf/cnn.py b/nems/tf/cnn.py index 6c09d2d2..e2da7e5e 100644 --- a/nems/tf/cnn.py +++ b/nems/tf/cnn.py @@ -307,7 +307,7 @@ def train(self, stim, resp, learning_rate=0.01, max_iter=300, eval_interval=30, # early stopping for > 5 non improving iterations # though partly redundant with tolerance early stopping, catches nans - if self.best_loss_index < len(self.train_loss) - early_stopping_steps: + if self.best_loss_index <= len(self.train_loss) - early_stopping_steps: log.info(f'Best epoch > {early_stopping_steps} iterations ago, stopping early!') break diff --git a/nems/tf/cnnlink_new.py b/nems/tf/cnnlink_new.py new file mode 100644 index 00000000..20e30c8f --- /dev/null +++ b/nems/tf/cnnlink_new.py @@ -0,0 +1,234 @@ +"""Helpers to build a Tensorflow Keras model from a modelspec.""" + +import logging +from pathlib import Path + +import numpy as np +import tensorflow as tf +import tensorflow_probability as tfp + +from nems import recording, signal +from nems.tf import callbacks, layers, loss_functions, modelbuilder + +log = logging.getLogger(__name__) + + +# TODO: make a layer registry OR build in a lookup within the modelspec layers +layer_conversion = { + 'nems.modules.nonlinearity.dlog': layers.Dlog, + 'nems.modules.weight_channels.gaussian': layers.WeightChannelsGaussian, + 'nems.modules.fir.basic': layers.FIR, + 'nems.modules.levelshift.levelshift': layers.Levelshift, + 'nems.modules.nonlinearity.double_exponential': layers.DoubleExponential, + 'nems.modules.nonlinearity.relu': layers.Relu, +} + + +def map_layer(layer, + # idx: int, + use_modelspec_init: bool = True, + # seed: int = 0, + ) -> tf.keras.layers.Layer: + """Maps a modelspec layer to a layer. + + :param layer: A layer from the modelspec. + :param idx: Which layer. + :param use_modelspec_init: Whether to use the modelspec's initialization or use a tf builtin. + :param net_seed: The seed to use for randomness. + + :return: A custom subclass of `tf.keras.layers.Layer`. + """ + log.info(f'Building tf layer for "{layer["fn"]}".') + try: + tf_layer = layer_conversion[layer['fn']] + except KeyError: + raise NotImplementedError(f'Layer "{layer["fn"]}" does not yet have a tf equivalent.') + + kwargs = { + 'name': layer['fn'], + } + + if use_modelspec_init: + # convert the phis to tf constants, removing extra dimensions + kwargs['initializer'] = {k: tf.constant_initializer(v.squeeze()) + for k, v in layer['phi'].items()} + + if 'bounds' in layer: + kwargs['bounds'] = layer['bounds'] + if 'chans' in layer['fn_kwargs']: + kwargs['units'] = layer['fn_kwargs']['chans'] + + return tf_layer(**kwargs) + + +def unmap_layer(layer): + """Returns a phi from the layer weights. + + Necessary because of reshaping that takes place when creating the weights. + """ + weight_names = [ + # extract the actual layer name from the tf layer naming format + # ex: "nems.modules.nonlinearity.double_exponential/base:0" + weight.name.split('/')[1].split(':')[0] + for weight in layer.weights + ] + layer_values = {layer_name: weight.numpy() for layer_name, weight in zip(weight_names, layer.weights)} + + if layer.name == 'nems.modules.nonlinearity.relu': + layer_values['offset'] = layer_values['offset'].reshape((-1, 1)) + + elif layer.name == 'nems.modules.nonlinearity.dlog': + layer_values['offset'] = layer_values['offset'].reshape((-1, 1)) + + elif layer.name == 'nems.modules.nonlinearity.double_exponential': + if layer.units == 1: + shape = (1,) + else: + shape = (-1, 1) + layer_values['amplitude'] = layer_values['amplitude'].reshape(shape) + layer_values['base'] = layer_values['base'].reshape(shape) + layer_values['kappa'] = layer_values['kappa'].reshape(shape) + layer_values['shift'] = layer_values['shift'].reshape(shape) + + elif layer.name == 'nems.modules.levelshift.levelshift': + layer_values['level'] = layer_values['level'].reshape((-1, 1)) + + elif layer.name == 'nems.modules.weight_channels.gaussian': + pass # don't need to reshape + + elif layer.name == 'nems.modules.weight_channels.basic': + layer_values['coefficients'] = layer_values['coefficients'].T + + elif layer.name == 'nems.modules.fir.basic': + layer_values['coefficients'] = layer_values['coefficients'].T + # layer_values['coefficients'] = np.fliplr(layer_values['coefficients'].T) + # layer_values['coefficients'] = layer_values['coefficients'].reshape((1, 15)) + + else: + raise NotImplementedError(f'Cannot unmap "{layer.name}", not yet supported.') + + return layer_values + + +def tf2modelspec(model, modelspec): + """Populates the modelspec phis with the model layers.""" + log.info('Populating modelspec with model weights.') + + # first layer is the input shahpe, so skip it + for idx, layer in enumerate(model.layers[1:]): + ms = modelspec[idx] + + if layer.name != ms['fn']: + raise AssertionError('Model layers and modelspec layers do not match up!') + + phis = unmap_layer(layer) + # check that phis/weights match up in both names and shapes + if phis.keys() != ms['phi'].keys(): + raise AssertionError(f'Model layer "{layer.name}" weights and modelspec phis do not have matching names!') + for model_weights, ms_phis in zip(phis.values(), ms['phi'].values()): + if model_weights.shape != ms_phis.shape: + raise AssertionError(f'Model layer "{layer.name}" weights and modelspec phis do not have matching ' + f'shapes!') + + ms['phi'] = phis + + return modelspec + + +def eval_tf(model: tf.keras.Model, stim: np.ndarray) -> np.ndarray: + """Evaluate the mode on the stim.""" + return model.predict(stim) + + +def compare_ms_tf(modelspec, model, rec_ms, stim_tf): + """Compares the evaluation between a modelspec and a keras model. + + For the modelspec, uses a recording object. For the tf model, uses the formatted data from fit_tf.""" + pred_ms = modelspec.evaluate(rec_ms)['pred']._data.flatten() + pred_tf = model.predict(stim_tf).flatten() + return np.nanstd(pred_ms - pred_tf) + + +def fit_tf(modelspec, + est: recording.Recording, + use_modelspec_init: bool = True, + optimizer: str = 'adam', + max_iter: int = 10000, + cost_function: str = 'squared_error', + early_stopping_steps: int = 5, + early_stopping_tolerance: float = 5e-4, + learning_rate: float = 1e-4, + seed: int = 0, + **context + ) -> dict: + """TODO + + :param est: + :param modelspec: + :param use_modelspec_init: + :param optimizer: + :param max_iter: + :param cost_function: + :param early_stopping_steps: + :param early_stopping_tolerance: + :param learning_rate: + :param seed: + :param context: + + :return: + """ + log.info('Building tensorflow keras model from modelspec.') + + # update seed based on fit index + seed += modelspec.fit_index + + # need to get duration of stims in order to reshape data + epoch_name = 'REFERENCE' # TODO: this should not be hardcoded + + # extract out the raw data, and reshape to (batch, time, channel) + stim_train = np.transpose(est['stim'].extract_epoch(epoch=epoch_name, mask=est['mask']), [0, 2, 1]) + resp_train = np.transpose(est['resp'].extract_epoch(epoch=epoch_name, mask=est['mask']), [0, 2, 1]) + log.info(f'Feature dimensions: {stim_train.shape}; Data dimensions: {resp_train.shape}.') + + # handling changes to weight log directory + # log_dir = modelspec.meta['modelpath'] # TODO: handle slurm, etc + + # correlation for monitoring + # TODO: tf.utils? + def pearson(y_true, y_pred): + return tfp.stats.correlation(y_true, y_pred, event_axis=None, sample_axis=None) + + # get the layers and build the model + cost_fn = loss_functions.get_loss_fn(cost_function) + model_layers = modelspec.modelspec2tf2(use_modelspec_init=use_modelspec_init, seed=seed) + model = modelbuilder.ModelBuilder( + name='Test-model', + layers=model_layers, + learning_rate=learning_rate, + loss_fn=cost_fn, + optimizer=optimizer, + metrics=[pearson], + ).build_model(input_shape=stim_train.shape) + + # create the callbacks + early_stopping = callbacks.DelayedStopper(monitor='loss', patience=30 * early_stopping_steps, + min_delta=early_stopping_tolerance, verbose=1) + + log.info('Fitting model...') + model.fit( + stim_train, + resp_train, + verbose=2, + epochs=max_iter, + callbacks=[ + early_stopping + # TODO: tensorboard + ] + ) + + modelspec = tf2modelspec(model, modelspec) + # compare the predictions from the model and modelspec + error = compare_ms_tf(modelspec, model, est, stim_train) + log.info(f'Mean difference between NEMS and TF model prediction: {error}') + + return {'modelspec': modelspec} diff --git a/nems/tf/layers.py b/nems/tf/layers.py new file mode 100644 index 00000000..c6fae54b --- /dev/null +++ b/nems/tf/layers.py @@ -0,0 +1,304 @@ +import tensorflow as tf + + +class Dlog(tf.keras.layers.Layer): + """Simple dlog nonlinearity.""" + def __init__(self, + units=None, + name='dlog', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(Dlog, self).__init__(name=name, *args, **kwargs) + + # try to infer the number of units if not specified + if units is None and initializer is None: + self.units = 1 + elif units is None: + self.units = initializer['offset'].value.shape[1] + else: + self.units = units + + self.initializer = {'offset': tf.random_normal_initializer(mean=0, stddev=0.5, seed=seed)} + if initializer is not None: + self.initializer.update(initializer) + + def build(self, input_shape): + self.offset = self.add_weight(name='offset', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['offset'], + trainable=True, + ) + + def call(self, inputs, training=True): + # clip bounds at ±2 to avoid huge compression/expansion + eb = tf.math.pow(tf.constant(10, dtype='float32'), tf.clip_by_value(self.offset, -2, 2)) + return tf.math.log((inputs + eb) / eb) + + +class Levelshift(tf.keras.layers.Layer): + """Simple levelshift nonlinearity.""" + def __init__(self, + units=None, + name='levelshift', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(Levelshift, self).__init__(name=name, *args, **kwargs) + + # try to infer the number of units if not specified + if units is None and initializer is None: + self.units = 1 + elif units is None: + self.units = initializer['level'].value.shape[1] + else: + self.units = units + + self.initializer = {'level': tf.random_normal_initializer(mean=0, stddev=1, seed=seed)} + if initializer is not None: + self.initializer.update(initializer) + + def build(self, input_shape): + self.level = self.add_weight(name='level', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['level'], + trainable=True, + ) + + def call(self, inputs, training=True): + return tf.identity(inputs + self.level) + + +class Relu(tf.keras.layers.Layer): + """Simple relu nonlinearity.""" + def __init__(self, + name='relu', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(Relu, self).__init__(name=name, *args, **kwargs) + + self.initializer = {'offset': tf.random_normal_initializer(mean=0, stddev=1, seed=seed)} + if initializer is not None: + self.initializer.update(initializer) + + def build(self, input_shape): + self.offset = self.add_weight(name='offset', + shape=(input_shape[-1],), + dtype='float32', + initializer=self.initializer['offset'], + trainable=True, + ) + + def call(self, inputs, training=True): + return tf.nn.relu(inputs + self.offset) + + +class DoubleExponential(tf.keras.layers.Layer): + """Basic double exponential nonlinearity.""" + def __init__(self, + units=None, + name='dexp', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(DoubleExponential, self).__init__(name=name, *args, **kwargs) + + # try to infer the number of units if not specified + if units is None and initializer is None: + self.units = 1 + elif units is None: + self.units = initializer['base'].value.shape[1] + else: + self.units = units + + self.initializer = { + 'base': tf.random_normal_initializer(mean=0, stddev=1, seed=seed), + 'amplitude': tf.random_normal_initializer(mean=1, stddev=0.5, seed=seed + 1), + 'shift': tf.random_normal_initializer(mean=0, stddev=1, seed=seed + 2), + 'kappa': tf.random_normal_initializer(mean=1, stddev=0.5, seed=seed + 3), + } + + if initializer is not None: + self.initializer.update(initializer) + + + def build(self, input_shape): + self.base = self.add_weight(name='base', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['base'], + trainable=True, + ) + self.amplitude = self.add_weight(name='amplitude', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['amplitude'], + trainable=True, + ) + self.shift = self.add_weight(name='shift', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['shift'], + trainable=True, + ) + self.kappa = self.add_weight(name='kappa', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['kappa'], + trainable=True, + ) + + def call(self, inputs, training=True): + # formula: base + amp * e^(-e^(-e^kappa * (inputs - shift))) + return self.base + self.amplitude * tf.math.exp(-tf.math.exp(-tf.math.exp(self.kappa) * (inputs - self.shift))) + + +class WeightChannelsBasic(tf.keras.layers.Layer): + """Basic weight channels.""" + def __init__(self, + # kind='basic', + units=None, + name='weight_channels_basic', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(WeightChannelsBasic, self).__init__(name=name, *args, **kwargs) + + # try to infer the number of units if not specified + if units is None and initializer is None: + self.units = 1 + elif units is None: + self.units = initializer['coefficients'].value.shape[1] + else: + self.units = units + + self.initializer = {'coefficients': tf.random_normal_initializer(mean=0.01, stddev=0.2, seed=seed)} + if initializer is not None: + self.initializer.update(initializer) + + def build(self, input_shape): + self.coefficients = self.add_weight(name='coefficients', + shape=(input_shape[-1], self.units), + dtype='float32', + initializer=self.initializer['coefficients'], + trainable=True, + ) + + def call(self, inputs, training=True): + return tf.nn.conv1d(inputs, tf.expand_dims(self.coefficients, 0), stride=1, padding='SAME') + + +class WeightChannelsGaussian(tf.keras.layers.Layer): + """Basic weight channels.""" + # TODO: convert per https://stackoverflow.com/a/52012658/1510542 + def __init__(self, + # kind='gaussian', + bounds, + units=None, + name='weight_channels_gaussian', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(WeightChannelsGaussian, self).__init__(name=name, *args, **kwargs) + + # try to infer the number of units if not specified + if units is None and initializer is None: + self.units = 1 + elif units is None: + self.units = initializer['mean'].value.shape[0] + else: + self.units = units + + self.initializer = { + 'mean': tf.random_normal_initializer(mean=0.5, stddev=0.4, seed=seed), + 'sd': tf.random_normal_initializer(mean=0, stddev=.4, seed=seed + 1), # this is halfnorm in NEMS + } + + if initializer is not None: + self.initializer.update(initializer) + + self.bounds_mean_lower = tf.constant(bounds['mean'][0], dtype='float32') + self.bounds_mean_upper = tf.constant(bounds['mean'][1], dtype='float32') + self.bounds_sd_lower = tf.constant(bounds['sd'][0], dtype='float32') + self.bounds_sd_upper = tf.constant(bounds['sd'][1], dtype='float32') + + def build(self, input_shape): + self.mean = self.add_weight(name='mean', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['mean'], + trainable=True, + ) + self.sd = self.add_weight(name='sd', + shape=(self.units,), + dtype='float32', + initializer=self.initializer['sd'], + trainable=True, + ) + + def call(self, inputs, training=True): + clipped_mean = tf.clip_by_value(self.mean, self.bounds_mean_lower, self.bounds_mean_upper) + clipped_sd = tf.clip_by_value(self.sd, self.bounds_sd_lower, self.bounds_sd_upper) + + input_features = tf.cast(tf.shape(inputs)[-1], dtype='float32') + temp = tf.range(input_features) / input_features + temp = (tf.reshape(temp, [1, input_features, 1]) - clipped_mean) / clipped_sd + temp = tf.math.exp(-0.5 * tf.math.square(temp)) + kernel = temp / tf.math.reduce_sum(temp, axis=1) + + return tf.nn.conv1d(inputs, kernel, stride=1, padding='SAME') + + +class FIR(tf.keras.layers.Layer): + """Basic weight channels.""" + def __init__(self, + units=None, + name='fir', + initializer=None, + seed=0, + *args, + **kwargs, + ): + super(FIR, self).__init__(name=name, *args, **kwargs) + + # try to infer the number of units if not specified + if units is None and initializer is None: + self.units = 1 + elif units is None: + self.units = initializer['coefficients'].value.shape[1] + else: + self.units = units + + self.initializer = {'coefficients': tf.random_normal_initializer(mean=0, stddev=1, seed=seed)} + if initializer is not None: + self.initializer.update(initializer) + + def build(self, input_shape): + self.coefficients = self.add_weight(name='coefficients', + shape=(self.units, input_shape[-1]), + dtype='float32', + initializer=self.initializer['coefficients'], + trainable=True, + ) + + def call(self, inputs, training=True): + # need to pad the front of the inputs + npads = self.units - 1 + padded_input = tf.pad(inputs, [[0, 0], [npads, 0], [0, 0]]) + return tf.nn.conv1d(padded_input, tf.expand_dims(tf.reverse(self.coefficients, axis=[0]), -1), stride=1, + padding='VALID') diff --git a/nems/tf/loss_functions.py b/nems/tf/loss_functions.py index a0f7fba2..f5d1bc93 100644 --- a/nems/tf/loss_functions.py +++ b/nems/tf/loss_functions.py @@ -8,24 +8,36 @@ log = logging.getLogger(__name__) +def get_loss_fn(loss_fn_str: str): + """Utility to get a function from a string name.""" + mapping = { + 'squared_error': loss_se_safe, + 'poisson': poisson, + 'nmse': loss_tf_nmse, + 'nmse_shrinkage': loss_tf_nmse_shrinkage, + } + + try: + return mapping[loss_fn_str] + except KeyError: + raise KeyError(f'Loss must be "squared_error", "poisson", "nmse",' + f' or "nmse_shrinkage", not {loss_fn_str}') + + def poisson(response, prediction): """Poisson loss.""" - return tf.reduce_mean(prediction - response * tf.log(prediction + 1e-5), name='poisson') + return tf.math.reduce_mean(prediction - response * tf.math.log(prediction + 1e-5), name='poisson') def loss_se(response, prediction): """Squared error loss.""" - # TODO : add percell option + return tf.math.reduce_mean(tf.math.square(response - prediction)) / tf.math.reduce_mean(tf.math.square(response)) - # SVD -- was trying to reduce nan outputs, but doesn't seem to make a different. more important to avoid - # having the fitter get there. Possibly useful in the future though? - #x = tf.square(response - prediction) - #y = tf.square(response) - #x = tf.reduce_mean(tf.boolean_mask(x, tf.is_finite(x))) + tf.reduce_mean(tf.dtypes.cast(tf.is_inf(x), tf.float32)) - #y = tf.reduce_mean(tf.boolean_mask(y, tf.is_finite(y))) + tf.reduce_mean(tf.dtypes.cast(tf.is_inf(y), tf.float32)) - #return x/y - return tf.reduce_mean(tf.square(response - prediction)) / tf.reduce_mean(tf.square(response)) +def loss_se_safe(response, prediction): + """Squared error loss, with checking for NaNs.""" + tf.debugging.check_numerics(prediction, message='Nan values in prediction.') + return tf.math.reduce_mean(tf.math.square(response - prediction)) / tf.math.reduce_mean(tf.math.square(response)) def loss_tf_nmse_shrinkage(response, prediction): @@ -37,7 +49,7 @@ def loss_tf_nmse(response, prediction, per_cell=True): """Normalized means squared error loss.""" mE, sE = tf_nmse(response, prediction, per_cell=per_cell) if per_cell: - return tf.reduce_mean(mE) + return tf.math.reduce_mean(mE) else: return mE diff --git a/nems/tf/modelbuilder.py b/nems/tf/modelbuilder.py new file mode 100644 index 00000000..e949b5fe --- /dev/null +++ b/nems/tf/modelbuilder.py @@ -0,0 +1,83 @@ +import tensorflow as tf +import logging + +from nems.tf import loss_functions + +log = logging.getLogger(__name__) + + +class ModelBuilder: + """Basic wrapper to build tensorflow keras models. + + Not an actual subclass of tf.keras.Model, but instead a set of helper + methods to prepare layers and then return a model when ready. + """ + def __init__(self, + name='model', + layers=None, + learning_rate=1e-3, + optimizer='adam', + loss_fn=loss_functions.loss_se, + metrics=None, + *args, + **kwargs + ): + """Layers don't need to be added at init, only prior to fitting. + + :param name: + :param args: + :param kwargs: + """ + self.optimizer_dict = { + 'adam': tf.keras.optimizers.Adam, + 'sgd': tf.keras.optimizers.SGD, + } + + if layers is not None: + self.model_layers = layers + else: + self.model_layers = [] + + self.name = name + self.loss_fn = loss_fn + self.metrics = metrics + + self.learning_rate = learning_rate + self.optimizer = optimizer + + def add_layer(self, layer, idx=None): + """Adds a layer to the list of layers that will be part of the model. + + :param layer: Layer to add. + :param idx: Insertion index. + """ + if idx is None: + self.model_layers.append(layer) + log.debug(f'Added "{layer.name}" to end of model.') + else: + self.model_layers.insert(idx, layer) + log.debug(f'Inserted "{layer.name}" to idx "{idx}".') + + def build_model(self, input_shape): + """Creates a tf.keras.Model instance. Simple single headed model. + + :param input_shape: Allows summary creation. Shape should be a tuple of (batch, time, channel). + :return: tf.keras.Model instance. + """ + inputs = tf.keras.Input(shape=input_shape[1:], name='input_node', batch_size=input_shape[0]) + outputs = inputs + for layer in self.model_layers: + outputs = layer(outputs) + + model = tf.keras.Model(inputs=inputs, outputs=outputs, name=self.name) + log.info('Built model, printing summary.') + model.summary(print_fn=log.info) + + # build the optimizer + try: + optimizer = self.optimizer_dict[self.optimizer](learning_rate=self.learning_rate) + except KeyError: + optimizer = self.optimizer + + model.compile(optimizer, loss=self.loss_fn, metrics=self.metrics) + return model