Skip to content

Commit

Permalink
Optimize quantile. (#409)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian authored Jan 7, 2025
1 parent 3853101 commit cdb2417
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions flox/aggregate_flox.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,39 @@ def _lerp(a, b, *, t, dtype, out=None):
def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=None):
inv_idx = np.concatenate((inv_idx, [array.shape[-1]]))

array_nanmask = isnull(array)
actual_sizes = np.add.reduceat(~array_nanmask, inv_idx[:-1], axis=axis)
array_validmask = notnull(array)
actual_sizes = np.add.reduceat(array_validmask, inv_idx[:-1], axis=axis)
newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,)
full_sizes = np.reshape(np.diff(inv_idx), newshape)
nanmask = full_sizes != actual_sizes

# The approach here is to use (complex_array.partition) because
# 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary
# 2. Using record_array.partition(..., order=["labels", "array"]) is incredibly slow.
# partition will first sort by real part, then by imaginary part, so it is a two element lex-partition.
# So we set
# 3. For complex arrays, partition will first sort by real part, then by imaginary part, so it is a two element
# lex-partition.
# Therefore we use approach (3) and set
# complex_array = group_idx + 1j * array
# group_idx is an integer (guaranteed), but array can have NaNs. Now,
# 1 + 1j*NaN = NaN + 1j * NaN
# so we must replace all NaNs with the maximum array value in the group so these NaNs
# get sorted to the end.
# Partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/
# TODO: Don't know if this array has been copied in _prepare_for_flox. This is potentially wasteful
array = np.where(array_nanmask, -np.inf, array)
maxes = np.maximum.reduceat(array, inv_idx[:-1], axis=axis)
replacement = np.repeat(maxes, np.diff(inv_idx), axis=axis)
array[array_nanmask] = replacement[array_nanmask]

# group_idx is an integer (guaranteed), but array can have NaNs.
# Now the sort order of np.nan is bigger than np.inf
# >>> c = (np.array([0, 1, 2, np.nan]) + np.array([np.nan, 2, 3, 4]) * 1j)
# >>> c.partition(2)
# >>> c
# array([ 1. +2.j, 2. +3.j, nan +4.j, nan+nanj])
# So we determine which indices we need using the fact that NaNs get sorted to the end.
# This *was* partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/
# but not any more now that I use partition and avoid replacing NaNs
qin = q
q = np.atleast_1d(qin)
q = np.reshape(q, (len(q),) + (1,) * array.ndim)

# This is numpy's method="linear"
# TODO: could support all the interpolations here
virtual_index = q * (actual_sizes - 1) + inv_idx[:-1]
offset = actual_sizes.cumsum(axis=-1)
actual_sizes -= 1
virtual_index = q * actual_sizes
# virtual_index is relative to group starts, so now offset that
virtual_index[..., 1:] += offset[..., :-1]

is_scalar_q = is_scalar(qin)
if is_scalar_q:
Expand All @@ -103,7 +106,10 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
# partition the complex array in-place
labels_broadcast = np.broadcast_to(group_idx, array.shape)
with np.errstate(invalid="ignore"):
cmplx = labels_broadcast + 1j * (array.view(int) if array.dtype.kind in "Mm" else array)
cmplx = 1j * (array.view(int) if array.dtype.kind in "Mm" else array)
# This is a very intentional way of handling `array` with -inf/+inf values :/
# a simple (labels + 1j * array) will yield `nan+inf * 1j` instead of `0 + inf * j`
cmplx.real = labels_broadcast
cmplx.partition(kth=kth, axis=-1)
if is_scalar_q:
a_ = cmplx.imag
Expand Down Expand Up @@ -145,7 +151,8 @@ def _np_grouped_op(
(inv_idx,) = flag.nonzero()

if size is None:
size = np.max(uniques) + 1
# This is sorted, so the last value is the largest label
size = uniques[-1] + 1
if dtype is None:
dtype = array.dtype

Expand Down

0 comments on commit cdb2417

Please sign in to comment.