This repo develop a domain adaption training method for radiology report summarization. It has two parts
- re-training BART model with Masked Language Masking (MLM) task using medical entities
- fine-tuning the model into text summarization data (radiology reports)
For retraining, we used the full radiology reports from MIMIC-III Clinical Database CareVue subset. This could avoid information leak when we fine-tuning the model using MIMIC-CXR (MIMIC-IV) data.
For fine-tuning, we used the same dataset from MEDIQA 2021. Please ask the organizer if you want to use them.
You'll need to obtain access to the MIMIC. Note that since we are only using the radiology report text data, you do NOT need to download the entire release. For MIMIC-III Clinical Database CareVue subset, you only need to download the NOTEEVENT.csv
. FOR MIMIC-CXR (MIMIC-IV), the only file you'll need to download is the compressed report file (mimic-cxr-reports.zip
).
Before running this, you need to make sure the following things are in the correct folder:
.
├── data
│ ├── fine_tune
│ │ ├── CXR_test.json
│ │ ├── CXR_val.json
│ │ ├── MEDIQA2021_RRS_Test_Set_Full.json
│ │ └── indiana_dev.json
│ └── retrain
│ ├── MIMIC_test_full.txt
│ ├── MIMIC_train_full.txt
│ └── MIMIC_val_full.txt
├── model
│ ├── RM
│ │ ├── README.md
│ │ ├── config.json
│ │ ├── flax_model.msgpack
│ │ ├── gitattributes.txt
│ │ ├── merges.txt
│ │ ├── model.safetensors
│ │ ├── pytorch_model.bin
│ │ ├── tf_model.h5
│ │ ├── tokenizer.json
│ │ └── vocab.json
│ └── SM
│ ├── phrase_SM
│ │ ├── README.md
│ │ ├── added_tokens.json
│ │ ├── config.json
│ │ ├── flax_model.msgpack
│ │ ├── gitattributes.txt
│ │ ├── merges.txt
│ │ ├── model.safetensors
│ │ ├── pytorch_model.bin
│ │ ├── tf_model.h5
│ │ ├── tokenizer.json
│ │ └── vocab.json
│ └── word_SM
│ ├── README.md
│ ├── added_tokens.json
│ ├── config.json
│ ├── flax_model.msgpack
│ ├── gitattributes.txt
│ ├── merges.txt
│ ├── model.safetensors
│ ├── pytorch_model.bin
│ ├── tf_model.h5
│ ├── tokenizer.json
│ └── vocab.json
├── pipeline
│ ├── MLM
│ │ ├── plot_result.ipynb
│ │ ├── run_mlm_RM.py
│ │ └── run_mlm_SM.py
│ └── fine_tune
│ ├── fine_tune_test.py
│ ├── fine_tune_train.py
│ └── utils.py
├── README.md
├── requirements.txt
- to obatin fine-tune data, see here
- to obtain re-train data, please contact me once you have the assess to MIMIC data
- for the models, you need to download the original model files from huggingface
Once you have all the files ready, you can run it like this:
python MLM/run_mlm_RM.py
--model_name_or_path ../model/RM
--line_by_line
--num_train_epochs 20
--train_file ../data/retrain/MIMIC_train_full.txt
--validation_file ../data/retrain/MIMIC_test_full.txt
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--do_train
--do_eval
--output_dir ../checkpoint
for fine-tuning, you can use pre-trained models from huggingface:
python fine_tune/fine_tune_train.py
--model_checkpoint facebook/bart-base
--model_save ../fine_tune_results
--epoch_num 20
--max_dataset_size 20000000
and test:
python test_parse.py
--model_save fine_tune_results
--model_checkpoint_path ../fine_tune_results_epoch_2_valid_rouge_22.22_model_weights.bin
--model_checkpoint facebook/bart-base
or you can use the re-training model from MLM step:
python fine_tune/fine_tune_train.py
--model_checkpoint ../chekpoint/checkpoint-1000
--model_save ../fine_tune_results
--epoch_num 20
--max_dataset_size 20000000
and test:
python test_parse.py
--model_save fine_tune_results
--model_checkpoint_path ../fine_tune_results_epoch_2_valid_rouge_22.22_model_weights.bin
--model_checkpoint ../chekpoint/checkpoint-1000
- Python==3.8
- datasets==2.7.0
- evaluate==0.4.0
- lp2==1.8.48
- numpy==1.23.5
- rouge==1.0.1
- torch==1.12.0+cu113
- torchmetrics==0.10.3
- tqdm==4.64.1
- transformers==4.25.1
- scikit-learn=1.1.3
- scipy=1.9.3
- wandb=0.13.7
Jinge Wu: [email protected]