forked from scverse/spatialdata
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'scverse:main' into main
- Loading branch information
Showing
8 changed files
with
317 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
from __future__ import annotations | ||
|
||
from collections import defaultdict | ||
from functools import singledispatch | ||
|
||
import dask.array as da | ||
import pandas as pd | ||
import xarray as xr | ||
from dask.dataframe.core import DataFrame as DaskDataFrame | ||
from geopandas import GeoDataFrame | ||
from multiscale_spatial_image import MultiscaleSpatialImage | ||
from shapely import MultiPolygon, Point, Polygon | ||
from spatial_image import SpatialImage | ||
|
||
from spatialdata._core.operations.transform import transform | ||
from spatialdata.models import get_axes_names | ||
from spatialdata.models._utils import SpatialElement | ||
from spatialdata.models.models import Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, get_model | ||
from spatialdata.transformations.operations import get_transformation | ||
from spatialdata.transformations.transformations import BaseTransformation | ||
|
||
BoundingBoxDescription = dict[str, tuple[float, float]] | ||
|
||
|
||
def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> None: | ||
d = get_transformation(e, get_all=True) | ||
assert isinstance(d, dict) | ||
assert coordinate_system in d, ( | ||
f"No transformation to coordinate system {coordinate_system} is available for the given element.\n" | ||
f"Available coordinate systems: {list(d.keys())}" | ||
) | ||
|
||
|
||
@singledispatch | ||
def get_centroids( | ||
e: SpatialElement, | ||
coordinate_system: str = "global", | ||
) -> DaskDataFrame: | ||
""" | ||
Get the centroids of the geometries contained in a SpatialElement, as a new Points element. | ||
Parameters | ||
---------- | ||
e | ||
The SpatialElement. Only points, shapes (circles, polygons and multipolygons) and labels are supported. | ||
coordinate_system | ||
The coordinate system in which the centroids are computed. | ||
Notes | ||
----- | ||
For :class:`~shapely.Multipolygon`s, the centroids are the average of the centroids of the polygons that constitute | ||
each :class:`~shapely.Multipolygon`. | ||
""" | ||
raise ValueError(f"The object type {type(e)} is not supported.") | ||
|
||
|
||
def _get_centroids_for_axis(xdata: xr.DataArray, axis: str) -> pd.DataFrame: | ||
""" | ||
Compute the component "axis" of the centroid of each label as a weighted average of the xarray coordinates. | ||
Parameters | ||
---------- | ||
xdata | ||
The xarray DataArray containing the labels. | ||
axis | ||
The axis for which the centroids are computed. | ||
Returns | ||
------- | ||
pd.DataFrame | ||
A DataFrame containing one column, named after "axis", with the centroids of the labels along that axis. | ||
The index of the DataFrame is the collection of label values, sorted ascendingly. | ||
""" | ||
centroids: dict[int, float] = defaultdict(float) | ||
for i in xdata[axis]: | ||
portion = xdata.sel(**{axis: i}).data | ||
u = da.unique(portion, return_counts=True) | ||
labels_values = u[0].compute() | ||
counts = u[1].compute() | ||
for j in range(len(labels_values)): | ||
label_value = labels_values[j] | ||
count = counts[j] | ||
centroids[label_value] += count * i.values.item() | ||
|
||
all_labels_values, all_labels_counts = da.unique(xdata.data, return_counts=True) | ||
all_labels = dict(zip(all_labels_values.compute(), all_labels_counts.compute())) | ||
for label_value in centroids: | ||
centroids[label_value] /= all_labels[label_value] | ||
centroids = dict(sorted(centroids.items(), key=lambda x: x[0])) | ||
return pd.DataFrame({axis: centroids.values()}, index=list(centroids.keys())) | ||
|
||
|
||
@get_centroids.register(SpatialImage) | ||
@get_centroids.register(MultiscaleSpatialImage) | ||
def _( | ||
e: SpatialImage | MultiscaleSpatialImage, | ||
coordinate_system: str = "global", | ||
) -> DaskDataFrame: | ||
"""Get the centroids of a Labels element (2D or 3D).""" | ||
model = get_model(e) | ||
if model in [Image2DModel, Image3DModel]: | ||
raise ValueError("Cannot compute centroids for images.") | ||
assert model in [Labels2DModel, Labels3DModel] | ||
_validate_coordinate_system(e, coordinate_system) | ||
|
||
if isinstance(e, MultiscaleSpatialImage): | ||
assert len(e["scale0"]) == 1 | ||
e = SpatialImage(next(iter(e["scale0"].values()))) | ||
|
||
dfs = [] | ||
for axis in get_axes_names(e): | ||
dfs.append(_get_centroids_for_axis(e, axis)) | ||
df = pd.concat(dfs, axis=1) | ||
t = get_transformation(e, coordinate_system) | ||
centroids = PointsModel.parse(df, transformations={coordinate_system: t}) | ||
return transform(centroids, to_coordinate_system=coordinate_system) | ||
|
||
|
||
@get_centroids.register(GeoDataFrame) | ||
def _(e: GeoDataFrame, coordinate_system: str = "global") -> DaskDataFrame: | ||
"""Get the centroids of a Shapes element (circles or polygons/multipolygons).""" | ||
_validate_coordinate_system(e, coordinate_system) | ||
t = get_transformation(e, coordinate_system) | ||
assert isinstance(t, BaseTransformation) | ||
# separate points from (multi-)polygons | ||
first_geometry = e["geometry"].iloc[0] | ||
if isinstance(first_geometry, Point): | ||
xy = e.geometry.get_coordinates().values | ||
else: | ||
assert isinstance(first_geometry, (Polygon, MultiPolygon)), ( | ||
f"Expected a GeoDataFrame either composed entirely of circles (Points with the `radius` column) or" | ||
f" Polygons/MultiPolygons. Found {type(first_geometry)} instead." | ||
) | ||
xy = e.centroid.get_coordinates().values | ||
points = PointsModel.parse(xy, transformations={coordinate_system: t}) | ||
return transform(points, to_coordinate_system=coordinate_system) | ||
|
||
|
||
@get_centroids.register(DaskDataFrame) | ||
def _(e: DaskDataFrame, coordinate_system: str = "global") -> DaskDataFrame: | ||
"""Get the centroids of a Points element.""" | ||
_validate_coordinate_system(e, coordinate_system) | ||
axes = get_axes_names(e) | ||
assert axes in [("x", "y"), ("x", "y", "z")] | ||
coords = e[list(axes)].compute().values | ||
t = get_transformation(e, coordinate_system) | ||
assert isinstance(t, BaseTransformation) | ||
centroids = PointsModel.parse(coords, transformations={coordinate_system: t}) | ||
return transform(centroids, to_coordinate_system=coordinate_system) | ||
|
||
|
||
## |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
from anndata import AnnData | ||
from numpy.random import default_rng | ||
from spatialdata._core.centroids import get_centroids | ||
from spatialdata.models import Labels2DModel, Labels3DModel, TableModel, get_axes_names | ||
from spatialdata.transformations import Identity, get_transformation, set_transformation | ||
|
||
from tests.core.operations.test_transform import _get_affine | ||
|
||
RNG = default_rng(42) | ||
|
||
affine = _get_affine() | ||
|
||
|
||
@pytest.mark.parametrize("coordinate_system", ["global", "aligned"]) | ||
@pytest.mark.parametrize("is_3d", [False, True]) | ||
def test_get_centroids_points(points, coordinate_system: str, is_3d: bool): | ||
element = points["points_0"] | ||
|
||
# by default, the coordinate system is global and the points are 2D; let's modify the points as instructed by the | ||
# test arguments | ||
if coordinate_system == "aligned": | ||
set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system) | ||
if is_3d: | ||
element["z"] = element["x"] | ||
|
||
axes = get_axes_names(element) | ||
centroids = get_centroids(element, coordinate_system=coordinate_system) | ||
|
||
# the axes of the centroids should be the same as the axes of the element | ||
assert centroids.columns.tolist() == list(axes) | ||
|
||
# the centroids should not contain extra columns | ||
assert "genes" in element.columns and "genes" not in centroids.columns | ||
|
||
# the centroids transformation to the target coordinate system should be an Identity because the transformation has | ||
# already been applied | ||
assert get_transformation(centroids, to_coordinate_system=coordinate_system) == Identity() | ||
|
||
# let's check the values | ||
if coordinate_system == "global": | ||
assert np.array_equal(centroids.compute().values, element[list(axes)].compute().values) | ||
else: | ||
matrix = affine.to_affine_matrix(input_axes=axes, output_axes=axes) | ||
centroids_untransformed = element[list(axes)].compute().values | ||
n = len(axes) | ||
centroids_transformed = np.dot(centroids_untransformed, matrix[:n, :n].T) + matrix[:n, n] | ||
assert np.allclose(centroids.compute().values, centroids_transformed) | ||
|
||
|
||
@pytest.mark.parametrize("coordinate_system", ["global", "aligned"]) | ||
@pytest.mark.parametrize("shapes_name", ["circles", "poly", "multipoly"]) | ||
def test_get_centroids_shapes(shapes, coordinate_system: str, shapes_name: str): | ||
element = shapes[shapes_name] | ||
if coordinate_system == "aligned": | ||
set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system) | ||
centroids = get_centroids(element, coordinate_system=coordinate_system) | ||
|
||
if shapes_name == "circles": | ||
xy = element.geometry.get_coordinates().values | ||
else: | ||
assert shapes_name in ["poly", "multipoly"] | ||
xy = element.geometry.centroid.get_coordinates().values | ||
|
||
if coordinate_system == "global": | ||
assert np.array_equal(centroids.compute().values, xy) | ||
else: | ||
matrix = affine.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) | ||
centroids_transformed = np.dot(xy, matrix[:2, :2].T) + matrix[:2, 2] | ||
assert np.allclose(centroids.compute().values, centroids_transformed) | ||
|
||
|
||
@pytest.mark.parametrize("coordinate_system", ["global", "aligned"]) | ||
@pytest.mark.parametrize("is_multiscale", [False, True]) | ||
@pytest.mark.parametrize("is_3d", [False, True]) | ||
def test_get_centroids_labels(labels, coordinate_system: str, is_multiscale: bool, is_3d: bool): | ||
scale_factors = [2] if is_multiscale else None | ||
if is_3d: | ||
model = Labels3DModel | ||
array = np.array( | ||
[ | ||
[ | ||
[0, 0, 1, 1], | ||
[0, 0, 1, 1], | ||
], | ||
[ | ||
[2, 2, 1, 1], | ||
[2, 2, 1, 1], | ||
], | ||
] | ||
) | ||
expected_centroids = pd.DataFrame( | ||
{ | ||
"x": [1, 3, 1], | ||
"y": [1, 1.0, 1], | ||
"z": [0.5, 1, 1.5], | ||
}, | ||
index=[0, 1, 2], | ||
) | ||
else: | ||
array = np.array( | ||
[ | ||
[1, 1, 1, 1], | ||
[2, 2, 2, 2], | ||
[2, 2, 2, 2], | ||
[2, 2, 2, 2], | ||
] | ||
) | ||
model = Labels2DModel | ||
expected_centroids = pd.DataFrame( | ||
{ | ||
"x": [2, 2], | ||
"y": [0.5, 2.5], | ||
}, | ||
index=[1, 2], | ||
) | ||
element = model.parse(array, scale_factors=scale_factors) | ||
|
||
if coordinate_system == "aligned": | ||
set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system) | ||
centroids = get_centroids(element, coordinate_system=coordinate_system) | ||
|
||
if coordinate_system == "global": | ||
assert np.array_equal(centroids.compute().values, expected_centroids.values) | ||
else: | ||
axes = get_axes_names(element) | ||
n = len(axes) | ||
# the axes from the labels have 'x' last, but we want it first to manually transform the points, so we sort | ||
matrix = affine.to_affine_matrix(input_axes=sorted(axes), output_axes=sorted(axes)) | ||
centroids_transformed = np.dot(expected_centroids.values, matrix[:n, :n].T) + matrix[:n, n] | ||
assert np.allclose(centroids.compute().values, centroids_transformed) | ||
|
||
|
||
def test_get_centroids_invalid_element(images): | ||
# cannot compute centroids for images | ||
with pytest.raises(ValueError, match="Cannot compute centroids for images."): | ||
get_centroids(images["image2d"]) | ||
|
||
# cannot compute centroids for tables | ||
N = 10 | ||
adata = TableModel.parse( | ||
AnnData(X=RNG.random((N, N)), obs={"region": ["dummy" for _ in range(N)], "instance_id": np.arange(N)}), | ||
region="dummy", | ||
region_key="region", | ||
instance_key="instance_id", | ||
) | ||
with pytest.raises(ValueError, match="The object type <class 'anndata._core.anndata.AnnData'> is not supported."): | ||
get_centroids(adata) | ||
|
||
|
||
def test_get_centroids_invalid_coordinate_system(points): | ||
with pytest.raises(AssertionError, match="No transformation to coordinate system"): | ||
get_centroids(points["points_0"], coordinate_system="invalid") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters