Skip to content

Commit

Permalink
Re #50: Added support for scan conversion to TorchLiveUs module
Browse files Browse the repository at this point in the history
  • Loading branch information
ungi committed Jun 9, 2023
1 parent 8bdb8e6 commit 395ff27
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<rect>
<x>0</x>
<y>0</y>
<width>308</width>
<width>350</width>
<height>624</height>
</rect>
</property>
Expand Down Expand Up @@ -37,14 +37,14 @@
</property>
</widget>
</item>
<item row="4" column="0">
<item row="5" column="0">
<widget class="QLabel" name="label">
<property name="text">
<string>Input volume:</string>
</property>
</widget>
</item>
<item row="4" column="1">
<item row="5" column="1">
<widget class="qMRMLNodeComboBox" name="inputSelector">
<property name="toolTip">
<string>Pick the input to the algorithm.</string>
Expand All @@ -65,14 +65,14 @@
</property>
</widget>
</item>
<item row="5" column="0">
<item row="6" column="0">
<widget class="QLabel" name="label_7">
<property name="text">
<string>Output volume: </string>
</property>
</widget>
</item>
<item row="5" column="1">
<item row="6" column="1">
<widget class="qMRMLNodeComboBox" name="outputSelector">
<property name="toolTip">
<string>Pick the output to the algorithm.</string>
Expand All @@ -99,14 +99,14 @@
</property>
</widget>
</item>
<item row="6" column="0">
<item row="7" column="0">
<widget class="QLabel" name="label_3">
<property name="text">
<string>Output transform (optional):</string>
</property>
</widget>
</item>
<item row="6" column="1">
<item row="7" column="1">
<widget class="qMRMLNodeComboBox" name="outputTransformSelector">
<property name="nodeTypes">
<stringlist notr="true"/>
Expand Down Expand Up @@ -176,6 +176,16 @@
</property>
</widget>
</item>
<item row="1" column="0">
<widget class="QLabel" name="label_2">
<property name="text">
<string>Scan conversion yaml (optional):</string>
</property>
</widget>
</item>
<item row="1" column="1">
<widget class="ctkPathLineEdit" name="scanConversionPathLineEdit"/>
</item>
</layout>
</widget>
</item>
Expand Down
128 changes: 114 additions & 14 deletions SlicerExtension/LiveUltrasoundAi/TorchLiveUs/TorchLiveUs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,24 @@
import time
import vtk, qt, ctk, slicer

# Import yaml. If import fails, install yaml package

try:
import yaml
except:
slicer.util.pip_install('pyyaml')
import yaml

# Import scipy. If import fails, install scipy package

try:
from scipy.ndimage import map_coordinates
from scipy.interpolate import griddata
except:
slicer.util.pip_install('scipy')
from scipy.ndimage import map_coordinates
from scipy.interpolate import griddata

import slicer
from slicer.ScriptedLoadableModule import *
from slicer.util import VTKObservationMixin
Expand Down Expand Up @@ -151,11 +169,12 @@ def setup(self):
# These connections ensure that whenever user changes some settings on the GUI, that is saved in either the MRML scene
# (in the selected parameter node), or in the application settings (independent of the scene).

self.ui.modelPathLineEdit.connect("currentPathChanged(QString)", self.updateParameterNodeFromGUI)
self.ui.modelPathLineEdit.connect("currentPathChanged(QString)", self.updateSettingsFromGUI)
self.ui.inputSelector.connect("currentNodeChanged(vtkMRMLNode*)", self.updateParameterNodeFromGUI)
self.ui.outputSelector.connect("currentNodeChanged(vtkMRMLNode*)", self.updateParameterNodeFromGUI)
self.ui.outputTransformSelector.connect("currentNodeChanged(vtkMRMLNode*)", self.updateParameterNodeFromGUI)
self.ui.verticalFlipCheckBox.connect("toggled(bool)", self.updateParameterNodeFromGUI)
self.ui.scanConversionPathLineEdit.connect("currentPathChanged(QString)", self.updateSettingsFromGUI)

# Buttons

Expand All @@ -181,6 +200,13 @@ def enter(self):
lastModelPath = self.logic.getLastModelPath()
if lastModelPath is not None:
self.ui.modelPathLineEdit.currentPath = lastModelPath
self.logic.loadModel(lastModelPath)

lastScanConversionPath = slicer.util.settingsValue(self.logic.LAST_SCANCONVERSION_PATH_SETTING, None)
if lastScanConversionPath is not None:
self.ui.scanConversionPathLineEdit.currentPath = lastScanConversionPath

self.updateSettingsFromGUI()

