Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Initial implementation of Hilbert detector #23

Merged
merged 31 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9ed7d32
initial implementation of Hilbert
pemyers27 Mar 12, 2021
b79cefa
Merge branch 'master' into Hilbert
pemyers27 Mar 12, 2021
beeac09
modify into three-step scheme
pemyers27 Mar 19, 2021
49e532f
move frq bands for hilbert
pemyers27 Mar 25, 2021
30c1425
broken tests
pemyers27 Mar 28, 2021
a534081
add frq based detector func'
pmyers16 Mar 28, 2021
f9587ec
fix dimensions in hilbert detector
pemyers27 Mar 28, 2021
245fcd5
fix RMS and Linelength
pemyers27 Mar 28, 2021
c1ed904
run flake
pmyers16 Mar 28, 2021
2ae3f41
fix flake
pmyers16 Mar 28, 2021
8d1d79a
run pydocstyle
pmyers16 Mar 28, 2021
36a1410
rerun pydocstring
pmyers16 Mar 28, 2021
20699a5
change test utils function names
pmyers16 Mar 28, 2021
6673869
add minimal notebook
pmyers16 Mar 30, 2021
a8467e9
Fixing reqs files.
adam2392 Mar 30, 2021
d592456
chunck the hilbert calculation
pmyers16 Mar 31, 2021
55fa472
skip Hilbert test
pemyers27 Mar 31, 2021
ffec966
running flake
pemyers27 Apr 1, 2021
048fcec
run flake again
pemyers27 Apr 1, 2021
7df11cd
add example to toctree
pemyers27 Apr 1, 2021
5c95c39
fix docstrings
pemyers27 Apr 1, 2021
97995d7
add band detect test
pmyers16 Apr 6, 2021
7905395
modify testing data
pemyers27 Apr 6, 2021
af78401
Merge branch 'master' into Hilbert
pmyers16 Apr 6, 2021
e11c6d7
update whats_new
pemyers27 Apr 6, 2021
9c5dca0
fix checks
pemyers27 Apr 6, 2021
f255b7e
modify whats_new
pemyers27 Apr 6, 2021
4524bfe
Apply suggestions from code review
adam2392 Apr 6, 2021
85026fc
Merging readme.
adam2392 Apr 6, 2021
5b88cda
Merging.
adam2392 Apr 6, 2021
4f73dcf
run flake
pemyers27 Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 132 additions & 72 deletions mne_hfo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
true_positive_rate, precision, false_discovery_rate
from mne_hfo.sklearn import _make_ydf_sklearn
from mne_hfo.utils import (threshold_std, compute_rms,
compute_line_length)
compute_line_length, compute_hilbert, apply_hilbert,
merge_contiguous_freq_bands)

ACCEPTED_THRESHOLD_METHODS = ['std']
ACCEPTED_HFO_METHODS = ['line_length', 'rms']
ACCEPTED_THRESHOLD_METHODS = ['std', 'hilbert']
ACCEPTED_MERGE_METHODS = ['time-windows', 'freq-bands']
ACCEPTED_HFO_METHODS = ['line_length', 'rms', 'hilbert']


