Skip to content

Commit

Permalink
Merge pull request #3693 from samuelgarcia/motion_svd
Browse files Browse the repository at this point in the history
Peak SVD motion extraction
  • Loading branch information
samuelgarcia authored Feb 21, 2025
2 parents 46f0c8d + 54c12ce commit 64d253c
Show file tree
Hide file tree
Showing 7 changed files with 380 additions and 40 deletions.
13 changes: 12 additions & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@


class PipelineNode:

# If False (general case) then compute(traces_chunk, *node_input_args)
# If True then compute(traces_chunk, start_frame, end_frame, segment_index, max_margin, *node_input_args)
_compute_has_extended_signature = False

def __init__(
self,
recording: BaseRecording,
Expand Down Expand Up @@ -684,7 +689,13 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin, peak_slice)
else:
# TODO later when in master: change the signature of all nodes (or maybe not!)
node_output = node.compute(traces_chunk, *node_input_args)
if not node._compute_has_extended_signature:
node_output = node.compute(traces_chunk, *node_input_args)
else:
node_output = node.compute(
traces_chunk, start_frame, end_frame, segment_index, max_margin, *node_input_args
)

pipeline_outputs[node] = node_output

if skip_after_n_peaks_per_worker is not None and isinstance(node, PeakSource):
Expand Down
137 changes: 137 additions & 0 deletions src/spikeinterface/sortingcomponents/clustering/peak_svd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from pathlib import Path
import pickle
import json

import numpy as np

from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.waveforms.temporal_pca import (
TemporalPCAProjection,
MotionAwareTemporalPCAProjection,
)
from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractSparseWaveforms, PeakRetriever


def extract_peaks_svd(
recording,
peaks,
ms_before=0.5,
ms_after=1.5,
svd_model=None,
n_components=5,
radius_um=120.0,
motion_aware=False,
motion=None,
folder=None,
**job_kwargs,
):
"""
Extract the sparse waveform compress to SVD (PCA) on a local set of channel per peak.
So importantly, the output buffer have unaligned channel on shape[2].
This is done in 2 steps:
* fit a TruncatedSVD model on a few peaks on max channel
* tranform each peaks in parralel on a sparse channel set with this model
The recording have a drift, hen, optionally, the motion object can be given.
In that case all the svd features are moved back using cubi interpolation.
This avoid the use of interpolating the traces iself (with krigging).
The output shape is (num_peaks, n_components, max_sparse_channel)
"""

nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
nafter = int(ms_after * recording.sampling_frequency / 1000.0)

# Step 1 : select a few peaks to fit the SVD
if svd_model is None:
few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=5000, margin=(nbefore, nafter))
few_wfs = extract_waveform_at_max_channel(
recording, few_peaks, ms_before=ms_before, ms_after=ms_after, job_name="Fit peaks svd", **job_kwargs
)

wfs = few_wfs[:, :, 0]
from sklearn.decomposition import TruncatedSVD

svd_model = TruncatedSVD(n_components=n_components)
svd_model.fit(wfs)
need_save_model = True
else:
need_save_model = False

if folder is None:
gather_mode = "memory"
features_folder = None
gather_kwargs = dict()
else:
gather_mode = "npy"
if folder is None:
raise ValueError("For gather_mode=npy a folder must be given")

folder = Path(folder)

# save the model
if need_save_model:
model_folder = folder / "svd_model"
model_folder.mkdir(exist_ok=True, parents=True)
with open(model_folder / "pca_model.pkl", "wb") as f:
pickle.dump(svd_model, f)
model_params = {
"ms_before": ms_before,
"ms_after": ms_after,
"sampling_frequency": float(recording.sampling_frequency),
}
with open(model_folder / "params.json", "w") as f:
json.dump(model_params, f)

# save the features
features_folder = folder / "features"
gather_kwargs = dict(exist_ok=True)

node0 = PeakRetriever(recording, peaks)

if motion_aware:
# we need to increase the radius by the max motion
max_motion = max(abs(e) for e in motion.get_boundaries())
radius_um = radius_um + max_motion

node1 = ExtractSparseWaveforms(
recording,
parents=[node0],
return_output=False,
ms_before=ms_before,
ms_after=ms_after,
radius_um=radius_um,
)

if motion_aware:
if motion is None:
raise ValueError("For motion aware PCA motion must provided")
node2 = MotionAwareTemporalPCAProjection(
recording, parents=[node0, node1], return_output=True, pca_model=svd_model, motion=motion
)
else:
node2 = TemporalPCAProjection(
recording,
parents=[node0, node1],
return_output=True,
pca_model=svd_model,
)

pipeline_nodes = [node0, node1, node2]

peaks_svd = run_node_pipeline(
recording,
pipeline_nodes,
job_kwargs,
gather_mode=gather_mode,
gather_kwargs=gather_kwargs,
folder=features_folder,
names=["sparse_svd"],
job_name="Transform peaks svd",
)

