From a9a4fcbc3d6da7f18a7a81d4dae1c52d0199bd55 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Sun, 7 May 2023 19:56:19 -0500 Subject: [PATCH 01/12] Add sample_rate property to BaseDataSource --- ephyviewer/datasource/sourcebase.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/ephyviewer/datasource/sourcebase.py b/ephyviewer/datasource/sourcebase.py index 5bcc885..8daddfc 100644 --- a/ephyviewer/datasource/sourcebase.py +++ b/ephyviewer/datasource/sourcebase.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) - +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) class BaseDataSource: @@ -9,15 +8,19 @@ def __init__(self): @property def nb_channel(self): - raise(NotImplementedError) + raise (NotImplementedError) def get_channel_name(self, chan=0): - raise(NotImplementedError) + raise (NotImplementedError) @property def t_start(self): - raise(NotImplementedError) + raise (NotImplementedError) @property def t_stop(self): - raise(NotImplementedError) + raise (NotImplementedError) + + @property + def sample_rate(self): + raise (NotImplementedError) From aecd9b975dfb238e048f0834e47188902e0f87c4 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Tue, 9 May 2023 20:23:41 -0500 Subject: [PATCH 02/12] Add basic black and white TraceImageViewer --- ephyviewer/__init__.py | 8 +- ephyviewer/traceimageviewer.py | 387 +++++++++++++++++++++++++++++++++ 2 files changed, 391 insertions(+), 4 deletions(-) create mode 100644 ephyviewer/traceimageviewer.py diff --git a/ephyviewer/__init__.py b/ephyviewer/__init__.py index 06b37eb..c31b575 100644 --- a/ephyviewer/__init__.py +++ b/ephyviewer/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from .version import version as __version__ -#common tools +# common tools from .myqt import * from .icons import * from .datasource import * @@ -9,8 +9,7 @@ from .navigation import NavigationToolBar - -#Viewers +# Viewers from .traceviewer import TraceViewer from .videoviewer import VideoViewer from .eventlist import EventList @@ -19,7 +18,8 @@ from .spectrogramviewer import SpectrogramViewer from .dataframeview import DataFrameView from .spiketrainviewer import SpikeTrainViewer +from .traceimageviewer import TraceImageViewer -#Encoders +# Encoders from .epochencoder import EpochEncoder diff --git a/ephyviewer/traceimageviewer.py b/ephyviewer/traceimageviewer.py new file mode 100644 index 0000000..135b453 --- /dev/null +++ b/ephyviewer/traceimageviewer.py @@ -0,0 +1,387 @@ +import numpy as np + +from .myqt import QT +import pyqtgraph as pg + +from .base import BaseMultiChannelViewer, Base_MultiChannel_ParamController +from .datasource import ( + AnalogSignalFromNeoRawIOSource, +) + +default_params = [ + {"name": "xsize", "type": "float", "value": 3.0, "step": 0.1}, + {"name": "xratio", "type": "float", "value": 0.3, "step": 0.1, "limits": (0, 1)}, + {"name": "vline_color", "type": "color", "value": "#FFFFFFAA"}, + {"name": "label_fill_color", "type": "color", "value": "#222222DD"}, + {"name": "label_size", "type": "int", "value": 8, "limits": (1, np.inf)}, + {"name": "display_labels", "type": "bool", "value": False}, + { + "name": "decimation_method", + "type": "list", + "value": "min_max", + "limits": [ + "min_max", + "mean", + "pure_decimate", + ], + }, +] + +default_by_channel_params = [ + {"name": "color", "type": "color", "value": "#55FF00"}, + {"name": "gain", "type": "float", "value": 1, "step": 0.1, "decimals": 8}, + {"name": "visible", "type": "bool", "value": True}, +] + + +class TraceImageViewer_ParamController(Base_MultiChannel_ParamController): + def __init__(self, parent=None, viewer=None): + Base_MultiChannel_ParamController.__init__( + self, parent=parent, viewer=viewer, with_visible=True, with_color=False + ) + + # raw_gains and raw_offsets are distinguished from adjustable gains and + # offsets associated with this viewer because it makes placement of the + # baselines and labels very easy for both raw and in-memory sources + if isinstance(self.viewer.source, AnalogSignalFromNeoRawIOSource): + # use raw_gains and raw_offsets from the raw source + self.raw_gains = self.viewer.source.get_gains() + else: + # use 1 and 0 for in-memory sources, which have already been scaled + # properly + self.raw_gains = np.ones(self.viewer.source.nb_channel) + + @property + def selected(self): # Is this ever used? Safe to remove? + selected = np.ones(self.viewer.source.nb_channel, dtype=bool) + if self.viewer.source.nb_channel > 1: + selected[:] = False + selected[[ind.row() for ind in self.qlist.selectedIndexes()]] = True + return selected + + @property + def visible_channels(self): + visible = [ + self.viewer.by_channel_params["ch{}".format(i), "visible"] + for i in range(self.source.nb_channel) + ] + return np.array(visible, dtype="bool") + + @property + def gains(self): + gains = [ + self.viewer.by_channel_params["ch{}".format(i), "gain"] + for i in range(self.source.nb_channel) + ] + return np.array(gains) + + @gains.setter + def gains(self, val): + for c, v in enumerate(val): + self.viewer.by_channel_params["ch{}".format(c), "gain"] = v + + @property + def total_gains(self): + # compute_rescale sets adjustable gains and offsets such that + # data_curves = (chunk * raw_gains + raw_offsets) * gains + offsets + # = chunk * (raw_gains * gains) + (raw_offsets * gains + offsets) + # = chunk * total_gains + total_offsets + return self.raw_gains * self.gains + + def on_channel_visibility_changed(self): + pass + self.viewer.refresh() + + def on_but_ygain_zoom(self): + factor = self.sender().factor + self.apply_ygain_zoom(factor) + + def apply_ygain_zoom(self, factor_ratio): + self.viewer.all_params.blockSignals(True) + self.gains = self.gains * factor_ratio + self.viewer.all_params.blockSignals(False) + self.viewer.refresh() + + +class DataGrabber(QT.QObject): + data_ready = QT.pyqtSignal(float, float, float, object, object) + + def __init__(self, source, viewer, parent=None): + QT.QObject.__init__(self, parent) + self.source = source + self.viewer = viewer + self._max_point = 3000 + + def get_data( + self, + t, + t_start, + t_stop, + total_gains, + visibles, + decimation_method, + ): + i_start, i_stop = ( + self.source.time_to_index(t_start), + self.source.time_to_index(t_stop) + 2, + ) + + ds_ratio = (i_stop - i_start) // self._max_point + 1 + + if ds_ratio > 1: + i_start = i_start - (i_start % ds_ratio) + i_stop = i_stop - (i_stop % ds_ratio) + + # clip it + i_start = max(0, i_start) + i_start = min(i_start, self.source.get_length()) + i_stop = max(0, i_stop) + i_stop = min(i_stop, self.source.get_length()) + if ds_ratio > 1: + # after clip + i_start = i_start - (i_start % ds_ratio) + i_stop = i_stop - (i_stop % ds_ratio) + + sigs_chunk = self.source.get_chunk(i_start=i_start, i_stop=i_stop) + + data_curves = sigs_chunk[:, visibles].T.copy() + if data_curves.dtype != "float32": + data_curves = data_curves.astype("float32") + + if ds_ratio > 1: + small_size = data_curves.shape[1] // ds_ratio + if decimation_method == "min_max": + small_size *= 2 + + small_arr = np.empty( + (data_curves.shape[0], small_size), dtype=data_curves.dtype + ) + + if decimation_method == "min_max" and data_curves.size > 0: + full_arr = data_curves.reshape(data_curves.shape[0], -1, ds_ratio) + small_arr[:, ::2] = full_arr.max(axis=2) + small_arr[:, 1::2] = full_arr.min(axis=2) + elif decimation_method == "mean" and data_curves.size > 0: + full_arr = data_curves.reshape(data_curves.shape[0], -1, ds_ratio) + small_arr[:, :] = full_arr.mean(axis=2) + elif decimation_method == "pure_decimate": + small_arr[:, :] = data_curves[:, ::ds_ratio] + elif data_curves.size == 0: + pass + + data_curves = small_arr + + data_curves *= total_gains[visibles, None] + + return ( + t, + t_start, + t_stop, + visibles, + data_curves, + ) + + def on_request_data( + self, + t, + t_start, + t_stop, + total_gains, + visibles, + decimation_method, + ): + if self.viewer.t != t: + return + + ( + t, + t_start, + t_stop, + visibles, + data_curves, + ) = self.get_data(t, t_start, t_stop, total_gains, visibles, decimation_method) + + self.data_ready.emit( + t, + t_start, + t_stop, + visibles, + data_curves, + ) + + +class TraceImageLabelItem(pg.TextItem): + label_dragged = QT.pyqtSignal(float) + label_ygain_zoom = QT.pyqtSignal(float) + + def __init__(self, **kwargs): + pg.TextItem.__init__(self, **kwargs) + + self.dragOffset = None + + def mouseDragEvent(self, ev): + """Emit the new y-coord of the label as it is dragged""" + + if ev.button() != QT.LeftButton: + ev.ignore() + return + else: + ev.accept() + + if ev.isStart(): + # To avoid snapping the label to the mouse cursor when the drag + # starts, we determine the offset of the position where the button + # was first pressed down relative to the label's origin/anchor, in + # plot coordinates + self.dragOffset = self.mapToParent(ev.buttonDownPos()) - self.pos() + + # The new y-coord for the label is the mouse's current position during + # the drag with the initial offset removed + new_y = (self.mapToParent(ev.pos()) - self.dragOffset).y() + self.label_dragged.emit(new_y) + + def wheelEvent(self, ev): + """Emit a yzoom factor for the associated trace""" + if ev.modifiers() == QT.Qt.ControlModifier: + z = 5.0 if ev.delta() > 0 else 1 / 5.0 + else: + z = 1.1 if ev.delta() > 0 else 1 / 1.1 + self.label_ygain_zoom.emit(z) + ev.accept() + + +class TraceImageViewer(BaseMultiChannelViewer): + _default_params = default_params + _default_by_channel_params = default_by_channel_params + + _ControllerClass = TraceImageViewer_ParamController + + request_data = QT.pyqtSignal(float, float, float, object, object, object) + + def __init__(self, useOpenGL=None, **kargs): + BaseMultiChannelViewer.__init__(self, **kargs) + + self.make_params() + + # Is there any advantage to using OpenGL for this viewer? + self.set_layout(useOpenGL=useOpenGL) + + self.make_param_controller() + + self.viewBox.doubleclicked.connect(self.show_params_controller) + + self.initialize_plot() + + self.thread = QT.QThread(parent=self) + self.datagrabber = DataGrabber(source=self.source, viewer=self) + self.datagrabber.moveToThread(self.thread) + self.thread.start() + + self.datagrabber.data_ready.connect(self.on_data_ready) + self.request_data.connect(self.datagrabber.on_request_data) + + self.params.param("xsize").setLimits((0, np.inf)) + + def closeEvent(self, event): + event.accept() + self.thread.quit() + self.thread.wait() + + def initialize_plot(self): + self.vline = pg.InfiniteLine( + angle=90, movable=False, pen=self.params["vline_color"] + ) + self.vline.setZValue(1) # ensure vline is above plot elements + self.plot.addItem(self.vline) + + self.image = pg.ImageItem() + self.plot.addItem(self.image) + + self.channel_labels = [] + for c in range(self.source.nb_channel): + color = self.by_channel_params["ch{}".format(c), "color"] + ch_name = "{}: {}".format(c, self.source.get_channel_name(chan=c)) + label = TraceImageLabelItem( + text=ch_name, + color=color, + anchor=(0, 0.5), + border=None, + fill=self.params["label_fill_color"], + ) + label.setZValue(2) # ensure labels are drawn above scatter + font = label.textItem.font() + font.setPointSize(self.params["label_size"]) + label.setFont(font) + + self.plot.addItem(label) + self.channel_labels.append(label) + + self.viewBox.xsize_zoom.connect(self.params_controller.apply_xsize_zoom) + self.viewBox.ygain_zoom.connect(self.params_controller.apply_ygain_zoom) + + def on_param_change(self, params=None, changes=None): + for param, change, data in changes: + if change != "value": + continue + if param.name() == "vline_color": + self.vline.setPen(self.params["vline_color"]) + if param.name() == "label_fill_color": + for label in self.channel_labels: + label.fill = pg.mkBrush(self.params["label_fill_color"]) + if param.name() == "label_size": + for label in self.channel_labels: + font = label.textItem.font() + font.setPointSize(self.params["label_size"]) + label.setFont(font) + + self.refresh() + + def refresh(self): + # ~ print('TraceViewer.refresh', 't', self.t) + xsize = self.params["xsize"] + xratio = self.params["xratio"] + t_start, t_stop = self.t - xsize * xratio, self.t + xsize * (1 - xratio) + (visibles,) = np.nonzero(self.params_controller.visible_channels) + total_gains = self.params_controller.total_gains + + self.request_data.emit( + self.t, + t_start, + t_stop, + total_gains, + visibles, + self.params["decimation_method"], + ) + + def on_data_ready( + self, + t, + t_start, + t_stop, + visibles, + data_curves, + ): + if self.t != t: # Under what circumstances can this happen? + return + + self.image.show() # Why does this happen before the data are set? + self.image.setImage(data_curves.T) # Need to set lut and clims? + n_ch = data_curves.shape[0] + self.image.setRect(QT.QRectF(t_start, 0, t_stop - t_start, n_ch)) + + for i, c in enumerate(visibles): + color = self.by_channel_params["ch{}".format(c), "color"] + if self.params["display_labels"]: + self.channel_labels[c].show() + self.channel_labels[c].setPos(t_start, c) + self.channel_labels[c].setColor(color) + else: + self.channel_labels[c].hide() + + for c in range(self.source.nb_channel): + if c not in visibles: + self.channel_labels[c].hide() + + self.vline.setPos(self.t) + self.plot.setXRange(t_start, t_stop, padding=0.0) + self.plot.setYRange(0, n_ch, padding=0.0) From d71deca274a8e8e676e35b2e240c4458ea1204f1 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Tue, 9 May 2023 20:35:25 -0500 Subject: [PATCH 03/12] Add colormap options to TraceImageViewer --- ephyviewer/traceimageviewer.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/ephyviewer/traceimageviewer.py b/ephyviewer/traceimageviewer.py index 135b453..7f80871 100644 --- a/ephyviewer/traceimageviewer.py +++ b/ephyviewer/traceimageviewer.py @@ -1,12 +1,13 @@ +import matplotlib.cm +import matplotlib.colors import numpy as np - -from .myqt import QT import pyqtgraph as pg from .base import BaseMultiChannelViewer, Base_MultiChannel_ParamController from .datasource import ( AnalogSignalFromNeoRawIOSource, ) +from .myqt import QT default_params = [ {"name": "xsize", "type": "float", "value": 3.0, "step": 0.1}, @@ -25,6 +26,18 @@ "pure_decimate", ], }, + { + "name": "colormap", + "type": "list", + "value": "viridis", + "limits": [ + "inferno", + "viridis", + "jet", + "gray", + "hot", + ], + }, ] default_by_channel_params = [ @@ -270,6 +283,7 @@ def __init__(self, useOpenGL=None, **kargs): self.viewBox.doubleclicked.connect(self.show_params_controller) + self.change_color_scale() self.initialize_plot() self.thread = QT.QThread(parent=self) @@ -287,6 +301,17 @@ def closeEvent(self, event): self.thread.quit() self.thread.wait() + def change_color_scale(self): + N = 512 + cmap_name = self.params["colormap"] + cmap = matplotlib.cm.get_cmap(cmap_name, N) + + lut = [] + for i in range(N): + r, g, b, _ = matplotlib.colors.ColorConverter().to_rgba(cmap(i)) + lut.append([r * 255, g * 255, b * 255]) + self.lut = np.array(lut, dtype="uint8") + def initialize_plot(self): self.vline = pg.InfiniteLine( angle=90, movable=False, pen=self.params["vline_color"] @@ -333,6 +358,8 @@ def on_param_change(self, params=None, changes=None): font = label.textItem.font() font.setPointSize(self.params["label_size"]) label.setFont(font) + if param.name() == "colormap": + self.change_color_scale() self.refresh() @@ -365,7 +392,7 @@ def on_data_ready( return self.image.show() # Why does this happen before the data are set? - self.image.setImage(data_curves.T) # Need to set lut and clims? + self.image.setImage(data_curves.T, lut=self.lut) # Ought to set clims? n_ch = data_curves.shape[0] self.image.setRect(QT.QRectF(t_start, 0, t_stop - t_start, n_ch)) From da6752982d4a87a5a2a53fd9f580d341fe6ae385 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Tue, 9 May 2023 21:13:47 -0500 Subject: [PATCH 04/12] Move sample_rate property from BaseDataSource to BaseAnalogSignalSource, and define where needed. --- ephyviewer/datasource/epochs.py | 326 +++++++++++------- ephyviewer/datasource/neosource.py | 275 ++++++++------- ephyviewer/datasource/signals.py | 60 ++-- ephyviewer/datasource/sourcebase.py | 4 - .../datasource/spikeinterfacesources.py | 47 +-- ephyviewer/datasource/spikes.py | 29 +- ephyviewer/datasource/video.py | 158 +++++---- 7 files changed, 513 insertions(+), 386 deletions(-) diff --git a/ephyviewer/datasource/epochs.py b/ephyviewer/datasource/epochs.py index 27e4da8..8effad5 100644 --- a/ephyviewer/datasource/epochs.py +++ b/ephyviewer/datasource/epochs.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) import os import numpy as np @@ -8,46 +8,49 @@ try: import pandas as pd + HAVE_PANDAS = True except ImportError: HAVE_PANDAS = False -from .sourcebase import BaseDataSource from .events import BaseEventAndEpoch - class InMemoryEpochSource(BaseEventAndEpoch): - type = 'Epoch' + type = "Epoch" def __init__(self, all_epochs=[]): BaseEventAndEpoch.__init__(self, all=all_epochs) - s = [ np.max(e['time']+e['duration']) for e in self.all if len(e['time'])>0] - self._t_stop = max(s) if len(s)>0 else 0 + s = [np.max(e["time"] + e["duration"]) for e in self.all if len(e["time"]) > 0] + self._t_stop = max(s) if len(s) > 0 else 0 - def get_chunk(self, chan=0, i_start=None, i_stop=None): - ep_times = self.all[chan]['time'][i_start:i_stop] - ep_durations = self.all[chan]['duration'][i_start:i_stop] - ep_labels = self.all[chan]['label'][i_start:i_stop] + def get_chunk(self, chan=0, i_start=None, i_stop=None): + ep_times = self.all[chan]["time"][i_start:i_stop] + ep_durations = self.all[chan]["duration"][i_start:i_stop] + ep_labels = self.all[chan]["label"][i_start:i_stop] return ep_times, ep_durations, ep_labels - def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None): - ep_times = self.all[chan]['time'] - ep_durations = self.all[chan]['duration'] - ep_labels = self.all[chan]['label'] - - keep1 = (ep_times>=t_start) & (ep_times<=t_stop) # epochs that start inside range - keep2 = (ep_times+ep_durations>=t_start) & (ep_times+ep_durations<=t_stop) # epochs that end inside range - keep3 = (ep_times<=t_start) & (ep_times+ep_durations>=t_stop) # epochs that span the range + def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None): + ep_times = self.all[chan]["time"] + ep_durations = self.all[chan]["duration"] + ep_labels = self.all[chan]["label"] + + keep1 = (ep_times >= t_start) & ( + ep_times <= t_stop + ) # epochs that start inside range + keep2 = (ep_times + ep_durations >= t_start) & ( + ep_times + ep_durations <= t_stop + ) # epochs that end inside range + keep3 = (ep_times <= t_start) & ( + ep_times + ep_durations >= t_stop + ) # epochs that span the range keep = keep1 | keep2 | keep3 return ep_times[keep], ep_durations[keep], ep_labels[keep] - - class WritableEpochSource(InMemoryEpochSource): """ Identique to EpochSource but onlye one channel that can be persisently saved. @@ -55,8 +58,15 @@ class WritableEpochSource(InMemoryEpochSource): epoch is dict { 'time':np.array, 'duration':np.array, 'label':np.array, 'name':' ''} """ - def __init__(self, epoch=None, possible_labels=[], color_labels=None, channel_name='', restrict_to_possible_labels=False): + def __init__( + self, + epoch=None, + possible_labels=[], + color_labels=None, + channel_name="", + restrict_to_possible_labels=False, + ): self.possible_labels = possible_labels self.channel_name = channel_name @@ -68,61 +78,70 @@ def __init__(self, epoch=None, possible_labels=[], color_labels=None, channel_na # assign each epoch a fixed, unique integer id self._next_id = 0 for chan in self.all: - chan['id'] = np.arange(self._next_id, self._next_id + len(chan['time'])) - self._next_id += len(chan['time']) + chan["id"] = np.arange(self._next_id, self._next_id + len(chan["time"])) + self._next_id += len(chan["time"]) - assert self.all[0]['time'].dtype.kind=='f' - assert self.all[0]['duration'].dtype.kind=='f' + assert self.all[0]["time"].dtype.kind == "f" + assert self.all[0]["duration"].dtype.kind == "f" # add labels missing from possible_labels but found in epoch data - new_labels_from_data = list(set(epoch['label'])-set(self.possible_labels)) + new_labels_from_data = list(set(epoch["label"]) - set(self.possible_labels)) if restrict_to_possible_labels: - assert len(new_labels_from_data)==0, f'epoch data contains labels not found in possible_labels: {new_labels_from_data}' + assert ( + len(new_labels_from_data) == 0 + ), f"epoch data contains labels not found in possible_labels: {new_labels_from_data}" self.possible_labels += new_labels_from_data # put the epochs into a canonical order after loading - self._clean_and_set(self.all[0]['time'], self.all[0]['duration'], self.all[0]['label'], self.all[0]['id']) + self._clean_and_set( + self.all[0]["time"], + self.all[0]["duration"], + self.all[0]["label"], + self.all[0]["id"], + ) # TODO: colors should be managed directly by EpochEncoder if color_labels is None: n = len(self.possible_labels) - cmap = matplotlib.cm.get_cmap('Dark2' , n) - color_labels = [ matplotlib.colors.ColorConverter().to_rgb(cmap(i)) for i in range(n)] - color_labels = (np.array(color_labels)*255).astype(int) + cmap = matplotlib.cm.get_cmap("Dark2", n) + color_labels = [ + matplotlib.colors.ColorConverter().to_rgb(cmap(i)) for i in range(n) + ] + color_labels = (np.array(color_labels) * 255).astype(int) color_labels = color_labels.tolist() self.color_labels = color_labels @property def ep_times(self): - return self.all[0]['time'] + return self.all[0]["time"] @ep_times.setter def ep_times(self, arr): - self.all[0]['time'] = arr + self.all[0]["time"] = arr @property def ep_durations(self): - return self.all[0]['duration'] + return self.all[0]["duration"] @ep_durations.setter def ep_durations(self, arr): - self.all[0]['duration'] = arr + self.all[0]["duration"] = arr @property def ep_labels(self): - return self.all[0]['label'] + return self.all[0]["label"] @ep_labels.setter def ep_labels(self, arr): - self.all[0]['label'] = arr + self.all[0]["label"] = arr @property def ep_ids(self): - return self.all[0]['id'] + return self.all[0]["id"] @ep_ids.setter def ep_ids(self, arr): - self.all[0]['id'] = arr + self.all[0]["id"] = arr @property def ep_stops(self): @@ -130,27 +149,32 @@ def ep_stops(self): @property def id_to_ind(self): - return dict((id,ind) for ind,id in enumerate(self.ep_ids)) - - def get_chunk(self, chan=0, i_start=None, i_stop=None): - assert chan==0 - ep_times = self.all[chan]['time'][i_start:i_stop] - ep_durations = self.all[chan]['duration'][i_start:i_stop] - ep_labels = self.all[chan]['label'][i_start:i_stop] - ep_ids = self.all[chan]['id'][i_start:i_stop] + return dict((id, ind) for ind, id in enumerate(self.ep_ids)) + + def get_chunk(self, chan=0, i_start=None, i_stop=None): + assert chan == 0 + ep_times = self.all[chan]["time"][i_start:i_stop] + ep_durations = self.all[chan]["duration"][i_start:i_stop] + ep_labels = self.all[chan]["label"][i_start:i_stop] + ep_ids = self.all[chan]["id"][i_start:i_stop] return ep_times, ep_durations, ep_labels, ep_ids - - def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None): - assert chan==0 - ep_times = self.all[chan]['time'] - ep_durations = self.all[chan]['duration'] - ep_labels = self.all[chan]['label'] - ep_ids = self.all[chan]['id'] - - keep1 = (ep_times>=t_start) & (ep_times<=t_stop) # epochs that start inside range - keep2 = (ep_times+ep_durations>=t_start) & (ep_times+ep_durations<=t_stop) # epochs that end inside range - keep3 = (ep_times<=t_start) & (ep_times+ep_durations>=t_stop) # epochs that span the range + def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None): + assert chan == 0 + ep_times = self.all[chan]["time"] + ep_durations = self.all[chan]["duration"] + ep_labels = self.all[chan]["label"] + ep_ids = self.all[chan]["id"] + + keep1 = (ep_times >= t_start) & ( + ep_times <= t_stop + ) # epochs that start inside range + keep2 = (ep_times + ep_durations >= t_start) & ( + ep_times + ep_durations <= t_stop + ) # epochs that end inside range + keep3 = (ep_times <= t_start) & ( + ep_times + ep_durations >= t_stop + ) # epochs that span the range keep = keep1 | keep2 | keep3 return ep_times[keep], ep_durations[keep], ep_labels[keep], ep_ids[keep] @@ -163,9 +187,10 @@ def color_by_label(self, label): return self.label_to_color[label] def _clean_and_set(self, ep_times, ep_durations, ep_labels, ep_ids): - # remove bad epochs - keep = ep_durations >= 1e-6 # discard epochs shorter than 1 microsecond or with negative duration + keep = ( + ep_durations >= 1e-6 + ) # discard epochs shorter than 1 microsecond or with negative duration ep_times = ep_times[keep] ep_durations = ep_durations[keep] ep_labels = ep_labels[keep] @@ -185,8 +210,12 @@ def _clean_and_set(self, ep_times, ep_durations, ep_labels, ep_ids): self.ep_ids = ep_ids def add_epoch(self, t1, duration, label): - - ep_times, ep_durations, ep_labels, ep_ids = self.ep_times, self.ep_durations, self.ep_labels, self.ep_ids + ep_times, ep_durations, ep_labels, ep_ids = ( + self.ep_times, + self.ep_durations, + self.ep_labels, + self.ep_ids, + ) ep_times = np.append(ep_times, t1) ep_durations = np.append(ep_durations, duration) ep_labels = np.append(ep_labels, label) @@ -196,8 +225,12 @@ def add_epoch(self, t1, duration, label): self._clean_and_set(ep_times, ep_durations, ep_labels, ep_ids) def delete_epoch(self, ind): - - ep_times, ep_durations, ep_labels, ep_ids = self.ep_times, self.ep_durations, self.ep_labels, self.ep_ids + ep_times, ep_durations, ep_labels, ep_ids = ( + self.ep_times, + self.ep_durations, + self.ep_labels, + self.ep_ids, + ) ep_times = np.delete(ep_times, ind) ep_durations = np.delete(ep_durations, ind) ep_labels = np.delete(ep_labels, ind) @@ -206,30 +239,36 @@ def delete_epoch(self, ind): self._clean_and_set(ep_times, ep_durations, ep_labels, ep_ids) def delete_in_between(self, t1, t2): - - ep_times, ep_durations, ep_stops, ep_labels, ep_ids = self.ep_times, self.ep_durations, self.ep_stops, self.ep_labels, self.ep_ids + ep_times, ep_durations, ep_stops, ep_labels, ep_ids = ( + self.ep_times, + self.ep_durations, + self.ep_stops, + self.ep_labels, + self.ep_ids, + ) for i in range(len(ep_times)): - # if epoch starts and ends inside range, delete it - if ep_times[i]>=t1 and ep_stops[i]<=t2: - ep_durations[i] = -1 # non-positive duration flags this epoch for clean up + if ep_times[i] >= t1 and ep_stops[i] <= t2: + ep_durations[ + i + ] = -1 # non-positive duration flags this epoch for clean up # if epoch starts before and ends inside range, truncate it - elif ep_times[i]t2: + elif (t1 <= ep_times[i] < t2) and ep_stops[i] > t2: ep_durations[i] = ep_stops[i] - t2 ep_times[i] = t2 # if epoch starts before and ends after range, # truncate the first part and add a new epoch for the end part - elif ep_times[i]<=t1 and ep_stops[i]>=t2: + elif ep_times[i] <= t1 and ep_stops[i] >= t2: ep_durations[i] = t1 - ep_times[i] ep_times = np.append(ep_times, t2) - ep_durations = np.append(ep_durations, ep_stops[i]-t2) + ep_durations = np.append(ep_durations, ep_stops[i] - t2) ep_labels = np.append(ep_labels, ep_labels[i]) ep_ids = np.append(ep_ids, self._next_id) self._next_id += 1 @@ -237,67 +276,86 @@ def delete_in_between(self, t1, t2): self._clean_and_set(ep_times, ep_durations, ep_labels, ep_ids) def merge_neighbors(self): - - ep_times, ep_durations, ep_stops, ep_labels, ep_ids = self.ep_times, self.ep_durations, self.ep_stops, self.ep_labels, self.ep_ids + ep_times, ep_durations, ep_stops, ep_labels, ep_ids = ( + self.ep_times, + self.ep_durations, + self.ep_stops, + self.ep_labels, + self.ep_ids, + ) for label in self.possible_labels: - inds, = np.nonzero(ep_labels == label) - for i in range(len(inds)-1): - + (inds,) = np.nonzero(ep_labels == label) + for i in range(len(inds) - 1): # if two sequentially adjacent epochs with the same label # overlap or have less than 1 microsecond separation, merge them - if ep_times[inds[i+1]] - ep_stops[inds[i]] < 1e-6: - + if ep_times[inds[i + 1]] - ep_stops[inds[i]] < 1e-6: # stretch the second epoch to cover the range of both epochs - ep_times[inds[i+1]] = min(ep_times[inds[i]], ep_times[inds[i+1]]) - ep_stops[inds[i+1]] = max(ep_stops[inds[i]], ep_stops[inds[i+1]]) - ep_durations[inds[i+1]] = ep_stops[inds[i+1]] - ep_times[inds[i+1]] + ep_times[inds[i + 1]] = min( + ep_times[inds[i]], ep_times[inds[i + 1]] + ) + ep_stops[inds[i + 1]] = max( + ep_stops[inds[i]], ep_stops[inds[i + 1]] + ) + ep_durations[inds[i + 1]] = ( + ep_stops[inds[i + 1]] - ep_times[inds[i + 1]] + ) # delete the first epoch - ep_durations[inds[i]] = -1 # non-positive duration flags this epoch for clean up + ep_durations[ + inds[i] + ] = -1 # non-positive duration flags this epoch for clean up self._clean_and_set(ep_times, ep_durations, ep_labels, ep_ids) def split_epoch(self, ind, t_split): - - ep_times, ep_durations, ep_stops, ep_labels, ep_ids = self.ep_times, self.ep_durations, self.ep_stops, self.ep_labels, self.ep_ids + ep_times, ep_durations, ep_stops, ep_labels, ep_ids = ( + self.ep_times, + self.ep_durations, + self.ep_stops, + self.ep_labels, + self.ep_ids, + ) if t_split <= ep_times[ind] or ep_stops[ind] <= t_split: return ep_durations[ind] = t_split - ep_times[ind] ep_times = np.append(ep_times, t_split) - ep_durations = np.append(ep_durations, ep_stops[ind]-t_split) + ep_durations = np.append(ep_durations, ep_stops[ind] - t_split) ep_labels = np.append(ep_labels, ep_labels[ind]) ep_ids = np.append(ep_ids, self._next_id) self._next_id += 1 self._clean_and_set(ep_times, ep_durations, ep_labels, ep_ids) - def fill_blank(self, method='from_left'): - - ep_times, ep_durations, ep_labels, ep_ids = self.ep_times, self.ep_durations, self.ep_labels, self.ep_ids + def fill_blank(self, method="from_left"): + ep_times, ep_durations, ep_labels, ep_ids = ( + self.ep_times, + self.ep_durations, + self.ep_labels, + self.ep_ids, + ) - mask = ((ep_times[:-1] + ep_durations[:-1])='0.6.0': + + if V(neo.__version__) >= "0.6.0": HAVE_NEO = True from neo.rawio.baserawio import BaseRawIO else: HAVE_NEO = False - #~ print('neo version is too old', neo.__version__) + # ~ print('neo version is too old', neo.__version__) except ImportError: HAVE_NEO = False @@ -30,22 +28,23 @@ from .epochs import InMemoryEpochSource - logger = logging.getLogger() -#~ print('HAVE_NEO', HAVE_NEO) - +# ~ print('HAVE_NEO', HAVE_NEO) ## neo.core stuff + class NeoAnalogSignalSource(InMemoryAnalogSignalSource): def __init__(self, neo_sig): signals = neo_sig.magnitude - sample_rate = float(neo_sig.sampling_rate.rescale('Hz').magnitude) - t_start = float(neo_sig.t_start.rescale('s').magnitude) + sample_rate = float(neo_sig.sampling_rate.rescale("Hz").magnitude) + t_start = float(neo_sig.t_start.rescale("s").magnitude) - InMemoryAnalogSignalSource.__init__(self, signals, sample_rate, t_start, channel_names=None) + InMemoryAnalogSignalSource.__init__( + self, signals, sample_rate, t_start, channel_names=None + ) class NeoSpikeTrainSource(InMemorySpikeSource): @@ -54,62 +53,68 @@ def __init__(self, neo_spiketrains=[]): for neo_spiketrain in neo_spiketrains: name = neo_spiketrain.name if name is None: - name = '' - all_spikes.append({'time' : neo_spiketrain.times.rescale('s').magnitude, - 'name' : name}) + name = "" + all_spikes.append( + {"time": neo_spiketrain.times.rescale("s").magnitude, "name": name} + ) InMemorySpikeSource.__init__(self, all_spikes=all_spikes) + class NeoEventSource(InMemoryEventSource): def __init__(self, neo_events=[]): all_events = [] for neo_event in neo_events: - all_events.append({ - 'name': neo_event.name, - 'time': neo_event.times.rescale('s').magnitude, - 'label': np.array(neo_event.labels), - }) - InMemoryEventSource.__init__(self, all_events = all_events) + all_events.append( + { + "name": neo_event.name, + "time": neo_event.times.rescale("s").magnitude, + "label": np.array(neo_event.labels), + } + ) + InMemoryEventSource.__init__(self, all_events=all_events) + class NeoEpochSource(InMemoryEpochSource): def __init__(self, neo_epochs=[]): all_epochs = [] for neo_epoch in neo_epochs: - all_epochs.append({ - 'name': neo_epoch.name, - 'time': neo_epoch.times.rescale('s').magnitude, - 'duration': neo_epoch.durations.rescale('s').magnitude, - 'label': np.array(neo_epoch.labels), - }) - epoch_source = InMemoryEpochSource.__init__(self, all_epochs = all_epochs) - - - + all_epochs.append( + { + "name": neo_epoch.name, + "time": neo_epoch.times.rescale("s").magnitude, + "duration": neo_epoch.durations.rescale("s").magnitude, + "label": np.array(neo_epoch.labels), + } + ) + epoch_source = InMemoryEpochSource.__init__(self, all_epochs=all_epochs) def get_sources_from_neo_segment(neo_seg): assert HAVE_NEO assert isinstance(neo_seg, neo.Segment) - sources = {'signal':[], 'epoch':[], 'spike':[],'event':[],} + sources = { + "signal": [], + "epoch": [], + "spike": [], + "event": [], + } for neo_sig in neo_seg.analogsignals: # normally neo signals are grouped by same sampling rate in one AnalogSignal # with shape (nb_channel, nb_sample) - sources['signal'].append(NeoAnalogSignalSource(neo_sig)) - - sources['spike'].append(NeoSpikeTrainSource(neo_seg.spiketrains)) - sources['event'].append(NeoEventSource(neo_seg.events)) - sources['epoch'].append(NeoEpochSource(neo_seg.epochs)) + sources["signal"].append(NeoAnalogSignalSource(neo_sig)) + sources["spike"].append(NeoSpikeTrainSource(neo_seg.spiketrains)) + sources["event"].append(NeoEventSource(neo_seg.events)) + sources["epoch"].append(NeoEpochSource(neo_seg.epochs)) return sources - - - ## neo.rawio stuff + class AnalogSignalFromNeoRawIOSource(BaseAnalogSignalSource): def __init__(self, neorawio, channel_indexes=None, stream_index=None): """ @@ -142,7 +147,7 @@ def __init__(self, neorawio, channel_indexes=None, stream_index=None): channel_indexes = slice(None) self.channel_indexes = channel_indexes - if V(neo.__version__)>='0.10.0': + if V(neo.__version__) >= "0.10.0": # Neo >= 0.10 # - versions 0.10+ index channels within a stream if stream_index is not None: @@ -150,32 +155,45 @@ def __init__(self, neorawio, channel_indexes=None, stream_index=None): elif self.neorawio.signal_streams_count() == 1: self.stream_index = 0 else: - raise ValueError(f'Because the Neo RawIO source contains multiple signal streams ({self.neorawio.signal_streams_count()}), stream_index must be provided') - self.stream_id = self.neorawio.header['signal_streams'][self.stream_index]['id'] - signal_channels = self.neorawio.header['signal_channels'] - mask = signal_channels['stream_id'] == self.stream_id + raise ValueError( + f"Because the Neo RawIO source contains multiple signal streams ({self.neorawio.signal_streams_count()}), stream_index must be provided" + ) + self.stream_id = self.neorawio.header["signal_streams"][self.stream_index][ + "id" + ] + signal_channels = self.neorawio.header["signal_channels"] + mask = signal_channels["stream_id"] == self.stream_id self.channels = signal_channels[mask][self.channel_indexes] else: # Neo < 0.10 # - versions 0.6-0.9 index channels globally (ignoring signal group) - assert stream_index is None, f'Neo version {neo.__version__} is installed, but only Neo>=0.10 uses stream_index' - self.channels = self.neorawio.header['signal_channels'][self.channel_indexes] - - if V(neo.__version__)>='0.10.0': + assert ( + stream_index is None + ), f"Neo version {neo.__version__} is installed, but only Neo>=0.10 uses stream_index" + self.channels = self.neorawio.header["signal_channels"][ + self.channel_indexes + ] + + if V(neo.__version__) >= "0.10.0": # Neo >= 0.10 # - versions 0.10+ use stream_index as an argument often, # but also require channel_indexes for get_chunk - self.signal_indexing_kwarg = {'stream_index': self.stream_index} - self.get_chunk_kwargs = {'stream_index': self.stream_index, 'channel_indexes': self.channel_indexes} + self.signal_indexing_kwarg = {"stream_index": self.stream_index} + self.get_chunk_kwargs = { + "stream_index": self.stream_index, + "channel_indexes": self.channel_indexes, + } else: # Neo < 0.10 # - versions 0.6-0.9 use channel_indexes as an argument often - self.signal_indexing_kwarg = {'channel_indexes': self.channel_indexes} - self.get_chunk_kwargs = {'channel_indexes': self.channel_indexes} + self.signal_indexing_kwarg = {"channel_indexes": self.channel_indexes} + self.get_chunk_kwargs = {"channel_indexes": self.channel_indexes} - self.sample_rate = self.neorawio.get_signal_sampling_rate(**self.signal_indexing_kwarg) + self.sample_rate = self.neorawio.get_signal_sampling_rate( + **self.signal_indexing_kwarg + ) - #TODO: something for multi segment + # TODO: something for multi segment self.block_index = 0 self.seg_index = 0 @@ -184,60 +202,65 @@ def nb_channel(self): return len(self.channels) def get_channel_name(self, chan=0): - return self.channels[chan]['name'] + return self.channels[chan]["name"] @property def t_start(self): - t_start = self.neorawio.get_signal_t_start(self.block_index, self.seg_index, - **self.signal_indexing_kwarg) + t_start = self.neorawio.get_signal_t_start( + self.block_index, self.seg_index, **self.signal_indexing_kwarg + ) return t_start @property def t_stop(self): - t_stop = self.t_start + self.get_length()/self.sample_rate + t_stop = self.t_start + self.get_length() / self.sample_rate return t_stop def get_length(self): - length = self.neorawio.get_signal_size(self.block_index, self.seg_index, - **self.signal_indexing_kwarg) + length = self.neorawio.get_signal_size( + self.block_index, self.seg_index, **self.signal_indexing_kwarg + ) return length def get_gains(self): - return self.channels['gain'] + return self.channels["gain"] def get_offsets(self): - return self.channels['offset'] + return self.channels["offset"] def get_shape(self): return (self.get_length(), self.nb_channel) def get_chunk(self, i_start=None, i_stop=None): - sigs = self.neorawio.get_analogsignal_chunk(block_index=self.block_index, seg_index=self.seg_index, - i_start=i_start, i_stop=i_stop, **self.get_chunk_kwargs) + sigs = self.neorawio.get_analogsignal_chunk( + block_index=self.block_index, + seg_index=self.seg_index, + i_start=i_start, + i_stop=i_stop, + **self.get_chunk_kwargs, + ) return sigs - - class SpikeFromNeoRawIOSource(BaseSpikeSource): def __init__(self, neorawio, channel_indexes=None): - self.neorawio =neorawio + self.neorawio = neorawio if channel_indexes is None: channel_indexes = slice(None) self.channel_indexes = channel_indexes - if V(neo.__version__)>='0.10.0': + if V(neo.__version__) >= "0.10.0": # Neo >= 0.10 # - versions 0.10+ have spike_channels - self.channels = self.neorawio.header['spike_channels'][channel_indexes] - self.get_chunk_kwarg = 'spike_channel_index' + self.channels = self.neorawio.header["spike_channels"][channel_indexes] + self.get_chunk_kwarg = "spike_channel_index" else: # Neo < 0.10 # - versions 0.6-0.9 have unit_channels - self.channels = self.neorawio.header['unit_channels'][channel_indexes] - self.get_chunk_kwarg = 'unit_index' + self.channels = self.neorawio.header["unit_channels"][channel_indexes] + self.get_chunk_kwarg = "unit_index" - #TODO: something for multi segment + # TODO: something for multi segment self.block_index = 0 self.seg_index = 0 @@ -246,7 +269,7 @@ def nb_channel(self): return len(self.channels) def get_channel_name(self, chan=0): - return self.channels[chan]['name'] + return self.channels[chan]["name"] @property def t_start(self): @@ -258,29 +281,35 @@ def t_stop(self): t_stop = self.neorawio.segment_t_stop(self.block_index, self.seg_index) return t_stop - def get_chunk(self, chan=0, i_start=None, i_stop=None): - raise(NotImplementedError) + def get_chunk(self, chan=0, i_start=None, i_stop=None): + raise (NotImplementedError) - def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None): - spike_timestamp = self.neorawio.get_spike_timestamps(block_index=self.block_index, - seg_index=self.seg_index, **{self.get_chunk_kwarg: chan}, t_start=t_start, t_stop=t_stop) + def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None): + spike_timestamp = self.neorawio.get_spike_timestamps( + block_index=self.block_index, + seg_index=self.seg_index, + **{self.get_chunk_kwarg: chan}, + t_start=t_start, + t_stop=t_stop, + ) - spike_times = self.neorawio.rescale_spike_timestamp(spike_timestamp, dtype='float64') + spike_times = self.neorawio.rescale_spike_timestamp( + spike_timestamp, dtype="float64" + ) return spike_times - class EpochFromNeoRawIOSource(BaseEventAndEpoch): def __init__(self, neorawio, channel_indexes=None): - self.neorawio =neorawio + self.neorawio = neorawio if channel_indexes is None: channel_indexes = slice(None) self.channel_indexes = channel_indexes - self.channels = self.neorawio.header['event_channels'][channel_indexes] + self.channels = self.neorawio.header["event_channels"][channel_indexes] - #TODO: something for multi segment + # TODO: something for multi segment self.block_index = 0 self.seg_index = 0 @@ -291,7 +320,7 @@ def nb_channel(self): return len(self.channels) def get_channel_name(self, chan=0): - return self.channels[chan]['name'] + return self.channels[chan]["name"] @property def t_start(self): @@ -303,10 +332,12 @@ def t_stop(self): t_stop = self.neorawio.segment_t_stop(self.block_index, self.seg_index) return t_stop - def get_chunk(self, chan=0, i_start=None, i_stop=None): - k = (self.block_index , self.seg_index, chan) + def get_chunk(self, chan=0, i_start=None, i_stop=None): + k = (self.block_index, self.seg_index, chan) if k not in self._cache_event: - self._cache_event[k] = self.get_chunk_by_time(chan=chan, t_start=None, t_stop=None) + self._cache_event[k] = self.get_chunk_by_time( + chan=chan, t_start=None, t_stop=None + ) ep_times, ep_durations, ep_labels = self._cache_event[k] @@ -317,74 +348,82 @@ def get_chunk(self, chan=0, i_start=None, i_stop=None): return ep_times, ep_durations, ep_labels - def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None): + def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None): + ep_timestamps, ep_durations, ep_labels = self.neorawio.get_event_timestamps( + block_index=self.block_index, + seg_index=self.seg_index, + event_channel_index=chan, + t_start=t_start, + t_stop=t_stop, + ) - ep_timestamps, ep_durations, ep_labels = self.neorawio.get_event_timestamps(block_index=self.block_index, - seg_index=self.seg_index, event_channel_index=chan, t_start=t_start, t_stop=t_stop) - - ep_times = self.neorawio.rescale_event_timestamp(ep_timestamps, dtype='float64') + ep_times = self.neorawio.rescale_event_timestamp(ep_timestamps, dtype="float64") if ep_durations is not None: - ep_durations = self.neorawio.rescale_epoch_duration(ep_durations, dtype='float64') + ep_durations = self.neorawio.rescale_epoch_duration( + ep_durations, dtype="float64" + ) else: ep_durations = np.zeros_like(ep_times) return ep_times, ep_durations, ep_labels - - - def get_sources_from_neo_rawio(neorawio): assert HAVE_NEO assert isinstance(neorawio, BaseRawIO) if neorawio.header is None: - logger.info('parse header') + logger.info("parse header") neorawio.parse_header() - sources = {'signal':[], 'epoch':[], 'spike':[]} + sources = {"signal": [], "epoch": [], "spike": []} - - if hasattr(neorawio, 'signal_streams_count'): + if hasattr(neorawio, "signal_streams_count"): # Neo >= 0.10.0 # - version 0.10 replaced signal groups with signal streams for stream_index in range(neorawio.signal_streams_count()): # one source per signal stream - sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, stream_index=stream_index)) - elif hasattr(neorawio, 'get_group_signal_channel_indexes'): + sources["signal"].append( + AnalogSignalFromNeoRawIOSource(neorawio, stream_index=stream_index) + ) + elif hasattr(neorawio, "get_group_signal_channel_indexes"): # Neo >= 0.9.0 and < 0.10 # - version 0.9 renamed BaseRawIO.get_group_channel_indexes() to BaseRawIO.get_group_signal_channel_indexes() if neorawio.signal_channels_count() > 0: channel_indexes_list = neorawio.get_group_signal_channel_indexes() for channel_indexes in channel_indexes_list: # one source per channel group - sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, channel_indexes=channel_indexes)) - elif hasattr(neorawio, 'get_group_channel_indexes'): + sources["signal"].append( + AnalogSignalFromNeoRawIOSource( + neorawio, channel_indexes=channel_indexes + ) + ) + elif hasattr(neorawio, "get_group_channel_indexes"): # Neo < 0.9.0 # - versions 0.6-0.8 have BaseRawIO.get_group_channel_indexes() if neorawio.signal_channels_count() > 0: channel_indexes_list = neorawio.get_group_channel_indexes() for channel_indexes in channel_indexes_list: # one source per channel group - sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, channel_indexes=channel_indexes)) - + sources["signal"].append( + AnalogSignalFromNeoRawIOSource( + neorawio, channel_indexes=channel_indexes + ) + ) - if hasattr(neorawio, 'spike_channels_count'): + if hasattr(neorawio, "spike_channels_count"): # Neo >= 0.10 # - version 0.10 renamed BaseRawIO.unit_channels_count() to BaseRawIO.spike_channels_count() - if neorawio.spike_channels_count()>0: - sources['spike'].append(SpikeFromNeoRawIOSource(neorawio, None)) - elif hasattr(neorawio, 'unit_channels_count'): + if neorawio.spike_channels_count() > 0: + sources["spike"].append(SpikeFromNeoRawIOSource(neorawio, None)) + elif hasattr(neorawio, "unit_channels_count"): # Neo < 0.10 # - versions 0.6-0.9 have BaseRawIO.unit_channels_count() - if neorawio.unit_channels_count()>0: - sources['spike'].append(SpikeFromNeoRawIOSource(neorawio, None)) - - - if neorawio.event_channels_count()>0: - sources['epoch'].append(EpochFromNeoRawIOSource(neorawio, None)) - + if neorawio.unit_channels_count() > 0: + sources["spike"].append(SpikeFromNeoRawIOSource(neorawio, None)) + if neorawio.event_channels_count() > 0: + sources["epoch"].append(EpochFromNeoRawIOSource(neorawio, None)) return sources diff --git a/ephyviewer/datasource/signals.py b/ephyviewer/datasource/signals.py index 0540bf7..1908bb8 100644 --- a/ephyviewer/datasource/signals.py +++ b/ephyviewer/datasource/signals.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) import numpy as np @@ -12,25 +12,29 @@ class BaseAnalogSignalSource(BaseDataSource): - type = 'AnalogSignal' + type = "AnalogSignal" def __init__(self): self.with_scatter = False - + def get_length(self): - raise(NotImplementedError) + raise (NotImplementedError) def get_shape(self): return (self.get_length(), self.nb_channel) def get_chunk(self, i_start=None, i_stop=None): - raise(NotImplementedError) + raise (NotImplementedError) def time_to_index(self, t): - return int((t-self.t_start)*self.sample_rate) + return int((t - self.t_start) * self.sample_rate) def index_to_time(self, ind): - return float(ind/self.sample_rate) + self.t_start + return float(ind / self.sample_rate) + self.t_start + + @property + def sample_rate(self): + raise (NotImplementedError) class InMemoryAnalogSignalSource(BaseAnalogSignalSource): @@ -38,13 +42,14 @@ def __init__(self, signals, sample_rate, t_start, channel_names=None): BaseAnalogSignalSource.__init__(self) self.signals = signals - self.sample_rate = float(sample_rate) + self._sample_rate = float(sample_rate) self._t_start = float(t_start) - self._t_stop = self.signals.shape[0]/self.sample_rate + float(t_start) + self._t_stop = self.signals.shape[0] / self.sample_rate + float(t_start) self.channel_names = channel_names if channel_names is None: - self.channel_names = ['Channel {:3}'.format(c) for c in range(self.signals.shape[1])] - + self.channel_names = [ + "Channel {:3}".format(c) for c in range(self.signals.shape[1]) + ] @property def nb_channel(self): @@ -53,6 +58,10 @@ def nb_channel(self): def get_channel_name(self, chan=0): return self.channel_names[chan] + @property + def sample_rate(self): + return self._sample_rate + @property def t_start(self): return self._t_start @@ -68,14 +77,23 @@ def get_chunk(self, i_start=None, i_stop=None): return self.signals[i_start:i_stop, :] - - class AnalogSignalSourceWithScatter(InMemoryAnalogSignalSource): - def __init__(self, signals, sample_rate, t_start, scatter_indexes, scatter_channels, scatter_colors=None, channel_names=None): - InMemoryAnalogSignalSource.__init__(self, signals, sample_rate, t_start, channel_names=channel_names) + def __init__( + self, + signals, + sample_rate, + t_start, + scatter_indexes, + scatter_channels, + scatter_colors=None, + channel_names=None, + ): + InMemoryAnalogSignalSource.__init__( + self, signals, sample_rate, t_start, channel_names=channel_names + ) self.with_scatter = True - #todo test and assert self.scatter_indexes sorted for eack k + # todo test and assert self.scatter_indexes sorted for eack k self.scatter_indexes = scatter_indexes self.scatter_channels = scatter_channels self.scatter_colors = scatter_colors @@ -85,18 +103,18 @@ def __init__(self, signals, sample_rate, t_start, scatter_indexes, scatter_chann if self.scatter_colors is None: self.scatter_colors = {} n = len(self._labels) - colors = matplotlib.cm.get_cmap('Accent', n) - for i,k in enumerate(self._labels): + colors = matplotlib.cm.get_cmap("Accent", n) + for i, k in enumerate(self._labels): self.scatter_colors[k] = matplotlib.colors.to_hex(colors(i)) def get_scatter_babels(self): return self._labels - def get_scatter(self, i_start=None, i_stop=None, chan=None, label=None): + def get_scatter(self, i_start=None, i_stop=None, chan=None, label=None): if chan not in self.scatter_channels[label]: return None inds = self.scatter_indexes[label] - i1 = np.searchsorted(inds, i_start, side='left') - i2 = np.searchsorted(inds, i_stop, side='left') + i1 = np.searchsorted(inds, i_start, side="left") + i2 = np.searchsorted(inds, i_stop, side="left") return inds[i1:i2] diff --git a/ephyviewer/datasource/sourcebase.py b/ephyviewer/datasource/sourcebase.py index 8daddfc..62916ec 100644 --- a/ephyviewer/datasource/sourcebase.py +++ b/ephyviewer/datasource/sourcebase.py @@ -20,7 +20,3 @@ def t_start(self): @property def t_stop(self): raise (NotImplementedError) - - @property - def sample_rate(self): - raise (NotImplementedError) diff --git a/ephyviewer/datasource/spikeinterfacesources.py b/ephyviewer/datasource/spikeinterfacesources.py index d2327cc..7748dba 100644 --- a/ephyviewer/datasource/spikeinterfacesources.py +++ b/ephyviewer/datasource/spikeinterfacesources.py @@ -1,17 +1,13 @@ """ Data sources for SpikeInterface """ - -from .sourcebase import BaseDataSource -import sys -import logging - import numpy as np try: from distutils.version import LooseVersion as V import spikeinterface - if V(spikeinterface.__version__)>='0.90.1': + + if V(spikeinterface.__version__) >= "0.90.1": HAVE_SI = True else: HAVE_SI = False @@ -31,8 +27,12 @@ def __init__(self, recording, segment_index=0): self.segment_index = segment_index self._nb_channel = self.recording.get_num_channels() - self.sample_rate = self.recording.get_sampling_frequency() - self._t_start = 0. + self._sample_rate = self.recording.get_sampling_frequency() + self._t_start = 0.0 + + @property + def sample_rate(self): + return self._sample_rate @property def nb_channel(self): @@ -53,10 +53,12 @@ def get_length(self): return self.recording.get_num_samples(segment_index=self.segment_index) def get_shape(self): - return (self.get_length(),self.nb_channel) + return (self.get_length(), self.nb_channel) def get_chunk(self, i_start=None, i_stop=None): - traces = self.recording.get_traces(segment_index=self.segment_index, start_frame=i_start, end_frame=i_stop) + traces = self.recording.get_traces( + segment_index=self.segment_index, start_frame=i_start, end_frame=i_stop + ) return traces def time_to_index(self, t): @@ -66,7 +68,6 @@ def index_to_time(self, ind): return float(ind / self.sample_rate) - class SpikeInterfaceSortingSource(BaseSpikeSource): def __init__(self, sorting, segment_index=0): BaseSpikeSource.__init__(self) @@ -74,8 +75,8 @@ def __init__(self, sorting, segment_index=0): self.sorting = sorting self.segment_index = segment_index - #TODO - self._t_stop = 10. + # TODO + self._t_stop = 10.0 @property def nb_channel(self): @@ -86,23 +87,27 @@ def get_channel_name(self, chan=0): @property def t_start(self): - return 0. + return 0.0 @property def t_stop(self): return self._t_stop - def get_chunk(self, chan=0, i_start=None, i_stop=None): + def get_chunk(self, chan=0, i_start=None, i_stop=None): unit_id = self.sorting.unit_ids[chan] - spike_frames = self.sorting.get_unit_spike_train(unit_id, - segment_index=self.segment_index, start_frame=i_start, end_frame=i_stop) + spike_frames = self.sorting.get_unit_spike_train( + unit_id, + segment_index=self.segment_index, + start_frame=i_start, + end_frame=i_stop, + ) spike_frames = spike_frames[i_start:i_stop] spike_times = spike_frames / self.sorting.get_sampling_frequency() return spike_times - def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None): + def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None): spike_times = self.get_chunk(chan=chan) - i1 = np.searchsorted(spike_times, t_start, side='left') - i2 = np.searchsorted(spike_times, t_stop, side='left') - sl = slice(i1, i2+1) + i1 = np.searchsorted(spike_times, t_start, side="left") + i2 = np.searchsorted(spike_times, t_stop, side="left") + sl = slice(i1, i2 + 1) return spike_times[sl] diff --git a/ephyviewer/datasource/spikes.py b/ephyviewer/datasource/spikes.py index fd7955c..9033e96 100644 --- a/ephyviewer/datasource/spikes.py +++ b/ephyviewer/datasource/spikes.py @@ -1,34 +1,29 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) - - +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) import numpy as np -from .sourcebase import BaseDataSource -from .events import BaseEventAndEpoch, InMemoryEventSource +from .events import BaseEventAndEpoch class BaseSpikeSource(BaseEventAndEpoch): - type = 'Spike' + type = "Spike" class InMemorySpikeSource(BaseSpikeSource): - def __init__(self, all_spikes=[]): BaseSpikeSource.__init__(self, all=all_spikes) - s = [ np.max(e['time']) for e in self.all if len(e['time'])>0] - self._t_stop = max(s) if len(s)>0 else 0 - + s = [np.max(e["time"]) for e in self.all if len(e["time"]) > 0] + self._t_stop = max(s) if len(s) > 0 else 0 - def get_chunk(self, chan=0, i_start=None, i_stop=None): - spike_times = self.all[chan]['time'][i_start:i_stop] + def get_chunk(self, chan=0, i_start=None, i_stop=None): + spike_times = self.all[chan]["time"][i_start:i_stop] return spike_times - def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None): - spike_times = self.all[chan]['time'] + def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None): + spike_times = self.all[chan]["time"] - i1 = np.searchsorted(spike_times, t_start, side='left') - i2 = np.searchsorted(spike_times, t_stop, side='left') - sl = slice(i1, i2+1) + i1 = np.searchsorted(spike_times, t_start, side="left") + i2 = np.searchsorted(spike_times, t_stop, side="left") + sl = slice(i1, i2 + 1) return spike_times[sl] diff --git a/ephyviewer/datasource/video.py b/ephyviewer/datasource/video.py index 8784fa2..19898e4 100644 --- a/ephyviewer/datasource/video.py +++ b/ephyviewer/datasource/video.py @@ -1,14 +1,13 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) from .sourcebase import BaseDataSource -import sys - import numpy as np try: import av + HAVE_AV = True except ImportError: HAVE_AV = False @@ -16,30 +15,31 @@ AV_TIME_BASE = 1000000 + def pts_to_frame(pts, time_base, frame_rate, start_time): return int(pts * time_base * frame_rate) - int(start_time * time_base * frame_rate) -def get_frame_rate(stream): +def get_frame_rate(stream): if stream.average_rate.denominator and stream.average_rate.numerator: return float(stream.average_rate) if stream.time_base.denominator and stream.time_base.numerator: - return 1.0/float(stream.time_base) + return 1.0 / float(stream.time_base) else: raise ValueError("Unable to determine FPS") -#~ def get_frame_count(f, stream): - #~ if stream.frames: - #~ return stream.frames - #~ elif stream.duration: - #~ return pts_to_frame(stream.duration, float(stream.time_base), get_frame_rate(stream), 0) - #~ elif f.duration: - #~ return pts_to_frame(f.duration, 1/float(AV_TIME_BASE), get_frame_rate(stream), 0) +# ~ def get_frame_count(f, stream): - #~ else: - #~ raise ValueError("Unable to determine number for frames") +# ~ if stream.frames: +# ~ return stream.frames +# ~ elif stream.duration: +# ~ return pts_to_frame(stream.duration, float(stream.time_base), get_frame_rate(stream), 0) +# ~ elif f.duration: +# ~ return pts_to_frame(f.duration, 1/float(AV_TIME_BASE), get_frame_rate(stream), 0) +# ~ else: +# ~ raise ValueError("Unable to determine number for frames") class FrameGrabber: @@ -65,7 +65,6 @@ def __init__(self): self.last_frame_index = None def next_frame(self): - frame_index = None rate = self.rate @@ -74,43 +73,45 @@ def next_frame(self): self.pts_seen = False for packet in self.file.demux(self.stream): - #~ print(" pkt", packet.pts, packet.dts, packet) + # ~ print(" pkt", packet.pts, packet.dts, packet) if packet.pts: self.pts_seen = True for frame in packet.decode(): - #~ print(' frame', frame) + # ~ print(' frame', frame) if frame_index is None: - if self.pts_seen: pts = frame.pts else: pts = frame.dts - #~ print(' pts',pts) + # ~ print(' pts',pts) if not pts is None: - frame_index = pts_to_frame(pts, time_base, rate, self.start_time) + frame_index = pts_to_frame( + pts, time_base, rate, self.start_time + ) elif not frame_index is None: frame_index += 1 - yield frame_index, frame def get_frame(self, target_frame): - #~ print('get_frame', target_frame) + # ~ print('get_frame', target_frame) if target_frame == self.last_frame_index: frame = self.last_frame elif target_frame < 0 or target_frame >= self.nb_frames: frame = None - elif self.last_frame_index is None or \ - (target_frame < self.last_frame_index) or \ - (target_frame > self.last_frame_index + 300): + elif ( + self.last_frame_index is None + or (target_frame < self.last_frame_index) + or (target_frame > self.last_frame_index + 300) + ): frame = self.get_frame_absolut_seek(target_frame) else: frame = None for i, (frame_index, next_frame) in enumerate(self.next_frame()): - #~ print(" ", i, "NEXT at frame", next_frame, "at ts:", next_frame.pts,next_frame.dts) + # ~ print(" ", i, "NEXT at frame", next_frame, "at ts:", next_frame.pts,next_frame.dts) if frame_index is None or frame_index >= target_frame: frame = next_frame break @@ -121,22 +122,20 @@ def get_frame(self, target_frame): return frame - - def get_frame_absolut_seek(self, target_frame): - #~ print('get_frame_absolut_seek', target_frame) - #~ print('self.active_frame', self.active_frame) + # ~ print('get_frame_absolut_seek', target_frame) + # ~ print('self.active_frame', self.active_frame) - #~ if target_frame != self.active_frame: - #~ print('YEP') - #~ return - #~ print 'seeking to', target_frame + # ~ if target_frame != self.active_frame: + # ~ print('YEP') + # ~ return + # ~ print 'seeking to', target_frame seek_frame = target_frame rate = self.rate time_base = self.time_base - #~ print 'ici', rate, time_base, 'target_frame', target_frame + # ~ print 'ici', rate, time_base, 'target_frame', target_frame frame = None reseek = 250 @@ -144,12 +143,11 @@ def get_frame_absolut_seek(self, target_frame): original_target_frame_pts = None while reseek >= 0: - # convert seek_frame to pts - target_sec = seek_frame * 1/rate + target_sec = seek_frame * 1 / rate target_pts = int(target_sec / time_base) + self.start_time - #~ print 'la', 'target_sec', target_sec, 'target_pts', target_pts + # ~ print 'la', 'target_sec', target_sec, 'target_pts', target_pts if original_target_frame_pts is None: original_target_frame_pts = target_pts @@ -166,13 +164,12 @@ def get_frame_absolut_seek(self, target_frame): frame_cache = [] for i, (frame_index, frame) in enumerate(self.next_frame()): + # ~ # optimization if the time slider has changed, the requested frame no longer valid + # ~ if target_frame != self.active_frame: + # ~ print('YEP0 target_frame != self.active_frame', target_frame, self.active_frame) + # ~ return - #~ # optimization if the time slider has changed, the requested frame no longer valid - #~ if target_frame != self.active_frame: - #~ print('YEP0 target_frame != self.active_frame', target_frame, self.active_frame) - #~ return - - #~ print(" ", i, "at frame", frame_index, "at ts:", frame.pts,frame.dts,"target:", target_pts, 'orig', original_target_frame_pts) + # ~ print(" ", i, "at frame", frame_index, "at ts:", frame.pts,frame.dts,"target:", target_pts, 'orig', original_target_frame_pts) if frame_index is None: pass @@ -185,53 +182,58 @@ def get_frame_absolut_seek(self, target_frame): # Check if we over seeked, if we over seekd we need to seek to a earlier time # but still looking for the target frame if frame_index != target_frame: - if frame_index is None: - over_seek = '?' + over_seek = "?" else: over_seek = frame_index - target_frame if frame_index > target_frame: - - #~ print over_seek, frame_cache + # ~ print over_seek, frame_cache if over_seek <= len(frame_cache): - #~ print "over seeked by %i, using cache" % over_seek + # ~ print "over seeked by %i, using cache" % over_seek frame = frame_cache[-over_seek] break - seek_frame -= 1 reseek -= 1 - #~ print "over seeked by %s, backtracking.. seeking: %i target: %i retry: %i" % (str(over_seek), seek_frame, target_frame, reseek) + # ~ print "over seeked by %s, backtracking.. seeking: %i target: %i retry: %i" % (str(over_seek), seek_frame, target_frame, reseek) else: break - #~ print('ici frame', frame) + # ~ print('ici frame', frame) if reseek < 0: - #~ print('YEP reseek < 0') - #~ raise ValueError("seeking failed %i" % frame_index) + # ~ print('YEP reseek < 0') + # ~ raise ValueError("seeking failed %i" % frame_index) return None # frame at this point should be the correct frame if frame: - return frame else: return None - #~ raise ValueError("seeking failed %i" % target_frame) + # ~ raise ValueError("seeking failed %i" % target_frame) def get_frame_count(self): - frame_count = None if self.stream.frames: frame_count = self.stream.frames elif self.stream.duration: - frame_count = pts_to_frame(self.stream.duration, float(self.stream.time_base), get_frame_rate(self.stream), 0) + frame_count = pts_to_frame( + self.stream.duration, + float(self.stream.time_base), + get_frame_rate(self.stream), + 0, + ) elif self.file.duration: - frame_count = pts_to_frame(self.file.duration, 1/float(AV_TIME_BASE), get_frame_rate(self.stream), 0) + frame_count = pts_to_frame( + self.file.duration, + 1 / float(AV_TIME_BASE), + get_frame_rate(self.stream), + 0, + ) else: raise ValueError("Unable to determine number for frames") @@ -240,7 +242,7 @@ def get_frame_count(self): retry = 100 while retry: - target_sec = seek_frame * 1/ self.rate + target_sec = seek_frame * 1 / self.rate target_pts = int(target_sec / self.time_base) + self.start_time try: @@ -253,7 +255,7 @@ def get_frame_count(self): frame_index = None for frame_index, frame in self.next_frame(): - #~ print frame_index, frame + # ~ print frame_index, frame continue if not frame_index is None: @@ -262,23 +264,21 @@ def get_frame_count(self): seek_frame -= 1 retry -= 1 - - #~ print "frame count seeked", frame_index, "container frame count", frame_count + # ~ print "frame count seeked", frame_index, "container frame count", frame_count return frame_index or frame_count def set_file(self, path): - #~ print(path, type(path)) + # ~ print(path, type(path)) self.file = av.open(path) - #~ for s in self.file.streams: - #~ print(s.type) - #~ self.stream = next(s for s in self.file.streams if s.type == b'video') #py2 - self.stream = next(s for s in self.file.streams if s.type == 'video') + # ~ for s in self.file.streams: + # ~ print(s.type) + # ~ self.stream = next(s for s in self.file.streams if s.type == b'video') #py2 + self.stream = next(s for s in self.file.streams if s.type == "video") self.rate = get_frame_rate(self.stream) self.time_base = float(self.stream.time_base) - index, first_frame = next(self.next_frame()) try: @@ -298,18 +298,17 @@ def set_file(self, path): self.start_time = pts or first_frame.dts - #~ print("First pts", pts, self.stream.start_time, first_frame) + # ~ print("First pts", pts, self.stream.start_time, first_frame) - #self.nb_frames = get_frame_count(self.file, self.stream) + # self.nb_frames = get_frame_count(self.file, self.stream) self.nb_frames = self.get_frame_count() +class MultiVideoFileSource(BaseDataSource): + type = "video" - -class MultiVideoFileSource( BaseDataSource): - type = 'video' def __init__(self, video_filenames, video_times=None): - assert HAVE_AV, 'PyAv is not installed' + assert HAVE_AV, "PyAv is not installed" self.video_filenames = video_filenames self.video_times = video_times @@ -326,7 +325,7 @@ def __init__(self, video_filenames, video_times=None): self.nb_frames.append(fg.get_frame_count()) self.rates.append(fg.rate) self.t_starts.append(fg.start_time) - self.t_stops.append(fg.start_time+fg.stream.duration*fg.time_base) + self.t_stops.append(fg.start_time + fg.stream.duration * fg.time_base) self._t_start = min(self.t_starts) self._t_stop = max(self.t_stops) @@ -336,7 +335,7 @@ def nb_channel(self): return len(self.video_filenames) def get_channel_name(self, chan=0): - return 'video {}'.format(chan) + return "video {}".format(chan) @property def t_start(self): @@ -347,12 +346,11 @@ def t_stop(self): return self._t_stop def time_to_frame_index(self, i, t): - # if t is between frames, both methods return # the index of the frame *preceding* t if self.video_times is None: - frame_index = int((t-self.t_starts[i])*self.rates[i]) + frame_index = int((t - self.t_starts[i]) * self.rates[i]) else: - frame_index = np.searchsorted(self.video_times[i], t, side='right') - 1 + frame_index = np.searchsorted(self.video_times[i], t, side="right") - 1 return frame_index From e0d6d61639a13320a9599f7c332b94cab7c74d80 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Tue, 9 May 2023 22:39:44 -0500 Subject: [PATCH 05/12] Update deprecated usage of matplotlib.cm.get_cmap --- ephyviewer/base.py | 193 +++++----- ephyviewer/datasource/epochs.py | 4 +- ephyviewer/datasource/signals.py | 4 +- ephyviewer/epochencoder.py | 533 +++++++++++++++----------- ephyviewer/epochviewer.py | 101 ++--- ephyviewer/spectrogramviewer.py | 354 +++++++++-------- ephyviewer/spiketrainviewer.py | 125 +++--- ephyviewer/timefreqviewer.py | 355 +++++++++-------- ephyviewer/traceimageviewer.py | 5 +- ephyviewer/traceviewer.py | 635 +++++++++++++++++++------------ ephyviewer/videoviewer.py | 192 ++++------ 11 files changed, 1426 insertions(+), 1075 deletions(-) diff --git a/ephyviewer/base.py b/ephyviewer/base.py index 6888f22..2ccb668 100644 --- a/ephyviewer/base.py +++ b/ephyviewer/base.py @@ -1,37 +1,34 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) from .myqt import QT import pyqtgraph as pg import numpy as np -import matplotlib.cm +from matplotlib import colormaps import matplotlib.colors import weakref - - class ViewerBase(QT.QWidget): - time_changed = QT.pyqtSignal(float) - def __init__(self, name='', source=None, **kargs): + def __init__(self, name="", source=None, **kargs): QT.QWidget.__init__(self, **kargs) self.name = name self.source = source - self.t = 0. + self.t = 0.0 def seek(self, t): self.t = t self.refresh() def refresh(self): - #overwrite this one - raise(NotImplementedError) + # overwrite this one + raise (NotImplementedError) def set_settings(self, value): pass @@ -43,37 +40,39 @@ def auto_scale(self): pass - class MyViewBox(pg.ViewBox): doubleclicked = QT.pyqtSignal() ygain_zoom = QT.pyqtSignal(float) xsize_zoom = QT.pyqtSignal(float) + def __init__(self, *args, **kwds): pg.ViewBox.__init__(self, *args, **kwds) self.disableAutoRange() + def mouseClickEvent(self, ev): if ev.double(): ev.accept() self.doubleclicked.emit() else: ev.ignore() + def wheelEvent(self, ev, axis=None): if ev.modifiers() == QT.Qt.ControlModifier: - z = 5. if ev.delta()>0 else 1/5. + z = 5.0 if ev.delta() > 0 else 1 / 5.0 else: - z = 1.1 if ev.delta()>0 else 1/1.1 + z = 1.1 if ev.delta() > 0 else 1 / 1.1 self.ygain_zoom.emit(z) ev.accept() + def mouseDragEvent(self, ev, axis=None): - if ev.button()== QT.RightButton: - self.xsize_zoom.emit((ev.pos()-ev.lastPos()).x()) + if ev.button() == QT.RightButton: + self.xsize_zoom.emit((ev.pos() - ev.lastPos()).x()) else: pass ev.accept() class BaseMultiChannelViewer(ViewerBase): - _default_params = None _default_by_channel_params = None _ControllerClass = None @@ -87,16 +86,29 @@ def make_params(self): # Create parameters all = [] for i in range(self.source.nb_channel): - #TODO add name, hadrware index, id - name = 'ch{}'.format(i) - children =[{'name': 'name', 'type': 'str', 'value': self.source.get_channel_name(i), 'readonly':True}] + # TODO add name, hadrware index, id + name = "ch{}".format(i) + children = [ + { + "name": "name", + "type": "str", + "value": self.source.get_channel_name(i), + "readonly": True, + } + ] children += self._default_by_channel_params - all.append({'name': name, 'type': 'group', 'children': children}) - self.by_channel_params = pg.parametertree.Parameter.create(name='Channels', type='group', children=all) - self.params = pg.parametertree.Parameter.create(name='Global options', - type='group', children=self._default_params) - self.all_params = pg.parametertree.Parameter.create(name='all param', - type='group', children=[self.params, self.by_channel_params]) + all.append({"name": name, "type": "group", "children": children}) + self.by_channel_params = pg.parametertree.Parameter.create( + name="Channels", type="group", children=all + ) + self.params = pg.parametertree.Parameter.create( + name="Global options", type="group", children=self._default_params + ) + self.all_params = pg.parametertree.Parameter.create( + name="all param", + type="group", + children=[self.params, self.by_channel_params], + ) self.all_params.sigTreeStateChanged.connect(self.on_param_change) def set_layout(self, useOpenGL=None): @@ -106,7 +118,7 @@ def set_layout(self, useOpenGL=None): self.viewBox = MyViewBox() - self.graphicsview = pg.GraphicsView(useOpenGL=useOpenGL) + self.graphicsview = pg.GraphicsView(useOpenGL=useOpenGL) self.mainlayout.addWidget(self.graphicsview) self.plot = pg.PlotItem(viewBox=self.viewBox) @@ -129,13 +141,13 @@ def on_param_change(self): self.refresh() def set_xsize(self, xsize): - #~ print(self.__class__.__name__, 'set_xsize', xsize) - if 'xsize' in [p.name() for p in self.params.children()]: - self.params['xsize'] = xsize + # ~ print(self.__class__.__name__, 'set_xsize', xsize) + if "xsize" in [p.name() for p in self.params.children()]: + self.params["xsize"] = xsize def set_settings(self, value): actual_value = self.all_params.saveState() - #~ print('same tree', same_param_tree(actual_value, value)) + # ~ print('same tree', same_param_tree(actual_value, value)) if same_param_tree(actual_value, value): # this prevent restore something that is not same tree # as actual. Possible when new features. @@ -143,35 +155,34 @@ def set_settings(self, value): self.all_params.restoreState(value) self.all_params.blockSignals(False) else: - print('Not possible to restore setiings') + print("Not possible to restore setiings") def get_settings(self): return self.all_params.saveState() def same_param_tree(tree1, tree2): - children1 = list(tree1['children'].keys()) - children2 = list(tree2['children'].keys()) + children1 = list(tree1["children"].keys()) + children2 = list(tree2["children"].keys()) if len(children1) != len(children2): return False for k1, k2 in zip(children1, children2): - if k1!=k2: + if k1 != k2: return False - if 'children' in tree1['children'][k1]: - if not 'children' in tree2['children'][k2]: + if "children" in tree1["children"][k1]: + if not "children" in tree2["children"][k2]: return False - #~ print('*'*5) - #~ print('Recursif', k1) - if not same_param_tree(tree1['children'][k1], tree2['children'][k2]): + # ~ print('*'*5) + # ~ print('Recursif', k1) + if not same_param_tree(tree1["children"][k1], tree2["children"][k2]): return False return True class Base_ParamController(QT.QWidget): - xsize_zoomed = QT.pyqtSignal(float) def __init__(self, parent=None, viewer=None): @@ -184,10 +195,9 @@ def __init__(self, parent=None, viewer=None): # layout self.mainlayout = QT.QVBoxLayout() self.setLayout(self.mainlayout) - t = 'Options for {}'.format(self.viewer.name) + t = "Options for {}".format(self.viewer.name) self.setWindowTitle(t) - self.mainlayout.addWidget(QT.QLabel(''+t+'<\b>')) - + self.mainlayout.addWidget(QT.QLabel("" + t + "<\b>")) @property def viewer(self): @@ -199,13 +209,12 @@ def source(self): def apply_xsize_zoom(self, xmove): MIN_XSIZE = 1e-6 - factor = xmove/100. + factor = xmove / 100.0 factor = max(factor, -0.5) factor = min(factor, 1) - newsize = self.viewer.params['xsize']*(factor+1.) - self.viewer.params['xsize'] = max(newsize, MIN_XSIZE) - self.xsize_zoomed.emit(self.viewer.params['xsize']) - + newsize = self.viewer.params["xsize"] * (factor + 1.0) + self.viewer.params["xsize"] = max(newsize, MIN_XSIZE) + self.xsize_zoomed.emit(self.viewer.params["xsize"]) class Base_MultiChannel_ParamController(Base_ParamController): @@ -215,7 +224,6 @@ class Base_MultiChannel_ParamController(Base_ParamController): def __init__(self, parent=None, viewer=None, with_visible=True, with_color=True): Base_ParamController.__init__(self, parent=parent, viewer=viewer) - h = QT.QHBoxLayout() self.mainlayout.addLayout(h) @@ -229,24 +237,27 @@ def __init__(self, parent=None, viewer=None, with_visible=True, with_color=True) self.tree_by_channel_params = pg.parametertree.ParameterTree() self.tree_by_channel_params.header().hide() h.addWidget(self.tree_by_channel_params) - self.tree_by_channel_params.setParameters(self.viewer.by_channel_params, showTop=True) + self.tree_by_channel_params.setParameters( + self.viewer.by_channel_params, showTop=True + ) v = QT.QVBoxLayout() h.addLayout(v) - #~ but = QT.PushButton('default params') - #~ v.addWidget(but) - #~ but.clicked.connect(self.reset_to_default) + # ~ but = QT.PushButton('default params') + # ~ v.addWidget(but) + # ~ but.clicked.connect(self.reset_to_default) - if hasattr(self.viewer, 'auto_scale'): - but = QT.PushButton('Auto scale') + if hasattr(self.viewer, "auto_scale"): + but = QT.PushButton("Auto scale") v.addWidget(but) but.clicked.connect(self.auto_scale_viewer) - if with_visible: - if self.source.nb_channel>1: - v.addWidget(QT.QLabel('Select channel...')) - names = [p.name() + ': '+p['name'] for p in self.viewer.by_channel_params] + if self.source.nb_channel > 1: + v.addWidget(QT.QLabel("Select channel...")) + names = [ + p.name() + ": " + p["name"] for p in self.viewer.by_channel_params + ] self.qlist = QT.QListWidget() v.addWidget(self.qlist, 2) self.qlist.addItems(names) @@ -255,34 +266,41 @@ def __init__(self, parent=None, viewer=None, with_visible=True, with_color=True) for i in range(len(names)): self.qlist.item(i).setSelected(True) - v.addWidget(QT.QLabel('and apply...<\b>')) + v.addWidget(QT.QLabel("and apply...<\b>")) - - but = QT.QPushButton('set visble') + but = QT.QPushButton("set visble") v.addWidget(but) but.clicked.connect(self.on_set_visible) self.channel_visibility_changed.connect(self.on_channel_visibility_changed) if with_color: - v.addWidget(QT.QLabel('Set color<\b>')) + v.addWidget(QT.QLabel("Set color<\b>")) h = QT.QHBoxLayout() - but = QT.QPushButton('Progressive') + but = QT.QPushButton("Progressive") but.clicked.connect(self.on_automatic_color) - h.addWidget(but,4) + h.addWidget(but, 4) self.combo_cmap = QT.QComboBox() - self.combo_cmap.addItems(['Accent', 'Dark2','jet', 'prism', 'hsv', ]) - h.addWidget(self.combo_cmap,1) + self.combo_cmap.addItems( + [ + "Accent", + "Dark2", + "jet", + "prism", + "hsv", + ] + ) + h.addWidget(self.combo_cmap, 1) v.addLayout(h) self.channel_color_changed.connect(self.on_channel_color_changed) - #~ def reset_to_default(self): - #~ self.viewer.make_params() - #~ self.tree_params.setParameters(self.viewer.params, showTop=True) - #~ self.tree_by_channel_params.setParameters(self.viewer.by_channel_params, showTop=True) - #~ ## self.viewer.on_param_change() - #~ self.viewer.refresh() + # ~ def reset_to_default(self): + # ~ self.viewer.make_params() + # ~ self.tree_params.setParameters(self.viewer.params, showTop=True) + # ~ self.tree_by_channel_params.setParameters(self.viewer.by_channel_params, showTop=True) + # ~ ## self.viewer.on_param_change() + # ~ self.viewer.refresh() def auto_scale_viewer(self): self.viewer.auto_scale() @@ -290,24 +308,26 @@ def auto_scale_viewer(self): @property def selected(self): selected = np.ones(self.viewer.source.nb_channel, dtype=bool) - if self.viewer.source.nb_channel>1: + if self.viewer.source.nb_channel > 1: selected[:] = False selected[[ind.row() for ind in self.qlist.selectedIndexes()]] = True return selected - @property def visible_channels(self): - visible = [self.viewer.by_channel_params['ch{}'.format(i), 'visible'] for i in range(self.source.nb_channel)] - return np.array(visible, dtype='bool') + visible = [ + self.viewer.by_channel_params["ch{}".format(i), "visible"] + for i in range(self.source.nb_channel) + ] + return np.array(visible, dtype="bool") def on_set_visible(self): # apply self.viewer.by_channel_params.blockSignals(True) visibles = self.selected - for i,param in enumerate(self.viewer.by_channel_params.children()): - param['visible'] = visibles[i] + for i, param in enumerate(self.viewer.by_channel_params.children()): + param["visible"] = visibles[i] self.viewer.by_channel_params.blockSignals(False) self.channel_visibility_changed.emit() @@ -315,21 +335,24 @@ def on_set_visible(self): def on_double_clicked(self, index): self.viewer.by_channel_params.blockSignals(True) visibles = self.selected - for i,param in enumerate(self.viewer.by_channel_params.children()): - param['visible'] = (i==index.row()) + for i, param in enumerate(self.viewer.by_channel_params.children()): + param["visible"] = i == index.row() self.viewer.by_channel_params.blockSignals(False) self.channel_visibility_changed.emit() - def on_automatic_color(self, cmap_name = None): + def on_automatic_color(self, cmap_name=None): cmap_name = str(self.combo_cmap.currentText()) n = np.sum(self.selected) - if n==0: return - cmap = matplotlib.cm.get_cmap(cmap_name , n) + if n == 0: + return + cmap = colormaps.get_cmap(cmap_name).resampled(n) self.viewer.by_channel_params.blockSignals(True) for i, c in enumerate(np.nonzero(self.selected)[0]): - color = [ int(e*255) for e in matplotlib.colors.ColorConverter().to_rgb(cmap(i)) ] - self.viewer.by_channel_params['ch{}'.format(c), 'color'] = color + color = [ + int(e * 255) for e in matplotlib.colors.ColorConverter().to_rgb(cmap(i)) + ] + self.viewer.by_channel_params["ch{}".format(c), "color"] = color self.viewer.all_params.sigTreeStateChanged.connect(self.viewer.on_param_change) self.viewer.by_channel_params.blockSignals(False) self.channel_color_changed.emit() diff --git a/ephyviewer/datasource/epochs.py b/ephyviewer/datasource/epochs.py index 8effad5..c095f50 100644 --- a/ephyviewer/datasource/epochs.py +++ b/ephyviewer/datasource/epochs.py @@ -3,7 +3,7 @@ import os import numpy as np -import matplotlib.cm +from matplotlib import colormaps import matplotlib.colors try: @@ -103,7 +103,7 @@ def __init__( # TODO: colors should be managed directly by EpochEncoder if color_labels is None: n = len(self.possible_labels) - cmap = matplotlib.cm.get_cmap("Dark2", n) + cmap = colormaps.get_cmap("Dark2").resampled(n) color_labels = [ matplotlib.colors.ColorConverter().to_rgb(cmap(i)) for i in range(n) ] diff --git a/ephyviewer/datasource/signals.py b/ephyviewer/datasource/signals.py index 1908bb8..10006cd 100644 --- a/ephyviewer/datasource/signals.py +++ b/ephyviewer/datasource/signals.py @@ -7,7 +7,7 @@ from .sourcebase import BaseDataSource -import matplotlib.cm +from matplotlib import colormaps import matplotlib.colors @@ -103,7 +103,7 @@ def __init__( if self.scatter_colors is None: self.scatter_colors = {} n = len(self._labels) - colors = matplotlib.cm.get_cmap("Accent", n) + colors = colormaps.get_cmap("Accent").resampled(n) for i, k in enumerate(self._labels): self.scatter_colors[k] = matplotlib.colors.to_hex(colors(i)) diff --git a/ephyviewer/epochencoder.py b/ephyviewer/epochencoder.py index fc8c63d..e738028 100644 --- a/ephyviewer/epochencoder.py +++ b/ephyviewer/epochencoder.py @@ -1,14 +1,11 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) from copy import deepcopy from collections import deque import numpy as np -import matplotlib.cm -import matplotlib.colors - from .myqt import QT from . import tools @@ -20,20 +17,30 @@ default_params = [ - {'name': 'xsize', 'type': 'float', 'value': 3., 'step': 0.1}, - {'name': 'xratio', 'type': 'float', 'value': 0.3, 'step': 0.1, 'limits': (0,1)}, - {'name': 'background_color', 'type': 'color', 'value': 'k'}, - {'name': 'vline_color', 'type': 'color', 'value': '#FFFFFFAA'}, - {'name': 'label_fill_color', 'type': 'color', 'value': '#222222DD'}, - {'name': 'label_size', 'type': 'int', 'value': 8, 'limits': (1,np.inf)}, - {'name': 'new_epoch_step', 'type': 'float', 'value': .1, 'step': 0.1, 'limits':(0,np.inf)}, - {'name': 'exclusive_mode', 'type': 'bool', 'value': True}, - {'name': 'view_mode', 'type': 'list', 'value':'stacked', 'limits' : ['stacked', 'flat']}, - {'name': 'keys_as_ticks', 'type': 'bool', 'value': True}, - {'name': 'undo_history_size', 'type': 'int', 'value': 500, 'limits': (1, np.inf)}, - - #~ {'name': 'display_labels', 'type': 'bool', 'value': True}, - ] + {"name": "xsize", "type": "float", "value": 3.0, "step": 0.1}, + {"name": "xratio", "type": "float", "value": 0.3, "step": 0.1, "limits": (0, 1)}, + {"name": "background_color", "type": "color", "value": "k"}, + {"name": "vline_color", "type": "color", "value": "#FFFFFFAA"}, + {"name": "label_fill_color", "type": "color", "value": "#222222DD"}, + {"name": "label_size", "type": "int", "value": 8, "limits": (1, np.inf)}, + { + "name": "new_epoch_step", + "type": "float", + "value": 0.1, + "step": 0.1, + "limits": (0, np.inf), + }, + {"name": "exclusive_mode", "type": "bool", "value": True}, + { + "name": "view_mode", + "type": "list", + "value": "stacked", + "limits": ["stacked", "flat"], + }, + {"name": "keys_as_ticks", "type": "bool", "value": True}, + {"name": "undo_history_size", "type": "int", "value": 500, "limits": (1, np.inf)}, + # ~ {'name': 'display_labels', 'type': 'bool', 'value': True}, +] SEEK_COL = 0 START_COL = 1 @@ -45,7 +52,6 @@ DELETE_COL = 7 - class EpochEncoder_ParamController(Base_ParamController): def __init__(self, parent=None, viewer=None, with_visible=True, with_color=True): Base_ParamController.__init__(self, parent=parent, viewer=viewer) @@ -60,49 +66,57 @@ def __init__(self, parent=None, viewer=None, with_visible=True, with_color=True) self.tree_params = pg.parametertree.ParameterTree() self.tree_params.setParameters(self.viewer.params, showTop=True) self.tree_params.header().hide() - self.v1.addWidget(self.tree_params, stretch = 1) + self.v1.addWidget(self.tree_params, stretch=1) self.tree_label_params = pg.parametertree.ParameterTree() self.tree_label_params.setParameters(self.viewer.by_label_params, showTop=True) self.tree_label_params.header().hide() - self.v1.addWidget(self.tree_label_params, stretch = 3) + self.v1.addWidget(self.tree_label_params, stretch=3) - self.btn_new_label = QT.PushButton('New label') + self.btn_new_label = QT.PushButton("New label") self.btn_new_label.clicked.connect(self.create_new_label) self.v1.addWidget(self.btn_new_label) def create_new_label(self): - label, ok = QT.QInputDialog.getText(self, 'New label', 'Enter a new label:') + label, ok = QT.QInputDialog.getText(self, "New label", "Enter a new label:") # abort if user cancelled or typed nothing - if not ok or not label: return + if not ok or not label: + return # abort if label already exists - if label in self.viewer.source.possible_labels: return + if label in self.viewer.source.possible_labels: + return # TODO: determine color from a color map? - color = (120, 120, 120) # gray + color = (120, 120, 120) # gray # add new label to WritableEpochSource self.viewer.source.possible_labels.append(label) self.viewer.source.color_labels.append(color) # get next unused shortcut key - used_keys = [p['key'] for p in self.viewer.by_label_params] - unused_keys = [k for k in '1234567890' if k not in used_keys] - key = unused_keys[0] if len(unused_keys)>0 else '' + used_keys = [p["key"] for p in self.viewer.by_label_params] + unused_keys = [k for k in "1234567890" if k not in used_keys] + key = unused_keys[0] if len(unused_keys) > 0 else "" # assign shortcuts without and with modifier key self.viewer.assign_label_shortcuts(label, key) # add new label to params - name = 'label{}'.format(len(self.viewer.source.possible_labels)-1) + name = "label{}".format(len(self.viewer.source.possible_labels) - 1) children = [ - {'name': 'name', 'type': 'str', 'value': label, 'readonly':True}, - {'name': 'color', 'type': 'color', 'value': self.source.color_by_label(label)}, - {'name': 'key', 'type': 'str', 'value': key}, + {"name": "name", "type": "str", "value": label, "readonly": True}, + { + "name": "color", + "type": "color", + "value": self.source.color_by_label(label), + }, + {"name": "key", "type": "str", "value": key}, ] - self.viewer.by_label_params.addChild({'name': name, 'type': 'group', 'children': children}) + self.viewer.by_label_params.addChild( + {"name": name, "type": "group", "children": children} + ) # clear and redraw plot to update labels and plot range self.viewer.plot.clear() @@ -116,7 +130,6 @@ def create_new_label(self): self.viewer.refresh_table() - class EpochEncoder(ViewerBase): _default_params = default_params @@ -139,21 +152,18 @@ def __init__(self, **kargs): self.on_range_visibility_changed(refresh=False) - - self.thread = QT.QThread(parent=self) self.datagrabber = DataGrabber(source=self.source) self.datagrabber.moveToThread(self.thread) self.thread.start() - self.datagrabber.data_ready.connect(self.on_data_ready) self.request_data.connect(self.datagrabber.on_request_data) self.table_button_fixed_width = 32 self.refresh_table() - self.history = deque(maxlen=self.params['undo_history_size']) + self.history = deque(maxlen=self.params["undo_history_size"]) self.history_position = 0 self.changes_since_save = 0 self.append_history() @@ -163,41 +173,45 @@ def __init__(self, **kargs): self.changes_since_save = 0 self.save_action.setEnabled(False) - def make_params(self): # Create parameters - self.params = pg.parametertree.Parameter.create(name='Global options', - type='group', children=self._default_params) - self.params.param('xsize').setLimits((0, np.inf)) + self.params = pg.parametertree.Parameter.create( + name="Global options", type="group", children=self._default_params + ) + self.params.param("xsize").setLimits((0, np.inf)) - - keys = '1234567890' + keys = "1234567890" all = [] for i, label in enumerate(self.source.possible_labels): # get string for shortcut key - key = keys[i] if i', 'Set stop >']): + for i, but_text in enumerate(["Set start >", "Set stop >"]): but = QT.QPushButton(but_text) buts.append(but) range_group_box_layout.addWidget(but, i, 0) - spinbox = pg.SpinBox(value=float(i), decimals = 8, bounds = (-np.inf, np.inf),step = 0.05, siPrefix=False, int=False) - if 'compactHeight' in spinbox.opts: # pyqtgraph >= 0.11.0 + spinbox = pg.SpinBox( + value=float(i), + decimals=8, + bounds=(-np.inf, np.inf), + step=0.05, + siPrefix=False, + int=False, + ) + if "compactHeight" in spinbox.opts: # pyqtgraph >= 0.11.0 spinbox.setOpts(compactHeight=False) range_group_box_layout.addWidget(spinbox, i, 1) - spinbox.setSizePolicy(QT.QSizePolicy.Preferred, QT.QSizePolicy.Preferred, ) + spinbox.setSizePolicy( + QT.QSizePolicy.Preferred, + QT.QSizePolicy.Preferred, + ) spinbox.valueChanged.connect(self.on_spin_limit_changed) spinboxs.append(spinbox) self.spin_limit1, self.spin_limit2 = spinboxs @@ -254,25 +280,27 @@ def set_layout(self): buts[1].clicked.connect(self.set_limit2) limit1_shortcut = QT.QShortcut(self) - limit1_shortcut.setKey(QT.QKeySequence('[')) + limit1_shortcut.setKey(QT.QKeySequence("[")) limit1_shortcut.activated.connect(buts[0].click) - buts[0].setToolTip('Set start with shortcut: [') + buts[0].setToolTip("Set start with shortcut: [") limit2_shortcut = QT.QShortcut(self) - limit2_shortcut.setKey(QT.QKeySequence(']')) + limit2_shortcut.setKey(QT.QKeySequence("]")) limit2_shortcut.activated.connect(buts[1].click) - buts[1].setToolTip('Set stop with shortcut: ]') + buts[1].setToolTip("Set stop with shortcut: ]") self.combo_labels = QT.QComboBox() self.combo_labels.addItems(self.source.possible_labels) range_group_box_layout.addWidget(self.combo_labels, 2, 0, 1, 2) - self.but_apply_region = QT.PushButton('Insert within range') + self.but_apply_region = QT.PushButton("Insert within range") range_group_box_layout.addWidget(self.but_apply_region, 3, 0, 1, 2) self.but_apply_region.clicked.connect(self.apply_region) - self.but_apply_region.setToolTip('Insert with customizable shortcuts (see options)') + self.but_apply_region.setToolTip( + "Insert with customizable shortcuts (see options)" + ) - self.but_del_region = QT.PushButton('Clear within range') + self.but_del_region = QT.PushButton("Clear within range") range_group_box_layout.addWidget(self.but_del_region, 4, 0, 1, 2) self.but_del_region.clicked.connect(self.delete_region) @@ -285,10 +313,10 @@ def set_layout(self): self.table_widget.cellClicked.connect(self.on_table_cell_click) self.table_widget.cellChanged.connect(self.on_table_cell_change) self.table_widget_icons = { - 'seek': QT.QIcon(':/epoch-encoder-seek.svg'), - 'split': QT.QIcon(':/epoch-encoder-split.svg'), - 'duplicate': QT.QIcon(':/epoch-encoder-duplicate.svg'), - 'delete': QT.QIcon(':/epoch-encoder-delete.svg'), + "seek": QT.QIcon(":/epoch-encoder-seek.svg"), + "split": QT.QIcon(":/epoch-encoder-split.svg"), + "duplicate": QT.QIcon(":/epoch-encoder-duplicate.svg"), + "delete": QT.QIcon(":/epoch-encoder-delete.svg"), } # Toolbar @@ -298,113 +326,131 @@ def set_layout(self): self.toolbar.setIconSize(QT.QSize(16, 16)) self.mainlayout.addWidget(self.toolbar) - self.toggle_controls_visibility_action = self.toolbar.addAction('Hide controls', self.on_controls_visibility_changed) - self.toggle_controls_visibility_button = self.toolbar.widgetForAction(self.toggle_controls_visibility_action) + self.toggle_controls_visibility_action = self.toolbar.addAction( + "Hide controls", self.on_controls_visibility_changed + ) + self.toggle_controls_visibility_button = self.toolbar.widgetForAction( + self.toggle_controls_visibility_action + ) self.toggle_controls_visibility_button.setArrowType(QT.UpArrow) self.toolbar.addSeparator() - self.save_action = self.toolbar.addAction('Save', self.on_save) - self.save_action.setShortcut('Ctrl+s') # automatically converted to Cmd+s on Mac - self.save_action.setToolTip('Save with shortcut: Ctrl/Cmd+s') - self.save_action.setIcon(QT.QIcon(':/epoch-encoder-save.svg')) + self.save_action = self.toolbar.addAction("Save", self.on_save) + self.save_action.setShortcut( + "Ctrl+s" + ) # automatically converted to Cmd+s on Mac + self.save_action.setToolTip("Save with shortcut: Ctrl/Cmd+s") + self.save_action.setIcon(QT.QIcon(":/epoch-encoder-save.svg")) - self.undo_action = self.toolbar.addAction('Undo', self.on_undo) - self.undo_action.setShortcut('Ctrl+z') # automatically converted to Cmd+z on Mac - self.undo_action.setToolTip('Undo with shortcut: Ctrl/Cmd+z') - self.undo_action.setIcon(QT.QIcon(':/epoch-encoder-undo.svg')) + self.undo_action = self.toolbar.addAction("Undo", self.on_undo) + self.undo_action.setShortcut( + "Ctrl+z" + ) # automatically converted to Cmd+z on Mac + self.undo_action.setToolTip("Undo with shortcut: Ctrl/Cmd+z") + self.undo_action.setIcon(QT.QIcon(":/epoch-encoder-undo.svg")) - self.redo_action = self.toolbar.addAction('Redo', self.on_redo) - self.redo_action.setShortcut('Ctrl+y') # automatically converted to Cmd+y on Mac - self.redo_action.setToolTip('Redo with shortcut: Ctrl/Cmd+y') - self.redo_action.setIcon(QT.QIcon(':/epoch-encoder-redo.svg')) + self.redo_action = self.toolbar.addAction("Redo", self.on_redo) + self.redo_action.setShortcut( + "Ctrl+y" + ) # automatically converted to Cmd+y on Mac + self.redo_action.setToolTip("Redo with shortcut: Ctrl/Cmd+y") + self.redo_action.setIcon(QT.QIcon(":/epoch-encoder-redo.svg")) - self.toolbar.addAction('Options', self.show_params_controller) + self.toolbar.addAction("Options", self.show_params_controller) - self.toolbar.addAction('New label', self.params_controller.create_new_label) + self.toolbar.addAction("New label", self.params_controller.create_new_label) - self.toolbar.addAction('Merge neighbors', self.on_merge_neighbors) + self.toolbar.addAction("Merge neighbors", self.on_merge_neighbors) - self.toolbar.addAction('Fill blank', self.on_fill_blank) + self.toolbar.addAction("Fill blank", self.on_fill_blank) - self.allow_overlap_action = self.toolbar.addAction('Allow overlap') - self.allow_overlap_action.setToolTip('Hold Shift when using shortcut keys to temporarily switch modes') + self.allow_overlap_action = self.toolbar.addAction("Allow overlap") + self.allow_overlap_action.setToolTip( + "Hold Shift when using shortcut keys to temporarily switch modes" + ) self.allow_overlap_action.setCheckable(True) - self.allow_overlap_action.setChecked(not self.params['exclusive_mode']) - self.allow_overlap_action.toggled.connect(lambda checked: self.params.param('exclusive_mode').setValue(not checked)) - + self.allow_overlap_action.setChecked(not self.params["exclusive_mode"]) + self.allow_overlap_action.toggled.connect( + lambda checked: self.params.param("exclusive_mode").setValue(not checked) + ) def make_param_controller(self): self.params_controller = EpochEncoder_ParamController(parent=self, viewer=self) self.params_controller.setWindowFlags(QT.Qt.Window) - def closeEvent(self, event): - if self.changes_since_save != 0: - text = 'Do you want to save epoch encoder changes before closing?' - title = 'Save?' - mb = QT.QMessageBox.question(self, title,text, - QT.QMessageBox.Ok , QT.QMessageBox.Discard) - if mb==QT.QMessageBox.Ok: + text = "Do you want to save epoch encoder changes before closing?" + title = "Save?" + mb = QT.QMessageBox.question( + self, title, text, QT.QMessageBox.Ok, QT.QMessageBox.Discard + ) + if mb == QT.QMessageBox.Ok: self.source.save() self.thread.quit() self.thread.wait() event.accept() - def initialize_plot(self): - self.region = pg.LinearRegionItem(brush='#FF00FF20') + self.region = pg.LinearRegionItem(brush="#FF00FF20") self.region.setZValue(10) self.region.setRegion((self.spin_limit1.value(), self.spin_limit2.value())) self.plot.addItem(self.region, ignoreBounds=True) self.region.sigRegionChanged.connect(self.on_region_changed) - self.vline = pg.InfiniteLine(angle=90, movable=False, pen=self.params['vline_color']) - self.vline.setZValue(1) # ensure vline is above plot elements + self.vline = pg.InfiniteLine( + angle=90, movable=False, pen=self.params["vline_color"] + ) + self.vline.setZValue(1) # ensure vline is above plot elements self.plot.addItem(self.vline) self.rect_items = [] self.label_items = [] for i, label in enumerate(self.source.possible_labels): - color = self.by_label_params['label'+str(i), 'color'] - label_item = pg.TextItem(label, color=color, anchor=(0, 0.5), border=None, fill=self.params['label_fill_color']) + color = self.by_label_params["label" + str(i), "color"] + label_item = pg.TextItem( + label, + color=color, + anchor=(0, 0.5), + border=None, + fill=self.params["label_fill_color"], + ) label_item.setZValue(11) font = label_item.textItem.font() - font.setPointSize(self.params['label_size']) + font.setPointSize(self.params["label_size"]) label_item.setFont(font) self.plot.addItem(label_item) self.label_items.append(label_item) self.viewBox.xsize_zoom.connect(self.params_controller.apply_xsize_zoom) - def on_controls_visibility_changed(self): if self.controls.isVisible(): self.controls.hide() - self.toggle_controls_visibility_action.setText('Show controls') + self.toggle_controls_visibility_action.setText("Show controls") self.toggle_controls_visibility_button.setArrowType(QT.RightArrow) else: self.controls.show() - self.toggle_controls_visibility_action.setText('Hide controls') + self.toggle_controls_visibility_action.setText("Hide controls") self.toggle_controls_visibility_button.setArrowType(QT.UpArrow) def show_params_controller(self): self.params_controller.show() def on_param_change(self): - self.allow_overlap_action.setChecked(not self.params['exclusive_mode']) - self.vline.setPen(self.params['vline_color']) + self.allow_overlap_action.setChecked(not self.params["exclusive_mode"]) + self.vline.setPen(self.params["vline_color"]) for label_item in self.label_items: font = label_item.textItem.font() - font.setPointSize(self.params['label_size']) + font.setPointSize(self.params["label_size"]) label_item.setFont(font) self.refresh() - if self.params['undo_history_size'] != self.history.maxlen: - while len(self.history) > self.params['undo_history_size']: + if self.params["undo_history_size"] != self.history.maxlen: + while len(self.history) > self.params["undo_history_size"]: if self.history_position > 0: # preferentially remove old states self.history.popleft() @@ -412,39 +458,50 @@ def on_param_change(self): else: # otherwise remove newer, undone states self.history.pop() - self.history = deque(self.history, maxlen=self.params['undo_history_size']) + self.history = deque(self.history, maxlen=self.params["undo_history_size"]) # if the saved state was bumped off either end of the history # queue, it will no longer be reachable via undo/redo, so set # changes_since_save to None if self.changes_since_save is not None: - if not (0 <= self.history_position - self.changes_since_save < len(self.history)): + if not ( + 0 + <= self.history_position - self.changes_since_save + < len(self.history) + ): self.changes_since_save = None self.refresh_toolbar() def assign_label_shortcuts(self, label, key): - if label not in self.label_shortcuts: # create new shortcuts shortcut_without_modifier = QT.QShortcut(self) - shortcut_with_modifier = QT.QShortcut(self) - shortcut_without_modifier.activated.connect(lambda: self.on_label_shortcut(label, False)) - shortcut_with_modifier .activated.connect(lambda: self.on_label_shortcut(label, True)) - self.label_shortcuts[label] = (shortcut_without_modifier, shortcut_with_modifier) + shortcut_with_modifier = QT.QShortcut(self) + shortcut_without_modifier.activated.connect( + lambda: self.on_label_shortcut(label, False) + ) + shortcut_with_modifier.activated.connect( + lambda: self.on_label_shortcut(label, True) + ) + self.label_shortcuts[label] = ( + shortcut_without_modifier, + shortcut_with_modifier, + ) else: # get existing shortcuts - shortcut_without_modifier, shortcut_with_modifier = self.label_shortcuts[label] + shortcut_without_modifier, shortcut_with_modifier = self.label_shortcuts[ + label + ] # set/change the shortcut keys shortcut_without_modifier.setKey(QT.QKeySequence(key)) - shortcut_with_modifier .setKey(QT.QKeySequence('Shift+' + key)) + shortcut_with_modifier.setKey(QT.QKeySequence("Shift+" + key)) def on_change_keys(self, refresh=True): - for i, label in enumerate(self.source.possible_labels): # get string for shortcut key - key = self.by_label_params['label'+str(i), 'key'] + key = self.by_label_params["label" + str(i), "key"] # assign shortcuts without and with modifier key self.assign_label_shortcuts(label, key) @@ -452,13 +509,12 @@ def on_change_keys(self, refresh=True): self.refresh() def set_xsize(self, xsize): - self.params['xsize'] = xsize - + self.params["xsize"] = xsize def set_settings(self, value): - #~ print('set_settings') + # ~ print('set_settings') actual_value = self.all_params.saveState() - #~ print('same tree', same_param_tree(actual_value, value)) + # ~ print('same tree', same_param_tree(actual_value, value)) if same_param_tree(actual_value, value): # this prevent restore something that is not same tree # as actual. Possible when new features. @@ -471,67 +527,71 @@ def set_settings(self, value): self.by_label_params.blockSignals(False) self.on_change_keys(refresh=False) else: - print('Not possible to restore setiings') + print("Not possible to restore setiings") def get_settings(self): return self.all_params.saveState() def refresh(self): - xsize = self.params['xsize'] - xratio = self.params['xratio'] - t_start, t_stop = self.t-xsize*xratio , self.t+xsize*(1-xratio) + xsize = self.params["xsize"] + xratio = self.params["xratio"] + t_start, t_stop = self.t - xsize * xratio, self.t + xsize * (1 - xratio) self.request_data.emit(t_start, t_stop, [0]) def on_data_ready(self, t_start, t_stop, visibles, data): - #~ print('on_data_ready', self, t_start, t_stop, visibles, data) - #~ self.plot.clear() + # ~ print('on_data_ready', self, t_start, t_stop, visibles, data) + # ~ self.plot.clear() for rect_item in self.rect_items: self.plot.removeItem(rect_item) self.rect_items = [] - self.graphicsview.setBackground(self.params['background_color']) + self.graphicsview.setBackground(self.params["background_color"]) times, durations, labels, ids = data[0] - #~ print(data) + # ~ print(data) n = len(self.source.possible_labels) for i, label in enumerate(labels): ind = self.source.possible_labels.index(label) - color = self.by_label_params['label'+str(ind), 'color'] + color = self.by_label_params["label" + str(ind), "color"] color2 = QT.QColor(color) color2.setAlpha(130) - if self.params['view_mode']=='stacked': - ypos = n-ind-1 + if self.params["view_mode"] == "stacked": + ypos = n - ind - 1 else: ypos = 0 - item = RectItem([times[i], ypos, durations[i], .9], border=color, fill=color2, id=ids[i]) + item = RectItem( + [times[i], ypos, durations[i], 0.9], + border=color, + fill=color2, + id=ids[i], + ) item.clicked.connect(self.on_rect_clicked) item.doubleclicked.connect(self.on_rect_doubleclicked) - item.setPos(times[i], ypos) + item.setPos(times[i], ypos) self.plot.addItem(item) self.rect_items.append(item) - ticks = [] for i, label_item in enumerate(self.label_items): - if self.params['view_mode']=='stacked': - color = self.by_label_params['label'+str(i), 'color'] - #~ label_item = pg.TextItem(label, color=color, anchor=(0, 0.5), border=None, fill=pg.mkColor((128,128,128, 120))) + if self.params["view_mode"] == "stacked": + color = self.by_label_params["label" + str(i), "color"] + # ~ label_item = pg.TextItem(label, color=color, anchor=(0, 0.5), border=None, fill=pg.mkColor((128,128,128, 120))) label_item.setColor(color) - label_item.fill = pg.mkBrush(self.params['label_fill_color']) - ypos = n-i-0.55 + label_item.fill = pg.mkBrush(self.params["label_fill_color"]) + ypos = n - i - 0.55 label_item.setPos(t_start, ypos) - ticks.append((ypos, self.by_label_params['label'+str(i), 'key'])) + ticks.append((ypos, self.by_label_params["label" + str(i), "key"])) label_item.show() - #~ self.plot.addItem(label_item) + # ~ self.plot.addItem(label_item) else: label_item.hide() - if self.params['keys_as_ticks'] and self.params['view_mode']=='stacked': - self.plot.getAxis('left').setTicks([ticks, []]) + if self.params["keys_as_ticks"] and self.params["view_mode"] == "stacked": + self.plot.getAxis("left").setTicks([ticks, []]) else: - self.plot.getAxis('left').setTicks([]) + self.plot.getAxis("left").setTicks([]) if self.range_group_box.isChecked(): self.region.show() @@ -539,14 +599,13 @@ def on_data_ready(self, t_start, t_stop, visibles, data): self.region.hide() self.vline.setPos(self.t) - self.plot.setXRange( t_start, t_stop, padding = 0.0) - if self.params['view_mode']=='stacked': - self.plot.setYRange( 0, n) + self.plot.setXRange(t_start, t_stop, padding=0.0) + if self.params["view_mode"] == "stacked": + self.plot.setYRange(0, n) else: - self.plot.setYRange( 0, 1) + self.plot.setYRange(0, 1) def on_label_shortcut(self, label, modifier_used): - range_selection_is_enabled = self.range_group_box.isChecked() if range_selection_is_enabled: @@ -556,12 +615,14 @@ def on_label_shortcut(self, label, modifier_used): duration = t_stop - t_start else: # use current time and step size to get end of new epoch - duration = self.params['new_epoch_step'] + duration = self.params["new_epoch_step"] t_start = self.t t_stop = self.t + duration # delete existing epochs in the region where the new epoch will be inserted - if (self.params['exclusive_mode'] and not modifier_used) or (not self.params['exclusive_mode'] and modifier_used): + if (self.params["exclusive_mode"] and not modifier_used) or ( + not self.params["exclusive_mode"] and modifier_used + ): self.source.delete_in_between(t_start, t_stop) # create the new epoch @@ -583,18 +644,24 @@ def on_merge_neighbors(self): self.refresh_table() def on_fill_blank(self): - params = [{'name': 'method', 'type': 'list', 'value':'from_left', 'limits' : ['from_left', 'from_right', 'from_nearest']}] - dia = tools.ParamDialog(params, title='Fill blank method', parent=self) + params = [ + { + "name": "method", + "type": "list", + "value": "from_left", + "limits": ["from_left", "from_right", "from_nearest"], + } + ] + dia = tools.ParamDialog(params, title="Fill blank method", parent=self) dia.resize(300, 100) if dia.exec_(): d = dia.get() - method = d['method'] + method = d["method"] self.source.fill_blank(method=method) self.append_history() self.refresh() self.refresh_table() - def on_save(self): self.source.save() self.changes_since_save = 0 @@ -670,14 +737,13 @@ def on_region_changed(self): self.spin_limit2.setMinimum(self.spin_limit1.value()) def apply_region(self): - rgn = self.region.getRegion() t = rgn[0] duration = rgn[1] - rgn[0] label = str(self.combo_labels.currentText()) # delete existing epochs in the region where the new epoch will be inserted - if self.params['exclusive_mode']: + if self.params["exclusive_mode"]: self.source.delete_in_between(rgn[0], rgn[1]) # create the new epoch @@ -688,7 +754,6 @@ def apply_region(self): self.refresh_table() def delete_region(self): - rgn = self.region.getRegion() self.source.delete_in_between(rgn[0], rgn[1]) @@ -704,7 +769,13 @@ def range_group_toggle(self): def on_range_visibility_changed(self, flag=None, refresh=True, shift_region=True): enabled = self.range_group_box.isChecked() - for w in (self.spin_limit1, self.spin_limit2, self.combo_labels, self.but_apply_region, self.but_del_region): + for w in ( + self.spin_limit1, + self.spin_limit2, + self.combo_labels, + self.but_apply_region, + self.but_del_region, + ): w.setEnabled(enabled) if enabled and shift_region: rgn = self.region.getRegion() @@ -713,12 +784,12 @@ def on_range_visibility_changed(self, flag=None, refresh=True, shift_region=True self.refresh() def set_limit1(self): - if self.tself.spin_limit1.value(): + if self.t > self.spin_limit1.value(): self.spin_limit2.setValue(self.t) self.spin_limit2.repaint() # needed on macOS @@ -726,37 +797,46 @@ def refresh_table(self): self.table_widget.blockSignals(True) self.table_widget.clear() - times, durations, labels = self.source.ep_times, self.source.ep_durations, self.source.ep_labels + times, durations, labels = ( + self.source.ep_times, + self.source.ep_durations, + self.source.ep_labels, + ) self.table_widget.setColumnCount(8) self.table_widget.setRowCount(times.size) - self.table_widget.setHorizontalHeaderLabels(['', 'Start', 'Stop', 'Duration', 'Label', '', '', '']) + self.table_widget.setHorizontalHeaderLabels( + ["", "Start", "Stop", "Duration", "Label", "", "", ""] + ) # lock column widths for buttons to fixed button width - self.table_widget.horizontalHeader().setMinimumSectionSize(self.table_button_fixed_width) + self.table_widget.horizontalHeader().setMinimumSectionSize( + self.table_button_fixed_width + ) for col in [SEEK_COL, SPLIT_COL, DUPLICATE_COL, DELETE_COL]: - self.table_widget.horizontalHeader().setSectionResizeMode(col, QT.QHeaderView.Fixed) + self.table_widget.horizontalHeader().setSectionResizeMode( + col, QT.QHeaderView.Fixed + ) self.table_widget.setColumnWidth(col, self.table_button_fixed_width) for r in range(times.size): - # seek button - item = QT.QTableWidgetItem(self.table_widget_icons['seek'], '') - item.setToolTip('Jump to epoch') + item = QT.QTableWidgetItem(self.table_widget_icons["seek"], "") + item.setToolTip("Jump to epoch") self.table_widget.setItem(r, SEEK_COL, item) # start - value = np.round(times[r], 6) # round to nearest microsecond - item = QT.QTableWidgetItem('{}'.format(value)) + value = np.round(times[r], 6) # round to nearest microsecond + item = QT.QTableWidgetItem("{}".format(value)) self.table_widget.setItem(r, START_COL, item) # stop - value = np.round(times[r]+durations[r], 6) # round to nearest microsecond - item = QT.QTableWidgetItem('{}'.format(value)) + value = np.round(times[r] + durations[r], 6) # round to nearest microsecond + item = QT.QTableWidgetItem("{}".format(value)) self.table_widget.setItem(r, STOP_COL, item) # duration - value = np.round(durations[r], 6) # round to nearest microsecond - item = QT.QTableWidgetItem('{}'.format(value)) + value = np.round(durations[r], 6) # round to nearest microsecond + item = QT.QTableWidgetItem("{}".format(value)) self.table_widget.setItem(r, DURATION_COL, item) # label @@ -764,34 +844,32 @@ def refresh_table(self): self.table_widget.setItem(r, LABEL_COL, item) # split button - item = QT.QTableWidgetItem(self.table_widget_icons['split'], '') - item.setToolTip('Split epoch at current time') + item = QT.QTableWidgetItem(self.table_widget_icons["split"], "") + item.setToolTip("Split epoch at current time") self.table_widget.setItem(r, SPLIT_COL, item) # duplicate button - item = QT.QTableWidgetItem(self.table_widget_icons['duplicate'], '') - item.setToolTip('Duplicate epoch') + item = QT.QTableWidgetItem(self.table_widget_icons["duplicate"], "") + item.setToolTip("Duplicate epoch") self.table_widget.setItem(r, DUPLICATE_COL, item) # delete button - item = QT.QTableWidgetItem(self.table_widget_icons['delete'], '') - item.setToolTip('Delete epoch') + item = QT.QTableWidgetItem(self.table_widget_icons["delete"], "") + item.setToolTip("Delete epoch") self.table_widget.setItem(r, DELETE_COL, item) self.table_widget.blockSignals(False) def on_rect_clicked(self, id): - # get index corresponding to epoch id ind = self.source.id_to_ind[id] # select the epoch in the data table self.table_widget.blockSignals(True) - self.table_widget.setCurrentCell(ind, LABEL_COL) # select the label combo box + self.table_widget.setCurrentCell(ind, LABEL_COL) # select the label combo box self.table_widget.blockSignals(False) def on_rect_doubleclicked(self, id): - # get index corresponding to epoch id ind = self.source.id_to_ind[id] @@ -803,11 +881,11 @@ def on_rect_doubleclicked(self, id): self.on_range_visibility_changed(shift_region=False) def on_seek_table(self, ind=None): - if self.table_widget.rowCount()==0: + if self.table_widget.rowCount() == 0: return if ind is None: selected_ind = self.table_widget.selectedIndexes() - if len(selected_ind)==0: + if len(selected_ind) == 0: return ind = selected_ind[0].row() self.t = self.source.ep_times[ind] @@ -829,7 +907,9 @@ def on_table_cell_click(self, row, col): combo_labels.addItems(self.source.possible_labels) combo_labels.setCurrentText(self.source.ep_labels[row]) combo_labels.currentIndexChanged.connect( - lambda label_index, row=row: self.on_change_label(row, self.source.possible_labels[label_index]) + lambda label_index, row=row: self.on_change_label( + row, self.source.possible_labels[label_index] + ) ) self.table_widget.setCellWidget(row, LABEL_COL, combo_labels) else: @@ -837,7 +917,8 @@ def on_table_cell_click(self, row, col): def on_table_cell_change(self, row, col): line_edit = self.table_widget.cellWidget(row, col) - if not isinstance(line_edit, QT.QLineEdit): return + if not isinstance(line_edit, QT.QLineEdit): + return new_text = line_edit.text() try: @@ -853,9 +934,9 @@ def on_table_cell_change(self, row, col): elif col == DURATION_COL: old_number = self.source.ep_durations[row] else: - print('unexpected column changed') + print("unexpected column changed") return - old_number = np.round(old_number, 6) # round to nearest microsecond + old_number = np.round(old_number, 6) # round to nearest microsecond self.table_widget.blockSignals(True) self.table_widget.item(row, col).setText(str(old_number)) self.table_widget.blockSignals(False) @@ -864,11 +945,10 @@ def on_table_cell_change(self, row, col): self.table_widget.blockSignals(True) # round and copy rounded number to table - new_number = np.round(new_number, 6) # round to nearest microsecond + new_number = np.round(new_number, 6) # round to nearest microsecond self.table_widget.item(row, col).setText(str(new_number)) if col == START_COL: - # get the epoch stop time stop_time = self.source.ep_times[row] + self.source.ep_durations[row] @@ -880,32 +960,38 @@ def on_table_cell_change(self, row, col): # update epoch duration in table duration_line_edit = self.table_widget.item(row, DURATION_COL) - new_duration = np.round(self.source.ep_durations[row], 6) # round to nearest microsecond + new_duration = np.round( + self.source.ep_durations[row], 6 + ) # round to nearest microsecond duration_line_edit.setText(str(new_duration)) elif col == STOP_COL: - # change epoch duration for corresponding stop time - self.source.ep_durations[row] = new_number-self.source.ep_times[row] + self.source.ep_durations[row] = new_number - self.source.ep_times[row] # update epoch duration in table duration_line_edit = self.table_widget.item(row, DURATION_COL) - new_duration = np.round(self.source.ep_durations[row], 6) # round to nearest microsecond + new_duration = np.round( + self.source.ep_durations[row], 6 + ) # round to nearest microsecond duration_line_edit.setText(str(new_duration)) elif col == DURATION_COL: - # change epoch duration self.source.ep_durations[row] = new_number # update epoch stop time in table stop_line_edit = self.table_widget.item(row, STOP_COL) - new_stop_time = self.source.ep_times[row] + self.source.ep_durations[row] - new_stop_time = np.round(new_stop_time, 6) # round to nearest microsecond + new_stop_time = ( + self.source.ep_times[row] + self.source.ep_durations[row] + ) + new_stop_time = np.round( + new_stop_time, 6 + ) # round to nearest microsecond stop_line_edit.setText(str(new_stop_time)) else: - print('unexpected column changed') + print("unexpected column changed") self.table_widget.blockSignals(False) @@ -916,7 +1002,6 @@ def on_table_cell_change(self, row, col): # refresh_table is not called to avoid deselecting table cell def on_change_label(self, ind, new_label): - # change epoch label self.source.ep_labels[ind] = new_label @@ -927,11 +1012,11 @@ def on_change_label(self, ind, new_label): # refresh_table is not called to avoid deselecting table cell def delete_selected_epoch(self, ind=None): - if self.table_widget.rowCount()==0: + if self.table_widget.rowCount() == 0: return if ind is None: selected_ind = self.table_widget.selectedIndexes() - if len(selected_ind)==0: + if len(selected_ind) == 0: return ind = selected_ind[0].row() self.source.delete_epoch(ind) @@ -940,24 +1025,28 @@ def delete_selected_epoch(self, ind=None): self.refresh_table() def duplicate_selected_epoch(self, ind=None): - if self.table_widget.rowCount()==0: + if self.table_widget.rowCount() == 0: return if ind is None: selected_ind = self.table_widget.selectedIndexes() - if len(selected_ind)==0: + if len(selected_ind) == 0: return ind = selected_ind[0].row() - self.source.add_epoch(self.source.ep_times[ind], self.source.ep_durations[ind], self.source.ep_labels[ind]) + self.source.add_epoch( + self.source.ep_times[ind], + self.source.ep_durations[ind], + self.source.ep_labels[ind], + ) self.append_history() self.refresh() self.refresh_table() def split_selected_epoch(self, ind=None): - if self.table_widget.rowCount()==0: + if self.table_widget.rowCount() == 0: return if ind is None: selected_ind = self.table_widget.selectedIndexes() - if len(selected_ind)==0: + if len(selected_ind) == 0: return ind = selected_ind[0].row() if self.t <= self.source.ep_times[ind] or self.source.ep_stops[ind] <= self.t: diff --git a/ephyviewer/epochviewer.py b/ephyviewer/epochviewer.py index fd2d3bb..96326b3 100644 --- a/ephyviewer/epochviewer.py +++ b/ephyviewer/epochviewer.py @@ -1,12 +1,9 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) import numpy as np -import matplotlib.cm -import matplotlib.colors - from .myqt import QT import pyqtgraph as pg @@ -15,39 +12,34 @@ default_params = [ - {'name': 'xsize', 'type': 'float', 'value': 3., 'step': 0.1}, - {'name': 'xratio', 'type': 'float', 'value': 0.3, 'step': 0.1, 'limits': (0,1)}, - {'name': 'background_color', 'type': 'color', 'value': 'k'}, - {'name': 'vline_color', 'type': 'color', 'value': '#FFFFFFAA'}, - {'name': 'label_fill_color', 'type': 'color', 'value': '#222222DD'}, - {'name': 'label_size', 'type': 'int', 'value': 8, 'limits': (1,np.inf)}, - {'name': 'display_labels', 'type': 'bool', 'value': True}, - ] + {"name": "xsize", "type": "float", "value": 3.0, "step": 0.1}, + {"name": "xratio", "type": "float", "value": 0.3, "step": 0.1, "limits": (0, 1)}, + {"name": "background_color", "type": "color", "value": "k"}, + {"name": "vline_color", "type": "color", "value": "#FFFFFFAA"}, + {"name": "label_fill_color", "type": "color", "value": "#222222DD"}, + {"name": "label_size", "type": "int", "value": 8, "limits": (1, np.inf)}, + {"name": "display_labels", "type": "bool", "value": True}, +] default_by_channel_params = [ - {'name': 'color', 'type': 'color', 'value': "#55FF00"}, - {'name': 'visible', 'type': 'bool', 'value': True}, - ] - - + {"name": "color", "type": "color", "value": "#55FF00"}, + {"name": "visible", "type": "bool", "value": True}, +] class EpochViewer_ParamController(Base_MultiChannel_ParamController): pass - - class RectItem(pg.GraphicsWidget): - clicked = QT.pyqtSignal(int) doubleclicked = QT.pyqtSignal(int) - def __init__(self, rect, border = 'r', fill = 'g', id = -1): + def __init__(self, rect, border="r", fill="g", id=-1): pg.GraphicsWidget.__init__(self) self.rect = rect - self.border= border - self.fill= fill + self.border = border + self.fill = fill self.id = id def boundingRect(self): @@ -59,7 +51,7 @@ def paint(self, p, *args): p.drawRect(self.boundingRect()) def mouseClickEvent(self, event): - if event.button()== QT.LeftButton: + if event.button() == QT.LeftButton: event.accept() if event.double(): self.doubleclicked.emit(self.id) @@ -79,7 +71,9 @@ def __init__(self, source, parent=None): def on_request_data(self, t_start, t_stop, visibles): data = {} for e, chan in enumerate(visibles): - data[chan] = self.source.get_chunk_by_time(chan=chan, t_start=t_start, t_stop=t_stop) + data[chan] = self.source.get_chunk_by_time( + chan=chan, t_start=t_start, t_stop=t_stop + ) self.data_ready.emit(t_start, t_stop, visibles, data) @@ -105,11 +99,10 @@ def __init__(self, **kargs): self.datagrabber.moveToThread(self.thread) self.thread.start() - self.datagrabber.data_ready.connect(self.on_data_ready) self.request_data.connect(self.datagrabber.on_request_data) - self.params.param('xsize').setLimits((0, np.inf)) + self.params.param("xsize").setLimits((0, np.inf)) @classmethod def from_numpy(cls, all_epochs, name): @@ -128,54 +121,64 @@ def closeEvent(self, event): self.thread.quit() self.thread.wait() - def initialize_plot(self): self.viewBox.xsize_zoom.connect(self.params_controller.apply_xsize_zoom) def refresh(self): - xsize = self.params['xsize'] - xratio = self.params['xratio'] - t_start, t_stop = self.t-xsize*xratio , self.t+xsize*(1-xratio) - visibles, = np.nonzero(self.params_controller.visible_channels) + xsize = self.params["xsize"] + xratio = self.params["xratio"] + t_start, t_stop = self.t - xsize * xratio, self.t + xsize * (1 - xratio) + (visibles,) = np.nonzero(self.params_controller.visible_channels) self.request_data.emit(t_start, t_stop, visibles) def on_data_ready(self, t_start, t_stop, visibles, data): self.plot.clear() - self.graphicsview.setBackground(self.params['background_color']) + self.graphicsview.setBackground(self.params["background_color"]) for e, chan in enumerate(visibles): - - if len(data[chan])==3: + if len(data[chan]) == 3: times, durations, labels = data[chan] - elif len(data[chan])==4: + elif len(data[chan]) == 4: times, durations, labels, _ = data[chan] else: raise ValueError("data has unexpected dimensions") - color = self.by_channel_params.children()[chan].param('color').value() + color = self.by_channel_params.children()[chan].param("color").value() color2 = QT.QColor(color) color2.setAlpha(130) - ypos = visibles.size-e-1 + ypos = visibles.size - e - 1 for i in range(times.size): - item = RectItem([times[i], ypos,durations[i], .9], border = color, fill = color2) - item.setPos(times[i], visibles.size-e-1) + item = RectItem( + [times[i], ypos, durations[i], 0.9], border=color, fill=color2 + ) + item.setPos(times[i], visibles.size - e - 1) self.plot.addItem(item) - if self.params['display_labels']: - label_name = '{}: {}'.format(chan, self.source.get_channel_name(chan=chan)) - label = pg.TextItem(label_name, color=color, anchor=(0, 0.5), border=None, fill=self.params['label_fill_color']) + if self.params["display_labels"]: + label_name = "{}: {}".format( + chan, self.source.get_channel_name(chan=chan) + ) + label = pg.TextItem( + label_name, + color=color, + anchor=(0, 0.5), + border=None, + fill=self.params["label_fill_color"], + ) font = label.textItem.font() - font.setPointSize(self.params['label_size']) + font.setPointSize(self.params["label_size"]) label.setFont(font) self.plot.addItem(label) - label.setPos(t_start, ypos+0.45) + label.setPos(t_start, ypos + 0.45) - self.vline = pg.InfiniteLine(angle = 90, movable = False, pen = self.params['vline_color']) - self.vline.setZValue(1) # ensure vline is above plot elements + self.vline = pg.InfiniteLine( + angle=90, movable=False, pen=self.params["vline_color"] + ) + self.vline.setZValue(1) # ensure vline is above plot elements self.plot.addItem(self.vline) self.vline.setPos(self.t) - self.plot.setXRange( t_start, t_stop, padding = 0.0) - self.plot.setYRange( 0, visibles.size) + self.plot.setXRange(t_start, t_stop, padding=0.0) + self.plot.setYRange(0, visibles.size) diff --git a/ephyviewer/spectrogramviewer.py b/ephyviewer/spectrogramviewer.py index c8914dd..fa20353 100644 --- a/ephyviewer/spectrogramviewer.py +++ b/ephyviewer/spectrogramviewer.py @@ -2,7 +2,7 @@ import scipy.fftpack import scipy.signal -import matplotlib.cm +from matplotlib import colormaps import matplotlib.colors from .myqt import QT @@ -14,108 +14,157 @@ from .tools import create_plot_grid, get_dict_from_group_param -#todo remove this +# todo remove this import time import threading - default_params = [ - {'name': 'xsize', 'type': 'float', 'value': 10., 'step': 0.1}, - {'name': 'xratio', 'type': 'float', 'value': 0.3, 'step': 0.1, 'limits': (0,1)}, - {'name': 'nb_column', 'type': 'int', 'value': 4}, - {'name': 'background_color', 'type': 'color', 'value': 'k'}, - {'name': 'vline_color', 'type': 'color', 'value': '#FFFFFFAA'}, - {'name': 'colormap', 'type': 'list', 'value': 'viridis', 'limits' : ['inferno', 'viridis', 'jet', 'gray', 'hot', ] }, - {'name': 'scale_mode', 'type': 'list', 'value': 'same_for_all', 'limits' : ['by_channel', 'same_for_all', ] }, - {'name': 'display_labels', 'type': 'bool', 'value': True}, - {'name': 'show_axis', 'type': 'bool', 'value': True}, - {'name': 'scalogram', 'type': 'group', 'children': [ - {'name': 'binsize', 'type': 'float', 'value': 0.01, 'step': .01, 'limits': (0,np.inf)}, - {'name': 'overlapratio', 'type': 'float', 'value': 0., 'step': .05, 'limits': (0., 0.95)}, - {'name': 'scaling', 'type': 'list', 'value': 'density', 'limits' : ['density', 'spectrum'] }, - {'name': 'mode', 'type': 'list', 'value': 'psd', 'limits' : ['psd'] }, - {'name': 'detrend', 'type': 'list', 'value': 'constant', 'limits' : ['constant'] }, - {'name': 'scale', 'type': 'list', 'value': 'dB', 'limits' : ['dB', 'linear'] }, - - - ] - } - - ] + {"name": "xsize", "type": "float", "value": 10.0, "step": 0.1}, + {"name": "xratio", "type": "float", "value": 0.3, "step": 0.1, "limits": (0, 1)}, + {"name": "nb_column", "type": "int", "value": 4}, + {"name": "background_color", "type": "color", "value": "k"}, + {"name": "vline_color", "type": "color", "value": "#FFFFFFAA"}, + { + "name": "colormap", + "type": "list", + "value": "viridis", + "limits": [ + "inferno", + "viridis", + "jet", + "gray", + "hot", + ], + }, + { + "name": "scale_mode", + "type": "list", + "value": "same_for_all", + "limits": [ + "by_channel", + "same_for_all", + ], + }, + {"name": "display_labels", "type": "bool", "value": True}, + {"name": "show_axis", "type": "bool", "value": True}, + { + "name": "scalogram", + "type": "group", + "children": [ + { + "name": "binsize", + "type": "float", + "value": 0.01, + "step": 0.01, + "limits": (0, np.inf), + }, + { + "name": "overlapratio", + "type": "float", + "value": 0.0, + "step": 0.05, + "limits": (0.0, 0.95), + }, + { + "name": "scaling", + "type": "list", + "value": "density", + "limits": ["density", "spectrum"], + }, + {"name": "mode", "type": "list", "value": "psd", "limits": ["psd"]}, + { + "name": "detrend", + "type": "list", + "value": "constant", + "limits": ["constant"], + }, + { + "name": "scale", + "type": "list", + "value": "dB", + "limits": ["dB", "linear"], + }, + ], + }, +] default_by_channel_params = [ - {'name': 'visible', 'type': 'bool', 'value': True}, - {'name': 'clim_min', 'type': 'float', 'value': -10}, - {'name': 'clim_max', 'type': 'float', 'value': 0.}, - ] - - - + {"name": "visible", "type": "bool", "value": True}, + {"name": "clim_min", "type": "float", "value": -10}, + {"name": "clim_max", "type": "float", "value": 0.0}, +] + + class SpectrogramViewer_ParamController(Base_MultiChannel_ParamController): some_clim_changed = QT.pyqtSignal() def on_channel_visibility_changed(self): - #~ print('SpectrogramViewer_ParamController.on_channel_visibility_changed') + # ~ print('SpectrogramViewer_ParamController.on_channel_visibility_changed') self.viewer.create_grid() self.viewer.refresh() def clim_zoom(self, factor): - print('clim_zoom factor', factor) + print("clim_zoom factor", factor) self.viewer.by_channel_params.blockSignals(True) for i, p in enumerate(self.viewer.by_channel_params.children()): - # p.param('clim').setValue(p.param('clim').value()*factor) - min_ = p['clim_min'] - max_ = p['clim_max'] + # p.param('clim').setValue(p.param('clim').value()*factor) + min_ = p["clim_min"] + max_ = p["clim_max"] d = max_ - min_ - m = (min_ + max_) / 2. - p['clim_min'] = m - d/2. * factor - p['clim_max'] = m + d/2. * factor + m = (min_ + max_) / 2.0 + p["clim_min"] = m - d / 2.0 * factor + p["clim_max"] = m + d / 2.0 * factor self.viewer.by_channel_params.blockSignals(False) self.some_clim_changed.emit() def compute_auto_clim(self): - #~ print('compute_auto_clim') - #~ print(self.visible_channels) + # ~ print('compute_auto_clim') + # ~ print(self.visible_channels) self.viewer.by_channel_params.blockSignals(True) mins = [] maxs = [] - visibles, = np.nonzero(self.visible_channels) + (visibles,) = np.nonzero(self.visible_channels) for chan in visibles: if chan in self.viewer.last_Sxx.keys(): min_ = np.min(self.viewer.last_Sxx[chan]) max_ = np.max(self.viewer.last_Sxx[chan]) - if self.viewer.params['scale_mode'] == 'by_channel': - self.viewer.by_channel_params['ch'+str(chan), 'clim_min'] = min_ - self.viewer.by_channel_params['ch'+str(chan), 'clim_max'] = max_ - + if self.viewer.params["scale_mode"] == "by_channel": + self.viewer.by_channel_params["ch" + str(chan), "clim_min"] = min_ + self.viewer.by_channel_params["ch" + str(chan), "clim_max"] = max_ + else: mins.append(min_) maxs.append(max_) - - if self.viewer.params['scale_mode'] == 'same_for_all' and len(maxs)>0: + + if self.viewer.params["scale_mode"] == "same_for_all" and len(maxs) > 0: for chan in visibles: - self.viewer.by_channel_params['ch'+str(chan), 'clim_min'] = np.min(mins) - self.viewer.by_channel_params['ch'+str(chan), 'clim_max'] = np.max(maxs) + self.viewer.by_channel_params["ch" + str(chan), "clim_min"] = np.min( + mins + ) + self.viewer.by_channel_params["ch" + str(chan), "clim_max"] = np.max( + maxs + ) self.viewer.by_channel_params.blockSignals(False) self.some_clim_changed.emit() - class SpectrogramWorker(QT.QObject): - data_ready = QT.pyqtSignal(int, float, float, float, float, float, object) + data_ready = QT.pyqtSignal(int, float, float, float, float, float, object) - def __init__(self, source,viewer, chan, parent=None): + def __init__(self, source, viewer, chan, parent=None): QT.QObject.__init__(self, parent) self.source = source self.viewer = viewer self.chan = chan - def on_request_data(self, chan, t, t_start, t_stop, visible_channels, worker_params): + def on_request_data( + self, chan, t, t_start, t_stop, visible_channels, worker_params + ): if chan != self.chan: return @@ -123,65 +172,65 @@ def on_request_data(self, chan, t, t_start, t_stop, visible_channels, worker_par return if self.viewer.t != t: - print('viewer has moved already', chan, self.viewer.t, t) + print("viewer has moved already", chan, self.viewer.t, t) # viewer has moved already return - binsize = worker_params['binsize'] - overlapratio = worker_params['overlapratio'] - scaling = worker_params['scaling'] - detrend = worker_params['detrend'] - mode = worker_params['mode'] - + binsize = worker_params["binsize"] + overlapratio = worker_params["overlapratio"] + scaling = worker_params["scaling"] + detrend = worker_params["detrend"] + mode = worker_params["mode"] i_start = self.source.time_to_index(t_start) i_stop = self.source.time_to_index(t_stop) # clip i_start = min(max(0, i_start), self.source.get_length()) i_stop = min(max(0, i_stop), self.source.get_length()) - sr = self.source.sample_rate nperseg = int(binsize * sr) noverlap = int(overlapratio * nperseg) - + if noverlap >= nperseg: noverlap = noverlap - 1 print(nperseg, noverlap) - if nperseg== 0 or (i_stop - i_start) < nperseg: + if nperseg == 0 or (i_stop - i_start) < nperseg: # too short t1, t2 = t_start, t_stop Sxx = None - self.data_ready.emit(chan, t, t_start, t_stop, t1, t2, Sxx) + self.data_ready.emit(chan, t, t_start, t_stop, t1, t2, Sxx) else: - sigs_chunk = self.source.get_chunk(i_start=i_start, i_stop=i_stop) sig = sigs_chunk[:, chan] - freqs, times, Sxx = scipy.signal.spectrogram(sig, fs=sr,nperseg=nperseg, noverlap=noverlap, - detrend=detrend, scaling=scaling, mode=mode) - - if worker_params['scale'] == 'dB': - if mode == 'psd': - Sxx = 10. * np.log10(Sxx) - - if len(times) >=2: + freqs, times, Sxx = scipy.signal.spectrogram( + sig, + fs=sr, + nperseg=nperseg, + noverlap=noverlap, + detrend=detrend, + scaling=scaling, + mode=mode, + ) + + if worker_params["scale"] == "dB": + if mode == "psd": + Sxx = 10.0 * np.log10(Sxx) + + if len(times) >= 2: bin_slide = times[1] - times[0] - t1 = self.source.index_to_time(i_start) + times[0] - bin_slide /2. - t2 = self.source.index_to_time(i_start) + times[-1] + bin_slide /2. - self.data_ready.emit(chan, t, t_start, t_stop, t1, t2, Sxx) + t1 = self.source.index_to_time(i_start) + times[0] - bin_slide / 2.0 + t2 = self.source.index_to_time(i_start) + times[-1] + bin_slide / 2.0 + self.data_ready.emit(chan, t, t_start, t_stop, t1, t2, Sxx) else: t1, t2 = t_start, t_stop Sxx = None - self.data_ready.emit(chan, t, t_start, t_stop, t1, t2, Sxx) - - - + self.data_ready.emit(chan, t, t_start, t_stop, t1, t2, Sxx) class SpectrogramViewer(BaseMultiChannelViewer): - _default_params = default_params _default_by_channel_params = default_by_channel_params @@ -197,7 +246,7 @@ def __init__(self, **kargs): # make all not visible self.by_channel_params.blockSignals(True) for c in range(self.source.nb_channel): - self.by_channel_params['ch'+str(c), 'visible'] = c==0 + self.by_channel_params["ch" + str(c), "visible"] = c == 0 self.by_channel_params.blockSignals(False) self.make_param_controller() @@ -208,7 +257,6 @@ def __init__(self, **kargs): self.change_color_scale() self.create_grid() - self.last_Sxx = {} self.threads = [] @@ -220,18 +268,19 @@ def __init__(self, **kargs): self.timefreq_makers.append(worker) worker.moveToThread(thread) thread.start() - - self.last_Sxx[c] = None + self.last_Sxx[c] = None worker.data_ready.connect(self.on_data_ready) self.request_data.connect(worker.on_request_data) - self.params.param('xsize').setLimits((0, np.inf)) + self.params.param("xsize").setLimits((0, np.inf)) @classmethod def from_numpy(cls, sigs, sample_rate, t_start, name, channel_names=None): - source = InMemoryAnalogSignalSource(sigs, sample_rate, t_start, channel_names=channel_names) + source = InMemoryAnalogSignalSource( + sigs, sample_rate, t_start, channel_names=channel_names + ) view = cls(source=source, name=name) return view @@ -249,14 +298,14 @@ def set_layout(self): self.mainlayout.addWidget(self.graphiclayout) def on_param_change(self, params=None, changes=None): - #~ print('on_param_change') - #track if new scale mode - #~ for param, change, data in changes: - #~ if change != 'value': continue - #~ if param.name()=='scale_mode': - #~ self.params_controller.compute_rescale() - - #for simplification everything is recompute + # ~ print('on_param_change') + # track if new scale mode + # ~ for param, change, data in changes: + # ~ if change != 'value': continue + # ~ if param.name()=='scale_mode': + # ~ self.params_controller.compute_rescale() + + # for simplification everything is recompute self.change_color_scale() self.create_grid() self.refresh() @@ -264,8 +313,13 @@ def on_param_change(self, params=None, changes=None): def create_grid(self): visible_channels = self.params_controller.visible_channels - self.plots = create_plot_grid(self.graphiclayout, self.params['nb_column'], visible_channels, - ViewBoxClass=MyViewBox, vb_params={}) + self.plots = create_plot_grid( + self.graphiclayout, + self.params["nb_column"], + visible_channels, + ViewBoxClass=MyViewBox, + vb_params={}, + ) for plot in self.plots: if plot is not None: @@ -281,8 +335,10 @@ def create_grid(self): self.plots[c].addItem(image) self.images.append(image) - vline = pg.InfiniteLine(angle = 90, movable = False, pen = self.params['vline_color']) - vline.setZValue(1) # ensure vline is above plot elements + vline = pg.InfiniteLine( + angle=90, movable=False, pen=self.params["vline_color"] + ) + vline.setZValue(1) # ensure vline is above plot elements self.plots[c].addItem(vline) self.vlines.append(vline) @@ -292,91 +348,85 @@ def create_grid(self): def change_color_scale(self): N = 512 - cmap_name = self.params['colormap'] - cmap = matplotlib.cm.get_cmap(cmap_name , N) - + cmap_name = self.params["colormap"] + cmap = colormaps.get_cmap(cmap_name).resampled(N) lut = [] for i in range(N): - r,g,b,_ = matplotlib.colors.ColorConverter().to_rgba(cmap(i)) - lut.append([r*255,g*255,b*255]) - self.lut = np.array(lut, dtype='uint8') + r, g, b, _ = matplotlib.colors.ColorConverter().to_rgba(cmap(i)) + lut.append([r * 255, g * 255, b * 255]) + self.lut = np.array(lut, dtype="uint8") def auto_scale(self): - #~ print('auto_scale', self.params['scale_mode']) + # ~ print('auto_scale', self.params['scale_mode']) self.params_controller.compute_auto_clim() self.refresh() def refresh(self): - #~ print('TimeFreqViewer.refresh', self.t) + # ~ print('TimeFreqViewer.refresh', self.t) visible_channels = self.params_controller.visible_channels - xsize = self.params['xsize'] - xratio = self.params['xratio'] - t_start, t_stop = self.t-xsize*xratio , self.t+xsize*(1-xratio) - - worker_params = get_dict_from_group_param(self.params.param('scalogram')) - #~ print('worker_params', worker_params) - - + xsize = self.params["xsize"] + xratio = self.params["xratio"] + t_start, t_stop = self.t - xsize * xratio, self.t + xsize * (1 - xratio) + + worker_params = get_dict_from_group_param(self.params.param("scalogram")) + # ~ print('worker_params', worker_params) + for c in range(self.source.nb_channel): if visible_channels[c]: - self.request_data.emit(c, self.t, t_start, t_stop, visible_channels, worker_params) + self.request_data.emit( + c, self.t, t_start, t_stop, visible_channels, worker_params + ) - self.graphiclayout.setBackground(self.params['background_color']) + self.graphiclayout.setBackground(self.params["background_color"]) - def on_data_ready(self, chan, t, t_start, t_stop, t1,t2, Sxx): + def on_data_ready(self, chan, t, t_start, t_stop, t1, t2, Sxx): if not self.params_controller.visible_channels[chan]: return if self.images[chan] is None: return - - + self.last_Sxx[chan] = Sxx - + image = self.images[chan] - + if Sxx is None: image.hide() return - + image.show() - + f_start = 0 - f_stop = self.source.sample_rate / 2. - #~ f_start = self.params['timefreq', 'f_start'] - #~ f_stop = self.params['timefreq', 'f_stop'] - - - #~ print(t_start, f_start,self.worker_params['wanted_size'], f_stop-f_start) - - #~ image.updateImage(wt_map) - clim_min = self.by_channel_params['ch{}'.format(chan), 'clim_min'] - clim_max = self.by_channel_params['ch{}'.format(chan), 'clim_max'] - - #~ clim = np.max(Sxx) + f_stop = self.source.sample_rate / 2.0 + # ~ f_start = self.params['timefreq', 'f_start'] + # ~ f_stop = self.params['timefreq', 'f_stop'] + + # ~ print(t_start, f_start,self.worker_params['wanted_size'], f_stop-f_start) + + # ~ image.updateImage(wt_map) + clim_min = self.by_channel_params["ch{}".format(chan), "clim_min"] + clim_max = self.by_channel_params["ch{}".format(chan), "clim_max"] + + # ~ clim = np.max(Sxx) image.setImage(Sxx.T, lut=self.lut, levels=[clim_min, clim_max]) - #~ image.setImage(Sxx.T, lut=self.lut, levels=[ np.min(Sxx), np.max(Sxx)]) - image.setRect(QT.QRectF(t1, f_start,t2-t1, f_stop-f_start)) + # ~ image.setImage(Sxx.T, lut=self.lut, levels=[ np.min(Sxx), np.max(Sxx)]) + image.setRect(QT.QRectF(t1, f_start, t2 - t1, f_stop - f_start)) - #TODO + # TODO # display_labels self.vlines[chan].setPos(t) - self.vlines[chan].setPen(self.params['vline_color']) + self.vlines[chan].setPen(self.params["vline_color"]) plot = self.plots[chan] - plot.setXRange(t_start, t_stop, padding = 0.0) - plot.setYRange(f_start, f_stop, padding = 0.0) + plot.setXRange(t_start, t_stop, padding=0.0) + plot.setYRange(f_start, f_stop, padding=0.0) - if self.params['display_labels']: - ch_name = '{}: {}'.format(chan, self.source.get_channel_name(chan=chan)) + if self.params["display_labels"]: + ch_name = "{}: {}".format(chan, self.source.get_channel_name(chan=chan)) self.plots[chan].setTitle(ch_name) else: self.plots[chan].setTitle(None) - - self.plots[chan].showAxis('left', self.params['show_axis']) - self.plots[chan].showAxis('bottom', self.params['show_axis']) - - - + self.plots[chan].showAxis("left", self.params["show_axis"]) + self.plots[chan].showAxis("bottom", self.params["show_axis"]) diff --git a/ephyviewer/spiketrainviewer.py b/ephyviewer/spiketrainviewer.py index 9257516..6f4f940 100644 --- a/ephyviewer/spiketrainviewer.py +++ b/ephyviewer/spiketrainviewer.py @@ -1,12 +1,9 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) import numpy as np -import matplotlib.cm -import matplotlib.colors - from .myqt import QT import pyqtgraph as pg @@ -14,41 +11,36 @@ from .datasource import InMemorySpikeSource, NeoSpikeTrainSource - -#make symbol for spikes +# make symbol for spikes from pyqtgraph.graphicsItems.ScatterPlotItem import Symbols -Symbols['|'] = QT.QPainterPath() -Symbols['|'].moveTo(0, -0.5) -Symbols['|'].lineTo(0, 0.5) -Symbols['|'].closeSubpath() - +Symbols["|"] = QT.QPainterPath() +Symbols["|"].moveTo(0, -0.5) +Symbols["|"].lineTo(0, 0.5) +Symbols["|"].closeSubpath() default_params = [ - {'name': 'xsize', 'type': 'float', 'value': 3., 'step': 0.1}, - {'name': 'xratio', 'type': 'float', 'value': 0.3, 'step': 0.1, 'limits': (0,1)}, - {'name': 'scatter_size', 'type': 'float', 'value': 0.8, 'limits': (0,np.inf)}, - {'name': 'background_color', 'type': 'color', 'value': 'k'}, - {'name': 'vline_color', 'type': 'color', 'value': '#FFFFFFAA'}, - {'name': 'label_fill_color', 'type': 'color', 'value': '#222222DD'}, - {'name': 'label_size', 'type': 'int', 'value': 8, 'limits': (1,np.inf)}, - {'name': 'display_labels', 'type': 'bool', 'value': True}, - ] + {"name": "xsize", "type": "float", "value": 3.0, "step": 0.1}, + {"name": "xratio", "type": "float", "value": 0.3, "step": 0.1, "limits": (0, 1)}, + {"name": "scatter_size", "type": "float", "value": 0.8, "limits": (0, np.inf)}, + {"name": "background_color", "type": "color", "value": "k"}, + {"name": "vline_color", "type": "color", "value": "#FFFFFFAA"}, + {"name": "label_fill_color", "type": "color", "value": "#222222DD"}, + {"name": "label_size", "type": "int", "value": 8, "limits": (1, np.inf)}, + {"name": "display_labels", "type": "bool", "value": True}, +] default_by_channel_params = [ - {'name': 'color', 'type': 'color', 'value': "#55FF00"}, - {'name': 'visible', 'type': 'bool', 'value': True}, - ] - - + {"name": "color", "type": "color", "value": "#55FF00"}, + {"name": "visible", "type": "bool", "value": True}, +] class SpikeTrainViewer_ParamController(Base_MultiChannel_ParamController): pass - class DataGrabber(QT.QObject): data_ready = QT.pyqtSignal(float, float, object, object) @@ -59,7 +51,9 @@ def __init__(self, source, parent=None): def on_request_data(self, t_start, t_stop, visibles): data = {} for e, chan in enumerate(visibles): - times = self.source.get_chunk_by_time(chan=chan, t_start=t_start, t_stop=t_stop) + times = self.source.get_chunk_by_time( + chan=chan, t_start=t_start, t_stop=t_stop + ) data[chan] = times self.data_ready.emit(t_start, t_stop, visibles, data) @@ -83,18 +77,15 @@ def __init__(self, **kargs): self.initialize_plot() - - self.thread = QT.QThread(parent=self) self.datagrabber = DataGrabber(source=self.source) self.datagrabber.moveToThread(self.thread) self.thread.start() - self.datagrabber.data_ready.connect(self.on_data_ready) self.request_data.connect(self.datagrabber.on_request_data) - self.params.param('xsize').setLimits((0, np.inf)) + self.params.param("xsize").setLimits((0, np.inf)) @classmethod def from_numpy(cls, all_epochs, name): @@ -113,39 +104,46 @@ def closeEvent(self, event): self.thread.quit() self.thread.wait() - def initialize_plot(self): - self.vline = pg.InfiniteLine(angle = 90, movable = False, pen = self.params['vline_color']) - self.vline.setZValue(1) # ensure vline is above plot elements + self.vline = pg.InfiniteLine( + angle=90, movable=False, pen=self.params["vline_color"] + ) + self.vline.setZValue(1) # ensure vline is above plot elements self.plot.addItem(self.vline) - self.scatter = pg.ScatterPlotItem(size=self.params['scatter_size'], pxMode = False, symbol='|') + self.scatter = pg.ScatterPlotItem( + size=self.params["scatter_size"], pxMode=False, symbol="|" + ) self.plot.addItem(self.scatter) - self.labels = [] for c in range(self.source.nb_channel): - label_name = '{}: {}'.format(c, self.source.get_channel_name(chan=c)) - color = self.by_channel_params.children()[c].param('color').value() - label = pg.TextItem(label_name, color=color, anchor=(0, 0.5), border=None, fill=self.params['label_fill_color']) + label_name = "{}: {}".format(c, self.source.get_channel_name(chan=c)) + color = self.by_channel_params.children()[c].param("color").value() + label = pg.TextItem( + label_name, + color=color, + anchor=(0, 0.5), + border=None, + fill=self.params["label_fill_color"], + ) font = label.textItem.font() - font.setPointSize(self.params['label_size']) + font.setPointSize(self.params["label_size"]) label.setFont(font) self.plot.addItem(label) self.labels.append(label) self.viewBox.xsize_zoom.connect(self.params_controller.apply_xsize_zoom) - def refresh(self): - xsize = self.params['xsize'] - xratio = self.params['xratio'] - t_start, t_stop = self.t-xsize*xratio , self.t+xsize*(1-xratio) - visibles, = np.nonzero(self.params_controller.visible_channels) + xsize = self.params["xsize"] + xratio = self.params["xratio"] + t_start, t_stop = self.t - xsize * xratio, self.t + xsize * (1 - xratio) + (visibles,) = np.nonzero(self.params_controller.visible_channels) self.request_data.emit(t_start, t_stop, visibles) def on_data_ready(self, t_start, t_stop, visibles, data): - self.graphicsview.setBackground(self.params['background_color']) + self.graphicsview.setBackground(self.params["background_color"]) self.scatter.clear() all_x = [] @@ -155,25 +153,24 @@ def on_data_ready(self, t_start, t_stop, visibles, data): for e, chan in enumerate(visibles): times = data[chan] - ypos = visibles.size-e-1 + ypos = visibles.size - e - 1 all_x.append(times) - all_y.append(np.ones(times.size)*ypos) - color = self.by_channel_params.children()[chan].param('color').value() - all_brush.append(np.array([pg.mkPen(color)]*len(times))) - + all_y.append(np.ones(times.size) * ypos) + color = self.by_channel_params.children()[chan].param("color").value() + all_brush.append(np.array([pg.mkPen(color)] * len(times))) - if self.params['display_labels']: + if self.params["display_labels"]: self.labels[chan].setPos(t_start, ypos) self.labels[chan].show() self.labels[chan].setColor(color) for chan in range(self.source.nb_channel): - if not self.params['display_labels'] or chan not in visibles: + if not self.params["display_labels"] or chan not in visibles: self.labels[chan].hide() for label in self.labels: - label.fill = pg.mkBrush(self.params['label_fill_color']) + label.fill = pg.mkBrush(self.params["label_fill_color"]) if len(all_x): all_x = np.concatenate(all_x) @@ -181,21 +178,25 @@ def on_data_ready(self, t_start, t_stop, visibles, data): all_brush = np.concatenate(all_brush) self.scatter.setData(x=all_x, y=all_y, pen=all_brush) - self.vline.setPen(self.params['vline_color']) + self.vline.setPen(self.params["vline_color"]) self.vline.setPos(self.t) - self.plot.setXRange( t_start, t_stop, padding = 0.0) - self.plot.setYRange(-self.params['scatter_size']/2, self.params['scatter_size']/2 + visibles.size - 1) + self.plot.setXRange(t_start, t_stop, padding=0.0) + self.plot.setYRange( + -self.params["scatter_size"] / 2, + self.params["scatter_size"] / 2 + visibles.size - 1, + ) def on_param_change(self, params=None, changes=None): for param, change, data in changes: - if change != 'value': continue - if param.name()=='scatter_size': - self.scatter.setSize(self.params['scatter_size']) - if param.name()=='label_size': + if change != "value": + continue + if param.name() == "scatter_size": + self.scatter.setSize(self.params["scatter_size"]) + if param.name() == "label_size": for label in self.labels: font = label.textItem.font() - font.setPointSize(self.params['label_size']) + font.setPointSize(self.params["label_size"]) label.setFont(font) self.refresh() diff --git a/ephyviewer/timefreqviewer.py b/ephyviewer/timefreqviewer.py index 204e395..65f817f 100644 --- a/ephyviewer/timefreqviewer.py +++ b/ephyviewer/timefreqviewer.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) import numpy as np import scipy.fftpack import scipy.signal -import matplotlib.cm +from matplotlib import colormaps import matplotlib.colors from .myqt import QT @@ -18,37 +18,67 @@ from .tools import create_plot_grid -#todo remove this +# todo remove this import time import threading default_params = [ - {'name': 'xsize', 'type': 'float', 'value': 3., 'step': 0.1}, - {'name': 'xratio', 'type': 'float', 'value': 0.3, 'step': 0.1, 'limits': (0,1)}, - {'name': 'nb_column', 'type': 'int', 'value': 4}, - {'name': 'background_color', 'type': 'color', 'value': 'k'}, - {'name': 'vline_color', 'type': 'color', 'value': '#FFFFFFAA'}, - {'name': 'colormap', 'type': 'list', 'value': 'viridis', 'limits' : ['viridis', 'jet', 'gray', 'hot', ] }, - {'name': 'display_labels', 'type': 'bool', 'value': True}, - {'name': 'show_axis', 'type': 'bool', 'value': True}, - {'name': 'scale_mode', 'type': 'list', 'value': 'same_for_all', 'limits' : ['by_channel', 'same_for_all', ] }, - {'name': 'timefreq', 'type': 'group', 'children': [ - {'name': 'f_start', 'type': 'float', 'value': 3., 'step': 1.}, - {'name': 'f_stop', 'type': 'float', 'value': 90., 'step': 1.}, - {'name': 'deltafreq', 'type': 'float', 'value': 3., 'step': 1., 'limits': [0.1, 1.e6]}, - {'name': 'f0', 'type': 'float', 'value': 2.5, 'step': 0.1}, - {'name': 'normalisation', 'type': 'float', 'value': 0., 'step': 0.1},]} - - ] + {"name": "xsize", "type": "float", "value": 3.0, "step": 0.1}, + {"name": "xratio", "type": "float", "value": 0.3, "step": 0.1, "limits": (0, 1)}, + {"name": "nb_column", "type": "int", "value": 4}, + {"name": "background_color", "type": "color", "value": "k"}, + {"name": "vline_color", "type": "color", "value": "#FFFFFFAA"}, + { + "name": "colormap", + "type": "list", + "value": "viridis", + "limits": [ + "viridis", + "jet", + "gray", + "hot", + ], + }, + {"name": "display_labels", "type": "bool", "value": True}, + {"name": "show_axis", "type": "bool", "value": True}, + { + "name": "scale_mode", + "type": "list", + "value": "same_for_all", + "limits": [ + "by_channel", + "same_for_all", + ], + }, + { + "name": "timefreq", + "type": "group", + "children": [ + {"name": "f_start", "type": "float", "value": 3.0, "step": 1.0}, + {"name": "f_stop", "type": "float", "value": 90.0, "step": 1.0}, + { + "name": "deltafreq", + "type": "float", + "value": 3.0, + "step": 1.0, + "limits": [0.1, 1.0e6], + }, + {"name": "f0", "type": "float", "value": 2.5, "step": 0.1}, + {"name": "normalisation", "type": "float", "value": 0.0, "step": 0.1}, + ], + }, +] default_by_channel_params = [ - {'name': 'visible', 'type': 'bool', 'value': True}, - {'name': 'clim', 'type': 'float', 'value': .1}, - ] + {"name": "visible", "type": "bool", "value": True}, + {"name": "clim", "type": "float", "value": 0.1}, +] -def generate_wavelet_fourier(len_wavelet, f_start, f_stop, deltafreq, sample_rate, f0, normalisation): +def generate_wavelet_fourier( + len_wavelet, f_start, f_stop, deltafreq, sample_rate, f0, normalisation +): """ Compute the wavelet coefficients at all scales and compute its Fourier transform. @@ -75,75 +105,78 @@ def generate_wavelet_fourier(len_wavelet, f_start, f_stop, deltafreq, sample_rat Axis 0 is time; axis 1 is frequency. """ # compute final map scales - scales = f0/np.arange(f_start,f_stop,deltafreq)*sample_rate + scales = f0 / np.arange(f_start, f_stop, deltafreq) * sample_rate # compute wavelet coeffs at all scales - xi=np.arange(-len_wavelet/2.,len_wavelet/2.) - xsd = xi[:,np.newaxis] / scales - wavelet_coefs=np.exp(complex(1j)*2.*np.pi*f0*xsd)*np.exp(-np.power(xsd,2)/2.) + xi = np.arange(-len_wavelet / 2.0, len_wavelet / 2.0) + xsd = xi[:, np.newaxis] / scales + wavelet_coefs = np.exp(complex(1j) * 2.0 * np.pi * f0 * xsd) * np.exp( + -np.power(xsd, 2) / 2.0 + ) - weighting_function = lambda x: x**(-(1.0+normalisation)) - wavelet_coefs = wavelet_coefs*weighting_function(scales[np.newaxis,:]) + weighting_function = lambda x: x ** (-(1.0 + normalisation)) + wavelet_coefs = wavelet_coefs * weighting_function(scales[np.newaxis, :]) # Transform the wavelet into the Fourier domain - wf=scipy.fftpack.fft(wavelet_coefs,axis=0) - wf=wf.conj() + wf = scipy.fftpack.fft(wavelet_coefs, axis=0) + wf = wf.conj() return wf - class TimeFreqViewer_ParamController(Base_MultiChannel_ParamController): some_clim_changed = QT.pyqtSignal() def on_channel_visibility_changed(self): - print('TimeFreqViewer_ParamController.on_channel_visibility_changed') + print("TimeFreqViewer_ParamController.on_channel_visibility_changed") self.viewer.create_grid() self.viewer.initialize_time_freq() self.viewer.refresh() def clim_zoom(self, factor): - #~ print('clim_zoom factor', factor) + # ~ print('clim_zoom factor', factor) self.viewer.by_channel_params.blockSignals(True) for i, p in enumerate(self.viewer.by_channel_params.children()): - p.param('clim').setValue(p.param('clim').value()*factor) + p.param("clim").setValue(p.param("clim").value() * factor) self.viewer.by_channel_params.blockSignals(False) self.some_clim_changed.emit() def compute_auto_clim(self): - print('compute_auto_clim') + print("compute_auto_clim") print(self.visible_channels) self.viewer.by_channel_params.blockSignals(True) maxs = [] - visibles, = np.nonzero(self.visible_channels) + (visibles,) = np.nonzero(self.visible_channels) for chan in visibles: if chan in self.viewer.last_wt_maps.keys(): m = np.max(self.viewer.last_wt_maps[chan]) - if self.viewer.params['scale_mode'] == 'by_channel': - self.viewer.by_channel_params['ch'+str(chan), 'clim'] = m + if self.viewer.params["scale_mode"] == "by_channel": + self.viewer.by_channel_params["ch" + str(chan), "clim"] = m else: maxs.append(m) - if self.viewer.params['scale_mode'] == 'same_for_all' and len(maxs)>0: + if self.viewer.params["scale_mode"] == "same_for_all" and len(maxs) > 0: for chan in visibles: - self.viewer.by_channel_params['ch'+str(chan), 'clim'] = max(maxs) + self.viewer.by_channel_params["ch" + str(chan), "clim"] = max(maxs) self.viewer.by_channel_params.blockSignals(False) self.some_clim_changed.emit() class TimeFreqWorker(QT.QObject): - data_ready = QT.pyqtSignal(int, float, float, float, float, float, object) + data_ready = QT.pyqtSignal(int, float, float, float, float, float, object) - def __init__(self, source,viewer, chan, parent=None): + def __init__(self, source, viewer, chan, parent=None): QT.QObject.__init__(self, parent) self.source = source self.viewer = viewer self.chan = chan - def on_request_data(self, chan, t, t_start, t_stop, visible_channels, worker_params): + def on_request_data( + self, chan, t, t_start, t_stop, visible_channels, worker_params + ): if chan != self.chan: return @@ -151,90 +184,88 @@ def on_request_data(self, chan, t, t_start, t_stop, visible_channels, worker_par return if self.viewer.t != t: - print('viewer has moved already', chan, self.viewer.t, t) + print("viewer has moved already", chan, self.viewer.t, t) # viewer has moved already return - ds_ratio = worker_params['downsample_ratio'] - sig_chunk_size = worker_params['sig_chunk_size'] - filter_sos = worker_params['filter_sos'] - wavelet_fourrier = worker_params['wavelet_fourrier'] - plot_length = worker_params['plot_length'] + ds_ratio = worker_params["downsample_ratio"] + sig_chunk_size = worker_params["sig_chunk_size"] + filter_sos = worker_params["filter_sos"] + wavelet_fourrier = worker_params["wavelet_fourrier"] + plot_length = worker_params["plot_length"] i_start = self.source.time_to_index(t_start) - #~ print('ds_ratio', ds_ratio) - #~ print('start', t_start, i_start) + # ~ print('ds_ratio', ds_ratio) + # ~ print('start', t_start, i_start) - if ds_ratio>1: - i_start = i_start - (i_start%ds_ratio) - #~ print('start', t_start, i_start) + if ds_ratio > 1: + i_start = i_start - (i_start % ds_ratio) + # ~ print('start', t_start, i_start) - #clip it + # clip it i_start = max(0, i_start) i_start = min(i_start, self.source.get_length()) - if ds_ratio>1: - #after clip - i_start = i_start - (i_start%ds_ratio) - #~ print('start', t_start, i_start) + if ds_ratio > 1: + # after clip + i_start = i_start - (i_start % ds_ratio) + # ~ print('start', t_start, i_start) i_stop = i_start + sig_chunk_size i_stop = min(i_stop, self.source.get_length()) - sigs_chunk = self.source.get_chunk(i_start=i_start, i_stop=i_stop) sig = sigs_chunk[:, chan] - if ds_ratio>1: + if ds_ratio > 1: small_sig = scipy.signal.sosfiltfilt(filter_sos, sig) - small_sig =small_sig[::ds_ratio].copy() # to ensure continuity + small_sig = small_sig[::ds_ratio].copy() # to ensure continuity else: - small_sig = sig.copy()# to ensure continuity + small_sig = sig.copy() # to ensure continuity left_pad = 0 if small_sig.shape[0] != wavelet_fourrier.shape[0]: - #Pad it + # Pad it z = np.zeros(wavelet_fourrier.shape[0], dtype=small_sig.dtype) left_pad = wavelet_fourrier.shape[0] - small_sig.shape[0] - z[:small_sig.shape[0]] = small_sig + z[: small_sig.shape[0]] = small_sig small_sig = z - - #avoid border effect + # avoid border effect small_sig -= small_sig.mean() - #~ print('sig', sig.shape, 'small_sig', small_sig.shape) + # ~ print('sig', sig.shape, 'small_sig', small_sig.shape) small_sig_f = scipy.fftpack.fft(small_sig) if small_sig_f.shape[0] != wavelet_fourrier.shape[0]: - print('oulala', small_sig_f.shape, wavelet_fourrier.shape) - #TODO pad with zeros somewhere + print("oulala", small_sig_f.shape, wavelet_fourrier.shape) + # TODO pad with zeros somewhere return - wt_tmp=scipy.fftpack.ifft(small_sig_f[:,np.newaxis]*wavelet_fourrier,axis=0) - wt = scipy.fftpack.fftshift(wt_tmp,axes=[0]) - wt = np.abs(wt).astype('float32') - if left_pad>0: + wt_tmp = scipy.fftpack.ifft( + small_sig_f[:, np.newaxis] * wavelet_fourrier, axis=0 + ) + wt = scipy.fftpack.fftshift(wt_tmp, axes=[0]) + wt = np.abs(wt).astype("float32") + if left_pad > 0: wt = wt[:-left_pad] wt_map = wt[:plot_length] - #~ wt_map =wt - #~ print('wt_map', wt_map.shape) + # ~ wt_map =wt + # ~ print('wt_map', wt_map.shape) + # ~ print('sleep', chan) + # ~ time.sleep(2.) - #~ print('sleep', chan) - #~ time.sleep(2.) - - #TODO t_start and t_stop wrong - #~ print('sub_sample_rate', worker_params['sub_sample_rate']) - #~ print('wanted_size', worker_params['wanted_size']) - #~ print('plot_length', plot_length) - #~ print(i_start, i_stop) + # TODO t_start and t_stop wrong + # ~ print('sub_sample_rate', worker_params['sub_sample_rate']) + # ~ print('wanted_size', worker_params['wanted_size']) + # ~ print('plot_length', plot_length) + # ~ print(i_start, i_stop) t1 = self.source.index_to_time(i_start) - t2 = self.source.index_to_time(i_start+wt_map.shape[0]*ds_ratio) - #~ t2 = self.source.index_to_time(i_stop) - self.data_ready.emit(chan, t, t_start, t_stop, t1, t2, wt_map) + t2 = self.source.index_to_time(i_start + wt_map.shape[0] * ds_ratio) + # ~ t2 = self.source.index_to_time(i_stop) + self.data_ready.emit(chan, t, t_start, t_stop, t1, t2, wt_map) class TimeFreqViewer(BaseMultiChannelViewer): - _default_params = default_params _default_by_channel_params = default_by_channel_params @@ -250,7 +281,7 @@ def __init__(self, **kargs): # make all not visible self.by_channel_params.blockSignals(True) for c in range(self.source.nb_channel): - self.by_channel_params['ch'+str(c), 'visible'] = c==0 + self.by_channel_params["ch" + str(c), "visible"] = c == 0 self.by_channel_params.blockSignals(False) self.make_param_controller() @@ -262,7 +293,6 @@ def __init__(self, **kargs): self.create_grid() self.initialize_time_freq() - self.last_wt_maps = {} self.threads = [] @@ -275,15 +305,16 @@ def __init__(self, **kargs): worker.moveToThread(thread) thread.start() - worker.data_ready.connect(self.on_data_ready) self.request_data.connect(worker.on_request_data) - self.params.param('xsize').setLimits((0, np.inf)) + self.params.param("xsize").setLimits((0, np.inf)) @classmethod def from_numpy(cls, sigs, sample_rate, t_start, name, channel_names=None): - source = InMemoryAnalogSignalSource(sigs, sample_rate, t_start, channel_names=channel_names) + source = InMemoryAnalogSignalSource( + sigs, sample_rate, t_start, channel_names=channel_names + ) view = cls(source=source, name=name) return view @@ -301,14 +332,14 @@ def set_layout(self): self.mainlayout.addWidget(self.graphiclayout) def on_param_change(self, params=None, changes=None): - #~ print('on_param_change') - #track if new scale mode - #~ for param, change, data in changes: - #~ if change != 'value': continue - #~ if param.name()=='scale_mode': - #~ self.params_controller.compute_rescale() - - #for simplification everything is recompute + # ~ print('on_param_change') + # track if new scale mode + # ~ for param, change, data in changes: + # ~ if change != 'value': continue + # ~ if param.name()=='scale_mode': + # ~ self.params_controller.compute_rescale() + + # for simplification everything is recompute self.change_color_scale() self.create_grid() self.initialize_time_freq() @@ -317,8 +348,13 @@ def on_param_change(self, params=None, changes=None): def create_grid(self): visible_channels = self.params_controller.visible_channels - self.plots = create_plot_grid(self.graphiclayout, self.params['nb_column'], visible_channels, - ViewBoxClass=MyViewBox, vb_params={}) + self.plots = create_plot_grid( + self.graphiclayout, + self.params["nb_column"], + visible_channels, + ViewBoxClass=MyViewBox, + vb_params={}, + ) for plot in self.plots: if plot is not None: @@ -334,8 +370,10 @@ def create_grid(self): self.plots[c].addItem(image) self.images.append(image) - vline = pg.InfiniteLine(angle = 90, movable = False, pen = self.params['vline_color']) - vline.setZValue(1) # ensure vline is above plot elements + vline = pg.InfiniteLine( + angle=90, movable=False, pen=self.params["vline_color"] + ) + vline.setZValue(1) # ensure vline is above plot elements self.plots[c].addItem(vline) self.vlines.append(vline) @@ -344,67 +382,77 @@ def create_grid(self): self.vlines.append(None) def initialize_time_freq(self): - tfr_params = self.params.param('timefreq') + tfr_params = self.params.param("timefreq") sample_rate = self.source.sample_rate # we take sample_rate = f_stop*4 or (original sample_rate) - if tfr_params['f_stop']*4 < sample_rate: - wanted_sub_sample_rate = tfr_params['f_stop']*4 + if tfr_params["f_stop"] * 4 < sample_rate: + wanted_sub_sample_rate = tfr_params["f_stop"] * 4 else: wanted_sub_sample_rate = sample_rate # this try to find the best size to get a timefreq of 2**N by changing # the sub_sample_rate and the sig_chunk_size d = self.worker_params = {} - d['wanted_size'] = self.params['xsize'] - l = d['len_wavelet'] = int(2**np.ceil(np.log(d['wanted_size']*wanted_sub_sample_rate)/np.log(2))) - d['sig_chunk_size'] = d['wanted_size']*self.source.sample_rate - d['downsample_ratio'] = int(np.ceil(d['sig_chunk_size']/l)) - d['sig_chunk_size'] = d['downsample_ratio']*l - d['sub_sample_rate'] = self.source.sample_rate/d['downsample_ratio'] - d['plot_length'] = int(d['wanted_size']*d['sub_sample_rate']) - - d['wavelet_fourrier'] = generate_wavelet_fourier(d['len_wavelet'], tfr_params['f_start'], tfr_params['f_stop'], - tfr_params['deltafreq'], d['sub_sample_rate'], tfr_params['f0'], tfr_params['normalisation']) - - if d['downsample_ratio']>1: + d["wanted_size"] = self.params["xsize"] + l = d["len_wavelet"] = int( + 2 ** np.ceil(np.log(d["wanted_size"] * wanted_sub_sample_rate) / np.log(2)) + ) + d["sig_chunk_size"] = d["wanted_size"] * self.source.sample_rate + d["downsample_ratio"] = int(np.ceil(d["sig_chunk_size"] / l)) + d["sig_chunk_size"] = d["downsample_ratio"] * l + d["sub_sample_rate"] = self.source.sample_rate / d["downsample_ratio"] + d["plot_length"] = int(d["wanted_size"] * d["sub_sample_rate"]) + + d["wavelet_fourrier"] = generate_wavelet_fourier( + d["len_wavelet"], + tfr_params["f_start"], + tfr_params["f_stop"], + tfr_params["deltafreq"], + d["sub_sample_rate"], + tfr_params["f0"], + tfr_params["normalisation"], + ) + + if d["downsample_ratio"] > 1: n = 8 - q = d['downsample_ratio'] - d['filter_sos'] = scipy.signal.cheby1(n, 0.05, 0.8 / q, output='sos') + q = d["downsample_ratio"] + d["filter_sos"] = scipy.signal.cheby1(n, 0.05, 0.8 / q, output="sos") else: - d['filter_sos'] = None + d["filter_sos"] = None def change_color_scale(self): N = 512 - cmap_name = self.params['colormap'] - cmap = matplotlib.cm.get_cmap(cmap_name , N) - + cmap_name = self.params["colormap"] + cmap = colormaps.get_cmap(cmap_name).resampled(N) lut = [] for i in range(N): - r,g,b,_ = matplotlib.colors.ColorConverter().to_rgba(cmap(i)) - lut.append([r*255,g*255,b*255]) - self.lut = np.array(lut, dtype='uint8') + r, g, b, _ = matplotlib.colors.ColorConverter().to_rgba(cmap(i)) + lut.append([r * 255, g * 255, b * 255]) + self.lut = np.array(lut, dtype="uint8") def auto_scale(self): - print('auto_scale', self.params['scale_mode']) + print("auto_scale", self.params["scale_mode"]) self.params_controller.compute_auto_clim() self.refresh() def refresh(self): - #~ print('TimeFreqViewer.refresh', self.t) + # ~ print('TimeFreqViewer.refresh', self.t) visible_channels = self.params_controller.visible_channels - xsize = self.params['xsize'] - xratio = self.params['xratio'] - t_start, t_stop = self.t-xsize*xratio , self.t+xsize*(1-xratio) + xsize = self.params["xsize"] + xratio = self.params["xratio"] + t_start, t_stop = self.t - xsize * xratio, self.t + xsize * (1 - xratio) for c in range(self.source.nb_channel): if visible_channels[c]: - self.request_data.emit(c, self.t, t_start, t_stop, visible_channels, self.worker_params) + self.request_data.emit( + c, self.t, t_start, t_stop, visible_channels, self.worker_params + ) - self.graphiclayout.setBackground(self.params['background_color']) + self.graphiclayout.setBackground(self.params["background_color"]) - def on_data_ready(self, chan, t, t_start, t_stop, t1,t2, wt_map): + def on_data_ready(self, chan, t, t_start, t_stop, t1, t2, wt_map): if not self.params_controller.visible_channels[chan]: return @@ -412,32 +460,31 @@ def on_data_ready(self, chan, t, t_start, t_stop, t1,t2, wt_map): return self.last_wt_maps[chan] = wt_map - f_start = self.params['timefreq', 'f_start'] - f_stop = self.params['timefreq', 'f_stop'] + f_start = self.params["timefreq", "f_start"] + f_stop = self.params["timefreq", "f_stop"] image = self.images[chan] - #~ print(t_start, f_start,self.worker_params['wanted_size'], f_stop-f_start) + # ~ print(t_start, f_start,self.worker_params['wanted_size'], f_stop-f_start) - #~ image.updateImage(wt_map) - clim = self.by_channel_params['ch{}'.format(chan), 'clim'] + # ~ image.updateImage(wt_map) + clim = self.by_channel_params["ch{}".format(chan), "clim"] image.setImage(wt_map, lut=self.lut, levels=[0, clim]) - image.setRect(QT.QRectF(t1, f_start,t2-t1, f_stop-f_start)) + image.setRect(QT.QRectF(t1, f_start, t2 - t1, f_stop - f_start)) - #TODO + # TODO # display_labels self.vlines[chan].setPos(t) - self.vlines[chan].setPen(self.params['vline_color']) + self.vlines[chan].setPen(self.params["vline_color"]) plot = self.plots[chan] - plot.setXRange(t_start, t_stop, padding = 0.0) - plot.setYRange(f_start, f_stop, padding = 0.0) + plot.setXRange(t_start, t_stop, padding=0.0) + plot.setYRange(f_start, f_stop, padding=0.0) - if self.params['display_labels']: - ch_name = '{}: {}'.format(chan, self.source.get_channel_name(chan=chan)) + if self.params["display_labels"]: + ch_name = "{}: {}".format(chan, self.source.get_channel_name(chan=chan)) self.plots[chan].setTitle(ch_name) else: self.plots[chan].setTitle(None) - - self.plots[chan].showAxis('left', self.params['show_axis']) - self.plots[chan].showAxis('bottom', self.params['show_axis']) + self.plots[chan].showAxis("left", self.params["show_axis"]) + self.plots[chan].showAxis("bottom", self.params["show_axis"]) diff --git a/ephyviewer/traceimageviewer.py b/ephyviewer/traceimageviewer.py index 7f80871..8c375b3 100644 --- a/ephyviewer/traceimageviewer.py +++ b/ephyviewer/traceimageviewer.py @@ -1,4 +1,4 @@ -import matplotlib.cm +from matplotlib import colormaps import matplotlib.colors import numpy as np import pyqtgraph as pg @@ -304,8 +304,7 @@ def closeEvent(self, event): def change_color_scale(self): N = 512 cmap_name = self.params["colormap"] - cmap = matplotlib.cm.get_cmap(cmap_name, N) - + cmap = colormaps.get_cmap(cmap_name).resampled(N) lut = [] for i in range(N): r, g, b, _ = matplotlib.colors.ColorConverter().to_rgba(cmap(i)) diff --git a/ephyviewer/traceviewer.py b/ephyviewer/traceviewer.py index 532bbce..b998777 100644 --- a/ephyviewer/traceviewer.py +++ b/ephyviewer/traceviewer.py @@ -1,68 +1,83 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) import numpy as np -#~ import matplotlib.cm -#~ import matplotlib.colors - from .myqt import QT import pyqtgraph as pg from .base import BaseMultiChannelViewer, Base_MultiChannel_ParamController -from .datasource import InMemoryAnalogSignalSource, AnalogSignalSourceWithScatter, NeoAnalogSignalSource, AnalogSignalFromNeoRawIOSource +from .datasource import ( + InMemoryAnalogSignalSource, + AnalogSignalSourceWithScatter, + NeoAnalogSignalSource, + AnalogSignalFromNeoRawIOSource, +) from .tools import mkCachedBrush - - - -#todo remove this +# todo remove this import time import threading - default_params = [ - {'name': 'xsize', 'type': 'float', 'value': 3., 'step': 0.1}, - {'name': 'xratio', 'type': 'float', 'value': 0.3, 'step': 0.1, 'limits': (0,1)}, - {'name': 'ylim_max', 'type': 'float', 'value': 10.}, - {'name': 'ylim_min', 'type': 'float', 'value': -10.}, - {'name': 'scatter_size', 'type': 'float', 'value': 10., 'limits': (0,np.inf)}, - {'name': 'scale_mode', 'type': 'list', 'value': 'real_scale', - 'limits':['real_scale', 'same_for_all', 'by_channel'] }, - {'name': 'auto_scale_factor', 'type': 'float', 'value': 0.1, 'step': 0.01, 'limits': (0,np.inf)}, - {'name': 'background_color', 'type': 'color', 'value': 'k'}, - {'name': 'vline_color', 'type': 'color', 'value': '#FFFFFFAA'}, - {'name': 'label_fill_color', 'type': 'color', 'value': '#222222DD'}, - {'name': 'label_size', 'type': 'int', 'value': 8, 'limits': (1,np.inf)}, - {'name': 'display_labels', 'type': 'bool', 'value': False}, - {'name': 'display_offset', 'type': 'bool', 'value': False}, - {'name': 'antialias', 'type': 'bool', 'value': False}, - {'name': 'decimation_method', 'type': 'list', 'value': 'min_max', 'limits': ['min_max', 'mean', 'pure_decimate', ]}, - {'name': 'line_width', 'type': 'float', 'value': 1., 'limits': (0, np.inf)}, - ] + {"name": "xsize", "type": "float", "value": 3.0, "step": 0.1}, + {"name": "xratio", "type": "float", "value": 0.3, "step": 0.1, "limits": (0, 1)}, + {"name": "ylim_max", "type": "float", "value": 10.0}, + {"name": "ylim_min", "type": "float", "value": -10.0}, + {"name": "scatter_size", "type": "float", "value": 10.0, "limits": (0, np.inf)}, + { + "name": "scale_mode", + "type": "list", + "value": "real_scale", + "limits": ["real_scale", "same_for_all", "by_channel"], + }, + { + "name": "auto_scale_factor", + "type": "float", + "value": 0.1, + "step": 0.01, + "limits": (0, np.inf), + }, + {"name": "background_color", "type": "color", "value": "k"}, + {"name": "vline_color", "type": "color", "value": "#FFFFFFAA"}, + {"name": "label_fill_color", "type": "color", "value": "#222222DD"}, + {"name": "label_size", "type": "int", "value": 8, "limits": (1, np.inf)}, + {"name": "display_labels", "type": "bool", "value": False}, + {"name": "display_offset", "type": "bool", "value": False}, + {"name": "antialias", "type": "bool", "value": False}, + { + "name": "decimation_method", + "type": "list", + "value": "min_max", + "limits": [ + "min_max", + "mean", + "pure_decimate", + ], + }, + {"name": "line_width", "type": "float", "value": 1.0, "limits": (0, np.inf)}, +] default_by_channel_params = [ - {'name': 'color', 'type': 'color', 'value': "#55FF00"}, - {'name': 'gain', 'type': 'float', 'value': 1, 'step': 0.1, 'decimals': 8}, - {'name': 'offset', 'type': 'float', 'value': 0., 'step': 0.1}, - {'name': 'visible', 'type': 'bool', 'value': True}, - ] + {"name": "color", "type": "color", "value": "#55FF00"}, + {"name": "gain", "type": "float", "value": 1, "step": 0.1, "decimals": 8}, + {"name": "offset", "type": "float", "value": 0.0, "step": 0.1}, + {"name": "visible", "type": "bool", "value": True}, +] - - -#TODO use Base_MultiChannel_ParamController instead of Base_ParamController -#to avoid code duplication - +# TODO use Base_MultiChannel_ParamController instead of Base_ParamController +# to avoid code duplication class TraceViewer_ParamController(Base_MultiChannel_ParamController): - def __init__(self, parent=None, viewer=None): - Base_MultiChannel_ParamController.__init__(self, parent=parent, viewer=viewer, with_visible=True, with_color=True) + Base_MultiChannel_ParamController.__init__( + self, parent=parent, viewer=viewer, with_visible=True, with_color=True + ) # raw_gains and raw_offsets are distinguished from adjustable gains and # offsets associated with this viewer because it makes placement of the @@ -77,51 +92,60 @@ def __init__(self, parent=None, viewer=None): self.raw_gains = np.ones(self.viewer.source.nb_channel) self.raw_offsets = np.zeros(self.viewer.source.nb_channel) - #TODO put this somewhere + # TODO put this somewhere - #~ v.addWidget(QT.QLabel(self.tr('Gain zoom (mouse wheel on graph):'),self)) - #~ h = QT.QHBoxLayout() - #~ v.addLayout(h) - #~ for label, factor in [('--', 1./5.), ('-', 1./1.1), ('+', 1.1), ('++', 5.),]: - #~ but = QT.QPushButton(label) - #~ but.factor = factor - #~ but.clicked.connect(self.on_but_ygain_zoom) - #~ h.addWidget(but) + # ~ v.addWidget(QT.QLabel(self.tr('Gain zoom (mouse wheel on graph):'),self)) + # ~ h = QT.QHBoxLayout() + # ~ v.addLayout(h) + # ~ for label, factor in [('--', 1./5.), ('-', 1./1.1), ('+', 1.1), ('++', 5.),]: + # ~ but = QT.QPushButton(label) + # ~ but.factor = factor + # ~ but.clicked.connect(self.on_but_ygain_zoom) + # ~ h.addWidget(but) - #~ self.ygain_factor = 1. + # ~ self.ygain_factor = 1. @property def selected(self): selected = np.ones(self.viewer.source.nb_channel, dtype=bool) - if self.viewer.source.nb_channel>1: + if self.viewer.source.nb_channel > 1: selected[:] = False selected[[ind.row() for ind in self.qlist.selectedIndexes()]] = True return selected @property def visible_channels(self): - visible = [self.viewer.by_channel_params['ch{}'.format(i), 'visible'] for i in range(self.source.nb_channel)] - return np.array(visible, dtype='bool') + visible = [ + self.viewer.by_channel_params["ch{}".format(i), "visible"] + for i in range(self.source.nb_channel) + ] + return np.array(visible, dtype="bool") @property def gains(self): - gains = [self.viewer.by_channel_params['ch{}'.format(i), 'gain'] for i in range(self.source.nb_channel)] + gains = [ + self.viewer.by_channel_params["ch{}".format(i), "gain"] + for i in range(self.source.nb_channel) + ] return np.array(gains) @gains.setter def gains(self, val): for c, v in enumerate(val): - self.viewer.by_channel_params['ch{}'.format(c), 'gain'] = v + self.viewer.by_channel_params["ch{}".format(c), "gain"] = v @property def offsets(self): - offsets = [self.viewer.by_channel_params['ch{}'.format(i), 'offset'] for i in range(self.source.nb_channel)] + offsets = [ + self.viewer.by_channel_params["ch{}".format(i), "offset"] + for i in range(self.source.nb_channel) + ] return np.array(offsets) @offsets.setter def offsets(self, val): for c, v in enumerate(val): - self.viewer.by_channel_params['ch{}'.format(c), 'offset'] = v + self.viewer.by_channel_params["ch{}".format(c), "offset"] = v @property def total_gains(self): @@ -139,21 +163,19 @@ def total_offsets(self): # = chunk * total_gains + total_offsets return (self.raw_offsets * self.gains) + self.offsets - - def estimate_median_mad(self): # Estimates are performed on real values for both raw and in-memory # sources, i.e., on sigs = chunk * raw_gains + raw_offsets, where # raw_gains = 1 and raw_offsets = 0 for in-memory sources. - #~ print('estimate_median_mad') - #~ t0 = time.perf_counter() + # ~ print('estimate_median_mad') + # ~ t0 = time.perf_counter() sigs = self.viewer.last_sigs_chunk - assert sigs is not None, 'Need to debug this' - #~ print(sigs) - #~ print(sigs.shape) + assert sigs is not None, "Need to debug this" + # ~ print(sigs) + # ~ print(sigs.shape) - if sigs.shape[0]>1000: + if sigs.shape[0] > 1000: # to fast auto scale on long signal ind = np.random.randint(0, sigs.shape[0], size=1000) sigs = sigs[ind, :] @@ -161,7 +183,7 @@ def estimate_median_mad(self): if sigs.shape[0] > 0: sigs = sigs * self.raw_gains + self.raw_offsets # calculate on real values self.signals_med = med = np.median(sigs, axis=0) - self.signals_mad = np.median(np.abs(sigs-med),axis=0)*1.4826 + self.signals_mad = np.median(np.abs(sigs - med), axis=0) * 1.4826 self.signals_min = np.min(sigs, axis=0) self.signals_max = np.max(sigs, axis=0) else: @@ -172,9 +194,9 @@ def estimate_median_mad(self): self.signals_min = -np.ones(n) self.signals_max = np.ones(n) - #~ t1 = time.perf_counter() - #~ print('estimate_median_mad DONE', t1-t0) - #~ print('self.signals_med', self.signals_med) + # ~ t1 = time.perf_counter() + # ~ print('estimate_median_mad DONE', t1-t0) + # ~ print('self.signals_med', self.signals_med) def compute_rescale(self): # estimate_median_mad operates on real values, i.e., on @@ -184,36 +206,55 @@ def compute_rescale(self): # = chunk * (raw_gains * gains) + (raw_offsets * gains + offsets) # = chunk * total_gains + total_offsets - scale_mode = self.viewer.params['scale_mode'] - #~ print('compute_rescale', scale_mode) + scale_mode = self.viewer.params["scale_mode"] + # ~ print('compute_rescale', scale_mode) self.viewer.all_params.blockSignals(True) gains = np.ones(self.viewer.source.nb_channel) offsets = np.zeros(self.viewer.source.nb_channel) nb_visible = np.sum(self.visible_channels) - #~ self.ygain_factor = 1 - if self.viewer.last_sigs_chunk is not None and self.viewer.last_sigs_chunk is not []: + # ~ self.ygain_factor = 1 + if ( + self.viewer.last_sigs_chunk is not None + and self.viewer.last_sigs_chunk is not [] + ): self.estimate_median_mad() - if scale_mode=='real_scale': - self.viewer.params['ylim_min'] = np.nanmin(self.signals_min[self.visible_channels]) - self.viewer.params['ylim_max'] = np.nanmax(self.signals_max[self.visible_channels]) + if scale_mode == "real_scale": + self.viewer.params["ylim_min"] = np.nanmin( + self.signals_min[self.visible_channels] + ) + self.viewer.params["ylim_max"] = np.nanmax( + self.signals_max[self.visible_channels] + ) else: - if scale_mode=='same_for_all': - gains[self.visible_channels] = np.ones(nb_visible, dtype=float) / max(self.signals_mad[self.visible_channels]) * self.viewer.params['auto_scale_factor'] - elif scale_mode=='by_channel': - gains[self.visible_channels] = np.ones(nb_visible, dtype=float) / self.signals_mad[self.visible_channels] * self.viewer.params['auto_scale_factor'] - offsets[self.visible_channels] = np.arange(nb_visible)[::-1] - self.signals_med[self.visible_channels]*gains[self.visible_channels] - self.viewer.params['ylim_min'] = -0.5 - self.viewer.params['ylim_max'] = nb_visible-0.5 + if scale_mode == "same_for_all": + gains[self.visible_channels] = ( + np.ones(nb_visible, dtype=float) + / max(self.signals_mad[self.visible_channels]) + * self.viewer.params["auto_scale_factor"] + ) + elif scale_mode == "by_channel": + gains[self.visible_channels] = ( + np.ones(nb_visible, dtype=float) + / self.signals_mad[self.visible_channels] + * self.viewer.params["auto_scale_factor"] + ) + offsets[self.visible_channels] = ( + np.arange(nb_visible)[::-1] + - self.signals_med[self.visible_channels] + * gains[self.visible_channels] + ) + self.viewer.params["ylim_min"] = -0.5 + self.viewer.params["ylim_max"] = nb_visible - 0.5 self.gains = gains self.offsets = offsets self.viewer.all_params.blockSignals(False) def on_channel_visibility_changed(self): - #~ print('on_channel_visibility_changed') + # ~ print('on_channel_visibility_changed') self.compute_rescale() self.viewer.refresh() @@ -222,45 +263,52 @@ def on_but_ygain_zoom(self): self.apply_ygain_zoom(factor) def apply_ygain_zoom(self, factor_ratio, chan_index=None): - - scale_mode = self.viewer.params['scale_mode'] + scale_mode = self.viewer.params["scale_mode"] self.viewer.all_params.blockSignals(True) - if scale_mode=='real_scale': - #~ self.ygain_factor *= factor_ratio + if scale_mode == "real_scale": + # ~ self.ygain_factor *= factor_ratio - self.viewer.params['ylim_max'] = self.viewer.params['ylim_max']*factor_ratio - self.viewer.params['ylim_min'] = self.viewer.params['ylim_min']*factor_ratio + self.viewer.params["ylim_max"] = ( + self.viewer.params["ylim_max"] * factor_ratio + ) + self.viewer.params["ylim_min"] = ( + self.viewer.params["ylim_min"] * factor_ratio + ) pass - #TODO ylims - else : - #~ self.ygain_factor *= factor_ratio - if not hasattr(self, 'signals_med'): + # TODO ylims + else: + # ~ self.ygain_factor *= factor_ratio + if not hasattr(self, "signals_med"): self.estimate_median_mad() - if scale_mode=='by_channel' and chan_index is not None: + if scale_mode == "by_channel" and chan_index is not None: # factor_ratio should be applied to only the desired channel, # so turn the scalar factor into a vector of ones everywhere # except at chan_index factor_ratio_vector = np.ones(self.source.nb_channel) factor_ratio_vector[chan_index] = factor_ratio factor_ratio = factor_ratio_vector - self.offsets = self.offsets + self.signals_med*self.gains * (1-factor_ratio) + self.offsets = self.offsets + self.signals_med * self.gains * ( + 1 - factor_ratio + ) self.gains = self.gains * factor_ratio self.viewer.all_params.blockSignals(False) self.viewer.refresh() - #~ print('apply_ygain_zoom', factor_ratio)#, 'self.ygain_factor', self.ygain_factor) + # ~ print('apply_ygain_zoom', factor_ratio)#, 'self.ygain_factor', self.ygain_factor) def apply_label_drag(self, label_y, chan_index): - self.viewer.by_channel_params['ch{}'.format(chan_index), 'offset'] = label_y - self.signals_med[chan_index]*self.gains[chan_index] - - + self.viewer.by_channel_params["ch{}".format(chan_index), "offset"] = ( + label_y - self.signals_med[chan_index] * self.gains[chan_index] + ) class DataGrabber(QT.QObject): - data_ready = QT.pyqtSignal(float, float, float, object, object, object, object, object) + data_ready = QT.pyqtSignal( + float, float, float, object, object, object, object, object + ) def __init__(self, source, viewer, parent=None): QT.QObject.__init__(self, parent) @@ -268,79 +316,89 @@ def __init__(self, source, viewer, parent=None): self.viewer = viewer self._max_point = 3000 - def get_data(self, t, t_start, t_stop, total_gains, total_offsets, visibles, decimation_method): - - i_start, i_stop = self.source.time_to_index(t_start), self.source.time_to_index(t_stop) + 2 - #~ print(t_start, t_stop, i_start, i_stop) - - ds_ratio = (i_stop - i_start)//self._max_point + 1 - #~ print() - #~ print('ds_ratio', ds_ratio, 'i_start i_stop', i_start, i_stop ) - - if ds_ratio>1: - i_start = i_start - (i_start%ds_ratio) - i_stop = i_stop - (i_stop%ds_ratio) - #~ print('i_start, i_stop', i_start, i_stop) - - #clip it + def get_data( + self, + t, + t_start, + t_stop, + total_gains, + total_offsets, + visibles, + decimation_method, + ): + i_start, i_stop = ( + self.source.time_to_index(t_start), + self.source.time_to_index(t_stop) + 2, + ) + # ~ print(t_start, t_stop, i_start, i_stop) + + ds_ratio = (i_stop - i_start) // self._max_point + 1 + # ~ print() + # ~ print('ds_ratio', ds_ratio, 'i_start i_stop', i_start, i_stop ) + + if ds_ratio > 1: + i_start = i_start - (i_start % ds_ratio) + i_stop = i_stop - (i_stop % ds_ratio) + # ~ print('i_start, i_stop', i_start, i_stop) + + # clip it i_start = max(0, i_start) i_start = min(i_start, self.source.get_length()) i_stop = max(0, i_stop) i_stop = min(i_stop, self.source.get_length()) - if ds_ratio>1: - #after clip - i_start = i_start - (i_start%ds_ratio) - i_stop = i_stop - (i_stop%ds_ratio) + if ds_ratio > 1: + # after clip + i_start = i_start - (i_start % ds_ratio) + i_stop = i_stop - (i_stop % ds_ratio) - #~ print('final i_start i_stop', i_start, i_stop ) + # ~ print('final i_start i_stop', i_start, i_stop ) sigs_chunk = self.source.get_chunk(i_start=i_start, i_stop=i_stop) - - - #~ print('sigs_chunk.shape', sigs_chunk.shape) + # ~ print('sigs_chunk.shape', sigs_chunk.shape) data_curves = sigs_chunk[:, visibles].T.copy() - if data_curves.dtype!='float32': - data_curves = data_curves.astype('float32') - - if ds_ratio>1: - + if data_curves.dtype != "float32": + data_curves = data_curves.astype("float32") - small_size = (data_curves.shape[1]//ds_ratio) - if decimation_method == 'min_max': + if ds_ratio > 1: + small_size = data_curves.shape[1] // ds_ratio + if decimation_method == "min_max": small_size *= 2 - small_arr = np.empty((data_curves.shape[0], small_size), dtype=data_curves.dtype) + small_arr = np.empty( + (data_curves.shape[0], small_size), dtype=data_curves.dtype + ) - if decimation_method == 'min_max' and data_curves.size>0: + if decimation_method == "min_max" and data_curves.size > 0: full_arr = data_curves.reshape(data_curves.shape[0], -1, ds_ratio) small_arr[:, ::2] = full_arr.max(axis=2) small_arr[:, 1::2] = full_arr.min(axis=2) - elif decimation_method == 'mean' and data_curves.size>0: + elif decimation_method == "mean" and data_curves.size > 0: full_arr = data_curves.reshape(data_curves.shape[0], -1, ds_ratio) small_arr[:, :] = full_arr.mean(axis=2) - elif decimation_method == 'pure_decimate': + elif decimation_method == "pure_decimate": small_arr[:, :] = data_curves[:, ::ds_ratio] elif data_curves.size == 0: pass - data_curves = small_arr - #~ print(data_curves.shape) + # ~ print(data_curves.shape) data_curves *= total_gains[visibles, None] data_curves += total_offsets[visibles, None] - dict_curves ={} + dict_curves = {} for i, c in enumerate(visibles): dict_curves[c] = data_curves[i, :] - #~ print(ds_ratio) + # ~ print(ds_ratio) t_start2 = self.source.index_to_time(i_start) - times_curves = np.arange(data_curves.shape[1], dtype='float64') # ensure high temporal precision (see issue #28) - times_curves /= self.source.sample_rate/ds_ratio - if ds_ratio>1 and decimation_method == 'min_max': - times_curves /=2 + times_curves = np.arange( + data_curves.shape[1], dtype="float64" + ) # ensure high temporal precision (see issue #28) + times_curves /= self.source.sample_rate / ds_ratio + if ds_ratio > 1 and decimation_method == "min_max": + times_curves /= 2 times_curves += t_start2 dict_scatter = None @@ -350,34 +408,76 @@ def get_data(self, t, t_start, t_stop, total_gains, total_offsets, visibles, dec for k in self.source.get_scatter_babels(): x, y = [[]], [[]] for i, c in enumerate(visibles): - scatter_inds = self.source.get_scatter(i_start=i_start, i_stop=i_stop, chan=c, label=k) - if scatter_inds is None: continue - x.append((scatter_inds-i_start)/self.source.sample_rate+t_start2) - y.append(sigs_chunk[scatter_inds-i_start, c]*total_gains[c]+total_offsets[c]) + scatter_inds = self.source.get_scatter( + i_start=i_start, i_stop=i_stop, chan=c, label=k + ) + if scatter_inds is None: + continue + x.append( + (scatter_inds - i_start) / self.source.sample_rate + t_start2 + ) + y.append( + sigs_chunk[scatter_inds - i_start, c] * total_gains[c] + + total_offsets[c] + ) dict_scatter[k] = (np.concatenate(x), np.concatenate(y)) - return t, t_start, t_stop, visibles, dict_curves, times_curves, sigs_chunk, dict_scatter - - def on_request_data(self, t, t_start, t_stop, total_gains, total_offsets, visibles, decimation_method): - #~ print('on_request_data', t_start, t_stop) + return ( + t, + t_start, + t_stop, + visibles, + dict_curves, + times_curves, + sigs_chunk, + dict_scatter, + ) + + def on_request_data( + self, + t, + t_start, + t_stop, + total_gains, + total_offsets, + visibles, + decimation_method, + ): + # ~ print('on_request_data', t_start, t_stop) if self.viewer.t != t: - #~ print('on_request_data not same t') + # ~ print('on_request_data not same t') return - t, t_start, t_stop, visibles, dict_curves, times_curves,\ - sigs_chunk, dict_scatter = self.get_data(t, t_start, t_stop, total_gains, total_offsets, visibles, decimation_method) - - - #~ print('on_request_data', threading.get_ident()) - #~ time.sleep(1.) - self.data_ready.emit(t, t_start, t_stop, visibles, dict_curves, times_curves, sigs_chunk, dict_scatter) - + ( + t, + t_start, + t_stop, + visibles, + dict_curves, + times_curves, + sigs_chunk, + dict_scatter, + ) = self.get_data( + t, t_start, t_stop, total_gains, total_offsets, visibles, decimation_method + ) + + # ~ print('on_request_data', threading.get_ident()) + # ~ time.sleep(1.) + self.data_ready.emit( + t, + t_start, + t_stop, + visibles, + dict_curves, + times_curves, + sigs_chunk, + dict_scatter, + ) class TraceLabelItem(pg.TextItem): - label_dragged = QT.pyqtSignal(float) label_ygain_zoom = QT.pyqtSignal(float) @@ -387,7 +487,7 @@ def __init__(self, **kwargs): self.dragOffset = None def mouseDragEvent(self, ev): - '''Emit the new y-coord of the label as it is dragged''' + """Emit the new y-coord of the label as it is dragged""" if ev.button() != QT.LeftButton: ev.ignore() @@ -408,11 +508,11 @@ def mouseDragEvent(self, ev): self.label_dragged.emit(new_y) def wheelEvent(self, ev): - '''Emit a yzoom factor for the associated trace''' + """Emit a yzoom factor for the associated trace""" if ev.modifiers() == QT.Qt.ControlModifier: - z = 5. if ev.delta()>0 else 1/5. + z = 5.0 if ev.delta() > 0 else 1 / 5.0 else: - z = 1.1 if ev.delta()>0 else 1/1.1 + z = 1.1 if ev.delta() > 0 else 1 / 1.1 self.label_ygain_zoom.emit(z) ev.accept() @@ -448,22 +548,37 @@ def __init__(self, useOpenGL=None, **kargs): self.datagrabber.moveToThread(self.thread) self.thread.start() - self.datagrabber.data_ready.connect(self.on_data_ready) self.request_data.connect(self.datagrabber.on_request_data) - self.params.param('xsize').setLimits((0, np.inf)) - + self.params.param("xsize").setLimits((0, np.inf)) @classmethod - def from_numpy(cls, sigs, sample_rate, t_start, name, channel_names=None, - scatter_indexes=None, scatter_channels=None, scatter_colors=None): - + def from_numpy( + cls, + sigs, + sample_rate, + t_start, + name, + channel_names=None, + scatter_indexes=None, + scatter_channels=None, + scatter_colors=None, + ): if scatter_indexes is None: - source = InMemoryAnalogSignalSource(sigs, sample_rate, t_start, channel_names=channel_names) + source = InMemoryAnalogSignalSource( + sigs, sample_rate, t_start, channel_names=channel_names + ) else: - source = AnalogSignalSourceWithScatter(sigs, sample_rate, t_start, channel_names=channel_names, - scatter_indexes=scatter_indexes, scatter_channels=scatter_channels, scatter_colors=scatter_colors) + source = AnalogSignalSourceWithScatter( + sigs, + sample_rate, + t_start, + channel_names=channel_names, + scatter_indexes=scatter_indexes, + scatter_channels=scatter_channels, + scatter_colors=scatter_colors, + ) view = cls(source=source, name=name) return view @@ -480,134 +595,180 @@ def closeEvent(self, event): self.thread.wait() def initialize_plot(self): - - self.vline = pg.InfiniteLine(angle = 90, movable = False, pen = self.params['vline_color']) - self.vline.setZValue(1) # ensure vline is above plot elements + self.vline = pg.InfiniteLine( + angle=90, movable=False, pen=self.params["vline_color"] + ) + self.vline.setZValue(1) # ensure vline is above plot elements self.plot.addItem(self.vline) self.curves = [] self.channel_labels = [] self.channel_offsets_line = [] for c in range(self.source.nb_channel): - color = self.by_channel_params['ch{}'.format(c), 'color'] - curve = pg.PlotCurveItem(pen='#7FFF00', downsampleMethod='peak', downsample=1, - autoDownsample=False, clipToView=True, antialias=False)#, connect='finite') + color = self.by_channel_params["ch{}".format(c), "color"] + curve = pg.PlotCurveItem( + pen="#7FFF00", + downsampleMethod="peak", + downsample=1, + autoDownsample=False, + clipToView=True, + antialias=False, + ) # , connect='finite') self.plot.addItem(curve) self.curves.append(curve) - ch_name = '{}: {}'.format(c, self.source.get_channel_name(chan=c)) - label = TraceLabelItem(text=ch_name, color=color, anchor=(0, 0.5), border=None, fill=self.params['label_fill_color']) - label.setZValue(2) # ensure labels are drawn above scatter + ch_name = "{}: {}".format(c, self.source.get_channel_name(chan=c)) + label = TraceLabelItem( + text=ch_name, + color=color, + anchor=(0, 0.5), + border=None, + fill=self.params["label_fill_color"], + ) + label.setZValue(2) # ensure labels are drawn above scatter font = label.textItem.font() - font.setPointSize(self.params['label_size']) + font.setPointSize(self.params["label_size"]) label.setFont(font) - label.label_dragged.connect(lambda label_y, chan_index=c: self.params_controller.apply_label_drag(label_y, chan_index)) - label.label_ygain_zoom.connect(lambda factor_ratio, chan_index=c: self.params_controller.apply_ygain_zoom(factor_ratio, chan_index)) + label.label_dragged.connect( + lambda label_y, chan_index=c: self.params_controller.apply_label_drag( + label_y, chan_index + ) + ) + label.label_ygain_zoom.connect( + lambda factor_ratio, chan_index=c: self.params_controller.apply_ygain_zoom( + factor_ratio, chan_index + ) + ) self.plot.addItem(label) self.channel_labels.append(label) - offset_line = pg.InfiniteLine(angle = 0, movable = False, pen = '#7FFF00') + offset_line = pg.InfiniteLine(angle=0, movable=False, pen="#7FFF00") self.plot.addItem(offset_line) self.channel_offsets_line.append(offset_line) if self.source.with_scatter: - self.scatter = pg.ScatterPlotItem(size=self.params['scatter_size'], pxMode = True) + self.scatter = pg.ScatterPlotItem( + size=self.params["scatter_size"], pxMode=True + ) self.plot.addItem(self.scatter) - - self.viewBox.xsize_zoom.connect(self.params_controller.apply_xsize_zoom) self.viewBox.ygain_zoom.connect(self.params_controller.apply_ygain_zoom) def on_param_change(self, params=None, changes=None): - #~ print('on_param_change') - #track if new scale mode + # ~ print('on_param_change') + # track if new scale mode for param, change, data in changes: - if change != 'value': continue - if param.name()=='scale_mode': + if change != "value": + continue + if param.name() == "scale_mode": self.params_controller.compute_rescale() - if param.name()=='antialias': + if param.name() == "antialias": for curve in self.curves: - curve.updateData(antialias=self.params['antialias']) - if param.name()=='scatter_size': + curve.updateData(antialias=self.params["antialias"]) + if param.name() == "scatter_size": if self.source.with_scatter: - self.scatter.setSize(self.params['scatter_size']) - if param.name()=='vline_color': - self.vline.setPen(self.params['vline_color']) - if param.name()=='label_fill_color': + self.scatter.setSize(self.params["scatter_size"]) + if param.name() == "vline_color": + self.vline.setPen(self.params["vline_color"]) + if param.name() == "label_fill_color": for label in self.channel_labels: - label.fill = pg.mkBrush(self.params['label_fill_color']) - if param.name()=='label_size': + label.fill = pg.mkBrush(self.params["label_fill_color"]) + if param.name() == "label_size": for label in self.channel_labels: font = label.textItem.font() - font.setPointSize(self.params['label_size']) + font.setPointSize(self.params["label_size"]) label.setFont(font) - self.refresh() def auto_scale(self): - #~ print('auto_scale', self.last_sigs_chunk) + # ~ print('auto_scale', self.last_sigs_chunk) if self.last_sigs_chunk is None: - xsize = self.params['xsize'] - xratio = self.params['xratio'] - t_start, t_stop = self.t-xsize*xratio , self.t+xsize*(1-xratio) - visibles, = np.nonzero(self.params_controller.visible_channels) + xsize = self.params["xsize"] + xratio = self.params["xratio"] + t_start, t_stop = self.t - xsize * xratio, self.t + xsize * (1 - xratio) + (visibles,) = np.nonzero(self.params_controller.visible_channels) total_gains = self.params_controller.total_gains total_offsets = self.params_controller.total_offsets - _, _, _, _, _, _,sigs_chunk, _ = self.datagrabber.get_data(self.t, t_start, t_stop, total_gains, - total_offsets, visibles, self.params['decimation_method']) + _, _, _, _, _, _, sigs_chunk, _ = self.datagrabber.get_data( + self.t, + t_start, + t_stop, + total_gains, + total_offsets, + visibles, + self.params["decimation_method"], + ) self.last_sigs_chunk = sigs_chunk self.params_controller.compute_rescale() self.refresh() def refresh(self): - #~ print('TraceViewer.refresh', 't', self.t) - xsize = self.params['xsize'] - xratio = self.params['xratio'] - t_start, t_stop = self.t-xsize*xratio , self.t+xsize*(1-xratio) - visibles, = np.nonzero(self.params_controller.visible_channels) + # ~ print('TraceViewer.refresh', 't', self.t) + xsize = self.params["xsize"] + xratio = self.params["xratio"] + t_start, t_stop = self.t - xsize * xratio, self.t + xsize * (1 - xratio) + (visibles,) = np.nonzero(self.params_controller.visible_channels) total_gains = self.params_controller.total_gains total_offsets = self.params_controller.total_offsets - self.request_data.emit(self.t, t_start, t_stop, total_gains, total_offsets, visibles, self.params['decimation_method']) - - - def on_data_ready(self, t, t_start, t_stop, visibles, dict_curves, times_curves, sigs_chunk, dict_scatter): - #~ print('on_data_ready', t, t_start, t_stop) + self.request_data.emit( + self.t, + t_start, + t_stop, + total_gains, + total_offsets, + visibles, + self.params["decimation_method"], + ) + + def on_data_ready( + self, + t, + t_start, + t_stop, + visibles, + dict_curves, + times_curves, + sigs_chunk, + dict_scatter, + ): + # ~ print('on_data_ready', t, t_start, t_stop) if self.t != t: - #~ print('on_data_ready not same t') + # ~ print('on_data_ready not same t') return - self.graphicsview.setBackground(self.params['background_color']) + self.graphicsview.setBackground(self.params["background_color"]) self.last_sigs_chunk = sigs_chunk offsets = self.params_controller.offsets gains = self.params_controller.gains - if not hasattr(self.params_controller, 'signals_med'): + if not hasattr(self.params_controller, "signals_med"): self.params_controller.estimate_median_mad() signals_med = self.params_controller.signals_med - for i, c in enumerate(visibles): self.curves[c].show() self.curves[c].setData(times_curves, dict_curves[c]) - color = self.by_channel_params['ch{}'.format(c), 'color'] - self.curves[c].setPen(color, width=self.params['line_width']) + color = self.by_channel_params["ch{}".format(c), "color"] + self.curves[c].setPen(color, width=self.params["line_width"]) - if self.params['display_labels']: + if self.params["display_labels"]: self.channel_labels[c].show() - self.channel_labels[c].setPos(t_start, offsets[c] + signals_med[c]*gains[c]) + self.channel_labels[c].setPos( + t_start, offsets[c] + signals_med[c] * gains[c] + ) self.channel_labels[c].setColor(color) else: self.channel_labels[c].hide() - if self.params['display_offset']: + if self.params["display_offset"]: self.channel_offsets_line[c].show() self.channel_offsets_line[c].setPos(offsets[c]) self.channel_offsets_line[c].setPen(color) @@ -632,8 +793,8 @@ def on_data_ready(self, t, t_start, t_stop, visibles, dict_curves, times_curve # here we must use cached brushes to avoid issues with # the SymbolAtlas in pyqtgraph >= 0.11.1. # see https://github.com/NeuralEnsemble/ephyviewer/issues/132 - color = self.source.scatter_colors.get(k, '#FFFFFF') - all_brush.append(np.array([mkCachedBrush(color)]*len(x))) + color = self.source.scatter_colors.get(k, "#FFFFFF") + all_brush.append(np.array([mkCachedBrush(color)] * len(x))) if len(all_x): all_x = np.concatenate(all_x) @@ -642,7 +803,9 @@ def on_data_ready(self, t, t_start, t_stop, visibles, dict_curves, times_curve self.scatter.setData(x=all_x, y=all_y, brush=all_brush) self.vline.setPos(self.t) - self.plot.setXRange( t_start, t_stop, padding = 0.0) - self.plot.setYRange(self.params['ylim_min'], self.params['ylim_max'], padding = 0.0) + self.plot.setXRange(t_start, t_stop, padding=0.0) + self.plot.setYRange( + self.params["ylim_min"], self.params["ylim_max"], padding=0.0 + ) - #~ self.graphicsview.repaint() + # ~ self.graphicsview.repaint() diff --git a/ephyviewer/videoviewer.py b/ephyviewer/videoviewer.py index 0ee5cd1..98a3ca0 100644 --- a/ephyviewer/videoviewer.py +++ b/ephyviewer/videoviewer.py @@ -1,42 +1,37 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) import numpy as np -import matplotlib.cm -import matplotlib.colors - from .myqt import QT import pyqtgraph as pg from pyqtgraph.util.mutex import Mutex -from .base import ViewerBase, BaseMultiChannelViewer, Base_MultiChannel_ParamController +from .base import ViewerBase, BaseMultiChannelViewer, Base_MultiChannel_ParamController from .datasource import FrameGrabber, MultiVideoFileSource from .tools import create_plot_grid import threading default_params = [ - {'name': 'nb_column', 'type': 'int', 'value': 4}, - ] + {"name": "nb_column", "type": "int", "value": 4}, +] default_by_channel_params = [ - {'name': 'visible', 'type': 'bool', 'value': True}, - ] - + {"name": "visible", "type": "bool", "value": True}, +] - -#~ class QFrameGrabber(QT.QObject, FrameGrabber): - #~ frame_ready = QT.pyqtSignal(int, object) - #~ def on_request_frame(self, video_index, target_frame): - #~ if self.video_index!=video_index: - #~ return - #~ frame = self.get_frame(target_frame) - #~ if not frame: - #~ return - #~ self.frame_ready.emit(self.video_index, frame) +# ~ class QFrameGrabber(QT.QObject, FrameGrabber): +# ~ frame_ready = QT.pyqtSignal(int, object) +# ~ def on_request_frame(self, video_index, target_frame): +# ~ if self.video_index!=video_index: +# ~ return +# ~ frame = self.get_frame(target_frame) +# ~ if not frame: +# ~ return +# ~ self.frame_ready.emit(self.video_index, frame) class QFrameGrabber(QT.QObject): @@ -52,43 +47,43 @@ def __init__(self, frame_grabber, video_index, parent=None): self.request_list = [] self._last_frame = None - #~ def on_request_frame(self, video_index, target_frame): - #~ if self.video_index!=video_index: - #~ return - #~ frame = self.fg.get_frame(target_frame) - #~ if frame is None: - #~ return - #~ self.frame_ready.emit(self.video_index, frame) + # ~ def on_request_frame(self, video_index, target_frame): + # ~ if self.video_index!=video_index: + # ~ return + # ~ frame = self.fg.get_frame(target_frame) + # ~ if frame is None: + # ~ return + # ~ self.frame_ready.emit(self.video_index, frame) @property def last_frame(self): with self.mutex: return self._last_frame - #~ @property - #~ def active_frame(self): - #~ return self.fg.active_frame + # ~ @property + # ~ def active_frame(self): + # ~ return self.fg.active_frame - #~ @active_frame.setter - #~ def active_frame(self, value): - #~ self.fg.active_frame = value + # ~ @active_frame.setter + # ~ def active_frame(self, value): + # ~ self.fg.active_frame = value - #~ def queue_request_frame(self, target_frame): - #~ print('queue_request_frame', threading.get_ident()) - #~ with self.mutex: - #~ self.request_list.append(target_frame) - #~ if len(self.request_list)>1: - #~ self.request_frame.emit() + # ~ def queue_request_frame(self, target_frame): + # ~ print('queue_request_frame', threading.get_ident()) + # ~ with self.mutex: + # ~ self.request_list.append(target_frame) + # ~ if len(self.request_list)>1: + # ~ self.request_frame.emit() def on_request_frame(self, video_index): - if self.video_index!=video_index: + if self.video_index != video_index: return - #~ print('on_request_frame', threading.get_ident()) + # ~ print('on_request_frame', threading.get_ident()) with self.mutex: - #~ print('len(self.request_list)', len(self.request_list)) - if len(self.request_list)==0: + # ~ print('len(self.request_list)', len(self.request_list)) + if len(self.request_list) == 0: return target_frame = self.request_list[-1] self.request_list = [] @@ -99,19 +94,17 @@ def on_request_frame(self, video_index): frame = self.fg.get_frame(target_frame) with self.mutex: self._last_frame = target_frame - #~ print('new self._last_frame', self._last_frame) - #~ self.fg.active_frame = target_frame + # ~ print('new self._last_frame', self._last_frame) + # ~ self.fg.active_frame = target_frame self.frame_ready.emit(self.video_index, frame) - - class VideoViewer_ParamController(Base_MultiChannel_ParamController): pass -#~ class VideoViewer(ViewerBase): -class VideoViewer(BaseMultiChannelViewer): +# ~ class VideoViewer(ViewerBase): +class VideoViewer(BaseMultiChannelViewer): _default_params = default_params _default_by_channel_params = default_by_channel_params @@ -128,12 +121,12 @@ def __init__(self, **kargs): self.qframe_grabbers = [] self.threads = [] - #~ self.actual_frames = [] + # ~ self.actual_frames = [] for i, video_filename in enumerate(self.source.video_filenames): fg = QFrameGrabber(self.source.frame_grabbers[i], i) self.qframe_grabbers.append(fg) - #~ fg.set_file(video_filename) - #~ fg.video_index = i + # ~ fg.set_file(video_filename) + # ~ fg.video_index = i fg.frame_ready.connect(self.update_frame) thread = QT.QThread(parent=self) @@ -149,14 +142,12 @@ def from_filenames(cls, video_filenames, video_times, name): view = cls(source=source, name=name) return view - def closeEvent(self, event): for i, thread in enumerate(self.threads): thread.quit() thread.wait() event.accept() - def set_layout(self): self.mainlayout = QT.QVBoxLayout() self.setLayout(self.mainlayout) @@ -171,13 +162,17 @@ def on_param_change(self): def create_grid(self): visible_channels = self.params_controller.visible_channels - self.plots = create_plot_grid(self.graphiclayout, self.params['nb_column'], visible_channels, - ViewBoxClass=pg.ViewBox, vb_params={'lockAspect':True}) + self.plots = create_plot_grid( + self.graphiclayout, + self.params["nb_column"], + visible_channels, + ViewBoxClass=pg.ViewBox, + vb_params={"lockAspect": True}, + ) for plot in self.plots: - plot.showAxis('left', False) - plot.showAxis('bottom', False) - + plot.showAxis("left", False) + plot.showAxis("bottom", False) self.images = [] for c in range(self.source.nb_channel): @@ -189,80 +184,61 @@ def create_grid(self): self.images.append(None) def refresh(self): - #~ print('videoviewer.refresh', self.t) + # ~ print('videoviewer.refresh', self.t) visible_channels = self.params_controller.visible_channels - #~ print() - #~ print('refresh t=', self.t) + # ~ print() + # ~ print('refresh t=', self.t) for c in range(self.source.nb_channel): if visible_channels[c]: frame_index = self.source.time_to_frame_index(c, self.t) - #~ print( 'c', c, 'frame_index', frame_index, 'self.qframe_grabbers[c].last_frame', self.qframe_grabbers[c].last_frame) + # ~ print( 'c', c, 'frame_index', frame_index, 'self.qframe_grabbers[c].last_frame', self.qframe_grabbers[c].last_frame) - #~ if self.qframe_grabbers[c].active_frame != frame_index: - #~ print('self.qframe_grabbers[c].last_frame != frame_index', self.qframe_grabbers[c].last_frame != frame_index) + # ~ if self.qframe_grabbers[c].active_frame != frame_index: + # ~ print('self.qframe_grabbers[c].last_frame != frame_index', self.qframe_grabbers[c].last_frame != frame_index) if self.qframe_grabbers[c].last_frame != frame_index: + # ~ self.qframe_grabbers[c].active_frame = frame_index + # ~ self.request_frame.emit(c, frame_index) - #~ self.qframe_grabbers[c].active_frame = frame_index - #~ self.request_frame.emit(c, frame_index) - - #~ self.qframe_grabbers[c].queue_request_frame(frame_index) - #~ print('enque frame', threading.get_ident(), 'frame_index', frame_index) + # ~ self.qframe_grabbers[c].queue_request_frame(frame_index) + # ~ print('enque frame', threading.get_ident(), 'frame_index', frame_index) with self.qframe_grabbers[c].mutex: - self.qframe_grabbers[c].request_list.append(frame_index) - if len(self.qframe_grabbers[c].request_list)>=1: + if len(self.qframe_grabbers[c].request_list) >= 1: self.request_frame.emit(c) - #~ print('EMIT!!!!!') - - - - - - + # ~ print('EMIT!!!!!') def update_frame(self, video_index, frame): - #~ print('update_frame', video_index, frame) + # ~ print('update_frame', video_index, frame) if frame is None: self.images[video_index].clear() else: - #TODO : find better solution!!!! to avoid copy + # TODO : find better solution!!!! to avoid copy try: # PyAV >= 0.5.3 - img = frame.to_ndarray(format='rgb24') + img = frame.to_ndarray(format="rgb24") except AttributeError: # PyAV < 0.5.3 - img = frame.to_nd_array(format='rgb24') - img = img.swapaxes(0,1)[:,::-1,:] - #~ print(img.shape, img.dtype) + img = frame.to_nd_array(format="rgb24") + img = img.swapaxes(0, 1)[:, ::-1, :] + # ~ print(img.shape, img.dtype) self.images[video_index].setImage(img) - - #~ rgba = frame.reformat(frame.width, frame.height, "rgb24", 'itu709') - #print rgba.to_image().save("test.png") + # ~ rgba = frame.reformat(frame.width, frame.height, "rgb24", 'itu709') + # print rgba.to_image().save("test.png") # could use the buffer interface here instead, some versions of PyQt don't support it for some reason # need to track down which version they added support for it - #~ bytearray(rgba.planes[0]) - #~ bytesPerPixel =3 - #~ img = QT.QImage(bytearray(rgba.planes[0]), rgba.width, rgba.height, rgba.width * bytesPerPixel, QT.QImage.Format_RGB888) - #~ self.images[video_index].setImage(img) - - #img = QtGui.QImage(rgba.planes[0], rgba.width, rgba.height, QtGui.QImage.Format_RGB888) - - - - - - - - #~ if self.actual_frames[i] != new_frame: - - + # ~ bytearray(rgba.planes[0]) + # ~ bytesPerPixel =3 + # ~ img = QT.QImage(bytearray(rgba.planes[0]), rgba.width, rgba.height, rgba.width * bytesPerPixel, QT.QImage.Format_RGB888) + # ~ self.images[video_index].setImage(img) + # img = QtGui.QImage(rgba.planes[0], rgba.width, rgba.height, QtGui.QImage.Format_RGB888) + # ~ if self.actual_frames[i] != new_frame: - #~ frame = self.source.get_frame(t=t,chan=c) - #~ self.images[c].setImage(frame) - #~ else: - #~ pass + # ~ frame = self.source.get_frame(t=t,chan=c) + # ~ self.images[c].setImage(frame) + # ~ else: + # ~ pass From 00789e3eaede09f73d17890ba390282a00324adf Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Tue, 9 May 2023 22:40:02 -0500 Subject: [PATCH 06/12] Add TraceImageViewer example --- examples/trace_image_viewer_datasource.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 examples/trace_image_viewer_datasource.py diff --git a/examples/trace_image_viewer_datasource.py b/examples/trace_image_viewer_datasource.py new file mode 100644 index 0000000..529bcf5 --- /dev/null +++ b/examples/trace_image_viewer_datasource.py @@ -0,0 +1,22 @@ +import ephyviewer +from ephyviewer.tests.testing_tools import make_fake_signals + + +# Create the main Qt application (for event loop) +app = ephyviewer.mkQApp() + +# Create the main window that can contain several viewers +win = ephyviewer.MainViewer(debug=True, show_auto_scale=True) + +# Create a TraceView and add it to the main window +view1 = ephyviewer.TraceViewer(source=make_fake_signals(), name="LFPs") +view1.params["scale_mode"] = "same_for_all" +win.add_view(view1) + +# Create a TraceImageView and add it to the main window +view2 = ephyviewer.TraceImageViewer(source=make_fake_signals(), name="CSDs") +win.add_view(view2) + +# show main window and run Qapp +win.show() +app.exec() From 9fdcbae83a738e0b5455073ea759fe2bb32d2d11 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Tue, 9 May 2023 22:44:18 -0500 Subject: [PATCH 07/12] Remove unused imports --- ephyviewer/dataframeview.py | 51 +++--- ephyviewer/eventlist.py | 47 +++-- ephyviewer/navigation.py | 250 +++++++++++++++------------ ephyviewer/scripts.py | 64 ++++--- ephyviewer/spectrogramviewer.py | 5 - examples/trace_viewer_datasource.py | 22 ++- examples/trace_viewer_with_marker.py | 51 +++--- 7 files changed, 264 insertions(+), 226 deletions(-) diff --git a/ephyviewer/dataframeview.py b/ephyviewer/dataframeview.py index 7fdeb7e..f8ac7e5 100644 --- a/ephyviewer/dataframeview.py +++ b/ephyviewer/dataframeview.py @@ -1,47 +1,44 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) import numpy as np from .myqt import QT -import pyqtgraph as pg from .base import ViewerBase -from .datasource import InMemoryEventSource def dataframe_on_qtable(qtable, df): - qtable.clear() - qtable.setColumnCount(len(df.columns)) - qtable.setRowCount(len(df.index)) + qtable.clear() + qtable.setColumnCount(len(df.columns)) + qtable.setRowCount(len(df.index)) - qtable.setHorizontalHeaderLabels(['{}'.format(c) for c in df.columns]) - qtable.setVerticalHeaderLabels(['{}'.format(c) for c in df.index]) - - for r, row in enumerate(df.index): - for c, col in enumerate(df.columns): - try: - v = u'{}'.format(df.loc[row,col]) - qtable.setItem(r, c, QT.QTableWidgetItem(v)) - except Exception as e: - print('erreur', e) - print(r, row) - print(c, col) + qtable.setHorizontalHeaderLabels(["{}".format(c) for c in df.columns]) + qtable.setVerticalHeaderLabels(["{}".format(c) for c in df.index]) + for r, row in enumerate(df.index): + for c, col in enumerate(df.columns): + try: + v = "{}".format(df.loc[row, col]) + qtable.setItem(r, c, QT.QTableWidgetItem(v)) + except Exception as e: + print("erreur", e) + print(r, row) + print(c, col) class DataFrameView(ViewerBase): - - def __init__(self, **kargs): + def __init__(self, **kargs): ViewerBase.__init__(self, **kargs) self.mainlayout = QT.QVBoxLayout() self.setLayout(self.mainlayout) - - self.qtable = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection, - selectionBehavior=QT.QAbstractItemView.SelectRows) + self.qtable = QT.QTableWidget( + selectionMode=QT.QAbstractItemView.SingleSelection, + selectionBehavior=QT.QAbstractItemView.SelectRows, + ) self.qtable.itemClicked.connect(self.on_selection_changed) self.mainlayout.addWidget(self.qtable) @@ -51,11 +48,11 @@ def refresh(self): pass def on_selection_changed(self): - if 'time' not in self.source.columns: + if "time" not in self.source.columns: return - ind = [e.row() for e in self.qtable.selectedIndexes() if e.column()==0] - if len(ind)==1: - t = self.source['time'].iloc[ind[0]] + ind = [e.row() for e in self.qtable.selectedIndexes() if e.column() == 0] + if len(ind) == 1: + t = self.source["time"].iloc[ind[0]] if t is not None: t = float(t) if not np.isnan(t): diff --git a/ephyviewer/eventlist.py b/ephyviewer/eventlist.py index 7119e9d..281c5e3 100644 --- a/ephyviewer/eventlist.py +++ b/ephyviewer/eventlist.py @@ -1,34 +1,28 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) - - -import numpy as np +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) from .myqt import QT -import pyqtgraph as pg from .base import ViewerBase from .datasource import InMemoryEventSource, NeoEventSource - - class EventList(ViewerBase): - - def __init__(self, **kargs): + def __init__(self, **kargs): ViewerBase.__init__(self, **kargs) self.mainlayout = QT.QVBoxLayout() self.setLayout(self.mainlayout) - self.combo = QT.QComboBox() self.mainlayout.addWidget(self.combo) self.list_widget = QT.QListWidget() self.mainlayout.addWidget(self.list_widget) self.combo.currentIndexChanged.connect(self.refresh_list) - self.combo.addItems([self.source.get_channel_name(i) for i in range(self.source.nb_channel) ]) + self.combo.addItems( + [self.source.get_channel_name(i) for i in range(self.source.nb_channel)] + ) self.list_widget.itemClicked.connect(self.select_event) @@ -50,41 +44,42 @@ def refresh(self): def refresh_list(self, ind): self.ind = ind self.list_widget.clear() - #~ ev = self.source.all_events[ind] - data = self.source.get_chunk(chan=ind, i_start=None, i_stop=None) + # ~ ev = self.source.all_events[ind] + data = self.source.get_chunk(chan=ind, i_start=None, i_stop=None) - if len(data)==2: + if len(data) == 2: times, labels = data - elif len(data)==3: + elif len(data) == 3: times, _, labels = data - elif len(data)==4: + elif len(data) == 4: times, _, labels, _ = data else: raise ValueError("data has unexpected dimensions") for i in range(times.size): if labels is None: - self.list_widget.addItem('{} : {:.3f}'.format(i, times[i]) ) + self.list_widget.addItem("{} : {:.3f}".format(i, times[i])) else: - self.list_widget.addItem('{} : {:.3f} {}'.format(i, times[i], labels[i]) ) - + self.list_widget.addItem( + "{} : {:.3f} {}".format(i, times[i], labels[i]) + ) def select_event(self): i = self.list_widget.currentRow() - #~ ev = self.source.all_events[self.ind] - #~ t = ev['time'][i] - data = self.source.get_chunk(chan=self.ind, i_start=i, i_stop=i+1) + # ~ ev = self.source.all_events[self.ind] + # ~ t = ev['time'][i] + data = self.source.get_chunk(chan=self.ind, i_start=i, i_stop=i + 1) - if len(data)==2: + if len(data) == 2: times, labels = data - elif len(data)==3: + elif len(data) == 3: times, _, labels = data - elif len(data)==4: + elif len(data) == 4: times, _, labels, _ = data else: raise ValueError("data has unexpected dimensions") - if len(times)>0: + if len(times) > 0: t = float(times[0]) self.time_changed.emit(t) diff --git a/ephyviewer/navigation.py b/ephyviewer/navigation.py index 33fa2e0..f2313b8 100644 --- a/ephyviewer/navigation.py +++ b/ephyviewer/navigation.py @@ -1,34 +1,40 @@ # -*- coding: utf-8 -*- -#~ from __future__ import (unicode_literals, print_function, division, absolute_import) +# ~ from __future__ import (unicode_literals, print_function, division, absolute_import) from .myqt import QT import pyqtgraph as pg import numpy as np -from collections import OrderedDict - import time import datetime -#TODO: +# TODO: # * xsize in navigation # * real time when possible -class NavigationToolBar(QT.QWidget) : - """ - """ + +class NavigationToolBar(QT.QWidget): + """ """ + time_changed = QT.pyqtSignal(float) xsize_changed = QT.pyqtSignal(float) auto_scale_requested = QT.pyqtSignal() - def __init__(self, parent=None, show_play=True, show_step=True, - show_scroll_time=True, show_spinbox=True, - show_label_datetime=False, datetime0=None, - datetime_format='%Y-%m-%d %H:%M:%S', - show_global_xsize=True, show_auto_scale=True, - play_interval = 0.1) : - + def __init__( + self, + parent=None, + show_play=True, + show_step=True, + show_scroll_time=True, + show_spinbox=True, + show_label_datetime=False, + datetime0=None, + datetime_format="%Y-%m-%d %H:%M:%S", + show_global_xsize=True, + show_auto_scale=True, + play_interval=0.1, + ): QT.QWidget.__init__(self, parent) self.setSizePolicy(QT.QSizePolicy.Minimum, QT.QSizePolicy.Maximum) @@ -36,11 +42,9 @@ def __init__(self, parent=None, show_play=True, show_step=True, self.mainlayout = QT.QVBoxLayout() self.setLayout(self.mainlayout) - - - #~ self.toolbar = QT.QToolBar() - #~ self.mainlayout.addWidget(self.toolbar) - #~ t = self.toolbar + # ~ self.toolbar = QT.QToolBar() + # ~ self.mainlayout.addWidget(self.toolbar) + # ~ t = self.toolbar self.show_play = show_play self.show_step = show_step @@ -53,56 +57,68 @@ def __init__(self, parent=None, show_play=True, show_step=True, self.datetime0 = datetime0 self.datetime_format = datetime_format - if show_scroll_time: - #~ self.slider = QSlider() - self.scroll_time = QT.QScrollBar(orientation=QT.Horizontal, minimum=0, maximum=1000) + # ~ self.slider = QSlider() + self.scroll_time = QT.QScrollBar( + orientation=QT.Horizontal, minimum=0, maximum=1000 + ) self.mainlayout.addWidget(self.scroll_time) self.scroll_time.valueChanged.connect(self.on_scroll_time_changed) - #TODO min/max/step - #~ self.scroll_time.valueChanged.disconnect(self.on_scroll_time_changed) - #~ self.scroll_time.setValue(int(sr*t)) - #~ self.scroll_time.setPageStep(int(sr*self.xsize)) - #~ self.scroll_time.valueChanged.connect(self.on_scroll_time_changed) - #~ self.scroll_time.setMinimum(0) - #~ self.scroll_time.setMaximum(length) + # TODO min/max/step + # ~ self.scroll_time.valueChanged.disconnect(self.on_scroll_time_changed) + # ~ self.scroll_time.setValue(int(sr*t)) + # ~ self.scroll_time.setPageStep(int(sr*self.xsize)) + # ~ self.scroll_time.valueChanged.connect(self.on_scroll_time_changed) + # ~ self.scroll_time.setMinimum(0) + # ~ self.scroll_time.setMaximum(length) h = QT.QHBoxLayout() h.addStretch() self.mainlayout.addLayout(h) if show_play: - but = QT.QPushButton(icon=QT.QIcon(':/media-playback-start.svg')) + but = QT.QPushButton(icon=QT.QIcon(":/media-playback-start.svg")) but.clicked.connect(self.on_play) h.addWidget(but) - but = QT.QPushButton(icon=QT.QIcon(':/media-playback-stop.svg')) - #~ but = QT.QPushButton(QT.QIcon(':/media-playback-stop.png'), '') + but = QT.QPushButton(icon=QT.QIcon(":/media-playback-stop.svg")) + # ~ but = QT.QPushButton(QT.QIcon(':/media-playback-stop.png'), '') but.clicked.connect(self.on_stop_pause) h.addWidget(but) - h.addWidget(QT.QLabel('Speed:')) - self.speedSpin = pg.SpinBox(bounds=(0.01, 100.), step=0.1, value=1.) - if 'compactHeight' in self.speedSpin.opts: # pyqtgraph >= 0.11.0 + h.addWidget(QT.QLabel("Speed:")) + self.speedSpin = pg.SpinBox(bounds=(0.01, 100.0), step=0.1, value=1.0) + if "compactHeight" in self.speedSpin.opts: # pyqtgraph >= 0.11.0 self.speedSpin.setOpts(compactHeight=False) h.addWidget(self.speedSpin) self.speedSpin.valueChanged.connect(self.on_change_speed) - self.speed = 1. + self.speed = 1.0 - #trick for separator - h.addWidget(QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken)) + # trick for separator + h.addWidget( + QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken) + ) # add spacebar shortcut for play/pause play_pause_shortcut = QT.QShortcut(self) - play_pause_shortcut.setKey(QT.QKeySequence(' ')) + play_pause_shortcut.setKey(QT.QKeySequence(" ")) play_pause_shortcut.activated.connect(self.on_play_pause_shortcut) - self.steps = ['60 s', '10 s', '1 s', '100 ms', '50 ms', '5 ms', '1 ms', '200 us'] + self.steps = [ + "60 s", + "10 s", + "1 s", + "100 ms", + "50 ms", + "5 ms", + "1 ms", + "200 us", + ] if show_step: - but = QT.QPushButton('<') + but = QT.QPushButton("<") but.clicked.connect(self.prev_step) h.addWidget(but) @@ -114,77 +130,98 @@ def __init__(self, parent=None, show_play=True, show_step=True, self.on_change_step(None) self.combo_step.currentIndexChanged.connect(self.on_change_step) - but = QT.QPushButton('>') + but = QT.QPushButton(">") but.clicked.connect(self.next_step) h.addWidget(but) - #trick for separator - h.addWidget(QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken)) + # trick for separator + h.addWidget( + QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken) + ) # add shortcuts for stepping through time and changing step size shortcuts = [ - {'key': QT.Qt.Key_Left, 'callback': self.prev_step}, - {'key': QT.Qt.Key_Right, 'callback': self.next_step}, - {'key': QT.Qt.Key_Up, 'callback': self.increase_step}, - {'key': QT.Qt.Key_Down, 'callback': self.decrease_step}, - {'key': 'a', 'callback': self.prev_step}, - {'key': 'd', 'callback': self.next_step}, - {'key': 'w', 'callback': self.increase_step}, - {'key': 's', 'callback': self.decrease_step}, + {"key": QT.Qt.Key_Left, "callback": self.prev_step}, + {"key": QT.Qt.Key_Right, "callback": self.next_step}, + {"key": QT.Qt.Key_Up, "callback": self.increase_step}, + {"key": QT.Qt.Key_Down, "callback": self.decrease_step}, + {"key": "a", "callback": self.prev_step}, + {"key": "d", "callback": self.next_step}, + {"key": "w", "callback": self.increase_step}, + {"key": "s", "callback": self.decrease_step}, ] for s in shortcuts: shortcut = QT.QShortcut(self) - shortcut.setKey(QT.QKeySequence(s['key'])) - shortcut.activated.connect(s['callback']) - + shortcut.setKey(QT.QKeySequence(s["key"])) + shortcut.activated.connect(s["callback"]) if show_spinbox: - h.addWidget(QT.QLabel('Time (s):')) - self.spinbox_time =pg.SpinBox(decimals = 8, bounds = (-np.inf, np.inf),step = 0.05, siPrefix=False, suffix='', int=False) - if 'compactHeight' in self.spinbox_time.opts: # pyqtgraph >= 0.11.0 + h.addWidget(QT.QLabel("Time (s):")) + self.spinbox_time = pg.SpinBox( + decimals=8, + bounds=(-np.inf, np.inf), + step=0.05, + siPrefix=False, + suffix="", + int=False, + ) + if "compactHeight" in self.spinbox_time.opts: # pyqtgraph >= 0.11.0 self.spinbox_time.setOpts(compactHeight=False) h.addWidget(self.spinbox_time) - #trick for separator - h.addWidget(QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken)) + # trick for separator + h.addWidget( + QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken) + ) - - #~ h.addSeparator() + # ~ h.addSeparator() self.spinbox_time.valueChanged.connect(self.on_spinbox_time_changed) if show_label_datetime: assert self.datetime0 is not None - self.label_datetime = QT.QLabel('') + self.label_datetime = QT.QLabel("") h.addWidget(self.label_datetime) - #trick for separator - h.addWidget(QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken)) + # trick for separator + h.addWidget( + QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken) + ) if show_global_xsize: - h.addWidget(QT.QLabel('Time width (s):')) - self.spinbox_xsize =pg.SpinBox(value=3., decimals = 8, bounds = (0.001, np.inf),step = 0.1, siPrefix=False, suffix='', int=False) - if 'compactHeight' in self.spinbox_xsize.opts: # pyqtgraph >= 0.11.0 + h.addWidget(QT.QLabel("Time width (s):")) + self.spinbox_xsize = pg.SpinBox( + value=3.0, + decimals=8, + bounds=(0.001, np.inf), + step=0.1, + siPrefix=False, + suffix="", + int=False, + ) + if "compactHeight" in self.spinbox_xsize.opts: # pyqtgraph >= 0.11.0 self.spinbox_xsize.setOpts(compactHeight=False) h.addWidget(self.spinbox_xsize) - #~ self.spinbox_xsize.valueChanged.connect(self.on_spinbox_xsize_changed) + # ~ self.spinbox_xsize.valueChanged.connect(self.on_spinbox_xsize_changed) self.spinbox_xsize.valueChanged.connect(self.xsize_changed.emit) - #trick for separator - h.addWidget(QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken)) + # trick for separator + h.addWidget( + QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken) + ) if show_auto_scale: - but = QT.PushButton('Auto scale') + but = QT.PushButton("Auto scale") h.addWidget(but) but.clicked.connect(self.auto_scale_requested.emit) - #~ h.addWidget(QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken)) + # ~ h.addWidget(QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken)) h.addStretch() - self.timer_play = QT.QTimer(parent=self, interval=int(play_interval*1000)) + self.timer_play = QT.QTimer(parent=self, interval=int(play_interval * 1000)) self.timer_play.timeout.connect(self.on_timer_play_interval) self.timer_delay = None # all in s - self.t = 0 # s - self.set_start_stop(0., 0.1) + self.t = 0 # s + self.set_start_stop(0.0, 0.1) self.last_time = None @@ -204,14 +241,14 @@ def on_play_pause_shortcut(self): def on_timer_play_interval(self): actual_time = time.time() - #~ t = self.t + self.play_interval*self.speed - t = self.t + (actual_time-self.last_time)*self.speed + # ~ t = self.t + self.play_interval*self.speed + t = self.t + (actual_time - self.last_time) * self.speed self.seek(t) self.last_time = actual_time - def set_start_stop(self, t_start, t_stop, seek = True): - #~ print 't_start', t_start, 't_stop', t_stop - assert t_stop>t_start + def set_start_stop(self, t_start, t_stop, seek=True): + # ~ print 't_start', t_start, 't_stop', t_stop + assert t_stop > t_start self.t_start = t_start self.t_stop = t_stop if seek: @@ -221,55 +258,55 @@ def set_start_stop(self, t_start, t_stop, seek = True): self.spinbox_time.setMaximum(t_stop) def on_change_step(self, val): - text = str(self.combo_step.currentText ()) + text = str(self.combo_step.currentText()) - if text.endswith('ms'): - self.step_size = float(text[:-2])*1e-3 - elif text.endswith('us'): - self.step_size = float(text[:-2])*1e-6 + if text.endswith("ms"): + self.step_size = float(text[:-2]) * 1e-3 + elif text.endswith("us"): + self.step_size = float(text[:-2]) * 1e-6 else: self.step_size = float(text[:-1]) - #~ print('self.step_size', self.step_size) + # ~ print('self.step_size', self.step_size) def prev_step(self): - t = self.t - self.step_size + t = self.t - self.step_size self.seek(t) def next_step(self): - t = self.t + self.step_size + t = self.t + self.step_size self.seek(t) def increase_step(self): - new_index = max(self.combo_step.currentIndex()-1, 0) + new_index = max(self.combo_step.currentIndex() - 1, 0) self.combo_step.setCurrentIndex(new_index) def decrease_step(self): - new_index = min(self.combo_step.currentIndex()+1, self.combo_step.count()-1) + new_index = min(self.combo_step.currentIndex() + 1, self.combo_step.count() - 1) self.combo_step.setCurrentIndex(new_index) def on_scroll_time_changed(self, pos): - t = pos/1000.*(self.t_stop - self.t_start)+self.t_start - self.seek(t, refresh_scroll = False) + t = pos / 1000.0 * (self.t_stop - self.t_start) + self.t_start + self.seek(t, refresh_scroll=False) def on_spinbox_time_changed(self, val): - self.seek(val, refresh_spinbox = False) + self.seek(val, refresh_spinbox=False) - #~ def on_spinbox_xsize_changed(self, val): - #~ print('xsize', val) + # ~ def on_spinbox_xsize_changed(self, val): + # ~ print('xsize', val) - def seek(self , t, refresh_scroll = True, refresh_spinbox = True, emit=True): + def seek(self, t, refresh_scroll=True, refresh_spinbox=True, emit=True): self.t = t - if (self.tself.t_stop): + if self.t > self.t_stop: self.t = self.t_stop if self.timer_play.isActive(): self.timer_play.stop() - #~ self.stop_pause() + # ~ self.stop_pause() if refresh_scroll and self.show_scroll_time: self.scroll_time.valueChanged.disconnect(self.on_scroll_time_changed) - pos = int((self.t - self.t_start)/(self.t_stop - self.t_start)*1000.) + pos = int((self.t - self.t_start) / (self.t_stop - self.t_start) * 1000.0) self.scroll_time.setValue(pos) self.scroll_time.valueChanged.connect(self.on_scroll_time_changed) @@ -282,19 +319,18 @@ def seek(self , t, refresh_scroll = True, refresh_spinbox = True, emit=True): dt = self.datetime0 + datetime.timedelta(seconds=self.t) self.label_datetime.setText(dt.strftime(self.datetime_format)) - if emit: self.time_changed.emit(self.t) - def on_change_speed(self , speed): + def on_change_speed(self, speed): self.speed = speed def set_settings(self, d): - if hasattr(self, 'spinbox_xsize') and 'xsize' in d: - self.spinbox_xsize.setValue(d['xsize']) + if hasattr(self, "spinbox_xsize") and "xsize" in d: + self.spinbox_xsize.setValue(d["xsize"]) def get_settings(self): d = {} - if hasattr(self, 'spinbox_xsize'): - d['xsize'] = float(self.spinbox_xsize.value()) + if hasattr(self, "spinbox_xsize"): + d["xsize"] = float(self.spinbox_xsize.value()) return d diff --git a/ephyviewer/scripts.py b/ephyviewer/scripts.py index fedda1e..a2129df 100644 --- a/ephyviewer/scripts.py +++ b/ephyviewer/scripts.py @@ -1,17 +1,16 @@ - import sys -import os import argparse from ephyviewer.datasource import HAVE_NEO from ephyviewer.standalone import all_neo_rawio_dict, rawio_gui_params from ephyviewer import __version__ + def launch_standalone_ephyviewer(): from ephyviewer.standalone import WindowManager import pyqtgraph as pg - assert HAVE_NEO, 'Must have Neo >= 0.6.0' - import neo + assert HAVE_NEO, "Must have Neo >= 0.6.0" + import neo argv = sys.argv[1:] @@ -20,41 +19,57 @@ def launch_standalone_ephyviewer(): RawIO classes """ parser = argparse.ArgumentParser(description=description) - parser.add_argument('file_or_dir', default=None, nargs='?', - help='an optional path to a data file or directory, ' - 'which will be opened immediately (the file ' - 'format will be inferred from the file ' - 'extension, if possible; otherwise, --format is ' - 'required)') - parser.add_argument('-V', '--version', action='version', - version='ephyviewer {}'.format(__version__)) - parser.add_argument('-f', '--format', default=None, - help='specify one of the following formats to ' - 'override the format detected automatically for ' - 'file_or_dir: {}'.format( - ', '.join(all_neo_rawio_dict.keys()))) + parser.add_argument( + "file_or_dir", + default=None, + nargs="?", + help="an optional path to a data file or directory, " + "which will be opened immediately (the file " + "format will be inferred from the file " + "extension, if possible; otherwise, --format is " + "required)", + ) + parser.add_argument( + "-V", "--version", action="version", version="ephyviewer {}".format(__version__) + ) + parser.add_argument( + "-f", + "--format", + default=None, + help="specify one of the following formats to " + "override the format detected automatically for " + "file_or_dir: {}".format(", ".join(all_neo_rawio_dict.keys())), + ) app = pg.mkQApp() manager = WindowManager() - if len(argv)>=1: + if len(argv) >= 1: args = parser.parse_args(argv) file_or_dir_name = args.file_or_dir if args.format is None: - #autoguess from extension + # autoguess from extension neo_rawio_class = neo.rawio.get_rawio_class(file_or_dir_name) else: neo_rawio_class = all_neo_rawio_dict.get(args.format, None) - assert neo_rawio_class is not None, 'Unknown format. Format list: {}'.format(', '.join(all_neo_rawio_dict.keys())) + assert neo_rawio_class is not None, "Unknown format. Format list: {}".format( + ", ".join(all_neo_rawio_dict.keys()) + ) - name = neo_rawio_class.__name__.replace('RawIO', '') + name = neo_rawio_class.__name__.replace("RawIO", "") if name in rawio_gui_params: - raise(Exception('This IO requires additional parameters. Run ephyviewer without arguments to input these via the GUI.')) + raise ( + Exception( + "This IO requires additional parameters. Run ephyviewer without arguments to input these via the GUI." + ) + ) - manager.load_dataset(neo_rawio_class=neo_rawio_class, file_or_dir_names=[file_or_dir_name]) + manager.load_dataset( + neo_rawio_class=neo_rawio_class, file_or_dir_names=[file_or_dir_name] + ) else: manager.open_dialog() @@ -62,5 +77,6 @@ def launch_standalone_ephyviewer(): if manager.windows: app.exec() -if __name__=='__main__': + +if __name__ == "__main__": launch_standalone_ephyviewer() diff --git a/ephyviewer/spectrogramviewer.py b/ephyviewer/spectrogramviewer.py index fa20353..49cda03 100644 --- a/ephyviewer/spectrogramviewer.py +++ b/ephyviewer/spectrogramviewer.py @@ -14,11 +14,6 @@ from .tools import create_plot_grid, get_dict_from_group_param -# todo remove this -import time -import threading - - default_params = [ {"name": "xsize", "type": "float", "value": 10.0, "step": 0.1}, {"name": "xratio", "type": "float", "value": 0.3, "step": 0.1, "limits": (0, 1)}, diff --git a/examples/trace_viewer_datasource.py b/examples/trace_viewer_datasource.py index 7acdf2c..2d9cf60 100644 --- a/examples/trace_viewer_datasource.py +++ b/examples/trace_viewer_datasource.py @@ -1,35 +1,33 @@ from ephyviewer import mkQApp, MainViewer, TraceViewer from ephyviewer import InMemoryAnalogSignalSource -import ephyviewer import numpy as np - -#you must first create a main Qt application (for event loop) +# you must first create a main Qt application (for event loop) app = mkQApp() -#create fake 16 signals with 100000 at 10kHz -sigs = np.random.rand(100000,16) -sample_rate = 1000. -t_start = 0. +# create fake 16 signals with 100000 at 10kHz +sigs = np.random.rand(100000, 16) +sample_rate = 1000.0 +t_start = 0.0 -#Create the main window that can contain several viewers +# Create the main window that can contain several viewers win = MainViewer(debug=True, show_auto_scale=True) -#Create a datasource for the viewer +# Create a datasource for the viewer # here we use InMemoryAnalogSignalSource but # you can alose use your custum datasource by inheritance source = InMemoryAnalogSignalSource(sigs, sample_rate, t_start) -#create a viewer for signal with TraceViewer +# create a viewer for signal with TraceViewer # TraceViewer normally accept a AnalogSignalSource but # TraceViewer.from_numpy is facitilty function to bypass that view1 = TraceViewer(source=source) -#put this veiwer in the main window +# put this veiwer in the main window win.add_view(view1) -#show main window and run Qapp +# show main window and run Qapp win.show() app.exec() diff --git a/examples/trace_viewer_with_marker.py b/examples/trace_viewer_with_marker.py index b8ae925..db91a96 100644 --- a/examples/trace_viewer_with_marker.py +++ b/examples/trace_viewer_with_marker.py @@ -1,51 +1,52 @@ from ephyviewer import mkQApp, MainViewer, TraceViewer from ephyviewer import AnalogSignalSourceWithScatter -import ephyviewer import numpy as np -#you must first create a main Qt application (for event loop) +# you must first create a main Qt application (for event loop) app = mkQApp() -#create 16 signals with 100000 at 10kHz -sigs = np.random.rand(100000,16) -sample_rate = 1000. -t_start = 0. +# create 16 signals with 100000 at 10kHz +sigs = np.random.rand(100000, 16) +sample_rate = 1000.0 +t_start = 0.0 -#create fake 16 signals with sinus -sample_rate = 1000. -t_start = 0. -times = np.arange(1000000)/sample_rate -signals = np.sin(times*2*np.pi*5)[:, None] +# create fake 16 signals with sinus +sample_rate = 1000.0 +t_start = 0.0 +times = np.arange(1000000) / sample_rate +signals = np.sin(times * 2 * np.pi * 5)[:, None] signals = np.tile(signals, (1, 16)) -#detect some crossing zeros +# detect some crossing zeros s0 = signals[:-2, 0] -s1 = signals[1:-1,0] -s2 = signals[2:,0] -peaks0, = np.nonzero((s0s1) & (s2>s1)) +s1 = signals[1:-1, 0] +s2 = signals[2:, 0] +(peaks0,) = np.nonzero((s0 < s1) & (s2 < s1)) +(peaks1,) = np.nonzero((s0 > s1) & (s2 > s1)) -#create 2 familly scatters from theses 2 indexes +# create 2 familly scatters from theses 2 indexes scatter_indexes = {0: peaks0, 1: peaks1} -#and asign them to some channels each +# and asign them to some channels each scatter_channels = {0: [0, 5, 8], 1: [0, 5, 10]} -source = AnalogSignalSourceWithScatter(signals, sample_rate, t_start, scatter_indexes, scatter_channels) +source = AnalogSignalSourceWithScatter( + signals, sample_rate, t_start, scatter_indexes, scatter_channels +) -#Create the main window that can contain several viewers +# Create the main window that can contain several viewers win = MainViewer(debug=True, show_auto_scale=True) -#create a viewer for signal with TraceViewer -#connected to the signal source +# create a viewer for signal with TraceViewer +# connected to the signal source view1 = TraceViewer(source=source) -view1.params['scale_mode'] = 'same_for_all' +view1.params["scale_mode"] = "same_for_all" view1.auto_scale() -#put this veiwer in the main window +# put this veiwer in the main window win.add_view(view1) -#show main window and run Qapp +# show main window and run Qapp win.show() app.exec() From 02a8a314f3adc799d7666f382233c2eb002da372 Mon Sep 17 00:00:00 2001 From: Tom Bugnon Date: Fri, 12 May 2023 08:38:17 -0500 Subject: [PATCH 08/12] Replace by selectable spinbox for step selection --- ephyviewer/navigation.py | 68 ++++++++++------------------------------ 1 file changed, 17 insertions(+), 51 deletions(-) diff --git a/ephyviewer/navigation.py b/ephyviewer/navigation.py index f2313b8..f14f913 100644 --- a/ephyviewer/navigation.py +++ b/ephyviewer/navigation.py @@ -106,29 +106,27 @@ def __init__( play_pause_shortcut.setKey(QT.QKeySequence(" ")) play_pause_shortcut.activated.connect(self.on_play_pause_shortcut) - self.steps = [ - "60 s", - "10 s", - "1 s", - "100 ms", - "50 ms", - "5 ms", - "1 ms", - "200 us", - ] - if show_step: + but = QT.QPushButton("<") but.clicked.connect(self.prev_step) h.addWidget(but) - self.combo_step = QT.QComboBox() - self.combo_step.addItems(self.steps) - self.combo_step.setCurrentIndex(2) - h.addWidget(self.combo_step) - - self.on_change_step(None) - self.combo_step.currentIndexChanged.connect(self.on_change_step) + h.addWidget(QT.QLabel("Step (s):")) + self.spinbox_step = pg.SpinBox( + value=3, + decimals=3, + bounds=(0, np.inf), + step=1, + siPrefix=False, + suffix="", + int=False, + ) + if "compactHeight" in self.spinbox_step.opts: # pyqtgraph >= 0.11.0 + self.spinbox_step.setOpts(compactHeight=False) + h.addWidget(self.spinbox_step) + self.spinbox_step.valueChanged.connect(self.on_change_step) + self.on_change_step(self.spinbox_step.value()) # Sets self.step_size but = QT.QPushButton(">") but.clicked.connect(self.next_step) @@ -139,22 +137,6 @@ def __init__( QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken) ) - # add shortcuts for stepping through time and changing step size - shortcuts = [ - {"key": QT.Qt.Key_Left, "callback": self.prev_step}, - {"key": QT.Qt.Key_Right, "callback": self.next_step}, - {"key": QT.Qt.Key_Up, "callback": self.increase_step}, - {"key": QT.Qt.Key_Down, "callback": self.decrease_step}, - {"key": "a", "callback": self.prev_step}, - {"key": "d", "callback": self.next_step}, - {"key": "w", "callback": self.increase_step}, - {"key": "s", "callback": self.decrease_step}, - ] - for s in shortcuts: - shortcut = QT.QShortcut(self) - shortcut.setKey(QT.QKeySequence(s["key"])) - shortcut.activated.connect(s["callback"]) - if show_spinbox: h.addWidget(QT.QLabel("Time (s):")) self.spinbox_time = pg.SpinBox( @@ -258,15 +240,7 @@ def set_start_stop(self, t_start, t_stop, seek=True): self.spinbox_time.setMaximum(t_stop) def on_change_step(self, val): - text = str(self.combo_step.currentText()) - - if text.endswith("ms"): - self.step_size = float(text[:-2]) * 1e-3 - elif text.endswith("us"): - self.step_size = float(text[:-2]) * 1e-6 - else: - self.step_size = float(text[:-1]) - # ~ print('self.step_size', self.step_size) + self.step_size = val def prev_step(self): t = self.t - self.step_size @@ -276,14 +250,6 @@ def next_step(self): t = self.t + self.step_size self.seek(t) - def increase_step(self): - new_index = max(self.combo_step.currentIndex() - 1, 0) - self.combo_step.setCurrentIndex(new_index) - - def decrease_step(self): - new_index = min(self.combo_step.currentIndex() + 1, self.combo_step.count() - 1) - self.combo_step.setCurrentIndex(new_index) - def on_scroll_time_changed(self, pos): t = pos / 1000.0 * (self.t_stop - self.t_start) + self.t_start self.seek(t, refresh_scroll=False) From cfdaa42bb92382c7996026493a6c25c611692ddd Mon Sep 17 00:00:00 2001 From: Tom Bugnon Date: Fri, 12 May 2023 08:39:42 -0500 Subject: [PATCH 09/12] 1s (rather than 0.1s) steps for xsize spinbox --- ephyviewer/navigation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ephyviewer/navigation.py b/ephyviewer/navigation.py index f14f913..dda332b 100644 --- a/ephyviewer/navigation.py +++ b/ephyviewer/navigation.py @@ -173,7 +173,7 @@ def __init__( value=3.0, decimals=8, bounds=(0.001, np.inf), - step=0.1, + step=1, siPrefix=False, suffix="", int=False, From d713200d58960411a4718de48cbff47c31dad63f Mon Sep 17 00:00:00 2001 From: Tom Bugnon Date: Fri, 12 May 2023 08:46:55 -0500 Subject: [PATCH 10/12] Step follows when xsize value changed --- ephyviewer/navigation.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ephyviewer/navigation.py b/ephyviewer/navigation.py index dda332b..561221b 100644 --- a/ephyviewer/navigation.py +++ b/ephyviewer/navigation.py @@ -181,8 +181,8 @@ def __init__( if "compactHeight" in self.spinbox_xsize.opts: # pyqtgraph >= 0.11.0 self.spinbox_xsize.setOpts(compactHeight=False) h.addWidget(self.spinbox_xsize) - # ~ self.spinbox_xsize.valueChanged.connect(self.on_spinbox_xsize_changed) - self.spinbox_xsize.valueChanged.connect(self.xsize_changed.emit) + # self.spinbox_xsize.valueChanged.connect(self.xsize_changed.emit) + self.spinbox_xsize.valueChanged.connect(self.on_spinbox_xsize_changed) # trick for separator h.addWidget( QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken) @@ -257,8 +257,9 @@ def on_scroll_time_changed(self, pos): def on_spinbox_time_changed(self, val): self.seek(val, refresh_spinbox=False) - # ~ def on_spinbox_xsize_changed(self, val): - # ~ print('xsize', val) + def on_spinbox_xsize_changed(self, val): + self.spinbox_step.setValue(val) + self.xsize_changed.emit(val) def seek(self, t, refresh_scroll=True, refresh_spinbox=True, emit=True): self.t = t From b0f7a183b5fcb17195efddc8198cc908ec9eecb2 Mon Sep 17 00:00:00 2001 From: Tom Bugnon Date: Fri, 12 May 2023 09:14:06 -0500 Subject: [PATCH 11/12] Allow add_view(split_with/tabify_with='navigation') --- ephyviewer/mainviewer.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/ephyviewer/mainviewer.py b/ephyviewer/mainviewer.py index 30f043b..b29a02a 100644 --- a/ephyviewer/mainviewer.py +++ b/ephyviewer/mainviewer.py @@ -70,9 +70,23 @@ def __init__(self, debug=False, settings_name=None, parent=None, global_xsize_zo def add_view(self, widget, location='bottom', orientation='vertical', tabify_with=None, split_with=None): + """ + Add view to window. + + Parameters: + =========== + widget: QT.QWidget + location: str + One of "left", "right", "top", "bottom" (default "bottom") + orientation: str + One of "horizontal", "vertical" (default "vertical") + tabify_with, split_with: str + Either of "navigation" or the name of an existing viewer. + """ name = widget.name assert name not in self.viewers, 'Viewer already in MainViewer' + assert name != "navigation", "Viewer cannot be named 'navigation'." dock = QT.QDockWidget(name) dock.setObjectName(name) @@ -81,18 +95,32 @@ def add_view(self, widget, location='bottom', orientation='vertical', #TODO chustum titlebar #~ dock.setTitleBarWidget(titlebar) + other_docks = { + "navigation": self.navigation_dock, + **{ + vname: vvalue["dock"] + for vname, vvalue in self.viewers.items() + } + } + if tabify_with is not None: - assert tabify_with in self.viewers, '{} no exists'.format(tabify_with) + assert tabify_with in other_docks.keys(), ( + f"Invalid value for 'tabify_with' kwarg (={tabify_with}). " + f"Expected one of: `{list(other_docks.keys())}`" + ) #~ raise(NotImplementedError) #tabifyDockWidget ( QDockWidget * first, QDockWidget * second ) - other_dock = self.viewers[tabify_with]['dock'] + other_dock = other_docks[tabify_with] self.tabifyDockWidget(other_dock, dock) elif split_with is not None: - assert split_with in self.viewers, '{} no exists'.format(split_with) + assert split_with in other_docks.keys(), ( + f"Invalid value for 'tabify_with' kwarg (={split_with}). " + f"Expected one of: `{list(other_docks.keys())}`" + ) #~ raise(NotImplementedError) orien = orientation_to_qt[orientation] - other_dock = self.viewers[split_with]['dock'] + other_dock = other_docks[split_with] self.splitDockWidget(other_dock, dock, orien) #splitDockWidget ( QDockWidget * first, QDockWidget * second, Qt::Orientation orientation ) else: From c7f1dbe57be80789a4c5f2b97e3068586e027abf Mon Sep 17 00:00:00 2001 From: Tom Bugnon Date: Fri, 12 May 2023 13:40:28 -0500 Subject: [PATCH 12/12] Restore step-related shortcuts --- ephyviewer/navigation.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/ephyviewer/navigation.py b/ephyviewer/navigation.py index 561221b..e81d98d 100644 --- a/ephyviewer/navigation.py +++ b/ephyviewer/navigation.py @@ -137,6 +137,22 @@ def __init__( QT.QFrame(frameShape=QT.QFrame.VLine, frameShadow=QT.QFrame.Sunken) ) + # add shortcuts for stepping through time and changing step size + shortcuts = [ + {"key": QT.Qt.Key_Left, "callback": self.prev_step}, + {"key": QT.Qt.Key_Right, "callback": self.next_step}, + {"key": QT.Qt.Key_Up, "callback": self.increase_step}, + {"key": QT.Qt.Key_Down, "callback": self.decrease_step}, + {"key": "a", "callback": self.prev_step}, + {"key": "d", "callback": self.next_step}, + {"key": "w", "callback": self.increase_step}, + {"key": "s", "callback": self.decrease_step}, + ] + for s in shortcuts: + shortcut = QT.QShortcut(self) + shortcut.setKey(QT.QKeySequence(s["key"])) + shortcut.activated.connect(s["callback"]) + if show_spinbox: h.addWidget(QT.QLabel("Time (s):")) self.spinbox_time = pg.SpinBox( @@ -250,6 +266,12 @@ def next_step(self): t = self.t + self.step_size self.seek(t) + def increase_step(self): + self.spinbox_step.setValue(self.spinbox_step.value() + float(self.spinbox_step.opts["step"])) + + def decrease_step(self): + self.spinbox_step.setValue(self.spinbox_step.value() - float(self.spinbox_step.opts["step"])) + def on_scroll_time_changed(self, pos): t = pos / 1000.0 * (self.t_stop - self.t_start) + self.t_start self.seek(t, refresh_scroll=False)