Skip to content

Commit

Permalink
save to memory
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Mar 26, 2021
1 parent 310995a commit b90b2c6
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 34 deletions.
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
numpy
neo
joblib
loky
probeinterface

40 changes: 31 additions & 9 deletions spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,19 @@ def copy_metadata(self, other, only_main=False, ids=None):
inds = self.ids_to_indices(ids)

if only_main:
other._annotations = deepcopy({k: self._annotations[k] for k in ExtractorBase._main_annotations})
other._properties = deepcopy({k: self._properties[k][inds] for k in ExtractorBase._main_properties})
# other._features = deepcopy({self._features[k] for k in ExtractorBase._main_features})
ann_keys = ExtractorBase._main_annotations
prop_keys = ExtractorBase._main_properties
feat_keys = ExtractorBase._main_features
else:
other._annotations = deepcopy(self._annotations)
other._properties = deepcopy({k: self._properties[k][inds] for k in self._properties})
# other._features = deepcopy(self._features)
ann_keys = self._annotations.keys()
prop_keys = self._properties.keys()
# feat_keys = ExtractorBase._features.keys()


other._annotations = deepcopy({k: self._annotations[k] for k in ann_keys})
other._properties = deepcopy({k: self._properties[k][inds] for k in prop_keys if self._properties[k] is not None})
# other._features = deepcopy({k: self._features[k] for k in feat_keys})



def to_dict(self, include_annotations=True, include_properties=True, include_features=True):
Expand Down Expand Up @@ -414,7 +420,7 @@ def load(file_path):
def load_from_folder(folder):
return BaseExtractor.load(folder)

def _save_to_folder(self, folder, **save_kargs):
def _save(self, folder, **save_kargs):
# This implemented in BaseRecording or baseSorting
# this is internally call by cache(...) main function
raise NotImplementedError
Expand All @@ -423,8 +429,24 @@ def _after_load(self, folder):
# This implemented in BaseRecording or baseSorting
# this is internally call by load(...) main function
raise NotImplementedError

def save(self, **kargs):
"""
route save_to_folder() or save_to_mem()
"""
if kargs.get('format', None) == 'memory':
return self.save_to_memory(**kargs)
else:
return self.save_to_folder(**kargs)

def save_to_memory(self, **kargs):
print('save_to_memory')
# used only by recording at the moment
cached = self._save(**kargs)
self.copy_metadata(cached)
return cached

