forked from pydata/xarray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Flexible coordinate transform (pydata#9543)
* Add coordinate transform classes from prototype * lint, public API and docstrings * missing import * sel: convert inverse transform results to ints * sel: add todo note about rounding decimal pos * rename create_coordinates -> create_coords More consistent with the rest of Xarray API where `coords` is used everywhere. * add a Coordinates.from_transform convenient method * fix repr (extract subset values of any n-d array) * Apply suggestions from code review Co-authored-by: Max Jones <[email protected]> * remove specific create coordinates methods In favor of the more generic `Coordinates.from_xindex()`. * fix more typing issues * remove public imports: not ready yet for public use * add experimental notice in docstrings * add coordinate transform tests * typing fixes * update what's new --------- Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Max Jones <[email protected]>
- Loading branch information
1 parent
84e81bc
commit 4bbab48
Showing
7 changed files
with
579 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from collections.abc import Hashable, Iterable, Mapping | ||
from typing import Any | ||
|
||
import numpy as np | ||
|
||
|
||
class CoordinateTransform: | ||
"""Abstract coordinate transform with dimension & coordinate names. | ||
EXPERIMENTAL (not ready for public use yet). | ||
""" | ||
|
||
coord_names: tuple[Hashable, ...] | ||
dims: tuple[str, ...] | ||
dim_size: dict[str, int] | ||
dtype: Any | ||
|
||
def __init__( | ||
self, | ||
coord_names: Iterable[Hashable], | ||
dim_size: Mapping[str, int], | ||
dtype: Any = None, | ||
): | ||
self.coord_names = tuple(coord_names) | ||
self.dims = tuple(dim_size) | ||
self.dim_size = dict(dim_size) | ||
|
||
if dtype is None: | ||
dtype = np.dtype(np.float64) | ||
self.dtype = dtype | ||
|
||
def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: | ||
"""Perform grid -> world coordinate transformation. | ||
Parameters | ||
---------- | ||
dim_positions : dict | ||
Grid location(s) along each dimension (axis). | ||
Returns | ||
------- | ||
coord_labels : dict | ||
World coordinate labels. | ||
""" | ||
# TODO: cache the results in order to avoid re-computing | ||
# all labels when accessing the values of each coordinate one at a time | ||
raise NotImplementedError | ||
|
||
def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: | ||
"""Perform world -> grid coordinate reverse transformation. | ||
Parameters | ||
---------- | ||
labels : dict | ||
World coordinate labels. | ||
Returns | ||
------- | ||
dim_positions : dict | ||
Grid relative location(s) along each dimension (axis). | ||
""" | ||
raise NotImplementedError | ||
|
||
def equals(self, other: "CoordinateTransform") -> bool: | ||
"""Check equality with another CoordinateTransform of the same kind.""" | ||
raise NotImplementedError | ||
|
||
def generate_coords( | ||
self, dims: tuple[str, ...] | None = None | ||
) -> dict[Hashable, Any]: | ||
"""Compute all coordinate labels at once.""" | ||
if dims is None: | ||
dims = self.dims | ||
|
||
positions = np.meshgrid( | ||
*[np.arange(self.dim_size[d]) for d in dims], | ||
indexing="ij", | ||
) | ||
dim_positions = {dim: positions[i] for i, dim in enumerate(dims)} | ||
|
||
return self.forward(dim_positions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.