Skip to content

Commit

Permalink
Allow each metric to return multiple values instead of only a single …
Browse files Browse the repository at this point in the history
…scalar
  • Loading branch information
MaxvandenBoom committed Sep 13, 2023
1 parent 708d0b4 commit 778da01
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions ieegprep/bids/data_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,10 +651,10 @@ def _load_data_epoch_averages__by_condition_trials(data_reader, retrieve_channel
if metric_callbacks is not None:
if callable(metric_callbacks):
metric_values = allocate_array((len(retrieve_channels), len(conditions_onsets)),
fill_value=np.nan, dtype=np.float64)
fill_value=np.nan, dtype=np.ndarray)
elif type(metric_callbacks) is tuple and len(metric_callbacks) > 0:
metric_values = allocate_array((len(retrieve_channels), len(conditions_onsets), len(metric_callbacks)),
fill_value=np.nan, dtype=np.float64)
fill_value=np.nan, dtype=np.ndarray)
except MemoryError:
raise MemoryError('Not enough memory create metric output matrix')

Expand Down Expand Up @@ -811,9 +811,6 @@ def _load_data_epoch_averages__by_condition_trials(data_reader, retrieve_channel
None if baseline_data is None else baseline_data[channel_idx, :, :])

if metric_value is not None:
if not np.isscalar(metric_value):
logging.error('Return metric is not scalar')
raise RuntimeError('Return metric is not scalar')
metric_values[channel_idx, condition_idx] = metric_value

elif type(metric_callbacks) is tuple and len(metric_callbacks) > 0:
Expand All @@ -825,9 +822,6 @@ def _load_data_epoch_averages__by_condition_trials(data_reader, retrieve_channel
condition_data[channel_idx, :, :],
None if baseline_data is None else baseline_data[channel_idx, :, :])
if metric_value is not None:
if not np.isscalar(metric_value):
logging.error('Return metric is not scalar')
raise RuntimeError('Return metric is not scalar')
metric_values[channel_idx, condition_idx, iCallback] = metric_value

# the callback has been made, perform -if needed- the (postponed) normalization with the baseline values
Expand Down Expand Up @@ -1223,10 +1217,10 @@ def _load_data_epoch_averages__by_channel_condition_trial(data_reader, channels,
if metric_callbacks is not None:
if callable(metric_callbacks):
metric_values = allocate_array((len(channels), len(conditions_onsets)),
fill_value=np.nan, dtype=np.float64)
fill_value=np.nan, dtype=np.ndarray)
elif type(metric_callbacks) is tuple and len(metric_callbacks) > 0:
metric_values = allocate_array((len(channels), len(conditions_onsets), len(metric_callbacks)),
fill_value=np.nan, dtype=np.float64)
fill_value=np.nan, dtype=np.ndarray)
except MemoryError:
raise MemoryError('Not enough memory create a metric output matrix')

Expand Down Expand Up @@ -1321,10 +1315,10 @@ def _load_data_epochs__by_channels__withPrep(average, data_reader, retrieve_chan
if metric_callbacks is not None:
if callable(metric_callbacks):
metric_values = allocate_array((len(retrieve_channels), len(onsets)),
fill_value=np.nan, dtype=np.float64)
fill_value=np.nan, dtype=np.ndarray)
elif type(metric_callbacks) is tuple and len(metric_callbacks) > 0:
metric_values = allocate_array((len(retrieve_channels), len(onsets), len(metric_callbacks)),
fill_value=np.nan, dtype=np.float64)
fill_value=np.nan, dtype=np.ndarray)
except MemoryError:
raise MemoryError('Not enough memory create a metric output matrix')

Expand Down

0 comments on commit 778da01

Please sign in to comment.