diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..baa4b2f --- /dev/null +++ b/.flake8 @@ -0,0 +1,10 @@ +[flake8] +max-line-length = 120 +# E122 - continuation line missing indentation or outdented +# E127 - continuation line over-indented for visual indent +# E201 - whitespace after '(' +# E202 - whitespace before ')' +# E225 - missing whitespace around operator +# E501 - line too long (managed by pylint, which ignores when line is too long because of directives) +# E731 - do not assign a lambda expression, use a def +extend-ignore = E127,E201,E202,E225,E501,E731 \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9974e60 --- /dev/null +++ b/.gitignore @@ -0,0 +1,172 @@ +# project specific +.vscode/ +*.bin +results/ +outputs/ +drafts/ +temp.* +*.tmp + +# MacOS +.DS_Store + +# Byte-compiled / optimized / DLL files +__pycache__ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 0000000..83f179d --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,4 @@ +[mypy] +# allow_redefinition = True +# local_partial_types +disable_error_code = import-untyped, str-bytes-safe diff --git a/README.md b/README.md new file mode 100644 index 0000000..75c796a --- /dev/null +++ b/README.md @@ -0,0 +1,495 @@ +# Working with the bravo_toolkit package + +## Preparing the environment + +Clone the repository to `local_repo_root` and create an enviroment with the requirements + +```bash +cd bravo_toolkit +conda create -n bravo python=3.9 +conda activate bravo +python -m pip install -r requirements.txt +``` + +Remember to always activate the environment and add `local_repo_root` to `PYTHONPATH` before running the commands in the sections below. +```bash +conda activate bravo +export PYTHONPATH= +``` + +## Encoding the submission files to the submission format + +The submission files must be in a directory tree or in a .tar file. Use one of the commands below: + +```bash +python -m bravo_toolkit.util.encode_submission +``` + +or + +```bash +python -m bravo_toolkit.util.encode_submission +``` + +## Expected format for the raw input images + +For the class prediction files (`_pred.png`): PNG format, 8-bits, grayscale, with each pixel with a value from 0 to 19 corresponding to the 19 classes of Cityscapes. + +For the confidence files (`_conf.png`): PNG format, 16-bits, grayscale, with each pixel with a value from 0 to 65535 corresponding to the confidence on the prediction (for the predicted class). For confidences originally computed on a continuous [0.0, 1.0] interval, we suggest discretizing them using the formula: `min(floor(conf*65536), 65535)` + +## Expected input directory tree for the submission + +The submission directory, or raw input tar file expected by `encode_submission` should have the following structure: + +``` +submission_directory_root or submission_raw.tar +├── bravo_ACDC +│   ├── fog +│   │   └── test +│   │   ├── GOPR0475 +│   │   │   ├── GOPR0475_frame_000247_rgb_anon_conf.png +│   │   │   ├── GOPR0475_frame_000247_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GOPR0475_frame_001060_rgb_anon_conf.png +│   │   │   └── GOPR0475_frame_001060_rgb_anon_pred.png +│   │   ├── GOPR0477 +│   │   │   ├── GOPR0477_frame_000794_rgb_anon_conf.png +│   │   │   ├── GOPR0477_frame_000794_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GOPR0477_frame_001032_rgb_anon_conf.png +│   │   │   └── GOPR0477_frame_001032_rgb_anon_pred.png +│   │   ├── GOPR0478 +│   │   │   ├── GOPR0478_frame_000259_rgb_anon_conf.png +│   │   │   ├── GOPR0478_frame_000259_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GOPR0478_frame_001023_rgb_anon_conf.png +│   │   │   └── GOPR0478_frame_001023_rgb_anon_pred.png +│   │   ├── GP010475 +│   │   │   ├── GP010475_frame_000006_rgb_anon_conf.png +│   │   │   ├── GP010475_frame_000006_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GP010475_frame_000831_rgb_anon_conf.png +│   │   │   └── GP010475_frame_000831_rgb_anon_pred.png +│   │   ├── GP010477 +│   │   │   ├── GP010477_frame_000001_rgb_anon_conf.png +│   │   │   ├── GP010477_frame_000001_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GP010477_frame_000224_rgb_anon_conf.png +│   │   │   └── GP010477_frame_000224_rgb_anon_pred.png +│   │   ├── GP010478 +│   │   │   ├── GP010478_frame_000032_rgb_anon_conf.png +│   │   │   ├── GP010478_frame_000032_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GP010478_frame_001061_rgb_anon_conf.png +│   │   │   └── GP010478_frame_001061_rgb_anon_pred.png +│   │   └── GP020478 +│   │   ├── GP020478_frame_000001_rgb_anon_conf.png +│   │   ├── GP020478_frame_000001_rgb_anon_pred.png +│   │   ├── ... +│   │   ├── GP020478_frame_000042_rgb_anon_conf.png +│   │   └── GP020478_frame_000042_rgb_anon_pred.png +│   ├── night +│   │   └── test +│   │   ├── GOPR0355 +│   │   │   ├── GOPR0355_frame_000138_rgb_anon_conf.png +│   │   │   ├── GOPR0355_frame_000138_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GOPR0355_frame_000214_rgb_anon_conf.png +│   │   │   └── GOPR0355_frame_000214_rgb_anon_pred.png +│   │   ├── GOPR0356 +│   │   │   ├── GOPR0356_frame_000065_rgb_anon_conf.png +│   │   │   ├── GOPR0356_frame_000065_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GOPR0356_frame_001001_rgb_anon_conf.png +│   │   │   └── GOPR0356_frame_001001_rgb_anon_pred.png +│   │   ├── GOPR0364 +│   │   │   ├── GOPR0364_frame_000001_rgb_anon_conf.png +│   │   │   ├── GOPR0364_frame_000001_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GOPR0364_frame_001053_rgb_anon_conf.png +│   │   │   └── GOPR0364_frame_001053_rgb_anon_pred.png +│   │   ├── GOPR0594 +│   │   │   ├── GOPR0594_frame_000114_rgb_anon_conf.png +│   │   │   ├── GOPR0594_frame_000114_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GOPR0594_frame_001060_rgb_anon_conf.png +│   │   │   └── GOPR0594_frame_001060_rgb_anon_pred.png +│   │   ├── GP010364 +│   │   │   ├── GP010364_frame_000009_rgb_anon_conf.png +│   │   │   ├── GP010364_frame_000009_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GP010364_frame_000443_rgb_anon_conf.png +│   │   │   └── GP010364_frame_000443_rgb_anon_pred.png +│   │   └── GP010594 +│   │   ├── GP010594_frame_000003_rgb_anon_conf.png +│   │   ├── GP010594_frame_000003_rgb_anon_pred.png +│   │   ├── ... +│   │   ├── GP010594_frame_000087_rgb_anon_conf.png +│   │   └── GP010594_frame_000087_rgb_anon_pred.png +│   ├── rain +│   │   └── test +│   │   ├── GOPR0572 +│   │   │   ├── GOPR0572_frame_000145_rgb_anon_conf.png +│   │   │   ├── GOPR0572_frame_000145_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GOPR0572_frame_001035_rgb_anon_conf.png +│   │   │   └── GOPR0572_frame_001035_rgb_anon_pred.png +│   │   ├── GOPR0573 +│   │   │   ├── GOPR0573_frame_000180_rgb_anon_conf.png +│   │   │   ├── GOPR0573_frame_000180_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GOPR0573_frame_001046_rgb_anon_conf.png +│   │   │   └── GOPR0573_frame_001046_rgb_anon_pred.png +│   │   ├── GP010400 +│   │   │   ├── GP010400_frame_000616_rgb_anon_conf.png +│   │   │   ├── GP010400_frame_000616_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GP010400_frame_001057_rgb_anon_conf.png +│   │   │   └── GP010400_frame_001057_rgb_anon_pred.png +│   │   ├── GP010402 +│   │   │   ├── GP010402_frame_000326_rgb_anon_conf.png +│   │   │   ├── GP010402_frame_000326_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GP010402_frame_001046_rgb_anon_conf.png +│   │   │   └── GP010402_frame_001046_rgb_anon_pred.png +│   │   ├── GP010571 +│   │   │   ├── GP010571_frame_000077_rgb_anon_conf.png +│   │   │   ├── GP010571_frame_000077_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GP010571_frame_001050_rgb_anon_conf.png +│   │   │   └── GP010571_frame_001050_rgb_anon_pred.png +│   │   ├── GP010572 +│   │   │   ├── GP010572_frame_000027_rgb_anon_conf.png +│   │   │   ├── GP010572_frame_000027_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GP010572_frame_000916_rgb_anon_conf.png +│   │   │   └── GP010572_frame_000916_rgb_anon_pred.png +│   │   ├── GP010573 +│   │   │   ├── GP010573_frame_000001_rgb_anon_conf.png +│   │   │   ├── GP010573_frame_000001_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GP010573_frame_001056_rgb_anon_conf.png +│   │   │   └── GP010573_frame_001056_rgb_anon_pred.png +│   │   ├── GP020400 +│   │   │   ├── GP020400_frame_000001_rgb_anon_conf.png +│   │   │   ├── GP020400_frame_000001_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GP020400_frame_000142_rgb_anon_conf.png +│   │   │   └── GP020400_frame_000142_rgb_anon_pred.png +│   │   ├── GP020571 +│   │   │   ├── GP020571_frame_000001_rgb_anon_conf.png +│   │   │   ├── GP020571_frame_000001_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GP020571_frame_000248_rgb_anon_conf.png +│   │   │   └── GP020571_frame_000248_rgb_anon_pred.png +│   │   ├── GP020573 +│   │   │   ├── GP020573_frame_000001_rgb_anon_conf.png +│   │   │   ├── GP020573_frame_000001_rgb_anon_pred.png +│   │   │   ├── ... +│   │   │   ├── GP020573_frame_000887_rgb_anon_conf.png +│   │   │   └── GP020573_frame_000887_rgb_anon_pred.png +│   │   └── GP030573 +│   │   ├── GP030573_frame_000073_rgb_anon_conf.png +│   │   ├── GP030573_frame_000073_rgb_anon_pred.png +│   │   ├── ... +│   │   ├── GP030573_frame_000914_rgb_anon_conf.png +│   │   └── GP030573_frame_000914_rgb_anon_pred.png +│   └── snow +│   └── test +│   ├── GOPR0122 +│   │   ├── GOPR0122_frame_000651_rgb_anon_conf.png +│   │   ├── GOPR0122_frame_000651_rgb_anon_pred.png +│   │   ├── ... +│   │   ├── GOPR0122_frame_001054_rgb_anon_conf.png +│   │   └── GOPR0122_frame_001054_rgb_anon_pred.png +│   ├── GOPR0176 +│   │   ├── GOPR0176_frame_000394_rgb_anon_conf.png +│   │   ├── GOPR0176_frame_000394_rgb_anon_pred.png +│   │   ├── ... +│   │   ├── GOPR0176_frame_000884_rgb_anon_conf.png +│   │   └── GOPR0176_frame_000884_rgb_anon_pred.png +│   ├── GOPR0494 +│   │   ├── GOPR0494_frame_000020_rgb_anon_conf.png +│   │   ├── GOPR0494_frame_000020_rgb_anon_pred.png +│   │   ├── ... +│   │   ├── GOPR0494_frame_001056_rgb_anon_conf.png +│   │   └── GOPR0494_frame_001056_rgb_anon_pred.png +│   ├── GOPR0496 +│   │   ├── GOPR0496_frame_000663_rgb_anon_conf.png +│   │   ├── GOPR0496_frame_000663_rgb_anon_pred.png +│   │   ├── ... +│   │   ├── GOPR0496_frame_001033_rgb_anon_conf.png +│   │   └── GOPR0496_frame_001033_rgb_anon_pred.png +│   ├── GP010122 +│   │   ├── GP010122_frame_000001_rgb_anon_conf.png +│   │   ├── GP010122_frame_000001_rgb_anon_pred.png +│   │   ├── ... +│   │   ├── GP010122_frame_000223_rgb_anon_conf.png +│   │   └── GP010122_frame_000223_rgb_anon_pred.png +│   ├── GP010176 +│   │   ├── GP010176_frame_000001_rgb_anon_conf.png +│   │   ├── GP010176_frame_000001_rgb_anon_pred.png +│   │   ├── ... +│   │   ├── GP010176_frame_001057_rgb_anon_conf.png +│   │   └── GP010176_frame_001057_rgb_anon_pred.png +│   ├── GP010494 +│   │   ├── GP010494_frame_000001_rgb_anon_conf.png +│   │   ├── GP010494_frame_000001_rgb_anon_pred.png +│   │   ├── ... +│   │   ├── GP010494_frame_000242_rgb_anon_conf.png +│   │   └── GP010494_frame_000242_rgb_anon_pred.png +│   ├── GP010496 +│   │   ├── GP010496_frame_000001_rgb_anon_conf.png +│   │   ├── GP010496_frame_000001_rgb_anon_pred.png +│   │   ├── ... +│   │   ├── GP010496_frame_000883_rgb_anon_conf.png +│   │   └── GP010496_frame_000883_rgb_anon_pred.png +│   ├── GP010606 +│   │   ├── GP010606_frame_000001_rgb_anon_conf.png +│   │   ├── GP010606_frame_000001_rgb_anon_pred.png +│   │   ├── ... +│   │   ├── GP010606_frame_001054_rgb_anon_conf.png +│   │   └── GP010606_frame_001054_rgb_anon_pred.png +│   ├── GP020176 +│   │   ├── GP020176_frame_000001_rgb_anon_conf.png +│   │   ├── GP020176_frame_000001_rgb_anon_pred.png +│   │   ├── ... +│   │   ├── GP020176_frame_001060_rgb_anon_conf.png +│   │   └── GP020176_frame_001060_rgb_anon_pred.png +│   ├── GP020606 +│   │   ├── GP020606_frame_000021_rgb_anon_conf.png +│   │   ├── GP020606_frame_000021_rgb_anon_pred.png +│   │   ├── ... +│   │   ├── GP020606_frame_000558_rgb_anon_conf.png +│   │   └── GP020606_frame_000558_rgb_anon_pred.png +│   └── GP030176 +│   ├── GP030176_frame_000001_rgb_anon_conf.png +│   ├── GP030176_frame_000001_rgb_anon_pred.png +│   ├── ... +│   ├── GP030176_frame_000369_rgb_anon_conf.png +│   └── GP030176_frame_000369_rgb_anon_pred.png +├── bravo_SMIYC +│   └── RoadAnomaly21 +│   └── images +│   ├── airplane0000_conf.png +│   ├── airplane0000_pred.png +│   ├── ... +│   ├── zebra0001_conf.png +│   └── zebra0001_pred.png +├── bravo_outofcontext +│   ├── frankfurt +│   │   ├── frankfurt_000000_000576_leftImg8bit_conf.png +│   │   ├── frankfurt_000000_000576_leftImg8bit_pred.png +│   │   ├── ... +│   │   ├── frankfurt_000001_082466_leftImg8bit_conf.png +│   │   └── frankfurt_000001_082466_leftImg8bit_pred.png +│   ├── lindau +│   │   ├── lindau_000000_000019_leftImg8bit_conf.png +│   │   ├── lindau_000000_000019_leftImg8bit_pred.png +│   │   ├── ... +│   │   ├── lindau_000058_000019_leftImg8bit_conf.png +│   │   └── lindau_000058_000019_leftImg8bit_pred.png +│   └── munster +│   ├── munster_000000_000019_leftImg8bit_conf.png +│   ├── munster_000000_000019_leftImg8bit_pred.png +│   ├── ... +│   ├── munster_000172_000019_leftImg8bit_conf.png +│   └── munster_000172_000019_leftImg8bit_pred.png +├── bravo_synflare +│   ├── frankfurt +│   │   ├── frankfurt_000000_000294_leftImg8bit_conf.png +│   │   ├── frankfurt_000000_000294_leftImg8bit_pred.png +│   │   ├── ... +│   │   ├── frankfurt_000001_082466_leftImg8bit_conf.png +│   │   └── frankfurt_000001_082466_leftImg8bit_pred.png +│   ├── lindau +│   │   ├── lindau_000000_000019_leftImg8bit_conf.png +│   │   ├── lindau_000000_000019_leftImg8bit_pred.png +│   │   ├── ... +│   │   ├── lindau_000058_000019_leftImg8bit_conf.png +│   │   └── lindau_000058_000019_leftImg8bit_pred.png +│   └── munster +│   ├── munster_000000_000019_leftImg8bit_conf.png +│   ├── munster_000000_000019_leftImg8bit_pred.png +│   ├── ... +│   ├── munster_000172_000019_leftImg8bit_conf.png +│   └── munster_000172_000019_leftImg8bit_pred.png +├── bravo_synobjs +│   ├── armchair +│   │   ├── 1_conf.png +│   │   ├── 1_pred.png +│   │   ├── ... +│   │   ├── 504_conf.png +│   │   ├── 504_pred.png +│   ├── baby +│   │   ├── 49_conf.png +│   │   ├── 49_pred.png +│   │   ├── ... +│   │   ├── 421_conf.png +│   │   ├── 421_pred.png +│   ├── bathtub +│   │   ├── 16_conf.png +│   │   ├── 16_pred.png +│   │   ├── ... +│   │   ├── 501_conf.png +│   │   ├── 501_pred.png +│   ├── bench +│   │   ├── 0_conf.png +│   │   ├── 0_pred.png +│   │   ├── ... +│   │   ├── 423_conf.png +│   │   ├── 423_pred.png +│   ├── billboard +│   │   ├── 134_conf.png +│   │   ├── 134_pred.png +│   │   ├── ... +│   │   ├── 461_conf.png +│   │   └── 461_pred.png +│   ├── box +│   │   ├── 58_conf.png +│   │   ├── 58_pred.png +│   │   ├── ... +│   │   ├── 381_conf.png +│   │   ├── 381_pred.png +│   ├── cheetah +│   │   ├── 14_conf.png +│   │   ├── 14_pred.png +│   │   ├── ... +│   │   ├── 500_conf.png +│   │   ├── 500_pred.png +│   ├── chimpanzee +│   │   ├── 0_conf.png +│   │   ├── 0_pred.png +│   │   ├── ... +│   │   ├── 468_conf.png +│   │   ├── 468_pred.png +│   ├── elephant +│   │   ├── 9_conf.png +│   │   └── 9_pred.png +│   │   ├── ... +│   │   ├── 441_conf.png +│   │   ├── 441_pred.png +│   ├── flamingo +│   │   ├── 5_conf.png +│   │   ├── 5_pred.png +│   │   ├── ... +│   │   ├── 482_conf.png +│   │   ├── 482_pred.png +│   ├── giraffe +│   │   ├── 8_conf.png +│   │   └── 8_pred.png +│   │   ├── ... +│   │   ├── 510_conf.png +│   │   ├── 510_pred.png +│   ├── gorilla +│   │   ├── 4_conf.png +│   │   ├── 4_pred.png +│   │   ├── ... +│   │   ├── 493_conf.png +│   │   ├── 493_pred.png +│   ├── hippopotamus +│   │   ├── 29_conf.png +│   │   ├── 29_pred.png +│   │   ├── ... +│   │   ├── 442_conf.png +│   │   ├── 442_pred.png +│   ├── kangaroo +│   │   ├── 6_conf.png +│   │   ├── 6_pred.png +│   │   ├── ... +│   │   ├── 495_conf.png +│   │   ├── 495_pred.png +│   ├── koala +│   │   ├── 0_conf.png +│   │   ├── 0_pred.png +│   │   ├── ... +│   │   ├── 489_conf.png +│   │   ├── 489_pred.png +│   ├── lion +│   │   ├── 7_conf.png +│   │   ├── 7_pred.png +│   │   ├── ... +│   │   ├── 503_conf.png +│   │   ├── 503_pred.png +│   ├── panda +│   │   ├── 5_conf.png +│   │   ├── 5_pred.png +│   │   ├── ... +│   │   ├── 494_conf.png +│   │   ├── 494_pred.png +│   ├── penguin +│   │   ├── 5_conf.png +│   │   ├── 5_pred.png +│   │   ├── ... +│   │   ├── 465_conf.png +│   │   ├── 465_pred.png +│   ├── plant +│   │   ├── 3_conf.png +│   │   ├── 3_pred.png +│   │   ├── ... +│   │   ├── 400_conf.png +│   │   ├── 400_pred.png +│   ├── polar bear +│   │   ├── 4_conf.png +│   │   ├── 4_pred.png +│   │   ├── ... +│   │   ├── 501_conf.png +│   │   ├── 501_pred.png +│   ├── sofa +│   │   ├── 3_conf.png +│   │   ├── 3_pred.png +│   │   ├── ... +│   │   ├── 453_conf.png +│   │   ├── 453_pred.png +│   ├── table +│   │   ├── 0_conf.png +│   │   ├── 0_pred.png +│   │   ├── ... +│   │   ├── 461_conf.png +│   │   └── 461_pred.png +│   ├── tiger +│   │   ├── 28_conf.png +│   │   ├── 28_pred.png +│   │   ├── ... +│   │   ├── 450_conf.png +│   │   ├── 450_pred.png +│   ├── toilet +│   │   ├── 15_conf.png +│   │   ├── 15_pred.png +│   │   ├── ... +│   │   ├── 504_conf.png +│   │   ├── 504_pred.png +│   ├── vase +│   │   ├── 3_conf.png +│   │   ├── 3_pred.png +│   │   ├── ... +│   │   ├── 506_conf.png +│   │   ├── 506_pred.png +│   └── zebra +│   ├── 5_conf.png +│   ├── 5_pred.png +│   ├── ... +│   ├── 499_conf.png +│   ├── 499_pred.png +└── bravo_synrain + ├── frankfurt + │   ├── frankfurt_000000_000294_leftImg8bit_conf.png + │   ├── frankfurt_000000_000294_leftImg8bit_pred.png + │   ├── ... + │   ├── frankfurt_000001_083852_leftImg8bit_conf.png + │   └── frankfurt_000001_083852_leftImg8bit_pred.png + ├── lindau + │   ├── lindau_000000_000019_leftImg8bit_conf.png + │   ├── lindau_000000_000019_leftImg8bit_pred.png + │   ├── ... + │   ├── lindau_000058_000019_leftImg8bit_conf.png + │   └── lindau_000058_000019_leftImg8bit_pred.png + └── munster + ├── munster_000000_000019_leftImg8bit_conf.png + ├── munster_000000_000019_leftImg8bit_pred.png + ├── ... + ├── munster_000173_000019_leftImg8bit_conf.png + └── munster_000173_000019_leftImg8bit_pred.png +88 directories, 7802 files +``` + diff --git a/bravo_toolkit/__init__.py b/bravo_toolkit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bravo_toolkit/codec/__init__.py b/bravo_toolkit/codec/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bravo_toolkit/codec/bravo_codec.py b/bravo_toolkit/codec/bravo_codec.py new file mode 100644 index 0000000..2bb401d --- /dev/null +++ b/bravo_toolkit/codec/bravo_codec.py @@ -0,0 +1,239 @@ +""" +Bravo CODEC: Compression scheme for segmentation maps with confidence values + +This module provides functionalities for compressing and decompressing 2D arrays +representing segmentation maps with confidence values. The input arrays are +assumed to be 2D arrays. The scheme is intended to be used on the BRAVO Challenge. + +The main functions in this module are `bravo_encode` and `bravo_decode`. + +Functions +--------- +- `bravo_encode(class_array, confidence_array, ...)` + Encode a 2D array of class labels and a 2D array of confidence values into + a compressed byte-string. + +- `bravo_decode(encoded_bytes)` + Decode a BRAVO compressed byte-string back into a 2D array of class labels + and a 2D array of confidence values. + +Usage +----- + from bravo_codec import bravo_encode, bravo_decode + + class_array, confidence_array = your_segmentation_method(input_image) + + # Encoding + encoded_bytes = bravo_encode(class_array, confidence_array) + + # Decoding + decoded_class_array, decoded_confidence_array, header = bravo_decode(encoded_bytes) + +Notes +----- +- The class_array compression is lossless. +- The confidence_array compression is lossy, but the loss is controlled by the + quantization parameters. Use the default values for the BRAVO Challenge. +""" +import struct +from typing import Tuple +import zlib + +import numpy as np +import zstandard as zstd + + +HEADER_MAGIC = b"BV23" +HEADER_VERSION = 2 +HEADER_FORMAT = "<4sIIIIII" +COMPRESS_TECHNIQUE = 2 +COMPRESS_LEVEL = 9 # Compression level for zlib and zstd, from 1 to 9, -1 for default + + +def _compress(data: bytes) -> bytes: + if COMPRESS_TECHNIQUE == 1: + data = zlib.compress(data, level=COMPRESS_LEVEL) + elif COMPRESS_TECHNIQUE == 2: + cctx = zstd.ZstdCompressor(level=COMPRESS_LEVEL) + data = cctx.compress(data) + else: + assert False, "Invalid compression technique" + return data + + +def _decompress(data: bytes) -> bytes: + if COMPRESS_TECHNIQUE == 1: + data = zlib.decompress(data) + elif COMPRESS_TECHNIQUE == 2: + dctx = zstd.ZstdDecompressor() + data = dctx.decompress(data) + else: + assert False, "Invalid compression technique" + return data + + +def bravo_encode(class_array: np.ndarray[np.uint8], + confidence_array: np.ndarray[np.uint16], + confidence_indices: np.ndarray[np.uint32]) -> bytes: + """ + Encode a class array and confidence array into a BRAVO compressed byte-string. + + Parameters + ---------- + class_array : np.ndarray[np.uint8] + Array with class labels. Must be 2D. + confidence_array : np.ndarray[np.uint16] + confidence_indices : np.ndarray[np.uint32] + Array with the indices of the sampled confidence values. Must be 1D. Assumed to be sorted. + + Returns + ------- + bytes + Compressed byte-string + """ + # Checks input + if class_array.ndim != 2: + raise ValueError("class_array must be 2D") + if class_array.dtype != np.uint8: + raise ValueError("class_array must be of dtype np.uint8") + if confidence_array.ndim != 2: + raise ValueError("confidence_array must be 2D") + if confidence_array.dtype != np.uint16: + raise ValueError("confidence_array must be of dtype np.uint16") + if confidence_indices is not None: + if confidence_indices.ndim != 1: + raise ValueError("confidence_indices must be 1D") + if confidence_indices.dtype not in (np.uint32, np.int32, np.uint64, np.int64): + raise ValueError("confidence_indices must be of a large integral type") + if class_array.shape != confidence_array.shape: + raise ValueError("class_array and confidence_array must have the same shape") + + # Downsamples the confidence array if necessary + class_rows = class_array.shape[0] + class_cols = class_array.shape[1] + + # Gets class array bytes + class_array = class_array.ravel() + class_bytes = class_array.tobytes() + + # Gets confidence_array bytes + confidence_array = confidence_array.ravel() + if confidence_indices is not None: + confidence_array = confidence_array[confidence_indices] + confidence_bytes = confidence_array.tobytes() + + # Compresses both arrays + data = class_bytes + confidence_bytes + data = _compress(data) + + # Assembles the header with struct + header = struct.pack( + HEADER_FORMAT, + HEADER_MAGIC, + HEADER_VERSION, + class_rows, + class_cols, + confidence_indices.size if confidence_indices is not None else 0, + len(class_bytes), + len(confidence_bytes) + ) + + data = header + data + crc32 = zlib.crc32(data) + + # Returns the compressed byte-string + return data + struct.pack(" Tuple[np.ndarray[np.uint8], np.ndarray, dict]: + """ + Decode a BRAVO compressed byte-string into a class array and confidence array. The confidence array is NOT upsampled + to the original size, and downstream processing should take care of this if needed. + + Parameters + ---------- + encoded_bytes : bytes + The compressed byte-string. + dequantize : bool, default = False + If True, the confidence array is dequantized to np.float32. If False, the confidence array is kept as np.uint16. + + Returns + ------- + np.ndarray[np.uint8] + The class array. + np.ndarray + The confidence array. + If dequantize=True and quantize_levels > 0, the confidence array is restored to its original values as + np.float32. + If quantize_levels == 0, the confidence array is kept as the original np.uint8. + dict(str, Any) + The header information. + """ + + # Parse the header + header_size = struct.calcsize(HEADER_FORMAT) + header_bytes = encoded_bytes[:header_size] + header = struct.unpack(HEADER_FORMAT, header_bytes) + signature, version, class_rows, class_cols, confidence_size, class_len, confidence_len = header + + # Check the signature and version + if signature != HEADER_MAGIC: + raise ValueError("Invalid magic number in header") + if version != HEADER_VERSION: + raise ValueError("Invalid version number in header") + + # Check the CRC32 + crc32 = struct.unpack(" np.ndarray: + ''' + Get the float values of the quantization levels for the confidence array. + + Args: + - quantize_levels: int, default = 128 + Number of quantization levels. + - quantize_classes: int, default = 19 + Number of classes, ignored, included for backwards compatibility. + - **_kwargs: dict + Ignored keyword arguments, included for convenience for allowing calling this function with the **header + dictionary returned by bravo_decode. + + Returns: + - np.ndarray + The quantization levels as float values. + ''' + levels = (np.linspace(0., quantize_levels-1, quantize_levels, dtype=np.float32) + 0.5) / quantize_levels + return levels diff --git a/bravo_toolkit/codec/bravo_tarfile.py b/bravo_toolkit/codec/bravo_tarfile.py new file mode 100644 index 0000000..caf7ee8 --- /dev/null +++ b/bravo_toolkit/codec/bravo_tarfile.py @@ -0,0 +1,123 @@ +import logging +from contextlib import closing + +import cv2 +import numpy as np + + +logger = logging.getLogger('bravo_toolkit') + + +# Suffixes for the ground-truth tar files + +SPLIT_TO_GT_SUFFIX = { + 'ACDC': '_gt_labelTrainIds.png', + 'SMIYC': '_labels_semantic_fake.png', + 'outofcontext': '_gt_labelTrainIds.png', + 'synobjs': '_gt.png', + 'synflare': '_gt_labelTrainIds.png', + 'synrain': '_gt_labelTrainIds.png', +} +SPLIT_TO_MASK_SUFFIX = { + 'ACDC': '_gt_invIds.png', + 'SMIYC': '_labels_semantic.png', + 'outofcontext': '_gt_invIds.png', + 'synobjs': '_mask.png', + 'synflare': '_gt_invIds.png', + 'synrain': '_gt_invIds.png', +} + +# Suffixes for the new-style submission tar files + +SUBMISSION_SUFFIX = '_encoded.bin' +SAMPLES_SUFFIX = '_samples.bin' + +# Suffixes for the old-style submission tar files + +SPLIT_TO_PRED_SUFFIX = { + 'ACDC': '_rgb_anon_pred.png', + 'SMIYC': '_pred.png', + 'outofcontext': '_leftImg8bit_pred.png', + 'synobjs': '_pred.png', + 'synflare': '_leftImg8bit_pred.png', + 'synrain': '_leftImg8bit_pred.png', +} + +SPLIT_TO_CONF_SUFFIX = { + 'ACDC': '_rgb_anon_conf.png', + 'SMIYC': '_conf.png', + 'outofcontext': '_leftImg8bit_conf.png', + 'synobjs': '_conf.png', + 'synflare': '_leftImg8bit_conf.png', + 'synrain': '_leftImg8bit_conf.png', +} + +# Directories inside the tar files + +SPLIT_PREFIX = 'bravo_{split}/' + + +# Helper functions for extracting images from tar files + +def tar_extract_grayscale(tar, member, image_type='image', flag=cv2.IMREAD_GRAYSCALE): + '''Helper function for `tar_extract_image` with default flag=cv2.IMREAD_GRAYSCALE.''' + return tar_extract_image(tar, member, image_type, flag) + + +def tar_extract_image(tar, member, image_type='image', flag=cv2.IMREAD_UNCHANGED): + ''' + Extracts an image from a tar file member. + + Args: + tar (tarfile): tar file object + member (tarfile.TarInfo): tar file member + image_type (str): type of image to extract, for logging purposes, does not affect the extraction + flag (int): flag for cv2.imdecode, default is cv2.IMREAD_UNCHANGED + + Returns: + np.ndarray: image data + ''' + with closing(tar.extractfile(member)) as f: + img = extract_image(f, image_type=image_type, flag=flag) + return img + + +def extract_grayscale(reader, image_type='image', flag=cv2.IMREAD_GRAYSCALE): + '''Helper function for `extract_image` with default flag=cv2.IMREAD_GRAYSCALE.''' + return extract_image(reader, image_type, flag) + + +def extract_image(reader, image_type='image', flag=cv2.IMREAD_UNCHANGED): + ''' + Extracts an image from a reader object. + + Args: + reader (io.BufferedReader): reader object + image_type (str): type of image to extract, for logging purposes, does not affect the extraction + flag (int): flag for cv2.imdecode, default is cv2.IMREAD_UNCHANGED + + Returns: + np.ndarray: image data + ''' + content = reader.read() + file_bytes = np.asarray(bytearray(content), dtype=np.uint8) + img = cv2.imdecode(file_bytes, flag) + if img is None: + raise ValueError(f'Failed to decode {image_type} image') + return img + + +def tar_extract_file(tar, member): + ''' + Extracts a file from a tar file member. + + Args: + tar (tarfile): tar file object + member (tarfile.TarInfo): tar file member + + Returns: + bytes: file data + ''' + with closing(tar.extractfile(member)) as f: + file_data = f.read() + return file_data diff --git a/bravo_toolkit/eval/__init__.py b/bravo_toolkit/eval/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bravo_toolkit/eval/eval_script.py b/bravo_toolkit/eval/eval_script.py new file mode 100644 index 0000000..caede85 --- /dev/null +++ b/bravo_toolkit/eval/eval_script.py @@ -0,0 +1,714 @@ +import argparse +from contextlib import closing +from functools import cache +import json +import logging +import os +import sys +import tarfile + +import numpy as np +from tqdm import tqdm +import scipy.stats as stats + +from bravo_toolkit.codec.bravo_codec import bravo_decode +from bravo_toolkit.codec.bravo_tarfile import (SAMPLES_SUFFIX, SPLIT_PREFIX, SPLIT_TO_CONF_SUFFIX, SPLIT_TO_GT_SUFFIX, + SPLIT_TO_MASK_SUFFIX, SPLIT_TO_PRED_SUFFIX, SUBMISSION_SUFFIX, + tar_extract_file, tar_extract_grayscale, tar_extract_image) +from bravo_toolkit.eval.metrics import fast_cm, get_auprc, get_auroc, get_ece, get_tp_fp_counts, per_class_iou +from bravo_toolkit.util.sample_gt_pixels import SAMPLES_PER_IMG, decode_indices + + +DEBUG = False + +NUM_CLASSES = 19 +CONF_N_LEVELS = 65536 +eps = 1. / CONF_N_LEVELS +bias = eps / 2. +CONF_VALUES = np.linspace(0., 1.-eps, CONF_N_LEVELS) + bias + +logger = logging.getLogger('bravo_toolkit') + + +def tqqdm(iterable, *args, **kwargs): + if logger.getEffectiveLevel() <= logging.INFO: + return tqdm(iterable, *args, **kwargs) + return iterable + + +@cache +def get_ground_truth_info(gt_path): + '''Loads information from the ground-truth tar file.''' + gt_truths = [] + gt_invalids = [] + with closing(tarfile.open(gt_path, 'r')) as gt_data: + split = prefix = '//////////' + for member in gt_data.getmembers(): + if not member.isfile(): + continue + name = member.name + + # Find the split name + if not name.startswith(prefix): # Prefixes tend to occur in runs, so makes a quick check + for split in SPLIT_TO_GT_SUFFIX: + prefix = SPLIT_PREFIX.format(split=split) + if name.startswith(prefix): + break + else: + split = prefix = name + logger.warning('Unexpected file prefix in ground-truth tarfile: %s', name) + continue + + # Find the file type + if name.endswith(SPLIT_TO_GT_SUFFIX[split]): + gt_truths.append(member) + elif name.endswith(SPLIT_TO_MASK_SUFFIX[split]): + gt_invalids.append(member) + else: + logger.warning('Unexpected file in ground-truth tarfile: %s', name) + + if len(gt_truths) != len(gt_invalids): + raise RuntimeError(f'Invalid ground-truth tarfile at `{gt_path}`: # of labelTrainIds files ({len(gt_truths)}) ' + f'!= invIds files {len(gt_invalids)}.') + return gt_truths, gt_invalids + + +def validate_data(gt_path, submission_path, _extra_params): + ''' + Validates API call for the ELSA Challenges server. + + Because the submission files are so large, pre-validating the data would be more expensive than just doing it in + the evaluation, so this function is a no-op. + + Args: + gt_path (str): path to the ground-truth tar file + submission_path (str): path to the submission tar file + _extra_params (dict): additional parameters, not used + + Raises: + ValueError: if the submission is invalid + ''' + logger.info('validate_data started') + + gt_truths, _ = get_ground_truth_info(gt_path) + submissions = { member.name: False for member in gt_truths } + + logger.info('validate_data ground truth read') + + with closing(tarfile.open(submission_path, 'r')) as submission_data: + split = prefix = '//////////' + for member in submission_data.getmembers(): + if not member.isfile(): + continue + name = member.name + + # Find the split name + if not name.startswith(prefix): # Prefixes tend to occur in runs, so makes a quick check + for split in SPLIT_TO_GT_SUFFIX: + prefix = SPLIT_PREFIX.format(split=split) + if name.startswith(prefix): + break + else: + split = prefix = name + logger.warning('Unexpected file prefix in submissions tarfile: %s', name) + continue + + # Maps submission file to a ground-truth file + if name.endswith(SUBMISSION_SUFFIX): + gt_name = name[:-len(SUBMISSION_SUFFIX)] + SPLIT_TO_GT_SUFFIX[split] + else: + logger.warning('Unexpected file suffix in submissions tarfile: %s', name) + continue + + # Checks if ground-truth file exists and marks it as found + if gt_name not in submissions: + logger.warning('Submission file name has no matches in ground-truth tarfile: %s -> %s', name, gt_name) + else: + submissions[gt_name] = True + + # Checks if all ground-truth files have been found + missing = [name for name, found in submissions.items() if not found] + if missing: + first_missing = ', '.join(f'`{mf}`' for mf in missing[:3]) + error_msg = f'Missing {len(missing)} files in submission. First missing files: {first_missing}...' + logger.error('validate_data error: %s', error_msg) + raise ValueError(error_msg) + + logger.info('validate_data passed') + + +def get_curve_metrics(tp_counts, fp_counts, tpr_th=0.95): + ''' + Computes the curve-based scores for the given ground-truth and confidence values. + + Args: + tp_counts (np.ndarray): raw counts of true positives for each confidence level + fp_counts (np.ndarray): raw counts of false positives for each confidence level + tpr_th (float): threshold for TPR, default is 0.95 + + Returns: + auroc (float): Area Under the Receiver Operating Characteristic curve for the positive class + fpr_at_tpr_th (float): False Positive Rate at the given True Positive Rate threshold in the ROC curve + auprc_pos (float): Area Under the Precision-Recall curve for the positive class + auprc_neg (float): Area Under the Precision-Recall curve for the negative class + ece (float): Expected Calibration Error for positive class + ''' + auroc, tprs, fprs = get_auroc(tp_counts, fp_counts) + tpr_th_i = np.searchsorted(tprs, tpr_th, 'left') # index of the first element >= tpr_th, assumes tprs is sorted + fpr_at_tpr_th = fprs[tpr_th_i] + auprc_pos, _, _ = get_auprc(tp_counts, fp_counts) + ece = get_ece(tp_counts, tp_counts+fp_counts, confidence_values=CONF_VALUES, bins=15) + # To obtain the negative counts, we have to integrate the curve, invert it, and differentiate it from the other end + tp_cumm = np.concatenate(([0], np.cumsum(tp_counts))) + fp_cumm = np.concatenate(([0], np.cumsum(fp_counts))) + tn_cumm = fp_cumm[-1] - fp_cumm # Total of negatives - false positives + fn_cumm = tp_cumm[-1] - tp_cumm # Total of positives - true positives + tn_counts = np.diff(tn_cumm[::-1]) # We would reverse the arrays again to get the negative counts parallel to the + fn_counts = np.diff(fn_cumm[::-1]) # positive counts, but in this case we actually need the reversed values... + auprc_neg, _, _ = get_auprc(tn_counts, fn_counts) # ...for the reversed arrays act as if negating confidence values + return auroc, fpr_at_tpr_th, auprc_pos, auprc_neg, ece + + +def evaluate_bravo(*, + gt_data, + samples_data, + submission_data, + split_name, + other_names=tuple(), + gt_suffix=None, + mask_suffix=None, + submission_suffix=SUBMISSION_SUFFIX, + samples_suffix=SAMPLES_SUFFIX, + semantic_metrics=True, + invalid_metrics=False, + ood_scores=False, + compare_old=False, + compare_old_seed=1, + show_counters=False, + ): + ''' + Evaluate submission_data against gt_data for the given split_name and additional conditions. + + Args: + gt_data (tarfile): ground truth data + samples_data (tarfile): sample data (ignored if compare_old is True) + submission_data (tarfile): submission data + split_name (str): name of the split (ACDC, SMIYC, outofcontext, synflare, synobjs, synrain) + gt_suffix (str): suffix for loading ground truths, default is derived from split_name + mask_suffix (str): suffix for loading invalid masks, default is derived from split_name + submission_suffix (str): suffix for loading submission files, default is `SUBMISSION_SUFFIX` + samples_suffix (str): suffix for loading sample files, default is `SAMPLES_SUFFIX` + other_names (iterable of str): iterable of additional name substrings that must be present in ground truths + semantic_metrics (bool): compute semantic metrics, default is True + invalid_metrics (bool): compute invalid AUC metrics, default is False + ood_scores (bool): compute OOD scores, default is False + compare_old (bool): comparison mode (old submission format and 100k pixels for curve-based metrics, used for + test purposes only, do not use in production), default is False + compare_old_seed (int): seed for the comparison mode, default is 1 + show_counters (bool): show debug counters, default is False + Returns: + dict: evaluation results + ''' + # Get suffixes from the split name if not provided + gt_suffix = gt_suffix or SPLIT_TO_GT_SUFFIX[split_name] + mask_suffix = mask_suffix or SPLIT_TO_MASK_SUFFIX[split_name] + + def substring_in_name(name, *conditions): + for c in conditions: + if c is not None and c not in name: + return False + return True + + gts = [mem for mem in gt_data.getmembers() if substring_in_name(mem.name, split_name, gt_suffix, *other_names)] + n_images = len(gts) + + other_names_str = ' '.join(other_names) + logger.info('%s-%s: evaluation on %d images', split_name, other_names_str, n_images) + + # Acquire data from the tar files... + # ...ground truth files + logger.info('evaluate_bravo - reading ground truth files...') + gt_files = {} + mask_files = {} + samples_files = {} + sub_files = {} + evaluation_tuples = [] + for idx, gt_mem in enumerate(tqqdm(gts)): + if DEBUG and idx >= 2: + break + # The ground truth files are in the "right" order and may be read directly + gt_name = gt_mem.name + gt_files[gt_name] = tar_extract_grayscale(gt_data, gt_mem, 'ground-truth') + # The submission files are not necessarily in the order: store them to read sequentially in the next loop + base_name = gt_name[:-len(gt_suffix)] + sub_name = base_name + submission_suffix + mask_name = base_name + mask_suffix + samples_name = base_name + samples_suffix + sub_files[sub_name] = None + mask_files[mask_name] = None + samples_files[samples_name] = None + evaluation_tuples.append((gt_name, sub_name, mask_name, samples_name)) + + # ...mask ground truth files + logger.info('evaluate_bravo - reading ground-truth mask files...') + for mask_mem in tqqdm(gt_data.getmembers()): + sub_name = mask_mem.name + if sub_name in mask_files: + mask_files[sub_name] = tar_extract_grayscale(gt_data, mask_mem, 'mask') + missing_mask_files = [n for n, f in mask_files.items() if f is None] + if missing_mask_files: + error_msg = f'{len(missing_mask_files)} mask files not found in ground-truth tar: ' + \ + ', '.join(missing_mask_files[:5]) + '...' + logger.error('evaluate_bravo - %s', error_msg) + raise ValueError(error_msg) + + # ...samples files + if not compare_old: + logger.info('evaluate_bravo - reading samples files...') + for samp_mem in tqqdm(samples_data.getmembers()): + samp_name = samp_mem.name + if samp_name in samples_files: + samples_files[samp_name] = decode_indices(tar_extract_file(samples_data, samp_mem)) + + # ...submission files + pred_files = {} + conf_files = {} + if not compare_old: + logger.info('evaluate_bravo - reading submission files...') + for sub_mem in tqqdm(submission_data.getmembers()): + sub_name = sub_mem.name + if sub_name in sub_files: + f = submission_data.extractfile(sub_mem) + submission_raw = f.read() + f.close() + pred_files[sub_name], conf_files[sub_name], header = bravo_decode(submission_raw, dequantize=False) + if header['quantize_levels'] != CONF_N_LEVELS: + logger.error('evaluate_bravo - invalid header in submission file `%s`: %s', sub_name, header) + raise ValueError(f'Invalid header in submission file `{sub_name}`: {header}') + + # Performs evaluation + + # ...accumulated confusion matrices for class labels, valid pixels, and invalid pixels (in pixel counts) + cm = np.zeros((NUM_CLASSES, NUM_CLASSES)) + cm_valid = np.zeros((NUM_CLASSES, NUM_CLASSES)) + cm_invalid = np.zeros((NUM_CLASSES, NUM_CLASSES)) + + # ...accumulated true positive/false positive pixel counts for semantic curve-based metrics + # ......class labels, all pixels + gt_tp = np.zeros(CONF_N_LEVELS) # Prediction == Ground-truth + gt_fp = np.zeros(CONF_N_LEVELS) + # ......class labels, valid and invalid pixels + gt_tp_valid = np.zeros(CONF_N_LEVELS) + gt_fp_valid = np.zeros(CONF_N_LEVELS) + gt_tp_invalid = np.zeros(CONF_N_LEVELS) + gt_fp_invalid = np.zeros(CONF_N_LEVELS) + # ......ood detection + ood_tp = np.zeros(CONF_N_LEVELS) # Prediction == OOD (invalid zone) + ood_fp = np.zeros(CONF_N_LEVELS) + + log_p = log_mv = log_nv = log_mvnv = log_i = 0 # DEBUG counters, do not affect the results + all_invalid = 0 + + logger.info('evaluate_bravo - accumulating statistics...') + for evaluation_tuple in tqqdm(evaluation_tuples): + # Gets the data + gt_name, sub_name, mask_name, samples_name = evaluation_tuple + logger.debug('evaluate_bravo - processing files of `%s`', gt_name) + gt_file = gt_files[gt_name] + mask_file = mask_files[mask_name] + if compare_old: + pred_name = gt_name[:-len(gt_suffix)] + SPLIT_TO_PRED_SUFFIX[split_name] + conf_name = gt_name[:-len(gt_suffix)] + SPLIT_TO_CONF_SUFFIX[split_name] + pred_member = submission_data.getmember(pred_name) + conf_member = submission_data.getmember(conf_name) + pred_file = tar_extract_grayscale(submission_data, pred_member) + conf_file = tar_extract_image(submission_data, conf_member) + conf_indices = None + else: + pred_file = pred_files[sub_name] + conf_file = conf_files[sub_name] + conf_indices = samples_files[samples_name] + + # Check the dimensions of the images + if gt_file.shape != mask_file.shape: + logger.error('evaluate_bravo - ground-truth and mask dimensions mismatch for file `%s`: %s vs %s', + gt_name, gt_file.shape, mask_file.shape) + raise ValueError(f'Ground-truth and mask dimensions mismatch for file `{gt_name}`: ' + f'{gt_file.shape} vs {mask_file.shape}') + if gt_file.shape != pred_file.shape: + logger.error('evaluate_bravo - ground-truth and prediction dimensions mismatch for file `%s`: %s vs %s', + gt_name, gt_file.shape, pred_file.shape) + raise ValueError(f'Ground-truth and prediction dimensions mismatch for file `{gt_name}`: ' + f'{gt_file.shape} vs {pred_file.shape}') + + # Converts everything to 1D arrays + gt_file = gt_file.ravel() + mask_valid = mask_file.ravel() == 0 + pred_file = pred_file.ravel() + conf_file = conf_file.ravel() if compare_old else conf_file # already 1D in new format + + # Computes and accumulates the per-pixel predicted class-labels confusion matrices + # ...filters-out void class + non_void = gt_file != 255 + gt_file_nv = gt_file[non_void] + pred_file_nv = pred_file[non_void] + mask_valid_nv = mask_valid[non_void] + mask_invalid_nv = ~mask_valid_nv + # ...computes matrices + cm += fast_cm(gt_file_nv, pred_file_nv, NUM_CLASSES) + cm_valid += fast_cm(gt_file_nv[mask_valid_nv], pred_file_nv[mask_valid_nv], NUM_CLASSES) + cm_invalid += fast_cm(gt_file_nv[mask_invalid_nv], pred_file_nv[mask_invalid_nv], NUM_CLASSES) + + # Debug counters, no effect on metrics + if show_counters: + log_p += gt_file.size + log_mv += np.sum(mask_valid) + log_nv += np.sum(non_void) + log_mvnv += np.sum(mask_valid_nv) + log_i += 1 + + # Computes and accumulates the counts for curve-based metrics... + # ...subsample arrays + if compare_old: + # ...nothing is subsampled: samples everything + all_indices = np.arange(gt_file_nv.size) + if gt_file_nv.size > SAMPLES_PER_IMG: + # For the comparison with the old script, we need to subsample the data in this order + if log_i == 1: + logger.info('evaluate_bravo - compare_old with %s samples per image and seed %d', + f'{SAMPLES_PER_IMG:_}', compare_old_seed) + np.random.seed(compare_old_seed) + all_indices = np.random.choice(all_indices, SAMPLES_PER_IMG, replace=False) + valid_indices = np.nonzero(mask_valid_nv)[0] + if valid_indices.size > SAMPLES_PER_IMG: + valid_indices = np.random.choice(valid_indices, SAMPLES_PER_IMG, replace=False) + invalid_indices = np.nonzero(mask_invalid_nv)[0] + if invalid_indices.size > SAMPLES_PER_IMG: + invalid_indices = np.random.choice(invalid_indices, SAMPLES_PER_IMG, replace=False) + + # ...gets derived subsampled arrays (all, valid, and invalid independent) + conf_file_nv = conf_file[non_void] + class_right_all = gt_file_nv[all_indices] == pred_file_nv[all_indices] + class_right_valid = gt_file_nv[valid_indices] == pred_file_nv[valid_indices] + class_right_invalid = gt_file_nv[invalid_indices] == pred_file_nv[invalid_indices] + conf_all = conf_file_nv[all_indices] + conf_valid = conf_file_nv[valid_indices] + conf_invalid = conf_file_nv[invalid_indices] + gt_ood = mask_invalid_nv[all_indices] + + else: + # ...confidences are already subsampled: subsamples other data on the same indices + assert conf_indices is not None + gt_file = gt_file[conf_indices] + pred_file = pred_file[conf_indices] + mask_valid = mask_valid[conf_indices] + mask_invalid = ~mask_valid + + assert gt_file.shape == conf_file.shape, f'{gt_file.shape} != {conf_file.shape} ' \ + f'({conf_indices.size if conf_indices is not None else None})' + + # ...gets derived subsampled arrays (all aligned to the confidences) + class_right_all = gt_file == pred_file + class_right_valid = class_right_all[mask_valid] + class_right_invalid = class_right_all[mask_invalid] + conf_all = conf_file + conf_valid = conf_file[mask_valid] + conf_invalid = conf_file[mask_invalid] + gt_ood = mask_invalid + + all_invalid += np.sum(gt_ood) + + # ...gets cummulative true positives and false positives pixel counts for each confidence level + # ......class labels, all pixels + get_tp_fp_counts(class_right_all, conf_all, gt_tp, gt_fp, score_levels=CONF_N_LEVELS) + # ......class labels, valid pixels + get_tp_fp_counts(class_right_valid, conf_valid, gt_tp_valid, gt_fp_valid, score_levels=CONF_N_LEVELS) + # ......class labels, invalid pixels + get_tp_fp_counts(class_right_invalid, conf_invalid, gt_tp_invalid, gt_fp_invalid, score_levels=CONF_N_LEVELS) + # ......ood detection + doubt = CONF_N_LEVELS - 1 - conf_all + get_tp_fp_counts(gt_ood, doubt, ood_tp, ood_fp, score_levels=CONF_N_LEVELS) + + logger.log(logging.INFO if show_counters else logging.DEBUG, + 'log_p: %d, log_mv: %d, log_nv: %d, log_mvnv: %d, log_i: %d', log_p, log_mv, log_nv, log_mvnv, log_i) + + logger.info('evaluate_bravo - computing metrics...') + + miou = iou = miou_valid = miou_invalid = iou_valid = iou_invalid = None + auroc = fpr95 = auprc_success = auprc_error = ece = None + auroc_valid = fpr95_valid = auprc_success_valid = auprc_error_valid = ece_valid = None + auroc_invalid = fpr95_invalid = auprc_success_invalid = auprc_error_invalid = ece_invalid = None + auroc_ood = fpr95_ood = auprc_ood = None + + if semantic_metrics: + # Metrics based on the ground-truth class labels <= cm, cm_valid, cm_invalid + # ...mean intersection over union (mIoU) for all pixels, valid pixels, and invalid pixels + iou = per_class_iou(cm).tolist() + miou = np.nanmean(iou) + iou_valid = per_class_iou(cm_valid).tolist() + miou_valid = np.nanmean(iou_valid) + if all_invalid > 0: + iou_invalid = per_class_iou(cm_invalid).tolist() + miou_invalid = np.nanmean(iou_invalid) + # ...curve-based metrics <= gt_tp, gt_fp, gt_tp_valid, gt_fp_valid, gt_tp_invalid, gt_fp_invalid + logger.debug('all - tp: %s, fp: %s', np.sum(gt_tp), np.sum(gt_fp)) + logger.debug('valid - tp: %s, fp: %s', np.sum(gt_tp_valid), np.sum(gt_fp_valid)) + logger.debug('invalid - tp: %s, fp: %s', np.sum(gt_tp_invalid), np.sum(gt_fp_invalid)) + auroc, fpr95, auprc_success, auprc_error, ece = get_curve_metrics(gt_tp, gt_fp) + auroc_valid, fpr95_valid, auprc_success_valid, auprc_error_valid, ece_valid = \ + get_curve_metrics(gt_tp_valid, gt_fp_valid) + if invalid_metrics: + auroc_invalid, fpr95_invalid, auprc_success_invalid, auprc_error_invalid, ece_invalid = \ + get_curve_metrics(gt_tp_invalid, gt_fp_invalid) + + if ood_scores: + if all_invalid == 0: + logger.error('evaluate_bravo - no invalid pixels found for OOD detection') + raise ValueError('No invalid pixels found for OOD detection') + + # Curve-metrics based on ood detection <= ood_tp, ood_fp + logger.debug('ood - tp: %s, fp: %s', np.sum(ood_tp), np.sum(ood_fp)) + auroc_ood, fpr95_ood, auprc_ood, _, _ = get_curve_metrics(ood_tp, ood_fp) + + computed_metrics = { + 'miou': miou, + 'iou': iou, + 'ece': ece, + 'auroc': auroc, + 'auprc_success': auprc_success, + 'auprc_error': auprc_error, + 'fpr95': fpr95, + 'miou_valid': miou_valid, + 'iou_valid': iou_valid, + 'ece_valid': ece_valid, + 'auroc_valid': auroc_valid, + 'auprc_success_valid': auprc_success_valid, + 'auprc_error_valid': auprc_error_valid, + 'fpr95_valid': fpr95_valid, + 'miou_invalid': miou_invalid, + 'iou_invalid': iou_invalid, + 'ece_invalid': ece_invalid, + 'auroc_invalid': auroc_invalid, + 'auprc_success_invalid': auprc_success_invalid, + 'auprc_error_invalid': auprc_error_invalid, + 'fpr95_invalid': fpr95_invalid, + 'auroc_ood': auroc_ood, + 'auprc_ood': auprc_ood, + 'fpr95_ood': fpr95_ood, + } + return computed_metrics + + +BRAVO_SUBSETS = ['ACDCfog', 'ACDCrain', 'ACDCnight', 'ACDCsnow', 'synrain', 'SMIYC', 'synobjs', 'synflare', + 'outofcontext'] + + +def update_results(all_results, new_results, new_key): + all_results.update((f'{new_key}_{k}', v) for k, v in new_results.items() if v is not None) + + +def summarize_results(all_results, subsets): + scalars = {k: v for k, v in all_results.items() if np.isscalar(v)} + for s in subsets: + subset_scalars = np.array([v for k, v in scalars.items() if k.startswith(f'{s}_')]) + all_results[f'{s}_amean'] = np.mean(subset_scalars) + all_results[f'{s}_gmean'] = stats.gmean(subset_scalars) + all_results[f'{s}_hmean'] = stats.hmean(subset_scalars) + all_scalars = np.array([v for v in scalars.values()]) + all_results['bravo_amean'] = np.mean(all_scalars) + all_results['bravo_gmean'] = stats.gmean(all_scalars) + all_results['bravo_hmean'] = stats.hmean(all_scalars) + + +def default_evaluation_params(): + return { + 'subsets': BRAVO_SUBSETS, + 'compare_old': False, + 'compare_old_seed': 1, + 'samples_path': None, + 'show_counters': False, + } + + +def evaluate_method(gt_path, submission_path, extra_params=None): + logger.info('evaluate_method - computing metrics...') + extra_params = extra_params or {} + subsets = extra_params.pop('subsets', BRAVO_SUBSETS) + compare_old = extra_params.get('compare_old', False) + gt_data = tarfile.open(gt_path, 'r') + submission_data = tarfile.open(submission_path, 'r') + samples_path = extra_params.pop('samples_path') + if samples_path == '': + samples_path = os.path.join(os.path.dirname(gt_path), 'bravo_SAMPLING.tar') + if compare_old: + samples_data = None + else: + samples_data = tarfile.open(samples_path, 'r') + extra_params['samples_data'] = samples_data + all_results = {} + + if compare_old: + logger.warning('COMPARISON MODE: using old submission format and %s pixels for curve-based metrics.', + f'{SAMPLES_PER_IMG:_}') + + logger.info('...1 of 9 - ACDCfog') + if 'ACDCfog' in subsets: + results = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='ACDC', + other_names=['/fog/'], + **extra_params) + update_results(all_results, results, 'ACDCfog') + + logger.info('...2 of 9 - ACDCrain') + if 'ACDCrain' in subsets: + results = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='ACDC', + other_names=['/rain/'], + **extra_params) + update_results(all_results, results, 'ACDCrain') + + logger.info('...3 of 9 - ACDCnight') + if 'ACDCnight' in subsets: + results = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='ACDC', + other_names=['/night/'], + **extra_params) + update_results(all_results, results, 'ACDCnight') + + logger.info('...4 of 9 - ACDCsnow') + if 'ACDCsnow' in subsets: + results = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='ACDC', + other_names=['/snow/'], + **extra_params) + update_results(all_results, results, 'ACDCsnow') + + logger.info('...5 of 9 - synrain') + if 'synrain' in subsets: + results = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='synrain', + invalid_metrics=True, + ood_scores=True, + **extra_params) + update_results(all_results, results, 'synrain') + + logger.info('...6 of 9 - SMIYC') + if 'SMIYC' in subsets: + results = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='SMIYC', + semantic_metrics=False, + ood_scores=True, + **extra_params) + update_results(all_results, results, 'SMIYC') + + logger.info('...7 of 9 - synobjs') + if 'synobjs' in subsets: + results = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='synobjs', + ood_scores=True, + **extra_params) + update_results(all_results, results, 'synobjs') + + logger.info('...8 of 9 - synflare') + if 'synflare' in subsets: + results = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='synflare', + ood_scores=True, + **extra_params) + update_results(all_results, results, 'synflare') + + logger.info('...9 of 9 - outofcontext') + if 'outofcontext' in subsets: + results = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='outofcontext', + **extra_params) + update_results(all_results, results, 'outofcontext') + + summarize_results(all_results, subsets) + + # This format is expected by the ELSA BRAVO Challenge server + return { + 'method': all_results, + 'result': True, + } + + +def main(): + parser = argparse.ArgumentParser( + description='Evaluates submissions for the ELSA BRAVO Challenge.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('submission', help='path to submission tar file') + parser.add_argument('--gt', required=True, help='path to ground-truth tar file') + parser.add_argument('--samples', help='path to pixel samples tar file, required unles --compare_old is set') + parser.add_argument('--results', default='results.json', help='JSON file to store the computed metrics') + parser.add_argument('--skip_val', action='store_true', help='skips validation of the submission') + parser.add_argument('--skip_eval', action='store_true', help='skips computing the metrics') + parser.add_argument('--debug', action='store_true', help='enables extra verbose debug output') + parser.add_argument('--quiet', action='store_true', help='prints only errors and warnings') + parser.add_argument('--show_counters', action='store_true', help='shows debug counters out of debug mode') + parser.add_argument('--compare_old', nargs='?', default=False, const=True, help='enables comparison mode, where ' + 'the script will expect submissions on the old format, and subsample 100 000 pixels per image ' + 'for the curve-based metrics. It accepts an optional argument as the seed for the random ' + 'sampling.') + parser.add_argument('--subsets', default=['ALL'], nargs='+', choices=BRAVO_SUBSETS+['ALL'], metavar='SUBSETS', + help='tasks to evaluate: ALL for all tasks, or one or more items from this list, separated ' + 'by spaces: ' + ' '.join(BRAVO_SUBSETS)) + args = parser.parse_args() + + level = logging.WARNING if args.quiet else (logging.DEBUG if args.debug else logging.INFO) + logging.basicConfig(level=level, format='%(levelname)s: %(message)s') + + if args.skip_val and args.skip_eval: + logger.error('both --skip_val and --skip_eval are set, nothing to do.') + sys.exit(1) + + seed = 1 + if args.compare_old: + logger.warning('comparison mode is enabled, this is for testing purposes only, do not use in production') + if args.compare_old is not True: + seed = int(args.compare_old) + args.compare_old = True + else: + args.compare_old = False + if args.samples is None: + logger.error('--samples is required unless --compare_old is set') + exit(1) + + if not (args.skip_val or args.compare_old): + logger.info('validating...') + validate_data(args.gt, args.submission, None) + + if not args.skip_eval: + logger.info('evaluating...') + if args.subsets == ['ALL']: + args.subsets = BRAVO_SUBSETS + res = evaluate_method(args.gt, args.submission, + extra_params={'subsets': args.subsets, + 'compare_old': args.compare_old, + 'compare_old_seed': seed, + 'samples_path': args.samples, + 'show_counters': args.show_counters}) + res = res['method'] + + if not args.quiet: + logger.info('results:') + print(json.dumps(res, indent=4)) + + logger.info('saving results...') + with open(args.results, 'wt', encoding='utf-8') as json_file: + json.dump(res, json_file, indent=4) + + logger.info('done.') + + +if __name__ == '__main__': + main() diff --git a/bravo_toolkit/eval/metrics.py b/bravo_toolkit/eval/metrics.py new file mode 100644 index 0000000..63abbeb --- /dev/null +++ b/bravo_toolkit/eval/metrics.py @@ -0,0 +1,344 @@ +import numpy as np +from numpy import ndarray + + +def _get_ece_original(*, conf, pred, label, ECE_NUM_BINS=15, CONF_NUM_BINS=65535, DEBIAS=False): + ''' + Original implementation of the Expected Calibration Error (ECE) metric, before refactoring. + + Used for test purposes only. Use _get_ece_reference() instead for an equivalent reference implementation. + Use get_ece() on production code. + + The default parameters are set to match the original implementation. Set CONF_NUM_BINS=65536, DEBIAS=True for a + closer match to the new implementations + ''' + conf = conf.astype(np.float32) + # normalize conf to [0, 1] + if DEBIAS: + conf = (conf + 0.5) / CONF_NUM_BINS + else: + conf = conf / CONF_NUM_BINS + tau_tab = np.linspace(0, 1, ECE_NUM_BINS + 1) # Confidence bins + nb_items_bin = np.zeros(ECE_NUM_BINS) + acc_tab = np.zeros(ECE_NUM_BINS) # Empirical (true) confidence + mean_conf = np.zeros(ECE_NUM_BINS) # Predicted confidence + for i in np.arange(ECE_NUM_BINS): # Iterates over the bins + # Selects the items where the predicted max probability falls in the bin + # [tau_tab[i], tau_tab[i + 1)] + sec = (tau_tab[i + 1] > conf) & (conf >= tau_tab[i]) + nb_items_bin[i] = np.sum(sec) # Number of items in the bin + # Selects the predicted classes, and the true classes + class_pred_sec, y_sec = pred[sec], label[sec] + # Averages of the predicted max probabilities + mean_conf[i] = np.mean(conf[sec]) if nb_items_bin[i] > 0 else np.nan + # Computes the empirical confidence + acc_tab[i] = np.mean(class_pred_sec == y_sec) if nb_items_bin[i] > 0 else np.nan + mean_conf = mean_conf[nb_items_bin > 0] + acc_tab = acc_tab[nb_items_bin > 0] + nb_items_bin = nb_items_bin[nb_items_bin > 0] + if sum(nb_items_bin) != 0: + ece = np.average(np.absolute(mean_conf-acc_tab), weights=nb_items_bin.astype(np.float32)/np.sum(nb_items_bin)) + else: + raise ValueError('No samples found for ECE calculation.') + return ece + + +def _get_ece_reference(y_true, y_pred, y_conf, bin_min=0., bin_max=1., ece_bins=15): + ''' + Evaluates the Expected Calibration Error (ECE) - Reference implementation for testing and debugging. + + The data is split into `ece_bins` bins based on their confidence values. + + For each bin, the average confidence and accuracy are computed. The ECE is then computed by: + ECE = sum_i (|avg_acc_i - avg_conf_i| * bin_count_i/n_samples) + + In short, the ECE checks if the confidence values are well-calibrated in comparison to the accuracies, in an + average weighted by the observations. In this implementation, the bins are equally spaced in the confidence range + [1./num_classes, 1.] + + The input data is not verified, and is assumed to be valid. + + Args: + y_true (np.ndarray): ground-truth class labels + y_pred (np.ndarray): predicted class labels + y_conf (np.ndarray): confidence for predicted class labels, in the range [1./num_classes, 1.] + ece_bins (int): number of bins for ECE metric, default is 15 + + Returns: + float: ECE + ''' + # Convert to 1D arrays + y_true = y_true.ravel() + y_pred = y_pred.ravel() + y_conf = y_conf.ravel() + + # Quantize confidence values into bins + bin_limits = np.linspace(bin_min, bin_max, ece_bins + 1) + bin_counts = np.zeros(ece_bins) + bin_indices = np.digitize(y_conf, bin_limits) - 1 # Allocate each sample to its bin + bin_indices = np.clip(bin_indices, 0, ece_bins - 1) # Clip the samples to valid bins + bin_counts = np.bincount(bin_indices, minlength=ece_bins) + + # Compute mean accuracy and confidence for each bin + bin_accs = np.zeros(ece_bins) # Empirical accuracy + bin_confs = np.zeros(ece_bins) # Predicted confidence + for i in range(ece_bins): + if bin_counts[i] == 0: + continue + bin_indices_i = bin_indices == i + bin_accs[i] = np.mean((y_pred == y_true)[bin_indices_i]) + bin_confs[i] = np.mean(y_conf[bin_indices_i]) + + # Gets average difference between confidence and accuracy, weighted by bin counts + has_data = bin_counts > 0 + bin_confs = bin_confs[has_data] + bin_accs = bin_accs[has_data] + bin_counts = bin_counts[has_data] + n_samples = np.sum(bin_counts) + if n_samples != 0: + ece_value = np.average(np.absolute(bin_confs - bin_accs), weights=bin_counts.astype(np.float32) / n_samples) + else: + raise ValueError('No samples found for ECE calculation.') + + return ece_value + + +def _ece_bin_subsample(all_counts, bins=15): + ''' + Subsamples the counts of samples in each bin to a smaller number of bins. + + If the number of bins is not a divisor of the original number of bins: + - The remaining bins are merged into the last bin, if the number of remaining bins is less than half the target. + - Otherwise the remaining bins are given an ind + + Args: + counts (tuple of np.ndarray): Raw counts to be subsampled (see `get_ece` for details on the counts format) + All arrays will be raveled to 1D and must have the same size. + bins (int): The target number of bins. Default is 15. + + Returns: + tuple of np.ndarray: The subsampled counts. + ''' + + # Get the number of bins in the original data and ensure that all counts have the same length + n = all_counts[0].size + if not all(c.size == n for c in all_counts): + raise ValueError('All arrays must have the same size.') + if bins <= 0 or bins > n: + raise ValueError('The number of bins must be a positive integer <= the size of the arrays.') + + # Calculate the subsampling factor + bin_limits = np.round(np.linspace(0, n, bins + 1)).astype(np.int64) + + # Use np.add.reduceat to compute the sums for each bin + binned = [np.add.reduceat(c.ravel(), bin_limits[:-1]) for c in all_counts] + + return binned + + +def get_ece(d_counts, t_counts, confidence_values, *, bins=15): + ''' + Calculate the Expected Calibration Error (ECE) for a binary classifier given the counts of true and false positive + samples at different levels of confidence for the classifier. The counts are assumed to be ordered by increasing + confidence level and to be for exact values of confidence (i.e. not cumulative counts). + + Args: + d_counts (np.ndarray): The number of correctly classified (diagonal) samples at each confidence level. + t_counts (np.ndarray): The number of total samples at each confidence level. + confidence_values (np.ndarray): The confidence values corresponding to each level, in the range [0, 1]. + bins (int): Desired number of bins for the ECE calculation. Default is 15. + + Returns: + float: The ECE of the classifier. + ''' + N = np.sum(t_counts) + if N == 0: + raise ValueError('No samples found for ECE calculation.') + + weighted_confidences = t_counts * confidence_values + d_counts, t_counts, weighted_confidences = _ece_bin_subsample((d_counts, t_counts, weighted_confidences), bins=bins) + + # Computes statistics and groups them into bins + weighted_confidences = weighted_confidences / t_counts + weights = t_counts.astype(np.float32) / N + accuracies = d_counts / t_counts + + # Remove levels with no samples + has_data = t_counts > 0 + weighted_confidences = weighted_confidences[has_data] + weights = weights[has_data] + accuracies = accuracies[has_data] + + # Compute ECE + ece_value = np.average(np.absolute(weighted_confidences - accuracies), weights=weights) + return ece_value + + +def get_auroc(tp_counts: ndarray, fp_counts: ndarray) -> tuple[float, ndarray, ndarray]: + ''' + Calculate the ROC curve and AUC for a binary classifier given the counts of true and false positive samples at + different levels of confidence for the classifier. The counts are assumed to be ordered by increasing confidence + level and to be for exact values of confidence (i.e. not cumulative counts). + + The counts are assumed to be non-negative: behavior is undefined if this is not the case. + + Args: + tp_counts (np.ndarray): The number of true positive samples at each confidence level. + fp_counts (np.ndarray): The number of false positive samples at each confidence level. + + Raise: + ValueError: If the counts have different sizes, if they have less than two elements, or if they are entirely + zero. + + Returns: + float: The AUC of the ROC curve. + np.ndarray: The true positive rates at each confidence level, starting from 0 and not decreasing. + np.ndarray: The false positive rates at each confidence level, starting from 0 and not decreasing. + ''' + if tp_counts.size != fp_counts.size: + raise ValueError('tp_counts and fp_counts must have the same length.') + if tp_counts.size <= 1: + raise ValueError('tp_counts and fp_counts must have at least two elements.') + + # Reverse cumulative sums get the counts up-to-and-above each confidence level + tp_counts = tp_counts.ravel() + fp_counts = fp_counts.ravel() + tp_cumsum = np.cumsum(tp_counts[::-1]) + fp_cumsum = np.cumsum(fp_counts[::-1]) + + # The last element of the cumulative sums is the total + p_total = tp_cumsum[-1] + n_total = fp_cumsum[-1] + if p_total == 0 or n_total == 0: + raise ValueError('tp_counts and fp_counts must have at least one non-zero element.') + + # Calculate TPR and FPR + tpr = tp_cumsum / p_total + fpr = fp_cumsum / n_total + + # Starts from an implicit (0, 0) point + tpr = np.concatenate(([0.], tpr)) + fpr = np.concatenate(([0.], fpr)) + + # Calculate AUC using the trapezoidal rule + auc = np.trapz(y=tpr, x=fpr) + + return auc, tpr, fpr + + +def get_auprc(tp_counts, fp_counts): + ''' + Calculate the PR curve and AUC for a binary classifier given the counts of true and false positive samples at + different levels of confidence for the classifier. The counts are assumed to be ordered by increasing confidence + level and to be for exact values of confidence (i.e. not cumulative counts). + + Args: + tp_counts (np.ndarray): The number of true positive samples at each confidence level. + fp_counts (np.ndarray): The number of false positive samples at each confidence level. + + Returns: + float: The AUC of the PR curve. + np.ndarray: The precision values at each confidence level, starting from 1. The values are not monotonic. + np.ndarray: The recall values at each confidence level, starting from 0 and not decreasing. + ''' + if tp_counts.size != fp_counts.size: + raise ValueError('tp_counts and fp_counts must have the same length.') + if tp_counts.size <= 1: + raise ValueError('tp_counts and fp_counts must have at least two elements.') + tp_counts = tp_counts.ravel() + fp_counts = fp_counts.ravel() + + # Reverse cumulative sums get the counts up-to-and-above each confidence level + tp_counts = tp_counts.ravel() + fp_counts = fp_counts.ravel() + tp_cumsum = np.cumsum(tp_counts[::-1]) + fp_cumsum = np.cumsum(fp_counts[::-1]) + + # The last element of the cumulative sums is the total + p_total = tp_cumsum[-1] + + # Calculate precision and recall + pp_cumsum = tp_cumsum + fp_cumsum # Cumulative sum of positive predictions + precision = np.where(pp_cumsum == 0, 1., tp_cumsum / pp_cumsum) + recall = (tp_cumsum / p_total) if p_total > 0 else tp_cumsum + + # Starts from an implicit (0, 1) point + precision = np.concatenate(([1.], precision)) + recall = np.concatenate(([0.], recall)) + + # Calculate AUC using step-function integral (recommended instead of trapezoidal rule for PR curves in Scikit-learn. + # See https://github.com/scikit-learn/scikit-learn/blob/8721245511de2f225ff5f9aa5f5fadce663cd4a3/sklearn/metrics/_ranking.py#L236C9-L236C67 + auc = np.sum(np.diff(recall) * precision[1:]) # precision[0] is always 1, making the integration formula correct + + return auc, precision, recall + + +def get_tp_fp_counts(y_true, y_score, tp_counts=None, fp_counts=None, *, score_levels=128): + ''' + Gets the true and false positive counts from ground-truth labels and scores. + + Args: + y_true (np.ndarray): Ground-truth labels, False for negatives and True for positives. + y_score (np.ndarray): Predicted scores, each in the interval [0, score_levels-1] + tp_counts (np.ndarray): Initial True positive counts at each score level. Default is None, which creates new + array initialized to zeros. + fp_counts (np.ndarray): Same, for False positive counts. + score_levels (int): Number of quantized score levels to use. Default is 128. + + Returns: + np.ndarray: True positive counts at each score level. Same as input tp_counts if not None. + np.ndarray: False positive counts at each score level. Same as input fp_counts if not None. + ''' + if y_true.shape != y_score.shape: + raise ValueError('y_true and y_score must have the same shape.') + y_true = y_true.ravel() + y_score = y_score.ravel() + # Sorts data by score + sorted_indices = np.argsort(y_score) + y_true = y_true[sorted_indices] + y_score = y_score[sorted_indices] + # Computes true and false positive counts + tp_counts = np.zeros(score_levels) if tp_counts is None else tp_counts + fp_counts = np.zeros(score_levels) if fp_counts is None else fp_counts + np.add.at(tp_counts, y_score, y_true) + np.add.at(fp_counts, y_score, ~y_true) + return tp_counts, fp_counts + + +def fast_cm(y_true, y_pred, n): + ''' + Fast computation of a confusion matrix from two arrays of labels. + + Args: + y_true (np.ndarray): array of true labels + y_pred (np.ndarray): array of predicted labels + n (int): number of classes + + Returns: + np.ndarray: confusion matrix, where rows are true labels and columns are predicted labels + ''' + y_true = y_true.ravel().astype(int) + y_pred = y_pred.ravel().astype(int) + k = (y_true < 0) | (y_true > n) | (y_pred < 0) | (y_pred > n) + if np.any(k): + raise ValueError('Invalid class values in ground-truth or prediction: ' + f'{np.unique(np.concatenate((y_true[k], y_pred[k])))}') + # Convert class numbers into indices of a simulated 2D array of shape (n, n) flattened into 1D, row-major + effective_indices = n * y_true + y_pred + # Count the occurrences of each index, reshaping the 1D array into a 2D array + return np.bincount(effective_indices, minlength=n ** 2).reshape(n, n) + + +def per_class_iou(cm): + '''' + Compute the Intersection over Union (IoU) for each class from a confusion matrix. + + Args: + cm (np.ndarray): n x n 2D confusion matrix (the orientation is not important, as the formula is symmetric) + + Returns: + np.ndarray: 1D array of IoU values for each of the n classes + ''' + # The diagonal contains the intersection of predicted and true labels + # The sum of rows (columns) is the union of predicted (true) labels (or vice-versa, depending on the orientation) + return np.diag(cm) / (cm.sum(1) + cm.sum(0) - np.diag(cm)) diff --git a/bravo_toolkit/util/__init__.py b/bravo_toolkit/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bravo_toolkit/util/compare_results.py b/bravo_toolkit/util/compare_results.py new file mode 100644 index 0000000..e69c0cd --- /dev/null +++ b/bravo_toolkit/util/compare_results.py @@ -0,0 +1,65 @@ +import argparse +import math +import json + + +def compare_dicts(dict1, dict2, key_path='', *, tolerance=0.001, decimals=4): + if not isinstance(dict1, dict) or not isinstance(dict2, dict): + raise ValueError("Both inputs should be dictionaries.") + + keys1, keys2 = set(dict1.keys()), set(dict2.keys()) + missing_in_dict2 = keys1 - keys2 + missing_in_dict1 = keys2 - keys1 + union_keys = keys1 | keys2 + + formatter = "{:.%df}" % decimals + + def f(value): + if isinstance(value, float): + return formatter.format(value) + if isinstance(value, str): + return repr(value) + return value + + for key in sorted(union_keys): + new_path = f"{key_path}.{key}" if key_path else key + + if key in missing_in_dict1: + print(f"{new_path}: missing in first") + continue + + if key in missing_in_dict2: + print(f"{new_path}: missing in second") + continue + + value1, value2 = dict1[key], dict2[key] + if isinstance(value1, dict) and isinstance(value2, dict): + compare_dicts(value1, value2, new_path) + elif isinstance(value1, float) and isinstance(value2, float): + if math.isnan(value1) and math.isnan(value2): + continue # Both are NaN, considered equal. + elif math.fabs(value1-value2) > tolerance or math.isnan(value1-value2): + print(f"{new_path}: {f(value1)} vs {f(value2)}, difference: {f(value1 - value2)}") + elif value1 != value2: + print(f"{new_path}: {f(value1)} vs {f(value2)}") + + +def main(): + parser = argparse.ArgumentParser(description="Compare two results JSON files", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("json_file1", metavar='results1.json', type=str, help="Path to the first results JSON file") + parser.add_argument("json_file2", metavar='results2.json', type=str, help="Path to the second results JSON file") + parser.add_argument("--tolerance", type=float, default=0.001, help="tolerance for comparing floats") + parser.add_argument("--decimals", type=int, default=4, help="number of decimals for displaying floats") + args = parser.parse_args() + + with open(args.json_file1, 'rt', encoding='utf-8') as file1: + dict1 = json.load(file1) + with open(args.json_file2, 'rt', encoding='utf-8') as file2: + dict2 = json.load(file2) + + compare_dicts(dict1, dict2, tolerance=args.tolerance, decimals=args.decimals) + + +if __name__ == "__main__": + main() diff --git a/bravo_toolkit/util/encode_submission.py b/bravo_toolkit/util/encode_submission.py new file mode 100644 index 0000000..221fcab --- /dev/null +++ b/bravo_toolkit/util/encode_submission.py @@ -0,0 +1,163 @@ +import argparse +from contextlib import closing +import glob +import io +import logging +import os +import sys +import tarfile +import time + +import numpy as np +from tqdm import tqdm + +from bravo_toolkit.codec.bravo_codec import bravo_encode +from bravo_toolkit.codec.bravo_tarfile import (SAMPLES_SUFFIX, SPLIT_PREFIX, SPLIT_TO_CONF_SUFFIX, SPLIT_TO_PRED_SUFFIX, + SUBMISSION_SUFFIX, tar_extract_file, tar_extract_grayscale, + tar_extract_image) +from bravo_toolkit.util.sample_gt_pixels import decode_indices + + +logger = logging.getLogger('bravo_toolkit') + + +def tqqdm(iterable, *args, **kwargs): + if logger.getEffectiveLevel() <= logging.INFO: + return tqdm(iterable, *args, **kwargs) + return iterable + + +class DirectoryAsTarMember: + def __init__(self, name, parent): + self.name = name + self.parent = parent + + def isfile(self): + return os.path.isfile(os.path.join(self.parent.root, self.name)) + + +class DirectoryAsTar: + def __init__(self, root): + self.root = os.path.abspath(root) + + def close(self): + pass + + def getmembers(self): + members = [] + for file_path in glob.iglob(self.root + '/**/*', recursive=True): + inside_path = os.path.relpath(file_path, self.root) + members.append(DirectoryAsTarMember(inside_path, self)) + return members + + def getmember(self, name): + if os.path.exists(os.path.join(self.root, name)): + return DirectoryAsTarMember(name, self) + + def extractfile(self, member): + if member.parent is not self: + raise ValueError("The member does not belong to this instance.") + file_path = os.path.join(self.root, member.name) + if not os.path.isfile(file_path): + raise ValueError(f"The member {member.name} is not a file.") + return open(file_path, 'rb') + + +def process_tar_files(input_path, output_tar_path, *, samples_tar_path): + """ + Converts all pairs of prediction and confidence images in input tar file (old submission format) into a single + encoded binary file in output tar file (new submission format) with `bravo_codec.bravo_encode`. + """ + # current_uid = os.getuid() + # current_gid = os.getgid() + + def open_input_path(): + if os.path.isdir(input_path): + return DirectoryAsTar(input_path) + return tarfile.open(input_path, "r") + + logger.info("Opening tar files...") + with closing(open_input_path()) as input_tar, \ + closing(tarfile.open(output_tar_path, "w")) as output_tar, \ + closing(tarfile.open(samples_tar_path, "r")) as samples_tar: + + logger.info("Listing input files...") + input_tar_members = input_tar.getmembers() + + logger.info("Reencoding submission...") + for member in tqqdm(input_tar_members): + # Skip non-files (directories, etc.) + if not member.isfile(): + continue + + # Determine the split and base path for prediction files + base_split = "" + base_path = "" + for split, pred_suffix in SPLIT_TO_PRED_SUFFIX.items(): + if member.name.startswith(SPLIT_PREFIX.format(split=split)): + base_split = split + if member.name.endswith(pred_suffix): + base_path = member.name[:-len(pred_suffix)] + break + if not base_path: + if not base_split or not member.name.endswith(SPLIT_TO_CONF_SUFFIX[base_split]): + logger.warning("Unexpected file: `%s`", member.name) + continue + + # Determine the file base and corresponding CONF file + conf_filename = base_path + SPLIT_TO_CONF_SUFFIX[base_split] + conf_member = input_tar.getmember(conf_filename) + + # Extract PRED and CONF images + pred_image = tar_extract_grayscale(input_tar, member, "prediction") + conf_image = tar_extract_image(input_tar, conf_member, "confidence") + if conf_image.dtype != np.uint16: + logger.error("Confidence image is not uint16: `%s`", conf_filename) + sys.exit(1) + if pred_image.shape != conf_image.shape: + logger.error("Prediction and confidence images have different shapes: `%s`, `%s`", + member.name, conf_filename) + sys.exit(1) + + # Extract samples index + confidence_indices_bytes = tar_extract_file(samples_tar, base_path + SAMPLES_SUFFIX) + confidence_indices = decode_indices(confidence_indices_bytes) + + # Encode the images + encoded_data = bravo_encode(pred_image, conf_image, confidence_indices=confidence_indices) + + # Add to output tarfile + encoded_filename = base_path + SUBMISSION_SUFFIX + tarinfo = tarfile.TarInfo(name=encoded_filename) + tarinfo.size = len(encoded_data) + # tarinfo.uid = current_uid + # tarinfo.gid = current_gid + tarinfo.mtime = time.time() + output_tar.addfile(tarinfo, io.BytesIO(encoded_data)) + + logger.info("Done!") + + +def main(): + parser = argparse.ArgumentParser(description="Process TAR files for encoding.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("input_path", help="Path to the input directory or TAR file.") + parser.add_argument("output_tar_path", help="Path to the output TAR file.") + parser.add_argument("--samples", help="Path to the pixel samples TAR file.") + parser.add_argument('--debug', action='store_true', help='enables extra verbose debug output') + parser.add_argument('--quiet', action='store_true', help='prints only errors and warnings') + args = parser.parse_args() + + level = logging.WARNING if args.quiet else (logging.DEBUG if args.debug else logging.INFO) + logging.basicConfig(level=level, format='%(levelname)s: %(message)s') + + try: + process_tar_files(args.input_path, args.output_tar_path, samples_tar_path=args.samples) + except Exception as e: + if os.path.exists(args.output_tar_path): + os.remove(args.output_tar_path) + raise e + + +if __name__ == "__main__": + main() diff --git a/bravo_toolkit/util/export_eval_script.py b/bravo_toolkit/util/export_eval_script.py new file mode 100644 index 0000000..6986153 --- /dev/null +++ b/bravo_toolkit/util/export_eval_script.py @@ -0,0 +1,118 @@ +import argparse +import ast +import base64 +import graphlib +import os +import sys + + +class ImportCollector(ast.NodeVisitor): + def __init__(self, base_name): + self.base_name = base_name + self.modules = set() + + def visit_Import(self, node): + for alias in node.names: + if alias.name.startswith(f'{self.base_name}.'): + self.modules.add(alias.name) + + def visit_ImportFrom(self, node): + if node.level == 0 and node.module and node.module.startswith(f'{self.base_name}.'): + self.modules.add(node.module) + + +def import_to_file_path(import_name, base_path): + partial_path = os.path.sep.join(import_name.split('.')) + if base_path: + return os.path.join(base_path, partial_path + '.py') + else: + return partial_path + '.py' + + +def encode_module_contents(file_path): + with open(file_path, 'r', encoding='utf-8') as file: + contents = file.read() + encoded = base64.b64encode(contents.encode('utf-8')).decode('utf-8') + return encoded + + +def find_dependencies(base_path, base_name, module_import, import_graph=None): + if import_graph is None: + import_graph = {} + elif module_import in import_graph: + assert False, f'invalid recursive call: {module_import} in {import_graph}' # this should never happen + + # Gets all immediate dependencies of the module in module_import + collector = ImportCollector(base_name) + module_path = import_to_file_path(module_import, base_path) + with open(module_path, 'r', encoding='utf-8') as file: + module_text = file.read() + node = ast.parse(module_text, filename=module_path) + collector.visit(node) + + # Adds the module to the import graph + import_graph[module_import] = set(collector.modules) + + # Gets recursive dependencies + for dep in collector.modules: + if dep not in import_graph: + find_dependencies(base_path, base_name, dep, import_graph) + + # Returns all dependencies + return import_graph + + +def compile_to_single_script(base_path, base_name, entry_module): + import_graph = find_dependencies(base_path, base_name, entry_module) + import_order = graphlib.TopologicalSorter(import_graph).static_order() + import_order = [mod for mod in import_order if mod != entry_module] + [entry_module] + + output_script = [] + output_script.append('import base64') + output_script.append('import sys') + output_script.append('\n') + output_script.append('modules_data = {}') + output_script.append('modules_path = {}') + output_script.append('\n') + + for module in import_order: + module_path = import_to_file_path(module, base_path) + friendly_path = import_to_file_path(module, None) + module_data = encode_module_contents(module_path) + output_script.append(f"modules_path['{module}'] = '''{friendly_path}'''") + output_script.append(f"modules_data['{module}'] = '''{module_data}'''\n") + + output_script.append(f"modules_order = [{', '.join([f'\"{mod}\"' for mod in import_order])}]\n") + + output_script.append(''' +def load_module(module_name, main=False): + if module_name not in modules_data: + raise ImportError(f"No module named '{module_name}'") + code = base64.b64decode(modules_data[module_name]) + module_namespace = {'__name__': '__main__' if main else module_name, + '__file__': modules_path[module_name]} + exec(compile(code, module_namespace['__file__'], 'exec'), module_namespace) + sys.modules[module_name] = type(sys)('module') + sys.modules[module_name].__dict__.update(module_namespace) + +if __name__ == '__main__': + for mod in modules_order[:-1]: + load_module(mod) + load_module(modules_order[-1], main=True) +else: + for mod in modules_order: + load_module(mod) +''') + +def main(): + parser = argparse.ArgumentParser( + description='Exports a script with all internal dependencies resolved.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--base-path', help='path to the base directory of the submission', default='.') + args = parser.parse_args() + + with open(args.results, 'wt', encoding='utf-8') as json_file: + + +if __name__ == '__main__': + main() diff --git a/bravo_toolkit/util/inspect_submission.py b/bravo_toolkit/util/inspect_submission.py new file mode 100644 index 0000000..eed81be --- /dev/null +++ b/bravo_toolkit/util/inspect_submission.py @@ -0,0 +1,93 @@ +import argparse +import logging +import sys +import tarfile + +import numpy as np +from tqdm import tqdm + +from bravo_toolkit.codec.bravo_codec import bravo_encode +from bravo_toolkit.codec.bravo_tarfile import (SPLIT_PREFIX, SPLIT_TO_CONF_SUFFIX, SPLIT_TO_PRED_SUFFIX, + tar_extract_grayscale, tar_extract_image) + + +def process_tar_files(input_tar_path, rounded_quantization=False): + ''' + Converts all pairs of prediction and confidence images in input tar file (old submission format) into a single + encoded binary file in output tar file (new submission format) with `bravo_codec.bravo_encode`. + ''' + # current_uid = os.getuid() + # current_gid = os.getgid() + + print("Opening tar files...") + all_shapes = [] + with tarfile.open(input_tar_path, 'r') as input_tar: + + print("Listing input files...") + input_tar_members = input_tar.getmembers() + + print("Inspecting submission...") + for member in tqdm(input_tar_members): + # Skip non-files (directories, etc.) + if not member.isfile(): + continue + + # Determine the split and base path for prediction files + base_split = '' + base_path = '' + for split, pred_suffix in SPLIT_TO_PRED_SUFFIX.items(): + if member.name.startswith(SPLIT_PREFIX.format(split=split)): + base_split = split + if member.name.endswith(pred_suffix): + base_path = member.name[:-len(pred_suffix)] + break + if not base_path: + if not base_split or not member.name.endswith(SPLIT_TO_CONF_SUFFIX[base_split]): + logging.warning('Unexpected file: `%s`', member.name) + continue + + # Determine the file base and corresponding CONF file + conf_filename = base_path + SPLIT_TO_CONF_SUFFIX[base_split] + conf_member = input_tar.getmember(conf_filename) + + # Extract PRED and CONF images + pred_image = tar_extract_grayscale(input_tar, member, 'prediction') + conf_image = tar_extract_image(input_tar, conf_member, 'confidence') + if conf_image.dtype != np.uint16: + logging.error('Confidence image is not uint16: `%s`', conf_filename) + sys.exit(1) + conf_image = conf_image.astype(np.float32) / 65536 # This preserves the mantissa and shifts the exponent + + if pred_image.shape != conf_image.shape: + logging.error('Prediction and confidence images have different shapes: `%s` and `%s`', + member.name, conf_filename) + sys.exit(1) + + all_shapes.append(pred_image.shape) + + np.set_printoptions(suppress=True, precision=1) + all_shapes = np.array(all_shapes) + print('Image shape statistics:') + print('Avg: ', np.mean(all_shapes, axis=0)) + print('Max: ', np.max(all_shapes, axis=0)) + print('75p: ', np.percentile(all_shapes, 75, axis=0)) + print('Med: ', np.median(all_shapes, axis=0)) + print('25p: ', np.percentile(all_shapes, 25, axis=0)) + print('Min: ', np.min(all_shapes, axis=0)) + + + print("Done!") + + +def main(): + parser = argparse.ArgumentParser(description="Process TAR files for encoding.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("input_tar_path", help="Path to the input TAR file.") + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') + process_tar_files(args.input_tar_path) + + +if __name__ == "__main__": + main() diff --git a/bravo_toolkit/util/old_eval_script_reference.py b/bravo_toolkit/util/old_eval_script_reference.py new file mode 100644 index 0000000..9553144 --- /dev/null +++ b/bravo_toolkit/util/old_eval_script_reference.py @@ -0,0 +1,786 @@ +# pylint: skip-file +# flake8 : noqa +import argparse +import json +import logging +import tarfile + +import cv2 +import numpy as np +from sklearn import metrics +from tqdm import tqdm + + +DEBUG = False +CS_STUFF_CLASSES = [11, 12, 13, 14, 15, 16, 17, 18, 19] + +logger = logging.getLogger('bravo_toolkit') + + +def tqqdm(iterable, *args, **kwargs): + if logger.getEffectiveLevel() <= logging.INFO: + return tqdm(iterable, *args, **kwargs) + return iterable + + +def load_tar(tar_path): + return tarfile.open(tar_path) + + +def read_image_from_tar(tar, member, flag=cv2.IMREAD_GRAYSCALE): + try: + f = tar.extractfile(member) + content = f.read() + f.close() + file_bytes = np.asarray(bytearray(content), dtype=np.uint8) + img = cv2.imdecode(file_bytes, flag) + if img is None: + raise ValueError('error decoding image') + return img + except: + logger.error('Unable to load file {}'.format(member.name)) + raise IOError('Unable to load file {}'.format(member.name)) + + +def read_image_from_tar_conf(tar, member, flag=cv2.IMREAD_UNCHANGED): + try: + f = tar.extractfile(member) + content = f.read() + f.close() + file_bytes = np.asarray(bytearray(content), dtype=np.uint8) + img = cv2.imdecode(file_bytes, flag) + if img is None: + raise ValueError('error decoding image') + return img + except: + logger.error('Unable to load file {}'.format(member.name)) + raise IOError('Unable to load file {}'.format(member.name)) + + +def fast_hist(a, b, n): + # Fast calculation of the confusion matrix per frame + k = (a >= 0) & (a < n) + return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) + + +def per_class_iu(hist): + return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) + + +def check_conds(split_name, gt_suffix, additional_conds, name): + if split_name is not None and split_name not in name: + return False + if gt_suffix is not None and gt_suffix not in name: + return False + if len(additional_conds) == 0: + return True + else: + for cond in additional_conds: + if cond not in name: + return False + return True + + +def semantic_filtering(seg_pred, ori_ood_conf): + conf = np.zeros_like(ori_ood_conf) + for c in CS_STUFF_CLASSES: + mask = np.where(seg_pred == c, 1, 0) + conf += mask * ori_ood_conf + conf += 0.01 * ori_ood_conf + return np.clip(conf, 0, 0.99) + + +# This is a slow alternative implementation of compute_ood_scores2, kept as reference for testing +# It is not used in the final evaluation +def compute_ood_scores(flat_labels, flat_pred, num_points=50): + # From fishycapes code + pos = flat_labels == 1 + valid = flat_labels <= 1 # filter out void + gt = pos[valid] + del pos + uncertainty = flat_pred[valid].reshape(-1).astype(np.float32, copy=False) + del valid + + # Sort the classifier scores (uncertainties) + sorted_indices = np.argsort(uncertainty, kind='mergesort')[::-1] + uncertainty, gt = uncertainty[sorted_indices], gt[sorted_indices] + del sorted_indices + + # Remove duplicates along the curve + distinct_value_indices = np.where(np.diff(uncertainty))[0] + threshold_idxs = np.r_[distinct_value_indices, gt.size - 1] + del distinct_value_indices, uncertainty + + # Accumulate TPs and FPs + tps = np.cumsum(gt, dtype=np.uint64)[threshold_idxs] + fps = 1 + threshold_idxs - tps + del threshold_idxs + + # Compute Precision and Recall + precision = tps / (tps + fps) + precision[np.isnan(precision)] = 0 + recall = tps / tps[-1] + # stop when full recall attained and reverse the outputs so recall is decreasing + sl = slice(tps.searchsorted(tps[-1]), None, -1) + precision = np.r_[precision[sl], 1] + recall = np.r_[recall[sl], 0] + average_precision = -np.sum(np.diff(recall) * precision[:-1]) + + # select num_points values for a plotted curve + interval = 1.0 / num_points + curve_precision = [precision[-1]] + curve_recall = [recall[-1]] + idx = recall.size - 1 + for p in range(1, num_points): + while recall[idx] < p * interval: + idx -= 1 + curve_precision.append(precision[idx]) + curve_recall.append(recall[idx]) + curve_precision.append(precision[0]) + curve_recall.append(recall[0]) + del precision, recall + + if tps.size == 0 or fps[0] != 0 or tps[0] != 0: + # Add an extra threshold position if necessary + # to make sure that the curve starts at (0, 0) + tps = np.r_[0., tps] + fps = np.r_[0., fps] + + # Compute TPR and FPR + tpr = tps / tps[-1] + del tps + fpr = fps / fps[-1] + del fps + + # Compute AUROC + auroc = np.trapz(tpr, fpr) + + # Compute FPR@95%TPR + fpr_tpr95 = fpr[np.searchsorted(tpr, 0.95)] + results = { + 'auroc': auroc, + 'AP': average_precision, + 'FPR@95%TPR': fpr_tpr95, + # 'recall': np.array(curve_recall), + # 'precision': np.array(curve_precision), + # 'fpr': fpr, + # 'tpr': tpr + } + + return results + + +def compute_scores(conf, pred, label, ECE_NUM_BINS=15, CONF_NUM_BINS=65535, tpr_th=0.95): + conf = conf.astype(np.float32) + # normalize conf to [0, 1] + conf = conf / CONF_NUM_BINS + tau_tab = np.linspace(0, 1, ECE_NUM_BINS + 1) # Confidence bins + nb_items_bin = np.zeros(ECE_NUM_BINS) + acc_tab = np.zeros(ECE_NUM_BINS) # Empirical (true) confidence + mean_conf = np.zeros(ECE_NUM_BINS) # Predicted confidence + for i in np.arange(ECE_NUM_BINS): # Iterates over the bins + # Selects the items where the predicted max probability falls in the bin + # [tau_tab[i], tau_tab[i + 1)] + sec = (tau_tab[i + 1] > conf) & (conf >= tau_tab[i]) + nb_items_bin[i] = np.sum(sec) # Number of items in the bin + # Selects the predicted classes, and the true classes + class_pred_sec, y_sec = pred[sec], label[sec] + # Averages of the predicted max probabilities + mean_conf[i] = np.mean(conf[sec]) if nb_items_bin[i] > 0 else np.nan + # Computes the empirical confidence + acc_tab[i] = np.mean(class_pred_sec == y_sec) if nb_items_bin[i] > 0 else np.nan + mean_conf = mean_conf[nb_items_bin > 0] + acc_tab = acc_tab[nb_items_bin > 0] + nb_items_bin = nb_items_bin[nb_items_bin > 0] + if sum(nb_items_bin) != 0: + ece = np.average(np.absolute(mean_conf - acc_tab), weights=nb_items_bin.astype(np.float32) / np.sum(nb_items_bin)) + else: + ece = 0.0 + # TODO: wouldn't it be better to call compute_ood_scores2 here? + # auroc, aupr_success, aupr_error, fpr95 = compute_ood_scores2(-conf, pred == label) + # In this postprocessing we assume ID samples have larger "conf" values than OOD samples => we negate "conf" values + # such that higher (signed) values correspond to detecting OOD samples + fpr_list, tpr_list, thresholds = metrics.roc_curve(pred == label, conf) + fpr = fpr_list[np.argmax(tpr_list >= tpr_th)] + precision_in, recall_in, thresholds_in = metrics.precision_recall_curve(pred == label, conf) + precision_out, recall_out, thresholds_out = metrics.precision_recall_curve(pred != label, -conf) # TODO: this formula differs from the one in compute_ood_scores2? Is this intended? => possibly, this came as this, to double-check + auroc = metrics.auc(fpr_list, tpr_list) + aupr_success = metrics.auc(recall_in, precision_in) + aupr_error = metrics.auc(recall_out, precision_out) + return ece, auroc, aupr_success, aupr_error, fpr + + +def compute_ood_scores2(conf, ood_label, tpr_th=0.95): + valid = ood_label <= 1 # filter out void + ood_label = ood_label[valid] == 1 + conf = conf[valid] + # TODO: double-check the logic of negating the confs, especially because the confs are already being negated by the caller + # In this postprocessing we assume ID samples have larger "conf" values than OOD samples => we negate "conf" values + # such that higher (signed) values correspond to detecting OOD samples + fpr_list, tpr_list, thresholds = metrics.roc_curve(ood_label, conf, drop_intermediate=False) + fpr = fpr_list[np.argmax(tpr_list >= tpr_th)] + precision_in, recall_in, thresholds_in = metrics.precision_recall_curve(ood_label, conf) + precision_out, recall_out, thresholds_out = metrics.precision_recall_curve(ood_label, -conf) # TODO: see above + auroc = metrics.auc(fpr_list, tpr_list) + aupr_success = metrics.auc(recall_in, precision_in) + aupr_error = metrics.auc(recall_out, precision_out) + return auroc, aupr_success, aupr_error, fpr + + +def evaluate_bravo(gt_data, + submission_data, + split_name, # ACDC, SMIYC, outofcontext, synflare, synobjs, synrain + additional_conds, # list of naming conditions for loading ground-truths + gt_suffix, # suffix of ground-truths + pred_suffix, # suffix for loading predictions + conf_suffix, # suffix for loading confidences + invalid_mask_suffix=None, # suffix for loading valid masks + compute_semantic_metrics=True, # compute semantic metrics + compute_semantic_metrics_invalid_area=False, + compute_OOD=False, # compute OOD scores + NUM_CLASSES=19, + ECE_NUM_BINS=15, + CONF_NUM_BINS=65535, + SAMPLES_PER_IMG=20000, + strict=False + ): + gts = [mem for mem in gt_data.getmembers() if check_conds(split_name, gt_suffix, additional_conds, mem.name)] + n_images = len(gts) + str_conds = ' '.join(additional_conds) + logger.info(f'{split_name}-{str_conds}: evaluation on {n_images} images') + hist = np.zeros((NUM_CLASSES, NUM_CLASSES)) + hist_valid = np.zeros((NUM_CLASSES, NUM_CLASSES)) + hist_invalid = np.zeros((NUM_CLASSES, NUM_CLASSES)) + all_gt_labels = [] + all_preds = [] + all_confs = [] + all_valid_indices = [] + all_invalid_indices = [] + + all_gt_labels_valid = [] + all_preds_valid = [] + all_confs_valid = [] + all_gt_labels_invalid = [] + all_preds_invalid = [] + all_confs_invalid = [] + + for idx, gt_mem in enumerate(tqqdm(gts)): + + np.random.seed(42) # CHANGED: Takes the same samples from all images + + if DEBUG and idx >= 2: + break + + gt_name = gt_mem.name + pred_name = gt_name.replace(gt_suffix, pred_suffix) + conf_name = gt_name.replace(gt_suffix, conf_suffix) + + # Read current ground truth and prediction files. + try: + img_gt_label_np = read_image_from_tar(gt_data, gt_mem) + except: + logger.error(f'Unable to load ground truth file {gt_name}') + raise IOError(f'Unable to load ground truth file {gt_name}') + + try: + pred = submission_data.getmember(pred_name) + img_pred_label_np = read_image_from_tar(submission_data, pred) + except: + logger.error(f'Unable to load prediction file {pred_name}') + raise IOError(f'Unable to load prediction file {pred_name}') + + try: + conf = submission_data.getmember(conf_name) + img_conf_np = read_image_from_tar_conf(submission_data, conf) + except: + logger.error(f'Unable to load prediction file {conf_name}') + raise IOError(f'Unable to load prediction file {conf_name}') + + # TODO: clarify the resizing rules in the submission instructions + # Ensures that dimensions of the two images match exactly. + img_gt_label_shape = img_gt_label_np.shape + img_pred_label_shape = img_pred_label_np.shape + if len(img_pred_label_shape) != 2: + logger.error(f'Prediction is not a proper 2D matrix for file {gt_name}') + raise ValueError(f'Prediction is not a proper 2D matrix for file {gt_name}') + if img_pred_label_shape != img_gt_label_shape: + # resize img_gt_label_np to img_pred_label_shape + img_pred_label_np = cv2.resize(img_pred_label_np, img_gt_label_shape[::-1], interpolation=cv2.INTER_LINEAR) + img_conf_np = cv2.resize(img_conf_np, img_gt_label_shape[::-1], interpolation=cv2.INTER_LINEAR) + img_pred_label_shape = img_pred_label_np.shape + logger.warning(f'Resized prediction to match ground truth dimensions for file {pred_name}') + # raise ValueError(f'Image dimensions mismatch for file {pred_name}') + + if strict: + # STRICTLY SIMILAR TO THE OLD SCRIPT + if invalid_mask_suffix is not None: + mask_name = gt_name.replace(gt_suffix, invalid_mask_suffix) + try: + img_invalid_mask_np = read_image_from_tar(gt_data, mask_name) + except: + logger.error(f'Unable to load invalid mask file {mask_name}') + raise IOError(f'Unable to load invalid mask file {mask_name}') + + # TODO: can we do this once outside the server? => probably, let's check + # resize img_invalid_mask_np to img_gt_label_shape + img_invalid_mask_np = cv2.resize(img_invalid_mask_np, img_gt_label_shape[::-1], interpolation=cv2.INTER_NEAREST) + + # convert the invalid binary mask to indices of invalid and valid pixels + img_invalid_indices = np.where(img_invalid_mask_np == 1) + img_valid_indices = np.where(img_invalid_mask_np == 0) + # get 1D indices of invalid and valid pixels + img_invalid_indices_1D = np.ravel_multi_index(img_invalid_indices, img_invalid_mask_np.shape) + img_valid_indices_1D = np.ravel_multi_index(img_valid_indices, img_invalid_mask_np.shape) + else: + mask_name = img_invalid_mask_np = img_valid_indices = img_invalid_indices = img_valid_indices_1D = img_invalid_indices_1D = None + + # Calculate the intersection and union counts per class for the two images. + img_gt_label_np = img_gt_label_np.flatten() + img_pred_label_np = img_pred_label_np.flatten() + img_conf_np = img_conf_np.flatten() + hist += fast_hist(img_gt_label_np, img_pred_label_np, NUM_CLASSES) + + # Sample pixels for ECE and AUROC + nsamples = min(SAMPLES_PER_IMG, img_gt_label_np.shape[0]) + samples_indices = np.random.choice(range(img_gt_label_np.shape[0]), nsamples, replace=False) + + valid_indices = np.zeros_like(img_gt_label_np) + valid_indices[img_valid_indices_1D] = 1 # all indices are valid if img_valid_indices_1D is None + + invalid_indices = np.zeros_like(img_gt_label_np) + invalid_indices[img_invalid_indices_1D] = 1 + + all_gt_labels.append(img_gt_label_np[samples_indices]) + all_preds.append(img_pred_label_np[samples_indices]) + all_confs.append(img_conf_np[samples_indices]) + all_valid_indices.append(valid_indices[samples_indices]) + all_invalid_indices.append(invalid_indices[samples_indices]) + + # Compute scores for valid/invalid pixels + if invalid_mask_suffix is not None: + hist_valid += fast_hist(img_gt_label_np[img_valid_indices_1D], img_pred_label_np[img_valid_indices_1D], NUM_CLASSES) + hist_invalid += fast_hist(img_gt_label_np[img_invalid_indices_1D], img_pred_label_np[img_invalid_indices_1D], NUM_CLASSES) + nsamples_valid = min(SAMPLES_PER_IMG, len(img_valid_indices_1D)) + samples_indices_valid = np.random.choice(range(len(img_valid_indices_1D)), nsamples_valid, replace=False) + samples_indices_valid = img_valid_indices_1D[samples_indices_valid] + all_gt_labels_valid.append(img_gt_label_np[samples_indices_valid]) + all_preds_valid.append(img_pred_label_np[samples_indices_valid]) + all_confs_valid.append(img_conf_np[samples_indices_valid]) + + if compute_semantic_metrics_invalid_area: + nsamples_invalid = min(SAMPLES_PER_IMG, len(img_invalid_indices_1D)) + samples_indices_invalid = np.random.choice(range(len(img_invalid_indices_1D)), nsamples_invalid, replace=False) + samples_indices_invalid = img_invalid_indices_1D[samples_indices_invalid] + all_gt_labels_invalid.append(img_gt_label_np[samples_indices_invalid]) + all_preds_invalid.append(img_pred_label_np[samples_indices_invalid]) + all_confs_invalid.append(img_conf_np[samples_indices_invalid]) + else: + # COMPATIBILITY MODE WITH NEW SCRIPT + # Removes all pixels of class 255 + # Because of that, the treatment of indices is modified, with all arrays being flattened from start + + # Calculate the intersection and union counts per class for the two images. + img_gt_label_np = img_gt_label_np.flatten() + img_pred_label_np = img_pred_label_np.flatten() + img_conf_np = img_conf_np.flatten() + + non_void = img_gt_label_np != 255 + img_gt_label_np = img_gt_label_np[non_void] + img_pred_label_np = img_pred_label_np[non_void] + img_conf_np = img_conf_np[non_void] + + if invalid_mask_suffix is not None: + mask_name = gt_name.replace(gt_suffix, invalid_mask_suffix) + try: + img_invalid_mask_np = read_image_from_tar(gt_data, mask_name) + except: + logger.error(f'Unable to load invalid mask file {mask_name}') + raise IOError(f'Unable to load invalid mask file {mask_name}') + + # TODO: can we do this once outside the server? => probably, let's check + # resize img_invalid_mask_np to img_gt_label_shape + img_invalid_mask_np = cv2.resize(img_invalid_mask_np, img_gt_label_shape[::-1], interpolation=cv2.INTER_NEAREST) + + img_invalid_mask_np = img_invalid_mask_np.flatten() + img_invalid_mask_np = img_invalid_mask_np[non_void] + + # convert the invalid binary mask to indices of invalid and valid pixels + img_invalid_indices_1D = np.nonzero(img_invalid_mask_np == 1)[0] + img_valid_indices_1D = np.nonzero(img_invalid_mask_np == 0)[0] + else: + mask_name = img_invalid_mask_np = img_valid_indices_1D = img_invalid_indices_1D = None + + hist += fast_hist(img_gt_label_np, img_pred_label_np, NUM_CLASSES) + + # Sample pixels for ECE and AUROC + nsamples = min(SAMPLES_PER_IMG, img_gt_label_np.shape[0]) + samples_indices = np.random.choice(range(img_gt_label_np.shape[0]), nsamples, replace=False) + + valid_indices = np.zeros_like(img_gt_label_np) + valid_indices[img_valid_indices_1D] = 1 # all indices are valid if img_valid_indices_1D is None + + invalid_indices = np.zeros_like(img_gt_label_np) + invalid_indices[img_invalid_indices_1D] = 1 + + all_gt_labels.append(img_gt_label_np[samples_indices]) + all_preds.append(img_pred_label_np[samples_indices]) + all_confs.append(img_conf_np[samples_indices]) + all_valid_indices.append(valid_indices[samples_indices]) + all_invalid_indices.append(invalid_indices[samples_indices]) + + # Compute scores for valid/invalid pixels + if invalid_mask_suffix is not None: + hist_valid += fast_hist(img_gt_label_np[img_valid_indices_1D], img_pred_label_np[img_valid_indices_1D], NUM_CLASSES) + hist_invalid += fast_hist(img_gt_label_np[img_invalid_indices_1D], img_pred_label_np[img_invalid_indices_1D], NUM_CLASSES) + nsamples_valid = min(SAMPLES_PER_IMG, len(img_valid_indices_1D)) + samples_indices_valid = np.random.choice(range(len(img_valid_indices_1D)), nsamples_valid, replace=False) + samples_indices_valid = img_valid_indices_1D[samples_indices_valid] + all_gt_labels_valid.append(img_gt_label_np[samples_indices_valid]) + all_preds_valid.append(img_pred_label_np[samples_indices_valid]) + all_confs_valid.append(img_conf_np[samples_indices_valid]) + + if compute_semantic_metrics_invalid_area: + nsamples_invalid = min(SAMPLES_PER_IMG, len(img_invalid_indices_1D)) + samples_indices_invalid = np.random.choice(range(len(img_invalid_indices_1D)), nsamples_invalid, replace=False) + samples_indices_invalid = img_invalid_indices_1D[samples_indices_invalid] + all_gt_labels_invalid.append(img_gt_label_np[samples_indices_invalid]) + all_preds_invalid.append(img_pred_label_np[samples_indices_invalid]) + all_confs_invalid.append(img_conf_np[samples_indices_invalid]) + + # calculate mIoU + if compute_semantic_metrics: + perclass_iou = per_class_iu(hist).tolist() + miou = np.nanmean(perclass_iou) + logger.info(f'{split_name}-{str_conds}: mIoU = {miou}') + else: + miou = None + perclass_iou = None + # calculate mIoU for valid/invalid pixels + if compute_semantic_metrics and invalid_mask_suffix is not None: + perclass_iou_valid = per_class_iu(hist_valid).tolist() + miou_valid = np.nanmean(perclass_iou_valid) + perclass_iou_invalid = per_class_iu(hist_invalid).tolist() + miou_invalid = np.nanmean(perclass_iou_invalid) + logger.info(f'{split_name}-{str_conds}: mIoU-valid = {miou_valid}') + logger.info(f'{split_name}-{str_conds}: mIoU-invalid = {miou_invalid}') + logger.info('====================================================') + else: + miou_valid = miou_invalid = perclass_iou_valid = perclass_iou_invalid = None + + # calculate ECE, AUROC, AUPR_success, AUPR_error, FPR_at_95TPR + + all_gt_labels = np.concatenate(all_gt_labels) + all_preds = np.concatenate(all_preds) + all_confs = np.concatenate(all_confs).astype(np.float32) + all_valid_indices = np.concatenate(all_valid_indices) + all_invalid_indices = np.concatenate(all_invalid_indices) + + if compute_semantic_metrics: + ece, auroc, aupr_success, aupr_error, fpr95 = compute_scores(all_confs, all_preds, all_gt_labels, ECE_NUM_BINS=ECE_NUM_BINS, tpr_th=0.95) + logger.info(f'{split_name}-{str_conds}: ECE = {ece}') + logger.info(f'{split_name}-{str_conds}: AUROC = {round(auroc * 100,2)}') + logger.info(f'{split_name}-{str_conds}: AUPR_success = {round(aupr_success * 100,2)}') + logger.info(f'{split_name}-{str_conds}: AUPR_error = {round(aupr_error * 100,2)}') + logger.info(f'{split_name}-{str_conds}: FPR_at_95TPR = {round(fpr95 * 100,2)}') + logger.info('====================================================') + else: + ece = auroc = aupr_success = aupr_error = fpr95 = None + + if compute_semantic_metrics and invalid_mask_suffix is not None: + all_gt_labels_valid = np.concatenate(all_gt_labels_valid) + all_preds_valid = np.concatenate(all_preds_valid) + all_confs_valid = np.concatenate(all_confs_valid).astype(np.float32) + ece_valid, auroc_valid, aupr_success_valid, aupr_error_valid, fpr95_valid = compute_scores(all_confs_valid, all_preds_valid, all_gt_labels_valid, ECE_NUM_BINS=ECE_NUM_BINS, tpr_th=0.95) + logger.info(f'{split_name}-{str_conds}: ECE_valid = {ece_valid}') + logger.info(f'{split_name}-{str_conds}: AUROC_valid = {round(auroc_valid * 100,2)}') + logger.info(f'{split_name}-{str_conds}: AUPR_success_valid = {round(aupr_success_valid * 100,2)}') + logger.info(f'{split_name}-{str_conds}: AUPR_error_valid = {round(aupr_error_valid * 100,2)}') + logger.info(f'{split_name}-{str_conds}: FPR_at_95TPR_valid = {round(fpr95_valid * 100,2)}') + logger.info('====================================================') + + if compute_semantic_metrics_invalid_area: + all_gt_labels_invalid = np.concatenate(all_gt_labels_invalid) + all_preds_invalid = np.concatenate(all_preds_invalid) + all_confs_invalid = np.concatenate(all_confs_invalid).astype(np.float32) + ece_invalid, auroc_invalid, aupr_success_invalid, aupr_error_invalid, fpr95_invalid = compute_scores(all_confs_invalid, all_preds_invalid, all_gt_labels_invalid, ECE_NUM_BINS=ECE_NUM_BINS, tpr_th=0.95) + logger.info(f'{split_name}-{str_conds}: ECE_invalid = {ece_invalid}') + logger.info(f'{split_name}-{str_conds}: AUROC_invalid = {round(auroc_invalid * 100,2)}') + logger.info(f'{split_name}-{str_conds}: AUPR_success_invalid = {round(aupr_success_invalid * 100,2)}') + logger.info(f'{split_name}-{str_conds}: AUPR_error_invalid = {round(aupr_error_invalid * 100,2)}') + logger.info(f'{split_name}-{str_conds}: FPR_at_95TPR_invalid = {round(fpr95_invalid * 100,2)}') + logger.info('====================================================') + else: + ece_invalid = auroc_invalid = aupr_success_invalid = aupr_error_invalid = fpr95_invalid = None + + else: + ece_valid = auroc_valid = aupr_success_valid = aupr_error_valid = fpr95_valid = None + ece_invalid = auroc_invalid = aupr_success_invalid = aupr_error_invalid = fpr95_invalid = None + + # calculate OOD detection scores + if compute_OOD and invalid_mask_suffix is not None: + ood_labels = np.zeros_like(all_confs) + ood_labels[all_invalid_indices == 1] = 1 + ood_labels[all_gt_labels == 255] = 255 + + auroc_ood, aupr_ood, aupr_ood_error, fpr95_ood = compute_ood_scores2(-all_confs, ood_labels) + logger.info(f'{split_name}-{str_conds}: AUROC_ood = {round(auroc_ood * 100,2)}') + logger.info(f'{split_name}-{str_conds}: AUPR_ood = {round(aupr_ood * 100,2)}') + logger.info(f'{split_name}-{str_conds}: FPR_at_95TPR_ood = {round(fpr95_ood * 100,2)}') + + # apply semantic filtering + all_confs_semfilt = semantic_filtering(all_preds, 1-all_confs/CONF_NUM_BINS) + auroc_ood_semfilt, aupr_ood_semfilt, aupr_ood_error_semfilt, fpr95_ood_semfilt = compute_ood_scores2(all_confs_semfilt, ood_labels) + logger.info(f'{split_name}-{str_conds}: AUROC_ood_semfilt = {round(auroc_ood_semfilt * 100,2)}') + logger.info(f'{split_name}-{str_conds}: AUPR_ood_semfilt = {round(aupr_ood_semfilt * 100,2)}') + logger.info(f'{split_name}-{str_conds}: FPR_at_95TPR_ood_semfilt = {round(fpr95_ood_semfilt * 100,2)}') + else: + auroc_ood = aupr_ood = fpr95_ood = auroc_ood_semfilt = aupr_ood_semfilt = fpr95_ood_semfilt = None + + metrics = { + 'miou': miou, + 'iou': perclass_iou, + 'ece': ece, + 'auroc': auroc, + 'auprc_success': aupr_success, + 'auprc_error': aupr_error, + 'fpr95': fpr95, + 'miou_valid': miou_valid, + 'iou_valid': perclass_iou_valid, + 'ece_valid': ece_valid, + 'auroc_valid': auroc_valid, + 'auprc_success_valid': aupr_success_valid, + 'auprc_error_valid': aupr_error_valid, + 'fpr95_valid': fpr95_valid, + 'miou_invalid': miou_invalid, + 'iou_invalid': perclass_iou_invalid, + 'ece_invalid': ece_invalid, + 'auroc_invalid': auroc_invalid, + 'auprc_success_invalid': aupr_success_invalid, + 'auprc_error_invalid': aupr_error_invalid, + 'fpr95_invalid': fpr95_invalid, + 'auroc_ood': auroc_ood, + 'auprc_ood': aupr_ood, + 'fpr95_ood': fpr95_ood, + } + if strict: + metrics.update({ + 'auroc_ood_semfilt': auroc_ood_semfilt, + 'auprc_ood_semfilt': aupr_ood_semfilt, + 'fpr95_ood_semfilt': fpr95_ood_semfilt + }) + return metrics + + +BRAVO_SUBSETS = ['ACDC_fog', 'ACDC_rain', 'ACDC_night', 'ACDC_snow', 'synrain', 'SMIYC', 'synobjs', 'synflare', + 'outofcontext'] + + +def evaluate_method(gtFilePath, submFilePath, evaluationParams): + gt_data = load_tar(gtFilePath) + submission_data = load_tar(submFilePath) + + NUM_CLASSES = 19 + ECE_NUM_BINS = 15 + CONF_NUM_BINS = 65535 + SAMPLES_PER_IMG = 100000 + # SAMPLES_PER_IMG = np.inf + + subsets = evaluationParams.get('subsets', BRAVO_SUBSETS) + strict = evaluationParams.get('strict', False) + results_dict = {} + + if 'ACDC_fog' in subsets: + acdcFogMetrics = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='ACDC', + additional_conds=['/fog/'], + gt_suffix='_gt_labelTrainIds.png', + invalid_mask_suffix='_gt_invIds.png', + pred_suffix='_rgb_anon_pred.png', + conf_suffix='_rgb_anon_conf.png', + NUM_CLASSES=NUM_CLASSES, + ECE_NUM_BINS=ECE_NUM_BINS, + CONF_NUM_BINS=CONF_NUM_BINS, + SAMPLES_PER_IMG=SAMPLES_PER_IMG, + strict=strict) + results_dict['ACDC_fog'] = acdcFogMetrics + + if 'ACDC_rain' in subsets: + acdcRainMetrics = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='ACDC', + additional_conds=['/rain/'], + gt_suffix='_gt_labelTrainIds.png', + invalid_mask_suffix='_gt_invIds.png', + pred_suffix='_rgb_anon_pred.png', + conf_suffix='_rgb_anon_conf.png', + NUM_CLASSES=NUM_CLASSES, + ECE_NUM_BINS=ECE_NUM_BINS, + CONF_NUM_BINS=CONF_NUM_BINS, + SAMPLES_PER_IMG=SAMPLES_PER_IMG, + strict=strict) + results_dict['ACDC_rain'] = acdcRainMetrics + + if 'ACDC_night' in subsets: + acdcNightMetrics = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='ACDC', + additional_conds=['/night/'], + gt_suffix='_gt_labelTrainIds.png', + invalid_mask_suffix='_gt_invIds.png', + pred_suffix='_rgb_anon_pred.png', + conf_suffix='_rgb_anon_conf.png', + NUM_CLASSES=NUM_CLASSES, + ECE_NUM_BINS=ECE_NUM_BINS, + CONF_NUM_BINS=CONF_NUM_BINS, + SAMPLES_PER_IMG=SAMPLES_PER_IMG, + strict=strict) + results_dict['ACDC_night'] = acdcNightMetrics + + if 'ACDC_snow' in subsets: + acdcSnowMetrics = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='ACDC', + additional_conds=['/snow/'], + gt_suffix='_gt_labelTrainIds.png', + invalid_mask_suffix='_gt_invIds.png', + pred_suffix='_rgb_anon_pred.png', + conf_suffix='_rgb_anon_conf.png', + NUM_CLASSES=NUM_CLASSES, + ECE_NUM_BINS=ECE_NUM_BINS, + CONF_NUM_BINS=CONF_NUM_BINS, + SAMPLES_PER_IMG=SAMPLES_PER_IMG, + strict=strict) + results_dict['ACDC_snow'] = acdcSnowMetrics + + if 'synrain' in subsets: + synrainMetrics = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='synrain', + additional_conds=[], + gt_suffix='_gt_labelTrainIds.png', + pred_suffix='_leftImg8bit_pred.png', + conf_suffix='_leftImg8bit_conf.png', + invalid_mask_suffix='_gt_invIds.png', + compute_semantic_metrics=True, + compute_semantic_metrics_invalid_area=True, + compute_OOD=True, + NUM_CLASSES=NUM_CLASSES, + ECE_NUM_BINS=ECE_NUM_BINS, + CONF_NUM_BINS=CONF_NUM_BINS, + SAMPLES_PER_IMG=SAMPLES_PER_IMG, + strict=strict) + results_dict['synrain'] = synrainMetrics + + if 'SMIYC' in subsets: + smiycMetrics = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='SMIYC', + additional_conds=[], + gt_suffix='_labels_semantic_fake.png', + pred_suffix='_pred.png', + conf_suffix='_conf.png', + invalid_mask_suffix='_labels_semantic.png', + compute_semantic_metrics=False, + compute_OOD=True, + NUM_CLASSES=NUM_CLASSES, + ECE_NUM_BINS=ECE_NUM_BINS, + CONF_NUM_BINS=CONF_NUM_BINS, + SAMPLES_PER_IMG=np.inf, + strict=strict) + results_dict['SMIYC'] = smiycMetrics + + if 'synobjs' in subsets: + synobjsMetrics = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='synobjs', + additional_conds=[], + gt_suffix='_gt.png', + pred_suffix='_pred.png', + conf_suffix='_conf.png', + invalid_mask_suffix='_mask.png', + compute_semantic_metrics=True, + compute_OOD=True, + NUM_CLASSES=NUM_CLASSES, + ECE_NUM_BINS=ECE_NUM_BINS, + CONF_NUM_BINS=CONF_NUM_BINS, + SAMPLES_PER_IMG=SAMPLES_PER_IMG, + strict=strict) + results_dict['synobjs'] = synobjsMetrics + + if 'synflare' in subsets: + synflareMetrics = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='synflare', + additional_conds=[], + gt_suffix='_gt_labelTrainIds.png', + pred_suffix='_leftImg8bit_pred.png', + conf_suffix='_leftImg8bit_conf.png', + invalid_mask_suffix='_gt_invIds.png', + compute_semantic_metrics=True, + compute_OOD=True, + NUM_CLASSES=NUM_CLASSES, + ECE_NUM_BINS=ECE_NUM_BINS, + CONF_NUM_BINS=CONF_NUM_BINS, + SAMPLES_PER_IMG=SAMPLES_PER_IMG, + strict=strict) + results_dict['synflare'] = synflareMetrics + + if 'outofcontext' in subsets: + outofcontextMetrics = evaluate_bravo(gt_data=gt_data, + submission_data=submission_data, + split_name='outofcontext', + additional_conds=[], + gt_suffix='_gt_labelTrainIds.png', + pred_suffix='_leftImg8bit_pred.png', + conf_suffix='_leftImg8bit_conf.png', + invalid_mask_suffix='_gt_invIds.png', + compute_semantic_metrics=True, + compute_OOD=False, + NUM_CLASSES=NUM_CLASSES, + ECE_NUM_BINS=ECE_NUM_BINS, + CONF_NUM_BINS=CONF_NUM_BINS, + SAMPLES_PER_IMG=SAMPLES_PER_IMG, + strict=strict) + results_dict['outofcontext'] = outofcontextMetrics + + return results_dict + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Evaluates submissions for the ELSA BRAVO Challenge.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('submission', help='path to submission tar file') + parser.add_argument('--gt', required=True, help='path to ground-truth tar file') + parser.add_argument('--results', default='results_old.json', help='JSON file to store the computed metrics') + parser.add_argument('--subsets', default=['ALL'], nargs='+', choices=BRAVO_SUBSETS+['ALL'], metavar='SUBSETS', + help='tasks to evaluate: ALL for all tasks, or one or more items from this list, separated ' + 'by spaces: ' + ' '.join(BRAVO_SUBSETS)) + parser.add_argument('--strict', action='store_true', help='disables comparison mode, where the script will ' + 'process the void class (label 255) like the new script.') + parser.add_argument('--debug', action='store_true', help='enables extra verbose debug output') + parser.add_argument('--quiet', action='store_true', help='prints only errors and warnings') + args = parser.parse_args() + + level = logging.WARNING if args.quiet else (logging.DEBUG if args.debug else logging.INFO) + logging.basicConfig(level=level, format='%(levelname)s: %(message)s') + + if args.strict: + logger.warning('comparison mode is disabled, results will not be comparable to the new script.') + + logger.info('evaluating...') + if args.subsets == ['ALL']: + args.subsets = BRAVO_SUBSETS + results = evaluate_method(args.gt, args.submission, + evaluationParams={'subsets': args.subsets, 'strict': args.strict}) + + if not args.quiet: + logger.info('results:') + print(json.dumps(results, indent=4)) + + logger.info('saving results...') + with open(args.results, 'wt', encoding='utf-8') as json_file: + json.dump(results, json_file, indent=4) + + logger.info('done.') diff --git a/bravo_toolkit/util/sample_gt_pixels.py b/bravo_toolkit/util/sample_gt_pixels.py new file mode 100644 index 0000000..4b5c3ba --- /dev/null +++ b/bravo_toolkit/util/sample_gt_pixels.py @@ -0,0 +1,221 @@ +import argparse +from contextlib import closing +import io +import logging +import os +import sys +import tarfile +import time + +import numpy as np +from tqdm import tqdm + +from bravo_toolkit.codec.bravo_codec import _compress, _decompress +from bravo_toolkit.codec.bravo_tarfile import (SAMPLES_SUFFIX, SPLIT_PREFIX, SPLIT_TO_GT_SUFFIX, SPLIT_TO_MASK_SUFFIX, + tar_extract_file, tar_extract_grayscale) + + +SAMPLES_PER_IMG = 100_000 + + +logger = logging.getLogger('bravo_toolkit') + + +def tqqdm(iterable, *args, **kwargs): + if logger.getEffectiveLevel() <= logging.INFO: + return tqdm(iterable, *args, **kwargs) + return iterable + + +def sample_gt_pixels(gt_file, samples_per_image, seed): + ''' + Sample the ground-truth pixels from a single image. The samples are taken after filtering out void pixels. + + Args: + gt_file: 2D numpy array of ground-truth pixel values + samples_per_image: number of samples to take + seed: random seed for sampling + + Returns: + sampled_indices: unsorted 1D numpy array of sampled indices + ''' + gt_file = gt_file.ravel() + non_void = np.nonzero(gt_file != 255)[0] + if non_void.size > samples_per_image: + np.random.seed(seed) + non_void = np.random.choice(non_void, samples_per_image, replace=False) + non_void = np.sort(non_void) + return non_void + + +def encode_indices(indices): + ''' + Encode the sampled 1D indices into a byte array. + - The indices are assumed to be sorted + - The first index will be stored as as a 16-bit unsigned integer (little-endian) + - The differences between the indices will be stored as 8-bit unsigned integers + - If a difference is greater than 255, it will be stored as a zero followed by a 16-bit unsigned integer + ''' + # Encode the differences + differences = np.diff(indices) + # logger.debug('encode_indices - differences: first=%d, min=%d, max=%d', + # indices[0], differences.min(), differences.max()) + encoded = bytearray() + encoded.extend(int(indices.size).to_bytes(3, byteorder='little')) + encoded.extend(int(indices[0]).to_bytes(3, byteorder='little')) + for diff in differences: + if diff == 0: + raise ValueError('repeated values in input array `indices` are not allowed') + elif diff <= 255: + encoded.append(diff) + else: + # 3-byte encoding for differences greater than 255 (0, low byte, high byte) + encoded.append(0) + encoded.extend(int(diff).to_bytes(3, byteorder='little')) + encoded = _compress(encoded) + return encoded + + +def decode_indices(encoded): + ''' + Decode a byte array into the sampled 1D indices, using the encoding scheme in `_encode_indices`. + The length of the indices is assumed to be known a priori. + ''' + encoded = _decompress(encoded) + length = int.from_bytes(encoded[:3], byteorder='little') + indices = np.empty(length, dtype=np.int32) + indices[0] = index = int.from_bytes(encoded[3:6], byteorder='little') + encoded_len = len(encoded) + i = 1 + e = 6 + while e < encoded_len: + diff = encoded[e] + if diff == 0: + diff = int.from_bytes(encoded[e+1:e+4], byteorder='little') + e += 4 + else: + e += 1 + index += diff + indices[i] = index + i += 1 + if i < length: + raise ValueError('decode_indices - unexpected end of encoded data ' + f'(encoded_len={encoded_len}, i={i}, length={length})') + return indices + + +def sample_all_gt_pixels(gt_tar_path, sample_tar_path, samples_per_image, seed, check=False): + # Acquire data from the tar files... + with closing(tarfile.open(gt_tar_path, "r")) as gt_tar: + logger.info("Listing input files...") + gt_tar_members = gt_tar.getmembers() + + logger.info('sample_gt_pixels - reading ground truth files...') + gt_files = [] + for member in tqqdm(gt_tar_members): + # Skip non-files (directories, etc.) + if not member.isfile(): + continue + + # Determine the split and base path for prediction files + base_split = "" + base_path = "" + for split, gt_suffix in SPLIT_TO_GT_SUFFIX.items(): + if member.name.startswith(SPLIT_PREFIX.format(split=split)): + base_split = split + if member.name.endswith(gt_suffix): + base_path = member.name[:-len(gt_suffix)] + break + if not base_path: + if not base_split or not member.name.endswith(SPLIT_TO_MASK_SUFFIX[base_split]): + logger.error("Unexpected file: `%s`", member.name) + sys.exit(1) + continue + + # Determine the file base and corresponding ground-truth file + gt_name = base_path + SPLIT_TO_GT_SUFFIX[base_split] + gt_member = gt_tar.getmember(gt_name) + gt_file = tar_extract_grayscale(gt_tar, gt_member, 'ground-truth') + gt_files.append((gt_name, gt_file, base_path)) + + if check: + logger.info('sample_gt_pixels - checking %s pixels with seed %d...', f'{samples_per_image:_}', seed) + else: + logger.info('sample_gt_pixels - sampling %s pixels with seed %d...', f'{samples_per_image:_}', seed) + all_nan = all_filled = 0 + with tarfile.open(sample_tar_path, "r" if check else "w") as sample_tar: + for gt_name, gt_file, base_path in tqqdm(gt_files): + assert base_path.startswith('bravo_') + # Sample and compress the indexes + gt_samples = sample_gt_pixels(gt_file, samples_per_image, seed) + # gt_samples = gt_samples.astype(np.int32).tobytes() + gt_encoded = encode_indices(gt_samples) + + encoded_filename = base_path + SAMPLES_SUFFIX + if check: + # Read and decode the samples from the tar file + sample_member = sample_tar.getmember(encoded_filename) + sample_data = tar_extract_file(sample_tar, sample_member) + if gt_file.size > samples_per_image: + assert gt_encoded == sample_data + sample_decoded = decode_indices(sample_data) + assert np.array_equal(gt_samples, sample_decoded) + all_filled += 1 + else: + assert np.isnan(gt_encoded) + assert np.isnan(sample_data) + all_nan += 1 + else: + # Write the samples to the tar file + tarinfo = tarfile.TarInfo(name=encoded_filename) + tarinfo.size = len(gt_encoded) + # tarinfo.uid = current_uid + # tarinfo.gid = current_gid + tarinfo.mtime = time.time() + sample_tar.addfile(tarinfo, io.BytesIO(gt_encoded)) + if check: + logger.info('sample_gt_pixels - %d small images, %d large images, %d total images - all matched', + all_nan, all_filled, all_nan + all_filled) + + +def main(): + parser = argparse.ArgumentParser( + description='Evaluates submissions for the ELSA BRAVO Challenge.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + # parser.add_argument('submission', default=default_submission, help='path to submission tar file') + # parser.add_argument('--gt', default=default_gt, help='path to ground-truth tar file') + parser.add_argument('--gt', help='path to ground-truth tar file') + parser.add_argument('--results', help='tar file to store the samples') + parser.add_argument('--check', help='checks the samples in the tar file') + parser.add_argument('--samples_per_image', type=int, default=SAMPLES_PER_IMG, help='number of samples per image') + parser.add_argument('--seed', type=int, default=1, help='seed for the random sampling') + parser.add_argument('--debug', action='store_true', help='enables extra verbose debug output') + parser.add_argument('--quiet', action='store_true', help='prints only errors and warnings') + args = parser.parse_args() + + level = logging.WARNING if args.quiet else (logging.DEBUG if args.debug else logging.INFO) + logging.basicConfig(level=level, format='%(levelname)s: %(message)s') + + if args.gt is None or (args.results is None and args.check is None): + logger.error('--gt and --results are required unless --test is specified') + sys.exit(1) + if args.results is not None and args.check is not None: + logger.error('--results and --check are mutually exclusive') + sys.exit(1) + if args.results is not None: + try: + sample_all_gt_pixels(args.gt, args.results, args.samples_per_image, args.seed) + except Exception as e: + if os.path.exists(args.results): + os.remove(args.results) + raise e + elif args.check is not None: + sample_all_gt_pixels(args.gt, args.check, args.samples_per_image, args.seed, check=True) + else: + assert False, 'unreachable code' + + logger.info('done.') + + +if __name__ == '__main__': + main() diff --git a/pylintrc b/pylintrc new file mode 100644 index 0000000..bbbdd05 --- /dev/null +++ b/pylintrc @@ -0,0 +1,592 @@ +# pylint: skip-file +# type: ignore +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-whitelist= + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS + +# Add files or directories matching the regex patterns to the blacklist. The +# regex matches against base names, not paths. +ignore-patterns= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +init-hook= + from pylint.config import find_default_config_files + import os, sys + rcfile = next(find_default_config_files(), None) + if rcfile is not None: sys.path.append(os.path.dirname(rcfile)) + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins=pylint.extensions.no_self_use + +# Pickle collected data for later comparisons. +persistent=yes + +# Specify a configuration file. +#rcfile= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=missing-docstring, # Those are general disables + no-else-break, + no-else-continue, + no-else-raise, + no-else-return, + raising-format-tuple, # Those are project-specific disables + too-few-public-methods, + unnecessary-pass, + not-callable, # Those are due to bugs in current version of pylint +# import-error, +# no-name-in-module, +# +# print-statement, # Those are potential disables +# parameter-unpacking, +# unpacking-in-except, +# old-raise-syntax, +# backtick, +# long-suffix, +# old-ne-operator, +# old-octal-literal, +# import-star-module-level, +# non-ascii-bytes-literal, +# raw-checker-failed, +# bad-inline-option, +# locally-disabled, +# file-ignored, +# suppressed-message, +# useless-suppression, +# deprecated-pragma, +# use-symbolic-message-instead, +# apply-builtin, +# basestring-builtin, +# buffer-builtin, +# cmp-builtin, +# coerce-builtin, +# execfile-builtin, +# file-builtin, +# long-builtin, +# raw_input-builtin, +# reduce-builtin, +# standarderror-builtin, +# unicode-builtin, +# xrange-builtin, +# coerce-method, +# delslice-method, +# getslice-method, +# setslice-method, +# no-absolute-import, +# old-division, +# dict-iter-method, +# dict-view-method, +# next-method-called, +# metaclass-assignment, +# indexing-exception, +# raising-string, +# reload-builtin, +# oct-method, +# hex-method, +# nonzero-method, +# cmp-method, +# input-builtin, +# round-builtin, +# intern-builtin, +# unichr-builtin, +# map-builtin-not-iterating, +# zip-builtin-not-iterating, +# range-builtin-not-iterating, +# filter-builtin-not-iterating, +# using-cmp-argument, +# eq-without-hash, +# div-method, +# idiv-method, +# rdiv-method, +# exception-message-attribute, +# invalid-str-codec, +# sys-max-int, +# bad-python3-import, +# deprecated-string-function, +# deprecated-str-translate-call, +# deprecated-itertools-function, +# deprecated-types-field, +# next-method-defined, +# dict-items-not-iterating, +# dict-keys-not-iterating, +# dict-values-not-iterating, +# deprecated-operator-function, +# deprecated-urllib-function, +# xreadlines-attribute, +# deprecated-sys-function, +# exception-escape, +# comprehension-escape, + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'error', 'warning', 'refactor', and 'convention' +# which contain the number of messages in each category, as well as 'statement' +# which is the total number of statements analyzed. This score is used by the +# global evaluation report (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit + + +[LOGGING] + +# Format style used to check logging format string. `old` means using % +# formatting, `new` is for `{}` formatting,and `fstr` is for f-strings. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members=numpy.*,torch.*, torchvision.*,cv2.*,cv.* + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,pytorch_lightning.trainer.properties.TrainerProperties + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=120 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. +argument-rgx=[_]?[a-z][A-Z0-9_]* + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. +attr-rgx=[a-z_][a-z0-9_]* + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. +#class-attribute-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. +const-rgx=[_]?[A-Za-z][a-z0-9_]* + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. +method-rgx=(_?[a-z][a-z0-9_]*)|(t_[A-Za-z0-9_]+)|(__[a-z0-9_]+__) + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. +variable-rgx=[A-Za-z][a-z0-9_]* + + +[STRING] + +# This flag controls whether the implicit-str-concat-in-sequence should +# generate a warning on implicit string concatenation in sequences defined over +# several lines. +check-str-concat-over-line-jumps=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled). +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled). +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[DESIGN] + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=15 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=15 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=10 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=10 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "BaseException, Exception". +overgeneral-exceptions=builtins.BaseException, + builtins.Exception diff --git a/requirements-all.txt b/requirements-all.txt new file mode 100644 index 0000000..77e7452 --- /dev/null +++ b/requirements-all.txt @@ -0,0 +1,11 @@ +numpy +opencv-python-headless +scipy +tqdm +zstandard +# For development and testing +scikit-learn +matplotlib +mypy +pillow +pytest diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ba70685 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +numpy==1.26.* +opencv-python-headless==4.9.* +scipy==1.13.* +tqdm==4.66.* +zstandard==0.22.* diff --git a/tests/bravo_test_images.json b/tests/bravo_test_images.json new file mode 100644 index 0000000..a744995 --- /dev/null +++ b/tests/bravo_test_images.json @@ -0,0 +1,28 @@ +{ + "ACDC_snow_test_GP010122_GP010122_frame_000085_pred.png": "ACDC", + "ACDC_fog_test_GOPR0478_GOPR0478_frame_000451_pred.png": "ACDC", + "ACDC_night_test_GOPR0594_GOPR0594_frame_000715_pred.png": "ACDC", + "ACDC_rain_test_GOPR0572_GOPR0572_frame_000692_pred.png": "ACDC", + "SMIYC_RoadAnomaly21_images_airplane0001_pred.png": "SMIYC", + "SMIYC_RoadAnomaly21_images_boat_trailer0004_pred.png": "SMIYC", + "SMIYC_RoadAnomaly21_images_carriage0001_pred.png": "SMIYC", + "SMIYC_RoadAnomaly21_images_cow0013_pred.png": "SMIYC", + "SMIYC_RoadAnomaly21_images_tent0000_pred.png": "SMIYC", + "SMIYC_RoadAnomaly21_images_validation0007_pred.png": "SMIYC", + "SMIYC_RoadAnomaly21_images_zebra0000_pred.png": "SMIYC", + "outofcontext_munster_munster_000026_000019_pred.png": "outofcontext", + "outofcontext_frankfurt_frankfurt_000001_048654_pred.png": "outofcontext", + "outofcontext_lindau_lindau_000024_000019_pred.png": "outofcontext", + "synflare_munster_munster_000111_000019_pred.png": "synflare", + "synflare_frankfurt_frankfurt_000001_014565_pred.png": "synflare", + "synflare_lindau_lindau_000015_000019_pred.png": "synflare", + "synobjs_elephant_127_pred.png": "synobjs", + "synobjs_toilet_487_pred.png": "synobjs", + "synobjs_tiger_437_pred.png": "synobjs", + "synobjs_flamingo_9_pred.png": "synobjs", + "synobjs_sofa_299_pred.png": "synobjs", + "synobjs_billboard_212_pred.png": "synobjs", + "synrain_munster_munster_000018_000019_pred.png": "synrain", + "synrain_frankfurt_frankfurt_000001_046272_pred.png": "synrain", + "synrain_lindau_lindau_000023_000019_pred.png": "synrain" +} \ No newline at end of file diff --git a/tests/extract_test_files.py b/tests/extract_test_files.py new file mode 100644 index 0000000..9d55e6a --- /dev/null +++ b/tests/extract_test_files.py @@ -0,0 +1,118 @@ +import os +import json +import tarfile + + +# Configuration variables +script_dir = os.path.dirname(__file__) +submission_dir = os.path.join(script_dir, 'bravo_test_images') +gt_dir = os.path.join(script_dir, 'bravo_test_images') +metadata_file = os.path.join(script_dir, 'bravo_test_images.json') + +submission_tarfile = os.path.expanduser('~/shared/thvu/BRAVO/challenge/toolkit/submissions_dgssclip/bravo_submission.tar') +gt_tarfile = os.path.expanduser('~/shared/thvu/BRAVO/challenge/toolkit/bravo_GT.tar') + +# Suffix and prefix mappings +SPLIT_TO_GT_SUFFIX = { + 'ACDC': '_gt_labelTrainIds.png', + 'SMIYC': '_labels_semantic_fake.png', + 'outofcontext': '_gt_labelTrainIds.png', + 'synobjs': '_gt.png', + 'synflare': '_gt_labelTrainIds.png', + 'synrain': '_gt_labelTrainIds.png', +} +SPLIT_TO_MASK_SUFFIX = { + 'ACDC': '_gt_invIds.png', + 'SMIYC': '_labels_semantic.png', + 'outofcontext': '_gt_invIds.png', + 'synobjs': '_mask.png', + 'synflare': '_gt_invIds.png', + 'synrain': '_gt_invIds.png', +} +SPLIT_TO_PRED_SUFFIX = { + 'ACDC': '_rgb_anon_pred.png', + 'SMIYC': '_pred.png', + 'outofcontext': '_leftImg8bit_pred.png', + 'synobjs': '_pred.png', + 'synflare': '_leftImg8bit_pred.png', + 'synrain': '_leftImg8bit_pred.png', +} +SPLIT_PREFIX = 'bravo_{split}/' + + +BRAVO_CODEC_TEST_IMAGES = [ + 'bravo_ACDC/snow/test/GP010122/GP010122_frame_000085_rgb_anon_pred.png', + 'bravo_ACDC/fog/test/GOPR0478/GOPR0478_frame_000451_rgb_anon_pred.png', + 'bravo_ACDC/night/test/GOPR0594/GOPR0594_frame_000715_rgb_anon_pred.png', + 'bravo_ACDC/rain/test/GOPR0572/GOPR0572_frame_000692_rgb_anon_pred.png', + 'bravo_SMIYC/RoadAnomaly21/images/airplane0001_pred.png', + 'bravo_SMIYC/RoadAnomaly21/images/boat_trailer0004_pred.png', + 'bravo_SMIYC/RoadAnomaly21/images/carriage0001_pred.png', + 'bravo_SMIYC/RoadAnomaly21/images/cow0013_pred.png', + 'bravo_SMIYC/RoadAnomaly21/images/tent0000_pred.png', + 'bravo_SMIYC/RoadAnomaly21/images/validation0007_pred.png', + 'bravo_SMIYC/RoadAnomaly21/images/zebra0000_pred.png', + 'bravo_outofcontext/munster/munster_000026_000019_leftImg8bit_pred.png', + 'bravo_outofcontext/frankfurt/frankfurt_000001_048654_leftImg8bit_pred.png', + 'bravo_outofcontext/lindau/lindau_000024_000019_leftImg8bit_pred.png', + 'bravo_synflare/munster/munster_000111_000019_leftImg8bit_pred.png', + 'bravo_synflare/frankfurt/frankfurt_000001_014565_leftImg8bit_pred.png', + 'bravo_synflare/lindau/lindau_000015_000019_leftImg8bit_pred.png', + 'bravo_synobjs/elephant/127_pred.png', + 'bravo_synobjs/toilet/487_pred.png', + 'bravo_synobjs/tiger/437_pred.png', + 'bravo_synobjs/flamingo/9_pred.png', + 'bravo_synobjs/sofa/299_pred.png', + 'bravo_synobjs/billboard/212_pred.png', + 'bravo_synrain/munster/munster_000018_000019_leftImg8bit_pred.png', + 'bravo_synrain/frankfurt/frankfurt_000001_046272_leftImg8bit_pred.png', + 'bravo_synrain/lindau/lindau_000023_000019_leftImg8bit_pred.png', +] + + +def extract_file(tar, path_in_tar, target_path): + """ Helper function to extract a file from a tar file """ + member = tar.getmember(path_in_tar) + f = tar.extractfile(member) + contents = f.read() + f.close() + with open(target_path, 'wb') as f: + f.write(contents) + print(f" {member.name} -> {target_path}") + + +print("Extracting files...") +metadata = {} +with tarfile.open(submission_tarfile, 'r') as sub_tar, tarfile.open(gt_tarfile, 'r') as gt_tar: + for pred_path in BRAVO_CODEC_TEST_IMAGES: + print(f"Processing {pred_path}...") + for s in SPLIT_TO_GT_SUFFIX: + if pred_path.startswith(SPLIT_PREFIX.format(split=s)): + split = s + break + else: + raise ValueError(f"Could not determine split for file {pred_path}") + + conf_path = pred_path.replace('_pred.png', '_conf.png') + + # Extract prediction and confidence files + pred_target_file = (pred_path[6:-len(SPLIT_TO_PRED_SUFFIX[split])] + '_pred.png').replace('/', '_') + conf_target_file = pred_target_file.replace('_pred.png', '_conf.png') + extract_file(sub_tar, pred_path, os.path.join(submission_dir, pred_target_file)) + extract_file(sub_tar, conf_path, os.path.join(submission_dir, conf_target_file)) + + # Extracts ground-truth files + gt_file = pred_path[:-len(SPLIT_TO_PRED_SUFFIX[split])] + SPLIT_TO_GT_SUFFIX[split] + mask_file = pred_path[:-len(SPLIT_TO_PRED_SUFFIX[split])] + SPLIT_TO_MASK_SUFFIX[split] + gt_target_file = pred_target_file.replace('_pred.png', '_gt.png') + mask_target_file = pred_target_file.replace('_pred.png', '_mask.png') + extract_file(gt_tar, gt_file, os.path.join(gt_dir, gt_target_file)) + extract_file(gt_tar, mask_file, os.path.join(gt_dir, mask_target_file)) + + metadata[pred_target_file] = split + +print("Writing metadata file...") +with open(metadata_file, 'wt', encoding='utf-8') as meta: + json.dump(metadata, meta, indent=4) + +print("Done.") diff --git a/tests/test_bravo_codec.py b/tests/test_bravo_codec.py new file mode 100644 index 0000000..e8d63b6 --- /dev/null +++ b/tests/test_bravo_codec.py @@ -0,0 +1,173 @@ +import json +import os + +import numpy as np +import pytest + +from bravo_toolkit.codec.bravo_codec import bravo_decode, bravo_encode +from bravo_toolkit.codec.bravo_tarfile import extract_grayscale, extract_image +from bravo_toolkit.util.sample_gt_pixels import SAMPLES_PER_IMG, sample_gt_pixels + + +# --------- Utilities --------- + +def bravo_simulation_test(*, seed=42, array_shape=(1000, 2000), n_classes=19, n_regions=50, n_indices=SAMPLES_PER_IMG, + void_chance=0.2, void_class=255, input_images=None): + np.random.seed(seed) + + if input_images is None: + # Creates a random but "realistic" class array with a Voronoi tessellation + n_rows, n_cols = array_shape + seeds = np.column_stack([ + np.random.randint(0, n_cols, n_regions), + np.random.randint(0, n_rows, n_regions) + ]) + classes = np.random.randint(0, n_classes, n_regions) + classes = np.where(np.random.rand(n_regions) < void_chance, void_class, classes) + rows = np.arange(n_rows) + cols = np.arange(n_cols) + # ...computes the distances of coordinates to each seed and finds the closest one + row_distances = (rows[:, None] - seeds[:, 0])**2 + col_distances = (cols[:, None] - seeds[:, 1])**2 + distances = row_distances[:, None, :] + col_distances[None, :, :] # Squared Euclidean distance + voronoi = np.argmin(distances, axis=2) + # ...assigns the class to each region + class_array = classes[voronoi].astype(np.uint8) + + # Generate a somewhat "realistic" confidence array: random but smooth + confidences = np.random.rand(n_regions) * (1. - 1./n_classes) + 1./n_classes + confidence_array = confidences[voronoi] + confidence_array += np.random.normal(0, 0.02, size=confidence_array.shape) + confidence_array = np.clip(confidence_array, 1./n_classes, 1.) + confidence_array = confidence_array.astype(np.float32) + confidence_indices = sample_gt_pixels(confidence_array, n_indices, seed=seed) + else: + class_array, confidence_array, confidence_indices = input_images + confidence_array = confidence_array.astype(np.float32) + confidence_array = np.clip(confidence_array, 1./n_classes, 1.) + + confidence_array = np.floor(confidence_array * 65536).astype(np.uint16) + + # Encode the arrays + encoded_bytes = bravo_encode(class_array, confidence_array, confidence_indices=confidence_indices) + confidence_slice = slice(None) if confidence_indices is None else confidence_indices + confidence_sample = confidence_array.ravel()[confidence_slice] + + # The computations below have to be reverified if the data types change + assert confidence_array.dtype == np.uint16 and confidence_sample.dtype == np.uint16 + original_size = class_array.nbytes + confidence_array.nbytes + raw_size = class_array.nbytes + confidence_array.nbytes + sampled_size = class_array.nbytes + confidence_sample.nbytes + encoded_size = len(encoded_bytes) + + results = { + "original size": original_size, + "raw size": raw_size, + "sampled size": sampled_size, + "encoded size": encoded_size, + "original/encoded ratio": original_size / encoded_size, + "raw/encoded ratio": raw_size / encoded_size, + "sampled/encoded ratio": sampled_size / encoded_size, + } + + # Decode the arrays + decoded_class_array, decoded_confidence_array, _ = bravo_decode(encoded_bytes) + + # Verify that the decoded class array matches the original + assert np.all(decoded_class_array == class_array), "Class arrays do not match" + + confidence_sample = confidence_sample / 65536.0 + decoded_confidence_array = decoded_confidence_array / 65536.0 + + # Verify that the decoded confidence array is close to the original within the quantization tolerance + tolerance = 1 / 65536.0 + results["tolerance"] = tolerance + results["max_abs_diff"] = np.max(np.abs(decoded_confidence_array - confidence_sample)) + results["match_within_tolerance"] = np.allclose(decoded_confidence_array, confidence_sample, atol=tolerance) + + assert results["match_within_tolerance"], f"codec failed within tolerance of {results['tolerance']}: " \ + f"max. diff: {results['max_abs_diff']}" + return results + +# --------- Utilities --------- + + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +IMAGES_DIR = os.path.join(SCRIPT_DIR, "bravo_test_images") +IMAGES_METADATA_FILE = os.path.join(SCRIPT_DIR, 'bravo_test_images.json') +with open(IMAGES_METADATA_FILE, 'rt', encoding='utf-8') as imff: + IMAGES_METADATA = json.load(imff) + IMAGES_LIST = list(IMAGES_METADATA.keys()) + + +def get_real_data(pred_file): + '''Gets the ground-truth labels and scores from a real data case.''' + conf_file = pred_file.replace('_pred.png', '_conf.png') + + # Loads ground-truth and confidence images + with open(os.path.join(IMAGES_DIR, pred_file), 'rb') as f: + pred = extract_grayscale(f, 'class prediction') + with open(os.path.join(IMAGES_DIR, conf_file), 'rb') as f: + conf = extract_image(f, 'confidence') + + # Compare toolkit auc with reference auc + return pred, conf + + +# --------- BRAVO_CODEC --------- + +def test_bravo_codec_default_test(): + results = bravo_simulation_test() + assert results["original/encoded ratio"] > 10 + assert results["sampled/encoded ratio"] > 5 + + +bravo_codec_test_cases = [ + # seed, array_shape, n_classes, n_regions, sample_size, + # Default test case + (42, (1000, 2000), 19, 50, 100_000), + # Different seeds and sizes + (43, (256, 128), 19, 50, 100_000), + (44, (1001, 64), 19, 50, 100_000), + (45, (317, 2030), 19, 50, 100_000), + (46, (31, 510), 19, 50, 100_000), + # Different number of classes, regions, and sample sizes + (47, (1000, 2000), 13, 50, 100_000), + (48, (1000, 2000), 167, 50, 100_000), + (49, (1000, 2000), 19, 17, 100_000), + (50, (1000, 2000), 19, 99, 100_000), + (51, (1000, 2000), 19, 99, 10_000), + (52, (1000, 2000), 19, 99, 1000_000), + # Extreme cases + (53, (1, 1), 19, 50, 100_000), + (54, (1, 1024), 19, 50, 100_000), + (55, (1024, 1), 19, 50, 100_000), + (56, (2, 2), 19, 50, 100_000), + (57, (1024, 1024), 2, 50, 100_000), + (58, (1024, 1024), 19, 1, 100_000), + (59, (1024, 1024), 19, 500, 100_000), +] + + +@pytest.mark.parametrize('seed, array_shape, n_classes, n_regions, sample_size', bravo_codec_test_cases) +def test_bravo_codec_test_cases(seed, array_shape, n_classes, n_regions, sample_size): + results = bravo_simulation_test(seed=seed, array_shape=array_shape, n_classes=n_classes, n_regions=n_regions, + n_indices=sample_size) + if array_shape[0] * array_shape[1] > 100: + assert results["original/encoded ratio"] > array_shape[0] * array_shape[1] / sample_size / 2 + assert results["sampled/encoded ratio"] > 1 + + +def test_bravo_codec_large_array(): + results = bravo_simulation_test(array_shape=(4096, 4097), n_classes=19, n_regions=50) + assert results["original/encoded ratio"] > 10 + assert results["sampled/encoded ratio"] > 5 + + +@pytest.mark.parametrize('pred_file', IMAGES_LIST) +def test_bravo_codec_true_images(pred_file): + pred_image, conf_image = get_real_data(pred_file) + confidence_indices = sample_gt_pixels(conf_image, SAMPLES_PER_IMG, seed=1) + results = bravo_simulation_test(input_images=(pred_image, conf_image, confidence_indices)) + assert results["original/encoded ratio"] > 10 + assert results["sampled/encoded ratio"] > 5 diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..b56069f --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,604 @@ +import json +import os + +import numpy as np +from numpy.testing import assert_almost_equal, assert_array_almost_equal, assert_equal +import pytest +from sklearn.metrics import average_precision_score, roc_auc_score, roc_curve + +from bravo_toolkit.codec.bravo_tarfile import extract_grayscale, extract_image +from bravo_toolkit.eval.metrics import (_get_ece_original, _get_ece_reference, get_auprc, get_auroc, get_ece, + get_tp_fp_counts) + + +# --------- Utilities --------- + +def get_random_data(seed, allow_duplicates=False): + '''Simulates array of ground-truth labels and scores.''' + # Samples simulated data parameters + np.random.seed(seed) + while True: + sample_size = np.random.randint(2, 100) + alpha = np.random.randint(1, 10) + beta = np.random.randint(1, 10) + # Samples ground truth labels, and ensures that at least one positive and one negative example are present + y_true = np.random.randint(0, 2, size=sample_size) + y_true[np.random.randint(sample_size//2)] = 0 + y_true[np.random.randint(sample_size//2) + sample_size//2] = 1 + # Samples scores + y_alphas = np.where(y_true == 1, alpha, beta) + y_betas = np.where(y_true == 1, beta, alpha) + y_score = np.random.beta(y_alphas, y_betas) + if not allow_duplicates: + # Eliminates entries with duplicate scores + y_score, indices = np.unique(y_score, return_index=True) + y_true = y_true[indices] + if y_score.size >= 2: + break + return y_true, y_score + + +def get_raw_counts(y_true, y_score, deduplicate=False): + '''Gets the raw true and false positive counts from ground-truth labels and scores.''' + # Sorts data by score + sorted_indices = np.argsort(y_score) + y_true = y_true[sorted_indices] + y_score = y_score[sorted_indices] + # Computes true and false positive counts + tp_counts = y_true + fp_counts = 1 - y_true + if deduplicate: + # Merge counts for equal scores + _y_unique_scores, unique_indices = np.unique(y_score, return_index=True) + tp_counts = np.add.reduceat(tp_counts, unique_indices) + fp_counts = np.add.reduceat(fp_counts, unique_indices) + return tp_counts, fp_counts + + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +IMAGES_DIR = os.path.join(SCRIPT_DIR, "bravo_test_images") +IMAGES_METADATA_FILE = os.path.join(SCRIPT_DIR, 'bravo_test_images.json') +with open(IMAGES_METADATA_FILE, 'rt', encoding='utf-8') as imff: + IMAGES_METADATA = json.load(imff) + IMAGES_LIST = list(IMAGES_METADATA.keys()) + + +def get_real_data_raw(pred_file): + '''Gets the ground-truth labels, class predictions, and confidence scores from a real data case.''' + conf_file = pred_file.replace('_pred.png', '_conf.png') + gt_file = pred_file.replace('_pred.png', '_gt.png') + + # Loads ground-truth and confidence images + with open(os.path.join(IMAGES_DIR, pred_file), 'rb') as f: + pred = extract_grayscale(f, 'class prediction') + with open(os.path.join(IMAGES_DIR, conf_file), 'rb') as f: + conf = extract_image(f, 'confidence') + with open(os.path.join(IMAGES_DIR, gt_file), 'rb') as f: + gt = extract_grayscale(f, 'gt') + + return gt, pred, conf + + +def get_real_data(pred_file): + '''Gets the binarized labels (hit/miss) and confidence scores from a real data case.''' + gt, pred, conf = get_real_data_raw(pred_file) + return gt.ravel() == pred.ravel(), conf.ravel() + + +BIG = 1000000000 # Coefficient for big numbers + +# --------- AUROC --------- + +auroc_test_cases = [ + # tp_counts, fp_counts, expected_auc + # tp_counts => exact (not cummulative) counts of true positives for each increasing confidence level + # fp_counts => exact (not cummulative) counts of false positives for each increasing confidence level + # Perfectly wrong predictions + (1, np.array([100, 0, 0, 0, 0]), np.array([0, 0, 0, 0, 100]), 0.0), + (2, np.array([13, 1, 45, 57, 0]), np.array([0, 0, 0, 0, 100]), 0.0), + (3, np.array([13, 1, 45, 1, 57, 0, 0, 0, 0]), np.array([0, 0, 0, 0, 0, 30, 2, 4, 16]), 0.0), + # Mostly wrong predictions + (4, np.array([20, 20, 20, 20, 20]), np.array([0, 0, 0, 0, 100]), 0.1), # 20/100/2 (/2 due to trapezoidal rule) + (5, np.array([13, 1, 45, 1, 57]), np.array([0, 0, 0, 0, 100]), 0.24358974358974358), # 57/sum(tp)/2 + (6, np.array([100, 0, 0, 0, 0]), np.array([20, 20, 20, 20, 20]), 0.1), # 1/5/2 + # Perfectly correct predictions + (7, np.array([0, 0, 0, 0, 100]), np.array([100, 0, 0, 0, 0]), 1.0), + (8, np.array([0, 0, 0, 0, 0, 30, 2, 4, 16]), np.array([13, 1, 45, 1, 57, 0, 0, 0, 0]), 1.0), + (9, np.array([0, 0, 0, 0, 100]), np.array([20, 20, 20, 20, 0]), 1.0), + (10, np.array([0, 0, 0, 0, 100]), np.array([13, 1, 45, 57, 0]), 1.0), + # Mostly correct predictions + (11, np.array([0, 0, 0, 0, 100]), np.array([20, 20, 20, 20, 20]), 0.9), # 1 - 1/5/2 + (12, np.array([0, 0, 0, 0, 100]), np.array([13, 1, 45, 1, 57]), 0.7564102564102564), # 1 - 57/sum(tp)/2 + # Corner cases + (13, np.array([100, 0, 0, 0, 0]), np.array([100, 0, 0, 0, 0]), 0.5), # Everything happens at the last step + (14, np.array([0, 0, 0, 0, 100]), np.array([0, 0, 0, 0, 100]), 0.5), # Everything happens at the first step + # Perfect ignorance + (15, np.array([517, 517, 517, 517, 517]), np.array([517, 517, 517, 517, 517]), 0.5), + # Basic test case + (16, np.array([40, 10, 30, 20]), np.array([15, 25, 5, 35]), 0.3875), + # Basic test case - different shapes + (17, np.array([[40, 10, 30, 20]]), np.array([[15, 25, 5, 35]]), 0.3875), + (18, np.array([[40, 10], [30, 20]]), np.array([[15, 25], [5, 35]]), 0.3875), + # Basic test case - heterogeneos shapes + (19, np.array([[40, 10], [30, 20]]), np.array([15, 25, 5, 35]), 0.3875), + # Basic test case - large numbers + (20, np.array([40*BIG, 10*BIG, 30*BIG, 20*BIG]), np.array([15*BIG, 25*BIG, 5*BIG, 35*BIG]), 0.3875), + ] + + +@pytest.mark.parametrize('_n, tp, fp, expected_auc', auroc_test_cases) +def test_metrics_get_auroc_basic(_n, tp, fp, expected_auc): + auc, tpr, fpr = get_auroc(tp, fp) + expected_tpr = np.cumsum(tp.ravel()[::-1]) / np.sum(tp) + expected_fpr = np.cumsum(fp.ravel()[::-1]) / np.sum(fp) + expected_tpr = np.concatenate(([0.], expected_tpr)) + expected_fpr = np.concatenate(([0.], expected_fpr)) + assert_almost_equal(auc, expected_auc, decimal=5) + assert_array_almost_equal(tpr, expected_tpr) + assert_array_almost_equal(fpr, expected_fpr) + + +@pytest.mark.parametrize('seed', range(100)) +def test_metrics_get_auroc_reference(seed): + y_true, y_score = get_random_data(seed) + # Computes reference AUC + expected_auc = roc_auc_score(y_true, y_score) + # Computes toolkit AUC + tp_counts, fp_counts = get_raw_counts(y_true, y_score) + auc, tpr, fpr = get_auroc(tp_counts, fp_counts) + assert_almost_equal(auc, expected_auc, decimal=5) + assert_equal(tpr[0], 0.) + assert_equal(fpr[0], 0.) + assert_equal(tpr[-1], 1.) + assert_equal(fpr[-1], 1.) + assert_equal(tpr.size, fpr.size) + assert_equal(tpr, np.sort(tpr)) + assert_equal(fpr, np.sort(fpr)) + + +def test_metrics_get_auroc_empty_arrays(): + tp = np.array([]) + fp = np.array([]) + with pytest.raises(ValueError): + get_auroc(tp, fp) + + +def test_metrics_get_auroc_single_element(): + tp = np.array([1]) + fp = np.array([1]) + with pytest.raises(ValueError): + get_auroc(tp, fp) + + +def test_metrics_get_auroc_mismatched_lengths(): + np.random.seed(0) + tp = np.random.randint(1, 1000, size=100) + fp = np.random.randint(1, 1000, size=99) + with pytest.raises(ValueError): + get_auroc(tp, fp) + + +def test_metrics_get_auroc_zero_arrays(): + tp = np.zeros(17) + fp = np.zeros(17) + with pytest.raises(ValueError): + get_auroc(tp, fp) + + +@pytest.mark.parametrize('pred_file', IMAGES_LIST) +def test_metrics_get_auroc_real_data(pred_file): + y_true, y_score = get_real_data(pred_file) + # Computes reference AUC + auc_ref = roc_auc_score(y_true, y_score) + # Gets toolkit AUC with reference counts + tp_counts, fp_counts = get_raw_counts(y_true, y_score, deduplicate=True) + auc, _, _ = get_auroc(tp_counts, fp_counts) + # Gets toolkit AUC with toolkit counts + tp_counts2, fp_counts2 = get_tp_fp_counts(y_true, y_score, score_levels=65536) + auc2, _, _ = get_auroc(tp_counts2, fp_counts2) + # Compares values + assert_almost_equal(auc, auc_ref, decimal=5) + assert_almost_equal(auc, auc2, decimal=6) + + +def test_metrics_get_auroc_aggregated(): + '''Check that computing the metric pixel-wise over many images and aggregating the counts gives the same result.''' + score_levels = 65536 + tp_counts = np.zeros(score_levels) + fp_counts = np.zeros(score_levels) + y_trues = [] + y_scores = [] + for pred_file in IMAGES_LIST: + y_true, y_score = get_real_data(pred_file) + get_tp_fp_counts(y_true, y_score, tp_counts, fp_counts, score_levels=score_levels) + y_trues.append(y_true) + y_scores.append(y_score) + # Computes reference AUC using all data + y_trues = np.concatenate(y_trues) + y_scores = np.concatenate(y_scores) + auc_ref = roc_auc_score(y_trues, y_scores) + # Gets toolkit AUC with aggregated reference counts + auc, _, _ = get_auroc(tp_counts, fp_counts) + # Compares values + assert_almost_equal(auc, auc_ref, decimal=5) + + +@pytest.mark.skip(reason="this test currently fails, because aggressively subsampling individual images is not stable") +@pytest.mark.parametrize('pred_file', IMAGES_LIST) +def test_metrics_get_auroc_subsamples(pred_file): + sample_conf = 8 # This picks only 1/(sample_conf*sample_conf) of the data + sample_offsets = (0, sample_conf//2) + + y_true, y_score = get_real_data(pred_file) + aucs = [] + for sample_offset in sample_offsets: + # Subsamples data + y_true = y_true[sample_offset::sample_conf] + y_score = y_score[sample_offset::sample_conf] + # Gets toolkit AUC with toolkit counts + tp_counts, fp_counts = get_tp_fp_counts(y_true, y_score, score_levels=65536) + auc, _, _ = get_auroc(tp_counts, fp_counts) + aucs.append(auc) + # Compares values + assert_almost_equal(aucs[0], aucs[1], decimal=3) + + +def test_metrics_get_auroc_subsamples_aggregated(): + sample_conf = 8 + sample_offsets = (0, sample_conf//2) + score_levels = 65536 + + tp_counts_aggregated = [np.zeros(score_levels) for _ in sample_offsets] + fp_counts_aggregated = [np.zeros(score_levels) for _ in sample_offsets] + aucs = [] + for pred_file in IMAGES_LIST: + y_true, y_score = get_real_data(pred_file) + # Aggregation for AUROC + for s, sample_offset in enumerate(sample_offsets): + # Subsamples for AUROC + y_true_subsample = y_true[sample_offset::sample_conf] + y_score_subsample = y_score[sample_offset::sample_conf] + tp_counts_subsample, fp_counts_subsample = get_tp_fp_counts(y_true_subsample, y_score_subsample, + score_levels=score_levels) + tp_counts_aggregated[s] += tp_counts_subsample + fp_counts_aggregated[s] += fp_counts_subsample + + # Aggregated AUROC Assertions + aucs = [] + for s, _ in enumerate(sample_offsets): + auc, _, _ = get_auroc(tp_counts_aggregated[s], fp_counts_aggregated[s]) + aucs.append(auc) + + # Subsampled AUROC Assertions + assert_almost_equal(aucs[0], aucs[1], decimal=2) + + +# --------- FPR@95 --------- + +@pytest.mark.parametrize('pred_file', IMAGES_LIST) +def test_metrics_get_auroc_fpr_at_tpr_th(pred_file, tpr_th=0.95): + y_true, y_score = get_real_data(pred_file) + # Computes reference AUC + fpr_ref, tpr_ref, _ = roc_curve(y_true, y_score) + fpr_at_th_ref = fpr_ref[np.argmax(tpr_ref >= tpr_th)] + # Gets toolkit AUC with reference counts + tp_counts, fp_counts = get_raw_counts(y_true, y_score, deduplicate=True) + _auc, tpr, fpr = get_auroc(tp_counts, fp_counts) + tpr_th_i = np.searchsorted(tpr, tpr_th, 'left') + fpr_at_th = fpr[tpr_th_i] + # Compares values + assert_almost_equal(fpr_at_th, fpr_at_th_ref, decimal=5) + + +# --------- AUPRC --------- + +AUPRC = 0.5096161616161616 + + +auprc_test_cases = [ + # tp_counts, fp_counts, expected_auc + # tp_counts => exact (not cummulative) counts of true positives for each increasing confidence level + # fp_counts => exact (not cummulative) counts of false positives for each increasing confidence level + # Mostly wrong predictions + (1, np.array([10, 0, 0, 0, 0]), np.array([0, 0, 0, 0, 10]), 0.5), # AUPRC uses step-function integral + (2, np.array([40, 0, 0, 0, 0]), np.array([0, 10, 10, 10, 10]), 0.5), + # Perfectly correct predictions + (3, np.array([0, 0, 0, 0, 100]), np.array([100, 0, 0, 0, 0]), 1.0), + (4, np.array([0, 0, 0, 0, 0, 30, 2, 4, 16]), np.array([13, 1, 45, 1, 57, 0, 0, 0, 0]), 1.0), + (5, np.array([0, 0, 0, 0, 100]), np.array([20, 20, 20, 20, 0]), 1.0), + (6, np.array([0, 0, 0, 0, 100]), np.array([13, 1, 45, 57, 0]), 1.0), + # Mostly correct predictions + (7, np.array([0, 0, 0, 40, 0]), np.array([30, 0, 0, 0, 10]), 0.8), + # Corner cases + (8, np.array([100, 0, 0, 0, 0]), np.array([100, 0, 0, 0, 0]), 0.5), # Everything happens at the last step + (9, np.array([0, 0, 0, 0, 100]), np.array([0, 0, 0, 0, 100]), 0.5), # Everything happens at the first step + (10, np.array([0, 0, 0, 0, 0]), np.array([0, 0, 0, 0, 100]), 0.), # No true positives + (11, np.array([0, 0, 0, 0, 100]), np.array([0, 0, 0, 0, 0]), 1.), # No false positives + (12, np.array([0, 0, 0, 0, 0]), np.array([0, 0, 0, 0, 0]), 0.), # No samples + # Perfect ignorance + (13, np.array([517, 517, 517, 517, 517]), np.array([517, 517, 517, 517, 517]), 0.5), + # Basic test case + (14, np.array([40, 10, 30, 20]), np.array([15, 25, 5, 35]), AUPRC), + # Basic test case - different shapes + (15, np.array([[40, 10, 30, 20]]), np.array([[15, 25, 5, 35]]), AUPRC), + (16, np.array([[40, 10], [30, 20]]), np.array([[15, 25], [5, 35]]), AUPRC), + # Basic test case - heterogeneous shapes + (17, np.array([[40, 10], [30, 20]]), np.array([15, 25, 5, 35]), AUPRC), + # Basic test case - large numbers + (18, np.array([40*BIG, 10*BIG, 30*BIG, 20*BIG]), np.array([15*BIG, 25*BIG, 5*BIG, 35*BIG]), AUPRC), + ] + + +@pytest.mark.filterwarnings("ignore:.*divide:RuntimeWarning") +@pytest.mark.parametrize('_n, tp, fp, expected_auprc', auprc_test_cases) +def test_metrics_get_auprc_basic(_n, tp, fp, expected_auprc): + auprc, precision, recall = get_auprc(tp, fp) + tp_counts = np.cumsum(tp.ravel()[::-1]) + fp_counts = np.cumsum(fp.ravel()[::-1]) + pp_counts = tp_counts + fp_counts + expected_precision = np.where(pp_counts == 0, 1., tp_counts / pp_counts) + expected_recall = (tp_counts / tp_counts[-1]) if tp_counts[-1] > 0 else tp_counts + expected_precision = np.concatenate(([1.], expected_precision)) + expected_recall = np.concatenate(([0.], expected_recall)) + assert_almost_equal(auprc, expected_auprc, decimal=5) + assert_array_almost_equal(precision, expected_precision) + assert_array_almost_equal(recall, expected_recall) + + +@pytest.mark.parametrize('seed', range(100)) +def test_metrics_get_auprc_reference(seed): + y_true, y_score = get_random_data(seed) + # Computes reference auprc + expected_auprc = average_precision_score(y_true, y_score) + # Computes toolkit auprc + tp_counts, fp_counts = get_raw_counts(y_true, y_score) + auprc, precision, recall = get_auprc(tp_counts, fp_counts) + assert_almost_equal(auprc, expected_auprc, decimal=5) + assert_equal(precision[0], 1.) + assert_equal(recall[0], 0.) + assert_equal(recall[-1], 1.) + assert_equal(precision.size, recall.size) + assert_equal(recall, np.sort(recall)) + + +def test_metrics_get_auprc_empty_arrays(): + tp = np.array([]) + fp = np.array([]) + with pytest.raises(ValueError): + get_auprc(tp, fp) + + +def test_metrics_get_auprc_single_element(): + tp = np.array([1]) + fp = np.array([1]) + with pytest.raises(ValueError): + get_auprc(tp, fp) + + +def test_metrics_get_auprc_mismatched_lengths(): + np.random.seed(0) + tp = np.random.randint(1, 1000, size=100) + fp = np.random.randint(1, 1000, size=99) + with pytest.raises(ValueError): + get_auprc(tp, fp) + + +@pytest.mark.parametrize('pred_file', IMAGES_LIST) +def test_metrics_get_auprc_real_data(pred_file): + y_true, y_score = get_real_data(pred_file) + # Computes reference auprc + auprc_ref = average_precision_score(y_true, y_score) + # Gets toolkit auprc + tp_counts, fp_counts = get_raw_counts(y_true, y_score, deduplicate=True) + auprc, _, _ = get_auprc(tp_counts, fp_counts) + # Compares values + assert_almost_equal(auprc, auprc_ref, decimal=5) + + +def test_metrics_get_auprc_aggregated(): + '''Check that computing the metric pixel-wise over many images and aggregating the counts gives the same result.''' + score_levels = 65536 + tp_counts = np.zeros(score_levels) + fp_counts = np.zeros(score_levels) + y_trues = [] + y_scores = [] + for pred_file in IMAGES_LIST: + y_true, y_score = get_real_data(pred_file) + get_tp_fp_counts(y_true, y_score, tp_counts, fp_counts, score_levels=score_levels) + y_trues.append(y_true) + y_scores.append(y_score) + # Computes reference auprc using all data + y_trues = np.concatenate(y_trues) + y_scores = np.concatenate(y_scores) + auprc_ref = average_precision_score(y_trues, y_scores) + # Gets toolkit auprc with aggregated reference counts + auprc, _, _ = get_auprc(tp_counts, fp_counts) + # Compares values + assert_almost_equal(auprc, auprc_ref, decimal=5) + + +@pytest.mark.skip(reason="this test currently fails, because aggressively subsampling individual images is not stable") +@pytest.mark.parametrize('pred_file', IMAGES_LIST) +def test_metrics_get_auprc_subsamples(pred_file): + sample_conf = 8 # This picks only 1/(sample_conf*sample_conf) of the data + sample_offsets = (0, sample_conf//2) + + y_true, y_score = get_real_data(pred_file) + auprcs = [] + for sample_offset in sample_offsets: + # Subsamples data + y_true = y_true[sample_offset::sample_conf] + y_score = y_score[sample_offset::sample_conf] + # Gets toolkit PRC with toolkit counts + tp_counts, fp_counts = get_tp_fp_counts(y_true, y_score, score_levels=65536) + auprc, _, _ = get_auprc(tp_counts, fp_counts) + auprcs.append(auprc) + # Compares values + assert_almost_equal(auprcs[0], auprcs[1], decimal=4) + + +def test_metrics_get_auprc_subsamples_aggregated(): + sample_conf = 8 + sample_offsets = (0, sample_conf//2) + score_levels = 65536 + + tp_counts_aggregated = [np.zeros(score_levels) for _ in sample_offsets] + fp_counts_aggregated = [np.zeros(score_levels) for _ in sample_offsets] + auprcs = [] + + for pred_file in IMAGES_LIST: + y_true, y_score = get_real_data(pred_file) + # Aggregation for AUPRC + for s, sample_offset in enumerate(sample_offsets): + # Subsamples for AUPRC + y_true_subsample = y_true[sample_offset::sample_conf] + y_score_subsample = y_score[sample_offset::sample_conf] + tp_counts_subsample, fp_counts_subsample = get_tp_fp_counts(y_true_subsample, y_score_subsample, + score_levels=score_levels) + tp_counts_aggregated[s] += tp_counts_subsample + fp_counts_aggregated[s] += fp_counts_subsample + + # Aggregated AUPRC Assertions + for s, _ in enumerate(sample_offsets): + auprc, _, _ = get_auprc(tp_counts_aggregated[s], fp_counts_aggregated[s]) + auprcs.append(auprc) + + # Subsampled AUPRC Assertions + assert_almost_equal(auprcs[0], auprcs[1], decimal=2) + + +# --------- ECE --------- + +@pytest.mark.parametrize('pred_file', IMAGES_LIST) +def test_metrics_get_ece_real_data(pred_file): + score_levels = 65536 + y_true, y_pred, y_score = get_real_data_raw(pred_file) + + # Computes original ece + ece_original = _get_ece_original(label=y_true, pred=y_pred, conf=y_score, ECE_NUM_BINS=15, + CONF_NUM_BINS=score_levels, DEBIAS=True) + # Computes reference ece + y_score_continuous = (y_score.astype(np.float32) + 0.5) / score_levels + ece_ref15 = _get_ece_reference(y_true, y_pred, y_score_continuous, ece_bins=15) + ece_ref32 = _get_ece_reference(y_true, y_pred, y_score_continuous, ece_bins=32) + + # Convert to 1D + y_true = y_true.ravel() + y_pred = y_pred.ravel() + y_score = y_score.ravel() + + # Gets toolkit ece + y_score = y_score.ravel() + d_counts = np.zeros(score_levels, dtype=np.int64) + t_counts = np.zeros(score_levels, dtype=np.int64) + np.add.at(d_counts, y_score, y_true == y_pred) + np.add.at(t_counts, y_score, 1) + confidence_values = (np.linspace(0, score_levels-1, score_levels) + 0.5) / score_levels + ece15 = get_ece(d_counts, t_counts, confidence_values, bins=15) + ece32 = get_ece(d_counts, t_counts, confidence_values, bins=32) + + # Compares values + assert_almost_equal(ece_ref15, ece_original, decimal=6) + assert_almost_equal(ece15, ece_ref15, decimal=6) + assert_almost_equal(ece32, ece_ref32, decimal=6) + + +def test_metrics_get_ece_aggregated(): + '''Check that computing the metric pixel-wise over many images and aggregating the counts gives the same result.''' + score_levels = 65536 + + y_trues = [] + y_preds = [] + y_scores = [] + for pred_file in IMAGES_LIST: + y_t, y_p, y_s = get_real_data_raw(pred_file) + y_t = y_t.ravel() + y_p = y_p.ravel() + y_s = y_s.ravel() + y_trues.append(y_t) + y_preds.append(y_p) + y_scores.append(y_s) + y_trues = np.concatenate(y_trues) + y_preds = np.concatenate(y_preds) + y_scores = np.concatenate(y_scores) + + # Computes original ece using all data + ece_original = _get_ece_original(label=y_trues, pred=y_preds, conf=y_scores, ECE_NUM_BINS=15, + CONF_NUM_BINS=score_levels, DEBIAS=True) + + # Computes reference ece using all data + y_scores_continuous = (y_scores.astype(np.float32) + 0.5) / score_levels + ece_ref15 = _get_ece_reference(y_trues, y_preds, y_scores_continuous, ece_bins=15) + ece_ref32 = _get_ece_reference(y_trues, y_preds, y_scores_continuous, ece_bins=32) + + # Gets toolkit ece with aggregated reference counts + y_scores = y_scores.ravel() + d_counts = np.zeros(score_levels, dtype=np.int64) + t_counts = np.zeros(score_levels, dtype=np.int64) + np.add.at(d_counts, y_scores, y_trues == y_preds) + np.add.at(t_counts, y_scores, 1) + confidence_values = (np.linspace(0, score_levels-1, score_levels) + 0.5) / score_levels + ece15 = get_ece(d_counts, t_counts, confidence_values, bins=15) + ece32 = get_ece(d_counts, t_counts, confidence_values, bins=32) + + # Compares values + assert_almost_equal(ece_ref15, ece_original, decimal=6) + assert_almost_equal(ece15, ece_ref15, decimal=6) + assert_almost_equal(ece32, ece_ref32, decimal=6) + + +@pytest.mark.skip(reason="this test currently fails, because aggressively subsampling individual images is not stable") +@pytest.mark.parametrize('pred_file', IMAGES_LIST) +def test_metrics_get_ece_subsamples(pred_file): + sample_conf = 8 # This picks only 1/(sample_conf*sample_conf) of the data + sample_offsets = (0, sample_conf//2) + score_levels = 65536 + + y_true, y_pred, y_score = get_real_data_raw(pred_file) + eces = [] + for sample_offset in sample_offsets: + # Subsamples data + y_true_sample = y_true[sample_offset::sample_conf].ravel() + y_pred_sample = y_pred[sample_offset::sample_conf].ravel() + y_score_sample = y_score[sample_offset::sample_conf].ravel() + # Computes toolkit ECE with subsampled data + y_score_sample = y_score_sample.ravel() + d_counts = np.zeros(score_levels, dtype=np.int64) + t_counts = np.zeros(score_levels, dtype=np.int64) + np.add.at(d_counts, y_score_sample, y_true_sample == y_pred_sample) + np.add.at(t_counts, y_score_sample, 1) + confidence_values = (np.linspace(0., score_levels-1, score_levels) + 0.5) / score_levels + ece = get_ece(d_counts, t_counts, confidence_values, bins=15) + eces.append(ece) + + # Compares values + assert_almost_equal(eces[0], eces[1], decimal=4) + + +def test_metrics_get_ece_subsamples_aggregated(): + sample_conf = 8 + sample_offsets = (0, sample_conf//2) + score_levels = 65536 + + d_counts_aggregated = [np.zeros(score_levels, dtype=np.int64) for _ in sample_offsets] + t_counts_aggregated = [np.zeros(score_levels, dtype=np.int64) for _ in sample_offsets] + eces = [] + for pred_file in IMAGES_LIST: + y_true, y_pred, y_score = get_real_data_raw(pred_file) + # Aggregation for ECE + for s, sample_offset in enumerate(sample_offsets): + # Subsamples for ECE + y_true_subsample = y_true[sample_offset::sample_conf].ravel() + y_pred_subsample = y_pred[sample_offset::sample_conf].ravel() + y_score_subsample = y_score[sample_offset::sample_conf].ravel() + np.add.at(d_counts_aggregated[s], y_score_subsample, y_true_subsample == y_pred_subsample) + np.add.at(t_counts_aggregated[s], y_score_subsample, 1) + + # Aggregated ECE Assertions + confidence_values = (np.linspace(0., score_levels-1, score_levels) + 0.5) / score_levels + for s, _ in enumerate(sample_offsets): + ece = get_ece(d_counts_aggregated[s], t_counts_aggregated[s], confidence_values) + eces.append(ece) + + # Subsampled ECE Assertions + assert_almost_equal(eces[0], eces[1], decimal=2) diff --git a/tests/test_sample_gt_pixels.py b/tests/test_sample_gt_pixels.py new file mode 100644 index 0000000..b1a6ea2 --- /dev/null +++ b/tests/test_sample_gt_pixels.py @@ -0,0 +1,42 @@ +from itertools import chain, product + +import numpy as np +from numpy.testing import assert_equal +import pytest + +from bravo_toolkit.util.sample_gt_pixels import decode_indices, encode_indices, sample_gt_pixels + + +@pytest.mark.parametrize('samples_per_image, seed1, seed2', + chain(product([50_000, 100_000, 500_000], [1, 2, 3], [1, 2, 3]))) +def test_sample_gt_pixels_cases(samples_per_image, seed1, seed2): + np.random.seed(seed1) + gt_file = np.random.randint(0, 254, 2000_000).astype(np.uint8) + gt_file[np.random.choice(2000_000, np.random.randint(50_000, 200_000), replace=False)] = 255 + sampled_indices = sample_gt_pixels(gt_file, samples_per_image, seed2) + assert sampled_indices.size == samples_per_image + encoded = encode_indices(sampled_indices) + decoded = decode_indices(encoded) + assert_equal(sampled_indices, decoded) + + +@pytest.mark.skip(reason="with the encoding with 3 bytes instead of 2 the encoding does not overflow anymore") +def test_sample_gt_pixels_overflow(): + np.random.seed(1) + gt_file = np.random.randint(0, 256, 10_000_000).astype(np.uint8) + samples_per_image = 100 + sampled_indices = sample_gt_pixels(gt_file, samples_per_image, seed=1) + with pytest.raises( OverflowError): + _ = encode_indices(sampled_indices) + + +@pytest.mark.parametrize('fraction, seed1, seed2', + chain(product([0.55, 0.75], [1, 2, 3], [1, 2, 3]))) +def test_sample_gt_pixels_small(fraction, seed1, seed2): + np.random.seed(seed1) + gt_file = np.random.randint(0, 2, 100*100).astype(np.uint8) * 255 + samples_per_image = int(100*100*fraction) + sampled_indices = sample_gt_pixels(gt_file, samples_per_image, seed2) + encoded = encode_indices(sampled_indices) + decoded = decode_indices(encoded) + assert_equal(sampled_indices, decoded)