def exit(self):
"""
Expand Down Expand Up @@ -259,7 +285,6 @@ def updateGUIFromParameterNode(self, caller=None, event=None):

self.ui.applyButton.checked = (self._parameterNode.GetParameter(self.logic.PREDICTION_ACTIVE).lower() == "true")

self.ui.modelPathLineEdit.setCurrentPath(self._parameterNode.GetParameter(self.logic.MODEL_PATH))
self.ui.verticalFlipCheckBox.checked = (self._parameterNode.GetParameter(self.logic.VERTICAL_FLIP).lower() == "true")

# All the GUI updates are done
Expand All @@ -281,17 +306,36 @@ def updateParameterNodeFromGUI(self, caller=None, event=None):
self._parameterNode.SetNodeReferenceID(self.logic.OUTPUT_TRANSFORM, self.ui.outputTransformSelector.currentNodeID)
self._parameterNode.SetParameter(self.logic.VERTICAL_FLIP, "true" if self.ui.verticalFlipCheckBox.checked else "false")

self._parameterNode.EndModify(wasModified)

def updateSettingsFromGUI(self, caller=None, event=None):

settings = qt.QSettings()

# Update model path and load model if changed

modelPath = self.ui.modelPathLineEdit.currentPath
if modelPath is None or modelPath == "":
self._parameterNode.SetParameter(self.logic.MODEL_PATH, "")
settings.setValue(self.logic.LAST_AI_MODEL_PATH_SETTING, "")
else:
if modelPath != self._parameterNode.GetParameter(self.logic.MODEL_PATH):
self._parameterNode.SetParameter(self.logic.MODEL_PATH, modelPath)
if modelPath != slicer.util.settingsValue(self.logic.LAST_AI_MODEL_PATH_SETTING, None):
settings.setValue(self.logic.LAST_AI_MODEL_PATH_SETTING, modelPath)
self.logic.loadModel(modelPath)

self._parameterNode.EndModify(wasModified)
# Update scan conversion file path if changed

scanConversionPath = self.ui.scanConversionPathLineEdit.currentPath
if scanConversionPath is None or scanConversionPath == "":
settings.setValue(self.logic.LAST_SCANCONVERSION_PATH_SETTING, "")
else:
if scanConversionPath != slicer.util.settingsValue(self.logic.LAST_SCANCONVERSION_PATH_SETTING, None):
settings.setValue(self.logic.LAST_SCANCONVERSION_PATH_SETTING, scanConversionPath)
self.logic.loadScanConversion(scanConversionPath)
scanConversionDict = self.logic.scanConversionDict
if scanConversionDict is not None:
self.ui.statusLabel.text = "Scan conversion loaded"
logging.info(f"Scan conversion loaded from {scanConversionPath}")
logging.info(f"Scan conversion: {scanConversionDict}")

def onApplyButton(self, toggled):
"""
Expand All @@ -300,27 +344,33 @@ def onApplyButton(self, toggled):

if self._parameterNode.GetNodeReference(self.logic.INPUT_IMAGE) is None:
self.ui.statusLabel.text = "Input volume is required"
self.ui.applyButton.checked = False
return

if self._parameterNode.GetNodeReference(self.logic.OUTPUT_IMAGE) is None:
self.ui.statusLabel.text = "Output volume is required"
self.ui.applyButton.checked = False
return

if self._parameterNode.GetParameter(self.logic.MODEL_PATH) == "":
modelPath = slicer.util.settingsValue(self.logic.LAST_AI_MODEL_PATH_SETTING, None)
if modelPath is None or modelPath == "":
self.ui.statusLabel.text = "Model path is required"
self.ui.applyButton.checked = False
return

try:
if toggled:
self.ui.inputSelector.enabled = False
self.ui.outputSelector.enabled = False
self.ui.modelPathLineEdit.enabled = False
self.ui.scanConversionPathLineEdit.enabled = False
self.ui.applyButton.text = "Stop processing"
self.ui.statusLabel.text = "Running"
else:
self.ui.inputSelector.enabled = True
self.ui.outputSelector.enabled = True
self.ui.modelPathLineEdit.enabled = True
self.ui.scanConversionPathLineEdit.enabled = True
self.ui.applyButton.text = "Start processing"
self.ui.statusLabel.text = "Stopped"

Expand Down Expand Up @@ -359,6 +409,7 @@ class TorchLiveUsLogic(ScriptedLoadableModuleLogic, VTKObservationMixin):
# Settings

LAST_AI_MODEL_PATH_SETTING = "TorchLiveUs/LastModelPath"
LAST_SCANCONVERSION_PATH_SETTING = "TorchLiveUs/LastScanConvertPath"


def __init__(self):
Expand All @@ -369,6 +420,9 @@ def __init__(self):
VTKObservationMixin.__init__(self)

self.model = None
self.scanConversionDict = None
self.x_cart = None
self.y_cart = None

