diff --git a/nibabies/interfaces/patches.py b/nibabies/interfaces/patches.py index 913af5c1..8ae8ac3a 100644 --- a/nibabies/interfaces/patches.py +++ b/nibabies/interfaces/patches.py @@ -7,9 +7,6 @@ from nipype.interfaces.ants.registration import ( CompositeTransformUtil as _CompositeTransformUtil, ) -from nipype.interfaces.ants.registration import ( - CompositeTransformUtilInputSpec as _CompositeTransformUtilInputSpec, -) from nipype.interfaces.ants.registration import ( CompositeTransformUtilOutputSpec as _CompositeTransformUtilOutputSpec, ) @@ -116,22 +113,13 @@ def _list_outputs(self): return outputs -class CompositeTransformUtilInputSpec(_CompositeTransformUtilInputSpec): - order_transforms = traits.Bool( - True, - usedefault=True, - desc='Order disassembled transforms into [Affine, Displacement] pairs.', - ) - - class CompositeTransformUtilOutputSpec(_CompositeTransformUtilOutputSpec): - out_transforms = traits.List(desc='list of transform components') + out_transforms = traits.List(desc='list of ordered transform components') class CompositeTransformUtil(_CompositeTransformUtil): """Outputs have changed in newer versions of ANTs.""" - input_spec = CompositeTransformUtilInputSpec output_spec = CompositeTransformUtilOutputSpec def _list_outputs(self): @@ -145,9 +133,6 @@ def _list_outputs(self): str(Path(x).absolute()) for x in sorted(Path().glob(f'{self.inputs.output_prefix}_*')) ] - - if self.inputs.order_transforms: - transforms = _order_xfms(transforms) outputs['out_transforms'] = transforms # Potentially could be more than one affine / displacement per composite transform... @@ -160,26 +145,3 @@ def _list_outputs(self): elif self.inputs.process == 'assemble': outputs['out_file'] = Path(self.inputs.out_file).absolute() return outputs - - -def _order_xfms(vals): - """ - Assumes [affine, displacement] or [displacement, affine] transform pairs. - - >>> _order_xfms(['DisplacementFieldTransform.nii.gz', 'AffineTransform.mat']) - ['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz'] - - >>> _order_xfms(['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz']) - ['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz'] - - >>> _order_xfms(['DisplacementFieldTransform.nii.gz', 'AffineTransform.mat', \ - 'AffineTransform.mat']) - ['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz', 'AffineTransform.mat'] - """ - for i in range(0, len(vals) - 1, 2): - if ( - 'DisplacementFieldTransform' in Path(vals[i]).name - and 'AffineTransform' in Path(vals[i + 1]).name - ): - vals[i], vals[i + 1] = vals[i + 1], vals[i] - return vals diff --git a/nibabies/utils/transforms.py b/nibabies/utils/transforms.py index bb030fb4..b1049ebd 100644 --- a/nibabies/utils/transforms.py +++ b/nibabies/utils/transforms.py @@ -21,9 +21,6 @@ def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.Trans if path.suffix == '.h5': # Load as a TransformChain xfm = nt.manip.load(path) - if len(xfm.transforms) == 4: - # MG: This behavior should be ported to nitransforms - xfm = nt.manip.TransformChain(reverse_pairs(xfm.transforms)) else: xfm = nt.linear.load(path) if inv: @@ -35,19 +32,3 @@ def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.Trans if chain is None: chain = nt.Affine() # Identity return chain - - -def reverse_pairs(arr: list) -> list: - """ - Reverse the order of pairs in a list. - - >>> reverse_pairs([1, 2, 3, 4]) - [3, 4, 1, 2] - - >>> reverse_pairs([1, 2, 3, 4, 5, 6]) - [5, 6, 3, 4, 1, 2] - """ - rev = [] - for i in range(len(arr), 0, -2): - rev.extend(arr[i - 2 : i]) - return rev diff --git a/nibabies/workflows/anatomical/fit.py b/nibabies/workflows/anatomical/fit.py index b634b4e8..c692863c 100644 --- a/nibabies/workflows/anatomical/fit.py +++ b/nibabies/workflows/anatomical/fit.py @@ -988,8 +988,8 @@ def init_infant_anat_fit_wf( (concat_std2anat_buffer, select_infant_mni, [('out', 'std2anat_xfm')]), (select_infant_mni, concat_reg_wf, [ ('key', 'inputnode.intermediate'), - ('anat2std_xfm', 'inputnode.anat2std_xfm'), - ('std2anat_xfm', 'inputnode.std2anat_xfm'), + ('anat2std_xfm', 'inputnode.anat2int_xfm'), + ('std2anat_xfm', 'inputnode.int2anat_xfm'), ]), (sourcefile_buffer, ds_concat_reg_wf, [ ('anat_source_files', 'inputnode.source_files') @@ -1905,8 +1905,8 @@ def init_infant_single_anat_fit_wf( (concat_std2anat_buffer, select_infant_mni, [('out', 'std2anat_xfm')]), (select_infant_mni, concat_reg_wf, [ ('key', 'inputnode.intermediate'), - ('anat2std_xfm', 'inputnode.anat2std_xfm'), - ('std2anat_xfm', 'inputnode.std2anat_xfm'), + ('anat2std_xfm', 'inputnode.anat2int_xfm'), + ('std2anat_xfm', 'inputnode.int2anat_xfm'), ]), (sourcefile_buffer, ds_concat_reg_wf, [ ('anat_source_files', 'inputnode.source_files') diff --git a/nibabies/workflows/anatomical/registration.py b/nibabies/workflows/anatomical/registration.py index 506cbd25..1a0aa8e4 100644 --- a/nibabies/workflows/anatomical/registration.py +++ b/nibabies/workflows/anatomical/registration.py @@ -382,7 +382,14 @@ def init_concat_registrations_wf( workflow.__desc__ += '.\n' if template == templates[-1] else ', ' inputnode = pe.Node( - niu.IdentityInterface(fields=['template', 'intermediate', 'anat2std_xfm', 'std2anat_xfm']), + niu.IdentityInterface( + fields=[ + 'template', # template identifier (name[+cohort]) + 'intermediate', # intermediate space (name[+cohort]) + 'anat2int_xfm', # anatomical -> intermediate + 'int2anat_xfm', # intermediate -> anatomical + ] + ), name='inputnode', ) inputnode.inputs.template = templates @@ -401,7 +408,6 @@ def init_concat_registrations_wf( ), name='intermed_xfms', iterfield=['std'], - overwrite=True, # otherwise, cache hits but not guarantee files are present on reruns run_without_submitting=True, ) @@ -409,60 +415,76 @@ def init_concat_registrations_wf( TemplateDesc(), run_without_submitting=True, iterfield='template', name='split_desc' ) - merge_anat2std = pe.Node(niu.Merge(2), name='merge_anat2std', run_without_submitting=True) - merge_std2anat = merge_anat2std.clone('merge_std2anat') + fmt_cohort = pe.MapNode( + niu.Function(function=_fmt_cohort, output_names=['template', 'spec']), + name='fmt_cohort', + run_without_submitting=True, + iterfield=['template', 'spec'], + ) - disassemble_anat2std = pe.MapNode( - CompositeTransformUtil(process='disassemble', output_prefix='anat2std'), - iterfield=['in_file'], - name='disassemble_anat2std', + # Disassemble each composite transform individually for readability + dis_anat2int = pe.Node( + CompositeTransformUtil(process='disassemble', output_prefix='anat2int'), + name='dis_anat2int', ) - disassemble_std2anat = pe.MapNode( - CompositeTransformUtil(process='disassemble', output_prefix='std2anat'), - iterfield=['in_file'], - name='disassemble_std2anat', + dis_int2std = pe.Node( + CompositeTransformUtil(process='disassemble', output_prefix='int2std'), + name='dis_int2std', ) - merge_anat2std_composites = pe.Node( - niu.Merge(1, ravel_inputs=True), - name='merge_anat2std_composites', + dis_std2int = pe.Node( + CompositeTransformUtil(process='disassemble', output_prefix='std2int'), + name='dis_std2int', ) - merge_std2anat_composites = pe.Node( - niu.Merge(1, ravel_inputs=True), - name='merge_std2anat_composites', + + dis_int2anat = pe.Node( + CompositeTransformUtil(process='disassemble', output_prefix='int2anat'), + name='dis_int2anat', ) + order_anat2std = pe.Node(niu.Merge(4), name='order_anat2std', run_without_submitting=True) + order_std2anat = pe.Node(niu.Merge(4), name='order_std2anat', run_without_submitting=True) + assemble_anat2std = pe.Node( CompositeTransformUtil(process='assemble', out_file='anat2std.h5'), name='assemble_anat2std', ) + # https://github.com/ANTsX/ANTs/issues/1827 + # Until CompositeTransformUtil accepts warps as first transform, + # Use SimpleITK to concatenate assemble_std2anat = pe.Node( - CompositeTransformUtil(process='assemble', out_file='std2anat.h5'), + niu.Function(function=_create_inverse_composite, output_names=['out_file']), name='assemble_std2anat', ) - fmt_cohort = pe.MapNode( - niu.Function(function=_fmt_cohort, output_names=['template', 'spec']), - name='fmt_cohort', - run_without_submitting=True, - iterfield=['template', 'spec'], - ) - workflow.connect([ - # Template concatenation - (inputnode, merge_anat2std, [('anat2std_xfm', 'in2')]), - (inputnode, merge_std2anat, [('std2anat_xfm', 'in2')]), + # Transform concatenation + (inputnode, dis_anat2int, [('anat2int_xfm', 'in_file')]), + (inputnode, dis_int2anat, [('int2anat_xfm', 'in_file')]), (inputnode, intermed_xfms, [('intermediate', 'intermediate')]), (inputnode, intermed_xfms, [('template', 'std')]), - (intermed_xfms, merge_anat2std, [('int2std_xfm', 'in1')]), - (intermed_xfms, merge_std2anat, [('std2int_xfm', 'in1')]), - (merge_anat2std, disassemble_anat2std, [('out', 'in_file')]), - (merge_std2anat, disassemble_std2anat, [('out', 'in_file')]), - (disassemble_anat2std, merge_anat2std_composites, [('out_transforms', 'in1')]), - (disassemble_std2anat, merge_std2anat_composites, [('out_transforms', 'in1')]), - (merge_anat2std_composites, assemble_anat2std, [('out', 'in_file')]), - (merge_std2anat_composites, assemble_std2anat, [('out', 'in_file')]), + (intermed_xfms, dis_int2std, [('int2std_xfm', 'in_file')]), + (intermed_xfms, dis_std2int, [('std2int_xfm', 'in_file')]), + (dis_anat2int, order_anat2std, [ + ('affine_transform', 'in1'), + ('displacement_field', 'in2'), + ]), + (dis_int2std, order_anat2std, [ + ('affine_transform', 'in3'), + ('displacement_field', 'in4'), + ]), + # Because std2anat are inverse transforms, warp is first + (dis_std2int, order_std2anat, [ + ('affine_transform', 'in2'), + ('displacement_field', 'in1'), + ]), + (dis_int2anat, order_std2anat, [ + ('affine_transform', 'in4'), + ('displacement_field', 'in3'), + ]), + (order_anat2std, assemble_anat2std, [('out', 'in_file')]), + (order_std2anat, assemble_std2anat, [('out', 'in_file')]), (assemble_anat2std, outputnode, [('out_file', 'anat2std_xfm')]), (assemble_std2anat, outputnode, [('out_file', 'std2anat_xfm')]), @@ -483,6 +505,7 @@ def init_concat_registrations_wf( def _load_intermediate_xfms(intermediate, std): import json + from pathlib import Path import pooch @@ -496,6 +519,7 @@ def _load_intermediate_xfms(intermediate, std): int2std_meta = xfms[int2std_name] int2std = pooch.retrieve( url=int2std_meta['url'], + path=Path.cwd(), known_hash=int2std_meta['hash'], fname=int2std_name, ) @@ -504,8 +528,44 @@ def _load_intermediate_xfms(intermediate, std): std2int_meta = xfms[std2int_name] std2int = pooch.retrieve( url=std2int_meta['url'], + path=Path.cwd(), known_hash=std2int_meta['hash'], fname=std2int_name, ) return int2std, std2int + + +def _create_inverse_composite(in_file, out_file='inverse_composite.h5'): + """Build a composite transform with SimpleITK. + + This serves as a workaround for a bug in ANTs's CompositeTransformUtil + where composite transforms cannot be created with a displacement field placed first. + + Parameters + ---------- + in_file : list of str + List of input transforms to concatenate into a composite transform. + out_file : str, optional + File to write the composite transform to. + + Returns + ------- + out_file : str + Absolute path to the composite transform. + from pathlib import Path + + import SimpleITK as sitk + + composite = sitk.CompositeTransform(3) + for xfm_file in in_file: + if xfm_file.endswith('mat'): + xfm = sitk.ReadTransform(xfm_file) + else: + xfm = sitk.DisplacementFieldTransform(sitk.ReadImage(xfm_file)) + + composite.AddTransform(xfm) + + out_file = str(Path(out_file).absolute()) + sitk.WriteTransform(composite, out_file) + return out_file diff --git a/nibabies/workflows/base.py b/nibabies/workflows/base.py index 9a4a236c..488162ba 100644 --- a/nibabies/workflows/base.py +++ b/nibabies/workflows/base.py @@ -846,6 +846,13 @@ def init_workflow_spaces(execution_spaces: SpatialReferences, age_months: int): if not spaces.is_cached(): spaces.checkpoint() + # Ensure one cohort of MNIInfant is always available as an internal space + if not any( + space.startswith('MNIInfant') for space in spaces.get_spaces(nonstandard=False, dim=(3,)) + ): + cohort = cohort_by_months('MNIInfant', age_months) + spaces.add(Reference('MNIInfant', {'cohort': cohort})) + if config.workflow.cifti_output: # CIFTI grayordinates to corresponding FSL-MNI resolutions. vol_res = '2' if config.workflow.cifti_output == '91k' else '1' diff --git a/pyproject.toml b/pyproject.toml index 07fca109..75625b9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "psutil >= 5.4", "pybids >= 0.15.0", "requests", + "SimpleITK", "sdcflows >= 2.10.0", "smriprep >= 0.17.0", "tedana >= 23.0.2",