Skip to content

Commit

Permalink
Re #50: add scan conversion to TorchSequenceSegmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscyyeung committed Jun 13, 2023
1 parent 424132b commit ae49685
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,25 @@
</property>
</widget>
</item>
<item row="4" column="0">
<widget class="QLabel" name="label_4">
<property name="text">
<string>Scan conversion config:</string>
</property>
</widget>
</item>
<item row="4" column="1">
<widget class="ctkPathLineEdit" name="scanConversionPathLineEdit">
<property name="filters">
<set>ctkPathLineEdit::AllEntries|ctkPathLineEdit::Dirs|ctkPathLineEdit::Drives|ctkPathLineEdit::Executable|ctkPathLineEdit::Files|ctkPathLineEdit::NoDot|ctkPathLineEdit::NoDotDot|ctkPathLineEdit::Readable</set>
</property>
<property name="nameFilters">
<stringlist>
<string>Configs (*.yaml)</string>
</stringlist>
</property>
</widget>
</item>
</layout>
</widget>
</item>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,28 @@
from slicer.ScriptedLoadableModule import *
from slicer.util import VTKObservationMixin

try:
import yaml
except:
slicer.util.pip_install("pyyaml")
import yaml

try:
import cv2
except:
slicer.util.pip_install("opencv-python")
import cv2

try:
from scipy.ndimage import map_coordinates
from scipy.interpolate import griddata
from scipy.spatial import Delaunay
except:
slicer.util.pip_install('scipy')
from scipy.ndimage import map_coordinates
from scipy.interpolate import griddata
from scipy.spatial import Delaunay

INSTALL_PYTORCHUTILS = False
try:
import torch
Expand Down Expand Up @@ -193,6 +215,7 @@ def setup(self):
self.ui.outputTransformSelector.connect("currentNodeChanged(vtkMRMLNode*)", self.updateParameterNodeFromGUI)
self.ui.verticalFlipCheckbox.connect("toggled(bool)", self.updateParameterNodeFromGUI)
self.ui.modelInputSizeSpinbox.connect("valueChanged(int)", self.updateParameterNodeFromGUI)
self.ui.scanConversionPathLineEdit.connect("currentPathChanged(const QString)", self.updateParameterNodeFromGUI)

# Buttons
self.ui.inputResliceButton.connect("clicked(bool)", self.onInputResliceButton)
Expand Down Expand Up @@ -254,8 +277,13 @@ def enter(self):

# Set last model path in UI
lastModelPath = slicer.util.settingsValue(self.logic.LAST_MODEL_PATH_SETTING, "")
if lastModelPath is not None:
if lastModelPath:
self.ui.modelPathLineEdit.currentPath = lastModelPath

# Set last scan conversion path in UI
lastScanConversionPath = slicer.util.settingsValue(self.logic.LAST_SCAN_CONVERSION_PATH_SETTING, "")
if lastScanConversionPath:
self.ui.scanConversionPathLineEdit.currentPath = lastScanConversionPath

# Create and select volume reconstruction node, if not done yet
if not self.ui.volumeReconstructionSelector.currentNode():
Expand Down Expand Up @@ -361,6 +389,8 @@ def updateGUIFromParameterNode(self, caller=None, event=None):
wasBlocked = self.ui.outputTransformSelector.blockSignals(True)
self.ui.outputTransformSelector.setCurrentNode(inputVolumeParent)
self.ui.outputTransformSelector.blockSignals(wasBlocked)

self.ui.scanConversionPathLineEdit.setCurrentPath(self._parameterNode.GetParameter("ScanConversionPath"))

# Enable/disable buttons
self.ui.segmentButton.setEnabled(sequenceBrowser and inputVolume and not self.logic.isProcessing)
Expand Down Expand Up @@ -399,6 +429,15 @@ def updateParameterNodeFromGUI(self, caller=None, event=None):
self._parameterNode.SetParameter("ModelPath", modelPath)
self.logic.loadModel(modelPath)

# Update scan conversion path
scanConversionPath = self.ui.scanConversionPathLineEdit.currentPath
if not scanConversionPath:
self._parameterNode.SetParameter("ScanConversionPath", "")
else:
if scanConversionPath != self._parameterNode.GetParameter("ScanConversionPath"):
self._parameterNode.SetParameter("ScanConversionPath", scanConversionPath)
self.logic.loadScanConversion(scanConversionPath)

