Skip to content

Commit

Permalink
Merge pull request #433 from mgxd/fix/composite-xfms
Browse files Browse the repository at this point in the history
FIX: Multi-step-reg / composite transform misformation
  • Loading branch information
mgxd authored Jan 22, 2025
2 parents e5ed154 + 29c87b6 commit bfa91a1
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 99 deletions.
40 changes: 1 addition & 39 deletions nibabies/interfaces/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from nipype.interfaces.ants.registration import (
CompositeTransformUtil as _CompositeTransformUtil,
)
from nipype.interfaces.ants.registration import (
CompositeTransformUtilInputSpec as _CompositeTransformUtilInputSpec,
)
from nipype.interfaces.ants.registration import (
CompositeTransformUtilOutputSpec as _CompositeTransformUtilOutputSpec,
)
Expand Down Expand Up @@ -116,22 +113,13 @@ def _list_outputs(self):
return outputs


class CompositeTransformUtilInputSpec(_CompositeTransformUtilInputSpec):
order_transforms = traits.Bool(
True,
usedefault=True,
desc='Order disassembled transforms into [Affine, Displacement] pairs.',
)


class CompositeTransformUtilOutputSpec(_CompositeTransformUtilOutputSpec):
out_transforms = traits.List(desc='list of transform components')
out_transforms = traits.List(desc='list of ordered transform components')


class CompositeTransformUtil(_CompositeTransformUtil):
"""Outputs have changed in newer versions of ANTs."""

input_spec = CompositeTransformUtilInputSpec
output_spec = CompositeTransformUtilOutputSpec

def _list_outputs(self):
Expand All @@ -145,9 +133,6 @@ def _list_outputs(self):
str(Path(x).absolute())
for x in sorted(Path().glob(f'{self.inputs.output_prefix}_*'))
]

if self.inputs.order_transforms:
transforms = _order_xfms(transforms)
outputs['out_transforms'] = transforms

# Potentially could be more than one affine / displacement per composite transform...
Expand All @@ -160,26 +145,3 @@ def _list_outputs(self):
elif self.inputs.process == 'assemble':
outputs['out_file'] = Path(self.inputs.out_file).absolute()
return outputs


def _order_xfms(vals):
"""
Assumes [affine, displacement] or [displacement, affine] transform pairs.
>>> _order_xfms(['DisplacementFieldTransform.nii.gz', 'AffineTransform.mat'])
['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz']
>>> _order_xfms(['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz'])
['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz']
>>> _order_xfms(['DisplacementFieldTransform.nii.gz', 'AffineTransform.mat', \
'AffineTransform.mat'])
['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz', 'AffineTransform.mat']
"""
for i in range(0, len(vals) - 1, 2):
if (
'DisplacementFieldTransform' in Path(vals[i]).name
and 'AffineTransform' in Path(vals[i + 1]).name
):
vals[i], vals[i + 1] = vals[i + 1], vals[i]
return vals
19 changes: 0 additions & 19 deletions nibabies/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.Trans
if path.suffix == '.h5':
# Load as a TransformChain
xfm = nt.manip.load(path)
if len(xfm.transforms) == 4:
# MG: This behavior should be ported to nitransforms
xfm = nt.manip.TransformChain(reverse_pairs(xfm.transforms))
else:
xfm = nt.linear.load(path)
if inv:
Expand All @@ -35,19 +32,3 @@ def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.Trans
if chain is None:
chain = nt.Affine() # Identity
return chain


