Skip to content

Commit

Permalink
Add toy example and mearec (neo base)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Feb 1, 2021
1 parent 083e8b9 commit c58ec5b
Show file tree
Hide file tree
Showing 16 changed files with 547 additions and 91 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ spikeinterface/core/tests/*/*/*.raw

spikeinterface/extractors/tests/*/*/*.json
spikeinterface/extractors/tests/*/*/*.raw
spikeinterface/extractors/tests/*.npz
spikeinterface/extractors/tests/extractor_testing_files/*



Expand Down
7 changes: 2 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
spikeextractors>=0.9.3
spiketoolkit>=0.7.2
spikesorters>=0.4.3
spikecomparison>=0.3.1
spikewidgets>=0.5.1
numpy
neo>=0.9.0
5 changes: 3 additions & 2 deletions spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(self, main_ids):

# main_ids will either channel_ids or units_ids
# it is used for properties and features
self._main_ids = np.array(main_ids, dtype=int)
#~ self._main_ids = np.array(main_ids, dtype=int)
self._main_ids = np.array(main_ids)

# dict at object level
self._annotations = {}
Expand Down Expand Up @@ -141,7 +142,7 @@ def set_property(self, key, values, ids=None):
"""
values = np.asarray(values)
if ids is None:
assert values.size == self._main_ids.size
assert values.shape[0] == self._main_ids.size
self._properties[key] = values
else:
assert key in self._properties, 'The key is not in properties'
Expand Down
2 changes: 1 addition & 1 deletion spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __repr__(self):
clsname = self.__class__.__name__
nseg = self.get_num_segments()
nchan = self.get_num_channels()
sf_khz = self.get_sampling_frequency()
sf_khz = self.get_sampling_frequency() / 1000.
txt = f'{clsname}: {nchan} channels - {nseg} segments - {sf_khz:0.1f}kHz'
if 'files_path' in self._kwargs:
txt += '\n files_path: {}'.format(self._kwargs['files_path'])
Expand Down
2 changes: 1 addition & 1 deletion spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __repr__(self):
clsname = self.__class__.__name__
nseg = self.get_num_segments()
nunits = self.get_num_units()
sf_khz = self.get_sampling_frequency()
sf_khz = self.get_sampling_frequency() / 1000.
txt = f'{clsname}: {nunits} nunits - {nseg} segments - {sf_khz:0.1f}kHz'
if 'file_path' in self._kwargs:
txt += '\n file_path: {}'.format(self._kwargs['file_path'])
Expand Down
2 changes: 2 additions & 0 deletions spikeinterface/extractors/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .extractorlist import *

from .toy_example import toy_example
9 changes: 9 additions & 0 deletions spikeinterface/extractors/extractorlist.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from .numpyextractors import NumpyRecording , NumpySorting

from .neoextractors import(
MEArecRecordingExtractor, MEArecSortingExtractor,
)


recording_extractor_full_list = [
NumpyRecording,

# neo based
MEArecRecordingExtractor,

##OLD
#~ MdaRecordingExtractor,
Expand Down Expand Up @@ -41,6 +47,9 @@
sorting_extractor_full_list = [
NumpySorting,

# neo based
MEArecSortingExtractor,

##OLD
#~ MdaSortingExtractor,
#~ MEArecSortingExtractor,
Expand Down
1 change: 1 addition & 0 deletions spikeinterface/extractors/neoextractors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mearecextractor import MEArecRecordingExtractor, MEArecSortingExtractor
70 changes: 70 additions & 0 deletions spikeinterface/extractors/neoextractors/mearecextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np

from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor

import neo


class MEArecRecordingExtractor(NeoBaseRecordingExtractor):
"""
Class for reading data from a MEArec simulated data.
Parameters
----------
file_path: str
locs_2d: bool
"""
mode = 'file'
NeoRawIOClass = 'MEArecRawIO'

def __init__(self, file_path, locs_2d=True):

neo_kwargs = {'filename' : file_path}
NeoBaseRecordingExtractor.__init__(self, **neo_kwargs)

# channel location
recgen = self.neo_reader._recgen
locations = np.array(recgen.channel_positions)
if locs_2d:
if 'electrodes' in recgen.info.keys():
if 'plane' in recgen.info['electrodes'].keys():
probe_plane = recgen.info['electrodes']['plane']
if probe_plane == 'xy':
locations = locations[:, :2]
elif probe_plane == 'yz':
locations = locations[:, 1:]
elif probe_plane == 'xz':
locations = locations[:, [0, 2]]
if locations.shape[1] == 3:
locations = locations[:, 1:]

print(locations.shape)
print(self._main_ids)
self.set_channel_locations(locations)

@staticmethod
def write_recording(recording, save_path, check_suffix=True):
# Alessio : I think we don't need this
raise NotImplementedError


class MEArecSortingExtractor(NeoBaseSortingExtractor):
mode = 'file'
NeoRawIOClass = 'MEArecRawIO'
handle_raw_spike_directly = True

def __init__(self, file_path, use_natural_unit_ids=True):
neo_kwargs = {'filename' : file_path}
NeoBaseSortingExtractor.__init__(self,
sampling_frequency=None, # auto guess is correct here
use_natural_unit_ids=use_natural_unit_ids,
**neo_kwargs)


@staticmethod
def write_sorting(sorting, save_path, sampling_frequency, check_suffix=True):
# Alessio : I think we don't need this
raise NotImplementedError
211 changes: 211 additions & 0 deletions spikeinterface/extractors/neoextractors/neobaseextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import numpy as np
from spikeinterface.core import BaseRecording, BaseSorting, BaseRecordingSegment, BaseSortingSegment

import neo


class _NeoBaseExtractor:
NeoRawIOClass = None
installed = True
is_writable = False

def __init__(self, **neo_kwargs):
neoIOclass = eval('neo.rawio.' + self.NeoRawIOClass)
self.neo_reader = neoIOclass(**neo_kwargs)
self.neo_reader.parse_header()

assert self.neo_reader.block_count() == 1, \
'This file is neo multi block spikeinterface support one block only dataset'

self._kwargs = neo_kwargs


class NeoBaseRecordingExtractor(_NeoBaseExtractor, BaseRecording):

def __init__(self, channel_selection=None, **neo_kwargs):

_NeoBaseExtractor.__init__(self, **neo_kwargs)

# check channel
# TODO propose a meachanisim to select the appropriate channel groups
# in neo one channel group have the same dtype/sampling_rate/group_id
#~ channel_indexes_list = self.neo_reader.get_group_signal_channel_indexes()

self.channel_selection = channel_selection

# check channel groups
chan_ids = self.neo_reader.header['signal_channels']['id']
group_id = self.neo_reader.header['signal_channels']['group_id']
raw_dtypes = self.neo_reader.header['signal_channels']['dtype']
if self.channel_selection is not None:
chan_ids = chan_ids[self.channel_selection]
group_id = group_id[self.channel_selection]
raw_dtypes = raw_dtypes[self.channel_selection]
assert np.unique(group_id).size == 1,\
'This file have several channel groups, use channel_selection=[...] to specify channel selection'

assert np.unique(raw_dtypes).size == 1,\
'This file have several dtype across channel, use channel_selection=[...] to specify channel selection'

sampling_frequency = self.neo_reader.get_signal_sampling_rate(channel_indexes=self.channel_selection)
# TODO propose a mechanism to select scaled/raw dtype
scaled_dtype = 'float32'
BaseRecording.__init__(self, sampling_frequency, chan_ids, scaled_dtype)

# spikeinterface for units to be uV implicitly
units = self.neo_reader.header['signal_channels']['units']
if not np.all(np.isin(units, ['V', 'mV', 'uV'])):
# check that units are V, mV or uV
error = f'This extractor base on neo.{self.NeoRawIOClass} have strange units not in (V, mV, uV)'
print(error)
self.additional_gain = np.ones(units.size, dtype='float')
self.additional_gain[units == 'V'] = 1e6
self.additional_gain[units == 'mV'] = 1e3
self.additional_gain[units == 'uV'] = 1.
self.additional_gain = self.additional_gain.reshape(1, -1)

nseg = self.neo_reader.segment_count(block_index=0)
for segment_index in range(nseg):
rec_segment = NeoRecordingSegment(self.neo_reader, segment_index, self.additional_gain)
self.add_recording_segment(rec_segment)

class NeoRecordingSegment(BaseRecordingSegment):
def __init__(self, neo_reader, segment_index, additional_gain):
BaseRecordingSegment.__init__(self)
self.neo_reader = neo_reader
self.segment_index = segment_index
self.additional_gain = additional_gain

def get_num_samples(self):
n = self.neo_reader.get_signal_size(block_index=0,
seg_index=self.segment_index,
channel_indexes=None)
return n

def get_traces(self, start_frame, end_frame, channel_indices):
# in neo rawio channel can acces by names/ids/indexes
# there is no garranty that ids/names are unique on some formats
raw_traces = self.neo_reader.get_analogsignal_chunk(
block_index=0,
seg_index=self.segment_index,
i_start=start_frame,
i_stop=end_frame,
channel_indexes=channel_indices
)

# rescale traces to natural units (can be anything)
scaled_traces = self.neo_reader.rescale_signal_raw_to_float(raw_traces,
dtype='float32', channel_indexes=channel_indices)
# and then to uV
scaled_traces *= self.additional_gain[:, channel_indices]

return scaled_traces



class NeoBaseSortingExtractor(_NeoBaseExtractor, BaseSorting):

# this will depend on each reader
handle_raw_spike_directly = True

def __init__(self, sampling_frequency=None, use_natural_unit_ids=False, **neo_kwargs):
_NeoBaseExtractor.__init__(self, **neo_kwargs)

self.use_natural_unit_ids = use_natural_unit_ids

if sampling_frequency is None:
sampling_frequency = self._auto_guess_sampling_frequency()

unit_channels = self.neo_reader.header['unit_channels']

if use_natural_unit_ids:
unit_ids = unit_channels['id']
assert np.unique(unit_ids).size == unit_ids.size, 'unit_ids is have duplications'
else:
# use interger based unit_ids
unit_ids = np.arange(unit_channels.size, dtype='int64')

BaseSorting.__init__(self, sampling_frequency, unit_ids)

nseg = self.neo_reader.segment_count(block_index=0)
for segment_index in range(nseg):
if self.handle_raw_spike_directly:
t_start = None
else:
t_start = self.neo_reader.get_signal_t_start(0, segment_index)

sorting_segment = NeoSortingSegment(self.neo_reader, segment_index,
self.use_natural_unit_ids, t_start, sampling_frequency)
self.add_sorting_segment(sorting_segment)

def _auto_guess_sampling_frequency(self):
"""
Because neo handle spike in times (s or ms) but spikeinterface in frames related to signals.
spikeinterface need so the sampling frequency.
Getting the sampling rate in for psike is quite tricky because in neo
spike are handle in s or ms
internally many format do have have the spike time stamps
at the same speed as the signal but at a higher clocks speed.
here in spikeinterface we need spike index to be at the same speed
that signal it do not make sens to have spikes at 50kHz sample
when the sig is 10kHz.
neo handle this but not spieinterface
In neo spikes can have diffrents sampling rate than signals so conversion from
signals frames to times is format dependent
"""

# here the generic case
#  all channels are in the same neo group so
sig_channels = self.neo_reader.header['signal_channels']
assert sig_channels.size > 0, 'samplinf_frequqency is not given and it is hard to guess it'
sampling_frequency = np.max(sig_channels['sampling_rate'])

# print('_auto_guess_sampling_frequency', sampling_frequency)
return sampling_frequency


class NeoSortingSegment(BaseSortingSegment):
def __init__(self, neo_reader, segment_index, use_natural_unit_ids, t_start, sampling_freq):
BaseSortingSegment.__init__(self)
self.neo_reader = neo_reader
self.segment_index = segment_index
self.use_natural_unit_ids = use_natural_unit_ids
self._t_start = t_start
self._sampling_freq = sampling_freq

self._natural_ids = None

def get_natural_ids(self):
if self._natural_ids is None:
self._natural_ids = list(self._parent_extractor().neo_reader.header['unit_channels']['id'])
return self._natural_ids

def get_unit_spike_train(self, unit_id, start_frame, end_frame):
if self.use_natural_unit_ids:
unit_index = self.get_natural_ids().index(unit_id)
else:
# already int
unit_index = unit_id

spike_timestamps = self.neo_reader.get_spike_timestamps(block_index=0,
seg_index=self.segment_index,
unit_index=unit_index)

if handle_raw_spike_directly:
spike_frames = spike_timestamps
else:
# convert to second second
spike_times = self.neo_reader.rescale_spike_timestamp(spike_timestamps, dtype='float64')
# convert to sample related to recording signals
spike_frames = ((spike_times - self._t_start) * self._sampling_freq).astype('int64')

# clip
if start_frame is not None:
spike_frames = spike_frames[spike_frames >= start_frame]

if end_frame is not None:
spike_frames = spike_frames[spike_frames <= end_frame]

return spike_frames

Loading

0 comments on commit c58ec5b

Please sign in to comment.