def save(self, name=None, folder=None, dump_ext='json', verbose=True, **save_kargs):
def save_to_folder(self, name=None, folder=None, dump_ext='json', verbose=True, **save_kargs):
"""
Save consist of:
* compute the extractro by calling get_trace() in chunk
Expand Down Expand Up @@ -487,7 +509,7 @@ def save(self, name=None, folder=None, dump_ext='json', verbose=True, **save_kar
)

# save data (done the subclass)
cached = self._save_to_folder(folder, verbose=verbose, **save_kargs)
cached = self._save(folder=folder, verbose=verbose, **save_kargs)

# copy properties/
self.copy_metadata(cached)
Expand Down
29 changes: 23 additions & 6 deletions spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


from .base import BaseExtractor, BaseSegment
from .core_tools import write_binary_recording
from .core_tools import write_binary_recording, write_memory_recording


class BaseRecording(BaseExtractor):
Expand Down Expand Up @@ -94,28 +94,45 @@ def is_filtered(self):
# the is_filtered is handle with annotation
return self._annotations.get('is_filtered', False)

def _save_to_folder(self, folder, format='binary', **save_kargs):
_job_keys = ['n_jobs', 'total_memory', 'chunk_size', 'chunk_memory', 'progress_bar', 'verbose']
def _save(self, format='binary', **save_kwargs):
"""
This replace the old CacheRecordingExtractor but enable more engine
for caching a results. at the moment only binaray with memmap is supported.
My plan is to add also zarr support.
"""
# TODO save propreties as npz!!!!!

if format == 'binary':
# TODO save propreties as npz!!!!!
folder = save_kwargs['folder']
files_path = [ folder / f'traces_cached_seg{i}.raw' for i in range(self.get_num_segments())]
dtype = save_kargs.get('dtype', 'float32')
keys = ['n_jobs', 'total_memory', 'chunk_size', 'chunk_memory', 'progress_bar', 'verbose']
job_kwargs = {k:save_kargs[k] for k in keys if k in save_kargs}
dtype = save_kwargs.get('dtype', 'float32')

job_kwargs = {k:save_kwargs[k] for k in self._job_keys if k in save_kwargs}
write_binary_recording(self, files_path=files_path, dtype=dtype, **job_kwargs)

from . binaryrecordingextractor import BinaryRecordingExtractor
cached = BinaryRecordingExtractor(files_path, self.get_sampling_frequency(),
self.get_num_channels(), dtype, channel_ids=self.get_channel_ids(), time_axis=0)

elif format == 'memory':
job_kwargs = {k:save_kwargs[k] for k in self._job_keys if k in save_kwargs}
traces_list = write_memory_recording(self, dtype=None, **job_kwargs)
from .numpyextractors import NumpyRecording

cached = NumpyRecording(traces_list, self.get_sampling_frequency(), channel_ids=self.channel_ids)
print('cached')
print(cached)


elif format == 'zarr':
# TODO implement a format based on zarr
raise NotImplementedError

elif format == 'nwb':
# TODO implement a format based on zarr
raise NotImplementedError

else:
raise ValueError(f'format {format} not supported')

Expand Down
6 changes: 2 additions & 4 deletions spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,14 @@ def get_unit_spike_train(self,
S = self._sorting_segments[segment_index]
return S.get_unit_spike_train(unit_id=unit_id, start_frame=start_frame, end_frame=end_frame)

def _save_to_folder(self, folder, format='npz', **cache_kargs):
def _save(self, format='npz', **save_kwargs):
"""
This replace the old CacheSortingExtractor but enable more engine
for caching a results. At the moment only npz.
"""
if format == 'npz':
assert len(cache_kargs) == 0, 'Sorting.save() with npz do not support options'

folder = save_kwargs.pop('folder')
# TODO save propreties/features as npz!!!!!

from .npzsortingextractor import NpzSortingExtractor
save_path = folder / 'sorting_cached.npz'
NpzSortingExtractor.write_sorting(self, save_path)
Expand Down
124 changes: 121 additions & 3 deletions spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def read_binary_recording(file, num_chan, dtype, time_axis=0, offset=0):
def _init_binary_worker(recording, rec_memmaps, dtype):
# create a local dict per worker
local_dict = {}
from spikeinterface.core import load_extractor
if isinstance(recording, dict):
from spikeinterface.core import load_extractor
local_dict['recording'] = load_extractor(recording)
Expand All @@ -126,8 +125,6 @@ def _write_binary_chunk(segment_index, start_frame, end_frame, local_dict):
traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index)
traces = traces.astype(dtype)
rec_memmap[start_frame:end_frame, :] = traces




def write_binary_recording(recording, files_path=None, dtype=None,
Expand Down Expand Up @@ -198,6 +195,7 @@ def write_binary_recording_file_handle(recording, file_handle=None,
dtype = recording.get_dtype()

chunk_size = ensure_chunk_size(recording, **job_kwargs)


if chunk_size is not None and time_axis == 1:
print("Chunking disabled due to 'time_axis' == 1")
Expand Down Expand Up @@ -228,6 +226,126 @@ def write_binary_recording_file_handle(recording, file_handle=None,



# used by write_memory_recording
def _init_memory_worker(recording, arrays, shm_names, shapes, dtype):
# create a local dict per worker
local_dict = {}
if isinstance(recording, dict):
from spikeinterface.core import load_extractor
local_dict['recording'] = load_extractor(recording)
else:
local_dict['recording'] = recording

local_dict['dtype'] = np.dtype(dtype)

if arrays is None:
# create it from share memory name
from multiprocessing.shared_memory import SharedMemory
arrays = []
# keep shm alive
local_dict['shms'] = []
for i in range(len(shm_names)):
shm = SharedMemory(shm_names[i])
local_dict['shms'].append(shm)
arr = np.ndarray(shape=shapes[i], dtype=dtype, buffer=shm.buf)
arrays.append(arr)

local_dict['arrays'] = arrays

return local_dict


# used by write_memory_recording
def _write_memory_chunk(segment_index, start_frame, end_frame, local_dict):
# recover variables of the worker
recording = local_dict['recording']
dtype = local_dict['dtype']
arr = local_dict['arrays'][segment_index]

# apply function
traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index)
traces = traces.astype(dtype)
arr[start_frame:end_frame, :] = traces
#~ print('yep')


def make_shared_array(shape, dtype):
# https://docs.python.org/3/library/multiprocessing.shared_memory.html
try:
from multiprocessing.shared_memory import SharedMemory
except Exception as e:
raise Exception('SharedMemory is available only for python>=3.8')

dtype = np.dtype(dtype)
nbytes = shape[0] * shape[1] * dtype.itemsize
shm = SharedMemory(name=None, create=True, size=nbytes)
arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf)
arr[:] = 0

return arr, shm


def write_memory_recording(recording, dtype=None, verbose=False, **job_kwargs):
'''
Saves the traces into numpy arrays (memory).
try to use the SharedMemory introduce in py3.8 if n_jobs > 1
Parameters
----------
recording: RecordingExtractor
The recording extractor object to be saved in .dat format
dtype: dtype
Type of the saved data. Default float32.
verbose: bool
If True, output is verbose (when chunks are used)
**job_kwargs:
Use by job_tools modules to set:
* chunk_size or chunk_memory, or total_memory
* n_jobs
* progress_bar
Returns
---------
arrays: one arrays per segment
'''

chunk_size = ensure_chunk_size(recording, **job_kwargs)
n_jobs = ensure_n_jobs(recording, n_jobs=job_kwargs.get('n_jobs', 1))

if dtype is None:
dtype = recording.get_dtype()

# create sharedmmep
arrays = []
shm_names = []
shapes = []
for segment_index in range(recording.get_num_segments()):
num_frames = recording.get_num_samples(segment_index)
num_channels = recording.get_num_channels()
shape = (num_frames, num_channels)
shapes.append(shape)
if n_jobs >1:
arr, shm = make_shared_array(shape, dtype)
shm_names.append(shm.name)
else:
arr = np.zeros(shape, dtype=dtype)
arrays.append(arr)

# use executor (loop or workers)
func = _write_memory_chunk
init_func = _init_memory_worker
if n_jobs >1:
init_args = (recording.to_dict(), None, shm_names, shapes, dtype)
else:
init_args = (recording.to_dict(), arrays, None, None, dtype)

executor = ChunkRecordingExecutor(recording, func, init_func, init_args, verbose=verbose,
job_name='write_memory_recording', **job_kwargs)
executor.run()

return arrays



def write_to_h5_dataset_format(recording, dataset_path, segment_index, save_path=None, file_handle=None,
Expand Down
6 changes: 1 addition & 5 deletions spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class NumpyRecording(BaseRecording):
"""
is_writable = False

