This repository implements BrainDiffU-Net, a decentralized learning framework based on the U-Net architecture for segmenting brain MRI images. It integrates privacy-preserving diffusion techniques to ensure compliance with data privacy regulations, making it suitable for medical image analysis tasks such as brain tumor segmentation.
The project is divided into five main scripts:
- Purpose: Handles dataset distribution across decentralized nodes.
- Key Features:
split_data_into_nodes
: Splits the dataset into subsets for each node to simulate a decentralized environment.- Ensures balanced data distribution while redistributing leftover samples to the last node.
- Purpose: Defines the
BrainMRIDataset
class for loading and preprocessing MRI images and segmentation masks. - Key Features:
__getitem__
: Loads an MRI image and its mask, applies preprocessing (resizing and normalization), and converts them into PyTorch tensors.__len__
: Returns the total number of samples in the dataset.- Supports augmentations using
transforms
.
- Purpose: Implements the U-Net architecture for segmentation tasks.
- Key Features:
- Encoder-decoder structure with skip connections to preserve spatial details.
conv_block
for convolutional operations andforward()
for the model's forward pass.
- Purpose: Contains the
train_node_model
function to train individual node models in a decentralized setup. - Key Features:
- Implements the training loop with loss calculation, backpropagation, and optimizer updates.
- Tracks metrics like Dice Coefficient and IoU during training.
- Supports learning rate scheduling for improved convergence.
- Purpose: Defines metrics for evaluating segmentation quality.
- Key Features:
dice_coefficient
: Computes the Dice Similarity Coefficient to measure overlap between predicted and ground truth masks.iou
: Calculates Intersection over Union for segmentation accuracy.
- Purpose: Provides utility functions for reproducibility and device initialization.
- Key Features:
set_seed
: Ensures reproducibility across different runs by setting seeds for NumPy and PyTorch.initialize_device
: Detects available GPUs or defaults to CPU for training.
- Purpose: Serves as the main script to integrate all components of the framework.
- Key Features:
- Loads and preprocesses the dataset, splitting it across nodes.
- Initializes U-Net models, optimizers, schedulers, and DataLoaders.
- Orchestrates decentralized training with a diffusion process for collaborative learning.
- Supports multi-GPU training via
DataParallel
.
The dataset used in this project comes from two key sources:
-
Mateusz Buda, Ashirbani Saha, Maciej A. Mazurowski
"Association of genomic subtypes of lower-grade gliomas with shape features automatically extracted by a deep learning algorithm." Computers in Biology and Medicine, 2019. -
Maciej A. Mazurowski, Kal Clark, Nicholas M. Czarnek, Parisa Shamsesfandabadi, Katherine B. Peters, Ashirbani Saha
"Radiogenomics of lower-grade glioma: algorithmically-assessed tumor shape is associated with tumor genomic subtypes and patient outcomes in a multi-institutional study with The Cancer Genome Atlas data." Journal of Neuro-Oncology, 2017.
This dataset contains brain MR images together with manual FLAIR abnormality segmentation masks. The images were obtained from The Cancer Imaging Archive (TCIA) and correspond to 110 patients included in The Cancer Genome Atlas (TCGA) lower-grade glioma collection. Each patient has at least one fluid-attenuated inversion recovery (FLAIR) sequence and genomic cluster data available.
- Tumor genomic clusters and patient data are provided in the
data.csv
file. - For more information on genomic data, refer to the publication "Comprehensive, Integrative Genomic Analysis of Diffuse Lower-Grade Gliomas" and supplementary material available here.
-
Clone the repository:
git clone https://github.com/your-username/BrainDiffU-Net.git cd BrainDiffU-Net
-
Create a virtual environment and activate it:
python -m venv braindiffunet-env source braindiffunet-env/bin/activate # On Windows use `braindiffunet-env\Scripts\activate`
-
Install the required packages:
pip install -r requirements.txt
Run the main.py script to start the decentralized training process:
python main.py
After training, update the evaluation section in main.py
to visualize predictions:
- Input MRI images
- Ground truth masks
- Predicted masks
If you encounter any issues or errors while running the project, please check the following:
-
Ensure all dependencies are installed correctly by running
pip install -r requirements.txt
. -
Make sure you are using a compatible version of Python (e.g., Python 3.6 or higher).
-
Verify that the dataset paths in
main.py
are correct.
If problems persist, feel free to open an issue on GitHub.
Contributions are welcome! If you have suggestions for improvements or bug fixes, please follow these steps:
-
Fork the repository.
-
Create a new branch (
git checkout -b feature-branch
). -
Make your changes and commit them (
git commit -m 'Add some feature'
). -
Push to the branch (
git push origin feature-branch
). -
Open a pull request.
Please ensure your code follows the existing style and includes appropriate tests.
This project is licensed under the MIT License. See the LICENSE file for details.