Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Initial implementation of Hilbert detector #23

Merged
merged 31 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9ed7d32
initial implementation of Hilbert
pemyers27 Mar 12, 2021
b79cefa
Merge branch 'master' into Hilbert
pemyers27 Mar 12, 2021
beeac09
modify into three-step scheme
pemyers27 Mar 19, 2021
49e532f
move frq bands for hilbert
pemyers27 Mar 25, 2021
30c1425
broken tests
pemyers27 Mar 28, 2021
a534081
add frq based detector func'
pmyers16 Mar 28, 2021
f9587ec
fix dimensions in hilbert detector
pemyers27 Mar 28, 2021
245fcd5
fix RMS and Linelength
pemyers27 Mar 28, 2021
c1ed904
run flake
pmyers16 Mar 28, 2021
2ae3f41
fix flake
pmyers16 Mar 28, 2021
8d1d79a
run pydocstyle
pmyers16 Mar 28, 2021
36a1410
rerun pydocstring
pmyers16 Mar 28, 2021
20699a5
change test utils function names
pmyers16 Mar 28, 2021
6673869
add minimal notebook
pmyers16 Mar 30, 2021
a8467e9
Fixing reqs files.
adam2392 Mar 30, 2021
d592456
chunck the hilbert calculation
pmyers16 Mar 31, 2021
55fa472
skip Hilbert test
pemyers27 Mar 31, 2021
ffec966
running flake
pemyers27 Apr 1, 2021
048fcec
run flake again
pemyers27 Apr 1, 2021
7df11cd
add example to toctree
pemyers27 Apr 1, 2021
5c95c39
fix docstrings
pemyers27 Apr 1, 2021
97995d7
add band detect test
pmyers16 Apr 6, 2021
7905395
modify testing data
pemyers27 Apr 6, 2021
af78401
Merge branch 'master' into Hilbert
pmyers16 Apr 6, 2021
e11c6d7
update whats_new
pemyers27 Apr 6, 2021
9c5dca0
fix checks
pemyers27 Apr 6, 2021
f255b7e
modify whats_new
pemyers27 Apr 6, 2021
4524bfe
Apply suggestions from code review
adam2392 Apr 6, 2021
85026fc
Merging readme.
adam2392 Apr 6, 2021
5b88cda
Merging.
adam2392 Apr 6, 2021
4f73dcf
run flake
pemyers27 Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 168 additions & 79 deletions mne_hfo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from mne_hfo.score import accuracy, false_negative_rate, \
true_positive_rate, precision, false_discovery_rate
from mne_hfo.sklearn import _make_ydf_sklearn
from mne_hfo.utils import (threshold_std, compute_rms,
compute_line_length)
from mne_hfo.utils import (apply_std, compute_rms,
compute_line_length, compute_hilbert, apply_hilbert,
merge_contiguous_freq_bands)

ACCEPTED_THRESHOLD_METHODS = ['std']
ACCEPTED_HFO_METHODS = ['line_length', 'rms']
ACCEPTED_THRESHOLD_METHODS = ['std', 'hilbert']
ACCEPTED_MERGE_METHODS = ['time-windows', 'freq-bands']
ACCEPTED_HFO_METHODS = ['line_length', 'rms', 'hilbert']


class Detector(BaseEstimator):
Expand All @@ -25,6 +27,15 @@ class Detector(BaseEstimator):
Note: Detection will occur on all channels present.
Subset your dataset before detecting.

Detectors fit follow the following general flow by implementing
private functions:
1. Compute a statistic on the raw data in _compute_hfo_statistic.
i.e. the LineLength of a time-window
2. Apply a threshold to the statistic computed in (1) in
_threshold_statistic. i.e. std of LineLength
3. Merge contiguous/overlapping events into unique detections
in _post_process_chs_hfo. i.e. contiguous time windows