def __init__(self, traces_list, sampling_frequency, channel_locations=None, channel_ids=None):
def __init__(self, traces_list, sampling_frequency, channel_ids=None):
if isinstance(traces_list, list):
assert all(isinstance(e, np.ndarray) for e in traces_list), 'must give a list of numpy array'
else:
Expand All @@ -43,13 +43,9 @@ def __init__(self, traces_list, sampling_frequency, channel_locations=None, chan
rec_segment = NumpyRecordingSegment(traces)
self.add_recording_segment(rec_segment)

# not sure that this is relevant!!!
if channel_locations is not None:
self.set_channel_locations(channel_locations)

self._kwargs = {'traces_list': traces_list,
'sampling_frequency': sampling_frequency,
'channel_locations': channel_locations
}


Expand Down
14 changes: 11 additions & 3 deletions spikeinterface/core/tests/test_baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def test_BaseRecording():

files_path = [f'test_base_recording_{i}.raw' for i in range(num_seg)]
for i in range(num_seg):
np.memmap(files_path[i], dtype=dtype, mode='w+', shape=(num_samples, num_chan))
a = np.memmap(files_path[i], dtype=dtype, mode='w+', shape=(num_samples, num_chan))
a[:] = np.random.randn(*a.shape).astype(dtype)

rec = BinaryRecordingExtractor(files_path, sampling_frequency, num_chan, dtype)
print(rec)
Expand Down Expand Up @@ -71,14 +72,21 @@ def test_BaseRecording():
rec2 = BaseExtractor.load('test_BaseRecording.pkl')
rec3 = load_extractor('test_BaseRecording.pkl')

# cache
# cache to binary
cache_folder = Path('./my_cache_folder')
folder = cache_folder / 'simple_recording'
rec.save(folder=folder)
rec.save(format='binary', folder=folder)
rec2 = BaseExtractor.load_from_folder(folder)
# but also possible
rec3 = BaseExtractor.load('./my_cache_folder/simple_recording')

# cache to memory
rec4 = rec3.save(format='memory')

traces4 = rec4.get_traces(segment_index=0)
traces = rec.get_traces(segment_index=0)
assert np.array_equal(traces4, traces)

# cache joblib several jobs
rec.save(name='simple_recording_2', chunk_size=10, n_jobs=4)

Expand Down
Loading

0 comments on commit b90b2c6

Please sign in to comment.