From 6b013529c2580e8183c48ca91dc5658c8d88186e Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Tue, 23 Apr 2024 15:52:42 +0200 Subject: [PATCH] First attempt at parallel concatenate --- lib/iris/_concatenate.py | 174 ++++++++++++++++++++++++++------------- lib/iris/_lazy_data.py | 11 ++- 2 files changed, 125 insertions(+), 60 deletions(-) diff --git a/lib/iris/_concatenate.py b/lib/iris/_concatenate.py index be953f3437..824ec66406 100644 --- a/lib/iris/_concatenate.py +++ b/lib/iris/_concatenate.py @@ -5,11 +5,16 @@ """Automatic concatenation of multiple cubes over one or more existing dimensions.""" from collections import namedtuple +import itertools import warnings +import dask import dask.array as da +from dask.base import tokenize import numpy as np +from xxhash import xxh3_64 +from iris._lazy_data import is_masked_data import iris.coords import iris.cube import iris.exceptions @@ -34,6 +39,35 @@ _INCREASING = 1 +def hash_array(a: da.Array | np.ndarray) -> np.int64: + def arrayhash(x): + value = xxh3_64(x.data.tobytes()) + if is_masked_data(x): + value.update(x.mask.tobytes()) + return np.frombuffer(value.digest(), dtype=np.int64) + + return da.reduction( + a, + chunk=lambda x, axis, keepdims: arrayhash(x).reshape((1,) * a.ndim), + combine=lambda x, axis, keepdims: arrayhash(x).reshape((1,) * a.ndim), + aggregate=lambda x, axis, keepdims: arrayhash(x)[0], + keepdims=False, + meta=np.empty(tuple(), dtype=np.int64), + dtype=np.int64, + ) + + +class ArrayHash: + def __init__(self, value: np.int64, chunks: tuple) -> None: + self.value = value + self.chunks = chunks + + def __eq__(self, other: "ArrayHash") -> bool: + if self.chunks != other.chunks: + raise ValueError("Unable to compare arrays with different chunks.") + return self.value == other.value + + class _CoordAndDims(namedtuple("CoordAndDims", ["coord", "dims"])): """Container for a coordinate and the associated data dimension(s). @@ -332,6 +366,49 @@ def concatenate( axis = None # Register each cube with its appropriate proto-cube. + arrays = [] + + # 1 collect list of arrays + for cube in cubes: + if check_aux_coords: + for coord in cube.aux_coords: + arrays.append(coord.core_points()) + if coord.has_bounds(): + arrays.append(coord.core_bounds()) + if check_derived_coords: + for coord in cube.derived_coords: + arrays.append(coord.core_points()) + if coord.has_bounds(): + arrays.append(coord.core_bounds()) + if check_cell_measures: + for var in cube.cell_measures(): + arrays.append(var.core_data()) + if check_ancils: + for var in cube.ancillary_variables(): + arrays.append(var.core_data()) + + # 2 unify chunks of arrays that have matching shape + hashes = {} + + def grouper(a): + return a.shape + + arrays.sort(key=grouper) + for _, group in itertools.groupby(arrays, key=grouper): + group = list(group) + indices = tuple(range(group[0].ndim))[::-1] + argpairs = [(a, indices) for a in group] + _, rechunked_group = da.core.unify_chunks(*itertools.chain(*argpairs)) + for array, rechunked in zip(group, rechunked_group): + hashes[dask.base.tokenize(array)] = ( + hash_array(rechunked), + rechunked.chunks, + ) + + # 3 compute hashes + (hashes,) = dask.compute(hashes) + hashes = {k: ArrayHash(*v) for k, v in hashes.items()} + for cube in cubes: registered = False @@ -339,6 +416,7 @@ def concatenate( for proto_cube in proto_cubes: registered = proto_cube.register( cube, + hashes, axis, error_on_mismatch, check_aux_coords, @@ -380,7 +458,7 @@ class _CubeSignature: """ - def __init__(self, cube): + def __init__(self, cube: iris.cube.Cube) -> None: """Represent the cube metadata and associated coordinate metadata. Parameters @@ -413,7 +491,7 @@ def __init__(self, cube): # # Collate the dimension coordinate metadata. # - for ind, coord in enumerate(self.dim_coords): + for coord in self.dim_coords: dims = cube.coord_dims(coord) metadata = _CoordMetaData(coord, dims) self.dim_metadata.append(metadata) @@ -836,6 +914,7 @@ def concatenate(self): def register( self, cube, + hashes, axis=None, error_on_mismatch=False, check_aux_coords=False, @@ -915,73 +994,56 @@ def register( msg = f"Found cubes with overlap on concatenate axis {candidate_axis}, skipping concatenation for these cubes" warnings.warn(msg, category=iris.warnings.IrisUserWarning) - # Check for compatible AuxCoords. - if match: - if check_aux_coords: - for coord_a, coord_b in zip( - self._cube_signature.aux_coords_and_dims, - cube_signature.aux_coords_and_dims, + def get_hash(array): + return hashes[tokenize(array)] + + def get_hashes(coord): + result = [] + if hasattr(coord, "core_points"): + result.append(get_hash(coord.core_points())) + if coord.has_bounds(): + result.append(get_hash(coord.core_bounds())) + else: + result.append(get_hash(coord.core_data())) + return tuple(result) + + def check_coord_match(coord_type): + for coord_a, coord_b in zip( + getattr(self._cube_signature, coord_type), + getattr(cube_signature, coord_type), + ): + # AuxCoords that span the candidate axis can differ + if ( + candidate_axis not in coord_a.dims + or candidate_axis not in coord_b.dims ): - # AuxCoords that span the candidate axis can differ - if ( - candidate_axis not in coord_a.dims - or candidate_axis not in coord_b.dims - ): - if not coord_a == coord_b: - match = False + if coord_a.dims != coord_b.dims: + return False + if get_hashes(coord_a.coord) != get_hashes(coord_b.coord): + return False + return True + + # Check for compatible AuxCoords. + if match and check_aux_coords: + match = check_coord_match("aux_coords_and_dims") # Check for compatible CellMeasures. - if match: - if check_cell_measures: - for coord_a, coord_b in zip( - self._cube_signature.cell_measures_and_dims, - cube_signature.cell_measures_and_dims, - ): - # CellMeasures that span the candidate axis can differ - if ( - candidate_axis not in coord_a.dims - or candidate_axis not in coord_b.dims - ): - if not coord_a == coord_b: - match = False + if match and check_cell_measures: + match = check_coord_match("cell_measures_and_dims") # Check for compatible AncillaryVariables. - if match: - if check_ancils: - for coord_a, coord_b in zip( - self._cube_signature.ancillary_variables_and_dims, - cube_signature.ancillary_variables_and_dims, - ): - # AncillaryVariables that span the candidate axis can differ - if ( - candidate_axis not in coord_a.dims - or candidate_axis not in coord_b.dims - ): - if not coord_a == coord_b: - match = False + if match and check_ancils: + match = check_coord_match("ancillary_variables_and_dims") # Check for compatible derived coordinates. - if match: - if check_derived_coords: - for coord_a, coord_b in zip( - self._cube_signature.derived_coords_and_dims, - cube_signature.derived_coords_and_dims, - ): - # Derived coords that span the candidate axis can differ - if ( - candidate_axis not in coord_a.dims - or candidate_axis not in coord_b.dims - ): - if not coord_a == coord_b: - match = False + if match and check_derived_coords: + match = check_coord_match("derived_coords_and_dims") if match: # Register the cube as a source-cube for this proto-cube. self._add_skeleton(coord_signature, cube.lazy_data()) # Declare the nominated axis of concatenation. self._axis = candidate_axis - - if match: # If the protocube dimension order is constant (indicating it was # created from a cube with a length 1 dimension coordinate) but # a subsequently registered cube has a non-constant dimension diff --git a/lib/iris/_lazy_data.py b/lib/iris/_lazy_data.py index 40984248d1..66701871c0 100644 --- a/lib/iris/_lazy_data.py +++ b/lib/iris/_lazy_data.py @@ -34,11 +34,14 @@ def is_lazy_data(data): """Return whether the argument is an Iris 'lazy' data array. At present, this means simply a :class:`dask.array.Array`. - We determine this by checking for a "compute" property. """ - result = hasattr(data, "compute") - return result + return isinstance(data, da.Array) + + +def is_masked_data(a): + """Determine whether the argument is a masked array.""" + return isinstance(da.utils.meta_from_array(a), np.ma.MaskedArray) def is_lazy_masked_data(data): @@ -48,7 +51,7 @@ def is_lazy_masked_data(data): underlying array is of masked type. Otherwise return False. """ - return is_lazy_data(data) and ma.isMA(da.utils.meta_from_array(data)) + return is_lazy_data(data) and is_masked_data(data) @lru_cache