def reverse_pairs(arr: list) -> list:
"""
Reverse the order of pairs in a list.
>>> reverse_pairs([1, 2, 3, 4])
[3, 4, 1, 2]
>>> reverse_pairs([1, 2, 3, 4, 5, 6])
[5, 6, 3, 4, 1, 2]
"""
rev = []
for i in range(len(arr), 0, -2):
rev.extend(arr[i - 2 : i])
return rev
8 changes: 4 additions & 4 deletions nibabies/workflows/anatomical/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,8 +988,8 @@ def init_infant_anat_fit_wf(
(concat_std2anat_buffer, select_infant_mni, [('out', 'std2anat_xfm')]),
(select_infant_mni, concat_reg_wf, [
('key', 'inputnode.intermediate'),
('anat2std_xfm', 'inputnode.anat2std_xfm'),
('std2anat_xfm', 'inputnode.std2anat_xfm'),
('anat2std_xfm', 'inputnode.anat2int_xfm'),
('std2anat_xfm', 'inputnode.int2anat_xfm'),
]),
(sourcefile_buffer, ds_concat_reg_wf, [
('anat_source_files', 'inputnode.source_files')
Expand Down Expand Up @@ -1905,8 +1905,8 @@ def init_infant_single_anat_fit_wf(
(concat_std2anat_buffer, select_infant_mni, [('out', 'std2anat_xfm')]),
(select_infant_mni, concat_reg_wf, [
('key', 'inputnode.intermediate'),
('anat2std_xfm', 'inputnode.anat2std_xfm'),
('std2anat_xfm', 'inputnode.std2anat_xfm'),
('anat2std_xfm', 'inputnode.anat2int_xfm'),
('std2anat_xfm', 'inputnode.int2anat_xfm'),
]),
(sourcefile_buffer, ds_concat_reg_wf, [
('anat_source_files', 'inputnode.source_files')
Expand Down
134 changes: 97 additions & 37 deletions nibabies/workflows/anatomical/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,14 @@ def init_concat_registrations_wf(
workflow.__desc__ += '.\n' if template == templates[-1] else ', '

inputnode = pe.Node(
niu.IdentityInterface(fields=['template', 'intermediate', 'anat2std_xfm', 'std2anat_xfm']),
niu.IdentityInterface(
fields=[
'template', # template identifier (name[+cohort])
'intermediate', # intermediate space (name[+cohort])
'anat2int_xfm', # anatomical -> intermediate
'int2anat_xfm', # intermediate -> anatomical
]
),
name='inputnode',
)
inputnode.inputs.template = templates
Expand All @@ -401,68 +408,83 @@ def init_concat_registrations_wf(
),
name='intermed_xfms',
iterfield=['std'],
overwrite=True, # otherwise, cache hits but not guarantee files are present on reruns
run_without_submitting=True,
)

split_desc = pe.MapNode(
TemplateDesc(), run_without_submitting=True, iterfield='template', name='split_desc'
)

merge_anat2std = pe.Node(niu.Merge(2), name='merge_anat2std', run_without_submitting=True)
merge_std2anat = merge_anat2std.clone('merge_std2anat')
fmt_cohort = pe.MapNode(
niu.Function(function=_fmt_cohort, output_names=['template', 'spec']),
name='fmt_cohort',
run_without_submitting=True,
iterfield=['template', 'spec'],
)

disassemble_anat2std = pe.MapNode(
CompositeTransformUtil(process='disassemble', output_prefix='anat2std'),
iterfield=['in_file'],
name='disassemble_anat2std',
# Disassemble each composite transform individually for readability
dis_anat2int = pe.Node(
CompositeTransformUtil(process='disassemble', output_prefix='anat2int'),
name='dis_anat2int',
)

disassemble_std2anat = pe.MapNode(
CompositeTransformUtil(process='disassemble', output_prefix='std2anat'),
iterfield=['in_file'],
name='disassemble_std2anat',
dis_int2std = pe.Node(
CompositeTransformUtil(process='disassemble', output_prefix='int2std'),
name='dis_int2std',
)

merge_anat2std_composites = pe.Node(
niu.Merge(1, ravel_inputs=True),
name='merge_anat2std_composites',
dis_std2int = pe.Node(
CompositeTransformUtil(process='disassemble', output_prefix='std2int'),
name='dis_std2int',
)
merge_std2anat_composites = pe.Node(
niu.Merge(1, ravel_inputs=True),
name='merge_std2anat_composites',

dis_int2anat = pe.Node(
CompositeTransformUtil(process='disassemble', output_prefix='int2anat'),
name='dis_int2anat',
)

order_anat2std = pe.Node(niu.Merge(4), name='order_anat2std', run_without_submitting=True)
order_std2anat = pe.Node(niu.Merge(4), name='order_std2anat', run_without_submitting=True)

assemble_anat2std = pe.Node(
CompositeTransformUtil(process='assemble', out_file='anat2std.h5'),
name='assemble_anat2std',
)
# https://github.com/ANTsX/ANTs/issues/1827
# Until CompositeTransformUtil accepts warps as first transform,
# Use SimpleITK to concatenate
assemble_std2anat = pe.Node(
CompositeTransformUtil(process='assemble', out_file='std2anat.h5'),
niu.Function(function=_create_inverse_composite, output_names=['out_file']),
name='assemble_std2anat',
)

fmt_cohort = pe.MapNode(
niu.Function(function=_fmt_cohort, output_names=['template', 'spec']),
name='fmt_cohort',
run_without_submitting=True,
iterfield=['template', 'spec'],
)

workflow.connect([
# Template concatenation
(inputnode, merge_anat2std, [('anat2std_xfm', 'in2')]),
(inputnode, merge_std2anat, [('std2anat_xfm', 'in2')]),
# Transform concatenation
(inputnode, dis_anat2int, [('anat2int_xfm', 'in_file')]),
(inputnode, dis_int2anat, [('int2anat_xfm', 'in_file')]),
(inputnode, intermed_xfms, [('intermediate', 'intermediate')]),
(inputnode, intermed_xfms, [('template', 'std')]),
(intermed_xfms, merge_anat2std, [('int2std_xfm', 'in1')]),
(intermed_xfms, merge_std2anat, [('std2int_xfm', 'in1')]),
(merge_anat2std, disassemble_anat2std, [('out', 'in_file')]),
(merge_std2anat, disassemble_std2anat, [('out', 'in_file')]),
(disassemble_anat2std, merge_anat2std_composites, [('out_transforms', 'in1')]),
(disassemble_std2anat, merge_std2anat_composites, [('out_transforms', 'in1')]),
(merge_anat2std_composites, assemble_anat2std, [('out', 'in_file')]),
(merge_std2anat_composites, assemble_std2anat, [('out', 'in_file')]),
(intermed_xfms, dis_int2std, [('int2std_xfm', 'in_file')]),
(intermed_xfms, dis_std2int, [('std2int_xfm', 'in_file')]),
(dis_anat2int, order_anat2std, [
('affine_transform', 'in1'),
('displacement_field', 'in2'),
]),
(dis_int2std, order_anat2std, [
('affine_transform', 'in3'),
('displacement_field', 'in4'),
]),
# Because std2anat are inverse transforms, warp is first
(dis_std2int, order_std2anat, [
('affine_transform', 'in2'),
('displacement_field', 'in1'),
]),
(dis_int2anat, order_std2anat, [
('affine_transform', 'in4'),
('displacement_field', 'in3'),
]),
(order_anat2std, assemble_anat2std, [('out', 'in_file')]),
(order_std2anat, assemble_std2anat, [('out', 'in_file')]),
(assemble_anat2std, outputnode, [('out_file', 'anat2std_xfm')]),
(assemble_std2anat, outputnode, [('out_file', 'std2anat_xfm')]),

Expand All @@ -483,6 +505,7 @@ def init_concat_registrations_wf(

def _load_intermediate_xfms(intermediate, std):
import json
from pathlib import Path

import pooch

Expand All @@ -496,6 +519,7 @@ def _load_intermediate_xfms(intermediate, std):
int2std_meta = xfms[int2std_name]
int2std = pooch.retrieve(
url=int2std_meta['url'],
path=Path.cwd(),
known_hash=int2std_meta['hash'],
fname=int2std_name,
)
Expand All @@ -504,8 +528,44 @@ def _load_intermediate_xfms(intermediate, std):
std2int_meta = xfms[std2int_name]
std2int = pooch.retrieve(
url=std2int_meta['url'],
path=Path.cwd(),
known_hash=std2int_meta['hash'],
fname=std2int_name,
)

return int2std, std2int


def _create_inverse_composite(in_file, out_file='inverse_composite.h5'):
"""Build a composite transform with SimpleITK.

This serves as a workaround for a bug in ANTs's CompositeTransformUtil
where composite transforms cannot be created with a displacement field placed first.

Parameters
----------
in_file : list of str
List of input transforms to concatenate into a composite transform.
out_file : str, optional
File to write the composite transform to.

Returns
-------
out_file : str
Absolute path to the composite transform.
from pathlib import Path

import SimpleITK as sitk

composite = sitk.CompositeTransform(3)
for xfm_file in in_file:
if xfm_file.endswith('mat'):
xfm = sitk.ReadTransform(xfm_file)
else:
xfm = sitk.DisplacementFieldTransform(sitk.ReadImage(xfm_file))

composite.AddTransform(xfm)

out_file = str(Path(out_file).absolute())
sitk.WriteTransform(composite, out_file)
return out_file
7 changes: 7 additions & 0 deletions nibabies/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,13 @@ def init_workflow_spaces(execution_spaces: SpatialReferences, age_months: int):
if not spaces.is_cached():
spaces.checkpoint()

# Ensure one cohort of MNIInfant is always available as an internal space
if not any(
space.startswith('MNIInfant') for space in spaces.get_spaces(nonstandard=False, dim=(3,))
):
cohort = cohort_by_months('MNIInfant', age_months)
spaces.add(Reference('MNIInfant', {'cohort': cohort}))

if config.workflow.cifti_output:
# CIFTI grayordinates to corresponding FSL-MNI resolutions.
vol_res = '2' if config.workflow.cifti_output == '91k' else '1'
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"psutil >= 5.4",
"pybids >= 0.15.0",
"requests",
"SimpleITK",
"sdcflows >= 2.10.0",
"smriprep >= 0.17.0",
"tedana >= 23.0.2",
Expand Down

0 comments on commit bfa91a1

Please sign in to comment.