From 14c756fdd8ec49af17b083f07deae482c1013f6b Mon Sep 17 00:00:00 2001 From: Reza Asakereh Date: Fri, 4 Oct 2024 10:03:17 -0400 Subject: [PATCH] Medficient Added --- MedSAMLite/MedSAMLite.py | 20 +- .../medsam_interface/engines/__init__.py | 2 +- .../medsam_interface/engines/medficientsam.py | 239 +++ .../medsam_interface/engines/src/__init__.py | 0 .../engines/src/data/__init__.py | 0 .../engines/src/data/components/__init__.py | 0 .../src/data/components/medsam_dataset.py | 370 ++++ .../engines/src/data/medsam_datamodule.py | 94 + .../engines/src/export_onnx.py | 189 ++ .../engines/src/export_torch.py | 50 + .../medsam_interface/engines/src/infer.py | 188 ++ .../engines/src/losses/SAMLoss.py | 39 + .../engines/src/losses/__init__.py | 1 + .../engines/src/losses/components/IoULoss.py | 28 + .../engines/src/metrics/__init__.py | 0 .../engines/src/metrics/generalized_dice.py | 18 + .../engines/src/models/__init__.py | 0 .../engines/src/models/base_sam/__init__.py | 1 + .../engines/src/models/base_sam/sam.py | 113 ++ .../engines/src/models/distill_module.py | 83 + .../engines/src/models/efficientvit/README.md | 1 + .../src/models/efficientvit/__init__.py | 0 .../src/models/efficientvit/apps/__init__.py | 0 .../apps/data_provider/__init__.py | 7 + .../apps/data_provider/augment/__init__.py | 6 + .../apps/data_provider/augment/bbox.py | 30 + .../apps/data_provider/augment/color_aug.py | 78 + .../efficientvit/apps/data_provider/base.py | 199 +++ .../random_resolution/__init__.py | 7 + .../random_resolution/_data_loader.py | 1538 +++++++++++++++++ .../random_resolution/_data_worker.py | 358 ++++ .../random_resolution/controller.py | 92 + .../src/models/efficientvit/apps/setup.py | 135 ++ .../efficientvit/apps/trainer/__init__.py | 6 + .../models/efficientvit/apps/trainer/base.py | 299 ++++ .../efficientvit/apps/trainer/run_config.py | 115 ++ .../efficientvit/apps/utils/__init__.py | 12 + .../models/efficientvit/apps/utils/dist.py | 71 + .../src/models/efficientvit/apps/utils/ema.py | 42 + .../models/efficientvit/apps/utils/export.py | 45 + .../models/efficientvit/apps/utils/init.py | 66 + .../src/models/efficientvit/apps/utils/lr.py | 44 + .../models/efficientvit/apps/utils/metric.py | 33 + .../models/efficientvit/apps/utils/misc.py | 101 ++ .../src/models/efficientvit/apps/utils/opt.py | 28 + .../src/models/efficientvit/cls_model_zoo.py | 79 + .../models/efficientvit/clscore/__init__.py | 0 .../clscore/data_provider/__init__.py | 5 + .../clscore/data_provider/imagenet.py | 123 ++ .../efficientvit/clscore/trainer/__init__.py | 6 + .../clscore/trainer/cls_run_config.py | 18 + .../clscore/trainer/cls_trainer.py | 233 +++ .../clscore/trainer/utils/__init__.py | 7 + .../clscore/trainer/utils/label_smooth.py | 18 + .../clscore/trainer/utils/metric.py | 23 + .../clscore/trainer/utils/mixup.py | 65 + .../models/efficientvit/models/__init__.py | 0 .../models/efficientvit/__init__.py | 8 + .../models/efficientvit/backbone.py | 376 ++++ .../efficientvit/models/efficientvit/cls.py | 162 ++ .../efficientvit/models/efficientvit/sam.py | 664 +++++++ .../efficientvit/models/efficientvit/seg.py | 343 ++++ .../models/efficientvit/models/nn/__init__.py | 8 + .../src/models/efficientvit/models/nn/act.py | 30 + .../src/models/efficientvit/models/nn/drop.py | 88 + .../src/models/efficientvit/models/nn/norm.py | 137 ++ .../src/models/efficientvit/models/nn/ops.py | 614 +++++++ .../efficientvit/models/utils/__init__.py | 7 + .../models/efficientvit/models/utils/list.py | 53 + .../efficientvit/models/utils/network.py | 73 + .../efficientvit/models/utils/random.py | 65 + .../src/models/efficientvit/sam_model_zoo.py | 51 + .../samcore/data_provider/__init__.py | 1 + .../efficientvit/samcore/data_provider/sam.py | 169 ++ .../samcore/data_provider/utils.py | 194 +++ .../efficientvit/samcore/trainer/__init__.py | 2 + .../samcore/trainer/sam_run_config.py | 9 + .../samcore/trainer/sam_trainer.py | 302 ++++ .../efficientvit/samcore/trainer/utils.py | 318 ++++ .../src/models/efficientvit/seg_model_zoo.py | 70 + .../engines/src/models/finetune_module.py | 113 ++ .../src/models/lite_medsam/__init__.py | 1 + .../engines/src/models/lite_medsam/sam.py | 64 + .../src/models/lite_medsam/tiny_vit.py | 706 ++++++++ .../engines/src/models/onnx/__init__.py | 2 + .../engines/src/models/onnx/decoder.py | 55 + .../engines/src/models/onnx/encoder.py | 93 + .../src/models/segment_anything/README.md | 1 + .../src/models/segment_anything/__init__.py | 15 + .../automatic_mask_generator.py | 372 ++++ .../src/models/segment_anything/build_sam.py | 107 ++ .../segment_anything/modeling/__init__.py | 11 + .../segment_anything/modeling/common.py | 43 + .../modeling/image_encoder.py | 395 +++++ .../segment_anything/modeling/mask_decoder.py | 176 ++ .../modeling/prompt_encoder.py | 214 +++ .../models/segment_anything/modeling/sam.py | 174 ++ .../segment_anything/modeling/transformer.py | 240 +++ .../src/models/segment_anything/predictor.py | 269 +++ .../models/segment_anything/utils/__init__.py | 5 + .../src/models/segment_anything/utils/amg.py | 346 ++++ .../src/models/segment_anything/utils/onnx.py | 144 ++ .../segment_anything/utils/transforms.py | 102 ++ .../engines/src/schedulers/__init__.py | 122 ++ .../medsam_interface/engines/src/train.py | 126 ++ .../engines/src/utils/__init__.py | 6 + .../engines/src/utils/instantiators.py | 56 + .../engines/src/utils/logging_utils.py | 57 + .../engines/src/utils/multiprocessing.py | 36 + .../engines/src/utils/pylogger.py | 51 + .../engines/src/utils/rich_utils.py | 99 ++ .../engines/src/utils/transforms.py | 155 ++ .../engines/src/utils/utils.py | 92 + .../engines/src/utils/visualize.py | 73 + .../medsam_interface/interface_impl.py | 2 +- 115 files changed, 13580 insertions(+), 12 deletions(-) create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/medficientsam.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/data/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/data/components/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/data/components/medsam_dataset.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/data/medsam_datamodule.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/export_onnx.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/export_torch.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/infer.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/losses/SAMLoss.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/losses/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/losses/components/IoULoss.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/metrics/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/metrics/generalized_dice.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/base_sam/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/base_sam/sam.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/distill_module.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/README.md create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/augment/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/augment/bbox.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/augment/color_aug.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/base.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/_data_loader.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/_data_worker.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/controller.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/setup.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/trainer/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/trainer/base.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/trainer/run_config.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/dist.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/ema.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/export.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/init.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/lr.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/metric.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/misc.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/opt.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/cls_model_zoo.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/data_provider/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/data_provider/imagenet.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/cls_run_config.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/cls_trainer.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/label_smooth.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/metric.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/mixup.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/backbone.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/cls.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/sam.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/seg.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/act.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/drop.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/norm.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/ops.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/list.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/network.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/random.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/sam_model_zoo.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/data_provider/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/data_provider/sam.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/data_provider/utils.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/sam_run_config.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/sam_trainer.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/utils.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/seg_model_zoo.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/finetune_module.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/lite_medsam/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/lite_medsam/sam.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/lite_medsam/tiny_vit.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/onnx/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/onnx/decoder.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/onnx/encoder.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/README.md create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/automatic_mask_generator.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/build_sam.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/common.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/image_encoder.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/mask_decoder.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/prompt_encoder.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/sam.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/transformer.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/predictor.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/amg.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/onnx.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/transforms.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/schedulers/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/train.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/__init__.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/instantiators.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/logging_utils.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/multiprocessing.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/pylogger.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/rich_utils.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/transforms.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/utils.py create mode 100644 MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/visualize.py diff --git a/MedSAMLite/MedSAMLite.py b/MedSAMLite/MedSAMLite.py index 7e86529..99ce208 100755 --- a/MedSAMLite/MedSAMLite.py +++ b/MedSAMLite/MedSAMLite.py @@ -37,7 +37,7 @@ except: pass # no installation anymore, shorter plugin load -MEDSAMLITE_VERSION = 'v0.12' +MEDSAMLITE_VERSION = 'v0.13' # # MedSAMLite @@ -279,6 +279,15 @@ def setup(self) -> None: 'url': 'https://drive.google.com/drive/folders/1FTwy6uOUFIrWnrkBbTNufv8N9r34hmeG?usp=sharing', 'submodels': {} }, + { + 'name': 'Medficient SAM', + 'description': 'Medficient SAM is an efficient and high accuracy alternative to classic MedSAMLite that can benefit from an existing NVIDIA GPU for faster segmentations. No approximate segmentation support.', + 'default checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/medficient/model.pth'), + 'controls to hide': [self.ui.cmbSpeed, self.ui.lblSubModel, self.ui.cmbSubModel], + 'controls to show': [], + 'url': 'https://drive.google.com/drive/folders/1gzNPIEe9NX444EaFEHw58Wt23q_5OyNJ?usp=sharing', + 'submodels': {} + }, { 'name': 'DAFT MedSAM', 'description': 'DAFT MedSAM is one of the fastest engines as it uses a relatively smaller data-specific model and OpenVINO backend. No approximate segmentation nor GPU support and need for user\'s mindful model selection are the cons.', @@ -299,15 +308,6 @@ def setup(self) -> None: 'XRay': {'checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/daft/XRay/image_encoder.xml'), 'url': 'https://drive.google.com/drive/folders/120gqhi-psC0c1W-D18iXiya9zuH2a9nX?usp=drive_link'}, } }, - # { - # 'name': 'Medficient SAM', - # 'description': 'Medficient SAM [.... placeholder ....]', - # 'default checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/medficient/model.pth'), - # 'controls to hide': [self.ui.cmbSpeed, self.ui.lblSubModel, self.ui.cmbSubModel], - # 'controls to show': [], - # 'url': '', - # 'submodels': {} - # }, ] self.ui.cmbEngine.addItems(list(map(lambda x: x['name'], self.engine_list))) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/__init__.py index a59c355..6106dba 100644 --- a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/__init__.py +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/__init__.py @@ -1,4 +1,4 @@ from .classicmedsam import ClassicMedSAM from .ovmedsam import OVMedSAMCore from .DAFTsam import DAFTSAMCore -# from .medficientsam2 import MedficientSAMCore +from .medficientsam import MedficientSAMCore diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/medficientsam.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/medficientsam.py new file mode 100644 index 0000000..2569f64 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/medficientsam.py @@ -0,0 +1,239 @@ +### Code and models are adopted from https://github.com/hieplpvip/medficientsam +### `engines.src` is almost an identical copy of https://github.com/hieplpvip/medficientsam/tree/59504938bb37ab7e2832ede358051976e740efe5/src + +import argparse +import sys +from datetime import datetime +from glob import glob +from os.path import join, basename +from pathlib import Path +from time import time +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.v2 as transforms +from collections import OrderedDict + +sys.path.append(str(Path(__file__).parent)) + +class ResizeLongestSide(torch.nn.Module): + def __init__( + self, + long_side_length: int, + interpolation: str, + ) -> None: + super().__init__() + self.long_side_length = long_side_length + self.interpolation = interpolation + + def forward(self, image: torch.Tensor) -> torch.Tensor: + oldh, oldw = image.shape[-2:] + if max(oldh, oldw) == self.long_side_length: + return image + newh, neww = self.get_preprocess_shape(oldh, oldw, self.long_side_length) + return F.interpolate( + image, (newh, neww), mode=self.interpolation, align_corners=False + ) + + @staticmethod + def get_preprocess_shape( + oldh: int, + oldw: int, + long_side_length: int, + ) -> Tuple[int, int]: + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) + + +class MinMaxScale(torch.nn.Module): + def forward(self, image: torch.Tensor) -> torch.Tensor: + assert len(image.shape) >= 3 and image.shape[-3] == 3 + min_val = image.amin((-3, -2, -1), keepdim=True) + max_val = image.amax((-3, -2, -1), keepdim=True) + return (image - min_val) / torch.clip(max_val - min_val, min=1e-8, max=None) + + +class PadToSquare(torch.nn.Module): + def __init__(self, target_size: int) -> None: + super().__init__() + self.target_size = target_size + + def forward(self, image: torch.Tensor) -> torch.Tensor: + h, w = image.shape[-2:] + return F.pad(image, (0, self.target_size - w, 0, self.target_size - h), value=0) + + +def get_bbox(mask: np.ndarray) -> np.ndarray: + y_indices, x_indices = np.where(mask > 0) + x_min, x_max = np.min(x_indices), np.max(x_indices) + y_min, y_max = np.min(y_indices), np.max(y_indices) + bboxes = np.array([x_min, y_min, x_max, y_max]) + return bboxes + + +def resize_box( + box: np.ndarray, + original_size: Tuple[int, int], + prompt_encoder_input_size: int, +) -> np.ndarray: + new_box = np.zeros_like(box) + ratio = prompt_encoder_input_size / max(original_size) + for i in range(len(box)): + new_box[i] = int(box[i] * ratio) + + return new_box + + +def get_image_transform( + long_side_length: int, + min_max_scale: bool = True, + normalize: bool = False, + pixel_mean: Optional[List[float]] = None, + pixel_std: Optional[List[float]] = None, + interpolation: str = "bilinear", +) -> transforms.Transform: + tsfm = [ + ResizeLongestSide(long_side_length, interpolation), + transforms.ToDtype(dtype=torch.float32, scale=False), + ] + if min_max_scale: + tsfm.append(MinMaxScale()) + if normalize: + tsfm.append(transforms.Normalize(pixel_mean, pixel_std)) + tsfm.append(PadToSquare(long_side_length)) + return transforms.Compose(tsfm) + + +class MedficientSAMCore: + model = None + device = None + MedSAM_CKPT_PATH = None + + H = None + W = None + image_shape = None + embeddings = None + + def __init__(self): + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + def load_model(self): + self.model = torch.load(self.MedSAM_CKPT_PATH, map_location="cpu") + self.model.to(self.device) + self.model.eval() + + def get_progress(self): + return {'layers': 100 if self.image_shape is None else self.image_shape[0], 'generated_embeds': len(self.embeddings)} + + def set_image(self, image_data, wmin, wmax, zmin, zmax, recurrent_func): + self.embeddings = [] + + self.image_shape = image_data.shape + self.original_size = image_data.shape[-2:] + + if len(image_data.shape) == 3: + # gray: (D, H, W) -> (D, 3, H, W) + tsfm_img_3D = np.repeat(image_data[:, None, ...], 3, axis=1) + else: + # rgb: (D, H, W, 3) -> (D, 3, H, W) + tsfm_img_3D = np.transpose(image_data, (0, 3, 1, 2)) + + transform_image = get_image_transform(long_side_length=512) + tsfm_img_3D = transform_image(torch.tensor(tsfm_img_3D, dtype=torch.uint8)) + + for z in range(image_data.shape[0]): + if recurrent_func is not None: + recurrent_func() + image_embedding = None + calculation_condition = (zmax == -1) or ((zmin-1) <= z <= (zmax+1)) # Full embedding or partial embedding that lies between slices + if calculation_condition: + img_2d = tsfm_img_3D[z, :, :, :].unsqueeze(0).to(self.device) # (1, 3, H, W) + image_embedding = self.model.image_encoder(img_2d).detach() # (1, 256, 64, 64) + else: + image_embedding = None + if image_embedding is not None: + print(image_embedding.shape, image_embedding.dtype) + self.embeddings.append(image_embedding)#.detach().cpu().numpy()) + + @torch.no_grad() + def infer(self, slice_idx, bbox, zrange): + res = {} + + new_size = ResizeLongestSide.get_preprocess_shape( + self.original_size[0], self.original_size[1], 512 + ) + prompt_encoder_input_size = self.model.prompt_encoder.input_image_size[0] + + z_min, z_max = zrange + z_max = min(z_max+1, len(self.embeddings)) + z_min = max(z_min-1, 0) + x_min, y_min, x_max, y_max = bbox + + box2D = np.array([x_min, y_min, x_max, y_max]) + box2D = resize_box( + box2D, + original_size=self.original_size, + prompt_encoder_input_size=prompt_encoder_input_size, + ) + box3D = torch.tensor(np.array([box2D[0], box2D[1], z_min, box2D[2], box2D[3], z_max]), dtype=torch.float32) + + segs_i = np.zeros(self.image_shape[:3], dtype=np.uint16) + x_min, y_min, z_min, x_max, y_max, z_max = box3D + box_default = np.array([x_min, y_min, x_max, y_max]) + z_middle = (z_max + z_min) // 2 + + # infer from middle slice to the z_max + box_2D = box_default + for z in range(int(z_middle), int(z_max)): + box_torch = torch.as_tensor(box_2D[None, ...], dtype=torch.float).to(self.device) # (B, 4) + mask, _ = self.model.prompt_and_decoder(self.embeddings[z], box_torch) + mask = self.model.postprocess_masks(mask, new_size, self.original_size) + mask = mask.squeeze().cpu().numpy() + if np.max(mask) > 0: + box_2D = get_bbox(mask) + box_2D = resize_box( + box=box_2D, + original_size=self.original_size, + prompt_encoder_input_size=prompt_encoder_input_size, + ) + segs_i[z, mask > 0] = 1 + res[z] = segs_i[z] + else: + box_2D = box_default + + # infer from middle slice to the z_min + if np.max(segs_i[int(z_middle), :, :]) == 0: + box_2D = box_default + else: + box_2D = get_bbox(segs_i[int(z_middle), :, :]) + box_2D = resize_box( + box=box_2D, + original_size=self.original_size, + prompt_encoder_input_size=prompt_encoder_input_size, + ) + + for z in range(int(z_middle - 1), int(z_min - 1), -1): + box_torch = torch.as_tensor(box_2D[None, ...], dtype=torch.float).to(self.device) # (B, 4) + mask, _ = self.model.prompt_and_decoder(self.embeddings[z], box_torch) + mask = self.model.postprocess_masks(mask, new_size, self.original_size) + mask = mask.squeeze().cpu().numpy() + if np.max(mask) > 0: + box_2D = get_bbox(mask) + box_2D = resize_box( + box=box_2D, + original_size=self.original_size, + prompt_encoder_input_size=prompt_encoder_input_size, + ) + segs_i[z, mask > 0] = 1 + res[z] = segs_i[z] + else: + box_2D = box_default + + return res + + diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/data/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/data/components/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/data/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/data/components/medsam_dataset.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/data/components/medsam_dataset.py new file mode 100644 index 0000000..cadb7cb --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/data/components/medsam_dataset.py @@ -0,0 +1,370 @@ +import itertools +import os +import random +import zipfile +from glob import glob +from time import time +from typing import List, Optional + +import albumentations as A +import numpy as np +import torch +from torch.utils.data import Dataset + +from src.utils.multiprocessing import parmap +from src.utils.transforms import ( + ResizeLongestSide, + get_bbox, + get_image_transform, + resize_box, + transform_gt, +) + + +class MedSAMBaseDataset(Dataset): + def __init__( + self, + data_dir: str, + image_encoder_input_size: int = 512, + prompt_encoder_input_size: Optional[int] = None, + scale_image: bool = True, + normalize_image: bool = False, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + interpolation: str = "bilinear", + ): + self.data_dir = data_dir + self.image_encoder_input_size = image_encoder_input_size + self.prompt_encoder_input_size = ( + prompt_encoder_input_size + if prompt_encoder_input_size is not None + else image_encoder_input_size + ) + self.scale_image = scale_image + self.normalize_image = normalize_image + self.pixel_mean = pixel_mean + self.pixel_std = pixel_std + self.interpolation = interpolation + self.transform_image = get_image_transform( + long_side_length=self.image_encoder_input_size, + min_max_scale=self.scale_image, + normalize=self.normalize_image, + pixel_mean=self.pixel_mean, + pixel_std=self.pixel_std, + interpolation=self.interpolation, + ) + + +class MedSAMTrainDataset(MedSAMBaseDataset): + def __init__( + self, + bbox_random_shift: int = 5, + mask_num: int = 5, + data_aug: bool = True, + num_workers: int = 8, + glob_pattern: str = "**/*.npz", + limit_npz: Optional[int] = None, + limit_sample: Optional[int] = None, + aug_transform: Optional[A.TransformType] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.bbox_random_shift = bbox_random_shift + self.mask_num = mask_num + + self.npz_file_paths = sorted( + glob(os.path.join(self.data_dir, glob_pattern), recursive=True) + ) + if limit_npz is not None: + self.npz_file_paths = self.npz_file_paths[:limit_npz] + + self.items = list( + itertools.chain.from_iterable( + parmap(self.__flatten_npz, self.npz_file_paths, nprocs=num_workers) + ) + ) + if limit_sample is not None: + rng = random.Random(42) + self.items = rng.sample(self.items, limit_sample) + + print("Number of samples:", len(self.items)) + + if not data_aug: + self.aug_transform = A.NoOp() + elif aug_transform is not None: + self.aug_transform = aug_transform + else: + self.aug_transform = A.Compose( + [ + A.HorizontalFlip(p=0.5), + A.VerticalFlip(p=0.5), + ] + ) + + def __flatten_npz(self, npz_file_path): + try: + data = np.load(npz_file_path, "r") + except zipfile.BadZipFile: + return [] + + gts = data["gts"] + assert len(gts.shape) == 2 or len(gts.shape) == 3 + if len(gts.shape) > 2: # 3D + return [ + (npz_file_path, slice_index) + for slice_index in gts.max(axis=(1, 2)).nonzero()[0] + ] + else: # 2D + return [(npz_file_path, -1)] if gts.max() > 0 else [] + + def get_name(self, item): + name = os.path.basename(item[0]).split(".")[0] + return name + f"_{item[1]:03d}" if item[1] != -1 else name + + def __len__(self): + return len(self.items) + + def __getitem__(self, index): + item = self.items[index] + data = np.load(item[0], "r") + img = data["imgs"] + gt = data["gts"] # multiple labels [0, 1, 4, 5, ...], (H, W) + + if item[1] != -1: # 3D + img = img[item[1], :, :] + gt = gt[item[1], :, :] + + # duplicate channel if the image is grayscale + if len(img.shape) < 3: + img = np.repeat(img[..., None], 3, axis=-1) # (H, W, 3) + + labels = np.unique(gt[gt > 0]) + assert len(labels) > 0, f"No label found in {item[0]}" + labels = random.choices(labels, k=self.mask_num) + + # augmentation + all_masks = [np.array(gt == label, dtype=np.uint8) for label in labels] + augmented = self.aug_transform(image=img, masks=all_masks) + img, all_masks = augmented["image"], augmented["masks"] + original_size = img.shape[:2] + + # Extract boxes and masks from ground-truths + masks_list = [] + boxes_list = [] + for mask in all_masks: + mask = torch.from_numpy(mask).type(torch.uint8) + mask = transform_gt(mask, self.image_encoder_input_size) + if mask.max() == 0: + H, W = mask.shape + x_min = random.randint(0, W - 1) + x_max = random.randint(0, W - 1) + y_min = random.randint(0, H - 1) + y_max = random.randint(0, H - 1) + if x_min > x_max: + x_min, x_max = x_max, x_min + if y_min > y_max: + y_min, y_max = y_max, y_min + + bbox_shift = 1 + x_min = max(0, x_min - bbox_shift) + x_max = min(W - 1, x_max + bbox_shift) + y_min = max(0, y_min - bbox_shift) + y_max = min(H - 1, y_max + bbox_shift) + + box = np.array([x_min, y_min, x_max, y_max]) + else: + box = get_bbox(mask, random.randint(0, self.bbox_random_shift)) + box = resize_box(box, mask.shape, self.prompt_encoder_input_size) + box = torch.tensor(box, dtype=torch.float32) + masks_list.append(mask) + boxes_list.append(box) + + tsfm_img = torch.tensor(np.transpose(img, (2, 0, 1)), dtype=torch.uint8) + tsfm_img = self.transform_image(tsfm_img.unsqueeze(0)).squeeze(0) + + return { + "image": tsfm_img, # (3, H, W) + "masks": torch.stack(masks_list).unsqueeze(1), # (N, H, W) + "boxes": torch.stack(boxes_list), # (N, 4) + "original_size": torch.tensor(original_size, dtype=torch.int32), + } + + +class MedSAMDistillDataset(MedSAMTrainDataset): + def __init__( + self, + teacher_image_encoder_input_size: Optional[int] = 1024, + teacher_scale_image: bool = True, + teacher_normalize_image: bool = False, + embedding_dir=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.teacher_image_encoder_input_size = teacher_image_encoder_input_size + self.teacher_scale_image = teacher_scale_image + self.teacher_normalize_image = teacher_normalize_image + if teacher_image_encoder_input_size is not None: + self.transform_teacher_image = get_image_transform( + long_side_length=self.teacher_image_encoder_input_size, + min_max_scale=self.teacher_scale_image, + normalize=self.teacher_normalize_image, + pixel_mean=self.pixel_mean, + pixel_std=self.pixel_std, + interpolation=self.interpolation, + ) + + self.embedding_dir = embedding_dir + if self.embedding_dir is not None: + self.items = self.__filter_valid_embs(self.items, embedding_dir) + + def __filter_valid_embs(self, items, embedding_dir): + """ + Filter the npz_file_paths, ignore file that does not have image embedding + Some embedding maybe missed during feature extraction process + """ + + valid = [] + for item in items: + name = self.get_name(item) + npy_file_path = os.path.join(embedding_dir, name + ".npy") + if os.path.exists(npy_file_path): + valid.append(item) + print(f"Found {len(valid)} image embeddings.") + return valid + + def __getitem__(self, index): + item = self.items[index] + data = np.load(item[0], "r") + img = data["imgs"] + + if item[1] != -1: # 3D + img = img[item[1], :, :] + + # duplicate channel if the image is grayscale + if len(img.shape) < 3: + img = np.repeat(img[..., None], 3, axis=-1) # (H, W, 3) + + # augmentation + tsfm_img = self.aug_transform(image=img)["image"] + tsfm_img = torch.tensor(np.transpose(tsfm_img, (2, 0, 1)), dtype=torch.uint8) + + items = {"image": self.transform_image(tsfm_img.unsqueeze(0)).squeeze(0)} + + if self.teacher_image_encoder_input_size is not None: + # Transform image + items["teacher_image"] = self.transform_teacher_image( + tsfm_img.unsqueeze(0) + ).squeeze(0) + elif self.embedding_dir is not None: + img_name = self.get_name(item) + emb_file = os.path.join(self.embedding_dir, img_name + ".npy") + items["embedding"] = np.load(emb_file, "r", allow_pickle=True) + + return items + + +class MedSAMInferDataset(MedSAMBaseDataset): + def __init__(self, glob_pattern: str = "**/*.npz", **kwargs): + super().__init__(**kwargs) + self.npz_file_paths = sorted( + glob(os.path.join(self.data_dir, glob_pattern), recursive=True) + ) + + def __len__(self): + return len(self.npz_file_paths) + + def __getitem__(self, index): + start_time = time() + + npz_file_path = self.npz_file_paths[index] + npz_name = os.path.basename(npz_file_path) + data = np.load(npz_file_path, "r") + img = data["imgs"] + boxes = data["boxes"] + + if os.path.basename(npz_file_path).startswith("2D"): + if len(img.shape) < 3: + img = np.repeat(img[..., None], 3, axis=-1) # (H, W, 3) + + original_size = img.shape[:2] + new_size = ResizeLongestSide.get_preprocess_shape( + original_size[0], original_size[1], self.image_encoder_input_size + ) + tsfm_img = torch.tensor(np.transpose(img, (2, 0, 1)), dtype=torch.uint8) + tsfm_img = self.transform_image(tsfm_img.unsqueeze(0)).squeeze(0) + + # Transform box + tsfm_boxes = [] + for box in boxes: + box = resize_box( + box, + original_size=original_size, + prompt_encoder_input_size=self.prompt_encoder_input_size, + ) + tsfm_boxes.append(box) + + end_time = time() + print(f"Processed {npz_name} in {end_time - start_time:.2f}s") + + return { + "image": tsfm_img, # (3, H, W) + "boxes": torch.tensor( + np.array(tsfm_boxes), dtype=torch.float32 + ), # (N, 4) + "npz_name": npz_name, + "new_size": torch.tensor(new_size, dtype=torch.int32), + "original_size": torch.tensor(original_size, dtype=torch.int32), + "image_type": "2D", + "original_image": img, + "original_boxes": boxes, + } + + elif os.path.basename(npz_file_path).startswith("3D"): + if len(img.shape) == 3: + # gray: (D, H, W) -> (D, 3, H, W) + tsfm_imgs = np.repeat(img[:, None, ...], 3, axis=1) + else: + # rbg: (D, H, W, 3) -> (D, 3, H, W) + tsfm_imgs = np.transpose(img, (0, 3, 1, 2)) + + original_size = img.shape[-2:] + new_size = ResizeLongestSide.get_preprocess_shape( + original_size[0], original_size[1], self.image_encoder_input_size + ) + tsfm_imgs = self.transform_image(torch.tensor(tsfm_imgs, dtype=torch.uint8)) + + # Transform box + tsfm_boxes = [] + for box3D in boxes: + x_min, y_min, z_min, x_max, y_max, z_max = box3D + box2D = np.array([x_min, y_min, x_max, y_max]) + box2D = resize_box( + box2D, + original_size=original_size, + prompt_encoder_input_size=self.prompt_encoder_input_size, + ) + box3D = np.array([box2D[0], box2D[1], z_min, box2D[2], box2D[3], z_max]) + tsfm_boxes.append(box3D) + + end_time = time() + print(f"Processed {npz_name} in {end_time - start_time:.2f}s") + + return { + "image": tsfm_imgs, # (D, 3, H, W) + "boxes": torch.tensor( + np.array(tsfm_boxes), dtype=torch.float32 + ), # (N, 6) + "npz_name": npz_name, + "new_size": torch.tensor(new_size, dtype=torch.int32), + "original_size": torch.tensor(original_size, dtype=torch.int32), + "prompt_encoder_input_size": self.prompt_encoder_input_size, + "image_type": "3D", + "original_image": img, + "original_boxes": boxes, + } + + raise Exception( + f"Unexpected input type for file {npz_file_path}, only allow 3D- and 2D- prefix" + ) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/data/medsam_datamodule.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/data/medsam_datamodule.py new file mode 100644 index 0000000..b2931be --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/data/medsam_datamodule.py @@ -0,0 +1,94 @@ +import copy +from typing import Any, Optional, Tuple + +import albumentations as A +import torch +from lightning import LightningDataModule +from torch.utils.data import DataLoader, Dataset, random_split + + +class MedSAMDataModule(LightningDataModule): + def __init__( + self, + dataset: Dataset, + train_val_test_split: Tuple[int, int, int] = (0.9, 0.05, 0.05), + batch_size: int = 16, + num_workers: int = 16, + pin_memory: bool = False, + ) -> None: + super().__init__() + + self.save_hyperparameters(logger=False) + + self.dataset = dataset + self.data_train: Optional[Dataset] = None + self.data_val: Optional[Dataset] = None + self.data_test: Optional[Dataset] = None + self.batch_size_per_device = batch_size + + def setup(self, stage: Optional[str] = None) -> None: + # Divide batch size by the number of devices. + if self.trainer is not None: + if self.hparams.batch_size % self.trainer.world_size != 0: + raise RuntimeError( + f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." + ) + self.batch_size_per_device = ( + self.hparams.batch_size // self.trainer.world_size + ) + + # load and split datasets only if not loaded already + if not self.data_train and not self.data_val and not self.data_test: + train_dataset = self.dataset(num_workers=self.hparams.num_workers) + self.data_train, self.data_val, self.data_test = random_split( + dataset=train_dataset, + lengths=self.hparams.train_val_test_split, + generator=torch.Generator().manual_seed(42), + ) + + val_dataset = copy.deepcopy(train_dataset) + val_dataset.aug_transform = A.NoOp() + self.data_val.dataset = self.data_test.dataset = val_dataset + + def train_dataloader(self) -> DataLoader[Any]: + """Create and return the train dataloader. + + :return: The train dataloader. + """ + return DataLoader( + dataset=self.data_train, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Create and return the validation dataloader. + + :return: The validation dataloader. + """ + return DataLoader( + dataset=self.data_val, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Create and return the test dataloader. + + :return: The test dataloader. + """ + return DataLoader( + dataset=self.data_test, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + ) + + +if __name__ == "__main__": + _ = MedSAMDataModule(None) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/export_onnx.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/export_onnx.py new file mode 100644 index 0000000..cda0667 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/export_onnx.py @@ -0,0 +1,189 @@ +import io +from pathlib import Path +from typing import Dict, List, Optional + +import onnx +import torch +import torch.nn as nn +import hydra +import rootutils +import onnxruntime as ort +import openvino as ov +from onnxruntime.quantization import QuantType +from onnxruntime.quantization.quantize import quantize_dynamic +from onnxsim import simplify +from omegaconf import DictConfig + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from src.models.base_sam import BaseSAM +from src.models.onnx import EncoderOnnxModel, DecoderOnnxModel +from src.utils import ( + RankedLogger, + extras, + task_wrapper, +) + +log = RankedLogger(__name__, rank_zero_only=True) + + +def convert_to_openvino(onnx_model: Path): + model = ov.convert_model(onnx_model) + ov.save_model(model, onnx_model.with_suffix(".xml"), compress_to_fp16=False) + + +def export_onnx( + cfg: DictConfig, + onnx_model: nn.Module, + dummy_inputs: Dict[str, torch.Tensor], + dynamic_axes: Optional[Dict[str, Dict[int, str]]], + output_names: List[str], + output_file: Path, +): + _ = onnx_model(**dummy_inputs) + + buffer = io.BytesIO() + torch.onnx.export( + onnx_model, + tuple(dummy_inputs.values()), + buffer, + export_params=True, + verbose=False, + opset_version=cfg.opset, + do_constant_folding=True, + input_names=list(dummy_inputs.keys()), + output_names=output_names, + dynamic_axes=dynamic_axes, + ) + buffer.seek(0, 0) + + if cfg.simplify: + onnx_model = onnx.load_model(buffer) + onnx_model, success = simplify(onnx_model) + assert success + new_buffer = io.BytesIO() + onnx.save(onnx_model, new_buffer) + buffer = new_buffer + buffer.seek(0, 0) + + with open(output_file, "wb") as f: + f.write(buffer.read()) + + optimized_output_file = output_file.with_suffix(".optimized.onnx") + opt = ort.SessionOptions() + opt.optimized_model_filepath = optimized_output_file.as_posix() + opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + _ = ort.InferenceSession(output_file, opt, providers=["CPUExecutionProvider"]) + + quantized_output_file = output_file.with_suffix(".quantized.onnx") + quantize_dynamic( + model_input=output_file, + model_output=quantized_output_file, + per_channel=False, + reduce_range=False, + weight_type=QuantType.QUInt8, + ) + + convert_to_openvino(output_file) + + +def export_encoder(sam_model: BaseSAM, cfg: DictConfig): + onnx_model = EncoderOnnxModel( + image_encoder=sam_model.image_encoder, + preprocess_image=cfg.encoder_config.preprocess_image, + image_encoder_input_size=cfg.encoder_config.image_encoder_input_size, + scale_image=cfg.encoder_config.scale_image, + normalize_image=cfg.encoder_config.normalize_image, + ) + + if cfg.encoder_config.preprocess_image: + dummy_inputs = { + "image": torch.randint(0, 256, (256, 384, 3), dtype=torch.uint8), + "original_size": torch.tensor([256, 384], dtype=torch.int16), + } + dynamic_axes = {"image": {0: "image_height", 1: "image_width"}} + output_names = ["image_embeddings"] + else: + dummy_inputs = { + "image": torch.randn( + ( + cfg.encoder_config.image_encoder_input_size, + cfg.encoder_config.image_encoder_input_size, + 3, + ), + dtype=torch.float32, + ), + } + dynamic_axes = None + output_names = ["image_embeddings"] + + export_onnx( + cfg=cfg, + onnx_model=onnx_model, + dummy_inputs=dummy_inputs, + dynamic_axes=dynamic_axes, + output_names=output_names, + output_file=Path(cfg.output_dir) / "encoder.onnx", + ) + + +def export_decoder(sam_model: BaseSAM, cfg: DictConfig): + onnx_model = DecoderOnnxModel( + mask_decoder=sam_model.mask_decoder, + prompt_encoder=sam_model.prompt_encoder, + image_encoder_input_size=cfg.encoder_config.image_encoder_input_size, + ) + + embed_dim = onnx_model.prompt_encoder.embed_dim + embed_size = onnx_model.prompt_encoder.image_embedding_size + + dummy_inputs = { + "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float32), + "boxes": torch.rand(4, dtype=torch.float32), + } + output_names = ["masks"] + + export_onnx( + cfg=cfg, + onnx_model=onnx_model, + dummy_inputs=dummy_inputs, + dynamic_axes=None, + output_names=output_names, + output_file=Path(cfg.output_dir) / "decoder.onnx", + ) + + +@task_wrapper +@torch.no_grad() +def export(cfg: DictConfig): + log.info(f"Instantiating model <{cfg.model._target_}>") + model: BaseSAM = hydra.utils.instantiate(cfg.model).to(cfg.device) + model.eval() + + output_dir = Path(cfg.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + log.info("Exporting encoder to ONNX") + export_encoder(model, cfg) + + log.info("Exporting decoder to ONNX") + export_decoder(model, cfg) + + +@hydra.main( + version_base="1.3", config_path="../configs", config_name="export_onnx.yaml" +) +def main(cfg: DictConfig): + """Main entry point for exporting. + + :param cfg: DictConfig configuration composed by Hydra. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + export(cfg) + + +if __name__ == "__main__": + main() diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/export_torch.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/export_torch.py new file mode 100644 index 0000000..26e4f99 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/export_torch.py @@ -0,0 +1,50 @@ +from pathlib import Path + +import torch +import hydra +import rootutils +from omegaconf import DictConfig + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + + +from src.models.base_sam import BaseSAM +from src.utils import ( + RankedLogger, + extras, + task_wrapper, +) + +log = RankedLogger(__name__, rank_zero_only=True) + + +@task_wrapper +def export(cfg: DictConfig): + log.info(f"Instantiating model <{cfg.model._target_}>") + model: BaseSAM = hydra.utils.instantiate(cfg.model).to("cpu") + model.eval() + + output_dir = Path(cfg.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + model.prompt_encoder.input_image_size = (512, 512) + torch.save(model, output_dir / "model.pth") + + +@hydra.main( + version_base="1.3", config_path="../configs", config_name="export_torch.yaml" +) +def main(cfg: DictConfig): + """Main entry point for exporting. + + :param cfg: DictConfig configuration composed by Hydra. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + export(cfg) + + +if __name__ == "__main__": + main() diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/infer.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/infer.py new file mode 100644 index 0000000..29663c4 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/infer.py @@ -0,0 +1,188 @@ +from pathlib import Path +from time import time +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +import hydra +import rootutils +from omegaconf import DictConfig +from tqdm import tqdm +from matplotlib import pyplot as plt + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from torch.utils.data import DataLoader + +from src.models.base_sam import BaseSAM +from src.utils import ( + RankedLogger, + extras, + task_wrapper, +) +from src.utils.transforms import get_bbox, resize_box +from src.utils.visualize import visualize_output + + +log = RankedLogger(__name__, rank_zero_only=True) + + +def infer_2D(model: BaseSAM, data, device, output_dir: Path, save_overlay: bool): + npz_name = data["npz_name"] + img = data["image"].unsqueeze(0).to(device) + boxes = data["boxes"].to(device) + original_size = data["original_size"].tolist() + new_size = data["new_size"].tolist() + + image_embedding = model.image_encoder(img) + masks, _ = model.prompt_and_decoder(image_embedding, boxes) + masks = model.postprocess_masks(masks, new_size, original_size) + masks = masks.squeeze(1).cpu().numpy() + + segs = np.zeros(original_size, dtype=np.uint16) + for idx in range(len(boxes)): + segs[masks[idx] > 0] = idx + 1 + + np.savez_compressed(output_dir / "npz" / npz_name, segs=segs) + + # visualize image, mask and bounding box + if save_overlay: + visualize_output( + img=data["original_image"], + boxes=data["original_boxes"], + segs=segs, + save_file=(output_dir / "png" / npz_name).with_suffix(".png"), + ) + + +def infer_3D(model: BaseSAM, data, device, output_dir: Path, save_overlay: bool): + npz_name = data["npz_name"] + imgs = data["image"].to(device) # (D, 3, H, W) + boxes = data["boxes"] # (N, 6), [[x_min, y_min, z_min, x_max, y_max, z_max]] + original_size = data["original_size"].tolist() # (2) + new_size = data["new_size"].tolist() # (2) + prompt_encoder_input_size = data["prompt_encoder_input_size"] + + segs = np.zeros((imgs.shape[0], *original_size), dtype=np.uint16) + + for idx, box3D in enumerate(boxes, start=1): + segs_i = np.zeros_like(segs, dtype=np.uint16) + x_min, y_min, z_min, x_max, y_max, z_max = box3D + z_min = max(z_min, 0) + z_max = min(z_max, imgs.shape[0]) + box_default = np.array([x_min, y_min, x_max, y_max]) + z_middle = (z_max + z_min) // 2 + + # infer from middle slice to the z_max + box_2D = box_default + for z in range(int(z_middle), int(z_max)): + img_2d = imgs[z, :, :, :].unsqueeze(0) # (1, 3, H, W) + image_embedding = model.image_encoder(img_2d) # (1, 256, 64, 64) + + box_torch = torch.as_tensor( + box_2D[None, ...], dtype=torch.float, device=device + ) # (B, 4) + mask, _ = model.prompt_and_decoder(image_embedding, box_torch) + mask = model.postprocess_masks(mask, new_size, original_size) + mask = mask.squeeze().cpu().numpy() + if np.max(mask) > 0: + box_2D = get_bbox(mask) + box_2D = resize_box( + box=box_2D, + original_size=original_size, + prompt_encoder_input_size=prompt_encoder_input_size, + ) + segs_i[z, mask > 0] = 1 + else: + box_2D = box_default + + # infer from middle slice to the z_min + if np.max(segs_i[int(z_middle), :, :]) == 0: + box_2D = box_default + else: + box_2D = get_bbox(segs_i[int(z_middle), :, :]) + box_2D = resize_box( + box=box_2D, + original_size=original_size, + prompt_encoder_input_size=prompt_encoder_input_size, + ) + + for z in range(int(z_middle - 1), int(z_min - 1), -1): + img_2d = imgs[z, :, :, :].unsqueeze(0) # (1, 3, H, W) + image_embedding = model.image_encoder(img_2d) # (1, 256, 64, 64) + + box_torch = torch.as_tensor( + box_2D[None, ...], dtype=torch.float, device=device + ) # (B, 4) + mask, _ = model.prompt_and_decoder(image_embedding, box_torch) + mask = model.postprocess_masks(mask, new_size, original_size) + mask = mask.squeeze().cpu().numpy() + if np.max(mask) > 0: + box_2D = get_bbox(mask) + box_2D = resize_box( + box=box_2D, + original_size=original_size, + prompt_encoder_input_size=prompt_encoder_input_size, + ) + segs_i[z, mask > 0] = 1 + else: + box_2D = box_default + + segs[segs_i > 0] = idx + + np.savez_compressed(output_dir / "npz" / npz_name, segs=segs) + + # visualize image, mask and bounding box + if save_overlay: + z = segs.shape[0] // 2 + visualize_output( + img=data["original_image"][z], + boxes=data["original_boxes"][:, [0, 1, 3, 4]], + segs=segs[z], + save_file=(output_dir / "png" / npz_name).with_suffix(".png"), + ) + + +@task_wrapper +@torch.no_grad() +def infer(cfg: DictConfig): + log.info(f"Instantiating dataloader <{cfg.data._target_}>") + dataloader: DataLoader = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: BaseSAM = hydra.utils.instantiate(cfg.model).to(cfg.device) + model = torch.compile(model) + model.eval() + + output_dir = Path(cfg.output_dir) + (output_dir / "npz").mkdir(parents=True, exist_ok=True) + if cfg.save_overlay: + (output_dir / "png").mkdir(parents=True, exist_ok=True) + + for data in tqdm(dataloader): + start_time = time() + if data["image_type"] == "2D": + infer_2D(model, data, cfg.device, output_dir, cfg.save_overlay) + elif data["image_type"] == "3D": + infer_3D(model, data, cfg.device, output_dir, cfg.save_overlay) + else: + raise NotImplementedError("Only support 2D and 3D image") + end_time = time() + print(f"Predicted {data['npz_name']} in {end_time - start_time:.2f}s") + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="infer.yaml") +def main(cfg: DictConfig): + """Main entry point for inference. + + :param cfg: DictConfig configuration composed by Hydra. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + infer(cfg) + + +if __name__ == "__main__": + main() diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/losses/SAMLoss.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/losses/SAMLoss.py new file mode 100644 index 0000000..d8aa8cf --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/losses/SAMLoss.py @@ -0,0 +1,39 @@ +from monai.losses import FocalLoss, DiceLoss +from torch import nn + +from .components.IoULoss import IoULoss + + +class SAMLoss(nn.Module): + """ + Loss function used in Segment Anything paper. + """ + + def __init__( + self, + dice_loss_weight=1.0, + focal_loss_weight=20.0, + iou_loss_weight=1.0, + ): + super().__init__() + self.dice_loss_weight = dice_loss_weight + self.focal_loss_weight = focal_loss_weight + self.iou_loss_weight = iou_loss_weight + self.dice_loss = DiceLoss(sigmoid=True, squared_pred=True, reduction="mean") + self.focal_loss = FocalLoss(use_softmax=False, reduction="mean") + self.iou_loss = IoULoss() + + def forward(self, pred_logits, pred_iou, gt_mask): + """ + pred_logits: [B, 1, H, W] + gt_mask: [B, 1, H, W] + pred_iou: [B, 1] + """ + dice_loss = self.dice_loss(pred_logits, gt_mask) + focal_loss = self.focal_loss(pred_logits, gt_mask) + iou_loss = self.iou_loss(pred_logits, pred_iou, gt_mask) + return ( + self.dice_loss_weight * dice_loss + + self.focal_loss_weight * focal_loss + + self.iou_loss_weight * iou_loss + ) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/losses/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/losses/__init__.py new file mode 100644 index 0000000..178d5a6 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/losses/__init__.py @@ -0,0 +1 @@ +from .SAMLoss import SAMLoss diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/losses/components/IoULoss.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/losses/components/IoULoss.py new file mode 100644 index 0000000..b705c5c --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/losses/components/IoULoss.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn + + +class IoULoss(nn.Module): + def __init__(self, eps=1e-7): + super().__init__() + self.eps = eps + + def forward(self, pred_logits, pred_iou, gt_mask): + """ + pred_mask: [B, 1, H, W] + gt_mask: [B, 1, H, W] + pred_iou: [B, 1] + """ + assert pred_logits.shape == gt_mask.shape + assert pred_logits.shape[1] == 1 + + pred_mask = pred_logits > 0 + reduce_axis = list(range(2, len(pred_logits.shape))) + intersection = torch.sum(pred_mask * gt_mask, dim=reduce_axis) + union = ( + torch.sum(pred_mask, dim=reduce_axis) + + torch.sum(gt_mask, dim=reduce_axis) + - intersection + ) + iou = intersection / (union + self.eps) + return torch.mean((iou - pred_iou) ** 2) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/metrics/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/metrics/generalized_dice.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/metrics/generalized_dice.py new file mode 100644 index 0000000..0f109e2 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/metrics/generalized_dice.py @@ -0,0 +1,18 @@ +import torch +from torchmetrics import Metric +from monai.metrics import compute_generalized_dice + + +class GeneralizedDiceMetric(Metric): + def __init__(self): + super().__init__() + self.add_state("dsc", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, gts: torch.Tensor): + dsc = compute_generalized_dice(preds, gts) + self.dsc += dsc.sum() + self.total += dsc.numel() + + def compute(self) -> torch.Tensor: + return self.dsc.float() / self.total diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/base_sam/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/base_sam/__init__.py new file mode 100644 index 0000000..d119491 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/base_sam/__init__.py @@ -0,0 +1 @@ +from .sam import BaseSAM diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/base_sam/sam.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/base_sam/sam.py new file mode 100644 index 0000000..2fa7fa6 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/base_sam/sam.py @@ -0,0 +1,113 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.models.segment_anything.modeling import MaskDecoder, PromptEncoder + + +class BaseSAM(nn.Module): + + def __init__( + self, + image_encoder: nn.Module, + mask_decoder: MaskDecoder, + prompt_encoder: PromptEncoder, + multimask_output: bool = False, + return_best_mask: bool = True, + ): + super().__init__() + self.image_encoder = image_encoder + self.mask_decoder = mask_decoder + self.prompt_encoder = prompt_encoder + self.multimask_output = multimask_output + self.return_best_mask = return_best_mask + + def prompt_and_decoder( + self, image_embedding: torch.Tensor, boxes: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=None, + boxes=boxes, + masks=None, + ) + + # I: number of image embeddings + # B: number of boxes + # Assume that each image has the same number of boxes (= B / I) + # M: number of multimask outputs (default is 3 if multimask_output is True, otherwise 1) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=image_embedding, # (I, 256, 64, 64) + image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) + sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) + dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) + multimask_output=self.multimask_output, + ) # (B, M, 256, 256) + + if self.multimask_output and self.return_best_mask: + max_values, max_indices = torch.max(iou_predictions, dim=1) + iou_predictions = max_values.unsqueeze(1) + low_res_masks = torch.take_along_dim( + low_res_masks, indices=max_indices.view(-1, 1, 1, 1), dim=1 + ) # (B, 1, 256, 256) + + return low_res_masks, iou_predictions + + def forward(self, image: torch.Tensor, boxes: torch.Tensor): + image_embedding = self.image_encoder(image) + return self.prompt_and_decoder(image_embedding, boxes) + + @torch.no_grad() + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, int], + original_size: Tuple[int, int], + return_with_image_encoder_size: bool = False, + ) -> torch.Tensor: + masks = F.interpolate( + masks, + (max(input_size), max(input_size)), + mode="bilinear", + align_corners=False, + ) + if return_with_image_encoder_size: + return masks + + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate( + masks, + original_size, + mode="bilinear", + align_corners=False, + ) + return masks + + @classmethod + def construct_from( + cls, + original_sam: Optional[nn.Module] = None, + distill_lit_module=None, + finetune_lit_module=None, + multimask_output: bool = False, + return_best_mask: bool = True, + ): + if finetune_lit_module is not None: + if isinstance(finetune_lit_module.model, cls): + return finetune_lit_module.model + original_sam = finetune_lit_module.model + + assert original_sam is not None + image_encoder = original_sam.image_encoder + mask_decoder = original_sam.mask_decoder + prompt_encoder = original_sam.prompt_encoder + if distill_lit_module is not None: + image_encoder = distill_lit_module.student_encoder + return cls( + image_encoder=image_encoder, + mask_decoder=mask_decoder, + prompt_encoder=prompt_encoder, + multimask_output=multimask_output, + return_best_mask=return_best_mask, + ) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/distill_module.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/distill_module.py new file mode 100644 index 0000000..eefbe91 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/distill_module.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, Optional + +import torch +from lightning import LightningModule +from torchmetrics import MeanMetric + + +class DistillLitModule(LightningModule): + + def __init__( + self, + student_net: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + teacher_net: Optional[torch.nn.Module] = None, + scheduler_interval: str = "epoch", + ) -> None: + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False, ignore=["student_net", "teacher_net"]) + + self.student_encoder = ( + student_net.image_encoder + if "image_encoder" in dir(student_net) + else student_net + ) + if teacher_net is None: + self.teacher_encoder = None + else: + self.teacher_encoder = ( + teacher_net.image_encoder + if "image_encoder" in dir(teacher_net) + else teacher_net + ) + self.teacher_encoder.requires_grad_(False) + self.teacher_encoder.eval() + + self.criterion = torch.nn.MSELoss() + self.train_loss = MeanMetric() + + def model_step(self, batch): + student_embeddings = self.student_encoder(batch["image"]) + if self.teacher_encoder is not None: + teacher_embeddings = self.teacher_encoder(batch["teacher_image"]) + else: + teacher_embeddings = batch["embedding"] + loss = self.criterion(student_embeddings, teacher_embeddings) + return loss + + def training_step(self, batch, batch_idx) -> torch.Tensor: + loss = self.model_step(batch) + self.train_loss(loss) + self.log( + "train/loss", self.train_loss, on_step=True, on_epoch=True, prog_bar=True + ) + return loss + + def validation_step(self, batch, batch_idx) -> None: + pass + + def test_step(self, batch, batch_idx) -> None: + pass + + def configure_optimizers(self) -> Dict[str, Any]: + optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "train/loss", + "interval": self.hparams.scheduler_interval, + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + +if __name__ == "__main__": + _ = DistillLitModule(None, None, None) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/README.md b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/README.md new file mode 100644 index 0000000..2ac24f9 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/README.md @@ -0,0 +1 @@ +Copied from https://github.com/mit-han-lab/efficientvit/tree/bd2f02695c7c6da942a1a3177bdc89a417286fcf/efficientvit diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/__init__.py new file mode 100644 index 0000000..2c9a5df --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/__init__.py @@ -0,0 +1,7 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .augment import * +from .base import * +from .random_resolution import * diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/augment/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/augment/__init__.py new file mode 100644 index 0000000..b9ea4d6 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/augment/__init__.py @@ -0,0 +1,6 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .bbox import * +from .color_aug import * diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/augment/bbox.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/augment/bbox.py new file mode 100644 index 0000000..b9f089a --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/augment/bbox.py @@ -0,0 +1,30 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import numpy as np + +__all__ = ["rand_bbox"] + + +def rand_bbox( + h: int, + w: int, + lam: float, + rand_func: callable = np.random.uniform, +) -> tuple[int, int, int, int]: + """randomly sample bbox, used in cutmix""" + cut_rat = np.sqrt(1.0 - lam) + cut_w = w * cut_rat + cut_h = h * cut_rat + + # uniform + cx = rand_func(0, w) + cy = rand_func(0, h) + + bbx1 = int(np.clip(cx - cut_w / 2, 0, w)) + bby1 = int(np.clip(cy - cut_h / 2, 0, h)) + bbx2 = int(np.clip(cx + cut_w / 2, 0, w)) + bby2 = int(np.clip(cy + cut_h / 2, 0, h)) + + return bbx1, bby1, bbx2, bby2 diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/augment/color_aug.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/augment/color_aug.py new file mode 100644 index 0000000..e5462ac --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/augment/color_aug.py @@ -0,0 +1,78 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import numpy as np +import torchvision.transforms as transforms +from PIL import Image +from timm.data.auto_augment import rand_augment_transform + +__all__ = ["ColorAug", "RandAug"] + + +class ImageAug: + def aug_image(self, image: Image.Image) -> Image.Image: + raise NotImplementedError + + def __call__(self, feed_dict: dict or np.ndarray or Image.Image) -> dict or np.ndarray or Image.Image: + if isinstance(feed_dict, dict): + output_dict = feed_dict + image = feed_dict[self.key] + else: + output_dict = None + image = feed_dict + is_ndarray = isinstance(image, np.ndarray) + if is_ndarray: + image = Image.fromarray(image) + + image = self.aug_image(image) + + if is_ndarray: + image = np.array(image) + + if output_dict is None: + return image + else: + output_dict[self.key] = image + return output_dict + + +class ColorAug(transforms.ColorJitter, ImageAug): + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, key="data"): + super().__init__( + brightness=brightness, + contrast=contrast, + saturation=saturation, + hue=hue, + ) + self.key = key + + def aug_image(self, image: Image.Image) -> Image.Image: + return transforms.ColorJitter.forward(self, image) + + def forward(self, feed_dict: dict or np.ndarray or Image.Image) -> dict or np.ndarray or Image.Image: + return ImageAug.__call__(self, feed_dict) + + +class RandAug(ImageAug): + def __init__(self, config: dict[str, any], mean: tuple[float, float, float], key="data"): + n = config.get("n", 2) + m = config.get("m", 9) + mstd = config.get("mstd", 1.0) + inc = config.get("inc", 1) + tpct = config.get("tpct", 0.45) + config_str = f"rand-n{n}-m{m}-mstd{mstd}-inc{inc}" + + aa_params = dict( + translate_pct=tpct, + img_mean=tuple([min(255, round(255 * x)) for x in mean]), + interpolation=Image.BICUBIC, + ) + self.aug_op = rand_augment_transform(config_str, aa_params) + self.key = key + + def aug_image(self, image: Image.Image) -> Image.Image: + return self.aug_op(image) + + def __repr__(self): + return self.aug_op.__repr__() diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/base.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/base.py new file mode 100644 index 0000000..e6b3c07 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/base.py @@ -0,0 +1,199 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import copy +import warnings + +import torch.utils.data +from torch.utils.data.distributed import DistributedSampler + +from src.models.efficientvit.apps.data_provider.random_resolution import RRSController +from src.models.efficientvit.models.utils import val2tuple + +__all__ = ["parse_image_size", "random_drop_data", "DataProvider"] + + +def parse_image_size(size: int or str) -> tuple[int, int]: + if isinstance(size, str): + size = [int(val) for val in size.split("-")] + return size[0], size[1] + else: + return val2tuple(size, 2) + + +def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)): + g = torch.Generator() + g.manual_seed(seed) # set random seed before sampling validation set + rand_indexes = torch.randperm(len(dataset), generator=g).tolist() + + dropped_indexes = rand_indexes[:drop_size] + remaining_indexes = rand_indexes[drop_size:] + + dropped_dataset = copy.deepcopy(dataset) + for key in keys: + setattr(dropped_dataset, key, [getattr(dropped_dataset, key)[idx] for idx in dropped_indexes]) + setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes]) + return dataset, dropped_dataset + + +class DataProvider: + data_keys = ("samples",) + mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]} + SUB_SEED = 937162211 # random seed for sampling subset + VALID_SEED = 2147483647 # random seed for the validation set + + name: str + + def __init__( + self, + train_batch_size: int, + test_batch_size: int or None, + valid_size: int or float or None, + n_worker: int, + image_size: int or list[int] or str or list[str], + num_replicas: int or None = None, + rank: int or None = None, + train_ratio: float or None = None, + drop_last: bool = False, + ): + warnings.filterwarnings("ignore") + super().__init__() + + # batch_size & valid_size + self.train_batch_size = train_batch_size + self.test_batch_size = test_batch_size or self.train_batch_size + self.valid_size = valid_size + + # image size + if isinstance(image_size, list): + self.image_size = [parse_image_size(size) for size in image_size] + self.image_size.sort() # e.g., 160 -> 224 + RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size) + self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1] + else: + self.image_size = parse_image_size(image_size) + RRSController.IMAGE_SIZE_LIST = [self.image_size] + self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size + + # distributed configs + self.num_replicas = num_replicas + self.rank = rank + + # build datasets + train_dataset, val_dataset, test_dataset = self.build_datasets() + + if train_ratio is not None and train_ratio < 1.0: + assert 0 < train_ratio < 1 + _, train_dataset = random_drop_data( + train_dataset, + int(train_ratio * len(train_dataset)), + self.SUB_SEED, + self.data_keys, + ) + + # build data loader + self.train = self.build_dataloader(train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True) + self.valid = self.build_dataloader(val_dataset, test_batch_size, n_worker, drop_last=False, train=False) + self.test = self.build_dataloader(test_dataset, test_batch_size, n_worker, drop_last=False, train=False) + if self.valid is None: + self.valid = self.test + self.sub_train = None + + @property + def data_shape(self) -> tuple[int, ...]: + return 3, self.active_image_size[0], self.active_image_size[1] + + def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any: + raise NotImplementedError + + def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any: + raise NotImplementedError + + def build_datasets(self) -> tuple[any, any, any]: + raise NotImplementedError + + def build_dataloader(self, dataset: any or None, batch_size: int, n_worker: int, drop_last: bool, train: bool): + if dataset is None: + return None + if isinstance(self.image_size, list) and train: + from src.models.efficientvit.apps.data_provider.random_resolution._data_loader import RRSDataLoader + + dataloader_class = RRSDataLoader + else: + dataloader_class = torch.utils.data.DataLoader + if self.num_replicas is None: + return dataloader_class( + dataset=dataset, + batch_size=batch_size, + shuffle=True, + num_workers=n_worker, + pin_memory=True, + drop_last=drop_last, + ) + else: + sampler = DistributedSampler(dataset, self.num_replicas, self.rank) + return dataloader_class( + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=n_worker, + pin_memory=True, + drop_last=drop_last, + ) + + def set_epoch(self, epoch: int) -> None: + RRSController.set_epoch(epoch, len(self.train)) + if isinstance(self.train.sampler, DistributedSampler): + self.train.sampler.set_epoch(epoch) + + def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None: + self.active_image_size = val2tuple(new_size, 2) + new_transform = self.build_valid_transform(self.active_image_size) + # change the transform of the valid and test set + self.valid.dataset.transform = self.test.dataset.transform = new_transform + + def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]: + if self.valid_size is not None: + if 0 < self.valid_size < 1: + valid_size = int(self.valid_size * len(train_dataset)) + else: + assert self.valid_size >= 1 + valid_size = int(self.valid_size) + train_dataset, val_dataset = random_drop_data( + train_dataset, + valid_size, + self.VALID_SEED, + self.data_keys, + ) + val_dataset.transform = valid_transform + else: + val_dataset = None + return train_dataset, val_dataset + + def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any: + # used for resetting BN running statistics + if self.sub_train is None: + self.sub_train = {} + if self.active_image_size in self.sub_train: + return self.sub_train[self.active_image_size] + + # construct dataset and dataloader + train_dataset = copy.deepcopy(self.train.dataset) + if n_samples < len(train_dataset): + _, train_dataset = random_drop_data( + train_dataset, + n_samples, + self.SUB_SEED, + self.data_keys, + ) + RRSController.ACTIVE_SIZE = self.active_image_size + train_dataset.transform = self.build_train_transform(image_size=self.active_image_size) + data_loader = self.build_dataloader(train_dataset, batch_size, self.train.num_workers, True, False) + + # pre-fetch data + self.sub_train[self.active_image_size] = [ + data for data in data_loader for _ in range(max(1, n_samples // len(train_dataset))) + ] + + return self.sub_train[self.active_image_size] diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/__init__.py new file mode 100644 index 0000000..b831fa9 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/__init__.py @@ -0,0 +1,7 @@ +"""Random resolution data loader compatible with multi-processing and distributed training. + +Replace Pytorch's DataLoader with RRSDataLoader to support random resolution +at the training time, resolution sampling is controlled by RRSController +""" + +from .controller import * diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/_data_loader.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/_data_loader.py new file mode 100644 index 0000000..b19092a --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/_data_loader.py @@ -0,0 +1,1538 @@ +r"""This file is based on torch/utils/data/data_loader.py + +Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter + +To support these two classes, in `./_utils` we define many utility methods and +functions to be run in multiprocessing. E.g., the data loading worker loop is +in `./_utils/worker.py`. +""" + +import functools +import itertools +import logging +import multiprocessing as python_multiprocessing +import os +import queue +import threading +import warnings +from typing import Any, Callable, Generic, Iterable, List, Optional, Sequence, TypeVar, Union + +import torch +import torch.distributed as dist +import torch.multiprocessing as multiprocessing +import torch.utils.data.graph_settings +from torch._utils import ExceptionWrapper +from torch.utils.data import ( + BatchSampler, + Dataset, + IterableDataset, + IterDataPipe, + MapDataPipe, + RandomSampler, + Sampler, + SequentialSampler, + _utils, +) +from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper + +from ._data_worker import _worker_loop + +__all__ = ["RRSDataLoader"] + +T_co = TypeVar("T_co", covariant=True) +T = TypeVar("T") +_worker_init_fn_t = Callable[[int], None] + +# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that +# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'. +# See https://github.com/python/mypy/issues/3737. +_collate_fn_t = Callable[[List[T]], Any] + + +# These functions used to be defined in this file. However, it was moved to +# _utils/collate.py. Although it is rather hard to access this from user land +# (one has to explicitly directly `import torch.utils.data.dataloader`), there +# probably is user code out there using it. This aliasing maintains BC in this +# aspect. +default_collate: _collate_fn_t = _utils.collate.default_collate +default_convert = _utils.collate.default_convert + +get_worker_info = _utils.worker.get_worker_info + +logger = logging.getLogger(__name__) + + +class _DatasetKind: + Map = 0 + Iterable = 1 + + @staticmethod + def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): + if kind == _DatasetKind.Map: + return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last) + else: + return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last) + + +class _InfiniteConstantSampler(Sampler): + r"""Analogous to ``itertools.repeat(None, None)``. + Used as sampler for :class:`~torch.utils.data.IterableDataset`. + + Args: + data_source (Dataset): dataset to sample from + """ + + def __init__(self): + super().__init__(None) + + def __iter__(self): + while True: + yield None + + +def _get_distributed_settings(): + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size(), dist.get_rank() + else: + return 1, 0 + + +def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id): + global_worker_id = worker_id + info = torch.utils.data.get_worker_info() + assert info is not None + total_workers = info.num_workers + datapipe = info.dataset + assert isinstance(datapipe, (IterDataPipe, MapDataPipe)) + # To distribute elements across distributed process evenly, we should shard data on distributed + # processes first then shard on worker processes + total_workers *= world_size + global_worker_id = global_worker_id * world_size + rank_id + # For BC, use default SHARDING_PRIORITIES + torch.utils.data.graph_settings.apply_sharding(datapipe, total_workers, global_worker_id) + if worker_init_fn is not None: + worker_init_fn(worker_id) + + +def _share_dist_seed(generator, pg): + _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator) + if isinstance(pg, dist.ProcessGroup): + dist.broadcast(_shared_seed, src=0, group=pg) + return _shared_seed.item() + + +class RRSDataLoader(Generic[T_co]): + r""" + Data loader. Combines a dataset and a sampler, and provides an iterable over + the given dataset. + + The :class:`~torch.utils.data.DataLoader` supports both map-style and + iterable-style datasets with single- or multi-process loading, customizing + loading order and optional automatic batching (collation) and memory pinning. + + See :py:mod:`torch.utils.data` documentation page for more details. + + Args: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: ``1``). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: ``False``). + sampler (Sampler or Iterable, optional): defines the strategy to draw + samples from the dataset. Can be any ``Iterable`` with ``__len__`` + implemented. If specified, :attr:`shuffle` must not be specified. + batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but + returns a batch of indices at a time. Mutually exclusive with + :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, + and :attr:`drop_last`. + num_workers (int, optional): how many subprocesses to use for data + loading. ``0`` means that the data will be loaded in the main process. + (default: ``0``) + collate_fn (Callable, optional): merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + pin_memory (bool, optional): If ``True``, the data loader will copy Tensors + into device/CUDA pinned memory before returning them. If your data elements + are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, + see the example below. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: ``False``) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn (Callable, optional): If not ``None``, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: ``None``) + generator (torch.Generator, optional): If not ``None``, this RNG will be used + by RandomSampler to generate random indexes and multiprocessing to generate + `base_seed` for workers. (default: ``None``) + prefetch_factor (int, optional, keyword-only arg): Number of batches loaded + in advance by each worker. ``2`` means there will be a total of + 2 * num_workers batches prefetched across all workers. (default value depends + on the set value for num_workers. If value of num_workers=0 default is ``None``. + Otherwise if value of num_workers>0 default is ``2``). + persistent_workers (bool, optional): If ``True``, the data loader will not shutdown + the worker processes after a dataset has been consumed once. This allows to + maintain the workers `Dataset` instances alive. (default: ``False``) + pin_memory_device (str, optional): the data loader will copy Tensors + into device pinned memory before returning them if pin_memory is set to true. + + + .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` + cannot be an unpicklable object, e.g., a lambda function. See + :ref:`multiprocessing-best-practices` on more details related + to multiprocessing in PyTorch. + + .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. + When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`, + it instead returns an estimate based on ``len(dataset) / batch_size``, with proper + rounding depending on :attr:`drop_last`, regardless of multi-process loading + configurations. This represents the best guess PyTorch can make because PyTorch + trusts user :attr:`dataset` code in correctly handling multi-process + loading to avoid duplicate data. + + However, if sharding results in multiple workers having incomplete last batches, + this estimate can still be inaccurate, because (1) an otherwise complete batch can + be broken into multiple ones and (2) more than one batch worth of samples can be + dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such + cases in general. + + See `Dataset Types`_ for more details on these two types of datasets and how + :class:`~torch.utils.data.IterableDataset` interacts with + `Multi-process data loading`_. + + .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and + :ref:`data-loading-randomness` notes for random seed related questions. + """ + + dataset: Dataset[T_co] + batch_size: Optional[int] + num_workers: int + pin_memory: bool + drop_last: bool + timeout: float + sampler: Union[Sampler, Iterable] + pin_memory_device: str + prefetch_factor: Optional[int] + _iterator: Optional["_BaseDataLoaderIter"] + __initialized = False + + def __init__( + self, + dataset: Dataset[T_co], + batch_size: Optional[int] = 1, + shuffle: Optional[bool] = None, + sampler: Union[Sampler, Iterable, None] = None, + batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None, + num_workers: int = 0, + collate_fn: Optional[_collate_fn_t] = None, + pin_memory: bool = False, + drop_last: bool = False, + timeout: float = 0, + worker_init_fn: Optional[_worker_init_fn_t] = None, + multiprocessing_context=None, + generator=None, + *, + prefetch_factor: Optional[int] = None, + persistent_workers: bool = False, + pin_memory_device: str = "" + ): + torch._C._log_api_usage_once("python.data_loader") + + if num_workers < 0: + raise ValueError( + "num_workers option should be non-negative; " "use num_workers=0 to disable multiprocessing." + ) + + if timeout < 0: + raise ValueError("timeout option should be non-negative") + + if num_workers == 0 and prefetch_factor is not None: + raise ValueError( + "prefetch_factor option could only be specified in multiprocessing." + "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None." + ) + elif num_workers > 0 and prefetch_factor is None: + prefetch_factor = 2 + elif prefetch_factor is not None and prefetch_factor < 0: + raise ValueError("prefetch_factor option should be non-negative") + + if persistent_workers and num_workers == 0: + raise ValueError("persistent_workers option needs num_workers > 0") + + self.dataset = dataset + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.pin_memory = pin_memory + self.pin_memory_device = pin_memory_device + self.timeout = timeout + self.worker_init_fn = worker_init_fn + self.multiprocessing_context = multiprocessing_context + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler + if isinstance(self.dataset, IterDataPipe): + self.dataset = _IterDataPipeSerializationWrapper(self.dataset) + elif isinstance(self.dataset, MapDataPipe): + self.dataset = _MapDataPipeSerializationWrapper(self.dataset) + + # Arg-check dataset related before checking samplers because we want to + # tell users that iterable-style datasets are incompatible with custom + # samplers first, so that they don't learn that this combo doesn't work + # after spending time fixing the custom sampler errors. + if isinstance(dataset, IterableDataset): + self._dataset_kind = _DatasetKind.Iterable + # NOTE [ Custom Samplers and IterableDataset ] + # + # `IterableDataset` does not support custom `batch_sampler` or + # `sampler` since the key is irrelevant (unless we support + # generator-style dataset one day...). + # + # For `sampler`, we always create a dummy sampler. This is an + # infinite sampler even when the dataset may have an implemented + # finite `__len__` because in multi-process data loading, naive + # settings will return duplicated data (which may be desired), and + # thus using a sampler with length matching that of dataset will + # cause data lost (you may have duplicates of the first couple + # batches, but never see anything afterwards). Therefore, + # `Iterabledataset` always uses an infinite sampler, an instance of + # `_InfiniteConstantSampler` defined above. + # + # A custom `batch_sampler` essentially only controls the batch size. + # However, it is unclear how useful it would be since an iterable-style + # dataset can handle that within itself. Moreover, it is pointless + # in multi-process data loading as the assignment order of batches + # to workers is an implementation detail so users can not control + # how to batchify each worker's iterable. Thus, we disable this + # option. If this turns out to be useful in future, we can re-enable + # this, and support custom samplers that specify the assignments to + # specific workers. + if isinstance(dataset, IterDataPipe): + if shuffle is not None: + dataset = torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) + # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default. + elif shuffle not in {False, None}: + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "shuffle option, but got shuffle={}".format(shuffle) + ) + + if sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "sampler option, but got sampler={}".format(sampler) + ) + elif batch_sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + "batch_sampler option, but got batch_sampler={}".format(batch_sampler) + ) + else: + shuffle = bool(shuffle) + self._dataset_kind = _DatasetKind.Map + + if sampler is not None and shuffle: + raise ValueError("sampler option is mutually exclusive with " "shuffle") + + if batch_sampler is not None: + # auto_collation with custom batch_sampler + if batch_size != 1 or shuffle or sampler is not None or drop_last: + raise ValueError( + "batch_sampler option is mutually exclusive " "with batch_size, shuffle, sampler, and " "drop_last" + ) + batch_size = None + drop_last = False + elif batch_size is None: + # no auto_collation + if drop_last: + raise ValueError( + "batch_size=None option disables auto-batching " "and is mutually exclusive with drop_last" + ) + + if sampler is None: # give default samplers + if self._dataset_kind == _DatasetKind.Iterable: + # See NOTE [ Custom Samplers and IterableDataset ] + sampler = _InfiniteConstantSampler() + else: # map-style + if shuffle: + sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type] + else: + sampler = SequentialSampler(dataset) # type: ignore[arg-type] + + if batch_size is not None and batch_sampler is None: + # auto_collation without custom batch_sampler + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.batch_size = batch_size + self.drop_last = drop_last + self.sampler = sampler + self.batch_sampler = batch_sampler + self.generator = generator + + if collate_fn is None: + if self._auto_collation: + collate_fn = _utils.collate.default_collate + else: + collate_fn = _utils.collate.default_convert + + self.collate_fn = collate_fn + self.persistent_workers = persistent_workers + + self.__initialized = True + self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ] + + self._iterator = None + + self.check_worker_number_rationality() + + torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined] + + def _get_iterator(self) -> "_BaseDataLoaderIter": + if self.num_workers == 0: + return _SingleProcessDataLoaderIter(self) + else: + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIter(self) + + @property + def multiprocessing_context(self): + return self.__multiprocessing_context + + @multiprocessing_context.setter + def multiprocessing_context(self, multiprocessing_context): + if multiprocessing_context is not None: + if self.num_workers > 0: + if isinstance(multiprocessing_context, str): + valid_start_methods = multiprocessing.get_all_start_methods() + if multiprocessing_context not in valid_start_methods: + raise ValueError( + ( + "multiprocessing_context option " + "should specify a valid start method in {!r}, but got " + "multiprocessing_context={!r}" + ).format(valid_start_methods, multiprocessing_context) + ) + multiprocessing_context = multiprocessing.get_context(multiprocessing_context) + + if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext): + raise TypeError( + ( + "multiprocessing_context option should be a valid context " + "object or a string specifying the start method, but got " + "multiprocessing_context={}" + ).format(multiprocessing_context) + ) + else: + raise ValueError( + ( + "multiprocessing_context can only be used with " + "multi-process loading (num_workers > 0), but got " + "num_workers={}" + ).format(self.num_workers) + ) + + self.__multiprocessing_context = multiprocessing_context + + def __setattr__(self, attr, val): + if self.__initialized and attr in ( + "batch_size", + "batch_sampler", + "sampler", + "drop_last", + "dataset", + "persistent_workers", + ): + raise ValueError( + "{} attribute should not be set after {} is " "initialized".format(attr, self.__class__.__name__) + ) + + super().__setattr__(attr, val) + + # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up + # since '_BaseDataLoaderIter' references 'DataLoader'. + def __iter__(self) -> "_BaseDataLoaderIter": + # When using a single worker the returned iterator should be + # created everytime to avoid reseting its state + # However, in the case of a multiple workers iterator + # the iterator is only created once in the lifetime of the + # DataLoader object so that workers can be reused + if self.persistent_workers and self.num_workers > 0: + if self._iterator is None: + self._iterator = self._get_iterator() + else: + self._iterator._reset(self) + return self._iterator + else: + return self._get_iterator() + + @property + def _auto_collation(self): + return self.batch_sampler is not None + + @property + def _index_sampler(self): + # The actual sampler used for generating indices for `_DatasetFetcher` + # (see _utils/fetch.py) to read data at each time. This would be + # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. + # We can't change `.sampler` and `.batch_sampler` attributes for BC + # reasons. + if self._auto_collation: + return self.batch_sampler + else: + return self.sampler + + def __len__(self) -> int: + if self._dataset_kind == _DatasetKind.Iterable: + # NOTE [ IterableDataset and __len__ ] + # + # For `IterableDataset`, `__len__` could be inaccurate when one naively + # does multi-processing data loading, since the samples will be duplicated. + # However, no real use case should be actually using that behavior, so + # it should count as a user error. We should generally trust user + # code to do the proper thing (e.g., configure each replica differently + # in `__iter__`), and give us the correct `__len__` if they choose to + # implement it (this will still throw if the dataset does not implement + # a `__len__`). + # + # To provide a further warning, we track if `__len__` was called on the + # `DataLoader`, save the returned value in `self._len_called`, and warn + # if the iterator ends up yielding more than this number of samples. + + # Cannot statically verify that dataset is Sized + length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type] + if self.batch_size is not None: # IterableDataset doesn't allow custom sampler or batch_sampler + from math import ceil + + if self.drop_last: + length = length // self.batch_size + else: + length = ceil(length / self.batch_size) + return length + else: + return len(self._index_sampler) + + def check_worker_number_rationality(self): + # This function check whether the dataloader's worker number is rational based on + # current system's resource. Current rule is that if the number of workers this + # Dataloader will create is bigger than the number of logical cpus that is allowed to + # use, than we will pop up a warning to let user pay attention. + # + # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2 + # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current + # DataLoader process can use half of them which is 32, then the rational max number of + # worker that initiated from this process is 32. + # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32. + # So the warning message is triggered to notify the user to lower the worker number if + # necessary. + # + # + # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is + # available (available in most of Linux system, but not OSX and Windows). + # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but + # it doesn't repect cpuset. + # We don't take threading into account since each worker process is single threaded + # at this time. + # + # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc) + # other than `torch.set_num_threads` to 1 in the worker process, if the passing + # in functions use 3rd party modules that rely on those threading flags to determine + # how many thread to create (eg. numpy, etc), then it is caller's responsibility to + # set those flags correctly. + def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked): + suggested_max_worker_msg = ( + ( + ( + "Our suggested max number of worker in current system is {}{}, which is smaller " + "than what this DataLoader is going to create." + ).format( + num_worker_suggest, + ("" if cpuset_checked else " (`cpuset` is not taken into account)"), + ) + ) + if num_worker_suggest is not None + else ("DataLoader is not able to compute a suggested max number of worker in current system.") + ) + + warn_msg = ( + "This DataLoader will create {} worker processes in total. {} " + "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, " + "lower the worker number to avoid potential slowness/freeze if necessary." + ).format(num_worker_created, suggested_max_worker_msg) + return warn_msg + + if not self.num_workers or self.num_workers == 0: + return + + # try to compute a suggested max number of worker based on system's resource + max_num_worker_suggest = None + cpuset_checked = False + if hasattr(os, "sched_getaffinity"): + try: + max_num_worker_suggest = len(os.sched_getaffinity(0)) + cpuset_checked = True + except Exception: + pass + if max_num_worker_suggest is None: + # os.cpu_count() could return Optional[int] + # get cpu count first and check None in order to satify mypy check + cpu_count = os.cpu_count() + if cpu_count is not None: + max_num_worker_suggest = cpu_count + + if max_num_worker_suggest is None: + warnings.warn(_create_warning_msg(max_num_worker_suggest, self.num_workers, cpuset_checked)) + return + + if self.num_workers > max_num_worker_suggest: + warnings.warn(_create_warning_msg(max_num_worker_suggest, self.num_workers, cpuset_checked)) + + +class _BaseDataLoaderIter: + def __init__(self, loader: RRSDataLoader) -> None: + self._dataset = loader.dataset + self._shared_seed = None + self._pg = None + if isinstance(self._dataset, IterDataPipe): + if dist.is_available() and dist.is_initialized(): + self._pg = dist.new_group(backend="gloo") + self._shared_seed = _share_dist_seed(loader.generator, self._pg) + shared_rng = torch.Generator() + shared_rng.manual_seed(self._shared_seed) + self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng) + self._dataset_kind = loader._dataset_kind + self._IterableDataset_len_called = loader._IterableDataset_len_called + self._auto_collation = loader._auto_collation + self._drop_last = loader.drop_last + self._index_sampler = loader._index_sampler + self._num_workers = loader.num_workers + ws, rank = _get_distributed_settings() + self._world_size = ws + self._rank = rank + # for other backends, pin_memory_device need to set. if not set + # default behaviour is CUDA device. if pin_memory_device is selected + # and pin_memory is not set, the default behaviour false. + if len(loader.pin_memory_device) == 0: + self._pin_memory = loader.pin_memory and torch.cuda.is_available() + self._pin_memory_device = None + else: + if not loader.pin_memory: + warn_msg = ( + "pin memory device is set and pin_memory flag is not used then device pinned memory won't be used" + "please set pin_memory to true, if you need to use the device pin memory" + ) + warnings.warn(warn_msg) + + self._pin_memory = loader.pin_memory + self._pin_memory_device = loader.pin_memory_device + self._timeout = loader.timeout + self._collate_fn = loader.collate_fn + self._sampler_iter = iter(self._index_sampler) + self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item() + self._persistent_workers = loader.persistent_workers + self._num_yielded = 0 + self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__) + + def __iter__(self) -> "_BaseDataLoaderIter": + return self + + def _reset(self, loader, first_iter=False): + self._sampler_iter = iter(self._index_sampler) + self._num_yielded = 0 + self._IterableDataset_len_called = loader._IterableDataset_len_called + if isinstance(self._dataset, IterDataPipe): + self._shared_seed = _share_dist_seed(loader.generator, self._pg) + shared_rng = torch.Generator() + shared_rng.manual_seed(self._shared_seed) + self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng) + + def _next_index(self): + return next(self._sampler_iter) # may raise StopIteration + + def _next_data(self): + raise NotImplementedError + + def __next__(self) -> Any: + with torch.autograd.profiler.record_function(self._profile_name): + if self._sampler_iter is None: + # TODO(https://github.com/pytorch/pytorch/issues/76750) + self._reset() # type: ignore[call-arg] + data = self._next_data() + self._num_yielded += 1 + if ( + self._dataset_kind == _DatasetKind.Iterable + and self._IterableDataset_len_called is not None + and self._num_yielded > self._IterableDataset_len_called + ): + warn_msg = ( + "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " + "samples have been fetched. " + ).format(self._dataset, self._IterableDataset_len_called, self._num_yielded) + if self._num_workers > 0: + warn_msg += ( + "For multiprocessing data-loading, this could be caused by not properly configuring the " + "IterableDataset replica at each worker. Please see " + "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples." + ) + warnings.warn(warn_msg) + return data + + def __len__(self) -> int: + return len(self._index_sampler) + + def __getstate__(self): + # TODO: add limited pickling support for sharing an iterator + # across multiple threads for HOGWILD. + # Probably the best way to do this is by moving the sample pushing + # to a separate thread and then just sharing the data queue + # but signalling the end is tricky without a non-blocking API + raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) + + +class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): + def __init__(self, loader): + super().__init__(loader) + assert self._timeout == 0 + assert self._num_workers == 0 + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Taking care of distributed sharding + if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + # For BC, use default SHARDING_PRIORITIES + torch.utils.data.graph_settings.apply_sharding(self._dataset, self._world_size, self._rank) + + self._dataset_fetcher = _DatasetKind.create_fetcher( + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, + ) + + def _next_data(self): + index = self._next_index() # may raise StopIteration + data = self._dataset_fetcher.fetch(index) # may raise StopIteration + if self._pin_memory: + data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) + return data + + +class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): + r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" + + # NOTE [ Data Loader Multiprocessing Shutdown Logic ] + # + # Preliminary: + # + # Our data model looks like this (queues are indicated with curly brackets): + # + # main process || + # | || + # {index_queue} || + # | || + # worker processes || DATA + # | || + # {worker_result_queue} || FLOW + # | || + # pin_memory_thread of main process || DIRECTION + # | || + # {data_queue} || + # | || + # data output \/ + # + # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if + # `pin_memory=False`. + # + # + # Terminating multiprocessing logic requires very careful design. In + # particular, we need to make sure that + # + # 1. The iterator gracefully exits the workers when its last reference is + # gone or it is depleted. + # + # In this case, the workers should be gracefully exited because the + # main process may still need to continue to run, and we want cleaning + # up code in the workers to be executed (e.g., releasing GPU memory). + # Naturally, we implement the shutdown logic in `__del__` of + # DataLoaderIterator. + # + # We delay the discussion on the logic in this case until later. + # + # 2. The iterator exits the workers when the loader process and/or worker + # processes exits normally or with error. + # + # We set all workers and `pin_memory_thread` to have `daemon=True`. + # + # You may ask, why can't we make the workers non-daemonic, and + # gracefully exit using the same logic as we have in `__del__` when the + # iterator gets deleted (see 1 above)? + # + # First of all, `__del__` is **not** guaranteed to be called when + # interpreter exits. Even if it is called, by the time it executes, + # many Python core library resources may alreay be freed, and even + # simple things like acquiring an internal lock of a queue may hang. + # Therefore, in this case, we actually need to prevent `__del__` from + # being executed, and rely on the automatic termination of daemonic + # children. + # + # Thus, we register an `atexit` hook that sets a global flag + # `_utils.python_exit_status`. Since `atexit` hooks are executed in the + # reverse order of registration, we are guaranteed that this flag is + # set before library resources we use are freed (which, at least in + # CPython, is done via an `atexit` handler defined in + # `multiprocessing/util.py` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362 + # registered when an object requiring this mechanism is first + # created, e.g., `mp.Queue` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103 + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29 + # ) + # + # So in `__del__`, we check if `_utils.python_exit_status` is set or + # `None` (freed), and perform no-op if so. + # + # However, simply letting library clean-up codes run can also be bad, + # because such codes (i.e., `multiprocessing.util._exit_function()`) + # include join putting threads for `mp.Queue`, which can be blocking. + # Hence, the main process putting threads are called with + # `cancel_join_thread` at creation. See later section + # [ 3b. A process won't hang when putting into a queue; ] + # for more details. + # + # Here are two example cases where library clean-up codes can run + # before `__del__` is called: + # + # 1. If we hold onto a reference to the iterator, it more often + # than not tries to do `multiprocessing` library cleaning before + # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666) + # and thus prevents our cleaning-up code to run first. + # + # 2. A similar issue araises when a `DataLoader` is used in a subprocess. + # When a process ends, it shuts the all its daemonic children + # down with a SIGTERM (instead of joining them without a timeout). + # Simiarly for threads, but by a different mechanism. This fact, + # together with a few implementation details of multiprocessing, forces + # us to make workers daemonic. All of our problems arise when a + # DataLoader is used in a subprocess, and are caused by multiprocessing + # code which looks more or less like this: + # + # try: + # your_function_using_a_dataloader() + # finally: + # multiprocessing.util._exit_function() + # + # The joining/termination mentioned above happens inside + # `_exit_function()`. Now, if `your_function_using_a_dataloader()` + # throws, the stack trace stored in the exception will prevent the + # frame which uses `DataLoaderIter` to be freed. If the frame has any + # reference to the `DataLoaderIter` (e.g., in a method of the iter), + # its `__del__`, which starts the shutdown procedure, will not be + # called. That, in turn, means that workers aren't notified. Attempting + # to join in `_exit_function` will then result in a hang. + # + # For context, `_exit_function` is also registered as an `atexit` call. + # So it is unclear to me (@ssnl) why this is needed in a finally block. + # The code dates back to 2008 and there is no comment on the original + # PEP 371 or patch https://bugs.python.org/issue3050 (containing both + # the finally block and the `atexit` registration) that explains this. + # + # + # Finally, another choice is to just shutdown workers with logic in 1 + # above whenever we see an error in `next`. This isn't ideal because + # a. It prevents users from using try-catch to resume data loading. + # b. It doesn't prevent hanging if users have references to the + # iterator. + # + # 3. All processes exit if any of them die unexpectedly by fatal signals. + # + # As shown above, the workers are set as daemonic children of the main + # process. However, automatic cleaning-up of such child processes only + # happens if the parent process exits gracefully (e.g., not via fatal + # signals like SIGKILL). So we must ensure that each process will exit + # even the process that should send/receive data to/from it were + # killed, i.e., + # + # a. A process won't hang when getting from a queue. + # + # Even with carefully designed data dependencies (i.e., a `put()` + # always corresponding to a `get()`), hanging on `get()` can still + # happen when data in queue is corrupted (e.g., due to + # `cancel_join_thread` or unexpected exit). + # + # For child exit, we set a timeout whenever we try to get data + # from `data_queue`, and check the workers' status on each timeout + # and error. + # See `_DataLoaderiter._get_batch()` and + # `_DataLoaderiter._try_get_data()` for details. + # + # Additionally, for child exit on non-Windows platforms, we also + # register a SIGCHLD handler (which is supported on Windows) on + # the main process, which checks if any of the workers fail in the + # (Python) handler. This is more efficient and faster in detecting + # worker failures, compared to only using the above mechanism. + # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. + # + # For `.get()` calls where the sender(s) is not the workers, we + # guard them with timeouts, and check the status of the sender + # when timeout happens: + # + in the workers, the `_utils.worker.ManagerWatchdog` class + # checks the status of the main process. + # + if `pin_memory=True`, when getting from `pin_memory_thread`, + # check `pin_memory_thread` status periodically until `.get()` + # returns or see that `pin_memory_thread` died. + # + # b. A process won't hang when putting into a queue; + # + # We use `mp.Queue` which has a separate background thread to put + # objects from an unbounded buffer array. The background thread is + # daemonic and usually automatically joined when the process + # *exits*. + # + # In case that the receiver has ended abruptly while + # reading from the pipe, the join will hang forever. The usual + # solution for this in Python is calling `q.cancel_join_thread`, + # which prevents automatically joining it when finalizing + # (exiting). + # + # Nonetheless, `cancel_join_thread` must only be called when the + # queue is **not** going to be read from or write into by another + # process, because it may hold onto a lock or leave corrupted data + # in the queue, leading other readers/writers to hang. + # + # Hence, + # + For worker processes, we only do so (for their output + # queues, i.e., `worker_result_queue`) before exiting. + # + For `pin_memory_thread`, its output queue `data_queue` is a + # `queue.Queue` that does blocking `put` if the queue is full. + # So there is no above problem, but as a result, in + # `_pin_memory_loop`, we do need to wrap the `put` in a loop + # that breaks not only upon success, but also when the main + # process stops reading, i.e., is shutting down. + # + For loader process, we `cancel_join_thread()` for all + # `_index_queues` because the whole purpose of workers and + # `pin_memory_thread` is to serve the loader process. If + # loader process is already exiting, we don't really care if + # the queues are corrupted. + # + # + # Now let's get back to 1: + # how we gracefully exit the workers when the last reference to the + # iterator is gone. + # + # To achieve this, we implement the following logic along with the design + # choices mentioned above: + # + # `workers_done_event`: + # A `multiprocessing.Event` shared among the main process and all worker + # processes. This is used to signal the workers that the iterator is + # shutting down. After it is set, they will not send processed data to + # queues anymore, and only wait for the final `None` before exiting. + # `done_event` isn't strictly needed. I.e., we can just check for `None` + # from the input queue, but it allows us to skip wasting resources + # processing data if we are already shutting down. + # + # `pin_memory_thread_done_event`: + # A `threading.Event` for a similar purpose to that of + # `workers_done_event`, but is for the `pin_memory_thread`. The reason + # that separate events are needed is that `pin_memory_thread` reads from + # the output queue of the workers. But the workers, upon seeing that + # `workers_done_event` is set, only wants to see the final `None`, and is + # not required to flush all data in the output queue (e.g., it may call + # `cancel_join_thread` on that queue if its `IterableDataset` iterator + # happens to exhaust coincidentally, which is out of the control of the + # main process). Thus, since we will exit `pin_memory_thread` before the + # workers (see below), two separete events are used. + # + # NOTE: In short, the protocol is that the main process will set these + # `done_event`s and then the corresponding processes/threads a `None`, + # and that they may exit at any time after receiving the `None`. + # + # NOTE: Using `None` as the final signal is valid, since normal data will + # always be a 2-tuple with the 1st element being the index of the data + # transferred (different from dataset index/key), and the 2nd being + # either the dataset key or the data sample (depending on which part + # of the data model the queue is at). + # + # [ worker processes ] + # While loader process is alive: + # Get from `index_queue`. + # If get anything else, + # Check `workers_done_event`. + # If set, continue to next iteration + # i.e., keep getting until see the `None`, then exit. + # Otherwise, process data: + # If is fetching from an `IterableDataset` and the iterator + # is exhausted, send an `_IterableDatasetStopIteration` + # object to signal iteration end. The main process, upon + # receiving such an object, will send `None` to this + # worker and not use the corresponding `index_queue` + # anymore. + # If timed out, + # No matter `workers_done_event` is set (still need to see `None`) + # or not, must continue to next iteration. + # (outside loop) + # If `workers_done_event` is set, (this can be False with `IterableDataset`) + # `data_queue.cancel_join_thread()`. (Everything is ending here: + # main process won't read from it; + # other workers will also call + # `cancel_join_thread`.) + # + # [ pin_memory_thread ] + # # No need to check main thread. If this thread is alive, the main loader + # # thread must be alive, because this thread is set as daemonic. + # While `pin_memory_thread_done_event` is not set: + # Get from `index_queue`. + # If timed out, continue to get in the next iteration. + # Otherwise, process data. + # While `pin_memory_thread_done_event` is not set: + # Put processed data to `data_queue` (a `queue.Queue` with blocking put) + # If timed out, continue to put in the next iteration. + # Otherwise, break, i.e., continuing to the out loop. + # + # NOTE: we don't check the status of the main thread because + # 1. if the process is killed by fatal signal, `pin_memory_thread` + # ends. + # 2. in other cases, either the cleaning-up in __del__ or the + # automatic exit of daemonic thread will take care of it. + # This won't busy-wait either because `.get(timeout)` does not + # busy-wait. + # + # [ main process ] + # In the DataLoader Iter's `__del__` + # b. Exit `pin_memory_thread` + # i. Set `pin_memory_thread_done_event`. + # ii Put `None` in `worker_result_queue`. + # iii. Join the `pin_memory_thread`. + # iv. `worker_result_queue.cancel_join_thread()`. + # + # c. Exit the workers. + # i. Set `workers_done_event`. + # ii. Put `None` in each worker's `index_queue`. + # iii. Join the workers. + # iv. Call `.cancel_join_thread()` on each worker's `index_queue`. + # + # NOTE: (c) is better placed after (b) because it may leave corrupted + # data in `worker_result_queue`, which `pin_memory_thread` + # reads from, in which case the `pin_memory_thread` can only + # happen at timeing out, which is slow. Nonetheless, same thing + # happens if a worker is killed by signal at unfortunate times, + # but in other cases, we are better off having a non-corrupted + # `worker_result_queue` for `pin_memory_thread`. + # + # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) + # can be omitted + # + # NB: `done_event`s isn't strictly needed. E.g., we can just check for + # `None` from `index_queue`, but it allows us to skip wasting resources + # processing indices already in `index_queue` if we are already shutting + # down. + + def __init__(self, loader): + super().__init__(loader) + + self._prefetch_factor = loader.prefetch_factor + + assert self._num_workers > 0 + assert self._prefetch_factor > 0 + + if loader.multiprocessing_context is None: + multiprocessing_context = multiprocessing + else: + multiprocessing_context = loader.multiprocessing_context + + self._worker_init_fn = loader.worker_init_fn + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Additional worker init function will take care of sharding in MP and Distributed + if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + self._worker_init_fn = functools.partial( + _sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank + ) + + # No certainty which module multiprocessing_context is + self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + self._worker_pids_set = False + self._shutdown = False + self._workers_done_event = multiprocessing_context.Event() + + self._index_queues = [] + self._workers = [] + for i in range(self._num_workers): + # No certainty which module multiprocessing_context is + index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + # Need to `cancel_join_thread` here! + # See sections (2) and (3b) above. + index_queue.cancel_join_thread() + w = multiprocessing_context.Process( + target=_worker_loop, + args=( + self._dataset_kind, + self._dataset, + index_queue, + self._worker_result_queue, + self._workers_done_event, + self._auto_collation, + self._collate_fn, + self._drop_last, + self._base_seed, + self._worker_init_fn, + i, + self._num_workers, + self._persistent_workers, + self._shared_seed, + ), + ) + w.daemon = True + # NB: Process.start() actually take some time as it needs to + # start a process and pass the arguments over via a pipe. + # Therefore, we only add a worker to self._workers list after + # it started, so that we do not call .join() if program dies + # before it starts, and __del__ tries to join but will get: + # AssertionError: can only join a started process. + w.start() + self._index_queues.append(index_queue) + self._workers.append(w) + + if self._pin_memory: + self._pin_memory_thread_done_event = threading.Event() + + # Queue is not type-annotated + self._data_queue = queue.Queue() # type: ignore[var-annotated] + if self._pin_memory_device == "xpu": + current_device = torch.xpu.current_device() # type: ignore[attr-defined] + else: + current_device = torch.cuda.current_device() # choose cuda for default + pin_memory_thread = threading.Thread( + target=_utils.pin_memory._pin_memory_loop, + args=( + self._worker_result_queue, + self._data_queue, + current_device, + self._pin_memory_thread_done_event, + self._pin_memory_device, + ), + ) + pin_memory_thread.daemon = True + pin_memory_thread.start() + # Similar to workers (see comment above), we only register + # pin_memory_thread once it is started. + self._pin_memory_thread = pin_memory_thread + else: + self._data_queue = self._worker_result_queue + + # In some rare cases, persistent workers (daemonic processes) + # would be terminated before `__del__` of iterator is invoked + # when main process exits + # It would cause failure when pin_memory_thread tries to read + # corrupted data from worker_result_queue + # atexit is used to shutdown thread and child processes in the + # right sequence before main process exits + if self._persistent_workers and self._pin_memory: + import atexit + + for w in self._workers: + atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) + + # .pid can be None only before process is spawned (not the case, so ignore) + _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] + _utils.signal_handling._set_SIGCHLD_handler() + self._worker_pids_set = True + self._reset(loader, first_iter=True) + + def _reset(self, loader, first_iter=False): + super()._reset(loader, first_iter) + self._send_idx = 0 # idx of the next task to be sent to workers + self._rcvd_idx = 0 # idx of the next task to be returned in __next__ + # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). + # map: task idx => - (worker_id,) if data isn't fetched (outstanding) + # \ (worker_id, data) if data is already fetched (out-of-order) + self._task_info = {} + self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1) + # A list of booleans representing whether each worker still has work to + # do, i.e., not having exhausted its iterable dataset object. It always + # contains all `True`s if not using an iterable-style dataset + # (i.e., if kind != Iterable). + # Not that this indicates that a worker still has work to do *for this epoch*. + # It does not mean that a worker is dead. In case of `_persistent_workers`, + # the worker will be reset to available in the next epoch. + self._workers_status = [True for i in range(self._num_workers)] + # Reset the worker queue cycle so it resumes next epoch at worker 0 + self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) + # We resume the prefetching in case it was enabled + if not first_iter: + for idx in range(self._num_workers): + self._index_queues[idx].put(_utils.worker._ResumeIteration(self._shared_seed)) + resume_iteration_cnt = self._num_workers + while resume_iteration_cnt > 0: + return_idx, return_data = self._get_data() + if isinstance(return_idx, _utils.worker._ResumeIteration): + assert return_data is None + resume_iteration_cnt -= 1 + # prime the prefetch loop + for _ in range(self._prefetch_factor * self._num_workers): + self._try_put_index() + + def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): + # Tries to fetch data from `self._data_queue` once for a given timeout. + # This can also be used as inner loop of fetching without timeout, with + # the sender status as the loop condition. + # + # This raises a `RuntimeError` if any worker died expectedly. This error + # can come from either the SIGCHLD handler in `_utils/signal_handling.py` + # (only for non-Windows platforms), or the manual check below on errors + # and timeouts. + # + # Returns a 2-tuple: + # (bool: whether successfully get data, any: data if successful else None) + try: + data = self._data_queue.get(timeout=timeout) + return (True, data) + except Exception as e: + # At timeout and error, we manually check whether any worker has + # failed. Note that this is the only mechanism for Windows to detect + # worker failures. + failed_workers = [] + for worker_id, w in enumerate(self._workers): + if self._workers_status[worker_id] and not w.is_alive(): + failed_workers.append(w) + self._mark_worker_as_unavailable(worker_id) + if len(failed_workers) > 0: + pids_str = ", ".join(str(w.pid) for w in failed_workers) + raise RuntimeError("DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str)) from e + if isinstance(e, queue.Empty): + return (False, None) + import errno + import tempfile + + try: + # Raise an exception if we are this close to the FDs limit. + # Apparently, trying to open only one file is not a sufficient + # test. + # See NOTE [ DataLoader on Linux and open files limit ] + fds_limit_margin = 10 + fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)] + except OSError as e: + if e.errno == errno.EMFILE: + raise RuntimeError( + "Too many open files. Communication with the" + " workers is no longer possible. Please increase the" + " limit using `ulimit -n` in the shell or change the" + " sharing strategy by calling" + " `torch.multiprocessing.set_sharing_strategy('file_system')`" + " at the beginning of your code" + ) from None + raise + + # NOTE [ DataLoader on Linux and open files limit ] + # + # On Linux when DataLoader is used with multiprocessing we pass the data between + # the root process and the workers through SHM files. We remove those files from + # the filesystem as soon as they are created and keep them alive by + # passing around their file descriptors through AF_UNIX sockets. (See + # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in + # the wiki (https://github.com/pytorch/pytorch/wiki).) + # + # This sometimes leads us to exceeding the open files limit. When that happens, + # and the offending file descriptor is coming over a socket, the `socket` Python + # package silently strips the file descriptor from the message, setting only the + # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that + # it _indicates that some control data were discarded due to lack of space in + # the buffer for ancillary data_). This might reflect the C implementation of + # AF_UNIX sockets. + # + # This behaviour can be reproduced with the script and instructions at the + # bottom of this note. + # + # When that happens, the standard Python `multiprocessing` (and not + # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata` + # + # Sometimes, instead of the FD being stripped, you may get an `OSError: + # Too many open files`, both in the script below and in DataLoader. However, + # this is rare and seems to be nondeterministic. + # + # + # #!/usr/bin/env python3 + # import sys + # import socket + # import os + # import array + # import shutil + # import socket + # + # + # if len(sys.argv) != 4: + # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)") + # sys.exit(1) + # + # if __name__ == '__main__': + # dirname = sys.argv[1] + # sock_path = dirname + "/sock" + # iterations = int(sys.argv[2]) + # def dummy_path(i): + # return dirname + "/" + str(i) + ".dummy" + # + # + # if sys.argv[3] == 'send': + # while not os.path.exists(sock_path): + # pass + # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + # client.connect(sock_path) + # for i in range(iterations): + # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT) + # ancdata = array.array('i', [fd]) + # msg = bytes([i % 256]) + # print("Sending fd ", fd, " (iteration #", i, ")") + # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)]) + # + # + # else: + # assert sys.argv[3] == 'recv' + # + # if os.path.exists(dirname): + # raise Exception("Directory exists") + # + # os.mkdir(dirname) + # + # print("Opening socket...") + # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + # server.bind(sock_path) + # + # print("Listening...") + # for i in range(iterations): + # a = array.array('i') + # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize)) + # assert(len(ancdata) == 1) + # cmsg_level, cmsg_type, cmsg_data = ancdata[0] + # a.frombytes(cmsg_data) + # print("Received fd ", a[0], " (iteration #", i, ")") + # + # shutil.rmtree(dirname) + # + # Steps to reproduce: + # + # 1. Run two shells and set lower file descriptor limit in the receiving one: + # (shell1) ulimit -n 1020 + # (shell2) ulimit -n 1022 + # + # 2. Run the script above with the `recv` option in the first shell + # (shell1) ./test_socket.py sock_tmp 1017 recv + # + # 3. Run the script with the `send` option in the second shell: + # (shell2) ./test_socket.py sock_tmp 1017 send + + def _get_data(self): + # Fetches data from `self._data_queue`. + # + # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, + # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` + # in a loop. This is the only mechanism to detect worker failures for + # Windows. For other platforms, a SIGCHLD handler is also used for + # worker failure detection. + # + # If `pin_memory=True`, we also need check if `pin_memory_thread` had + # died at timeouts. + if self._timeout > 0: + success, data = self._try_get_data(self._timeout) + if success: + return data + else: + raise RuntimeError("DataLoader timed out after {} seconds".format(self._timeout)) + elif self._pin_memory: + while self._pin_memory_thread.is_alive(): + success, data = self._try_get_data() + if success: + return data + else: + # while condition is false, i.e., pin_memory_thread died. + raise RuntimeError("Pin memory thread exited unexpectedly") + # In this case, `self._data_queue` is a `queue.Queue`,. But we don't + # need to call `.task_done()` because we don't use `.join()`. + else: + while True: + success, data = self._try_get_data() + if success: + return data + + def _next_data(self): + while True: + # If the worker responsible for `self._rcvd_idx` has already ended + # and was unable to fulfill this task (due to exhausting an `IterableDataset`), + # we try to advance `self._rcvd_idx` to find the next valid index. + # + # This part needs to run in the loop because both the `self._get_data()` + # call and `_IterableDatasetStopIteration` check below can mark + # extra worker(s) as dead. + while self._rcvd_idx < self._send_idx: + info = self._task_info[self._rcvd_idx] + worker_id = info[0] + if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active + break + del self._task_info[self._rcvd_idx] + self._rcvd_idx += 1 + else: + # no valid `self._rcvd_idx` is found (i.e., didn't break) + if not self._persistent_workers: + self._shutdown_workers() + raise StopIteration + + # Now `self._rcvd_idx` is the batch index we want to fetch + + # Check if the next sample has already been generated + if len(self._task_info[self._rcvd_idx]) == 2: + data = self._task_info.pop(self._rcvd_idx)[1] + return self._process_data(data) + + assert not self._shutdown and self._tasks_outstanding > 0 + idx, data = self._get_data() + self._tasks_outstanding -= 1 + if self._dataset_kind == _DatasetKind.Iterable: + # Check for _IterableDatasetStopIteration + if isinstance(data, _utils.worker._IterableDatasetStopIteration): + if self._persistent_workers: + self._workers_status[data.worker_id] = False + else: + self._mark_worker_as_unavailable(data.worker_id) + self._try_put_index() + continue + + if idx != self._rcvd_idx: + # store out-of-order samples + self._task_info[idx] += (data,) + else: + del self._task_info[idx] + return self._process_data(data) + + def _try_put_index(self): + assert self._tasks_outstanding < self._prefetch_factor * self._num_workers + + try: + index = self._next_index() + except StopIteration: + return + for _ in range(self._num_workers): # find the next active worker, if any + worker_queue_idx = next(self._worker_queue_idx_cycle) + if self._workers_status[worker_queue_idx]: + break + else: + # not found (i.e., didn't break) + return + + self._index_queues[worker_queue_idx].put((self._send_idx, index)) + self._task_info[self._send_idx] = (worker_queue_idx,) + self._tasks_outstanding += 1 + self._send_idx += 1 + + def _process_data(self, data): + self._rcvd_idx += 1 + self._try_put_index() + if isinstance(data, ExceptionWrapper): + data.reraise() + return data + + def _mark_worker_as_unavailable(self, worker_id, shutdown=False): + # Mark a worker as having finished its work e.g., due to + # exhausting an `IterableDataset`. This should be used only when this + # `_MultiProcessingDataLoaderIter` is going to continue running. + + assert self._workers_status[worker_id] or (self._persistent_workers and shutdown) + + # Signal termination to that specific worker. + q = self._index_queues[worker_id] + # Indicate that no more data will be put on this queue by the current + # process. + q.put(None) + + # Note that we don't actually join the worker here, nor do we remove the + # worker's pid from C side struct because (1) joining may be slow, and + # (2) since we don't join, the worker may still raise error, and we + # prefer capturing those, rather than ignoring them, even though they + # are raised after the worker has finished its job. + # Joinning is deferred to `_shutdown_workers`, which it is called when + # all workers finish their jobs (e.g., `IterableDataset` replicas) or + # when this iterator is garbage collected. + + self._workers_status[worker_id] = False + + assert self._workers_done_event.is_set() == shutdown + + def _shutdown_workers(self): + # Called when shutting down this `_MultiProcessingDataLoaderIter`. + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on + # the logic of this function. + if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None: + # See (2) of the note. If Python is shutting down, do no-op. + return + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + if not self._shutdown: + self._shutdown = True + try: + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + + # Exit `pin_memory_thread` first because exiting workers may leave + # corrupted data in `worker_result_queue` which `pin_memory_thread` + # reads from. + if hasattr(self, "_pin_memory_thread"): + # Use hasattr in case error happens before we set the attribute. + self._pin_memory_thread_done_event.set() + # Send something to pin_memory_thread in case it is waiting + # so that it can wake up and check `pin_memory_thread_done_event` + self._worker_result_queue.put((None, None)) + self._pin_memory_thread.join() + self._worker_result_queue.cancel_join_thread() + self._worker_result_queue.close() + + # Exit workers now. + self._workers_done_event.set() + for worker_id in range(len(self._workers)): + # Get number of workers from `len(self._workers)` instead of + # `self._num_workers` in case we error before starting all + # workers. + # If we are using workers_status with persistent_workers + # we have to shut it down because the worker is paused + if self._persistent_workers or self._workers_status[worker_id]: + self._mark_worker_as_unavailable(worker_id, shutdown=True) + for w in self._workers: + # We should be able to join here, but in case anything went + # wrong, we set a timeout and if the workers fail to join, + # they are killed in the `finally` block. + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + for q in self._index_queues: + q.cancel_join_thread() + q.close() + finally: + # Even though all this function does is putting into queues that + # we have called `cancel_join_thread` on, weird things can + # happen when a worker is killed by a signal, e.g., hanging in + # `Event.set()`. So we need to guard this with SIGCHLD handler, + # and remove pids from the C side data structure only at the + # end. + # + # FIXME: Unfortunately, for Windows, we are missing a worker + # error detection mechanism here in this function, as it + # doesn't provide a SIGCHLD handler. + if self._worker_pids_set: + _utils.signal_handling._remove_worker_pids(id(self)) + self._worker_pids_set = False + for w in self._workers: + if w.is_alive(): + # Existing mechanisms try to make the workers exit + # peacefully, but in case that we unfortunately reach + # here, which we shouldn't, (e.g., pytorch/pytorch#39570), + # we kill the worker. + w.terminate() + + # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter` + @staticmethod + def _clean_up_worker(w): + try: + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + finally: + if w.is_alive(): + w.terminate() + + def __del__(self): + self._shutdown_workers() diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/_data_worker.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/_data_worker.py new file mode 100644 index 0000000..49d2956 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/_data_worker.py @@ -0,0 +1,358 @@ +r""""This file is based on torch/utils/data/_utils/worker.py + +Contains definitions of the methods used by the _BaseDataLoaderIter workers. +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import os +import queue +import random +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union + +import torch +from torch._utils import ExceptionWrapper +from torch.utils.data._utils import HAS_NUMPY, IS_WINDOWS, MP_STATUS_CHECK_INTERVAL, signal_handling + +if TYPE_CHECKING: + from torch.utils.data import Dataset + +from .controller import RRSController + +if IS_WINDOWS: + import ctypes + from ctypes.wintypes import BOOL, DWORD, HANDLE + + # On Windows, the parent ID of the worker process remains unchanged when the manager process + # is gone, and the only way to check it through OS is to let the worker have a process handle + # of the manager and ask if the process status has changed. + class ManagerWatchdog: + def __init__(self): + self.manager_pid = os.getppid() + + # mypy cannot detect this code is windows only + self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined] + self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) + self.kernel32.OpenProcess.restype = HANDLE + self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) + self.kernel32.WaitForSingleObject.restype = DWORD + + # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx + SYNCHRONIZE = 0x00100000 + self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) + + if not self.manager_handle: + raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined] + + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx + self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 + return not self.manager_dead + +else: + + class ManagerWatchdog: # type: ignore[no-redef] + def __init__(self): + self.manager_pid = os.getppid() + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + self.manager_dead = os.getppid() != self.manager_pid + return not self.manager_dead + + +_worker_info = None + + +class WorkerInfo: + id: int + num_workers: int + seed: int + dataset: "Dataset" + __initialized = False + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + self.__keys = tuple(kwargs.keys()) + self.__initialized = True + + def __setattr__(self, key, val): + if self.__initialized: + raise RuntimeError("Cannot assign attributes to {} objects".format(self.__class__.__name__)) + return super().__setattr__(key, val) + + def __repr__(self): + items = [] + for k in self.__keys: + items.append("{}={}".format(k, getattr(self, k))) + return "{}({})".format(self.__class__.__name__, ", ".join(items)) + + +def get_worker_info() -> Optional[WorkerInfo]: + r"""Returns the information about the current + :class:`~torch.utils.data.DataLoader` iterator worker process. + + When called in a worker, this returns an object guaranteed to have the + following attributes: + + * :attr:`id`: the current worker id. + * :attr:`num_workers`: the total number of workers. + * :attr:`seed`: the random seed set for the current worker. This value is + determined by main process RNG and the worker id. See + :class:`~torch.utils.data.DataLoader`'s documentation for more details. + * :attr:`dataset`: the copy of the dataset object in **this** process. Note + that this will be a different object in a different process than the one + in the main process. + + When called in the main process, this returns ``None``. + + .. note:: + When used in a :attr:`worker_init_fn` passed over to + :class:`~torch.utils.data.DataLoader`, this method can be useful to + set up each worker process differently, for instance, using ``worker_id`` + to configure the ``dataset`` object to only read a specific fraction of a + sharded dataset, or use ``seed`` to seed other libraries used in dataset + code. + """ + return _worker_info + + +r"""Dummy class used to signal the end of an IterableDataset""" + + +@dataclass(frozen=True) +class _IterableDatasetStopIteration: + worker_id: int + + +r"""Dummy class used to resume the fetching when worker reuse is enabled""" + + +@dataclass(frozen=True) +class _ResumeIteration: + seed: Optional[int] = None + + +# The function `_generate_state` is adapted from `numpy.random.SeedSequence` +# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx +# It's MIT licensed, here is the copyright: + +# Copyright (c) 2015 Melissa E. O'Neill +# Copyright (c) 2019 NumPy Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +# This function generates an array of int32 as the seed for +# `numpy.random`, in order to prevent state collision due to same +# seed and algorithm for `numpy.random` and `random` modules. +# TODO: Implement `SeedSequence` like object for `torch.random` +def _generate_state(base_seed, worker_id): + INIT_A = 0x43B0D7E5 + MULT_A = 0x931E8875 + INIT_B = 0x8B51F9DD + MULT_B = 0x58F38DED + MIX_MULT_L = 0xCA01F9DD + MIX_MULT_R = 0x4973F715 + XSHIFT = 4 * 8 // 2 + MASK32 = 0xFFFFFFFF + + entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0] + pool = [0] * 4 + + hash_const_A = INIT_A + + def hash(value): + nonlocal hash_const_A + value = (value ^ hash_const_A) & MASK32 + hash_const_A = (hash_const_A * MULT_A) & MASK32 + value = (value * hash_const_A) & MASK32 + value = (value ^ (value >> XSHIFT)) & MASK32 + return value + + def mix(x, y): + result_x = (MIX_MULT_L * x) & MASK32 + result_y = (MIX_MULT_R * y) & MASK32 + result = (result_x - result_y) & MASK32 + result = (result ^ (result >> XSHIFT)) & MASK32 + return result + + # Add in the entropy to the pool. + for i in range(len(pool)): + pool[i] = hash(entropy[i]) + + # Mix all bits together so late bits can affect earlier bits. + for i_src in range(len(pool)): + for i_dst in range(len(pool)): + if i_src != i_dst: + pool[i_dst] = mix(pool[i_dst], hash(pool[i_src])) + + hash_const_B = INIT_B + state = [] + for i_dst in range(4): + data_val = pool[i_dst] + data_val = (data_val ^ hash_const_B) & MASK32 + hash_const_B = (hash_const_B * MULT_B) & MASK32 + data_val = (data_val * hash_const_B) & MASK32 + data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32 + state.append(data_val) + return state + + +def _worker_loop( + dataset_kind, + dataset, + index_queue, + data_queue, + done_event, + auto_collation, + collate_fn, + drop_last, + base_seed, + init_fn, + worker_id, + num_workers, + persistent_workers, + shared_seed, +): + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + + try: + # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal + # module's handlers are executed after Python returns from C low-level + # handlers, likely when the same fatal signal had already happened + # again. + # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers + signal_handling._set_worker_signal_handlers() + + torch.set_num_threads(1) + seed = base_seed + worker_id + random.seed(seed) + torch.manual_seed(seed) + if HAS_NUMPY: + np_seed = _generate_state(base_seed, worker_id) + import numpy as np + + np.random.seed(np_seed) + + from torch.utils.data import IterDataPipe + from torch.utils.data.graph_settings import apply_random_seed + + shared_rng = torch.Generator() + if isinstance(dataset, IterDataPipe): + assert shared_seed is not None + shared_rng.manual_seed(shared_seed) + dataset = apply_random_seed(dataset, shared_rng) + + global _worker_info + _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset) + + from torch.utils.data import _DatasetKind + + init_exception = None + + try: + if init_fn is not None: + init_fn(worker_id) + + fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last) + except Exception: + init_exception = ExceptionWrapper(where="in DataLoader worker process {}".format(worker_id)) + + # When using Iterable mode, some worker can exit earlier than others due + # to the IterableDataset behaving differently for different workers. + # When such things happen, an `_IterableDatasetStopIteration` object is + # sent over to the main process with the ID of this worker, so that the + # main process won't send more tasks to this worker, and will send + # `None` to this worker to properly exit it. + # + # Note that we cannot set `done_event` from a worker as it is shared + # among all processes. Instead, we set the `iteration_end` flag to + # signify that the iterator is exhausted. When either `done_event` or + # `iteration_end` is set, we skip all processing step and just wait for + # `None`. + iteration_end = False + + watchdog = ManagerWatchdog() + + while watchdog.is_alive(): + try: + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + if isinstance(r, _ResumeIteration): + # Acknowledge the main process + data_queue.put((r, None)) + iteration_end = False + + if isinstance(dataset, IterDataPipe): + assert r.seed is not None + shared_rng.manual_seed(r.seed) + dataset = apply_random_seed(dataset, shared_rng) + + # Recreate the fetcher for worker-reuse policy + fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last) + continue + elif r is None: + # Received the final signal + assert done_event.is_set() or iteration_end + break + elif done_event.is_set() or iteration_end: + # `done_event` is set. But I haven't received the final signal + # (None) yet. I will keep continuing until get it, and skip the + # processing steps. + continue + idx, index = r + """ Added """ + RRSController.sample_resolution(batch_id=idx) + """ Added """ + data: Union[_IterableDatasetStopIteration, ExceptionWrapper] + if init_exception is not None: + data = init_exception + init_exception = None + else: + try: + data = fetcher.fetch(index) + except Exception as e: + if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable: + data = _IterableDatasetStopIteration(worker_id) + # Set `iteration_end` + # (1) to save future `next(...)` calls, and + # (2) to avoid sending multiple `_IterableDatasetStopIteration`s. + iteration_end = True + else: + # It is important that we don't store exc_info in a variable. + # `ExceptionWrapper` does the correct thing. + # See NOTE [ Python Traceback Reference Cycle Problem ] + data = ExceptionWrapper(where="in DataLoader worker process {}".format(worker_id)) + data_queue.put((idx, data)) + del data, idx, index, r # save memory + except KeyboardInterrupt: + # Main process will raise KeyboardInterrupt anyways. + pass + if done_event.is_set(): + data_queue.cancel_join_thread() + data_queue.close() diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/controller.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/controller.py new file mode 100644 index 0000000..077b735 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/data_provider/random_resolution/controller.py @@ -0,0 +1,92 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import copy + +import torch +import torchvision.transforms as transforms +import torchvision.transforms.functional as F + +from src.models.efficientvit.models.utils import torch_random_choices + +__all__ = [ + "RRSController", + "get_interpolate", + "MyRandomResizedCrop", +] + + +class RRSController: + ACTIVE_SIZE = (224, 224) + IMAGE_SIZE_LIST = [(224, 224)] + + CHOICE_LIST = None + + @staticmethod + def get_candidates() -> list[tuple[int, int]]: + return copy.deepcopy(RRSController.IMAGE_SIZE_LIST) + + @staticmethod + def sample_resolution(batch_id: int) -> None: + RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id] + + @staticmethod + def set_epoch(epoch: int, batch_per_epoch: int) -> None: + g = torch.Generator() + g.manual_seed(epoch) + RRSController.CHOICE_LIST = torch_random_choices( + RRSController.get_candidates(), + g, + batch_per_epoch, + ) + + +def get_interpolate(name: str) -> F.InterpolationMode: + mapping = { + "nearest": F.InterpolationMode.NEAREST, + "bilinear": F.InterpolationMode.BILINEAR, + "bicubic": F.InterpolationMode.BICUBIC, + "box": F.InterpolationMode.BOX, + "hamming": F.InterpolationMode.HAMMING, + "lanczos": F.InterpolationMode.LANCZOS, + } + if name in mapping: + return mapping[name] + elif name == "random": + return torch_random_choices( + [ + F.InterpolationMode.NEAREST, + F.InterpolationMode.BILINEAR, + F.InterpolationMode.BICUBIC, + F.InterpolationMode.BOX, + F.InterpolationMode.HAMMING, + F.InterpolationMode.LANCZOS, + ], + ) + else: + raise NotImplementedError + + +class MyRandomResizedCrop(transforms.RandomResizedCrop): + def __init__( + self, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation: str = "random", + ): + super(MyRandomResizedCrop, self).__init__(224, scale, ratio) + self.interpolation = interpolation + + def forward(self, img: torch.Tensor) -> torch.Tensor: + i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio)) + target_size = RRSController.ACTIVE_SIZE + return F.resized_crop(img, i, j, h, w, list(target_size), get_interpolate(self.interpolation)) + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + format_string += f"(\n\tsize={RRSController.get_candidates()},\n" + format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n" + format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n" + format_string += f"\tinterpolation={self.interpolation})" + return format_string diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/setup.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/setup.py new file mode 100644 index 0000000..3bdaf67 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/setup.py @@ -0,0 +1,135 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import os +import time +from copy import deepcopy + +import torch.backends.cudnn +import torch.distributed +import torch.nn as nn + +from src.models.efficientvit.apps.data_provider import DataProvider +from src.models.efficientvit.apps.trainer.run_config import RunConfig +from src.models.efficientvit.apps.utils import ( + dist_init, + dump_config, + get_dist_local_rank, + get_dist_rank, + get_dist_size, + init_modules, + is_master, + load_config, + partial_update_config, + zero_last_gamma, +) +from src.models.efficientvit.models.utils import build_kwargs_from_config, load_state_dict_from_file + +__all__ = [ + "save_exp_config", + "setup_dist_env", + "setup_seed", + "setup_exp_config", + "setup_data_provider", + "setup_run_config", + "init_model", +] + + +def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None: + if not is_master(): + return + dump_config(exp_config, os.path.join(path, name)) + + +def setup_dist_env(gpu: str or None = None) -> None: + if gpu is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = gpu + if not torch.distributed.is_initialized(): + dist_init() + torch.backends.cudnn.benchmark = True + torch.cuda.set_device(get_dist_local_rank()) + + +def setup_seed(manual_seed: int, resume: bool) -> None: + if resume: + manual_seed = int(time.time()) + manual_seed = get_dist_rank() + manual_seed + torch.manual_seed(manual_seed) + torch.cuda.manual_seed_all(manual_seed) + + +def setup_exp_config(config_path: str, recursive=True, opt_args: dict or None = None) -> dict: + # load config + if not os.path.isfile(config_path): + raise ValueError(config_path) + + fpaths = [config_path] + if recursive: + extension = os.path.splitext(config_path)[1] + while os.path.dirname(config_path) != config_path: + config_path = os.path.dirname(config_path) + fpath = os.path.join(config_path, "default" + extension) + if os.path.isfile(fpath): + fpaths.append(fpath) + fpaths = fpaths[::-1] + + default_config = load_config(fpaths[0]) + exp_config = deepcopy(default_config) + for fpath in fpaths[1:]: + partial_update_config(exp_config, load_config(fpath)) + # update config via args + if opt_args is not None: + partial_update_config(exp_config, opt_args) + + return exp_config + + +def setup_data_provider( + exp_config: dict, data_provider_classes: list[type[DataProvider]], is_distributed: bool = True +) -> DataProvider: + dp_config = exp_config["data_provider"] + dp_config["num_replicas"] = get_dist_size() if is_distributed else None + dp_config["rank"] = get_dist_rank() if is_distributed else None + dp_config["test_batch_size"] = dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2 + dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config["base_batch_size"] + + data_provider_lookup = {provider.name: provider for provider in data_provider_classes} + data_provider_class = data_provider_lookup[dp_config["dataset"]] + + data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class) + data_provider = data_provider_class(**data_provider_kwargs) + return data_provider + + +def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig: + exp_config["run_config"]["init_lr"] = exp_config["run_config"]["base_lr"] * get_dist_size() + + run_config = run_config_cls(**exp_config["run_config"]) + + return run_config + + +def init_model( + network: nn.Module, + init_from: str or None = None, + backbone_init_from: str or None = None, + rand_init="trunc_normal", + last_gamma=None, +) -> None: + # initialization + init_modules(network, init_type=rand_init) + # zero gamma of last bn in each block + if last_gamma is not None: + zero_last_gamma(network, last_gamma) + + # load weight + if init_from is not None and os.path.isfile(init_from): + network.load_state_dict(load_state_dict_from_file(init_from)) + print(f"Loaded init from {init_from}") + elif backbone_init_from is not None and os.path.isfile(backbone_init_from): + network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from)) + print(f"Loaded backbone init from {backbone_init_from}") + else: + print(f"Random init ({rand_init}) with last gamma {last_gamma}") diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/trainer/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/trainer/__init__.py new file mode 100644 index 0000000..6b9219c --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/trainer/__init__.py @@ -0,0 +1,6 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .base import * +from .run_config import * diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/trainer/base.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/trainer/base.py new file mode 100644 index 0000000..67b281b --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/trainer/base.py @@ -0,0 +1,299 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import os + +import torch +import torch.nn as nn + +from src.models.efficientvit.apps.data_provider import DataProvider, parse_image_size +from src.models.efficientvit.apps.trainer.run_config import RunConfig +from src.models.efficientvit.apps.utils import EMA, dist_barrier, get_dist_local_rank, is_master +from src.models.efficientvit.models.nn.norm import reset_bn +from src.models.efficientvit.models.utils import is_parallel, load_state_dict_from_file + +__all__ = ["Trainer"] + + +class Trainer: + def __init__(self, path: str, model: nn.Module, data_provider: DataProvider): + self.path = os.path.realpath(os.path.expanduser(path)) + self.model = model.cuda() + self.data_provider = data_provider + + self.ema = None + + self.checkpoint_path = os.path.join(self.path, "checkpoint") + self.logs_path = os.path.join(self.path, "logs") + for path in [self.path, self.checkpoint_path, self.logs_path]: + os.makedirs(path, exist_ok=True) + + self.best_val = 0.0 + self.start_epoch = 0 + + @property + def network(self) -> nn.Module: + return self.model.module if is_parallel(self.model) else self.model + + @property + def eval_network(self) -> nn.Module: + if self.ema is None: + model = self.model + else: + model = self.ema.shadows + model = model.module if is_parallel(model) else model + return model + + def write_log(self, log_str, prefix="valid", print_log=True, mode="a") -> None: + if is_master(): + fout = open(os.path.join(self.logs_path, f"{prefix}.log"), mode) + fout.write(log_str + "\n") + fout.flush() + fout.close() + if print_log: + print(log_str) + + def save_model( + self, + checkpoint=None, + only_state_dict=True, + epoch=0, + model_name=None, + ) -> None: + if is_master(): + if checkpoint is None: + if only_state_dict: + checkpoint = {"state_dict": self.network.state_dict()} + else: + checkpoint = { + "state_dict": self.network.state_dict(), + "epoch": epoch, + "best_val": self.best_val, + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "ema": self.ema.state_dict() if self.ema is not None else None, + "scaler": self.scaler.state_dict() if self.enable_amp else None, + } + + model_name = model_name or "checkpoint.pt" + + latest_fname = os.path.join(self.checkpoint_path, "latest.txt") + model_path = os.path.join(self.checkpoint_path, model_name) + with open(latest_fname, "w") as _fout: + _fout.write(model_path + "\n") + torch.save(checkpoint, model_path) + + def load_model(self, model_fname=None) -> None: + latest_fname = os.path.join(self.checkpoint_path, "latest.txt") + if model_fname is None and os.path.exists(latest_fname): + with open(latest_fname, "r") as fin: + model_fname = fin.readline() + if len(model_fname) > 0 and model_fname[-1] == "\n": + model_fname = model_fname[:-1] + try: + if model_fname is None: + model_fname = f"{self.checkpoint_path}/checkpoint.pt" + elif not os.path.exists(model_fname): + model_fname = f"{self.checkpoint_path}/{os.path.basename(model_fname)}" + if not os.path.exists(model_fname): + model_fname = f"{self.checkpoint_path}/checkpoint.pt" + print(f"=> loading checkpoint {model_fname}") + checkpoint = load_state_dict_from_file(model_fname, False) + except Exception: + self.write_log(f"fail to load checkpoint from {self.checkpoint_path}") + return + + # load checkpoint + self.network.load_state_dict(checkpoint["state_dict"], strict=False) + log = [] + if "epoch" in checkpoint: + self.start_epoch = checkpoint["epoch"] + 1 + self.run_config.update_global_step(self.start_epoch) + log.append(f"epoch={self.start_epoch - 1}") + if "best_val" in checkpoint: + self.best_val = checkpoint["best_val"] + log.append(f"best_val={self.best_val:.2f}") + if "optimizer" in checkpoint: + self.optimizer.load_state_dict(checkpoint["optimizer"]) + log.append("optimizer") + if "lr_scheduler" in checkpoint: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + log.append("lr_scheduler") + if "ema" in checkpoint and self.ema is not None: + self.ema.load_state_dict(checkpoint["ema"]) + log.append("ema") + if "scaler" in checkpoint and self.enable_amp: + self.scaler.load_state_dict(checkpoint["scaler"]) + log.append("scaler") + self.write_log("Loaded: " + ", ".join(log)) + + """ validate """ + + def reset_bn( + self, + network: nn.Module or None = None, + subset_size: int = 16000, + subset_batch_size: int = 100, + data_loader=None, + progress_bar=False, + ) -> None: + network = network or self.network + if data_loader is None: + data_loader = [] + for data in self.data_provider.build_sub_train_loader(subset_size, subset_batch_size): + if isinstance(data, list): + data_loader.append(data[0]) + elif isinstance(data, dict): + data_loader.append(data["data"]) + elif isinstance(data, torch.Tensor): + data_loader.append(data) + else: + raise NotImplementedError + + network.eval() + reset_bn( + network, + data_loader, + sync=True, + progress_bar=progress_bar, + ) + + def _validate(self, model, data_loader, epoch) -> dict[str, any]: + raise NotImplementedError + + def validate(self, model=None, data_loader=None, is_test=True, epoch=0) -> dict[str, any]: + model = model or self.eval_network + if data_loader is None: + if is_test: + data_loader = self.data_provider.test + else: + data_loader = self.data_provider.valid + + model.eval() + return self._validate(model, data_loader, epoch) + + def multires_validate( + self, + model=None, + data_loader=None, + is_test=True, + epoch=0, + eval_image_size=None, + ) -> dict[str, dict[str, any]]: + eval_image_size = eval_image_size or self.run_config.eval_image_size + eval_image_size = eval_image_size or self.data_provider.image_size + model = model or self.eval_network + + if not isinstance(eval_image_size, list): + eval_image_size = [eval_image_size] + + output_dict = {} + for r in eval_image_size: + self.data_provider.assign_active_image_size(parse_image_size(r)) + if self.run_config.reset_bn: + self.reset_bn( + network=model, + subset_size=self.run_config.reset_bn_size, + subset_batch_size=self.run_config.reset_bn_batch_size, + progress_bar=True, + ) + output_dict[f"r{r}"] = self.validate(model, data_loader, is_test, epoch) + return output_dict + + """ training """ + + def prep_for_training(self, run_config: RunConfig, ema_decay: float or None = None, amp="fp32") -> None: + self.run_config = run_config + self.model = nn.parallel.DistributedDataParallel( + self.model.cuda(), + device_ids=[get_dist_local_rank()], + static_graph=True, + ) + + self.run_config.global_step = 0 + self.run_config.batch_per_epoch = len(self.data_provider.train) + assert self.run_config.batch_per_epoch > 0, "Training set is empty" + + # build optimizer + self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model) + + if ema_decay is not None: + self.ema = EMA(self.network, ema_decay) + + # amp + self.amp = amp + self.scaler = torch.cuda.amp.GradScaler(enabled=self.enable_amp) + + @property + def enable_amp(self) -> bool: + return self.amp != "fp32" + + @property + def amp_dtype(self) -> torch.dtype: + if self.amp == "fp16": + return torch.float16 + elif self.amp == "bf16": + return torch.bfloat16 + else: + return torch.float32 + + def sync_model(self): + print("Sync model") + self.save_model(model_name="sync.pt") + dist_barrier() + checkpoint = torch.load(os.path.join(self.checkpoint_path, "sync.pt"), map_location="cpu") + dist_barrier() + if is_master(): + os.remove(os.path.join(self.checkpoint_path, "sync.pt")) + dist_barrier() + + # load checkpoint + self.network.load_state_dict(checkpoint["state_dict"], strict=False) + if "optimizer" in checkpoint: + self.optimizer.load_state_dict(checkpoint["optimizer"]) + if "lr_scheduler" in checkpoint: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + if "ema" in checkpoint and self.ema is not None: + self.ema.load_state_dict(checkpoint["ema"]) + if "scaler" in checkpoint and self.enable_amp: + self.scaler.load_state_dict(checkpoint["scaler"]) + + def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]: + for key in feed_dict: + if isinstance(feed_dict[key], torch.Tensor): + feed_dict[key] = feed_dict[key].cuda() + return feed_dict + + def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]: + raise NotImplementedError + + def after_step(self) -> None: + self.scaler.unscale_(self.optimizer) + # gradient clip + if self.run_config.grad_clip is not None: + torch.nn.utils.clip_grad_value_(self.model.parameters(), self.run_config.grad_clip) + # update + self.scaler.step(self.optimizer) + self.scaler.update() + + self.lr_scheduler.step() + self.run_config.step() + # update ema + if self.ema is not None: + self.ema.step(self.network, self.run_config.global_step) + + def _train_one_epoch(self, epoch: int) -> dict[str, any]: + raise NotImplementedError + + def train_one_epoch(self, epoch: int) -> dict[str, any]: + self.model.train() + + self.data_provider.set_epoch(epoch) + + train_info_dict = self._train_one_epoch(epoch) + + return train_info_dict + + def train(self) -> None: + raise NotImplementedError diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/trainer/run_config.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/trainer/run_config.py new file mode 100644 index 0000000..6da12ca --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/trainer/run_config.py @@ -0,0 +1,115 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import json + +import numpy as np +import torch.nn as nn + +from src.models.efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer + +__all__ = ["Scheduler", "RunConfig"] + + +class Scheduler: + PROGRESS = 0 + + +class RunConfig: + n_epochs: int + init_lr: float + warmup_epochs: int + warmup_lr: float + lr_schedule_name: str + lr_schedule_param: dict + optimizer_name: str + optimizer_params: dict + weight_decay: float + no_wd_keys: list + grad_clip: float # allow none to turn off grad clipping + reset_bn: bool + reset_bn_size: int + reset_bn_batch_size: int + eval_image_size: list # allow none to use image_size in data_provider + + @property + def none_allowed(self): + return ["grad_clip", "eval_image_size"] + + def __init__(self, **kwargs): # arguments must be passed as kwargs + for k, val in kwargs.items(): + setattr(self, k, val) + + # check that all relevant configs are there + annotations = {} + for clas in type(self).mro(): + if hasattr(clas, "__annotations__"): + annotations.update(clas.__annotations__) + for k, k_type in annotations.items(): + assert hasattr(self, k), f"Key {k} with type {k_type} required for initialization." + attr = getattr(self, k) + if k in self.none_allowed: + k_type = (k_type, type(None)) + assert isinstance(attr, k_type), f"Key {k} must be type {k_type}, provided={attr}." + + self.global_step = 0 + self.batch_per_epoch = 1 + + def build_optimizer(self, network: nn.Module) -> tuple[any, any]: + r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler""" + param_dict = {} + for name, param in network.named_parameters(): + if param.requires_grad: + opt_config = [self.weight_decay, self.init_lr] + if self.no_wd_keys is not None and len(self.no_wd_keys) > 0: + if np.any([key in name for key in self.no_wd_keys]): + opt_config[0] = 0 + opt_key = json.dumps(opt_config) + param_dict[opt_key] = param_dict.get(opt_key, []) + [param] + + net_params = [] + for opt_key, param_list in param_dict.items(): + wd, lr = json.loads(opt_key) + net_params.append({"params": param_list, "weight_decay": wd, "lr": lr}) + + optimizer = build_optimizer(net_params, self.optimizer_name, self.optimizer_params, self.init_lr) + # build lr scheduler + if self.lr_schedule_name == "cosine": + decay_steps = [] + for epoch in self.lr_schedule_param.get("step", []): + decay_steps.append(epoch * self.batch_per_epoch) + decay_steps.append(self.n_epochs * self.batch_per_epoch) + decay_steps.sort() + lr_scheduler = CosineLRwithWarmup( + optimizer, + self.warmup_epochs * self.batch_per_epoch, + self.warmup_lr, + decay_steps, + ) + else: + raise NotImplementedError + return optimizer, lr_scheduler + + def update_global_step(self, epoch, batch_id=0) -> None: + self.global_step = epoch * self.batch_per_epoch + batch_id + Scheduler.PROGRESS = self.progress + + @property + def progress(self) -> float: + warmup_steps = self.warmup_epochs * self.batch_per_epoch + steps = max(0, self.global_step - warmup_steps) + return steps / (self.n_epochs * self.batch_per_epoch) + + def step(self) -> None: + self.global_step += 1 + Scheduler.PROGRESS = self.progress + + def get_remaining_epoch(self, epoch, post=True) -> int: + return self.n_epochs + self.warmup_epochs - epoch - int(post) + + def epoch_format(self, epoch: int) -> str: + epoch_format = f"%.{len(str(self.n_epochs))}d" + epoch_format = f"[{epoch_format}/{epoch_format}]" + epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs) + return epoch_format diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/__init__.py new file mode 100644 index 0000000..c826a22 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/__init__.py @@ -0,0 +1,12 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .dist import * +from .ema import * +from .export import * +from .init import * +from .lr import * +from .metric import * +from .misc import * +from .opt import * diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/dist.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/dist.py new file mode 100644 index 0000000..cbedea9 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/dist.py @@ -0,0 +1,71 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import os + +import torch +import torch.distributed + +from src.models.efficientvit.models.utils.list import list_mean, list_sum + +__all__ = [ + "dist_init", + "get_dist_rank", + "get_dist_size", + "is_master", + "dist_barrier", + "get_dist_local_rank", + "sync_tensor", +] + + +def dist_init() -> None: + try: + torch.distributed.init_process_group(backend="nccl") + assert torch.distributed.is_initialized() + except Exception: + # use torchpack + from torchpack import distributed as dist + + dist.init() + os.environ["RANK"] = f"{dist.rank()}" + os.environ["WORLD_SIZE"] = f"{dist.size()}" + os.environ["LOCAL_RANK"] = f"{dist.local_rank()}" + + +def get_dist_rank() -> int: + return int(os.environ["RANK"]) + + +def get_dist_size() -> int: + return int(os.environ["WORLD_SIZE"]) + + +def is_master() -> bool: + return get_dist_rank() == 0 + + +def dist_barrier() -> None: + torch.distributed.barrier() + + +def get_dist_local_rank() -> int: + return int(os.environ["LOCAL_RANK"]) + + +def sync_tensor(tensor: torch.Tensor or float, reduce="mean") -> torch.Tensor or list[torch.Tensor]: + if not isinstance(tensor, torch.Tensor): + tensor = torch.Tensor(1).fill_(tensor).cuda() + tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())] + torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False) + if reduce == "mean": + return list_mean(tensor_list) + elif reduce == "sum": + return list_sum(tensor_list) + elif reduce == "cat": + return torch.cat(tensor_list, dim=0) + elif reduce == "root": + return tensor_list[0] + else: + return tensor_list diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/ema.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/ema.py new file mode 100644 index 0000000..5d55a14 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/ema.py @@ -0,0 +1,42 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import copy +import math + +import torch +import torch.nn as nn + +from src.models.efficientvit.models.utils import is_parallel + +__all__ = ["EMA"] + + +def update_ema(ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float) -> None: + for k, v in ema.state_dict().items(): + if v.dtype.is_floating_point: + v -= (1.0 - decay) * (v - new_state_dict[k].detach()) + + +class EMA: + def __init__(self, model: nn.Module, decay: float, warmup_steps=2000): + self.shadows = copy.deepcopy(model.module if is_parallel(model) else model).eval() + self.decay = decay + self.warmup_steps = warmup_steps + + for p in self.shadows.parameters(): + p.requires_grad = False + + def step(self, model: nn.Module, global_step: int) -> None: + with torch.no_grad(): + msd = (model.module if is_parallel(model) else model).state_dict() + update_ema(self.shadows, msd, self.decay * (1 - math.exp(-global_step / self.warmup_steps))) + + def state_dict(self) -> dict[float, dict[str, torch.Tensor]]: + return {self.decay: self.shadows.state_dict()} + + def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None: + for decay in state_dict: + if decay == self.decay: + self.shadows.load_state_dict(state_dict[decay]) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/export.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/export.py new file mode 100644 index 0000000..85a03c7 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/export.py @@ -0,0 +1,45 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import io +import os + +import onnx +import torch +import torch.nn as nn +from onnxsim import simplify as simplify_func + +__all__ = ["export_onnx"] + + +def export_onnx(model: nn.Module, export_path: str, sample_inputs: any, simplify=True, opset=11) -> None: + """Export a model to a platform-specific onnx format. + + Args: + model: a torch.nn.Module object. + export_path: export location. + sample_inputs: Any. + simplify: a flag to turn on onnx-simplifier + opset: int + """ + model.eval() + + buffer = io.BytesIO() + with torch.no_grad(): + torch.onnx.export(model, sample_inputs, buffer, opset_version=opset) + buffer.seek(0, 0) + if simplify: + onnx_model = onnx.load_model(buffer) + onnx_model, success = simplify_func(onnx_model) + assert success + new_buffer = io.BytesIO() + onnx.save(onnx_model, new_buffer) + buffer = new_buffer + buffer.seek(0, 0) + + if buffer.getbuffer().nbytes > 0: + save_dir = os.path.dirname(export_path) + os.makedirs(save_dir, exist_ok=True) + with open(export_path, "wb") as f: + f.write(buffer.read()) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/init.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/init.py new file mode 100644 index 0000000..da1af28 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/init.py @@ -0,0 +1,66 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch +import torch.nn as nn +from torch.nn.modules.batchnorm import _BatchNorm + +__all__ = ["init_modules", "zero_last_gamma"] + + +def init_modules(model: nn.Module or list[nn.Module], init_type="trunc_normal") -> None: + _DEFAULT_INIT_PARAM = {"trunc_normal": 0.02} + + if isinstance(model, list): + for sub_module in model: + init_modules(sub_module, init_type) + else: + init_params = init_type.split("@") + init_params = float(init_params[1]) if len(init_params) > 1 else None + + if init_type.startswith("trunc_normal"): + init_func = lambda param: nn.init.trunc_normal_( + param, std=(init_params or _DEFAULT_INIT_PARAM["trunc_normal"]) + ) + else: + raise NotImplementedError + + for m in model.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)): + init_func(m.weight) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Embedding): + init_func(m.weight) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + m.weight.data.fill_(1) + m.bias.data.zero_() + else: + weight = getattr(m, "weight", None) + bias = getattr(m, "bias", None) + if isinstance(weight, torch.nn.Parameter): + init_func(weight) + if isinstance(bias, torch.nn.Parameter): + bias.data.zero_() + + +def zero_last_gamma(model: nn.Module, init_val=0) -> None: + import efficientvit.models.nn.ops as ops + + for m in model.modules(): + if isinstance(m, ops.ResidualBlock) and isinstance(m.shortcut, ops.IdentityLayer): + if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)): + parent_module = m.main.point_conv + elif isinstance(m.main, ops.ResBlock): + parent_module = m.main.conv2 + elif isinstance(m.main, ops.ConvLayer): + parent_module = m.main + elif isinstance(m.main, (ops.LiteMLA)): + parent_module = m.main.proj + else: + parent_module = None + if parent_module is not None: + norm = getattr(parent_module, "norm", None) + if norm is not None: + nn.init.constant_(norm.weight, init_val) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/lr.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/lr.py new file mode 100644 index 0000000..dbb134f --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/lr.py @@ -0,0 +1,44 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import math + +import torch + +from src.models.efficientvit.models.utils.list import val2list + +__all__ = ["CosineLRwithWarmup"] + + +class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + warmup_steps: int, + warmup_lr: float, + decay_steps: int or list[int], + last_epoch: int = -1, + ) -> None: + self.warmup_steps = warmup_steps + self.warmup_lr = warmup_lr + self.decay_steps = val2list(decay_steps) + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> list[float]: + if self.last_epoch < self.warmup_steps: + return [ + (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps + self.warmup_lr + for base_lr in self.base_lrs + ] + else: + current_steps = self.last_epoch - self.warmup_steps + decay_steps = [0] + self.decay_steps + idx = len(decay_steps) - 2 + for i, decay_step in enumerate(decay_steps[:-1]): + if decay_step <= current_steps < decay_steps[i + 1]: + idx = i + break + current_steps -= decay_steps[idx] + decay_step = decay_steps[idx + 1] - decay_steps[idx] + return [0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step)) for base_lr in self.base_lrs] diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/metric.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/metric.py new file mode 100644 index 0000000..0f1b154 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/metric.py @@ -0,0 +1,33 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch + +from src.models.efficientvit.apps.utils.dist import sync_tensor + +__all__ = ["AverageMeter"] + + +class AverageMeter: + """Computes and stores the average and current value.""" + + def __init__(self, is_distributed=True): + self.is_distributed = is_distributed + self.sum = 0 + self.count = 0 + + def _sync(self, val: torch.Tensor or int or float) -> torch.Tensor or int or float: + return sync_tensor(val, reduce="sum") if self.is_distributed else val + + def update(self, val: torch.Tensor or int or float, delta_n=1): + self.count += self._sync(delta_n) + self.sum += self._sync(val * delta_n) + + def get_count(self) -> torch.Tensor or int or float: + return self.count.item() if isinstance(self.count, torch.Tensor) and self.count.numel() == 1 else self.count + + @property + def avg(self): + avg = -1 if self.count == 0 else self.sum / self.count + return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/misc.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/misc.py new file mode 100644 index 0000000..c72a22f --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/misc.py @@ -0,0 +1,101 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import os + +import yaml + +__all__ = [ + "parse_with_yaml", + "parse_unknown_args", + "partial_update_config", + "resolve_and_load_config", + "load_config", + "dump_config", +] + + +def parse_with_yaml(config_str: str) -> str or dict: + try: + # add space manually for dict + if "{" in config_str and "}" in config_str and ":" in config_str: + out_str = config_str.replace(":", ": ") + else: + out_str = config_str + return yaml.safe_load(out_str) + except ValueError: + # return raw string if parsing fails + return config_str + + +def parse_unknown_args(unknown: list) -> dict: + """Parse unknown args.""" + index = 0 + parsed_dict = {} + while index < len(unknown): + key, val = unknown[index], unknown[index + 1] + index += 2 + if not key.startswith("--"): + continue + key = key[2:] + + # try parsing with either dot notation or full yaml notation + # Note that the vanilla case "--key value" will be parsed the same + if "." in key: + # key == a.b.c, val == val --> parsed_dict[a][b][c] = val + keys = key.split(".") + dict_to_update = parsed_dict + for key in keys[:-1]: + if not (key in dict_to_update and isinstance(dict_to_update[key], dict)): + dict_to_update[key] = {} + dict_to_update = dict_to_update[key] + dict_to_update[keys[-1]] = parse_with_yaml(val) # so we can parse lists, bools, etc... + else: + parsed_dict[key] = parse_with_yaml(val) + return parsed_dict + + +def partial_update_config(config: dict, partial_config: dict) -> dict: + for key in partial_config: + if key in config and isinstance(partial_config[key], dict) and isinstance(config[key], dict): + partial_update_config(config[key], partial_config[key]) + else: + config[key] = partial_config[key] + return config + + +def resolve_and_load_config(path: str, config_name="config.yaml") -> dict: + path = os.path.realpath(os.path.expanduser(path)) + if os.path.isdir(path): + config_path = os.path.join(path, config_name) + else: + config_path = path + if os.path.isfile(config_path): + pass + else: + raise Exception(f"Cannot find a valid config at {path}") + config = load_config(config_path) + return config + + +class SafeLoaderWithTuple(yaml.SafeLoader): + """A yaml safe loader with python tuple loading capabilities.""" + + def construct_python_tuple(self, node): + return tuple(self.construct_sequence(node)) + + +SafeLoaderWithTuple.add_constructor("tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple) + + +def load_config(filename: str) -> dict: + """Load a yaml file.""" + filename = os.path.realpath(os.path.expanduser(filename)) + return yaml.load(open(filename), Loader=SafeLoaderWithTuple) + + +def dump_config(config: dict, filename: str) -> None: + """Dump a config file""" + filename = os.path.realpath(os.path.expanduser(filename)) + yaml.dump(config, open(filename, "w"), sort_keys=False) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/opt.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/opt.py new file mode 100644 index 0000000..54f8e2e --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/apps/utils/opt.py @@ -0,0 +1,28 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch + +__all__ = ["REGISTERED_OPTIMIZER_DICT", "build_optimizer"] + +# register optimizer here +# name: optimizer, kwargs with default values +REGISTERED_OPTIMIZER_DICT: dict[str, tuple[type, dict[str, any]]] = { + "sgd": (torch.optim.SGD, {"momentum": 0.9, "nesterov": True}), + "adam": (torch.optim.Adam, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}), + "adamw": (torch.optim.AdamW, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}), +} + + +def build_optimizer( + net_params, optimizer_name: str, optimizer_params: dict or None, init_lr: float +) -> torch.optim.Optimizer: + optimizer_class, default_params = REGISTERED_OPTIMIZER_DICT[optimizer_name] + optimizer_params = optimizer_params or {} + + for key in default_params: + if key in optimizer_params: + default_params[key] = optimizer_params[key] + optimizer = optimizer_class(net_params, init_lr, **default_params) + return optimizer diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/cls_model_zoo.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/cls_model_zoo.py new file mode 100644 index 0000000..28f3771 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/cls_model_zoo.py @@ -0,0 +1,79 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from src.models.efficientvit.models.efficientvit import ( + EfficientViTCls, + efficientvit_cls_b0, + efficientvit_cls_b1, + efficientvit_cls_b2, + efficientvit_cls_b3, + efficientvit_cls_l1, + efficientvit_cls_l2, + efficientvit_cls_l3, +) +from src.models.efficientvit.models.nn.norm import set_norm_eps +from src.models.efficientvit.models.utils import load_state_dict_from_file + +__all__ = ["create_cls_model"] + + +REGISTERED_CLS_MODEL: dict[str, str] = { + "b0-r224": "assets/checkpoints/cls/b0-r224.pt", + ############################################### + "b1-r224": "assets/checkpoints/cls/b1-r224.pt", + "b1-r256": "assets/checkpoints/cls/b1-r256.pt", + "b1-r288": "assets/checkpoints/cls/b1-r288.pt", + ############################################### + "b2-r224": "assets/checkpoints/cls/b2-r224.pt", + "b2-r256": "assets/checkpoints/cls/b2-r256.pt", + "b2-r288": "assets/checkpoints/cls/b2-r288.pt", + ############################################### + "b3-r224": "assets/checkpoints/cls/b3-r224.pt", + "b3-r256": "assets/checkpoints/cls/b3-r256.pt", + "b3-r288": "assets/checkpoints/cls/b3-r288.pt", + ############################################### + "l1-r224": "assets/checkpoints/cls/l1-r224.pt", + ############################################### + "l2-r224": "assets/checkpoints/cls/l2-r224.pt", + "l2-r256": "assets/checkpoints/cls/l2-r256.pt", + "l2-r288": "assets/checkpoints/cls/l2-r288.pt", + "l2-r320": "assets/checkpoints/cls/l2-r320.pt", + "l2-r384": "assets/checkpoints/cls/l2-r384.pt", + ############################################### + "l3-r224": "assets/checkpoints/cls/l3-r224.pt", + "l3-r256": "assets/checkpoints/cls/l3-r256.pt", + "l3-r288": "assets/checkpoints/cls/l3-r288.pt", + "l3-r320": "assets/checkpoints/cls/l3-r320.pt", + "l3-r384": "assets/checkpoints/cls/l3-r384.pt", +} + + +def create_cls_model(name: str, pretrained=True, weight_url: str or None = None, **kwargs) -> EfficientViTCls: + model_dict = { + "b0": efficientvit_cls_b0, + "b1": efficientvit_cls_b1, + "b2": efficientvit_cls_b2, + "b3": efficientvit_cls_b3, + ######################### + "l1": efficientvit_cls_l1, + "l2": efficientvit_cls_l2, + "l3": efficientvit_cls_l3, + } + + model_id = name.split("-")[0] + if model_id not in model_dict: + raise ValueError(f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}") + else: + model = model_dict[model_id](**kwargs) + if model_id in ["l1", "l2", "l3"]: + set_norm_eps(model, 1e-7) + + if pretrained: + weight_url = weight_url or REGISTERED_CLS_MODEL.get(name, None) + if weight_url is None: + raise ValueError(f"Do not find the pretrained weight of {name}.") + else: + weight = load_state_dict_from_file(weight_url) + model.load_state_dict(weight) + return model diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/data_provider/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/data_provider/__init__.py new file mode 100644 index 0000000..8c803ca --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/data_provider/__init__.py @@ -0,0 +1,5 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .imagenet import * diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/data_provider/imagenet.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/data_provider/imagenet.py new file mode 100644 index 0000000..3832677 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/data_provider/imagenet.py @@ -0,0 +1,123 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import copy +import math +import os + +import torchvision.transforms as transforms +from torchvision.datasets import ImageFolder + +from src.models.efficientvit.apps.data_provider import DataProvider +from src.models.efficientvit.apps.data_provider.augment import RandAug +from src.models.efficientvit.apps.data_provider.random_resolution import MyRandomResizedCrop, get_interpolate +from src.models.efficientvit.apps.utils import partial_update_config +from src.models.efficientvit.models.utils import val2list + +__all__ = ["ImageNetDataProvider"] + + +class ImageNetDataProvider(DataProvider): + name = "imagenet" + + data_dir = "/dataset/imagenet" + n_classes = 1000 + _DEFAULT_RRC_CONFIG = { + "train_interpolate": "random", + "test_interpolate": "bicubic", + "test_crop_ratio": 1.0, + } + + def __init__( + self, + data_dir: str or None = None, + rrc_config: dict or None = None, + data_aug: dict or list[dict] or None = None, + ########################################### + train_batch_size=128, + test_batch_size=128, + valid_size: int or float or None = None, + n_worker=8, + image_size: int or list[int] = 224, + num_replicas: int or None = None, + rank: int or None = None, + train_ratio: float or None = None, + drop_last: bool = False, + ): + self.data_dir = data_dir or self.data_dir + self.rrc_config = partial_update_config( + copy.deepcopy(self._DEFAULT_RRC_CONFIG), + rrc_config or {}, + ) + self.data_aug = data_aug + + super().__init__( + train_batch_size, + test_batch_size, + valid_size, + n_worker, + image_size, + num_replicas, + rank, + train_ratio, + drop_last, + ) + + def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any: + image_size = (image_size or self.active_image_size)[0] + crop_size = int(math.ceil(image_size / self.rrc_config["test_crop_ratio"])) + return transforms.Compose( + [ + transforms.Resize( + crop_size, + interpolation=get_interpolate(self.rrc_config["test_interpolate"]), + ), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + transforms.Normalize(**self.mean_std), + ] + ) + + def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any: + image_size = image_size or self.image_size + + # random_resize_crop -> random_horizontal_flip + train_transforms = [ + MyRandomResizedCrop(interpolation=self.rrc_config["train_interpolate"]), + transforms.RandomHorizontalFlip(), + ] + + # data augmentation + post_aug = [] + if self.data_aug is not None: + for aug_op in val2list(self.data_aug): + if aug_op["name"] == "randaug": + data_aug = RandAug(aug_op, mean=self.mean_std["mean"]) + elif aug_op["name"] == "erase": + from timm.data.random_erasing import RandomErasing + + random_erase = RandomErasing(aug_op["p"], device="cpu") + post_aug.append(random_erase) + data_aug = None + else: + raise NotImplementedError + if data_aug is not None: + train_transforms.append(data_aug) + train_transforms = [ + *train_transforms, + transforms.ToTensor(), + transforms.Normalize(**self.mean_std), + *post_aug, + ] + return transforms.Compose(train_transforms) + + def build_datasets(self) -> tuple[any, any, any]: + train_transform = self.build_train_transform() + valid_transform = self.build_valid_transform() + + train_dataset = ImageFolder(os.path.join(self.data_dir, "train"), train_transform) + test_dataset = ImageFolder(os.path.join(self.data_dir, "val"), valid_transform) + + train_dataset, val_dataset = self.sample_val_dataset(train_dataset, valid_transform) + return train_dataset, val_dataset, test_dataset diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/__init__.py new file mode 100644 index 0000000..b7887dd --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/__init__.py @@ -0,0 +1,6 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .cls_run_config import * +from .cls_trainer import * diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/cls_run_config.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/cls_run_config.py new file mode 100644 index 0000000..aa20c31 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/cls_run_config.py @@ -0,0 +1,18 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from src.models.efficientvit.apps.trainer.run_config import RunConfig + +__all__ = ["ClsRunConfig"] + + +class ClsRunConfig(RunConfig): + label_smooth: float + mixup_config: dict # allow none to turn off mixup + bce: bool + mesa: dict + + @property + def none_allowed(self): + return ["mixup_config", "mesa"] + super().none_allowed diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/cls_trainer.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/cls_trainer.py new file mode 100644 index 0000000..7b8a256 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/cls_trainer.py @@ -0,0 +1,233 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import os +import sys + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +from src.models.efficientvit.apps.trainer import Trainer +from src.models.efficientvit.apps.utils import AverageMeter, is_master, sync_tensor +from src.models.efficientvit.clscore.trainer.utils import accuracy, apply_mixup, label_smooth +from src.models.efficientvit.models.utils import list_join, list_mean, torch_random_choices + +__all__ = ["ClsTrainer"] + + +class ClsTrainer(Trainer): + def __init__( + self, + path: str, + model: nn.Module, + data_provider, + auto_restart_thresh: float or None = None, + ) -> None: + super().__init__( + path=path, + model=model, + data_provider=data_provider, + ) + self.auto_restart_thresh = auto_restart_thresh + self.test_criterion = nn.CrossEntropyLoss() + + def _validate(self, model, data_loader, epoch) -> dict[str, any]: + val_loss = AverageMeter() + val_top1 = AverageMeter() + val_top5 = AverageMeter() + + with torch.no_grad(): + with tqdm( + total=len(data_loader), + desc=f"Validate Epoch #{epoch + 1}", + disable=not is_master(), + file=sys.stdout, + ) as t: + for images, labels in data_loader: + images, labels = images.cuda(), labels.cuda() + # compute output + output = model(images) + loss = self.test_criterion(output, labels) + val_loss.update(loss, images.shape[0]) + if self.data_provider.n_classes >= 100: + acc1, acc5 = accuracy(output, labels, topk=(1, 5)) + val_top5.update(acc5[0], images.shape[0]) + else: + acc1 = accuracy(output, labels, topk=(1,))[0] + val_top1.update(acc1[0], images.shape[0]) + + t.set_postfix( + { + "loss": val_loss.avg, + "top1": val_top1.avg, + "top5": val_top5.avg, + "#samples": val_top1.get_count(), + "bs": images.shape[0], + "res": images.shape[2], + } + ) + t.update() + return { + "val_top1": val_top1.avg, + "val_loss": val_loss.avg, + **({"val_top5": val_top5.avg} if val_top5.count > 0 else {}), + } + + def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]: + images = feed_dict["data"].cuda() + labels = feed_dict["label"].cuda() + + # label smooth + labels = label_smooth(labels, self.data_provider.n_classes, self.run_config.label_smooth) + + # mixup + if self.run_config.mixup_config is not None: + # choose active mixup config + mix_weight_list = [mix_list[2] for mix_list in self.run_config.mixup_config["op"]] + active_id = torch_random_choices( + list(range(len(self.run_config.mixup_config["op"]))), + weight_list=mix_weight_list, + ) + active_id = int(sync_tensor(active_id, reduce="root")) + active_mixup_config = self.run_config.mixup_config["op"][active_id] + mixup_type, mixup_alpha = active_mixup_config[:2] + + lam = float(torch.distributions.beta.Beta(mixup_alpha, mixup_alpha).sample()) + lam = float(np.clip(lam, 0, 1)) + lam = float(sync_tensor(lam, reduce="root")) + + images, labels = apply_mixup(images, labels, lam, mixup_type) + + return { + "data": images, + "label": labels, + } + + def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]: + images = feed_dict["data"] + labels = feed_dict["label"] + + # setup mesa + if self.run_config.mesa is not None and self.run_config.mesa["thresh"] <= self.run_config.progress: + ema_model = self.ema.shadows + with torch.inference_mode(): + ema_output = ema_model(images).detach() + ema_output = torch.clone(ema_output) + ema_output = F.sigmoid(ema_output).detach() + else: + ema_output = None + + with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.enable_amp): + output = self.model(images) + loss = self.train_criterion(output, labels) + # mesa loss + if ema_output is not None: + mesa_loss = self.train_criterion(output, ema_output) + loss = loss + self.run_config.mesa["ratio"] * mesa_loss + self.scaler.scale(loss).backward() + + # calc train top1 acc + if self.run_config.mixup_config is None: + top1 = accuracy(output, torch.argmax(labels, dim=1), topk=(1,))[0][0] + else: + top1 = None + + return { + "loss": loss, + "top1": top1, + } + + def _train_one_epoch(self, epoch: int) -> dict[str, any]: + train_loss = AverageMeter() + train_top1 = AverageMeter() + + with tqdm( + total=len(self.data_provider.train), + desc="Train Epoch #{}".format(epoch + 1), + disable=not is_master(), + file=sys.stdout, + ) as t: + for images, labels in self.data_provider.train: + feed_dict = {"data": images, "label": labels} + + # preprocessing + feed_dict = self.before_step(feed_dict) + # clear gradient + self.optimizer.zero_grad() + # forward & backward + output_dict = self.run_step(feed_dict) + # update: optimizer, lr_scheduler + self.after_step() + + # update train metrics + train_loss.update(output_dict["loss"], images.shape[0]) + if output_dict["top1"] is not None: + train_top1.update(output_dict["top1"], images.shape[0]) + + # tqdm + postfix_dict = { + "loss": train_loss.avg, + "top1": train_top1.avg, + "bs": images.shape[0], + "res": images.shape[2], + "lr": list_join( + sorted(set([group["lr"] for group in self.optimizer.param_groups])), + "#", + "%.1E", + ), + "progress": self.run_config.progress, + } + t.set_postfix(postfix_dict) + t.update() + return { + **({"train_top1": train_top1.avg} if train_top1.count > 0 else {}), + "train_loss": train_loss.avg, + } + + def train(self, trials=0, save_freq=1) -> None: + if self.run_config.bce: + self.train_criterion = nn.BCEWithLogitsLoss() + else: + self.train_criterion = nn.CrossEntropyLoss() + + for epoch in range(self.start_epoch, self.run_config.n_epochs + self.run_config.warmup_epochs): + train_info_dict = self.train_one_epoch(epoch) + # eval + val_info_dict = self.multires_validate(epoch=epoch) + avg_top1 = list_mean([info_dict["val_top1"] for info_dict in val_info_dict.values()]) + is_best = avg_top1 > self.best_val + self.best_val = max(avg_top1, self.best_val) + + if self.auto_restart_thresh is not None: + if self.best_val - avg_top1 > self.auto_restart_thresh: + self.write_log(f"Abnormal accuracy drop: {self.best_val} -> {avg_top1}") + self.load_model(os.path.join(self.checkpoint_path, "model_best.pt")) + return self.train(trials + 1, save_freq) + + # log + val_log = self.run_config.epoch_format(epoch) + val_log += f"\tval_top1={avg_top1:.2f}({self.best_val:.2f})" + val_log += "\tVal(" + for key in list(val_info_dict.values())[0]: + if key == "val_top1": + continue + val_log += f"{key}={list_mean([info_dict[key] for info_dict in val_info_dict.values()]):.2f}," + val_log += ")\tTrain(" + for key, val in train_info_dict.items(): + val_log += f"{key}={val:.2E}," + val_log += ( + f'lr={list_join(sorted(set([group["lr"] for group in self.optimizer.param_groups])), "#", "%.1E")})' + ) + self.write_log(val_log, prefix="valid", print_log=False) + + # save model + if (epoch + 1) % save_freq == 0 or (is_best and self.run_config.progress > 0.8): + self.save_model( + only_state_dict=False, + epoch=epoch, + model_name="model_best.pt" if is_best else "checkpoint.pt", + ) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/__init__.py new file mode 100644 index 0000000..b11c938 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/__init__.py @@ -0,0 +1,7 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .label_smooth import * +from .metric import * +from .mixup import * diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/label_smooth.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/label_smooth.py new file mode 100644 index 0000000..d7f1fab --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/label_smooth.py @@ -0,0 +1,18 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch + +__all__ = ["label_smooth"] + + +def label_smooth(target: torch.Tensor, n_classes: int, smooth_factor=0.1) -> torch.Tensor: + # convert to one-hot + batch_size = target.shape[0] + target = torch.unsqueeze(target, 1) + soft_target = torch.zeros((batch_size, n_classes), device=target.device) + soft_target.scatter_(1, target, 1) + # label smoothing + soft_target = torch.add(soft_target * (1 - smooth_factor), smooth_factor / n_classes) + return soft_target diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/metric.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/metric.py new file mode 100644 index 0000000..6a5d908 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/metric.py @@ -0,0 +1,23 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch + +__all__ = ["accuracy"] + + +def accuracy(output: torch.Tensor, target: torch.Tensor, topk=(1,)) -> list[torch.Tensor]: + """Computes the precision@k for the specified values of k.""" + maxk = max(topk) + batch_size = target.shape[0] + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/mixup.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/mixup.py new file mode 100644 index 0000000..6bd01e5 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/clscore/trainer/utils/mixup.py @@ -0,0 +1,65 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch.distributions + +from src.models.efficientvit.apps.data_provider.augment import rand_bbox +from src.models.efficientvit.models.utils.random import torch_randint, torch_shuffle + +__all__ = ["apply_mixup", "mixup", "cutmix"] + + +def apply_mixup( + images: torch.Tensor, + labels: torch.Tensor, + lam: float, + mix_type="mixup", +) -> tuple[torch.Tensor, torch.Tensor]: + if mix_type == "mixup": + return mixup(images, labels, lam) + elif mix_type == "cutmix": + return cutmix(images, labels, lam) + else: + raise NotImplementedError + + +def mixup( + images: torch.Tensor, + target: torch.Tensor, + lam: float, +) -> tuple[torch.Tensor, torch.Tensor]: + rand_index = torch_shuffle(list(range(0, images.shape[0]))) + + flipped_images = images[rand_index] + flipped_target = target[rand_index] + + return ( + lam * images + (1 - lam) * flipped_images, + lam * target + (1 - lam) * flipped_target, + ) + + +def cutmix( + images: torch.Tensor, + target: torch.Tensor, + lam: float, +) -> tuple[torch.Tensor, torch.Tensor]: + rand_index = torch_shuffle(list(range(0, images.shape[0]))) + + flipped_images = images[rand_index] + flipped_target = target[rand_index] + + b, _, h, w = images.shape + lam_list = [] + for i in range(b): + bbx1, bby1, bbx2, bby2 = rand_bbox( + h=h, + w=w, + lam=lam, + rand_func=torch_randint, + ) + images[i, :, bby1:bby2, bbx1:bbx2] = flipped_images[i, :, bby1:bby2, bbx1:bbx2] + lam_list.append(1 - ((bbx2 - bbx1) * (bby2 - bby1) / (h * w))) + lam = torch.Tensor(lam_list).to(images.device).view(b, 1) + return images, lam * target + (1 - lam) * flipped_target diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/__init__.py new file mode 100644 index 0000000..cea677f --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/__init__.py @@ -0,0 +1,8 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .backbone import * +from .cls import * +from .sam import * +from .seg import * diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/backbone.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/backbone.py new file mode 100644 index 0000000..3b92dfc --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/backbone.py @@ -0,0 +1,376 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch +import torch.nn as nn + +from src.models.efficientvit.models.nn import ( + ConvLayer, + DSConv, + EfficientViTBlock, + FusedMBConv, + IdentityLayer, + MBConv, + OpSequential, + ResBlock, + ResidualBlock, +) +from src.models.efficientvit.models.utils import build_kwargs_from_config + +__all__ = [ + "EfficientViTBackbone", + "efficientvit_backbone_b0", + "efficientvit_backbone_b1", + "efficientvit_backbone_b2", + "efficientvit_backbone_b3", + "EfficientViTLargeBackbone", + "efficientvit_backbone_l0", + "efficientvit_backbone_l1", + "efficientvit_backbone_l2", + "efficientvit_backbone_l3", +] + + +class EfficientViTBackbone(nn.Module): + def __init__( + self, + width_list: list[int], + depth_list: list[int], + in_channels=3, + dim=32, + expand_ratio=4, + norm="bn2d", + act_func="hswish", + ) -> None: + super().__init__() + + self.width_list = [] + # input stem + self.input_stem = [ + ConvLayer( + in_channels=in_channels, + out_channels=width_list[0], + stride=2, + norm=norm, + act_func=act_func, + ) + ] + for _ in range(depth_list[0]): + block = self.build_local_block( + in_channels=width_list[0], + out_channels=width_list[0], + stride=1, + expand_ratio=1, + norm=norm, + act_func=act_func, + ) + self.input_stem.append(ResidualBlock(block, IdentityLayer())) + in_channels = width_list[0] + self.input_stem = OpSequential(self.input_stem) + self.width_list.append(in_channels) + + # stages + self.stages = [] + for w, d in zip(width_list[1:3], depth_list[1:3]): + stage = [] + for i in range(d): + stride = 2 if i == 0 else 1 + block = self.build_local_block( + in_channels=in_channels, + out_channels=w, + stride=stride, + expand_ratio=expand_ratio, + norm=norm, + act_func=act_func, + ) + block = ResidualBlock(block, IdentityLayer() if stride == 1 else None) + stage.append(block) + in_channels = w + self.stages.append(OpSequential(stage)) + self.width_list.append(in_channels) + + for w, d in zip(width_list[3:], depth_list[3:]): + stage = [] + block = self.build_local_block( + in_channels=in_channels, + out_channels=w, + stride=2, + expand_ratio=expand_ratio, + norm=norm, + act_func=act_func, + fewer_norm=True, + ) + stage.append(ResidualBlock(block, None)) + in_channels = w + + for _ in range(d): + stage.append( + EfficientViTBlock( + in_channels=in_channels, + dim=dim, + expand_ratio=expand_ratio, + norm=norm, + act_func=act_func, + ) + ) + self.stages.append(OpSequential(stage)) + self.width_list.append(in_channels) + self.stages = nn.ModuleList(self.stages) + + @staticmethod + def build_local_block( + in_channels: int, + out_channels: int, + stride: int, + expand_ratio: float, + norm: str, + act_func: str, + fewer_norm: bool = False, + ) -> nn.Module: + if expand_ratio == 1: + block = DSConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + use_bias=(True, False) if fewer_norm else False, + norm=(None, norm) if fewer_norm else norm, + act_func=(act_func, None), + ) + else: + block = MBConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + use_bias=(True, True, False) if fewer_norm else False, + norm=(None, None, norm) if fewer_norm else norm, + act_func=(act_func, act_func, None), + ) + return block + + def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: + output_dict = {"input": x} + output_dict["stage0"] = x = self.input_stem(x) + for stage_id, stage in enumerate(self.stages, 1): + output_dict["stage%d" % stage_id] = x = stage(x) + output_dict["stage_final"] = x + return output_dict + + +def efficientvit_backbone_b0(**kwargs) -> EfficientViTBackbone: + backbone = EfficientViTBackbone( + width_list=[8, 16, 32, 64, 128], + depth_list=[1, 2, 2, 2, 2], + dim=16, + **build_kwargs_from_config(kwargs, EfficientViTBackbone), + ) + return backbone + + +def efficientvit_backbone_b1(**kwargs) -> EfficientViTBackbone: + backbone = EfficientViTBackbone( + width_list=[16, 32, 64, 128, 256], + depth_list=[1, 2, 3, 3, 4], + dim=16, + **build_kwargs_from_config(kwargs, EfficientViTBackbone), + ) + return backbone + + +def efficientvit_backbone_b2(**kwargs) -> EfficientViTBackbone: + backbone = EfficientViTBackbone( + width_list=[24, 48, 96, 192, 384], + depth_list=[1, 3, 4, 4, 6], + dim=32, + **build_kwargs_from_config(kwargs, EfficientViTBackbone), + ) + return backbone + + +def efficientvit_backbone_b3(**kwargs) -> EfficientViTBackbone: + backbone = EfficientViTBackbone( + width_list=[32, 64, 128, 256, 512], + depth_list=[1, 4, 6, 6, 9], + dim=32, + **build_kwargs_from_config(kwargs, EfficientViTBackbone), + ) + return backbone + + +class EfficientViTLargeBackbone(nn.Module): + def __init__( + self, + width_list: list[int], + depth_list: list[int], + block_list: list[str] or None = None, + expand_list: list[float] or None = None, + fewer_norm_list: list[bool] or None = None, + in_channels=3, + qkv_dim=32, + norm="bn2d", + act_func="gelu", + ) -> None: + super().__init__() + block_list = block_list or ["res", "fmb", "fmb", "mb", "att"] + expand_list = expand_list or [1, 4, 4, 4, 6] + fewer_norm_list = fewer_norm_list or [False, False, False, True, True] + + self.width_list = [] + self.stages = [] + # stage 0 + stage0 = [ + ConvLayer( + in_channels=3, + out_channels=width_list[0], + stride=2, + norm=norm, + act_func=act_func, + ) + ] + for _ in range(depth_list[0]): + block = self.build_local_block( + block=block_list[0], + in_channels=width_list[0], + out_channels=width_list[0], + stride=1, + expand_ratio=expand_list[0], + norm=norm, + act_func=act_func, + fewer_norm=fewer_norm_list[0], + ) + stage0.append(ResidualBlock(block, IdentityLayer())) + in_channels = width_list[0] + self.stages.append(OpSequential(stage0)) + self.width_list.append(in_channels) + + for stage_id, (w, d) in enumerate(zip(width_list[1:], depth_list[1:]), start=1): + stage = [] + block = self.build_local_block( + block="mb" if block_list[stage_id] not in ["mb", "fmb"] else block_list[stage_id], + in_channels=in_channels, + out_channels=w, + stride=2, + expand_ratio=expand_list[stage_id] * 4, + norm=norm, + act_func=act_func, + fewer_norm=fewer_norm_list[stage_id], + ) + stage.append(ResidualBlock(block, None)) + in_channels = w + + for _ in range(d): + if block_list[stage_id].startswith("att"): + stage.append( + EfficientViTBlock( + in_channels=in_channels, + dim=qkv_dim, + expand_ratio=expand_list[stage_id], + scales=(3,) if block_list[stage_id] == "att@3" else (5,), + norm=norm, + act_func=act_func, + ) + ) + else: + block = self.build_local_block( + block=block_list[stage_id], + in_channels=in_channels, + out_channels=in_channels, + stride=1, + expand_ratio=expand_list[stage_id], + norm=norm, + act_func=act_func, + fewer_norm=fewer_norm_list[stage_id], + ) + block = ResidualBlock(block, IdentityLayer()) + stage.append(block) + self.stages.append(OpSequential(stage)) + self.width_list.append(in_channels) + self.stages = nn.ModuleList(self.stages) + + @staticmethod + def build_local_block( + block: str, + in_channels: int, + out_channels: int, + stride: int, + expand_ratio: float, + norm: str, + act_func: str, + fewer_norm: bool = False, + ) -> nn.Module: + if block == "res": + block = ResBlock( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + use_bias=(True, False) if fewer_norm else False, + norm=(None, norm) if fewer_norm else norm, + act_func=(act_func, None), + ) + elif block == "fmb": + block = FusedMBConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + use_bias=(True, False) if fewer_norm else False, + norm=(None, norm) if fewer_norm else norm, + act_func=(act_func, None), + ) + elif block == "mb": + block = MBConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + expand_ratio=expand_ratio, + use_bias=(True, True, False) if fewer_norm else False, + norm=(None, None, norm) if fewer_norm else norm, + act_func=(act_func, act_func, None), + ) + else: + raise ValueError(block) + return block + + def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: + output_dict = {"input": x} + for stage_id, stage in enumerate(self.stages): + output_dict["stage%d" % stage_id] = x = stage(x) + output_dict["stage_final"] = x + return output_dict + + +def efficientvit_backbone_l0(**kwargs) -> EfficientViTLargeBackbone: + backbone = EfficientViTLargeBackbone( + width_list=[32, 64, 128, 256, 512], + depth_list=[1, 1, 1, 4, 4], + **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), + ) + return backbone + + +def efficientvit_backbone_l1(**kwargs) -> EfficientViTLargeBackbone: + backbone = EfficientViTLargeBackbone( + width_list=[32, 64, 128, 256, 512], + depth_list=[1, 1, 1, 6, 6], + **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), + ) + return backbone + + +def efficientvit_backbone_l2(**kwargs) -> EfficientViTLargeBackbone: + backbone = EfficientViTLargeBackbone( + width_list=[32, 64, 128, 256, 512], + depth_list=[1, 2, 2, 8, 8], + **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), + ) + return backbone + + +def efficientvit_backbone_l3(**kwargs) -> EfficientViTLargeBackbone: + backbone = EfficientViTLargeBackbone( + width_list=[64, 128, 256, 512, 1024], + depth_list=[1, 2, 2, 8, 8], + **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), + ) + return backbone diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/cls.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/cls.py new file mode 100644 index 0000000..ac7e3ba --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/cls.py @@ -0,0 +1,162 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch +import torch.nn as nn + +from src.models.efficientvit.models.efficientvit.backbone import EfficientViTBackbone, EfficientViTLargeBackbone +from src.models.efficientvit.models.nn import ConvLayer, LinearLayer, OpSequential +from src.models.efficientvit.models.utils import build_kwargs_from_config + +__all__ = [ + "EfficientViTCls", + ###################### + "efficientvit_cls_b0", + "efficientvit_cls_b1", + "efficientvit_cls_b2", + "efficientvit_cls_b3", + ###################### + "efficientvit_cls_l1", + "efficientvit_cls_l2", + "efficientvit_cls_l3", +] + + +class ClsHead(OpSequential): + def __init__( + self, + in_channels: int, + width_list: list[int], + n_classes=1000, + dropout=0.0, + norm="bn2d", + act_func="hswish", + fid="stage_final", + ): + ops = [ + ConvLayer(in_channels, width_list[0], 1, norm=norm, act_func=act_func), + nn.AdaptiveAvgPool2d(output_size=1), + LinearLayer(width_list[0], width_list[1], False, norm="ln", act_func=act_func), + LinearLayer(width_list[1], n_classes, True, dropout, None, None), + ] + super().__init__(ops) + + self.fid = fid + + def forward(self, feed_dict: dict[str, torch.Tensor]) -> torch.Tensor: + x = feed_dict[self.fid] + return OpSequential.forward(self, x) + + +class EfficientViTCls(nn.Module): + def __init__(self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, head: ClsHead) -> None: + super().__init__() + self.backbone = backbone + self.head = head + + def forward(self, x: torch.Tensor) -> torch.Tensor: + feed_dict = self.backbone(x) + output = self.head(feed_dict) + return output + + +def efficientvit_cls_b0(**kwargs) -> EfficientViTCls: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_b0 + + backbone = efficientvit_backbone_b0(**kwargs) + + head = ClsHead( + in_channels=128, + width_list=[1024, 1280], + **build_kwargs_from_config(kwargs, ClsHead), + ) + model = EfficientViTCls(backbone, head) + return model + + +def efficientvit_cls_b1(**kwargs) -> EfficientViTCls: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_b1 + + backbone = efficientvit_backbone_b1(**kwargs) + + head = ClsHead( + in_channels=256, + width_list=[1536, 1600], + **build_kwargs_from_config(kwargs, ClsHead), + ) + model = EfficientViTCls(backbone, head) + return model + + +def efficientvit_cls_b2(**kwargs) -> EfficientViTCls: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_b2 + + backbone = efficientvit_backbone_b2(**kwargs) + + head = ClsHead( + in_channels=384, + width_list=[2304, 2560], + **build_kwargs_from_config(kwargs, ClsHead), + ) + model = EfficientViTCls(backbone, head) + return model + + +def efficientvit_cls_b3(**kwargs) -> EfficientViTCls: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_b3 + + backbone = efficientvit_backbone_b3(**kwargs) + + head = ClsHead( + in_channels=512, + width_list=[2304, 2560], + **build_kwargs_from_config(kwargs, ClsHead), + ) + model = EfficientViTCls(backbone, head) + return model + + +def efficientvit_cls_l1(**kwargs) -> EfficientViTCls: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_l1 + + backbone = efficientvit_backbone_l1(**kwargs) + + head = ClsHead( + in_channels=512, + width_list=[3072, 3200], + act_func="gelu", + **build_kwargs_from_config(kwargs, ClsHead), + ) + model = EfficientViTCls(backbone, head) + return model + + +def efficientvit_cls_l2(**kwargs) -> EfficientViTCls: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_l2 + + backbone = efficientvit_backbone_l2(**kwargs) + + head = ClsHead( + in_channels=512, + width_list=[3072, 3200], + act_func="gelu", + **build_kwargs_from_config(kwargs, ClsHead), + ) + model = EfficientViTCls(backbone, head) + return model + + +def efficientvit_cls_l3(**kwargs) -> EfficientViTCls: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_l3 + + backbone = efficientvit_backbone_l3(**kwargs) + + head = ClsHead( + in_channels=1024, + width_list=[6144, 6400], + act_func="gelu", + **build_kwargs_from_config(kwargs, ClsHead), + ) + model = EfficientViTCls(backbone, head) + return model diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/sam.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/sam.py new file mode 100644 index 0000000..373a304 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/sam.py @@ -0,0 +1,664 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import copy +from typing import Any, Dict, List + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +from src.models.segment_anything import SamAutomaticMaskGenerator +from src.models.segment_anything.modeling import MaskDecoder, PromptEncoder, TwoWayTransformer +from src.models.segment_anything.modeling.mask_decoder import MaskDecoder +from src.models.segment_anything.modeling.prompt_encoder import PromptEncoder +from src.models.segment_anything.utils.amg import build_all_layer_point_grids +from src.models.segment_anything.utils.transforms import ResizeLongestSide +from torchvision.transforms.functional import resize, to_pil_image + +from src.models.efficientvit.models.efficientvit.backbone import EfficientViTBackbone, EfficientViTLargeBackbone +from src.models.efficientvit.models.nn import ( + ConvLayer, + DAGBlock, + FusedMBConv, + IdentityLayer, + MBConv, + OpSequential, + ResBlock, + ResidualBlock, + UpSampleLayer, + build_norm, +) +from src.models.efficientvit.models.utils import build_kwargs_from_config, get_device + +__all__ = [ + "SamPad", + "SamResize", + "SamNeck", + "EfficientViTSamImageEncoder", + "EfficientViTSam", + "EfficientViTSamPredictor", + "EfficientViTSamAutomaticMaskGenerator", + "efficientvit_sam_l0", + "efficientvit_sam_l1", + "efficientvit_sam_l2", + "efficientvit_sam_xl0", + "efficientvit_sam_xl1", +] + + +class SamPad: + def __init__(self, size: int, fill: float = 0, pad_mode="corner") -> None: + self.size = size + self.fill = fill + self.pad_mode = pad_mode + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + h, w = image.shape[-2:] + th, tw = self.size, self.size + assert th >= h and tw >= w + if self.pad_mode == "corner": + image = F.pad(image, (0, tw - w, 0, th - h), value=self.fill) + else: + raise NotImplementedError + return image + + def __repr__(self) -> str: + return f"{type(self).__name__}(size={self.size},mode={self.pad_mode},fill={self.fill})" + + +class SamResize: + def __init__(self, size: int) -> None: + self.size = size + + def __call__(self, image: np.ndarray) -> np.ndarray: + h, w, _ = image.shape + long_side = max(h, w) + if long_side != self.size: + return self.apply_image(image) + else: + return image + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.size) + return np.array(resize(to_pil_image(image), target_size)) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) + + def __repr__(self) -> str: + return f"{type(self).__name__}(size={self.size})" + + +class SamNeck(DAGBlock): + def __init__( + self, + fid_list: list[str], + in_channel_list: list[int], + head_width: int, + head_depth: int, + expand_ratio: float, + middle_op: str, + out_dim: int = 256, + norm="bn2d", + act_func="gelu", + ): + inputs = {} + for fid, in_channel in zip(fid_list, in_channel_list): + inputs[fid] = OpSequential( + [ + ConvLayer(in_channel, head_width, 1, norm=norm, act_func=None), + UpSampleLayer(size=(64, 64)), + ] + ) + + middle = [] + for _ in range(head_depth): + if middle_op == "mb": + block = MBConv( + head_width, + head_width, + expand_ratio=expand_ratio, + norm=norm, + act_func=(act_func, act_func, None), + ) + elif middle_op == "fmb": + block = FusedMBConv( + head_width, + head_width, + expand_ratio=expand_ratio, + norm=norm, + act_func=(act_func, None), + ) + elif middle_op == "res": + block = ResBlock( + head_width, + head_width, + expand_ratio=expand_ratio, + norm=norm, + act_func=(act_func, None), + ) + else: + raise NotImplementedError + middle.append(ResidualBlock(block, IdentityLayer())) + middle = OpSequential(middle) + + outputs = { + "sam_encoder": OpSequential( + [ + ConvLayer( + head_width, + out_dim, + 1, + use_bias=True, + norm=None, + act_func=None, + ), + ] + ) + } + + super(SamNeck, self).__init__(inputs, "add", None, middle=middle, outputs=outputs) + + +class EfficientViTSamImageEncoder(nn.Module): + def __init__(self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, neck: SamNeck): + super().__init__() + self.backbone = backbone + self.neck = neck + + self.norm = build_norm("ln2d", 256) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + feed_dict = self.backbone(x) + feed_dict = self.neck(feed_dict) + + output = feed_dict["sam_encoder"] + output = self.norm(output) + return output + + +class EfficientViTSam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: EfficientViTSamImageEncoder, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + image_size: tuple[int, int] = (1024, 512), + ) -> None: + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + + self.image_size = image_size + + self.transform = transforms.Compose( + [ + SamResize(self.image_size[1]), + transforms.ToTensor(), + transforms.Normalize( + mean=[123.675 / 255, 116.28 / 255, 103.53 / 255], + std=[58.395 / 255, 57.12 / 255, 57.375 / 255], + ), + SamPad(self.image_size[1]), + ] + ) + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: tuple[int, ...], + original_size: tuple[int, ...], + ) -> torch.Tensor: + masks = F.interpolate( + masks, + (self.image_size[0], self.image_size[0]), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ): + input_images = torch.stack([x["image"] for x in batched_input], dim=0) + + image_embeddings = self.image_encoder(input_images) + + outputs = [] + iou_outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + outputs.append(low_res_masks) + iou_outputs.append(iou_predictions) + + outputs = torch.stack([out for out in outputs], dim=0) + iou_outputs = torch.stack(iou_outputs, dim=0) + + return outputs, iou_outputs + + +class EfficientViTSamPredictor: + def __init__(self, sam_model: EfficientViTSam) -> None: + self.model = sam_model + self.reset_image() + + @property + def transform(self): + return self + + @property + def device(self): + return get_device(self.model) + + def reset_image(self) -> None: + self.is_image_set = False + self.features = None + self.original_size = None + self.input_size = None + + def apply_coords(self, coords: np.ndarray, im_size=None) -> np.ndarray: + old_h, old_w = self.original_size + new_h, new_w = self.input_size + coords = copy.deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, im_size=None) -> np.ndarray: + boxes = self.apply_coords(boxes.reshape(-1, 2, 2)) + return boxes.reshape(-1, 4) + + @torch.inference_mode() + def set_image(self, image: np.ndarray, image_format: str = "RGB") -> None: + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + self.reset_image() + + self.original_size = image.shape[:2] + self.input_size = ResizeLongestSide.get_preprocess_shape( + *self.original_size, long_side_length=self.model.image_size[0] + ) + + torch_data = self.model.transform(image).unsqueeze(dim=0).to(get_device(self.model)) + self.features = self.model.image_encoder(torch_data) + self.is_image_set = True + + def predict( + self, + point_coords: np.ndarray or None = None, + point_labels: np.ndarray or None = None, + box: np.ndarray or None = None, + mask_input: np.ndarray or None = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + device = get_device(self.model) + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert point_labels is not None, "point_labels must be supplied if point_coords is supplied." + point_coords = self.apply_coords(point_coords) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.apply_boxes(box) + box_torch = torch.as_tensor(box, dtype=torch.float, device=device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks = masks[0].detach().cpu().numpy() + iou_predictions = iou_predictions[0].detach().cpu().numpy() + low_res_masks = low_res_masks[0].detach().cpu().numpy() + return masks, iou_predictions, low_res_masks + + @torch.inference_mode() + def predict_torch( + self, + point_coords: torch.Tensor or None = None, + point_labels: torch.Tensor or None = None, + boxes: torch.Tensor or None = None, + mask_input: torch.Tensor or None = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + +class EfficientViTSamAutomaticMaskGenerator(SamAutomaticMaskGenerator): + def __init__( + self, + model: EfficientViTSam, + points_per_side: int or None = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: list[np.ndarray] or None = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = EfficientViTSamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + +def build_efficientvit_sam(image_encoder: EfficientViTSamImageEncoder, image_size: int) -> EfficientViTSam: + return EfficientViTSam( + image_encoder=image_encoder, + prompt_encoder=PromptEncoder( + embed_dim=256, + image_embedding_size=(64, 64), + input_image_size=(image_size, image_size), # Modified from (1024, 1024) + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=256, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=256, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + image_size=(image_size, image_size), # Modified from (1024, image_size) + ) + + +def efficientvit_sam_l0(image_size: int = 512, **kwargs) -> EfficientViTSam: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_l0 + + backbone = efficientvit_backbone_l0(**kwargs) + + neck = SamNeck( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + head_width=256, + head_depth=4, + expand_ratio=1, + middle_op="fmb", + ) + + image_encoder = EfficientViTSamImageEncoder(backbone, neck) + return build_efficientvit_sam(image_encoder, image_size) + + +def efficientvit_sam_l1(image_size: int = 512, **kwargs) -> EfficientViTSam: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_l1 + + backbone = efficientvit_backbone_l1(**kwargs) + + neck = SamNeck( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + head_width=256, + head_depth=8, + expand_ratio=1, + middle_op="fmb", + ) + + image_encoder = EfficientViTSamImageEncoder(backbone, neck) + return build_efficientvit_sam(image_encoder, image_size) + + +def efficientvit_sam_l2(image_size: int = 512, **kwargs) -> EfficientViTSam: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_l2 + + backbone = efficientvit_backbone_l2(**kwargs) + + neck = SamNeck( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + head_width=256, + head_depth=12, + expand_ratio=1, + middle_op="fmb", + ) + + image_encoder = EfficientViTSamImageEncoder(backbone, neck) + return build_efficientvit_sam(image_encoder, image_size) + + +def efficientvit_sam_xl0(image_size: int = 1024, **kwargs) -> EfficientViTSam: + from src.models.efficientvit.models.efficientvit.backbone import EfficientViTLargeBackbone + + backbone = EfficientViTLargeBackbone( + width_list=[32, 64, 128, 256, 512, 1024], + depth_list=[0, 1, 1, 2, 3, 3], + block_list=["res", "fmb", "fmb", "fmb", "att@3", "att@3"], + expand_list=[1, 4, 4, 4, 4, 6], + fewer_norm_list=[False, False, False, False, True, True], + **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), + ) + + neck = SamNeck( + fid_list=["stage5", "stage4", "stage3"], + in_channel_list=[1024, 512, 256], + head_width=256, + head_depth=6, + expand_ratio=4, + middle_op="fmb", + ) + + image_encoder = EfficientViTSamImageEncoder(backbone, neck) + return build_efficientvit_sam(image_encoder, image_size) + + +def efficientvit_sam_xl1(image_size: int = 1024, **kwargs) -> EfficientViTSam: + from src.models.efficientvit.models.efficientvit.backbone import EfficientViTLargeBackbone + + backbone = EfficientViTLargeBackbone( + width_list=[32, 64, 128, 256, 512, 1024], + depth_list=[1, 2, 2, 4, 6, 6], + block_list=["res", "fmb", "fmb", "fmb", "att@3", "att@3"], + expand_list=[1, 4, 4, 4, 4, 6], + fewer_norm_list=[False, False, False, False, True, True], + **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), + ) + + neck = SamNeck( + fid_list=["stage5", "stage4", "stage3"], + in_channel_list=[1024, 512, 256], + head_width=256, + head_depth=12, + expand_ratio=4, + middle_op="fmb", + ) + + image_encoder = EfficientViTSamImageEncoder(backbone, neck) + return build_efficientvit_sam(image_encoder, image_size) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/seg.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/seg.py new file mode 100644 index 0000000..3de5697 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/efficientvit/seg.py @@ -0,0 +1,343 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch +import torch.nn as nn + +from src.models.efficientvit.models.efficientvit.backbone import EfficientViTBackbone, EfficientViTLargeBackbone +from src.models.efficientvit.models.nn import ( + ConvLayer, + DAGBlock, + FusedMBConv, + IdentityLayer, + MBConv, + OpSequential, + ResidualBlock, + UpSampleLayer, +) +from src.models.efficientvit.models.utils import build_kwargs_from_config + +__all__ = [ + "EfficientViTSeg", + "efficientvit_seg_b0", + "efficientvit_seg_b1", + "efficientvit_seg_b2", + "efficientvit_seg_b3", + "efficientvit_seg_l1", + "efficientvit_seg_l2", +] + + +class SegHead(DAGBlock): + def __init__( + self, + fid_list: list[str], + in_channel_list: list[int], + stride_list: list[int], + head_stride: int, + head_width: int, + head_depth: int, + expand_ratio: float, + middle_op: str, + final_expand: float or None, + n_classes: int, + dropout=0, + norm="bn2d", + act_func="hswish", + ): + inputs = {} + for fid, in_channel, stride in zip(fid_list, in_channel_list, stride_list): + factor = stride // head_stride + if factor == 1: + inputs[fid] = ConvLayer(in_channel, head_width, 1, norm=norm, act_func=None) + else: + inputs[fid] = OpSequential( + [ + ConvLayer(in_channel, head_width, 1, norm=norm, act_func=None), + UpSampleLayer(factor=factor), + ] + ) + + middle = [] + for _ in range(head_depth): + if middle_op == "mbconv": + block = MBConv( + head_width, + head_width, + expand_ratio=expand_ratio, + norm=norm, + act_func=(act_func, act_func, None), + ) + elif middle_op == "fmbconv": + block = FusedMBConv( + head_width, + head_width, + expand_ratio=expand_ratio, + norm=norm, + act_func=(act_func, None), + ) + else: + raise NotImplementedError + middle.append(ResidualBlock(block, IdentityLayer())) + middle = OpSequential(middle) + + outputs = { + "segout": OpSequential( + [ + ( + None + if final_expand is None + else ConvLayer(head_width, head_width * final_expand, 1, norm=norm, act_func=act_func) + ), + ConvLayer( + head_width * (final_expand or 1), + n_classes, + 1, + use_bias=True, + dropout=dropout, + norm=None, + act_func=None, + ), + ] + ) + } + + super(SegHead, self).__init__(inputs, "add", None, middle=middle, outputs=outputs) + + +class EfficientViTSeg(nn.Module): + def __init__(self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, head: SegHead) -> None: + super().__init__() + self.backbone = backbone + self.head = head + + def forward(self, x: torch.Tensor) -> torch.Tensor: + feed_dict = self.backbone(x) + feed_dict = self.head(feed_dict) + + return feed_dict["segout"] + + +def efficientvit_seg_b0(dataset: str, **kwargs) -> EfficientViTSeg: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_b0 + + backbone = efficientvit_backbone_b0(**kwargs) + + if dataset == "cityscapes": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[128, 64, 32], + stride_list=[32, 16, 8], + head_stride=8, + head_width=32, + head_depth=1, + expand_ratio=4, + middle_op="mbconv", + final_expand=4, + n_classes=19, + **build_kwargs_from_config(kwargs, SegHead), + ) + else: + raise NotImplementedError + model = EfficientViTSeg(backbone, head) + return model + + +def efficientvit_seg_b1(dataset: str, **kwargs) -> EfficientViTSeg: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_b1 + + backbone = efficientvit_backbone_b1(**kwargs) + + if dataset == "cityscapes": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[256, 128, 64], + stride_list=[32, 16, 8], + head_stride=8, + head_width=64, + head_depth=3, + expand_ratio=4, + middle_op="mbconv", + final_expand=4, + n_classes=19, + **build_kwargs_from_config(kwargs, SegHead), + ) + elif dataset == "ade20k": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[256, 128, 64], + stride_list=[32, 16, 8], + head_stride=8, + head_width=64, + head_depth=3, + expand_ratio=4, + middle_op="mbconv", + final_expand=None, + n_classes=150, + **build_kwargs_from_config(kwargs, SegHead), + ) + else: + raise NotImplementedError + model = EfficientViTSeg(backbone, head) + return model + + +def efficientvit_seg_b2(dataset: str, **kwargs) -> EfficientViTSeg: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_b2 + + backbone = efficientvit_backbone_b2(**kwargs) + + if dataset == "cityscapes": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[384, 192, 96], + stride_list=[32, 16, 8], + head_stride=8, + head_width=96, + head_depth=3, + expand_ratio=4, + middle_op="mbconv", + final_expand=4, + n_classes=19, + **build_kwargs_from_config(kwargs, SegHead), + ) + elif dataset == "ade20k": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[384, 192, 96], + stride_list=[32, 16, 8], + head_stride=8, + head_width=96, + head_depth=3, + expand_ratio=4, + middle_op="mbconv", + final_expand=None, + n_classes=150, + **build_kwargs_from_config(kwargs, SegHead), + ) + else: + raise NotImplementedError + model = EfficientViTSeg(backbone, head) + return model + + +def efficientvit_seg_b3(dataset: str, **kwargs) -> EfficientViTSeg: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_b3 + + backbone = efficientvit_backbone_b3(**kwargs) + + if dataset == "cityscapes": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + stride_list=[32, 16, 8], + head_stride=8, + head_width=128, + head_depth=3, + expand_ratio=4, + middle_op="mbconv", + final_expand=4, + n_classes=19, + **build_kwargs_from_config(kwargs, SegHead), + ) + elif dataset == "ade20k": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + stride_list=[32, 16, 8], + head_stride=8, + head_width=128, + head_depth=3, + expand_ratio=4, + middle_op="mbconv", + final_expand=None, + n_classes=150, + **build_kwargs_from_config(kwargs, SegHead), + ) + else: + raise NotImplementedError + model = EfficientViTSeg(backbone, head) + return model + + +def efficientvit_seg_l1(dataset: str, **kwargs) -> EfficientViTSeg: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_l1 + + backbone = efficientvit_backbone_l1(**kwargs) + + if dataset == "cityscapes": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + stride_list=[32, 16, 8], + head_stride=8, + head_width=256, + head_depth=3, + expand_ratio=1, + middle_op="fmbconv", + final_expand=None, + n_classes=19, + act_func="gelu", + **build_kwargs_from_config(kwargs, SegHead), + ) + elif dataset == "ade20k": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + stride_list=[32, 16, 8], + head_stride=8, + head_width=128, + head_depth=3, + expand_ratio=4, + middle_op="fmbconv", + final_expand=8, + n_classes=150, + act_func="gelu", + **build_kwargs_from_config(kwargs, SegHead), + ) + else: + raise NotImplementedError + model = EfficientViTSeg(backbone, head) + return model + + +def efficientvit_seg_l2(dataset: str, **kwargs) -> EfficientViTSeg: + from src.models.efficientvit.models.efficientvit.backbone import efficientvit_backbone_l2 + + backbone = efficientvit_backbone_l2(**kwargs) + + if dataset == "cityscapes": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + stride_list=[32, 16, 8], + head_stride=8, + head_width=256, + head_depth=5, + expand_ratio=1, + middle_op="fmbconv", + final_expand=None, + n_classes=19, + act_func="gelu", + **build_kwargs_from_config(kwargs, SegHead), + ) + elif dataset == "ade20k": + head = SegHead( + fid_list=["stage4", "stage3", "stage2"], + in_channel_list=[512, 256, 128], + stride_list=[32, 16, 8], + head_stride=8, + head_width=128, + head_depth=3, + expand_ratio=4, + middle_op="fmbconv", + final_expand=8, + n_classes=150, + act_func="gelu", + **build_kwargs_from_config(kwargs, SegHead), + ) + else: + raise NotImplementedError + model = EfficientViTSeg(backbone, head) + return model diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/__init__.py new file mode 100644 index 0000000..d615215 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/__init__.py @@ -0,0 +1,8 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .act import * +from .drop import * +from .norm import * +from .ops import * diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/act.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/act.py new file mode 100644 index 0000000..42e4405 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/act.py @@ -0,0 +1,30 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from functools import partial + +import torch.nn as nn + +from src.models.efficientvit.models.utils import build_kwargs_from_config + +__all__ = ["build_act"] + + +# register activation function here +REGISTERED_ACT_DICT: dict[str, type] = { + "relu": nn.ReLU, + "relu6": nn.ReLU6, + "hswish": nn.Hardswish, + "silu": nn.SiLU, + "gelu": partial(nn.GELU, approximate="tanh"), +} + + +def build_act(name: str, **kwargs) -> nn.Module or None: + if name in REGISTERED_ACT_DICT: + act_cls = REGISTERED_ACT_DICT[name] + args = build_kwargs_from_config(kwargs, act_cls) + return act_cls(**args) + else: + return None diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/drop.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/drop.py new file mode 100644 index 0000000..4f78c77 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/drop.py @@ -0,0 +1,88 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import numpy as np +import torch +import torch.nn as nn + +# from src.models.efficientvit.apps.trainer.run_config import Scheduler +from src.models.efficientvit.models.nn.ops import IdentityLayer, ResidualBlock +from src.models.efficientvit.models.utils import build_kwargs_from_config + +__all__ = ["apply_drop_func"] + + +def apply_drop_func(network: nn.Module, drop_config: dict[str, any] or None) -> None: + if drop_config is None: + return + + drop_lookup_table = { + "droppath": apply_droppath, + } + + drop_func = drop_lookup_table[drop_config["name"]] + drop_kwargs = build_kwargs_from_config(drop_config, drop_func) + + drop_func(network, **drop_kwargs) + + +def apply_droppath( + network: nn.Module, + drop_prob: float, + linear_decay=True, + scheduled=True, + skip=0, +) -> None: + all_valid_blocks = [] + for m in network.modules(): + for name, sub_module in m.named_children(): + if isinstance(sub_module, ResidualBlock) and isinstance(sub_module.shortcut, IdentityLayer): + all_valid_blocks.append((m, name, sub_module)) + all_valid_blocks = all_valid_blocks[skip:] + for i, (m, name, sub_module) in enumerate(all_valid_blocks): + prob = drop_prob * (i + 1) / len(all_valid_blocks) if linear_decay else drop_prob + new_module = DropPathResidualBlock( + sub_module.main, + sub_module.shortcut, + sub_module.post_act, + sub_module.pre_norm, + prob, + scheduled, + ) + m._modules[name] = new_module + + +class DropPathResidualBlock(ResidualBlock): + def __init__( + self, + main: nn.Module, + shortcut: nn.Module or None, + post_act=None, + pre_norm: nn.Module or None = None, + ###################################### + drop_prob: float = 0, + scheduled=True, + ): + super().__init__(main, shortcut, post_act, pre_norm) + + self.drop_prob = drop_prob + self.scheduled = scheduled + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.training or self.drop_prob == 0 or not isinstance(self.shortcut, IdentityLayer): + return ResidualBlock.forward(self, x) + else: + drop_prob = self.drop_prob + if self.scheduled: + drop_prob *= np.clip(Scheduler.PROGRESS, 0, 1) + keep_prob = 1 - drop_prob + + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + + res = self.forward_main(x) / keep_prob * random_tensor + self.shortcut(x) + if self.post_act: + res = self.post_act(res) + return res diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/norm.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/norm.py new file mode 100644 index 0000000..63cec74 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/norm.py @@ -0,0 +1,137 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch +import torch.nn as nn +from torch.nn.modules.batchnorm import _BatchNorm + +from src.models.efficientvit.models.utils import build_kwargs_from_config + +__all__ = ["LayerNorm2d", "build_norm", "reset_bn", "set_norm_eps"] + + +class LayerNorm2d(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = x - torch.mean(x, dim=1, keepdim=True) + out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps) + if self.elementwise_affine: + out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) + return out + + +# register normalization function here +REGISTERED_NORM_DICT: dict[str, type] = { + "bn2d": nn.BatchNorm2d, + "ln": nn.LayerNorm, + "ln2d": LayerNorm2d, +} + + +def build_norm(name="bn2d", num_features=None, **kwargs) -> nn.Module or None: + if name in ["ln", "ln2d"]: + kwargs["normalized_shape"] = num_features + else: + kwargs["num_features"] = num_features + if name in REGISTERED_NORM_DICT: + norm_cls = REGISTERED_NORM_DICT[name] + args = build_kwargs_from_config(kwargs, norm_cls) + return norm_cls(**args) + else: + return None + + +def reset_bn( + model: nn.Module, + data_loader: list, + sync=True, + progress_bar=False, +) -> None: + import copy + + import torch.nn.functional as F + from tqdm import tqdm + + from src.models.efficientvit.apps.utils import AverageMeter, is_master, sync_tensor + from src.models.efficientvit.models.utils import get_device, list_join + + bn_mean = {} + bn_var = {} + + tmp_model = copy.deepcopy(model) + for name, m in tmp_model.named_modules(): + if isinstance(m, _BatchNorm): + bn_mean[name] = AverageMeter(is_distributed=False) + bn_var[name] = AverageMeter(is_distributed=False) + + def new_forward(bn, mean_est, var_est): + def lambda_forward(x): + x = x.contiguous() + if sync: + batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1 + batch_mean = sync_tensor(batch_mean, reduce="cat") + batch_mean = torch.mean(batch_mean, dim=0, keepdim=True) + + batch_var = (x - batch_mean) * (x - batch_mean) + batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + batch_var = sync_tensor(batch_var, reduce="cat") + batch_var = torch.mean(batch_var, dim=0, keepdim=True) + else: + batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1 + batch_var = (x - batch_mean) * (x - batch_mean) + batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + + batch_mean = torch.squeeze(batch_mean) + batch_var = torch.squeeze(batch_var) + + mean_est.update(batch_mean.data, x.size(0)) + var_est.update(batch_var.data, x.size(0)) + + # bn forward using calculated mean & var + _feature_dim = batch_mean.shape[0] + return F.batch_norm( + x, + batch_mean, + batch_var, + bn.weight[:_feature_dim], + bn.bias[:_feature_dim], + False, + 0.0, + bn.eps, + ) + + return lambda_forward + + m.forward = new_forward(m, bn_mean[name], bn_var[name]) + + # skip if there is no batch normalization layers in the network + if len(bn_mean) == 0: + return + + tmp_model.eval() + with torch.no_grad(): + with tqdm(total=len(data_loader), desc="reset bn", disable=not progress_bar or not is_master()) as t: + for images in data_loader: + images = images.to(get_device(tmp_model)) + tmp_model(images) + t.set_postfix( + { + "bs": images.size(0), + "res": list_join(images.shape[-2:], "x"), + } + ) + t.update() + + for name, m in model.named_modules(): + if name in bn_mean and bn_mean[name].count > 0: + feature_dim = bn_mean[name].avg.size(0) + assert isinstance(m, _BatchNorm) + m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg) + m.running_var.data[:feature_dim].copy_(bn_var[name].avg) + + +def set_norm_eps(model: nn.Module, eps: float or None = None) -> None: + for m in model.modules(): + if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)): + if eps is not None: + m.eps = eps diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/ops.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/ops.py new file mode 100644 index 0000000..742a8b5 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/nn/ops.py @@ -0,0 +1,614 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast + +from src.models.efficientvit.models.nn.act import build_act +from src.models.efficientvit.models.nn.norm import build_norm +from src.models.efficientvit.models.utils import get_same_padding, list_sum, resize, val2list, val2tuple + +__all__ = [ + "ConvLayer", + "UpSampleLayer", + "LinearLayer", + "IdentityLayer", + "DSConv", + "MBConv", + "FusedMBConv", + "ResBlock", + "LiteMLA", + "EfficientViTBlock", + "ResidualBlock", + "DAGBlock", + "OpSequential", +] + + +################################################################################# +# Basic Layers # +################################################################################# + + +class ConvLayer(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + dilation=1, + groups=1, + use_bias=False, + dropout=0, + norm="bn2d", + act_func="relu", + ): + super(ConvLayer, self).__init__() + + padding = get_same_padding(kernel_size) + padding *= dilation + + self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=(kernel_size, kernel_size), + stride=(stride, stride), + padding=padding, + dilation=(dilation, dilation), + groups=groups, + bias=use_bias, + ) + self.norm = build_norm(norm, num_features=out_channels) + self.act = build_act(act_func) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dropout is not None: + x = self.dropout(x) + x = self.conv(x) + if self.norm: + x = self.norm(x) + if self.act: + x = self.act(x) + return x + + +class UpSampleLayer(nn.Module): + def __init__( + self, + mode="bicubic", + size: int or tuple[int, int] or list[int] or None = None, + factor=2, + align_corners=False, + ): + super(UpSampleLayer, self).__init__() + self.mode = mode + self.size = val2list(size, 2) if size is not None else None + self.factor = None if self.size is not None else factor + self.align_corners = align_corners + + @autocast(enabled=False) + def forward(self, x: torch.Tensor) -> torch.Tensor: + if (self.size is not None and tuple(x.shape[-2:]) == self.size) or self.factor == 1: + return x + if x.dtype in [torch.float16, torch.bfloat16]: + x = x.float() + return resize(x, self.size, self.factor, self.mode, self.align_corners) + + +class LinearLayer(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + use_bias=True, + dropout=0, + norm=None, + act_func=None, + ): + super(LinearLayer, self).__init__() + + self.dropout = nn.Dropout(dropout, inplace=False) if dropout > 0 else None + self.linear = nn.Linear(in_features, out_features, use_bias) + self.norm = build_norm(norm, num_features=out_features) + self.act = build_act(act_func) + + def _try_squeeze(self, x: torch.Tensor) -> torch.Tensor: + if x.dim() > 2: + x = torch.flatten(x, start_dim=1) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._try_squeeze(x) + if self.dropout: + x = self.dropout(x) + x = self.linear(x) + if self.norm: + x = self.norm(x) + if self.act: + x = self.act(x) + return x + + +class IdentityLayer(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +################################################################################# +# Basic Blocks # +################################################################################# + + +class DSConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + use_bias=False, + norm=("bn2d", "bn2d"), + act_func=("relu6", None), + ): + super(DSConv, self).__init__() + + use_bias = val2tuple(use_bias, 2) + norm = val2tuple(norm, 2) + act_func = val2tuple(act_func, 2) + + self.depth_conv = ConvLayer( + in_channels, + in_channels, + kernel_size, + stride, + groups=in_channels, + norm=norm[0], + act_func=act_func[0], + use_bias=use_bias[0], + ) + self.point_conv = ConvLayer( + in_channels, + out_channels, + 1, + norm=norm[1], + act_func=act_func[1], + use_bias=use_bias[1], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.depth_conv(x) + x = self.point_conv(x) + return x + + +class MBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=6, + use_bias=False, + norm=("bn2d", "bn2d", "bn2d"), + act_func=("relu6", "relu6", None), + ): + super(MBConv, self).__init__() + + use_bias = val2tuple(use_bias, 3) + norm = val2tuple(norm, 3) + act_func = val2tuple(act_func, 3) + mid_channels = mid_channels or round(in_channels * expand_ratio) + + self.inverted_conv = ConvLayer( + in_channels, + mid_channels, + 1, + stride=1, + norm=norm[0], + act_func=act_func[0], + use_bias=use_bias[0], + ) + self.depth_conv = ConvLayer( + mid_channels, + mid_channels, + kernel_size, + stride=stride, + groups=mid_channels, + norm=norm[1], + act_func=act_func[1], + use_bias=use_bias[1], + ) + self.point_conv = ConvLayer( + mid_channels, + out_channels, + 1, + norm=norm[2], + act_func=act_func[2], + use_bias=use_bias[2], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.inverted_conv(x) + x = self.depth_conv(x) + x = self.point_conv(x) + return x + + +class FusedMBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=6, + groups=1, + use_bias=False, + norm=("bn2d", "bn2d"), + act_func=("relu6", None), + ): + super().__init__() + use_bias = val2tuple(use_bias, 2) + norm = val2tuple(norm, 2) + act_func = val2tuple(act_func, 2) + + mid_channels = mid_channels or round(in_channels * expand_ratio) + + self.spatial_conv = ConvLayer( + in_channels, + mid_channels, + kernel_size, + stride, + groups=groups, + use_bias=use_bias[0], + norm=norm[0], + act_func=act_func[0], + ) + self.point_conv = ConvLayer( + mid_channels, + out_channels, + 1, + use_bias=use_bias[1], + norm=norm[1], + act_func=act_func[1], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.spatial_conv(x) + x = self.point_conv(x) + return x + + +class ResBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=1, + use_bias=False, + norm=("bn2d", "bn2d"), + act_func=("relu6", None), + ): + super().__init__() + use_bias = val2tuple(use_bias, 2) + norm = val2tuple(norm, 2) + act_func = val2tuple(act_func, 2) + + mid_channels = mid_channels or round(in_channels * expand_ratio) + + self.conv1 = ConvLayer( + in_channels, + mid_channels, + kernel_size, + stride, + use_bias=use_bias[0], + norm=norm[0], + act_func=act_func[0], + ) + self.conv2 = ConvLayer( + mid_channels, + out_channels, + kernel_size, + 1, + use_bias=use_bias[1], + norm=norm[1], + act_func=act_func[1], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.conv2(x) + return x + + +class LiteMLA(nn.Module): + r"""Lightweight multi-scale linear attention""" + + def __init__( + self, + in_channels: int, + out_channels: int, + heads: int or None = None, + heads_ratio: float = 1.0, + dim=8, + use_bias=False, + norm=(None, "bn2d"), + act_func=(None, None), + kernel_func="relu", + scales: tuple[int, ...] = (5,), + eps=1.0e-15, + ): + super(LiteMLA, self).__init__() + self.eps = eps + heads = heads or int(in_channels // dim * heads_ratio) + + total_dim = heads * dim + + use_bias = val2tuple(use_bias, 2) + norm = val2tuple(norm, 2) + act_func = val2tuple(act_func, 2) + + self.dim = dim + self.qkv = ConvLayer( + in_channels, + 3 * total_dim, + 1, + use_bias=use_bias[0], + norm=norm[0], + act_func=act_func[0], + ) + self.aggreg = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d( + 3 * total_dim, + 3 * total_dim, + scale, + padding=get_same_padding(scale), + groups=3 * total_dim, + bias=use_bias[0], + ), + nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]), + ) + for scale in scales + ] + ) + self.kernel_func = build_act(kernel_func, inplace=False) + + self.proj = ConvLayer( + total_dim * (1 + len(scales)), + out_channels, + 1, + use_bias=use_bias[1], + norm=norm[1], + act_func=act_func[1], + ) + + @autocast(enabled=False) + def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor: + B, _, H, W = list(qkv.size()) + + if qkv.dtype == torch.float16: + qkv = qkv.float() + + qkv = torch.reshape( + qkv, + ( + B, + -1, + 3 * self.dim, + H * W, + ), + ) + q, k, v = ( + qkv[:, :, 0 : self.dim], + qkv[:, :, self.dim : 2 * self.dim], + qkv[:, :, 2 * self.dim :], + ) + + # lightweight linear attention + q = self.kernel_func(q) + k = self.kernel_func(k) + + # linear matmul + trans_k = k.transpose(-1, -2) + + v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1) + vk = torch.matmul(v, trans_k) + out = torch.matmul(vk, q) + if out.dtype == torch.bfloat16: + out = out.float() + out = out[:, :, :-1] / (out[:, :, -1:] + self.eps) + + out = torch.reshape(out, (B, -1, H, W)) + return out + + @autocast(enabled=False) + def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor: + B, _, H, W = list(qkv.size()) + + qkv = torch.reshape( + qkv, + ( + B, + -1, + 3 * self.dim, + H * W, + ), + ) + q, k, v = ( + qkv[:, :, 0 : self.dim], + qkv[:, :, self.dim : 2 * self.dim], + qkv[:, :, 2 * self.dim :], + ) + + q = self.kernel_func(q) + k = self.kernel_func(k) + + att_map = torch.matmul(k.transpose(-1, -2), q) # b h n n + original_dtype = att_map.dtype + if original_dtype in [torch.float16, torch.bfloat16]: + att_map = att_map.float() + att_map = att_map / (torch.sum(att_map, dim=2, keepdim=True) + self.eps) # b h n n + att_map = att_map.to(original_dtype) + out = torch.matmul(v, att_map) # b h d n + + out = torch.reshape(out, (B, -1, H, W)) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # generate multi-scale q, k, v + qkv = self.qkv(x) + multi_scale_qkv = [qkv] + for op in self.aggreg: + multi_scale_qkv.append(op(qkv)) + qkv = torch.cat(multi_scale_qkv, dim=1) + + H, W = list(qkv.size())[-2:] + if H * W > self.dim: + out = self.relu_linear_att(qkv) + else: + out = self.relu_quadratic_att(qkv) + out = self.proj(out) + + return out + + +class EfficientViTBlock(nn.Module): + def __init__( + self, + in_channels: int, + heads_ratio: float = 1.0, + dim=32, + expand_ratio: float = 4, + scales=(5,), + norm="bn2d", + act_func="hswish", + ): + super(EfficientViTBlock, self).__init__() + self.context_module = ResidualBlock( + LiteMLA( + in_channels=in_channels, + out_channels=in_channels, + heads_ratio=heads_ratio, + dim=dim, + norm=(None, norm), + scales=scales, + ), + IdentityLayer(), + ) + local_module = MBConv( + in_channels=in_channels, + out_channels=in_channels, + expand_ratio=expand_ratio, + use_bias=(True, True, False), + norm=(None, None, norm), + act_func=(act_func, act_func, None), + ) + self.local_module = ResidualBlock(local_module, IdentityLayer()) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.context_module(x) + x = self.local_module(x) + return x + + +################################################################################# +# Functional Blocks # +################################################################################# + + +class ResidualBlock(nn.Module): + def __init__( + self, + main: nn.Module or None, + shortcut: nn.Module or None, + post_act=None, + pre_norm: nn.Module or None = None, + ): + super(ResidualBlock, self).__init__() + + self.pre_norm = pre_norm + self.main = main + self.shortcut = shortcut + self.post_act = build_act(post_act) + + def forward_main(self, x: torch.Tensor) -> torch.Tensor: + if self.pre_norm is None: + return self.main(x) + else: + return self.main(self.pre_norm(x)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.main is None: + res = x + elif self.shortcut is None: + res = self.forward_main(x) + else: + res = self.forward_main(x) + self.shortcut(x) + if self.post_act: + res = self.post_act(res) + return res + + +class DAGBlock(nn.Module): + def __init__( + self, + inputs: dict[str, nn.Module], + merge: str, + post_input: nn.Module or None, + middle: nn.Module, + outputs: dict[str, nn.Module], + ): + super(DAGBlock, self).__init__() + + self.input_keys = list(inputs.keys()) + self.input_ops = nn.ModuleList(list(inputs.values())) + self.merge = merge + self.post_input = post_input + + self.middle = middle + + self.output_keys = list(outputs.keys()) + self.output_ops = nn.ModuleList(list(outputs.values())) + + def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + feat = [op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops)] + if self.merge == "add": + feat = list_sum(feat) + elif self.merge == "cat": + feat = torch.concat(feat, dim=1) + else: + raise NotImplementedError + if self.post_input is not None: + feat = self.post_input(feat) + feat = self.middle(feat) + for key, op in zip(self.output_keys, self.output_ops): + feature_dict[key] = op(feat) + return feature_dict + + +class OpSequential(nn.Module): + def __init__(self, op_list: list[nn.Module or None]): + super(OpSequential, self).__init__() + valid_op_list = [] + for op in op_list: + if op is not None: + valid_op_list.append(op) + self.op_list = nn.ModuleList(valid_op_list) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for op in self.op_list: + x = op(x) + return x diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/__init__.py new file mode 100644 index 0000000..0aab6b0 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/__init__.py @@ -0,0 +1,7 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from .list import * +from .network import * +from .random import * diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/list.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/list.py new file mode 100644 index 0000000..496a032 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/list.py @@ -0,0 +1,53 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +__all__ = [ + "list_sum", + "list_mean", + "weighted_list_sum", + "list_join", + "val2list", + "val2tuple", + "squeeze_list", +] + + +def list_sum(x: list) -> any: + return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) + + +def list_mean(x: list) -> any: + return list_sum(x) / len(x) + + +def weighted_list_sum(x: list, weights: list) -> any: + assert len(x) == len(weights) + return x[0] * weights[0] if len(x) == 1 else x[0] * weights[0] + weighted_list_sum(x[1:], weights[1:]) + + +def list_join(x: list, sep="\t", format_str="%s") -> str: + return sep.join([format_str % val for val in x]) + + +def val2list(x: list or tuple or any, repeat_time=1) -> list: + if isinstance(x, (list, tuple)): + return list(x) + return [x for _ in range(repeat_time)] + + +def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: + x = val2list(x) + + # repeat elements if necessary + if len(x) > 0: + x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] + + return tuple(x) + + +def squeeze_list(x: list or None) -> list or any: + if x is not None and len(x) == 1: + return x[0] + else: + return x diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/network.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/network.py new file mode 100644 index 0000000..ee85826 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/network.py @@ -0,0 +1,73 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import os +from inspect import signature + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = [ + "is_parallel", + "get_device", + "get_same_padding", + "resize", + "build_kwargs_from_config", + "load_state_dict_from_file", +] + + +def is_parallel(model: nn.Module) -> bool: + return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)) + + +def get_device(model: nn.Module) -> torch.device: + return model.parameters().__next__().device + + +def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]: + if isinstance(kernel_size, tuple): + return tuple([get_same_padding(ks) for ks in kernel_size]) + else: + assert kernel_size % 2 > 0, "kernel size should be odd number" + return kernel_size // 2 + + +def resize( + x: torch.Tensor, + size: any or None = None, + scale_factor: list[float] or None = None, + mode: str = "bicubic", + align_corners: bool or None = False, +) -> torch.Tensor: + if mode in {"bilinear", "bicubic"}: + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + ) + elif mode in {"nearest", "area"}: + return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode) + else: + raise NotImplementedError(f"resize(mode={mode}) not implemented.") + + +def build_kwargs_from_config(config: dict, target_func: callable) -> dict[str, any]: + valid_keys = list(signature(target_func).parameters) + kwargs = {} + for key in config: + if key in valid_keys: + kwargs[key] = config[key] + return kwargs + + +def load_state_dict_from_file(file: str, only_state_dict=True) -> dict[str, torch.Tensor]: + file = os.path.realpath(os.path.expanduser(file)) + checkpoint = torch.load(file, map_location="cpu") + if only_state_dict and "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + return checkpoint diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/random.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/random.py new file mode 100644 index 0000000..ee207f5 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/models/utils/random.py @@ -0,0 +1,65 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +import numpy as np +import torch + +__all__ = [ + "torch_randint", + "torch_random", + "torch_shuffle", + "torch_uniform", + "torch_random_choices", +] + + +def torch_randint(low: int, high: int, generator: torch.Generator or None = None) -> int: + """uniform: [low, high)""" + if low == high: + return low + else: + assert low < high + return int(torch.randint(low=low, high=high, generator=generator, size=(1,))) + + +def torch_random(generator: torch.Generator or None = None) -> float: + """uniform distribution on the interval [0, 1)""" + return float(torch.rand(1, generator=generator)) + + +def torch_shuffle(src_list: list[any], generator: torch.Generator or None = None) -> list[any]: + rand_indexes = torch.randperm(len(src_list), generator=generator).tolist() + return [src_list[i] for i in rand_indexes] + + +def torch_uniform(low: float, high: float, generator: torch.Generator or None = None) -> float: + """uniform distribution on the interval [low, high)""" + rand_val = torch_random(generator) + return (high - low) * rand_val + low + + +def torch_random_choices( + src_list: list[any], + generator: torch.Generator or None = None, + k=1, + weight_list: list[float] or None = None, +) -> any or list: + if weight_list is None: + rand_idx = torch.randint(low=0, high=len(src_list), generator=generator, size=(k,)) + out_list = [src_list[i] for i in rand_idx] + else: + assert len(weight_list) == len(src_list) + accumulate_weight_list = np.cumsum(weight_list) + + out_list = [] + for _ in range(k): + val = torch_uniform(0, accumulate_weight_list[-1], generator) + active_id = 0 + for i, weight_val in enumerate(accumulate_weight_list): + active_id = i + if weight_val > val: + break + out_list.append(src_list[active_id]) + + return out_list[0] if k == 1 else out_list diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/sam_model_zoo.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/sam_model_zoo.py new file mode 100644 index 0000000..f0d1043 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/sam_model_zoo.py @@ -0,0 +1,51 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from src.models.efficientvit.models.efficientvit import ( + EfficientViTSam, + efficientvit_sam_l0, + efficientvit_sam_l1, + efficientvit_sam_l2, + efficientvit_sam_xl0, + efficientvit_sam_xl1, +) +from src.models.efficientvit.models.nn.norm import set_norm_eps +from src.models.efficientvit.models.utils import load_state_dict_from_file + +__all__ = ["create_sam_model"] + + +REGISTERED_SAM_MODEL: dict[str, str] = { + "l0": "assets/checkpoints/sam/l0.pt", + "l1": "assets/checkpoints/sam/l1.pt", + "l2": "assets/checkpoints/sam/l2.pt", + "xl0": "assets/checkpoints/sam/xl0.pt", + "xl1": "assets/checkpoints/sam/xl1.pt", +} + + +def create_sam_model(name: str, pretrained=True, weight_url: str or None = None, **kwargs) -> EfficientViTSam: + model_dict = { + "l0": efficientvit_sam_l0, + "l1": efficientvit_sam_l1, + "l2": efficientvit_sam_l2, + "xl0": efficientvit_sam_xl0, + "xl1": efficientvit_sam_xl1, + } + + model_id = name.split("-")[0] + if model_id not in model_dict: + raise ValueError(f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}") + else: + model = model_dict[model_id](**kwargs) + set_norm_eps(model, 1e-6) + + if pretrained: + weight_url = weight_url or REGISTERED_SAM_MODEL.get(name, None) + if weight_url is None: + raise ValueError(f"Do not find the pretrained weight of {name}.") + else: + weight = load_state_dict_from_file(weight_url) + model.load_state_dict(weight) + return model diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/data_provider/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/data_provider/__init__.py new file mode 100644 index 0000000..4067e8b --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/data_provider/__init__.py @@ -0,0 +1 @@ +from .sam import * diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/data_provider/sam.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/data_provider/sam.py new file mode 100644 index 0000000..f92750d --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/data_provider/sam.py @@ -0,0 +1,169 @@ +import json + +import numpy as np +import torch +import torchvision.transforms as transforms +from pycocotools import mask as mask_utils +from skimage import io +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +from src.models.efficientvit.apps.data_provider import DataProvider +from src.models.efficientvit.samcore.data_provider.utils import ( + Normalize_and_Pad, + RandomHFlip, + ResizeLongestSide, + SAMDistributedSampler, +) + +__all__ = ["SAMDataProvider"] + + +class OnlineDataset(Dataset): + def __init__(self, root, train=True, num_masks=64, transform=None): + self.root = root + self.train = train + self.num_masks = num_masks + self.transform = transform + + self.data = open(f"{self.root}/sa_images_ids.txt", "r").read().splitlines() + + if self.train: + self.data = self.data[: int(len(self.data) * 0.99)] + else: + self.data = self.data[int(len(self.data) * 0.99) :] + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + """ + Note: We provide the simplest data organization here. You can modify the code according to your data organization. + """ + + index = int(self.data[idx]) + + image_path = f"{self.root}/images/sa_{index}.jpg" + image = io.imread(image_path) + + json_path = f"{self.root}/masks/sa_{index}.json" + annotations = json.load(open(json_path))["annotations"] + + if self.train: + if len(annotations) > self.num_masks: + r = np.random.choice(len(annotations), size=self.num_masks, replace=False) + else: + repeat, residue = self.num_masks // len(annotations), self.num_masks % len(annotations) + r = np.random.choice(len(annotations), size=residue, replace=False) + r = np.concatenate([np.arange(len(annotations)) for _ in range(repeat)] + [r], axis=0) + + else: + if len(annotations) > self.num_masks: + r = np.arange(self.num_masks) + else: + repeat, residue = self.num_masks // len(annotations), self.num_masks % len(annotations) + r = np.arange(residue) + r = np.concatenate([np.arange(len(annotations)) for _ in range(repeat)] + [r], axis=0) + + masks = np.stack([mask_utils.decode(annotations[i]["segmentation"]) for i in r]) + points = np.stack([annotations[i]["point_coords"][0] for i in r]) + bboxs = np.stack([annotations[i]["bbox"] for i in r]) + + image = torch.tensor(image, dtype=torch.float32) + image = torch.transpose(torch.transpose(image, 1, 2), 0, 1) + masks = torch.tensor(masks, dtype=torch.float32) + points = torch.tensor(points, dtype=torch.float32) + bboxs = torch.tensor(bboxs, dtype=torch.float32) + + sample = { + "image": image, + "masks": masks, + "points": points, + "bboxs": bboxs, + "shape": torch.tensor(image.shape[-2:]), + } + + if self.transform: + sample = self.transform(sample) + + return sample + + +class SAMDataProvider(DataProvider): + name = "sam" + + def __init__( + self, + root: str, + sub_epochs_per_epoch: int, + num_masks: int, + train_batch_size: int, + test_batch_size: int, + valid_size: int or float or None = None, + n_worker=8, + image_size: int = 1024, + num_replicas: int or None = None, + rank: int or None = None, + train_ratio: float or None = None, + drop_last: bool = False, + ): + self.root = root + self.num_masks = num_masks + self.sub_epochs_per_epoch = sub_epochs_per_epoch + + super().__init__( + train_batch_size, + test_batch_size, + valid_size, + n_worker, + image_size, + num_replicas, + rank, + train_ratio, + drop_last, + ) + + def build_train_transform(self): + train_transforms = [ + RandomHFlip(), + ResizeLongestSide(target_length=self.image_size[0]), + Normalize_and_Pad(target_length=self.image_size[0]), + ] + + return transforms.Compose(train_transforms) + + def build_valid_transform(self): + valid_transforms = [ + ResizeLongestSide(target_length=self.image_size[0]), + Normalize_and_Pad(target_length=self.image_size[0]), + ] + + return transforms.Compose(valid_transforms) + + def build_datasets(self) -> tuple[any, any, any]: + train_transform = self.build_train_transform() + valid_transform = self.build_valid_transform() + + train_dataset = OnlineDataset(root=self.root, train=True, num_masks=self.num_masks, transform=train_transform) + + val_dataset = OnlineDataset(root=self.root, train=False, num_masks=2, transform=valid_transform) + + test_dataset = None + + return train_dataset, val_dataset, test_dataset + + def build_dataloader(self, dataset: any or None, batch_size: int, n_worker: int, drop_last: bool, train: bool): + if dataset is None: + return None + if train: + sampler = SAMDistributedSampler(dataset, sub_epochs_per_epoch=self.sub_epochs_per_epoch) + dataloader = DataLoader(dataset, batch_size, sampler=sampler, drop_last=True, num_workers=n_worker) + return dataloader + else: + sampler = DistributedSampler(dataset, shuffle=False) + dataloader = DataLoader(dataset, batch_size, sampler=sampler, drop_last=False, num_workers=n_worker) + return dataloader + + def set_epoch_and_sub_epoch(self, epoch: int, sub_epoch: int) -> None: + if isinstance(self.train.sampler, SAMDistributedSampler): + self.train.sampler.set_epoch_and_sub_epoch(epoch, sub_epoch) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/data_provider/utils.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/data_provider/utils.py new file mode 100644 index 0000000..5505f1b --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/data_provider/utils.py @@ -0,0 +1,194 @@ +import random +from copy import deepcopy +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +import torchvision.transforms as transforms +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler + + +class SAMDistributedSampler(DistributedSampler): + """ + Modified from https://github.com/pytorch/pytorch/blob/97261be0a8f09bed9ab95d0cee82e75eebd249c3/torch/utils/data/distributed.py. + """ + + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + sub_epochs_per_epoch: int = 1, + ) -> None: + super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + + self.sub_epoch = 0 + self.sub_epochs_per_epoch = sub_epochs_per_epoch + self.set_sub_num_samples() + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + indices = indices[(self.sub_epoch % self.sub_epochs_per_epoch) :: self.sub_epochs_per_epoch] + + return iter(indices) + + def __len__(self) -> int: + return self.sub_num_samples + + def set_sub_num_samples(self) -> int: + self.sub_num_samples = self.num_samples // self.sub_epochs_per_epoch + if self.sub_num_samples % self.sub_epochs_per_epoch > self.sub_epoch: + self.sub_num_samples += 1 + + def set_epoch_and_sub_epoch(self, epoch: int, sub_epoch: int) -> None: + r""" + Set the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + sub_epoch (int): Sub epoch number. + """ + self.epoch = epoch + self.sub_epoch = sub_epoch + self.set_sub_num_samples() + + +class RandomHFlip(object): + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, sample): + image, masks, points, bboxs, shape = ( + sample["image"], + sample["masks"], + sample["points"], + sample["bboxs"], + sample["shape"], + ) + + if random.random() >= self.prob: + image = torch.flip(image, dims=[2]) + masks = torch.flip(masks, dims=[2]) + points = deepcopy(points).to(torch.float) + bboxs = deepcopy(bboxs).to(torch.float) + points[:, 0] = shape[-1] - points[:, 0] + bboxs[:, 0] = shape[-1] - bboxs[:, 2] - bboxs[:, 0] + + return {"image": image, "masks": masks, "points": points, "bboxs": bboxs, "shape": shape} + + +class ResizeLongestSide(object): + """ + Modified from https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/utils/transforms.py. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: + target_size = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) + return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) + + def apply_boxes(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_coords(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) + + def __call__(self, sample): + image, masks, points, bboxs, shape = ( + sample["image"], + sample["masks"], + sample["points"], + sample["bboxs"], + sample["shape"], + ) + + image = self.apply_image(image.unsqueeze(0), shape).squeeze(0) + masks = self.apply_image(masks.unsqueeze(1), shape).squeeze(1) + points = self.apply_coords(points, shape) + bboxs = self.apply_boxes(bboxs, shape) + + return {"image": image, "masks": masks, "points": points, "bboxs": bboxs, "shape": shape} + + +class Normalize_and_Pad(object): + def __init__(self, target_length: int) -> None: + self.target_length = target_length + self.transform = transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) + + def __call__(self, sample): + image, masks, points, bboxs, shape = ( + sample["image"], + sample["masks"], + sample["points"], + sample["bboxs"], + sample["shape"], + ) + + h, w = image.shape[-2:] + image = self.transform(image) + + padh = self.target_length - h + padw = self.target_length - w + + image = F.pad(image.unsqueeze(0), (0, padw, 0, padh), value=0).squeeze(0) + masks = F.pad(masks.unsqueeze(1), (0, padw, 0, padh), value=0).squeeze(1) + + return {"image": image, "masks": masks, "points": points, "bboxs": bboxs, "shape": shape} diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/__init__.py new file mode 100644 index 0000000..8fa54be --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/__init__.py @@ -0,0 +1,2 @@ +from .sam_run_config import * +from .sam_trainer import * diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/sam_run_config.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/sam_run_config.py new file mode 100644 index 0000000..107fb95 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/sam_run_config.py @@ -0,0 +1,9 @@ +from src.models.efficientvit.apps.trainer.run_config import RunConfig + +__all__ = ["SAMRunConfig"] + + +class SAMRunConfig(RunConfig): + @property + def none_allowed(self): + return ["reset_bn", "reset_bn_size", "reset_bn_batch_size"] + super().none_allowed diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/sam_trainer.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/sam_trainer.py new file mode 100644 index 0000000..961ae1d --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/sam_trainer.py @@ -0,0 +1,302 @@ +import random +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import wandb +from PIL import Image +from tqdm import tqdm + +from src.models.efficientvit.apps.trainer import Trainer +from src.models.efficientvit.apps.utils import AverageMeter, get_dist_local_rank, get_dist_size, is_master, sync_tensor +from src.models.efficientvit.models.utils import list_join +from src.models.efficientvit.samcore.data_provider import SAMDataProvider +from src.models.efficientvit.samcore.trainer import SAMRunConfig +from src.models.efficientvit.samcore.trainer.utils import ( + compute_boundary_iou, + compute_iou, + loss_masks, + mask_iou_batch, + masks_sample_points, +) + +__all__ = ["SAMTrainer"] + + +class SAMTrainer(Trainer): + def __init__( + self, + path: str, + model: nn.Module, + data_provider: SAMDataProvider, + ) -> None: + super().__init__( + path=path, + model=model, + data_provider=data_provider, + ) + + if is_master(): + self.wandb_log = wandb.init(project="efficientvit-sam") + + def _validate(self, model, data_loader, epoch: int, sub_epoch: int) -> dict[str, any]: + val_loss = AverageMeter() + val_iou = AverageMeter() + val_iou_boundary = AverageMeter() + + with torch.no_grad(): + with tqdm( + total=len(data_loader), + desc=f"Validate Epoch #{epoch + 1}, Sub Epoch #{sub_epoch+1}", + disable=not is_master(), + file=sys.stdout, + ) as t: + for i, data in enumerate(data_loader): + image = data["image"].cuda() + masks = data["masks"].cuda() + bboxs = data["bboxs"].cuda() * 2 if image.shape[2] == 512 else data["bboxs"].cuda() + points = data["points"].cuda() * 2 if image.shape[2] == 512 else data["points"].cuda() + + bboxs[..., 2] = bboxs[..., 0] + bboxs[..., 2] + bboxs[..., 3] = bboxs[..., 1] + bboxs[..., 3] + + batched_input = [] + for b_i in range(len(image)): + dict_input = dict() + + dict_input["image"] = image[b_i] + dict_input["boxes"] = bboxs[b_i] + + batched_input.append(dict_input) + + output, iou_predictions = model(batched_input, True) + + B, M, N, H, W = output.shape + output = torch.stack( + [ + output[k][torch.arange(M), iou_predictions[k].argmax(-1).squeeze()] + for k in range(len(output)) + ], + dim=0, + ) + output = ( + F.interpolate(output, size=(image.shape[2], image.shape[3]), mode="bilinear") + .reshape(-1, image.shape[2], image.shape[3]) + .unsqueeze(1) + ) + masks = masks.reshape(-1, image.shape[2], image.shape[3]).unsqueeze(1) + + loss_mask, loss_dice = loss_masks(output, masks, len(output)) + loss = loss_mask * 20 + loss_dice + + iou = compute_iou(output, masks * 255) + boundary_iou = compute_boundary_iou(output, masks * 255) + + loss = sync_tensor(loss) + iou = sync_tensor(iou) + boundary_iou = sync_tensor(boundary_iou) + + val_loss.update(loss, image.shape[0] * get_dist_size()) + val_iou.update(iou, image.shape[0] * get_dist_size()) + val_iou_boundary.update(boundary_iou, image.shape[0] * get_dist_size()) + + t.set_postfix( + { + "loss": val_loss.avg, + "iou": val_iou.avg, + "boundary_iou": val_iou_boundary.avg, + "bs": image.shape[0] * get_dist_size(), + } + ) + t.update() + + if is_master(): + self.wandb_log.log( + {"val_loss": val_loss.avg, "val_iou": val_iou.avg, "val_boundary_iou": val_iou_boundary.avg} + ) + + return { + "val_loss": val_loss.avg, + "val_iou": val_iou.avg, + "val_boundary_iou": val_iou_boundary.avg, + } + + def validate(self, model=None, data_loader=None, epoch=0, sub_epoch=0) -> dict[str, any]: + model = model or self.eval_network + if data_loader is None: + data_loader = self.data_provider.valid + + model.eval() + return self._validate(model, data_loader, epoch, sub_epoch) + + def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]: + image = feed_dict["image"].cuda() + masks = feed_dict["masks"].cuda() + bboxs = feed_dict["bboxs"].cuda() * 2 if image.shape[2] == 512 else feed_dict["bboxs"].cuda() + points = feed_dict["points"].cuda() * 2 if image.shape[2] == 512 else feed_dict["points"].cuda() + + bboxs[..., 2] = bboxs[..., 0] + bboxs[..., 2] + bboxs[..., 3] = bboxs[..., 1] + bboxs[..., 3] + + return { + "image": image, + "masks": masks, + "points": points, + "bboxs": bboxs, + } + + def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]: + image = feed_dict["image"] + masks = feed_dict["masks"] + bboxs = feed_dict["bboxs"] + points = feed_dict["points"] + + batched_input = [] + for b_i in range(len(image)): + dict_input = dict() + dict_input["image"] = image[b_i] + + if random.random() >= 0.5: + dict_input["boxes"] = bboxs[b_i] + else: + try: + n_p = int(random.random() * 10 + 1) + dict_input["point_coords"] = masks_sample_points(masks[b_i], k=n_p) + if image.shape[2] == 512: + dict_input["point_coords"] = dict_input["point_coords"] * 2 + dict_input["point_labels"] = torch.ones((points[b_i].shape[0], n_p), device=image.device) + except: + dict_input["boxes"] = bboxs[b_i] + + batched_input.append(dict_input) + + with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.enable_amp): + if random.random() >= 0.5: + output, iou_predictions = self.model(batched_input, multimask_output=True) + else: + output, iou_predictions = self.model(batched_input, multimask_output=False) + + masks = masks.reshape(-1, image.shape[2], image.shape[3]).unsqueeze(1) + + loss_list = [] + for i in range(output.shape[2]): + output_i = ( + F.interpolate(output[:, :, i], size=(image.shape[2], image.shape[3]), mode="bilinear") + .reshape(-1, image.shape[2], image.shape[3]) + .unsqueeze(1) + ) + loss_mask_i, loss_dice_i = loss_masks(output_i, masks, len(output_i), mode="none") + loss_i = loss_mask_i * 20 + loss_dice_i + loss_list.append(loss_i) + loss = torch.stack(loss_list, -1) + + min_indices = torch.argmin(loss, dim=1) + mask = torch.zeros_like(loss, device=loss.device) + mask.scatter_(1, min_indices.unsqueeze(1), 1) + + loss = (loss * mask).mean() * loss.shape[-1] + + self.scaler.scale(loss).backward() + + return {"loss": loss, "output": output} + + def _train_one_sub_epoch(self, epoch: int, sub_epoch: int) -> dict[str, any]: + train_loss = AverageMeter() + + with tqdm( + total=len(self.data_provider.train), + desc=f"Train Epoch #{epoch + 1}, Sub Epoch #{sub_epoch + 1}", + disable=not is_master(), + file=sys.stdout, + ) as t: + for i, data in enumerate(self.data_provider.train): + feed_dict = data + + # preprocessing + feed_dict = self.before_step(feed_dict) + # clear gradient + self.optimizer.zero_grad() + # forward & backward + output_dict = self.run_step(feed_dict) + # update: optimizer, lr_scheduler + self.after_step() + + loss = output_dict["loss"] + loss = sync_tensor(loss) + train_loss.update(loss, data["image"].shape[0] * get_dist_size()) + + if is_master(): + self.wandb_log.log( + { + "train_loss": train_loss.avg, + "epoch": epoch, + "sub_epoch": sub_epoch, + "learning_rate": sorted(set([group["lr"] for group in self.optimizer.param_groups]))[0], + } + ) + + t.set_postfix( + { + "loss": train_loss.avg, + "bs": data["image"].shape[0] * get_dist_size(), + "res": data["image"].shape[2], + "lr": list_join( + sorted(set([group["lr"] for group in self.optimizer.param_groups])), + "#", + "%.1E", + ), + "progress": self.run_config.progress, + } + ) + t.update() + + return { + "train_loss": train_loss.avg, + } + + def train_one_sub_epoch(self, epoch: int, sub_epoch: int) -> dict[str, any]: + self.model.train() + + self.data_provider.set_epoch_and_sub_epoch(epoch, sub_epoch) + + train_info_dict = self._train_one_sub_epoch(epoch, sub_epoch) + + return train_info_dict + + def train(self) -> None: + for sub_epoch in range(self.start_epoch, self.run_config.n_epochs): + epoch = sub_epoch // self.data_provider.sub_epochs_per_epoch + + train_info_dict = self.train_one_sub_epoch(epoch, sub_epoch) + + val_info_dict = self.validate(epoch=epoch, sub_epoch=sub_epoch) + + val_iou = val_info_dict["val_iou"] + is_best = val_iou > self.best_val + self.best_val = max(val_iou, self.best_val) + + self.save_model( + only_state_dict=False, + epoch=sub_epoch, + model_name=f"checkpoint_{epoch}_{sub_epoch}.pt", + ) + + def prep_for_training(self, run_config: SAMRunConfig, amp="fp32") -> None: + self.run_config = run_config + self.model = nn.parallel.DistributedDataParallel( + self.model.cuda(), + device_ids=[get_dist_local_rank()], + find_unused_parameters=True, + ) + + self.run_config.global_step = 0 + self.run_config.batch_per_epoch = len(self.data_provider.train) + assert self.run_config.batch_per_epoch > 0, "Training set is empty" + + # build optimizer + self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model) + + # amp + self.amp = amp + self.scaler = torch.cuda.amp.GradScaler(enabled=self.enable_amp) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/utils.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/utils.py new file mode 100644 index 0000000..cc09d72 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/samcore/trainer/utils.py @@ -0,0 +1,318 @@ +import io +from typing import List + +import cv2 +import numpy as np +import torch +from torch.nn import functional as F + +""" + Some functions in this file are modified from https://github.com/SysCV/sam-hq/blob/main/train/utils/misc.py. +""" + + +def point_sample(input, point_coords, **kwargs): + """ + A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. + Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside + [0, 1] x [0, 1] square. + Args: + input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. + point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains + [0, 1] x [0, 1] normalized point coordinates. + Returns: + output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains + features for points in `point_coords`. The features are obtained via bilinear + interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. + """ + + add_dim = False + if point_coords.dim() == 3: + add_dim = True + point_coords = point_coords.unsqueeze(2) + output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) + if add_dim: + output = output.squeeze(3) + return output + + +def cat(tensors: List[torch.Tensor], dim: int = 0): + """ + Efficient version of torch.cat that avoids a copy if there is only a single element in a list. + """ + + assert isinstance(tensors, (list, tuple)) + if len(tensors) == 1: + return tensors[0] + return torch.cat(tensors, dim) + + +def get_uncertain_point_coords_with_randomness( + coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio +): + """ + Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties + are calculated for each point using 'uncertainty_func' function that takes point's logit + prediction as input. + See PointRend paper for details. + Args: + coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for + class-specific or class-agnostic prediction. + uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that + contains logit predictions for P points and returns their uncertainties as a Tensor of + shape (N, 1, P). + num_points (int): The number of points P to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. + Returns: + point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P + sampled points. + """ + + assert oversample_ratio >= 1 + assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 + num_boxes = coarse_logits.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) + point_logits = point_sample(coarse_logits, point_coords, align_corners=False) + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2) + if num_random_points > 0: + point_coords = cat( + [ + point_coords, + torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), + ], + dim=1, + ) + return point_coords + + +def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, mode: str): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + if mode == "none": + return loss + else: + return loss.sum() / num_masks + + +dice_loss_jit = torch.jit.script(dice_loss) # type: torch.jit.ScriptModule + + +def sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, mode: str): + """ + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + Returns: + Loss tensor + """ + + loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + + if mode == "none": + return loss.mean(1) + else: + return loss.mean(1).sum() / num_masks + + +sigmoid_ce_loss_jit = torch.jit.script(sigmoid_ce_loss) # type: torch.jit.ScriptModule + + +def calculate_uncertainty(logits): + """ + We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the + foreground class in `classes`. + Args: + logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or + class-agnostic, where R is the total number of predicted masks in all images and C is + the number of foreground classes. The values are logits. + Returns: + scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with + the most uncertain locations having the highest uncertainty score. + """ + + assert logits.shape[1] == 1 + gt_class_logits = logits.clone() + return -(torch.abs(gt_class_logits)) + + +def loss_masks(src_masks, target_masks, num_masks, oversample_ratio=3.0, mode="mean"): + """ + Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + + with torch.no_grad(): + # sample point_coords + point_coords = get_uncertain_point_coords_with_randomness( + src_masks, + lambda logits: calculate_uncertainty(logits), + 112 * 112, + oversample_ratio, + 0.75, + ) + # get gt labels + point_labels = point_sample( + target_masks, + point_coords, + align_corners=False, + ).squeeze(1) + + point_logits = point_sample( + src_masks, + point_coords, + align_corners=False, + ).squeeze(1) + + loss_mask = sigmoid_ce_loss_jit(point_logits, point_labels, num_masks, mode) + loss_dice = dice_loss_jit(point_logits, point_labels, num_masks, mode) + + del src_masks + del target_masks + return loss_mask, loss_dice + + +def mask_iou(pred_label, label): + """ + calculate mask iou for pred_label and gt_label. + """ + + pred_label = (pred_label > 0)[0].int() + label = (label > 128)[0].int() + + intersection = ((label * pred_label) > 0).sum() + union = ((label + pred_label) > 0).sum() + return intersection / (union + 1e-6) + + +def compute_iou(preds, target): + if preds.shape[2] != target.shape[2] or preds.shape[3] != target.shape[3]: + postprocess_preds = F.interpolate(preds, size=target.size()[2:], mode="bilinear", align_corners=False) + else: + postprocess_preds = preds + iou = 0 + for i in range(0, len(preds)): + iou = iou + mask_iou(postprocess_preds[i], target[i]) + return iou / len(preds) + + +def mask_to_boundary(mask, dilation_ratio=0.02): + """ + Convert binary mask to boundary mask. + :param mask (numpy array, uint8): binary mask + :param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal + :return: boundary mask (numpy array) + """ + + h, w = mask.shape + img_diag = np.sqrt(h**2 + w**2) + dilation = int(round(dilation_ratio * img_diag)) + if dilation < 1: + dilation = 1 + # Pad image so mask truncated by the image border is also considered as boundary. + new_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0) + kernel = np.ones((3, 3), dtype=np.uint8) + new_mask_erode = cv2.erode(new_mask, kernel, iterations=dilation) + mask_erode = new_mask_erode[1 : h + 1, 1 : w + 1] + # G_d intersects G in the paper. + return mask - mask_erode + + +def boundary_iou(gt, dt, dilation_ratio=0.02): + """ + Compute boundary iou between two binary masks. + :param gt (numpy array, uint8): binary mask + :param dt (numpy array, uint8): binary mask + :param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal + :return: boundary iou (float) + """ + + device = gt.device + dt = (dt > 0)[0].cpu().byte().numpy() + gt = (gt > 128)[0].cpu().byte().numpy() + + gt_boundary = mask_to_boundary(gt, dilation_ratio) + dt_boundary = mask_to_boundary(dt, dilation_ratio) + intersection = ((gt_boundary * dt_boundary) > 0).sum() + union = ((gt_boundary + dt_boundary) > 0).sum() + boundary_iou = intersection / (union + 1e-6) + return torch.tensor(boundary_iou).float().to(device) + + +def compute_boundary_iou(preds, target): + if preds.shape[2] != target.shape[2] or preds.shape[3] != target.shape[3]: + postprocess_preds = F.interpolate(preds, size=target.size()[2:], mode="bilinear", align_corners=False) + else: + postprocess_preds = preds + iou = 0 + for i in range(0, len(preds)): + iou = iou + boundary_iou(target[i], postprocess_preds[i]) + return iou / len(preds) + + +def masks_sample_points(masks, k=10): + """Sample points on mask""" + + if masks.numel() == 0: + return torch.zeros((0, 2), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + y = y.to(masks) + x = x.to(masks) + + # k = 10 + samples = [] + for b_i in range(len(masks)): + select_mask = masks[b_i].bool() + x_idx = torch.masked_select(x, select_mask) + y_idx = torch.masked_select(y, select_mask) + + perm = torch.randperm(x_idx.size(0)) + idx = perm[:k] + samples_x = x_idx[idx] + samples_y = y_idx[idx] + samples_xy = torch.cat((samples_x[:, None], samples_y[:, None]), dim=1) + samples.append(samples_xy) + + samples = torch.stack(samples) + + return samples + + +def mask_iou_batch(pred_label, label): + """ + calculate mask iou for pred_label and gt_label. + """ + + pred_label = (pred_label > 0).int() + label = (label > 128).int() + + intersection = ((label * pred_label) > 0).sum(dim=(-1, -2)) + union = ((label + pred_label) > 0).sum(dim=(-1, -2)) + return intersection / (union + 1e-6) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/seg_model_zoo.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/seg_model_zoo.py new file mode 100644 index 0000000..fefa1ff --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/efficientvit/seg_model_zoo.py @@ -0,0 +1,70 @@ +# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction +# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han +# International Conference on Computer Vision (ICCV), 2023 + +from src.models.efficientvit.models.efficientvit import ( + EfficientViTSeg, + efficientvit_seg_b0, + efficientvit_seg_b1, + efficientvit_seg_b2, + efficientvit_seg_b3, + efficientvit_seg_l1, + efficientvit_seg_l2, +) +from src.models.efficientvit.models.nn.norm import set_norm_eps +from src.models.efficientvit.models.utils import load_state_dict_from_file + +__all__ = ["create_seg_model"] + + +REGISTERED_SEG_MODEL: dict[str, dict[str, str]] = { + "cityscapes": { + "b0": "assets/checkpoints/seg/cityscapes/b0.pt", + "b1": "assets/checkpoints/seg/cityscapes/b1.pt", + "b2": "assets/checkpoints/seg/cityscapes/b2.pt", + "b3": "assets/checkpoints/seg/cityscapes/b3.pt", + ################################################ + "l1": "assets/checkpoints/seg/cityscapes/l1.pt", + "l2": "assets/checkpoints/seg/cityscapes/l2.pt", + }, + "ade20k": { + "b1": "assets/checkpoints/seg/ade20k/b1.pt", + "b2": "assets/checkpoints/seg/ade20k/b2.pt", + "b3": "assets/checkpoints/seg/ade20k/b3.pt", + ################################################ + "l1": "assets/checkpoints/seg/ade20k/l1.pt", + "l2": "assets/checkpoints/seg/ade20k/l2.pt", + }, +} + + +def create_seg_model( + name: str, dataset: str, pretrained=True, weight_url: str or None = None, **kwargs +) -> EfficientViTSeg: + model_dict = { + "b0": efficientvit_seg_b0, + "b1": efficientvit_seg_b1, + "b2": efficientvit_seg_b2, + "b3": efficientvit_seg_b3, + ######################### + "l1": efficientvit_seg_l1, + "l2": efficientvit_seg_l2, + } + + model_id = name.split("-")[0] + if model_id not in model_dict: + raise ValueError(f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}") + else: + model = model_dict[model_id](dataset=dataset, **kwargs) + + if model_id in ["l1", "l2"]: + set_norm_eps(model, 1e-7) + + if pretrained: + weight_url = weight_url or REGISTERED_SEG_MODEL[dataset].get(name, None) + if weight_url is None: + raise ValueError(f"Do not find the pretrained weight of {name}.") + else: + weight = load_state_dict_from_file(weight_url) + model.load_state_dict(weight) + return model diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/finetune_module.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/finetune_module.py new file mode 100644 index 0000000..8f6c551 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/finetune_module.py @@ -0,0 +1,113 @@ +from typing import Any, Dict + +import torch +from lightning import LightningModule +from torchmetrics import MeanMetric + +from src.losses import SAMLoss +from src.metrics.generalized_dice import GeneralizedDiceMetric +from src.models.base_sam import BaseSAM + + +class FinetuneLitModule(LightningModule): + + def __init__( + self, + model: BaseSAM, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + criterion: torch.nn.Module = None, + scheduler_interval: str = "step", + freeze_image_encoder: bool = False, + freeze_prompt_encoder: bool = False, + ) -> None: + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False, ignore=["model", "criterion"]) + + self.model = model + self.criterion = criterion if criterion is not None else SAMLoss() + self.train_loss = MeanMetric() + self.val_loss = MeanMetric() + self.test_loss = MeanMetric() + self.val_acc = GeneralizedDiceMetric() + self.test_acc = GeneralizedDiceMetric() + + if freeze_image_encoder: + self.model.image_encoder.requires_grad_(False) + self.model.image_encoder.eval() + if freeze_prompt_encoder: + self.model.prompt_encoder.requires_grad_(False) + self.model.prompt_encoder.eval() + + def model_step(self, batch, metric=None): + imgs = batch["image"] # (B, 3, H, W) + image_encoder_input_size = max(imgs.shape[-2:]) + target_masks = ( + batch["masks"] + .view(-1, 1, image_encoder_input_size, image_encoder_input_size) + .float() + ) # (B * N, 1, H, W) + boxes = batch["boxes"].view(-1, 4) # (B * N, 4) + + image_embeddings = self.model.image_encoder(imgs) # (B, 256, 64, 64) + masks, iou_preds = self.model.prompt_and_decoder(image_embeddings, boxes) + masks = self.model.postprocess_masks( + masks=masks, + input_size=(image_encoder_input_size, image_encoder_input_size), + original_size=(-1, -1), + return_with_image_encoder_size=True, + ) # (B * N, 1, H, W) + loss = self.criterion( + pred_logits=masks, + pred_iou=iou_preds, + gt_mask=target_masks, + ) + + if metric is not None: + metric.update(preds=(masks > 0), gts=target_masks) + + return loss + + def training_step(self, batch, batch_idx) -> torch.Tensor: + loss = self.model_step(batch) + self.train_loss(loss) + self.log( + "train/loss", self.train_loss, on_step=True, on_epoch=True, prog_bar=True + ) + return loss + + def validation_step(self, batch, batch_idx) -> None: + loss = self.model_step(batch, metric=self.val_acc) + self.val_loss(loss) + metrics = {"val/loss": self.val_loss, "val/acc": self.val_acc} + self.log_dict(metrics, on_step=True, on_epoch=True) + return loss + + def test_step(self, batch, batch_idx) -> None: + loss = self.model_step(batch, metric=self.test_acc) + self.test_loss(loss) + metrics = {"test/loss": self.test_loss, "test/acc": self.test_acc} + self.log_dict(metrics, on_step=True, on_epoch=True) + return loss + + def configure_optimizers(self) -> Dict[str, Any]: + optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "train/loss", + "interval": self.hparams.scheduler_interval, + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + +if __name__ == "__main__": + _ = FinetuneLitModule(None, None, None) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/lite_medsam/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/lite_medsam/__init__.py new file mode 100644 index 0000000..3724c67 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/lite_medsam/__init__.py @@ -0,0 +1 @@ +from .sam import build_lite_medsam diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/lite_medsam/sam.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/lite_medsam/sam.py new file mode 100644 index 0000000..563740c --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/lite_medsam/sam.py @@ -0,0 +1,64 @@ +import torch + +from src.models.segment_anything.modeling import ( + MaskDecoder, + PromptEncoder, + TwoWayTransformer, +) + +from .tiny_vit import TinyViT +from ..base_sam import BaseSAM + + +def build_lite_medsam(checkpoint=None): + medsam_lite_image_encoder = TinyViT( + img_size=256, + in_chans=3, + embed_dims=[ + 64, ## (64, 256, 256) + 128, ## (128, 128, 128) + 160, ## (160, 64, 64) + 320, ## (320, 64, 64) + ], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8, + ) + + medsam_lite_prompt_encoder = PromptEncoder( + embed_dim=256, + image_embedding_size=(64, 64), + input_image_size=(256, 256), + mask_in_chans=16, + ) + + medsam_lite_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=256, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=256, + iou_head_depth=3, + iou_head_hidden_dim=256, + ) + + sam = BaseSAM( + image_encoder=medsam_lite_image_encoder, + mask_decoder=medsam_lite_mask_decoder, + prompt_encoder=medsam_lite_prompt_encoder, + ) + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/lite_medsam/tiny_vit.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/lite_medsam/tiny_vit.py new file mode 100644 index 0000000..fed7d14 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/lite_medsam/tiny_vit.py @@ -0,0 +1,706 @@ +# -------------------------------------------------------- +# TinyViT Model Architecture +# Copyright (c) 2022 Microsoft +# Adapted from LeViT and Swin Transformer +# LeViT: (https://github.com/facebookresearch/levit) +# Swin: (https://github.com/microsoft/swin-transformer) +# Build the TinyViT Model +# -------------------------------------------------------- +# The TinyViT model is adapted from MobileSAM's variant. +# -------------------------------------------------------- + +import itertools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath as TimmDropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from typing import Tuple + + +class Conv2d_BN(torch.nn.Sequential): + def __init__( + self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1 + ): + super().__init__() + self.add_module( + "c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False) + ) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module("bn", bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Conv2d( + w.size(1) * self.c.groups, + w.size(0), + w.shape[2:], + stride=self.c.stride, + padding=self.c.padding, + dilation=self.c.dilation, + groups=self.c.groups, + ) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class DropPath(TimmDropPath): + def __init__(self, drop_prob=None): + super().__init__(drop_prob=drop_prob) + self.drop_prob = drop_prob + + def __repr__(self): + msg = super().__repr__() + msg += f"(drop_prob={self.drop_prob})" + return msg + + +class PatchEmbed(nn.Module): + def __init__(self, in_chans, embed_dim, resolution, activation): + super().__init__() + img_size: Tuple[int, int] = to_2tuple(resolution) + # self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.patches_resolution = img_size + self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + n = embed_dim + # self.seq = nn.Sequential( + # Conv2d_BN(in_chans, n // 2, 3, 2, 1), + # activation(), + # Conv2d_BN(n // 2, n, 3, 2, 1), + # ) + self.seq = nn.Sequential( + Conv2d_BN(in_chans, n // 2, 1, 1, 0), + activation(), + Conv2d_BN(n // 2, n, 1, 1, 0), + ) + + def forward(self, x): + return self.seq(x) + + +class MBConv(nn.Module): + def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path): + super().__init__() + self.in_chans = in_chans + self.hidden_chans = int(in_chans * expand_ratio) + self.out_chans = out_chans + + self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) + self.act1 = activation() + + self.conv2 = Conv2d_BN( + self.hidden_chans, + self.hidden_chans, + ks=3, + stride=1, + pad=1, + groups=self.hidden_chans, + ) + self.act2 = activation() + + self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) + self.act3 = activation() + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.act2(x) + + x = self.conv3(x) + + x = self.drop_path(x) + + x += shortcut + x = self.act3(x) + + return x + + +class PatchMerging(nn.Module): + def __init__(self, input_resolution, dim, out_dim, activation): + super().__init__() + + self.input_resolution = input_resolution + self.dim = dim + self.out_dim = out_dim + self.act = activation() + self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) + stride_c = 2 + if out_dim == 320 or out_dim == 448 or out_dim == 576: + stride_c = 1 + self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) + self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) + + def forward(self, x): + if x.ndim == 3: + H, W = self.input_resolution + B = len(x) + # (B, C, H, W) + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class ConvLayer(nn.Module): + def __init__( + self, + dim, + input_resolution, + depth, + activation, + drop_path=0.0, + downsample=None, + use_checkpoint=False, + out_dim=None, + conv_expand_ratio=4.0, + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + MBConv( + dim, + dim, + conv_expand_ratio, + activation, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation + ) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm = nn.LayerNorm(in_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) + self.act = act_layer() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(torch.nn.Module): + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): + super().__init__() + # (h, w) + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list(itertools.product(range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets)) + ) + self.register_buffer( + "attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False + ) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, "ab"): + del self.ab + else: + self.register_buffer( + "ab", + self.attention_biases[:, self.attention_bias_idxs], + persistent=False, + ) + + def forward(self, x): # x (B,N,C) + B, N, _ = x.shape + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, -1).split( + [self.key_dim, self.key_dim, self.d], dim=3 + ) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + ( + self.attention_biases[:, self.attention_bias_idxs] + if self.training + else self.ab + ) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class TinyViTBlock(nn.Module): + r"""TinyViT Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int, int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + local_conv_size (int): the kernel size of the convolution between + Attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + local_conv_size=3, + activation=nn.GELU, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + assert window_size > 0, "window_size must be greater than 0" + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + assert dim % num_heads == 0, "dim must be divisible by num_heads" + head_dim = dim // num_heads + + window_resolution = (window_size, window_size) + self.attn = Attention( + dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution + ) + + mlp_hidden_dim = int(dim * mlp_ratio) + mlp_activation = activation + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=mlp_activation, + drop=drop, + ) + + pad = local_conv_size // 2 + self.local_conv = Conv2d_BN( + dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim + ) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + res_x = x + if H == self.window_size and W == self.window_size: + x = self.attn(x) + else: + x = x.view(B, H, W, C) + pad_b = (self.window_size - H % self.window_size) % self.window_size + pad_r = (self.window_size - W % self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + # window partition + x = ( + x.view(B, nH, self.window_size, nW, self.window_size, C) + .transpose(2, 3) + .reshape(B * nH * nW, self.window_size * self.window_size, C) + ) + x = self.attn(x) + # window reverse + x = ( + x.view(B, nH, nW, self.window_size, self.window_size, C) + .transpose(2, 3) + .reshape(B, pH, pW, C) + ) + + if padding: + x = x[:, :H, :W].contiguous() + + x = x.view(B, L, C) + + x = res_x + self.drop_path(x) + + x = x.transpose(1, 2).reshape(B, C, H, W) + x = self.local_conv(x) + x = x.view(B, C, L).transpose(1, 2) + + x = x + self.drop_path(self.mlp(x)) + return x + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + ) + + +class BasicLayer(nn.Module): + """A basic TinyViT layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + out_dim: the output dimension of the layer. Default: dim + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + downsample=None, + use_checkpoint=False, + local_conv_size=3, + activation=nn.GELU, + out_dim=None, + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + TinyViTBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=( + drop_path[i] if isinstance(drop_path, list) else drop_path + ), + local_conv_size=local_conv_size, + activation=activation, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation + ) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class TinyViT(nn.Module): + def __init__( + self, + img_size=224, + in_chans=3, + # num_classes=1000, + embed_dims=[96, 192, 384, 768], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + ): + super().__init__() + self.img_size = img_size + # self.num_classes = num_classes + self.depths = depths + self.num_layers = len(depths) + self.mlp_ratio = mlp_ratio + + activation = nn.GELU + + self.patch_embed = PatchEmbed( + in_chans=in_chans, + embed_dim=embed_dims[0], + resolution=img_size, + activation=activation, + ) + + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + kwargs = dict( + dim=embed_dims[i_layer], + input_resolution=( + patches_resolution[0] + // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), + patches_resolution[1] + // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), + ), + # input_resolution=(patches_resolution[0] // (2 ** i_layer), + # patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)], + activation=activation, + ) + if i_layer == 0: + layer = ConvLayer( + conv_expand_ratio=mbconv_expand_ratio, + **kwargs, + ) + else: + layer = BasicLayer( + num_heads=num_heads[i_layer], + window_size=window_sizes[i_layer], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + **kwargs, + ) + self.layers.append(layer) + + # init weights + self.apply(self._init_weights) + self.set_layer_lr_decay(layer_lr_decay) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dims[-1], + 256, + kernel_size=1, + bias=False, + ), + LayerNorm2d(256), + nn.Conv2d( + 256, + 256, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(256), + ) + + def set_layer_lr_decay(self, layer_lr_decay): + decay_rate = layer_lr_decay + + # layers -> blocks (depth) + depth = sum(self.depths) + lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] + + def _set_lr_scale(m, scale): + for p in m.parameters(): + p.lr_scale = scale + + self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) + i = 0 + for layer in self.layers: + for block in layer.blocks: + block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) + i += 1 + if layer.downsample is not None: + layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1])) + assert i == depth + + for k, p in self.named_parameters(): + p.param_name = k + + def _check_lr_scale(m): + for p in m.parameters(): + assert hasattr(p, "lr_scale"), p.param_name + + self.apply(_check_lr_scale) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"attention_biases"} + + def forward_features(self, x): + # x: (N, C, H, W) + x = self.patch_embed(x) + + x = self.layers[0](x) + start_i = 1 + + for i in range(start_i, len(self.layers)): + layer = self.layers[i] + x = layer(x) + + B, _, C = x.size() + x = x.view(B, 64, 64, C) + x = x.permute(0, 3, 1, 2) + x = self.neck(x) + + return x + + def forward(self, x): + x = self.forward_features(x) + return x diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/onnx/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/onnx/__init__.py new file mode 100644 index 0000000..351ee17 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/onnx/__init__.py @@ -0,0 +1,2 @@ +from .encoder import EncoderOnnxModel +from .decoder import DecoderOnnxModel diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/onnx/decoder.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/onnx/decoder.py new file mode 100644 index 0000000..a7eb14b --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/onnx/decoder.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.models.segment_anything.modeling import MaskDecoder, PromptEncoder + + +class DecoderOnnxModel(nn.Module): + def __init__( + self, + mask_decoder: MaskDecoder, + prompt_encoder: PromptEncoder, + image_encoder_input_size: int = 512, + ): + super().__init__() + self.mask_decoder = mask_decoder + self.prompt_encoder = prompt_encoder + self.image_encoder_input_size = image_encoder_input_size + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + boxes: torch.Tensor, + ): + coords = boxes.reshape(-1, 2, 2) + sparse_embeddings = self.prompt_encoder.pe_layer._pe_encoding(coords) + sparse_embeddings[:, 0, :] += self.prompt_encoder.point_embeddings[2].weight + sparse_embeddings[:, 1, :] += self.prompt_encoder.point_embeddings[3].weight + + dense_embeddings = self.prompt_encoder.no_mask_embed.weight.reshape( + 1, -1, 1, 1 + ).expand( + 1, + -1, + self.prompt_encoder.image_embedding_size[0], + self.prompt_encoder.image_embedding_size[1], + ) + + masks, _ = self.mask_decoder( + image_embeddings=image_embeddings, + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=False, + ) + + masks = F.interpolate( + masks, + (self.image_encoder_input_size, self.image_encoder_input_size), + mode="bilinear", + align_corners=False, + ) + + return masks diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/onnx/encoder.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/onnx/encoder.py new file mode 100644 index 0000000..73e47a1 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/onnx/encoder.py @@ -0,0 +1,93 @@ +from typing import List, Optional + +import torch +import torch.nn as nn +from torch.nn import functional as F +from torchvision.transforms.v2 import functional as F2 + + +class EncoderOnnxModel(nn.Module): + def __init__( + self, + image_encoder: nn.Module, + preprocess_image: bool = True, + image_encoder_input_size: int = 512, + scale_image: bool = True, + normalize_image: bool = False, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + interpolation: str = "bilinear", + ): + super().__init__() + self.image_encoder = image_encoder + self.preprocess_image = preprocess_image + self.image_encoder_input_size = image_encoder_input_size + self.scale_image = scale_image + self.normalize_image = normalize_image + self.pixel_mean = pixel_mean + self.pixel_std = pixel_std + self.interpolation = interpolation + + @torch.no_grad() + def forward( + self, + image: torch.Tensor, + original_size: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + image: (H, W, 3) + """ + image = image.permute(2, 0, 1).unsqueeze(0) # (1, 3, H, W) + + if not self.preprocess_image: + return self.image_encoder(image) + + # Resize longest side + new_size = self.get_preprocess_shape( + original_size, self.image_encoder_input_size + ) + image = F.interpolate( + image, + (new_size[0], new_size[1]), + mode=self.interpolation, + align_corners=False, + ) + + image = image.to(torch.float32) + + # Min max scale + if self.scale_image: + min_val = image.amin((-3, -2, -1), keepdim=True) + max_val = image.amax((-3, -2, -1), keepdim=True) + image = (image - min_val) / torch.clip( + max_val - min_val, min=1e-8, max=None + ) + + # Normalize + if self.normalize_image: + image = F2.normalize(image, self.pixel_mean, self.pixel_std) + + # Pad + image = F.pad( + image, + ( + 0, + self.image_encoder_input_size - image.shape[-1], + 0, + self.image_encoder_input_size - image.shape[-2], + ), + value=0, + ) + + return self.image_encoder(image) + + @staticmethod + def get_preprocess_shape( + original_size: torch.Tensor, + long_side_length: int, + ) -> torch.Tensor: + original_size = original_size.to(torch.float32) + scale = long_side_length / torch.max(original_size) + new_size = scale * original_size + new_size = torch.floor(new_size + 0.5).to(torch.int16) + return new_size diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/README.md b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/README.md new file mode 100644 index 0000000..45405ae --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/README.md @@ -0,0 +1 @@ +Copied from https://github.com/facebookresearch/segment-anything/tree/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/__init__.py new file mode 100644 index 0000000..34383d8 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .build_sam import ( + build_sam, + build_sam_vit_h, + build_sam_vit_l, + build_sam_vit_b, + sam_model_registry, +) +from .predictor import SamPredictor +from .automatic_mask_generator import SamAutomaticMaskGenerator diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/automatic_mask_generator.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/automatic_mask_generator.py new file mode 100644 index 0000000..d5a8c96 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/automatic_mask_generator.py @@ -0,0 +1,372 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from typing import Any, Dict, List, Optional, Tuple + +from .modeling import Sam +from .predictor import SamPredictor +from .utils.amg import ( + MaskData, + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SamAutomaticMaskGenerator: + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=True, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/build_sam.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/build_sam.py new file mode 100644 index 0000000..37cd245 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/build_sam.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/__init__.py new file mode 100644 index 0000000..38e9062 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/common.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/common.py new file mode 100644 index 0000000..2bf1523 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/image_encoder.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/image_encoder.py new file mode 100644 index 0000000..66351d9 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/image_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/mask_decoder.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/mask_decoder.py new file mode 100644 index 0000000..5592d7f --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/mask_decoder.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0] // image_embeddings.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/prompt_encoder.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/prompt_encoder.py new file mode 100644 index 0000000..c3143f4 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/sam.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/sam.py new file mode 100644 index 0000000..8074cff --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/sam.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/transformer.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/transformer.py new file mode 100644 index 0000000..28fafea --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/predictor.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/predictor.py new file mode 100644 index 0000000..1f026eb --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/predictor.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from src.models.segment_anything.modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/amg.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/amg.py new file mode 100644 index 0000000..be06407 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/amg.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/onnx.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/onnx.py new file mode 100644 index 0000000..3196bdf --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/onnx.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + return_single_mask: bool, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.return_single_mask = return_single_mask + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + def select_masks( + self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Determine if we should return the multiclick mask or not from the number of points. + # The reweighting is used to avoid control flow. + score_reweight = torch.tensor( + [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] + ).to(iou_preds.device) + score = iou_preds + (num_points - 2.5) * score_reweight + best_idx = torch.argmax(score, dim=1) + masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) + iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) + + return masks, iou_preds + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.return_single_mask: + masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/transforms.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/transforms.py new file mode 100644 index 0000000..c08ba1e --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/models/segment_anything/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/schedulers/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/schedulers/__init__.py new file mode 100644 index 0000000..244b260 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/schedulers/__init__.py @@ -0,0 +1,122 @@ +import math +from torch.optim.lr_scheduler import _LRScheduler + + +class PolyLR(_LRScheduler): + def __init__( + self, optimizer, max_iter, decay_iter=1, power=0.9, last_epoch=-1 + ) -> None: + self.decay_iter = decay_iter + self.max_iter = max_iter + self.power = power + super().__init__(optimizer, last_epoch=last_epoch) + + def get_lr(self): + if self.last_epoch % self.decay_iter or self.last_epoch % self.max_iter: + return self.base_lrs + else: + factor = (1 - self.last_epoch / float(self.max_iter)) ** self.power + return [factor * lr for lr in self.base_lrs] + + +class WarmupLR(_LRScheduler): + def __init__( + self, optimizer, warmup_iter=500, warmup_ratio=5e-4, warmup="exp", last_epoch=-1 + ) -> None: + self.warmup_iter = warmup_iter + self.warmup_ratio = warmup_ratio + self.warmup = warmup + super().__init__(optimizer, last_epoch) + + def get_lr(self): + ratio = self.get_lr_ratio() + return [ratio * lr for lr in self.base_lrs] + + def get_lr_ratio(self): + return ( + self.get_warmup_ratio() + if self.last_epoch < self.warmup_iter + else self.get_main_ratio() + ) + + def get_main_ratio(self): + raise NotImplementedError + + def get_warmup_ratio(self): + assert self.warmup in ["linear", "exp"] + alpha = self.last_epoch / self.warmup_iter + + return ( + self.warmup_ratio + (1.0 - self.warmup_ratio) * alpha + if self.warmup == "linear" + else self.warmup_ratio ** (1.0 - alpha) + ) + + +class WarmupPolyLR(WarmupLR): + def __init__( + self, + optimizer, + power, + max_iter, + warmup_iter=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ) -> None: + self.power = power + self.max_iter = max_iter + super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) + + def get_main_ratio(self): + real_iter = self.last_epoch - self.warmup_iter + real_max_iter = self.max_iter - self.warmup_iter + alpha = real_iter / real_max_iter + + return (1 - alpha) ** self.power + + +class WarmupExpLR(WarmupLR): + def __init__( + self, + optimizer, + gamma, + interval=1, + warmup_iter=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ) -> None: + self.gamma = gamma + self.interval = interval + super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) + + def get_main_ratio(self): + real_iter = self.last_epoch - self.warmup_iter + return self.gamma ** (real_iter // self.interval) + + +class WarmupCosineLR(WarmupLR): + def __init__( + self, + optimizer, + max_iter, + eta_ratio=0, + warmup_iter=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ) -> None: + self.eta_ratio = eta_ratio + self.max_iter = max_iter + super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) + + def get_main_ratio(self): + real_max_iter = self.max_iter - self.warmup_iter + + return ( + self.eta_ratio + + (1 - self.eta_ratio) + * (1 + math.cos(math.pi * self.last_epoch / real_max_iter)) + / 2 + ) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/train.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/train.py new file mode 100644 index 0000000..b868c15 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/train.py @@ -0,0 +1,126 @@ +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import hydra +import lightning as L +import rootutils +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from lightning.pytorch.tuner import Tuner +from omegaconf import DictConfig + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from src import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/rootutils +# ------------------------------------------------------------------------------------ # + +from src.utils import ( + RankedLogger, + extras, + instantiate_callbacks, + instantiate_loggers, + log_hyperparameters, + task_wrapper, +) + +log = RankedLogger(__name__, rank_zero_only=True) + + +@task_wrapper +def train(cfg: DictConfig): + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + + :param cfg: A DictConfig configuration composed by Hydra. + """ + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + cfg.trainer, callbacks=callbacks, logger=logger + ) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + log_hyperparameters(object_dict) + + if cfg.get("find_lr"): + model.learning_rate = None # Tuner needs a learning_rate attribute + tuner = Tuner(trainer) + lr_finder = tuner.lr_find(model=model, datamodule=datamodule, attr_name=None) + fig = lr_finder.plot(suggest=True) + fig.savefig( + Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) + / "lr_finder.png" + ) + return + + if cfg.get("train"): + log.info("Starting training!") + trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = None + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") +def main(cfg: DictConfig): + """Main entry point for training. + + :param cfg: DictConfig configuration composed by Hydra. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + # train the model + train(cfg) + + +if __name__ == "__main__": + main() diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/__init__.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/__init__.py new file mode 100644 index 0000000..fd40a10 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/__init__.py @@ -0,0 +1,6 @@ +from src.utils.instantiators import instantiate_callbacks, instantiate_loggers +from src.utils.logging_utils import log_hyperparameters +from src.utils.multiprocessing import parmap +from src.utils.pylogger import RankedLogger +from src.utils.rich_utils import enforce_tags, print_config_tree +from src.utils.utils import extras, task_wrapper diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/instantiators.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/instantiators.py new file mode 100644 index 0000000..82b9278 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/instantiators.py @@ -0,0 +1,56 @@ +from typing import List + +import hydra +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from src.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config. + + :param callbacks_cfg: A DictConfig object containing callback configurations. + :return: A list of instantiated callbacks. + """ + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config. + + :param logger_cfg: A DictConfig object containing logger configurations. + :return: A list of instantiated loggers. + """ + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/logging_utils.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/logging_utils.py new file mode 100644 index 0000000..360abcd --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/logging_utils.py @@ -0,0 +1,57 @@ +from typing import Any, Dict + +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import OmegaConf + +from src.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def log_hyperparameters(object_dict: Dict[str, Any]) -> None: + """Controls which config parts are saved by Lightning loggers. + + Additionally saves: + - Number of model parameters + + :param object_dict: A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. + """ + hparams = {} + + cfg = OmegaConf.to_container(object_dict["cfg"]) + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/multiprocessing.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/multiprocessing.py new file mode 100644 index 0000000..d1a762c --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/multiprocessing.py @@ -0,0 +1,36 @@ +import multiprocessing + +from tqdm import tqdm + + +def fun(f, q_in, q_out): + while True: + i, x = q_in.get() + if i is None: + break + q_out.put((i, f(x))) + + +def parmap(f, X, nprocs=multiprocessing.cpu_count()): + q_in = multiprocessing.Queue(1) + q_out = multiprocessing.Queue() + + proc = [ + multiprocessing.Process(target=fun, args=(f, q_in, q_out)) + for _ in range(nprocs) + ] + for p in proc: + p.daemon = True + p.start() + + sent = [ + q_in.put((i, x)) for i, x in enumerate(tqdm(X, position=0, desc="Queue In")) + ] + for _ in range(nprocs): + q_in.put((None, None)) + res = [q_out.get() for _ in tqdm(range(len(sent)), position=1, desc="Queue Out")] + res = [x for _, x in sorted(res)] + for p in proc: + p.join() + + return res diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/pylogger.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/pylogger.py new file mode 100644 index 0000000..c4ee867 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/pylogger.py @@ -0,0 +1,51 @@ +import logging +from typing import Mapping, Optional + +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" + + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = False, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. + + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only + + def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. + + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/rich_utils.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/rich_utils.py new file mode 100644 index 0000000..aeec680 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/rich_utils.py @@ -0,0 +1,99 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from src.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints the contents of a DictConfig as a tree structure using the Rich library. + + :param cfg: A DictConfig composed by Hydra. + :param print_order: Determines in what order config components are printed. Default is ``("data", "model", + "callbacks", "logger", "trainer", "paths", "extras")``. + :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. + :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + queue.append(field) if field in cfg else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config. + + :param cfg: A DictConfig composed by Hydra. + :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + """ + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/transforms.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/transforms.py new file mode 100644 index 0000000..19486b2 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/transforms.py @@ -0,0 +1,155 @@ +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torchvision.transforms.v2 as transforms +from torch.nn import functional as F + + +class ResizeLongestSide(torch.nn.Module): + def __init__( + self, + long_side_length: int, + interpolation: str, + ) -> None: + super().__init__() + self.long_side_length = long_side_length + self.interpolation = interpolation + + def forward(self, image: torch.Tensor) -> torch.Tensor: + oldh, oldw = image.shape[-2:] + if max(oldh, oldw) == self.long_side_length: + return image + newh, neww = self.get_preprocess_shape(oldh, oldw, self.long_side_length) + return F.interpolate( + image, (newh, neww), mode=self.interpolation, align_corners=False + ) + + @staticmethod + def get_preprocess_shape( + oldh: int, + oldw: int, + long_side_length: int, + ) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) + + +class MinMaxScale(torch.nn.Module): + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + image should have shape (..., 3, H, W) + """ + assert len(image.shape) >= 3 and image.shape[-3] == 3 + min_val = image.amin((-3, -2, -1), keepdim=True) + max_val = image.amax((-3, -2, -1), keepdim=True) + return (image - min_val) / torch.clip(max_val - min_val, min=1e-8, max=None) + + +class PadToSquare(torch.nn.Module): + def __init__(self, target_size: int) -> None: + super().__init__() + self.target_size = target_size + + def forward(self, image: torch.Tensor) -> torch.Tensor: + h, w = image.shape[-2:] + return F.pad(image, (0, self.target_size - w, 0, self.target_size - h), value=0) + + +def get_bbox(mask: np.ndarray, bbox_shift: int = 0) -> np.ndarray: + """ + Get the bounding box coordinates from the mask + + Parameters + ---------- + mask : numpy.ndarray + the mask of the resized image + + bbox_shift : int + Add perturbation to the bounding box coordinates + + Returns + ------- + numpy.ndarray + bounding box coordinates in the resized image + """ + y_indices, x_indices = np.where(mask > 0) + x_min, x_max = np.min(x_indices), np.max(x_indices) + y_min, y_max = np.min(y_indices), np.max(y_indices) + # add perturbation to bounding box coordinates and test the robustness + # this can be removed if you do not want to test the robustness + H, W = mask.shape + x_min = max(0, x_min - bbox_shift) + x_max = min(W - 1, x_max + bbox_shift) + y_min = max(0, y_min - bbox_shift) + y_max = min(H - 1, y_max + bbox_shift) + + bboxes = np.array([x_min, y_min, x_max, y_max]) + + return bboxes + + +def resize_box( + box: np.ndarray, + original_size: Tuple[int, int], + prompt_encoder_input_size: int, +) -> np.ndarray: + """ + the input bounding box is obtained from the original image + here, we rescale it to the coordinates of the resized image + + Parameters + ---------- + box : numpy.ndarray + bounding box coordinates in the original image + original_size : tuple + the original size of the image + prompt_encoder_input_size : int + the target size of the image + + Returns + ------- + numpy.ndarray + bounding box coordinates in the resized image + """ + new_box = np.zeros_like(box) + ratio = prompt_encoder_input_size / max(original_size) + for i in range(len(box)): + new_box[i] = int(box[i] * ratio) + + return new_box + + +def get_image_transform( + long_side_length: int, + min_max_scale: bool = True, + normalize: bool = False, + pixel_mean: Optional[List[float]] = None, + pixel_std: Optional[List[float]] = None, + interpolation: str = "bilinear", +) -> transforms.Transform: + tsfm = [ + ResizeLongestSide(long_side_length, interpolation), + transforms.ToDtype(dtype=torch.float32, scale=False), + ] + if min_max_scale: + tsfm.append(MinMaxScale()) + if normalize: + tsfm.append(transforms.Normalize(pixel_mean, pixel_std)) + tsfm.append(PadToSquare(long_side_length)) + return transforms.Compose(tsfm) + + +def transform_gt(gt: torch.Tensor, long_side_length: int): + gt = gt[None, None, ...] + oldh, oldw = gt.shape[-2:] + newh, neww = ResizeLongestSide.get_preprocess_shape(oldh, oldw, long_side_length) + gt = F.interpolate(gt, (newh, neww), mode="nearest-exact") + gt = F.pad(gt, (0, long_side_length - neww, 0, long_side_length - newh), value=0) + return gt.squeeze((0, 1)) diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/utils.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/utils.py new file mode 100644 index 0000000..731eb1a --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/utils.py @@ -0,0 +1,92 @@ +import warnings +from importlib.util import find_spec +from typing import Callable + +from omegaconf import DictConfig + +from src.utils import pylogger, rich_utils + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + + :param cfg: A DictConfig object containing the config tree. + """ + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig): + ... + ``` + + :param task_func: The task function to be wrapped. + + :return: The wrapped task function. + """ + + def wrap(cfg: DictConfig): + # execute the task + try: + task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return wrap diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/visualize.py b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/visualize.py new file mode 100644 index 0000000..70ce103 --- /dev/null +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/engines/src/utils/visualize.py @@ -0,0 +1,73 @@ +import numpy as np +from matplotlib import pyplot as plt + + +def show_mask(mask, ax, mask_color=None, alpha=0.5): + """ + show mask on the image + + Parameters + ---------- + mask : numpy.ndarray + mask of the image + ax : matplotlib.axes.Axes + axes to plot the mask + mask_color : numpy.ndarray + color of the mask + alpha : float + transparency of the mask + """ + if mask_color is not None: + color = np.concatenate([mask_color, np.array([alpha])], axis=0) + else: + color = np.array([251 / 255, 252 / 255, 30 / 255, alpha]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + + +def show_box(box, ax, edgecolor="blue"): + """ + show bounding box on the image + + Parameters + ---------- + box : numpy.ndarray + bounding box coordinates in the original image + ax : matplotlib.axes.Axes + axes to plot the bounding box + edgecolor : str + color of the bounding box + """ + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch( + plt.Rectangle((x0, y0), w, h, edgecolor=edgecolor, facecolor=(0, 0, 0, 0), lw=2) + ) + + +def visualize_output(img, boxes, segs, save_file=None): + _, ax = plt.subplots(1, 2, figsize=(10, 5)) + ax[0].imshow(img) + ax[1].imshow(img) + ax[0].set_title("Input") + ax[1].set_title("Segmentation") + ax[0].axis("off") + ax[1].axis("off") + + for i, box in enumerate(boxes): + color = np.random.rand(3) + box_viz = box + mask = (segs == i + 1).astype(np.uint16) + show_box(box_viz, ax[0], edgecolor=color) + + if np.max(mask) > 0: + show_box(box_viz, ax[1], edgecolor=color) + show_mask(mask, ax[1], mask_color=color) + + plt.tight_layout() + if save_file is not None: + plt.savefig(save_file, dpi=300) + plt.close() + else: + plt.show() diff --git a/MedSAMLite/Resources/server_essentials/medsam_interface/interface_impl.py b/MedSAMLite/Resources/server_essentials/medsam_interface/interface_impl.py index 23b132a..b4e2d52 100644 --- a/MedSAMLite/Resources/server_essentials/medsam_interface/interface_impl.py +++ b/MedSAMLite/Resources/server_essentials/medsam_interface/interface_impl.py @@ -4,7 +4,7 @@ 'Classic MedSAM': ClassicMedSAM, 'OpenVino MedSAM': OVMedSAMCore, 'DAFT MedSAM': DAFTSAMCore, - # 'Medficient SAM': MedficientSAMCore, + 'Medficient SAM': MedficientSAMCore, } class MedSAM_Interface: