diff --git a/nibabies/workflows/anatomical/fit.py b/nibabies/workflows/anatomical/fit.py index 1f59dc3b..444ddd62 100644 --- a/nibabies/workflows/anatomical/fit.py +++ b/nibabies/workflows/anatomical/fit.py @@ -1030,7 +1030,6 @@ def init_infant_anat_fit_wf( surface_recon_wf = init_mcribs_surface_recon_wf( omp_nthreads=omp_nthreads, use_aseg=bool(anat_aseg), - use_mask=True, precomputed=precomputed, mcribs_dir=str(config.execution.mcribs_dir), ) @@ -1040,10 +1039,8 @@ def init_infant_anat_fit_wf( ('subject_id', 'inputnode.subject_id'), ('subjects_dir', 'inputnode.subjects_dir'), ]), - (t2w_buffer, surface_recon_wf, [ - ('t2w_preproc', 'inputnode.t2w'), - ('t2w_mask', 'inputnode.in_mask'), - ]), + (t2w_validate, surface_recon_wf, [('out_file', 'inputnode.t2w')]), + (t2w_buffer, surface_recon_wf, [('t2w_mask', 'inputnode.in_mask'),]), (aseg_buffer, surface_recon_wf, [ ('anat_aseg', 'inputnode.in_aseg'), ]), @@ -1950,7 +1947,6 @@ def init_infant_single_anat_fit_wf( surface_recon_wf = init_mcribs_surface_recon_wf( omp_nthreads=omp_nthreads, use_aseg=bool(anat_aseg), - use_mask=True, precomputed=precomputed, mcribs_dir=str(config.execution.mcribs_dir), ) @@ -1960,10 +1956,8 @@ def init_infant_single_anat_fit_wf( ('subject_id', 'inputnode.subject_id'), ('subjects_dir', 'inputnode.subjects_dir'), ]), - (anat_buffer, surface_recon_wf, [ - ('anat_preproc', 'inputnode.t2w'), - ('anat_mask', 'inputnode.in_mask'), - ]), + (anat_validate, surface_recon_wf, [('out_file', 'inputnode.t2w')]), + (anat_buffer, surface_recon_wf, [('anat_mask', 'inputnode.in_mask')]), (aseg_buffer, surface_recon_wf, [ ('anat_aseg', 'inputnode.in_aseg'), ]), diff --git a/nibabies/workflows/anatomical/surfaces.py b/nibabies/workflows/anatomical/surfaces.py index cec467db..8cb94993 100644 --- a/nibabies/workflows/anatomical/surfaces.py +++ b/nibabies/workflows/anatomical/surfaces.py @@ -6,6 +6,7 @@ from nipype.interfaces import freesurfer as fs from nipype.interfaces import io as nio from nipype.interfaces import utility as niu +from nipype.interfaces.ants import N4BiasFieldCorrection from nipype.pipeline import engine as pe from niworkflows.engine.workflows import LiterateWorkflow from niworkflows.interfaces.freesurfer import ( @@ -14,6 +15,7 @@ from niworkflows.interfaces.freesurfer import ( PatchedRobustRegister as RobustRegister, ) +from niworkflows.interfaces.morphology import BinaryDilation from niworkflows.interfaces.patches import FreeSurferSource from smriprep.interfaces.freesurfer import MakeMidthickness from smriprep.interfaces.workbench import SurfaceResample @@ -42,7 +44,6 @@ def init_mcribs_surface_recon_wf( *, omp_nthreads: int, use_aseg: bool, - use_mask: bool, precomputed: dict, mcribs_dir: str | None = None, name: str = 'mcribs_surface_recon_wf', @@ -119,7 +120,7 @@ def init_mcribs_surface_recon_wf( fs_to_mcribs = pe.Node(MapLabels(mappings=fs2mcribs), name='fs_to_mcribs') t2w_las = pe.Node(ReorientImage(target_orientation='LAS'), name='t2w_las') - seg_las = t2w_las.clone(name='seg_las') + seg_las = pe.Node(ReorientImage(target_orientation='LAS'), name='seg_las') mcribs_recon = pe.Node( MCRIBReconAll( @@ -136,17 +137,25 @@ def init_mcribs_surface_recon_wf( mcribs_recon.inputs.outdir = mcribs_dir mcribs_recon.config = {'execution': {'remove_unnecessary_outputs': False}} - if use_mask: - # If available, dilated mask and use in recon-neonatal-cortex - from niworkflows.interfaces.morphology import BinaryDilation - - mask_dil = pe.Node(BinaryDilation(radius=3), name='mask_dil') - mask_las = t2w_las.clone(name='mask_las') - workflow.connect([ - (inputnode, mask_dil, [('in_mask', 'in_mask')]), - (mask_dil, mask_las, [('out_mask', 'in_file')]), - (mask_las, mcribs_recon, [('out_file', 'mask_file')]), - ]) # fmt:skip + # dilated mask and use in recon-neonatal-cortex + mask_dil = pe.Node(BinaryDilation(radius=3), name='mask_dil') + mask_las = pe.Node(ReorientImage(target_orientation='LAS'), name='mask_las') + + # N4BiasCorrection occurs in MCRIBTissueSegMCRIBS (which is skipped) + # Run it (with mask to rescale intensities) prior injection + n4_mcribs = pe.Node( + N4BiasFieldCorrection( + dimension=3, + bspline_fitting_distance=200, + save_bias=True, + copy_header=True, + n_iterations=[50] * 5, + convergence_threshold=1e-7, + rescale_intensities=True, + shrink_factor=4, + ), + name='n4_mcribs', + ) mcribs_postrecon = pe.Node( MCRIBReconAll(autorecon_after_surf=True, nthreads=omp_nthreads), @@ -160,11 +169,16 @@ def init_mcribs_surface_recon_wf( workflow.connect([ (inputnode, t2w_las, [('t2w', 'in_file')]), (inputnode, fs_to_mcribs, [('in_aseg', 'in_file')]), + (inputnode, mask_dil, [('in_mask', 'in_mask')]), + (mask_dil, mask_las, [('out_mask', 'in_file')]), + (mask_las, mcribs_recon, [('out_file', 'mask_file')]), (fs_to_mcribs, seg_las, [('out_file', 'in_file')]), (inputnode, mcribs_recon, [ ('subjects_dir', 'subjects_dir'), ('subject_id', 'subject_id')]), - (t2w_las, mcribs_recon, [('out_file', 't2w_file')]), + (t2w_las, n4_mcribs, [('out_file', 'input_image')]), + (mask_las, n4_mcribs, [('out_file', 'mask_image')]), + (n4_mcribs, mcribs_recon, [('output_image', 't2w_file')]), (seg_las, mcribs_recon, [('out_file', 'segmentation_file')]), (inputnode, mcribs_postrecon, [ ('subjects_dir', 'subjects_dir'),