sparse_mask = node1.neighbours_mask

return peaks_svd, sparse_mask, svd_model
1 change: 1 addition & 0 deletions src/spikeinterface/sortingcomponents/motion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .motion_estimation import estimate_motion
from .motion_interpolation import (
compute_peak_displacements,
correct_motion_on_peaks,
interpolate_motion_on_traces,
InterpolateMotionRecording,
Expand Down
56 changes: 45 additions & 11 deletions src/spikeinterface/sortingcomponents/motion/motion_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,49 @@
from spikeinterface.preprocessing.filter import fix_dtype


def compute_peak_displacements(peaks, motion, recording, peak_locations=None):
"""
Compute the local motion for each peak given a motion object.
Parameters
----------
peaks : np.array
peaks vector
motion : Motion
The motion object.
recording : Recording
The recording object. This is used to convert sample indices to times.
peak_locations: np.array | None
Optional : peaks location vector.
Otherwise use the channel_index for location.
Returns
-------
peak_displacements: np.array
Motion-corrected peak locations
"""
if recording is None:
raise ValueError("compute_peak_displacements need recording to be not None")

channel_locations = recording.get_channel_locations()

peak_displacements = np.zeros(peaks.size, dtype="float32")

for segment_index in range(motion.num_segments):
i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1])
sl = slice(i0, i1)
peak_times = recording.sample_index_to_time(peaks["sample_index"][sl], segment_index=segment_index)
if peak_locations is None:
peak_locs = channel_locations[peaks["channel_index"][sl], motion.dim]
else:
peak_locs = peak_locations[motion.direction][sl]

peak_displacements[sl] = motion.get_displacement_at_time_and_depth(
peak_times, peak_locs, segment_index=segment_index
)

return peak_displacements


def correct_motion_on_peaks(peaks, peak_locations, motion, recording) -> np.ndarray:
"""
Given the output of estimate_motion(), apply inverse motion on peak locations.
Expand All @@ -31,18 +74,9 @@ def correct_motion_on_peaks(peaks, peak_locations, motion, recording) -> np.ndar
if recording is None:
raise ValueError("correct_motion_on_peaks need recording to be not None")

peak_displacements = compute_peak_displacements(peaks, motion, recording, peak_locations=peak_locations)
corrected_peak_locations = peak_locations.copy()

for segment_index in range(motion.num_segments):
times_s = recording.sample_index_to_time(peaks["sample_index"], segment_index=segment_index)
i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1])

spike_times = times_s[i0:i1]
spike_locs = peak_locations[motion.direction][i0:i1]
spike_displacement = motion.get_displacement_at_time_and_depth(
spike_times, spike_locs, segment_index=segment_index
)
corrected_peak_locations[i0:i1][motion.direction] -= spike_displacement
corrected_peak_locations[motion.direction] -= peak_displacements

return corrected_peak_locations

Expand Down
16 changes: 13 additions & 3 deletions src/spikeinterface/sortingcomponents/tests/test_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks, clustering_methods
from spikeinterface.sortingcomponents.clustering.peak_svd import extract_peaks_svd

from spikeinterface.core import get_noise_levels

Expand Down Expand Up @@ -71,14 +72,23 @@ def test_find_cluster_from_peaks(clustering_method, recording, peaks, peak_locat
print(clustering_method, "found", len(labels), "clusters in ", t1 - t0)


def test_extract_peaks_svd(recording, peaks, job_kwargs):
peaks_svd, sparse_mask, svd_model = extract_peaks_svd(recording, peaks, n_components=5, **job_kwargs)
assert peaks_svd.shape[0] == peaks.shape[0]
assert peaks_svd.shape[1] == 5
assert peaks_svd.shape[2] == np.max(np.sum(sparse_mask, axis=1))


if __name__ == "__main__":
job_kwargs = dict(n_jobs=1, chunk_size=10000, progress_bar=True)
recording, sorting = make_dataset()
peaks = run_peaks(recording, job_kwargs)
peak_locations = run_peak_locations(recording, peaks, job_kwargs)
# peak_locations = run_peak_locations(recording, peaks, job_kwargs)
# method = "position_and_pca"
# method = "circus"
# method = "tdc_clustering"
method = "random_projections"
# method = "random_projections"

# test_find_cluster_from_peaks(method, recording, peaks, peak_locations)

test_find_cluster_from_peaks(method, recording, peaks, peak_locations)
test_extract_peaks_svd(recording, peaks, job_kwargs)
3 changes: 2 additions & 1 deletion src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def make_multi_method_doc(methods, ident=" "):
return doc


def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs):
def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, job_name=None, **job_kwargs):
"""
Helper function to extract waveforms at the max channel from a peak list
Expand Down Expand Up @@ -63,6 +63,7 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j
sparsity_mask=sparsity_mask,
copy=True,
verbose=False,
job_name=job_name,
**job_kwargs,
)

Expand Down
Loading

0 comments on commit 64d253c

Please sign in to comment.