Skip to content

Commit

Permalink
checked warp saving to ants format
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitrango committed Oct 22, 2024
1 parent 6702de6 commit d08be9b
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 32 deletions.
34 changes: 34 additions & 0 deletions fireants/registration/greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from typing import List, Optional, Union, Callable
from tqdm import tqdm
import SimpleITK as sitk

from fireants.utils.globals import MIN_IMG_SIZE
from fireants.io.image import BatchedImages
Expand Down Expand Up @@ -97,6 +98,39 @@ def get_warped_coordinates(self, fixed_images: BatchedImages, moving_images: Bat
# move these coordinates, and return them
moved_coords = fixed_image_affinecoords + warp_field # affine transform + warp field
return moved_coords

def save_as_ants_transforms(self, filenames: Union[str, List[str]]):
''' given a list of filenames, save the warp field as ants transforms '''
if isinstance(filenames, str):
filenames = [filenames]
assert len(filenames) == self.opt_size, "Number of filenames should match the number of warps"
# get the warp field
fixed_image: BatchedImages = self.fixed_images
moving_image: BatchedImages = self.moving_images
# get the moved coordinates
moved_coords = self.get_warped_coordinates(fixed_image, moving_image) # [B, H, W, [D], dim]
init_grid = F.affine_grid(torch.eye(self.dims, self.dims+1, device=moved_coords.device)[None], \
fixed_image.shape, align_corners=True)
# this is now moved displacements
moved_coords = moved_coords - init_grid

# convert this grid into moving coordinates
moving_t2p = moving_image.get_torch2phy()[:, :self.dims, :self.dims]
moved_coords = torch.einsum('bij, b...j->b...i', moving_t2p, moved_coords)
# save
for i in range(self.opt_size):
moved_disp = moved_coords[i].detach().cpu().numpy() # [H, W, D, 3]
savefile = filenames[i]
# get itk image
if len(fixed_image.images) < i: # this image is probably broadcasted then
itk_data = fixed_image.images[0].itk_image
else:
itk_data = fixed_image.images[i].itk_image
# copy itk data
warp = sitk.GetImageFromArray(moved_disp)
warp.CopyInformation(itk_data)
sitk.WriteImage(warp, savefile)


def optimize(self, save_transformed=False):
''' optimize the warp field to match the two images based on loss function '''
Expand Down
Binary file added tutorials/1000_1001_warp.nii.gz
Binary file not shown.
191 changes: 159 additions & 32 deletions tutorials/[Tutorial 1] Basic Usage.ipynb

Large diffs are not rendered by default.

Binary file added tutorials/antsMoved.nii.gz
Binary file not shown.

0 comments on commit d08be9b

Please sign in to comment.