Parameters
----------
threshold: float
Expand All @@ -49,8 +60,8 @@ def __init__(self, threshold: Union[int, float],
self.verbose = verbose
self.n_jobs = n_jobs

def _compute_hfo(self, X, picks):
"""Compute HFO event array.
def _compute_hfo_statistic(self, X, picks):
"""Compute HFO statistic.

Takes a sliding window approach and computes the existence
of an HFO defined by algorithm parameters. If an HFO is
Expand All @@ -73,6 +84,44 @@ def _compute_hfo(self, X, picks):
raise NotImplementedError('Private function that computes the HFOs '
'needs to be implemented.')

def _threshold_statistic(self, hfo_statistic_arr):
"""Apply threshold(s) to the calculated statistic to generate hfo events.

Parameters
----------
hfo_statistic_arr: np.ndarray
The output of _compute_hfo_statistic

Returns
-------
hfo_event_array: np.ndarray
HFO event array that contains (at minimum) a series of start
and stop times.
"""
raise NotImplementedError('Private function that computes the HFOs '
'needs to be implemented.')

def _post_process_ch_hfos(self, hfo_event_array, idx):
"""Post process one channel's HFO events.
adam2392 marked this conversation as resolved.
Show resolved Hide resolved

Joins contiguously detected HFOs as one event.

Parameters
----------
hfo_event_array : np.ndarray
List of HFO metric values (e.g. Line Length, or RMS) over windows.
idx : int
Index of the channel being analyzed

Returns
-------
output : List[Tuple[int, int]]
A list of tuples, storing the event start and stop sample index
for the detected HFO.
"""
raise NotImplementedError('Private function that computes the HFOs '
'needs to be implemented.')

def _compute_n_wins(self, win_size, step_size, n_times):
n_windows = int(np.ceil((n_times - win_size) / step_size)) + 1
return n_windows
Expand Down Expand Up @@ -275,21 +324,68 @@ def fit(self, X, y=None):
f'below the suggested rate of {MINIMUM_SUGGESTED_SFREQ}. '
f'Please use with caution.')

# compute HFOs as a binary occurrence array over time
hfo_event_arr = self._compute_hfo(X)
# compute HFO related statistic for the detector
hfo_statistic_arr = self._compute_hfo_statistic(X)

# apply the threshold(s) to the statistic to get detections
hfo_detection_arr = self._threshold_statistic(hfo_statistic_arr)

# post-process hfo events
# merge contiguous detections into discrete hfo events
# store hfo event endpoints per channel
chs_hfos = {ch_name: self._post_process_ch_hfos(
hfo_event_arr[idx, :], n_times=self.n_times,
threshold_method='std'
hfo_detection_arr[idx], idx
) for idx, ch_name in enumerate(self.ch_names)}

self.chs_hfos_ = chs_hfos
self.hfo_event_arr_ = hfo_event_arr
self.hfo_event_arr_ = hfo_statistic_arr
self._create_annotation_df(self.chs_hfos_dict, self.hfo_name)
return self

def _apply_threshold(self, metric, threshold_method):
adam2392 marked this conversation as resolved.
Show resolved Hide resolved
"""Apply the threshold(s) to the calculated metric for a single channel.

Parameters
----------
metric : np.ndarray
The single channel values to check against a threshold
threshold_method : str
The type of threshold to use
Returns
-------
thresholded_metric: np.ndarray
Metric values that pass the given threshold

"""
if threshold_method not in ACCEPTED_THRESHOLD_METHODS:
raise ValueError(f'Threshold method {threshold_method} '
f'is not an implemented threshold method. '
f'Please use one of {ACCEPTED_THRESHOLD_METHODS} '
f'methods.')
if threshold_method == 'std':
threshold_func = apply_std
threshold_dict = dict(thresh=self.threshold)
kwargs = dict(step_size=self.step_size,
win_size=self.win_size,
n_times=self.n_times)
elif threshold_method == 'hilbert':
threshold_func = apply_hilbert
threshold_dict = dict(zscore=self.threshold,
cycles=self.cycle_threshold,
gap=self.gap_threshold)
kwargs = dict(n_times=self.n_times,
sfreq=self.sfreq,
filter_band=self.filter_band,
freq_cutoffs=self.freq_cutoffs,
freq_span=self.freq_span,
n_jobs=self.n_jobs)

if self.verbose:
print(f'Using {threshold_method} to perform HFO '
f'thresholding.')

thresholded_metric = threshold_func(metric, threshold_dict, kwargs)
return thresholded_metric

def predict(self, X):
"""Scikit-learn override predict function.

Expand Down Expand Up @@ -321,6 +417,24 @@ def _create_annotation_df(self, chs_hfos_list, hfo_name):
self.df_ = annot_df

def _compute_sliding_window_detection(self, sig, method):
"""Compute detections on an individual channel data.
adam2392 marked this conversation as resolved.
Show resolved Hide resolved

If the method does not use sliding windows, make win_size
equal to the length of the dataset.

Parameters
----------
sig: np.array
Data from a single channel
adam2392 marked this conversation as resolved.
Show resolved Hide resolved
method: str
Method used to compute the detection
adam2392 marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
signal_win_stat: np.ndarray
Statistic calculated per window

