diff --git a/MedSAMLite/MedSAMLite.py b/MedSAMLite/MedSAMLite.py index 99ce208..83d85cb 100755 --- a/MedSAMLite/MedSAMLite.py +++ b/MedSAMLite/MedSAMLite.py @@ -1,6 +1,7 @@ import logging import os from typing import Annotated, Optional +import pathlib import time import json @@ -16,7 +17,7 @@ import zipfile import slicer -from slicer import vtkMRMLMarkupsROINode, vtkMRMLSegmentationNode +from slicer import vtkMRMLMarkupsROINode, vtkMRMLSegmentationNode, vtkMRMLScalarVolumeNode from slicer.ScriptedLoadableModule import * from slicer.util import VTKObservationMixin from slicer.parameterNodeWrapper import ( @@ -31,11 +32,6 @@ from PythonQt.QtCore import QTimer, QByteArray, Qt from PythonQt.QtGui import QIcon, QPixmap, QMessageBox -try: - import gdown - from medsam_interface import MedSAM_Interface # FIXME -except: - pass # no installation anymore, shorter plugin load MEDSAMLITE_VERSION = 'v0.13' @@ -124,12 +120,17 @@ def registerSampleData(): @parameterNodeWrapper class MedSAMLiteParameterNode: + volumeNode: vtkMRMLScalarVolumeNode + prepMethod: str = 'Manual' roiNode: vtkMRMLMarkupsROINode segmentationNode: vtkMRMLSegmentationNode - modelPath: str - prepMethod: str - prepWinLevel: float = 40.0 - prepWinWidth: float = 400.0 + modelPath: pathlib.Path + prepWinLevel: Annotated[float, WithinRange(-2000, 2000)] = 40.0 + prepWinWidth: Annotated[float, WithinRange(0, 2000)] = 400.0 + engine: str + submodel: str + speed: str + embeddingState: str # @@ -166,7 +167,6 @@ def setup(self) -> None: # Create logic class. Logic implements all computations that should be possible to run # in batch mode, without a graphical user interface. self.logic = MedSAMLiteLogic() - self.logic.widget = self self.logic.server_dir = os.path.join(os.path.dirname(__file__), 'Resources/server_essentials') @@ -200,41 +200,41 @@ def setup(self) -> None: self.layout.addWidget(uiWidget) self.ui = slicer.util.childWidgetVariables(uiWidget) + # Set scene in MRML widgets. Make sure that in Qt designer the top-level qMRMLWidget's + # "mrmlSceneChanged(vtkMRMLScene*)" signal in is connected to each MRML widget's. + # "setMRMLScene(vtkMRMLScene*)" slot. + uiWidget.setMRMLScene(slicer.mrmlScene) + ############################################################################ # Model Selection - self.model_path_widget = self.ui.ctkPathModel - self.model_path_widget.currentPath = os.path.join(self.logic.server_dir, 'medsam_interface/models/classic/medsam_lite.pth') - self.logic.new_model_loaded = True + self.ui.ctkPathModel.currentPath = os.path.join(self.logic.server_dir, 'medsam_interface/models/classic/medsam_lite.pth') ############################################################################ - ############################################################################ - # Segmentation Module - import qSlicerSegmentationsModuleWidgetsPythonQt - self.editor = qSlicerSegmentationsModuleWidgetsPythonQt.qMRMLSegmentEditorWidget() - self.editor.setMaximumNumberOfUndoStates(10) - self.selectParameterNode() - self.editor.setMRMLScene(slicer.mrmlScene) - self.ui.clbtnOperation.layout().addWidget(self.editor, 5, 0, 1, 2) + # Segmentation + self.ui.roiOptionsFrame.setVisible(False) + self.ui.segmentationOptionsFrame.setVisible(False) + + self.ui.segmentationNodeSelector.currentNodeChanged.connect(self.segmentationNodeChanged) + self.ui.editor.setMaximumNumberOfUndoStates(10) + + # Select parameter set node if one is found in the scene, and create one otherwise + segmentEditorSingletonTag = "SegmentEditor" + segmentEditorNode = slicer.mrmlScene.GetSingletonNode(segmentEditorSingletonTag, "vtkMRMLSegmentEditorNode") + if segmentEditorNode is None: + segmentEditorNode = slicer.mrmlScene.CreateNodeByClass("vtkMRMLSegmentEditorNode") + segmentEditorNode.UnRegister(None) + segmentEditorNode.SetSingletonTag(segmentEditorSingletonTag) + segmentEditorNode = slicer.mrmlScene.AddNode(segmentEditorNode) + self.ui.editor.setMRMLSegmentEditorNode(segmentEditorNode) + ############################################################################ ########################################################################### # Volume load/close tracker - from PythonQt.qMRMLWidgets import qMRMLNodeComboBox - - self.qNodeSelect = qMRMLNodeComboBox() - self.qNodeSelect.addEnabled = False - self.qNodeSelect.removeEnabled = False - self.qNodeSelect.nodeTypes = ['vtkMRMLScalarVolumeNode'] - self.qNodeSelect.setMRMLScene(slicer.mrmlScene) - self.qNodeSelect.currentNodeChanged.connect(self.logic.volumeChanged) - self.logic.volumeChanged() - ########################################################################### + self.ui.qNodeSelect.currentNodeChanged.connect(self.volumeChanged) - # Set scene in MRML widgets. Make sure that in Qt designer the top-level qMRMLWidget's - # "mrmlSceneChanged(vtkMRMLScene*)" signal in is connected to each MRML widget's. - # "setMRMLScene(vtkMRMLScene*)" slot. - uiWidget.setMRMLScene(slicer.mrmlScene) + ########################################################################### # Connections @@ -244,14 +244,13 @@ def setup(self) -> None: # Buttons self.ui.pbUpgrade.setVisible(False) # it's gliching so let's hide it :D - self.ui.pbUpgrade.connect('clicked(bool)', lambda: self.logic.run_on_background(self.logic.upgrade, (True,), 'Checking for updates...')) - self.ui.pbSendImage.connect('clicked(bool)', lambda: self.logic.sendImage(partial=False)) - self.ui.pbSegment.connect('clicked(bool)', lambda: self.logic.applySegmentation()) + self.ui.pbUpgrade.connect('clicked(bool)', self.upgrade) + self.ui.pbSegment.connect('clicked(bool)', self.applySegmentation) # Preprocessing self.ui.cmbPrepOptions.addItems(['Manual', 'Abdominal CT', 'Lung CT', 'Brain CT', 'Mediastinum CT', 'MR']) - self.ui.cmbPrepOptions.currentTextChanged.connect(lambda new_text: self.setManualPreprocessVis(new_text == 'Manual')) - self.ui.pbApplyPrep.connect('clicked(bool)', lambda: self.logic.applyPreprocess(self.ui.cmbPrepOptions.currentText, self.ui.sldWinLevel.value, self.ui.sldWinWidth.value)) + self.ui.cmbPrepOptions.currentTextChanged.connect(self.setPrepMethod) + self.ui.pbApplyPrep.connect('clicked(bool)', self.preprocess) # Hide unnecessary ROI controls self.ui.widgetROI.findChild("QLabel", "label").hide() @@ -259,7 +258,7 @@ def setup(self) -> None: self.ui.widgetROI.findChild("QLabel", "label_10").hide() self.ui.widgetROI.findChild("QComboBox", "roiTypeComboBox").hide() - # Segmentation Engine + # Segmentation Engine self.engine_list = [ { 'name': 'Classic MedSAM', @@ -318,18 +317,51 @@ def setup(self) -> None: # Segmentation Speed self.ui.cmbSpeed.addItems(['Normal Speed - Highest Quality', 'Faster Segmentation - High Quality', 'Fastest Segmentation - Moderate Quality']) - + self.ui.cmbSpeed.currentTextChanged.connect(self.setSpeed) + self.ui.pbAttach.connect('clicked(bool)', lambda: self._createAndAttachROI()) + self.ui.pbPlaceROI.connect('clicked(bool)', lambda: self._placeROI()) self.ui.pbTwoDim.connect('clicked(bool)', lambda: self.makeROI2D()) self.ui.pbLowerSelection.connect('clicked(bool)', lambda: self.setROIboundary(lower=True)) self.ui.pbUpperSelection.connect('clicked(bool)', lambda: self.setROIboundary(lower=False)) - self.model_path_widget.connect('currentPathChanged(const QString&)', lambda: setattr(self.logic, 'new_model_loaded', True)) - # Make sure parameter node is initialized (needed for module reload) self.initializeParameterNode() self.newEngineSelected('Classic MedSAM') - + + self.updateGUIFromParameterNode() + + def upgrade(self): + self.logic.run_on_background(self.logic.upgrade, (True,), 'Checking for updates...') + + def applySegmentation(self): + if not self._parameterNode.segmentationNode: + self._parameterNode.segmentationNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode") + self._parameterNode.segmentationNode.SetReferenceImageGeometryParameterFromVolumeNode(self._parameterNode.volumeNode) + + if self._parameterNode.embeddingState == 'SINGLE': + self.logic.sendImage(partial=False) + + self.logic.applySegmentation() + + def volumeChanged(self, volumeNode): + self._parameterNode.embeddingState = 'SINGLE' + if self._parameterNode.segmentationNode and volumeNode: + self._parameterNode.segmentationNode.SetReferenceImageGeometryParameterFromVolumeNode(volumeNode) + + def segmentationNodeChanged(self, segmentationNode): + if segmentationNode and self._parameterNode.volumeNode: + segmentationNode.SetReferenceImageGeometryParameterFromVolumeNode(self._parameterNode.volumeNode) + + def setPrepMethod(self, method): + self._parameterNode.prepMethod = method + + def setSpeed(self, speed): + self._parameterNode.speed = speed + + def preprocess(self): + self.logic.applyPreprocess(self._parameterNode.prepMethod, self._parameterNode.prepWinLevel, self._parameterNode.prepWinWidth) + def setManualPreprocessVis(self, visible): self.ui.lblLevel.setVisible(visible) self.ui.lblWidth.setVisible(visible) @@ -337,9 +369,8 @@ def setManualPreprocessVis(self, visible): self.ui.sldWinWidth.setVisible(visible) def newEngineSelected(self, new_engine): + self._parameterNode.engine = new_engine current_engine = list(filter(lambda x: x['name'] == new_engine, self.engine_list))[0] - # inform logic object - self.logic.new_model_loaded = True # load list of submodels self.dont_invoke_submodel_change = True # prevent onchange event to happen self.ui.cmbSubModel.clear() @@ -352,8 +383,7 @@ def newEngineSelected(self, new_engine): ctrl.setVisible(False) # change engine-related paths - self.model_path_widget.currentPath = current_engine['default checkpoint'] - self.updateAllParameters() + self._parameterNode.modelPath = pathlib.Path(current_engine['default checkpoint']) # if there is a submodel, choose the first one if len(current_engine['submodels']) > 0: @@ -364,45 +394,35 @@ def newEngineSelected(self, new_engine): self.logic.download_if_necessary(current_engine['url'], current_engine['default checkpoint']) def newSubmodelSelected(self, new_submodel): + self._parameterNode.submodel = new_submodel if self.dont_invoke_submodel_change: return - current_submodel = list(filter(lambda x: x['name'] == self.ui.cmbEngine.currentText, self.engine_list))[0]['submodels'][new_submodel] - # inform logic object - self.logic.new_model_loaded = True + self._parameterNode.submodel = new_submodel + current_submodel = list(filter(lambda x: x['name'] == new_submodel, self.engine_list))[0]['submodels'][new_submodel] # change submodel-related paths - self.model_path_widget.currentPath = current_submodel['checkpoint'] - self.updateAllParameters() + self._parameterNode.modelPath = pathlib.Path(current_submodel['checkpoint']) # download checkpoints if necessary self.logic.download_if_necessary(current_submodel['url'], current_submodel['checkpoint']) - def selectParameterNode(self): - # Select parameter set node if one is found in the scene, and create one otherwise - segmentEditorSingletonTag = "SegmentEditor" - segmentEditorNode = slicer.mrmlScene.GetSingletonNode(segmentEditorSingletonTag, "vtkMRMLSegmentEditorNode") - if segmentEditorNode is None: - segmentEditorNode = slicer.mrmlScene.CreateNodeByClass("vtkMRMLSegmentEditorNode") - segmentEditorNode.UnRegister(None) - segmentEditorNode.SetSingletonTag(segmentEditorSingletonTag) - segmentEditorNode = slicer.mrmlScene.AddNode(segmentEditorNode) - if self.parameterSetNode == segmentEditorNode: - # nothing changed - return - self.parameterSetNode = segmentEditorNode - self.editor.setMRMLSegmentEditorNode(self.parameterSetNode) - + def _getOrCreateROI(self): + if not self._parameterNode.roiNode: + self._parameterNode.roiNode = slicer.mrmlScene.GetFirstNodeByClass("vtkMRMLMarkupsROINode") + if not self._parameterNode.roiNode: + self._parameterNode.roiNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsROINode", "R") + return self._parameterNode.roiNode def _createAndAttachROI(self): # Make sure there is only one 'R' - if self.logic.volume_node is None: - self.logic.volume_node = slicer.util.getNodesByClass('vtkMRMLScalarVolumeNode')[0] - volumeNode = self.logic.volume_node + if self._parameterNode.volumeNode is None: + slicer.util.errorDisplay("Select a source volume first") + return # Create a new ROI that will be fit to volumeNode - roiNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsROINode", "R") + roiNode = self._getOrCreateROI() cropVolumeParameters = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLCropVolumeParametersNode") - cropVolumeParameters.SetInputVolumeNodeID(volumeNode.GetID()) + cropVolumeParameters.SetInputVolumeNodeID(self._parameterNode.volumeNode.GetID()) cropVolumeParameters.SetROINodeID(roiNode.GetID()) slicer.modules.cropvolume.logic().SnapROIToVoxelGrid(cropVolumeParameters) # optional (rotates the ROI to match the volume axis directions) slicer.modules.cropvolume.logic().FitROIToInputVolume(cropVolumeParameters) @@ -410,27 +430,35 @@ def _createAndAttachROI(self): self.scaleROI(.85) - self.ui.widgetROI.setMRMLMarkupsNode(slicer.util.getNode("R")) - self.updateAllParameters() + def _placeROI(self): + # Make sure there is exactly one 'R' + roiNode = self._getOrCreateROI() + roiNode.RemoveAllControlPoints() + selectionNode = slicer.mrmlScene.GetNodeByID("vtkMRMLSelectionNodeSingleton") + selectionNode.SetReferenceActivePlaceNodeClassName("vtkMRMLMarkupsROINode") + selectionNode.SetReferenceActivePlaceNodeID(roiNode.GetID()) + interactionNode = slicer.mrmlScene.GetNodeByID("vtkMRMLInteractionNodeSingleton") + interactionNode.SetPlaceModePersistence(False) # stop placement mode after placing one ROI + interactionNode.SetCurrentInteractionMode(slicer.vtkMRMLInteractionNode().Place) def scaleROI(self, ratio): # Make sure there is exactly one 'R' - roiNode = slicer.util.getNode('R') + roiNode = self._parameterNode.roiNode roi_size = roiNode.GetSize() roiNode.SetSize(int(roi_size[0] * ratio), int(roi_size[1] * ratio), int(roi_size[2] * ratio)) def makeROI2D(self): # Make sure there is exactly one 'R' - roiNode = slicer.util.getNode('R') + roiNode = self._getOrCreateROI() roi_size = roiNode.GetSize() roiNode.SetSize(roi_size[0], roi_size[1], 1) roi_center = np.array(roiNode.GetCenter()) roiNode.SetCenter([roi_center[0], roi_center[1], slicer.app.layoutManager().sliceWidget("Red").sliceLogic().GetSliceOffset()]) def setROIboundary(self, lower): - roiNode = slicer.util.getNode('R') + roiNode = roiNode = self._getOrCreateROI() bounds = np.zeros(6) roiNode.GetBounds(bounds) @@ -505,6 +533,12 @@ def initializeParameterNode(self) -> None: self.setParameterNode(self.logic.getParameterNode()) + # Select default input nodes if nothing is selected yet to save a few clicks for the user + if not self._parameterNode.volumeNode: + firstVolumeNode = slicer.mrmlScene.GetFirstNodeByClass("vtkMRMLScalarVolumeNode") + if firstVolumeNode: + self._parameterNode.volumeNode = firstVolumeNode + def setParameterNode(self, inputParameterNode: Optional[MedSAMLiteParameterNode]) -> None: """ Set and observe parameter node. @@ -517,42 +551,39 @@ def setParameterNode(self, inputParameterNode: Optional[MedSAMLiteParameterNode] if self._parameterNode: # Note: in the .ui file, a Qt dynamic property called "SlicerParameterName" is set on each # ui element that needs connection. - try: - self._parameterNodeGuiTag = self._parameterNode.connectGui(self.ui) - except: - pass #this part might be invoked before UI is loaded. does not cause any issues but might be confusing for users - self.renderAllParameters() + self._parameterNodeGuiTag = self._parameterNode.connectGui(self.ui) + self.addObserver(self._parameterNode, vtk.vtkCommand.ModifiedEvent, self.updateGUIFromParameterNode) + self.updateGUIFromParameterNode() - def renderAllParameters(self): - if self._parameterNode.modelPath: - self.model_path_widget.currentPath = self._parameterNode.modelPath - if self._parameterNode.roiNode: - self.ui.widgetROI.setMRMLMarkupsNode(slicer.util.getNode("R")) + def updateGUIFromParameterNode(self, unused1=None, unused2=None): + if self._parameterNode is None: + self.ui.editor.setSegmentationNode(None) + self.ui.widgetROI.setMRMLMarkupsNode(None) + self.ui.pbSegment.setEnabled(False) + return + + self.ui.editor.setSegmentationNode(self._parameterNode.segmentationNode) if self._parameterNode.segmentationNode: - self.logic.segment_res_group = self._parameterNode.segmentationNode - + self.ui.editor.setSourceVolumeNode(self._parameterNode.volumeNode) + self.ui.widgetROI.setMRMLMarkupsNode(self._parameterNode.roiNode) - if self._parameterNode.prepMethod: - self.ui.cmbPrepOptions.currentText = self._parameterNode.prepMethod - if self._parameterNode.prepWinLevel: - self.ui.sldWinLevel.value = self._parameterNode.prepWinLevel - if self._parameterNode.prepWinWidth: - self.ui.sldWinWidth.value = self._parameterNode.prepWinWidth - + if self._parameterNode.embeddingState == 'IN_PROGRESS': + self.ui.pbSegment.setEnabled(False) + self.ui.pbSegment.setText('Processing, please wait...') + else: + self.ui.pbSegment.setEnabled((self._parameterNode.volumeNode is not None) and (self._parameterNode.roiNode is not None)) + if self._parameterNode.embeddingState == 'SINGLE': + self.ui.pbSegment.setText('Preprocess and Segment') + elif self._parameterNode.embeddingState == 'FULL': + self.ui.pbSegment.setText('Segment') - def updateAllParameters(self): - self._parameterNode.modelPath = self.model_path_widget.currentPath - try: - self._parameterNode.roiNode = slicer.util.getNode('R') - except: - pass - self._parameterNode.segmentationNode = self.logic.segment_res_group - self._parameterNode.prepMethod = self.ui.cmbPrepOptions.currentText - self._parameterNode.prepWinLevel = self.ui.sldWinLevel.value - self._parameterNode.prepWinWidth = self.ui.sldWinWidth.value - + self.ui.cmbPrepOptions.currentText = self._parameterNode.prepMethod + self.ui.cmbEngine.currentText = self._parameterNode.engine + self.ui.cmbSubModel.currentText = self._parameterNode.submodel + self.ui.cmbSpeed.currentText = self._parameterNode.speed + self.setManualPreprocessVis(self._parameterNode.prepMethod == 'Manual') # # MedSAMLiteLogic @@ -568,15 +599,11 @@ class MedSAMLiteLogic(ScriptedLoadableModuleLogic): https://github.com/Slicer/Slicer/blob/main/Base/Python/slicer/ScriptedLoadableModule.py """ image_data = None - segment_res_group = None - server_ready = False server_process = None - volume_node = None timer = None progressbar = None server_dir = None - widget = None - new_model_loaded = True + loaded_model_path = None backend = None test_mode = False @@ -586,6 +613,7 @@ def __init__(self) -> None: """ ScriptedLoadableModuleLogic.__init__(self) try: # In case the dependencies are not installed, an error will raise + from medsam_interface import MedSAM_Interface # FIXME self.backend = MedSAM_Interface() except: pass @@ -720,6 +748,7 @@ def run_on_background(self, target, args, title, progress_check=None): self.progressbar.close() def download_model(self, url, model_path, event): + import gdown gdown.download_folder(url=url, output=model_path) event.set() @@ -733,12 +762,11 @@ def download_if_necessary(self, model_url, model_path): def run_server(self): #FIXME show that 'Backend is loading...' - self.widget.updateAllParameters() - self.backend.set_engine(self.widget.ui.cmbEngine.currentText) - self.widget.renderAllParameters() - self.backend.MedSAM_CKPT_PATH = self.widget.model_path_widget.currentPath + self.backend.set_engine(self.getParameterNode().engine) + modelPath = self.getParameterNode().modelPath + self.backend.MedSAM_CKPT_PATH = modelPath self.backend.load_model() - self.server_ready = True + self.loaded_model_path = modelPath def progressCheck(self, partial=False): slicer.app.processEvents() @@ -748,32 +776,28 @@ def progressCheck(self, partial=False): if progress_data['layers'] <= progress_data['generated_embeds']: self.progressbar.close() self.timer.stop() - self.widget.ui.pbSegment.setEnabled(True) if partial: segmentation_mask = self.inferSegmentation() self.showSegmentation(segmentation_mask) - self.widget.ui.pbSegment.setText('Single Segmentation') + self.getParameterNode().embeddingState = 'SINGLE' else: - self.widget.ui.pbSegment.setText('Segmentation') - - def volumeChanged(self, node=None): - self.widget.ui.pbSegment.setText('Single Segmentation') + self.getParameterNode().embeddingState = 'FULL' + def _getVolumeNode(self): + return self.getParameterNode().volumeNode def captureImage(self): ######## Set your image path here - self.volume_node = slicer.util.getNodesByClass('vtkMRMLScalarVolumeNode')[0] - self.image_data = slicer.util.arrayFromVolume(self.volume_node) ################ Only one node? + self.image_data = slicer.util.arrayFromVolume(self._getVolumeNode()).copy() if len(self.image_data.shape) == 4 and self.image_data.shape[-1] == 4: # colored image, it can have 4 channels (r,g,b,a) so we remove the last one self.image_data = self.image_data[:,:,:,:3] def sendImage(self, partial=False): - self.widget.ui.pbSegment.setEnabled(False) - self.widget.ui.pbSegment.setText('Sending image, please wait...') + self.getParameterNode().embeddingState = 'IN_PROGRESS' - if self.new_model_loaded or not self.server_ready: + if self.loaded_model_path != self.getParameterNode().modelPath: self.run_server() - self.new_model_loaded = False + ############ Partial segmentation if partial: @@ -793,10 +817,10 @@ def sendImage(self, partial=False): self.timer.timeout.connect(lambda: self.progressCheck(partial)) self.timer.start(1000) - self.backend.speed_level = 1 if 'Normal' in self.widget.ui.cmbSpeed.currentText else 2 if 'Faster' in self.widget.ui.cmbSpeed.currentText else 3 + speed = self.getParameterNode().speed + self.backend.speed_level = 1 if 'Normal' in speed else 2 if 'Faster' in speed else 3 self.backend.set_image(self.image_data, -160, 240, zmin, zmax, recurrent_func=slicer.app.processEvents) - self.widget.updateAllParameters() # self.run_on_background(self.embedding_prep_wrapper, (self.image_data, -160, 240, zmin, zmax), "Preparing image embeddings...", lambda: self.progressCheck(partial)) def embedding_prep_wrapper(self, arr, wmin, wmax, zmin, zmax, event): @@ -808,7 +832,7 @@ def embedding_prep_wrapper(self, arr, wmin, wmax, zmin, zmax, event): def inferSegmentation(self): print('sending infer request...') ################ DEBUG MODE ################ - if self.volume_node is None: + if self._getVolumeNode() is None: self.captureImage() ################ DEBUG MODE ################ @@ -823,42 +847,12 @@ def inferSegmentation(self): def showSegmentation(self, segmentation_mask): segment_volume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode", 'segment_'+str(int(time.time()))) - slicer.util.updateVolumeFromArray(segment_volume, segmentation_mask) - - current_seg_group = self.widget.editor.segmentationNode() - if current_seg_group is None: - if self.segment_res_group is None: - self.segment_res_group = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode") - self.segment_res_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node) - current_seg_group = self.segment_res_group - - try: - check_if_node_is_removed = slicer.util.getNode(current_seg_group.GetID()) # if scene is closed and reopend, this line will raise an error - except: - self.segment_res_group = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode") - self.segment_res_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node) - current_seg_group = self.segment_res_group - - - current_seg_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node) - slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, current_seg_group) - slicer.util.updateSegmentBinaryLabelmapFromArray(segmentation_mask, current_seg_group, segment_volume.GetName(), self.volume_node) - + segmentation_mask_int = segmentation_mask.astype(np.int16) + slicer.util.updateVolumeFromArray(segment_volume, segmentation_mask_int) + slicer.util.updateSegmentBinaryLabelmapFromArray(segmentation_mask_int, self.getParameterNode().segmentationNode, segment_volume.GetName(), self._getVolumeNode()) slicer.mrmlScene.RemoveNode(segment_volume) - - self.widget.updateAllParameters() - - def singleSegmentation(self): - self.sendImage(partial=True) - def applySegmentation(self): - if self.widget.ui.pbSegment.text == 'Single Segmentation': - continueSingle = QMessageBox.question(None,'', "You are using single segmentation option which is faster but is not advised if you want large or multiple regions be segmented in one image. In that case click 'Send Image' button. Do you wish to continue with single segmentation?", QMessageBox.Yes | QMessageBox.No) - if continueSingle == QMessageBox.No: return - self.singleSegmentation() - - return segmentation_mask = self.inferSegmentation() self.showSegmentation(segmentation_mask) @@ -868,7 +862,7 @@ def get_bounding_box(self): # If volume node is transformed, apply that transform to get volume's RAS coordinates transformRasToVolumeRas = vtk.vtkGeneralTransform() - slicer.vtkMRMLTransformNode.GetTransformBetweenNodes(None, self.volume_node.GetParentTransformNode(), transformRasToVolumeRas) + slicer.vtkMRMLTransformNode.GetTransformBetweenNodes(None, self._getVolumeNode().GetParentTransformNode(), transformRasToVolumeRas) bounds = np.zeros(6) roiNode.GetBounds(bounds) @@ -883,7 +877,7 @@ def get_bounding_box(self): # Get voxel coordinates from physical coordinates volumeRasToIjk = vtk.vtkMatrix4x4() - self.volume_node.GetRASToIJKMatrix(volumeRasToIjk) + self._getVolumeNode().GetRASToIJKMatrix(volumeRasToIjk) point_Ijk = [0, 0, 0, 1] volumeRasToIjk.MultiplyPoint(np.append(point_VolumeRas,1.0), point_Ijk) point_Ijk = [ int(round(c)) for c in point_Ijk[0:3] ] @@ -906,8 +900,8 @@ def preprocess_CT(self, win_level=40.0, win_width=400.0): image_data_pre = (image_data_pre - np.min(image_data_pre))/(np.max(image_data_pre)-np.min(image_data_pre))*255.0 image_data_pre = np.uint8(image_data_pre) - self.volume_node.GetDisplayNode().SetAutoWindowLevel(False) - self.volume_node.GetDisplayNode().SetWindowLevelMinMax(0, 255) + self._getVolumeNode().GetDisplayNode().SetAutoWindowLevel(False) + self._getVolumeNode().GetDisplayNode().SetWindowLevelMinMax(0, 255) return image_data_pre @@ -919,14 +913,14 @@ def preprocess_MR(self, lower_percent=0.5, upper_percent=99.5): image_data_pre = (image_data_pre - np.min(image_data_pre))/(np.max(image_data_pre)-np.min(image_data_pre))*255.0 image_data_pre = np.uint8(image_data_pre) - self.volume_node.GetDisplayNode().SetAutoWindowLevel(False) - self.volume_node.GetDisplayNode().SetWindowLevelMinMax(0, 255) + self._getVolumeNode().GetDisplayNode().SetAutoWindowLevel(False) + self._getVolumeNode().GetDisplayNode().SetWindowLevelMinMax(0, 255) return image_data_pre def updateImage(self, new_image): self.image_data[:,:,:] = new_image - slicer.util.arrayFromVolumeModified(self.volume_node) + slicer.util.arrayFromVolumeModified(self._getVolumeNode()) def applyPreprocess(self, method, win_level, win_width): if method == 'MR': @@ -945,7 +939,6 @@ def applyPreprocess(self, method, win_level, win_width): self.updateImage(prep_img) - self.widget.updateAllParameters() # diff --git a/MedSAMLite/Resources/UI/MedSAMLite.ui b/MedSAMLite/Resources/UI/MedSAMLite.ui index 9237051..44861b3 100644 --- a/MedSAMLite/Resources/UI/MedSAMLite.ui +++ b/MedSAMLite/Resources/UI/MedSAMLite.ui @@ -6,25 +6,63 @@ 0 0 - 739 - 1206 + 474 + 1012 - - + + Upgrade Module - + + + + Source Volume: + + + + + + + + vtkMRMLScalarVolumeNode + + + + + + + false + + + false + + + + + + volumeNode + + + + Prepare Data - + + + + Apply + + + + 1 @@ -38,9 +76,26 @@ 40.000000000000000 + + prepWinLevel + - + + + + Preprocessing Options: + + + + + + + Window Level: + + + + 1 @@ -51,150 +106,247 @@ 400.000000000000000 - - - - - - - - - Window Level: + + prepWinWidth - - - Window Width: - - - - - - - Preprocessing Options: + + + - - + + - Apply + Window Width: - - + + - Select the Region of Interest + Segmentation - - - - - Attach ROI + + + + + false - - - - - - 2D Selection + + false - - - - - - Set Current Frame As Selection's Start + + - - - - Set Current Frame As Selection's End + + + + Qt::Vertical - - - - - - - - - - Start Segmentation - - - - - - MedSAM Model: + + + + + 1 + 0 + - - - - - Submodel: + Place ROI - - - - ctkPathLineEdit::Executable|ctkPathLineEdit::Files|ctkPathLineEdit::NoDot|ctkPathLineEdit::NoDotDot|ctkPathLineEdit::Readable - - - - *.pth - *.xml - *.onnx - *.ckpt - + + + + QFrame::StyledPanel + + QFrame::Raised + + + + + + Fit to volume + + + + + + + End at current slice + + + + + + + Start at current slice + + + + + + + Fit to current slice + + + + + + + - - - - - - - - - - Send Image + + + + Qt::Vertical - - - - Select Engine: + + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + Output: + + + + + + + + vtkMRMLScalarVolumeNode + + + + + + + true + + + false + + + true + + + true + + + + + + (Create new) + + + segmentationNode + + + + + + + Engine: + + + + + + + Submodel: + + + + + + + MedSAM Model: + + + + + + + Quality: + + + + + + + + + + + + + ctkPathLineEdit::Executable|ctkPathLineEdit::Files|ctkPathLineEdit::NoDot|ctkPathLineEdit::NoDotDot|ctkPathLineEdit::Readable + + + + *.pth + *.xml + *.onnx + *.ckpt + + + + modelPath + + + + + + + - + - Segmentation + Segment - - - + + + + Qt::Vertical + + + + 20 + 40 + + + + @@ -204,6 +356,11 @@
ctkCollapsibleButton.h
1 + + ctkExpandButton + QToolButton +
ctkExpandButton.h
+
ctkPathLineEdit QWidget @@ -214,6 +371,11 @@ QWidget
ctkSliderWidget.h
+ + qMRMLNodeComboBox + QWidget +
qMRMLNodeComboBox.h
+
qMRMLWidget QWidget @@ -225,7 +387,109 @@ QWidget
qMRMLMarkupsROIWidget.h
+ + qMRMLSegmentEditorWidget + qMRMLWidget +
qMRMLSegmentEditorWidget.h
+
- + + + t_SegmentData + mrmlSceneChanged(vtkMRMLScene*) + qNodeSelect + setMRMLScene(vtkMRMLScene*) + + + 423 + 859 + + + 432 + 58 + + + + + t_SegmentData + mrmlSceneChanged(vtkMRMLScene*) + editor + setMRMLScene(vtkMRMLScene*) + + + 312 + 859 + + + 87 + 729 + + + + + t_SegmentData + mrmlSceneChanged(vtkMRMLScene*) + segmentationNodeSelector + setMRMLScene(vtkMRMLScene*) + + + 368 + 857 + + + 252 + 561 + + + + + segmentationNodeSelector + currentNodeChanged(vtkMRMLNode*) + editor + setSegmentationNode(vtkMRMLNode*) + + + 216 + 557 + + + 236 + 726 + + + + + Button + toggled(bool) + roiOptionsFrame + setVisible(bool) + + + 657 + 261 + + + 656 + 296 + + + + + Button_2 + toggled(bool) + segmentationOptionsFrame + setVisible(bool) + + + 660 + 529 + + + 657 + 574 + + + +