From cdb24170b420822daa197ae0de411ff5b7df353a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 7 Jan 2025 09:19:25 -0700 Subject: [PATCH] Optimize quantile. (#409) --- flox/aggregate_flox.py | 43 ++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index 1e7d330a..938bd6fc 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -50,8 +50,8 @@ def _lerp(a, b, *, t, dtype, out=None): def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=None): inv_idx = np.concatenate((inv_idx, [array.shape[-1]])) - array_nanmask = isnull(array) - actual_sizes = np.add.reduceat(~array_nanmask, inv_idx[:-1], axis=axis) + array_validmask = notnull(array) + actual_sizes = np.add.reduceat(array_validmask, inv_idx[:-1], axis=axis) newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,) full_sizes = np.reshape(np.diff(inv_idx), newshape) nanmask = full_sizes != actual_sizes @@ -59,27 +59,30 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non # The approach here is to use (complex_array.partition) because # 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary # 2. Using record_array.partition(..., order=["labels", "array"]) is incredibly slow. - # partition will first sort by real part, then by imaginary part, so it is a two element lex-partition. - # So we set + # 3. For complex arrays, partition will first sort by real part, then by imaginary part, so it is a two element + # lex-partition. + # Therefore we use approach (3) and set # complex_array = group_idx + 1j * array - # group_idx is an integer (guaranteed), but array can have NaNs. Now, - # 1 + 1j*NaN = NaN + 1j * NaN - # so we must replace all NaNs with the maximum array value in the group so these NaNs - # get sorted to the end. - # Partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/ - # TODO: Don't know if this array has been copied in _prepare_for_flox. This is potentially wasteful - array = np.where(array_nanmask, -np.inf, array) - maxes = np.maximum.reduceat(array, inv_idx[:-1], axis=axis) - replacement = np.repeat(maxes, np.diff(inv_idx), axis=axis) - array[array_nanmask] = replacement[array_nanmask] - + # group_idx is an integer (guaranteed), but array can have NaNs. + # Now the sort order of np.nan is bigger than np.inf + # >>> c = (np.array([0, 1, 2, np.nan]) + np.array([np.nan, 2, 3, 4]) * 1j) + # >>> c.partition(2) + # >>> c + # array([ 1. +2.j, 2. +3.j, nan +4.j, nan+nanj]) + # So we determine which indices we need using the fact that NaNs get sorted to the end. + # This *was* partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/ + # but not any more now that I use partition and avoid replacing NaNs qin = q q = np.atleast_1d(qin) q = np.reshape(q, (len(q),) + (1,) * array.ndim) # This is numpy's method="linear" # TODO: could support all the interpolations here - virtual_index = q * (actual_sizes - 1) + inv_idx[:-1] + offset = actual_sizes.cumsum(axis=-1) + actual_sizes -= 1 + virtual_index = q * actual_sizes + # virtual_index is relative to group starts, so now offset that + virtual_index[..., 1:] += offset[..., :-1] is_scalar_q = is_scalar(qin) if is_scalar_q: @@ -103,7 +106,10 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non # partition the complex array in-place labels_broadcast = np.broadcast_to(group_idx, array.shape) with np.errstate(invalid="ignore"): - cmplx = labels_broadcast + 1j * (array.view(int) if array.dtype.kind in "Mm" else array) + cmplx = 1j * (array.view(int) if array.dtype.kind in "Mm" else array) + # This is a very intentional way of handling `array` with -inf/+inf values :/ + # a simple (labels + 1j * array) will yield `nan+inf * 1j` instead of `0 + inf * j` + cmplx.real = labels_broadcast cmplx.partition(kth=kth, axis=-1) if is_scalar_q: a_ = cmplx.imag @@ -145,7 +151,8 @@ def _np_grouped_op( (inv_idx,) = flag.nonzero() if size is None: - size = np.max(uniques) + 1 + # This is sorted, so the last value is the largest label + size = uniques[-1] + 1 if dtype is None: dtype = array.dtype