class Detector(BaseEstimator):
Expand Down Expand Up @@ -49,8 +51,8 @@ def __init__(self, threshold: Union[int, float],
self.verbose = verbose
self.n_jobs = n_jobs

def _compute_hfo(self, X, picks):
"""Compute HFO event array.
def _compute_hfo_statistic(self, X, picks):
"""Compute HFO statistic.

Takes a sliding window approach and computes the existence
of an HFO defined by algorithm parameters. If an HFO is
Expand All @@ -73,6 +75,45 @@ def _compute_hfo(self, X, picks):
raise NotImplementedError('Private function that computes the HFOs '
'needs to be implemented.')


def _threshold_statistic(self, hfo_statistic_arr):
"""Apply threshold(s) to the calculated statistic to generate hfo events.

Parameters
----------
hfo_statistic_arr: np.ndarray
The output of _compute_hfo_statistic

Returns
-------
hfo_event_array: np.ndarray
HFO event array that contains (at minimum) a series of start
and stop times.
"""
raise NotImplementedError('Private function that computes the HFOs '
'needs to be implemented.')

def _post_process_ch_hfos(self, hfo_event_array, idx):
"""Post process one channel's HFO events.
adam2392 marked this conversation as resolved.
Show resolved Hide resolved

Joins contiguously detected HFOs as one event.

Parameters
----------
hfo_event_array : np.ndarray
List of HFO metric values (e.g. Line Length, or RMS) over windows.
idx : int
Index of the channel being analyzed

Returns
-------
output : List[Tuple[int, int]]
A list of tuples, storing the event start and stop sample index
for the detected HFO.
"""
raise NotImplementedError('Private function that computes the HFOs '
'needs to be implemented.')

def _compute_n_wins(self, win_size, step_size, n_times):
n_windows = int(np.ceil((n_times - win_size) / step_size)) + 1
return n_windows
Expand Down Expand Up @@ -275,21 +316,70 @@ def fit(self, X, y=None):
f'below the suggested rate of {MINIMUM_SUGGESTED_SFREQ}. '
f'Please use with caution.')

# compute HFOs as a binary occurrence array over time
hfo_event_arr = self._compute_hfo(X)
# compute HFO related statistic for the detector
hfo_statistic_arr = self._compute_hfo_statistic(X)

# post-process hfo events
# apply the threshold(s) to the statistic to get detections
hfo_detection_arr = self._threshold_statistic(hfo_statistic_arr)

# merge contiguous detections into discrete hfo events
# store hfo event endpoints per channel
chs_hfos = {ch_name: self._post_process_ch_hfos(
hfo_event_arr[idx, :], n_times=self.n_times,
threshold_method='std'
hfo_detection_arr[idx, :], idx
) for idx, ch_name in enumerate(self.ch_names)}

self.chs_hfos_ = chs_hfos
self.hfo_event_arr_ = hfo_event_arr
self.hfo_event_arr_ = hfo_statistic_arr
self._create_annotation_df(self.chs_hfos_dict, self.hfo_name)
return self


def _apply_threshold(self, metric, threshold_method):
adam2392 marked this conversation as resolved.
Show resolved Hide resolved
"""Apply the threshold(s) to the calculated metric

Parameters
----------
metric : np.ndarray
The values to check against a threshold
threshold_method : str
The type of threshold to use
Returns
-------
thresholded_metric: np.ndarray
Metric values that pass the given threshold

"""
if threshold_method not in ACCEPTED_THRESHOLD_METHODS:
raise ValueError(f'Threshold method {threshold_method} '
f'is not an implemented threshold method. '
f'Please use one of {ACCEPTED_THRESHOLD_METHODS} '
f'methods.')
if threshold_method == 'std':
threshold_func = apply_std
threshold_dict = dict(thresh=self.threshold)
kwargs = dict(step_size=self.step_size,
win_size=self.win_size,
n_times=self.n_times)
elif threshold_method == 'hilbert':
threshold_func = apply_hilbert
threshold_dict = dict(z_score=self.threshold,
cycles=self.cycle_threshold,
gaps=self.gap_threshold)
kwargs = dict(n_times=n_times,
sfreq=self.sfreq,
filter_band=self.filter_band,
freq_cutoffs=self.freq_cutoffs,
freq_span=self.freq_span,
n_jobs=self.n_jobs)

if self.verbose:
print(f'Using {threshold_method} to perform HFO '
f'thresholding.')

thresholded_metric = threshold_func(metric, *threshold_dict, **kwargs)
return thresholded_metric