"""
if method not in ACCEPTED_HFO_METHODS:
raise ValueError(f'Sliding window HFO detection method '
f'{method} is not implemented. Please '
Expand All @@ -338,93 +452,68 @@ def _compute_sliding_window_detection(self, sig, method):
self.step_size,
self.n_times)

# store the RMS of each window
signal_win_rms = np.empty(n_windows)
# store the statistic of each window
signal_win_stat = np.empty(n_windows)
win_idx = 0
while win_start < self.n_times:
if win_stop > self.n_times:
win_stop = self.n_times

# compute the RMS of filtered signal in this window
signal_win_rms[win_idx] = hfo_detect_func(
# compute the statistic based on 'method' on filtered signal
# in this window
stat = hfo_detect_func(
sig[int(win_start):int(win_stop)], win_size=self.win_size)[0]
signal_win_stat[win_idx] = stat

if win_stop == self.n_times:
break

win_start += self.step_size
win_stop += self.step_size
win_idx += 1
return signal_win_rms
return signal_win_stat

def _post_process_ch_hfos(self, metric_vals_list, n_times,
threshold_method='std'):
"""Post process one channel's HFO events.
def _compute_frq_band_detection(self, sig, method):
if method not in ACCEPTED_HFO_METHODS:
raise ValueError(f'Sliding window HFO detection method '
f'{method} is not implemented. Please '
f'use one of {ACCEPTED_HFO_METHODS}.')
if method == 'hilbert':
hfo_detect_func = compute_hilbert
signal_stat = hfo_detect_func(sig, self.freq_cutoffs,
self.freq_span, self.sfreq)
return signal_stat

Joins contiguously detected HFOs as one event, and applies
the threshold based on number of stdev above baseline on the
RMS of the bandpass filtered signal.
def _merge_contiguous_ch_detections(self, detections, method):
"""Merge contiguous hfo detections into distinct events.

Parameters
----------
metric_vals_list : list
List of HFO metric values (e.g. Line Length, or RMS) over windows.
n_times : int
The number of time points in the original data matrix fed in.
threshold_method : str
The threshold method to use.
detections : List(tuples)
List of raw hfo detected events
method : str
Method to use to merge the detections.
adam2392 marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
output : List[Tuple[int, int]]
A list of tuples, storing the event start and stop sample index
for the detected HFO.
events: List(tuples)
List of start and stop times of the distinct HFO events.

"""
if threshold_method not in ACCEPTED_THRESHOLD_METHODS:
raise ValueError(f'Threshold method {threshold_method} '
f'is not an implemented threshold method. '
f'Please use one of {ACCEPTED_THRESHOLD_METHODS} '
if method not in ACCEPTED_MERGE_METHODS:
raise ValueError(f'Merging method {method} '
f'is not an implemented merging method. '
f'Please use one of {ACCEPTED_MERGE_METHODS} '
f'methods.')
if threshold_method == 'std':
threshold_func = threshold_std

if self.verbose:
print(f'Using {threshold_method} to perform HFO '
f'thresholding.')

n_windows = len(metric_vals_list)

# store post-processed hfo events as a list
output = []

# only keep RMS values above a certain number
# stdevs above baseline (threshold)
det_th = threshold_func(metric_vals_list, self.threshold)

# Detect and now group events if they are within a
# step size of each other
win_idx = 0
while win_idx < n_windows:
# log events if they pass our threshold criterion
if metric_vals_list[win_idx] >= det_th:
event_start = win_idx * self.step_size

# group events together if they occur in
# contiguous windows
while win_idx < n_windows and \
metric_vals_list[win_idx] >= det_th:
win_idx += 1
event_stop = (win_idx * self.step_size) + self.win_size

if event_stop > n_times:
event_stop = n_times

# TODO: Optional feature calculations

# Write into output
output.append((event_start, event_stop))
win_idx += 1
else:
win_idx += 1

return output
if method == "time-windows":
return detections
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't merge time-windows?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do, but we do it upstream. I left a TODO where its done to refactor it here if you wanted. I think refactoring will slow it down a bit, which is why I held off

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm okay I will take a look

elif method == "freq-bands":
merge_func = merge_contiguous_freq_bands

events = merge_func(detections)
if method == "freq-bands":
events, max_amplitude, freq_bands = events
self.hfo_max_amplitudes_ = max_amplitude
self.hfo_freq_bands_ = freq_bands

return events
Loading