Skip to content

Commit

Permalink
xarray API (#17)
Browse files Browse the repository at this point in the history
* xarray wrapper for `gaussian_kernel`

* get the xarray API for gaussian to work properly

* make the `reshape` calls compatible with `dask`

* manually create the grid object

* partially allow dask arrays

* attach coordinates to the kernel's dimensions

* refactor `gaussian_kernel` to outsource the computation of weights

* compute explicitly

* configure `ruff` to disallow relative imports

* move `hc.kernels.xarray` to `hc.xarray.kernels`

* basic implementation of `convolve` for `Dataset` objects

* docstring for `convolve`

* expose `convolve` and the kernels module in the top-level `xarray` module

* use the accessor to fetch the grid info object

* move the rolling mean kernel into the function itself

* rename the tested module to `np_kernels`

* return the new cell ids from the kernel functions

* use the cell ids returned by the low-level kernel function as input cell ids

* check that the kernels are constructed properly with the xarray API

* fix the cell ids

* check that full-sphere convolution kernels still work

* add back the extended deadline, `numba` might take a while to compile

* support rings greater than 127

Not sure if we ever would need those, though.

* manually construct the `DataArray` that wraps the sparse matrix

* make `dim` a keyword-only parameter

* add padding parameters to `convolve`

* pad the input if it is not a global map

* add a simple wrapper for the padding function

* properly use the padding function

* refactor

* attach the ring to the kernel's metadata

* actually get the padding to work

* `pre-commit`

* expose `pad`
  • Loading branch information
keewis authored Jul 25, 2024
1 parent 79d11b7 commit c726356
Show file tree
Hide file tree
Showing 12 changed files with 430 additions and 51 deletions.
9 changes: 8 additions & 1 deletion healpix_convolution/kernels/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import dask
import dask.array as da
import numpy as np
import sparse

Expand All @@ -22,5 +24,10 @@ def create_sparse(cell_ids, neighbours, weights):
weights_ = np.reshape(weights, (-1,))[mask]
coords_ = coords[..., mask]

if isinstance(weights_, da.Array):
coords_, weights_ = dask.compute(coords_, weights_)

shape = (cell_ids.size, all_cell_ids.size)
return sparse.COO(coords=coords_, data=weights_, shape=shape, fill_value=0)
return all_cell_ids, sparse.COO(
coords=coords_, data=weights_, shape=shape, fill_value=0
)
20 changes: 14 additions & 6 deletions healpix_convolution/kernels/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@
from healpix_convolution.neighbours import neighbours


def gaussian_function(distances, sigma, *, mask=None):
sigma2 = sigma * sigma
phi_x = np.exp(-0.5 / sigma2 * distances**2)

if mask is not None:
masked = np.where(mask, 0, phi_x)
else:
masked = phi_x

return masked / np.sum(masked, axis=1, keepdims=True)


def gaussian_kernel(
cell_ids,
*,
Expand Down Expand Up @@ -59,10 +71,6 @@ def gaussian_kernel(
cell_ids, resolution=resolution, indexing_scheme=indexing_scheme, ring=ring
)
d = angular_distances(nb, resolution=resolution, indexing_scheme=indexing_scheme)
weights = gaussian_function(d, sigma, mask=nb == -1)

sigma2 = sigma * sigma
phi_x = np.exp(-0.5 / sigma2 * d**2)
masked = np.where(nb == -1, 0, phi_x)
normalized = masked / np.sum(masked, axis=1, keepdims=True)

return create_sparse(cell_ids, nb, normalized)
return create_sparse(cell_ids, nb, weights)
17 changes: 15 additions & 2 deletions healpix_convolution/neighbours.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import healpy as hp
import numba
import numpy as np
from numba import int8, int32, int64
from numba import int8, int16, int32, int64

try:
import dask.array as da
Expand Down Expand Up @@ -138,6 +138,8 @@ def adjust_xyf(cx, cy, cf, nside):
[
(int32, int32, int32, int32, int8[:, :], int32[:], int32[:], int32[:]),
(int64, int64, int64, int32, int8[:, :], int64[:], int64[:], int64[:]),
(int32, int32, int32, int32, int16[:, :], int32[:], int32[:], int32[:]),
(int64, int64, int64, int32, int16[:, :], int64[:], int64[:], int64[:]),
],
"(),(),(),(),(n,m)->(n),(n),(n)",
)
Expand Down Expand Up @@ -166,6 +168,17 @@ def _neighbours(cell_ids, *, offsets, nside, indexing_scheme):
return np.where(neighbour_face == -1, -1, n_)


def minimum_dtype(value):
if value < np.iinfo("int8").max:
return "int8"
elif value < np.iinfo("int16").max:
return "int16"
elif value < np.iinfo("int32").max:
return "int32"
elif value < np.iinfo("int64").max:
return "int64"


def neighbours(cell_ids, *, resolution, indexing_scheme, ring=1):
"""determine the neighbours within the nth ring around the center pixel
Expand All @@ -188,7 +201,7 @@ def neighbours(cell_ids, *, resolution, indexing_scheme, ring=1):
"rings containing more than the neighbouring base pixels are not supported"
)

offsets = np.asarray(list(generate_offsets(ring=ring)), dtype="int8")
offsets = np.asarray(list(generate_offsets(ring=ring)), dtype=minimum_dtype(ring))

if isinstance(cell_ids, dask_array_type):
n_neighbours = (2 * ring + 1) ** 2
Expand Down
32 changes: 11 additions & 21 deletions healpix_convolution/tests/test_convolution.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import hypothesis
import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
import numpy as np
import pytest
import sparse
from hypothesis import given, settings

from healpix_convolution import convolution


@pytest.fixture
def rolling_mean_kernel():
kernel = (
@given(
data=npst.arrays(
shape=st.sampled_from([(5,), (10, 5)]),
# TODO: figure out how to deal with floating point values
dtype=st.sampled_from(["int16", "int32", "int64"]),
),
)
@settings(deadline=1000)
def test_numpy_convolve(data):
dense_kernel = (
np.array(
[
[1, 1, 0, 0, 1],
Expand All @@ -24,22 +29,7 @@ def rolling_mean_kernel():
/ 3
)

return sparse.COO.from_numpy(kernel, fill_value=0)


@given(
data=npst.arrays(
shape=st.sampled_from([(5,), (10, 5)]),
# TODO: figure out how to deal with floating point values
dtype=st.sampled_from(["int16", "int32", "int64"]),
),
)
@settings(
deadline=1000,
suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture],
)
def test_numpy_convolve(data, rolling_mean_kernel):
kernel = rolling_mean_kernel
kernel = sparse.COO.from_numpy(dense_kernel, fill_value=0)
actual = convolution.convolve(data, kernel)

padding = [(0, 0)] * (data.ndim - 1) + [(1, 1)]
Expand Down
147 changes: 139 additions & 8 deletions healpix_convolution/tests/test_kernels.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
import pytest
import xarray as xr
import xdggs

from healpix_convolution import kernels
from healpix_convolution import kernels as np_kernels
from healpix_convolution.xarray import kernels as xr_kernels


@pytest.mark.parametrize(
Expand All @@ -24,16 +27,20 @@
),
)
def test_create_sparse(cell_ids, neighbours, weights):
input_cell_ids = np.unique(neighbours)
if input_cell_ids[0] == -1:
input_cell_ids = input_cell_ids[1:]
expected_cell_ids = np.unique(neighbours)
if expected_cell_ids[0] == -1:
expected_cell_ids = expected_cell_ids[1:]

actual = kernels.common.create_sparse(cell_ids, neighbours, weights)
actual_cell_ids, actual = np_kernels.common.create_sparse(
cell_ids, neighbours, weights
)

nnz = np.sum(neighbours != -1, axis=1)
value = nnz * weights[0]

expected_shape = (cell_ids.size, input_cell_ids.size)
np.testing.assert_equal(actual_cell_ids, expected_cell_ids)

expected_shape = (cell_ids.size, expected_cell_ids.size)
assert hasattr(actual, "nnz"), "not a sparse matrix"
assert np.allclose(
np.sum(actual, axis=1).todense(), value
Expand Down Expand Up @@ -78,7 +85,7 @@ class TestGaussian:
),
)
def test_gaussian_kernel(self, cell_ids, kwargs):
actual = kernels.gaussian_kernel(cell_ids, **kwargs)
_, actual = np_kernels.gaussian_kernel(cell_ids, **kwargs)

kernel_sum = np.sum(actual, axis=1)

Expand Down Expand Up @@ -116,4 +123,128 @@ def test_gaussian_kernel(self, cell_ids, kwargs):
)
def test_gaussian_kernel_errors(self, cell_ids, kwargs, error, pattern):
with pytest.raises(error, match=pattern):
kernels.gaussian_kernel(cell_ids, **kwargs)
np_kernels.gaussian_kernel(cell_ids, **kwargs)


class TestXarray:
@pytest.mark.parametrize(
["obj", "kwargs"],
(
(
xr.DataArray(
[1, 2],
coords={
"cell_ids": (
"cells",
np.array([1, 2]),
{
"grid_name": "healpix",
"resolution": 1,
"indexing_scheme": "nested",
},
)
},
dims="cells",
),
{"sigma": 0.1},
),
(
xr.DataArray(
[1, 2],
coords={
"cell_ids": (
"cells",
np.array([1, 2]),
{
"grid_name": "healpix",
"resolution": 1,
"indexing_scheme": "ring",
},
)
},
dims="cells",
),
{"sigma": 0.1},
),
(
xr.DataArray(
[0, 2],
coords={
"cell_ids": (
"cells",
np.array([0, 2]),
{
"grid_name": "healpix",
"resolution": 1,
"indexing_scheme": "nested",
},
)
},
dims="cells",
),
{"sigma": 0.2},
),
(
xr.DataArray(
[1, 2],
coords={
"cell_ids": (
"cells",
np.array([1, 2]),
{
"grid_name": "healpix",
"resolution": 1,
"indexing_scheme": "nested",
},
)
},
dims="cells",
),
{"sigma": 0.1, "kernel_size": 5},
),
(
xr.DataArray(
[0, 3],
coords={
"cell_ids": (
"cells",
np.array([0, 3]),
{
"grid_name": "healpix",
"resolution": 1,
"indexing_scheme": "ring",
},
)
},
dims="cells",
),
{"sigma": 0.1, "kernel_size": 3},
),
(
xr.DataArray(
np.arange(12 * 4**2, dtype="int64"),
coords={
"cell_ids": (
"cells",
np.arange(12 * 4**2, dtype="int64"),
{
"grid_name": "healpix",
"resolution": 2,
"indexing_scheme": "ring",
},
)
},
dims="cells",
),
{"sigma": 0.1, "kernel_size": 3},
),
),
)
def test_gaussian_kernel(self, obj, kwargs):
obj_ = obj.pipe(xdggs.decode)
actual = xr_kernels.gaussian_kernel(obj_, **kwargs)

kernel_sum = actual.sum(dim="input_cells")

assert actual.count() == actual.size
np.testing.assert_allclose(kernel_sum.data.todense(), 1)
3 changes: 3 additions & 0 deletions healpix_convolution/xarray/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from healpix_convolution.xarray import kernels # noqa: F401
from healpix_convolution.xarray.convolution import convolve # noqa: F401
from healpix_convolution.xarray.padding import pad # noqa: F401
67 changes: 67 additions & 0 deletions healpix_convolution/xarray/convolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import xarray as xr

from healpix_convolution.xarray.padding import pad


def convolve(
ds, kernel, *, dim="cells", mode: str = "constant", constant_value: int | float = 0
):
"""convolve data on a DGGS grid
Parameters
----------
ds : xarray.Dataset
The input data. Must be compatible with `xdggs`.
kernel : xarray.DataArray
The sparse kernel matrix.
dim : str, default: "cells"
The input dimension name. Will also be the output dimension name.
mode : str, default: "constant"
Mode used to pad the array before convolving. Only used for regional maps. See
:py:func:`pad` for a list of available modes.
constant_values : int or float, default: 0
The constant value to pad with when mode is ``"constant"``.
Returns
-------
convolved : xarray.Dataset
The result of the convolution.
See Also
--------
healpix_convolution.xarray.pad
healpix_convolution.pad
"""
if ds.chunksizes:
kernel = kernel.chunk()

def _convolve(arr, weights):
src_dims = ["input_cells"]

if not set(src_dims).issubset(arr.dims):
return arr

return xr.dot(
# drop all input coords, as those would most likely be broadcast
arr.variable,
weights,
# This dimension will be "contracted"
# or summed over after multiplying by the weights
dims=src_dims,
)

if ds.sizes["cells"] != kernel.sizes["input_cells"]:
padder = pad(
ds["cell_ids"],
ring=kernel.attrs["ring"],
mode=mode,
constant_value=constant_value,
)
ds = padder.apply(ds)

return (
ds.rename_dims({dim: "input_cells"})
.map(_convolve, weights=kernel)
.rename_dims({"output_cells": dim})
.rename_vars({"output_cell_ids": "cell_ids"})
)
1 change: 1 addition & 0 deletions healpix_convolution/xarray/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from healpix_convolution.xarray.kernels.gaussian import gaussian_kernel # noqa: F401
Loading

0 comments on commit c726356

Please sign in to comment.