diff --git a/MedSAMLite/MedSAMLite.py b/MedSAMLite/MedSAMLite.py
index 99ce208..83d85cb 100755
--- a/MedSAMLite/MedSAMLite.py
+++ b/MedSAMLite/MedSAMLite.py
@@ -1,6 +1,7 @@
import logging
import os
from typing import Annotated, Optional
+import pathlib
import time
import json
@@ -16,7 +17,7 @@
import zipfile
import slicer
-from slicer import vtkMRMLMarkupsROINode, vtkMRMLSegmentationNode
+from slicer import vtkMRMLMarkupsROINode, vtkMRMLSegmentationNode, vtkMRMLScalarVolumeNode
from slicer.ScriptedLoadableModule import *
from slicer.util import VTKObservationMixin
from slicer.parameterNodeWrapper import (
@@ -31,11 +32,6 @@
from PythonQt.QtCore import QTimer, QByteArray, Qt
from PythonQt.QtGui import QIcon, QPixmap, QMessageBox
-try:
- import gdown
- from medsam_interface import MedSAM_Interface # FIXME
-except:
- pass # no installation anymore, shorter plugin load
MEDSAMLITE_VERSION = 'v0.13'
@@ -124,12 +120,17 @@ def registerSampleData():
@parameterNodeWrapper
class MedSAMLiteParameterNode:
+ volumeNode: vtkMRMLScalarVolumeNode
+ prepMethod: str = 'Manual'
roiNode: vtkMRMLMarkupsROINode
segmentationNode: vtkMRMLSegmentationNode
- modelPath: str
- prepMethod: str
- prepWinLevel: float = 40.0
- prepWinWidth: float = 400.0
+ modelPath: pathlib.Path
+ prepWinLevel: Annotated[float, WithinRange(-2000, 2000)] = 40.0
+ prepWinWidth: Annotated[float, WithinRange(0, 2000)] = 400.0
+ engine: str
+ submodel: str
+ speed: str
+ embeddingState: str
#
@@ -166,7 +167,6 @@ def setup(self) -> None:
# Create logic class. Logic implements all computations that should be possible to run
# in batch mode, without a graphical user interface.
self.logic = MedSAMLiteLogic()
- self.logic.widget = self
self.logic.server_dir = os.path.join(os.path.dirname(__file__), 'Resources/server_essentials')
@@ -200,41 +200,41 @@ def setup(self) -> None:
self.layout.addWidget(uiWidget)
self.ui = slicer.util.childWidgetVariables(uiWidget)
+ # Set scene in MRML widgets. Make sure that in Qt designer the top-level qMRMLWidget's
+ # "mrmlSceneChanged(vtkMRMLScene*)" signal in is connected to each MRML widget's.
+ # "setMRMLScene(vtkMRMLScene*)" slot.
+ uiWidget.setMRMLScene(slicer.mrmlScene)
+
############################################################################
# Model Selection
- self.model_path_widget = self.ui.ctkPathModel
- self.model_path_widget.currentPath = os.path.join(self.logic.server_dir, 'medsam_interface/models/classic/medsam_lite.pth')
- self.logic.new_model_loaded = True
+ self.ui.ctkPathModel.currentPath = os.path.join(self.logic.server_dir, 'medsam_interface/models/classic/medsam_lite.pth')
############################################################################
-
############################################################################
- # Segmentation Module
- import qSlicerSegmentationsModuleWidgetsPythonQt
- self.editor = qSlicerSegmentationsModuleWidgetsPythonQt.qMRMLSegmentEditorWidget()
- self.editor.setMaximumNumberOfUndoStates(10)
- self.selectParameterNode()
- self.editor.setMRMLScene(slicer.mrmlScene)
- self.ui.clbtnOperation.layout().addWidget(self.editor, 5, 0, 1, 2)
+ # Segmentation
+ self.ui.roiOptionsFrame.setVisible(False)
+ self.ui.segmentationOptionsFrame.setVisible(False)
+
+ self.ui.segmentationNodeSelector.currentNodeChanged.connect(self.segmentationNodeChanged)
+ self.ui.editor.setMaximumNumberOfUndoStates(10)
+
+ # Select parameter set node if one is found in the scene, and create one otherwise
+ segmentEditorSingletonTag = "SegmentEditor"
+ segmentEditorNode = slicer.mrmlScene.GetSingletonNode(segmentEditorSingletonTag, "vtkMRMLSegmentEditorNode")
+ if segmentEditorNode is None:
+ segmentEditorNode = slicer.mrmlScene.CreateNodeByClass("vtkMRMLSegmentEditorNode")
+ segmentEditorNode.UnRegister(None)
+ segmentEditorNode.SetSingletonTag(segmentEditorSingletonTag)
+ segmentEditorNode = slicer.mrmlScene.AddNode(segmentEditorNode)
+ self.ui.editor.setMRMLSegmentEditorNode(segmentEditorNode)
+
############################################################################
###########################################################################
# Volume load/close tracker
- from PythonQt.qMRMLWidgets import qMRMLNodeComboBox
-
- self.qNodeSelect = qMRMLNodeComboBox()
- self.qNodeSelect.addEnabled = False
- self.qNodeSelect.removeEnabled = False
- self.qNodeSelect.nodeTypes = ['vtkMRMLScalarVolumeNode']
- self.qNodeSelect.setMRMLScene(slicer.mrmlScene)
- self.qNodeSelect.currentNodeChanged.connect(self.logic.volumeChanged)
- self.logic.volumeChanged()
- ###########################################################################
+ self.ui.qNodeSelect.currentNodeChanged.connect(self.volumeChanged)
- # Set scene in MRML widgets. Make sure that in Qt designer the top-level qMRMLWidget's
- # "mrmlSceneChanged(vtkMRMLScene*)" signal in is connected to each MRML widget's.
- # "setMRMLScene(vtkMRMLScene*)" slot.
- uiWidget.setMRMLScene(slicer.mrmlScene)
+ ###########################################################################
# Connections
@@ -244,14 +244,13 @@ def setup(self) -> None:
# Buttons
self.ui.pbUpgrade.setVisible(False) # it's gliching so let's hide it :D
- self.ui.pbUpgrade.connect('clicked(bool)', lambda: self.logic.run_on_background(self.logic.upgrade, (True,), 'Checking for updates...'))
- self.ui.pbSendImage.connect('clicked(bool)', lambda: self.logic.sendImage(partial=False))
- self.ui.pbSegment.connect('clicked(bool)', lambda: self.logic.applySegmentation())
+ self.ui.pbUpgrade.connect('clicked(bool)', self.upgrade)
+ self.ui.pbSegment.connect('clicked(bool)', self.applySegmentation)
# Preprocessing
self.ui.cmbPrepOptions.addItems(['Manual', 'Abdominal CT', 'Lung CT', 'Brain CT', 'Mediastinum CT', 'MR'])
- self.ui.cmbPrepOptions.currentTextChanged.connect(lambda new_text: self.setManualPreprocessVis(new_text == 'Manual'))
- self.ui.pbApplyPrep.connect('clicked(bool)', lambda: self.logic.applyPreprocess(self.ui.cmbPrepOptions.currentText, self.ui.sldWinLevel.value, self.ui.sldWinWidth.value))
+ self.ui.cmbPrepOptions.currentTextChanged.connect(self.setPrepMethod)
+ self.ui.pbApplyPrep.connect('clicked(bool)', self.preprocess)
# Hide unnecessary ROI controls
self.ui.widgetROI.findChild("QLabel", "label").hide()
@@ -259,7 +258,7 @@ def setup(self) -> None:
self.ui.widgetROI.findChild("QLabel", "label_10").hide()
self.ui.widgetROI.findChild("QComboBox", "roiTypeComboBox").hide()
- # Segmentation Engine
+ # Segmentation Engine
self.engine_list = [
{
'name': 'Classic MedSAM',
@@ -318,18 +317,51 @@ def setup(self) -> None:
# Segmentation Speed
self.ui.cmbSpeed.addItems(['Normal Speed - Highest Quality', 'Faster Segmentation - High Quality', 'Fastest Segmentation - Moderate Quality'])
-
+ self.ui.cmbSpeed.currentTextChanged.connect(self.setSpeed)
+
self.ui.pbAttach.connect('clicked(bool)', lambda: self._createAndAttachROI())
+ self.ui.pbPlaceROI.connect('clicked(bool)', lambda: self._placeROI())
self.ui.pbTwoDim.connect('clicked(bool)', lambda: self.makeROI2D())
self.ui.pbLowerSelection.connect('clicked(bool)', lambda: self.setROIboundary(lower=True))
self.ui.pbUpperSelection.connect('clicked(bool)', lambda: self.setROIboundary(lower=False))
- self.model_path_widget.connect('currentPathChanged(const QString&)', lambda: setattr(self.logic, 'new_model_loaded', True))
-
# Make sure parameter node is initialized (needed for module reload)
self.initializeParameterNode()
self.newEngineSelected('Classic MedSAM')
-
+
+ self.updateGUIFromParameterNode()
+
+ def upgrade(self):
+ self.logic.run_on_background(self.logic.upgrade, (True,), 'Checking for updates...')
+
+ def applySegmentation(self):
+ if not self._parameterNode.segmentationNode:
+ self._parameterNode.segmentationNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode")
+ self._parameterNode.segmentationNode.SetReferenceImageGeometryParameterFromVolumeNode(self._parameterNode.volumeNode)
+
+ if self._parameterNode.embeddingState == 'SINGLE':
+ self.logic.sendImage(partial=False)
+
+ self.logic.applySegmentation()
+
+ def volumeChanged(self, volumeNode):
+ self._parameterNode.embeddingState = 'SINGLE'
+ if self._parameterNode.segmentationNode and volumeNode:
+ self._parameterNode.segmentationNode.SetReferenceImageGeometryParameterFromVolumeNode(volumeNode)
+
+ def segmentationNodeChanged(self, segmentationNode):
+ if segmentationNode and self._parameterNode.volumeNode:
+ segmentationNode.SetReferenceImageGeometryParameterFromVolumeNode(self._parameterNode.volumeNode)
+
+ def setPrepMethod(self, method):
+ self._parameterNode.prepMethod = method
+
+ def setSpeed(self, speed):
+ self._parameterNode.speed = speed
+
+ def preprocess(self):
+ self.logic.applyPreprocess(self._parameterNode.prepMethod, self._parameterNode.prepWinLevel, self._parameterNode.prepWinWidth)
+
def setManualPreprocessVis(self, visible):
self.ui.lblLevel.setVisible(visible)
self.ui.lblWidth.setVisible(visible)
@@ -337,9 +369,8 @@ def setManualPreprocessVis(self, visible):
self.ui.sldWinWidth.setVisible(visible)
def newEngineSelected(self, new_engine):
+ self._parameterNode.engine = new_engine
current_engine = list(filter(lambda x: x['name'] == new_engine, self.engine_list))[0]
- # inform logic object
- self.logic.new_model_loaded = True
# load list of submodels
self.dont_invoke_submodel_change = True # prevent onchange event to happen
self.ui.cmbSubModel.clear()
@@ -352,8 +383,7 @@ def newEngineSelected(self, new_engine):
ctrl.setVisible(False)
# change engine-related paths
- self.model_path_widget.currentPath = current_engine['default checkpoint']
- self.updateAllParameters()
+ self._parameterNode.modelPath = pathlib.Path(current_engine['default checkpoint'])
# if there is a submodel, choose the first one
if len(current_engine['submodels']) > 0:
@@ -364,45 +394,35 @@ def newEngineSelected(self, new_engine):
self.logic.download_if_necessary(current_engine['url'], current_engine['default checkpoint'])
def newSubmodelSelected(self, new_submodel):
+ self._parameterNode.submodel = new_submodel
if self.dont_invoke_submodel_change: return
- current_submodel = list(filter(lambda x: x['name'] == self.ui.cmbEngine.currentText, self.engine_list))[0]['submodels'][new_submodel]
- # inform logic object
- self.logic.new_model_loaded = True
+ self._parameterNode.submodel = new_submodel
+ current_submodel = list(filter(lambda x: x['name'] == new_submodel, self.engine_list))[0]['submodels'][new_submodel]
# change submodel-related paths
- self.model_path_widget.currentPath = current_submodel['checkpoint']
- self.updateAllParameters()
+ self._parameterNode.modelPath = pathlib.Path(current_submodel['checkpoint'])
# download checkpoints if necessary
self.logic.download_if_necessary(current_submodel['url'], current_submodel['checkpoint'])
- def selectParameterNode(self):
- # Select parameter set node if one is found in the scene, and create one otherwise
- segmentEditorSingletonTag = "SegmentEditor"
- segmentEditorNode = slicer.mrmlScene.GetSingletonNode(segmentEditorSingletonTag, "vtkMRMLSegmentEditorNode")
- if segmentEditorNode is None:
- segmentEditorNode = slicer.mrmlScene.CreateNodeByClass("vtkMRMLSegmentEditorNode")
- segmentEditorNode.UnRegister(None)
- segmentEditorNode.SetSingletonTag(segmentEditorSingletonTag)
- segmentEditorNode = slicer.mrmlScene.AddNode(segmentEditorNode)
- if self.parameterSetNode == segmentEditorNode:
- # nothing changed
- return
- self.parameterSetNode = segmentEditorNode
- self.editor.setMRMLSegmentEditorNode(self.parameterSetNode)
-
+ def _getOrCreateROI(self):
+ if not self._parameterNode.roiNode:
+ self._parameterNode.roiNode = slicer.mrmlScene.GetFirstNodeByClass("vtkMRMLMarkupsROINode")
+ if not self._parameterNode.roiNode:
+ self._parameterNode.roiNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsROINode", "R")
+ return self._parameterNode.roiNode
def _createAndAttachROI(self):
# Make sure there is only one 'R'
- if self.logic.volume_node is None:
- self.logic.volume_node = slicer.util.getNodesByClass('vtkMRMLScalarVolumeNode')[0]
- volumeNode = self.logic.volume_node
+ if self._parameterNode.volumeNode is None:
+ slicer.util.errorDisplay("Select a source volume first")
+ return
# Create a new ROI that will be fit to volumeNode
- roiNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsROINode", "R")
+ roiNode = self._getOrCreateROI()
cropVolumeParameters = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLCropVolumeParametersNode")
- cropVolumeParameters.SetInputVolumeNodeID(volumeNode.GetID())
+ cropVolumeParameters.SetInputVolumeNodeID(self._parameterNode.volumeNode.GetID())
cropVolumeParameters.SetROINodeID(roiNode.GetID())
slicer.modules.cropvolume.logic().SnapROIToVoxelGrid(cropVolumeParameters) # optional (rotates the ROI to match the volume axis directions)
slicer.modules.cropvolume.logic().FitROIToInputVolume(cropVolumeParameters)
@@ -410,27 +430,35 @@ def _createAndAttachROI(self):
self.scaleROI(.85)
- self.ui.widgetROI.setMRMLMarkupsNode(slicer.util.getNode("R"))
- self.updateAllParameters()
+ def _placeROI(self):
+ # Make sure there is exactly one 'R'
+ roiNode = self._getOrCreateROI()
+ roiNode.RemoveAllControlPoints()
+ selectionNode = slicer.mrmlScene.GetNodeByID("vtkMRMLSelectionNodeSingleton")
+ selectionNode.SetReferenceActivePlaceNodeClassName("vtkMRMLMarkupsROINode")
+ selectionNode.SetReferenceActivePlaceNodeID(roiNode.GetID())
+ interactionNode = slicer.mrmlScene.GetNodeByID("vtkMRMLInteractionNodeSingleton")
+ interactionNode.SetPlaceModePersistence(False) # stop placement mode after placing one ROI
+ interactionNode.SetCurrentInteractionMode(slicer.vtkMRMLInteractionNode().Place)
def scaleROI(self, ratio):
# Make sure there is exactly one 'R'
- roiNode = slicer.util.getNode('R')
+ roiNode = self._parameterNode.roiNode
roi_size = roiNode.GetSize()
roiNode.SetSize(int(roi_size[0] * ratio), int(roi_size[1] * ratio), int(roi_size[2] * ratio))
def makeROI2D(self):
# Make sure there is exactly one 'R'
- roiNode = slicer.util.getNode('R')
+ roiNode = self._getOrCreateROI()
roi_size = roiNode.GetSize()
roiNode.SetSize(roi_size[0], roi_size[1], 1)
roi_center = np.array(roiNode.GetCenter())
roiNode.SetCenter([roi_center[0], roi_center[1], slicer.app.layoutManager().sliceWidget("Red").sliceLogic().GetSliceOffset()])
def setROIboundary(self, lower):
- roiNode = slicer.util.getNode('R')
+ roiNode = roiNode = self._getOrCreateROI()
bounds = np.zeros(6)
roiNode.GetBounds(bounds)
@@ -505,6 +533,12 @@ def initializeParameterNode(self) -> None:
self.setParameterNode(self.logic.getParameterNode())
+ # Select default input nodes if nothing is selected yet to save a few clicks for the user
+ if not self._parameterNode.volumeNode:
+ firstVolumeNode = slicer.mrmlScene.GetFirstNodeByClass("vtkMRMLScalarVolumeNode")
+ if firstVolumeNode:
+ self._parameterNode.volumeNode = firstVolumeNode
+
def setParameterNode(self, inputParameterNode: Optional[MedSAMLiteParameterNode]) -> None:
"""
Set and observe parameter node.
@@ -517,42 +551,39 @@ def setParameterNode(self, inputParameterNode: Optional[MedSAMLiteParameterNode]
if self._parameterNode:
# Note: in the .ui file, a Qt dynamic property called "SlicerParameterName" is set on each
# ui element that needs connection.
- try:
- self._parameterNodeGuiTag = self._parameterNode.connectGui(self.ui)
- except:
- pass #this part might be invoked before UI is loaded. does not cause any issues but might be confusing for users
- self.renderAllParameters()
+ self._parameterNodeGuiTag = self._parameterNode.connectGui(self.ui)
+ self.addObserver(self._parameterNode, vtk.vtkCommand.ModifiedEvent, self.updateGUIFromParameterNode)
+ self.updateGUIFromParameterNode()
- def renderAllParameters(self):
- if self._parameterNode.modelPath:
- self.model_path_widget.currentPath = self._parameterNode.modelPath
- if self._parameterNode.roiNode:
- self.ui.widgetROI.setMRMLMarkupsNode(slicer.util.getNode("R"))
+ def updateGUIFromParameterNode(self, unused1=None, unused2=None):
+ if self._parameterNode is None:
+ self.ui.editor.setSegmentationNode(None)
+ self.ui.widgetROI.setMRMLMarkupsNode(None)
+ self.ui.pbSegment.setEnabled(False)
+ return
+
+ self.ui.editor.setSegmentationNode(self._parameterNode.segmentationNode)
if self._parameterNode.segmentationNode:
- self.logic.segment_res_group = self._parameterNode.segmentationNode
-
+ self.ui.editor.setSourceVolumeNode(self._parameterNode.volumeNode)
+ self.ui.widgetROI.setMRMLMarkupsNode(self._parameterNode.roiNode)
- if self._parameterNode.prepMethod:
- self.ui.cmbPrepOptions.currentText = self._parameterNode.prepMethod
- if self._parameterNode.prepWinLevel:
- self.ui.sldWinLevel.value = self._parameterNode.prepWinLevel
- if self._parameterNode.prepWinWidth:
- self.ui.sldWinWidth.value = self._parameterNode.prepWinWidth
-
+ if self._parameterNode.embeddingState == 'IN_PROGRESS':
+ self.ui.pbSegment.setEnabled(False)
+ self.ui.pbSegment.setText('Processing, please wait...')
+ else:
+ self.ui.pbSegment.setEnabled((self._parameterNode.volumeNode is not None) and (self._parameterNode.roiNode is not None))
+ if self._parameterNode.embeddingState == 'SINGLE':
+ self.ui.pbSegment.setText('Preprocess and Segment')
+ elif self._parameterNode.embeddingState == 'FULL':
+ self.ui.pbSegment.setText('Segment')
- def updateAllParameters(self):
- self._parameterNode.modelPath = self.model_path_widget.currentPath
- try:
- self._parameterNode.roiNode = slicer.util.getNode('R')
- except:
- pass
- self._parameterNode.segmentationNode = self.logic.segment_res_group
- self._parameterNode.prepMethod = self.ui.cmbPrepOptions.currentText
- self._parameterNode.prepWinLevel = self.ui.sldWinLevel.value
- self._parameterNode.prepWinWidth = self.ui.sldWinWidth.value
-
+ self.ui.cmbPrepOptions.currentText = self._parameterNode.prepMethod
+ self.ui.cmbEngine.currentText = self._parameterNode.engine
+ self.ui.cmbSubModel.currentText = self._parameterNode.submodel
+ self.ui.cmbSpeed.currentText = self._parameterNode.speed
+ self.setManualPreprocessVis(self._parameterNode.prepMethod == 'Manual')
#
# MedSAMLiteLogic
@@ -568,15 +599,11 @@ class MedSAMLiteLogic(ScriptedLoadableModuleLogic):
https://github.com/Slicer/Slicer/blob/main/Base/Python/slicer/ScriptedLoadableModule.py
"""
image_data = None
- segment_res_group = None
- server_ready = False
server_process = None
- volume_node = None
timer = None
progressbar = None
server_dir = None
- widget = None
- new_model_loaded = True
+ loaded_model_path = None
backend = None
test_mode = False
@@ -586,6 +613,7 @@ def __init__(self) -> None:
"""
ScriptedLoadableModuleLogic.__init__(self)
try: # In case the dependencies are not installed, an error will raise
+ from medsam_interface import MedSAM_Interface # FIXME
self.backend = MedSAM_Interface()
except:
pass
@@ -720,6 +748,7 @@ def run_on_background(self, target, args, title, progress_check=None):
self.progressbar.close()
def download_model(self, url, model_path, event):
+ import gdown
gdown.download_folder(url=url, output=model_path)
event.set()
@@ -733,12 +762,11 @@ def download_if_necessary(self, model_url, model_path):
def run_server(self):
#FIXME show that 'Backend is loading...'
- self.widget.updateAllParameters()
- self.backend.set_engine(self.widget.ui.cmbEngine.currentText)
- self.widget.renderAllParameters()
- self.backend.MedSAM_CKPT_PATH = self.widget.model_path_widget.currentPath
+ self.backend.set_engine(self.getParameterNode().engine)
+ modelPath = self.getParameterNode().modelPath
+ self.backend.MedSAM_CKPT_PATH = modelPath
self.backend.load_model()
- self.server_ready = True
+ self.loaded_model_path = modelPath
def progressCheck(self, partial=False):
slicer.app.processEvents()
@@ -748,32 +776,28 @@ def progressCheck(self, partial=False):
if progress_data['layers'] <= progress_data['generated_embeds']:
self.progressbar.close()
self.timer.stop()
- self.widget.ui.pbSegment.setEnabled(True)
if partial:
segmentation_mask = self.inferSegmentation()
self.showSegmentation(segmentation_mask)
- self.widget.ui.pbSegment.setText('Single Segmentation')
+ self.getParameterNode().embeddingState = 'SINGLE'
else:
- self.widget.ui.pbSegment.setText('Segmentation')
-
- def volumeChanged(self, node=None):
- self.widget.ui.pbSegment.setText('Single Segmentation')
+ self.getParameterNode().embeddingState = 'FULL'
+ def _getVolumeNode(self):
+ return self.getParameterNode().volumeNode
def captureImage(self):
######## Set your image path here
- self.volume_node = slicer.util.getNodesByClass('vtkMRMLScalarVolumeNode')[0]
- self.image_data = slicer.util.arrayFromVolume(self.volume_node) ################ Only one node?
+ self.image_data = slicer.util.arrayFromVolume(self._getVolumeNode()).copy()
if len(self.image_data.shape) == 4 and self.image_data.shape[-1] == 4: # colored image, it can have 4 channels (r,g,b,a) so we remove the last one
self.image_data = self.image_data[:,:,:,:3]
def sendImage(self, partial=False):
- self.widget.ui.pbSegment.setEnabled(False)
- self.widget.ui.pbSegment.setText('Sending image, please wait...')
+ self.getParameterNode().embeddingState = 'IN_PROGRESS'
- if self.new_model_loaded or not self.server_ready:
+ if self.loaded_model_path != self.getParameterNode().modelPath:
self.run_server()
- self.new_model_loaded = False
+
############ Partial segmentation
if partial:
@@ -793,10 +817,10 @@ def sendImage(self, partial=False):
self.timer.timeout.connect(lambda: self.progressCheck(partial))
self.timer.start(1000)
- self.backend.speed_level = 1 if 'Normal' in self.widget.ui.cmbSpeed.currentText else 2 if 'Faster' in self.widget.ui.cmbSpeed.currentText else 3
+ speed = self.getParameterNode().speed
+ self.backend.speed_level = 1 if 'Normal' in speed else 2 if 'Faster' in speed else 3
self.backend.set_image(self.image_data, -160, 240, zmin, zmax, recurrent_func=slicer.app.processEvents)
- self.widget.updateAllParameters()
# self.run_on_background(self.embedding_prep_wrapper, (self.image_data, -160, 240, zmin, zmax), "Preparing image embeddings...", lambda: self.progressCheck(partial))
def embedding_prep_wrapper(self, arr, wmin, wmax, zmin, zmax, event):
@@ -808,7 +832,7 @@ def embedding_prep_wrapper(self, arr, wmin, wmax, zmin, zmax, event):
def inferSegmentation(self):
print('sending infer request...')
################ DEBUG MODE ################
- if self.volume_node is None:
+ if self._getVolumeNode() is None:
self.captureImage()
################ DEBUG MODE ################
@@ -823,42 +847,12 @@ def inferSegmentation(self):
def showSegmentation(self, segmentation_mask):
segment_volume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode", 'segment_'+str(int(time.time())))
- slicer.util.updateVolumeFromArray(segment_volume, segmentation_mask)
-
- current_seg_group = self.widget.editor.segmentationNode()
- if current_seg_group is None:
- if self.segment_res_group is None:
- self.segment_res_group = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode")
- self.segment_res_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node)
- current_seg_group = self.segment_res_group
-
- try:
- check_if_node_is_removed = slicer.util.getNode(current_seg_group.GetID()) # if scene is closed and reopend, this line will raise an error
- except:
- self.segment_res_group = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode")
- self.segment_res_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node)
- current_seg_group = self.segment_res_group
-
-
- current_seg_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node)
- slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, current_seg_group)
- slicer.util.updateSegmentBinaryLabelmapFromArray(segmentation_mask, current_seg_group, segment_volume.GetName(), self.volume_node)
-
+ segmentation_mask_int = segmentation_mask.astype(np.int16)
+ slicer.util.updateVolumeFromArray(segment_volume, segmentation_mask_int)
+ slicer.util.updateSegmentBinaryLabelmapFromArray(segmentation_mask_int, self.getParameterNode().segmentationNode, segment_volume.GetName(), self._getVolumeNode())
slicer.mrmlScene.RemoveNode(segment_volume)
-
- self.widget.updateAllParameters()
-
- def singleSegmentation(self):
- self.sendImage(partial=True)
-
def applySegmentation(self):
- if self.widget.ui.pbSegment.text == 'Single Segmentation':
- continueSingle = QMessageBox.question(None,'', "You are using single segmentation option which is faster but is not advised if you want large or multiple regions be segmented in one image. In that case click 'Send Image' button. Do you wish to continue with single segmentation?", QMessageBox.Yes | QMessageBox.No)
- if continueSingle == QMessageBox.No: return
- self.singleSegmentation()
-
- return
segmentation_mask = self.inferSegmentation()
self.showSegmentation(segmentation_mask)
@@ -868,7 +862,7 @@ def get_bounding_box(self):
# If volume node is transformed, apply that transform to get volume's RAS coordinates
transformRasToVolumeRas = vtk.vtkGeneralTransform()
- slicer.vtkMRMLTransformNode.GetTransformBetweenNodes(None, self.volume_node.GetParentTransformNode(), transformRasToVolumeRas)
+ slicer.vtkMRMLTransformNode.GetTransformBetweenNodes(None, self._getVolumeNode().GetParentTransformNode(), transformRasToVolumeRas)
bounds = np.zeros(6)
roiNode.GetBounds(bounds)
@@ -883,7 +877,7 @@ def get_bounding_box(self):
# Get voxel coordinates from physical coordinates
volumeRasToIjk = vtk.vtkMatrix4x4()
- self.volume_node.GetRASToIJKMatrix(volumeRasToIjk)
+ self._getVolumeNode().GetRASToIJKMatrix(volumeRasToIjk)
point_Ijk = [0, 0, 0, 1]
volumeRasToIjk.MultiplyPoint(np.append(point_VolumeRas,1.0), point_Ijk)
point_Ijk = [ int(round(c)) for c in point_Ijk[0:3] ]
@@ -906,8 +900,8 @@ def preprocess_CT(self, win_level=40.0, win_width=400.0):
image_data_pre = (image_data_pre - np.min(image_data_pre))/(np.max(image_data_pre)-np.min(image_data_pre))*255.0
image_data_pre = np.uint8(image_data_pre)
- self.volume_node.GetDisplayNode().SetAutoWindowLevel(False)
- self.volume_node.GetDisplayNode().SetWindowLevelMinMax(0, 255)
+ self._getVolumeNode().GetDisplayNode().SetAutoWindowLevel(False)
+ self._getVolumeNode().GetDisplayNode().SetWindowLevelMinMax(0, 255)
return image_data_pre
@@ -919,14 +913,14 @@ def preprocess_MR(self, lower_percent=0.5, upper_percent=99.5):
image_data_pre = (image_data_pre - np.min(image_data_pre))/(np.max(image_data_pre)-np.min(image_data_pre))*255.0
image_data_pre = np.uint8(image_data_pre)
- self.volume_node.GetDisplayNode().SetAutoWindowLevel(False)
- self.volume_node.GetDisplayNode().SetWindowLevelMinMax(0, 255)
+ self._getVolumeNode().GetDisplayNode().SetAutoWindowLevel(False)
+ self._getVolumeNode().GetDisplayNode().SetWindowLevelMinMax(0, 255)
return image_data_pre
def updateImage(self, new_image):
self.image_data[:,:,:] = new_image
- slicer.util.arrayFromVolumeModified(self.volume_node)
+ slicer.util.arrayFromVolumeModified(self._getVolumeNode())
def applyPreprocess(self, method, win_level, win_width):
if method == 'MR':
@@ -945,7 +939,6 @@ def applyPreprocess(self, method, win_level, win_width):
self.updateImage(prep_img)
- self.widget.updateAllParameters()
#
diff --git a/MedSAMLite/Resources/UI/MedSAMLite.ui b/MedSAMLite/Resources/UI/MedSAMLite.ui
index 9237051..44861b3 100644
--- a/MedSAMLite/Resources/UI/MedSAMLite.ui
+++ b/MedSAMLite/Resources/UI/MedSAMLite.ui
@@ -6,25 +6,63 @@
0
0
- 739
- 1206
+ 474
+ 1012
-
- -
+
+
-
Upgrade Module
- -
+
-
+
+
+ Source Volume:
+
+
+
+ -
+
+
+
+ vtkMRMLScalarVolumeNode
+
+
+
+
+
+
+ false
+
+
+ false
+
+
+
+
+
+ volumeNode
+
+
+
+ -
Prepare Data
-
-
+
-
+
+
+ Apply
+
+
+
+ -
1
@@ -38,9 +76,26 @@
40.000000000000000
+
+ prepWinLevel
+
- -
+
-
+
+
+ Preprocessing Options:
+
+
+
+ -
+
+
+ Window Level:
+
+
+
+ -
1
@@ -51,150 +106,247 @@
400.000000000000000
-
-
- -
-
-
- -
-
-
- Window Level:
+
+ prepWinWidth
-
-
-
- Window Width:
-
-
-
- -
-
-
- Preprocessing Options:
+
+
+
- -
-
+
-
+
- Apply
+ Window Width:
- -
-
+
-
+
- Select the Region of Interest
+ Segmentation
-
-
-
-
-
- Attach ROI
+
+
-
+
+
+ false
-
-
- -
-
-
- 2D Selection
+
+ false
-
-
- -
-
-
- Set Current Frame As Selection's Start
+
+
- -
-
-
- Set Current Frame As Selection's End
+
-
+
+
+ Qt::Vertical
- -
-
-
-
-
-
- -
-
-
- Start Segmentation
-
-
-
-
-
-
- MedSAM Model:
+
-
+
+
+
+ 1
+ 0
+
-
-
- -
-
- Submodel:
+ Place ROI
- -
-
-
- ctkPathLineEdit::Executable|ctkPathLineEdit::Files|ctkPathLineEdit::NoDot|ctkPathLineEdit::NoDotDot|ctkPathLineEdit::Readable
-
-
-
- *.pth
- *.xml
- *.onnx
- *.ckpt
-
+
-
+
+
+ QFrame::StyledPanel
+
+ QFrame::Raised
+
+
+
-
+
+
+ Fit to volume
+
+
+
+ -
+
+
+ End at current slice
+
+
+
+ -
+
+
+ Start at current slice
+
+
+
+ -
+
+
+ Fit to current slice
+
+
+
+ -
+
+
+
- -
-
-
- -
-
-
- -
-
-
- Send Image
+
-
+
+
+ Qt::Vertical
- -
-
-
- Select Engine:
+
-
+
+
+ QFrame::StyledPanel
+
+
+ QFrame::Raised
+
+
-
+
+
+ Output:
+
+
+
+ -
+
+
+
+ vtkMRMLScalarVolumeNode
+
+
+
+
+
+
+ true
+
+
+ false
+
+
+ true
+
+
+ true
+
+
+
+
+
+ (Create new)
+
+
+ segmentationNode
+
+
+
+ -
+
+
+ Engine:
+
+
+
+ -
+
+
+ Submodel:
+
+
+
+ -
+
+
+ MedSAM Model:
+
+
+
+ -
+
+
+ Quality:
+
+
+
+ -
+
+
+ -
+
+
+ -
+
+
+ ctkPathLineEdit::Executable|ctkPathLineEdit::Files|ctkPathLineEdit::NoDot|ctkPathLineEdit::NoDotDot|ctkPathLineEdit::Readable
+
+
+
+ *.pth
+ *.xml
+ *.onnx
+ *.ckpt
+
+
+
+ modelPath
+
+
+
+ -
+
+
+
- -
+
-
- Segmentation
+ Segment
- -
-
-
+ -
+
+
+ Qt::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
@@ -204,6 +356,11 @@
1
+
+ ctkExpandButton
+ QToolButton
+
+
ctkPathLineEdit
QWidget
@@ -214,6 +371,11 @@
QWidget
+
+ qMRMLNodeComboBox
+ QWidget
+
+
qMRMLWidget
QWidget
@@ -225,7 +387,109 @@
QWidget
+
+ qMRMLSegmentEditorWidget
+ qMRMLWidget
+ qMRMLSegmentEditorWidget.h
+
-
+
+
+ t_SegmentData
+ mrmlSceneChanged(vtkMRMLScene*)
+ qNodeSelect
+ setMRMLScene(vtkMRMLScene*)
+
+
+ 423
+ 859
+
+
+ 432
+ 58
+
+
+
+
+ t_SegmentData
+ mrmlSceneChanged(vtkMRMLScene*)
+ editor
+ setMRMLScene(vtkMRMLScene*)
+
+
+ 312
+ 859
+
+
+ 87
+ 729
+
+
+
+
+ t_SegmentData
+ mrmlSceneChanged(vtkMRMLScene*)
+ segmentationNodeSelector
+ setMRMLScene(vtkMRMLScene*)
+
+
+ 368
+ 857
+
+
+ 252
+ 561
+
+
+
+
+ segmentationNodeSelector
+ currentNodeChanged(vtkMRMLNode*)
+ editor
+ setSegmentationNode(vtkMRMLNode*)
+
+
+ 216
+ 557
+
+
+ 236
+ 726
+
+
+
+
+ Button
+ toggled(bool)
+ roiOptionsFrame
+ setVisible(bool)
+
+
+ 657
+ 261
+
+
+ 656
+ 296
+
+
+
+
+ Button_2
+ toggled(bool)
+ segmentationOptionsFrame
+ setVisible(bool)
+
+
+ 660
+ 529
+
+
+ 657
+ 574
+
+
+
+