Skip to content

Commit

Permalink
Sparse array bitmask
Browse files Browse the repository at this point in the history
| Before [e082cbd] <speedup-cohorts~1>   | After [3fa4d7a8] <speedup-cohorts>   |   Ratio | Benchmark (Parameter)                                   |
|-----------------------------------------|--------------------------------------|---------|---------------------------------------------------------|
| 1.02±0.01ms                             | 7.68±0.06ms                          |    7.53 | cohorts.ERA5Google.time_find_group_cohorts              |
| 233±0.5μs                               | 1.63±0ms                             |    6.97 | cohorts.PerfectMonthlyRechunked.time_find_group_cohorts |
| 234±0.6μs                               | 1.63±0ms                             |    6.96 | cohorts.PerfectMonthly.time_find_group_cohorts          |
| 1.50±0ms                                | 6.37±0.02ms                          |    4.25 | cohorts.ERA5MonthHourRechunked.time_find_group_cohorts  |
| 1.41±0ms                                | 5.88±0.01ms                          |    4.17 | cohorts.ERA5MonthHour.time_find_group_cohorts           |
| 4.97±0.01ms                             | 9.37±0.04ms                          |    1.88 | cohorts.ERA5DayOfYearRechunked.time_find_group_cohorts  |
| 6.23±0.01ms                             | 10.2±0.02ms                          |    1.63 | cohorts.ERA5DayOfYear.time_find_group_cohorts           |
| 2.66±0.01ms                             | 4.09±0.02ms                          |    1.54 | cohorts.PerfectMonthly.time_graph_construct             |
| 2.66±0.01ms                             | 4.09±0ms                             |    1.54 | cohorts.PerfectMonthlyRechunked.time_graph_construct    |
| 10.6±0.04ms                             | 15.5±0.05ms                          |    1.47 | cohorts.ERA5MonthHourRechunked.time_graph_construct     |
| 10.1±0.04ms                             | 14.8±0.03ms                          |    1.45 | cohorts.ERA5MonthHour.time_graph_construct              |
| 19.5±0.08ms                             | 26.5±0.1ms                           |    1.36 | cohorts.ERA5Google.time_graph_construct                 |
| 21.5±0.05ms                             | 28.4±0.6ms                           |    1.33 | cohorts.NWMMidwest.time_find_group_cohorts              |
  • Loading branch information
dcherian committed Nov 27, 2023
1 parent 97ce15f commit 1b79831
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 10 deletions.
1 change: 1 addition & 0 deletions ci/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ dependencies:
- numpy_groupies>=0.9.19
- numbagg>=0.3
- wheel
- scipy
1 change: 1 addition & 0 deletions ci/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dependencies:
- pip
- xarray
- numpy>=1.22
- scipy
- numpydoc
- numpy_groupies>=0.9.19
- toolz
Expand Down
2 changes: 1 addition & 1 deletion ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies:
- netcdf4
- pandas
- numpy>=1.22
- scipy
- lxml # for mypy coverage report
- matplotlib
- pip
Expand All @@ -24,4 +25,3 @@ dependencies:
- toolz
- numba
- numbagg>=0.3
- scipy
1 change: 1 addition & 0 deletions ci/no-dask.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dependencies:
- netcdf4
- pandas
- numpy>=1.22
- scipy
- pip
- pytest
- pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion ci/no-numba.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies:
- netcdf4
- pandas
- numpy>=1.22
- scipy
- lxml # for mypy coverage report
- matplotlib
- pip
Expand All @@ -21,4 +22,3 @@ dependencies:
- numpy_groupies>=0.9.19
- pooch
- toolz
- scipy
1 change: 1 addition & 0 deletions ci/no-xarray.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dependencies:
- netcdf4
- pandas
- numpy>=1.22
- scipy
- pip
- pytest
- pytest-cov
Expand Down
48 changes: 40 additions & 8 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,10 @@ def _unique(a: np.ndarray) -> np.ndarray:
def slices_from_chunks(chunks):
"""slightly modified from dask.array.core.slices_from_chunks to be lazy"""
cumdims = [tlz.accumulate(operator.add, bds, 0) for bds in chunks]
slices = [
[slice(s, s + dim) for s, dim in zip(starts, shapes)]
slices = (
(slice(s, s + dim) for s, dim in zip(starts, shapes))
for starts, shapes in zip(cumdims, chunks)
]
)
return product(*slices)


Expand Down Expand Up @@ -247,12 +247,44 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:

labels = np.broadcast_to(labels, shape[-labels.ndim :])

bitmask = np.zeros((nchunks, nlabels), dtype=bool)
from scipy.sparse import csc_array

rows = []
cols = []
out = np.zeros((nlabels,), dtype=bool)
ilabels = np.arange(nlabels - 1)
for idx, region in enumerate(slices_from_chunks(chunks)):
bitmask[idx, labels[region]] = True
bitmask = bitmask[:, :-1]
chunk = np.arange(nchunks) # [:, np.newaxis] * bitmask
label_chunks = {lab: chunk[bitmask[:, lab]] for lab in range(nlabels - 1)}
# This is a quite fast way to find uniques,
# inspired by a similar idea in numpy_groupies
# instead of explicitly finding uniques, repeatedly write to the same location
subset = labels[region]
# The reshape is not strictly necessary but is about 100ms faster on a test problem.
out[subset.reshape(-1)] = True
# skip the -1 sentinel by slicing
uniques = ilabels[out[:-1]]
rows.append([idx] * len(uniques))
cols.append(uniques)
out[:] = False
rows_array = np.concatenate(rows)
cols_array = np.concatenate(cols)
data = np.broadcast_to(np.array(1, dtype=np.uint8), rows_array.shape)
bitmask = csc_array(
(data, (rows_array, cols_array)),
dtype=bool,
shape=(nchunks, nlabels - 1),
)
# chunk = np.arange(nchunks) # [:, np.newaxis] * bitmask
label_chunks = {
lab: bitmask.indices[slice(bitmask.indptr[lab], bitmask.indptr[lab + 1])]
for lab in range(nlabels - 1)
}

# bitmask = np.zeros((nchunks, nlabels), dtype=bool)
# for idx, region in enumerate(slices_from_chunks(chunks)):
# bitmask[idx, labels[region]] = True
# bitmask = bitmask[:, :-1]
# chunk = np.arange(nchunks) # [:, np.newaxis] * bitmask
# label_chunks = {lab: chunk[bitmask[:, lab]] for lab in range(nlabels - 1)}

# which_chunk = np.empty(shape, dtype=np.int64)
# for idx, region in enumerate(slices_from_chunks(chunks)):
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ requires = [
"pandas",
"numpy>=1.22",
"numpy_groupies>=0.9.19",
"scipy",
"toolz",
"setuptools>=61.0.0",
"setuptools_scm[toml]>=7.0",
Expand Down Expand Up @@ -101,6 +102,7 @@ known-third-party = [
"pkg_resources",
"pytest",
"setuptools",
"scipy",
"xarray"
]

Expand Down

0 comments on commit 1b79831

Please sign in to comment.