diff --git a/connectomics/jax/models/convstack.py b/connectomics/jax/models/convstack.py new file mode 100644 index 0000000..6c488ca --- /dev/null +++ b/connectomics/jax/models/convstack.py @@ -0,0 +1,277 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""2d/3d residual convstack.""" + +import itertools +from typing import Iterable + +from connectomics.common.bounding_box import BoundingBox +from connectomics.jax import parameter_replacement_util as param_util +from connectomics.jax import util +from flax import struct +import flax.linen as nn +import jax.numpy as jnp + + +@struct.dataclass +class ConvstackConfig: + """Config settings for residual convstacks. + + Attributes: + features: number of feature maps + depth: number of residual modules + padding: padding mode to use for convolutions ('same', 'valid') + dim: number of spatial dimensions + num_convs: number of convolutions in the residual module + use_layernorm: whether to use layer normalization; this has been observed to + stabilize the training of FFNs, particularly in the case of deeper models. + out_features: number of output feature maps + enumerate_layers: If true, layer names will be prefixed with their number + within the model. This parameter affects only the way how model params are + names, not the behavior. + kernel_shape: The 3d shape of the convolution kernel + native_input_size: The native spatial size of the model input. The model may + be able to process input of different size, but the configured input is + usually expected to work the best. Changing this parameter does not affect + the inference. + """ + + features: int | Iterable[int] = 32 + depth: int = 12 # number of residual modules + padding: str = 'same' + dim: int = 3 + num_convs: int = 2 + use_layernorm: bool = True + out_features: int = 1 + enumerate_layers: bool = False + kernel_shape: tuple[int, int, int] = (3, 3, 3) + native_input_size: tuple[int, int, int] | None = None + + +class ResConvStack(nn.Module): + """Residual convstack.""" + + config: ConvstackConfig + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """Applies the convstack to the input. + + Args: + x: [batch, z, y, x, channels]-shaped input. + + Returns: + convstack output + """ + cfg = self.config + + layer_naming = param_util.LayerNaming(self.config.enumerate_layers) + + if isinstance(cfg.features, int): + features = itertools.repeat(cfg.features) + else: + features = iter(cfg.features) + + labels = 'abcdefghijklmnopqrstuvwxyz' + + x = nn.Conv( + next(features), + self.config.kernel_shape[: cfg.dim], + padding=cfg.padding, + name=layer_naming.get_name('pre_a'), + )(x) + if cfg.use_layernorm: + x = nn.LayerNorm()(x) + for i in range(1, cfg.num_convs): + x = nn.relu(x) + x = nn.Conv( + next(features), + self.config.kernel_shape[: cfg.dim], + padding=cfg.padding, + name=layer_naming.get_name(f'pre_{labels[i]}'), + )(x) + + for i in range(cfg.depth): + mod_input = x + if cfg.use_layernorm: + x = nn.LayerNorm()(x) + for j in range(0, cfg.num_convs): + x = nn.relu(x) + x = nn.Conv( + next(features), + self.config.kernel_shape[: cfg.dim], + padding=cfg.padding, + name=layer_naming.get_name(f'res{i}{labels[j]}'), + )(x) + + if x.shape != mod_input.shape: + crop_shape_zyx = x.shape[1 : 1 + cfg.dim] + x += util.center_crop(mod_input, crop_shape_zyx) + else: + x += mod_input + + if cfg.use_layernorm: + x = nn.LayerNorm()(x) + x = nn.relu(x) + return nn.Conv( + cfg.out_features, + (1, 1, 1)[: cfg.dim], + name=layer_naming.get_name('output'), + )(x) + + def compute_output_box_from_input_box( + self, input_box: BoundingBox + ) -> BoundingBox: + """Computes the bounding box in the output volume. + + Args: + input_box: The bounding box at the input of the model. + + Returns: + The bounding box in the output volume. + """ + normalized_padding = self.config.padding.lower() + kernel_shape = self.config.kernel_shape + if normalized_padding == 'valid': + # Each layer contract by the (kernel shape - 1) / 2 voxels. + # Each res block contains a number of convs + a skip connection. Only conv + # layers contract. + single_conv_contraction = ( + jnp.asarray(kernel_shape) - jnp.asarray((1, 1, 1)) + ) / 2 + num_contractions = self.config.num_convs * (self.config.depth + 1) + return BoundingBox( + input_box.start + num_contractions * single_conv_contraction, + input_box.size - 2 * num_contractions * single_conv_contraction, + ) + + # When padding, the output of the model results in the same location. + return input_box + + def compute_input_box_from_output_box( + self, output_box: BoundingBox + ) -> BoundingBox: + """Computes the bounding box in the input volume. + + Args: + output_box: The bounding box which should be inferred. + + Returns: + The bounding box in the input volume. + """ + + normalized_padding = self.config.padding.lower() + kernel_shape = self.config.kernel_shape + if normalized_padding == 'valid': + # Each layer contract by the (kernel shape - 1) / 2 voxels. + # Each res block contains a number of convs + a skip connection. Only conv + # layers contract. + single_conv_contraction = ( + jnp.asarray(kernel_shape) - jnp.asarray((1, 1, 1)) + ) / 2 + num_contractions = self.config.num_convs * (self.config.depth + 1) + return BoundingBox( + output_box.start - num_contractions * single_conv_contraction, + output_box.size + 2 * num_contractions * single_conv_contraction, + ) + + # When padding, the output of the model results in the same location. + return output_box + + def get_bounding_box_calculator(self) -> 'ResConvStack': + """Returns the bounding box calculator. + + Returns: + The object capable of transforming bounding boxes between the input and + the output volumes. + """ + return self + + def get_native_output_size(self) -> tuple[int, int, int] | None: + if not self.config.native_input_size: + return None + input_bounding_box = BoundingBox( + start=(0, 0, 0), size=self.config.native_input_size + ) + bbox_calculator = self.get_bounding_box_calculator() + output_box = bbox_calculator.compute_output_box_from_input_box( + input_bounding_box + ) + return output_box.size + + def get_native_input_size(self) -> tuple[int, int, int] | None: + return self.config.native_input_size + + +class ResConvNeXtStack(nn.Module): + """Inspired by ConvNeXt: https://arxiv.org/abs/2201.03545.""" + + config: ConvstackConfig + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """Applies the convstack to the input. + + Args: + x: [batch, z, y, x, channels]-shaped input. + + Returns: + convstack output + """ + cfg = self.config + + if isinstance(cfg.features, int): + features = itertools.repeat(cfg.features) + else: + features = iter(cfg.features) + + point_kernel = (1, 1, 1)[: cfg.dim] + space_kernel = (7, 7, 7)[: cfg.dim] + + feat_out = next(features) + x = nn.Conv(feat_out, space_kernel, padding=cfg.padding, name='pre_a')(x) + x = nn.LayerNorm()(x) + x = nn.Conv(feat_out * 4, point_kernel, padding=cfg.padding, name='pre_b')( + x + ) + x = nn.relu(x) + x = nn.Conv(feat_out, point_kernel, padding=cfg.padding, name='pre_c')(x) + + for i in range(cfg.depth): + mod_input = x + feat_in, feat_out = feat_out, next(features) + x = nn.Conv( + feat_out, + space_kernel, + padding=cfg.padding, + feature_group_count=feat_in, + name=f'res{i}_a', + )(x) + x = nn.LayerNorm()(x) + x = nn.Conv( + feat_out * 4, point_kernel, padding=cfg.padding, name=f'res{i}_b' + )(x) + x = nn.relu(x) + x = nn.Conv( + feat_out, point_kernel, padding=cfg.padding, name=f'res{i}_c' + )(x) + if x.shape != mod_input.shape: + crop_shape_zyx = x.shape[1 : 1 + cfg.dim] + x += util.center_crop(mod_input, crop_shape_zyx) + else: + x += mod_input + + x = nn.relu(x) + return nn.Conv(cfg.out_features, point_kernel, name='output')(x) diff --git a/connectomics/jax/parameter_replacement_util.py b/connectomics/jax/parameter_replacement_util.py new file mode 100644 index 0000000..d53d6fc --- /dev/null +++ b/connectomics/jax/parameter_replacement_util.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A library allowing to manage models with ordered parameter blocks. + +This library defines the ordering on atomic parameter blocks (see below for the +definition and examples) and provides functions for replacing a specified number +of final parameter blocks in one model with parameters from another model. + + +This library assumes that during model evaluation, subcomponents are evaluated +in a sequential order (thus, their parameters can also be ordered). The ordering +must be encoded in the names of parameter keys inside the parameters pytree. + +For each node in the pytree, one of the following must be true: +* All child keys are prefixed with the ordering prefix. +* No child keys are prefixed with the ordering prefix. +The ordering prefix consists of the child numer followed by a colon. + +Examples of correct nodes: + "parent_key": {"child":..., "another child": ...} + "parent_key": {"2:child":..., "2:another child": ...} + +Examples of incorrect pytree nodes: + "parent_key": {"child":..., "1:another child": ...} + +A tree node without child ordering is considered as belonging to an atomic model +component. Parameters of such a component will always be replaced together. +This is true even if children of this node contain parameter block ordering. +This mechanism allows treating small submodels (e.g. convolutional layers having +kernel weights + biases) as undivisible. + +All atomic pytree nodes can be ordered by running a depth first search on the +pytree with the constraint of visiting non-atomic node children in their +declared order. + +For example, let's consider the following tree: +params = { + "1:a": { + "1:c": { + "biases": [1,2,3,4], + "weights": [0, 0.2] + }, + "2:d": { + "biases": [1,2,3,4], + "weights": [0, 0.2] + } + }, + "2:b": { + "biases": [1,2,3,4], + "weights": [0, 0.2] + } +} + +This tree defines the following ordering of atomic tree nodes (here represented +by their key in the parent node, which in this example are unique): + +["1:c", "2:d", "2:b"] + +get_num_atomic_blocks(params) +> 3 + +is_params_block_atomic(params) +> False +is_params_block_atomic(params["2:b"]) +> True + +""" +import itertools +from typing import Any, Optional + +PyTree = Any + + +def get_ordered_child_keys(parameters: PyTree) -> list[str]: + """Returns a list of children node keys sorted in increasing order. + + Args: + parameters: A pytree of model parameters. + + Returns: + List of children keys or empty list if there are no children. + """ + if 'keys' not in dir(parameters): + # There are no children (we are at the leaf node). + return [] + keys_to_sort = [] + + for layer_key in parameters.keys(): + delimiter_position = layer_key.find(':') + if delimiter_position == -1: + continue + keys_to_sort.append((int(layer_key[:delimiter_position]), layer_key)) + return [x[1] for x in sorted(keys_to_sort)] + + +def get_num_atomic_blocks(parameters: Any) -> int: + """Determines the number of parameter blocks which are atomic. + + Args: + parameters: A pytree of model parameters. + + Returns: + The number of atomic parameter blocks. + """ + num_atomic_children = 0 + for child_key in get_ordered_child_keys(parameters): + num_atomic_children += max(1, get_num_atomic_blocks(parameters[child_key])) + return num_atomic_children + + +def is_params_block_atomic(parameters: PyTree) -> bool: + """Checks whether the passed block of parameters is atomic. + + Args: + parameters: pytree of model parameters. + + Returns: + True/False + """ + return not get_ordered_child_keys(parameters) + + +def replace_final_parameters(parameters: PyTree, replacements: PyTree, + num_to_replace: int) -> int: + """Replaces the specified number of final model parameters. + + Replaces last parameters when ordered by the + Args: + parameters: Pytree containing the original parameters. + replacements: Pytree of parameters to replace with. + num_to_replace: The number of atomic parameter blocks to replace. + + Returns: + The number of replaced parameter blocks. + """ + num_replaced = 0 + if num_replaced == num_to_replace: + return num_replaced + + for considered_key in reversed(get_ordered_child_keys(parameters)): + if is_params_block_atomic(parameters[considered_key]): + parameters[considered_key] = replacements[considered_key] + num_replaced += 1 + else: + num_replaced += replace_final_parameters(parameters[considered_key], + replacements[considered_key], + num_to_replace - num_replaced) + if num_replaced == num_to_replace: + return num_replaced + return num_replaced + + +class LayerNaming: + """Generator for model layer names. + + Allows to optionally prefix names with their number. Every time a new name is + requested, the layer number is increased. + """ + + def __init__(self, should_prefix_names: bool): + self.should_prefix_names = should_prefix_names + self.layer_num_iter = itertools.count() + + def get_name(self, + base_name: str, + fallback_name: Optional[str] = None) -> str: + """Generates a layer name. + + If layer numbering is enabled, generates the next layer number and prefixes + basename with this number. Otherwise, returns base name. + + Args: + base_name: The base name of the layer. + fallback_name: Name of the layer to use when name prefixing is disabled. + If specified, takes precedence over 'base_name'. + + Returns: + Generated layer name. + """ + if self.should_prefix_names: + return f'{next(self.layer_num_iter)}:{base_name}' + elif fallback_name is not None: + return fallback_name + else: + return base_name diff --git a/connectomics/jax/parameter_replacement_util_test.py b/connectomics/jax/parameter_replacement_util_test.py new file mode 100644 index 0000000..2514fcc --- /dev/null +++ b/connectomics/jax/parameter_replacement_util_test.py @@ -0,0 +1,95 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for parameter_swap_util.""" + + +from absl.testing import absltest + +from connectomics.jax import parameter_replacement_util + + +class ParameterSwapUtilTest(absltest.TestCase): + + def test_replacement_works_no_replacement(self): + replacements = { + "0:key1": { + "0:key2": "bbb", + "1:key3": "bbb" + }, + "1:key1": "bbb", + } + parameters = {"0:key1": {"0:key2": "aaa", "1:key3": "aaa"}, "1:key1": "aaa"} + self.assertEqual( + 3, parameter_replacement_util.get_num_atomic_blocks(parameters)) + self.assertEqual( + 0, + parameter_replacement_util.replace_final_parameters( + parameters, replacements, 0)) + self.assertEqual( + { + "0:key1": { + "0:key2": "aaa", + "1:key3": "aaa" + }, + "1:key1": "aaa" + }, parameters) + + def test_replacement_works_some_replacement(self): + replacements = { + "0:key1": { + "0:key2": "bbb", + "1:key3": "bbb" + }, + "1:key1": "bbb", + } + parameters = {"0:key1": {"0:key2": "aaa", "1:key3": "aaa"}, "1:key1": "aaa"} + self.assertEqual( + 2, + parameter_replacement_util.replace_final_parameters( + parameters, replacements, 2)) + self.assertEqual( + { + "0:key1": { + "0:key2": "aaa", + "1:key3": "bbb" + }, + "1:key1": "bbb" + }, parameters) + + def test_replacement_works_too_many_replacements(self): + replacements = { + "0:key1": { + "0:key2": "bbb", + "1:key3": "bbb" + }, + "1:key1": "bbb", + } + parameters = {"0:key1": {"0:key2": "aaa", "1:key3": "aaa"}, "1:key1": "aaa"} + self.assertEqual( + 3, + parameter_replacement_util.replace_final_parameters( + parameters, replacements, 12)) + self.assertEqual( + { + "0:key1": { + "0:key2": "bbb", + "1:key3": "bbb" + }, + "1:key1": "bbb" + }, parameters) + + +if __name__ == "__main__": + absltest.main() diff --git a/connectomics/jax/util.py b/connectomics/jax/util.py new file mode 100644 index 0000000..8a9c230 --- /dev/null +++ b/connectomics/jax/util.py @@ -0,0 +1,134 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for JAX / FLAX models. + +This file gets imported by XM launch scipts. Please keep dependencies +minimal and ensure that they work with the binary_import mechanism. +""" + +from typing import Sequence + +from connectomics.common import bounding_box +import jax +import jax.numpy as jnp + + +def center_crop_in_all_dimensions( + x: jax.Array, expected_shape: Sequence[int] +) -> jax.Array: + """Extracts a crop of the expected shape from the center of a tensor. + + No distinction is made between spatial, batch and channel dimensions. + If the expected output size is larger than the input size, no cropping will be + performed in the corresponding dimension. + + Args: + x: The tensor which should be cropped. + expected_shape: A sequence of dimensions after cropping. + + Returns: + The cropped tensor. + """ + starts = [max(0, (x - e) // 2) for x, e in zip(x.shape, expected_shape)] + ends = [min(lim, s + e) for s, e, lim in zip(starts, expected_shape, x.shape)] + slices = tuple([slice(s, e) for s, e in zip(starts, ends)]) + return x[slices] + + +def center_crop(x: jax.Array, crop_spatial_shape: Sequence[int]) -> jax.Array: + """Extracts crop_shape from the center of a xZYXx tensor or xYXx tensor. + + Args: + x: The tensor which should be cropped. + crop_spatial_shape: The spatial shape after croppping. + + Returns: + The cropped tensor. + """ + return center_crop_in_all_dimensions( + x, + tuple(x.shape[: -(len(crop_spatial_shape) + 1)]) + + tuple(crop_spatial_shape) + + (x.shape[-1],), + ) + + +def pad_symmetrically_in_all_dimensions( + x: jax.Array, expected_shape: Sequence[int] +) -> jax.Array: + """Symmetrically pads the provided tensor in all dimensions. + + No distinction is made between spatial, batch and channel dimensions. + If the expected output size is smaller than the input size, no padding will be + performed in the corresponding dimension. + + Args: + x: The tensor which should be padded. + expected_shape: A sequence of dimensions after padding. + + Returns: + The cropped tensor. + """ + total_padding = jnp.asarray(expected_shape) - jnp.asarray(x.shape) + total_padding = jnp.clip(total_padding, 0) + left_padding = total_padding // 2 + right_padding = total_padding - left_padding + requested_padding = jnp.concatenate( + (left_padding.reshape((-1, 1)), right_padding.reshape((-1, 1))), axis=1 + ) + return jnp.pad(x, requested_padding) + + +def pad_symmetrically( + x: jax.Array, padded_spatial_shape: Sequence[int] +) -> jax.Array: + """Spatially pads tensors with batch and channel dimensions. + + Batch and channel dimensions are not affected by the padding. + Example correct input tensor shapes: xZYXc, xYXx. + + Args: + x: The tensor which should be padded. + padded_spatial_shape: The spatial shape after padding. + + Returns: + The cropped tensor. + """ + return pad_symmetrically_in_all_dimensions( + x, (1,) + tuple(padded_spatial_shape) + (1,) + ) + + +def center_crop_bounding_box( + original_box: bounding_box.BoundingBox, final_size_zyx: Sequence[int] +) -> bounding_box.BoundingBox: + """Updates the bounding box to match the final tensor spatial size. + + Cropping assumes that the output tensor (of size `final_size`) corresponds to + the center of the original bounding box. + + Args: + original_box: The bounding box to be cropped. + final_size_zyx: The tensor size after cropping. + + Returns: + The cropped bounding box. + """ + final_size_xyz = tuple(reversed(final_size_zyx)) + cropping_offsets = (original_box.size - final_size_xyz) // 2 + new_start = jnp.asarray(original_box.start) + cropping_offsets + return bounding_box.BoundingBox( + start=tuple(new_start), size=tuple(final_size_xyz) + ) diff --git a/connectomics/jax/util_test.py b/connectomics/jax/util_test.py new file mode 100644 index 0000000..23a2230 --- /dev/null +++ b/connectomics/jax/util_test.py @@ -0,0 +1,104 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for util.""" + +from absl.testing import absltest +from connectomics.common import bounding_box +from connectomics.jax import util +import jax +import jax.numpy as jnp + + +class UtilTest(absltest.TestCase): + + def test_center_crop_in_all_dimensions(self): + original = jnp.arange(0, 1000, 1, dtype=int).reshape((10, 10, 10)) + + # Crop to a larger size in one of dimensions + cropped = util.center_crop_in_all_dimensions(original, (2, 3000, 2)) + self.assertEqual(cropped.shape, (2, 10, 2)) + + # Really cropping in all dimensions + cropped = util.center_crop_in_all_dimensions(original, (2, 2, 2)) + self.assertEqual(cropped.shape, (2, 2, 2)) + self.assertEqual((cropped == jnp.asarray([[[444, 445], [454, 455]], + [[544, 545], [554, 555]]])).all(), + True) + + # No crop in any dimension. + cropped = util.center_crop_in_all_dimensions(cropped, (20, 20, 20)) + self.assertEqual(cropped.shape, (2, 2, 2)) + self.assertEqual((cropped == jnp.asarray([[[444, 445], [454, 455]], + [[544, 545], [554, 555]]])).all(), + True) + + # Crop requested, no symmetric crop possible. We expect less crop on the + # right + original = jnp.arange(0, 125, 1, dtype=int).reshape((5, 5, 5)) + cropped = util.center_crop_in_all_dimensions(original, (2, 2, 2)) + + self.assertEqual(cropped.shape, (2, 2, 2)) + self.assertEqual((cropped == jnp.asarray([[[31, 32], [36, 37]], + [[56, 57], [61, 62]]])).all(), + True) + + def test_center_crop(self): + original = jnp.arange(0, 1000, 1, dtype=int).reshape((10, 10, 10)) + + @jax.jit + def _crop(x): + return util.center_crop(x, (2,)) + + # Crop to a larger size in one of dimensions + cropped = _crop(original) + self.assertEqual(cropped.shape, (10, 2, 10)) + + def test_pad_symmetrically_in_all_dimensions(self): + original = jnp.arange(0, 8, 1, dtype=int).reshape((2, 2, 2)) + + # Need to pad asymmetrically + padded = util.pad_symmetrically_in_all_dimensions(original, (3, 3, 3)) + self.assertEqual(padded.shape, (3, 3, 3)) + self.assertEqual((padded == jnp.asarray([[[0, 1, 0], [2, 3, 0], [0, 0, 0]], + [[4, 5, 0], [6, 7, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], + [0, 0, 0]]])).all(), True) + + # Regular symmetric padding + padded = util.pad_symmetrically_in_all_dimensions(original, (4, 4, 4)) + self.assertEqual(padded.shape, (4, 4, 4)) + self.assertEqual((padded == jnp.asarray([[[0, 0, 0, 0], [0, 0, 0, 0], + [0, 0, 0, 0], [0, 0, 0, 0]], + [[0, 0, 0, 0], [0, 0, 1, 0], + [0, 2, 3, 0], [0, 0, 0, 0]], + [[0, 0, 0, 0], [0, 4, 5, 0], + [0, 6, 7, 0], [0, 0, 0, 0]], + [[0, 0, 0, 0], [0, 0, 0, 0], + [0, 0, 0, 0], [0, 0, 0, + 0]]])).all(), True) + + def test_crop_bounding_box(self): + original = bounding_box.BoundingBox(start=(10, 10, 10), end=(20, 20, 20)) + cropped = util.center_crop_bounding_box(original, (2, 2, 2)) + self.assertEqual(tuple(cropped.start), (14, 14, 14)) + self.assertEqual(tuple(cropped.size), (2, 2, 2)) + + cropped = util.center_crop_bounding_box(original, (3, 3, 3)) + self.assertEqual(tuple(cropped.start), (13, 13, 13)) + self.assertEqual(tuple(cropped.size), (3, 3, 3)) + + +if __name__ == '__main__': + absltest.main()