Skip to content

Commit

Permalink
Add v1s
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Nov 8, 2023
1 parent 4d959e2 commit 20d753a
Show file tree
Hide file tree
Showing 9 changed files with 2,405 additions and 1,995 deletions.
315 changes: 187 additions & 128 deletions src/spyglass/common/common_interval.py

Large diffs are not rendered by default.

392 changes: 219 additions & 173 deletions src/spyglass/spikesorting/v1/artifact.py

Large diffs are not rendered by default.

1,302 changes: 312 additions & 990 deletions src/spyglass/spikesorting/v1/curation.py

Large diffs are not rendered by default.

223 changes: 131 additions & 92 deletions src/spyglass/spikesorting/v1/figurl_curation.py
Original file line number Diff line number Diff line change
@@ -1,130 +1,169 @@
import datajoint as dj

from typing import Any, Union, List, Dict

from .spikesorting_curation import Curation
from .spikesorting_recording import SpikeSortingRecording
from .spikesorting_sorting import SpikeSorting
import datajoint as dj
import pynwb

import spikeinterface as si

from sortingview.SpikeSortingView import SpikeSortingView
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.spikesorting.v1.sorting import SpikeSorting
from spyglass.spikesorting.v1.curation import CurationV1, _merge_dict_to_list

import kachery_cloud as kcl
import sortingview.views as vv
from sortingview.SpikeSortingView import SpikeSortingView

schema = dj.schema("spikesorting_curation_figurl")

# A curation figURL is a link to a visualization of a curation.
# Optionally you can specify a new_curation_uri which will be
# the location of the new manually-edited curation. The
# new_curation_uri should be a github uri of the form
# gh://user/repo/branch/path/to/curation.json
# and ideally the path should be determined by the primary key
# of the curation. The new_curation_uri can also be blank if no
# further manual curation is planned.
schema = dj.schema("spikesorting_v1_figurl_curation")


@schema
class CurationFigurlSelection(dj.Manual):
class FigURLCurationSelection(dj.Manual):
definition = """
-> Curation
-> CurationV1
curation_uri: varchar(1000) # GitHub-based URI to a file to which the manual curation will be saved
---
new_curation_uri: varchar(2000)
metrics_figurl: blob # metrics to display in the figURL
"""

@staticmethod
def generate_curation_uri(key: Dict) -> str:
"""Generates a kachery-cloud URI containing curation info from a row in CurationV1 table
Parameters
----------
key : dict
primary key from CurationV1
"""
curation_key = (CurationV1 & key).fetch1()
analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
curation_key["analysis_file_name"]
)
with pynwb.NWBHDF5IO(
analysis_file_abs_path, "r", load_namespaces=True
) as io:
nwbfile = io.read()
nwb_sorting = nwbfile.objects[curation_key["object_id"]]
unit_ids = nwb_sorting["id"][:]
labels = nwb_sorting["labels"][:]
merge_groups = nwb_sorting["merge_groups"][:]

unit_ids = [str(unit_id) for unit_id in unit_ids]

if labels:
labels_dict = dict(zip(unit_ids, labels))
else:
labels_dict = {}

if merge_groups:
merge_groups_list = _merge_dict_to_list(merge_groups)
merge_groups_list = [
[str(unit_id) for unit_id in merge_group]
for merge_group in merge_groups_list
]
else:
merge_groups_list = []

curation_dict = {
"labelsByUnit": labels_dict,
"mergeGroups": merge_groups_list,
}
curation_uri = kcl.store_json(curation_dict)

return curation_uri


@schema
class CurationFigurl(dj.Computed):
class FigURLCuration(dj.Computed):
definition = """
-> CurationFigurlSelection
-> FigURLCurationSelection
---
url: varchar(2000)
initial_curation_uri: varchar(2000)
new_curation_uri: varchar(2000)
url: varchar(1000)
"""

def make(self, key: dict):
"""Create a Curation Figurl
Parameters
----------
key : dict
primary key of an entry from CurationFigurlSelection table
"""