def setDefaultParameters(self, parameterNode):
"""
Expand Down Expand Up @@ -396,8 +450,39 @@ def loadModel(self, modelPath):
self.model = None
else:
self.model = torch.jit.load(modelPath)
settings = qt.QSettings()
settings.setValue(self.LAST_AI_MODEL_PATH_SETTING, modelPath)
logging.info(f"Model loaded from {modelPath}")

def loadScanConversion(self, scanConversionPath):
if scanConversionPath is None or scanConversionPath == "":
logging.warning("Scan conversion path is empty")
self.scanConversionDict = None
self.x_cart = None
self.y_cart = None
elif not os.path.isfile(scanConversionPath):
logging.error("Scan conversion file does not exist: "+scanConversionPath)
self.scanConversionDict = None
self.x_cart = None
self.y_cart = None
else:
with open(scanConversionPath, "r") as f:
self.scanConversionDict = yaml.safe_load(f)

if self.scanConversionDict is not None:
initial_radius = np.deg2rad(self.scanConversionDict["angle_min_degrees"])
final_radius = np.deg2rad(self.scanConversionDict["angle_max_degrees"])
radius_start_px = self.scanConversionDict["radius_start_pixels"]
radius_end_px = self.scanConversionDict["radius_end_pixels"]
num_samples_along_lines = self.scanConversionDict["num_samples_along_lines"]
num_lines = self.scanConversionDict["num_lines"]
center_coordinate_pixel = self.scanConversionDict["center_coordinate_pixel"]

theta, r = np.meshgrid(np.linspace(initial_radius, final_radius, num_samples_along_lines),
np.linspace(radius_start_px, radius_end_px, num_lines))

# Convert the polar coordinates to cartesian coordinates

self.x_cart = r * np.cos(theta) + center_coordinate_pixel[0]
self.y_cart = r * np.sin(theta) + center_coordinate_pixel[1]

def togglePrediction(self, toggled):
"""
Expand Down Expand Up @@ -448,26 +533,41 @@ def onInputVolumeModified(self, inputVolume, event):
if parameterNode.GetParameter(self.VERTICAL_FLIP) == "true":
input_array = np.flip(input_array, axis=0)

# Resize input using opencv with linear interpolation to match model input size
resized_array = cv2.resize(input_array[0, :, :], (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_LINEAR)
# If scan conversion given, map image accordingly.
# Otherwise, resize input using opencv with linear interpolation to match model input size

if self.scanConversionDict is not None:
resized_array = map_coordinates(input_array[0, :, :], [self.x_cart, self.y_cart], order=1, mode='constant', cval=0.0)
else:
resized_array = cv2.resize(input_array[0, :, :], (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_LINEAR)

# Convert to tensor and add batch dimension

input_tensor = torch.from_numpy(resized_array).unsqueeze(0).unsqueeze(0).float()

# Run inference

with torch.inference_mode():
output_logits = self.model(input_tensor)
output_tensor = torch.sigmoid(output_logits)

# Convert output to numpy array

output_array = output_tensor.squeeze().numpy() * 255

# Resize output to match input size
output_array = cv2.resize(output_array, (input_array.shape[2], input_array.shape[1]), interpolation=cv2.INTER_LINEAR)
# If scan conversion given, map image accordingly. Otherwise, resize output to match input size

if self.scanConversionDict is not None:
grid_x, grid_y = np.mgrid[0:input_array.shape[1], 0:input_array.shape[2]]
resized_output_array = griddata((self.x_cart.flatten(), self.y_cart.flatten()), output_array[1, :, :].flatten(),
(grid_x, grid_y), method="linear", fill_value=0)
else:
resized_output_array = cv2.resize(output_array, (input_array.shape[2], input_array.shape[1]), interpolation=cv2.INTER_LINEAR)

# Set output volume image data

outputVolume = parameterNode.GetNodeReference(self.OUTPUT_IMAGE)
slicer.util.updateVolumeFromArray(outputVolume, output_array.astype(np.uint8)[np.newaxis, ...])
slicer.util.updateVolumeFromArray(outputVolume, resized_output_array.astype(np.uint8)[np.newaxis, ...])



Expand Down
8 changes: 4 additions & 4 deletions UltrasoundSegmentation/scanconvert_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ num_lines: !!int 128
num_samples_along_lines: !!int 128
curvilinear_image_size: !!int 512
center_coordinate_pixel: [0, 256]
radius_start_pixels: 95
radius_end_pixels: 422
angle_min_degrees: -37
angle_max_degrees: 37
radius_start_pixels: 100
radius_end_pixels: 420
angle_min_degrees: -36
angle_max_degrees: 36
23 changes: 11 additions & 12 deletions UltrasoundSegmentation/test_dataloader.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 395ff27

Please sign in to comment.