diff --git a/ASO/ASO.py b/ASO/ASO.py index 8385e40..9d31990 100644 --- a/ASO/ASO.py +++ b/ASO/ASO.py @@ -15,15 +15,48 @@ QGridLayout, QMediaPlayer, ) +import pkg_resources from slicer.ScriptedLoadableModule import * -from slicer.util import VTKObservationMixin +from slicer.util import VTKObservationMixin,pip_install from functools import partial +import platform from ASO_Method.IOS import Auto_IOS, Semi_IOS from ASO_Method.CBCT import Semi_CBCT, Auto_CBCT from ASO_Method.Method import Method from ASO_Method.Progress import Display +def check_lib_installed(lib_name, required_version=None): + try: + installed_version = pkg_resources.get_distribution(lib_name).version + if required_version and installed_version != required_version: + return False + return True + except pkg_resources.DistributionNotFound: + return False + +# import csv + +def install_function(): + libs = [('vtk', None), ('torch', None), ('monai', None),('pytorch_lightning',None),('dicom2nifti',None)] + libs_to_install = [] + for lib, version in libs: + if not check_lib_installed(lib, version): + libs_to_install.append((lib, version)) + + if libs_to_install: + message = "The following libraries are not installed or need updating:\n" + message += "\n".join([f"{lib}=={version}" if version else lib for lib, version in libs_to_install]) + message += "\n\nDo you want to install/update these libraries?\n Doing it could break other modules" + user_choice = slicer.util.confirmYesNoDisplay(message) + + if user_choice: + for lib, version in libs_to_install: + lib_version = f'{lib}=={version}' if version else lib + pip_install(lib_version) + else : + return False + return True class ASO(ScriptedLoadableModule): """Uses ScriptedLoadableModule base class, available at: @@ -817,6 +850,7 @@ def enableCheckbox(self): """ def onPredictButton(self): + install_function() """Function to launch the prediction""" error = self.ActualMeth.TestProcess( input_folder=self.ui.lineEditScanLmPath.text, diff --git a/ASO_CBCT/ASO_CBCT_utils/Net.py b/ASO_CBCT/ASO_CBCT_utils/Net.py index 474e706..56622b8 100644 --- a/ASO_CBCT/ASO_CBCT_utils/Net.py +++ b/ASO_CBCT/ASO_CBCT_utils/Net.py @@ -1,29 +1,21 @@ from slicer.util import pip_install, pip_uninstall #try to upgrade pip -try: - pip_install("pip -q --upgrade") -except: - pass -try: - import torch -except ImportError: - pip_install('torch') - import torch + + +import torch + import torch.nn as nn import torch.optim as optim -try: - import pytorch_lightning as pl -except ImportError: - pip_install("pytorch_lightning -q") - import pytorch_lightning as pl -pip_uninstall("monai -q") -pip_install("monai -q") +import pytorch_lightning as pl + + + from monai.networks.nets.densenet import DenseNet169 # Different Network diff --git a/ASO_CBCT/ASO_CBCT_utils/utils.py b/ASO_CBCT/ASO_CBCT_utils/utils.py index 313d6b8..e8b3708 100644 --- a/ASO_CBCT/ASO_CBCT_utils/utils.py +++ b/ASO_CBCT/ASO_CBCT_utils/utils.py @@ -26,11 +26,9 @@ ) from vtkmodules.vtkFiltersGeneral import vtkTransformPolyDataFilter -try: - import dicom2nifti -except ImportError: - pip_install("dicom2nifti -q") - import dicom2nifti + +import dicom2nifti + cross = lambda x, y: np.cross( x, y