diff --git a/kwave/utils/matlab.py b/kwave/utils/matlab.py index bfefce97..2a426f40 100644 --- a/kwave/utils/matlab.py +++ b/kwave/utils/matlab.py @@ -90,13 +90,11 @@ def matlab_mask(arr: np.ndarray, mask: np.ndarray, diff: Optional[int] = None) - """ - if mask.dtype == "uint8": - mask = mask.astype(np.int16) - if diff is None: - return np.expand_dims(arr.ravel(order="F")[mask.ravel(order="F")], axis=-1) # compatibility, n => [n, 1] + flat_mask = mask.ravel(order="F") else: - return np.expand_dims(arr.ravel(order="F")[mask.ravel(order="F") + diff], axis=-1) # compatibility, n => [n, 1] + flat_mask = mask.ravel(order="F") + diff + return np.expand_dims(arr.ravel(order="F")[flat_mask], axis=-1) # compatibility, n => [n, 1] def unflatten_matlab_mask(arr: np.ndarray, mask: np.ndarray, diff: Optional[int] = None) -> Tuple[Union[int, np.ndarray], ...]: