diff --git a/healpix_convolution/kernels/common.py b/healpix_convolution/kernels/common.py index 9f4d548..04bc30d 100644 --- a/healpix_convolution/kernels/common.py +++ b/healpix_convolution/kernels/common.py @@ -4,8 +4,9 @@ import sparse -def create_sparse(cell_ids, neighbours, weights): +def create_sparse(cell_ids, neighbours, weights, weights_threshold=None): neighbours_ = np.reshape(neighbours, (-1,)) + reshaped_weights = np.reshape(weights, (-1,)) all_cell_ids = np.unique(neighbours_) if all_cell_ids[0] == -1: @@ -20,8 +21,10 @@ def create_sparse(cell_ids, neighbours, weights): coords = np.stack([row_indices, column_indices], axis=0) mask = neighbours_ != -1 + if weights_threshold is not None: + mask = np.logical_and(mask, np.abs(reshaped_weights) >= weights_threshold) - weights_ = np.reshape(weights, (-1,))[mask] + weights_ = reshaped_weights[mask] coords_ = coords[..., mask] if isinstance(weights_, da.Array): diff --git a/healpix_convolution/kernels/gaussian.py b/healpix_convolution/kernels/gaussian.py index 03867e3..f77c662 100644 --- a/healpix_convolution/kernels/gaussian.py +++ b/healpix_convolution/kernels/gaussian.py @@ -36,6 +36,7 @@ def gaussian_kernel( sigma: float, truncate: float = 4.0, kernel_size: int | None = None, + weights_threshold: float | None = None, ): """construct a gaussian kernel on the healpix grid @@ -48,11 +49,13 @@ def gaussian_kernel( indexing_scheme : {"nested", "ring"} The healpix indexing scheme sigma : float - The standard deviation of the gaussian kernel + The standard deviation of the gaussian function in radians. truncate : float, default: 4.0 Truncate the kernel after this many multiples of ``sigma``. kernel_size : int, optional - If given, determines the size of the kernel. In that case, ``truncate`` is ignored. + If given, will be used instead of ``truncate`` to determine the size of the kernel. + weights_threshold : float, optional + If given, drop all kernel weights whose absolute value is smaller than this threshold. Returns ------- @@ -78,4 +81,4 @@ def gaussian_kernel( d = angular_distances(nb, resolution=resolution, indexing_scheme=indexing_scheme) weights = gaussian_function(d, sigma, mask=nb == -1) - return create_sparse(cell_ids, nb, weights) + return create_sparse(cell_ids, nb, weights, weights_threshold) diff --git a/healpix_convolution/tests/test_kernels.py b/healpix_convolution/tests/test_kernels.py index b70184c..fa4de35 100644 --- a/healpix_convolution/tests/test_kernels.py +++ b/healpix_convolution/tests/test_kernels.py @@ -48,6 +48,55 @@ def test_create_sparse(cell_ids, neighbours, weights): assert actual.shape == expected_shape +@pytest.mark.parametrize( + ["cell_ids", "neighbours", "threshold"], + ( + pytest.param( + np.array([0, 1]), + np.array( + [[0, 17, 19, 2, 3, 1, 23, 22, 35], [3, 2, 13, 15, 11, 7, 6, 1, 0]] + ), + None, + ), + pytest.param( + np.array([0, 1]), + np.array( + [[0, 17, 19, 2, 3, 1, 23, 22, 35], [3, 2, 13, 15, 11, 7, 6, 1, 0]] + ), + 0.1, + ), + pytest.param( + np.array([3, 2]), + np.array( + [[3, 2, 13, 15, 11, 7, 6, 1, 0], [2, 19, -1, 13, 15, 3, 1, 0, 17]] + ), + 0.1, + ), + ), +) +def test_create_sparse_threshold(cell_ids, neighbours, threshold): + expected_cell_ids = np.unique(neighbours) + if expected_cell_ids[0] == -1: + expected_cell_ids = expected_cell_ids[1:] + + weights = np.reshape(neighbours / np.sum(neighbours, axis=1, keepdims=True), (-1,)) + + actual_cell_ids, actual = np_kernels.common.create_sparse( + cell_ids, neighbours, weights, weights_threshold=threshold + ) + + expected_nnz = ( + np.sum(abs(weights) >= threshold) if threshold is not None else weights.size + ) + + np.testing.assert_equal(actual_cell_ids, expected_cell_ids) + + expected_shape = (cell_ids.size, expected_cell_ids.size) + assert hasattr(actual, "nnz"), "not a sparse matrix" + assert actual.shape == expected_shape + assert actual.nnz == expected_nnz, "non-zero entries don't match" + + def fit_polynomial(x, y, deg): mask = y > 0 x_ = x[mask] diff --git a/healpix_convolution/xarray/kernels/gaussian.py b/healpix_convolution/xarray/kernels/gaussian.py index f58ddb8..e50f252 100644 --- a/healpix_convolution/xarray/kernels/gaussian.py +++ b/healpix_convolution/xarray/kernels/gaussian.py @@ -15,7 +15,11 @@ def compute_ring(grid_info, sigma, kernel_size, truncate): def gaussian_kernel( - cell_ids, sigma: float, truncate: float = 4.0, kernel_size: int | None = None + cell_ids, + sigma: float, + truncate: float = 4.0, + kernel_size: int | None = None, + weights_threshold: float | None = None, ): """Create a symmetric gaussian kernel for the given cell ids @@ -26,9 +30,12 @@ def gaussian_kernel( sigma : float The standard deviation of the gaussian function in radians. truncate : float, default: 4.0 - Truncate the kernel after this many multiples of sigma. + Truncate the kernel after this many multiples of ``sigma``. kernel_size : int, optional If given, will be used instead of ``truncate`` to determine the size of the kernel. + weights_threshold : float, optional + If given, drop all kernel weights whose absolute value is smaller than this threshold. + Returns ------- @@ -44,6 +51,7 @@ def gaussian_kernel( sigma=sigma, truncate=truncate, kernel_size=kernel_size, + weights_threshold=weights_threshold, ) if kernel_size is not None: