Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prune the kernel weights matrix using a threshold #36

Merged
merged 9 commits into from
Aug 13, 2024
7 changes: 5 additions & 2 deletions healpix_convolution/kernels/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import sparse


def create_sparse(cell_ids, neighbours, weights):
def create_sparse(cell_ids, neighbours, weights, weights_threshold=None):
neighbours_ = np.reshape(neighbours, (-1,))
reshaped_weights = np.reshape(weights, (-1,))

all_cell_ids = np.unique(neighbours_)
if all_cell_ids[0] == -1:
Expand All @@ -20,8 +21,10 @@ def create_sparse(cell_ids, neighbours, weights):
coords = np.stack([row_indices, column_indices], axis=0)

mask = neighbours_ != -1
if weights_threshold is not None:
mask = np.logical_and(mask, np.abs(reshaped_weights) >= weights_threshold)

weights_ = np.reshape(weights, (-1,))[mask]
weights_ = reshaped_weights[mask]
coords_ = coords[..., mask]

if isinstance(weights_, da.Array):
Expand Down
9 changes: 6 additions & 3 deletions healpix_convolution/kernels/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def gaussian_kernel(
sigma: float,
truncate: float = 4.0,
kernel_size: int | None = None,
weights_threshold: float | None = None,
):
"""construct a gaussian kernel on the healpix grid

Expand All @@ -48,11 +49,13 @@ def gaussian_kernel(
indexing_scheme : {"nested", "ring"}
The healpix indexing scheme
sigma : float
The standard deviation of the gaussian kernel
The standard deviation of the gaussian function in radians.
truncate : float, default: 4.0
Truncate the kernel after this many multiples of ``sigma``.
kernel_size : int, optional
If given, determines the size of the kernel. In that case, ``truncate`` is ignored.
If given, will be used instead of ``truncate`` to determine the size of the kernel.
weights_threshold : float, optional
If given, drop all kernel weights whose absolute value is smaller than this threshold.

Returns
-------
Expand All @@ -78,4 +81,4 @@ def gaussian_kernel(
d = angular_distances(nb, resolution=resolution, indexing_scheme=indexing_scheme)
weights = gaussian_function(d, sigma, mask=nb == -1)

return create_sparse(cell_ids, nb, weights)
return create_sparse(cell_ids, nb, weights, weights_threshold)
49 changes: 49 additions & 0 deletions healpix_convolution/tests/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,55 @@ def test_create_sparse(cell_ids, neighbours, weights):
assert actual.shape == expected_shape


@pytest.mark.parametrize(
["cell_ids", "neighbours", "threshold"],
(
pytest.param(
np.array([0, 1]),
np.array(
[[0, 17, 19, 2, 3, 1, 23, 22, 35], [3, 2, 13, 15, 11, 7, 6, 1, 0]]
),
None,
),
pytest.param(
np.array([0, 1]),
np.array(
[[0, 17, 19, 2, 3, 1, 23, 22, 35], [3, 2, 13, 15, 11, 7, 6, 1, 0]]
),
0.1,
),
pytest.param(
np.array([3, 2]),
np.array(
[[3, 2, 13, 15, 11, 7, 6, 1, 0], [2, 19, -1, 13, 15, 3, 1, 0, 17]]
),
0.1,
),
),
)
def test_create_sparse_threshold(cell_ids, neighbours, threshold):
expected_cell_ids = np.unique(neighbours)
if expected_cell_ids[0] == -1:
expected_cell_ids = expected_cell_ids[1:]

weights = np.reshape(neighbours / np.sum(neighbours, axis=1, keepdims=True), (-1,))

actual_cell_ids, actual = np_kernels.common.create_sparse(
cell_ids, neighbours, weights, weights_threshold=threshold
)

expected_nnz = (
np.sum(abs(weights) >= threshold) if threshold is not None else weights.size
)

np.testing.assert_equal(actual_cell_ids, expected_cell_ids)

expected_shape = (cell_ids.size, expected_cell_ids.size)
assert hasattr(actual, "nnz"), "not a sparse matrix"
assert actual.shape == expected_shape
assert actual.nnz == expected_nnz, "non-zero entries don't match"


def fit_polynomial(x, y, deg):
mask = y > 0
x_ = x[mask]
Expand Down
12 changes: 10 additions & 2 deletions healpix_convolution/xarray/kernels/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ def compute_ring(grid_info, sigma, kernel_size, truncate):


def gaussian_kernel(
cell_ids, sigma: float, truncate: float = 4.0, kernel_size: int | None = None
cell_ids,
sigma: float,
truncate: float = 4.0,
kernel_size: int | None = None,
weights_threshold: float | None = None,
):
"""Create a symmetric gaussian kernel for the given cell ids

Expand All @@ -26,9 +30,12 @@ def gaussian_kernel(
sigma : float
The standard deviation of the gaussian function in radians.
truncate : float, default: 4.0
Truncate the kernel after this many multiples of sigma.
Truncate the kernel after this many multiples of ``sigma``.
kernel_size : int, optional
If given, will be used instead of ``truncate`` to determine the size of the kernel.
weights_threshold : float, optional
If given, drop all kernel weights whose absolute value is smaller than this threshold.


Returns
-------
Expand All @@ -44,6 +51,7 @@ def gaussian_kernel(
sigma=sigma,
truncate=truncate,
kernel_size=kernel_size,
weights_threshold=weights_threshold,
)

if kernel_size is not None:
Expand Down