Skip to content

Commit

Permalink
Merge branch 'scverse:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
berombau authored Mar 14, 2024
2 parents 4c150c5 + cdb3d45 commit 7bb87ba
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning][].
- added utils function: are_extents_equal()
- added utils function: postpone_transformation()
- added utils function: remove_transformations_to_coordinate_system()
- added utils function: get_centroids()

### Minor

Expand Down
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Operations on `SpatialData` objects.
polygon_query
get_values
get_extent
get_centroids
match_table_to_element
concatenate
transform
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies = [
"xarray-spatial>=0.3.5",
"tqdm",
"fsspec<=2023.6",
"dask<=2024.2.1"
]

[project.optional-dependencies]
Expand Down
2 changes: 2 additions & 0 deletions src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"match_table_to_element",
"SpatialData",
"get_extent",
"get_centroids",
"read_zarr",
"unpad_raster",
"save_transformations",
Expand All @@ -33,6 +34,7 @@
]

from spatialdata import dataloader, models, transformations
from spatialdata._core.centroids import get_centroids
from spatialdata._core.concatenate import concatenate
from spatialdata._core.data_extent import are_extents_equal, get_extent
from spatialdata._core.operations.aggregate import aggregate
Expand Down
152 changes: 152 additions & 0 deletions src/spatialdata/_core/centroids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from __future__ import annotations

from collections import defaultdict
from functools import singledispatch

import dask.array as da
import pandas as pd
import xarray as xr
from dask.dataframe.core import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
from multiscale_spatial_image import MultiscaleSpatialImage
from shapely import MultiPolygon, Point, Polygon
from spatial_image import SpatialImage

from spatialdata._core.operations.transform import transform
from spatialdata.models import get_axes_names
from spatialdata.models._utils import SpatialElement
from spatialdata.models.models import Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, get_model
from spatialdata.transformations.operations import get_transformation
from spatialdata.transformations.transformations import BaseTransformation

BoundingBoxDescription = dict[str, tuple[float, float]]


def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> None:
d = get_transformation(e, get_all=True)
assert isinstance(d, dict)
assert coordinate_system in d, (
f"No transformation to coordinate system {coordinate_system} is available for the given element.\n"
f"Available coordinate systems: {list(d.keys())}"
)


@singledispatch
def get_centroids(
e: SpatialElement,
coordinate_system: str = "global",
) -> DaskDataFrame:
"""
Get the centroids of the geometries contained in a SpatialElement, as a new Points element.
Parameters
----------
e
The SpatialElement. Only points, shapes (circles, polygons and multipolygons) and labels are supported.
coordinate_system
The coordinate system in which the centroids are computed.
Notes
-----
For :class:`~shapely.Multipolygon`s, the centroids are the average of the centroids of the polygons that constitute
each :class:`~shapely.Multipolygon`.
"""
raise ValueError(f"The object type {type(e)} is not supported.")


def _get_centroids_for_axis(xdata: xr.DataArray, axis: str) -> pd.DataFrame:
"""
Compute the component "axis" of the centroid of each label as a weighted average of the xarray coordinates.
Parameters
----------
xdata
The xarray DataArray containing the labels.
axis
The axis for which the centroids are computed.
Returns
-------
pd.DataFrame
A DataFrame containing one column, named after "axis", with the centroids of the labels along that axis.
The index of the DataFrame is the collection of label values, sorted ascendingly.
"""
centroids: dict[int, float] = defaultdict(float)
for i in xdata[axis]:
portion = xdata.sel(**{axis: i}).data
u = da.unique(portion, return_counts=True)
labels_values = u[0].compute()
counts = u[1].compute()
for j in range(len(labels_values)):
label_value = labels_values[j]
count = counts[j]
centroids[label_value] += count * i.values.item()

all_labels_values, all_labels_counts = da.unique(xdata.data, return_counts=True)
all_labels = dict(zip(all_labels_values.compute(), all_labels_counts.compute()))
for label_value in centroids:
centroids[label_value] /= all_labels[label_value]
centroids = dict(sorted(centroids.items(), key=lambda x: x[0]))
return pd.DataFrame({axis: centroids.values()}, index=list(centroids.keys()))


@get_centroids.register(SpatialImage)
@get_centroids.register(MultiscaleSpatialImage)
def _(
e: SpatialImage | MultiscaleSpatialImage,
coordinate_system: str = "global",
) -> DaskDataFrame:
"""Get the centroids of a Labels element (2D or 3D)."""
model = get_model(e)
if model in [Image2DModel, Image3DModel]:
raise ValueError("Cannot compute centroids for images.")
assert model in [Labels2DModel, Labels3DModel]
_validate_coordinate_system(e, coordinate_system)

