From 8c1bd5e70e16ef1e2e0d21c4995abf588cd1efbc Mon Sep 17 00:00:00 2001 From: pierre-nedelec Date: Fri, 23 Dec 2022 07:30:56 -0800 Subject: [PATCH] initial commit --- .gitignore | 129 ++++++ LICENSE | 28 ++ README.md | 267 ++++++++++++ example_split.csv | 3 + unet.yml | 181 ++++++++ unet/START_HERE.py | 90 ++++ unet/__init__.py | 0 unet/analysis/__init__.py | 0 unet/analysis/utils_data.py | 42 ++ unet/analysis/utils_plot.py | 146 +++++++ unet/analysis/utils_stats.py | 41 ++ unet/image/LinearInterpolation.py | 366 ++++++++++++++++ unet/image/__init__.py | 0 unet/image/elastic_transform_tf.py | 200 +++++++++ unet/image/fix_orientation.py | 69 +++ unet/image/get_main_component.sh | 15 + unet/image/image_plot.py | 127 ++++++ unet/image/patch_util.py | 181 ++++++++ unet/image/postprocess_deepmedic.py | 70 +++ unet/image/prepare_predict.py | 57 +++ unet/image/preprocess_Unet.py | 255 +++++++++++ unet/image/preprocess_images.py | 197 +++++++++ unet/image/process_image.py | 170 ++++++++ unet/image/resample_util.py | 120 ++++++ unet/model/__init__.py | 0 unet/model/model_definition.py | 292 +++++++++++++ unet/model/train_test.py | 394 +++++++++++++++++ unet/utils/__init__.py | 0 unet/utils/calculateDice_MICCAI.py | 125 ++++++ unet/utils/config.py | 643 ++++++++++++++++++++++++++++ unet/utils/datacheck.py | 27 ++ unet/utils/datasplit.py | 153 +++++++ unet/utils/dice.py | 38 ++ unet/utils/multiprocessing.py | 35 ++ unet/utils/tic_toc.py | 18 + 35 files changed, 4479 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 example_split.csv create mode 100644 unet.yml create mode 100755 unet/START_HERE.py create mode 100644 unet/__init__.py create mode 100755 unet/analysis/__init__.py create mode 100644 unet/analysis/utils_data.py create mode 100644 unet/analysis/utils_plot.py create mode 100644 unet/analysis/utils_stats.py create mode 100755 unet/image/LinearInterpolation.py create mode 100644 unet/image/__init__.py create mode 100755 unet/image/elastic_transform_tf.py create mode 100755 unet/image/fix_orientation.py create mode 100755 unet/image/get_main_component.sh create mode 100644 unet/image/image_plot.py create mode 100755 unet/image/patch_util.py create mode 100755 unet/image/postprocess_deepmedic.py create mode 100644 unet/image/prepare_predict.py create mode 100755 unet/image/preprocess_Unet.py create mode 100755 unet/image/preprocess_images.py create mode 100755 unet/image/process_image.py create mode 100755 unet/image/resample_util.py create mode 100644 unet/model/__init__.py create mode 100755 unet/model/model_definition.py create mode 100755 unet/model/train_test.py create mode 100644 unet/utils/__init__.py create mode 100755 unet/utils/calculateDice_MICCAI.py create mode 100755 unet/utils/config.py create mode 100644 unet/utils/datacheck.py create mode 100755 unet/utils/datasplit.py create mode 100755 unet/utils/dice.py create mode 100644 unet/utils/multiprocessing.py create mode 100755 unet/utils/tic_toc.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b6e4761 --- /dev/null +++ b/.gitignore @@ -0,0 +1,129 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8ce2793 --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright (c) 2022, The Regents of the University of California + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..5b30a92 --- /dev/null +++ b/README.md @@ -0,0 +1,267 @@ +# 3D U-Net for Fetal Brain Segmentation + +Code used for the following paper: [Development of Gestational Age–Based Fetal Brain and Intracranial Volume Reference Norms Using Deep Learning](http://www.ajnr.org/content/early/2022/12/21/ajnr.A7747). See [Citation](#citation). + +Table of Contents +================= + + * [3D U-Net](#3d-u-net) + * [Description](#description) + * [Installation](#installation) + * [Code](#code) + * [Python Environment](#python-environment) + * [Directory structure](#directory-structure) + * [Usage](#usage) + * [Tensorflow 2 version](#tensorflow-2-version) + * [Configuration](#configuration) + * [Split file](#split-file) + * [Start training & testing](#start-training--testing) + * [Outputs](#outputs) + * [Others](#others) + * [Tensorboard](#tensorboard) + * [Launch](#launch) + * [Read](#read) + * [Setup](#setup) + * [Contributing](#contributing) + * [Citation](#citation) +## Description + +This UNET has been developed to automatically segment fetal brain from 3D MRI images. This repository contains code to train a new model, continue the training, fine tune the training with frozen layers, and test a model with new images. It uses a cross entropy loss and computes a soft dice score during training, and regular dice score for the test part. + +## Installation +### Code +Download this code directly, or use the following: +```sh +cd /directory/where/you/want/the/code +# with UCSF GitHub, you have to use SSH connection +git clone git@github.com:rauschecker-sugrue-labs/fetal-brain-segmentation.git +# to update later, simply use: +git pull origin +``` +The code is accessible inside the `unet` directory, with a `START_HERE.py` file to help getting started. [See more here](#usage). +### Python Environment +A valid `conda` install is needed (Anaconda, Miniconda). +Install the environment from the `unet_def.yml` file. +```sh +# Standard install: +conda env create -f unet_def.yml +# Specify where to install: +conda env create --prefix /path/to/install -f unet_def.yml +``` +### Directory structure +The code and configuration is set up to work best with a specific directory structure, described in [Configuration](#configuration). + +## Usage +### Tensorflow 2 version + +Everything can be done from the `START_HERE.py` file. There are 2 actions to perform: load a `Config` object, and call `Config.start()`. + +1. Load config: `c = Config(your_parameters)`. Specify the run parameters here, cf. [config](#configuration) section. +2. Launch training or testing: `c.start()` + +### Configuration + +Most configuration happens in the file `config.py`, with the object `Config(object)`. It contains train/test setup (paths, *etc.*), along with most model parameters (batch size, # epochs, learning rate, *etc.*). + +Before running any experiment, a csv split file *must* exist in the main folder (`models/` or `predict/`). + +The directories are organized in the following way: +``` +rootdir/ ## referred to as *root* +├── Data +│ ├── my_image_data +│ │ └── raw +│ └── my_image_data2 +│ └── raw +├── models +│ ├── my_train_set +│ ├── my_train_set.csv +│ └── etc. +├── predict +│ └── etc. +└── fetal-brain-segmentation +``` +A call to the configuration object could look like the following: +```py +# Training +c = Config(root='path/to/root', + datasource='my_image_data', + experiment='my_train_set', + train=False, + GPU='5,7') +c.print_recap() + +# Testing +c = Config(root='path/to/root', + datasource='my_image_data', + experiment='my_test_set', + train=False, + model='my_train_set', + which_model='latest', + GPU='5,7') +c.print_recap() +``` +Configuration call options: + +| Variable | Type | Default | Description | Usage | Required | +|-------------|--------------|----------|--------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------|----------| +| root | str | | Path to root directory (above Data/, models/) | | ✓ | +| datasource | str | | Directory name for the source of data | | ✓ | +| experiment | str | | Name of experiment (if train: model name, if predict: prediction name) | | ✓ | +| train | bool | | True for train, False for predict | | ✓ | +| model | str | None | If train = False, this is the model name used to predict the data | | | +| move_images | bool | True | For predict only: move images into directory to allow processing (needs to be ran at least once per predict experiment). | | | +| tmp_storage | str | None | Location of tfrecords | *None=in main exp directory. 'TMP'=$TMPDIR. | | +| which_model | str | 'latest' | Chooses which model to load for predict | 'latest' -> latest \| 'root' if checkpoint is at the root \| 'your_name' -> this name | | +| read_only | bool | False | If set to True, folders won't be created on init. To use only if the goal is to read data from a previous run | | | +| GPU | int/list int | None | Chooses which GPU(s) will be used for this session | If None, all available GPUs will be used. GPU=0 \| GPU=[0,1] \| GPU='1' accepted | | + +### Split file +The split file is a csv file that specifies filenames (e.g. 10101010_FLAIR.nii.gz) and whether that file represents a training, validation, or test case. The split file must be saved in either the model/ or the predict/ directory. It must have the same name as the input of `experiment` variable in the `Config` call. +It doesn't have headers, and consists of 3 columns: +* first: image file name without extension +* second: disease class (not being used currently so could be NA) +* third: train, val, or test (for model training, validation during training, or testing) + +### Start training & testing +Make your own copy of [START_HERE.py](START_HERE.py). +```py +## Load module +from utils.config import Config +``` +#### Training +```py +## Config +# ... for training (OR) +c = Config(root='path/to/root', + datasource='my_image_data', + experiment='my_train_set', + train=True, + GPU='5,7') + +# (OR) ... for continued training +c = Config(root='path/to/root', + datasource='my_image_data', + experiment='my_train_set', + train=True, + model='my_train_set', + which_model='latest', + GPU='5,7') + +## Check that everything is setup correctly +c.print_recap(notes='notes_run') + +## Start training +c.start(notes_run='notes_run') +``` +#### Testing +If you are planning to simply apply a pre-trained model (such as the latest version of the FLAIR U-net) on a new dataset, you can do so as follows: +```py +## Config for testing +c = Config(root='path/to/root', + datasource='my_image_data', + experiment='my_test_set', + train=False, + model='my_train_set', + which_model='latest', + GPU='5,7') + +## Check that everything is setup correctly +c.print_recap(notes='notes_run') + +## Start testing +c.start(notes_run='notes_run') +``` + +### Outputs +**Training output:** +``` +rootdir/models/model_name +├── config.txt +├── model +│ ├── date1 +│ ├── date2 +│ ├── checkpoint +│ ├── cp-030.ckpt.data-00000-of-00001 +│ └── cp-030.ckpt.index +├── tfboard +│ ├── date1 +│ ├── date2 +│ └── date3 +├── tf_records +└── validation_output + ├── binarized_masks + ├── Dice_score + ├── predictions + └── resampled_to_originalspacing +``` +**Prediction outputs:** +``` +rootdir/predict/Tr-model_name_Tt-pred_name/ +├── binarized_masks +├── config.txt +├── Dice_score +├── predictions +├── preprocessed +├── raw +├── resampled_to_originalspacing +└── tf_records +``` +**Tensorboard output:** +See Tensorboard section. + +## Others +### Tensorboard +#### Launch +```sh +source activate unet +tensorboard --logdir /path/to/model_dir/tfboard +``` +It will open a port to be loaded in a browser: +``` +Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all +TensorBoard 2.4.0 at http://localhost:6006/ (Press CTRL+C to quit) +``` + +If the data is on a server (e.g. SCS server), a few options: +* open a graphic interface to this server, launch terminal, run the above commands, and open a web browser at the indicated address, +* if using `vscode`, while being connected through `ssh` to the server, run the above commands, and `vscode` will automatically forward the open port to a local port on your computer, +* if using a local terminal, follow these steps, after having run the above steps *via* `ssh`: +```sh +# XX = port provided after running the $tensorboard command (usually 06, 07, etc.) +# server = server on which the model is training (e.g. callosum.radiology.ucsf.edu) +ssh -L 160XX:127.0.0.1:60XX server +# In your browser, go to: http://127.0.0.1:160XX/ +``` + +#### Read +* Scalars: graphs for loss, metrics, across batches, epochs, time. +* Graphs: model graph. +* Distribution & histograms: model weight evolution layer by layer across epochs. See more here: https://github.com/tensorflow/tensorboard/blob/master/docs/r1/histograms.md + + +#### Setup +In file `train_test.py`, function `train()`, there is a *Tensorboard Callback* `tb_callback`. + + +## Contributing +Edits and suggestions are welcome, using GitHub branch system to ask for merge. + +## Citation +If you use our code, please cite [our paper](http://www.ajnr.org/content/early/2022/12/21/ajnr.A7747): + +C.B.N. Tran, P. Nedelec, D.A. Weiss, J.D. Rudie, L. Kini, L.P. Sugrue, O.A. Glenn, C.P. Hess, A.M. Rauschecker, [Development of Gestational Age–Based Fetal Brain and Intracranial Volume Reference Norms Using Deep Learning](http://www.ajnr.org/content/early/2022/12/21/ajnr.A7747). DOI:10.3174/ajnr.A7747. + +``` +@article {Tran, + author = {Tran, C.B.N. and Nedelec, P. and Weiss, D.A. and Rudie, J.D. and Kini, L. and Sugrue, L.P. and Glenn, O.A. and Hess, C.P. and Rauschecker, A.M.}, + title = {Development of Gestational Age{\textendash}Based Fetal Brain and Intracranial Volume Reference Norms Using Deep Learning}, + year = {2022}, + doi = {10.3174/ajnr.A7747}, + publisher = {American Journal of Neuroradiology}, + issn = {0195-6108}, + URL = {http://www.ajnr.org/content/early/2022/12/21/ajnr.A7747}, + eprint = {http://www.ajnr.org/content/early/2022/12/21/ajnr.A7747.full.pdf}, + journal = {American Journal of Neuroradiology} +} +``` diff --git a/example_split.csv b/example_split.csv new file mode 100644 index 0000000..0d79b8f --- /dev/null +++ b/example_split.csv @@ -0,0 +1,3 @@ +123456789,Disease_class,train +234567891,disease2,val +345678912,disease3,test \ No newline at end of file diff --git a/unet.yml b/unet.yml new file mode 100644 index 0000000..c994d39 --- /dev/null +++ b/unet.yml @@ -0,0 +1,181 @@ +channels: + - simpleitk + - conda-forge + - anaconda + - defaults +dependencies: + - ca-certificates=2020.10.14=0 + - certifi=2020.6.20=py38_0 + - cudatoolkit=10.1.243=h6bb024c_0 + - cudnn=7.6.5=cuda10.1_0 + - cupti=10.1.168=0 + - openssl=1.1.1h=h7b6447c_0 + - argon2-cffi=20.1.0=py38h1e0a361_0 + - ipympl=0.6.2=pyhd8ed1ab_0 + - ipywidgets=7.6.3=pyhd3deb0d_0 + - jupyterlab_widgets=1.0.0=pyhd8ed1ab_1 + - python_abi=3.8=1_cp38 + - widgetsnbextension=3.5.1=py38h578d9bd_4 + - _libgcc_mutex=0.1=main + - astroid=2.4.2=py38_0 + - attrs=20.3.0=pyhd3eb1b0_0 + - autopep8=1.5.4=py_0 + - backcall=0.2.0=py_0 + - binutils_impl_linux-64=2.33.1=he6710b0_7 + - binutils_linux-64=2.33.1=h9595d00_15 + - blas=1.0=mkl + - bleach=3.2.1=py_0 + - cairo=1.14.12=h8948797_3 + - cffi=1.14.0=py38h2e261b9_0 + - cycler=0.10.0=py38_0 + - dbus=1.13.18=hb2f20db_0 + - decorator=4.4.2=py_0 + - defusedxml=0.6.0=py_0 + - entrypoints=0.3=py38_0 + - expat=2.2.10=he6710b0_2 + - fontconfig=2.13.0=h9420a91_0 + - freetype=2.10.4=h5ab3b9f_0 + - fribidi=1.0.10=h7b6447c_0 + - gcc_impl_linux-64=7.3.0=habb00fd_1 + - gcc_linux-64=7.3.0=h553295d_15 + - glib=2.63.1=h5a9c865_0 + - graphite2=1.3.14=h23475e2_0 + - graphviz=2.40.1=h21bd128_2 + - gst-plugins-base=1.14.0=hbbd80ab_1 + - gstreamer=1.14.0=hb453b48_1 + - harfbuzz=1.8.8=hffaf4a1_0 + - icu=58.2=he6710b0_3 + - importlib-metadata=2.0.0=py_1 + - importlib_metadata=2.0.0=1 + - intel-openmp=2020.2=254 + - ipykernel=5.3.4=py38h5ca1d4c_0 + - ipython=7.18.1=py38h5ca1d4c_0 + - ipython_genutils=0.2.0=py38_0 + - isort=5.6.4=py_0 + - jedi=0.17.2=py38_0 + - jinja2=2.11.2=py_0 + - jpeg=9b=h024ee3a_2 + - jsonschema=3.2.0=py_2 + - jupyter_client=6.1.7=py_0 + - jupyter_core=4.6.3=py38_0 + - kiwisolver=1.2.0=py38hfd86e86_0 + - lazy-object-proxy=1.4.3=py38h7b6447c_0 + - lcms2=2.11=h396b838_0 + - ld_impl_linux-64=2.33.1=h53a641e_7 + - libedit=3.1.20191231=h14c3975_1 + - libffi=3.2.1=hf484d3e_1007 + - libgcc-ng=9.1.0=hdf63c60_0 + - libgfortran-ng=7.3.0=hdf63c60_0 + - libpng=1.6.37=hbc83047_0 + - libsodium=1.0.18=h7b6447c_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - libtiff=4.1.0=h2733197_1 + - libtool=2.4.6=h7b6447c_1005 + - libuuid=1.0.3=h1bed415_2 + - libxcb=1.14=h7b6447c_0 + - libxml2=2.9.10=hb55368b_3 + - lz4-c=1.9.2=heb0550a_3 + - markupsafe=1.1.1=py38h7b6447c_0 + - matplotlib=3.3.2=0 + - matplotlib-base=3.3.2=py38h817c723_0 + - mccabe=0.6.1=py38_1 + - mistune=0.8.4=py38h7b6447c_1000 + - mkl=2020.2=256 + - mkl-service=2.3.0=py38he904b0f_0 + - mkl_fft=1.2.0=py38h23d657b_0 + - mkl_random=1.1.1=py38h0573a6f_0 + - nbconvert=5.6.1=py38_0 + - nbformat=5.1.1=pyhd3eb1b0_1 + - ncurses=6.2=he6710b0_1 + - notebook=6.1.6=py38h06a4308_0 + - numpy=1.19.1=py38hbc911f0_0 + - numpy-base=1.19.1=py38hfa32c7d_0 + - olefile=0.46=py_0 + - packaging=20.8=pyhd3eb1b0_0 + - pandas=1.1.3=py38he6710b0_0 + - pandoc=2.11=hb0f4dca_0 + - pandocfilters=1.4.3=py38h06a4308_1 + - pango=1.42.4=h049681c_0 + - parso=0.7.0=py_0 + - pcre=8.44=he6710b0_0 + - pexpect=4.8.0=py38_0 + - pickleshare=0.7.5=py38_1000 + - pillow=8.0.0=py38h9a89aac_0 + - pip=20.2.3=py38_0 + - pixman=0.40.0=h7b6447c_0 + - prometheus_client=0.9.0=pyhd3eb1b0_0 + - prompt-toolkit=3.0.8=py_0 + - ptyprocess=0.6.0=py38_0 + - pycodestyle=2.6.0=py_0 + - pycparser=2.20=py_2 + - pydot=1.4.1=py38_0 + - pygments=2.7.1=py_0 + - pylint=2.6.0=py38_0 + - pyparsing=2.4.7=py_0 + - pyqt=5.9.2=py38h05f1152_4 + - pyrsistent=0.17.3=py38h7b6447c_0 + - python=3.8.2=hcf32534_0 + - python-dateutil=2.8.1=py_0 + - pytz=2020.1=py_0 + - pyzmq=19.0.2=py38he6710b0_1 + - qt=5.9.7=h5867ecd_1 + - readline=8.0=h7b6447c_0 + - rope=0.18.0=py_0 + - scipy=1.5.2=py38h0b6359f_0 + - seaborn=0.11.1=pyhd3eb1b0_0 + - send2trash=1.5.0=pyhd3eb1b0_1 + - setuptools=50.3.0=py38hb0f4dca_1 + - sip=4.19.13=py38he6710b0_0 + - six=1.15.0=py_0 + - sqlite=3.33.0=h62c20be_0 + - terminado=0.9.2=py38h06a4308_0 + - testpath=0.4.4=py_0 + - tk=8.6.10=hbc83047_0 + - toml=0.10.1=py_0 + - tornado=6.0.4=py38h7b6447c_1 + - tqdm=4.50.2=py_0 + - traitlets=5.0.5=py_0 + - wcwidth=0.2.5=py_0 + - webencodings=0.5.1=py38_1 + - wheel=0.35.1=py_0 + - wrapt=1.11.2=py38h7b6447c_0 + - xz=5.2.5=h7b6447c_0 + - zeromq=4.3.3=he6710b0_3 + - zipp=3.4.0=pyhd3eb1b0_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.5=h9ceee32_0 + - simpleitk=2.0.1=py38hf484d3e_0 + - pip: + - absl-py==0.11.0 + - astunparse==1.6.3 + - cachetools==4.1.1 + - chardet==3.0.4 + - gast==0.3.3 + - google-auth==1.23.0 + - google-auth-oauthlib==0.4.2 + - google-pasta==0.2.0 + - grpcio==1.34.0 + - gviz-api==1.9.0 + - h5py==2.10.0 + - idna==2.10 + - keras-preprocessing==1.1.2 + - markdown==3.3.3 + - oauthlib==3.1.0 + - opt-einsum==3.3.0 + - protobuf==3.14.0 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - requests==2.25.0 + - requests-oauthlib==1.3.0 + - rsa==4.6 + - tensorboard==2.4.0 + - tensorboard-plugin-profile==2.4.0 + - tensorboard-plugin-wit==1.7.0 + - tensorflow==2.3.1 + - tensorflow-addons==0.12.0 + - tensorflow-estimator==2.3.0 + - termcolor==1.1.0 + - typeguard==2.10.0 + - urllib3==1.26.2 + - werkzeug==1.0.1 + diff --git a/unet/START_HERE.py b/unet/START_HERE.py new file mode 100755 index 0000000..085030f --- /dev/null +++ b/unet/START_HERE.py @@ -0,0 +1,90 @@ +#%% Imports +from utils.config import Config, ModelParameters + +# Add a few notes +# used in both the config.txt file for records (under `models/my_train_set/config.txt` or `predict/TRmy_train_set_TTmy_test_set`) +# and in tfboard name for easy filtering of experiments in Tensorboard +notes_run = f"my_notes" +model_params = ModelParameters(learning_rate=1e-4, num_epochs=10) + +### Choose *one* of the following Config call +#%% Config for normal training +c = Config(root='path/to/root/directory', + datasource='my_image_data', + experiment='my_train_set', + train=True, + nickname='my_unique_name', + model_params=model_params, + GPU='5,7', + ) +# Check that everything is setup correctly +c.print_recap(notes=notes_run) +# Start! +c.preprocess(force_preprocess=False) +c.start(notes_run=notes_run) + +#%% Config for cross-validation training +fold = 1 # choose fold number +notes_run = f"fetal-brain-segmentation_fold-{fold}" +c = Config(root='path/to/root/directory', + datasource='my_image_data', + experiment='my_train_set', + train=True, + nickname=f'CV-{fold}', + model_params=model_params, + GPU='5,7', + ) +# Check that everything is setup correctly +c.update_for_CV(fold) +c.print_recap(notes=notes_run) +# c.preprocess(CV=(6,None), create_csv=True) # first time only, to automatically create folds +c.preprocess() + +#%% Config for continued training +c = Config(root='', datasource='my_image_data', + experiment='my_train_set', + train=True, + model='my_train_set', + which_model='latest', + nickname='my_unique_name', + model_params=model_params, + GPU='5,7') +# Check that everything is setup correctly +c.print_recap(notes=notes_run) +# Start! +c.preprocess(force_preprocess=False) +c.start(notes_run=notes_run) + +#%% Config for testing +c = Config(root='', datasource='my_image_data', + experiment='my_test_set', + train=False, + model='my_train_set', + which_model='latest', + nickname='my_unique_name_for_these_predictions', + model_params=model_params, + read_existing=False, + GPU='5,7') +# Check that everything is setup correctly +c.print_recap(notes=notes_run) +c.preprocess(force_preprocess=False) +c.start(notes_run=notes_run) + +# Get Main component from masks +import subprocess +original_images_dir = c.thresholdedDir / str(c.model_params.bin_threshold) +main_component_images_dir = original_images_dir.with_name(f'{c.model_params.bin_threshold}_MC') +main_component_images_dir.mkdir(exist_ok=True) +sp = subprocess.run( + ['image/get_main_component.sh', original_images_dir, main_component_images_dir], + capture_output=True, + text='utf-8') +print(sp.stdout) +# Compute stats on main component masks +from utils.calculateDice_MICCAI import compute_stats_dir +stats_file_name = f'{c.diceDir}/stats_{c.experiment}_th_{c.model_params.bin_threshold}_MC.csv' +stats_df = compute_stats_dir(c.data_raw, + main_component_images_dir, + stats_file_name, + num_processes=12, + inference_only=False) diff --git a/unet/__init__.py b/unet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/unet/analysis/__init__.py b/unet/analysis/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/unet/analysis/utils_data.py b/unet/analysis/utils_data.py new file mode 100644 index 0000000..5874ee0 --- /dev/null +++ b/unet/analysis/utils_data.py @@ -0,0 +1,42 @@ +import pandas as pd +import numpy as np +import os +from pathlib import Path +import re + +def missing_columns(li1, li2, name=None): + """ prints the names of missing data points from one list to the other. name is a label to help if this function is called several times. + """ + l1 = len(li1) + l2 = len(li2) + if l2 < l1: + if name: + print(f'Missing {l1-l2} columns in {name}!\t{diff_2lists(li1, li2)}') + else: + print(f'Missing {l1-l2} columns!\t{diff_2lists(li1, li2)}') + + +def diff_2lists(li1, li2): + return list(list(set(li2)-set(li1)) + list(set(li1)-set(li2))) + +def exp_df(csv_path, exp_name, sep=' '): + """ Return DataFrame with exp_name suffixed to columns and 'Patient' as index + """ + df = pd.read_csv(csv_path) + # df = df[['Patient', 'd']] + df.set_index('Patient', inplace=True) + return df.add_prefix(exp_name + sep) + +def filter_metric(df, metric): + """ Return DataFrame with only the columns with appropriate metric + """ + col_metric = [col for col in df.columns if col.endswith(metric)] + df_filter = df[col_metric] + return remove_metric_colnames(df_filter, metric) + +def remove_metric_colnames(df, metric): + new_cols = {} + for col in df.columns: + if col.endswith(metric): + new_cols[col] = col.split(metric)[0][:-1] + return df.rename(columns=new_cols) \ No newline at end of file diff --git a/unet/analysis/utils_plot.py b/unet/analysis/utils_plot.py new file mode 100644 index 0000000..d95bd96 --- /dev/null +++ b/unet/analysis/utils_plot.py @@ -0,0 +1,146 @@ +import os +import numpy as np +from pathlib import Path +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.offsetbox import OffsetImage,AnnotationBbox + +try: + iroot = Path('Icons') + figures_dir = Path('Figures') / os.environ["USER"] + figures_dir.mkdir(exist_ok=True) +except: + pass + +def multi_plot(plot_type, dfs, metric=None, save=None, right_axis=True, icons=False, rotate_labels=False, colors='#4472C4', **plt_kwargs): + """ Plots several graphs on the same y axis + Example + ------- + multi_plot(sns.boxplot, dfs=[df1,df2], save='fig_name', colors=['#4472C4','#70AD47'], **{'width': 0.6}); + """ + n = len(dfs) + if colors is None: + colors = ['#4472C4'] * n + elif type(colors) is str: + colors = [colors] * n + elif len(colors) != n: + print('Wrong colors argument: does not match len(dfs).\nUsing default.') + colors = ['#4472C4'] * n + cols_length = [len(df.columns) for df in dfs] + total_col_length = np.sum(cols_length) + + f = plt.figure(constrained_layout=True, figsize=(total_col_length+1,5)) + gs = f.add_gridspec(1,total_col_length, wspace=0) + + fs = [] + start = 0 + for inum, df in enumerate(dfs): + end = start + cols_length[inum] + fi = f.add_subplot(gs[0,start:end]) + start = end + plot_type(data=df, color=colors[inum], ax=fi, **plt_kwargs) # green: #70AD47 + # fi.yaxis.set_visible(False) + if inum >0: fi.set_yticks([]) + fs.append(fi) + + # Set left axis label + if metric and metric != 'd': + fs[0].set_ylabel(metric, fontsize=12) + else: + fs[0].set_ylabel('Dice score', fontsize=12) + + # Set right axis label + if right_axis: + fs[-1].yaxis.tick_right() + fs[-1].yaxis.set_label_position("right") + + for ax in f.axes: + plt.sca(ax) + if rotate_labels: plt.xticks(rotation=40, horizontalalignment='right'); + plt.ylim(0,1) + plt.tick_params(axis='y', which='both', right=False) + + if icons: + for inum, df in enumerate(dfs): + show_icons(df, fs[inum]) + + sns.despine(left=True, right=True) + f.tight_layout() + + if save: + savefig(f, save) + + return [f, fs] + + +def get_icon(name): + + if '.' in name: # takes care of the 'exp.metric' (e.g. 102P.FDR) #TODO changed that to space... maybe better way than just one character? + name = name.split('.')[0] + if 'block' in name or name == 'all': + icon_path = iroot / 'ft.png' + else: + icon_path = iroot / (name+'.png') + try: + im = plt.imread(icon_path) + except: + return None + return im + +def offset_image(coord, name, ax): + img = get_icon(name) + if img is None: + return + im = OffsetImage(img, zoom=0.04) + im.image.axes = ax + ab = AnnotationBbox(im, (coord, 0), xybox=(0., -35.), frameon=False, xycoords='data', boxcoords="offset points", pad=0) + ax.add_artist(ab) + +def show_icons(df, ax): + for i, c in enumerate(df.columns): + offset_image(i, c, ax) + +def savefig(fig, name, transparent=False): + if not isinstance(fig, plt.Figure): + fig.get_figure().savefig(figures_dir/(name+'.png'), dpi=300, transparent=True, bbox_inches="tight") + else: + fig.savefig(figures_dir/(name+'.png'), dpi=300, transparent=transparent, bbox_inches="tight") + + +class IndexTracker(object): + def __init__(self, ax, X): + self.ax = ax + ax.set_title('use scroll wheel to navigate images') + + self.X = X + rows, cols, self.slices = X.shape + self.ind = self.slices//2 + + self.im = ax.imshow(self.X[:, :, self.ind]) + self.update() + + def onscroll(self, event): + print("%s %s" % (event.button, event.step)) + if event.button == 'up': + self.ind = (self.ind + 1) % self.slices + else: + self.ind = (self.ind - 1) % self.slices + self.update() + + def update(self): + self.im.set_data(self.X[:, :, self.ind]) + self.ax.set_ylabel('slice %s' % self.ind) + self.im.axes.figure.canvas.draw() + + +def slices_viewer(X): + """ Scroll through 2D image slices of a 3D array. + ex: slices_viewer(np.random.rand(20, 20, 40)) + NB: only works in notebook environment. Run `%matplotlib widget` first. + """ + fig, ax = plt.subplots(1, 1) + tracker = IndexTracker(ax, X) + fig.canvas.mpl_connect('scroll_event', tracker.onscroll) + # plt.show() + return fig \ No newline at end of file diff --git a/unet/analysis/utils_stats.py b/unet/analysis/utils_stats.py new file mode 100644 index 0000000..4b4c286 --- /dev/null +++ b/unet/analysis/utils_stats.py @@ -0,0 +1,41 @@ +from scipy.stats import wilcoxon, ranksums, mannwhitneyu, ttest_ind +import pandas as pd +import numpy as np + +def stats(func, base, compare_to, **stats_kwargs): + """ + Args: + func: stat function to use + base: tuple (dataframe, column_names) + compare_to: dataframe + **stats_kwargs: additional arguments for stat function (see examples below) + Examples: + stats(mannwhitneyu, (penn1, 'col1'), ucsf.columns, **{'alternative':'greater'}) + stats(wilcoxon, (ucsf, 'col2'), ucsf.columns) + stats(wilcoxon, (ucsf, ['col2', 'col3']), ucsf.columns) # for several analyses + """ + stats_name = type(func([1],[2])).__name__[:-6] + sig_label = 'Sig (<5%: ***, <10%: *)' + list_results = [] + base_df = base[0] + base_col_names = base[1] + if type(base_col_names) is str: base_col_names=[base_col_names] + for base_col_name in base_col_names: + results = {} + base_col_name = base_col_name + base_vals = base_df[base_col_name] + for col in compare_to.columns: + if col == base_col_name and base_df is compare_to: + results[col] = {'Median': np.median(base_vals), f'{stats_name} p-value': 0, sig_label: 'I'} + continue + res = func(base_vals, compare_to[col], **stats_kwargs)[1] + if res <= 0.05: + sig = '***' + elif res <= 0.1: + sig = '*' + else: + sig = '' + results[col] = {'Median': np.median(compare_to[col]), f'{stats_name} p-value': res, sig_label: sig} + list_results.append(pd.DataFrame(results)) + + return pd.concat(list_results) \ No newline at end of file diff --git a/unet/image/LinearInterpolation.py b/unet/image/LinearInterpolation.py new file mode 100755 index 0000000..710abf0 --- /dev/null +++ b/unet/image/LinearInterpolation.py @@ -0,0 +1,366 @@ +#%% This is differentiable linear interpolation code implemented in Tensorflow. +# It is based on the the implementation by Kelvin Zakka in Google, the author +# of spatial transformer network. +# The github link for spatial transformer network is here. +# https://github.com/kevinzakka/spatial-transformer-network + +#%% +import tensorflow as tf + +#%% + +def get_pixel_value_2D(img, x, y): + """ + Utility function to get pixel value for coordinate + vectors x and y from a 4D tensor image. + + Input + ----- + - img: tensor of shape (B, H, W, C) + - x: tensor of shape (B, H, W) + - y: flattened tensor of shape (B, H, W) + + Returns + ------- + - output: tensor of shape (B, H, W, C) + """ + shape = tf.shape(x) + batch_size = shape[0] + height = shape[1] + width = shape[2] + + batch_idx = tf.range(0, batch_size) + batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1)) + b = tf.tile(batch_idx, (1, height, width)) + + indices = tf.stack([b, y, x], axis = 3) + + return tf.gather_nd(img, indices) + +def get_pixel_value_3D(img, x, y, z): + """ + Utility function to get pixel value for coordinate + vectors x, y and z from a 5D tensor image. + + Input + ----- + - img: tensor of shape (B, H, W, D, C) + - x: tensor of shape (B, H, W, D) + - y: tensor of shape (B, H, W, D) + - z: tensor of shape (B, H, W, D) + + Returns + ------- + - output: tensor of shape (B, H, W, D, C) + """ + shape = tf.shape(x) + batch_size = shape[0] + height = shape[1] + width = shape[2] + depth = shape[3] + + batch_idx = tf.range(0, batch_size) + batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1, 1)) + b = tf.tile(batch_idx, (1, height, width, depth)) + + indices = tf.stack([b, y, x, z], axis = 4) + + return tf.gather_nd(img, indices) + + +def bilinear_sampler(img, x, y, normalized_coordinate = False): + """ + Performs bilinear sampling of the input images. Note that the + sampling is done identically for each channel of the input. + + To test if the function works properly, output image should be + identical to input image when theta is initialized to identity + transform. + + Input + ----- + - img: batch of images in (B, H, W, C) layout. + - grid: x, y which is the output of affine_grid_generator. + + Returns + ------- + - interpolated images according to grids. Same size as grid. + + """ + # prepare useful params + B = tf.shape(img)[0] + H = tf.shape(img)[1] + W = tf.shape(img)[2] + C = tf.shape(img)[3] + + max_y = tf.cast(H - 1, 'int32') + max_x = tf.cast(W - 1, 'int32') + zero = tf.zeros([], dtype='int32') + + # cast indices as float32 (for rescaling) + x = tf.cast(x, 'float32') + y = tf.cast(y, 'float32') + + # rescale x and y to [0, W/H] + if normalized_coordinate: + x = 0.5 * ((x + 1.0) * tf.cast(W, 'float32')) + y = 0.5 * ((y + 1.0) * tf.cast(H, 'float32')) + + # grab 4 nearest corner points for each (x_i, y_i) + # i.e. we need a rectangle around the point of interest + x0 = tf.cast(tf.floor(x), 'int32') + x1 = x0 + 1 + y0 = tf.cast(tf.floor(y), 'int32') + y1 = y0 + 1 + + # clip to range [0, H/W] to not violate img boundaries + x0 = tf.clip_by_value(x0, zero, max_x) + x1 = tf.clip_by_value(x1, zero, max_x) + y0 = tf.clip_by_value(y0, zero, max_y) + y1 = tf.clip_by_value(y1, zero, max_y) + + # get pixel value at corner coords + Ia = get_pixel_value_2D(img, x0, y0) + Ib = get_pixel_value_2D(img, x0, y1) + Ic = get_pixel_value_2D(img, x1, y0) + Id = get_pixel_value_2D(img, x1, y1) + + # recast as float for delta calculation + x0 = tf.cast(x0, 'float32') + x1 = tf.cast(x1, 'float32') + y0 = tf.cast(y0, 'float32') + y1 = tf.cast(y1, 'float32') + + # calculate deltas + wa = (x1-x) * (y1-y) + wb = (x1-x) * (y-y0) + wc = (x-x0) * (y1-y) + wd = (x-x0) * (y-y0) + + # add dimension for addition + wa = tf.expand_dims(wa, axis=3) + wb = tf.expand_dims(wb, axis=3) + wc = tf.expand_dims(wc, axis=3) + wd = tf.expand_dims(wd, axis=3) + + # compute output + out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id]) + + return out + +def trilinear_sampler(img, x, y, z, normalized_coordinate = False): + """ + Performs trilinear sampling of the input images. Note that the + sampling is done identically for each channel of the input. + + To test if the function works properly, output image should be + identical to input image when theta is initialized to identity + transform. + + Input + ----- + - img: batch of images in (B, H, W, D, C) layout. + - grid: x, y, z the sampling grid on the image + + Returns + ------- + - interpolated images according to grids. Same size as grid. + + """ + # prepare useful params + B = tf.shape(img)[0] + H = tf.shape(img)[1] + W = tf.shape(img)[2] + D = tf.shape(img)[3] + C = tf.shape(img)[4] + + max_y = tf.cast(H - 1, tf.int32) + max_x = tf.cast(W - 1, tf.int32) + max_z = tf.cast(D - 1, tf.int32) + zero = tf.zeros([], dtype=tf.int32) + + # rescale x and y to [0, W/H] + if normalized_coordinate: + # cast indices as float32 (for rescaling) + x = tf.cast(x, tf.float32) + y = tf.cast(y, tf.float32) + z = tf.cast(y, tf.float32) + x = 0.5 * ((x + 1.0) * tf.cast(W, tf.float32)) + y = 0.5 * ((y + 1.0) * tf.cast(H, tf.float32)) + z = 0.5 * ((z + 1.0) * tf.cast(D, tf.float32)) + + # grab 4 nearest corner points for each (x_i, y_i. z_i) + # i.e. we need a rectangle around the point of interest + x0 = tf.cast(tf.floor(x), tf.int32) + x1 = x0 + 1 + y0 = tf.cast(tf.floor(y), tf.int32) + y1 = y0 + 1 + z0 = tf.cast(tf.floor(z), tf.int32) + z1 = z0 + 1 + + # clip to range [0, H/W] to not violate img boundaries + x0 = tf.clip_by_value(x0, zero, max_x) + x1 = tf.clip_by_value(x1, zero, max_x) + y0 = tf.clip_by_value(y0, zero, max_y) + y1 = tf.clip_by_value(y1, zero, max_y) + z0 = tf.clip_by_value(z0, zero, max_z) + z1 = tf.clip_by_value(z1, zero, max_z) + + # get pixel value at corner coords + Ia_0 = get_pixel_value_3D(img, x0, y0, z0) + Ia_1 = get_pixel_value_3D(img, x0, y0, z1) + Ib_0 = get_pixel_value_3D(img, x0, y1, z0) + Ib_1 = get_pixel_value_3D(img, x0, y1, z1) + Ic_0 = get_pixel_value_3D(img, x1, y0, z0) + Ic_1 = get_pixel_value_3D(img, x1, y0, z1) + Id_0 = get_pixel_value_3D(img, x1, y1, z0) + Id_1 = get_pixel_value_3D(img, x1, y1, z1) + + # recast as float for delta calculation + x0 = tf.cast(x0, tf.float32) + x1 = tf.cast(x1, tf.float32) + y0 = tf.cast(y0, tf.float32) + y1 = tf.cast(y1, tf.float32) + z0 = tf.cast(z0, tf.float32) + z1 = tf.cast(z1, tf.float32) + + # calculate deltas + wa_0 = (x1-x) * (y1-y) * (z1-z) + wa_1 = (x1-x) * (y1-y) * (z-z0) + wb_0 = (x1-x) * (y-y0) * (z1-z) + wb_1 = (x1-x) * (y-y0) * (z-z0) + wc_0 = (x-x0) * (y1-y) * (z1-z) + wc_1 = (x-x0) * (y1-y) * (z-z0) + wd_0 = (x-x0) * (y-y0) * (z1-z) + wd_1 = (x-x0) * (y-y0) * (z-z0) + + # add dimension for addition + wa_0 = tf.expand_dims(wa_0, axis=4) + wa_1 = tf.expand_dims(wa_1, axis=4) + wb_0 = tf.expand_dims(wb_0, axis=4) + wb_1 = tf.expand_dims(wb_1, axis=4) + wc_0 = tf.expand_dims(wc_0, axis=4) + wc_1 = tf.expand_dims(wc_1, axis=4) + wd_0 = tf.expand_dims(wd_0, axis=4) + wd_1 = tf.expand_dims(wd_1, axis=4) + + # compute output + out = tf.add_n([wa_0*Ia_0, wb_0*Ib_0, wc_0*Ic_0, wd_0*Id_0, \ + wa_1*Ia_1, wb_1*Ib_1, wc_1*Ic_1, wd_1*Id_1]) + + return out + +# Nearest neighbor sampler +def binnsampler(img, x, y, normalized_coordinate = False): + """ + Performs 2D nearest neigbour sampling of the input images. Note + that the sampling is done identically for each channel of the + input. + + To test if the function works properly, output image should be + identical to input image when theta is initialized to identity + transform. + + Input + ----- + - img: batch of images in (B, H, W, C) layout. + - grid: x, y which is the output of affine_grid_generator. + + Returns + ------- + - interpolated images according to grids. Same size as grid. + + """ + # prepare useful params + B = tf.shape(img)[0] + H = tf.shape(img)[1] + W = tf.shape(img)[2] + C = tf.shape(img)[3] + + max_y = tf.cast(H - 1, 'int32') + max_x = tf.cast(W - 1, 'int32') + zero = tf.zeros([], dtype='int32') + + # cast indices as float32 (for rescaling) + x = tf.cast(x, 'float32') + y = tf.cast(y, 'float32') + + # rescale x and y to [0, W/H] + if normalized_coordinate: + x = 0.5 * ((x + 1.0) * tf.cast(W, 'float32')) + y = 0.5 * ((y + 1.0) * tf.cast(H, 'float32')) + + # grab 4 nearest corner points for each (x_i, y_i) + # i.e. we need a rectangle around the point of interest + x0 = tf.cast(x, 'int32') + y0 = tf.cast(y, 'int32') + + # clip to range [0, H/W] to not violate img boundaries + x0 = tf.clip_by_value(x0, zero, max_x) + y0 = tf.clip_by_value(y0, zero, max_y) + + # get pixel value at corner coords + out = get_pixel_value_2D(img, x0, y0) + + return out + +# Nearest neighbor sampler +def trinnsampler(img, x, y, z, normalized_coordinate = False): + """ + Performs 3D nearest neigbour sampling of the input images. Note + that the sampling is done identically for each channel of the + input. + + To test if the function works properly, output image should be + identical to input image when theta is initialized to identity + transform. + + Input + ----- + - img: batch of images in (B, H, W, C) layout. + - grid: x, y which is the output of affine_grid_generator. + + Returns + ------- + - interpolated images according to grids. Same size as grid. + + """ + # prepare useful params + B = tf.shape(img)[0] + H = tf.shape(img)[1] + W = tf.shape(img)[2] + D = tf.shape(img)[3] + C = tf.shape(img)[4] + + max_y = tf.cast(H - 1, tf.int32) + max_x = tf.cast(W - 1, tf.int32) + max_z = tf.cast(D - 1, tf.int32) + zero = tf.zeros([], dtype=tf.int32) + + # rescale x and y to [0, W/H] + if normalized_coordinate: + # cast indices as float32 (for rescaling) + x = tf.cast(x, tf.float32) + y = tf.cast(y, tf.float32) + z = tf.cast(z, tf.float32) + x = 0.5 * ((x + 1.0) * tf.cast(W, tf.float32)) + y = 0.5 * ((y + 1.0) * tf.cast(H, tf.float32)) + z = 0.5 * ((z + 1.0) * tf.cast(D, tf.float32)) + + # grab 4 nearest corner points for each (x_i, y_i. z_i) + # i.e. we need a rectangle around the point of interest + x0 = tf.cast(x, tf.int32) + y0 = tf.cast(y, tf.int32) + z0 = tf.cast(z, tf.int32) + + # clip to range [0, H/W] to not violate img boundaries + x0 = tf.clip_by_value(x0, zero, max_x) + y0 = tf.clip_by_value(y0, zero, max_y) + z0 = tf.clip_by_value(z0, zero, max_z) + + # get pixel value at corner coords + out = get_pixel_value_3D(img, x0, y0, z0) + + return out + + diff --git a/unet/image/__init__.py b/unet/image/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/unet/image/elastic_transform_tf.py b/unet/image/elastic_transform_tf.py new file mode 100755 index 0000000..53aa012 --- /dev/null +++ b/unet/image/elastic_transform_tf.py @@ -0,0 +1,200 @@ +# This is a tensorflow implementation of elastic transformation for data augmentation +import numpy as np +import tensorflow as tf +from .LinearInterpolation import trilinear_sampler, trinnsampler + +#%% + +class elastic_param(): + def __init__(self): + # Affine parameters + # Rotation, specify the maximum random rotation in radians + self.rotation_x = 0.0 + self.rotation_y = 0.0 + self.rotation_z = 0.0 + + # Translation, specify the maximum random translation in normalized distance (0-1 to 0-X/Y/Z) + self.trans_x = 0 + self.trans_y = 0 + self.trans_z = 0 + + # scaling, specify the maximum random scaling in normalized size (1 + scale) + self.scale_x = 1.0 + self.scale_y = 1.0 + self.scale_z = 1.0 + + # shearing, unimeplemented due to ins medical image the organ is usually kept unsheared + + # deformation parameters + # Voxel shifting, normalized by scale + self.df_x = 0.0 + self.df_y = 0.0 + self.df_z = 0.0 + +# input: +# img: image, tensor of size (H, W, D, C) +# seg: segmentation, tensor of size (H, W, D) +# affine_param: the affine transformation parameter +# elastic_param: the elastic transformation parameter +# output: +# out: the augmented image + +def elastic_transform_3D(img, seg, elastic_param): + # Compose Rotation, scaling, translation + H = tf.shape(img)[0] + W = tf.shape(img)[1] + D = tf.shape(img)[2] + C = tf.shape(img)[3] + + Hf = tf.cast(H, dtype = tf.float32) + Wf = tf.cast(W, dtype = tf.float32) + Df = tf.cast(D, dtype = tf.float32) + Cf = tf.cast(C, dtype = tf.float32) + + # Generate homogenous grid + y = tf.linspace(0.0, Hf-1, H) + x = tf.linspace(0.0, Wf-1, W) + z = tf.linspace(0.0, Df-1, D) + x_t, y_t, z_t = tf.meshgrid(x, y, z) + + x_t_flat = tf.reshape(x_t, [-1]) + y_t_flat = tf.reshape(y_t, [-1]) + z_t_flat = tf.reshape(z_t, [-1]) + + # reshape to [x_t, y_t , 1] - (homogeneous form) + ones = tf.ones_like(x_t_flat) + sampling_grid = tf.stack([y_t_flat, x_t_flat, z_t_flat, ones]) + sampling_grid = tf.cast(sampling_grid, tf.float32) # cast to float32 (required for matmul) + + # Random rotation, random scaling, random translation + rx = tf.random_uniform(shape = [], minval=-elastic_param.rotation_x, maxval= elastic_param.rotation_x, dtype=tf.float32) + ry = tf.random_uniform(shape = [], minval=-elastic_param.rotation_y, maxval= elastic_param.rotation_y, dtype=tf.float32) + rz = tf.random_uniform(shape = [], minval=-elastic_param.rotation_z, maxval= elastic_param.rotation_z, dtype=tf.float32) + + tx = Wf * tf.random_uniform(shape = [], minval=-elastic_param.trans_x, maxval= elastic_param.trans_x, dtype=tf.float32) + ty = Hf * tf.random_uniform(shape = [], minval=-elastic_param.trans_y, maxval= elastic_param.trans_y, dtype=tf.float32) + tz = Df * tf.random_uniform(shape = [], minval=-elastic_param.trans_z, maxval= elastic_param.trans_z, dtype=tf.float32) + + sx = tf.random_uniform(shape = [], minval=1-elastic_param.scale_x, maxval= 1+elastic_param.scale_x, dtype=tf.float32) + sy = tf.random_uniform(shape = [], minval=1-elastic_param.scale_y, maxval= 1+elastic_param.scale_y, dtype=tf.float32) + sz = tf.random_uniform(shape = [], minval=1-elastic_param.scale_z, maxval= 1+elastic_param.scale_z, dtype=tf.float32) + + # Form the affine matrix + Ry = tf.stack([ (1.0, 0.0, 0.0, 0.0), (0.0, tf.cos(ry), -tf.sin(ry), 0.0), (0.0, tf.sin(ry), tf.cos(ry), 0.0), (0.0, 0.0, 0.0, 1.0)], axis=0) + Rx = tf.stack([ (tf.cos(rx), 0.0, tf.sin(rx), 0.0), (0.0, 1, 0.0, 0.0), (-tf.sin(rx), 0.0, tf.cos(rx), 0.0), (0.0, 0.0, 0.0, 1.0)], axis=0) + Rz = tf.stack([ (tf.cos(rz), -tf.sin(rz), 0.0, 0.0), (tf.sin(rz), tf.cos(rz), 0.0, 0.0), (0.0, 0.0, 1.0, 0.0), (0.0, 0.0, 0.0, 1.0)], axis=0) + R = tf.matmul(Rz, tf.matmul(Rx, Ry)) + S = tf.stack( [(sy, 0.0, 0.0 , 0.0), (0.0, sx, 0.0, 0.0), (0.0, 0.0, sz, 0.0), (0.0, 0.0, 0.0, 1.0)], axis = 0) + T = tf.stack( [(1.0, 0.0, 0.0, ty), (0.0, 1.0, 0.0, tx), (0.0, 0.0, 1.0, tz), (0.0, 0.0, 0.0, 1.0)], axis = 0) + + # Affine transform + A = tf.matmul(R, S) + A = tf.matmul(A, T) + yt_A, xt_A, zt_A, _ = tf.unstack(tf.matmul(A, sampling_grid)) + + # Elastic transform + xt_A = xt_A + tf.random_uniform( [H * W * D] , minval = -elastic_param.df_x, maxval = elastic_param.df_x) + yt_A = yt_A + tf.random_uniform( [H * W * D] , minval = -elastic_param.df_y, maxval = elastic_param.df_y) + zt_A = zt_A + tf.random_uniform( [H * W * D] , minval = -elastic_param.df_z, maxval = elastic_param.df_z) + + xt_A = tf.reshape(xt_A, [H, W, D]) + yt_A = tf.reshape(yt_A, [H, W, D]) + zt_A = tf.reshape(zt_A, [H, W, D]) + + xt_A = tf.expand_dims(xt_A, axis = 0) + yt_A = tf.expand_dims(yt_A, axis = 0) + zt_A = tf.expand_dims(zt_A, axis = 0) + + # Interpolation + img_ex = tf.expand_dims(img, axis = 0) + img_A = trilinear_sampler(img_ex, xt_A, yt_A, zt_A, normalized_coordinate = False) + img_A = tf.squeeze(img_A) + + seg_ex = tf.expand_dims(seg, axis = 0) + seg_ex = tf.expand_dims(seg_ex, axis = -1) + seg_A = trinnsampler(seg_ex, xt_A, yt_A, zt_A, normalized_coordinate = False) + seg_A = tf.squeeze(seg_A) + + return img_A, seg_A, y_t, x_t, z_t + +def elastic_transform_3D_tf2(img, seg, elastic_param): + # Compose Rotation, scaling, translation + H = tf.shape(input=img)[0] + W = tf.shape(input=img)[1] + D = tf.shape(input=img)[2] + C = tf.shape(input=img)[3] + + Hf = tf.cast(H, dtype = tf.float32) + Wf = tf.cast(W, dtype = tf.float32) + Df = tf.cast(D, dtype = tf.float32) + Cf = tf.cast(C, dtype = tf.float32) + + # Generate homogenous grid + y = tf.linspace(0.0, Hf-1, H) + x = tf.linspace(0.0, Wf-1, W) + z = tf.linspace(0.0, Df-1, D) + x_t, y_t, z_t = tf.meshgrid(x, y, z) + + x_t_flat = tf.reshape(x_t, [-1]) + y_t_flat = tf.reshape(y_t, [-1]) + z_t_flat = tf.reshape(z_t, [-1]) + + # reshape to [x_t, y_t , 1] - (homogeneous form) + ones = tf.ones_like(x_t_flat) + sampling_grid = tf.stack([y_t_flat, x_t_flat, z_t_flat, ones]) + sampling_grid = tf.cast(sampling_grid, tf.float32) # cast to float32 (required for matmul) + + # Random rotation, random scaling, random translation + rx = tf.random.uniform(shape = [], minval=-elastic_param.rotation_x, maxval= elastic_param.rotation_x, dtype=tf.float32) + ry = tf.random.uniform(shape = [], minval=-elastic_param.rotation_y, maxval= elastic_param.rotation_y, dtype=tf.float32) + rz = tf.random.uniform(shape = [], minval=-elastic_param.rotation_z, maxval= elastic_param.rotation_z, dtype=tf.float32) + + tx = Wf * tf.random.uniform(shape = [], minval=-elastic_param.trans_x, maxval= elastic_param.trans_x, dtype=tf.float32) + ty = Hf * tf.random.uniform(shape = [], minval=-elastic_param.trans_y, maxval= elastic_param.trans_y, dtype=tf.float32) + tz = Df * tf.random.uniform(shape = [], minval=-elastic_param.trans_z, maxval= elastic_param.trans_z, dtype=tf.float32) + + sx = tf.random.uniform(shape = [], minval=1-elastic_param.scale_x, maxval= 1+elastic_param.scale_x, dtype=tf.float32) + sy = tf.random.uniform(shape = [], minval=1-elastic_param.scale_y, maxval= 1+elastic_param.scale_y, dtype=tf.float32) + sz = tf.random.uniform(shape = [], minval=1-elastic_param.scale_z, maxval= 1+elastic_param.scale_z, dtype=tf.float32) + + # Form the affine matrix + Ry = tf.stack([ (1.0, 0.0, 0.0, 0.0), (0.0, tf.cos(ry), -tf.sin(ry), 0.0), (0.0, tf.sin(ry), tf.cos(ry), 0.0), (0.0, 0.0, 0.0, 1.0)], axis=0) + Rx = tf.stack([ (tf.cos(rx), 0.0, tf.sin(rx), 0.0), (0.0, 1, 0.0, 0.0), (-tf.sin(rx), 0.0, tf.cos(rx), 0.0), (0.0, 0.0, 0.0, 1.0)], axis=0) + Rz = tf.stack([ (tf.cos(rz), -tf.sin(rz), 0.0, 0.0), (tf.sin(rz), tf.cos(rz), 0.0, 0.0), (0.0, 0.0, 1.0, 0.0), (0.0, 0.0, 0.0, 1.0)], axis=0) + R = tf.matmul(Rz, tf.matmul(Rx, Ry)) + S = tf.stack( [(sy, 0.0, 0.0 , 0.0), (0.0, sx, 0.0, 0.0), (0.0, 0.0, sz, 0.0), (0.0, 0.0, 0.0, 1.0)], axis = 0) + T = tf.stack( [(1.0, 0.0, 0.0, ty), (0.0, 1.0, 0.0, tx), (0.0, 0.0, 1.0, tz), (0.0, 0.0, 0.0, 1.0)], axis = 0) + + # Affine transform + A = tf.matmul(R, S) + A = tf.matmul(A, T) + yt_A, xt_A, zt_A, _ = tf.unstack(tf.matmul(A, sampling_grid)) + + # Elastic transform + xt_A = xt_A + tf.random.uniform( [H * W * D] , minval = -elastic_param.df_x, maxval = elastic_param.df_x) + yt_A = yt_A + tf.random.uniform( [H * W * D] , minval = -elastic_param.df_y, maxval = elastic_param.df_y) + zt_A = zt_A + tf.random.uniform( [H * W * D] , minval = -elastic_param.df_z, maxval = elastic_param.df_z) + + xt_A = tf.reshape(xt_A, [H, W, D]) + yt_A = tf.reshape(yt_A, [H, W, D]) + zt_A = tf.reshape(zt_A, [H, W, D]) + + xt_A = tf.expand_dims(xt_A, axis = 0) + yt_A = tf.expand_dims(yt_A, axis = 0) + zt_A = tf.expand_dims(zt_A, axis = 0) + + # Interpolation + img_ex = tf.expand_dims(img, axis = 0) + img_A = trilinear_sampler(img_ex, xt_A, yt_A, zt_A, normalized_coordinate = False) + img_A = tf.squeeze(img_A) + + seg_ex = tf.expand_dims(seg, axis = 0) + seg_ex = tf.expand_dims(seg_ex, axis = -1) + seg_A = trinnsampler(seg_ex, xt_A, yt_A, zt_A, normalized_coordinate = False) + seg_A = tf.squeeze(seg_A) + + return img_A, seg_A + +def elastic_transform_2D(img, affine_param, elastic_param): + pass +# %% diff --git a/unet/image/fix_orientation.py b/unet/image/fix_orientation.py new file mode 100755 index 0000000..0d12658 --- /dev/null +++ b/unet/image/fix_orientation.py @@ -0,0 +1,69 @@ +#%% This is a script that fix the orientation of nifti image +import SimpleITK as sitk +import numpy as np + +# This helper function get the direction, spacing and numpy array from an SimpleITK image +# input: +# im: a SimpleITK image +# output: +# im_np, spacing, direction: the image, spacing, direction, all converted to numpy array. Note that +# the axis of the image numpy array is flipped backward in this (Z, Y, X). +def getImSpDirct(im): + direction = np.array(im.GetDirection()).reshape((3, 3)) + spacing = np.array(im.GetSpacing()) + im_np = sitk.GetArrayFromImage(im) + + return im_np, spacing, direction + + +# This function fix image orientation and swap spacing given direction +# input: +# im_np, spacing, direction: the image, spacing, direction, all converted to numpy array. Note that +# the axis of the image numpy array is flipped backward in this (Z, Y, X). +# output: +# im_np, spacing: the image and spacing fixed by the direction. So these two can be straight up used to +# create a SimpleITK image with identity direction and it will look the same as the original. Note that +# the im_np direction is again still flipped backward in (Z, Y, X). +def fixOrientation(im_np, direction, spacing): + im_np = np.swapaxes(im_np, axis1 = 0, axis2 = 2) # From X Y Z to Z X Y + ax_phy = np.argmax(np.abs(direction), axis = 1) + sign_phy = np.sign(np.array( [ direction[0, ax_phy[0]], direction[1, ax_phy[1]], direction[2, ax_phy[2]] ])) + + # Transpose axis + im_np = np.transpose(im_np, axes = ax_phy) + sign_phy = sign_phy[ax_phy] + spacing = spacing[ax_phy] + + # Flip axis + for ax, sn in enumerate(list(sign_phy)): + if sn == -1: + im_np = np.flip(im_np, axis = ax) + + im_np = np.swapaxes(im_np, axis1 = 0, axis2 = 2) + + return im_np, spacing + +# This function does the opposite of the fixOrientation. Given an image without the header and an image that does, this +# function reverse the orientation of the image without header +def reverseOrientation(im_np, direction): + im_np = np.swapaxes(im_np, axis1 = 0, axis2 = 2) # From X Y Z to Z X Y + ax_phy = np.argmax(np.abs(direction), axis = 1) + sign_phy = np.sign(np.array( [ direction[0, ax_phy[0]], direction[1, ax_phy[1]], direction[2, ax_phy[2]] ])) + sign_phy = sign_phy[ax_phy] + + # Reverse the conversion + ax_im = np.array([np.argwhere(ax_phy==0)[0, 0], np.argwhere(ax_phy==1)[0, 0], np.argwhere(ax_phy==2)[0, 0]]) + sign_im = sign_phy[ax_im] + + # Transpose axis + im_np = np.transpose(im_np, axes = ax_im) + + # Flip axis + for ax, sn in enumerate(list(sign_im)): + if sn == -1: + im_np = np.flip(im_np, axis = ax) + + im_np = np.swapaxes(im_np, axis1 = 0, axis2 = 2) + + return im_np + diff --git a/unet/image/get_main_component.sh b/unet/image/get_main_component.sh new file mode 100755 index 0000000..9030f14 --- /dev/null +++ b/unet/image/get_main_component.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +rootdir=$1 +newdir=$2 + +for f in $rootdir/*; do + echo $f + filename=$(basename "$f") + c3d \ + $f -popas S \ + -push S -thresh 1 inf 1 0 -comp -popas C \ + -push C -thresh 1 1 1 0 \ + -push S -multiply \ + -o $newdir/$filename +done diff --git a/unet/image/image_plot.py b/unet/image/image_plot.py new file mode 100644 index 0000000..ce04bd9 --- /dev/null +++ b/unet/image/image_plot.py @@ -0,0 +1,127 @@ +from tensorflow.keras.callbacks import Callback +import io +import numpy as np +import tensorflow as tf +import matplotlib.pyplot as plt +import random + +class ImageHistory(Callback): + def __init__(self, tensorboard_dir, data, num_images_to_show=10): + super(ImageHistory, self).__init__() + self.tensorboard_dir = str(tensorboard_dir) + self.batches_for_plot = self.get_random_batches(data, num_images_to_show) + # get indices for a few batches, and keep them so we can see the progression on the same images + ind_bright, ind_dimm, ind_random = [], [], [] + for _, batch_seg in self.batches_for_plot: + ind_bright.append(find_best_indices(find_brightest, batch_seg)) + ind_dimm.append(find_best_indices(find_dimmest, batch_seg)) + ind_random.append(find_best_indices(find_random, batch_seg)) + self.indices = ind_bright, ind_dimm, ind_random # axes: (num_func, num_batch, ind_tuple) + self.indices = np.moveaxis(self.indices, 1, 0) # axes: (num_batch, num_func, ind_tuple) + + def on_epoch_end(self, epoch, logs={}): + recap_images = [] + for bb, batch_imgseg in enumerate(self.batches_for_plot): + batch_img, batch_seg = batch_imgseg + batch_pred = self.model.predict(batch_imgseg) #TODO batch_size? + # Get `best` 2D slices + b_ind = self.indices[bb] + bright = batch_img[b_ind[0][0], b_ind[0][1], :, :], batch_seg[b_ind[0][0], b_ind[0][1], :, :], batch_pred[b_ind[0][0], b_ind[0][1], :, :] + dimm = batch_img[b_ind[1][0], b_ind[1][1], :, :], batch_seg[b_ind[1][0], b_ind[1][1], :, :], batch_pred[b_ind[1][0], b_ind[1][1], :, :] + rand = batch_img[b_ind[2][0], b_ind[2][1], :, :], batch_seg[b_ind[2][0], b_ind[2][1], :, :], batch_pred[b_ind[2][0], b_ind[2][1], :, :] + # Display them in a grid + figure = image_grid([bright, dimm, rand], title=['Bright', 'Dimm', 'Random'], figsize=(4.7,5)) + # Transforms figure into Tensor + recap_image = plot_to_image(figure) + recap_images.append(recap_image) + _, h, w, c = recap_image.shape + recap_images = np.reshape(recap_images, (-1, h, w, c)) + writer = tf.summary.create_file_writer(self.tensorboard_dir) + with writer.as_default(): + tf.summary.image("Images and segmentations after each epoch", recap_images, max_outputs=len(recap_images), step=epoch) #(value=[tf.Summary.Value(tag='Images and segmentations', image=recap_images)]) + return + + def get_random_batches(self, dataset, num_image_to_show, rand_span=10): + range_batch = rand_span * num_image_to_show + rand_batches = list(range(range_batch)) + random.shuffle(rand_batches) + rand_batches = rand_batches[:num_image_to_show] + batches_for_plot = [] + for bb, batch in enumerate(dataset.repeat().take(range_batch)): + if bb in rand_batches: + batches_for_plot.append(batch) + return batches_for_plot + + +def image_grid(imgsegpred_list, title=None, figsize=(4.7,5)): + """Creates figure with list of triples (img2d, seg2d, pred2) + """ + n = len(imgsegpred_list) + if not title: title=['']*n + # Create a figure to contain the plot. + fig, axes = plt.subplots(n,3, sharex=True, sharey=True, figsize=figsize) + for ax, triple, ylabel in zip(axes, imgsegpred_list, title): + ax[0].imshow(triple[0], cmap=plt.cm.binary, aspect='equal') + if ylabel!='': ax[0].set_ylabel(ylabel, fontsize=10) + ax[1].imshow(triple[1], cmap=plt.cm.binary, aspect='equal') + ax[2].imshow(triple[2], cmap=plt.cm.binary, aspect='equal') + ax[0].set_title('Image', y=-0.3, fontsize=10) + ax[1].set_title('Segmentation', y=-0.3, fontsize=10) + ax[2].set_title('Prediction', y=-0.3, fontsize=10) + fig.tight_layout() + fig.subplots_adjust(hspace=0.1, wspace=0.1) + plt.xticks([]) + plt.yticks([]) + return fig + +def plot_to_image(figure): + """Converts the matplotlib plot specified by 'figure' to a PNG image and + returns it. The supplied figure is closed and inaccessible after this call.""" + # Save the plot to a PNG in memory. + buf = io.BytesIO() + plt.savefig(buf, format='png') + # Closing the figure prevents it from being displayed directly inside the notebook. + plt.close(figure) + buf.seek(0) + # Convert PNG buffer to TF image + image = tf.image.decode_png(buf.getvalue(), channels=4) + # Add the batch dimension + image = tf.expand_dims(image, 0) + return image + +def find_brightest(img): + img = np.squeeze(img) + brightest = 0 + i_brightest = 0 + for i in range(img.shape[0]): + bright = np.sum(img[i]) + if bright > brightest: + brightest = bright + i_brightest = i + return i_brightest, brightest + +def find_dimmest(img): + img = np.squeeze(img) + dimmest = float('inf') + i_dimmest = 0 + for i in range(img.shape[0]): + bright = np.sum(img[i]) + if bright < dimmest: + dimmest = bright + i_dimmest = i + return i_dimmest, dimmest + +def find_random(img): + img = np.squeeze(img) + i_rand = random.randint(0, img.shape[0]-1) + return i_rand, np.sum(img[i_rand]) + +def find_best_indices(func, batch_segs): + """ Return the batch and 3D indices corresponding to the brightest/dimmest (whatever func chooses) true segmentation + """ + # find image in batch + best_img_from_batch_i, _ = func(batch_segs) + # find slice in 3D + best_slice_i, slice_brightness = func(batch_segs[best_img_from_batch_i]) + return best_img_from_batch_i, best_slice_i + diff --git a/unet/image/patch_util.py b/unet/image/patch_util.py new file mode 100755 index 0000000..5cc47d2 --- /dev/null +++ b/unet/image/patch_util.py @@ -0,0 +1,181 @@ +# This contains utility that generate patches from 3D volume +#%% Import libraries +import numpy as np +from .resample_util import resample_by_resolution + +#%% Utility functions + + +# This function sample 3D patches from 3D volume +# input: +# image: the image in numpy array, dimension [H, W, D, C] +# seg: segmentation of the image, dimension [H, W, D], right now assuming this is binary +# patch_size: the size of patch +# num_pos: number of positive patches that contains lesion to sample. If there is no enough patches to sample, +# it will return all indexes that contains lesion +# num_negative: number of negative background patches that doesn't contain lesion to sample. +# output: +# patches_pos, patches_neg: list of (img_patch, seg_patch, cpt) +def single_resolution_patcher_3D(image, seg, patch_size, is_training = True, num_pos = 10, num_neg = 10, spacing = [1, 1, 1]): + if is_training: + # Randomly sample center points + cpts_pos_sampled, cpts_neg_sampled = sample_center_points(seg, num_pos, num_neg) + # Crop patches around center points + patches_pos = crop_patch_by_cpts(image, seg, cpts_pos_sampled, patch_size) + patches_neg = crop_patch_by_cpts(image, seg, cpts_neg_sampled, patch_size) + + return patches_pos, patches_neg + else: + # Regularly grid center points + cpts = grid_center_points(image.shape, spacing) + # Crop patches around center points + patches = crop_patch_by_cpts(image, seg, cpts, patch_size) + return patches + + +# This function sample 3D patches from 3D volume in multiple resolution around same center, used deepmedic or +# similar style network +# input: +# image: the image in numpy array, dimension [H, W, D, C] +# seg: segmentation of the image, dimension [H, W, D], right now assuming this is binary +# patchsize_multi_res: this is the patch size in multi-resolution [(1, (25, 25, 25)), (0.33, (19, 19, 19))] +# this means it will sample patch size (25, 25, 25) in resolution 1x, patch size (19, 19, 19) in resolution 0.33x etc +# num_pos: number of positive patches that contains lesion to sample. If there is no enough patches to sample, +# it will return all indexes that contains lesion +# num_negative: number of negative background patches that doesn't contain lesion to sample. +def multi_resolution_patcher_3D(image, seg, patchsize_multi_res, is_training = True, num_pos = 10, num_neg = 10, spacing = [1, 1, 1]): + # Sample center points + import tensorflow as tf + seg = tf.cast(seg, dtype='int16') + if is_training: + cpts_pos_sampled, cpts_neg_sampled = sample_center_points(seg, num_pos, num_neg) + + # Get center pts in multi resolution + cpts_pos_multi_res = multiple_resolution_cpts(cpts_pos_sampled, patchsize_multi_res) + cpts_neg_multi_res = multiple_resolution_cpts(cpts_neg_sampled, patchsize_multi_res) + + patches_pos_multi_res = [] + patches_neg_multi_res = [] + for idx, pr in enumerate(patchsize_multi_res): + res, patch_size = pr + # Downsample the image and segmentation + image_resize, seg_resize = resample_by_resolution(image, seg, res) + + cpts_max = np.array(image_resize.shape[:3]) - 1 + cpts_max = cpts_max[:, None] + + # Fetch positive patches + cpts_pos = cpts_pos_multi_res[idx] + cpts_pos = np.minimum(cpts_max, cpts_pos) # Limit the range + # Due to numerical rounding the cpts in different resolution may not match the + # resize image exactly. So need to hard constraint it + + patches = crop_patch_by_cpts(image_resize, seg_resize, cpts_pos, patch_size) + patches_pos_multi_res.append([patches, res]) + + # Fetch positive patches + cpts_neg = cpts_neg_multi_res[idx] + cpts_neg = np.minimum(cpts_max, cpts_neg) # Limit the range. + patches = crop_patch_by_cpts(image_resize, seg_resize, cpts_neg, patch_size) + patches_neg_multi_res.append([patches, res]) + + return patches_pos_multi_res, patches_neg_multi_res + else: + # Regularly grid center points + cpts = grid_center_points(image.shape, spacing) + cpts_multi_res = multiple_resolution_cpts(cpts, patchsize_multi_res) + patches_multi_res = [] + + for idx, pr in enumerate(patchsize_multi_res): + res, patch_size = pr + # Downsample the image and segmentation + image_resize, seg_resize = resample_by_resolution(image, seg, res) + + # Fetch patches + cpts_res = cpts_multi_res[idx] + patches_res = crop_patch_by_cpts(image_resize, seg_resize, cpts_res, patch_size) + patches_multi_res.append([patches_res, res]) + + return patches_multi_res + +# This function samples center points from segmentation for patching. +# Implement all patch selection in this function. Leave other function clean +# input: +# seg: segmentation of the image, dimension [H, W, D], right now assuming this is binary +# num_pos: number of positive patches that contains lesion to sample. If there is no enough patches to sample, +# it will return all indexes that contains lesion +# num_negative: number of negative background patches that doesn't contain lesion to sample. +def sample_center_points(seg, num_pos, num_neg): + idx_pos = np.stack(np.where(seg>0), axis = 0) + + if idx_pos[0].shape[0] 1e308] = 0 + + # Normalize the image + image_voxels = image_np[image_np!=0] # Get rid of the background + image_np_norm = (image_np - np.mean(image_voxels)) / np.std(image_voxels) + image_np_norm[image_np==0] = 0 + + return image_np_norm + +def tf_elastic_define_graph(): + image_ph = tf.compat.v1.placeholder(tf.float32, shape = [None, None, None, None]) + seg_ph = tf.compat.v1.placeholder(tf.float32, shape = [None, None, None]) + image_aug, seg_aug, _, _, _ = ett.elastic_transform_3D(image_ph, seg_ph, c.ep) + return image_ph, seg_ph, image_aug, seg_aug + +# This function is a wrapper function for the tensorflow elastic transformation for data augmentation +# input: +# image_np, seg_np: input image and segmentation, numpy array +def tf_elastic_wrapper(image_np, seg_np, sess, ops): + image_ph, seg_ph, image_aug, seg_aug = ops + image_aug_np, seg_aug_np = sess.run([image_aug, seg_aug], feed_dict = {image_ph:image_np[:,:,:,None], seg_ph:seg_np}) + return image_aug_np, seg_aug_np + +# Auxiliary function for loading image and segmentation +def loadImageSegPair(dirDiseaseTuple): + patientDir, patient, disease = dirDiseaseTuple + + # Load the image and segmentation + #imageDir = patientDir + "FLAIR/FLAIR_1x1x1.nii.gz" + #segDir = patientDir + "FLAIR/ManSeg_1x1x1.nii.gz" + imageDir = patientDir + "%s_1x1x1.nii.gz" % patient + segDir = patientDir + "%s_seg_1x1x1.nii.gz" % patient + + + imageDir = imageDir.replace(' ', '') + segDir = segDir.replace(' ', '') + + + # Read in image and segmentation + image_np_orig = sitk.GetArrayFromImage(sitk.ReadImage(imageDir)) + + # Handle cases where the segmentation image is missing + if os.path.exists(segDir): + seg_np_orig = sitk.GetArrayFromImage(sitk.ReadImage(segDir)) + else: + seg_np_orig= np.zeros_like(image_np_orig, dtype = np.uint8) + + return image_np_orig, seg_np_orig, patient, disease + +# Auxiliary function for cropping patches +def cropPatches(ispairAndParam): + image_np, seg_np, patient, disease, is_training, num_pos, num_neg = ispairAndParam + + # Crop patches + if is_training: + patches_pos_multi_res, patches_neg_multi_res =\ + multi_resolution_patcher_3D(image_np[:, :, :, None], seg_np, c.model_params.patchsize_multi_res, is_training = is_training, num_pos = num_pos, num_neg = num_neg) + + # Fit the patch to deepmedic format + if disease == "BG_normal" or disease == "normal": + patch_pos = [] + else: + patch_pos = patch_to_Unet_format(patches_pos_multi_res, c.model_params.segsize, disease, patient) + patch_negative = patch_to_Unet_format(patches_neg_multi_res, c.model_params.segsize, disease, patient) + + return patch_pos + patch_negative + else: + # Fit the patch to deepmedic format + patches_multi_res = multi_resolution_patcher_3D(image_np[:, :, :, None], seg_np, c.model_params.patchsize_multi_res, is_training = is_training, spacing = c.model_params.test_patch_spacing) + patches_multi_res = patch_to_Unet_format(patches_multi_res, c.model_params.segsize, disease, patient) + return patches_multi_res + +# This function generate patches according to deep medic format given list of directory. +# This is a very time consuming/very memory heavy function to run. So threadpool is used to optimized +# performance +# for batch of directories: +# load image (multi-thread) -> augmentation (single thread) -> patching (multi-thread) +def generate_deepmedic_patches(directories, is_training = True, num_pos = 100, num_neg = 100, aug = 1, ops = None, sess = None, num_thread = 1): + p = Pool(num_thread) + num_im = len(directories) + num_batch = int(np.ceil(num_im /num_thread)) + + # Generate patches + patches = [] + loop_patch = tqdm(range(num_batch), desc=" Generating patches", position=1, leave=False) + for i in loop_patch: + directories_batch = directories[i * num_thread : min((i+1) * num_thread, num_im)] + + # Parallelize the process of loading image + loop_in_patch = tqdm(total=0, desc=" Loading image", position=2, leave=False) + image_seg_pairs = p.map(loadImageSegPair, directories_batch) + image_seg_pairs_aug = [] + + # No augmentation in test case + loop_in_patch = tqdm(image_seg_pairs, desc=" Augmenting", position=2, leave=False) + if not is_training: + for ispair in loop_in_patch: + image_np, seg_np, patient, disease = ispair + image_np = image_np.astype(np.float32) + image_np = normalize_image(image_np) + image_seg_pairs_aug.append((image_np, seg_np, patient, disease, is_training, num_pos, num_neg)) + # Augment the image in training case - since it uses gpu, only serial execution + else: + for ispair in loop_in_patch: + image_np_orig, seg_np_orig, patient, disease = ispair + + for j in range(aug): + if aug == 1: # No augmentation + image_np, seg_np = image_np_orig, seg_np_orig + else: + image_np, seg_np = tf_elastic_wrapper(image_np_orig, seg_np_orig, sess, ops) + # Normalize the image + image_np = image_np.astype(np.float32) + image_np = normalize_image(image_np) + image_seg_pairs_aug.append((image_np, seg_np, patient, disease, is_training, num_pos, num_neg)) + + # Parallelize the process of generating patches + if False: + #if is_training: + patches_batch = p.map(cropPatches, image_seg_pairs_aug) + else: + # It seems that the pool use pickle to , which has a limitation on how large an object it can + # pass out as return. So in test case it is saver to just use a for loop instead of pool + # Shit! + patches_batch = [] + loop_in_patch = tqdm(image_seg_pairs_aug, desc=" Cropping patches", position=2, leave=False) + for ispair in loop_in_patch: + patches_batch.append(cropPatches(ispair)) + + # Append to the global one + for pat in patches_batch: + patches.extend(pat) + + # Close the thread pool + p.close() + p.join() + + return patches + + +# This function combine image normalization, patch generation and save to tf record +# Need to do shuffle on patch level to ensure it is truly random. +# This costs memory +def save_patches_to_tfrecord(patches, tfrecordDir): + # shuffle the patches #TODO shouldn't we only do that for training? + trIdx = list(range(len(patches))) + np.random.shuffle(trIdx) + + f = tfrecordDir + writer = tf.io.TFRecordWriter(f) + + # Save the patches to tfrecord + # iterate over each example + # wrap with tqdm for a progress bar + loop_trIdx = tqdm(trIdx, desc=" Saving patches to tfrecord", position=1, leave=False) + for example_idx in loop_trIdx: + p = patches[example_idx] + cpts = p[2] + shape = p[3] + disease = p[4] + patient = p[5] + + # construct the Example proto boject + example = tf.train.Example( + # Example contains a Features proto object + features=tf.train.Features( + # Features contains a map of string to Feature proto objects + feature={ + # A Feature contains one of either a int64_list, + # float_list, or bytes_list + 'patch_high_res': tf.train.Feature( + bytes_list= tf.train.BytesList(value=[p[0].astype(np.float32).tostring()])), + 'seg': tf.train.Feature( + bytes_list= tf.train.BytesList(value=[p[1].astype(np.uint8).tostring()])), + 'phr_x': tf.train.Feature( + int64_list=tf.train.Int64List(value=[p[0].shape[0]])), + 'phr_y': tf.train.Feature( + int64_list=tf.train.Int64List(value=[p[0].shape[1]])), + 'phr_z': tf.train.Feature( + int64_list=tf.train.Int64List(value=[p[0].shape[2]])), + 'phr_c': tf.train.Feature( + int64_list=tf.train.Int64List(value=[p[0].shape[3]])), + 'seg_x': tf.train.Feature( + int64_list=tf.train.Int64List(value=[p[1].shape[0]])), + 'seg_y': tf.train.Feature( + int64_list=tf.train.Int64List(value=[p[1].shape[1]])), + 'seg_z': tf.train.Feature( + int64_list=tf.train.Int64List(value=[p[1].shape[2]])), + 'x': tf.train.Feature( + int64_list=tf.train.Int64List(value=[cpts[0]])), + 'y': tf.train.Feature( + int64_list=tf.train.Int64List(value=[cpts[1]])), + 'z': tf.train.Feature( + int64_list=tf.train.Int64List(value=[cpts[2]])), + 'h': tf.train.Feature( + int64_list=tf.train.Int64List(value=[shape[0]])), + 'w': tf.train.Feature( + int64_list=tf.train.Int64List(value=[shape[1]])), + 'd': tf.train.Feature( + int64_list=tf.train.Int64List(value=[shape[2]])), + # 'disease': tf.train.Feature( + # int64_list=tf.train.Int64List(value=[c.diseaseCode[disease]])), + + 'patient': _bytes_feature(bytes(patient, encoding='ascii')), # python 3 + # tf.train.Feature( + # int64_list=tf.train.Int64List(value=[int(patient)])) + })) + # use the proto object to serialize the example to a string + serialized = example.SerializeToString() + # write the serialized object to disk + writer.write(serialized) + + writer.close() + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + diff --git a/unet/image/preprocess_images.py b/unet/image/preprocess_images.py new file mode 100755 index 0000000..da16b8d --- /dev/null +++ b/unet/image/preprocess_images.py @@ -0,0 +1,197 @@ +##### Preprocess images ##### +import os +from tqdm.auto import tqdm +import tensorflow as tf +import numpy as np +import SimpleITK as sitk +from multiprocessing import Pool, cpu_count + +from utils.datasplit import read_split +from .elastic_transform_tf import elastic_transform_3D_tf2, elastic_param +from .patch_util import multi_resolution_patcher_3D +from .preprocess_Unet import save_patches_to_tfrecord + +#TODO: watch resources used here by different steps +tf.debugging.set_log_device_placement(False) +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # only tf ERRORS are logged + +#what can be done faster for inference only? --> no tfrecord maybe? + +def img_to_tfrecord(config, train=True, val=True, test=False): + global c + c = config + + # Load elastic-transform parameters + ep = elastic_transform_load(c.ep) + trainDirs, valDirs, testDirs = read_split(split_file = c.train_test_csv, im_dir=c.data_nh_1mm) + + if train and len(trainDirs)>0: process_and_save(trainDirs, 'train', ep, c) + if val and len(valDirs)>0: process_and_save(valDirs, 'val', ep, c) + if test and len(testDirs)>0: process_and_save(testDirs, 'test', ep, c) + +def process_and_save(images_info, TVT, elastic_param, c): + train = TVT=='train' + num_images = len(images_info) + # Shuffle images + if TVT in ['train', 'val']: + trIdx = list(range(num_images)) + np.random.shuffle(trIdx) + images_info = [ images_info[idx] for idx in trIdx ] + num_images_batch = int(np.ceil(num_images / c.model_params.nTrainPerTfrecord)) + loop_tb = tqdm(range(num_images_batch), desc=f"Processing batch of {TVT} images", position=0) + for tb in loop_tb: + # Taking c.model_params.nTrainPerTfrecord images + ia = tb * c.model_params.nTrainPerTfrecord + ib = min((tb + 1) * c.model_params.nTrainPerTfrecord, num_images) + # Generate patches #TODO: how to optimize numthreads here? + patches = generate_patches( + images_info[ia:ib], + is_training = train, + num_pos=c.model_params.num_pos, + num_neg=c.model_params.num_neg, + augmentation=c.model_params.aug, + elastic_param=elastic_param, + num_processes=c.num_cpu) + # Save patches to TfRecords + if c.model_params.nTrainPerTfrecord > 1: + tfrecord_path = f"{c.tfrDir}/{TVT}_{tb:03d}.tfrecords" + else: + accession = images_info[ia][1] + tfrecord_path = f"{c.tfrDir}/{TVT}_{accession}.tfrecords" + save_patches_to_tfrecord(patches, tfrecord_path) + + + +def generate_patches( + images_info, + is_training = True, + num_pos=30, + num_neg=30, + augmentation=1, + elastic_param=None, + num_processes = cpu_count()): + """ + Returns: patches list (image, seg, cpts, shape, disease, patient ID) + """ + # Set multi-threading + p = Pool(num_processes) + num_images = len(images_info) + num_images_thread = int(np.ceil(num_images/num_processes)) + patch_list = [] + loop_batch = tqdm(range(num_images_thread), desc=f" Generating patches for batch of images ({num_images})", position=1, leave=False) + for i in loop_batch: + images_batch_thread = images_info[i * num_processes : min((i+1) * num_processes, num_images)] + # Parallelize the process of loading images + loop_in_batch = tqdm(total=0, desc=" Loading images", position=2, leave=False) + image_seg_pairs = p.map(load_image_seg_pair, images_batch_thread) + image_seg_pairs_aug = [] + # Image augmentation (none in case of testing) + loop_in_batch = tqdm(image_seg_pairs, desc=f" Augmenting (x{augmentation})", position=2, leave=False) + if not is_training: + # Normalization #TODO: how does that affect volume calculation? + for pair in loop_in_batch: + image_np, seg_np, patient, disease = pair + image_np = image_np.astype(np.float32) + image_np = normalize_image(image_np) + image_seg_pairs_aug.append((image_np, seg_np, patient, disease, is_training, num_pos, num_neg)) + else: + for pair in loop_in_batch: + image_np_orig, seg_np_orig, patient, disease = pair + # Augmentation + for i in range(augmentation): + if augmentation == 1: # No augmentation + image_np, seg_np = image_np_orig, seg_np_orig + else: + image_np_orig = tf.cast(image_np_orig, tf.float32) + image_np, seg_np = elastic_transform_3D_tf2(image_np_orig[:,:,:,None], seg_np_orig, elastic_param) #TODO: remove tf2 + # Normalization + image_np = tf.cast(image_np, tf.float32).numpy() + image_np = normalize_image(image_np) + image_seg_pairs_aug.append((image_np, seg_np, patient, disease, is_training, num_pos, num_neg)) + # Generate patches + # if is_training: + # patches_batch = p.map(crop_patches, image_seg_pairs_aug) + # else: + patches_batch = [] + loop_img_aug = tqdm(image_seg_pairs_aug, desc=" Cropping patches", position=2, leave=False) + for pair in loop_img_aug: + patches_batch.append(crop_patches(pair, c.model_params.patchsize_multi_res, c.model_params.test_patch_spacing, c.model_params.segsize)) + # Append to the global patch list + for pat in patches_batch: + patch_list.extend(pat) + # Close thread pool + p.close() + p.join() + return patch_list + +def load_image_seg_pair(image_info): + """ + Input: List with (image dir path, patient ID, disease) + Returns: SimpleITK image, segmentation, patient ID, disease + """ + image_dir, patient, disease = image_info + image_path = f"{image_dir/patient}_1x1x1.nii.gz".replace(' ','') + seg_path = f"{image_dir/patient}_seg_1x1x1.nii.gz".replace(' ','') + # Read image and segmentation + image = sitk.GetArrayFromImage(sitk.ReadImage(image_path)) + # Handle cases where the segmentation image is missing by inputing 0s + if os.path.exists(seg_path): + seg = sitk.GetArrayFromImage(sitk.ReadImage(seg_path)) + else: + seg= np.zeros_like(image, dtype = np.uint8) + return image, seg, patient, disease + +def normalize_image(image): + """ + Normalizes an image + """ + # Remove all possible artifact + image[np.isnan(image)] = 0 + image[np.abs(image) > 1e308] = 0 + # Normalize the image + image_voxels = image[image>0] # Get rid of the background before normalization + image_norm = (image - np.mean(image_voxels)) / np.std(image_voxels) + image_norm[image==0] = 0 # Put the background back where it was + return image_norm + +def crop_patches(image_info, patchsize_multi_res, test_patch_spacing, segsize): + """ + Input: List with (image, seg, patient ID, disease, is_training, num_pos, num_neg), patchsize_multi_res, test_patch_spacing, segsize + Returns: patches for one image (image, seg, cpts, shape, disease, patient ID) + """ + image_np, seg_np, patient, disease, is_training, num_pos, num_neg = image_info + # Crop patches + if is_training: + patches_pos_multi_res, patches_neg_multi_res =\ + multi_resolution_patcher_3D(image_np[:, :, :, None], seg_np, patchsize_multi_res, is_training = is_training, num_pos = num_pos, num_neg = num_neg) + # Fit the patch to deepmedic format + if disease == "BG_normal" or disease == "normal": + patch_pos = [] + else: + patch_pos = patch_to_Unet_format(patches_pos_multi_res, segsize, disease, patient) + patch_negative = patch_to_Unet_format(patches_neg_multi_res, segsize, disease, patient) + return patch_pos + patch_negative + else: + # Fit the patch to deepmedic format + patches_multi_res = multi_resolution_patcher_3D(image_np[:, :, :, None], seg_np, patchsize_multi_res, is_training = is_training, spacing = test_patch_spacing) + patches_multi_res = patch_to_Unet_format(patches_multi_res, segsize, disease, patient) + return patches_multi_res + +def patch_to_Unet_format(patches_multi_res, seg_size, disease, patient): + """ + This function converts the patch from patch generator function to one that deepmedic needs + """ + patches = [] + patches_high_res = patches_multi_res[0][0] + for ph in patches_high_res: + seg = ph[1] + cpts = ph[2] + shape = ph[3] + patches.append((ph[0], seg, cpts, shape, disease, patient)) + return patches + +def elastic_transform_load(param_dict): + ep = elastic_param() + for key, value in param_dict.items(): + setattr(ep, key, value) + return ep diff --git a/unet/image/process_image.py b/unet/image/process_image.py new file mode 100755 index 0000000..07b8794 --- /dev/null +++ b/unet/image/process_image.py @@ -0,0 +1,170 @@ +### Process images (both preprocess and postprocess) +import os +import shutil + +import numpy as np +import scipy +import SimpleITK as sitk +from tqdm.auto import tqdm +from utils.multiprocessing import run_multiprocessing + +from .fix_orientation import fixOrientation, getImSpDirct, reverseOrientation +from .resample_util import resample_im_by_spacing, resample_seg_by_spacing + + +def image_already_processed(image, output_dir): #TODO it doesn't work + # don't process images that have already been! + # check output directory for existence + if image in os.listdir(output_dir): + return True + return False + +def create_image_list(input_dir, image_list): + if image_list is None: + images = [p for p in os.listdir(input_dir) if p.endswith(".nii.gz")] + else: + image_list = list(map(str, image_list)) + images = [] + for im in os.listdir(input_dir): + # extract patient ID from image name + if im.endswith("_seg.nii.gz"): + p = im.split("_seg.nii.gz")[0] + elif im.endswith(".nii.gz"): + p = im.split(".nii.gz")[0] + # check if it's in the list provided + if p in image_list: + images.append(im) + return images + +##Strip out the header +def strip_header_dir(data_raw_dir, data_nh_dir, image_list=None, num_processes=1, force_preprocess=False): + images = create_image_list(data_raw_dir, image_list) + run_multiprocessing(strip_header, + images, + fixed_arguments={'input_dir':data_raw_dir, + 'output_dir':data_nh_dir, + 'force_preprocess':force_preprocess}, + num_processes=num_processes, + title="Stripping header") + return + +def strip_header(im_name, input_dir, output_dir, force_preprocess): + if not force_preprocess and image_already_processed(im_name, output_dir): return + im_in = input_dir / im_name + im_out = output_dir / im_name + try: + im = sitk.ReadImage(str(im_in)) + im_np, spacing, direction = getImSpDirct(im) + im_np, spacing = fixOrientation(im_np, direction, spacing) + im = sitk.GetImageFromArray(im_np) + im.SetSpacing(list(spacing)) + sitk.WriteImage(im, str(im_out)) + except Exception as err: + print(f'Error in strip_header for {im_name}.\n {type(err)}\n {err.args}') + # except ValueError as err: + # print(f'Error in strip_header for {im_name}.\n {type(err)}\n {err.args}') + return + +##Resample the image and segmentation into 1x1x1 mm +def resample_to_1mm_dir(data_nh_dir, data_nh_1mm, image_list=None, num_processes=1, force_preprocess=False): + images = create_image_list(data_nh_dir, image_list) + run_multiprocessing(resample_to_1mm, + images, + fixed_arguments={'input_dir':data_nh_dir, + 'output_dir':data_nh_1mm, + 'force_preprocess':force_preprocess}, + num_processes=num_processes, + title="Resampling images") + +def resample_to_1mm(im, input_dir, output_dir, force_preprocess): + if "native" in im: + return + if not force_preprocess and image_already_processed(im.split(".nii.gz")[0]+"_1x1x1.nii.gz", output_dir): return + im_in = input_dir / im + im_out = output_dir / im + + # Read the image and segmentation + if "seg" in im: + seg_sitk = sitk.ReadImage(str(im_in)) + seg_np = sitk.GetArrayFromImage(seg_sitk) + spacing = seg_sitk.GetSpacing()[::-1] + seg_re_np = resample_seg_by_spacing(seg_np, spacing) + sitk.WriteImage( sitk.GetImageFromArray(seg_re_np), str(im_out).split(".nii.gz")[0] + "_1x1x1.nii.gz") + else: + image_sitk = sitk.ReadImage(str(im_in)) + image_np = sitk.GetArrayFromImage(image_sitk) + image_np = image_np[:, :, :, None] + spacing = image_sitk.GetSpacing()[::-1] + image_re_np = resample_im_by_spacing(image_np, spacing) + image_re_np = np.squeeze(image_re_np) + sitk.WriteImage( sitk.GetImageFromArray(image_re_np), str(im_out).split(".nii.gz")[0] + "_1x1x1.nii.gz") + return + +##Resample the validation output to original space +# Resize the segmentation to original space and reverse the orientation + +def resample_from_1mm_dir(c, num_processes=1): + patients = [p.split(".nii.gz")[0] for p in os.listdir(c.valoutDir) if p.endswith(".nii.gz")] + run_multiprocessing(resample_from_1mm, + patients, + fixed_arguments={'input_dir':c.valoutDir, + 'output_dir':c.resampledDir, + 'noheader_dir':c.data_nh, + 'raw_dir':c.data_raw}, + num_processes=num_processes, + title="Resampling images") + +def resample_from_1mm(accession, input_dir, output_dir, noheader_dir, raw_dir): + pred = sitk.GetArrayFromImage(sitk.ReadImage(str(input_dir / (accession + ".nii.gz")))) + noheader = sitk.ReadImage(str(noheader_dir / (accession + ".nii.gz"))) + header = sitk.ReadImage(str(raw_dir / (accession + ".nii.gz"))) + + # Resampling + old_size = noheader.GetSize()[::-1] + pred_old = scipy.ndimage.zoom(pred, (old_size[0]/pred.shape[0], old_size[1]/pred.shape[1], old_size[2]/pred.shape[2] ), order = 1) + + if (pred_old.shape[0]!=old_size[0]) or (pred_old.shape[1]!=old_size[1]) or (pred_old.shape[2]!=old_size[2]): + print(accession) + + # Reverse direction + direction = np.array(header.GetDirection()).reshape((3, 3)) + pred_old_reverse = reverseOrientation(pred_old, direction) + + # Copy header and spacing + out = sitk.GetImageFromArray(pred_old_reverse) + out.SetSpacing(noheader.GetSpacing()) + out.CopyInformation(header) + + sitk.WriteImage(out, str(output_dir / (accession + ".nii.gz"))) + +#%% Run binarization of images + +def binarize(patient, prob_seg_dir, bin_seg_dir, threshold, fill_holes=False): + import nibabel as nib + prob_seg_file = prob_seg_dir / (str(patient) + ".nii.gz") + md = bin_seg_dir / str(threshold) + md.mkdir(parents=True, exist_ok=True) + bin_seg_file = md / (str(patient) + "_binary.nii.gz") + # im_sitk = sitk.ReadImage(str(prob_seg_file)) + # im_np = sitk.GetArrayFromImage(im_sitk) + img_nib = nib.load(prob_seg_file) + img_np = img_nib.get_fdata() + seg_np = np.where(img_np > threshold, 1, 0) + if fill_holes: + seg_np = scipy.ndimage.binary_fill_holes(seg_np).astype(int) + nib.save(nib.Nifti1Image(seg_np,img_nib.affine,img_nib.header), bin_seg_file) + # sitk.WriteImage(sitk.GetImageFromArray(seg_np).CopyInformation(im_sitk), bin_seg_file) + # command = f"fslmaths {prob_seg_file} -thr {threshold} -bin {bin_seg_file}" + # os.system(command) + +def binarize_dir(input_dir, output_dir, threshold, num_processes=1, fill_holes=False): + patients = [p.split(".nii.gz")[0] for p in os.listdir(input_dir) if p.endswith(".nii.gz")] + # if type(threshold) is not list: threshold = [threshold] + run_multiprocessing(binarize, + patients, + fixed_arguments={'prob_seg_dir':input_dir, + 'bin_seg_dir':output_dir, + 'threshold':threshold, + 'fill_holes':fill_holes}, + num_processes=num_processes, + title="Binarization") diff --git a/unet/image/resample_util.py b/unet/image/resample_util.py new file mode 100755 index 0000000..137f96d --- /dev/null +++ b/unet/image/resample_util.py @@ -0,0 +1,120 @@ +#%% This scrips defines resample utility +import numpy as np +from scipy import ndimage +import tensorflow as tf + +from .LinearInterpolation import trilinear_sampler, trinnsampler + +# This function uses numpy nd image implementation of tri-linear interpolation to resample image +# input: +# image: input image, [H, W, D, C] +# seg: segmentation, [H, W, D] +# res: 0-1, relative resolution of output to original image +# outputs: +# image_resize, seg_resize: image, segmentation that get resampled +def resample_by_resolution(image, seg, res): + image_resize = ndimage.zoom(image, (res, res, res, 1.0), order = 1) + seg_resize = ndimage.zoom(seg, (res, res, res), order = 0) + + return image_resize, seg_resize + +# This function generate grid by resolution +# input: +# height, width, depth: Tensor, size of the image being resampled +# res: list of length 3, each 0-1, relative resolution of output to original image in 3 dimension +def grid_generator_tf(height, width, depth, res): + y = tf.linspace(0.0, tf.cast(height, tf.float32) - 1, res[0] * height) + x = tf.linspace(0.0, tf.cast(width, tf.float32) - 1, res[1] * width) + z = tf.linspace(0.0, tf.cast(depth, tf.float32) - 1, res[2] * depth) + x_t, y_t, z_t = tf.meshgrid(x, y, z) + + x_t = tf.expand_dims(x_t, axis=0) + y_t = tf.expand_dims(y_t, axis=0) + z_t = tf.expand_dims(z_t, axis=0) + + return x_t, y_t, z_t + + +# This function uses tensorflow implementation of tri-linear interpolation to resample image +# input: +# image: input image, Tensor, size [H, W, D, C] +# seg: segmentation, Tensor, [H, W, D] +# res: list of length 3, each 0-1, relative resolution of output to original image in 3 dimension +# outputs: +# image_resize, seg_resize: Tensor, image, segmentation that get resampled +def resample_by_resolution_tf(image, seg, res): + height, width, depth, C = tf.shape(image) + x_t, y_t, z_t = grid_generator_tf(height, width, depth, res) + + image = tf.expand_dims(image, axis = 0) + image_resize = trilinear_sampler(image, x_t, y_t, z_t) + image_resize = tf.squeeze(image_resize) + + seg = tf.expand_dims(seg, axis = 0) + seg = tf.expand_dims(seg, axis = -1) + seg_resize = trinnsampler(seg, x_t, y_t, z_t) + seg_resize = tf.squeeze(seg_resize) + + return image_resize, seg_resize + +# This function uses tensorflow implementation of tri-linear interpolation to resample just segmentation +# input: +# labels: labels in 3D, Tensor, [N, H, W, D, C] +# res: resample resolution factor. +# interp: interpolation method, either nn or linear. Default nn. +# outputs: +# labels_resize: Tensor, [N ,H, W, D, C], segmentation that get resampled +def resample_labels_tf(labels, res, interp = "nn"): + N, height, width, depth, C = tf.shape(labels) + x_t, y_t, z_t = grid_generator_tf(height, width, depth, res) + + if interp == "nn": + labels_resize = trinnsampler(labels, x_t, y_t, z_t) + elif interp == "linear": + labels_resize = trilinear_sampler(labels, x_t, y_t, z_t) + else: + raise ValueError("Invalid interp type in resamples") + + return labels_resize + +# This function resampled img and seg to 1x1x1. +# inputs: +# img: the image, [H, W, D, C] +# seg: the segmentation, [H, W, D] +# spacing: tuple, voxel spacing in 3D before resampled +# new_spacing: tuple, voxel spacing in 3D after resampled +# outputs: +# image_resize, seg_resize: image, segmentation that get resampled +def resample_by_spacing(image, seg, spacing, new_spacing = (1.0, 1.0, 1.0)): + # calculate resize factor + rf = np.array(spacing).astype(np.float64) / np.array(new_spacing).astype(np.float64) + image_resize = ndimage.zoom(image, (rf[0], rf[1], rf[2], 1.0), order = 1) + seg_resize = ndimage.zoom(seg, rf, order = 0) + return image_resize, seg_resize + +def resample_im_by_spacing(image, spacing, new_spacing = (1.0, 1.0, 1.0)): + # calculate resize factor + rf = np.array(spacing).astype(np.float64) / np.array(new_spacing).astype(np.float64) + image_resize = ndimage.zoom(image, (rf[0], rf[1], rf[2], 1.0), order = 1) + return image_resize + +def resample_seg_by_spacing(seg, spacing, new_spacing = (1.0, 1.0, 1.0)): + # calculate resize factor + rf = np.array(spacing).astype(np.float64) / np.array(new_spacing).astype(np.float64) + seg_resize = ndimage.zoom(seg, rf, order = 0) + return seg_resize + + +# This function resampled img and seg to 1x1x1 using tensorflow trilinear interpolation. +# inputs: +# image: input image, Tensor, size [H, W, D, C] +# seg: segmentation, Tensor, [H, W, D] +# spacing: tuple, voxel spacing in 3D before resampled +# new_spacing: tuple, voxel spacing in 3D after resampled +# outputs: +# image_resize, seg_resize: Tensor, image, segmentation that get resampled +def resample_by_spacing_tf(image, seg, spacing, new_spacing = (1.0, 1.0, 1.0)): + # calculate resize factor + rf = np.array(spacing).astype(np.float64) / np.array(new_spacing).astype(np.float64) + image_resize, seg_resize = resample_by_resolution_tf(image, seg, rf) + return image_resize, seg_resize \ No newline at end of file diff --git a/unet/model/__init__.py b/unet/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/unet/model/model_definition.py b/unet/model/model_definition.py new file mode 100755 index 0000000..c0e1d30 --- /dev/null +++ b/unet/model/model_definition.py @@ -0,0 +1,292 @@ +###### Model definition ###### +import tensorflow as tf +from tensorflow.keras import backend as K +import numpy as np +import sys + +default_kernel_size = (3, 3, 3) +OUTPUT_CHANNELS = 1 + +down_res_block_filters = [(16, 16, 64, 2), (64, 64, 128, 2), (128, 128, 256, 3), (256, 256, 256, 3)] +up_res_block_filters = [(256, 256, 128, 1), (128, 128, 64, 1), (64, 64, 16, 1), (16, 16, 16, 1)] +final_filters = (16, 16, 1) + +#### Model Functions #### + +# A wrapper around the normalization layer +def norm_layer(name): + return tf.keras.layers.BatchNormalization(name = name + "_bn") + +# Convolution sandwich building block +def sandwich(filters, kernel_size, dilation, name): + l = tf.keras.Sequential(name=name) + l.add(tf.keras.layers.Conv3D(filters, kernel_size, dilation_rate=dilation, activation= tf.nn.relu, name=name + "_conv", padding='same')) + l.add(norm_layer(name)) + return l + + +# The residual block +def resblock(X, filters, filters2, kernel_size, dilation, name): + # The long way + l_lw = tf.keras.Sequential(name=name+"_lgWay") + l_lw.add(tf.keras.layers.Conv3D(filters, kernel_size, dilation_rate=dilation, activation= tf.nn.relu, name=name + "_sand1_conv", padding='same')) + l_lw.add(norm_layer(name+"_sand1")) + # l_lw.add(sandwich(filters, default_kernel_size, dilation, name = name + "_sand1")) + l_lw.add(tf.keras.layers.Conv3D(filters2, kernel_size, dilation_rate=dilation, activation= tf.nn.relu, name=name + "_sand2_conv", padding='same')) + l_lw.add(norm_layer(name+"_sand2")) + # l_lw.add(sandwich(filters2, default_kernel_size, dilation, name = name + "_sand2")) + + # The short way + l_sw = tf.keras.Sequential(name=name+"_shWay") + l_sw.add(tf.keras.layers.Conv3D(filters2, (1, 1, 1), name=name + "_short_conv", padding='same')) + l_sw.add(norm_layer(name = name + "_short")) + + l = tf.keras.layers.Add(name=name+"_resAdd")([l_lw(X), l_sw(X)]) + return l + +# Downsampling block +def downsampling_block(filters, name): + l = tf.keras.Sequential(name=name) + l.add(tf.keras.layers.Conv3D(filters = filters, strides = (2, 2, 2), kernel_size = default_kernel_size, padding = 'same', activation = tf.nn.relu)) + l.add(norm_layer(name)) + return l + +# Upsampling block +def upsampling_block(filters, name): + l = tf.keras.Sequential(name=name) + l.add(tf.keras.layers.Conv3DTranspose(filters = filters, strides=(2, 2, 2), kernel_size = default_kernel_size, padding = 'same', activation = tf.nn.relu)) + l.add(norm_layer(name)) + return l + + + +#### Metric and losses #### + +class DiceLoss(tf.keras.metrics.Metric): + """tried to define a DiceLoss class -- haven't got it to work yet + """ + def __init__(self, name="dice_loss", eps=1e-6, **kwargs): + super(DiceLoss, self).__init__(name=name, **kwargs) + self.__name__ = name + self.eps = eps + self.dice_loss = self.add_weight(name="dice", initializer="zeros") + + def update_state(self, label, logits, sample_weight=None): + label = tf.cast(label, tf.float32) + logits = tf.cast(logits, tf.float32) + intersection = tf.reduce_sum(logits * label, axis = (1, 2, 3)) + union = tf.reduce_sum(logits, axis = (1, 2, 3)) + tf.reduce_sum(label, axis = (1, 2, 3)) + loss = 2.0*(intersection + self.eps)/ (union + 2.0*self.eps) + self.dice_loss.assign_add(tf.reduce_mean(-loss)) + + def result(self): + return self.dice_loss + + def reset_states(self): + # The state of the metric will be reset at the start of each epoch. + self.dice_loss.assign(0.0) + +class DiceScore(tf.keras.metrics.Metric): + def __init__(self, name="dice_score", eps=1e-6, **kwargs): + super(DiceScore, self).__init__(name=name, **kwargs) + self.__name__ = name + self.eps = eps + self.dice_score = self.add_weight(name="dice", initializer="zeros") + + def update_state(self, label, logits, sample_weight=None): + label = tf.cast(label, tf.float32) + logits = tf.cast(logits, tf.float32) + intersection = tf.reduce_sum(logits * label, axis = (1, 2, 3)) + union = tf.reduce_sum(logits, axis = (1, 2, 3)) + tf.reduce_sum(label, axis = (1, 2, 3)) + score = 2.0*(intersection + self.eps)/ (union + 2.0*self.eps) + self.dice_score.assign_add(tf.reduce_mean(score)) + + def result(self): + return self.dice_score + + def reset_states(self): + # The state of the metric will be reset at the start of each epoch. + self.dice_score.assign(0.0) + + +def soft_dice(_type="score", smooth = 1e-5): + """ Computes Dice + Args: + _type: "score" or "loss" + smooth: smoothing coefficient + Returns: + either Dice score or Dice loss (1-score) + """ + def dice_score(y_true, y_pred): + # Flatten + y_true_f = K.cast(K.flatten(y_true), y_pred.dtype) + y_pred_f = K.flatten(y_pred) + # Sum + im_sum = K.sum(y_true_f) + K.sum(y_pred_f) + im_sum = K.cast(im_sum, tf.float32) + # Intersection + intersection = K.sum(y_true_f * y_pred_f) + intersection = K.cast(intersection, tf.float32) + # Return Dice coefficient + return (2. * intersection + smooth) / (im_sum + smooth) + + def dice_loss(y_true, y_pred): + return 1-dice_score(y_true, y_pred) + + if _type == "score": + return dice_score + elif _type == "loss": + return dice_loss + +def weighted_cross_entropy(weight_alpha=0.9, binary=True): + def _loss(y_true, y_pred): + y_true = K.cast(y_true, y_pred.dtype) + weights = y_true * (weight_alpha/(1.-weight_alpha)) + 1. + bce = K.binary_crossentropy(y_true, y_pred, from_logits=False) + # axis = (1, 2, 3, 4) if binary else (1, 2, 3) + weighted_loss = K.mean(bce * weights) + return weighted_loss + return _loss + +def focal_loss(alpha=0.25, gamma=2.0): + from tensorflow_addons.losses import SigmoidFocalCrossEntropy + return SigmoidFocalCrossEntropy(alpha=alpha, gamma=gamma) + + +#### Model architecture definition #### + +def UNet(kernel_size=default_kernel_size, + down_res_block_filters=down_res_block_filters, + up_res_block_filters=up_res_block_filters, + final_filters=final_filters, input_shape = (96,96,96,1), training=True): #TODO: check training=True/False + X_input = tf.keras.Input(input_shape) + X = X_input + + # Down + X_inter = [] + for k, filters_tuple in enumerate(down_res_block_filters): + filters, filters2, filter_downsample, dilation = filters_tuple + X = resblock(X, filters, filters2, kernel_size, dilation, name = "dblock_%d_res_down"%k) + X_inter.append(X) + X = downsampling_block(filter_downsample, name = "dblock_%d_down"%k)(X) + + # Reverse the intermediary blocks to send to upsampling + X_inter = X_inter[::-1] + + # Up + K = len(up_res_block_filters) + for k, (filters_tuple, inter) in enumerate(zip(up_res_block_filters, X_inter)): + filters, filters2, filter_upsample, dilation = filters_tuple + X = resblock(X, filters, filters2, kernel_size, dilation, name = "ublock_%d_res_up"%(K-k-1)) + X = upsampling_block(filter_upsample, name = "ublock_%d_up"%(K-k-1))(X) + X = tf.keras.layers.Concatenate(axis = -1, name="ublock_%d_concat"%(K-k-1))([X, inter]) + + # Final layer + for k, filters in enumerate(final_filters): + X = sandwich(filters, (1, 1, 1), 1, name = "final_sandwich_%d"%k)(X) + X_output = tf.keras.layers.Activation('sigmoid', dtype='float32', name='prediction')(X) + + # Create model + model = tf.keras.Model(inputs = X_input, outputs = X_output, name = "UNet") + + return model + +class UNetModel(tf.keras.Model): + """ Class defining the training steps for the model + Args: + x, y: model input tensor, model output + """ + def train_step(self, data): + # Unpack the data. Its structure depends on your model and + # on what you pass to `fit()`. + x, y = data + + with tf.GradientTape() as tape: + y_pred = self(x, training=True) # Forward pass + # Compute the loss value + # (the loss function is configured in `compile()`) + loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) + + # Compute gradients + trainable_vars = self.trainable_variables + gradients = tape.gradient(loss, trainable_vars) + # Update weights + self.optimizer.apply_gradients(zip(gradients, trainable_vars)) + # Update metrics (includes the metric that tracks the loss) + self.compiled_metrics.update_state(y, y_pred) + # Return a dict mapping metric names to current value + return {m.name: m.result() for m in self.metrics} + +def UNet_custom_train(kernel_size=default_kernel_size, + down_res_block_filters=down_res_block_filters, + up_res_block_filters=up_res_block_filters, + final_filters=final_filters, + input_shape = (96,96,96,1), + training=True): #TODO: check training=True/False + + X_input = tf.keras.Input(input_shape) + X = X_input + + # Down + X_inter = [] + for k, filters_tuple in enumerate(down_res_block_filters): + filters, filters2, filter_downsample, dilation = filters_tuple + X = resblock(X, filters, filters2, kernel_size, dilation, name = "dblock_%d_res_down"%k) + X_inter.append(X) + X = downsampling_block(filter_downsample, name = "dblock_%d_down"%k)(X) + + # Reverse the intermediary blocks to send to upsampling + X_inter = X_inter[::-1] + + # Up + K = len(up_res_block_filters) + for k, (filters_tuple, inter) in enumerate(zip(up_res_block_filters, X_inter)): + filters, filters2, filter_upsample, dilation = filters_tuple + X = resblock(X, filters, filters2, kernel_size, dilation, name = "ublock_%d_res_up"%(K-k-1)) + X = upsampling_block(filter_upsample, name = "ublock_%d_up"%(K-k-1))(X) + X = tf.keras.layers.Concatenate(axis = -1, name="ublock_%d_concat"%(K-k-1))([X, inter]) + + # Final layer + for k, filters in enumerate(final_filters): + X = sandwich(filters, (1, 1, 1), 1, name = "final_sandwich_%d"%k)(X) + X = tf.keras.activations.sigmoid(X) + + # Create model, using the custom training loop + # model = UNetModel(X_input, X) + + return X_input, X + +#### Freezing layers #### +def freeze_layers(model, instring, all_but=True, debug=True): + """ + Freeze layers that contain `instring` in their name (or all but that layer with `all_but` argument) + for example, instring='ublock' freezes all up blocks. + see model.summary() for layer names + """ + def freeze(elem, all_but): + for key, layer in layer_dict.items(): + if all_but: + freeze_layer = elem not in key + else: + freeze_layer = elem in key + if freeze_layer: + layer.trainable = False + if debug: print(f"Froze: {key}") + + layer_dict = {l.name: model.get_layer(l.name) for l in model.layers} + + if type(instring) == list: + if all_but: + sys.exit("Freezing layers currently doesn't support freezing all but a list of layers.") + else: + for elem in instring: + freeze(elem, all_but) + elif type(instring) == str: + freeze(instring, all_but) + else: + print("wrong input") + exit + + if debug: model.summary() + + diff --git a/unet/model/train_test.py b/unet/model/train_test.py new file mode 100755 index 0000000..0878def --- /dev/null +++ b/unet/model/train_test.py @@ -0,0 +1,394 @@ +##### train/test/predict ##### +import os +from pathlib import Path +import logging +from multiprocessing import Pool +from functools import partial +import tensorflow as tf +tf.get_logger().setLevel(logging.WARN) +import numpy as np +import SimpleITK as sitk +from tqdm.auto import tqdm +from utils.tic_toc import tic, toc +from .model_definition import UNet, soft_dice, weighted_cross_entropy, freeze_layers +from image.postprocess_deepmedic import assemSegFromPatches_dir +from image.preprocess_images import load_image_seg_pair, normalize_image, crop_patches +from image.image_plot import ImageHistory + + +def dataset_from_tfr(tfrDir, buffer_size, batch_size, test_train, img_list=None): + """Load dataset from tensorflow records + Args: + tfrDir: directory for Tensorflow records + buffer_size: buffer size for element shuffling + batch_size: batch size + test_train: "test" or "train" (different behaviors depending on which one is chosen - i.e. no shuffle for test time) + img_list: if set, chooses a subsample of training records that matches the names provided in that list + Returns: + tf.Dataset + """ + train = test_train=="train" + if img_list is not None: + tfrecords = [str(tr) for tr in tfrDir.iterdir() if tr.name.endswith(".tfrecords") and test_train in tr.stem and tr.stem.split('_',1)[1] in img_list] + buffer_size = len(tfrecords) + else: + tfrecords = [str(tfrDir / tr) for tr in os.listdir(tfrDir) if tr.endswith(".tfrecords") and test_train in tr] + try: + ds = tf.data.TFRecordDataset(tfrecords, num_parallel_reads=len(tfrecords)) + except: + ds = tf.data.TFRecordDataset(tfrecords) + if train: + ds = ds.shuffle(buffer_size, reshuffle_each_iteration=True) #seed=2020 TODO: check if loss function is repeating itself? + ds = ds.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE) + if img_list is None: ds = ds.shuffle(buffer_size, reshuffle_each_iteration=True) # shuffle twice since several examples per tfrecord + elif test_train == "val": + ds = ds.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE) + else: + ds = ds.map(load_image_test, num_parallel_calls=tf.data.experimental.AUTOTUNE) + ds = ds.batch(batch_size) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) + return ds + +def load_image_train(tfrecord): + tfrecord_features = tf.io.parse_single_example( + tfrecord, + features={ + 'patch_high_res': tf.io.FixedLenFeature([], tf.string), + 'seg': tf.io.FixedLenFeature([], tf.string), + 'phr_x': tf.io.FixedLenFeature([], tf.int64), + 'phr_y': tf.io.FixedLenFeature([], tf.int64), + 'phr_z': tf.io.FixedLenFeature([], tf.int64), + 'phr_c': tf.io.FixedLenFeature([], tf.int64), + 'seg_x': tf.io.FixedLenFeature([], tf.int64), + 'seg_y': tf.io.FixedLenFeature([], tf.int64), + 'seg_z': tf.io.FixedLenFeature([], tf.int64), + 'x': tf.io.FixedLenFeature([], tf.int64), + 'y': tf.io.FixedLenFeature([], tf.int64), + 'z': tf.io.FixedLenFeature([], tf.int64), + #'disease': tf.io.FixedLenFeature([], tf.int64), + 'patient': tf.io.FixedLenFeature([], tf.string) + } + ) + + phr_x, phr_y, phr_z, phr_c = tfrecord_features['phr_x'], tfrecord_features['phr_y'], tfrecord_features['phr_z'], tfrecord_features['phr_c'] + phr = tf.io.decode_raw(tfrecord_features['patch_high_res'], tf.float32) + phr = tf.reshape(phr, shape = tf.stack([phr_x, phr_y, phr_z, phr_c])) + + seg_x, seg_y, seg_z = tfrecord_features['seg_x'], tfrecord_features['seg_y'], tfrecord_features['seg_z'] + seg = tf.io.decode_raw(tfrecord_features['seg'], tf.uint8) + seg = tf.reshape(seg, shape = tf.stack([seg_x, seg_y, seg_z])) + seg = tf.expand_dims(seg, axis = -1) + + # disease = tfrecord_features['disease'] + patient = tf.io.decode_raw(tfrecord_features['patient'], tf.uint8) + + sz = tf.size(patient) + patient = tf.pad(patient, [[0, 100 - sz]]) + + x, y, z = tfrecord_features['x'], tfrecord_features['y'], tfrecord_features['z'] + + # Random flip left and right + flip = tf.random.uniform([1], minval=0.0, maxval=1.0) + flip = tf.squeeze(flip) + phr = tf.cond(flip < 0.5, lambda: tf.reverse(phr, axis = [2]), lambda: phr) + seg = tf.cond(flip < 0.5, lambda: tf.reverse(seg, axis = [2]), lambda: seg) + + return phr, seg #, x, y, z, patient + +def load_image_test(tfrecord): + tfrecord_features = tf.io.parse_single_example( + tfrecord, + features={ + 'patch_high_res': tf.io.FixedLenFeature([], tf.string), + 'seg': tf.io.FixedLenFeature([], tf.string), + 'phr_x': tf.io.FixedLenFeature([], tf.int64), + 'phr_y': tf.io.FixedLenFeature([], tf.int64), + 'phr_z': tf.io.FixedLenFeature([], tf.int64), + 'phr_c': tf.io.FixedLenFeature([], tf.int64), + 'seg_x': tf.io.FixedLenFeature([], tf.int64), + 'seg_y': tf.io.FixedLenFeature([], tf.int64), + 'seg_z': tf.io.FixedLenFeature([], tf.int64), + 'x': tf.io.FixedLenFeature([], tf.int64), + 'y': tf.io.FixedLenFeature([], tf.int64), + 'z': tf.io.FixedLenFeature([], tf.int64), + 'h': tf.io.FixedLenFeature([], tf.int64), + 'w': tf.io.FixedLenFeature([], tf.int64), + 'd': tf.io.FixedLenFeature([], tf.int64), + # 'disease': tf.io.FixedLenFeature([], tf.int64), + 'patient': tf.io.FixedLenFeature([], tf.string) + } + ) + + phr_x, phr_y, phr_z, phr_c = tfrecord_features['phr_x'], tfrecord_features['phr_y'], tfrecord_features['phr_z'], tfrecord_features['phr_c'] + phr = tf.io.decode_raw(tfrecord_features['patch_high_res'], tf.float32) + phr = tf.reshape(phr, shape = tf.stack([phr_x, phr_y, phr_z, phr_c])) + + seg_x, seg_y, seg_z = \ + tfrecord_features['seg_x'], tfrecord_features['seg_y'], tfrecord_features['seg_z'] + seg = tf.io.decode_raw(tfrecord_features['seg'], tf.uint8) + seg = tf.reshape(seg, shape = tf.stack([seg_x, seg_y, seg_z])) + seg = tf.expand_dims(seg, axis = -1) + + # disease = tfrecord_features['disease'] + # patient = tfrecord_features['patient'] + patient = tf.io.decode_raw(tfrecord_features['patient'], tf.uint8) + sz = tf.size(patient) + patient = tf.pad(patient, [[0, 100 - sz]]) + + x, y, z = tfrecord_features['x'], tfrecord_features['y'], tfrecord_features['z'] + h, w, d = tfrecord_features['h'], tfrecord_features['w'], tfrecord_features['d'] + + return phr, seg, x, y, z, h, w, d, patient + +def loss_choice(loss_name='WCE', **loss_options): + """ Returns the loss function based on string name + Supported loss names: + WCE: weighted cross entropy loss + focal: focal loss ## not yet implemented + Supported loss options: + see individual losses arguments + """ + if loss_name == 'WCE': + if 'weight_alpha' in loss_options.keys(): + return weighted_cross_entropy(weight_alpha=loss_options['weight_alpha']) + else: + return weighted_cross_entropy() + elif loss_name == 'focal': + print('not yet implemented') + exit + +def create_model(c, freeze=None, cpt=None): + # Mixed precision training and dynamic loss scaling + tf.keras.mixed_precision.experimental.set_policy('mixed_float16') + optimizer = tf.keras.optimizers.Adam(learning_rate=c.model_params.learning_rate) + # optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, "dynamic") ## no need for that when using model.fit() + + # Enabling Accelerated Linear Algebra + # tf.config.optimizer.set_jit(True) + + mirrored_strategy = tf.distribute.MirroredStrategy() + + with mirrored_strategy.scope(): + model = UNet() + METRICS = [ + # tf.keras.metrics.Precision(name='precision'), + # tf.keras.metrics.BinaryAccuracy(name='accuracy'), + # tf.keras.metrics.Recall(name='recall'), + soft_dice() + ] + if freeze is not None: + freeze_layers(model, instring=freeze) + model.compile(optimizer=optimizer, + loss=loss_choice(c.model_params.training_loss, **c.model_params.loss_options),#soft_dice(_type="loss"), #tf.keras.losses.BinaryCrossentropy(from_logits=True), #weightedLoss(tf.keras.losses.BinaryCrossentropy(from_logits=True), {0:0.95, 1:0.05}), + metrics=METRICS) + + # Loading model within the scope of Distributed Strategy if applicable + if c.load_model is not None: + model.load_weights(c.load_model) + + return model + +def train(train_data, c, notes=None, validation_data=None, freeze=None): + # Create model + model = create_model(c, freeze=freeze) + + # Model saving - checkpoint definition + checkpoint_path = str(c.modelDir / "cp-{epoch:03d}.ckpt") + # Create a callback that saves the model's weights + cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, + save_weights_only=True, + verbose=1, + save_freq='epoch', + monitor='loss', + save_best_only=True) + + # TensorBoard setup + logDir = c.tfboardDir / (c.timestamp + "_" + notes) + print(f"\ntensorboard --logdir '{logDir}'\n") + tb_callback = tf.keras.callbacks.TensorBoard(log_dir=logDir, + # histogram_freq=1, # How often to log histogram visualizations + # embeddings_freq=1, # How often to log embedding visualizations + # update_freq=30, + # profile_batch=10 + ) + + # Show images on TensorBoard + image_history_callback = ImageHistory(tensorboard_dir=logDir/'images', data=train_data, num_images_to_show=c.model_params.num_images_to_show) + + progress_bar = 1 if c.interactive else 2 + # Start training + model_history = model.fit(train_data, + initial_epoch=c.starting_epoch, + epochs=c.model_params.num_epochs+c.starting_epoch, + # batch_size=c.model_params.batch_size, + # steps_per_epoch=1000, #TODO what does that mean? -- previously while True loop to go through every example + callbacks=[cp_callback, tb_callback, ], #image_history_callback + validation_data=validation_data, + verbose=progress_bar) # progress bar (1), epoch only (2) + + # Save model + tf.saved_model.save(model, str(c.modelDir)) + return model_history + +def predict(test_dataset, c, presize=False): + """ Predict the given dataset + Args: + test_dataset: tf.Dataset object + c: config file + presize: precompute the size of dataset to display progress bar correctly (adds an upfront cost) + """ + # Load model + print("Load model...") + model = create_model(c) + + # Loop over the dataset to run prediction and save info + if presize: + num_batches=0 + print("Evaluating dataset size...") + for batch in test_dataset: + num_batches+=1 + else: + num_batches=None + pred_list=[] + from utils.tic_toc import tic, toc + c.logging = {f'Predict_batch({c.model_params.batch_size})': []} + loop_dataset = tqdm(test_dataset, total=num_batches, desc="Predicting batches", position=0) + for batch in loop_dataset: + tic() + batch_size = batch[0].shape[0] + pred_batch = model.predict(batch, batch_size=1) + # loop over the batch size + for id_b in range(batch_size): + img, seg_ini = batch[0][id_b], batch[1][id_b] + x, y, z, h, w, d, patient = batch[2][id_b].numpy(), batch[3][id_b].numpy(), batch[4][id_b].numpy(), batch[5][id_b].numpy(), batch[6][id_b].numpy(), batch[7][id_b].numpy(), batch[8][id_b].numpy(), + pred = pred_batch[id_b] + pred_list.append((x, y, z, h, w, d, patient, pred)) + c.logging[f'Predict_batch({c.model_params.batch_size})'].append(toc(print_to_screen=False)) + # Assemble the prediction into complete output file and write those to a folder + X = [p[0] for p in pred_list] + Y = [p[1] for p in pred_list] + Z = [p[2] for p in pred_list] + cpts = np.stack([X, Y, Z], axis = 1) + H = [p[3] for p in pred_list] + W = [p[4] for p in pred_list] + D = [p[5] for p in pred_list] + shape = np.stack([H, W, D], axis = 1) + patients = [p[6] for p in pred_list] + patches_pred = [p[7] for p in pred_list] + + patients_str = [] + loop = tqdm(list(patients), desc="Cleaning up patient IDs", position=0) + for p in loop: + pstr = "".join([chr(c) for c in p]).replace("\x00", "") + patients_str.append(pstr) + + patients_str = np.array(patients_str) + patients_list = list(np.unique(patients_str)) + + # Stitching patches and saving images + c.logging['Stitch_preprocess'] = [] + c.logging['Stitch_assemble'] = [] + c.logging['Stitch_writetodisk'] = [] + loop = tqdm(patients_list, desc="Stitching patches and saving segmentation", position=0) + for p in loop: + tic() + # Extract patches for certain patient + idx = np.where(patients_str == p)[0] + cpts_p = cpts[idx] + shape_p = shape[idx] + patches_pred = np.array(patches_pred) + patches_pred_p = patches_pred[idx] + c.logging['Stitch_preprocess'].append(toc(print_to_screen=False)) + # Assemble into segmentation and save to a file + tic() + seg = assemSegFromPatches_dir(shape_p[0], cpts_p, patches_pred_p) + c.logging['Stitch_assemble'].append(toc(print_to_screen=False)) + tic() + sitk.WriteImage(sitk.GetImageFromArray(seg), str(c.valoutDir / (p + ".nii.gz"))) + c.logging['Stitch_writetodisk'].append(toc(print_to_screen=False)) + + +def predict_batch(c): + from utils.datasplit import read_split + train_list, val_list, test_list = read_split(split_file = c.train_test_csv, im_dir=c.data_nh_1mm) + if c.predict_on == 'all': + img_info = train_list + val_list + test_list + elif c.predict_on == 'train': + img_info = train_list + elif c.predict_on == 'val': + img_info = val_list + elif c.predict_on == 'test': + img_info = test_list + img_paths = [k[0]/(k[1]+'_1x1x1.nii.gz') for k in img_info] + # Instantiate model if not already + tic() + model = create_model(c) + c.logging['Model loading'] = [toc(print_to_screen=False, restart=True)] + pred_list = [] + seg_list = [] + c.logging['Load image'] = [] + c.logging['Normalize image'] = [] + c.logging['Create patches'] = [] + c.logging['Predict image'] = [] + c.logging['Assemble patches'] = [] + c.logging['Save image to disk'] = [] + for img_path in tqdm(img_paths, desc='Predicting images', position=0): + try: + pred, seg = predict_single(img_path, c, model) + pred_list.append(pred) + seg_list.append(seg) + except Exception as err: + print('Issue in prediction with', img_path) + print(err) + continue + return pred_list, seg_list + +def predict_single(img_path, c, model=None, condition='NA', ignore_existing=False): + img_meta = ((img_path.parent), str(img_path.stem)[:-10], condition) + # Load image based on path + img, seg, patient, disease = load_image_seg_pair(img_meta) + c.logging['Load image'].append(toc(print_to_screen=False, restart=True)) + out_path = c.valoutDir / (patient + ".nii.gz") + if ignore_existing and out_path.is_file(): return + + # Instantiate model if not already + if model is None: model = create_model(c) + img = normalize_image(img) + c.logging['Normalize image'].append(toc(print_to_screen=False, restart=True)) + + # seg = normalize_image(seg) + # Crop patches + patches = crop_patches((img, seg, patient, disease, False, c.model_params.num_pos, c.model_params.num_neg), c.model_params.patchsize_multi_res, c.model_params.test_patch_spacing, c.model_params.segsize) + patches = img_to_model_format(patches) + img_to_predict = np.array(patches[0]) + c.logging['Create patches'].append(toc(print_to_screen=False, restart=True)) + + # Predict + patient_patch_prediction = model.predict(img_to_predict, batch_size=c.model_params.batch_size) + c.logging['Predict image'].append(toc(print_to_screen=False, restart=True)) + + # Reassemble the prediction + pred = assemSegFromPatches_dir(patches[3][0][:3], patches[2], patient_patch_prediction, position=-1) + pred = np.squeeze(pred) + c.logging['Assemble patches'].append(toc(print_to_screen=False, restart=True)) + sitk.WriteImage(sitk.GetImageFromArray(pred), str(out_path)) + c.logging['Save image to disk'].append(toc(print_to_screen=False, restart=True)) + + return pred, seg + +def img_to_model_format(patches): + img, seg, cpts, shape, disease, patient = [], [], [], [], [], [] + for patch in patches: + img.append(patch[0]) + seg.append(patch[1]) + cpts.append(patch[2]) + shape.append(patch[3]) + disease.append(patch[4]) + patient.append(patch[5]) + p_img = [i for i in img] + p_seg = [tf.cast(tf.expand_dims(seg, axis = -1), dtype=tf.uint8) for seg in seg] # change type to uint8 + p_cpts = [i for i in cpts] + p_shape = [i for i in shape] + p_disease = [i for i in disease] + p_patient = [i for i in patient] + return p_img, p_seg, p_cpts, p_shape, p_disease, p_patient \ No newline at end of file diff --git a/unet/utils/__init__.py b/unet/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/unet/utils/calculateDice_MICCAI.py b/unet/utils/calculateDice_MICCAI.py new file mode 100755 index 0000000..897675a --- /dev/null +++ b/unet/utils/calculateDice_MICCAI.py @@ -0,0 +1,125 @@ +### This script is used to calculate dice +import numpy as np +import SimpleITK as sitk +import os +import pandas as pd +pd.set_option('display.max_rows', None) +import csv +import shutil +from .dice import dice +from utils.multiprocessing import run_multiprocessing +from pathlib import Path +from scipy import ndimage + +def compute_volume(sitk_img:sitk.Image) -> float: + """ Computes volume of binary image + Arguments: + sitk_img: sitk binary image or path to it + Returns: + float: computed volume + """ + if isinstance(sitk_img, str) or isinstance(sitk_img, Path): + sitk_img = sitk.ReadImage(str(sitk_img)) + size_vox = float(sitk_img.GetMetaData('pixdim[1]')) * float(sitk_img.GetMetaData('pixdim[2]')) * float(sitk_img.GetMetaData('pixdim[3]')) + num_vox = np.sum(sitk.GetArrayFromImage(sitk_img)) + return num_vox*size_vox + +def compute_volume_batch(images, num_processes=12): + """ Compute volumes for a list of binary images + Arguments: + images: list(Path) + num_processes: int. Number of parallel processes to launch + """ + sitk_images = [sitk.ReadImage(str(image)) for image in images] + volumes = run_multiprocessing(compute_volume, images, title="Computing volumes", num_processes=num_processes) + return volumes + +def ignore_zeros(a,b,c): + try: + result = a / (b + c) + except ZeroDivisionError: + result = np.nan + return result + +def compute_stats(p, gt_dir, pred_dir, inference_only=False): + lesion_pred = pred_dir / (p + "_binary.nii.gz") + lesion_gt = gt_dir / (p + "_seg.nii.gz") + + if not os.path.exists(lesion_pred): + print(f"non-existent: {lesion_pred}") + return + + # Load prediction and manual segmentation + pred_sitk = sitk.ReadImage(str(lesion_pred)) + pred = sitk.GetArrayFromImage(pred_sitk) + + # Compute lesions volume + lesion_volume_pred = compute_volume(pred_sitk) + if inference_only: + return (p, lesion_volume_pred) + else: + if os.path.exists(lesion_gt): + gt_sitk = sitk.ReadImage(str(lesion_gt)) + gt = sitk.GetArrayFromImage(gt_sitk) + if gt.shape != pred.shape: + print( + f'WARNING: ground truth segmentation has a different ' + f'shape than image: {lesion_gt}.\nAdjusting from ' + f'{gt.shape} to {pred.shape}') + gt = ndimage.zoom( + gt, + tuple(d1 / d2 for d1, d2 in zip(pred.shape, gt.shape)), + order=0 + ) + print(gt.shape, np.unique(gt)) + else: + print(f"WARNING: non-existent ground truth file: {lesion_gt}") + gt = np.zeros_like(pred) + + # Compute lesions volume + lesion_volume_true = compute_volume(gt_sitk) if os.path.exists(lesion_gt) else np.nan + + # Calculate dice + try: + d = dice(gt, pred) + # Calculate TP, FP, TN, FN + TP = float(np.sum(np.logical_and(pred == 1, gt == 1))) + TN = float(np.sum(np.logical_and(pred == 0, gt == 0))) + FP = float(np.sum(np.logical_and(pred == 1, gt == 0))) + FN = float(np.sum(np.logical_and(pred == 0, gt == 1))) + + FPR = ignore_zeros(FP, FP, TN) + FNR = ignore_zeros(FN, FN, TP) + FDR = ignore_zeros(FP, FP, TP) # this is false discovery rate, and I think this is what bianca calculates and calls FPR (in bianca_overlap_measures) + + NPV = ignore_zeros(TN, TN, FN) + PPV = ignore_zeros(TP, TP, FP) + sens = ignore_zeros(TP, TP, FN) + spec = ignore_zeros(TN, TN, FP) + + FPR_biancaStyle = FP/(np.sum(pred == 1)) + FNR_biancaStyle = FN/(np.sum(gt == 1)) + except ValueError as ve: + print(f'For {p}: {ve}\n{lesion_pred}\n{lesion_gt}') + d = 'shape_error' + FPR, FNR, FDR, NPV, PPV, sens, spec = np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan + + # Print results to screen and return them + return (p, d, FPR, FNR, FDR, NPV, PPV, sens, spec, lesion_volume_pred, lesion_volume_true) + +def compute_stats_dir(gt_dir, pred_dir, outfile, num_processes=1, inference_only=False): + patients = [p.split(".nii.gz")[0].split('_binary')[0] for p in os.listdir(pred_dir) if p.endswith(".nii.gz")] + stats = run_multiprocessing( + compute_stats, + patients, + fixed_arguments={'gt_dir':gt_dir, 'pred_dir':pred_dir, 'inference_only': inference_only}, + title="Computing stats", + num_processes=num_processes) + if inference_only: + columns=('Patient', 'Predicted lesion volume') + else: + columns=('Patient', 'd', 'FPR', 'FNR', 'FDR', 'NPV', 'PPV', 'Sensitivity', 'Specificity', 'Predicted lesion volume', 'True lesion volume') + stats = pd.DataFrame(stats, columns=columns).set_index('Patient') + stats.to_csv(outfile) + print("\nStats calculations completed.") + return stats diff --git a/unet/utils/config.py b/unet/utils/config.py new file mode 100755 index 0000000..8fc29ab --- /dev/null +++ b/unet/utils/config.py @@ -0,0 +1,643 @@ +# This scripts houses some usual config and constants used in the network +import os +import subprocess +import shutil +import sys +import pandas as pd +from pathlib import Path +from multiprocessing import cpu_count +from socket import gethostname +from datetime import datetime +from .tic_toc import tic, toc + +class ModelParameters(): + """ Class defining the models parameters + """ + def __init__(self, + learning_rate = 1e-4, + num_epochs = 30, + bin_threshold = 0.7, + training_loss='WCE', + aug = 3, + segsize = (96, 96, 96)): + # Number image per tfrecord in train set + self.nTrainPerTfrecord = 1 + # Multi resolution patch size and spacing setting + self.segsize = segsize + self.patchsize_multi_res = [(1, segsize)] # [(resolution, patchsize)] + self.test_patch_spacing = tuple(x//2 for x in segsize) + # Visualization params + self.num_images_to_show = 4 + # Training patch params + self.num_pos = 30 + self.num_neg = 30 + # Number of augmentations + self.aug = aug + # Learning parameters + self.batch_size = 12 + self.shuffle_buffer = 100 + self.learning_rate = learning_rate + # Training parameters + self.num_epochs = num_epochs + # Binarization threshold + self.bin_threshold = bin_threshold + # Training loss + self.training_loss = training_loss + self.loss_options = {'weight_alpha': 0.9} + + def print_recap(self): + return + + + +class Config(object): + """ + Parameters + ---------- + root: string, required + Path to the directory root (subdivided into Data/ and models/) + datasource : string or Path, required + if string: self.root / "Data" / datasource + if Path: datasource + Directory name for the datasource of data. + experiment: string, required + Name of experiment (if train, will be the main model name). Needs to match the csv split file. + train: boolean + True for train, False for predict. + model: string + If train = False, this is the model name used to predict the data. + tmp_storage: string + Location of tfrecords. *None=in main exp directory. 'TMP'=$TMPDIR. + which_model: string + Chooses which model to load for predict. *'latest' -> latest | 'root' if checkpoint is at the root | 'your_name' -> this name. + nickname: string + Name of particular model or prediction directory. If already exists, append a timestamp to it (unless read_existing=True). + read_existing: boolean + *False. If set to True and new files are created, they might overwrite existing files. Example use includes prediction has already been run, and you want to access these to run further binarization or statistics. + GPU: int, str, or list of. Default None + Chooses which GPU(s) will be used for this session. If None, all available GPUs will be used. GPU=0 | GPU=[0,1] | GPU='1' accepted. + predict_on: string + Which images to run the prediction on without modifying the split file. *'test', 'all', 'train', 'val'. + * = default value + + Main functions + -------------- + self.print_recap(): print a summary of run parameters to screen + self.preprocess(force_preprocess): preprocess images (resample, tf records for train/val) + self.start(): start the run (train or test) + self.show_imgseg(): return the command to run itksnap with image, segmentation and prediction + self.binarize_and_stats(bin_threshold): binarize already existing predictions and compute stats + self.lesionwise(bin_threshold): compute lesion by lesion statistics + """ + def __init__(self, root, datasource, experiment, train, + model_params=ModelParameters(), + model=None, + tmp_storage=None, + which_model='latest', + nickname='', + GPU=None, + continued_training=False, + read_existing=False, + predictions_output_dir=None, + predict_on='test'): + self.GPU = self.set_GPU(GPU) + self.host = gethostname() + self.list_folders = [] + self.interactive = bool(getattr(sys, 'ps1', sys.flags.interactive)) + self.user = os.environ.get('USER') if os.environ.get('USER') else 'default' + self.logging = {} + self.type = 'training' if train else 'testing' # more precise definition further down + self.timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + if train==False and model is None: + sys.exit("In testing mode, please define the model variable.") + predict = not train + self.train = train + self.read_existing = read_existing + self.num_cpu = self.set_CPU() + self.model_params = model_params + self.model_params.batch_size = self.set_batch_size() + + self.root = Path(root) + self.add_dir(self.root) + # The name of the specific run + self.experiment = experiment + # CSV file that stores train/test split + self.train_test_csv = self.root / "models" / (self.experiment + ".csv") + + # Directory where config.txt is saved + if continued_training: + self.main_dir = self.root / "models" / model + self.experiment = model + # fine-tuning conditions + elif train and model is not None and model!=experiment: + self.experiment = model + "." + experiment # e.g. 293P.34U + self.main_dir = self.root / "models" / self.experiment + # normal training conditions, or further training of the same model + elif train: + self.main_dir = self.root / "models" / self.experiment + # testing conditions + else: + self.main_dir = self.root / "models" / model + + + ### IMAGES ### + # Root directory for the images + if isinstance(datasource, Path): + data_dir = datasource + else: + data_dir = self.root / "Data" / datasource + # Directory that contains the raw images and the segmented masks (patientid.nii.gz & patientid_seg.nii.gz) + self.data_raw = data_dir / "raw" + self.add_dir(self.data_raw) + # Directory that contains the preprocessed images + self.data_nh = data_dir / "preprocessed/noheader/" + self.add_dir(self.data_nh) + # Directory that contains the preprocessed images used for model training + self.data_nh_1mm = data_dir / "preprocessed/noheader_1mm/" + self.add_dir(self.data_nh_1mm) + + + ### MODEL ### + # Model + if train: + model_dir = self.root / "models" / self.experiment + # Directory that contains the model + if nickname: + if (self.read_existing or not (model_dir / "model" / (nickname)).exists() or continued_training): + self.modelDir = model_dir / "model" / nickname + else: + self.modelDir = model_dir / "model" / (nickname + '_' + self.timestamp) + else: + self.modelDir = model_dir / "model" / self.timestamp + self.add_dir(self.modelDir) + if (model and model!=experiment): + self.type = 'fine-tuning' + elif model: + self.type = 'continued_training' + else: + self.type = 'training' + else: + model_dir = self.root / "models" / model + # Directory that contains the model to load: + # will be the main directory, not one with the timestamps, + # so copy in main directory the model to use for predictions + self.modelDir = self.choose_model_dir(model_dir / "model", which_model) + # self.add_dir(self.modelDir) #TODO this should exist! + # Load model + if model is not None: + self.load_model, self.starting_epoch = self.load_saved_model(model, which_model) + else: + self.load_model = None + self.starting_epoch = 0 + # tfrecords + if train: + if tmp_storage: + if tmp_storage=='TMP' and os.getenv('TMPDIR'): + tmp_dir = Path(os.getenv('TMPDIR')) + elif tmp_storage=='scratch' and os.getenv('USER'): + tmp_dir = Path('/scratch') / os.getenv('USER') / 'tfrecords' + else: + print(f'tmp_storage value is unknown, reverting to None behavior.\nYou passed {tmp_storage}.\nPass either "TMP" or "scratch".\n') + tmp_storage = None + if tmp_storage and tmp_dir is not None: + self.tfrDir = tmp_dir / experiment / "tf_records" + else: + self.tfrDir = model_dir / "tf_records" + self.add_dir(self.tfrDir) + # Tensorboard + if train: + self.tfboardDir = model_dir / "tfboard" + self.add_dir(self.tfboardDir) + + ### PREDICTIONS ### + if predict: + self.predict_on = predict_on + if predictions_output_dir is None: + if nickname: + if self.read_existing or not (model_dir / "predictions_output" / nickname).exists(): + predictions_output_dir = model_dir / "predictions_output" / nickname + else: + predictions_output_dir = model_dir / "predictions_output" / (nickname + '_' + self.timestamp) + else: + predictions_output_dir = model_dir / "predictions_output" / self.timestamp + else: + predictions_output_dir = Path(predictions_output_dir) + self.add_dir(predictions_output_dir) + # Model output data in: + self.valoutDir = predictions_output_dir / "predictions" + self.add_dir(self.valoutDir) + # Resampled outptut data in: + self.resampledDir = predictions_output_dir / "resampled_to_originalspacing" + self.add_dir(self.resampledDir) + # Thresholded outptut data in: + self.thresholdedDir = predictions_output_dir / "binarized_masks" + self.add_dir(self.thresholdedDir) + # Dice scores and other analytics in: + self.diceDir = predictions_output_dir / "Analytics" + self.add_dir(self.diceDir) + + # disease to code + self.diseaseCode = { + "multiple_sclerosis_active": 0, + "ADEM": 1, + "adrenoleukodystrophy":2, + "BG_normal":3, + "normal":4, + "Normal": 4, + "CADASIL":56, + "CNS_lymphoma":5, + "High_Grade_Glioma":6, + "HIV_Encephalopathy":7, + "Low_Grade_Glioma":8, + "metastatic_disease":9, + "Metastases": 9, + "migraine":10, + "multiple_sclerosis_inactive":11, + "neuromyelitis_optica":12, + "PML":13, + "PRES":14, + "Susac_syndrome":15, + "SVID":16, + "toxic_leukoencephalopathy":17, + "multiple_sclerosis_tumefactive":18, + "vascular":19, + "Hypoxic_Ischemic_Encephalopathy_acute":20, + "Hypoxic_Ischemic_Encephalopathy_chronic":22, + "Carbon_Monoxide_acute":23, + "hemorrhage_chronic":24, + "Hemorrhage_chronic":24, + "Lymphoma":25, + "Hemorrhage_subacute":26, + "Nonketotic_Hyperglycemia":27, + "Seizures":28, + "Toxoplasmosis":29, + "Cryptococcus":30, + "Wilsons_Disease":31, + "Artery_of_Percheron_acute":32, + "infarct_chronic":33, + "Infarct_chronic":33, + "Deep_Vein_Thrombosis_subacute":34, + "Deep_Vein_Thrombosis_chronic":35, + "infarct_acute":36, + "Infarct_acute":36, + "Metastases":37, + "Hypoxic_Ischemic_Encephalopathy_subacute":38, + "Encephalitis":39, + "Hemorrhage_chronic":41, + "Deep_Vein_Thrombosis_acute":42, + "Creutzfeldt_Jakob":43, + "Wernicke_Encephalopathy":44, + "Manganese_Deposition":45, + "Carbon_Monoxide_subacute":46, + "Calcium_Deposition":47, + "Bilateral_Thalamic_Glioma":48, + "Hemorrhage_acute":49, + "Neuro_Behcet_Disease":50, + "Sarcoidosis":51, + "Neurofibromatosis":52, + "Abscess":53, + "infarct_subacute":54, + "Infarct_subacute":54, + "Carbon_Monoxide_chronic":55 + } + + # Elastic transform parameter + self.ep = { + 'rotation_x': 0.001, + 'rotation_y': 0.001, + 'rotation_z': 0.05, + 'trans_x': 0.01, + 'trans_y': 0.01, + 'trans_z': 0.01, + 'scale_x': 0.1, + 'scale_y': 0.1, + 'scale_z': 0.1, + 'df_x': 0.1, + 'df_y': 0.1, + 'df_z': 0.1 + } + + + ### HELPER FUNCTIONS ### + def preprocess(self, force_preprocess=False, CV=None, create_csv=False): + """ Run preprocessing for the images (resampling, train to tfrecords) + Parameters + ---------- + force_preprocess: boolean + whether to force the reprocessing in case it finds that it had been done before + CV: tuple (num_folds: int, test_size: float) + whether to run cross-validation. Parameters: number of folds, test size (as a fraction - if None, set automatically) + create_csv: boolean + whether to create split csv + """ + from image.process_image import strip_header_dir, resample_to_1mm_dir + if create_csv: + image_list = [im.name.split('.nii')[0] for im in self.data_raw.iterdir() + if im.name.endswith(('.nii.gz', '.nii')) + and not im.name.endswith(('_seg.nii.gz', '_seg.nii'))] + n = len(image_list) + _ = pd.DataFrame( + { + 'subject': image_list, + 'condition': ['NA']*n, + 'train_test': ['train' if self.train else 'test']*n + } + ).to_csv(self.train_test_csv, header=False, index=False) + else: + image_list = pd.read_csv(self.train_test_csv, header=None)[0].to_numpy() + self.data_nh.mkdir(parents=True, exist_ok=True) + self.data_nh_1mm.mkdir(parents=True, exist_ok=True) + strip_header_dir(data_raw_dir=self.data_raw, data_nh_dir=self.data_nh, image_list=image_list, num_processes=self.num_cpu, force_preprocess=force_preprocess) + resample_to_1mm_dir(data_nh_dir=self.data_nh, data_nh_1mm=self.data_nh_1mm, image_list=image_list, num_processes=self.num_cpu, force_preprocess=force_preprocess) + if self.train: + if CV: + from .datasplit import create_cv_splits + create_cv_splits(num_folds=CV[0], images_info_csv=self.train_test_csv, test_size=CV[1]) + self.model_params.nTrainPerTfrecord = 1 + preprocess = not self.check_for_tfrecords() + if force_preprocess or preprocess: + from image.preprocess_images import img_to_tfrecord + # Creation of tfrecords + self.make_dirs([self.tfrDir]) + img_to_tfrecord(self) + if CV: print(f"Next step is to rerun Config with experiment set to {self.experiment}_k where k is the fold to run.") + print('Preprocessing completed.') + + def start(self, notes_run='', freeze=None): + """ Run training or testing + Parameters + ---------- + notes_run: string + Add a few notes used in both the config.txt file for records (under `models/my_train_set/config.txt` or `predict/TRmy_train_set_TTmy_test_set`) and in tfboard name for easy filtering of experiments in Tensorboard + freeze: string + freeze all layers but those containing `freeze` in their name + """ + self.make_dirs(self.list_folders) + self.print_recap(notes=notes_run, print_=False, save=True) + + from model.train_test import dataset_from_tfr, create_model, train, predict, predict_batch + from .datasplit import img_list_from_split + if self.train: + # Load datasets + img_list = img_list_from_split(self.train_test_csv, 'train') if self.model_params.nTrainPerTfrecord == 1 else None + train_dataset = dataset_from_tfr(self.tfrDir, self.model_params.shuffle_buffer, self.model_params.batch_size, "train", img_list=img_list) + img_list = img_list_from_split(self.train_test_csv, 'val') if self.model_params.nTrainPerTfrecord == 1 else None + val_dataset = dataset_from_tfr(self.tfrDir, self.model_params.shuffle_buffer, self.model_params.batch_size, "val", img_list=img_list) + # Train model + train(train_dataset, self, notes=notes_run, validation_data=val_dataset, freeze=freeze) + else: + from image.process_image import resample_from_1mm_dir, binarize_dir + from utils.calculateDice_MICCAI import compute_stats_dir + # Run predictions + predict_batch(self) + # Process predictions + resample_from_1mm_dir(self, num_processes=self.num_cpu) + binarize_dir(self.resampledDir, self.thresholdedDir, threshold=self.model_params.bin_threshold, num_processes=self.num_cpu) + # Compute Dice score + stats_file_name = f'{self.diceDir}/stats_{self.experiment}_th_{self.model_params.bin_threshold}.csv' + compute_stats_dir(self.data_raw, self.thresholdedDir / str(self.model_params.bin_threshold), stats_file_name, num_processes=self.num_cpu) + print(stats_file_name) + print("Done!") + + def update_for_CV(self, fold:int): + """Updates `self.train_test_csv` to match cross-validation experiment + + Args: + fold (int): fold ID + + Raises: + FileNotFoundError: no CV folds csv available + FileNotFoundError: fold ID doesn't exist + """ + new_csv = self.train_test_csv.with_name( + f'{self.train_test_csv.stem}_{fold}.csv') + if not new_csv.exists(): + folds = [csv.stem.split('_')[-1] for csv in self.train_test_csv.parent.iterdir() + if csv.suffix == '.csv' and csv.name.startswith(self.train_test_csv.stem + '_')] + if folds == []: + raise FileNotFoundError(f'No CV folds csv available. Create one by running ' + f'`c.preprocess(CV=(n_folds,None))`') + raise FileNotFoundError(f'No csv found for {fold=}. Available {folds=}.') + self.train_test_csv = new_csv + + def binarize_and_stats(self, bin_threshold): + from image.process_image import binarize_dir + from utils.calculateDice_MICCAI import compute_stats_dir + binarize_dir(self.resampledDir, self.thresholdedDir, threshold=bin_threshold, num_processes=self.num_cpu) + # Compute Dice score + stats_file_name = f'{self.diceDir}/stats_{self.experiment}_th_{bin_threshold}.csv' + stats = compute_stats_dir(self.data_raw, self.thresholdedDir / str(bin_threshold), stats_file_name, num_processes=self.num_cpu) + print(stats_file_name) + return stats + + def check_for_tfrecords(self): + """ Returns True if tfrDir is not empty, False otherwise + """ + try: + n = len(list(self.tfrDir.iterdir())) + except: + print(f"The tfrecords directory is empty. Starting to preprocess.") + return False + if n == 0: + print(f"The tfrecords directory is empty. Starting to preprocess.") + return False + else: + return True + + def force_create(self, folder): + """Recreate directory, deleting previous one""" + if os.path.exists(folder) and os.path.isdir(folder): + shutil.rmtree(folder) + os.mkdir(folder) + + def add_dir(self, folder): + # if not self.read_existing: + self.list_folders.append(folder) + + def make_dirs(self, list_folders): + for folder in list_folders: + folder.mkdir(parents=True, exist_ok=True) + + def splitfile_exist(self): + if not os.path.isfile(self.train_test_csv): + sys.exit(f"The split file does not exist at this location:\n{self.train_test_csv}") + + def choose_model_dir(self, model_dir, which_dir): + if which_dir == 'latest': + list_tmp = [directory for directory in model_dir.iterdir() if not directory.is_file()] + latest_dir = max(list_tmp, key=os.path.getctime) + # removes empty dir, until the last one + while len(list_tmp)>1 and not any(latest_dir.iterdir()): + list_tmp.remove(latest_dir) + latest_dir = max(list_tmp, key=os.path.getctime) + return latest_dir + elif which_dir == 'root': + return model_dir + else: + return model_dir / which_dir + + def load_saved_model(self, model, which_model): + """ Find model checkpoint and corresponding epoch + Returns: + (latest_model, starting_epoch) + where: latest_model=path(str), starting_epoch=int + """ + from tensorflow.train import latest_checkpoint + model_dir = self.choose_model_dir(self.root / "models" / model / "model", which_model) + latest_model = latest_checkpoint(model_dir) + # if no model is found, set load_model to False + if latest_model is None: + err_message = ( + f"There is an issue loading the model specified:\n" + f"model name specified: {model}\n" + f"which model specified: {which_model}\n" + f"model directory inferred: {model_dir}\n" + f"Check if the above directory has a `checkpoint` file, a `.cpkt.data` and `.ckpt.index`." + ) + raise FileNotFoundErr(err_message) + else: + epoch = int(Path(latest_model).stem[3:]) # assuming the model name follows: cp-123.cpkt + return latest_model, epoch + + def set_CPU(self, CPU:int=6): + """Sets the number of CPUs to be used""" + CPU_env = os.environ.get('SLURM_CPUS_ON_NODE') + if CPU_env is not None: + CPU_env = int(CPU_env) + if CPU is not None: + max_CPU = max(CPU,CPU_env) + print(f'CPU set in config ({CPU}) and in environment ({CPU_env}) differ. Defaulting to {max_CPU}.') + return max_CPU + return CPU_env + return CPU + + + def set_GPU(self, GPU=None): + """ Returns a tuple (num_gpus, gpu_names) + """ + if GPU: + if type(GPU) is str: + GPU = list(map(int, GPU.split(','))) + elif type(GPU) is int: + GPU = [GPU] + GPU_str = ','.join(list(map(str, GPU))) + os.environ['CUDA_VISIBLE_DEVICES'] = GPU_str + else: + GPU_str = None + from tensorflow.config.experimental import list_physical_devices, set_visible_devices, list_logical_devices, set_memory_growth + gpus = list_physical_devices('GPU') + if not gpus and GPU is not None: + sys.exit(f"GPU specified: {GPU}. This device doesn't have any gpus.\n") + elif not gpus and GPU is None: + print("CPU job\n") + # only using CPUs here + return (0, 'CPU job') + elif gpus and GPU is not None: + if len(GPU) > len(gpus): + sys.exit(f"GPU specified: {GPU}. Available GPUs: {list(range(len(gpus)))}. Please choose another GPU.\n") + for gpu in gpus: + set_memory_growth(gpu, True) + if GPU_str: + prettyprint = f"CUDA ({GPU_str}) - TF ({','.join([gpu.name for gpu in gpus])})" + else: + prettyprint = ','.join([gpu.name for gpu in gpus]) + num_gpu = len(gpus) + return (num_gpu, prettyprint) + + def print_recap(self, notes="NA", print_=True, save=False): + line = "-"*12 + recap_dic = { + "Datetime:": self.timestamp, + "Experiment:": self.experiment, + "Type:": self.type, + "Split file:": self.train_test_csv, + "Model dir:": self.modelDir, + "Data dir:": self.data_nh_1mm, + f"GPU ({self.GPU[0]}):": self.GPU[1], + "CPU usage:": f"{self.num_cpu}/{cpu_count()}", + "Host:": self.host, + "User:": self.user, + "Notes:": notes, + } + if self.type in ['fine-tuning', 'continued_training', 'testing']: + recap_dic["Last checkpoint:"] = self.load_model + if self.type == 'testing': + recap_dic["Prediction dir:"] = self.valoutDir.parent + if self.train: + recap_dic["TfRecords:"] = self.tfrDir + recap_dic["Epochs:"] = self.model_params.num_epochs + recap_dic["Batch size:"] = self.model_params.batch_size + recap_dic["Learning rate:"] = self.model_params.learning_rate + + recap = f"\n{line*2} Parameters {line*2}\n" + for key, value in recap_dic.items(): + recap += f"{key:<15}{value}\n" + recap += f"{line*5}\n" + + if print_: + print(recap) + if save: + with open(self.main_dir / "config.txt", 'a') as config_file: + config_file.write(recap) + + def plot_logging(self): + """ Return figure and pandas dataframe used to make it + """ + if self.logging == {}: return + import seaborn as sns + import pandas as pd + import matplotlib.pyplot as plt + sns.set_theme() + sns.set_context("paper") + sns.set_style('whitegrid') + times = pd.DataFrame({key: pd.Series(value) for key, value in self.logging.items()}) + colors = sns.color_palette('Set3', len(times.columns)) + pp = sns.boxplot(data=times, palette=colors); + pp.set_title('Processing times'); + pp.set_ylabel('Time per iteration (s)'); + plt.yscale('log'); + plt.grid(which='major', linestyle='-'); + plt.grid(which='minor', linestyle='--'); + plt.xticks(rotation=30, horizontalalignment='right'); + return pp, times + + def show_imgseg(self, accession=None, no_true_seg=False, bin_threshold=None): + """Return the command to load the original image, segmentation, and the prediction in itksnap. If no accession is provided, will pick one at random within prediction directory. + """ + if bin_threshold is None: bin_threshold = self.model_params.bin_threshold + if accession is None: + import random + # get random accession number + accessions = [img.stem[:-4] for img in self.valoutDir.iterdir() if img.name.endswith('.nii.gz')] + accession = random.choice(accessions) + if no_true_seg: + COMMAND = f'itksnap -g {self.data_raw}/{accession}.nii.gz -s {self.thresholdedDir}/{bin_threshold}/{accession}_binary.nii.gz' + else: + COMMAND = f'itksnap -g {self.data_raw}/{accession}.nii.gz -s {self.data_raw}/{accession}_seg.nii.gz -o {self.thresholdedDir}/{bin_threshold}/{accession}_binary.nii.gz' + return COMMAND + + def set_batch_size(self): + """Return automatically computed batch size to maximize GPU(s) usage. + """ + if self.GPU[0] == 0: return 3 # if CPU job, default to batch size of 3 + gpu_memory = min(self.get_gpu_memory()) + num_GPUs = self.GPU[0] + voxels = self.model_params.patchsize_multi_res[0][1][0] * self.model_params.patchsize_multi_res[0][1][1] * self.model_params.patchsize_multi_res[0][1][2] + # Following formula + # 0.80: observed that peak memory was capped somewhere between 75% and 89% (75% worked, the 89% went OOM) + # 700 MiBs: weight of model, gradients, etc. (experimental) + # 1.299e-3: slope, experimental + # x0.5 if not using mixed precision (GPU Compute Capability < 7) + optimal_batch_size_per_gpu = 0.80 * (gpu_memory - 700) / (1.299e-3 * voxels) # * 1.7 + if self.host in ['titan.radiology.ucsf.edu', 'cronus.radiology.ucsf.edu', 'titan', 'cronus']: optimal_batch_size_per_gpu = optimal_batch_size_per_gpu / 2 + optimal_batch_size = int(optimal_batch_size_per_gpu * num_GPUs) + return optimal_batch_size + + + def get_gpu_memory(self): + _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1] + COMMAND = "nvidia-smi --query-gpu=memory.total --format=csv" + memory_free_info = _output_to_list(subprocess.check_output(COMMAND.split()))[1:] + memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)] + return memory_free_values + \ No newline at end of file diff --git a/unet/utils/datacheck.py b/unet/utils/datacheck.py new file mode 100644 index 0000000..c92ea52 --- /dev/null +++ b/unet/utils/datacheck.py @@ -0,0 +1,27 @@ +from pathlib import Path + +def diff_2lists(li1: list, li2: list) -> list: + return list(list(set(li2)-set(li1)) + list(set(li1)-set(li2))) + +def check_seg(data_dir: Path) -> {}: + """ Check if the number of segmentation files matches with the number of images + Return dictionary with list of missing segmentations and list of extra segmentations + """ + image_names = [f.stem[:-4] for f in data_dir.iterdir() if f.name.endswith('.nii.gz') and 'seg' not in f.name] + seg_names = [f.stem[:-8] for f in data_dir.iterdir() if f.name.endswith('_seg.nii.gz')] + + diff = diff_2lists(image_names, seg_names) + res = {'missing seg': [], 'extra seg': []} + for d in diff: + if d in image_names: + res['missing seg'].append(d) + else: + res['extra seg'].append(d) + return res + +def complete_images(data_dir: Path) -> []: + """ Return list of images that have a segmentation + """ + image_names = [f.stem[:-4] for f in data_dir.iterdir() if f.name.endswith('.nii.gz') and 'seg' not in f.name] + missing_data = check_seg(data_dir) + return diff_2lists(image_names, missing_data['missing seg']) \ No newline at end of file diff --git a/unet/utils/datasplit.py b/unet/utils/datasplit.py new file mode 100755 index 0000000..eae0ef8 --- /dev/null +++ b/unet/utils/datasplit.py @@ -0,0 +1,153 @@ +#%% This script splits the dataset and saves the split into a csv file. This should be run before the +# preprocessing script get run so that the split be read by the preprocessing script +import numpy as np +import csv +import os +import pandas as pd +from pathlib import Path + +#%% +def create_cv_splits(num_folds: int, images_info_csv: Path, test_size=None) -> None: + """ Create csv files for the number of folds with train/test examples chosen at random + Parameters + ---------- + num_folds: int + number of folds for the cross validation. Corresponds to the number of csv files that are created. + images_info_csv: importlib.Path + csv file path containing the list of images. The fold csvs will be stored in the same location. + test_size: float + percent of images to reserve for the test set. It will be floored. If None, set to num_folds/num_images. + """ + images_info = pd.read_csv(images_info_csv, names=['Patient', 'Condition', 'TT']) + num_images = len(images_info) + if test_size is None: + num_test = num_images // num_folds + elif not 0 <= test_size <= 1: + print(f'Error: test_size must be between 0 and 1. {test_size} given.') + return + else: + num_test = int(test_size*num_images) + # Shuffle list + images_info = images_info.sample(frac=1) + for k in range(num_folds): + images_info[:k*num_test][['TT']] = 'train' + images_info[k*num_test:(k+1)*num_test][['TT']] = 'test' + images_info[(k+1)*num_test:][['TT']] = 'train' + file_name = images_info_csv.parent / (images_info_csv.stem + f'_{k}.csv') + images_info.to_csv(file_name, index=False, header=False) + print(f'Saving to: {file_name}') + +#%% +def img_list_from_split(split_file, filterby): + images_info = pd.read_csv(split_file, names=['Patient', 'Condition', 'TT'], dtype={'Patient': str, 'Condition': str, 'TT': str}) + return list(images_info[images_info['TT'] == filterby]['Patient']) + + +# This function gives a list of train and test directory, and optionally save the +# train and test into a csv file +def divide_train_test(rootDir, trainRatio = 0.7, outDir = None): + # Divide train and test set + trainDirs = [] + testDirs = [] + + diseases = os.listdir(rootDir) + + for d in diseases: + diseaseDir = rootDir + d + "/" + patientDirs = [ (diseaseDir + p + "/", p, d) for p in os.listdir(diseaseDir)] + + num_patients = len(patientDirs) + idx = np.random.permutation(num_patients) + + trainDirs.extend([patientDirs[i] for i in idx[:int(trainRatio*num_patients)]]) + testDirs.extend([patientDirs[i] for i in idx[int(trainRatio*num_patients):]]) + + # Write out the train test split + if outDir is not None: + with open(outDir, 'w') as csvfile: + writer = csv.writer(csvfile) + for tr in trainDirs: + writer.writerow(list(tr) + ['train']) + for tr in testDirs: + writer.writerow(list(tr) + ['test']) + return trainDirs, testDirs +#%% +def create_train_test_split(inputdir, outDir, trainRatio = 0.8): + """ Create random split from images in directory specified below + """ + inputdir = Path(inputdir) + img_list = [nii for nii in inputdir.iterdir() if nii.is_file() and not nii.name.endswith("_seg.nii.gz")] + num_img = len(img_list) + idx = np.random.permutation(num_img) + train_img = [img_list[i].stem[:-4] for i in idx[:int(trainRatio*num_img)]] + test_img = [img_list[i].stem[:-4] for i in idx[int(trainRatio*num_img):]] + with open(outDir, 'w') as csvfile: + writer = csv.writer(csvfile) + for img in train_img: + writer.writerow([img, 'NA', 'train']) + for img in test_img: + writer.writerow([img, 'NA', 'test']) + +#%% +# This function read in the split file created by divide_train_test +def read_split(split_file, im_dir): + """ + Returns tuple (for train, val, and test) of lists of length 3 + - image_path + - patient ID + - disease + """ + trainDirs = [] + valDirs = [] + testDirs = [] + + # split = pd.read_csv(c.train_test_csv, names=["patient", "condition", "train_test"]) + # split.insert(0, "path", c.data_nh_1mm) + # split.drop(columns=["condition"], inplace=True) + # split_train = split[split["train_test"]=="train"] + # split_test = split[split["train_test"]=="test"] + # trainDirs = split_train.to_numpy() + # testDirs = split_test.to_numpy() + with open(split_file, 'r') as csvfile: + reader = csv.reader(csvfile) + + for row in reader: #TODO: currently includes disease (:2)?? + if "train" in row: + trainDirs.append([im_dir] + row[:2]) + elif "val" in row: + valDirs.append([im_dir] + row[:2]) + elif "test" in row: + testDirs.append([im_dir] + row[:2]) + return trainDirs, valDirs, testDirs + +def create_split(csv_file, origin="UCSF"): + df = pd.read_csv(csv_file, header="infer") + df = df[df['Origin']==origin] + return df + + +def create_3x_val(csv_file, origin): + df = create_split(csv_file, origin) + df1 = df.copy(deep=True) + df2 = df.copy(deep=True) + df3 = df.copy(deep=True) + n = len(df) + for i in range(n): + if i%3 == 0: + df1.iloc[i, df1.columns.get_loc('Origin')] = "train" + df2.iloc[i, df2.columns.get_loc('Origin')] = "train" + df3.iloc[i, df3.columns.get_loc('Origin')] = "test" + elif i%3 == 1: + df1.iloc[i, df1.columns.get_loc('Origin')] = "train" + df2.iloc[i, df2.columns.get_loc('Origin')] = "test" + df3.iloc[i, df3.columns.get_loc('Origin')] = "train" + elif i%3 == 2: + df1.iloc[i, df1.columns.get_loc('Origin')] = "test" + df2.iloc[i, df2.columns.get_loc('Origin')] = "train" + df3.iloc[i, df3.columns.get_loc('Origin')] = "train" + print(df1.head()) + print(df2.head()) + print(df3.head()) + + return (df1, df2, df3) + diff --git a/unet/utils/dice.py b/unet/utils/dice.py new file mode 100755 index 0000000..d0600a1 --- /dev/null +++ b/unet/utils/dice.py @@ -0,0 +1,38 @@ +import numpy as np + +def dice(im1, im2, empty_score=1.0): + """ + Computes the Dice coefficient, a measure of set similarity. + Parameters + ---------- + im1 : array-like, bool + Any array of arbitrary size. If not boolean, will be converted. + im2 : array-like, bool + Any other array of identical size. If not boolean, will be converted. + Returns + ------- + dice : float + Dice coefficient as a float on range [0,1]. + Maximum similarity = 1 + No similarity = 0 + Both are empty (sum eq to zero) = empty_score + + Notes + ----- + The order of inputs for `dice` is irrelevant. The result will be + identical if `im1` and `im2` are switched. + """ + im1 = np.asarray(im1).astype(np.bool) + im2 = np.asarray(im2).astype(np.bool) + + if im1.shape != im2.shape: + raise ValueError("Shape mismatch: im1 and im2 must have the same shape.") + + im_sum = im1.sum() + im2.sum() + if im_sum == 0: + return empty_score + + # Compute Dice coefficient + intersection = np.logical_and(im1, im2) + + return 2. * intersection.sum() / im_sum \ No newline at end of file diff --git a/unet/utils/multiprocessing.py b/unet/utils/multiprocessing.py new file mode 100644 index 0000000..e19597b --- /dev/null +++ b/unet/utils/multiprocessing.py @@ -0,0 +1,35 @@ +from tqdm.auto import tqdm +from multiprocessing import Pool, get_context +from functools import partial + + +def run_multiprocessing(func, iterable_arguments, fixed_arguments={}, num_processes=12, title="", position=0): + """ Performs the task indicated in `func` in parallel on `num_processes` CPUs + Arguments: + func: function to be executed in parallel + iterable_arguments: list of inputs to be served to func on different processes + fixed_arguments: dictionary containing other arguments necessary for func + num_processes: number of CPU to use + title: string to name the progress bar + position: position of tqdm bar. If set to -1, no progress bar. + Returns: + the results of func as a list + Example: + run_multiprocessing(binarize, images, {'prob_seg_dir':input_dir, 'bin_seg_dir':output_dir, 'threshold':0.6}) + """ + func = partial(func, **fixed_arguments) + leave_bar = position==0 + result_list_tqdm = [] + if num_processes > 1: + with get_context('fork').Pool(processes=num_processes) as pool: + if position==-1: + for result in pool.imap(func=func, iterable=iterable_arguments): + result_list_tqdm.append(result) + else: + for result in tqdm(pool.imap(func=func, iterable=iterable_arguments), total=len(iterable_arguments), desc=title, position=position, leave=leave_bar): + result_list_tqdm.append(result) + return result_list_tqdm + else: + for item in tqdm(iterable_arguments, desc=title, position=position, leave=leave_bar): + result_list_tqdm.append(func(item)) + return result_list_tqdm \ No newline at end of file diff --git a/unet/utils/tic_toc.py b/unet/utils/tic_toc.py new file mode 100755 index 0000000..f1ad293 --- /dev/null +++ b/unet/utils/tic_toc.py @@ -0,0 +1,18 @@ +# To measure timing +#' http://stackoverflow.com/questions/5849800/tic-toc-functions-analog-in-python +def tic(): + # Homemade version of matlab tic and toc functions + import time + global startTime_for_tictoc + startTime_for_tictoc = time.time() + +def toc(note='', print_to_screen=True, restart=False): + import time + if 'startTime_for_tictoc' in globals(): + elapsed_time = (time.time() - startTime_for_tictoc) + if print_to_screen: print(f"{note}. Elapsed time is {elapsed_time:.3f} seconds.") + if restart: tic() + return elapsed_time + else: + print("Toc: start time not set") + return