From d629e1abef2347e0b2ceb8a81150f9db66646922 Mon Sep 17 00:00:00 2001 From: Andras Lasso Date: Thu, 12 Dec 2024 07:04:48 -0500 Subject: [PATCH 1/4] Fix installation issues on Windows and improve GUI/logic separation --- MedSAMLite/MedSAMLite.py | 318 ++++++++++++-------------- MedSAMLite/Resources/UI/MedSAMLite.ui | 266 ++++++++++++++++++--- 2 files changed, 378 insertions(+), 206 deletions(-) diff --git a/MedSAMLite/MedSAMLite.py b/MedSAMLite/MedSAMLite.py index 99ce208..a35c42a 100755 --- a/MedSAMLite/MedSAMLite.py +++ b/MedSAMLite/MedSAMLite.py @@ -1,6 +1,7 @@ import logging import os from typing import Annotated, Optional +import pathlib import time import json @@ -16,7 +17,7 @@ import zipfile import slicer -from slicer import vtkMRMLMarkupsROINode, vtkMRMLSegmentationNode +from slicer import vtkMRMLMarkupsROINode, vtkMRMLSegmentationNode, vtkMRMLScalarVolumeNode from slicer.ScriptedLoadableModule import * from slicer.util import VTKObservationMixin from slicer.parameterNodeWrapper import ( @@ -31,11 +32,6 @@ from PythonQt.QtCore import QTimer, QByteArray, Qt from PythonQt.QtGui import QIcon, QPixmap, QMessageBox -try: - import gdown - from medsam_interface import MedSAM_Interface # FIXME -except: - pass # no installation anymore, shorter plugin load MEDSAMLITE_VERSION = 'v0.13' @@ -124,12 +120,17 @@ def registerSampleData(): @parameterNodeWrapper class MedSAMLiteParameterNode: + volumeNode: vtkMRMLScalarVolumeNode + prepMethod: str = 'Manual' roiNode: vtkMRMLMarkupsROINode segmentationNode: vtkMRMLSegmentationNode - modelPath: str - prepMethod: str - prepWinLevel: float = 40.0 - prepWinWidth: float = 400.0 + modelPath: pathlib.Path + prepWinLevel: Annotated[float, WithinRange(-2000, 2000)] = 40.0 + prepWinWidth: Annotated[float, WithinRange(0, 2000)] = 400.0 + engine: str + submodel: str + speed: str + embeddingState: str # @@ -166,7 +167,6 @@ def setup(self) -> None: # Create logic class. Logic implements all computations that should be possible to run # in batch mode, without a graphical user interface. self.logic = MedSAMLiteLogic() - self.logic.widget = self self.logic.server_dir = os.path.join(os.path.dirname(__file__), 'Resources/server_essentials') @@ -200,41 +200,39 @@ def setup(self) -> None: self.layout.addWidget(uiWidget) self.ui = slicer.util.childWidgetVariables(uiWidget) + # Set scene in MRML widgets. Make sure that in Qt designer the top-level qMRMLWidget's + # "mrmlSceneChanged(vtkMRMLScene*)" signal in is connected to each MRML widget's. + # "setMRMLScene(vtkMRMLScene*)" slot. + uiWidget.setMRMLScene(slicer.mrmlScene) + ############################################################################ # Model Selection - self.model_path_widget = self.ui.ctkPathModel - self.model_path_widget.currentPath = os.path.join(self.logic.server_dir, 'medsam_interface/models/classic/medsam_lite.pth') - self.logic.new_model_loaded = True + self.ui.ctkPathModel.currentPath = os.path.join(self.logic.server_dir, 'medsam_interface/models/classic/medsam_lite.pth') ############################################################################ ############################################################################ - # Segmentation Module - import qSlicerSegmentationsModuleWidgetsPythonQt - self.editor = qSlicerSegmentationsModuleWidgetsPythonQt.qMRMLSegmentEditorWidget() - self.editor.setMaximumNumberOfUndoStates(10) - self.selectParameterNode() - self.editor.setMRMLScene(slicer.mrmlScene) - self.ui.clbtnOperation.layout().addWidget(self.editor, 5, 0, 1, 2) + # Segmentation + self.ui.segmentationNodeSelector.currentNodeChanged.connect(self.segmentationNodeChanged) + self.ui.editor.setMaximumNumberOfUndoStates(10) + + # Select parameter set node if one is found in the scene, and create one otherwise + segmentEditorSingletonTag = "SegmentEditor" + segmentEditorNode = slicer.mrmlScene.GetSingletonNode(segmentEditorSingletonTag, "vtkMRMLSegmentEditorNode") + if segmentEditorNode is None: + segmentEditorNode = slicer.mrmlScene.CreateNodeByClass("vtkMRMLSegmentEditorNode") + segmentEditorNode.UnRegister(None) + segmentEditorNode.SetSingletonTag(segmentEditorSingletonTag) + segmentEditorNode = slicer.mrmlScene.AddNode(segmentEditorNode) + self.ui.editor.setMRMLSegmentEditorNode(segmentEditorNode) + ############################################################################ ########################################################################### # Volume load/close tracker - from PythonQt.qMRMLWidgets import qMRMLNodeComboBox - - self.qNodeSelect = qMRMLNodeComboBox() - self.qNodeSelect.addEnabled = False - self.qNodeSelect.removeEnabled = False - self.qNodeSelect.nodeTypes = ['vtkMRMLScalarVolumeNode'] - self.qNodeSelect.setMRMLScene(slicer.mrmlScene) - self.qNodeSelect.currentNodeChanged.connect(self.logic.volumeChanged) - self.logic.volumeChanged() - ########################################################################### + self.ui.qNodeSelect.currentNodeChanged.connect(self.volumeChanged) - # Set scene in MRML widgets. Make sure that in Qt designer the top-level qMRMLWidget's - # "mrmlSceneChanged(vtkMRMLScene*)" signal in is connected to each MRML widget's. - # "setMRMLScene(vtkMRMLScene*)" slot. - uiWidget.setMRMLScene(slicer.mrmlScene) + ########################################################################### # Connections @@ -244,14 +242,14 @@ def setup(self) -> None: # Buttons self.ui.pbUpgrade.setVisible(False) # it's gliching so let's hide it :D - self.ui.pbUpgrade.connect('clicked(bool)', lambda: self.logic.run_on_background(self.logic.upgrade, (True,), 'Checking for updates...')) - self.ui.pbSendImage.connect('clicked(bool)', lambda: self.logic.sendImage(partial=False)) - self.ui.pbSegment.connect('clicked(bool)', lambda: self.logic.applySegmentation()) + self.ui.pbUpgrade.connect('clicked(bool)', self.upgrade) + self.ui.pbSendImage.connect('clicked(bool)', self.sendImage) + self.ui.pbSegment.connect('clicked(bool)', self.applySegmentation) # Preprocessing self.ui.cmbPrepOptions.addItems(['Manual', 'Abdominal CT', 'Lung CT', 'Brain CT', 'Mediastinum CT', 'MR']) - self.ui.cmbPrepOptions.currentTextChanged.connect(lambda new_text: self.setManualPreprocessVis(new_text == 'Manual')) - self.ui.pbApplyPrep.connect('clicked(bool)', lambda: self.logic.applyPreprocess(self.ui.cmbPrepOptions.currentText, self.ui.sldWinLevel.value, self.ui.sldWinWidth.value)) + self.ui.cmbPrepOptions.currentTextChanged.connect(self.setPrepMethod) + self.ui.pbApplyPrep.connect('clicked(bool)', self.preprocess) # Hide unnecessary ROI controls self.ui.widgetROI.findChild("QLabel", "label").hide() @@ -259,7 +257,7 @@ def setup(self) -> None: self.ui.widgetROI.findChild("QLabel", "label_10").hide() self.ui.widgetROI.findChild("QComboBox", "roiTypeComboBox").hide() - # Segmentation Engine + # Segmentation Engine self.engine_list = [ { 'name': 'Classic MedSAM', @@ -318,18 +316,50 @@ def setup(self) -> None: # Segmentation Speed self.ui.cmbSpeed.addItems(['Normal Speed - Highest Quality', 'Faster Segmentation - High Quality', 'Fastest Segmentation - Moderate Quality']) - + self.ui.cmbSpeed.currentTextChanged.connect(self.setSpeed) + self.ui.pbAttach.connect('clicked(bool)', lambda: self._createAndAttachROI()) self.ui.pbTwoDim.connect('clicked(bool)', lambda: self.makeROI2D()) self.ui.pbLowerSelection.connect('clicked(bool)', lambda: self.setROIboundary(lower=True)) self.ui.pbUpperSelection.connect('clicked(bool)', lambda: self.setROIboundary(lower=False)) - self.model_path_widget.connect('currentPathChanged(const QString&)', lambda: setattr(self.logic, 'new_model_loaded', True)) - # Make sure parameter node is initialized (needed for module reload) self.initializeParameterNode() self.newEngineSelected('Classic MedSAM') - + + self.updateGUIFromParameterNode() + + def upgrade(self): + self.logic.run_on_background(self.logic.upgrade, (True,), 'Checking for updates...') + + def sendImage(self): + self.logic.sendImage(partial=False) + + def applySegmentation(self): + if not self._parameterNode.segmentationNode: + self._parameterNode.segmentationNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode") + self._parameterNode.segmentationNode.SetReferenceImageGeometryParameterFromVolumeNode(self._parameterNode.volumeNode) + + self.logic.applySegmentation() + + def volumeChanged(self, volumeNode): + self._parameterNode.embeddingState = 'SINGLE' + if self._parameterNode.segmentationNode and volumeNode: + self._parameterNode.segmentationNode.SetReferenceImageGeometryParameterFromVolumeNode(volumeNode) + + def segmentationNodeChanged(self, segmentationNode): + if segmentationNode and self._parameterNode.volumeNode: + segmentationNode.SetReferenceImageGeometryParameterFromVolumeNode(self._parameterNode.volumeNode) + + def setPrepMethod(self, method): + self._parameterNode.prepMethod = method + + def setSpeed(self, speed): + self._parameterNode.speed = speed + + def preprocess(self): + self.logic.applyPreprocess(self._parameterNode.prepMethod, self._parameterNode.prepWinLevel, self._parameterNode.prepWinWidth) + def setManualPreprocessVis(self, visible): self.ui.lblLevel.setVisible(visible) self.ui.lblWidth.setVisible(visible) @@ -337,9 +367,8 @@ def setManualPreprocessVis(self, visible): self.ui.sldWinWidth.setVisible(visible) def newEngineSelected(self, new_engine): + self._parameterNode.engine = new_engine current_engine = list(filter(lambda x: x['name'] == new_engine, self.engine_list))[0] - # inform logic object - self.logic.new_model_loaded = True # load list of submodels self.dont_invoke_submodel_change = True # prevent onchange event to happen self.ui.cmbSubModel.clear() @@ -352,8 +381,7 @@ def newEngineSelected(self, new_engine): ctrl.setVisible(False) # change engine-related paths - self.model_path_widget.currentPath = current_engine['default checkpoint'] - self.updateAllParameters() + self._parameterNode.modelPath = pathlib.Path(current_engine['default checkpoint']) # if there is a submodel, choose the first one if len(current_engine['submodels']) > 0: @@ -364,73 +392,55 @@ def newEngineSelected(self, new_engine): self.logic.download_if_necessary(current_engine['url'], current_engine['default checkpoint']) def newSubmodelSelected(self, new_submodel): + self._parameterNode.submodel = new_submodel if self.dont_invoke_submodel_change: return - current_submodel = list(filter(lambda x: x['name'] == self.ui.cmbEngine.currentText, self.engine_list))[0]['submodels'][new_submodel] - # inform logic object - self.logic.new_model_loaded = True + self._parameterNode.submodel = new_submodel + current_submodel = list(filter(lambda x: x['name'] == new_submodel, self.engine_list))[0]['submodels'][new_submodel] # change submodel-related paths - self.model_path_widget.currentPath = current_submodel['checkpoint'] - self.updateAllParameters() + self._parameterNode.modelPath = pathlib.Path(current_submodel['checkpoint']) # download checkpoints if necessary self.logic.download_if_necessary(current_submodel['url'], current_submodel['checkpoint']) - def selectParameterNode(self): - # Select parameter set node if one is found in the scene, and create one otherwise - segmentEditorSingletonTag = "SegmentEditor" - segmentEditorNode = slicer.mrmlScene.GetSingletonNode(segmentEditorSingletonTag, "vtkMRMLSegmentEditorNode") - if segmentEditorNode is None: - segmentEditorNode = slicer.mrmlScene.CreateNodeByClass("vtkMRMLSegmentEditorNode") - segmentEditorNode.UnRegister(None) - segmentEditorNode.SetSingletonTag(segmentEditorSingletonTag) - segmentEditorNode = slicer.mrmlScene.AddNode(segmentEditorNode) - if self.parameterSetNode == segmentEditorNode: - # nothing changed - return - self.parameterSetNode = segmentEditorNode - self.editor.setMRMLSegmentEditorNode(self.parameterSetNode) - - def _createAndAttachROI(self): # Make sure there is only one 'R' - if self.logic.volume_node is None: - self.logic.volume_node = slicer.util.getNodesByClass('vtkMRMLScalarVolumeNode')[0] - volumeNode = self.logic.volume_node + if self._parameterNode.volumeNode is None: + slicer.util.errorDisplay("Select a source volume first") + return # Create a new ROI that will be fit to volumeNode - roiNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsROINode", "R") + if not self._parameterNode.roiNode: + self._parameterNode.roiNode = slicer.mrmlScene.GetFirstNodeByClass("vtkMRMLMarkupsROINode") + if not self._parameterNode.roiNode: + self._parameterNode.roiNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsROINode", "R") cropVolumeParameters = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLCropVolumeParametersNode") - cropVolumeParameters.SetInputVolumeNodeID(volumeNode.GetID()) - cropVolumeParameters.SetROINodeID(roiNode.GetID()) + cropVolumeParameters.SetInputVolumeNodeID(self._parameterNode.volumeNode.GetID()) + cropVolumeParameters.SetROINodeID(self._parameterNode.roiNode.GetID()) slicer.modules.cropvolume.logic().SnapROIToVoxelGrid(cropVolumeParameters) # optional (rotates the ROI to match the volume axis directions) slicer.modules.cropvolume.logic().FitROIToInputVolume(cropVolumeParameters) slicer.mrmlScene.RemoveNode(cropVolumeParameters) self.scaleROI(.85) - self.ui.widgetROI.setMRMLMarkupsNode(slicer.util.getNode("R")) - self.updateAllParameters() - - def scaleROI(self, ratio): # Make sure there is exactly one 'R' - roiNode = slicer.util.getNode('R') + roiNode = self._parameterNode.roiNode roi_size = roiNode.GetSize() roiNode.SetSize(int(roi_size[0] * ratio), int(roi_size[1] * ratio), int(roi_size[2] * ratio)) def makeROI2D(self): # Make sure there is exactly one 'R' - roiNode = slicer.util.getNode('R') + roiNode = self._parameterNode.roiNode roi_size = roiNode.GetSize() roiNode.SetSize(roi_size[0], roi_size[1], 1) roi_center = np.array(roiNode.GetCenter()) roiNode.SetCenter([roi_center[0], roi_center[1], slicer.app.layoutManager().sliceWidget("Red").sliceLogic().GetSliceOffset()]) def setROIboundary(self, lower): - roiNode = slicer.util.getNode('R') + roiNode = self._parameterNode.roiNode bounds = np.zeros(6) roiNode.GetBounds(bounds) @@ -505,6 +515,12 @@ def initializeParameterNode(self) -> None: self.setParameterNode(self.logic.getParameterNode()) + # Select default input nodes if nothing is selected yet to save a few clicks for the user + if not self._parameterNode.volumeNode: + firstVolumeNode = slicer.mrmlScene.GetFirstNodeByClass("vtkMRMLScalarVolumeNode") + if firstVolumeNode: + self._parameterNode.volumeNode = firstVolumeNode + def setParameterNode(self, inputParameterNode: Optional[MedSAMLiteParameterNode]) -> None: """ Set and observe parameter node. @@ -517,42 +533,33 @@ def setParameterNode(self, inputParameterNode: Optional[MedSAMLiteParameterNode] if self._parameterNode: # Note: in the .ui file, a Qt dynamic property called "SlicerParameterName" is set on each # ui element that needs connection. - try: - self._parameterNodeGuiTag = self._parameterNode.connectGui(self.ui) - except: - pass #this part might be invoked before UI is loaded. does not cause any issues but might be confusing for users - self.renderAllParameters() + self._parameterNodeGuiTag = self._parameterNode.connectGui(self.ui) + self.addObserver(self._parameterNode, vtk.vtkCommand.ModifiedEvent, self.updateGUIFromParameterNode) + self.updateGUIFromParameterNode() - def renderAllParameters(self): - if self._parameterNode.modelPath: - self.model_path_widget.currentPath = self._parameterNode.modelPath - if self._parameterNode.roiNode: - self.ui.widgetROI.setMRMLMarkupsNode(slicer.util.getNode("R")) + def updateGUIFromParameterNode(self, unused1=None, unused2=None): + self.ui.editor.setSegmentationNode(self._parameterNode.segmentationNode) if self._parameterNode.segmentationNode: - self.logic.segment_res_group = self._parameterNode.segmentationNode - + self.ui.editor.setSourceVolumeNode(self._parameterNode.volumeNode) + self.ui.widgetROI.setMRMLMarkupsNode(self._parameterNode.roiNode) - if self._parameterNode.prepMethod: - self.ui.cmbPrepOptions.currentText = self._parameterNode.prepMethod - if self._parameterNode.prepWinLevel: - self.ui.sldWinLevel.value = self._parameterNode.prepWinLevel - if self._parameterNode.prepWinWidth: - self.ui.sldWinWidth.value = self._parameterNode.prepWinWidth - + if self._parameterNode.embeddingState == 'IN_PROGRESS': + self.ui.pbSegment.setEnabled(False) + self.ui.pbSegment.setText('Processing, please wait...') + else: + self.ui.pbSegment.setEnabled(True) + if self._parameterNode.embeddingState == 'SINGLE': + self.ui.pbSegment.setText('Segment single') + elif self._parameterNode.embeddingState == 'FULL': + self.ui.pbSegment.setText('Segment') - def updateAllParameters(self): - self._parameterNode.modelPath = self.model_path_widget.currentPath - try: - self._parameterNode.roiNode = slicer.util.getNode('R') - except: - pass - self._parameterNode.segmentationNode = self.logic.segment_res_group - self._parameterNode.prepMethod = self.ui.cmbPrepOptions.currentText - self._parameterNode.prepWinLevel = self.ui.sldWinLevel.value - self._parameterNode.prepWinWidth = self.ui.sldWinWidth.value - + self.ui.cmbPrepOptions.currentText = self._parameterNode.prepMethod + self.ui.cmbEngine.currentText = self._parameterNode.engine + self.ui.cmbSubModel.currentText = self._parameterNode.submodel + self.ui.cmbSpeed.currentText = self._parameterNode.speed + self.setManualPreprocessVis(self._parameterNode.prepMethod == 'Manual') # # MedSAMLiteLogic @@ -568,15 +575,11 @@ class MedSAMLiteLogic(ScriptedLoadableModuleLogic): https://github.com/Slicer/Slicer/blob/main/Base/Python/slicer/ScriptedLoadableModule.py """ image_data = None - segment_res_group = None - server_ready = False server_process = None - volume_node = None timer = None progressbar = None server_dir = None - widget = None - new_model_loaded = True + loaded_model_path = None backend = None test_mode = False @@ -586,6 +589,7 @@ def __init__(self) -> None: """ ScriptedLoadableModuleLogic.__init__(self) try: # In case the dependencies are not installed, an error will raise + from medsam_interface import MedSAM_Interface # FIXME self.backend = MedSAM_Interface() except: pass @@ -720,6 +724,7 @@ def run_on_background(self, target, args, title, progress_check=None): self.progressbar.close() def download_model(self, url, model_path, event): + import gdown gdown.download_folder(url=url, output=model_path) event.set() @@ -733,12 +738,11 @@ def download_if_necessary(self, model_url, model_path): def run_server(self): #FIXME show that 'Backend is loading...' - self.widget.updateAllParameters() - self.backend.set_engine(self.widget.ui.cmbEngine.currentText) - self.widget.renderAllParameters() - self.backend.MedSAM_CKPT_PATH = self.widget.model_path_widget.currentPath + self.backend.set_engine(self.getParameterNode().engine) + modelPath = self.getParameterNode().modelPath + self.backend.MedSAM_CKPT_PATH = modelPath self.backend.load_model() - self.server_ready = True + self.loaded_model_path = modelPath def progressCheck(self, partial=False): slicer.app.processEvents() @@ -748,32 +752,28 @@ def progressCheck(self, partial=False): if progress_data['layers'] <= progress_data['generated_embeds']: self.progressbar.close() self.timer.stop() - self.widget.ui.pbSegment.setEnabled(True) if partial: segmentation_mask = self.inferSegmentation() self.showSegmentation(segmentation_mask) - self.widget.ui.pbSegment.setText('Single Segmentation') + self.getParameterNode().embeddingState = 'SINGLE' else: - self.widget.ui.pbSegment.setText('Segmentation') - - def volumeChanged(self, node=None): - self.widget.ui.pbSegment.setText('Single Segmentation') + self.getParameterNode().embeddingState = 'FULL' + def _getVolumeNode(self): + return self.getParameterNode().volumeNode def captureImage(self): ######## Set your image path here - self.volume_node = slicer.util.getNodesByClass('vtkMRMLScalarVolumeNode')[0] - self.image_data = slicer.util.arrayFromVolume(self.volume_node) ################ Only one node? + self.image_data = slicer.util.arrayFromVolume(self._getVolumeNode()) if len(self.image_data.shape) == 4 and self.image_data.shape[-1] == 4: # colored image, it can have 4 channels (r,g,b,a) so we remove the last one self.image_data = self.image_data[:,:,:,:3] def sendImage(self, partial=False): - self.widget.ui.pbSegment.setEnabled(False) - self.widget.ui.pbSegment.setText('Sending image, please wait...') + self.getParameterNode().embeddingState = 'IN_PROGRESS' - if self.new_model_loaded or not self.server_ready: + if self.loaded_model_path != self.getParameterNode().modelPath: self.run_server() - self.new_model_loaded = False + ############ Partial segmentation if partial: @@ -793,10 +793,10 @@ def sendImage(self, partial=False): self.timer.timeout.connect(lambda: self.progressCheck(partial)) self.timer.start(1000) - self.backend.speed_level = 1 if 'Normal' in self.widget.ui.cmbSpeed.currentText else 2 if 'Faster' in self.widget.ui.cmbSpeed.currentText else 3 + speed = self.getParameterNode().speed + self.backend.speed_level = 1 if 'Normal' in speed else 2 if 'Faster' in speed else 3 self.backend.set_image(self.image_data, -160, 240, zmin, zmax, recurrent_func=slicer.app.processEvents) - self.widget.updateAllParameters() # self.run_on_background(self.embedding_prep_wrapper, (self.image_data, -160, 240, zmin, zmax), "Preparing image embeddings...", lambda: self.progressCheck(partial)) def embedding_prep_wrapper(self, arr, wmin, wmax, zmin, zmax, event): @@ -808,7 +808,7 @@ def embedding_prep_wrapper(self, arr, wmin, wmax, zmin, zmax, event): def inferSegmentation(self): print('sending infer request...') ################ DEBUG MODE ################ - if self.volume_node is None: + if self._getVolumeNode() is None: self.captureImage() ################ DEBUG MODE ################ @@ -823,37 +823,18 @@ def inferSegmentation(self): def showSegmentation(self, segmentation_mask): segment_volume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode", 'segment_'+str(int(time.time()))) - slicer.util.updateVolumeFromArray(segment_volume, segmentation_mask) - - current_seg_group = self.widget.editor.segmentationNode() - if current_seg_group is None: - if self.segment_res_group is None: - self.segment_res_group = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode") - self.segment_res_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node) - current_seg_group = self.segment_res_group - - try: - check_if_node_is_removed = slicer.util.getNode(current_seg_group.GetID()) # if scene is closed and reopend, this line will raise an error - except: - self.segment_res_group = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode") - self.segment_res_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node) - current_seg_group = self.segment_res_group - - - current_seg_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node) - slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, current_seg_group) - slicer.util.updateSegmentBinaryLabelmapFromArray(segmentation_mask, current_seg_group, segment_volume.GetName(), self.volume_node) - + segmentation_mask_int = segmentation_mask.astype(np.int16) + slicer.util.updateVolumeFromArray(segment_volume, segmentation_mask_int) + slicer.util.updateSegmentBinaryLabelmapFromArray(segmentation_mask_int, self.getParameterNode().segmentationNode, segment_volume.GetName(), self._getVolumeNode()) slicer.mrmlScene.RemoveNode(segment_volume) - self.widget.updateAllParameters() def singleSegmentation(self): self.sendImage(partial=True) def applySegmentation(self): - if self.widget.ui.pbSegment.text == 'Single Segmentation': + if self.getParameterNode().embeddingState == 'SINGLE': continueSingle = QMessageBox.question(None,'', "You are using single segmentation option which is faster but is not advised if you want large or multiple regions be segmented in one image. In that case click 'Send Image' button. Do you wish to continue with single segmentation?", QMessageBox.Yes | QMessageBox.No) if continueSingle == QMessageBox.No: return self.singleSegmentation() @@ -868,7 +849,7 @@ def get_bounding_box(self): # If volume node is transformed, apply that transform to get volume's RAS coordinates transformRasToVolumeRas = vtk.vtkGeneralTransform() - slicer.vtkMRMLTransformNode.GetTransformBetweenNodes(None, self.volume_node.GetParentTransformNode(), transformRasToVolumeRas) + slicer.vtkMRMLTransformNode.GetTransformBetweenNodes(None, self._getVolumeNode().GetParentTransformNode(), transformRasToVolumeRas) bounds = np.zeros(6) roiNode.GetBounds(bounds) @@ -883,7 +864,7 @@ def get_bounding_box(self): # Get voxel coordinates from physical coordinates volumeRasToIjk = vtk.vtkMatrix4x4() - self.volume_node.GetRASToIJKMatrix(volumeRasToIjk) + self._getVolumeNode().GetRASToIJKMatrix(volumeRasToIjk) point_Ijk = [0, 0, 0, 1] volumeRasToIjk.MultiplyPoint(np.append(point_VolumeRas,1.0), point_Ijk) point_Ijk = [ int(round(c)) for c in point_Ijk[0:3] ] @@ -906,8 +887,8 @@ def preprocess_CT(self, win_level=40.0, win_width=400.0): image_data_pre = (image_data_pre - np.min(image_data_pre))/(np.max(image_data_pre)-np.min(image_data_pre))*255.0 image_data_pre = np.uint8(image_data_pre) - self.volume_node.GetDisplayNode().SetAutoWindowLevel(False) - self.volume_node.GetDisplayNode().SetWindowLevelMinMax(0, 255) + self._getVolumeNode().GetDisplayNode().SetAutoWindowLevel(False) + self._getVolumeNode().GetDisplayNode().SetWindowLevelMinMax(0, 255) return image_data_pre @@ -919,14 +900,14 @@ def preprocess_MR(self, lower_percent=0.5, upper_percent=99.5): image_data_pre = (image_data_pre - np.min(image_data_pre))/(np.max(image_data_pre)-np.min(image_data_pre))*255.0 image_data_pre = np.uint8(image_data_pre) - self.volume_node.GetDisplayNode().SetAutoWindowLevel(False) - self.volume_node.GetDisplayNode().SetWindowLevelMinMax(0, 255) + self._getVolumeNode().GetDisplayNode().SetAutoWindowLevel(False) + self._getVolumeNode().GetDisplayNode().SetWindowLevelMinMax(0, 255) return image_data_pre def updateImage(self, new_image): self.image_data[:,:,:] = new_image - slicer.util.arrayFromVolumeModified(self.volume_node) + slicer.util.arrayFromVolumeModified(self._getVolumeNode()) def applyPreprocess(self, method, win_level, win_width): if method == 'MR': @@ -945,7 +926,6 @@ def applyPreprocess(self, method, win_level, win_width): self.updateImage(prep_img) - self.widget.updateAllParameters() # diff --git a/MedSAMLite/Resources/UI/MedSAMLite.ui b/MedSAMLite/Resources/UI/MedSAMLite.ui index 9237051..cba498e 100644 --- a/MedSAMLite/Resources/UI/MedSAMLite.ui +++ b/MedSAMLite/Resources/UI/MedSAMLite.ui @@ -6,8 +6,8 @@ 0 0 - 739 - 1206 + 428 + 864 @@ -24,7 +24,14 @@ Prepare Data - + + + + Apply + + + + 1 @@ -38,9 +45,26 @@ 40.000000000000000 + + prepWinLevel + - + + + + Preprocessing Options: + + + + + + + Window Level: + + + + 1 @@ -51,36 +75,53 @@ 400.000000000000000 + + prepWinWidth + - - - - - - - Window Level: + + + + - + Window Width: - - - - Preprocessing Options: + + + + + vtkMRMLScalarVolumeNode + + + + + + + false + + + false + + + + + + volumeNode - - + + - Apply + Source Volume: @@ -133,7 +174,7 @@ Start Segmentation - + MedSAM Model: @@ -141,13 +182,40 @@ + + + Engine: + + + + + + + false + + + false + + + + + + + + + + Quality: + + + + Submodel: - + ctkPathLineEdit::Executable|ctkPathLineEdit::Files|ctkPathLineEdit::NoDot|ctkPathLineEdit::NoDotDot|ctkPathLineEdit::Readable @@ -160,41 +228,90 @@ *.ckpt + + modelPath + - + + + + Preprocess + + + + - + + + + - - + + - Send Image + Segment - - - - Select Engine: + + + + + vtkMRMLScalarVolumeNode + + + + + + + true + + + false + + + true + + + true + + + + + + (Create new) + + + segmentationNode - - + + - Segmentation + Output segmentation: - - - + + + + Qt::Vertical + + + + 20 + 40 + + + + @@ -214,6 +331,11 @@ QWidget
ctkSliderWidget.h
+ + qMRMLNodeComboBox + QWidget +
qMRMLNodeComboBox.h
+
qMRMLWidget QWidget @@ -225,7 +347,77 @@ QWidget
qMRMLMarkupsROIWidget.h
+ + qMRMLSegmentEditorWidget + qMRMLWidget +
qMRMLSegmentEditorWidget.h
+
- + + + t_SegmentData + mrmlSceneChanged(vtkMRMLScene*) + qNodeSelect + setMRMLScene(vtkMRMLScene*) + + + 423 + 859 + + + 340 + 87 + + + + + t_SegmentData + mrmlSceneChanged(vtkMRMLScene*) + editor + setMRMLScene(vtkMRMLScene*) + + + 312 + 859 + + + 87 + 729 + + + + + t_SegmentData + mrmlSceneChanged(vtkMRMLScene*) + segmentationNodeSelector + setMRMLScene(vtkMRMLScene*) + + + 368 + 857 + + + 252 + 561 + + + + + segmentationNodeSelector + currentNodeChanged(vtkMRMLNode*) + editor + setSegmentationNode(vtkMRMLNode*) + + + 216 + 557 + + + 236 + 726 + + + + From 501ea3c0aecea0fb22adab00458e897b5c961203 Mon Sep 17 00:00:00 2001 From: Andras Lasso Date: Thu, 12 Dec 2024 12:42:46 -0500 Subject: [PATCH 2/4] Simplify GUI by hiding not frequently needed options --- MedSAMLite/MedSAMLite.py | 43 ++- MedSAMLite/Resources/UI/MedSAMLite.ui | 422 +++++++++++++++----------- 2 files changed, 286 insertions(+), 179 deletions(-) diff --git a/MedSAMLite/MedSAMLite.py b/MedSAMLite/MedSAMLite.py index a35c42a..48220f0 100755 --- a/MedSAMLite/MedSAMLite.py +++ b/MedSAMLite/MedSAMLite.py @@ -210,9 +210,11 @@ def setup(self) -> None: self.ui.ctkPathModel.currentPath = os.path.join(self.logic.server_dir, 'medsam_interface/models/classic/medsam_lite.pth') ############################################################################ - ############################################################################ # Segmentation + self.ui.roiOptionsFrame.setVisible(False) + self.ui.segmentationOptionsFrame.setVisible(False) + self.ui.segmentationNodeSelector.currentNodeChanged.connect(self.segmentationNodeChanged) self.ui.editor.setMaximumNumberOfUndoStates(10) @@ -319,6 +321,7 @@ def setup(self) -> None: self.ui.cmbSpeed.currentTextChanged.connect(self.setSpeed) self.ui.pbAttach.connect('clicked(bool)', lambda: self._createAndAttachROI()) + self.ui.pbPlaceROI.connect('clicked(bool)', lambda: self._placeROI()) self.ui.pbTwoDim.connect('clicked(bool)', lambda: self.makeROI2D()) self.ui.pbLowerSelection.connect('clicked(bool)', lambda: self.setROIboundary(lower=True)) self.ui.pbUpperSelection.connect('clicked(bool)', lambda: self.setROIboundary(lower=False)) @@ -403,6 +406,13 @@ def newSubmodelSelected(self, new_submodel): # download checkpoints if necessary self.logic.download_if_necessary(current_submodel['url'], current_submodel['checkpoint']) + def _getOrCreateROI(self): + if not self._parameterNode.roiNode: + self._parameterNode.roiNode = slicer.mrmlScene.GetFirstNodeByClass("vtkMRMLMarkupsROINode") + if not self._parameterNode.roiNode: + self._parameterNode.roiNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsROINode", "R") + return self._parameterNode.roiNode + def _createAndAttachROI(self): # Make sure there is only one 'R' if self._parameterNode.volumeNode is None: @@ -410,20 +420,29 @@ def _createAndAttachROI(self): return # Create a new ROI that will be fit to volumeNode - if not self._parameterNode.roiNode: - self._parameterNode.roiNode = slicer.mrmlScene.GetFirstNodeByClass("vtkMRMLMarkupsROINode") - if not self._parameterNode.roiNode: - self._parameterNode.roiNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsROINode", "R") + roiNode = self._getOrCreateROI() cropVolumeParameters = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLCropVolumeParametersNode") cropVolumeParameters.SetInputVolumeNodeID(self._parameterNode.volumeNode.GetID()) - cropVolumeParameters.SetROINodeID(self._parameterNode.roiNode.GetID()) + cropVolumeParameters.SetROINodeID(roiNode.GetID()) slicer.modules.cropvolume.logic().SnapROIToVoxelGrid(cropVolumeParameters) # optional (rotates the ROI to match the volume axis directions) slicer.modules.cropvolume.logic().FitROIToInputVolume(cropVolumeParameters) slicer.mrmlScene.RemoveNode(cropVolumeParameters) self.scaleROI(.85) + def _placeROI(self): + # Make sure there is exactly one 'R' + roiNode = self._getOrCreateROI() + roiNode.RemoveAllControlPoints() + + selectionNode = slicer.mrmlScene.GetNodeByID("vtkMRMLSelectionNodeSingleton") + selectionNode.SetReferenceActivePlaceNodeClassName("vtkMRMLMarkupsROINode") + selectionNode.SetReferenceActivePlaceNodeID(roiNode.GetID()) + + interactionNode = slicer.mrmlScene.GetNodeByID("vtkMRMLInteractionNodeSingleton") + interactionNode.SetPlaceModePersistence(False) # stop placement mode after placing one ROI + interactionNode.SetCurrentInteractionMode(slicer.vtkMRMLInteractionNode().Place) def scaleROI(self, ratio): # Make sure there is exactly one 'R' @@ -433,14 +452,14 @@ def scaleROI(self, ratio): def makeROI2D(self): # Make sure there is exactly one 'R' - roiNode = self._parameterNode.roiNode + roiNode = self._getOrCreateROI() roi_size = roiNode.GetSize() roiNode.SetSize(roi_size[0], roi_size[1], 1) roi_center = np.array(roiNode.GetCenter()) roiNode.SetCenter([roi_center[0], roi_center[1], slicer.app.layoutManager().sliceWidget("Red").sliceLogic().GetSliceOffset()]) def setROIboundary(self, lower): - roiNode = self._parameterNode.roiNode + roiNode = roiNode = self._getOrCreateROI() bounds = np.zeros(6) roiNode.GetBounds(bounds) @@ -539,6 +558,12 @@ def setParameterNode(self, inputParameterNode: Optional[MedSAMLiteParameterNode] def updateGUIFromParameterNode(self, unused1=None, unused2=None): + if self._parameterNode is None: + self.ui.editor.setSegmentationNode(None) + self.ui.widgetROI.setMRMLMarkupsNode(None) + self.ui.pbSegment.setEnabled(False) + return + self.ui.editor.setSegmentationNode(self._parameterNode.segmentationNode) if self._parameterNode.segmentationNode: self.ui.editor.setSourceVolumeNode(self._parameterNode.volumeNode) @@ -548,7 +573,7 @@ def updateGUIFromParameterNode(self, unused1=None, unused2=None): self.ui.pbSegment.setEnabled(False) self.ui.pbSegment.setText('Processing, please wait...') else: - self.ui.pbSegment.setEnabled(True) + self.ui.pbSegment.setEnabled((self._parameterNode.volumeNode is not None) and (self._parameterNode.roiNode is not None)) if self._parameterNode.embeddingState == 'SINGLE': self.ui.pbSegment.setText('Segment single') elif self._parameterNode.embeddingState == 'FULL': diff --git a/MedSAMLite/Resources/UI/MedSAMLite.ui b/MedSAMLite/Resources/UI/MedSAMLite.ui index cba498e..5b3f9e6 100644 --- a/MedSAMLite/Resources/UI/MedSAMLite.ui +++ b/MedSAMLite/Resources/UI/MedSAMLite.ui @@ -6,19 +6,50 @@ 0 0 - 428 - 864 + 474 + 1012 - - + + Upgrade Module - + + + + Source Volume: + + + + + + + + vtkMRMLScalarVolumeNode + + + + + + + false + + + false + + + + + + volumeNode + + + + Prepare Data @@ -94,101 +125,16 @@ - - - - - vtkMRMLScalarVolumeNode - - - - - - - false - - - false - - - - - - volumeNode - - - - - - - Source Volume: - - - - - - - - - - Select the Region of Interest - - - - - - Attach ROI - - - - - - - 2D Selection - - - - - - - Set Current Frame As Selection's Start - - - - - - - Set Current Frame As Selection's End - - - - - - - + - Start Segmentation + Segmentation - - - - MedSAM Model: - - - - - - - Engine: - - - - + false @@ -201,105 +147,204 @@ - - - - Quality: + + + + Qt::Vertical - - - - Submodel: + + + + + 1 + 0 + - - - - - - ctkPathLineEdit::Executable|ctkPathLineEdit::Files|ctkPathLineEdit::NoDot|ctkPathLineEdit::NoDotDot|ctkPathLineEdit::Readable - - - - *.pth - *.xml - *.onnx - *.ckpt - - - - modelPath - - - - - - Preprocess + Place ROI - - - - - - - - - - + Segment - - - - - vtkMRMLScalarVolumeNode - - - - - - - true - - - false - - - true - - - true - - - - - - (Create new) + + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + + Fit to volume + + + + + + + End at current slice + + + + + + + Start at current slice + + + + + + + Fit to current slice + + + + + + + + + + + + + Precompute segmentation for the entire image for faster segmentation - - segmentationNode + + Precompute - - - - Output segmentation: + + + + Qt::Vertical + + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + + Output: + + + + + + + + vtkMRMLScalarVolumeNode + + + + + + + true + + + false + + + true + + + true + + + + + + (Create new) + + + segmentationNode + + + + + + + Engine: + + + + + + + Submodel: + + + + + + + MedSAM Model: + + + + + + + Quality: + + + + + + + + + + + + + ctkPathLineEdit::Executable|ctkPathLineEdit::Files|ctkPathLineEdit::NoDot|ctkPathLineEdit::NoDotDot|ctkPathLineEdit::Readable + + + + *.pth + *.xml + *.onnx + *.ckpt + + + + modelPath + + + + + + + + + - + Qt::Vertical @@ -321,6 +366,11 @@
ctkCollapsibleButton.h
1 + + ctkExpandButton + QToolButton +
ctkExpandButton.h
+
ctkPathLineEdit QWidget @@ -366,8 +416,8 @@ 859 - 340 - 87 + 432 + 58 @@ -419,5 +469,37 @@ + + Button + toggled(bool) + roiOptionsFrame + setVisible(bool) + + + 657 + 261 + + + 656 + 296 + + + + + Button_2 + toggled(bool) + segmentationOptionsFrame + setVisible(bool) + + + 660 + 529 + + + 657 + 574 + + +
From ec32dc611cd888eb5920c3133ac80390638116bd Mon Sep 17 00:00:00 2001 From: Andras Lasso Date: Thu, 12 Dec 2024 12:52:09 -0500 Subject: [PATCH 3/4] Do not overwrite the input volume --- MedSAMLite/MedSAMLite.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/MedSAMLite/MedSAMLite.py b/MedSAMLite/MedSAMLite.py index 48220f0..ec30181 100755 --- a/MedSAMLite/MedSAMLite.py +++ b/MedSAMLite/MedSAMLite.py @@ -343,6 +343,13 @@ def applySegmentation(self): self._parameterNode.segmentationNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode") self._parameterNode.segmentationNode.SetReferenceImageGeometryParameterFromVolumeNode(self._parameterNode.volumeNode) + if self._parameterNode.embeddingState == 'SINGLE': + continueSingle = QMessageBox.question(None,'', "You are using single segmentation option which is faster but is not advised if you want large or multiple regions be segmented in one image. In that case click 'Precompute' button. Do you wish to continue with single segmentation?", QMessageBox.Yes | QMessageBox.No) + if continueSingle == QMessageBox.No: + return + self.logic.singleSegmentation() + return + self.logic.applySegmentation() def volumeChanged(self, volumeNode): @@ -789,7 +796,7 @@ def _getVolumeNode(self): def captureImage(self): ######## Set your image path here - self.image_data = slicer.util.arrayFromVolume(self._getVolumeNode()) + self.image_data = slicer.util.arrayFromVolume(self._getVolumeNode()).copy() if len(self.image_data.shape) == 4 and self.image_data.shape[-1] == 4: # colored image, it can have 4 channels (r,g,b,a) so we remove the last one self.image_data = self.image_data[:,:,:,:3] @@ -859,12 +866,6 @@ def singleSegmentation(self): def applySegmentation(self): - if self.getParameterNode().embeddingState == 'SINGLE': - continueSingle = QMessageBox.question(None,'', "You are using single segmentation option which is faster but is not advised if you want large or multiple regions be segmented in one image. In that case click 'Send Image' button. Do you wish to continue with single segmentation?", QMessageBox.Yes | QMessageBox.No) - if continueSingle == QMessageBox.No: return - self.singleSegmentation() - - return segmentation_mask = self.inferSegmentation() self.showSegmentation(segmentation_mask) From da174c1ec2028d5e57060c6944cff2777bfc9b39 Mon Sep 17 00:00:00 2001 From: Andras Lasso Date: Thu, 12 Dec 2024 13:00:00 -0500 Subject: [PATCH 4/4] Automatically preprocess The slight reduction in computation time is not worth the extra complexity for the user. --- MedSAMLite/MedSAMLite.py | 19 +++---------------- MedSAMLite/Resources/UI/MedSAMLite.ui | 24 +++++++----------------- 2 files changed, 10 insertions(+), 33 deletions(-) diff --git a/MedSAMLite/MedSAMLite.py b/MedSAMLite/MedSAMLite.py index ec30181..83d85cb 100755 --- a/MedSAMLite/MedSAMLite.py +++ b/MedSAMLite/MedSAMLite.py @@ -245,7 +245,6 @@ def setup(self) -> None: # Buttons self.ui.pbUpgrade.setVisible(False) # it's gliching so let's hide it :D self.ui.pbUpgrade.connect('clicked(bool)', self.upgrade) - self.ui.pbSendImage.connect('clicked(bool)', self.sendImage) self.ui.pbSegment.connect('clicked(bool)', self.applySegmentation) # Preprocessing @@ -333,10 +332,7 @@ def setup(self) -> None: self.updateGUIFromParameterNode() def upgrade(self): - self.logic.run_on_background(self.logic.upgrade, (True,), 'Checking for updates...') - - def sendImage(self): - self.logic.sendImage(partial=False) + self.logic.run_on_background(self.logic.upgrade, (True,), 'Checking for updates...') def applySegmentation(self): if not self._parameterNode.segmentationNode: @@ -344,11 +340,7 @@ def applySegmentation(self): self._parameterNode.segmentationNode.SetReferenceImageGeometryParameterFromVolumeNode(self._parameterNode.volumeNode) if self._parameterNode.embeddingState == 'SINGLE': - continueSingle = QMessageBox.question(None,'', "You are using single segmentation option which is faster but is not advised if you want large or multiple regions be segmented in one image. In that case click 'Precompute' button. Do you wish to continue with single segmentation?", QMessageBox.Yes | QMessageBox.No) - if continueSingle == QMessageBox.No: - return - self.logic.singleSegmentation() - return + self.logic.sendImage(partial=False) self.logic.applySegmentation() @@ -582,7 +574,7 @@ def updateGUIFromParameterNode(self, unused1=None, unused2=None): else: self.ui.pbSegment.setEnabled((self._parameterNode.volumeNode is not None) and (self._parameterNode.roiNode is not None)) if self._parameterNode.embeddingState == 'SINGLE': - self.ui.pbSegment.setText('Segment single') + self.ui.pbSegment.setText('Preprocess and Segment') elif self._parameterNode.embeddingState == 'FULL': self.ui.pbSegment.setText('Segment') @@ -859,11 +851,6 @@ def showSegmentation(self, segmentation_mask): slicer.util.updateVolumeFromArray(segment_volume, segmentation_mask_int) slicer.util.updateSegmentBinaryLabelmapFromArray(segmentation_mask_int, self.getParameterNode().segmentationNode, segment_volume.GetName(), self._getVolumeNode()) slicer.mrmlScene.RemoveNode(segment_volume) - - - def singleSegmentation(self): - self.sendImage(partial=True) - def applySegmentation(self): segmentation_mask = self.inferSegmentation() diff --git a/MedSAMLite/Resources/UI/MedSAMLite.ui b/MedSAMLite/Resources/UI/MedSAMLite.ui index 5b3f9e6..44861b3 100644 --- a/MedSAMLite/Resources/UI/MedSAMLite.ui +++ b/MedSAMLite/Resources/UI/MedSAMLite.ui @@ -167,13 +167,6 @@ - - - - Segment - - - @@ -217,16 +210,6 @@
- - - - Precompute segmentation for the entire image for faster segmentation - - - Precompute - - - @@ -341,6 +324,13 @@ + + + + Segment + + +