if isinstance(e, MultiscaleSpatialImage):
assert len(e["scale0"]) == 1
e = SpatialImage(next(iter(e["scale0"].values())))

dfs = []
for axis in get_axes_names(e):
dfs.append(_get_centroids_for_axis(e, axis))
df = pd.concat(dfs, axis=1)
t = get_transformation(e, coordinate_system)
centroids = PointsModel.parse(df, transformations={coordinate_system: t})
return transform(centroids, to_coordinate_system=coordinate_system)


@get_centroids.register(GeoDataFrame)
def _(e: GeoDataFrame, coordinate_system: str = "global") -> DaskDataFrame:
"""Get the centroids of a Shapes element (circles or polygons/multipolygons)."""
_validate_coordinate_system(e, coordinate_system)
t = get_transformation(e, coordinate_system)
assert isinstance(t, BaseTransformation)
# separate points from (multi-)polygons
first_geometry = e["geometry"].iloc[0]
if isinstance(first_geometry, Point):
xy = e.geometry.get_coordinates().values
else:
assert isinstance(first_geometry, (Polygon, MultiPolygon)), (
f"Expected a GeoDataFrame either composed entirely of circles (Points with the `radius` column) or"
f" Polygons/MultiPolygons. Found {type(first_geometry)} instead."
)
xy = e.centroid.get_coordinates().values
points = PointsModel.parse(xy, transformations={coordinate_system: t})
return transform(points, to_coordinate_system=coordinate_system)


@get_centroids.register(DaskDataFrame)
def _(e: DaskDataFrame, coordinate_system: str = "global") -> DaskDataFrame:
"""Get the centroids of a Points element."""
_validate_coordinate_system(e, coordinate_system)
axes = get_axes_names(e)
assert axes in [("x", "y"), ("x", "y", "z")]
coords = e[list(axes)].compute().values
t = get_transformation(e, coordinate_system)
assert isinstance(t, BaseTransformation)
centroids = PointsModel.parse(coords, transformations={coordinate_system: t})
return transform(centroids, to_coordinate_system=coordinate_system)


##
4 changes: 4 additions & 0 deletions src/spatialdata/_core/operations/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def _(
transformation,
raster_translation=raster_translation,
maintain_positioning=maintain_positioning,
to_coordinate_system=to_coordinate_system,
)
transformed_data = compute_coordinates(transformed_data)
schema().validate(transformed_data)
Expand Down Expand Up @@ -404,6 +405,7 @@ def _(
transformation,
raster_translation=raster_translation,
maintain_positioning=maintain_positioning,
to_coordinate_system=to_coordinate_system,
)
transformed_data = compute_coordinates(transformed_data)
schema().validate(transformed_data)
Expand Down Expand Up @@ -447,6 +449,7 @@ def _(
transformation,
raster_translation=None,
maintain_positioning=maintain_positioning,
to_coordinate_system=to_coordinate_system,
)
PointsModel.validate(transformed)
return transformed
Expand Down Expand Up @@ -490,6 +493,7 @@ def _(
transformation,
raster_translation=None,
maintain_positioning=maintain_positioning,
to_coordinate_system=to_coordinate_system,
)
ShapesModel.validate(transformed_data)
return transformed_data
155 changes: 155 additions & 0 deletions tests/core/test_centroids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import numpy as np
import pandas as pd
import pytest
from anndata import AnnData
from numpy.random import default_rng
from spatialdata._core.centroids import get_centroids
from spatialdata.models import Labels2DModel, Labels3DModel, TableModel, get_axes_names
from spatialdata.transformations import Identity, get_transformation, set_transformation

from tests.core.operations.test_transform import _get_affine

RNG = default_rng(42)

affine = _get_affine()


@pytest.mark.parametrize("coordinate_system", ["global", "aligned"])
@pytest.mark.parametrize("is_3d", [False, True])
def test_get_centroids_points(points, coordinate_system: str, is_3d: bool):
element = points["points_0"]

# by default, the coordinate system is global and the points are 2D; let's modify the points as instructed by the
# test arguments
if coordinate_system == "aligned":
set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system)
if is_3d:
element["z"] = element["x"]

axes = get_axes_names(element)
centroids = get_centroids(element, coordinate_system=coordinate_system)

# the axes of the centroids should be the same as the axes of the element
assert centroids.columns.tolist() == list(axes)

# the centroids should not contain extra columns
assert "genes" in element.columns and "genes" not in centroids.columns