self._parameterNode.EndModify(wasModified)

def onInputResliceButton(self):
Expand Down Expand Up @@ -538,6 +577,7 @@ class TorchSequenceSegmentationLogic(ScriptedLoadableModuleLogic):
"""

LAST_MODEL_PATH_SETTING = "TorchSequenceSegmentation/LastModelPath"
LAST_SCAN_CONVERSION_PATH_SETTING = "TorchSequenceSegmentation/LastScanConversionPath"

def __init__(self):
"""
Expand All @@ -548,6 +588,15 @@ def __init__(self):
self.progressCallback = None
self.isProcessing = False
self.model = None
self.scanConversionDict = None
self.cart_x = None
self.cart_y = None
self.grid_x = None
self.grid_y = None
self.vertices = None
self.weights = None
self.curvilinear_size = None
self.curvilinear_mask = None
self.volRecLogic = slicer.modules.volumereconstruction.logic()

def setDefaultParameters(self, parameterNode):
Expand All @@ -567,7 +616,7 @@ def loadModel(self, modelPath):
logging.warning("Model path is empty")
self.model = None
elif not os.path.isfile(modelPath):
logging.error("Model file does not exist: "+ modelPath)
logging.error("Model file does not exist: " + modelPath)
self.model = None
else:
extra_files = {"config.json": ""}
Expand All @@ -581,6 +630,71 @@ def loadModel(self, modelPath):

settings = qt.QSettings()
settings.setValue(self.LAST_MODEL_PATH_SETTING, modelPath)

def loadScanConversion(self, scanConversionPath):
if not scanConversionPath:
logging.warning("Scan conversion path is empty")
self.scanConversionDict = None
self.cart_x = None
self.cart_y = None
elif not os.path.isfile(scanConversionPath):
logging.error("Scan conversion file does not exist: " + scanConversionPath)
self.scanConversionDict = None
self.cart_x = None
self.cart_y = None
else:
with open(scanConversionPath, "r") as f:
self.scanConversionDict = yaml.safe_load(f)

if self.scanConversionDict:
# Load scan conversion parameters
self.curvilinear_size = self.scanConversionDict["curvilinear_image_size"]
initial_radius = np.deg2rad(self.scanConversionDict["angle_min_degrees"])
final_radius = np.deg2rad(self.scanConversionDict["angle_max_degrees"])
radius_start_px = self.scanConversionDict["radius_start_pixels"]
radius_end_px = self.scanConversionDict["radius_end_pixels"]
num_samples_along_lines = self.scanConversionDict["num_samples_along_lines"]
num_lines = self.scanConversionDict["num_lines"]
center_coordinate_pixel = self.scanConversionDict["center_coordinate_pixel"]

theta, r = np.meshgrid(np.linspace(initial_radius, final_radius, num_samples_along_lines),
np.linspace(radius_start_px, radius_end_px, num_lines))

# Precompute mapping parameters between scan converted and curvilinear images
self.cart_x = r * np.cos(theta) + center_coordinate_pixel[0]
self.cart_y = r * np.sin(theta) + center_coordinate_pixel[1]

self.grid_x, self.grid_y = np.mgrid[0:self.curvilinear_size, 0:self.curvilinear_size]

triangulation = Delaunay(np.vstack((self.cart_x.flatten(), self.cart_y.flatten())).T)

simplices = triangulation.find_simplex(np.vstack((self.grid_x.flatten(), self.grid_y.flatten())).T)
self.vertices = triangulation.simplices[simplices]

X = triangulation.transform[simplices, :2]
Y = np.vstack((self.grid_x.flatten(), self.grid_y.flatten())).T - triangulation.transform[simplices, 2]
b = np.einsum('ijk,ik->ij', X, Y)
self.weights = np.c_[b, 1 - b.sum(axis=1)]

# Compute curvilinear mask, one pixel tighter to avoid artifacts
angle1 = 90.0 + (self.scanConversionDict["angle_min_degrees"] + 1)
angle2 = 90.0 + (self.scanConversionDict["angle_max_degrees"] - 1)