def predict(self, X):
"""Scikit-learn override predict function.

Expand Down Expand Up @@ -328,8 +418,15 @@ def _compute_sliding_window_detection(self, sig, method):

if method == 'rms':
hfo_detect_func = compute_rms
extra_params = dict(win_size=self.win_size)
elif method == 'line_length':
hfo_detect_func = compute_line_length
extra_params = dict(win_size=self.win_size)
elif method == 'hilbert':
hfo_detect_func = compute_hilbert
extra_params=dict(freq_cutoffs=self.freq_cutoffs,
freq_span=self.freq_span,
sfreq=self.sfreq)

# Overlapping window
win_start = 0
Expand All @@ -347,7 +444,7 @@ def _compute_sliding_window_detection(self, sig, method):

# compute the RMS of filtered signal in this window
signal_win_rms[win_idx] = hfo_detect_func(
sig[int(win_start):int(win_stop)], win_size=self.win_size)[0]
sig[int(win_start):int(win_stop)], extra_params=extra_params)[0]

if win_stop == self.n_times:
break
Expand All @@ -357,74 +454,37 @@ def _compute_sliding_window_detection(self, sig, method):
win_idx += 1
return signal_win_rms

def _post_process_ch_hfos(self, metric_vals_list, n_times,
threshold_method='std'):
"""Post process one channel's HFO events.

Joins contiguously detected HFOs as one event, and applies
the threshold based on number of stdev above baseline on the
RMS of the bandpass filtered signal.
def _merge_contiguous_ch_detections(self, detections, method):
"""Merge contiguous hfo detections into distinct events.

Parameters
----------
metric_vals_list : list
List of HFO metric values (e.g. Line Length, or RMS) over windows.
n_times : int
The number of time points in the original data matrix fed in.
threshold_method : str
The threshold method to use.
detections : List(tuples)
List of raw hfo detected events
method : str
Method to use to merge the detections.
adam2392 marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
output : List[Tuple[int, int]]
A list of tuples, storing the event start and stop sample index
for the detected HFO.
events: List(tuples)
List of start and stop times of the distinct HFO events.

"""
if threshold_method not in ACCEPTED_THRESHOLD_METHODS:
raise ValueError(f'Threshold method {threshold_method} '
f'is not an implemented threshold method. '
f'Please use one of {ACCEPTED_THRESHOLD_METHODS} '
if method not in ACCEPTED_MERGE_METHODS:
raise ValueError(f'Merging method {method} '
f'is not an implemented merging method. '
f'Please use one of {ACCEPTED_MERGE_METHODS} '
f'methods.')
if threshold_method == 'std':
threshold_func = threshold_std

if self.verbose:
print(f'Using {threshold_method} to perform HFO '
f'thresholding.')
if method == "time-windows":
return detections
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't merge time-windows?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do, but we do it upstream. I left a TODO where its done to refactor it here if you wanted. I think refactoring will slow it down a bit, which is why I held off

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm okay I will take a look

elif method == "freq-bands":
merge_func = merge_contiguous_freq_bands

n_windows = len(metric_vals_list)
events = merge_func(detections)
if method == "freq-bands":
events, max_amplitude, freq_bands = events
self.hfo_max_amplitudes_ = max_amplitude
self.hfo_freq_bands_ = freq_bands

# store post-processed hfo events as a list
output = []
return events

# only keep RMS values above a certain number
# stdevs above baseline (threshold)
det_th = threshold_func(metric_vals_list, self.threshold)

# Detect and now group events if they are within a
# step size of each other
win_idx = 0
while win_idx < n_windows:
# log events if they pass our threshold criterion
if metric_vals_list[win_idx] >= det_th:
event_start = win_idx * self.step_size

# group events together if they occur in
# contiguous windows
while win_idx < n_windows and \
metric_vals_list[win_idx] >= det_th:
win_idx += 1
event_stop = (win_idx * self.step_size) + self.win_size

if event_stop > n_times:
event_stop = n_times

# TODO: Optional feature calculations

# Write into output
output.append((event_start, event_stop))
win_idx += 1
else:
win_idx += 1

return output
Loading