Skip to content

Commit

Permalink
Flexible coordinate transform (pydata#9543)
Browse files Browse the repository at this point in the history
* Add coordinate transform classes from prototype

* lint, public API and docstrings

* missing import

* sel: convert inverse transform results to ints

* sel: add todo note about rounding decimal pos

* rename create_coordinates -> create_coords

More consistent with the rest of Xarray API where `coords` is used
everywhere.

* add a Coordinates.from_transform convenient method

* fix repr (extract subset values of any n-d array)

* Apply suggestions from code review

Co-authored-by: Max Jones <[email protected]>

* remove specific create coordinates methods

In favor of the more generic `Coordinates.from_xindex()`.

* fix more typing issues

* remove public imports: not ready yet for public use

* add experimental notice in docstrings

* add coordinate transform tests

* typing fixes

* update what's new

---------

Co-authored-by: Deepak Cherian <[email protected]>
Co-authored-by: Max Jones <[email protected]>
  • Loading branch information
3 people authored Feb 14, 2025
1 parent 84e81bc commit 4bbab48
Show file tree
Hide file tree
Showing 7 changed files with 579 additions and 2 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ New Features
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- support python 3.13 (no free-threading) (:issue:`9664`, :pull:`9681`)
By `Justus Magin <https://github.com/keewis>`_.
- Added experimental support for coordinate transforms (not ready for public use yet!) (:pull:`9543`)
By `Benoit Bovy <https://github.com/benbovy>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
84 changes: 84 additions & 0 deletions xarray/core/coordinate_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from collections.abc import Hashable, Iterable, Mapping
from typing import Any

import numpy as np


class CoordinateTransform:
"""Abstract coordinate transform with dimension & coordinate names.
EXPERIMENTAL (not ready for public use yet).
"""

coord_names: tuple[Hashable, ...]
dims: tuple[str, ...]
dim_size: dict[str, int]
dtype: Any

def __init__(
self,
coord_names: Iterable[Hashable],
dim_size: Mapping[str, int],
dtype: Any = None,
):
self.coord_names = tuple(coord_names)
self.dims = tuple(dim_size)
self.dim_size = dict(dim_size)

if dtype is None:
dtype = np.dtype(np.float64)
self.dtype = dtype

def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]:
"""Perform grid -> world coordinate transformation.
Parameters
----------
dim_positions : dict
Grid location(s) along each dimension (axis).
Returns
-------
coord_labels : dict
World coordinate labels.
"""
# TODO: cache the results in order to avoid re-computing
# all labels when accessing the values of each coordinate one at a time
raise NotImplementedError

def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]:
"""Perform world -> grid coordinate reverse transformation.
Parameters
----------
labels : dict
World coordinate labels.
Returns
-------
dim_positions : dict
Grid relative location(s) along each dimension (axis).
"""
raise NotImplementedError

def equals(self, other: "CoordinateTransform") -> bool:
"""Check equality with another CoordinateTransform of the same kind."""
raise NotImplementedError

def generate_coords(
self, dims: tuple[str, ...] | None = None
) -> dict[Hashable, Any]:
"""Compute all coordinate labels at once."""
if dims is None:
dims = self.dims

positions = np.meshgrid(
*[np.arange(self.dim_size[d]) for d in dims],
indexing="ij",
)
dim_positions = {dim: positions[i] for i, dim in enumerate(dims)}

return self.forward(dim_positions)
2 changes: 1 addition & 1 deletion xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def from_xindex(cls, index: Index) -> Self:
def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self:
"""Wrap a pandas multi-index as Xarray coordinates (dimension + levels).
The returned coordinates can be directly assigned to a
The returned coordinate variables can be directly assigned to a
:py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the
``coords`` argument of their constructor.
Expand Down
121 changes: 121 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import pandas as pd

from xarray.core import formatting, nputils, utils
from xarray.core.coordinate_transform import CoordinateTransform
from xarray.core.indexing import (
CoordinateTransformIndexingAdapter,
IndexSelResult,
PandasIndexingAdapter,
PandasMultiIndexingAdapter,
Expand Down Expand Up @@ -1377,6 +1379,125 @@ def rename(self, name_dict, dims_dict):
)


class CoordinateTransformIndex(Index):
"""Helper class for creating Xarray indexes based on coordinate transforms.
EXPERIMENTAL (not ready for public use yet).
- wraps a :py:class:`CoordinateTransform` instance
- takes care of creating the index (lazy) coordinates
- supports point-wise label-based selection
- supports exact alignment only, by comparing indexes based on their transform
(not on their explicit coordinate labels)
"""

transform: CoordinateTransform

def __init__(
self,
transform: CoordinateTransform,
):
self.transform = transform

def create_variables(
self, variables: Mapping[Any, Variable] | None = None
) -> IndexVars:
from xarray.core.variable import Variable

new_variables = {}

for name in self.transform.coord_names:
# copy attributes, if any
attrs: Mapping[Hashable, Any] | None

if variables is not None and name in variables:
var = variables[name]
attrs = var.attrs
else:
attrs = None

data = CoordinateTransformIndexingAdapter(self.transform, name)
new_variables[name] = Variable(self.transform.dims, data, attrs=attrs)

return new_variables

def isel(
self, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
) -> Self | None:
# TODO: support returning a new index (e.g., possible to re-calculate the
# the transform or calculate another transform on a reduced dimension space)
return None

def sel(
self, labels: dict[Any, Any], method=None, tolerance=None
) -> IndexSelResult:
from xarray.core.dataarray import DataArray
from xarray.core.variable import Variable

if method != "nearest":
raise ValueError(
"CoordinateTransformIndex only supports selection with method='nearest'"
)

labels_set = set(labels)
coord_names_set = set(self.transform.coord_names)

missing_labels = coord_names_set - labels_set
if missing_labels:
missing_labels_str = ",".join([f"{name}" for name in missing_labels])
raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.")

label0_obj = next(iter(labels.values()))
dim_size0 = getattr(label0_obj, "sizes", {})

is_xr_obj = [
isinstance(label, DataArray | Variable) for label in labels.values()
]
if not all(is_xr_obj):
raise TypeError(
"CoordinateTransformIndex only supports advanced (point-wise) indexing "
"with either xarray.DataArray or xarray.Variable objects."
)
dim_size = [getattr(label, "sizes", {}) for label in labels.values()]
if any(ds != dim_size0 for ds in dim_size):
raise ValueError(
"CoordinateTransformIndex only supports advanced (point-wise) indexing "
"with xarray.DataArray or xarray.Variable objects of macthing dimensions."
)

coord_labels = {
name: labels[name].values for name in self.transform.coord_names
}
dim_positions = self.transform.reverse(coord_labels)

results: dict[str, Variable | DataArray] = {}
dims0 = tuple(dim_size0)
for dim, pos in dim_positions.items():
# TODO: rounding the decimal positions is not always the behavior we expect
# (there are different ways to represent implicit intervals)
# we should probably make this customizable.
pos = np.round(pos).astype("int")
if isinstance(label0_obj, Variable):
results[dim] = Variable(dims0, pos)
else:
# dataarray
results[dim] = DataArray(pos, dims=dims0)

return IndexSelResult(results)

def equals(self, other: Self) -> bool:
return self.transform.equals(other.transform)

def rename(
self,
name_dict: Mapping[Any, Hashable],
dims_dict: Mapping[Any, Hashable],
) -> Self:
# TODO: maybe update self.transform coord_names, dim_size and dims attributes
return self


def create_default_index_implicit(
dim_variable: Variable,
all_variables: Mapping | Iterable[Hashable] | None = None,
Expand Down
Loading

0 comments on commit 4bbab48

Please sign in to comment.