Shahina Kunhimon,
Abdelrahman Shaker,
Muzammal Naseer,
Salman Khan,
and Fahad Shahbaz Khan
Abstract: Hybrid volumetric medical image segmentation models, combining the advantages of local convolution and global attention, have recently received considerable attention. While mainly focusing on architectural modifications, most existing hybrid approaches still use conventional data-independent weight initialization schemes which restrict their performance due to ignoring the inherent volumetric nature of the medical data. To address this issue, we propose a learnable weight initialization approach that utilizes the available medical training data to effectively learn the contextual and structural cues via the proposed self-supervised objectives. Our approach is easy to integrate into any hybrid model and requires no external training data. Experiments on multi-organ and lung cancer segmentation tasks demonstrate the effectiveness of our approach, leading to state-of-the-art segmentation performance.
If you find our work, this repository, or pretrained models useful, please consider giving a star ⭐ and citation.
@article{kunhimon2024learnable,
title={Learnable weight initialization for volumetric medical image segmentation},
author={Kunhimon, Shahina and Shaker, Abdelrahman and Naseer, Muzammal and Khan, Salman and Khan, Fahad Shahbaz},
journal={Artificial Intelligence in Medicine},
volume={151},
pages={102863},
year={2024},
publisher={Elsevier}
}
- We propose a learnable weight initialization method that can be integrated into any hybrid volumetric medical segmentation model.
- To learn such a weight initialization, we propose data-dependent self-supervised objectives tailored to learn the structural and contextual cues from the volumetric medical data.
- We demonstrate the effectiveness of our approach by conducting experiments for multi-organ and tumor segmentation tasks, achieving superior segmentation performance without requiring additional external training data.
- Create and activate conda environment
conda create --name unetr_pp python=3.8
conda activate unetr_pp
- Install PyTorch and torchvision
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
- Install other dependencies
pip install -r requirements.txt
- Download the data Synapse dataset
- Download the json file for training:validation data split Synapse_split.
- Follow the instructions in UNETR++ repo to download the preprocessed data and to organise the dataset folders.
- To train Synapse on UNETR using our proposed two stage framework:
bash UNETR/init_train_test.sh init_run1
- To train Synapse on UNETR++ using our proposed two stage framework:
bash UNETR_PP/init_train_val_synapse.sh init_run1
Download the pretrained model weights for Synapse_UNETR_Ours and paste Synapse_UNETR_Ours_final.pt
in the path: UNETR/BTCV/pretrained_models/
. Then, run
python BTCV/test.py --infer_overlap=0.5 \
--pretrained_dir='./pretrained_models/' \
--pretrained_model_name=Synapse_UNETR_Ours_final.pt \
--saved_checkpoint=ckpt
Download the pretrained model weights for Synapse_UNETR++ Ours and paste model_final_checkpoint.model
in the following path:
UNETR_PP/unetr_pp/evaluation/unetr_pp_ours_synapse_checkpoint/unetr_pp/3d_fullres/Task002_Synapse/unetr_pp_trainer_synapse__unetr_pp_Plansv2.1/fold_0/
Then, run
bash evaluation_scripts/run_evaluation_synapse.sh init_run1
Our approach improves the SOTA on Synapse. We observe large variance in the performance of existing methods across different organs, In comparison, our approach consistently performs better while increasing the overall performance.
The proposed data-dependent initialization scheme, when integrated with UNETR++, improves the overall segmentation performance by accurately segmenting the organs and delineating organ boundaries.
Should you have any questions, please create an issue in this repository or contact [email protected]
Our code is build on the repositories of UNETR and UNETR++. We thank them for releasing their code.