diff --git a/src/patch_denoise/_docs.py b/src/patch_denoise/_docs.py index 980389e..539142e 100644 --- a/src/patch_denoise/_docs.py +++ b/src/patch_denoise/_docs.py @@ -5,7 +5,6 @@ Docstring can then use templated argument such as ``$patch_config`` that will be substitute by their definition (see docdict items). -source: """ import inspect diff --git a/src/patch_denoise/bindings/cli.py b/src/patch_denoise/bindings/cli.py index 63e6a34..7325041 100644 --- a/src/patch_denoise/bindings/cli.py +++ b/src/patch_denoise/bindings/cli.py @@ -43,6 +43,19 @@ def parse_args(): ), default="optimal-fro_11_5_weighted", ) + parser.add_argument( + "--time-slice", + help=( + "Slice across time. \n" + "If x the patch will be N times longer in space than in time \n" + "If int, this is the size of the time dimension patch. \n" + "If not specified, the whole time serie is used. \n" + "Note: setting a low aspect ratio will increase the number of patch to be" + "processed, and will increase memory usage and computation times." + ), + default=None, + type=str, + ) parser.add_argument( "--mask", default=None, @@ -145,6 +158,13 @@ def main(): ) d_par = DenoiseParameters.from_str(args.conf) + if isinstance(args.time_slice, str): + if args.time_slice.endswith("x"): + t = float(args.time_slice[:-1]) + t = int(d_par ** (input_data.ndim - 1) / t) + else: + t = int(args.time_slice) + d_par.patch_shape = (d_par.patch_shape,) * (input_data.ndim - 1) + (t,) print(d_par) denoise_func = DENOISER_MAP[d_par.method] extra_kwargs = dict() diff --git a/src/patch_denoise/bindings/utils.py b/src/patch_denoise/bindings/utils.py index 5eb2ded..c303b6e 100644 --- a/src/patch_denoise/bindings/utils.py +++ b/src/patch_denoise/bindings/utils.py @@ -45,8 +45,8 @@ class DenoiseParameters: """Denoise Parameters data structure.""" method: str = None - patch_shape: int = 11 - patch_overlap: int = 0 + patch_shape: int | tuple[int, ...] = 11 + patch_overlap: int | tuple[int, ...] = 0 recombination: str = "weighted" # "center" is also available mask_threshold: int = 10 diff --git a/src/patch_denoise/space_time/base.py b/src/patch_denoise/space_time/base.py index 3b0a660..e11a916 100644 --- a/src/patch_denoise/space_time/base.py +++ b/src/patch_denoise/space_time/base.py @@ -1,12 +1,118 @@ """Base Structure for patch-based denoising on spatio-temporal dimension.""" import abc +from functools import partial, cached_property import logging import numpy as np from tqdm.auto import tqdm from .._docs import fill_doc -from .utils import get_patch_locs + +class PatchedArray: + """A container for accessing custom view of array easily. + + Parameters + ---------- + array: np.ndarray + patch_shape: tuple + patch_overlap: tuple + + """ + + def __init__( + self, + array, + patch_shape, + patch_overlap, + dtype=None, + padding_mode="edge", + **kwargs, + ): + if isinstance(array, tuple): + array = np.zeros(array, dtype=dtype) + self._arr = array + + self._ps = np.asarray(patch_shape) + self._po = np.asarray(patch_overlap) + self._po = patch_overlap + + dimensions = self._arr.ndim + step = self._ps - self._po + if np.any(step < 0): + raise ValueError("overlap should be smaller than patch on every dimension.") + + if self._ps.size != dimensions or step.size != dimensions: + raise ValueError( + "self._ps and step must have the same number of dimensions as the input self._array." + ) + + # Ensure patch size is not larger than self._array size along each axis + self._ps = np.minimum(self._ps, self._arr.shape) + + # Calculate the shape and strides of the sliding view + grid_shape = tuple( + ((self._arr.shape[i] - self._ps[i]) // step[i] + 1) + if self._ps[i] < self._arr.shape[i] + else 1 + for i in range(dimensions) + ) + shape = grid_shape + tuple(self._ps) + strides = ( + tuple( + self._arr.strides[i] * step[i] + if self._ps[i] < self._arr.shape[i] + else 0 + for i in range(dimensions) + ) + + self._arr.strides + ) + + # Create the sliding view + self.sliding_view = np.lib.stride_tricks.as_strided( + self._arr, shape=shape, strides=strides + ) + + self._grid_shape = grid_shape + + @property + def n_patches(self): + """Get number of patches.""" + return np.prod(self._grid_shape) + + def get_patch(self, idx): + """Get patch at linear index ``idx``.""" + return self.sliding_view[np.unravel_index(idx, self._grid_shape)] + + def set_patch(self, idx, value): + """Set patch at linear index ``idx`` with value.""" + self.sliding_view[np.unravel_index(idx, self._grid_shape)] + + def add2patch(self, idx, value): + """Add to patch, in place.""" + patch = self.get_patch(idx) + # self.set_patch(idx, patch + value) + patch += value + + # def sync(self): + # """Apply the padded value to the array back.""" + # np.copyto( + # self._array, + # self._padded_array[ + # tuple( + # np.s_[: (s + 1 - ps) if (s - ps) else s] + # for ps, s in zip(self._ps, self._padded_array.shape) + # ) + # ], + # ) + + # def get(self): + # """Return the regular array, after applying the padded values.""" + # self.sync() + # return self._array + + def __getattr__(self, name): + """Get attribute of underlying array.""" + return getattr(self._arr, name) @fill_doc @@ -50,96 +156,157 @@ def denoise(self, input_data, mask=None, mask_threshold=50, progbar=None): $denoise_return """ data_shape = input_data.shape - output_data = np.zeros_like(input_data) - rank_map = np.zeros(data_shape[:-1], dtype=np.int32) + p_s, p_o = self._get_patch_param(data_shape) + + input_data = PatchedArray(input_data, p_s, p_o) + output_data = PatchedArray(data_shape, p_s, p_o, dtype=input_data.dtype) + patch_weights = PatchedArray(data_shape, p_s, p_o, dtype=np.float32) + rank_map = PatchedArray(data_shape, p_s, p_o, dtype=np.int32) + noise_std_estimate = PatchedArray(data_shape, p_s, p_o, dtype=np.float32) # Create Default mask if mask is None: - process_mask = np.full(data_shape[:-1], True) - else: - process_mask = np.copy(mask) - - patch_shape, patch_overlap = self.__get_patch_param(data_shape) - patch_size = np.prod(patch_shape) + process_mask = np.full(data_shape, True) + elif mask.shape == input_data.shape[:-1]: + process_mask = np.broadcast_to(mask, input_data.shape) - if self.recombination == "center": - patch_center = ( - *(slice(ps // 2, ps // 2 + 1) for ps in patch_shape), - slice(None, None, None), - ) - patchs_weight = np.zeros(data_shape[:-1], np.float32) - noise_std_estimate = np.zeros(data_shape[:-1], dtype=np.float32) + process_mask = PatchedArray( + process_mask, p_s, p_o, padding_mode="constant", constant_values=0 + ) - # discard useless patches - patch_locs = get_patch_locs(patch_shape, patch_overlap, data_shape[:-1]) - get_it = np.zeros(len(patch_locs), dtype=bool) + center_pos = tuple(p // 2 for p in p_s) + patch_space_size = np.prod(p_s[:-1]) + # select only queue index where process_mask is valid. + get_it = np.zeros(input_data.n_patches, dtype=bool) - for i, patch_tl in enumerate(patch_locs): - patch_slice = tuple( - slice(tl, tl + ps) for tl, ps in zip(patch_tl, patch_shape) - ) - if 100 * np.sum(process_mask[patch_slice]) / patch_size > mask_threshold: + for i in range(len(get_it)): + pm = process_mask.get_patch(i) + if 100 * np.sum(pm) / pm.size > mask_threshold: get_it[i] = True - logging.info(f"Denoise {100 * np.sum(get_it) / len(patch_locs):.2f}% patches") - patch_locs = np.ascontiguousarray(patch_locs[get_it]) + select_patches = np.nonzero(get_it)[0] + del get_it if progbar is None: - progbar = tqdm(total=len(patch_locs)) + progbar = tqdm(total=len(select_patches)) elif progbar is not False: - progbar.reset(total=len(patch_locs)) - - for patch_tl in patch_locs: - patch_slice = tuple( - slice(tl, tl + ps) for tl, ps in zip(patch_tl, patch_shape) - ) - process_mask[patch_slice] = 1 - # building the casoratti matrix - patch = np.reshape(input_data[patch_slice], (-1, input_data.shape[-1])) - - # Replace all nan by mean value of patch. - # FIXME this behaviour should be documented - # And ideally choosen by the user. + progbar.reset(total=len(select_patches)) - patch[np.isnan(patch)] = np.mean(patch) + for i in select_patches: + input_patch_casorati = input_data.get_patch(i).reshape(patch_space_size, -1) p_denoise, maxidx, noise_var = self._patch_processing( - patch, - patch_slice=patch_slice, + input_patch_casorati, + patch_idx=i, **self.input_denoising_kwargs, ) - p_denoise = np.reshape(p_denoise, (*patch_shape, -1)) - patch_center_img = tuple( - ptl + ps // 2 for ptl, ps in zip(patch_tl, patch_shape) - ) + p_denoise = np.reshape(p_denoise, p_s) if self.recombination == "center": - output_data[patch_center_img] = p_denoise[patch_center] - noise_std_estimate[patch_center_img] += noise_var + output_data.get_patch(i)[center_pos] = p_denoise[center_pos] elif self.recombination == "weighted": theta = 1 / (2 + maxidx) - output_data[patch_slice] += p_denoise * theta - patchs_weight[patch_slice] += theta + output_data.add2patch(i, p_denoise * theta) + patch_weights.add2patch(i, theta) elif self.recombination == "average": - output_data[patch_slice] += p_denoise - patchs_weight[patch_slice] += 1 + output_data.add2patch(i, p_denoise) + patch_weights.add2patch(i, 1) else: raise ValueError( "recombination must be one of 'weighted', 'average', 'center'" ) - if not np.isnan(noise_var): - noise_std_estimate[patch_slice] += noise_var - # the top left corner of the patch is used as id for the patch. - rank_map[patch_center_img] = maxidx if progbar: progbar.update() # Averaging the overlapping pixels. # this is only required for averaging recombinations. + + output_data = output_data._arr + patch_weights = patch_weights._arr + if self.recombination in ["average", "weighted"]: - output_data /= patchs_weight[..., None] - noise_std_estimate /= patchs_weight + output_data /= patch_weights + + output_data[~process_mask._arr] = 0 - output_data[~process_mask] = 0 + return output_data, patch_weights, noise_std_estimate, rank_map - return output_data, patchs_weight, noise_std_estimate, rank_map + # if self.recombination == "center": + # patch_center = ( + # *(slice(ps // 2, ps // 2 + 1) for ps in patch_shape), + # slice(None, None, None), + # ) + # patchs_weight = np.zeros(data_shape[:-1], np.float32) + # noise_std_estimate = np.zeros(data_shape[:-1], dtype=np.float32) + + # # discard useless patches + # patch_locs = get_patch_locs(patch_shape, patch_overlap, data_shape) + # get_it = np.zeros(len(patch_locs), dtype=bool) + + # for i, patch_tl in enumerate(patch_locs): + # patch_slice = tuple( + # slice(tl, tl + ps) for tl, ps in zip(patch_tl, patch_shape) + # ) + # if 100 * np.sum(process_mask[patch_slice]) / patch_size > mask_threshold: + # get_it[i] = True + + # logging.info(f"Denoise {100 * np.sum(get_it) / len(patch_locs):.2f}% patches") + # patch_locs = np.ascontiguousarray(patch_locs[get_it]) + + # if progbar is None: + # progbar = tqdm(total=len(patch_locs)) + # elif progbar is not False: + # progbar.reset(total=len(patch_locs)) + + # for patch_tl in patch_locs: + # patch_slice = tuple( + # slice(tl, tl + ps) for tl, ps in zip(patch_tl, patch_shape) + # ) + # process_mask[patch_slice] = 1 + # # building the casoratti matrix + # patch = np.reshape(input_data[patch_slice], (-1, input_data.shape[-1])) + + # # Replace all nan by mean value of patch. + # # FIXME this behaviour should be documented + # # And ideally choosen by the user. + + # patch[np.isnan(patch)] = np.mean(patch) + # p_denoise, maxidx, noise_var = self._patch_processing( + # patch, + # patch_slice=patch_slice, + # **self.input_denoising_kwargs, + # ) + + # p_denoise = np.reshape(p_denoise, (*patch_shape, -1)) + # patch_center_img = tuple( + # ptl + ps // 2 for ptl, ps in zip(patch_tl, patch_shape) + # ) + # if self.recombination == "center": + # output_data[patch_center_img] = p_denoise[patch_center] + # noise_std_estimate[patch_center_img] += noise_var + # elif self.recombination == "weighted": + # theta = 1 / (2 + maxidx) + # output_data[patch_slice] += p_denoise * theta + # patchs_weight[patch_slice] += theta + # elif self.recombination == "average": + # output_data[patch_slice] += p_denoise + # patchs_weight[patch_slice] += 1 + # else: + # raise ValueError( + # "recombination must be one of 'weighted', 'average', 'center'" + # ) + # if not np.isnan(noise_var): + # noise_std_estimate[patch_slice] += noise_var + # # the top left corner of the patch is used as id for the patch. + # rank_map[patch_center_img] = maxidx + # if progbar: + # progbar.update() + # # Averaging the overlapping pixels. + # # this is only required for averaging recombinations. + # if self.recombination in ["average", "weighted"]: + # output_data /= patchs_weight[..., None] + # noise_std_estimate /= patchs_weight + + # output_data[~process_mask] = 0 + + # return output_data, patchs_weight, noise_std_estimate, rank_map @abc.abstractmethod def _patch_processing(self, patch, patch_slice=None, **kwargs): @@ -148,7 +315,7 @@ def _patch_processing(self, patch, patch_slice=None, **kwargs): Implemented by child classes. """ - def __get_patch_param(self, data_shape): + def _get_patch_param(self, data_shape): """Return tuple for patch_shape and patch_overlap. It works from whatever the input format was (int or list). @@ -161,8 +328,12 @@ def __get_patch_param(self, data_shape): p = tuple(p) elif isinstance(p, (int, np.integer)): p = (p,) * (len(data_shape) - 1) - pp[i] = p - if np.prod(pp[0]) < data_shape[-1]: + if len(p) == len(data_shape) - 1: + # add the time dimension + p = (*p, data_shape[-1]) + pp[i] = p + + if np.prod(pp[0][:-1]) < data_shape[-1]: logging.warning( f"the number of voxel in patch ({np.prod(pp[0])}) is smaller than the" f" last dimension ({data_shape[-1]}), this makes an ill-conditioned" diff --git a/src/patch_denoise/space_time/lowrank.py b/src/patch_denoise/space_time/lowrank.py index 69d6b26..345f39e 100644 --- a/src/patch_denoise/space_time/lowrank.py +++ b/src/patch_denoise/space_time/lowrank.py @@ -5,7 +5,7 @@ from scipy.linalg import svd from scipy.optimize import minimize -from .base import BaseSpaceTimeDenoiser +from .base import BaseSpaceTimeDenoiser, PatchedArray from .utils import ( eig_analysis, eig_synthesis, @@ -38,19 +38,19 @@ def __init__(self, patch_shape, patch_overlap, threshold_scale, **kwargs): super().__init__(patch_shape, patch_overlap, **kwargs) self.input_denoising_kwargs["threshold_scale"] = threshold_scale - def _patch_processing(self, patch, patch_slice=None, threshold_scale=1.0): + def _patch_processing(self, patch, patch_idx=None, threshold_scale=1.0): """Process a pach with the MP-PCA method.""" p_center, eig_vals, eig_vec, p_tmean = eig_analysis(patch) maxidx = 0 meanvar = np.mean(eig_vals) meanvar *= 4 * np.sqrt((len(eig_vals) - maxidx + 1) / len(patch)) - while meanvar < eig_vals[~maxidx] - eig_vals[0]: + while maxidx < len(eig_vals) and meanvar < eig_vals[~maxidx] - eig_vals[0]: maxidx += 1 meanvar = np.mean(eig_vals[:-maxidx]) meanvar *= 4 * np.sqrt((len(eig_vec) - maxidx + 1) / len(patch)) var_noise = np.mean(eig_vals[: len(eig_vals) - maxidx]) - maxidx = np.sum(eig_vals > (var_noise * threshold_scale**2)) + maxidx = np.sum(eig_vals > (var_noise * threshold_scale ** 2)) if maxidx == 0: patch_new = np.zeros_like(patch) + p_tmean @@ -89,18 +89,20 @@ def denoise( ------- $denoise_return """ + p_s, p_o = self._get_patch_param(input_data.shape) if isinstance(noise_std, (float, np.floating)): - self.input_denoising_kwargs["var_apriori"] = noise_std**2 * np.ones( - input_data.shape[:-1] - ) + var_apriori = noise_std ** 2 * np.ones(input_data.shape[:-1]) else: - self.input_denoising_kwargs["var_apriori"] = noise_std**2 - + var_apriori = noise_std ** 2 + var_apriori = PatchedArray( + np.broadcast_to(var_apriori[..., None], input_data.shape), p_s, p_o + ) + self.input_denoising_kwargs["var_apriori"] = var_apriori return super().denoise(input_data, mask, mask_threshold, progbar=progbar) - def _patch_processing(self, patch, patch_slice=None, var_apriori=None): + def _patch_processing(self, patch, patch_idx=None, var_apriori=None): """Process a pach with the Hybrid-PCA method.""" - varest = np.mean(var_apriori[patch_slice]) + varest = np.mean(var_apriori.get_patch(patch_idx)) p_center, eig_vals, eig_vec, p_tmean = eig_analysis(patch) maxidx = 0 var_noise = np.mean(eig_vals) @@ -164,7 +166,7 @@ def denoise( self._threshold = self._threshold_val * threshold_scale return super().denoise(input_data, mask, mask_threshold, progbar=progbar) - def _patch_processing(self, patch, patch_slice=None, **kwargs): + def _patch_processing(self, patch, patch_idx=None, **kwargs): """Process a pach with the simple SVT method.""" # Centering for better precision in SVD u_vec, s_values, v_vec, p_tmean = svd_analysis(patch) @@ -215,7 +217,7 @@ def denoise( $denoise_return """ - patch_shape, _ = self._BaseSpaceTimeDenoiser__get_patch_param(input_data.shape) + patch_shape, _ = self._get_patch_param(input_data.shape) # compute the threshold using Monte-Carlo Simulations. max_sval = sum( max( @@ -246,8 +248,8 @@ def denoise( # From MATLAB implementation def _opt_loss_x(y, beta): """Compute (8) of donoho2017.""" - tmp = y**2 - beta - 1 - return np.sqrt(0.5 * (tmp + np.sqrt((tmp**2) - (4 * beta)))) * ( + tmp = y ** 2 - beta - 1 + return np.sqrt(0.5 * (tmp + np.sqrt((tmp ** 2) - (4 * beta)))) * ( y >= (1 + np.sqrt(beta)) ) @@ -260,17 +262,20 @@ def _opt_ope_shrink(singvals, beta=1): def _opt_nuc_shrink(singvals, beta=1): """Perform optimal threshold of singular values for nuclear norm.""" tmp = _opt_loss_x(singvals, beta) - return np.maximum( - 0, - (tmp**4 - (np.sqrt(beta) * tmp * singvals) - beta), - ) / ((tmp**2) * singvals) + return ( + np.maximum( + 0, + (tmp ** 4 - (np.sqrt(beta) * tmp * singvals) - beta), + ) + / ((tmp ** 2) * singvals) + ) def _opt_fro_shrink(singvals, beta=1): """Perform optimal threshold of singular values for frobenius norm.""" return np.sqrt( np.maximum( - (((singvals**2) - beta - 1) ** 2 - 4 * beta), + (((singvals ** 2) - beta - 1) ** 2 - 4 * beta), 0, ) / singvals @@ -349,34 +354,37 @@ def denoise( IEEE Transactions on Information Theory 63, no. 4 (April 2017): 2137–52. https://doi.org/10.1109/TIT.2017.2653801. """ - patch_shape, _ = self._BaseSpaceTimeDenoiser__get_patch_param(input_data.shape) + p_s, p_o = self._get_patch_param(input_data.shape) self.input_denoising_kwargs["mp_median"] = marshenko_pastur_median( - beta=input_data.shape[-1] / np.prod(patch_shape), + beta=input_data.shape[-1] / np.prod(p_s), eps=eps_marshenko_pastur, ) + if noise_std is None: self.input_denoising_kwargs["var_apriori"] = None - elif isinstance(noise_std, (float, np.floating)): - self.input_denoising_kwargs["var_apriori"] = noise_std**2 * np.ones( - input_data.shape[:-1] - ) else: - self.input_denoising_kwargs["var_apriori"] = noise_std**2 - + if isinstance(noise_std, (float, np.floating)): + var_apriori = noise_std ** 2 * np.ones(input_data.shape[:-1]) + else: + var_apriori = noise_std ** 2 + var_apriori = PatchedArray( + np.broadcast_to(var_apriori[..., None], input_data.shape), p_s, p_o + ) + self.input_denoising_kwargs["var_apriori"] = var_apriori return super().denoise(input_data, mask, mask_threshold, progbar=progbar) def _patch_processing( self, patch, - patch_slice=None, + patch_idx=None, shrink_func=None, mp_median=None, var_apriori=None, ): u_vec, s_values, v_vec, p_tmean = svd_analysis(patch) if var_apriori is not None: - sigma = np.mean(np.sqrt(var_apriori[patch_slice])) + sigma = np.mean(np.sqrt(var_apriori.get_patch(patch_idx))) else: sigma = np.median(s_values) / np.sqrt(patch.shape[1] * mp_median) @@ -415,7 +423,7 @@ def _sure_atn_cost(X, method, sing_vals, gamma, sigma=None, tau=None): else: tau = np.exp(tau) - sing_vals2 = sing_vals**2 + sing_vals2 = sing_vals ** 2 n_vals = len(sing_vals) D = np.zeros((n_vals, n_vals), dtype=np.float32) dhat = sing_vals * np.maximum(1 - ((tau / sing_vals) ** gamma), 0) @@ -431,7 +439,7 @@ def _sure_atn_cost(X, method, sing_vals, gamma, sigma=None, tau=None): rss = np.sum((dhat - sing_vals) ** 2) if method == "gsure": return rss / (1 - div / n / p) ** 2 - return (sigma**2) * ((-n * p) + (2 * div)) + rss + return (sigma ** 2) * ((-n * p) + (2 * div)) + rss if NUMBA_AVAILABLE: @@ -582,25 +590,28 @@ def denoise( self.input_denoising_kwargs["gamma0"] = gamma0 self.input_denoising_kwargs["tau0"] = tau0 + p_s, p_o = self._get_patch_param(input_data.shape) if isinstance(noise_std, (float, np.floating)): - self.input_denoising_kwargs["var_apriori"] = noise_std**2 * np.ones( - input_data.shape[:-1] - ) + var_apriori = noise_std ** 2 * np.ones(input_data.shape[:-1]) else: - self.input_denoising_kwargs["var_apriori"] = noise_std**2 + var_apriori = noise_std ** 2 + var_apriori = PatchedArray( + np.broadcast_to(var_apriori[..., None], input_data.shape), p_s, p_o + ) + self.input_denoising_kwargs["var_apriori"] = var_apriori return super().denoise(input_data, mask, mask_threshold, progbar=progbar) def _patch_processing( self, patch, - patch_slice=None, + patch_idx=None, gamma0=None, tau0=None, var_apriori=None, method=None, nbsim=None, ): - stdest = np.sqrt(np.mean(var_apriori[patch_slice])) + stdest = np.sqrt(np.mean(var_apriori.get_patch(patch_idx))) u_vec, sing_vals, v_vec, p_tmean = svd_analysis(patch) diff --git a/src/patch_denoise/space_time/utils.py b/src/patch_denoise/space_time/utils.py index 0ad13aa..50a2e41 100644 --- a/src/patch_denoise/space_time/utils.py +++ b/src/patch_denoise/space_time/utils.py @@ -4,6 +4,56 @@ from scipy.linalg import eigh, svd +def get_patch_locs(p_shape, p_ovl, v_shape): + """ + Get all the patch top-left corner locations. + + Parameters + ---------- + vol_shape : tuple + The volume shape + patch_shape : tuple + The patch shape + patch_overlap : tuple + The overlap of patch for each dimension. + + Returns + ------- + numpy.ndarray + All the patch top-left corner locations. + + Notes + ----- + This is a legacy function, you probably want to use the PatchedArray class. + """ + # Create an iterator for all the possible patches top-left corner location. + if len(v_shape) != len(p_shape) or len(v_shape) != len(p_ovl): + raise ValueError( + f"Dimension mismatch between the arguments. {p_shape}{p_ovl}, {v_shape}" + ) + + ranges = [] + for v_s, p_s, p_o in zip(v_shape, p_shape, p_ovl): + if p_o >= p_s: + raise ValueError( + "Overlap should be a non-negative integer smaller than patch_size", + ) + last_idx = v_s - p_s + range_ = np.arange(0, last_idx, p_s - p_o, dtype=np.int32) + if range_[-1] < last_idx: + range_ = np.append(range_, last_idx) + ranges.append(range_) + # fast ND-Cartesian product from https://stackoverflow.com/a/11146645 + patch_locs = np.empty( + [len(arr) for arr in ranges] + [len(p_shape)], + dtype=np.int32, + ) + for idx, coords in enumerate(np.ix_(*ranges)): + patch_locs[..., idx] = coords + + return patch_locs.reshape(-1, len(p_shape)) + + def svd_analysis(input_data): """Return the centered SVD decomposition. @@ -153,50 +203,6 @@ def mp_pdf(x): return (lobnd + hibnd) / 2 -def get_patch_locs(p_shape, p_ovl, v_shape): - """ - Get all the patch top-left corner locations. - - Parameters - ---------- - vol_shape : tuple - The volume shape - patch_shape : tuple - The patch shape - patch_overlap : tuple - The overlap of patch for each dimension. - - Returns - ------- - numpy.ndarray - All the patch top-left corner locations. - """ - # Create an iterator for all the possible patches top-left corner location. - if len(v_shape) != len(p_shape) or len(v_shape) != len(p_ovl): - raise ValueError("Dimension mismatch between the arguments.") - - ranges = [] - for v_s, p_s, p_o in zip(v_shape, p_shape, p_ovl): - if p_o >= p_s: - raise ValueError( - "Overlap should be a non-negative integer smaller than patch_size", - ) - last_idx = v_s - p_s - range_ = np.arange(0, last_idx, p_s - p_o, dtype=np.int32) - if range_[-1] < last_idx: - range_ = np.append(range_, last_idx) - ranges.append(range_) - # fast ND-Cartesian product from https://stackoverflow.com/a/11146645 - patch_locs = np.empty( - [len(arr) for arr in ranges] + [len(p_shape)], - dtype=np.int32, - ) - for idx, coords in enumerate(np.ix_(*ranges)): - patch_locs[..., idx] = coords - - return patch_locs.reshape(-1, len(p_shape)) - - def estimate_noise(noise_sequence, block_size=1): """Estimate the temporal noise standard deviation of a noise only sequence.""" volume_shape = noise_sequence.shape[:-1]