Skip to content

Commit

Permalink
Add z-score threshold and raster
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Feb 7, 2024
1 parent 9fb27f2 commit d917bce
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 30 deletions.
84 changes: 54 additions & 30 deletions src/spyglass/mua/v1/mua.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,10 @@ def get_speed(key):
def create_figurl(
self,
zscore_mua=True,
plot_speed=True,
mua_times_color="red",
speed_color="black",
mua_color="black",
view_height=200,
view_height=800,
):
key = self.fetch1("KEY")

Expand All @@ -159,36 +158,61 @@ def create_figurl(
color=mua_color,
width=1,
)
if not plot_speed:
return multiunit_firing_rate_view
else:
speed_view = vv.TimeseriesGraph().add_line_series(
name="Speed [cm/s]",
t=np.asarray(time),
y=np.asarray(self.get_speed(key), dtype=np.float32),
color=speed_color,
if zscore_mua:
mua_params = (MuaEventsParameters & key).fetch1("mua_param_dict")
zscore_threshold = mua_params.get("zscore_threshold")
multiunit_firing_rate_view.add_line_series(
name="Z-Score Threshold",
t=np.asarray(time).squeeze(),
y=np.ones_like(
multiunit_firing_rate, dtype=np.float32
).squeeze()
* zscore_threshold,
color="red",
width=1,
dash=1,
)
vertical_panel_content = [
spike_times = SortedSpikesGroup.get_spike_times(key)
raster_view = vv.RasterPlot(
start_time_sec=time[0],
end_time_sec=time[-1],
plots=[
vv.RasterPlotItem(
unit_id=str(i),
spike_times_sec=np.asarray(times, dtype=np.float32),
)
for i, times in enumerate(spike_times)
],
)

speed_view = vv.TimeseriesGraph().add_line_series(
name="Speed [cm/s]",
t=np.asarray(time),
y=np.asarray(self.get_speed(key), dtype=np.float32),
color=speed_color,
width=1,
)
vertical_panel_content = [
vv.LayoutItem(
multiunit_firing_rate_view, stretch=2, title="Multiunit"
),
vv.LayoutItem(raster_view, stretch=8, title="Raster"),
vv.LayoutItem(speed_view, stretch=2, title="Speed"),
]

view = vv.Box(
direction="horizontal",
show_titles=True,
height=view_height,
items=[
vv.LayoutItem(
multiunit_firing_rate_view, stretch=1, title="Multiunit"
vv.Box(
direction="vertical",
show_titles=True,
items=vertical_panel_content,
)
),
vv.LayoutItem(speed_view, stretch=1, title="Speed"),
]

view = vv.Box(
direction="horizontal",
show_titles=True,
height=view_height,
items=[
vv.LayoutItem(
vv.Box(
direction="vertical",
show_titles=True,
items=vertical_panel_content,
)
),
],
)
],
)

return view.url(label="Multiunit Detection")
return view.url(label="Multiunit Detection")
20 changes: 20 additions & 0 deletions src/spyglass/ripple/v1/ripple.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,26 @@ def create_figurl(
color=consensus_color,
width=1,
)
if zscore_ripple:
ripple_params = (
RippleParameters
& {"ripple_param_name": key["ripple_param_name"]}
).fetch1("ripple_param_dict")

zscore_threshold = ripple_params["ripple_detection_params"].get(
"zscore_threshold"
)
consensus_view.add_line_series(
name="Z-Score Threshold",
t=np.asarray(ripple_consensus_trace.index).squeeze(),
y=np.ones_like(
ripple_consensus_trace, dtype=np.float32
).squeeze()
* zscore_threshold,
color="red",
width=1,
dash=1,
)

if use_ripple_filtered_lfps:
interval_ripple_lfps = ripple_filtered_lfps
Expand Down

0 comments on commit d917bce

Please sign in to comment.