Skip to content

Commit

Permalink
#207 migrate cnn module to tf2
Browse files Browse the repository at this point in the history
  • Loading branch information
arrrobase committed Apr 30, 2020
1 parent 5a04664 commit 8d6889b
Show file tree
Hide file tree
Showing 21 changed files with 1,631 additions and 93 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ dependencies:
- pyqt
- tensorflow
- tensorboard
- tensorflow-probability
20 changes: 20 additions & 0 deletions nems/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,26 @@ def update_startdate(queueid=None):
return r


def update_job_pid(pid, queueid=None):
"""
in tQueue, update the PID
"""
if queueid is None:
if 'QUEUEID' in os.environ:
queueid = os.environ['QUEUEID']
else:
log.warning("queueid not specified or found in os.environ")
return 0

engine = Engine()
conn = engine.connect()
sql = ("UPDATE tQueue SET pid={} WHERE id={}"
.format(pid, queueid))
r = conn.execute(sql)
conn.close()
return r


def update_job_tick(queueid=None):
"""
update current machine's load in the cluster db and tick off a step
Expand Down
71 changes: 67 additions & 4 deletions nems/modelspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
import re
import typing
from functools import partial

import matplotlib.gridspec as gridspec
Expand Down Expand Up @@ -541,6 +542,9 @@ def quickplot(self, rec=None, epoch=None, occurrence=None, fit_index=None,
"""
if rec is None:
rec = self.recording
# rec = rec.copy()

# rec['resp'] = rec['resp'].rasterize()

if fit_index is not None:
self.fit_index = fit_index
Expand Down Expand Up @@ -595,18 +599,22 @@ def quickplot(self, rec=None, epoch=None, occurrence=None, fit_index=None,
if epoch is not None:
try:
epoch_bounds = rec['resp'].get_epoch_bounds(epoch, mask=rec['mask'])
except ValueError:
# some signal types can't handle masks with epoch bounds
epoch_bounds = rec['resp'].get_epoch_bounds(epoch)
except:
log.warning(f'Quickplot: no valid epochs matching {epoch}. Will not subset data.')
epoch = None

if 'mask' in rec.signals.keys():
rec = rec.apply_mask()

rec_resp = rec['resp']
rec_pred = rec['pred']
rec_stim = rec['stim']

if (epoch is None) or (len(epoch_bounds)==0):
epoch_bounds = np.array([[0, rec['resp'].shape[1]/rec['resp'].fs]])
if epoch is None or len(epoch_bounds) == 0:
epoch_bounds = np.array([[0, rec['resp'].shape[1] / rec['resp'].fs]])

# figure out which occurrence
# not_empty = [np.any(np.isfinite(x)) for x in extracted] # occurrences containing non inf/nan data
Expand Down Expand Up @@ -675,7 +683,7 @@ def quickplot(self, rec=None, epoch=None, occurrence=None, fit_index=None,
]

for s in sig_names:
if rec[s].shape[0]>1:
if rec[s].shape[0] > 1:
plot_fn = _lookup_fn_at('nems.plots.api.spectrogram')
title = s + ' spectrogram'
else:
Expand Down Expand Up @@ -978,7 +986,7 @@ def modelspec2tf(self, tps_per_stim=550, feat_dims=1, data_dims=1, state_dims=0,

layer = nems.tf.cnnlink.map_layer(layer=layer, prev_layers=layers, fn=fn, idx=idx, modelspec=m,
n_input_feats=n_input_feats, net_seed=net_seed, weight_scale=weight_scale,
use_modelspec_init=use_modelspec_init, distr=distr,)
use_modelspec_init=use_modelspec_init, fs=fs, distr=distr,)

# necessary?
layer['time_win_sec'] = layer['time_win_smp'] / fs
Expand All @@ -987,6 +995,24 @@ def modelspec2tf(self, tps_per_stim=550, feat_dims=1, data_dims=1, state_dims=0,

return layers

def modelspec2tf2(self, seed=0, use_modelspec_init=True, fs=100):
"""New version
TODO
"""
layers = []

for m in self:
try:
tf_layer = nems.utils.lookup_fn_at(m['tf_layer'])
except KeyError:
raise NotImplementedError(f'Layer "{m["fn"]}" does not have a tf equivalent.')

layer = tf_layer.from_ms_layer(m, use_modelspec_init=use_modelspec_init, fs=fs)
layers.append(layer)

return layers


def get_modelspec_metadata(modelspec):
"""Return a dict of the metadata for this modelspec.
Expand Down Expand Up @@ -1215,6 +1241,43 @@ def fit_mode_off(modelspec):
modelspec.fast_eval_off()


def eval_ms_layer(data: np.ndarray,
layer_spec: typing.Union[None, str] = None,
stop: typing.Union[None, int] = None,
modelspec: ModelSpec = None
) -> np.ndarray:
"""Takes in a numpy array and applies a single ms layer to it.
:param data: The input data. Shape of (reps, time, channels).
:param layer_spec: A layer spec for layers of a modelspec.
:param stop: What layer to eval to. Non inclusive. If not passed, will evaluate the whole layer spec.
:param modelspec: Optionally use an existing modelspec. Takes precedence over layer_spec.
:return: The processed data.
"""
if layer_spec is None and modelspec is None:
raise ValueError('Either of "layer_spec" or "modelspec" must be specified.')

if modelspec is not None:
ms = modelspec
else:
ms = nems.initializers.from_keywords(layer_spec)

sig = nems.signal.RasterizedSignal.from_3darray(
fs=100,
array=np.swapaxes(data, 1, 2),
name='temp',
recording='temp',
epoch_name='REFERENCE'
)

rec = nems.recording.Recording({'stim': sig})
rec = ms.evaluate(rec=rec, stop=stop)

pred = np.swapaxes(rec['pred'].extract_epoch('REFERENCE'), 1, 2)
return pred


def evaluate(rec, modelspec, start=None, stop=None):
"""Given a recording object and a modelspec, return a prediction in a new recording.
Expand Down
2 changes: 1 addition & 1 deletion nems/modules/levelshift.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def levelshift(rec, i, o, level):
def levelshift(rec, i, o, level, **kwargs):
'''
Parameters
----------
Expand Down
4 changes: 2 additions & 2 deletions nems/modules/nonlinearity.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _double_exponential(x, base, amplitude, shift, kappa):
return base + amplitude * exp(-exp(np.array(-exp(kappa)) * (x - shift)))


def double_exponential(rec, i, o, base, amplitude, shift, kappa):
def double_exponential(rec, i, o, base, amplitude, shift, kappa, **kwargs):
'''
A double exponential applied to all channels of a single signal.
rec Recording object
Expand Down Expand Up @@ -84,7 +84,7 @@ def _dlog(x, offset):
return np.log((x + d) / d)


def dlog(rec, i, o, offset):
def dlog(rec, i, o, offset, **kwargs):
"""
Log compression with variable offset
:param rec: recording object with signals i and o.
Expand Down
Loading

0 comments on commit 8d6889b

Please sign in to comment.