self.curvilinear_mask = np.zeros((self.curvilinear_size, self.curvilinear_size), dtype=np.int8)
self.curvilinear_mask = cv2.ellipse(self.curvilinear_mask,
(center_coordinate_pixel[1], center_coordinate_pixel[0]),
(radius_end_px - 1, radius_end_px - 1), 0.0, angle1, angle2, 1, -1)
self.curvilinear_mask = cv2.circle(self.curvilinear_mask,
(center_coordinate_pixel[1], center_coordinate_pixel[0]),
radius_start_px + 1, 0, -1)

settings = qt.QSettings()
settings.setValue(self.LAST_SCAN_CONVERSION_PATH_SETTING, scanConversionPath)

def scanConvert(self, linearArray):
z = linearArray.flatten()
zi = np.einsum("ij,ij->i", np.take(z, self.vertices), self.weights)
return zi.reshape(self.curvilinear_size, self.curvilinear_size)

def getUniqueName(self, node, baseName):
newName = baseName
Expand Down Expand Up @@ -634,38 +748,44 @@ def getPrediction(self, image):
return

imageArray = slicer.util.arrayFromVolume(image)
imageArray = torch.from_numpy(imageArray).float() # convert to tensor

# Flip image vertically if specified by user
parameterNode = self.getParameterNode()
toFlip = parameterNode.GetParameter("FlipVertical").lower() == "true"
if toFlip:
imageArray = torch.flip(imageArray, dims=[1]) # axis 0 is channel dimension
imageArray = np.flip(imageArray, axis=0) # axis 0 is channel dimension

# Use inverse scan conversion if specified by user, otherwise resize
if self.scanConversionDict:
inputArray = map_coordinates(imageArray[0, :, :], [self.cart_x, self.cart_y], order=1)
else:
inputSize = int(parameterNode.GetParameter("ModelInputSize"))
inputArray = cv2.resize(imageArray[0, :, :], (inputSize, inputSize), antialias=True) # default is bilinear

# Resize input to match model input size
inputSize = int(parameterNode.GetParameter("ModelInputSize"))
inputTensor = torchvision.transforms.functional.resize(imageArray, (inputSize, inputSize), antialias=True) # default is bilinear
inputTensor = inputTensor.unsqueeze(0).to(DEVICE) # add batch dimension
# Convert to tensor and add batch dimension
inputTensor = torch.from_numpy(inputArray).unsqueeze(0).unsqueeze(0).float().to(DEVICE)

# Run prediction
with torch.inference_mode():
output = self.model(inputTensor)
output = torch.nn.functional.softmax(output, dim=1).detach().cpu()
output = torch.nn.functional.softmax(output, dim=1)

# Scan convert or resize
if self.scanConversionDict:
outputArray = output.detach().cpu().numpy() * 255
outputArray = self.scanConvert(outputArray[0, 1, :, :])
outputArray *= self.curvilinear_mask
else:
outputArray = output.squeeze().detach().cpu().numpy() * 255
outputArray = cv2.resize(outputArray, (imageArray.shape[2], imageArray.shape[1]))

# Resize output to match original image size
output = torchvision.transforms.functional.resize(
output,
(imageArray.shape[1], imageArray.shape[2]),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True
)
output = (output.numpy()[:, 1, :, :] * 255).astype(np.uint8)
outputArray = outputArray.astype(np.uint8)[np.newaxis, ...]

# Flip output back if needed
if toFlip:
output = np.flip(output, axis=1)
outputArray = np.flip(outputArray, axis=1)

return output
return outputArray

def segmentSequence(self):
self.isProcessing = True
Expand Down Expand Up @@ -755,7 +875,7 @@ def runVolumeReconstruction(self):
volRenDisplayNode = volRenLogic.CreateDefaultVolumeRenderingNodes(reconstructionVolume)
volRenDisplayNode.SetAndObserveROINodeID(roiNode.GetID())
volPropertyNode = volRenDisplayNode.GetVolumePropertyNode()
volPropertyNode.Copy(volRenLogic.GetPresetByName("US-Fetal"))
volPropertyNode.Copy(volRenLogic.GetPresetByName("MR-Default"))

# Run volume reconstruction
self.volRecLogic.ReconstructVolumeFromSequence(reconstructionNode)
Expand Down

0 comments on commit ae49685

Please sign in to comment.