# the centroids transformation to the target coordinate system should be an Identity because the transformation has
# already been applied
assert get_transformation(centroids, to_coordinate_system=coordinate_system) == Identity()

# let's check the values
if coordinate_system == "global":
assert np.array_equal(centroids.compute().values, element[list(axes)].compute().values)
else:
matrix = affine.to_affine_matrix(input_axes=axes, output_axes=axes)
centroids_untransformed = element[list(axes)].compute().values
n = len(axes)
centroids_transformed = np.dot(centroids_untransformed, matrix[:n, :n].T) + matrix[:n, n]
assert np.allclose(centroids.compute().values, centroids_transformed)


@pytest.mark.parametrize("coordinate_system", ["global", "aligned"])
@pytest.mark.parametrize("shapes_name", ["circles", "poly", "multipoly"])
def test_get_centroids_shapes(shapes, coordinate_system: str, shapes_name: str):
element = shapes[shapes_name]
if coordinate_system == "aligned":
set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system)
centroids = get_centroids(element, coordinate_system=coordinate_system)

if shapes_name == "circles":
xy = element.geometry.get_coordinates().values
else:
assert shapes_name in ["poly", "multipoly"]
xy = element.geometry.centroid.get_coordinates().values

if coordinate_system == "global":
assert np.array_equal(centroids.compute().values, xy)
else:
matrix = affine.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
centroids_transformed = np.dot(xy, matrix[:2, :2].T) + matrix[:2, 2]
assert np.allclose(centroids.compute().values, centroids_transformed)


@pytest.mark.parametrize("coordinate_system", ["global", "aligned"])
@pytest.mark.parametrize("is_multiscale", [False, True])
@pytest.mark.parametrize("is_3d", [False, True])
def test_get_centroids_labels(labels, coordinate_system: str, is_multiscale: bool, is_3d: bool):
scale_factors = [2] if is_multiscale else None
if is_3d:
model = Labels3DModel
array = np.array(
[
[
[0, 0, 1, 1],
[0, 0, 1, 1],
],
[
[2, 2, 1, 1],
[2, 2, 1, 1],
],
]
)
expected_centroids = pd.DataFrame(
{
"x": [1, 3, 1],
"y": [1, 1.0, 1],
"z": [0.5, 1, 1.5],
},
index=[0, 1, 2],
)
else:
array = np.array(
[
[1, 1, 1, 1],
[2, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 2, 2],
]
)
model = Labels2DModel
expected_centroids = pd.DataFrame(
{
"x": [2, 2],
"y": [0.5, 2.5],
},
index=[1, 2],
)
element = model.parse(array, scale_factors=scale_factors)

if coordinate_system == "aligned":
set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system)
centroids = get_centroids(element, coordinate_system=coordinate_system)

if coordinate_system == "global":
assert np.array_equal(centroids.compute().values, expected_centroids.values)
else:
axes = get_axes_names(element)
n = len(axes)
# the axes from the labels have 'x' last, but we want it first to manually transform the points, so we sort
matrix = affine.to_affine_matrix(input_axes=sorted(axes), output_axes=sorted(axes))
centroids_transformed = np.dot(expected_centroids.values, matrix[:n, :n].T) + matrix[:n, n]
assert np.allclose(centroids.compute().values, centroids_transformed)


def test_get_centroids_invalid_element(images):
# cannot compute centroids for images
with pytest.raises(ValueError, match="Cannot compute centroids for images."):
get_centroids(images["image2d"])

# cannot compute centroids for tables
N = 10
adata = TableModel.parse(
AnnData(X=RNG.random((N, N)), obs={"region": ["dummy" for _ in range(N)], "instance_id": np.arange(N)}),
region="dummy",
region_key="region",
instance_key="instance_id",
)
with pytest.raises(ValueError, match="The object type <class 'anndata._core.anndata.AnnData'> is not supported."):
get_centroids(adata)


def test_get_centroids_invalid_coordinate_system(points):
with pytest.raises(AssertionError, match="No transformation to coordinate system"):
get_centroids(points["points_0"], coordinate_system="invalid")
2 changes: 1 addition & 1 deletion tests/core/test_data_extent.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def test_get_extent_affine_circles():
gdf = ShapesModel.parse(gdf, transformations={"transformed": affine})
transformed_bounding_box = transform(gdf, to_coordinate_system="transformed")

transformed_bounding_box_extent = get_extent(transformed_bounding_box)
transformed_bounding_box_extent = get_extent(transformed_bounding_box, coordinate_system="transformed")

assert transformed_axes == list(transformed_bounding_box_extent.keys())
for ax in transformed_axes:
Expand Down

0 comments on commit 7bb87ba

Please sign in to comment.