# get new_curation_uri from selection table
new_curation_uri = (CurationFigurlSelection & key).fetch1(
"new_curation_uri"
# FETCH
sorting_analysis_file_name = (CurationV1 & key).fetch1(
"analysis_file_name"
)
object_id = (CurationV1 & key).fetch1("object_id")
recording_label = (SpikeSorting & key).fetch1("recording_id")

# fetch
recording_path = (SpikeSortingRecording & key).fetch1("recording_path")
sorting_path = (SpikeSorting & key).fetch1("sorting_path")
recording_label = SpikeSortingRecording._get_recording_name(key)
sorting_label = SpikeSorting._get_sorting_name(key)
unit_metrics = _reformat_metrics(
(Curation & key).fetch1("quality_metrics")
# DO
sorting_analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
sorting_analysis_file_name
)
initial_labels = (Curation & key).fetch1("curation_labels")
initial_merge_groups = (Curation & key).fetch1("merge_groups")

# new_curation_uri = key["new_curation_uri"]

# Create the initial curation and store it in kachery
for k, v in initial_labels.items():
new_list = []
for item in v:
if item not in new_list:
new_list.append(item)
initial_labels[k] = new_list
initial_curation = {
"labelsByUnit": initial_labels,
"mergeGroups": initial_merge_groups,
}
initial_curation_uri = kcl.store_json(initial_curation)
recording = CurationV1.get_recording(key)
sorting = CurationV1.get_sorting(key)
sorting_label = key["sorting_id"]
curation_uri = key["curation_uri"]

metric_dict = {}
with pynwb.NWBHDF5IO(
sorting_analysis_file_abs_path, "r", load_namespaces=True
) as io:
nwbf = io.read()
nwb_sorting = nwbf.objects[object_id]
unit_ids = nwb_sorting["id"][:]
for metric in key["metrics_figurl"]:
metric_dict[metric] = dict(
zip(unit_ids, nwb_sorting[metric][:])
)

# Get the recording/sorting extractors
R = si.load_extractor(recording_path)
if R.get_num_segments() > 1:
R = si.concatenate_recordings([R])
S = si.load_extractor(sorting_path)
unit_metrics = _reformat_metrics(metric_dict)

# TODO: figure out a way to specify the similarity metrics

# Generate the figURL
url = _generate_the_figurl(
R=R,
S=S,
initial_curation_uri=initial_curation_uri,
new_curation_uri=new_curation_uri,
key["url"] = _generate_figurl(
R=recording,
S=sorting,
initial_curation_uri=curation_uri,
recording_label=recording_label,
sorting_label=sorting_label,
unit_metrics=unit_metrics,
)

# insert
key["url"] = url
key["initial_curation_uri"] = initial_curation_uri
key["new_curation_uri"] = new_curation_uri
self.insert1(key)
# INSERT
self.insert1(key, skip_duplicates=True)

@classmethod
def get_labels(cls):
return NotImplementedError

@classmethod
def get_merge_groups(cls):
return NotImplementedError

def _generate_the_figurl(
*,

def _generate_figurl(
R: si.BaseRecording,
S: si.BaseSorting,
unit_metrics: Union[List[Any], None] = None,
initial_curation_uri: str,
recording_label: str,
sorting_label: str,
new_curation_uri: str,
unit_metrics: Union[List[Any], None] = None,
segment_duration_sec=1200,
snippet_ms_before=1,
snippet_ms_after=1,
max_num_snippets_per_segment=1000,
channel_neighborhood_size=5,
raster_plot_subsample_max_firing_rate=50,
spike_amplitudes_subsample_max_firing_rate=50,
):
print("Preparing spikesortingview data")
sampling_frequency = R.get_sampling_frequency()
X = SpikeSortingView.create(
recording=R,
sorting=S,
segment_duration_sec=60 * 20,
snippet_len=(20, 20),
max_num_snippets_per_segment=100,
channel_neighborhood_size=7,
segment_duration_sec=segment_duration_sec,
snippet_len=(
int(snippet_ms_before * sampling_frequency / 1000),
int(snippet_ms_after * sampling_frequency / 1000),
),
max_num_snippets_per_segment=max_num_snippets_per_segment,
channel_neighborhood_size=channel_neighborhood_size,
)

# create a fake unit similarity matrix (for future reference)
# similarity_scores = []
# for u1 in X.unit_ids:
Expand All @@ -136,16 +175,20 @@ def _generate_the_figurl(
# similarity=similarity_matrix[(X.unit_ids==u1),(X.unit_ids==u2)]
# )
# )
# Create the similarity matrix view
# # Create the similarity matrix view
# unit_similarity_matrix_view = vv.UnitSimilarityMatrix(
# unit_ids=X.unit_ids,
# similarity_scores=similarity_scores
# )

# Assemble the views in a layout
# You can replace this with other layouts
raster_plot_subsample_max_firing_rate = 50
spike_amplitudes_subsample_max_firing_rate = 50
raster_plot_subsample_max_firing_rate = (
raster_plot_subsample_max_firing_rate
)
spike_amplitudes_subsample_max_firing_rate = (
spike_amplitudes_subsample_max_firing_rate
)
view = vv.MountainLayout(
items=[
vv.MountainLayoutItem(
Expand Down Expand Up @@ -195,14 +238,10 @@ def _generate_the_figurl(
),
]
)
url_state = (
{
"initialSortingCuration": initial_curation_uri,
"sortingCuration": new_curation_uri,
}
if new_curation_uri
else {"sortingCuration": initial_curation_uri}
)
url_state = {
"initialSortingCuration": initial_curation_uri,
"sortingCuration": initial_curation_uri,
}
label = f"{recording_label} {sorting_label}"
url = view.url(label=label, state=url_state)
return url
Expand Down
Loading

0 comments on commit 20d753a

Please sign in to comment.