From ed05f6c56a9a2b39f9eb9daf0c13b2c961ca6727 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 4 Aug 2024 00:21:06 -0400 Subject: [PATCH 01/23] use requirements.txt for all versioning --- docker/Dockerfile-full | 32 ++++++++++---------------------- docker/Dockerfile-gui | 5 +---- docker/Dockerfile-headless | 11 ++--------- docs/docker.md | 2 +- environment.yml | 14 ++------------ requirements.txt | 16 ++++++++++++++++ setup.py | 12 ++++++------ 7 files changed, 38 insertions(+), 54 deletions(-) create mode 100644 requirements.txt diff --git a/docker/Dockerfile-full b/docker/Dockerfile-full index 541ea04..cf8bfe8 100644 --- a/docker/Dockerfile-full +++ b/docker/Dockerfile-full @@ -1,22 +1,18 @@ -FROM --platform=linux/amd64 nvidia/cuda:11.3.0-cudnn8-runtime-ubuntu20.04 +FROM --platform=linux/amd64 nvidia/cuda:11.5.2-cudnn8-devel-ubuntu20.04 # modified from here # https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile # Install some basic utilities -RUN apt-get update && apt-get install -y \ - curl \ - ca-certificates \ - sudo \ - git \ - bzip2 \ - libx11-6 \ +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + curl ca-certificates sudo git bzip2 libx11-6 \ + ffmpeg libsm6 libxext6 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-render-util0 libxcb-xinerama0 \ + libxcb-xkb-dev libxkbcommon-x11-0 libpulse-mainloop-glib0 ubuntu-restricted-extras libqt5multimedia5-plugins vlc \ + libkrb5-3 libgssapi-krb5-2 libkrb5support0 \ && rm -rf /var/lib/apt/lists/* # don't ask for location etc user input when building # this is for opencv, apparently -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y ffmpeg libsm6 libxext6 libxcb-icccm4 \ - libxcb-image0 libxcb-keysyms1 libxcb-render-util0 libxcb-xinerama0 libxcb-xkb-dev libxkbcommon-x11-0 \ - libpulse-mainloop-glib0 ubuntu-restricted-extras libqt5multimedia5-plugins vlc +RUN apt-get update && apt-get install -y # Create a working directory and data directory RUN mkdir /app @@ -35,18 +31,10 @@ RUN curl -sLo ~/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-py39 # install RUN conda install python=3.7 -y -RUN pip install setuptools --upgrade && pip install --upgrade pip -RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch +RUN pip install setuptools --upgrade && pip install --upgrade "pip<24.0" +RUN pip install torch==1.11.0+cu115 torchvision==0.12.0+cu115 -f https://download.pytorch.org/whl/torch_stable.html -# RUN pip install \ -# vidio - -RUN pip install "chardet<4.0" h5py kornia>=0.5 matplotlib numpy omegaconf>=2 "pandas<1.4" PySide2 \ - "scikit-learn<1.1" "scipy<1.8" tqdm pytorch_lightning>=1.5.10 opencv-python-headless vidio>=0.0.4 pytest \ - opencv-transforms - -# # needed for pandas for some reason ADD . /app/deepethogram WORKDIR /app/deepethogram ENV DEG_VERSION='full' -RUN pip install -e . --no-dependencies \ No newline at end of file +RUN pip install -e . \ No newline at end of file diff --git a/docker/Dockerfile-gui b/docker/Dockerfile-gui index 748f9cb..60855be 100644 --- a/docker/Dockerfile-gui +++ b/docker/Dockerfile-gui @@ -39,12 +39,9 @@ RUN pip install setuptools --upgrade && pip install --upgrade pip # TODO: REFACTOR CODE SO IT'S POSSIBLE TO RUN GUI WITHOUT TORCH RUN conda install pytorch cpuonly -c pytorch -RUN pip install "chardet<4.0" h5py matplotlib numpy omegaconf>=2 "pandas<1.4" PySide2 \ - "scikit-learn<1.1" "scipy<1.8" tqdm opencv-python-headless vidio>=0.0.4 pytest \ - opencv-transforms # # needed for pandas for some reason ADD . /app/deepethogram WORKDIR /app/deepethogram ENV DEG_VERSION='gui' -RUN pip install -e . --no-dependencies \ No newline at end of file +RUN pip install -e . \ No newline at end of file diff --git a/docker/Dockerfile-headless b/docker/Dockerfile-headless index 66d6dec..d0830f5 100644 --- a/docker/Dockerfile-headless +++ b/docker/Dockerfile-headless @@ -1,4 +1,4 @@ -FROM --platform=linux/amd64 nvidia/cuda:11.3.0-cudnn8-runtime-ubuntu20.04 +FROM --platform=linux/amd64 nvidia/cuda:11.5.2-cudnn8-devel-ubuntu20.04 # modified from here # https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile @@ -36,15 +36,8 @@ RUN conda install python=3.7 -y RUN pip install setuptools --upgrade && pip install --upgrade pip RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch -# RUN pip install \ -# vidio - -RUN pip install "chardet<4.0" h5py kornia>=0.5 matplotlib numpy omegaconf>=2 "pandas<1.4" \ - "scikit-learn<1.1" "scipy<1.8" tqdm pytorch_lightning>=1.5.10 opencv-python-headless vidio>=0.0.4 pytest \ - opencv-transforms - # # needed for pandas for some reason ADD . /app/deepethogram WORKDIR /app/deepethogram ENV DEG_VERSION='headless' -RUN pip install -e . --no-dependencies \ No newline at end of file +RUN pip install -e . \ No newline at end of file diff --git a/docs/docker.md b/docs/docker.md index 87a7b13..3c7e261 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -37,4 +37,4 @@ Again, change `/media` to your hard drive with your training data # building it yourself To build the container with both GUI and model training support: * `cd` to your `deepethogram` directory -* `nvidia-docker build -t deepethogram:full -f docker/Dockerfile-full . ` +* `docker build -t deepethogram:full -f docker/Dockerfile-full . ` diff --git a/environment.yml b/environment.yml index 335f3bb..ba10d1b 100644 --- a/environment.yml +++ b/environment.yml @@ -6,17 +6,7 @@ channels: dependencies: - pip - conda-forge::pyside2=5.13.2 - - python>3.7, <3.8 + - python>3.7, <3.9 - pytorch::pytorch - pip: - - h5py==2.10.0 - - hydra-core==0.11.3 - - omegaconf==1.4.1 - - matplotlib - - opencv-python-headless - - opencv-transforms==0.0.3.post2 - - pandas - - scikit-learn - - scipy - - tifffile - - tqdm + - -r requirements.txt \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4f3fcb6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +chardet<4.0 +h5py +kornia>=0.5 +matplotlib +numpy +omegaconf>=2 +opencv-python-headless +opencv-transforms +pandas<1.4 +PySide2==5.13.2 +pytest +scikit-learn<1.1 +scipy<1.8 +tqdm +vidio +pytorch_lightning==1.6.5 \ No newline at end of file diff --git a/setup.py b/setup.py index 7fa7687..477584b 100644 --- a/setup.py +++ b/setup.py @@ -3,8 +3,12 @@ with open("README.md", 'r') as f: long_description = f.read() +def get_requirements(): + with open('requirements.txt') as f: + return f.read().splitlines() + setuptools.setup(name='deepethogram', - version='0.1.4', + version='0.1.5', author='Jim Bohnslav', author_email='jbohnslav@gmail.com', description='Temporal action detection for biology', @@ -15,8 +19,4 @@ classifiers=['Programming Language :: Python :: 3', 'Operating System :: OS Independent'], entry_points={'console_scripts': ['deepethogram = deepethogram.gui.main:entry']}, python_requires='>=3.6', - install_requires=[ - 'chardet<4.0', 'h5py', 'kornia>=0.5', 'matplotlib', 'numpy', 'omegaconf>=2', - 'opencv-python-headless', 'opencv-transforms', 'pandas<1.4', 'PySide2', 'scikit-learn<1.1', - 'scipy<1.8', 'tqdm', 'vidio', 'pytorch_lightning>=1.5.10' - ]) + install_requires=get_requirements()) From e968525152964e3c8805336d7e94b5dc9d4a00d1 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 4 Aug 2024 00:27:28 -0400 Subject: [PATCH 02/23] add codeowners --- CODEOWNERS | 1 + 1 file changed, 1 insertion(+) create mode 100644 CODEOWNERS diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000..cfa029b --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +@jbohnslav \ No newline at end of file From 215fef58d985dccaa068439c6b1885711c81fab6 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sat, 11 Jan 2025 15:32:23 -0500 Subject: [PATCH 03/23] fixed tests and add download script --- setup_tests.py | 93 +++++++++++++++++++++++++++++++++++++++++++++++ tests/test_gui.py | 14 +++++-- 2 files changed, 103 insertions(+), 4 deletions(-) create mode 100644 setup_tests.py diff --git a/setup_tests.py b/setup_tests.py new file mode 100644 index 0000000..4d39924 --- /dev/null +++ b/setup_tests.py @@ -0,0 +1,93 @@ +"""This script downloads the test data archive and sets up the testing environment for DeepEthogram. + +For it to work, you need to `pip install gdown` +""" + +import os +import shutil +import sys +import zipfile +from pathlib import Path + +import gdown +import requests + + +def download_file(url, destination): + """Downloads a file from a URL to a destination with progress indication.""" + response = requests.get(url, stream=True) + total_size = int(response.headers.get("content-length", 0)) + block_size = 8192 + + if total_size == 0: + print("Warning: Content length not provided by server") + + print(f"Downloading to: {destination}") + + with open(destination, "wb") as f: + downloaded = 0 + for data in response.iter_content(block_size): + downloaded += len(data) + f.write(data) + + # Print progress + if total_size > 0: + progress = int(50 * downloaded / total_size) + sys.stdout.write( + f"\r[{'=' * progress}{' ' * (50 - progress)}] {downloaded}/{total_size} bytes" + ) + sys.stdout.flush() + print("\nDownload completed!") + + +def setup_tests(): + """Sets up the testing environment for DeepEthogram.""" + + # Create tests/DATA directory if it doesn't exist + tests_dir = Path("tests") + data_dir = tests_dir / "DATA" + data_dir.mkdir(parents=True, exist_ok=True) + + # Download the test archive + zip_url = "https://drive.google.com/uc?export=download&id=1IFz4ABXppVxyuhYik8j38k9-Fl9kYKHo" + zip_path = data_dir / "testing_deepethogram_archive.zip" + + try: + print("Downloading test data archive...") + gdown.download( + id="1IFz4ABXppVxyuhYik8j38k9-Fl9kYKHo", output=str(zip_path), quiet=False + ) + + print("Extracting archive...") + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(data_dir) + + # Verify the extraction + archive_path = data_dir / "testing_deepethogram_archive" + required_items = ["DATA", "models", "project_config.yaml"] + + missing_items = [ + item for item in required_items if not (archive_path / item).exists() + ] + + if missing_items: + print(f"Warning: The following items are missing: {missing_items}") + return False + + print("Setup completed successfully!") + print("\nYou can now run the tests using: pytest tests/") + print( + "Note: The zz_commandline test module will take a few minutes to complete." + ) + + # Clean up the zip file + zip_path.unlink() + return True + + except Exception as e: + print(f"Error during setup: {str(e)}") + return False + + +if __name__ == "__main__": + setup_tests() diff --git a/tests/test_gui.py b/tests/test_gui.py index f60b34d..d873998 100644 --- a/tests/test_gui.py +++ b/tests/test_gui.py @@ -3,14 +3,20 @@ import pytest +DEG_VERSION = os.environ.get("DEG_VERSION", "full") -@pytest.mark.skipif(os.environ['DEG_VERSION'] == 'headless', reason="Dont run GUI tests for headless deepethogram") + +@pytest.mark.skipif( + DEG_VERSION == "headless", + reason="Dont run GUI tests for headless deepethogram", +) def test_setup(): # put imports here so that headless version does not import gui tools - from deepethogram.gui.main import run, setup_gui_cfg, MainWindow + from deepethogram.gui.main import MainWindow, run, setup_gui_cfg + cfg = setup_gui_cfg() - assert cfg.run.type == 'gui' + assert cfg.run.type == "gui" # def test_new_project(): @@ -22,4 +28,4 @@ def test_setup(): # window._new_project() # def test_open(): -# run() \ No newline at end of file +# run() From 83d001521ce22114775f994bb2da90be5e937e9d Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sat, 11 Jan 2025 15:33:11 -0500 Subject: [PATCH 04/23] add ruff linting --- .style.yapf | 3 --- pyproject.toml | 41 +++++++++++++++++++++++++++++++++++++++++ requirements.txt | 3 ++- setup.py | 43 ++++++++++++++++++++++++++++--------------- 4 files changed, 71 insertions(+), 19 deletions(-) delete mode 100644 .style.yapf create mode 100644 pyproject.toml diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index 5e7df5a..0000000 --- a/.style.yapf +++ /dev/null @@ -1,3 +0,0 @@ -[style] -based_on_style = google -column_limit = 120 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f9492d9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,41 @@ +[tool.ruff] +# Python version compatibility +target-version = "py37" + +# Ignore specific rules +ignore = [] + +# Allow autofix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", +] + +# Same as Black. +line-length = 120 + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" diff --git a/requirements.txt b/requirements.txt index 4f3fcb6..a1a4131 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ scikit-learn<1.1 scipy<1.8 tqdm vidio -pytorch_lightning==1.6.5 \ No newline at end of file +pytorch_lightning==1.6.5 +ruff>=0.1.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 477584b..e5c86c6 100644 --- a/setup.py +++ b/setup.py @@ -1,22 +1,35 @@ import setuptools -with open("README.md", 'r') as f: +with open("README.md", "r") as f: long_description = f.read() + def get_requirements(): - with open('requirements.txt') as f: + with open("requirements.txt") as f: return f.read().splitlines() -setuptools.setup(name='deepethogram', - version='0.1.5', - author='Jim Bohnslav', - author_email='jbohnslav@gmail.com', - description='Temporal action detection for biology', - long_description=long_description, - long_description_content_type='text/markdown', - include_package_data=True, - packages=setuptools.find_packages(), - classifiers=['Programming Language :: Python :: 3', 'Operating System :: OS Independent'], - entry_points={'console_scripts': ['deepethogram = deepethogram.gui.main:entry']}, - python_requires='>=3.6', - install_requires=get_requirements()) + +setuptools.setup( + name="deepethogram", + version="0.2.0", + author="Jim Bohnslav", + author_email="jbohnslav@gmail.com", + description="Temporal action detection for biology", + long_description=long_description, + long_description_content_type="text/markdown", + include_package_data=True, + packages=setuptools.find_packages(), + classifiers=[ + "Programming Language :: Python :: 3", + "Operating System :: OS Independent", + ], + entry_points={"console_scripts": ["deepethogram = deepethogram.gui.main:entry"]}, + python_requires=">=3.7,<3.8", + install_requires=get_requirements(), + options={ + "ruff": { + "target-version": "py37", + }, + }, + setup_requires=["setuptools>=61.0.0", "ruff"], +) From 0c7c57e5723c92dbf2db14bb1c9800e35759f562 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sat, 11 Jan 2025 15:34:07 -0500 Subject: [PATCH 05/23] ruff --- deepethogram/base.py | 3 +-- deepethogram/data/augs.py | 1 - deepethogram/data/dataloaders.py | 2 +- deepethogram/data/datasets.py | 16 ++++--------- deepethogram/debug.py | 2 +- deepethogram/feature_extractor/inference.py | 5 ++-- deepethogram/feature_extractor/losses.py | 2 -- deepethogram/feature_extractor/train.py | 1 - deepethogram/flow_generator/inference.py | 1 - deepethogram/flow_generator/train.py | 3 +-- deepethogram/gui/custom_widgets.py | 2 +- deepethogram/gui/main.py | 1 - deepethogram/gui/mainwindow.py | 2 +- deepethogram/metrics.py | 2 +- deepethogram/postprocessing.py | 2 +- deepethogram/projects.py | 3 +-- deepethogram/sequence/models/sequence.py | 1 - deepethogram/sequence/train.py | 4 ---- deepethogram/tune/feature_extractor.py | 2 +- deepethogram/tune/sequence.py | 2 +- deepethogram/tune/utils.py | 2 +- deepethogram/utils.py | 3 +-- deepethogram/viz.py | 8 +++---- deepethogram/zscore.py | 2 +- docs/docker.md | 2 +- setup_tests.py | 2 -- tests/setup_data.py | 2 +- tests/test_data.py | 7 ------ tests/test_flow_generator.py | 3 +-- tests/test_gui.py | 2 +- tests/test_projects.py | 2 +- tests/test_z_score.py | 2 +- tests/test_zz_commandline.py | 26 ++++++++++----------- 33 files changed, 41 insertions(+), 79 deletions(-) diff --git a/deepethogram/base.py b/deepethogram/base.py index 040a442..7943a5f 100644 --- a/deepethogram/base.py +++ b/deepethogram/base.py @@ -6,7 +6,6 @@ from typing import Tuple import matplotlib.pyplot as plt -import numpy as np from omegaconf import DictConfig, OmegaConf import pytorch_lightning as pl try: @@ -22,7 +21,7 @@ from torch.utils.data import DataLoader, WeightedRandomSampler from deepethogram.data.augs import get_gpu_transforms, get_empty_gpu_transforms -from deepethogram.callbacks import FPSCallback, DebugCallback, MetricsCallback, \ +from deepethogram.callbacks import FPSCallback, MetricsCallback, \ ExampleImagesCallback, CheckpointCallback, StopperCallback from deepethogram.metrics import Metrics, EmptyMetrics from deepethogram.schedulers import initialize_scheduler diff --git a/deepethogram/data/augs.py b/deepethogram/data/augs.py index b271bf0..17922a0 100644 --- a/deepethogram/data/augs.py +++ b/deepethogram/data/augs.py @@ -1,5 +1,4 @@ import logging -from pprint import pformat import cv2 import numpy as np diff --git a/deepethogram/data/dataloaders.py b/deepethogram/data/dataloaders.py index b22ddbf..7a8b8ed 100644 --- a/deepethogram/data/dataloaders.py +++ b/deepethogram/data/dataloaders.py @@ -9,7 +9,7 @@ from torch.utils import data from deepethogram import projects -from deepethogram.data.augs import get_transforms, get_cpu_transforms +from deepethogram.data.augs import get_cpu_transforms from deepethogram.data.datasets import SequenceDataset, TwoStreamDataset, VideoDataset, KineticsDataset from deepethogram.data.utils import get_split_from_records, remove_invalid_records_from_split_dictionary, \ make_loss_weight diff --git a/deepethogram/data/datasets.py b/deepethogram/data/datasets.py index ea522c1..49024a3 100644 --- a/deepethogram/data/datasets.py +++ b/deepethogram/data/datasets.py @@ -1,30 +1,22 @@ -import bisect from collections import deque import logging import os -import pprint import random -import warnings -from functools import partial from typing import Union, Tuple import h5py import numpy as np from omegaconf import DictConfig -import pandas as pd import torch -from opencv_transforms import transforms from torch.utils import data from vidio import VideoReader # from deepethogram.dataloaders import log from deepethogram import projects from deepethogram.data.augs import get_cpu_transforms -from deepethogram.data.utils import purge_unlabeled_elements_from_records, get_video_metadata, extract_metadata, \ - find_labelfile, read_all_labels, get_split_from_records, remove_invalid_records_from_split_dictionary, \ +from deepethogram.data.utils import purge_unlabeled_elements_from_records, get_video_metadata, read_all_labels, get_split_from_records, remove_invalid_records_from_split_dictionary, \ make_loss_weight, fix_label -from deepethogram.data.keypoint_utils import load_dlcfile, interpolate_bad_values, stack_features_in_time, \ - expand_features_sturman +from deepethogram.data.keypoint_utils import load_dlcfile, interpolate_bad_values, expand_features_sturman from deepethogram.file_io import read_labels log = logging.getLogger(__name__) @@ -134,7 +126,7 @@ def get_current_item(self): else: try: im = self.readers[worker_id][self.cnt] - except Exception as e: + except Exception: print(f'problem reading frame {self.cnt}') raise im = self.transform(im) @@ -184,7 +176,7 @@ def close(self): continue try: v.close() - except Exception as e: + except Exception: print(f'error destroying reader {k}') else: print(f'destroyed {k}') diff --git a/deepethogram/debug.py b/deepethogram/debug.py index 7749553..1bb300c 100644 --- a/deepethogram/debug.py +++ b/deepethogram/debug.py @@ -75,7 +75,7 @@ def try_load_all_frames(datadir: Union[str, os.PathLike]): for i in tqdm(range(len(reader)), leave=False): try: frame = reader[i] - except Exception as e: + except Exception: had_error = True print('error reading frame {} from video {}'.format(i, record['rgb'])) except KeyboardInterrupt: diff --git a/deepethogram/feature_extractor/inference.py b/deepethogram/feature_extractor/inference.py index aafe32b..1dfbe1e 100644 --- a/deepethogram/feature_extractor/inference.py +++ b/deepethogram/feature_extractor/inference.py @@ -1,4 +1,3 @@ -from collections import defaultdict import logging import os import sys @@ -16,7 +15,7 @@ from deepethogram import utils, projects from deepethogram.configuration import make_feature_extractor_inference_cfg -from deepethogram.data.augs import get_cpu_transforms, get_gpu_transforms_inference, get_gpu_transforms +from deepethogram.data.augs import get_cpu_transforms, get_gpu_transforms from deepethogram.data.datasets import VideoIterable from deepethogram.feature_extractor.train import build_model_from_cfg as build_feature_extractor from deepethogram.file_io import read_labels @@ -423,7 +422,7 @@ def extract(rgbs: list, while not has_worked: try: f = h5py.File(h5file, 'r+') - except OSError as e: + except OSError: log.warning('resource unavailable, waiting 30 seconds...') time.sleep(30) else: diff --git a/deepethogram/feature_extractor/losses.py b/deepethogram/feature_extractor/losses.py index 99dc758..b749eff 100644 --- a/deepethogram/feature_extractor/losses.py +++ b/deepethogram/feature_extractor/losses.py @@ -1,6 +1,4 @@ -import os import logging -import pdb import numpy as np import torch diff --git a/deepethogram/feature_extractor/train.py b/deepethogram/feature_extractor/train.py index 24aab76..f71f8a3 100644 --- a/deepethogram/feature_extractor/train.py +++ b/deepethogram/feature_extractor/train.py @@ -18,7 +18,6 @@ from deepethogram import utils, viz from deepethogram.base import BaseLightningModule, get_trainer_from_cfg from deepethogram.configuration import make_feature_extractor_train_cfg -from deepethogram.data.augs import get_gpu_transforms from deepethogram.data.datasets import get_datasets_from_cfg from deepethogram.feature_extractor.losses import ClassificationLoss, BinaryFocalLoss, CrossEntropyLoss from deepethogram.feature_extractor.models.CNN import get_cnn diff --git a/deepethogram/flow_generator/inference.py b/deepethogram/flow_generator/inference.py index 7f2ba81..43e2cdc 100644 --- a/deepethogram/flow_generator/inference.py +++ b/deepethogram/flow_generator/inference.py @@ -6,7 +6,6 @@ from typing import Union import cv2 -from matplotlib.pyplot import psd import numpy as np from omegaconf import OmegaConf, ListConfig import torch diff --git a/deepethogram/flow_generator/train.py b/deepethogram/flow_generator/train.py index 023d5d6..18c81c8 100644 --- a/deepethogram/flow_generator/train.py +++ b/deepethogram/flow_generator/train.py @@ -14,8 +14,7 @@ import deepethogram.projects from deepethogram import utils, viz, projects from deepethogram.base import BaseLightningModule, get_trainer_from_cfg -from deepethogram.configuration import make_config, make_flow_generator_train_cfg -from deepethogram.data.augs import get_gpu_transforms +from deepethogram.configuration import make_flow_generator_train_cfg from deepethogram.data.datasets import get_datasets_from_cfg from deepethogram.flow_generator import models from deepethogram.flow_generator.losses import MotionNetLoss diff --git a/deepethogram/gui/custom_widgets.py b/deepethogram/gui/custom_widgets.py index be17e9b..4c41f6d 100644 --- a/deepethogram/gui/custom_widgets.py +++ b/deepethogram/gui/custom_widgets.py @@ -415,7 +415,7 @@ def initialize(self, try: self.cmap = Mapper(colormap) - except ValueError as e: + except ValueError: raise ('Colormap not in matplotlib' 's defaults! {}'.format(colormap)) if self.debug: self.make_debug() diff --git a/deepethogram/gui/main.py b/deepethogram/gui/main.py index 4c110ff..32cdf9a 100644 --- a/deepethogram/gui/main.py +++ b/deepethogram/gui/main.py @@ -6,7 +6,6 @@ from functools import partial from typing import Union -import cv2 # import hydra import numpy as np import pandas as pd diff --git a/deepethogram/gui/mainwindow.py b/deepethogram/gui/mainwindow.py index 7600bc3..584dd3d 100644 --- a/deepethogram/gui/mainwindow.py +++ b/deepethogram/gui/mainwindow.py @@ -8,7 +8,7 @@ # # WARNING! All changes made in this file will be lost! -from PySide2 import QtCore, QtGui, QtWidgets +from PySide2 import QtCore, QtWidgets class Ui_MainWindow(object): def setupUi(self, MainWindow): diff --git a/deepethogram/metrics.py b/deepethogram/metrics.py index 5394a76..f94a929 100644 --- a/deepethogram/metrics.py +++ b/deepethogram/metrics.py @@ -8,7 +8,7 @@ import h5py import numpy as np import torch -from sklearn.metrics import f1_score, roc_auc_score, confusion_matrix, average_precision_score, auc +from sklearn.metrics import f1_score, roc_auc_score, confusion_matrix, auc from deepethogram import utils from deepethogram.postprocessing import remove_low_thresholds diff --git a/deepethogram/postprocessing.py b/deepethogram/postprocessing.py index 4d60feb..e329de6 100644 --- a/deepethogram/postprocessing.py +++ b/deepethogram/postprocessing.py @@ -1,7 +1,7 @@ from collections import defaultdict import logging import os -from typing import Dict, Type, Tuple +from typing import Type, Tuple import h5py import numpy as np diff --git a/deepethogram/projects.py b/deepethogram/projects.py index 988c3ed..3423bbf 100644 --- a/deepethogram/projects.py +++ b/deepethogram/projects.py @@ -2,7 +2,6 @@ import os import re import shutil -import sys import warnings from datetime import datetime from typing import Union @@ -10,7 +9,7 @@ import h5py import numpy as np import pandas as pd -from omegaconf import DictConfig, OmegaConf, ListConfig +from omegaconf import DictConfig, OmegaConf from tqdm import tqdm import deepethogram diff --git a/deepethogram/sequence/models/sequence.py b/deepethogram/sequence/models/sequence.py index cef8b60..6c11a11 100644 --- a/deepethogram/sequence/models/sequence.py +++ b/deepethogram/sequence/models/sequence.py @@ -1,4 +1,3 @@ -import torch from torch import nn diff --git a/deepethogram/sequence/train.py b/deepethogram/sequence/train.py index adefbb3..51bbe4f 100644 --- a/deepethogram/sequence/train.py +++ b/deepethogram/sequence/train.py @@ -1,22 +1,18 @@ import logging import os import sys -from typing import Type, Tuple, Union import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn -import torch.optim as optim from omegaconf import DictConfig, OmegaConf -import deepethogram.projects from deepethogram.base import BaseLightningModule, get_trainer_from_cfg from deepethogram import utils, projects, viz from deepethogram.configuration import make_sequence_train_cfg from deepethogram.data.datasets import get_datasets_from_cfg from deepethogram.feature_extractor.train import get_metrics, get_stopper, get_criterion -from deepethogram.schedulers import initialize_scheduler from deepethogram.sequence.models.mlp import MLP from deepethogram.sequence.models.sequence import Linear, Conv_Nonlinear, RNN from deepethogram.sequence.models.tgm import TGM, TGMJ diff --git a/deepethogram/tune/feature_extractor.py b/deepethogram/tune/feature_extractor.py index 87824a0..f2bc8dc 100644 --- a/deepethogram/tune/feature_extractor.py +++ b/deepethogram/tune/feature_extractor.py @@ -12,7 +12,7 @@ print('To use the deepethogram.tune module, you must `pip install \'ray[tune]`') raise -from deepethogram.configuration import make_config, load_config_by_name +from deepethogram.configuration import make_config from deepethogram.feature_extractor.train import feature_extractor_train from deepethogram import projects from deepethogram.tune.utils import dict_to_dotlist, generate_tune_cfg diff --git a/deepethogram/tune/sequence.py b/deepethogram/tune/sequence.py index 40d8d81..6084c22 100644 --- a/deepethogram/tune/sequence.py +++ b/deepethogram/tune/sequence.py @@ -12,7 +12,7 @@ print('To use the deepethogram.tune module, you must `pip install \'ray[tune]`') raise -from deepethogram.configuration import make_config, load_config_by_name +from deepethogram.configuration import make_config from deepethogram import sequence_train from deepethogram import projects from deepethogram.tune.utils import dict_to_dotlist, generate_tune_cfg diff --git a/deepethogram/tune/utils.py b/deepethogram/tune/utils.py index fa99c6b..97c72a7 100644 --- a/deepethogram/tune/utils.py +++ b/deepethogram/tune/utils.py @@ -1,4 +1,4 @@ -from omegaconf import OmegaConf, DictConfig +from omegaconf import OmegaConf try: import ray from ray import tune diff --git a/deepethogram/utils.py b/deepethogram/utils.py index 397d449..c176811 100644 --- a/deepethogram/utils.py +++ b/deepethogram/utils.py @@ -1,4 +1,3 @@ -from collections.abc import Mapping, Container import logging import os import pkgutil @@ -266,7 +265,7 @@ def load_state_dict_from_file(weights_file, distributed: bool = False): # log.info('loading onto cpu...') # state = torch.load(weights_file, map_location='cpu') - is_pure_weights = not 'epoch' in list(state.keys()) + is_pure_weights = 'epoch' not in list(state.keys()) # load params if is_pure_weights: state_dict = state diff --git a/deepethogram/viz.py b/deepethogram/viz.py index c3eb3b8..2314417 100644 --- a/deepethogram/viz.py +++ b/deepethogram/viz.py @@ -12,14 +12,12 @@ # import tifffile as TIFF from matplotlib import pyplot as plt from matplotlib.animation import FuncAnimation -from matplotlib.projections import get_projection_class from mpl_toolkits.axes_grid1 import make_axes_locatable, inset_locator -from sklearn.metrics import auc import torch from deepethogram.flow_generator.utils import flow_to_rgb_polar # from deepethogram.metrics import load_threshold_data -from deepethogram.utils import tensor_to_np, print_top_largest_variables +from deepethogram.utils import tensor_to_np log = logging.getLogger(__name__) # override warning level for matplotlib, which outputs a million debugging statements @@ -943,7 +941,7 @@ def get_data_from_file(f, name): data = OrderedDict(train=f['train/fps'][:], val=f['val/fps'][:], speedtest=f['speedtest/fps'][:]) - except Exception as e: + except Exception: # likely don't have speedtest, not too important data = OrderedDict(train=f['train/fps'][:], val=f['val/fps'][:]) @@ -1244,7 +1242,7 @@ def __init__(self, colormap='deepethogram'): else: try: self.cmap = plt.get_cmap(colormap) - except ValueError as e: + except ValueError: raise ('Colormap not in matplotlib''s defaults! {}'.format(colormap)) def init_deepethogram(self): diff --git a/deepethogram/zscore.py b/deepethogram/zscore.py index 0d7d23f..ec655f2 100644 --- a/deepethogram/zscore.py +++ b/deepethogram/zscore.py @@ -100,7 +100,7 @@ def get_video_statistics(videofile, stride): for i in tqdm(range(0, len(reader), stride)): try: image = reader[i] - except Exception as e: + except Exception: log.warning('Error reading frame {} from video {}'.format(i, videofile)) continue image = image.astype(float) / 255 diff --git a/docs/docker.md b/docs/docker.md index 3c7e261..598c073 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -37,4 +37,4 @@ Again, change `/media` to your hard drive with your training data # building it yourself To build the container with both GUI and model training support: * `cd` to your `deepethogram` directory -* `docker build -t deepethogram:full -f docker/Dockerfile-full . ` +* `docker build -t deepethogram:full -f docker/Dockerfile-full .` diff --git a/setup_tests.py b/setup_tests.py index 4d39924..781d3d6 100644 --- a/setup_tests.py +++ b/setup_tests.py @@ -3,8 +3,6 @@ For it to work, you need to `pip install gdown` """ -import os -import shutil import sys import zipfile from pathlib import Path diff --git a/tests/setup_data.py b/tests/setup_data.py index 8dab459..19fb70a 100644 --- a/tests/setup_data.py +++ b/tests/setup_data.py @@ -2,7 +2,7 @@ import shutil # from projects import get_records_from_datadir, fix_config_paths -from deepethogram import projects, utils +from deepethogram import projects this_path = os.path.abspath(__file__) test_path = os.path.dirname(this_path) diff --git a/tests/test_data.py b/tests/test_data.py index a3a29c0..25067b4 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,14 +1,7 @@ -import os -import random -import shutil import numpy as np -import pandas as pd -import pytest -from deepethogram import projects, utils from deepethogram.data import utils as data_utils -from setup_data import make_project_from_archive, project_path, test_data_path, clean_test_data, get_records def test_loss_weight(): diff --git a/tests/test_flow_generator.py b/tests/test_flow_generator.py index 7647b6f..f4f4e12 100644 --- a/tests/test_flow_generator.py +++ b/tests/test_flow_generator.py @@ -4,8 +4,7 @@ from deepethogram.configuration import make_flow_generator_train_cfg from deepethogram.flow_generator.train import (get_datasets_from_cfg, build_model_from_cfg, get_metrics, OpticalFlowLightning) -from setup_data import (make_project_from_archive, project_path, test_data_path, clean_test_data, get_records, - config_path, data_path) +from setup_data import (make_project_from_archive, project_path) def test_metrics(): diff --git a/tests/test_gui.py b/tests/test_gui.py index d873998..74c03fd 100644 --- a/tests/test_gui.py +++ b/tests/test_gui.py @@ -12,7 +12,7 @@ ) def test_setup(): # put imports here so that headless version does not import gui tools - from deepethogram.gui.main import MainWindow, run, setup_gui_cfg + from deepethogram.gui.main import setup_gui_cfg cfg = setup_gui_cfg() diff --git a/tests/test_projects.py b/tests/test_projects.py index 37888a3..c1cea2c 100644 --- a/tests/test_projects.py +++ b/tests/test_projects.py @@ -6,7 +6,7 @@ import pandas as pd import pytest -from deepethogram import projects, utils +from deepethogram import projects from setup_data import make_project_from_archive, project_path, test_data_path, clean_test_data, get_records # make_project_from_archive() diff --git a/tests/test_z_score.py b/tests/test_z_score.py index 5e615da..806923a 100644 --- a/tests/test_z_score.py +++ b/tests/test_z_score.py @@ -2,7 +2,7 @@ from deepethogram import projects from deepethogram.zscore import get_video_statistics -from setup_data import make_project_from_archive, project_path, data_path +from setup_data import make_project_from_archive, data_path make_project_from_archive() diff --git a/tests/test_zz_commandline.py b/tests/test_zz_commandline.py index c3f08cf..4270a60 100644 --- a/tests/test_zz_commandline.py +++ b/tests/test_zz_commandline.py @@ -1,6 +1,4 @@ # this is named test__zz_commandline so that it comes last, after all module-specific tests -import os -import numpy as np import subprocess from deepethogram import utils @@ -55,19 +53,19 @@ def add_default_arguments(string, train=True): def test_flow(): make_project_from_archive() - string = (f'python -m deepethogram.flow_generator.train preset=deg_f ') + string = ('python -m deepethogram.flow_generator.train preset=deg_f ') string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) assert ret.returncode == 0 - string = (f'python -m deepethogram.flow_generator.train preset=deg_m ') + string = ('python -m deepethogram.flow_generator.train preset=deg_m ') string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) assert ret.returncode == 0 - string = (f'python -m deepethogram.flow_generator.train preset=deg_s ') + string = ('python -m deepethogram.flow_generator.train preset=deg_s ') string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) @@ -75,21 +73,21 @@ def test_flow(): def test_feature_extractor(): - string = (f'python -m deepethogram.feature_extractor.train preset=deg_f flow_generator.weights=latest ') + string = ('python -m deepethogram.feature_extractor.train preset=deg_f flow_generator.weights=latest ') string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) assert ret.returncode == 0 - string = (f'python -m deepethogram.feature_extractor.train preset=deg_m flow_generator.weights=latest ') + string = ('python -m deepethogram.feature_extractor.train preset=deg_m flow_generator.weights=latest ') string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) assert ret.returncode == 0 # for resnet3d, must specify weights, because we can't just download them from the torchvision repo - string = (f'python -m deepethogram.feature_extractor.train preset=deg_s flow_generator.weights=latest ' - f'feature_extractor.weights=latest ') + string = ('python -m deepethogram.feature_extractor.train preset=deg_s flow_generator.weights=latest ' + 'feature_extractor.weights=latest ') string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) @@ -106,7 +104,7 @@ def test_feature_extractor(): def test_feature_extraction(softmax: bool = False): # the reason for this complexity is that I don't want to run inference on all directories - string = (f'python -m deepethogram.feature_extractor.inference preset=deg_f feature_extractor.weights=latest ' + string = ('python -m deepethogram.feature_extractor.inference preset=deg_f feature_extractor.weights=latest ' 'flow_generator.weights=latest ') if softmax: string += 'feature_extractor.final_activation=softmax ' @@ -125,7 +123,7 @@ def test_feature_extraction(softmax: bool = False): def test_sequence_train(): - string = (f'python -m deepethogram.sequence.train ') + string = ('python -m deepethogram.sequence.train ') string = add_default_arguments(string) command = command_from_string(string) print(command) @@ -133,7 +131,7 @@ def test_sequence_train(): assert ret.returncode == 0 # mutually exclusive - string = (f'python -m deepethogram.sequence.train feature_extractor.final_activation=softmax ') + string = ('python -m deepethogram.sequence.train feature_extractor.final_activation=softmax ') string = add_default_arguments(string) command = command_from_string(string) print(command) @@ -143,7 +141,7 @@ def test_sequence_train(): def test_softmax(): make_project_from_archive() - string = (f'python -m deepethogram.flow_generator.train preset=deg_f ') + string = ('python -m deepethogram.flow_generator.train preset=deg_f ') string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) @@ -158,7 +156,7 @@ def test_softmax(): test_feature_extraction(softmax=True) - string = (f'python -m deepethogram.sequence.train feature_extractor.final_activation=softmax ') + string = ('python -m deepethogram.sequence.train feature_extractor.final_activation=softmax ') string = add_default_arguments(string) command = command_from_string(string) print(command) From 681524213c55cfa4fc36e0011874eac510ccab5d Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sat, 11 Jan 2025 15:39:15 -0500 Subject: [PATCH 06/23] pre commit hooks not passing --- .dockerignore | 2 +- .gitignore | 2 +- .pre-commit-config.yaml | 17 + CODEOWNERS | 2 +- MANIFEST.in | 2 +- README.md | 79 +- deepethogram/__init__.py | 4 +- deepethogram/__main__.py | 4 +- deepethogram/base.py | 257 +++--- deepethogram/callbacks.py | 113 ++- deepethogram/conf/augs.yaml | 2 +- deepethogram/conf/config.yaml | 4 +- deepethogram/conf/debug.yaml | 2 +- deepethogram/conf/gui.yaml | 2 +- deepethogram/conf/inference.yaml | 4 +- .../conf/model/feature_extractor.yaml | 2 +- deepethogram/conf/model/flow_generator.yaml | 2 +- deepethogram/conf/model/sequence.yaml | 4 +- deepethogram/conf/postprocessor.yaml | 2 +- deepethogram/conf/preset/deg_f.yaml | 2 +- deepethogram/conf/preset/deg_m.yaml | 2 +- deepethogram/conf/preset/deg_s.yaml | 2 +- deepethogram/conf/train.yaml | 1 - deepethogram/conf/tune/feature_extractor.yaml | 4 +- deepethogram/conf/tune/sequence.yaml | 4 +- deepethogram/conf/tune/tune.yaml | 4 +- deepethogram/configuration.py | 108 +-- deepethogram/data/augs.py | 65 +- deepethogram/data/dali.py | 217 ++--- deepethogram/data/dataloaders.py | 452 ++++++---- deepethogram/data/datasets.py | 603 ++++++------- deepethogram/data/keypoint_utils.py | 76 +- deepethogram/data/utils.py | 189 +++-- deepethogram/debug.py | 98 +-- deepethogram/feature_extractor/inference.py | 356 ++++---- deepethogram/feature_extractor/losses.py | 45 +- deepethogram/feature_extractor/models/CNN.py | 39 +- .../models/classifiers/alexnet.py | 18 +- .../models/classifiers/densenet.py | 143 ++-- .../models/classifiers/inception.py | 35 +- .../models/classifiers/resnet.py | 85 +- .../models/classifiers/resnet3d.py | 105 +-- .../models/classifiers/squeezenet.py | 54 +- .../models/classifiers/vgg.py | 140 +-- .../models/hidden_two_stream.py | 332 ++++---- .../feature_extractor/models/utils.py | 62 +- deepethogram/feature_extractor/train.py | 344 ++++---- deepethogram/file_io.py | 44 +- deepethogram/flow_generator/__init__.py | 2 +- deepethogram/flow_generator/inference.py | 164 ++-- deepethogram/flow_generator/losses.py | 106 ++- .../flow_generator/models/FlowNetS.py | 15 +- .../flow_generator/models/MotionNet.py | 13 +- .../flow_generator/models/TinyMotionNet.py | 3 +- .../flow_generator/models/TinyMotionNet3D.py | 39 +- .../flow_generator/models/__init__.py | 1 - .../flow_generator/models/components.py | 96 ++- deepethogram/flow_generator/train.py | 186 ++-- deepethogram/flow_generator/utils.py | 71 +- deepethogram/gui/custom_widgets.py | 202 ++--- deepethogram/gui/main.py | 574 +++++++------ deepethogram/gui/mainwindow.py | 18 +- deepethogram/gui/menus_and_popups.py | 60 +- deepethogram/losses.py | 119 +-- deepethogram/metrics.py | 318 ++++--- deepethogram/postprocessing.py | 100 +-- deepethogram/projects.py | 629 +++++++------- deepethogram/schedulers.py | 60 +- deepethogram/sequence/__main__.py | 2 +- deepethogram/sequence/inference.py | 225 ++--- deepethogram/sequence/models/mlp.py | 21 +- deepethogram/sequence/models/sequence.py | 53 +- deepethogram/sequence/models/tgm.py | 126 +-- deepethogram/sequence/train.py | 178 ++-- deepethogram/stoppers.py | 57 +- deepethogram/tune/feature_extractor.py | 121 +-- deepethogram/tune/sequence.py | 116 +-- deepethogram/tune/utils.py | 26 +- deepethogram/utils.py | 255 +++--- deepethogram/viz.py | 799 ++++++++++-------- deepethogram/zscore.py | 58 +- docker/Dockerfile-full | 14 +- docker/Dockerfile-gui | 10 +- docker/Dockerfile-headless | 12 +- docs/beta.md | 22 +- docs/code_examples.md | 2 +- docs/docker.md | 4 +- docs/file_structure.md | 12 +- docs/getting_started.md | 82 +- docs/installation.md | 37 +- docs/performance.md | 120 +-- docs/troubleshooting.md | 8 +- docs/using_CLI.md | 50 +- docs/using_code.md | 2 +- docs/using_config_files.md | 36 +- docs/using_gui.md | 120 ++- docs/using_tune.md | 2 +- environment.yml | 2 +- license.txt | 2 +- requirements.txt | 3 +- setup_tests.py | 16 +- tests/setup_data.py | 24 +- tests/test_data.py | 17 +- tests/test_flow_generator.py | 24 +- tests/test_models.py | 4 +- tests/test_projects.py | 72 +- tests/test_z_score.py | 8 +- tests/test_zz_commandline.py | 74 +- 108 files changed, 5158 insertions(+), 4569 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.dockerignore b/.dockerignore index 4b39369..1944b42 100644 --- a/.dockerignore +++ b/.dockerignore @@ -13,4 +13,4 @@ docker-compose.yml build/* dist/* docs/* -docker/ \ No newline at end of file +docker/ diff --git a/.gitignore b/.gitignore index 36dd69b..aea7b48 100644 --- a/.gitignore +++ b/.gitignore @@ -152,4 +152,4 @@ venv.bak/ dmypy.json # Pyre type checker -.pyre/ \ No newline at end of file +.pyre/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..3a5071c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + - id: debug-statements + - id: check-case-conflict + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.6 + hooks: + - id: ruff + args: [ --fix ] + - id: ruff-format diff --git a/CODEOWNERS b/CODEOWNERS index cfa029b..b184419 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1 +1 @@ -@jbohnslav \ No newline at end of file +@jbohnslav diff --git a/MANIFEST.in b/MANIFEST.in index 39022d6..21985dd 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,3 @@ include README.md include deepethogram/gui/icons/*.png -recursive-include deepethogram/conf * \ No newline at end of file +recursive-include deepethogram/conf * diff --git a/README.md b/README.md index 56e8e00..865485d 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ - Written by Jim Bohnslav, except where as noted - JBohnslav@gmail.com -DeepEthogram is an open-source package for automatically classifying each frame of a video into a set of pre-defined -behaviors. Designed for neuroscience research, it could be used in any scenario where you need to detect actions from +DeepEthogram is an open-source package for automatically classifying each frame of a video into a set of pre-defined +behaviors. Designed for neuroscience research, it could be used in any scenario where you need to detect actions from each frame of a video. Example use cases: @@ -12,7 +12,7 @@ Example use cases: * Counting licks from video for appetite measurement * Measuring reach onset times for alignment with neural activity -DeepEthogram uses state-of-the-art algorithms for *temporal action detection*. We build on the following previous machine +DeepEthogram uses state-of-the-art algorithms for *temporal action detection*. We build on the following previous machine learning research into action detection: * [Hidden Two-Stream Convolutional Networks for Action Recognition](https://arxiv.org/abs/1704.00389) * [Temporal Gaussian Mixture Layer for Videos](https://arxiv.org/abs/1803.06316) @@ -20,24 +20,24 @@ learning research into action detection: ![deepethogram schematic](docs/images/deepethogram_schematic.png) ## Installation -For full installation instructions, see [this readme file](docs/installation.md). +For full installation instructions, see [this readme file](docs/installation.md). -In brief: -* [Install PyTorch](https://pytorch.org/) +In brief: +* [Install PyTorch](https://pytorch.org/) * `pip install deepethogram` ## Data -**NEW!** All datasets collected and annotated by the DeepEthogram authors are now available from this DropBox link: +**NEW!** All datasets collected and annotated by the DeepEthogram authors are now available from this DropBox link: https://www.dropbox.com/sh/3lilfob0sz21och/AABv8o8KhhRQhYCMNu0ilR8wa?dl=0 -If you have issues downloading the data, please raise an issue on Github. +If you have issues downloading the data, please raise an issue on Github. ## COLAB -I've written a Colab notebook that shows how to upload your data and train models. You can also use this if you don't -have access to a decent GPU. +I've written a Colab notebook that shows how to upload your data and train models. You can also use this if you don't +have access to a decent GPU. -To use it, please [click this link to the Colab notebook](https://colab.research.google.com/drive/1Nf9FU7FD77wgvbUFc608839v2jPYgDhd?usp=sharing). -Then, click `copy to Drive` at the top. You won't be able to save your changes to the notebook as-is. +To use it, please [click this link to the Colab notebook](https://colab.research.google.com/drive/1Nf9FU7FD77wgvbUFc608839v2jPYgDhd?usp=sharing). +Then, click `copy to Drive` at the top. You won't be able to save your changes to the notebook as-is. ## News @@ -45,23 +45,23 @@ We now support docker! Docker is a way to run `deepethogram` in completely repro with other system dependencies. [See docs/Docker for more information](docs/docker.md) ## Pretrained models -Rather than start from scratch, we will start with model weights pretrained on the Kinetics700 dataset. Go to +Rather than start from scratch, we will start with model weights pretrained on the Kinetics700 dataset. Go to To download the pretrained weights, please use [this Google Drive link](https://drive.google.com/file/d/1ntIZVbOG1UAiFVlsAAuKEBEVCVevyets/view?usp=sharing). -Unzip the files in your `project/models` directory. Make sure that you don't add an extra directory when unzipping! The path should be +Unzip the files in your `project/models` directory. Make sure that you don't add an extra directory when unzipping! The path should be `your_project/models/pretrained_models/{models 1:6}`, not `your_project/models/pretrained_models/pretrained_models/{models1:6}`. ## Licensing Copyright (c) 2020 - President and Fellows of Harvard College. All rights reserved. -This software is free for academic use. For commercial use, please contact the Harvard Office of Technology -Development (hms_otd@harvard.edu) with cc to Dr. Chris Harvey. For details, see [license.txt](license.txt). +This software is free for academic use. For commercial use, please contact the Harvard Office of Technology +Development (hms_otd@harvard.edu) with cc to Dr. Chris Harvey. For details, see [license.txt](license.txt). ## Usage ### [To use the GUI, click](docs/using_gui.md) #### [To use the command line interface, click](docs/using_CLI.md) ## Dependencies -The major dependencies for DeepEthogram are as follows: +The major dependencies for DeepEthogram are as follows: * pytorch, torchvision: all the neural networks, training, and inference pipelines were written in PyTorch * pytorch-lightning: for nice model training base classes * kornia: for GPU-based image augmentations @@ -76,25 +76,48 @@ The major dependencies for DeepEthogram are as follows: * tqdm: for nice progress bars ## Hardware requirements -For GUI usage, we expect that the users will be working on a local workstation with a good NVIDIA graphics card. For training via a cluster, you can use the command line interface. +For GUI usage, we expect that the users will be working on a local workstation with a good NVIDIA graphics card. For training via a cluster, you can use the command line interface. * CPU: 4 cores or more for parallel data loading * Hard Drive: SSD at minimum, NVMe drive is better. -* GPU: DeepEthogram speed is directly related to GPU performance. An NVIDIA GPU is absolutely required, as PyTorch uses -CUDA, while AMD does not. +* GPU: DeepEthogram speed is directly related to GPU performance. An NVIDIA GPU is absolutely required, as PyTorch uses +CUDA, while AMD does not. The more VRAM you have, the more data you can fit in one batch, which generally increases performance. a I'd recommend 6GB VRAM at absolute minimum. 8GB is better, with 10+ GB preferred. -Recommended GPUs: `RTX 3090`, `RTX 3080`, `Titan RTX`, `2080 Ti`, `2080 super`, `2080`, `1080 Ti`, `2070 super`, `2070` -Some older ones might also be fine, like a `1080` or even `1070 Ti`/ `1070`. +Recommended GPUs: `RTX 3090`, `RTX 3080`, `Titan RTX`, `2080 Ti`, `2080 super`, `2080`, `1080 Ti`, `2070 super`, `2070` +Some older ones might also be fine, like a `1080` or even `1070 Ti`/ `1070`. ## testing -Test coverage is still low, but in the future we will be expanding our unit tests. +Test coverage is still low, but in the future we will be expanding our unit tests. First, download a copy of [`testing_deepethogram_archive.zip`](https://drive.google.com/file/d/1IFz4ABXppVxyuhYik8j38k9-Fl9kYKHo/view?usp=sharing) - Make a directory in tests called `DATA`. Unzip this and move it to the `deepethogram/tests/DATA` -directory, so that the path is `deepethogram/tests/DATA/testing_deepethogram_archive/{DATA,models,project_config.yaml}`. Then run `pytest tests/` to run. -the `zz_commandline` test module will take a few minutes, as it is an end-to-end test that performs model training -and inference. Its name reflects the fact that it should come last in testing. + Make a directory in tests called `DATA`. Unzip this and move it to the `deepethogram/tests/DATA` +directory, so that the path is `deepethogram/tests/DATA/testing_deepethogram_archive/{DATA,models,project_config.yaml}`. Then run `pytest tests/` to run. +the `zz_commandline` test module will take a few minutes, as it is an end-to-end test that performs model training +and inference. Its name reflects the fact that it should come last in testing. + +## Developer Guide +### Code Style and Pre-commit Hooks +We use pre-commit hooks to maintain code quality and consistency. The hooks include: +- Ruff for Python linting and formatting +- Various file checks (trailing whitespace, YAML validation, etc.) + +To set up the development environment: + +1. Install the development dependencies: +```bash +pip install -r requirements.txt +``` + +2. Install pre-commit hooks: +```bash +pre-commit install +``` + +The hooks will run automatically on every commit. You can also run them manually on all files: +```bash +pre-commit run --all-files +``` ## Changelog * 0.1.4: bugfixes for dependencies; added docker @@ -102,6 +125,6 @@ and inference. Its name reflects the fact that it should come last in testing. * 0.1.1.post1/2: batch prediction * 0.1.1.post0: flow generator metric bug fix * 0.1.1: bug fixes -* 0.1: deepethogram beta! See above for details. +* 0.1: deepethogram beta! See above for details. * 0.0.1.post1: bug fixes and video conversion scripts added * 0.0.1: initial version diff --git a/deepethogram/__init__.py b/deepethogram/__init__.py index d6c61fd..f21a6e6 100644 --- a/deepethogram/__init__.py +++ b/deepethogram/__init__.py @@ -7,9 +7,9 @@ # from deepethogram.sequence.inference import sequence_inference import importlib.util -spec = importlib.util.find_spec('hydra') +spec = importlib.util.find_spec("hydra") if spec is not None: - raise ValueError('Hydra installation found. Please run pip uninstall hydra-core: {}'.format(spec)) + raise ValueError("Hydra installation found. Please run pip uninstall hydra-core: {}".format(spec)) # try: # import hydra # except Exception as e: diff --git a/deepethogram/__main__.py b/deepethogram/__main__.py index 85e8553..d50518a 100644 --- a/deepethogram/__main__.py +++ b/deepethogram/__main__.py @@ -6,5 +6,5 @@ # def main(cfg: DictConfig) -> None: # run(cfg) -if __name__ == '__main__': - entry() \ No newline at end of file +if __name__ == "__main__": + entry() diff --git a/deepethogram/base.py b/deepethogram/base.py index 7943a5f..364a88b 100644 --- a/deepethogram/base.py +++ b/deepethogram/base.py @@ -8,11 +8,12 @@ import matplotlib.pyplot as plt from omegaconf import DictConfig, OmegaConf import pytorch_lightning as pl + try: - from ray.tune.integration.pytorch_lightning import TuneReportCallback, \ - TuneReportCheckpointCallback + from ray.tune.integration.pytorch_lightning import TuneReportCallback, TuneReportCheckpointCallback from ray.tune import get_trial_dir from ray.tune import CLIReporter + ray = True except ImportError: ray = False @@ -21,8 +22,13 @@ from torch.utils.data import DataLoader, WeightedRandomSampler from deepethogram.data.augs import get_gpu_transforms, get_empty_gpu_transforms -from deepethogram.callbacks import FPSCallback, MetricsCallback, \ - ExampleImagesCallback, CheckpointCallback, StopperCallback +from deepethogram.callbacks import ( + FPSCallback, + MetricsCallback, + ExampleImagesCallback, + CheckpointCallback, + StopperCallback, +) from deepethogram.metrics import Metrics, EmptyMetrics from deepethogram.schedulers import initialize_scheduler from deepethogram import viz, utils @@ -31,8 +37,7 @@ class BaseLightningModule(pl.LightningModule): - """Base class for all Lightning modules for training - """ + """Base class for all Lightning modules for training""" def __init__(self, model: nn.Module, cfg: DictConfig, datasets: dict, metrics: Metrics, visualization_func): """constructor @@ -69,10 +74,10 @@ def __init__(self, model: nn.Module, cfg: DictConfig, datasets: dict, metrics: M self.visualization_func = visualization_func model_type = cfg.run.model - if model_type in ['feature_extractor', 'flow_generator']: + if model_type in ["feature_extractor", "flow_generator"]: arch = self.hparams[model_type].arch - gpu_transforms = get_gpu_transforms(self.hparams.augs, '3d' if '3d' in arch.lower() else '2d') - elif model_type == 'sequence': + gpu_transforms = get_gpu_transforms(self.hparams.augs, "3d" if "3d" in arch.lower() else "2d") + elif model_type == "sequence": gpu_transforms = get_empty_gpu_transforms() else: raise NotImplementedError @@ -82,21 +87,21 @@ def __init__(self, model: nn.Module, cfg: DictConfig, datasets: dict, metrics: M self.optimizer = None # will be overridden in configure_optimizers self.hparams.weight_decay = None - if self.metrics.key_metric == 'loss' or self.metrics.key_metric == 'SSIM': - self.scheduler_mode = 'min' + if self.metrics.key_metric == "loss" or self.metrics.key_metric == "SSIM": + self.scheduler_mode = "min" else: # accuracy, F1, etc. - self.scheduler_mode = 'max' + self.scheduler_mode = "max" # need to move this to top-level for lightning's learning rate finder # don't set it to auto here, so that we can automatically find batch size first - self.lr = self.hparams.train.lr if self.hparams.train.lr != 'auto' else 1e-4 - log.info('scheduler mode: {}'.format(self.scheduler_mode)) + self.lr = self.hparams.train.lr if self.hparams.train.lr != "auto" else 1e-4 + log.info("scheduler mode: {}".format(self.scheduler_mode)) # self.is_key_metric_loss = self.metrics.key_metric == 'loss' self.viz_cnt = defaultdict(int) # for hyperparameter tuning, log specific hyperparameters and metrics for tensorboard - if 'tune' in cfg.keys(): + if "tune" in cfg.keys(): # print('KEYS KEYS KEYS') tune_keys = list(cfg.tune.hparams.keys()) # this function goes takes a list like [`feature_extractor.dropout_p`, `train.loss_weight_exp`], and finds @@ -107,7 +112,7 @@ def __init__(self, model: nn.Module, cfg: DictConfig, datasets: dict, metrics: M self.tune_hparams = {} self.tune_metrics = [] - self.samplers = {'train': self.get_train_sampler(), 'val': self.get_val_sampler(), 'test': None} + self.samplers = {"train": self.get_train_sampler(), "val": self.get_val_sampler(), "test": None} def on_train_epoch_start(self, *args, **kwargs): # I couldn't figure out how to make sure that this is called after BOTH train and validation ends @@ -121,31 +126,34 @@ def on_test_epoch_end(self): def get_dataloader(self, split: str): # for use with auto-batch-sizing. Lightning doesn't expect batch size to be nested, it expects it to be # top-level in self.hparams - batch_size = self.hparams.compute.batch_size if self.hparams.compute.batch_size != 'auto' else \ - self.hparams.batch_size - - shuffles = {'train': self.samplers['train'] is None, 'val': self.samplers['val'] is None, 'test': False} - - dataloader = DataLoader(self.datasets[split], - batch_size=batch_size, - shuffle=shuffles[split], - num_workers=self.hparams.compute.num_workers, - pin_memory=torch.cuda.is_available(), - drop_last=False, - sampler=self.samplers[split]) + batch_size = ( + self.hparams.compute.batch_size if self.hparams.compute.batch_size != "auto" else self.hparams.batch_size + ) + + shuffles = {"train": self.samplers["train"] is None, "val": self.samplers["val"] is None, "test": False} + + dataloader = DataLoader( + self.datasets[split], + batch_size=batch_size, + shuffle=shuffles[split], + num_workers=self.hparams.compute.num_workers, + pin_memory=torch.cuda.is_available(), + drop_last=False, + sampler=self.samplers[split], + ) return dataloader def train_dataloader(self): - return self.get_dataloader('train') + return self.get_dataloader("train") def val_dataloader(self): - return self.get_dataloader('val') + return self.get_dataloader("val") def test_dataloader(self): - if 'test' in self.datasets.keys() and self.datasets['test'] is not None: - return self.get_dataloader('test') + if "test" in self.datasets.keys() and self.datasets["test"] is not None: + return self.get_dataloader("test") else: - raise ValueError('no test set!') + raise ValueError("no test set!") def training_step(self, batch: dict, batch_idx: int): raise NotImplementedError @@ -160,14 +168,13 @@ def forward(self, batch: dict, mode: str) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError def get_train_sampler(self): - """gets a WeightedRandomSampler for over-sampling rare classes. Not rigorously evaluated - """ - dataset = self.datasets['train'] - if not hasattr(dataset, 'labels') or dataset.labels is None: + """gets a WeightedRandomSampler for over-sampling rare classes. Not rigorously evaluated""" + dataset = self.datasets["train"] + if not hasattr(dataset, "labels") or dataset.labels is None: # self-supervised, e.g. flow generators return if self.hparams.train.oversampling_exp < 1e-4: - log.info('not using oversampling') + log.info("not using oversampling") return # total positive examples of each class in our training set class_counts = dataset.labels.sum(axis=0) @@ -182,15 +189,15 @@ def get_train_sampler(self): sample_weights = dataset.labels @ sampling_ratio replacement = self.hparams.train.oversampling_exp > 1e-4 - log.info('oversampling exp: {}'.format(self.hparams.train.oversampling_exp)) - log.info('oversampling ratio: {}'.format(sampling_ratio)) + log.info("oversampling exp: {}".format(self.hparams.train.oversampling_exp)) + log.info("oversampling ratio: {}".format(sampling_ratio)) sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=replacement) return sampler def get_val_sampler(self): # get sample weights for validation dataset to up-sample rare classes - dataset = self.datasets['val'] + dataset = self.datasets["val"] # if dataset.labels is None: # # self-supervised, e.g. flow generators # return @@ -216,28 +223,26 @@ def apply_gpu_transforms(self, images: torch.Tensor, mode: str) -> torch.Tensor: return images def configure_optimizers(self): - weight_decay = 0 # if self.hparams.weight_decay is None else self.hparams.weight_decay - optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), - lr=self.lr, - weight_decay=weight_decay) + optimizer = optim.Adam( + filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.lr, weight_decay=weight_decay + ) self.optimizer = optimizer - log.info('learning rate: {}'.format(self.lr)) - scheduler = initialize_scheduler(optimizer, - self.hparams, - mode=self.scheduler_mode, - reduction_factor=self.hparams.train.reduction_factor) - monitor_key = 'val/' + self.metrics.key_metric - return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': monitor_key} + log.info("learning rate: {}".format(self.lr)) + scheduler = initialize_scheduler( + optimizer, self.hparams, mode=self.scheduler_mode, reduction_factor=self.hparams.train.reduction_factor + ) + monitor_key = "val/" + self.metrics.key_metric + return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": monitor_key} # @profile default_tune_dict = { - 'loss': 'val_loss', - 'f1_micro': 'val_f1_class_mean', - 'data_loss': 'val_data_loss', - 'reg_loss': 'val_reg_loss' + "loss": "val_loss", + "f1_micro": "val_f1_class_mean", + "data_loss": "val_data_loss", + "reg_loss": "val_reg_loss", } @@ -267,25 +272,27 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html """ steps_per_epoch = cfg.train.steps_per_epoch - for split in ['train', 'val', 'test']: + for split in ["train", "val", "test"]: steps_per_epoch[split] = steps_per_epoch[split] if steps_per_epoch[split] is not None else 1.0 # reload_dataloaders_every_epoch = True: a bit slower, but enables validation dataloader to get the new, automatic # learning rate schedule. - if cfg.compute.batch_size == 'auto' or cfg.train.lr == 'auto': - trainer = pl.Trainer(gpus=[cfg.compute.gpu_id], - precision=16 if cfg.compute.fp16 else 32, - limit_train_batches=1.0, - limit_val_batches=1.0, - limit_test_batches=1.0, - num_sanity_val_steps=0) + if cfg.compute.batch_size == "auto" or cfg.train.lr == "auto": + trainer = pl.Trainer( + gpus=[cfg.compute.gpu_id], + precision=16 if cfg.compute.fp16 else 32, + limit_train_batches=1.0, + limit_val_batches=1.0, + limit_test_batches=1.0, + num_sanity_val_steps=0, + ) # callbacks=[ExampleImagesCallback()]) tmp_metrics = lightning_module.metrics tmp_workers = lightning_module.hparams.compute.num_workers # visualize_examples = lightning_module.visualize_examples - if lightning_module.model_type != 'sequence': + if lightning_module.model_type != "sequence": # there is a somewhat common error that VRAM will be maximized by the gpu-auto-tuner. # However, during training, we probabilistically sample colorspace transforms; in an "unlucky" # batch, perhaps all of the training samples are converted to HSV, hue and saturation changed, then changed @@ -293,17 +300,17 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s # so, we crank up the colorspace augmentation probability, then pick batch size, then change it back original_gpu_transforms = deepcopy(lightning_module.gpu_transforms) - log.debug('orig: {}'.format(lightning_module.gpu_transforms)) + log.debug("orig: {}".format(lightning_module.gpu_transforms)) original_augs = cfg.augs new_augs = deepcopy(cfg.augs) new_augs.color_p = 1.0 arch = lightning_module.hparams[lightning_module.model_type].arch - mode = '2d' - gpu_transforms = get_gpu_transforms(new_augs, '3d' if '3d' in arch.lower() else '2d') + mode = "2d" + gpu_transforms = get_gpu_transforms(new_augs, "3d" if "3d" in arch.lower() else "2d") lightning_module.gpu_transforms = gpu_transforms - log.debug('new: {}'.format(lightning_module.gpu_transforms)) + log.debug("new: {}".format(lightning_module.gpu_transforms)) tuner = pl.tuner.tuning.Tuner(trainer) # hack for lightning to find the batch size @@ -318,25 +325,27 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s lightning_module.hparams.train.viz_examples = 0 # dramatically reduces RAM usage by this process lightning_module.hparams.compute.num_workers = min(tmp_workers, 1) - if cfg.compute.batch_size == 'auto': + if cfg.compute.batch_size == "auto": max_trials = int(math.log2(cfg.compute.max_batch_size)) - int(math.log2(cfg.compute.min_batch_size)) - log.info('max trials: {}'.format(max_trials)) - new_batch_size = trainer.tuner.scale_batch_size(lightning_module, - mode='power', - steps_per_trial=30, - init_val=cfg.compute.min_batch_size, - max_trials=max_trials) + log.info("max trials: {}".format(max_trials)) + new_batch_size = trainer.tuner.scale_batch_size( + lightning_module, + mode="power", + steps_per_trial=30, + init_val=cfg.compute.min_batch_size, + max_trials=max_trials, + ) cfg.compute.batch_size = new_batch_size - log.info('auto-tuned batch size: {}'.format(new_batch_size)) - if cfg.train.lr == 'auto': + log.info("auto-tuned batch size: {}".format(new_batch_size)) + if cfg.train.lr == "auto": lr_finder = trainer.tuner.lr_find(lightning_module, early_stop_threshold=None, min_lr=1e-6, max_lr=10.0) # log.info(lr_finder.results) - plt.style.use('seaborn') + plt.style.use("seaborn") fig = lr_finder.plot(suggest=True, show=False) - viz.save_figure(fig, 'auto_lr_finder', False, 0, overwrite=False) + viz.save_figure(fig, "auto_lr_finder", False, 0, overwrite=False) plt.close(fig) new_lr = lr_finder.suggestion() - log.info('auto-tuned learning rate: {}'.format(new_lr)) + log.info("auto-tuned learning rate: {}".format(new_lr)) cfg.train.lr = new_lr lightning_module.lr = new_lr lightning_module.hparams.lr = new_lr @@ -345,34 +354,30 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s lightning_module.hparams.train.viz_examples = should_viz lightning_module.metrics = tmp_metrics lightning_module.hparams.compute.num_workers = tmp_workers - if lightning_module.model_type != 'sequence': + if lightning_module.model_type != "sequence": lightning_module.gpu_transforms = original_gpu_transforms - log.debug('reverted: {}'.format(lightning_module.gpu_transforms)) + log.debug("reverted: {}".format(lightning_module.gpu_transforms)) key_metric = lightning_module.metrics.key_metric - mode = 'min' if 'loss' in key_metric else 'max' - monitor = f'val/{key_metric}' - dirpath = os.path.join(cfg.run.dir, 'lightning_checkpoints') + mode = "min" if "loss" in key_metric else "max" + monitor = f"val/{key_metric}" + dirpath = os.path.join(cfg.run.dir, "lightning_checkpoints") callback_list = [ FPSCallback(), MetricsCallback(), ExampleImagesCallback(), CheckpointCallback(), StopperCallback(stopper), - pl.callbacks.ModelCheckpoint(dirpath=dirpath, - save_top_k=1, - save_last=True, - mode=mode, - monitor=monitor, - save_weights_only=True) + pl.callbacks.ModelCheckpoint( + dirpath=dirpath, save_top_k=1, save_last=True, mode=mode, monitor=monitor, save_weights_only=True + ), ] - if 'tune' in cfg and cfg.tune.use and ray: - callback_list.append(TuneReportCallback(OmegaConf.to_container(cfg.tune.metrics), on='validation_end')) + if "tune" in cfg and cfg.tune.use and ray: + callback_list.append(TuneReportCallback(OmegaConf.to_container(cfg.tune.metrics), on="validation_end")) # https://docs.ray.io/en/master/tune/tutorials/tune-pytorch-lightning.html - tensorboard_logger = pl.loggers.tensorboard.TensorBoardLogger(save_dir=get_trial_dir(), - name="", - version=".", - default_hp_metric=False) + tensorboard_logger = pl.loggers.tensorboard.TensorBoardLogger( + save_dir=get_trial_dir(), name="", version=".", default_hp_metric=False + ) refresh_rate = 0 else: tensorboard_logger = pl.loggers.tensorboard.TensorBoardLogger(os.getcwd()) @@ -382,34 +387,38 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s try: # will be deprecated in the future; pytorch lightning updated their kwargs for this function # don't like how they keep updating the api without proper deprecation warnings, etc. - trainer = pl.Trainer(gpus=[cfg.compute.gpu_id], - precision=16 if cfg.compute.fp16 else 32, - limit_train_batches=steps_per_epoch['train'], - limit_val_batches=steps_per_epoch['val'], - limit_test_batches=steps_per_epoch['test'], - logger=tensorboard_logger, - max_epochs=cfg.train.num_epochs, - num_sanity_val_steps=0, - callbacks=callback_list, - reload_dataloaders_every_epoch=True, - progress_bar_refresh_rate=refresh_rate, - profiler=profiler, - log_every_n_steps=1) + trainer = pl.Trainer( + gpus=[cfg.compute.gpu_id], + precision=16 if cfg.compute.fp16 else 32, + limit_train_batches=steps_per_epoch["train"], + limit_val_batches=steps_per_epoch["val"], + limit_test_batches=steps_per_epoch["test"], + logger=tensorboard_logger, + max_epochs=cfg.train.num_epochs, + num_sanity_val_steps=0, + callbacks=callback_list, + reload_dataloaders_every_epoch=True, + progress_bar_refresh_rate=refresh_rate, + profiler=profiler, + log_every_n_steps=1, + ) except TypeError: - trainer = pl.Trainer(gpus=[cfg.compute.gpu_id], - precision=16 if cfg.compute.fp16 else 32, - limit_train_batches=steps_per_epoch['train'], - limit_val_batches=steps_per_epoch['val'], - limit_test_batches=steps_per_epoch['test'], - logger=tensorboard_logger, - max_epochs=cfg.train.num_epochs, - num_sanity_val_steps=0, - callbacks=callback_list, - reload_dataloaders_every_n_epochs=1, - progress_bar_refresh_rate=refresh_rate, - profiler=profiler, - log_every_n_steps=1) + trainer = pl.Trainer( + gpus=[cfg.compute.gpu_id], + precision=16 if cfg.compute.fp16 else 32, + limit_train_batches=steps_per_epoch["train"], + limit_val_batches=steps_per_epoch["val"], + limit_test_batches=steps_per_epoch["test"], + logger=tensorboard_logger, + max_epochs=cfg.train.num_epochs, + num_sanity_val_steps=0, + callbacks=callback_list, + reload_dataloaders_every_n_epochs=1, + progress_bar_refresh_rate=refresh_rate, + profiler=profiler, + log_every_n_steps=1, + ) torch.cuda.empty_cache() # gc.collect() diff --git a/deepethogram/callbacks.py b/deepethogram/callbacks.py index 9c2a41c..1c0ff78 100644 --- a/deepethogram/callbacks.py +++ b/deepethogram/callbacks.py @@ -12,69 +12,67 @@ class DebugCallback(Callback): - def __init__(self): super().__init__() - log.info('callback initialized') + log.info("callback initialized") def on_init_end(self, trainer): - log.info('on init start') + log.info("on init start") def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - log.debug('on train batch start') + log.debug("on train batch start") def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - log.debug('on train batch end') + log.debug("on train batch end") def on_train_epoch_start(self, trainer, pl_module): - log.info('on train epoch start') + log.info("on train epoch start") def on_train_epoch_end(self, *args, **kwargs): - log.info('on train epoch end') + log.info("on train epoch end") def on_validation_epoch_start(self, trainer, pl_module): - log.info('on validation epoch start') + log.info("on validation epoch start") def on_validation_epoch_end(self, trainer, pl_module): - log.info('on validation epoch end') + log.info("on validation epoch end") def on_test_epoch_start(self, trainer, pl_module): - log.info('on test epoch start') + log.info("on test epoch start") def on_test_epoch_end(self, trainer, pl_module): - log.info('on test epoch end') + log.info("on test epoch end") def on_epoch_start(self, trainer, pl_module): - log.info('on epoch start') + log.info("on epoch start") def on_epoch_end(self, trainer, pl_module): - log.info('on epoch end') + log.info("on epoch end") def on_train_start(self, trainer, pl_module): - log.info('on train start') + log.info("on train start") def on_train_end(self, trainer, pl_module): - log.info('on train end') + log.info("on train end") def on_validation_start(self, trainer, pl_module): - log.info('on validation start') + log.info("on validation start") def on_validation_end(self, trainer, pl_module): - log.info('on validation end') + log.info("on validation end") def on_keyboard_interrupt(self, trainer, pl_module): - log.info('on keyboard interrupt') + log.info("on keyboard interrupt") class FPSCallback(Callback): - """Measures frames per second in training and inference - """ + """Measures frames per second in training and inference""" def __init__(self): super().__init__() - self.times = {'train': 0.0, 'val': 0.0, 'test': 0.0, 'speedtest': 0.0} - self.n_images = {'train': 0, 'val': 0, 'test': 0, 'speedtest': 0} - self.fps = {'train': 0.0, 'val': 0.0, 'test': 0.0, 'speedtest': 0.0} + self.times = {"train": 0.0, "val": 0.0, "test": 0.0, "speedtest": 0.0} + self.n_images = {"train": 0, "val": 0, "test": 0, "speedtest": 0} + self.fps = {"train": 0.0, "val": 0.0, "test": 0.0, "speedtest": 0.0} def start_timer(self, split): self.times[split] = time.time() @@ -92,25 +90,25 @@ def end_batch(self, split, batch, pl_module, eps: float = 1e-7): n_images = self.get_num_images(batch) fps = n_images / elapsed - pl_module.metrics.buffer.append(split, {'fps': fps}) + pl_module.metrics.buffer.append(split, {"fps": fps}) def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - self.start_timer('train') + self.start_timer("train") def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self.end_batch('train', batch, pl_module) + self.end_batch("train", batch, pl_module) def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - self.start_timer('val') + self.start_timer("val") def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self.end_batch('val', batch, pl_module) + self.end_batch("val", batch, pl_module) def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - self.start_timer('speedtest') + self.start_timer("speedtest") def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self.end_batch('speedtest', batch, pl_module) + self.end_batch("speedtest", batch, pl_module) # class SpeedtestCallback(Callback): @@ -120,7 +118,7 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal # def on_validation_end(self, trainer, pl_module): # trainer.test(pl_module) def log_metrics(pl_module, split): - assert split in ['train', 'val', 'test'] + assert split in ["train", "val", "test"] metrics, _ = pl_module.metrics.end_epoch(split) scalar_metrics = {} for key, value in metrics.items(): @@ -130,29 +128,28 @@ def log_metrics(pl_module, split): value = value.squeeze()[0] if np.isscalar(value): # print('{}/{}: {:.2f}'.format(split, key, value)) - pl_module.log(split + '/' + key, value, on_epoch=True) - scalar_metrics[split + '/' + key] = value + pl_module.log(split + "/" + key, value, on_epoch=True) + scalar_metrics[split + "/" + key] = value return scalar_metrics class MetricsCallback(Callback): - """Uses the lightning module to log metrics and hyperparameters, e.g. for tensorboard - """ + """Uses the lightning module to log metrics and hyperparameters, e.g. for tensorboard""" def __init__(self): super().__init__() def on_train_epoch_end(self, trainer, pl_module): - pl_module.metrics.buffer.append('train', {'lr': utils.get_minimum_learning_rate(pl_module.optimizer)}) - _ = log_metrics(pl_module, 'train') + pl_module.metrics.buffer.append("train", {"lr": utils.get_minimum_learning_rate(pl_module.optimizer)}) + _ = log_metrics(pl_module, "train") # latest_key = pl_module.metrics.latest_key['train'] # key = 'train_{}'.format(pl_module.metrics.key_metric) # pl_module.log(key, latest_key, on_epoch=True) def on_validation_epoch_end(self, trainer, pl_module): - scalar_metrics = log_metrics(pl_module, 'val') - latest_key = pl_module.metrics.latest_key['val'] + scalar_metrics = log_metrics(pl_module, "val") + latest_key = pl_module.metrics.latest_key["val"] # this logic is to correctly log only important hyperparameters and important metrics to tensorboard's # hyperparameter view. Just using all the parameters in our configuration makes for a huge and ugly tensorboard @@ -163,10 +160,11 @@ def on_validation_epoch_end(self, trainer, pl_module): for key in pl_module.tune_metrics: # have to have a different key, otherwise pytorch lightning will log it twice if key in scalar_metrics.keys(): - hparam_metrics['hp/' + key] = scalar_metrics[key] + hparam_metrics["hp/" + key] = scalar_metrics[key] else: - log.warning('requested hparam metric {} not found in metrics: {}'.format( - key, list(scalar_metrics.keys()))) + log.warning( + "requested hparam metric {} not found in metrics: {}".format(key, list(scalar_metrics.keys())) + ) print(pl_module.tune_hparams, hparam_metrics) pl_module.logger.log_hyperparams(pl_module.tune_hparams, hparam_metrics) @@ -174,7 +172,7 @@ def on_validation_epoch_end(self, trainer, pl_module): # pl_module.log('hp_metric', latest_key, on_epoch=True) def on_test_epoch_end(self, trainer, pl_module): - log_metrics(pl_module, 'test') + log_metrics(pl_module, "test") # pl_module.metrics.end_epoch('speedtest') def on_keyboard_interrupt(self, trainer, pl_module): @@ -182,7 +180,6 @@ def on_keyboard_interrupt(self, trainer, pl_module): class ExampleImagesCallback(Callback): - def __init__(self): super().__init__() @@ -196,26 +193,25 @@ def reset_cnt(self, pl_module, split): pl_module.viz_cnt[split] = 0 def on_train_epoch_end(self, trainer, pl_module): - self.reset_cnt(pl_module, 'train') + self.reset_cnt(pl_module, "train") def on_validation_epoch_end(self, trainer, pl_module): - self.reset_cnt(pl_module, 'val') + self.reset_cnt(pl_module, "val") def on_test_epoch_end(self, trainer, pl_module): - self.reset_cnt(pl_module, 'test') + self.reset_cnt(pl_module, "test") def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - pl_module.viz_cnt['train'] += 1 + pl_module.viz_cnt["train"] += 1 def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - pl_module.viz_cnt['val'] += 1 + pl_module.viz_cnt["val"] += 1 def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - pl_module.viz_cnt['test'] += 1 + pl_module.viz_cnt["test"] += 1 class CheckpointCallback(Callback): - def __init__(self): super().__init__() @@ -230,7 +226,6 @@ def on_keyboard_interrupt(self, trainer, pl_module): class StopperCallback(Callback): - def __init__(self, stopper): super().__init__() self.stopper = stopper @@ -241,19 +236,19 @@ def on_train_epoch_end(self, trainer, pl_module): if pl_module.current_epoch == 0: return - if self.stopper.name == 'early': - _, should_stop = self.stopper(pl_module.metrics.latest_key['val']) - elif self.stopper.name == 'learning_rate': - min_lr = pl_module.metrics[('train', 'lr', -1)] + if self.stopper.name == "early": + _, should_stop = self.stopper(pl_module.metrics.latest_key["val"]) + elif self.stopper.name == "learning_rate": + min_lr = pl_module.metrics[("train", "lr", -1)] # log.info('LR: {}'.format(min_lr)) should_stop = self.stopper(min_lr) - elif self.stopper.name == 'num_epochs': + elif self.stopper.name == "num_epochs": should_stop = self.stopper.step() else: - raise ValueError('invalid stopping name: {}'.format(self.stopper.name)) + raise ValueError("invalid stopping name: {}".format(self.stopper.name)) if should_stop: # log.info('Stopping criterion reached! Raising KeyboardInterrupt to quit') - log.info('Stopping criterion reached! setting trainer.should_stop=True') + log.info("Stopping criterion reached! setting trainer.should_stop=True") trainer.should_stop = True # raise KeyboardInterrupt diff --git a/deepethogram/conf/augs.yaml b/deepethogram/conf/augs.yaml index 9e18928..eb39b67 100644 --- a/deepethogram/conf/augs.yaml +++ b/deepethogram/conf/augs.yaml @@ -48,4 +48,4 @@ augs: std: - 0.5 - 0.5 - - 0.5 \ No newline at end of file + - 0.5 diff --git a/deepethogram/conf/config.yaml b/deepethogram/conf/config.yaml index dc8eb7b..b491c5d 100644 --- a/deepethogram/conf/config.yaml +++ b/deepethogram/conf/config.yaml @@ -46,6 +46,6 @@ log: # project: project_config # hyra configuration: specifies how to create the run directory # hydra: -# run: +# run: # dir: ${project.path}/${project.model_path}/${now:%y%m%d_%H%M%S}_${run.model}_${run.type}_${notes} -# output_subdir: '' \ No newline at end of file +# output_subdir: '' diff --git a/deepethogram/conf/debug.yaml b/deepethogram/conf/debug.yaml index 8df36b4..6a89c2a 100644 --- a/deepethogram/conf/debug.yaml +++ b/deepethogram/conf/debug.yaml @@ -6,4 +6,4 @@ train: num_epochs: 3 tune: num_trials: 3 -debug: True \ No newline at end of file +debug: True diff --git a/deepethogram/conf/gui.yaml b/deepethogram/conf/gui.yaml index 214d711..e68a1fe 100644 --- a/deepethogram/conf/gui.yaml +++ b/deepethogram/conf/gui.yaml @@ -18,4 +18,4 @@ unlabeled_alpha: 0.1 # will have the below alpha value (predictions are always 1) prediction_opacity: 0.2 # notes to add to gui logs -notes: null \ No newline at end of file +notes: null diff --git a/deepethogram/conf/inference.yaml b/deepethogram/conf/inference.yaml index c791085..cfb6f79 100644 --- a/deepethogram/conf/inference.yaml +++ b/deepethogram/conf/inference.yaml @@ -8,6 +8,6 @@ inference: # if the group sequence.latent_name already exists in the HDF5 file, will overwrite it. # if false, that file would be skipped overwrite: false - # if True, overwrite settings in the config file with the one loaded from disk. E.g. if you try to configure + # if True, overwrite settings in the config file with the one loaded from disk. E.g. if you try to configure # inference dropout_p to be 0.9 but in the trained model it was 0.5, it will set dropout to 0.5 - use_loaded_model_cfg: true \ No newline at end of file + use_loaded_model_cfg: true diff --git a/deepethogram/conf/model/feature_extractor.yaml b/deepethogram/conf/model/feature_extractor.yaml index 72308e2..fa56860 100644 --- a/deepethogram/conf/model/feature_extractor.yaml +++ b/deepethogram/conf/model/feature_extractor.yaml @@ -37,4 +37,4 @@ train: train: 1000 val: 1000 test: null - num_epochs: 20 \ No newline at end of file + num_epochs: 20 diff --git a/deepethogram/conf/model/flow_generator.yaml b/deepethogram/conf/model/flow_generator.yaml index e0af0e1..a6fbef9 100644 --- a/deepethogram/conf/model/flow_generator.yaml +++ b/deepethogram/conf/model/flow_generator.yaml @@ -19,7 +19,7 @@ flow_generator: # path to a checkpoint.pt weight file for reloading. weights: pretrained train: - # overwrite default steps per epoch: because we don't care about rare classes for optic flow, don't need so much + # overwrite default steps per epoch: because we don't care about rare classes for optic flow, don't need so much # validation steps_per_epoch: train: 1000 diff --git a/deepethogram/conf/model/sequence.yaml b/deepethogram/conf/model/sequence.yaml index 3b380b5..bfb541e 100644 --- a/deepethogram/conf/model/sequence.yaml +++ b/deepethogram/conf/model/sequence.yaml @@ -63,7 +63,7 @@ train: # overwrite patience: because of Nonoverlapping, train epochs can be very low patience: 5 # overwrite num epochs. due to nonoverlapping, one epoch takes only a minute or two - num_epochs: 100 + num_epochs: 100 compute: min_batch_size: 2 - max_batch_size: 64 # sequence can get weird when batch sizes are too high \ No newline at end of file + max_batch_size: 64 # sequence can get weird when batch sizes are too high diff --git a/deepethogram/conf/postprocessor.yaml b/deepethogram/conf/postprocessor.yaml index 7de3948..10922dd 100644 --- a/deepethogram/conf/postprocessor.yaml +++ b/deepethogram/conf/postprocessor.yaml @@ -8,4 +8,4 @@ postprocessor: # if type is min_bout_per_behavior, this will be the PERCENTILE of each behavior's bout length distribution # if value is 5, then all bouts less than the 5th percentile of the label distribution will be removed # see deepethogram/postprocessing.py for details - min_bout_length: 1 \ No newline at end of file + min_bout_length: 1 diff --git a/deepethogram/conf/preset/deg_f.yaml b/deepethogram/conf/preset/deg_f.yaml index ce781ef..a7ea8bb 100644 --- a/deepethogram/conf/preset/deg_f.yaml +++ b/deepethogram/conf/preset/deg_f.yaml @@ -7,4 +7,4 @@ feature_extractor: n_rgb: 1 flow_generator: arch: TinyMotionNet - n: 10 \ No newline at end of file + n: 10 diff --git a/deepethogram/conf/preset/deg_m.yaml b/deepethogram/conf/preset/deg_m.yaml index 53c5231..a9cf08f 100644 --- a/deepethogram/conf/preset/deg_m.yaml +++ b/deepethogram/conf/preset/deg_m.yaml @@ -7,4 +7,4 @@ feature_extractor: n_rgb: 1 flow_generator: arch: MotionNet - n: 10 \ No newline at end of file + n: 10 diff --git a/deepethogram/conf/preset/deg_s.yaml b/deepethogram/conf/preset/deg_s.yaml index a5aa525..65e4a7d 100644 --- a/deepethogram/conf/preset/deg_s.yaml +++ b/deepethogram/conf/preset/deg_s.yaml @@ -10,4 +10,4 @@ flow_generator: n: 10 flow_sparsity: true sparsity_weight : 0.05 - smooth_weight_multiplier: 0.25 \ No newline at end of file + smooth_weight_multiplier: 0.25 diff --git a/deepethogram/conf/train.yaml b/deepethogram/conf/train.yaml index 837db8a..744d217 100644 --- a/deepethogram/conf/train.yaml +++ b/deepethogram/conf/train.yaml @@ -67,4 +67,3 @@ train: style: l2_sp alpha: 1e-5 beta: 1e-3 - diff --git a/deepethogram/conf/tune/feature_extractor.yaml b/deepethogram/conf/tune/feature_extractor.yaml index fcc4425..b4ab626 100644 --- a/deepethogram/conf/tune/feature_extractor.yaml +++ b/deepethogram/conf/tune/feature_extractor.yaml @@ -7,13 +7,13 @@ tune: # space: uniform # space: how to sample # short: dropout # a shortened version to view in Ray's command line interface # current_best: 0.25 # current best estimate. a moving target. used for initializing search space with hyperopt - # train.regularization.alpha: + # train.regularization.alpha: # min: 1e-7 # max: 1e-1 # space: log # short: reg_alpha # current_best: 1e-5 - # train.regularization.beta: + # train.regularization.beta: # min: 1e-4 # max: 1e-1 # space: log diff --git a/deepethogram/conf/tune/sequence.yaml b/deepethogram/conf/tune/sequence.yaml index d9fee97..547d6b6 100644 --- a/deepethogram/conf/tune/sequence.yaml +++ b/deepethogram/conf/tune/sequence.yaml @@ -7,7 +7,7 @@ tune: # space: uniform # space: how to sample # short: dropout # a shortened version to view in Ray's command line interface # current_best: 0.25 # current best estimate. a moving target. used for initializing search space with hyperopt - train.regularization.alpha: + train.regularization.alpha: min: 1e-2 max: 2 space: log @@ -49,4 +49,4 @@ tune: short: nonlinear_classification # use these to overwrite default configuration parameters only when running tune jobs compute: - max_batch_size: 32 \ No newline at end of file + max_batch_size: 32 diff --git a/deepethogram/conf/tune/tune.yaml b/deepethogram/conf/tune/tune.yaml index a18b572..b040c2c 100644 --- a/deepethogram/conf/tune/tune.yaml +++ b/deepethogram/conf/tune/tune.yaml @@ -14,11 +14,11 @@ tune: gpu: 0.5 cpu: 3 train: - viz_examples: 0 # don't spend time and space making example images + viz_examples: 0 # don't spend time and space making example images steps_per_epoch: train: 1000 val: 1000 num_epochs: 20 compute: metrics_workers: 0 # sometimes has a bug in already multiprocessed jobs - # batch_size: 64 # auto batch sizing takes a long time \ No newline at end of file + # batch_size: 64 # auto batch sizing takes a long time diff --git a/deepethogram/configuration.py b/deepethogram/configuration.py index 0f3a873..4748051 100644 --- a/deepethogram/configuration.py +++ b/deepethogram/configuration.py @@ -15,21 +15,21 @@ def config_string_to_path(config_path: Union[str, os.PathLike], string: str) -> config_path : Union[str, os.PathLike] absolute path to deepethogram/deepethogram/conf directory string : str - name of configuration. + name of configuration. Returns ------- str Absolute path to configuration default file - + Examples -------- >>> config_string_to_path('path/to/deepethogram/deepethogram/conf', 'tune/feature_extractor') 'path/to/deepethogram/deepethogram/conf/tune/feature_extractor.yaml' - + """ - fullpath = os.path.join(config_path, *string.split('/')) + '.yaml' - assert os.path.isfile(fullpath), f'{fullpath} not found' + fullpath = os.path.join(config_path, *string.split("/")) + ".yaml" + assert os.path.isfile(fullpath), f"{fullpath} not found" return fullpath @@ -47,7 +47,7 @@ def load_config_by_name(string: str, config_path: Union[str, os.PathLike] = None ------- DictConfig Configuration loaded from YAML file - + Examples -------- >>> load_config_by_name('model/feature_extractor') @@ -73,50 +73,52 @@ def load_config_by_name(string: str, config_path: Union[str, os.PathLike] = None """ if config_path is None: - config_path = os.path.join(os.path.dirname(deepethogram.__file__), 'conf') + config_path = os.path.join(os.path.dirname(deepethogram.__file__), "conf") path = config_string_to_path(config_path, string) return OmegaConf.load(path) -def make_config(project_path: Union[str, os.PathLike], - config_list: list, - run_type: str, - model: str, - use_command_line: bool = False, - preset: str = None, - debug: bool = False) -> DictConfig: - """Makes a configuration for model training or inference. - +def make_config( + project_path: Union[str, os.PathLike], + config_list: list, + run_type: str, + model: str, + use_command_line: bool = False, + preset: str = None, + debug: bool = False, +) -> DictConfig: + """Makes a configuration for model training or inference. + A list of default configurations are composed into one single cfg. From the project path, the project configuration - is found and loaded. If a preset is specified either in the config_list or in the project config, load "preset" - parameters. + is found and loaded. If a preset is specified either in the config_list or in the project config, load "preset" + parameters. - Order of composition: + Order of composition: 1. Defaults 2. Preset 3. Project configuration 4. Command line - - This means if you specify the value of a parameter (say, dropout probability) in multiple places, the last one + + This means if you specify the value of a parameter (say, dropout probability) in multiple places, the last one (highest number in above list) will be chosen. This means we can specify a default dropout (0.25); for your project, - you can specify a new default in your project_config (e.g. 0.5). For an experiment, you can use the commmand line - to set `feature_extractor.dropout_p=0.75`. If its in all 3 places, the command line "wins" and the actual dropout is - 0.75. + you can specify a new default in your project_config (e.g. 0.5). For an experiment, you can use the commmand line + to set `feature_extractor.dropout_p=0.75`. If its in all 3 places, the command line "wins" and the actual dropout is + 0.75. Parameters ---------- project_path : Union[str, os.PathLike] Path to deepethogram project. Should contain: project_config.yaml, models directory, DATA directory config_list : list - List of string names of default configurations. Each of them is the name of a file or sub-file in the - deepethogram/conf directory. + List of string names of default configurations. Each of them is the name of a file or sub-file in the + deepethogram/conf directory. run_type : str Train, inference, or gui model : str feature_extractor, flow_generator, or sequence use_command_line : bool, optional - If True, command line arguments are parsed and composed into the + If True, command line arguments are parsed and composed into the preset : str, optional One of deg_f, deg_m, deg_s, by default None debug : bool, optional @@ -136,21 +138,21 @@ def make_config(project_path: Union[str, os.PathLike], # then, append the user config # then, the command line args # so if we specify a preset and manually change, say, the feature extractor architecture, we can do that - if 'preset' in user_cfg: - config_list.append('preset/' + user_cfg.preset) + if "preset" in user_cfg: + config_list.append("preset/" + user_cfg.preset) if use_command_line: command_line_cfg = OmegaConf.from_cli() - if 'preset' in command_line_cfg: - config_list.append('preset/' + command_line_cfg.preset) + if "preset" in command_line_cfg: + config_list.append("preset/" + command_line_cfg.preset) # add this option so we can add a preset programmatically if preset is not None: - assert preset in ['deg_f', 'deg_m', 'deg_s'] - config_list.append('preset/' + preset) + assert preset in ["deg_f", "deg_m", "deg_s"] + config_list.append("preset/" + preset) if debug: - config_list.append('debug') + config_list.append("debug") # config_files = [config_string_to_path(config_path, i) for i in config_list] cfgs = [load_config_by_name(i) for i in config_list] @@ -161,7 +163,7 @@ def make_config(project_path: Union[str, os.PathLike], else: cfg = OmegaConf.merge(*cfgs, user_cfg) - cfg.run = {'type': run_type, 'model': model} + cfg.run = {"type": run_type, "model": model} return cfg @@ -178,9 +180,9 @@ def make_flow_generator_train_cfg(project_path: Union[str, os.PathLike], **kwarg DictConfig flow generator config """ - config_list = ['config', 'augs', 'train', 'model/flow_generator'] - run_type = 'train' - model = 'flow_generator' + config_list = ["config", "augs", "train", "model/flow_generator"] + run_type = "train" + model = "flow_generator" cfg = make_config(project_path=project_path, config_list=config_list, run_type=run_type, model=model, **kwargs) @@ -200,9 +202,9 @@ def make_feature_extractor_train_cfg(project_path: Union[str, os.PathLike], **kw DictConfig feature extractor train config """ - config_list = ['config', 'augs', 'train', 'model/flow_generator', 'model/feature_extractor'] - run_type = 'train' - model = 'feature_extractor' + config_list = ["config", "augs", "train", "model/flow_generator", "model/feature_extractor"] + run_type = "train" + model = "feature_extractor" cfg = make_config(project_path=project_path, config_list=config_list, run_type=run_type, model=model, **kwargs) @@ -222,9 +224,9 @@ def make_feature_extractor_inference_cfg(project_path: Union[str, os.PathLike], DictConfig feature extractor inference config """ - config_list = ['config', 'augs', 'model/feature_extractor', 'model/flow_generator', 'inference', 'postprocessor'] - run_type = 'inference' - model = 'feature_extractor' + config_list = ["config", "augs", "model/feature_extractor", "model/flow_generator", "inference", "postprocessor"] + run_type = "inference" + model = "feature_extractor" cfg = make_config(project_path=project_path, config_list=config_list, run_type=run_type, model=model, **kwargs) @@ -244,9 +246,9 @@ def make_sequence_train_cfg(project_path: Union[str, os.PathLike], **kwargs) -> DictConfig sequence train config """ - config_list = ['config', 'model/feature_extractor', 'train', 'model/sequence'] - run_type = 'train' - model = 'sequence' + config_list = ["config", "model/feature_extractor", "train", "model/sequence"] + run_type = "train" + model = "sequence" cfg = make_config(project_path=project_path, config_list=config_list, run_type=run_type, model=model, **kwargs) @@ -266,9 +268,9 @@ def make_sequence_inference_cfg(project_path: Union[str, os.PathLike], **kwargs) DictConfig sequence inference config """ - config_list = ['config', 'augs', 'model/feature_extractor', 'model/sequence', 'inference'] - run_type = 'inference' - model = 'sequence' + config_list = ["config", "augs", "model/feature_extractor", "model/sequence", "inference"] + run_type = "inference" + model = "sequence" cfg = make_config(project_path=project_path, config_list=config_list, run_type=run_type, model=model, **kwargs) @@ -288,9 +290,9 @@ def make_postprocessing_cfg(project_path: Union[str, os.PathLike], **kwargs) -> DictConfig postprocessing config """ - config_list = ['config', 'model/sequence', 'postprocessor'] - run_type = 'inference' - model = 'sequence' + config_list = ["config", "model/sequence", "postprocessor"] + run_type = "inference" + model = "sequence" cfg = make_config(project_path=project_path, config_list=config_list, run_type=run_type, model=model, **kwargs) diff --git a/deepethogram/data/augs.py b/deepethogram/data/augs.py index 17922a0..f7ee65c 100644 --- a/deepethogram/data/augs.py +++ b/deepethogram/data/augs.py @@ -14,7 +14,7 @@ log = logging.getLogger(__name__) -def get_normalization_layer(mean: list, std: list, num_images: int = 1, mode: str = '2d'): +def get_normalization_layer(mean: list, std: list, num_images: int = 1, mode: str = "2d"): """Get Z-scoring layer from config If RGB frames are stacked into tensor N, num_rgb*3, H, W, we need to repeat the mean and std num_rgb times """ @@ -26,8 +26,7 @@ def get_normalization_layer(mean: list, std: list, num_images: int = 1, mode: st class Transpose: - """Module to transpose image stacks. - """ + """Module to transpose image stacks.""" def __call__(self, images: np.ndarray) -> np.ndarray: shape = images.shape @@ -39,12 +38,11 @@ def __call__(self, images: np.ndarray) -> np.ndarray: return images.transpose(2, 0, 1) def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + "()" class NormalizeVideo(nn.Module): - """Z-scores input video sequences - """ + """Z-scores input video sequences""" def __init__(self, mean, std): super().__init__() @@ -73,8 +71,7 @@ def forward(self, tensor): class DenormalizeVideo(nn.Module): - """Un-z-scores input video sequences - """ + """Un-z-scores input video sequences""" def __init__(self, mean, std): super().__init__() @@ -102,8 +99,7 @@ def forward(self, tensor): class ToFloat(nn.Module): - """Module for converting input uint8 tensors to floats, dividing by 255 - """ + """Module for converting input uint8 tensors to floats, dividing by 255""" def __init__(self): super().__init__() @@ -112,12 +108,11 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: return tensor.float().div(255) def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + "()" class StackClipInChannels(nn.Module): - """Module to convert image from N,C,T,H,W -> N,C*T,H,W - """ + """Module to convert image from N,C,T,H,W -> N,C*T,H,W""" def __init__(self): super().__init__() @@ -130,8 +125,7 @@ def forward(self, tensor): class UnstackClip(nn.Module): - """Module to convert image from N,C*T,H,W -> N,C,T,H,W - """ + """Module to convert image from N,C*T,H,W -> N,C,T,H,W""" def __init__(self): super().__init__() @@ -144,7 +138,7 @@ def forward(self, tensor): def get_cpu_transforms(augs: DictConfig) -> dict: - """Makes CPU augmentations from the aug section of a configuration. + """Makes CPU augmentations from the aug section of a configuration. Parameters ---------- @@ -154,7 +148,7 @@ def get_cpu_transforms(augs: DictConfig) -> dict: Returns ------- xform : dict - keys: ['train', 'val', 'test']. Values: a composed OpenCV augmentation pipeline callable. + keys: ['train', 'val', 'test']. Values: a composed OpenCV augmentation pipeline callable. Example: auged_images = xform['train'](images) """ train_transforms = [] @@ -177,12 +171,12 @@ def get_cpu_transforms(augs: DictConfig) -> dict: train_transforms = transforms.Compose(train_transforms) val_transforms = transforms.Compose(val_transforms) - xform = {'train': train_transforms, 'val': val_transforms, 'test': val_transforms} - log.debug('CPU transforms: {}'.format(xform)) + xform = {"train": train_transforms, "val": val_transforms, "test": val_transforms} + log.debug("CPU transforms: {}".format(xform)) return xform -def get_gpu_transforms(augs: DictConfig, mode: str = '2d') -> dict: +def get_gpu_transforms(augs: DictConfig, mode: str = "2d") -> dict: """Makes GPU augmentations from the augs section of a configuration. Parameters @@ -195,7 +189,7 @@ def get_gpu_transforms(augs: DictConfig, mode: str = '2d') -> dict: Returns ------- xform : dict - keys: ['train', 'val', 'test']. Values: a nn.Sequential with Kornia augmentations. + keys: ['train', 'val', 'test']. Values: a nn.Sequential with Kornia augmentations. Example: auged_images = xform['train'](images) """ # input is a tensor of shape N x C x F x H x W @@ -213,25 +207,28 @@ def get_gpu_transforms(augs: DictConfig, mode: str = '2d') -> dict: if augs.brightness > 0 or augs.contrast > 0 or augs.saturation > 0 or augs.hue > 0: kornia_transforms.append( - K.ColorJitter(brightness=augs.brightness, - contrast=augs.contrast, - saturation=augs.saturation, - hue=augs.hue, - p=augs.color_p, - same_on_batch=False)) + K.ColorJitter( + brightness=augs.brightness, + contrast=augs.contrast, + saturation=augs.saturation, + hue=augs.hue, + p=augs.color_p, + same_on_batch=False, + ) + ) if augs.grayscale > 0: kornia_transforms.append(K.RandomGrayscale(p=augs.grayscale)) norm = NormalizeVideo(mean=augs.normalization.mean, std=augs.normalization.std) # kornia_transforms.append(norm) - kornia_transforms = VideoSequential(*kornia_transforms, data_format='BCTHW', same_on_frame=True) + kornia_transforms = VideoSequential(*kornia_transforms, data_format="BCTHW", same_on_frame=True) train_transforms = [ToFloat(), kornia_transforms, norm] val_transforms = [ToFloat(), norm] denormalize = [] - if mode == '2d': + if mode == "2d": train_transforms.append(StackClipInChannels()) val_transforms.append(StackClipInChannels()) denormalize.append(UnstackClip()) @@ -242,11 +239,11 @@ def get_gpu_transforms(augs: DictConfig, mode: str = '2d') -> dict: denormalize = nn.Sequential(*denormalize) gpu_transforms = dict(train=train_transforms, val=val_transforms, test=val_transforms, denormalize=denormalize) - log.info('GPU transforms: {}'.format(gpu_transforms)) + log.info("GPU transforms: {}".format(gpu_transforms)) return gpu_transforms -def get_gpu_transforms_inference(augs: DictConfig, mode: str = '2d') -> dict: +def get_gpu_transforms_inference(augs: DictConfig, mode: str = "2d") -> dict: """Gets GPU transforms needed for inference Parameters @@ -259,7 +256,7 @@ def get_gpu_transforms_inference(augs: DictConfig, mode: str = '2d') -> dict: Returns ------- xform : dict - keys: ['train', 'val', 'test']. Values: a nn.Sequential with Kornia augmentations. + keys: ['train', 'val', 'test']. Values: a nn.Sequential with Kornia augmentations. Example: auged_images = xform['val'](images) """ # sequential iterator already handles casting to float, dividing by 255, and stacking in channel dimension @@ -267,7 +264,7 @@ def get_gpu_transforms_inference(augs: DictConfig, mode: str = '2d') -> dict: # norm = get_normalization_layer(np.array(augs.normalization.mean), np.array(augs.normalization.std), # num_images, mode) xform = [NormalizeVideo(mean=augs.normalization.mean, std=augs.normalization.std)] - if mode == '2d': + if mode == "2d": xform.append(StackClipInChannels()) xform = nn.Sequential(*xform) gpu_transforms = dict(val=xform, test=xform) @@ -280,7 +277,7 @@ def get_empty_gpu_transforms() -> dict: Returns ------- xform : dict - keys: ['train', 'val', 'test']. Values: a nn.Sequential with Kornia augmentations. + keys: ['train', 'val', 'test']. Values: a nn.Sequential with Kornia augmentations. Example: auged_images = xform['train'](images) """ gpu_transforms = dict(train=nn.Identity(), val=nn.Identity(), test=nn.Identity(), denormalize=nn.Identity()) diff --git a/deepethogram/data/dali.py b/deepethogram/data/dali.py index 0c5bd2e..a044019 100644 --- a/deepethogram/data/dali.py +++ b/deepethogram/data/dali.py @@ -14,49 +14,52 @@ class KineticsDALIPipe(Pipeline): - def __init__(self, directory, - supervised: bool = True, - sequence_length: int = 11, - batch_size: int = 1, - num_workers: int = 1, - gpu_id: int = 0, - shuffle: bool = True, - crop_size: tuple = (256, 256), - resize: tuple = None, - brightness: float = 0.25, - contrast: float = 0.1, - mean: list = [0.5, 0.5, 0.5], - std: list = [0.5, 0.5, 0.5], - conv_mode='3d', - image_shape=(256, 256), - validate: bool = False): + def __init__( + self, + directory, + supervised: bool = True, + sequence_length: int = 11, + batch_size: int = 1, + num_workers: int = 1, + gpu_id: int = 0, + shuffle: bool = True, + crop_size: tuple = (256, 256), + resize: tuple = None, + brightness: float = 0.25, + contrast: float = 0.1, + mean: list = [0.5, 0.5, 0.5], + std: list = [0.5, 0.5, 0.5], + conv_mode="3d", + image_shape=(256, 256), + validate: bool = False, + ): super().__init__(batch_size, num_workers, gpu_id, prefetch_queue_depth=1) - self.input = ops.VideoReader(additional_decode_surfaces=1, - channels=3, - device="gpu", - dtype=types.FLOAT, - enable_frame_num=False, - enable_timestamps=False, - file_root=directory, - image_type=types.RGB, - initial_fill=1, - lazy_init=False, - normalized=True, - num_shards=1, - pad_last_batch=False, - prefetch_queue_depth=1, - random_shuffle=shuffle, - sequence_length=sequence_length, - skip_vfr_check=True, - step=-1, - shard_id=0, - stick_to_shard=False, - stride=1) + self.input = ops.VideoReader( + additional_decode_surfaces=1, + channels=3, + device="gpu", + dtype=types.FLOAT, + enable_frame_num=False, + enable_timestamps=False, + file_root=directory, + image_type=types.RGB, + initial_fill=1, + lazy_init=False, + normalized=True, + num_shards=1, + pad_last_batch=False, + prefetch_queue_depth=1, + random_shuffle=shuffle, + sequence_length=sequence_length, + skip_vfr_check=True, + step=-1, + shard_id=0, + stick_to_shard=False, + stride=1, + ) self.uniform = ops.Uniform(range=(0.0, 1.0)) - self.cmn = ops.CropMirrorNormalize(device='gpu', crop=crop_size, - mean=mean, std=std, - output_layout=types.NFHWC) + self.cmn = ops.CropMirrorNormalize(device="gpu", crop=crop_size, mean=mean, std=std, output_layout=types.NFHWC) self.coin = ops.CoinFlip(probability=0.5) self.brightness_val = ops.Uniform(range=[1 - brightness, 1 + brightness]) @@ -64,19 +67,19 @@ def __init__(self, directory, self.supervised = supervised self.half = ops.Constant(fdata=0.5) self.zero = ops.Constant(idata=0) - self.cast_to_long = ops.Cast(device='gpu', dtype=types.INT64) + self.cast_to_long = ops.Cast(device="gpu", dtype=types.INT64) if crop_size is not None: H, W = crop_size else: # default H, W = image_shape # print('CONV MODE!!! {}'.format(conv_mode)) - if conv_mode == '3d': + if conv_mode == "3d": self.transpose = ops.Transpose(device="gpu", perm=[3, 0, 1, 2]) self.reshape = None - elif conv_mode == '2d': - self.transpose = ops.Transpose(device='gpu', perm=[0, 3, 1, 2]) - self.reshape = ops.Reshape(device='gpu', shape=[-1, H, W]) + elif conv_mode == "2d": + self.transpose = ops.Transpose(device="gpu", perm=[0, 3, 1, 2]) + self.reshape = ops.Reshape(device="gpu", shape=[-1, H, W]) self.validate = validate def define_graph(self): @@ -108,37 +111,39 @@ def define_graph(self): # # https://github.com/NVIDIA/DALI/blob/cde7271a840142221273f8642952087acd919b6e # # /docs/examples/use_cases/video_superres/dataloading/dataloaders.py class DALILoader: - def __init__(self, directory, - supervised: bool = True, - sequence_length: int = 11, - batch_size: int = 1, - num_workers: int = 1, - gpu_id: int = 0, - shuffle: bool = True, - crop_size: tuple = (256, 256), - mean: list = [0.5, 0.5, 0.5], - std: list = [0.5, 0.5, 0.5], - conv_mode: str = '3d', - validate: bool = False, - distributed: bool = False): - self.pipeline = KineticsDALIPipe(directory=directory, - batch_size=batch_size, - supervised=supervised, - sequence_length=sequence_length, - num_workers=num_workers, - gpu_id=gpu_id, - crop_size=crop_size, - mean=mean, - std=std, - conv_mode=conv_mode, - validate=validate) + def __init__( + self, + directory, + supervised: bool = True, + sequence_length: int = 11, + batch_size: int = 1, + num_workers: int = 1, + gpu_id: int = 0, + shuffle: bool = True, + crop_size: tuple = (256, 256), + mean: list = [0.5, 0.5, 0.5], + std: list = [0.5, 0.5, 0.5], + conv_mode: str = "3d", + validate: bool = False, + distributed: bool = False, + ): + self.pipeline = KineticsDALIPipe( + directory=directory, + batch_size=batch_size, + supervised=supervised, + sequence_length=sequence_length, + num_workers=num_workers, + gpu_id=gpu_id, + crop_size=crop_size, + mean=mean, + std=std, + conv_mode=conv_mode, + validate=validate, + ) self.pipeline.build() self.epoch_size = self.pipeline.epoch_size("Reader") - names = ['images', 'labels'] if supervised else ['images'] - self.dali_iterator = pytorch.DALIGenericIterator(self.pipeline, - names, - self.epoch_size, - auto_reset=True) + names = ["images", "labels"] if supervised else ["images"] + self.dali_iterator = pytorch.DALIGenericIterator(self.pipeline, names, self.epoch_size, auto_reset=True) def __len__(self): return int(self.epoch_size) @@ -147,37 +152,41 @@ def __iter__(self): return self.dali_iterator.__iter__() -def get_dataloaders_kinetics_dali(directory, - rgb_frames=1, - batch_size=1, - shuffle=True, - num_workers=0, - supervised=True, - conv_mode='2d', - gpu_id: int = 0, - crop_size: tuple = (256, 256), - mean: list = [0.5, 0.5, 0.5], - std: list = [0.5, 0.5, 0.5], - distributed: bool = False): - shuffles = {'train': shuffle, 'val': True, 'test': False} +def get_dataloaders_kinetics_dali( + directory, + rgb_frames=1, + batch_size=1, + shuffle=True, + num_workers=0, + supervised=True, + conv_mode="2d", + gpu_id: int = 0, + crop_size: tuple = (256, 256), + mean: list = [0.5, 0.5, 0.5], + std: list = [0.5, 0.5, 0.5], + distributed: bool = False, +): + shuffles = {"train": shuffle, "val": True, "test": False} dataloaders = {} - for split in ['train', 'val']: + for split in ["train", "val"]: splitdir = os.path.join(directory, split) - dataloaders[split] = DALILoader(splitdir, - supervised=supervised, - batch_size=batch_size, - gpu_id=gpu_id, - shuffle=shuffles[split], - crop_size=crop_size, - mean=mean, - std=std, - validate=split == 'val', - num_workers=num_workers, - sequence_length=rgb_frames, - conv_mode=conv_mode, - distributed=distributed) - - dataloaders['split'] = None + dataloaders[split] = DALILoader( + splitdir, + supervised=supervised, + batch_size=batch_size, + gpu_id=gpu_id, + shuffle=shuffles[split], + crop_size=crop_size, + mean=mean, + std=std, + validate=split == "val", + num_workers=num_workers, + sequence_length=rgb_frames, + conv_mode=conv_mode, + distributed=distributed, + ) + + dataloaders["split"] = None return dataloaders @@ -186,4 +195,4 @@ def __len__(self): def __iter__(self): - return self.dali_iterator.__iter__() \ No newline at end of file + return self.dali_iterator.__iter__() diff --git a/deepethogram/data/dataloaders.py b/deepethogram/data/dataloaders.py index 7a8b8ed..cd33b51 100644 --- a/deepethogram/data/dataloaders.py +++ b/deepethogram/data/dataloaders.py @@ -11,8 +11,11 @@ from deepethogram import projects from deepethogram.data.augs import get_cpu_transforms from deepethogram.data.datasets import SequenceDataset, TwoStreamDataset, VideoDataset, KineticsDataset -from deepethogram.data.utils import get_split_from_records, remove_invalid_records_from_split_dictionary, \ - make_loss_weight +from deepethogram.data.utils import ( + get_split_from_records, + remove_invalid_records_from_split_dictionary, + make_loss_weight, +) try: from nvidia.dali.pipeline import Pipeline @@ -24,14 +27,29 @@ log = logging.getLogger(__name__) -def get_dataloaders_sequence(datadir: Union[str, os.PathLike], latent_name: str, sequence_length: int = 60, - is_two_stream: bool = True, nonoverlapping: bool = True, splitfile: str = None, - reload_split: bool = True, store_in_ram: bool = True, dimension: int = None, - train_val_test: Union[list, np.ndarray] = [0.8, 0.2, 0.0], weight_exp: float = 1.0, - batch_size=1, shuffle=True, num_workers=0, pin_memory=False, drop_last=False, - supervised=True, reduce=False, valid_splits_only: bool = True, - return_logits=False) -> dict: - """ Gets dataloaders for sequence models assuming DeepEthogram file structure. +def get_dataloaders_sequence( + datadir: Union[str, os.PathLike], + latent_name: str, + sequence_length: int = 60, + is_two_stream: bool = True, + nonoverlapping: bool = True, + splitfile: str = None, + reload_split: bool = True, + store_in_ram: bool = True, + dimension: int = None, + train_val_test: Union[list, np.ndarray] = [0.8, 0.2, 0.0], + weight_exp: float = 1.0, + batch_size=1, + shuffle=True, + num_workers=0, + pin_memory=False, + drop_last=False, + supervised=True, + reduce=False, + valid_splits_only: bool = True, + return_logits=False, +) -> dict: + """Gets dataloaders for sequence models assuming DeepEthogram file structure. Parameters ---------- @@ -99,9 +117,9 @@ def get_dataloaders_sequence(datadir: Union[str, os.PathLike], latent_name: str, loss_weight: loss weight for softmax activation / NLL loss """ - return_types = ['output'] + return_types = ["output"] if supervised: - return_types += ['label'] + return_types += ["label"] # records: dictionary of dictionaries. Keys: unique data identifiers # values: a dictionary corresponding to different files. the first record might be: @@ -111,102 +129,156 @@ def get_dataloaders_sequence(datadir: Union[str, os.PathLike], latent_name: str, records = projects.filter_records_for_filetypes(records, return_types) # returns a dictionary, where each split in ['train', 'val', 'test'] as a list of keys # each key corresponds to a unique directory, and has - split_dictionary = get_split_from_records(records, datadir, splitfile, supervised, reload_split, valid_splits_only, - train_val_test) + split_dictionary = get_split_from_records( + records, datadir, splitfile, supervised, reload_split, valid_splits_only, train_val_test + ) # it's possible that your split has records that are invalid for the current task. # e.g.: you've added a video, but not labeled it yet. In that case, it will already be in your split, but it is # invalid for current purposes, because it has no label. Therefore, we want to remove it from the current split split_dictionary = remove_invalid_records_from_split_dictionary(split_dictionary, records) - log.info('~~~~~ train val test split ~~~~~') + log.info("~~~~~ train val test split ~~~~~") pprint.pprint(split_dictionary) datasets = {} - splits = ['train', 'val', 'test'] + splits = ["train", "val", "test"] datasets = {} for split in splits: - outputfiles = [records[i]['output'] for i in split_dictionary[split]] + outputfiles = [records[i]["output"] for i in split_dictionary[split]] - if split == 'test' and len(outputfiles) == 0: + if split == "test" and len(outputfiles) == 0: datasets[split] = None continue # h5file, labelfile = outputs[i] # print('making dataset:{}'.format(split)) if supervised: - labelfiles = [records[i]['label'] for i in split_dictionary[split]] + labelfiles = [records[i]["label"] for i in split_dictionary[split]] else: labelfiles = None - datasets[split] = SequenceDataset(outputfiles, labelfiles, latent_name, sequence_length, - is_two_stream=is_two_stream, nonoverlapping=nonoverlapping, - dimension=dimension, - store_in_ram=store_in_ram, return_logits=return_logits) - - shuffles = {'train': shuffle, 'val': True, 'test': False} - - dataloaders = {split: data.DataLoader(datasets[split], batch_size=batch_size, - shuffle=shuffles[split], num_workers=num_workers, - pin_memory=pin_memory, drop_last=drop_last) - for split in ['train', 'val', 'test'] if datasets[split] is not None} + datasets[split] = SequenceDataset( + outputfiles, + labelfiles, + latent_name, + sequence_length, + is_two_stream=is_two_stream, + nonoverlapping=nonoverlapping, + dimension=dimension, + store_in_ram=store_in_ram, + return_logits=return_logits, + ) + + shuffles = {"train": shuffle, "val": True, "test": False} + + dataloaders = { + split: data.DataLoader( + datasets[split], + batch_size=batch_size, + shuffle=shuffles[split], + num_workers=num_workers, + pin_memory=pin_memory, + drop_last=drop_last, + ) + for split in ["train", "val", "test"] + if datasets[split] is not None + } # figure out what our inputs to our model will be (D dimension) - dataloaders['num_features'] = datasets['train'].num_features + dataloaders["num_features"] = datasets["train"].num_features if supervised: - dataloaders['class_counts'] = datasets['train'].class_counts - dataloaders['num_classes'] = len(dataloaders['class_counts']) - pos_weight, softmax_weight = make_loss_weight(dataloaders['class_counts'], - datasets['train'].num_pos, - datasets['train'].num_neg, - weight_exp=weight_exp) - dataloaders['pos'] = datasets['train'].num_pos - dataloaders['neg'] = datasets['train'].num_neg - dataloaders['pos_weight'] = pos_weight - dataloaders['loss_weight'] = softmax_weight - dataloaders['split'] = split_dictionary + dataloaders["class_counts"] = datasets["train"].class_counts + dataloaders["num_classes"] = len(dataloaders["class_counts"]) + pos_weight, softmax_weight = make_loss_weight( + dataloaders["class_counts"], datasets["train"].num_pos, datasets["train"].num_neg, weight_exp=weight_exp + ) + dataloaders["pos"] = datasets["train"].num_pos + dataloaders["neg"] = datasets["train"].num_neg + dataloaders["pos_weight"] = pos_weight + dataloaders["loss_weight"] = softmax_weight + dataloaders["split"] = split_dictionary return dataloaders -def get_dataloaders_kinetics(directory, mode='both', xform=None, rgb_frames=1, flow_frames=10, - batch_size=1, shuffle=True, - num_workers=0, pin_memory=False, drop_last=False, - supervised=True, - reduce=True, conv_mode='2d'): +def get_dataloaders_kinetics( + directory, + mode="both", + xform=None, + rgb_frames=1, + flow_frames=10, + batch_size=1, + shuffle=True, + num_workers=0, + pin_memory=False, + drop_last=False, + supervised=True, + reduce=True, + conv_mode="2d", +): datasets = {} - for split in ['train', 'val', 'test']: + for split in ["train", "val", "test"]: # this is in the two stream case where you can't apply color transforms to an optic flow if type(xform[split]) == dict: - spatial_transform = xform[split]['spatial'] - color_transform = xform[split]['color'] + spatial_transform = xform[split]["spatial"] + color_transform = xform[split]["color"] else: spatial_transform = xform[split] color_transform = None - datasets[split] = KineticsDataset(directory, split, mode, supervised=supervised, - rgb_frames=rgb_frames, flow_frames=flow_frames, - spatial_transform=spatial_transform, - color_transform=color_transform, - reduce=reduce, - flow_style='rgb', - flow_max=10, - conv_mode=conv_mode) - - shuffles = {'train': shuffle, 'val': True, 'test': False} - - dataloaders = {split: data.DataLoader(datasets[split], batch_size=batch_size, - shuffle=shuffles[split], num_workers=num_workers, - pin_memory=pin_memory, drop_last=drop_last) - for split in ['train', 'val', 'test']} - dataloaders['split'] = None + datasets[split] = KineticsDataset( + directory, + split, + mode, + supervised=supervised, + rgb_frames=rgb_frames, + flow_frames=flow_frames, + spatial_transform=spatial_transform, + color_transform=color_transform, + reduce=reduce, + flow_style="rgb", + flow_max=10, + conv_mode=conv_mode, + ) + + shuffles = {"train": shuffle, "val": True, "test": False} + + dataloaders = { + split: data.DataLoader( + datasets[split], + batch_size=batch_size, + shuffle=shuffles[split], + num_workers=num_workers, + pin_memory=pin_memory, + drop_last=drop_last, + ) + for split in ["train", "val", "test"] + } + dataloaders["split"] = None return dataloaders -def get_video_dataloaders(datadir: Union[str, os.PathLike], xform: dict, is_two_stream: bool = False, - reload_split: bool = True, splitfile: Union[str, os.PathLike] = None, - train_val_test: Union[list, np.ndarray] = [0.8, 0.1, 0.1], weight_exp: float = 1.0, - rgb_frames: int = 1, flow_frames: int = 10, batch_size=1, shuffle=True, num_workers=0, - pin_memory=False, drop_last=False, supervised=True, reduce=False, flow_max: int = 5, - flow_style: str = 'linear', valid_splits_only: bool = True, conv_mode: str = '2d'): - """ Gets dataloaders for video-based datasets. +def get_video_dataloaders( + datadir: Union[str, os.PathLike], + xform: dict, + is_two_stream: bool = False, + reload_split: bool = True, + splitfile: Union[str, os.PathLike] = None, + train_val_test: Union[list, np.ndarray] = [0.8, 0.1, 0.1], + weight_exp: float = 1.0, + rgb_frames: int = 1, + flow_frames: int = 10, + batch_size=1, + shuffle=True, + num_workers=0, + pin_memory=False, + drop_last=False, + supervised=True, + reduce=False, + flow_max: int = 5, + flow_style: str = "linear", + valid_splits_only: bool = True, + conv_mode: str = "2d", +): + """Gets dataloaders for video-based datasets. Parameters ---------- @@ -264,11 +336,11 @@ def get_video_dataloaders(datadir: Union[str, os.PathLike], xform: dict, is_two_ split contains the split dictionary, for saving keys for loss weighting are also added. see make_loss_weight for explanation """ - return_types = ['rgb'] + return_types = ["rgb"] if is_two_stream: - return_types += ['flow'] + return_types += ["flow"] if supervised: - return_types += ['label'] + return_types += ["label"] # records: dictionary of dictionaries. Keys: unique data identifiers # values: a dictionary corresponding to different files. the first record might be: # {'mouse000': {'rgb': path/to/rgb.avi, 'label':path/to/labels.csv} } @@ -277,71 +349,82 @@ def get_video_dataloaders(datadir: Union[str, os.PathLike], xform: dict, is_two_ records = projects.filter_records_for_filetypes(records, return_types) # returns a dictionary, where each split in ['train', 'val', 'test'] as a list of keys # each key corresponds to a unique directory, and has - split_dictionary = get_split_from_records(records, datadir, splitfile, supervised, reload_split, valid_splits_only, - train_val_test) + split_dictionary = get_split_from_records( + records, datadir, splitfile, supervised, reload_split, valid_splits_only, train_val_test + ) # it's possible that your split has records that are invalid for the current task. # e.g.: you've added a video, but not labeled it yet. In that case, it will already be in your split, but it is # invalid for current purposes, because it has no label. Therefore, we want to remove it from the current split split_dictionary = remove_invalid_records_from_split_dictionary(split_dictionary, records) datasets = {} - for i, split in enumerate(['train', 'val', 'test']): - rgb = [records[i]['rgb'] for i in split_dictionary[split]] - flow = [records[i]['flow'] for i in split_dictionary[split]] + for i, split in enumerate(["train", "val", "test"]): + rgb = [records[i]["rgb"] for i in split_dictionary[split]] + flow = [records[i]["flow"] for i in split_dictionary[split]] - if split == 'test' and len(rgb) == 0: + if split == "test" and len(rgb) == 0: datasets[split] = None continue if supervised: - labelfiles = [records[i]['label'] for i in split_dictionary[split]] + labelfiles = [records[i]["label"] for i in split_dictionary[split]] else: labelfiles = None if is_two_stream: - datasets[split] = TwoStreamDataset(rgb_list=rgb, - flow_list=flow, - rgb_frames=rgb_frames, - flow_frames=flow_frames, - spatial_transform=xform[split]['spatial'], - color_transform=xform[split]['color'], - label_list=labelfiles, - reduce=reduce, - flow_max=flow_max, - flow_style=flow_style - ) + datasets[split] = TwoStreamDataset( + rgb_list=rgb, + flow_list=flow, + rgb_frames=rgb_frames, + flow_frames=flow_frames, + spatial_transform=xform[split]["spatial"], + color_transform=xform[split]["color"], + label_list=labelfiles, + reduce=reduce, + flow_max=flow_max, + flow_style=flow_style, + ) else: - datasets[split] = VideoDataset(rgb, - frames_per_clip=rgb_frames, - label_list=labelfiles, - reduce=reduce, - transform=xform[split], - conv_mode=conv_mode) - - shuffles = {'train': shuffle, 'val': True, 'test': False} - - dataloaders = {split: data.DataLoader(datasets[split], batch_size=batch_size, - shuffle=shuffles[split], num_workers=num_workers, - pin_memory=pin_memory, drop_last=drop_last) - for split in ['train', 'val', 'test'] if datasets[split] is not None} + datasets[split] = VideoDataset( + rgb, + frames_per_clip=rgb_frames, + label_list=labelfiles, + reduce=reduce, + transform=xform[split], + conv_mode=conv_mode, + ) + + shuffles = {"train": shuffle, "val": True, "test": False} + + dataloaders = { + split: data.DataLoader( + datasets[split], + batch_size=batch_size, + shuffle=shuffles[split], + num_workers=num_workers, + pin_memory=pin_memory, + drop_last=drop_last, + ) + for split in ["train", "val", "test"] + if datasets[split] is not None + } if supervised: - dataloaders['class_counts'] = datasets['train'].class_counts - dataloaders['num_classes'] = len(dataloaders['class_counts']) - pos_weight, softmax_weight = make_loss_weight(dataloaders['class_counts'], - datasets['train'].num_pos, - datasets['train'].num_neg, - weight_exp=weight_exp) - dataloaders['pos'] = datasets['train'].num_pos - dataloaders['neg'] = datasets['train'].num_neg - dataloaders['pos_weight'] = pos_weight - dataloaders['loss_weight'] = softmax_weight - dataloaders['split'] = split_dictionary - return (dataloaders) + dataloaders["class_counts"] = datasets["train"].class_counts + dataloaders["num_classes"] = len(dataloaders["class_counts"]) + pos_weight, softmax_weight = make_loss_weight( + dataloaders["class_counts"], datasets["train"].num_pos, datasets["train"].num_neg, weight_exp=weight_exp + ) + dataloaders["pos"] = datasets["train"].num_pos + dataloaders["neg"] = datasets["train"].num_neg + dataloaders["pos_weight"] = pos_weight + dataloaders["loss_weight"] = softmax_weight + dataloaders["split"] = split_dictionary + return dataloaders def get_dataloaders_from_cfg(cfg: DictConfig, model_type: str, input_images: int = 1) -> dict: - """ Returns dataloader objects using a Hydra-generated configuration dictionary. + """Returns dataloader objects using a Hydra-generated configuration dictionary. This is the main entry point for getting dataloaders from the command line. it will return the correct dataloader with given hyperparameters for either flow, feature extractor, or sequence models. @@ -367,67 +450,92 @@ def get_dataloaders_from_cfg(cfg: DictConfig, model_type: str, input_images: int information see the specific dataloader of the model you're training, e.g. get_video_dataloaders """ # - supervised = model_type != 'flow_generator' - batch_size = cfg.compute.batch_size if cfg.compute.batch_size != 'auto' else cfg.batch_size - log.info('batch size: {}'.format(batch_size)) - if model_type == 'feature_extractor' or model_type == 'flow_generator': + supervised = model_type != "flow_generator" + batch_size = cfg.compute.batch_size if cfg.compute.batch_size != "auto" else cfg.batch_size + log.info("batch size: {}".format(batch_size)) + if model_type == "feature_extractor" or model_type == "flow_generator": arch = cfg[model_type].arch - mode = '3d' if '3d' in arch.lower() else '2d' - log.info('getting dataloaders: {} convolution type detected'.format(mode)) + mode = "3d" if "3d" in arch.lower() else "2d" + log.info("getting dataloaders: {} convolution type detected".format(mode)) xform = get_cpu_transforms(cfg.augs) - - if cfg.project.name == 'kinetics': + if cfg.project.name == "kinetics": if cfg.compute.dali: - dataloaders = get_dataloaders_kinetics_dali(cfg.project.data_path, - rgb_frames=input_images, - batch_size=batch_size, - num_workers=cfg.compute.num_workers, - supervised=supervised, - conv_mode=mode, - gpu_id=cfg.compute.gpu_id, - crop_size=cfg.augs.crop_size, - mean=list(cfg.augs.normalization.mean), - std=list(cfg.augs.normalization.std), - distributed=cfg.compute.distributed) + dataloaders = get_dataloaders_kinetics_dali( + cfg.project.data_path, + rgb_frames=input_images, + batch_size=batch_size, + num_workers=cfg.compute.num_workers, + supervised=supervised, + conv_mode=mode, + gpu_id=cfg.compute.gpu_id, + crop_size=cfg.augs.crop_size, + mean=list(cfg.augs.normalization.mean), + std=list(cfg.augs.normalization.std), + distributed=cfg.compute.distributed, + ) # hack, because for DEG projects we'll get the number of positive and negative examples # for kinetics, we don't want to weight the loss at all - dataloaders['pos'] = None - dataloaders['neg'] = None + dataloaders["pos"] = None + dataloaders["neg"] = None else: - dataloaders = get_dataloaders_kinetics(cfg.project.data_path, - mode='rgb', - xform=xform, - rgb_frames=input_images, - batch_size=batch_size, - shuffle=True, - num_workers=cfg.compute.num_workers, - pin_memory=torch.cuda.is_available(), - reduce=True, - supervised=supervised, - conv_mode=mode) + dataloaders = get_dataloaders_kinetics( + cfg.project.data_path, + mode="rgb", + xform=xform, + rgb_frames=input_images, + batch_size=batch_size, + shuffle=True, + num_workers=cfg.compute.num_workers, + pin_memory=torch.cuda.is_available(), + reduce=True, + supervised=supervised, + conv_mode=mode, + ) else: reduce = False - if cfg.run.model == 'feature_extractor': - if cfg.feature_extractor.final_activation == 'softmax': + if cfg.run.model == "feature_extractor": + if cfg.feature_extractor.final_activation == "softmax": reduce = True - dataloaders = get_video_dataloaders(cfg.project.data_path, xform=xform, is_two_stream=False, - splitfile=cfg.split.file, train_val_test=cfg.split.train_val_test, - weight_exp=cfg.train.loss_weight_exp, rgb_frames=input_images, - batch_size=batch_size, num_workers=cfg.compute.num_workers, - pin_memory=torch.cuda.is_available(), drop_last=False, - supervised=supervised, reduce=reduce, conv_mode=mode) - elif model_type == 'sequence': - dataloaders = get_dataloaders_sequence(cfg.project.data_path, cfg.sequence.latent_name, - cfg.sequence.sequence_length, is_two_stream=True, - nonoverlapping=cfg.sequence.nonoverlapping, splitfile=cfg.split.file, - reload_split=True, store_in_ram=False, dimension=None, - train_val_test=cfg.split.train_val_test, - weight_exp=cfg.train.loss_weight_exp, batch_size=batch_size, - shuffle=True, num_workers=cfg.compute.num_workers, - pin_memory=torch.cuda.is_available(), drop_last=False, supervised=True, - reduce=cfg.feature_extractor.final_activation == 'softmax', - valid_splits_only=True, return_logits=False) + dataloaders = get_video_dataloaders( + cfg.project.data_path, + xform=xform, + is_two_stream=False, + splitfile=cfg.split.file, + train_val_test=cfg.split.train_val_test, + weight_exp=cfg.train.loss_weight_exp, + rgb_frames=input_images, + batch_size=batch_size, + num_workers=cfg.compute.num_workers, + pin_memory=torch.cuda.is_available(), + drop_last=False, + supervised=supervised, + reduce=reduce, + conv_mode=mode, + ) + elif model_type == "sequence": + dataloaders = get_dataloaders_sequence( + cfg.project.data_path, + cfg.sequence.latent_name, + cfg.sequence.sequence_length, + is_two_stream=True, + nonoverlapping=cfg.sequence.nonoverlapping, + splitfile=cfg.split.file, + reload_split=True, + store_in_ram=False, + dimension=None, + train_val_test=cfg.split.train_val_test, + weight_exp=cfg.train.loss_weight_exp, + batch_size=batch_size, + shuffle=True, + num_workers=cfg.compute.num_workers, + pin_memory=torch.cuda.is_available(), + drop_last=False, + supervised=True, + reduce=cfg.feature_extractor.final_activation == "softmax", + valid_splits_only=True, + return_logits=False, + ) else: - raise ValueError('Unknown model type: {}'.format(model_type)) + raise ValueError("Unknown model type: {}".format(model_type)) return dataloaders diff --git a/deepethogram/data/datasets.py b/deepethogram/data/datasets.py index 49024a3..3d4aa13 100644 --- a/deepethogram/data/datasets.py +++ b/deepethogram/data/datasets.py @@ -14,8 +14,15 @@ # from deepethogram.dataloaders import log from deepethogram import projects from deepethogram.data.augs import get_cpu_transforms -from deepethogram.data.utils import purge_unlabeled_elements_from_records, get_video_metadata, read_all_labels, get_split_from_records, remove_invalid_records_from_split_dictionary, \ - make_loss_weight, fix_label +from deepethogram.data.utils import ( + purge_unlabeled_elements_from_records, + get_video_metadata, + read_all_labels, + get_split_from_records, + remove_invalid_records_from_split_dictionary, + make_loss_weight, + fix_label, +) from deepethogram.data.keypoint_utils import load_dlcfile, interpolate_bad_values, expand_features_sturman from deepethogram.file_io import read_labels @@ -24,22 +31,24 @@ # https://pytorch.org/docs/stable/data.html class VideoIterable(data.IterableDataset): - """Highly optimized Dataset for running inference on videos. - - Features: + """Highly optimized Dataset for running inference on videos. + + Features: - Data is only read sequentially - Each frame is only read once - The input video is divided into NUM_WORKERS segments. Each worker reads its segment in parallel - - Each clip is read with stride = 1. If sequence_length==3, the first clips would be frames [0, 1, 2], + - Each clip is read with stride = 1. If sequence_length==3, the first clips would be frames [0, 1, 2], [1, 2, 3], [2, 3, 4], ... etc """ - def __init__(self, - videofile: Union[str, os.PathLike], - transform, - sequence_length: int = 11, - num_workers: int = 0, - mean_by_channels: Union[list, np.ndarray] = [0, 0, 0]): + def __init__( + self, + videofile: Union[str, os.PathLike], + transform, + sequence_length: int = 11, + num_workers: int = 0, + mean_by_channels: Union[list, np.ndarray] = [0, 0, 0], + ): """Cosntructor for video iterable Parameters @@ -90,10 +99,12 @@ def get_image_shape(self): im = self.transform(im) self._image_shape = im.shape - def get_zeros_image(self,): + def get_zeros_image( + self, + ): if self._zeros_image is None: if self._image_shape is None: - raise ValueError('must set shape before getting zeros image') + raise ValueError("must set shape before getting zeros image") # ALWAYS ASSUME OUTPUT IS TRANSPOSED self._zeros_image = np.zeros(self._image_shape, dtype=np.uint8) for i in range(3): @@ -107,12 +118,12 @@ def parse_mean_by_channels(self, mean_by_channels): assert np.array_equal(np.clip(mean_by_channels, 0, 255), np.array(mean_by_channels)) return np.array(mean_by_channels).astype(np.uint8) else: - raise ValueError('unexpected type for input channel mean: {}'.format(mean_by_channels)) + raise ValueError("unexpected type for input channel mean: {}".format(mean_by_channels)) def my_iter_func(self, start, end): for i in range(start, end): self.buffer.append(self.get_current_item()) - yield {'images': np.stack(self.buffer, axis=1), 'framenum': self.cnt - 1 - self.sequence_length // 2} + yield {"images": np.stack(self.buffer, axis=1), "framenum": self.cnt - 1 - self.sequence_length // 2} def get_current_item(self): worker_info = data.get_worker_info() @@ -127,7 +138,7 @@ def get_current_item(self): try: im = self.readers[worker_id][self.cnt] except Exception: - print(f'problem reading frame {self.cnt}') + print(f"problem reading frame {self.cnt}") raise im = self.transform(im) # print(im.dtype) @@ -177,9 +188,9 @@ def close(self): try: v.close() except Exception: - print(f'error destroying reader {k}') + print(f"error destroying reader {k}") else: - print(f'destroyed {k}') + print(f"destroyed {k}") def __exit__(self, *args): self.close() @@ -208,15 +219,17 @@ class SingleVideoDataset(data.Dataset): # ~5 x 11 """ - def __init__(self, - videofile: Union[str, os.PathLike], - labelfile: Union[str, os.PathLike] = None, - mean_by_channels: Union[list, np.ndarray] = [0, 0, 0], - frames_per_clip: int = 1, - transform=None, - reduce: bool = True, - conv_mode: str = '2d', - keep_reader_open: bool = False): + def __init__( + self, + videofile: Union[str, os.PathLike], + labelfile: Union[str, os.PathLike] = None, + mean_by_channels: Union[list, np.ndarray] = [0, 0, 0], + frames_per_clip: int = 1, + transform=None, + reduce: bool = True, + conv_mode: str = "2d", + keep_reader_open: bool = False, + ): """Initializes a VideoDataset object. Args: @@ -244,33 +257,33 @@ def __init__(self, self.supervised = self.labelfile is not None assert os.path.isfile(videofile) or os.path.isdir(videofile) - assert self.conv_mode in ['2d', '3d'] + assert self.conv_mode in ["2d", "3d"] # find labels given the filename of a video, load, save as an attribute for fast reading if self.supervised: assert os.path.isfile(labelfile) # self.video_list, self.label_list = purge_unlabeled_videos(self.video_list, self.label_list) - labels, class_counts, num_labels, num_pos, num_neg = read_all_labels([self.labelfile], - True, - multilabel=not self.reduce) + labels, class_counts, num_labels, num_pos, num_neg = read_all_labels( + [self.labelfile], True, multilabel=not self.reduce + ) self.labels = labels self.class_counts = class_counts self.num_labels = num_labels self.num_pos = num_pos self.num_neg = num_neg - log.debug('label shape: {}'.format(self.labels.shape)) + log.debug("label shape: {}".format(self.labels.shape)) metadata = {} ret, width, height, framecount = get_video_metadata(self.videofile) if ret: - metadata['name'] = videofile - metadata['width'] = width - metadata['height'] = height - metadata['framecount'] = framecount + metadata["name"] = videofile + metadata["width"] = width + metadata["height"] = height + metadata["framecount"] = framecount else: - raise ValueError('error loading video: {}'.format(videofile)) + raise ValueError("error loading video: {}".format(videofile)) self.metadata = metadata - self.N = self.metadata['framecount'] + self.N = self.metadata["framecount"] self._zeros_image = None def get_zeros_image(self, c, h, w, channel_first: bool = True): @@ -288,7 +301,7 @@ def parse_mean_by_channels(self, mean_by_channels): assert np.array_equal(np.clip(mean_by_channels, 0, 255), np.array(mean_by_channels)) return np.array(mean_by_channels).astype(np.uint8) else: - raise ValueError('unexpected type for input channel mean: {}'.format(mean_by_channels)) + raise ValueError("unexpected type for input channel mean: {}".format(mean_by_channels)) def __len__(self): return self.N @@ -321,7 +334,7 @@ def __getitem__(self, index: int): # if frames per clip is 11, dataset[0] would have 5 blank frames preceding, with the 6th-11th being real frames blank_start_frames = max(self.frames_per_clip // 2 - index, 0) - framecount = self.metadata['framecount'] + framecount = self.metadata["framecount"] # cap = cv2.VideoCapture(self.movies[style][movie_index]) start_frame = index - self.frames_per_clip // 2 + blank_start_frames blank_end_frames = max(index - framecount + self.frames_per_clip // 2 + 1, 0) @@ -334,8 +347,9 @@ def __getitem__(self, index: int): image = reader[i + start_frame] except Exception as e: image = self._zeros_image.copy().transpose(1, 2, 0) - log.warning('Error {} on frame {} of video {}. Is the video corrupted?'.format( - e, index, self.videofile)) + log.warning( + "Error {} on frame {} of video {}. Is the video corrupted?".format(e, index, self.videofile) + ) if self.transform: random.seed(seed) image = self.transform(image) @@ -345,31 +359,34 @@ def __getitem__(self, index: int): images = self.append_with_zeros(images, blank_end_frames) if log.isEnabledFor(logging.DEBUG): - log.debug('idx: {} st: {} blank_start: {} blank_end: {} real: {} total: {}'.format( - index, start_frame, blank_start_frames, blank_end_frames, real_frames, framecount)) + log.debug( + "idx: {} st: {} blank_start: {} blank_end: {} real: {} total: {}".format( + index, start_frame, blank_start_frames, blank_end_frames, real_frames, framecount + ) + ) # images are now numpy arrays of shape 3, H, W # stacking in the first dimension changes to 3, T, H, W, compatible with Conv3D images = np.stack(images, axis=1) if log.isEnabledFor(logging.DEBUG): - log.debug('images shape: {}'.format(images.shape)) + log.debug("images shape: {}".format(images.shape)) # print(images.shape) - outputs = {'images': images} + outputs = {"images": images} if self.supervised: label = self.labels[index] if self.reduce: try: label = np.where(label)[0][0].astype(np.int64) except IndexError: - logging.error(f'label {index} from video {self.videofile} has no positive labels! {label}') + logging.error(f"label {index} from video {self.videofile} has no positive labels! {label}") raise - outputs['labels'] = label + outputs["labels"] = label return outputs class VideoDataset(data.Dataset): - """ Simple wrapper around SingleVideoDataset for smoothly loading multiple videos """ + """Simple wrapper around SingleVideoDataset for smoothly loading multiple videos""" def __init__(self, videofiles: list, labelfiles: list, *args, **kwargs): datasets, labels = [], [] @@ -410,33 +427,34 @@ def __getitem__(self, index: int): class SingleSequenceDataset(data.Dataset): """PyTorch Dataset for loading a set of saved 1d features and one-hot labels for Action Detection. - Features: - - Loads a set of sequential frames and sequential one-hot labels - - loads by indexing from an HDF5 dataset, given a dataset name (latent_name) - - Pads beginning or end so that every label has a corresponding clip - - Optionally loads two-stream features - - Example: - dataset = SequenceDataset(['features1.h5', 'features2.h5'], label_files=['labels1.csv', 'labels2.csv', - h5_key='CNN_features', sequence_length=180, is_two_stream=True) - features, labels = dataset(np.random.randint(low=0, high=len(dataset)) - print(features.shape) - # 180 x 1024 - print(labels.shape) - # assuming there are 5 classes in dataset - # ~5 x 180 - """ + Features: + - Loads a set of sequential frames and sequential one-hot labels + - loads by indexing from an HDF5 dataset, given a dataset name (latent_name) + - Pads beginning or end so that every label has a corresponding clip + - Optionally loads two-stream features - def __init__(self, - data_file: Union[str, os.PathLike], - labelfile: Union[str, os.PathLike], - N: int, - sequence_length: int = 60, - nonoverlapping: bool = True, - store_in_ram: bool = True, - reduce: bool = False, - stack_in_time: bool = False): + Example: + dataset = SequenceDataset(['features1.h5', 'features2.h5'], label_files=['labels1.csv', 'labels2.csv', + h5_key='CNN_features', sequence_length=180, is_two_stream=True) + features, labels = dataset(np.random.randint(low=0, high=len(dataset)) + print(features.shape) + # 180 x 1024 + print(labels.shape) + # assuming there are 5 classes in dataset + # ~5 x 180 + """ + def __init__( + self, + data_file: Union[str, os.PathLike], + labelfile: Union[str, os.PathLike], + N: int, + sequence_length: int = 60, + nonoverlapping: bool = True, + store_in_ram: bool = True, + reduce: bool = False, + stack_in_time: bool = False, + ): self.reduce = reduce assert os.path.isfile(data_file) @@ -467,7 +485,7 @@ def __init__(self, self.verify_dataset() tmp_sequence = self.__getitem__(0) # self.read_sequence([0, 1]) - self.num_features = tmp_sequence['features'].shape[0] + self.num_features = tmp_sequence["features"].shape[0] def read_sequence(self, indices): raise NotImplementedError @@ -492,8 +510,7 @@ def compute_starts_ends(self): inds = np.arange(self.N) self.starts = inds - self.sequence_length // 2 # if it's odd, should go from - self.ends = inds + self.sequence_length//2 + \ - self.sequence_length % 2 + self.ends = inds + self.sequence_length // 2 + self.sequence_length % 2 def __len__(self): return len(self.starts) @@ -530,20 +547,24 @@ def compute_indices_and_padding(self, index): label_indices = indices label_pad = pad - assert (len(indices) + pad_left + pad_right) == self.sequence_length, \ - 'indices: {} + pad_left: {} + pad_right: {} should equal seq len: {}'.format( - len(indices), pad_left, pad_right, self.sequence_length) + assert ( + len(indices) + pad_left + pad_right + ) == self.sequence_length, "indices: {} + pad_left: {} + pad_right: {} should equal seq len: {}".format( + len(indices), pad_left, pad_right, self.sequence_length + ) # if we are stacking in time, label indices should not be the sequence length if not self.stack_in_time: - assert (len(label_indices) + label_pad[0] + label_pad[1]) == self.sequence_length, \ - 'label indices: {} + pad_left: {} + pad_right: {} should equal seq len: {}'.format( - len(label_indices), label_pad[0], label_pad[1], self.sequence_length) + assert ( + (len(label_indices) + label_pad[0] + label_pad[1]) == self.sequence_length + ), "label indices: {} + pad_left: {} + pad_right: {} should equal seq len: {}".format( + len(label_indices), label_pad[0], label_pad[1], self.sequence_length + ) return indices, label_indices, pad, label_pad def __del__(self): - if hasattr(self, 'sequence'): + if hasattr(self, "sequence"): del self.sequence - if hasattr(self, 'labels'): + if hasattr(self, "labels"): del self.labels def __getitem__(self, index: int) -> dict: @@ -557,7 +578,7 @@ def __getitem__(self, index: int) -> dict: output = {} pad_left, pad_right = pad for key, value in data.items(): - value = np.pad(value, ((0, 0), (pad_left, pad_right)), mode='constant') + value = np.pad(value, ((0, 0), (pad_left, pad_right)), mode="constant") if self.stack_in_time: value = value.flatten() value = torch.from_numpy(value).float() @@ -571,41 +592,43 @@ def __getitem__(self, index: int) -> dict: labels = self.label[:, label_indices].astype(np.int64) if labels.ndim == 1: labels = labels[:, np.newaxis] - labels = np.pad(labels, ((0, 0), (pad_left, pad_right)), mode='constant', constant_values=-1) + labels = np.pad(labels, ((0, 0), (pad_left, pad_right)), mode="constant", constant_values=-1) else: labels = self.label[label_indices].astype(np.int64) - labels = np.pad(labels, (pad_left, pad_right), mode='constant', constant_values=-1) + labels = np.pad(labels, (pad_left, pad_right), mode="constant", constant_values=-1) # if we stack in time, we want to make sure we have labels of shape (N_behaviors,) # not (N_behaviors, 1) labels = labels.squeeze() labels = torch.from_numpy(labels).to(torch.long) - output['labels'] = labels + output["labels"] = labels - if labels.ndim > 1 and labels.shape[1] != output['features'].shape[1]: - out_shape = output['features'].shape - raise ValueError(f'problem in label shape! {labels.shape}, {out_shape}') + if labels.ndim > 1 and labels.shape[1] != output["features"].shape[1]: + out_shape = output["features"].shape + raise ValueError(f"problem in label shape! {labels.shape}, {out_shape}") return output class KeypointDataset(SingleSequenceDataset): - """Dataset for reading keypoints (e.g. from deeplabcut) and performing basis function expansion. - + """Dataset for reading keypoints (e.g. from deeplabcut) and performing basis function expansion. + Currently, only an edited variant of Sturman et al.'s basis expansion is implemented - Sturman, O., von Ziegler, L., Schläppi, C. et al. Deep learning-based behavioral analysis reaches human - accuracy and is capable of outperforming commercial solutions. Neuropsychopharmacol. 45, 1942–1952 (2020). + Sturman, O., von Ziegler, L., Schläppi, C. et al. Deep learning-based behavioral analysis reaches human + accuracy and is capable of outperforming commercial solutions. Neuropsychopharmacol. 45, 1942–1952 (2020). https://doi.org/10.1038/s41386-020-0776-y """ - def __init__(self, - data_file: Union[str, os.PathLike], - labelfile: Union[str, os.PathLike], - videofile: Union[str, os.PathLike], - expansion_method: str = 'sturman', - confidence_threshold: float = 0.9, - *args, - **kwargs): + def __init__( + self, + data_file: Union[str, os.PathLike], + labelfile: Union[str, os.PathLike], + videofile: Union[str, os.PathLike], + expansion_method: str = "sturman", + confidence_threshold: float = 0.9, + *args, + **kwargs, + ): """Constructor Parameters @@ -624,10 +647,10 @@ def __init__(self, Raises ------ NotImplementedError - For basis function expansion. Currently, only 'sturman' is implemented: - + For basis function expansion. Currently, only 'sturman' is implemented: + """ - if expansion_method == 'sturman': + if expansion_method == "sturman": self.expansion_func = expand_features_sturman else: raise NotImplementedError @@ -664,45 +687,39 @@ def __init__(self, def verify_dataset(self): if self.supervised: - assert self.label.shape[1] == self.sequence.shape[1], 'label {} and sequence {} shape do not match!'.format( - self.label.shape, self.sequence.shape) + assert self.label.shape[1] == self.sequence.shape[1], "label {} and sequence {} shape do not match!".format( + self.label.shape, self.sequence.shape + ) assert self.sequence is not None def read_sequence(self, indices): data = {} - data['features'] = self.sequence[:, indices] + data["features"] = self.sequence[:, indices] return data class FeatureVectorDataset(SingleSequenceDataset): - """Reads image and flow feature vectors from HDF5 files. - """ - - def __init__(self, - data_file, - labelfile, - h5_key: str, - store_in_ram=False, - is_two_stream: bool = True, - *args, - **kwargs): + """Reads image and flow feature vectors from HDF5 files.""" + def __init__( + self, data_file, labelfile, h5_key: str, store_in_ram=False, is_two_stream: bool = True, *args, **kwargs + ): self.is_two_stream = is_two_stream self.store_in_ram = store_in_ram assert os.path.isfile(data_file) self.key = h5_key if self.is_two_stream: - self.flow_key = self.key + '/flow_features' - self.image_key = self.key + '/spatial_features' - self.logit_key = self.key + '/logits' + self.flow_key = self.key + "/flow_features" + self.image_key = self.key + "/spatial_features" + self.logit_key = self.key + "/logits" self.data_file = data_file self.verify_dataset() data = self.read_features_from_disk(None, None) - features_shape = data['features'].shape + features_shape = data["features"].shape self.shape = features_shape self.N = self.shape[1] if self.store_in_ram: @@ -714,7 +731,7 @@ def __init__(self, super().__init__(data_file, labelfile, self.N, *args, **kwargs) def verify_dataset(self): - with h5py.File(self.data_file, 'r') as f: + with h5py.File(self.data_file, "r") as f: assert self.logit_key in f if self.is_two_stream: @@ -731,12 +748,12 @@ def verify_dataset(self): def read_features_from_disk(self, start_ind, end_ind): inds = slice(start_ind, end_ind) - with h5py.File(self.data_file, 'r') as f: + with h5py.File(self.data_file, "r") as f: if self.is_two_stream: flow_shape = f[self.flow_key].shape image_shape = f[self.image_key].shape assert len(flow_shape) == 2 - assert (flow_shape == image_shape) + assert flow_shape == image_shape # we want each timepoint to be one COLUMN flow_feats = f[self.flow_key][inds, :].T image_feats = f[self.image_key][inds, :].T @@ -749,7 +766,7 @@ def read_features_from_disk(self, start_ind, end_ind): def read_sequence(self, indices): if self.store_in_ram: - data = {'features': self.data['features'][:, indices], 'logits': self.data['logits'][:, indices]} + data = {"features": self.data["features"][:, indices], "logits": self.data["logits"][:, indices]} else: # assume indices are in order # we use the start and end so that we can slice without knowing the exact size of the dataset @@ -758,15 +775,11 @@ def read_sequence(self, indices): class SequenceDataset(data.Dataset): - """ Simple wrapper around SingleSequenceDataset for smoothly loading multiple sequences """ - - def __init__(self, - datafiles: list, - labelfiles: list, - videofiles: list = None, - is_keypoint: bool = False, - *args, - **kwargs): + """Simple wrapper around SingleSequenceDataset for smoothly loading multiple sequences""" + + def __init__( + self, datafiles: list, labelfiles: list, videofiles: list = None, is_keypoint: bool = False, *args, **kwargs + ): datasets = [] for i, (datafile, labelfile) in enumerate(zip(datafiles, labelfiles)): if is_keypoint: @@ -800,23 +813,25 @@ def __getitem__(self, index: int): return self.dataset[index] -def get_video_datasets(datadir: Union[str, os.PathLike], - xform: dict, - is_two_stream: bool = False, - reload_split: bool = True, - splitfile: Union[str, os.PathLike] = None, - train_val_test: Union[list, np.ndarray] = [0.8, 0.1, 0.1], - weight_exp: float = 1.0, - rgb_frames: int = 1, - flow_frames: int = 10, - supervised=True, - reduce=False, - flow_max: int = 5, - flow_style: str = 'linear', - valid_splits_only: bool = True, - conv_mode: str = '2d', - mean_by_channels: list = [0.5, 0.5, 0.5]): - """ Gets dataloaders for video-based datasets. +def get_video_datasets( + datadir: Union[str, os.PathLike], + xform: dict, + is_two_stream: bool = False, + reload_split: bool = True, + splitfile: Union[str, os.PathLike] = None, + train_val_test: Union[list, np.ndarray] = [0.8, 0.1, 0.1], + weight_exp: float = 1.0, + rgb_frames: int = 1, + flow_frames: int = 10, + supervised=True, + reduce=False, + flow_max: int = 5, + flow_style: str = "linear", + valid_splits_only: bool = True, + conv_mode: str = "2d", + mean_by_channels: list = [0.5, 0.5, 0.5], +): + """Gets dataloaders for video-based datasets. Parameters ---------- @@ -874,11 +889,11 @@ def get_video_datasets(datadir: Union[str, os.PathLike], split contains the split dictionary, for saving keys for loss weighting are also added. see make_loss_weight for explanation """ - return_types = ['rgb'] + return_types = ["rgb"] if is_two_stream: - return_types += ['flow'] + return_types += ["flow"] if supervised: - return_types += ['label'] + return_types += ["label"] # records: dictionary of dictionaries. Keys: unique data identifiers # values: a dictionary corresponding to different files. the first record might be: # {'mouse000': {'rgb': path/to/rgb.avi, 'label':path/to/labels.csv} } @@ -890,73 +905,77 @@ def get_video_datasets(datadir: Union[str, os.PathLike], records = purge_unlabeled_elements_from_records(records) if len(records) < 3: - error_message = 'You only have {} valid videos with file types {}!'.format(len(records), return_types) - error_message += 'You need at least 3 videos in your project to begin training.' + error_message = "You only have {} valid videos with file types {}!".format(len(records), return_types) + error_message += "You need at least 3 videos in your project to begin training." raise ValueError(error_message) # returns a dictionary, where each split in ['train', 'val', 'test'] as a list of keys # each key corresponds to a unique directory, and has - split_dictionary = get_split_from_records(records, datadir, splitfile, supervised, reload_split, valid_splits_only, - train_val_test) + split_dictionary = get_split_from_records( + records, datadir, splitfile, supervised, reload_split, valid_splits_only, train_val_test + ) # it's possible that your split has records that are invalid for the current task. # e.g.: you've added a video, but not labeled it yet. In that case, it will already be in your split, but it is # invalid for current purposes, because it has no label. Therefore, we want to remove it from the current split split_dictionary = remove_invalid_records_from_split_dictionary(split_dictionary, records) datasets = {} - for i, split in enumerate(['train', 'val', 'test']): - rgb = [records[i]['rgb'] for i in split_dictionary[split]] - flow = [records[i]['flow'] for i in split_dictionary[split]] + for i, split in enumerate(["train", "val", "test"]): + rgb = [records[i]["rgb"] for i in split_dictionary[split]] + flow = [records[i]["flow"] for i in split_dictionary[split]] - if split == 'test' and len(rgb) == 0: + if split == "test" and len(rgb) == 0: datasets[split] = None continue if supervised: - labelfiles = [records[i]['label'] for i in split_dictionary[split]] + labelfiles = [records[i]["label"] for i in split_dictionary[split]] else: labelfiles = None - datasets[split] = VideoDataset(rgb, - labelfiles, - frames_per_clip=rgb_frames, - reduce=reduce, - transform=xform[split], - conv_mode=conv_mode, - mean_by_channels=mean_by_channels) - data_info = {'split': split_dictionary} + datasets[split] = VideoDataset( + rgb, + labelfiles, + frames_per_clip=rgb_frames, + reduce=reduce, + transform=xform[split], + conv_mode=conv_mode, + mean_by_channels=mean_by_channels, + ) + data_info = {"split": split_dictionary} if supervised: - data_info['class_counts'] = datasets['train'].class_counts - data_info['num_classes'] = len(data_info['class_counts']) - pos_weight, softmax_weight = make_loss_weight(data_info['class_counts'], - datasets['train'].num_pos, - datasets['train'].num_neg, - weight_exp=weight_exp) - data_info['pos'] = datasets['train'].num_pos - data_info['neg'] = datasets['train'].num_neg - data_info['pos_weight'] = pos_weight - data_info['loss_weight'] = softmax_weight + data_info["class_counts"] = datasets["train"].class_counts + data_info["num_classes"] = len(data_info["class_counts"]) + pos_weight, softmax_weight = make_loss_weight( + data_info["class_counts"], datasets["train"].num_pos, datasets["train"].num_neg, weight_exp=weight_exp + ) + data_info["pos"] = datasets["train"].num_pos + data_info["neg"] = datasets["train"].num_neg + data_info["pos_weight"] = pos_weight + data_info["loss_weight"] = softmax_weight return datasets, data_info -def get_sequence_datasets(datadir: Union[str, os.PathLike], - latent_name: str, - sequence_length: int = 60, - is_two_stream: bool = True, - nonoverlapping: bool = True, - splitfile: str = None, - reload_split: bool = True, - store_in_ram: bool = False, - train_val_test: Union[list, np.ndarray] = [0.8, 0.2, 0.0], - weight_exp: float = 1.0, - supervised=True, - reduce=False, - valid_splits_only: bool = True, - is_keypoint: bool = False, - stack_in_time: bool = False) -> Tuple[dict, dict]: - """ Gets dataloaders for sequence models assuming DeepEthogram file structure. +def get_sequence_datasets( + datadir: Union[str, os.PathLike], + latent_name: str, + sequence_length: int = 60, + is_two_stream: bool = True, + nonoverlapping: bool = True, + splitfile: str = None, + reload_split: bool = True, + store_in_ram: bool = False, + train_val_test: Union[list, np.ndarray] = [0.8, 0.2, 0.0], + weight_exp: float = 1.0, + supervised=True, + reduce=False, + valid_splits_only: bool = True, + is_keypoint: bool = False, + stack_in_time: bool = False, +) -> Tuple[dict, dict]: + """Gets dataloaders for sequence models assuming DeepEthogram file structure. Parameters ---------- @@ -1022,12 +1041,12 @@ def get_sequence_datasets(datadir: Union[str, os.PathLike], """ return_types = [] if is_keypoint: - log.info('Creating keypoint datasets, with feature expansion. Might take a few minutes') - return_types.append('keypoint') + log.info("Creating keypoint datasets, with feature expansion. Might take a few minutes") + return_types.append("keypoint") else: - return_types.append('output') + return_types.append("output") if supervised: - return_types.append('label') + return_types.append("label") # records: dictionary of dictionaries. Keys: unique data identifiers # values: a dictionary corresponding to different files. the first record might be: @@ -1040,14 +1059,15 @@ def get_sequence_datasets(datadir: Union[str, os.PathLike], records = purge_unlabeled_elements_from_records(records) if len(records) < 3: - error_message = 'You only have {} valid videos with file types {}!'.format(len(records), return_types) - error_message += 'You need at least 3 videos in your project to begin training.' + error_message = "You only have {} valid videos with file types {}!".format(len(records), return_types) + error_message += "You need at least 3 videos in your project to begin training." raise ValueError(error_message) # returns a dictionary, where each split in ['train', 'val', 'test'] as a list of keys # each key corresponds to a unique directory, and has - split_dictionary = get_split_from_records(records, datadir, splitfile, supervised, reload_split, valid_splits_only, - train_val_test) + split_dictionary = get_split_from_records( + records, datadir, splitfile, supervised, reload_split, valid_splits_only, train_val_test + ) # it's possible that your split has records that are invalid for the current task. # e.g.: you've added a video, but not labeled it yet. In that case, it will already be in your split, but it is # invalid for current purposes, because it has no label. Therefore, we want to remove it from the current split @@ -1055,75 +1075,78 @@ def get_sequence_datasets(datadir: Union[str, os.PathLike], # log.info('~~~~~ train val test split ~~~~~') # pprint.pprint(split_dictionary) - splits = ['train', 'val', 'test'] + splits = ["train", "val", "test"] datasets = {} # if stack_in_time, nonoverlapping would make us skip a bunch of labels - nonoverlapping = {'train': nonoverlapping, 'val': not stack_in_time, 'test': not stack_in_time} + nonoverlapping = {"train": nonoverlapping, "val": not stack_in_time, "test": not stack_in_time} for split in splits: if is_keypoint: - videofiles = [records[i]['rgb'] for i in split_dictionary[split]] - datafiles = [records[i]['keypoint'] for i in split_dictionary[split]] + videofiles = [records[i]["rgb"] for i in split_dictionary[split]] + datafiles = [records[i]["keypoint"] for i in split_dictionary[split]] else: videofiles = None - datafiles = [records[i]['output'] for i in split_dictionary[split]] + datafiles = [records[i]["output"] for i in split_dictionary[split]] - if split == 'test' and len(datafiles) == 0: + if split == "test" and len(datafiles) == 0: datasets[split] = None continue # h5file, labelfile = outputs[i] # print('making dataset:{}'.format(split)) if supervised: - labelfiles = [records[i]['label'] for i in split_dictionary[split]] + labelfiles = [records[i]["label"] for i in split_dictionary[split]] else: labelfiles = None # todo: figure out a nice way to be able to pass arguments to one subclass that don't exist in the other # example: is_two_stream, latent_name if is_keypoint: - datasets[split] = SequenceDataset(datafiles, - labelfiles, - videofiles, - sequence_length=sequence_length, - nonoverlapping=nonoverlapping[split], - store_in_ram=store_in_ram, - reduce=reduce, - is_keypoint=is_keypoint, - stack_in_time=stack_in_time) + datasets[split] = SequenceDataset( + datafiles, + labelfiles, + videofiles, + sequence_length=sequence_length, + nonoverlapping=nonoverlapping[split], + store_in_ram=store_in_ram, + reduce=reduce, + is_keypoint=is_keypoint, + stack_in_time=stack_in_time, + ) else: - datasets[split] = SequenceDataset(datafiles, - labelfiles, - videofiles=videofiles, - sequence_length=sequence_length, - h5_key=latent_name, - is_two_stream=is_two_stream, - nonoverlapping=nonoverlapping[split], - store_in_ram=store_in_ram, - reduce=reduce, - is_keypoint=is_keypoint, - stack_in_time=stack_in_time) + datasets[split] = SequenceDataset( + datafiles, + labelfiles, + videofiles=videofiles, + sequence_length=sequence_length, + h5_key=latent_name, + is_two_stream=is_two_stream, + nonoverlapping=nonoverlapping[split], + store_in_ram=store_in_ram, + reduce=reduce, + is_keypoint=is_keypoint, + stack_in_time=stack_in_time, + ) # figure out what our inputs to our model will be (D dimension) - data_info = {'split': split_dictionary} - data_info['num_features'] = datasets['train'].num_features + data_info = {"split": split_dictionary} + data_info["num_features"] = datasets["train"].num_features if supervised: - data_info['class_counts'] = datasets['train'].class_counts - data_info['num_classes'] = len(data_info['class_counts']) - pos_weight, softmax_weight = make_loss_weight(data_info['class_counts'], - datasets['train'].num_pos, - datasets['train'].num_neg, - weight_exp=weight_exp) - data_info['pos'] = datasets['train'].num_pos - data_info['neg'] = datasets['train'].num_neg - data_info['pos_weight'] = pos_weight - data_info['loss_weight'] = softmax_weight + data_info["class_counts"] = datasets["train"].class_counts + data_info["num_classes"] = len(data_info["class_counts"]) + pos_weight, softmax_weight = make_loss_weight( + data_info["class_counts"], datasets["train"].num_pos, datasets["train"].num_neg, weight_exp=weight_exp + ) + data_info["pos"] = datasets["train"].num_pos + data_info["neg"] = datasets["train"].num_neg + data_info["pos_weight"] = pos_weight + data_info["loss_weight"] = softmax_weight return datasets, data_info def get_datasets_from_cfg(cfg: DictConfig, model_type: str, input_images: int = 1) -> Tuple[dict, dict]: - """ Returns dataloader objects using a Hydra-generated configuration dictionary. + """Returns dataloader objects using a Hydra-generated configuration dictionary. This is the main entry point for getting dataloaders from the command line. it will return the correct dataloader with given hyperparameters for either flow, feature extractor, or sequence models. @@ -1149,52 +1172,56 @@ def get_datasets_from_cfg(cfg: DictConfig, model_type: str, input_images: int = information see the specific dataloader of the model you're training, e.g. get_video_dataloaders """ # - supervised = model_type != 'flow_generator' - if model_type == 'feature_extractor' or model_type == 'flow_generator': + supervised = model_type != "flow_generator" + if model_type == "feature_extractor" or model_type == "flow_generator": arch = cfg[model_type].arch - mode = '3d' if '3d' in arch.lower() else '2d' + mode = "3d" if "3d" in arch.lower() else "2d" # log.info('getting dataloaders: {} convolution type detected'.format(mode)) xform = get_cpu_transforms(cfg.augs) - if cfg.project.name == 'kinetics': + if cfg.project.name == "kinetics": raise NotImplementedError else: reduce = False - if cfg.run.model == 'feature_extractor': - if cfg.feature_extractor.final_activation == 'softmax': + if cfg.run.model == "feature_extractor": + if cfg.feature_extractor.final_activation == "softmax": reduce = True - datasets, info = get_video_datasets(datadir=cfg.project.data_path, - xform=xform, - is_two_stream=False, - reload_split=cfg.split.reload, - splitfile=cfg.split.file, - train_val_test=cfg.split.train_val_test, - weight_exp=cfg.train.loss_weight_exp, - rgb_frames=input_images, - supervised=supervised, - reduce=reduce, - valid_splits_only=True, - conv_mode=mode, - mean_by_channels=cfg.augs.normalization.mean) - - elif model_type == 'sequence': - if cfg.feature_extractor.final_activation == 'softmax': + datasets, info = get_video_datasets( + datadir=cfg.project.data_path, + xform=xform, + is_two_stream=False, + reload_split=cfg.split.reload, + splitfile=cfg.split.file, + train_val_test=cfg.split.train_val_test, + weight_exp=cfg.train.loss_weight_exp, + rgb_frames=input_images, + supervised=supervised, + reduce=reduce, + valid_splits_only=True, + conv_mode=mode, + mean_by_channels=cfg.augs.normalization.mean, + ) + + elif model_type == "sequence": + if cfg.feature_extractor.final_activation == "softmax": reduce = True - datasets, info = get_sequence_datasets(cfg.project.data_path, - cfg.sequence.latent_name, - cfg.sequence.sequence_length, - is_two_stream=True, - nonoverlapping=cfg.sequence.nonoverlapping, - splitfile=cfg.split.file, - reload_split=True, - store_in_ram=False, - train_val_test=cfg.split.train_val_test, - weight_exp=cfg.train.loss_weight_exp, - supervised=True, - reduce=cfg.feature_extractor.final_activation == 'softmax', - valid_splits_only=True, - stack_in_time=cfg.sequence.arch == 'mlp', - is_keypoint=cfg.sequence.input_type == 'keypoints') + datasets, info = get_sequence_datasets( + cfg.project.data_path, + cfg.sequence.latent_name, + cfg.sequence.sequence_length, + is_two_stream=True, + nonoverlapping=cfg.sequence.nonoverlapping, + splitfile=cfg.split.file, + reload_split=True, + store_in_ram=False, + train_val_test=cfg.split.train_val_test, + weight_exp=cfg.train.loss_weight_exp, + supervised=True, + reduce=cfg.feature_extractor.final_activation == "softmax", + valid_splits_only=True, + stack_in_time=cfg.sequence.arch == "mlp", + is_keypoint=cfg.sequence.input_type == "keypoints", + ) else: - raise ValueError('Unknown model type: {}'.format(model_type)) + raise ValueError("Unknown model type: {}".format(model_type)) return datasets, info diff --git a/deepethogram/data/keypoint_utils.py b/deepethogram/data/keypoint_utils.py index c934c41..e4491c8 100644 --- a/deepethogram/data/keypoint_utils.py +++ b/deepethogram/data/keypoint_utils.py @@ -7,6 +7,7 @@ log = logging.getLogger(__name__) + def interpolate_bad_values(keypoint: np.ndarray, threshold: float = 0.9) -> np.ndarray: """Interpolates keypoints with low confidence @@ -28,7 +29,7 @@ def interpolate_bad_values(keypoint: np.ndarray, threshold: float = 0.9) -> np.n keypoint_interped = keypoint.copy() - log.debug('fraction of points below {:.1f}: {:.4f}'.format(threshold, np.mean(keypoint[..., 2] < threshold))) + log.debug("fraction of points below {:.1f}: {:.4f}".format(threshold, np.mean(keypoint[..., 2] < threshold))) # TODO: VECTORIZE for i in range(kp): for j in range(2): @@ -49,7 +50,7 @@ def interpolate_bad_values(keypoint: np.ndarray, threshold: float = 0.9) -> np.n def normalize_keypoints(keypoints: np.ndarray, H: int, W: int) -> np.ndarray: - """Normalizes keypoints from range [(0, H), (0, W)] to range [(-1, 1), (-1, 1)]. + """Normalizes keypoints from range [(0, H), (0, W)] to range [(-1, 1), (-1, 1)]. Non-square images will use the maximum side length in the denominator. Parameters @@ -73,7 +74,7 @@ def normalize_keypoints(keypoints: np.ndarray, H: int, W: int) -> np.ndarray: def denormalize_keypoints(keypoints, H, W): - """Un-normalizes keypoints from range [(-1, 1), (-1, 1)] to range [(0, H), (0, W)]. + """Un-normalizes keypoints from range [(-1, 1), (-1, 1)] to range [(0, H), (0, W)]. Non-square images will use the maximum side length. Parameters @@ -107,7 +108,7 @@ def slow_alignment(keypoints, rotmats, origins): def compute_distance(arr1: np.ndarray, arr2: np.ndarray) -> np.ndarray: - """Computes euclidean distance along the final dimension of two input arrays. + """Computes euclidean distance along the final dimension of two input arrays. Parameters ---------- @@ -121,10 +122,11 @@ def compute_distance(arr1: np.ndarray, arr2: np.ndarray) -> np.ndarray: """ return np.sqrt(((arr1 - arr2) ** 2).sum(axis=-1)) + def poly_area(x: np.ndarray, y: np.ndarray): - """Returns area of the polygon specified by X and Y coordinates. + """Returns area of the polygon specified by X and Y coordinates. REQUIRES POINTS TO BE IN CLOCKWISE ORDER!! - + https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates Parameters @@ -149,7 +151,7 @@ def load_dlcfile(dlcfile: Union[str, os.PathLike]) -> Tuple[np.ndarray, list, pd ---------- dlcfile : Union[str, os.PathLike] Path to deeplabcut file - + Returns ------- Tuple[np.ndarray, list, pd.DataFrame] @@ -160,7 +162,7 @@ def load_dlcfile(dlcfile: Union[str, os.PathLike]) -> Tuple[np.ndarray, list, pd assert os.path.isfile(dlcfile) # TODO: make function to load HDF5s ending = os.path.splitext(dlcfile)[1] - assert ending == '.csv' + assert ending == ".csv" # read the csv df = pd.read_csv(dlcfile, index_col=0) @@ -180,7 +182,7 @@ def load_dlcfile(dlcfile: Union[str, os.PathLike]) -> Tuple[np.ndarray, list, pd return keypoints, bodyparts, df -def stack_features_in_time(features: np.ndarray, frames_before_and_after: int=15) -> np.ndarray: +def stack_features_in_time(features: np.ndarray, frames_before_and_after: int = 15) -> np.ndarray: """For an array of keypoints, stack the frames before and after the current frame into one single vector. Parameters @@ -198,23 +200,22 @@ def stack_features_in_time(features: np.ndarray, frames_before_and_after: int=15 assert features.ndim == 2 stacked_features = [] N = features.shape[0] - padded = np.pad(features, ((frames_before_and_after, frames_before_and_after), (0, 0)), - mode='edge') + padded = np.pad(features, ((frames_before_and_after, frames_before_and_after), (0, 0)), mode="edge") - for i in range(frames_before_and_after, N+frames_before_and_after): - start_ind = i- frames_before_and_after - end_ind = i + frames_before_and_after+1 + for i in range(frames_before_and_after, N + frames_before_and_after): + start_ind = i - frames_before_and_after + end_ind = i + frames_before_and_after + 1 stacked_features.append(padded[start_ind:end_ind, :].flatten()) stacked = np.stack(stacked_features) assert stacked.shape[0] == features.shape[0] - assert stacked.shape[1] == features.shape[1]*(frames_before_and_after*2 + 1) + assert stacked.shape[1] == features.shape[1] * (frames_before_and_after * 2 + 1) return stacked def expand_features_sturman(keypoints: np.ndarray, bodyparts: list, H: int, W: int) -> Tuple[np.ndarray, list]: - """ Expand 2D keypoints into features for behavioral classification. + """Expand 2D keypoints into features for behavioral classification. Based on Sturman et al. 2020: Sturman, O. et al. Deep learning-based behavioral analysis reaches human accuracy and is capable of @@ -257,7 +258,7 @@ def expand_features_sturman(keypoints: np.ndarray, bodyparts: list, H: int, W: i # add centroid as the 8th keypoint. mean of all paws keypoints = np.concatenate((keypoints, np.mean(keypoints[:, 1:5, :], axis=1, keepdims=True)), axis=1) - bodyparts += ['centroid'] + bodyparts += ["centroid"] # normalize keypoints = normalize_keypoints(keypoints, H, W) @@ -285,14 +286,17 @@ def expand_features_sturman(keypoints: np.ndarray, bodyparts: list, H: int, W: i # l_forepaw, nose, r_forepaw, r_hindpaw, tailbase, l_hindpaw area. must be clockwise and in order! areas = np.array( - [poly_area(aligned[i, [1, 0, 2, 4, 5, 3], 0], aligned[i, [1, 0, 2, 4, 5, 3], 1]) for i in range(len(aligned))]) + [poly_area(aligned[i, [1, 0, 2, 4, 5, 3], 0], aligned[i, [1, 0, 2, 4, 5, 3], 1]) for i in range(len(aligned))] + ) nose_tailbase_distance = compute_distance(aligned[:, 0, :], aligned[:, 5, :]) tailbase_tailtip_distance = compute_distance(aligned[:, 5, :], aligned[:, 6, :]) - forepaw_hindpaw_distance = (compute_distance(aligned[:, 1, :], aligned[:, 3, :]) + - compute_distance(aligned[:, 2, :], aligned[:, 4, :])) / 2 - forepaw_nose_distance = (compute_distance(aligned[:, 0, :], aligned[:, 1, :]) + - compute_distance(aligned[:, 0, :], aligned[:, 2, :])) / 2 + forepaw_hindpaw_distance = ( + compute_distance(aligned[:, 1, :], aligned[:, 3, :]) + compute_distance(aligned[:, 2, :], aligned[:, 4, :]) + ) / 2 + forepaw_nose_distance = ( + compute_distance(aligned[:, 0, :], aligned[:, 1, :]) + compute_distance(aligned[:, 0, :], aligned[:, 2, :]) + ) / 2 forepaw_forepaw_distance = compute_distance(aligned[:, 1, :], aligned[:, 2, :]) hindpaw_hindpaw_distance = compute_distance(aligned[:, 3, :], aligned[:, 4, :]) @@ -300,41 +304,41 @@ def expand_features_sturman(keypoints: np.ndarray, bodyparts: list, H: int, W: i features = [] columns = [] for i in range(len(bodyparts)): - for j, coord in enumerate(['x', 'y']): + for j, coord in enumerate(["x", "y"]): features.append(keypoints[:, i, j]) - columns.append('{}_{}'.format(bodyparts[i], coord)) + columns.append("{}_{}".format(bodyparts[i], coord)) for i in range(len(bodyparts)): - for j, coord in enumerate(['x', 'y']): + for j, coord in enumerate(["x", "y"]): features.append(aligned[:, i, j]) - columns.append('{}_{}_aligned'.format(bodyparts[i], coord)) + columns.append("{}_{}_aligned".format(bodyparts[i], coord)) features.append(tail_angle) - columns.append('tail_angle') + columns.append("tail_angle") for i in range(4): features.append(paw_angles[:, i]) - columns.append('{}_centroid_angle'.format(bodyparts[i + 1])) + columns.append("{}_centroid_angle".format(bodyparts[i + 1])) features.append(nose_tailbase_distance) - columns.append('nose_tailbase_dist') + columns.append("nose_tailbase_dist") features.append(tailbase_tailtip_distance) - columns.append('tailbase_tailtip_dist') + columns.append("tailbase_tailtip_dist") features.append(forepaw_hindpaw_distance) - columns.append('forepaw_hindpaw_dist') + columns.append("forepaw_hindpaw_dist") features.append(forepaw_nose_distance) - columns.append('forepaw_nose_dist') + columns.append("forepaw_nose_dist") features.append(forepaw_forepaw_distance) - columns.append('forepaw_forepaw_dist') + columns.append("forepaw_forepaw_dist") features.append(hindpaw_hindpaw_distance) - columns.append('hindpaw_hindpaw_dist') + columns.append("hindpaw_hindpaw_dist") features.append(areas) - columns.append('body_area') + columns.append("body_area") features = np.stack(features, axis=-1) # z-score denominator = features.std(axis=0, keepdims=True) - denominator[denominator < 1e-6 ] = 1e-6 + denominator[denominator < 1e-6] = 1e-6 z = (features - features.mean(axis=0, keepdims=True)) / denominator return z, columns diff --git a/deepethogram/data/utils.py b/deepethogram/data/utils.py index 06c991e..052c1f0 100644 --- a/deepethogram/data/utils.py +++ b/deepethogram/data/utils.py @@ -12,6 +12,7 @@ from vidio import VideoReader from deepethogram import utils + # from deepethogram.dataloaders import log from deepethogram.file_io import read_labels @@ -26,8 +27,8 @@ def purge_unlabeled_videos(video_list: list, label_list: list) -> Tuple[list, li valid_videos = [] valid_labels = [] - warning_string = '''Labelfile {} associated with video {} has unlabeled frames! - Please finish labeling or click the Finalize Labels button on the GUI.''' + warning_string = """Labelfile {} associated with video {} has unlabeled frames! + Please finish labeling or click the Finalize Labels button on the GUI.""" for i in range(len(label_list)): label = read_labels(label_list[i]) @@ -43,16 +44,18 @@ def purge_unlabeled_videos(video_list: list, label_list: list) -> Tuple[list, li def purge_unlabeled_elements_from_records(records: dict) -> dict: valid_records = {} - warning_message = '''labelfile {} has unlabeled frames! + warning_message = """labelfile {} has unlabeled frames! Please finish labeling or click the Finalize Labels button on the GUI. - Associated files: {}''' + Associated files: {}""" for animal, record in records.items(): - labelfile = record['label'] + labelfile = record["label"] if labelfile is None: - log.warning('Record {} does not have a labelfile! Please start and finish labeling. '.format(animal) + \ - 'Associated files: {}'.format(record)) + log.warning( + "Record {} does not have a labelfile! Please start and finish labeling. ".format(animal) + + "Associated files: {}".format(record) + ) continue label = read_labels(labelfile) has_unlabeled_frames = np.any(label == -1) @@ -63,11 +66,10 @@ def purge_unlabeled_elements_from_records(records: dict) -> dict: return valid_records -def make_loss_weight(class_counts: np.ndarray, - num_pos: np.ndarray, - num_neg: np.ndarray, - weight_exp: float = 1) -> Tuple[np.ndarray, np.ndarray]: - """ Makes weight for different classes in loss function. +def make_loss_weight( + class_counts: np.ndarray, num_pos: np.ndarray, num_neg: np.ndarray, weight_exp: float = 1 +) -> Tuple[np.ndarray, np.ndarray]: + """Makes weight for different classes in loss function. In general, rare classes will be up-weighted and common classes will be down-weighted. @@ -95,7 +97,7 @@ def make_loss_weight(class_counts: np.ndarray, # if there are zero positive examples, we don't want the pos weight to be a large number # we want it to be infinity, then we will manually set it to zero with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=RuntimeWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) pos_weight = num_neg / num_pos # if there are zero negative examples, loss should be 1 pos_weight[pos_weight == 0] = 1 @@ -105,24 +107,24 @@ def make_loss_weight(class_counts: np.ndarray, pos_weight_transformed = np.nan_to_num(pos_weight_transformed, nan=0.0, posinf=0.0, neginf=0) with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=RuntimeWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) softmax_weight = 1 / class_counts softmax_weight = np.nan_to_num(softmax_weight, nan=0.0, posinf=0.0, neginf=0) softmax_weight = softmax_weight / np.sum(softmax_weight) softmax_weight_transformed = (softmax_weight**weight_exp).astype(np.float32) np.set_printoptions(suppress=True) - log.info('Class counts: {}'.format(class_counts)) - log.info('Pos weight: {}'.format(pos_weight)) - log.info('Pos weight, weighted: {}'.format(pos_weight_transformed)) - log.info('Softmax weight: {}'.format(softmax_weight)) - log.info('Softmax weight transformed: {}'.format(softmax_weight_transformed)) + log.info("Class counts: {}".format(class_counts)) + log.info("Pos weight: {}".format(pos_weight)) + log.info("Pos weight, weighted: {}".format(pos_weight_transformed)) + log.info("Softmax weight: {}".format(softmax_weight)) + log.info("Softmax weight transformed: {}".format(softmax_weight_transformed)) return pos_weight_transformed, softmax_weight_transformed def get_video_metadata(videofile): - """ Simple wrapper to get video availability, width, height, and frame number """ + """Simple wrapper to get video availability, width, height, and frame number""" try: with VideoReader(videofile) as reader: framenum = reader.nframes @@ -133,24 +135,23 @@ def get_video_metadata(videofile): except BaseException as e: ret = False print(e) - print('Error reading file {}'.format(videofile)) + print("Error reading file {}".format(videofile)) return ret, width, height, framenum def extract_metadata(splitdir, allmovies=None, is_flow=False, num_workers=32): - """ Function to get the video metadata for all videos in Kinetics """ + """Function to get the video metadata for all videos in Kinetics""" actions = os.listdir(splitdir) actions.sort() if allmovies is None: - allmovies = glob.glob(splitdir + '**/**/**.mp4') + glob.glob(splitdir + '**/**/**.avi') + allmovies = glob.glob(splitdir + "**/**/**.mp4") + glob.glob(splitdir + "**/**/**.avi") allmovies.sort() if not is_flow: - allmovies = [i for i in allmovies if 'flow' not in os.path.basename(i)] + allmovies = [i for i in allmovies if "flow" not in os.path.basename(i)] else: - - allmovies = [i for i in allmovies if 'flow' in os.path.basename(i)] + allmovies = [i for i in allmovies if "flow" in os.path.basename(i)] widths = [] heights = [] framenums = [] @@ -166,9 +167,9 @@ def extract_metadata(splitdir, allmovies=None, is_flow=False, num_workers=32): # movies = glob.glob(action_dir + '**/**.mp4') + glob.glob(action_dir + '**/**.avi') movies.sort() if not is_flow: - movies = [i for i in movies if 'flow' not in os.path.basename(i)] + movies = [i for i in movies if "flow" not in os.path.basename(i)] else: - movies = [i for i in movies if 'flow' in os.path.basename(i)] + movies = [i for i in movies if "flow" in os.path.basename(i)] results = pool.map(get_video_metadata, movies) success = [] @@ -189,62 +190,62 @@ def extract_metadata(splitdir, allmovies=None, is_flow=False, num_workers=32): action_indices.append(action_index) video_data = { - 'name': allnames, - 'action': allactions, - 'action_int': action_indices, - 'width': widths, - 'height': heights, - 'framecount': framenums + "name": allnames, + "action": allactions, + "action_int": action_indices, + "width": widths, + "height": heights, + "framecount": framenums, } df = pd.DataFrame(data=video_data) - fname = '_metadata.csv' + fname = "_metadata.csv" if is_flow: - fname = '_flow' + fname + fname = "_flow" + fname df.to_csv(os.path.join(os.path.dirname(splitdir), os.path.basename(splitdir) + fname)) return df def find_labelfile(video: Union[str, os.PathLike]) -> Tuple[str, str]: - """ Function for finding a label file for a given a video """ + """Function for finding a label file for a given a video""" base = os.path.splitext(video)[0] - labelfile = base + '_labels.csv' + labelfile = base + "_labels.csv" if os.path.isfile(labelfile): - return (labelfile, 'csv') - labelfile = base + '_labels.h5' + return (labelfile, "csv") + labelfile = base + "_labels.h5" if os.path.isfile(labelfile): - return (labelfile, 'h5') - labelfile = base + '_scores.csv' + return (labelfile, "h5") + labelfile = base + "_scores.csv" if os.path.isfile(labelfile): - return (labelfile, 'csv') - labelfile = base + '_scores.h5' + return (labelfile, "csv") + labelfile = base + "_scores.h5" if os.path.isfile(labelfile): - return (labelfile, 'h5') + return (labelfile, "h5") basedir = os.path.dirname(video) files = os.listdir(basedir) files.sort() files = [os.path.join(basedir, i) for i in files] # handles case where directory contains 'movie.avi', and 'labels.csv' - files = [i for i in files if 'label' in i or 'score' in i] + files = [i for i in files if "label" in i or "score" in i] if len(files) == 1: - if files[0].endswith('csv'): - return files[0], 'csv' - elif files[0].endswith('h5'): - return files[0], 'h5' - basename = os.path.basename(base).split('_')[:-1] - basename = '_'.join(basename) + if files[0].endswith("csv"): + return files[0], "csv" + elif files[0].endswith("h5"): + return files[0], "h5" + basename = os.path.basename(base).split("_")[:-1] + basename = "_".join(basename) matching_files = [i for i in files if basename in i] if len(matching_files) == 1: labelfile = matching_files[0] ext = os.path.splitext(labelfile)[1][1:] return labelfile, ext - raise ValueError('no corresponding labels found: {}'.format(video)) + raise ValueError("no corresponding labels found: {}".format(video)) def read_all_labels(labelfiles: list, fix: bool = True, multilabel: bool = True): - """ Function for reading all labels into memory """ + """Function for reading all labels into memory""" labels = [] for i, labelfile in enumerate(labelfiles): - assert (os.path.isfile(labelfile)) + assert os.path.isfile(labelfile) label_type = os.path.splitext(labelfile)[1][1:] # labelfile, label_type = find_labelfile(video) label = read_labels(labelfile) @@ -254,7 +255,7 @@ def read_all_labels(labelfiles: list, fix: bool = True, multilabel: bool = True) label = label.T if label.shape[1] == 1: # add a background class - warnings.warn('binary labels found, adding background class') + warnings.warn("binary labels found, adding background class") label = np.hstack((np.logical_not(label), label)) if fix: @@ -307,12 +308,12 @@ def parse_split(split: Union[tuple, list, np.ndarray], N: int): N = total split = split / split.sum() else: - raise ValueError('Unknown split type: {}'.format(split.dtype)) + raise ValueError("Unknown split type: {}".format(split.dtype)) return split, N def train_val_test_split(records: dict, split: Union[tuple, list, np.ndarray] = (0.7, 0.15, 0.15)) -> dict: - """ Split a dict of dicts into train, validation, and test sets. + """Split a dict of dicts into train, validation, and test sets. Parameters ---------- @@ -340,28 +341,28 @@ def train_val_test_split(records: dict, split: Union[tuple, list, np.ndarray] = # in place indices = np.random.permutation(N) - splits = ['train', 'val', 'test'] + splits = ["train", "val", "test"] keys = np.array(keys) outputs = {} - outputs['metadata'] = {'split': split.tolist()} + outputs["metadata"] = {"split": split.tolist()} # print(list(split)) # outputs['metadata']['split'] = split.tolist() # handle edge cases if len(records) < 4: assert len(records) > 1 - warnings.warn('Only {} records found...'.format(len(keys))) + warnings.warn("Only {} records found...".format(len(keys))) shuffled = np.random.permutation(keys) - outputs['train'] = [str(shuffled[0])] - outputs['val'] = [str(shuffled[1])] - outputs['test'] = [] + outputs["train"] = [str(shuffled[0])] + outputs["val"] = [str(shuffled[1])] + outputs["test"] = [] if len(records) == 3: shuffled = np.random.permutation(keys) - outputs['test'] = [str(shuffled[2])] + outputs["test"] = [str(shuffled[2])] return outputs for i, spl in enumerate(splits): shuffled = keys[indices] - splitfiles = shuffled[starts[i]:ends[i]] + splitfiles = shuffled[starts[i] : ends[i]] outputs[spl] = splitfiles.tolist() # print(type(split.tolist()[0])) @@ -369,21 +370,21 @@ def train_val_test_split(records: dict, split: Union[tuple, list, np.ndarray] = def do_all_classes_have_labels(records: dict, split_dict: dict) -> bool: - """ Helper function to determine if each split has at least one instance of every class """ + """Helper function to determine if each split has at least one instance of every class""" labelfiles = [] - for split in ['train', 'val', 'test']: + for split in ["train", "val", "test"]: if len(split_dict[split]) > 0: splitfiles = split_dict[split] for f in splitfiles: - labelfiles.append(records[f]['label']) + labelfiles.append(records[f]["label"]) # labelfiles += [records[i]['label'] for i in split_dict[split]] _, class_counts, _, _, _ = read_all_labels(labelfiles) return np.all(class_counts > 0) def get_valid_split(records: dict, train_val_test: Union[list, np.ndarray]) -> dict: - """ Gets a train, val, test split with at least one instance of every class + """Gets a train, val, test split with at least one instance of every class Keep doing train_test_split until each split of the data has at least one single example of every behavior in the dataset. it would be bad if your train data had class counts: [1000, 0, 0, 10] and your test data had @@ -409,11 +410,11 @@ class counts: [500, 100, 300, 0] split_dict = train_val_test_split(records, train_val_test) should_continue = do_all_classes_have_labels(records, split_dict) if not should_continue: - warnings.warn('Not all classes in the dataset have *any* labels!') + warnings.warn("Not all classes in the dataset have *any* labels!") return split_dict is_wrong = False - for split in ['train', 'val', 'test']: - labelfiles = [records[i]['label'] for i in split_dict[split]] + for split in ["train", "val", "test"]: + labelfiles = [records[i]["label"] for i in split_dict[split]] if len(labelfiles) > 0: _, class_counts, _, _, _ = read_all_labels(labelfiles) if not np.all(class_counts > 0): @@ -422,12 +423,12 @@ class counts: [500, 100, 300, 0] def update_split(records: dict, split_dictionary: dict) -> dict: - """ Updates existing split if there are new entries in the records dictionary """ + """Updates existing split if there are new entries in the records dictionary""" # records: dictionary of dictionaries. Keys: unique data identifiers # values: a dictionary corresponding to different files. the first record might be: # {'mouse000': {'rgb': path/to/rgb.avi, 'label':path/to/labels.csv} } # split_dictionary: {'metadata': ..., 'train':[mouse000, mouse001], 'val':[mouse002,mouse003]... etc} - old_dictionary = {k: v for (k, v) in split_dictionary.items() if k != 'metadata'} + old_dictionary = {k: v for (k, v) in split_dictionary.items() if k != "metadata"} # https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-list-of-lists old_keys = [item for sublist in old_dictionary.values() for item in sublist] old_keys.sort() @@ -437,29 +438,31 @@ def update_split(records: dict, split_dictionary: dict) -> dict: # of the split dictionary new_entries = [i for i in new_keys if i not in old_keys] splits = list(split_dictionary.keys()) - splits = [i for i in splits if i != 'metadata'] + splits = [i for i in splits if i != "metadata"] if len(splits) == 3: # alphabetical order does not work - splits = ['train', 'val', 'test'] + splits = ["train", "val", "test"] # goes through new entries, and assigns them to a split based on loaded split_probabilities if len(new_entries) > 0: - split_p = split_dictionary['metadata']['split'] + split_p = split_dictionary["metadata"]["split"] N = len(new_entries) new_splits = np.random.choice(splits, size=(N,), p=split_p).tolist() for i, k in enumerate(new_entries): split_dictionary[new_splits[i]].append(k) - log.info('file {} assigned to split {}'.format(k, new_splits[i])) + log.info("file {} assigned to split {}".format(k, new_splits[i])) return split_dictionary -def get_split_from_records(records: dict, - datadir: Union[str, bytes, os.PathLike], - splitfile: Union[str, bytes, os.PathLike] = None, - supervised: bool = True, - reload_split: bool = True, - valid_splits_only: bool = True, - train_val_test: list = [0.7, 0.15, 0.15]): - """ Splits the records into train, validation, and test splits +def get_split_from_records( + records: dict, + datadir: Union[str, bytes, os.PathLike], + splitfile: Union[str, bytes, os.PathLike] = None, + supervised: bool = True, + reload_split: bool = True, + valid_splits_only: bool = True, + train_val_test: list = [0.7, 0.15, 0.15], +): + """Splits the records into train, validation, and test splits Parameters ---------- @@ -484,9 +487,9 @@ def get_split_from_records(records: dict, see train_val_test_split """ if splitfile is None: - splitfile = os.path.join(datadir, 'split.yaml') + splitfile = os.path.join(datadir, "split.yaml") else: - assert os.path.isfile(splitfile), 'split file does not exist! {}'.format(splitfile) + assert os.path.isfile(splitfile), "split file does not exist! {}".format(splitfile) if supervised and valid_splits_only: # this function makes sure that each split has all classes in the dataset @@ -510,13 +513,13 @@ def get_split_from_records(records: dict, def remove_invalid_records_from_split_dictionary(split_dictionary: dict, records: dict) -> dict: - """ Removes records that exist in split_dictionary but not in records. + """Removes records that exist in split_dictionary but not in records. Can be useful if you previously had a video in your project and used that to make a train / val / test split, but later deleted it. """ valid_records = {} record_keys = list(records.keys()) - for split in ['train', 'val', 'test']: + for split in ["train", "val", "test"]: valid_records[split] = {} splitfiles = split_dictionary[split] for i, key in enumerate(record_keys): @@ -536,7 +539,7 @@ def count_multilabeled_frames(label): def fix_label(labelfile, label: np.ndarray, multilabel: bool = True) -> np.ndarray: n_unlabeled = count_unlabeled_frames(label) if n_unlabeled > 0: - logging.warning(f'file {labelfile} has {n_unlabeled} unlabeled frames!! setting to background...') + logging.warning(f"file {labelfile} has {n_unlabeled} unlabeled frames!! setting to background...") # set rows to 0 unlabeled = label.sum(axis=1) < 1 label[unlabeled, :] = 0 @@ -547,7 +550,7 @@ def fix_label(labelfile, label: np.ndarray, multilabel: bool = True) -> np.ndarr # labels must be mutually exclusive n_multilabel = count_multilabeled_frames(label) if n_multilabel > 1: - logging.warning(f'file {labelfile} has {n_multilabel} multi-label frames! randomly selecting the label...') + logging.warning(f"file {labelfile} has {n_multilabel} multi-label frames! randomly selecting the label...") multilabel = label.sum(axis=1) > 1 multilabel_inds = np.where(multilabel)[0] for ind in multilabel_inds: diff --git a/deepethogram/debug.py b/deepethogram/debug.py index 1bb300c..8928f5d 100644 --- a/deepethogram/debug.py +++ b/deepethogram/debug.py @@ -22,12 +22,12 @@ def print_models(model_path: Union[str, os.PathLike]) -> None: Absolute path to models directory """ trained_models = projects.get_weights_from_model_path(model_path) - log.info('Trained models: {}'.format(pprint.pformat(trained_models))) - + log.info("Trained models: {}".format(pprint.pformat(trained_models))) + def print_dataset_info(datadir: Union[str, os.PathLike]) -> None: - """Prints information about your dataset. - + """Prints information about your dataset. + - video path - number of unlabeled rows in a video - number of examples of each behavior in each video @@ -38,85 +38,87 @@ def print_dataset_info(datadir: Union[str, os.PathLike]) -> None: [description] """ records = projects.get_records_from_datadir(datadir) - + for key, record in records.items(): - log.info('Information about subdir {}'.format(key)) - if record['rgb'] is not None: - log.info('Video: {}'.format(record['rgb'])) - - if record['label'] is not None: - label = file_io.read_labels(record['label']) + log.info("Information about subdir {}".format(key)) + if record["rgb"] is not None: + log.info("Video: {}".format(record["rgb"])) + + if record["label"] is not None: + label = file_io.read_labels(record["label"]) if np.sum(label == -1) > 0: - unlabeled_rows = np.any(label == -1, axis=0) + unlabeled_rows = np.any(label == -1, axis=0) n_unlabeled = np.sum(unlabeled_rows) - log.warning('{} UNLABELED ROWS!'.format(n_unlabeled) + \ - 'VIDEO WILL NOT BE USED FOR FEATURE_EXTRACTOR OR SEQUENCE TRAINING.') + log.warning( + "{} UNLABELED ROWS!".format(n_unlabeled) + + "VIDEO WILL NOT BE USED FOR FEATURE_EXTRACTOR OR SEQUENCE TRAINING." + ) else: class_counts = label.sum(axis=0) - log.info('Labels with counts: {}'.format(class_counts)) - + log.info("Labels with counts: {}".format(class_counts)) + + def try_load_all_frames(datadir: Union[str, os.PathLike]): - """Attempts to read every image from every video. - + """Attempts to read every image from every video. + Useful for debugging corrupted videos, e.g. if saving to disk was aborted improperly during acquisition If there is an error reading a frame, it will print the video name and frame number - + Parameters ---------- datadir : Union[str, os.PathLike] absolute path to the project/DATA directory """ - log.info('Iterating through all frames of all movies to test for frame reading bugs') + log.info("Iterating through all frames of all movies to test for frame reading bugs") records = projects.get_records_from_datadir(datadir) for key, record in tqdm(records.items()): - with VideoReader(record['rgb']) as reader: - log.info('reading all frames from file {}'.format(record['rgb'])) + with VideoReader(record["rgb"]) as reader: + log.info("reading all frames from file {}".format(record["rgb"])) had_error = False for i in tqdm(range(len(reader)), leave=False): try: frame = reader[i] except Exception: had_error = True - print('error reading frame {} from video {}'.format(i, record['rgb'])) + print("error reading frame {} from video {}".format(i, record["rgb"])) except KeyboardInterrupt: raise if had_error: - log.warning('Error in file {}. Is this video corrupted?'.format(record['rgb'])) + log.warning("Error in file {}. Is this video corrupted?".format(record["rgb"])) else: - log.info('No problems in {}'.format(key)) - - -if __name__ == '__main__': - if os.path.isfile('debug.log'): - os.remove('debug.log') - logging.basicConfig(level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s", - handlers=[ - logging.FileHandler("debug.log"), - logging.StreamHandler() - ]) - + log.info("No problems in {}".format(key)) + + +if __name__ == "__main__": + if os.path.isfile("debug.log"): + os.remove("debug.log") + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.FileHandler("debug.log"), logging.StreamHandler()], + ) + cfg = OmegaConf.from_cli() if cfg.project.path is None and cfg.project.config_file is None: - raise ValueError('must input either a path or a config file') + raise ValueError("must input either a path or a config file") elif cfg.project.path is not None: - cfg.project.config_file = os.path.join(cfg.project.path, 'project_config.yaml') + cfg.project.config_file = os.path.join(cfg.project.path, "project_config.yaml") elif cfg.project.config_file is not None: - cfg.project.path = os.path.dirname(cfg.project.config_file) + cfg.project.path = os.path.dirname(cfg.project.config_file) else: - raise ValueError('must input either a path or a config file, not {}'.format(cfg)) - + raise ValueError("must input either a path or a config file, not {}".format(cfg)) + assert os.path.isfile(cfg.project.config_file) and os.path.isdir(cfg.project.path) - + user_cfg = OmegaConf.load(cfg.project.config_file) cfg = OmegaConf.merge(cfg, user_cfg) cfg = projects.convert_config_paths_to_absolute(cfg) # print(cfg) - + logging.info(OmegaConf.to_yaml(cfg)) - + print_models(cfg.project.model_path) - + print_dataset_info(cfg.project.data_path) - - try_load_all_frames(cfg.project.data_path) \ No newline at end of file + + try_load_all_frames(cfg.project.data_path) diff --git a/deepethogram/feature_extractor/inference.py b/deepethogram/feature_extractor/inference.py index 1dfbe1e..67d6787 100644 --- a/deepethogram/feature_extractor/inference.py +++ b/deepethogram/feature_extractor/inference.py @@ -27,8 +27,8 @@ torch.set_printoptions(sci_mode=False) -def unpack_penultimate_layer(model: Type[nn.Module], fusion: str = 'average'): - """ Adds the activations in the penulatimate layer of the given PyTorch module to a dictionary called 'activation'. +def unpack_penultimate_layer(model: Type[nn.Module], fusion: str = "average"): + """Adds the activations in the penulatimate layer of the given PyTorch module to a dictionary called 'activation'. Assumes the model has two subcomponents: spatial and flow models. Every time the forward pass of this network is run, the penultimate neural activations will be added to the activations dictionary. @@ -60,15 +60,15 @@ def hook(model, inputs, output): if len(inputs) == 1: inputs = inputs[0] else: - raise ValueError('unknown inputs: {}'.format(inputs)) + raise ValueError("unknown inputs: {}".format(inputs)) activation[name] = inputs.detach() return hook final_spatial_linear = get_linear_layers(model.spatial_classifier)[-1] - final_spatial_linear.register_forward_hook(get_inputs('spatial')) + final_spatial_linear.register_forward_hook(get_inputs("spatial")) final_flow_linear = get_linear_layers(model.flow_classifier)[-1] - final_flow_linear.register_forward_hook(get_inputs('flow')) + final_flow_linear.register_forward_hook(get_inputs("flow")) return activation @@ -96,14 +96,19 @@ def get_linear_layers(model: nn.Module) -> list: def get_penultimate_layer(model: Type[nn.Module]): - """ Function to unpack a linear layer from a nn sequential module """ + """Function to unpack a linear layer from a nn sequential module""" assert isinstance(model, nn.Module) children = list(model.children()) return children[-1] -def print_debug_statement(images: torch.Tensor, logits: torch.Tensor, spatial_features: torch.Tensor, - flow_features: torch.Tensor, probabilities: torch.Tensor): +def print_debug_statement( + images: torch.Tensor, + logits: torch.Tensor, + spatial_features: torch.Tensor, + flow_features: torch.Tensor, + probabilities: torch.Tensor, +): """prints useful debug information to make sure there are no inference bugs Parameters @@ -124,40 +129,48 @@ def print_debug_statement(images: torch.Tensor, logits: torch.Tensor, spatial_fe ValueError in case of non 4-d or 5-d input tensors """ - log.info('images shape: {}'.format(images.shape)) - log.info('logits shape: {}'.format(logits.shape)) - log.info('spatial_features shape: {}'.format(spatial_features.shape)) - log.info('flow_features shape: {}'.format(flow_features.shape)) - log.info('spatial: min {} mean {} max {} shape {}'.format(spatial_features.min(), spatial_features.mean(), - spatial_features.max(), spatial_features.shape)) - log.info('flow : min {} mean {} max {} shape {}'.format(flow_features.min(), flow_features.mean(), - flow_features.max(), flow_features.shape)) + log.info("images shape: {}".format(images.shape)) + log.info("logits shape: {}".format(logits.shape)) + log.info("spatial_features shape: {}".format(spatial_features.shape)) + log.info("flow_features shape: {}".format(flow_features.shape)) + log.info( + "spatial: min {} mean {} max {} shape {}".format( + spatial_features.min(), spatial_features.mean(), spatial_features.max(), spatial_features.shape + ) + ) + log.info( + "flow : min {} mean {} max {} shape {}".format( + flow_features.min(), flow_features.mean(), flow_features.max(), flow_features.shape + ) + ) # a common issue I've had is not properly z-scoring input channels. this will check for that if len(images.shape) == 4: N, C, H, W = images.shape elif images.ndim == 5: N, C, T, H, W = images.shape else: - raise ValueError('images of unknown shape: {}'.format(images.shape)) - - log.info('channel min: {}'.format(images[0].reshape(C, -1).min(dim=1).values)) - log.info('channel mean: {}'.format(images[0].reshape(C, -1).mean(dim=1))) - log.info('channel max : {}'.format(images[0].reshape(C, -1).max(dim=1).values)) - log.info('channel std : {}'.format(images[0].reshape(C, -1).std(dim=1))) - - -def predict_single_video(videofile: Union[str, os.PathLike], - model: nn.Module, - activation_function: nn.Module, - fusion: str, - num_rgb: int, - mean_by_channels: np.ndarray, - device: str = 'cuda:0', - cpu_transform=None, - gpu_transform=None, - should_print: bool = False, - num_workers: int = 1, - batch_size: int = 16): + raise ValueError("images of unknown shape: {}".format(images.shape)) + + log.info("channel min: {}".format(images[0].reshape(C, -1).min(dim=1).values)) + log.info("channel mean: {}".format(images[0].reshape(C, -1).mean(dim=1))) + log.info("channel max : {}".format(images[0].reshape(C, -1).max(dim=1).values)) + log.info("channel std : {}".format(images[0].reshape(C, -1).std(dim=1))) + + +def predict_single_video( + videofile: Union[str, os.PathLike], + model: nn.Module, + activation_function: nn.Module, + fusion: str, + num_rgb: int, + mean_by_channels: np.ndarray, + device: str = "cuda:0", + cpu_transform=None, + gpu_transform=None, + should_print: bool = False, + num_workers: int = 1, + batch_size: int = 16, +): """Runs inference on one input video, caching the output probabilities and image and flow feature vectors Parameters @@ -209,11 +222,13 @@ def predict_single_video(videofile: Union[str, os.PathLike], if type(device) != torch.device: device = torch.device(device) - dataset = VideoIterable(videofile, - transform=cpu_transform, - num_workers=num_workers, - sequence_length=num_rgb, - mean_by_channels=mean_by_channels) + dataset = VideoIterable( + videofile, + transform=cpu_transform, + num_workers=num_workers, + sequence_length=num_rgb, + mean_by_channels=mean_by_channels, + ) dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size) video_frame_num = len(dataset) @@ -225,11 +240,11 @@ def predict_single_video(videofile: Union[str, os.PathLike], # log.debug('model training mode: {}'.format(model.training)) for i, batch in enumerate(tqdm(dataloader, leave=False)): if isinstance(batch, dict): - images = batch['images'] + images = batch["images"] elif isinstance(batch, torch.Tensor): images = batch else: - raise ValueError('unknown input type: {}'.format(type(batch))) + raise ValueError("unknown input type: {}".format(type(batch))) if images.device != device: images = images.to(device) @@ -238,12 +253,12 @@ def predict_single_video(videofile: Union[str, os.PathLike], images = gpu_transform(images) logits = model(images) - spatial_features = activation['spatial'] - flow_features = activation['flow'] + spatial_features = activation["spatial"] + flow_features = activation["flow"] # because we are using iterable datasets, each batch will be a consecutive chunk of frames from one worker # but they might be from totally different chunks of the video. therefore, we return the frame numbers, # and use this to store into our buffer in the right location - frame_numbers = batch['framenum'].detach().cpu() + frame_numbers = batch["framenum"].detach().cpu() probabilities = activation_function(logits).detach().cpu() logits = logits.detach().cpu() @@ -255,18 +270,19 @@ def predict_single_video(videofile: Union[str, os.PathLike], has_printed = True if i == 0: # print(f'~~~ N: {N} ~~~') - buffer['probabilities'] = torch.zeros((video_frame_num, probabilities.shape[1]), dtype=probabilities.dtype) - buffer['logits'] = torch.zeros((video_frame_num, logits.shape[1]), dtype=logits.dtype) - buffer['spatial_features'] = torch.zeros((video_frame_num, spatial_features.shape[1]), - dtype=spatial_features.dtype) - buffer['flow_features'] = torch.zeros((video_frame_num, flow_features.shape[1]), dtype=flow_features.dtype) - buffer['debug'] = torch.zeros((video_frame_num, )).float() - buffer['probabilities'][frame_numbers, :] = probabilities - buffer['logits'][frame_numbers] = logits - - buffer['spatial_features'][frame_numbers] = spatial_features - buffer['flow_features'][frame_numbers] = flow_features - buffer['debug'][frame_numbers] += 1 + buffer["probabilities"] = torch.zeros((video_frame_num, probabilities.shape[1]), dtype=probabilities.dtype) + buffer["logits"] = torch.zeros((video_frame_num, logits.shape[1]), dtype=logits.dtype) + buffer["spatial_features"] = torch.zeros( + (video_frame_num, spatial_features.shape[1]), dtype=spatial_features.dtype + ) + buffer["flow_features"] = torch.zeros((video_frame_num, flow_features.shape[1]), dtype=flow_features.dtype) + buffer["debug"] = torch.zeros((video_frame_num,)).float() + buffer["probabilities"][frame_numbers, :] = probabilities + buffer["logits"][frame_numbers] = logits + + buffer["spatial_features"][frame_numbers] = spatial_features + buffer["flow_features"][frame_numbers] = flow_features + buffer["debug"][frame_numbers] += 1 return buffer @@ -291,34 +307,35 @@ def check_if_should_run_inference(h5file: Union[str, os.PathLike], mode: str, la """ should_run = True with h5py.File(h5file, mode) as f: - if latent_name in list(f.keys()): if overwrite: del f[latent_name] else: - log.warning('Latent {} already found in file {}, skipping...'.format(latent_name, h5file)) + log.warning("Latent {} already found in file {}, skipping...".format(latent_name, h5file)) should_run = False return should_run -def extract(rgbs: list, - model, - final_activation: str, - thresholds: np.ndarray, - postprocessor, - mean_by_channels, - fusion: str, - num_rgb: int, - latent_name: str, - class_names: list = ['background'], - device: str = 'cuda:0', - cpu_transform=None, - gpu_transform=None, - ignore_error=True, - overwrite=False, - num_workers: int = 1, - batch_size: int = 1): - """ Use the model to extract spatial and flow feature vectors, and predictions, and save them to disk. +def extract( + rgbs: list, + model, + final_activation: str, + thresholds: np.ndarray, + postprocessor, + mean_by_channels, + fusion: str, + num_rgb: int, + latent_name: str, + class_names: list = ["background"], + device: str = "cuda:0", + cpu_transform=None, + gpu_transform=None, + ignore_error=True, + overwrite=False, + num_workers: int = 1, + batch_size: int = 1, +): + """Use the model to extract spatial and flow feature vectors, and predictions, and save them to disk. Assumes you have a pretrained model, and K classes. Will go through each video in rgbs, run inference, extracting the 512-d spatial features, 512-d flow features, K-d probabilities to disk for each video frame. Also stores thresholds for later reloading. @@ -358,102 +375,105 @@ def extract(rgbs: list, assert isinstance(model, torch.nn.Module) device = torch.device(device) - if device.type != 'cpu': + if device.type != "cpu": torch.cuda.set_device(device) model = model.to(device) # freeze model and set to eval mode for batch normalization - model.set_mode('inference') + model.set_mode("inference") # double checknig for parameter in model.parameters(): parameter.requires_grad = False model.eval() - if final_activation == 'softmax': + if final_activation == "softmax": activation_function = nn.Softmax(dim=1) - elif final_activation == 'sigmoid': + elif final_activation == "sigmoid": activation_function = nn.Sigmoid() else: - raise ValueError('unknown final activation: {}'.format(final_activation)) + raise ValueError("unknown final activation: {}".format(final_activation)) # 16 is a decent trade off between CPU and GPU load on datasets I've tested - if batch_size == 'auto': + if batch_size == "auto": batch_size = 16 batch_size = min(batch_size, 16) - log.info('inference batch size: {}'.format(batch_size)) + log.info("inference batch size: {}".format(batch_size)) class_names = [n.encode("ascii", "ignore") for n in class_names] - log.debug('model training mode: {}'.format(model.training)) + log.debug("model training mode: {}".format(model.training)) # iterate over movie files for i in tqdm(range(len(rgbs))): rgb = rgbs[i] basename = os.path.splitext(rgb)[0] # make the outputfile have the same name as the video, with _outputs appended - h5file = basename + '_outputs.h5' - mode = 'r+' if os.path.isfile(h5file) else 'w' + h5file = basename + "_outputs.h5" + mode = "r+" if os.path.isfile(h5file) else "w" should_run = check_if_should_run_inference(h5file, mode, latent_name, overwrite) if not should_run: continue # iterate over each frame of the movie - outputs = predict_single_video(rgb, - model, - activation_function, - fusion, - num_rgb, - mean_by_channels, - device, - cpu_transform, - gpu_transform, - should_print=i == 0, - num_workers=num_workers, - batch_size=batch_size) + outputs = predict_single_video( + rgb, + model, + activation_function, + fusion, + num_rgb, + mean_by_channels, + device, + cpu_transform, + gpu_transform, + should_print=i == 0, + num_workers=num_workers, + batch_size=batch_size, + ) if i == 0: for k, v in outputs.items(): - log.info('{}: {}'.format(k, v.shape)) - if k == 'debug': - log.debug('All should be 1.0: min: {:.4f} mean {:.4f} max {:.4f}'.format( - v.min(), v.mean(), v.max())) + log.info("{}: {}".format(k, v.shape)) + if k == "debug": + log.debug( + "All should be 1.0: min: {:.4f} mean {:.4f} max {:.4f}".format(v.min(), v.mean(), v.max()) + ) # if running inference from multiple processes, this will wait until the resource is available has_worked = False while not has_worked: try: - f = h5py.File(h5file, 'r+') + f = h5py.File(h5file, "r+") except OSError: - log.warning('resource unavailable, waiting 30 seconds...') + log.warning("resource unavailable, waiting 30 seconds...") time.sleep(30) else: has_worked = True try: - predictions = postprocessor(outputs['probabilities'].detach().cpu().numpy()) + predictions = postprocessor(outputs["probabilities"].detach().cpu().numpy()) labelfile = projects.find_labelfiles(os.path.dirname(rgb))[0] labels = read_labels(labelfile) - f1 = f1_score(labels, predictions, average='macro') - log.info('macro F1: {}'.format(f1)) + f1 = f1_score(labels, predictions, average="macro") + log.info("macro F1: {}".format(f1)) except Exception as e: - log.warning('error calculating f1: {}'.format(e)) + log.warning("error calculating f1: {}".format(e)) # since this is just for debugging, ignore pass # these assignments are where it's actually saved to disk group = f.create_group(latent_name) - group.create_dataset('thresholds', data=thresholds, dtype=np.float32) - group.create_dataset('logits', data=outputs['logits'], dtype=np.float32) - group.create_dataset('P', data=outputs['probabilities'], dtype=np.float32) - group.create_dataset('spatial_features', data=outputs['spatial_features'], dtype=np.float32) - group.create_dataset('flow_features', data=outputs['flow_features'], dtype=np.float32) + group.create_dataset("thresholds", data=thresholds, dtype=np.float32) + group.create_dataset("logits", data=outputs["logits"], dtype=np.float32) + group.create_dataset("P", data=outputs["probabilities"], dtype=np.float32) + group.create_dataset("spatial_features", data=outputs["spatial_features"], dtype=np.float32) + group.create_dataset("flow_features", data=outputs["flow_features"], dtype=np.float32) dt = h5py.string_dtype() - group.create_dataset('class_names', data=class_names, dtype=dt) + group.create_dataset("class_names", data=class_names, dtype=dt) del outputs f.close() def feature_extractor_inference(cfg: DictConfig): - """Runs inference on the feature extractor from an OmegaConf configuration. + """Runs inference on the feature extractor from an OmegaConf configuration. Parameters ---------- @@ -469,23 +489,24 @@ def feature_extractor_inference(cfg: DictConfig): """ cfg = projects.setup_run(cfg) # turn "models" in your project configuration to "full/path/to/models" - log.info('args: {}'.format(' '.join(sys.argv))) + log.info("args: {}".format(" ".join(sys.argv))) - log.info('configuration used in inference: ') + log.info("configuration used in inference: ") log.info(OmegaConf.to_yaml(cfg)) - if 'sequence' not in cfg.keys() or 'latent_name' not in cfg.sequence.keys() or cfg.sequence.latent_name is None: + if "sequence" not in cfg.keys() or "latent_name" not in cfg.sequence.keys() or cfg.sequence.latent_name is None: latent_name = cfg.feature_extractor.arch else: latent_name = cfg.sequence.latent_name - log.info('Latent name used in HDF5 file: {}'.format(latent_name)) + log.info("Latent name used in HDF5 file: {}".format(latent_name)) directory_list = cfg.inference.directory_list if directory_list is None or len(directory_list) == 0: - raise ValueError('must pass list of directories from commmand line. ' - 'Ex: directory_list=[path_to_dir1,path_to_dir2]') - elif type(directory_list) == str and directory_list == 'all': + raise ValueError( + "must pass list of directories from commmand line. " "Ex: directory_list=[path_to_dir1,path_to_dir2]" + ) + elif type(directory_list) == str and directory_list == "all": basedir = cfg.project.data_path - directory_list = utils.get_subfiles(basedir, 'directory') + directory_list = utils.get_subfiles(basedir, "directory") elif isinstance(directory_list, str): directory_list = [directory_list] elif isinstance(directory_list, list): @@ -493,35 +514,36 @@ def feature_extractor_inference(cfg: DictConfig): elif isinstance(directory_list, ListConfig): directory_list = OmegaConf.to_container(directory_list) else: - raise ValueError('unknown value for directory list: {}'.format(directory_list)) + raise ValueError("unknown value for directory list: {}".format(directory_list)) # video files are found in your input list of directories using the records.yaml file that should be present # in each directory records = [] for directory in directory_list: - assert os.path.isdir(directory), 'Not a directory: {}'.format(directory) + assert os.path.isdir(directory), "Not a directory: {}".format(directory) record = projects.get_record_from_subdir(directory) - assert record['rgb'] is not None + assert record["rgb"] is not None records.append(record) - assert cfg.feature_extractor.n_flows + 1 == cfg.flow_generator.n_rgb, 'Flow generator inputs must be one greater ' \ - 'than feature extractor num flows ' + assert cfg.feature_extractor.n_flows + 1 == cfg.flow_generator.n_rgb, ( + "Flow generator inputs must be one greater " "than feature extractor num flows " + ) input_images = cfg.feature_extractor.n_flows + 1 - mode = '3d' if '3d' in cfg.feature_extractor.arch.lower() else '2d' + mode = "3d" if "3d" in cfg.feature_extractor.arch.lower() else "2d" # get the validation transforms. should have resizing, etc - cpu_transform = get_cpu_transforms(cfg.augs)['val'] - gpu_transform = get_gpu_transforms(cfg.augs, mode)['val'] - log.info('gpu_transform: {}'.format(gpu_transform)) + cpu_transform = get_cpu_transforms(cfg.augs)["val"] + gpu_transform = get_gpu_transforms(cfg.augs, mode)["val"] + log.info("gpu_transform: {}".format(gpu_transform)) rgb = [] for record in records: - rgb.append(record['rgb']) + rgb.append(record["rgb"]) - feature_extractor_weights = projects.get_weightfile_from_cfg(cfg, 'feature_extractor') + feature_extractor_weights = projects.get_weightfile_from_cfg(cfg, "feature_extractor") assert os.path.isfile(feature_extractor_weights) run_files = utils.get_run_files_from_weights(feature_extractor_weights) if cfg.inference.use_loaded_model_cfg: - loaded_config_file = run_files['config_file'] + loaded_config_file = run_files["config_file"] loaded_cfg = OmegaConf.load(loaded_config_file) loaded_model_cfg = loaded_cfg.feature_extractor current_model_cfg = cfg.feature_extractor @@ -535,51 +557,53 @@ def feature_extractor_inference(cfg: DictConfig): # log.warning('Overwriting current project classes with loaded classes! REVERT') model_components = build_feature_extractor(cfg) _, _, _, _, model = model_components - device = 'cuda:{}'.format(cfg.compute.gpu_id) + device = "cuda:{}".format(cfg.compute.gpu_id) - metrics_file = run_files['metrics_file'] + metrics_file = run_files["metrics_file"] assert os.path.isfile(metrics_file) best_epoch = utils.get_best_epoch_from_weightfile(feature_extractor_weights) # best_epoch = -1 - log.info('best epoch from loaded file: {}'.format(best_epoch)) - with h5py.File(metrics_file, 'r') as f: + log.info("best epoch from loaded file: {}".format(best_epoch)) + with h5py.File(metrics_file, "r") as f: try: - thresholds = f['val']['metrics_by_threshold']['optimum'][best_epoch, :] + thresholds = f["val"]["metrics_by_threshold"]["optimum"][best_epoch, :] except KeyError: # backwards compatibility - thresholds = f['threshold_curves']['val']['optimum'][best_epoch, :] - log.info('thresholds: {}'.format(thresholds)) + thresholds = f["threshold_curves"]["val"]["optimum"][best_epoch, :] + log.info("thresholds: {}".format(thresholds)) class_names = list(cfg.project.class_names) if len(thresholds) != len(class_names): - error_message = '''Number of classes in trained model: {} + error_message = """Number of classes in trained model: {} Number of classes in project: {} Did you add or remove behaviors after training this model? If so, please retrain! - '''.format(len(thresholds), len(class_names)) + """.format(len(thresholds), len(class_names)) raise ValueError(error_message) # class_names = projects.get_classes_from_project(cfg) class_names = np.array(class_names) postprocessor = get_postprocessor_from_cfg(cfg, thresholds) - extract(rgb, - model, - final_activation=cfg.feature_extractor.final_activation, - thresholds=thresholds, - postprocessor=postprocessor, - mean_by_channels=cfg.augs.normalization.mean, - fusion=cfg.feature_extractor.fusion, - num_rgb=input_images, - latent_name=latent_name, - device=device, - cpu_transform=cpu_transform, - gpu_transform=gpu_transform, - ignore_error=cfg.inference.ignore_error, - overwrite=cfg.inference.overwrite, - class_names=class_names, - num_workers=cfg.compute.num_workers, - batch_size=cfg.compute.batch_size) - - -if __name__ == '__main__': + extract( + rgb, + model, + final_activation=cfg.feature_extractor.final_activation, + thresholds=thresholds, + postprocessor=postprocessor, + mean_by_channels=cfg.augs.normalization.mean, + fusion=cfg.feature_extractor.fusion, + num_rgb=input_images, + latent_name=latent_name, + device=device, + cpu_transform=cpu_transform, + gpu_transform=gpu_transform, + ignore_error=cfg.inference.ignore_error, + overwrite=cfg.inference.overwrite, + class_names=class_names, + num_workers=cfg.compute.num_workers, + batch_size=cfg.compute.batch_size, + ) + + +if __name__ == "__main__": project_path = projects.get_project_path_from_cl(sys.argv) cfg = make_feature_extractor_inference_cfg(project_path, use_command_line=True) feature_extractor_inference(cfg) diff --git a/deepethogram/feature_extractor/losses.py b/deepethogram/feature_extractor/losses.py index b749eff..c574a7c 100644 --- a/deepethogram/feature_extractor/losses.py +++ b/deepethogram/feature_extractor/losses.py @@ -8,7 +8,7 @@ class NLLLossCNN(nn.Module): - """ A simple wrapper around Pytorch's NLL loss. Appropriate for models with a softmax activation function. + """A simple wrapper around Pytorch's NLL loss. Appropriate for models with a softmax activation function. Adds: optional label smoothing set loss to zero if label = ignore_index (when images have been added at beginning or end of a video, for @@ -36,8 +36,9 @@ def forward(self, outputs, label): if (1, *label.shape) == outputs.shape: label = label.unsqueeze(0) - assert outputs.shape == label.shape, 'Outputs shape must match labels! {}, {}'.format( - outputs.shape, label.shape) + assert outputs.shape == label.shape, "Outputs shape must match labels! {}, {}".format( + outputs.shape, label.shape + ) # N, K, T = outputs.shape label = label.float() @@ -61,8 +62,9 @@ def forward(self, outputs, label): loss = loss.mean() if loss < 0 or loss != loss or torch.isinf(loss).sum() > 0: - msg = 'invalid loss! loss: {}, outputs: {} labels: {}\nUse Torch anomaly detection'.format( - loss, outputs, label) + msg = "invalid loss! loss: {}, outputs: {} labels: {}\nUse Torch anomaly detection".format( + loss, outputs, label + ) raise ValueError(msg) return loss @@ -71,7 +73,7 @@ def forward(self, outputs, label): class BinaryFocalLoss(nn.Module): """Simple wrapper around nn.BCEWithLogitsLoss. Adds masking if label = ignore_index, and support for sequence inputs of shape N,K,T - + References: - https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/focal_loss.py - https://arxiv.org/pdf/1708.02002.pdf @@ -88,16 +90,16 @@ def __init__(self, pos_weight=None, ignore_index=-1, gamma: float = 0, label_smo ignore_index : int, optional Labels with these values will not count toward loss, by default -1 gamma : float, optional - focal loss gamma. see above paper. Higher values: "focus more" on hard examples rather than increasing + focal loss gamma. see above paper. Higher values: "focus more" on hard examples rather than increasing confidence on easy examples. 0 means simple BCELoss, not focal loss, by default 0 label_smoothing : float, optional Targets for BCELoss will be, instead of 0 and 1, 0+label_smoothing, 1-label_smoothing, by default 0.0 """ super().__init__() - log.info('Focal loss: gamma {:.2f} smoothing: {:.2f}'.format(gamma, label_smoothing)) + log.info("Focal loss: gamma {:.2f} smoothing: {:.2f}".format(gamma, label_smoothing)) - self.bcewithlogitsloss = nn.BCEWithLogitsLoss(weight=None, reduction='none', pos_weight=pos_weight) + self.bcewithlogitsloss = nn.BCEWithLogitsLoss(weight=None, reduction="none", pos_weight=pos_weight) self.ignore_index = ignore_index self.gamma = gamma # self.alpha = alpha @@ -112,8 +114,9 @@ def forward(self, outputs, label): # see if it's just a batch issue if (1, *label.shape) == outputs.shape: label = label.unsqueeze(0) - assert outputs.shape == label.shape, 'Outputs shape must match labels! {}, {}'.format( - outputs.shape, label.shape) + assert outputs.shape == label.shape, "Outputs shape must match labels! {}, {}".format( + outputs.shape, label.shape + ) if outputs.ndim == 3: sequence = True @@ -188,18 +191,18 @@ def forward(self, outputs, label): loss = loss.mean() if loss < 0 or loss != loss or torch.isinf(loss).sum() > 0: - msg = 'invalid loss! loss: {}, outputs: {} labels: {}\nUse Torch anomaly detection'.format( - loss, outputs, label) + msg = "invalid loss! loss: {}, outputs: {} labels: {}\nUse Torch anomaly detection".format( + loss, outputs, label + ) raise ValueError(msg) return loss class CrossEntropyLoss(nn.Module): - def __init__(self, weight=None, **kwargs): super().__init__() - self.cross_entropy = nn.CrossEntropyLoss(weight, reduction='none', ignore_index=-1, **kwargs) + self.cross_entropy = nn.CrossEntropyLoss(weight, reduction="none", ignore_index=-1, **kwargs) def forward(self, outputs: torch.Tensor, label: torch.Tensor) -> torch.Tensor: if outputs.ndim == 3 and label.ndim == 2: @@ -223,8 +226,7 @@ def forward(self, outputs: torch.Tensor, label: torch.Tensor) -> torch.Tensor: class ClassificationLoss(nn.Module): - """Simple wrapper to compute data loss and regularization loss at once - """ + """Simple wrapper to compute data loss and regularization loss at once""" def __init__(self, data_criterion: nn.Module, regularization_criterion: nn.Module): super().__init__() @@ -237,11 +239,12 @@ def forward(self, outputs, label, model): loss = data_loss + reg_loss - loss_dict = {'data_loss': data_loss.detach(), 'reg_loss': reg_loss.detach()} + loss_dict = {"data_loss": data_loss.detach(), "reg_loss": reg_loss.detach()} if loss < 0 or loss != loss or torch.isinf(loss).sum() > 0: - msg = 'invalid loss! loss: {}, outputs: {} labels: {}\nUse Torch anomaly detection'.format( - loss, outputs, label) + msg = "invalid loss! loss: {}, outputs: {} labels: {}\nUse Torch anomaly detection".format( + loss, outputs, label + ) raise ValueError(msg) - return loss, loss_dict \ No newline at end of file + return loss, loss_dict diff --git a/deepethogram/feature_extractor/models/CNN.py b/deepethogram/feature_extractor/models/CNN.py index 199d523..57cf7ac 100644 --- a/deepethogram/feature_extractor/models/CNN.py +++ b/deepethogram/feature_extractor/models/CNN.py @@ -16,8 +16,13 @@ # from nvidia # https://github.com/NVIDIA/flownet2-pytorch/blob/master/utils/tools.py def module_to_dict(module, exclude=[]): - return dict([(x, getattr(module, x)) for x in dir(module) - if isfunction(getattr(module, x)) and x not in exclude and getattr(module, x) not in exclude]) + return dict( + [ + (x, getattr(module, x)) + for x in dir(module) + if isfunction(getattr(module, x)) and x not in exclude and getattr(module, x) not in exclude + ] + ) # model definitions can be accessed by indexing into this dictionary @@ -30,16 +35,18 @@ def module_to_dict(module, exclude=[]): # https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html -def get_cnn(model_name: str, - in_channels: int = 3, - reload_imagenet: bool = True, - num_classes: int = 1000, - freeze: bool = False, - pos: np.ndarray = None, - neg: np.ndarray = None, - final_bn: bool = False, - **kwargs): - """ Initializes a pretrained CNN from Torchvision. +def get_cnn( + model_name: str, + in_channels: int = 3, + reload_imagenet: bool = True, + num_classes: int = 1000, + freeze: bool = False, + pos: np.ndarray = None, + neg: np.ndarray = None, + final_bn: bool = False, + **kwargs, +): + """Initializes a pretrained CNN from Torchvision. Currently supported models: AlexNet, DenseNet, Inception, VGGXX, ResNets, SqueezeNets, and Resnet3Ds (not torchvision) @@ -64,10 +71,10 @@ def get_cnn(model_name: str, model = models[model_name](pretrained=reload_imagenet, in_channels=in_channels, **kwargs) if freeze: - log.info('Before freezing: {:,}'.format(utils.get_num_parameters(model))) + log.info("Before freezing: {:,}".format(utils.get_num_parameters(model))) for param in model.parameters(): param.requires_grad = False - log.info('After freezing: {:,}'.format(utils.get_num_parameters(model))) + log.info("After freezing: {:,}".format(utils.get_num_parameters(model))) # we have to use the pop function because the final layer in these models has different names model, num_features, final_layer = pop(model, model_name, 1) @@ -81,14 +88,14 @@ def get_cnn(model_name: str, if pos is not None and neg is not None: with torch.no_grad(): with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=RuntimeWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) bias = np.nan_to_num(np.log(pos / neg), neginf=0.0, posinf=1.0) bias = torch.nn.Parameter(torch.from_numpy(bias).float()) if final_bn: bn_layer.bias = bias else: linear_layer.bias = bias - log.info('Custom bias: {}'.format(bias)) + log.info("Custom bias: {}".format(bias)) model = nn.Sequential(*modules) return model diff --git a/deepethogram/feature_extractor/models/classifiers/alexnet.py b/deepethogram/feature_extractor/models/classifiers/alexnet.py index 4ac2570..742c716 100644 --- a/deepethogram/feature_extractor/models/classifiers/alexnet.py +++ b/deepethogram/feature_extractor/models/classifiers/alexnet.py @@ -33,15 +33,15 @@ import torch.utils.model_zoo as model_zoo -__all__ = ['AlexNet', 'alexnet'] +__all__ = ["AlexNet", "alexnet"] model_urls = { - 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', + "alexnet": "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth", } -class AlexNet(nn.Module): +class AlexNet(nn.Module): def __init__(self, in_channels=3, num_classes=1000, dropout_p=0.5): super(AlexNet, self).__init__() self.features = nn.Sequential( @@ -58,7 +58,7 @@ def __init__(self, in_channels=3, num_classes=1000, dropout_p=0.5): nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), # modified to be fully convolutional - nn.AdaptiveMaxPool2d((6,6)) + nn.AdaptiveMaxPool2d((6, 6)), ) self.classifier = nn.Sequential( nn.Dropout(p=dropout_p), @@ -89,14 +89,14 @@ def alexnet(pretrained=False, in_channels=3, **kwargs): """ model = AlexNet(in_channels=in_channels, **kwargs) if pretrained: - print('Downloading pretrained weights...') + print("Downloading pretrained weights...") # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['alexnet']) - if in_channels !=3: + state_dict = model_zoo.load_url(model_urls["alexnet"]) + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) - return model \ No newline at end of file + return model diff --git a/deepethogram/feature_extractor/models/classifiers/densenet.py b/deepethogram/feature_extractor/models/classifiers/densenet.py index c0f8192..cde4604 100644 --- a/deepethogram/feature_extractor/models/classifiers/densenet.py +++ b/deepethogram/feature_extractor/models/classifiers/densenet.py @@ -37,28 +37,34 @@ import torch.nn.functional as F import torch.utils.model_zoo as model_zoo -__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] +__all__ = ["DenseNet", "densenet121", "densenet169", "densenet201", "densenet161"] model_urls = { - 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', - 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', - 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', - 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', + "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", + "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", + "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", + "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth", } class _DenseLayer(nn.Sequential): def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): super(_DenseLayer, self).__init__() - self.add_module('norm1', nn.BatchNorm2d(num_input_features)), - self.add_module('relu1', nn.ReLU(inplace=True)), - self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * - growth_rate, kernel_size=1, stride=1, bias=False)), - self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), - self.add_module('relu2', nn.ReLU(inplace=True)), - self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, - kernel_size=3, stride=1, padding=1, bias=False)), + (self.add_module("norm1", nn.BatchNorm2d(num_input_features)),) + (self.add_module("relu1", nn.ReLU(inplace=True)),) + ( + self.add_module( + "conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False) + ), + ) + (self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate)),) + (self.add_module("relu2", nn.ReLU(inplace=True)),) + ( + self.add_module( + "conv2", nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False) + ), + ) self.drop_rate = drop_rate def forward(self, x): @@ -73,17 +79,16 @@ def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_ra super(_DenseBlock, self).__init__() for i in range(num_layers): layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) - self.add_module('denselayer%d' % (i + 1), layer) + self.add_module("denselayer%d" % (i + 1), layer) class _Transition(nn.Sequential): def __init__(self, num_input_features, num_output_features): super(_Transition, self).__init__() - self.add_module('norm', nn.BatchNorm2d(num_input_features)) - self.add_module('relu', nn.ReLU(inplace=True)) - self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, - kernel_size=1, stride=1, bias=False)) - self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + self.add_module("norm", nn.BatchNorm2d(num_input_features)) + self.add_module("relu", nn.ReLU(inplace=True)) + self.add_module("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) + self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2)) class DenseNet(nn.Module): @@ -100,33 +105,53 @@ class DenseNet(nn.Module): num_classes (int) - number of classification classes """ - def __init__(self, in_channels=3,growth_rate=32, block_config=(6, 12, 24, 16), - num_init_features=64, bn_size=4, drop_rate=0, dropout_p=0, num_classes=1000): - + def __init__( + self, + in_channels=3, + growth_rate=32, + block_config=(6, 12, 24, 16), + num_init_features=64, + bn_size=4, + drop_rate=0, + dropout_p=0, + num_classes=1000, + ): super(DenseNet, self).__init__() # First convolution - self.features = nn.Sequential(OrderedDict([ - ('conv0', nn.Conv2d(in_channels, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), - ('norm0', nn.BatchNorm2d(num_init_features)), - ('relu0', nn.ReLU(inplace=True)), - ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), - ])) + self.features = nn.Sequential( + OrderedDict( + [ + ( + "conv0", + nn.Conv2d(in_channels, num_init_features, kernel_size=7, stride=2, padding=3, bias=False), + ), + ("norm0", nn.BatchNorm2d(num_init_features)), + ("relu0", nn.ReLU(inplace=True)), + ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ] + ) + ) # Each denseblock num_features = num_init_features for i, num_layers in enumerate(block_config): - block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, - bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) - self.features.add_module('denseblock%d' % (i + 1), block) + block = _DenseBlock( + num_layers=num_layers, + num_input_features=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + drop_rate=drop_rate, + ) + self.features.add_module("denseblock%d" % (i + 1), block) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) - self.features.add_module('transition%d' % (i + 1), trans) + self.features.add_module("transition%d" % (i + 1), trans) num_features = num_features // 2 # Final batch norm - self.features.add_module('norm5', nn.BatchNorm2d(num_features)) + self.features.add_module("norm5", nn.BatchNorm2d(num_features)) self.dropout_p = dropout_p if self.dropout_p > 0: @@ -154,66 +179,68 @@ def forward(self, x): return out -def densenet121(pretrained=False, in_channels=3,**kwargs): +def densenet121(pretrained=False, in_channels=3, **kwargs): r"""Densenet-121 model from `"Densely Connected Convolutional Networks" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = DenseNet(in_channels=in_channels,num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), - **kwargs) + model = DenseNet( + in_channels=in_channels, num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs + ) if pretrained: # '.'s are no longer allowed in module names, but pervious _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used # to find such keys. pattern = re.compile( - r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') - state_dict = model_zoo.load_url(model_urls['densenet121']) + r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) + state_dict = model_zoo.load_url(model_urls["densenet121"]) for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] - if in_channels !=3: + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) return model -def densenet169(pretrained=False, in_channels=3,**kwargs): +def densenet169(pretrained=False, in_channels=3, **kwargs): r"""Densenet-169 model from `"Densely Connected Convolutional Networks" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), - **kwargs) + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) if pretrained: # '.'s are no longer allowed in module names, but pervious _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used # to find such keys. pattern = re.compile( - r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') - state_dict = model_zoo.load_url(model_urls['densenet169']) + r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) + state_dict = model_zoo.load_url(model_urls["densenet169"]) for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] - if in_channels !=3: + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) @@ -227,26 +254,26 @@ def densenet201(pretrained=False, in_channels=3, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), - **kwargs) + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) if pretrained: # '.'s are no longer allowed in module names, but pervious _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used # to find such keys. pattern = re.compile( - r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') - state_dict = model_zoo.load_url(model_urls['densenet201']) + r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) + state_dict = model_zoo.load_url(model_urls["densenet201"]) for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] - if in_channels !=3: + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) @@ -260,16 +287,16 @@ def densenet161(pretrained=False, in_channels=3, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), - **kwargs) + model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) if pretrained: # '.'s are no longer allowed in module names, but pervious _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used # to find such keys. pattern = re.compile( - r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') - state_dict = model_zoo.load_url(model_urls['densenet161']) + r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) + state_dict = model_zoo.load_url(model_urls["densenet161"]) for key in list(state_dict.keys()): res = pattern.match(key) if res: @@ -277,4 +304,4 @@ def densenet161(pretrained=False, in_channels=3, **kwargs): state_dict[new_key] = state_dict[key] del state_dict[key] model.load_state_dict(state_dict) - return model \ No newline at end of file + return model diff --git a/deepethogram/feature_extractor/models/classifiers/inception.py b/deepethogram/feature_extractor/models/classifiers/inception.py index 47f515e..791f271 100644 --- a/deepethogram/feature_extractor/models/classifiers/inception.py +++ b/deepethogram/feature_extractor/models/classifiers/inception.py @@ -37,14 +37,15 @@ import torch.utils.model_zoo as model_zoo -__all__ = ['Inception3', 'inception_v3'] +__all__ = ["Inception3", "inception_v3"] model_urls = { # Inception v3 ported from TensorFlow - 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', + "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth", } + def inception_v3(pretrained=False, in_channels=3, **kwargs): r"""Inception v3 model architecture from `"Rethinking the Inception Architecture for Computer Vision" `_. @@ -53,15 +54,15 @@ def inception_v3(pretrained=False, in_channels=3, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: - if 'transform_input' not in kwargs: - kwargs['transform_input'] = True - model = Inception3(in_channels=3,**kwargs) + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + model = Inception3(in_channels=3, **kwargs) # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_zoo.load_url(model_urls['inception_v3_google'])) - if in_channels !=3: + state_dict = model_zoo.load_url(model_zoo.load_url(model_urls["inception_v3_google"])) + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) @@ -69,9 +70,9 @@ def inception_v3(pretrained=False, in_channels=3, **kwargs): return Inception3(**kwargs) -class Inception3(nn.Module): - def __init__(self, in_channels=3,num_classes=1000, aux_logits=True, transform_input=False, dropout_p=0): +class Inception3(nn.Module): + def __init__(self, in_channels=3, num_classes=1000, aux_logits=True, transform_input=False, dropout_p=0): super(Inception3, self).__init__() self.aux_logits = aux_logits self.transform_input = transform_input @@ -93,7 +94,7 @@ def __init__(self, in_channels=3,num_classes=1000, aux_logits=True, transform_in self.Mixed_7a = InceptionD(768) self.Mixed_7b = InceptionE(1280) self.Mixed_7c = InceptionE(2048) - self.avgpool = nn.AdaptiveAvgPool2d((1,1)) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.dropout_p = dropout_p if self.dropout_p > 0: @@ -104,7 +105,8 @@ def __init__(self, in_channels=3,num_classes=1000, aux_logits=True, transform_in for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): import scipy.stats as stats - stddev = m.stddev if hasattr(m, 'stddev') else 0.1 + + stddev = m.stddev if hasattr(m, "stddev") else 0.1 X = stats.truncnorm(-2, 2, scale=stddev) values = torch.Tensor(X.rvs(m.weight.numel())) values = values.view(m.weight.size()) @@ -178,7 +180,6 @@ def forward(self, x): class InceptionA(nn.Module): - def __init__(self, in_channels, pool_features): super(InceptionA, self).__init__() self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) @@ -210,7 +211,6 @@ def forward(self, x): class InceptionB(nn.Module): - def __init__(self, in_channels): super(InceptionB, self).__init__() self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) @@ -233,7 +233,6 @@ def forward(self, x): class InceptionC(nn.Module): - def __init__(self, in_channels, channels_7x7): super(InceptionC, self).__init__() self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) @@ -272,7 +271,6 @@ def forward(self, x): class InceptionD(nn.Module): - def __init__(self, in_channels): super(InceptionD, self).__init__() self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) @@ -298,7 +296,6 @@ def forward(self, x): class InceptionE(nn.Module): - def __init__(self, in_channels): super(InceptionE, self).__init__() self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) @@ -340,7 +337,6 @@ def forward(self, x): class InceptionAux(nn.Module): - def __init__(self, in_channels, num_classes): super(InceptionAux, self).__init__() self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) @@ -365,7 +361,6 @@ def forward(self, x): class BasicConv2d(nn.Module): - def __init__(self, in_channels, out_channels, **kwargs): super(BasicConv2d, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) @@ -374,4 +369,4 @@ def __init__(self, in_channels, out_channels, **kwargs): def forward(self, x): x = self.conv(x) x = self.bn(x) - return F.relu(x, inplace=True) \ No newline at end of file + return F.relu(x, inplace=True) diff --git a/deepethogram/feature_extractor/models/classifiers/resnet.py b/deepethogram/feature_extractor/models/classifiers/resnet.py index ccf9fda..bee236d 100644 --- a/deepethogram/feature_extractor/models/classifiers/resnet.py +++ b/deepethogram/feature_extractor/models/classifiers/resnet.py @@ -36,18 +36,26 @@ from deepethogram.utils import load_state_from_dict model_urls = { - 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', - 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', - 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', - 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', - 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", + "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", + "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", + "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", + "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", } def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) def conv1x1(in_planes, out_planes, stride=1): @@ -58,13 +66,14 @@ def conv1x1(in_planes, out_planes, stride=1): class BasicBlock(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None): + def __init__( + self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None + ): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') + raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 @@ -104,12 +113,13 @@ class Bottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None): + def __init__( + self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None + ): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups + width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) @@ -145,12 +155,10 @@ def forward(self, x): class ResNet(nn.Module): - def __init__(self, block, layers, in_channels=3, num_classes=1000, dropout_p=0, compress_to: int = 512): self.inplanes = 64 super().__init__() - self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, - bias=False) + self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self._norm_layer = nn.BatchNorm2d self.dilation = 1 @@ -166,7 +174,6 @@ def __init__(self, block, layers, in_channels=3, num_classes=1000, dropout_p=0, # self.avgpool = torch.jit.script(FastGlobalAvgPool2d(flatten=True)) # self.adaptive_max = nn.AdaptiveMaxPool2d(1) - self.dropout_p = dropout_p if dropout_p > 0: self.dropout = torch.nn.Dropout(p=dropout_p) @@ -179,12 +186,14 @@ def __init__(self, block, layers, in_channels=3, num_classes=1000, dropout_p=0, # instead, reduce the features to 512 if compress_to < 512 * block.expansion: self.compression_fc = nn.Sequential( - nn.Linear(512 * block.expansion, compress_to), - nn.ReLU(inplace=True) + nn.Linear(512 * block.expansion, compress_to), nn.ReLU(inplace=True) ) fc_infeatures = compress_to - print('Altered from standard resnet50: instead of {} inputs to the fc layer, it has {}'.format( - 512 * block.expansion, compress_to)) + print( + "Altered from standard resnet50: instead of {} inputs to the fc layer, it has {}".format( + 512 * block.expansion, compress_to + ) + ) else: self.compression_fc = nn.Identity() fc_infeatures = 512 * block.expansion @@ -196,7 +205,7 @@ def __init__(self, block, layers, in_channels=3, num_classes=1000, dropout_p=0, for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @@ -216,18 +225,28 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): # kernel_size=1, stride=1, bias=False), # uncomment below line for vanilla resnet conv1x1(self.inplanes, planes * block.expansion, stride), - # should be in all model types V + # should be in all model types V norm_layer(planes * block.expansion), ) layers = [] - layers.append(block(self.inplanes, planes, stride, downsample, self.groups, - self.base_width, previous_dilation, norm_layer)) + layers.append( + block( + self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer + ) + ) self.inplanes = planes * block.expansion for _ in range(1, blocks): - layers.append(block(self.inplanes, planes, groups=self.groups, - base_width=self.base_width, dilation=self.dilation, - norm_layer=norm_layer)) + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) return nn.Sequential(*layers) @@ -261,7 +280,7 @@ def resnet18(pretrained=False, in_channels=3, **kwargs): model = ResNet(BasicBlock, [2, 2, 2, 2], in_channels=in_channels, **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['resnet18']) + state_dict = model_zoo.load_url(model_urls["resnet18"]) if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] @@ -281,7 +300,7 @@ def resnet34(pretrained=False, in_channels=3, **kwargs): model = ResNet(BasicBlock, [3, 4, 6, 3], in_channels=in_channels, **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['resnet34']) + state_dict = model_zoo.load_url(model_urls["resnet34"]) if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] @@ -300,7 +319,7 @@ def resnet50(pretrained=False, in_channels=3, **kwargs): model = ResNet(Bottleneck, [3, 4, 6, 3], in_channels=in_channels, **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['resnet50']) + state_dict = model_zoo.load_url(model_urls["resnet50"]) if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] @@ -320,7 +339,7 @@ def resnet101(pretrained=False, in_channels=3, **kwargs): model = ResNet(Bottleneck, [3, 4, 23, 3], in_channels=in_channels, **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['resnet101']) + state_dict = model_zoo.load_url(model_urls["resnet101"]) if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] @@ -339,7 +358,7 @@ def resnet152(pretrained=False, in_channels=3, **kwargs): model = ResNet(Bottleneck, [3, 8, 36, 3], in_channels=in_channels, **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['resnet152']) + state_dict = model_zoo.load_url(model_urls["resnet152"]) if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] diff --git a/deepethogram/feature_extractor/models/classifiers/resnet3d.py b/deepethogram/feature_extractor/models/classifiers/resnet3d.py index c421344..4b26695 100644 --- a/deepethogram/feature_extractor/models/classifiers/resnet3d.py +++ b/deepethogram/feature_extractor/models/classifiers/resnet3d.py @@ -30,7 +30,6 @@ import warnings - # __all__ = [ # 'ResNet', 'resnet10_3d', 'resnet18_3d', 'resnet34_3d', 'resnet50_3d', 'resnet101_3d', # 'resnet152_3d', 'resnet200_3d' @@ -39,20 +38,12 @@ def conv3x3x3(in_planes, out_planes, stride=1): # 3x3x3 convolution with padding - return nn.Conv3d( - in_planes, - out_planes, - kernel_size=3, - stride=stride, - padding=1, - bias=False) + return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) def downsample_basic_block(x, planes, stride): out = F.avg_pool3d(x, kernel_size=1, stride=stride) - zero_pads = torch.Tensor( - out.size(0), planes - out.size(1), out.size(2), out.size(3), - out.size(4)).zero_() + zero_pads = torch.Tensor(out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4)).zero_() if isinstance(out.data, torch.cuda.FloatTensor): zero_pads = zero_pads.cuda() @@ -100,8 +91,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm3d(planes) - self.conv2 = nn.Conv3d( - planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm3d(planes) self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm3d(planes * 4) @@ -133,33 +123,17 @@ def forward(self, x): class ResNet(nn.Module): - - def __init__(self, - block, - layers, - in_channels=3, - shortcut_type='B', - num_classes=400, - dropout_p=0.5): + def __init__(self, block, layers, in_channels=3, shortcut_type="B", num_classes=400, dropout_p=0.5): self.inplanes = 64 super(ResNet, self).__init__() - self.conv1 = nn.Conv3d( - in_channels, - 64, - kernel_size=7, - stride=(1, 2, 2), - padding=(3, 3, 3), - bias=False) + self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=7, stride=(1, 2, 2), padding=(3, 3, 3), bias=False) self.bn1 = nn.BatchNorm3d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) - self.layer2 = self._make_layer( - block, 128, layers[1], shortcut_type, stride=2) - self.layer3 = self._make_layer( - block, 256, layers[2], shortcut_type, stride=2) - self.layer4 = self._make_layer( - block, 512, layers[3], shortcut_type, stride=2) + self.layer2 = self._make_layer(block, 128, layers[1], shortcut_type, stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], shortcut_type, stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], shortcut_type, stride=2) # last_duration = int(math.ceil(sample_duration / 16)) # last_size = int(math.ceil(sample_size / 32)) # print(last_duration, last_size) @@ -167,13 +141,13 @@ def __init__(self, # self.avgpool = nn.AvgPool3d( # (last_duration, last_size, last_size), stride=1) self.dropout_p = dropout_p - if dropout_p>0: + if dropout_p > 0: self.dropout = torch.nn.Dropout(p=dropout_p) self.fc = nn.Linear(512 * block.expansion, num_classes) - + for m in self.modules(): if isinstance(m, nn.Conv3d): - m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') + m.weight = nn.init.kaiming_normal(m.weight, mode="fan_out") elif isinstance(m, nn.BatchNorm3d): m.weight.data.fill_(1) m.bias.data.zero_() @@ -181,19 +155,13 @@ def __init__(self, def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: - if shortcut_type == 'A': - downsample = partial( - downsample_basic_block, - planes=planes * block.expansion, - stride=stride) + if shortcut_type == "A": + downsample = partial(downsample_basic_block, planes=planes * block.expansion, stride=stride) else: downsample = nn.Sequential( - nn.Conv3d( - self.inplanes, - planes * block.expansion, - kernel_size=1, - stride=stride, - bias=False), nn.BatchNorm3d(planes * block.expansion)) + nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm3d(planes * block.expansion), + ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) @@ -218,10 +186,10 @@ def forward(self, x): x = self.avgpool(x) x = x.view(x.size(0), -1) - if self.dropout_p>0: + if self.dropout_p > 0: x = self.dropout(x) x = self.fc(x) - + return x @@ -231,73 +199,66 @@ def get_fine_tuning_parameters(model, ft_begin_index): ft_module_names = [] for i in range(ft_begin_index, 5): - ft_module_names.append('layer{}'.format(i)) - ft_module_names.append('fc') + ft_module_names.append("layer{}".format(i)) + ft_module_names.append("fc") parameters = [] for k, v in model.named_parameters(): for ft_module in ft_module_names: if ft_module in k: - parameters.append({'params': v}) + parameters.append({"params": v}) break else: - parameters.append({'params': v, 'lr': 0.0}) + parameters.append({"params": v, "lr": 0.0}) return parameters def resnet3d_10(**kwargs): - """Constructs a ResNet-18 model. - """ + """Constructs a ResNet-18 model.""" model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) return model def resnet3d_18(**kwargs): - """Constructs a ResNet-18 model. - """ + """Constructs a ResNet-18 model.""" model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) return model def resnet3d_34(pretrained=False, in_channels=3, path_to_weights=None, **kwargs): - """Constructs a ResNet-34 model. - """ + """Constructs a ResNet-34 model.""" model = ResNet(BasicBlock, [3, 4, 6, 3], in_channels=in_channels, **kwargs) if pretrained: if in_channels != 3: - warnings.warn('in channels is {}, not reloading imagenet weights...'.format(in_channels)) + warnings.warn("in channels is {}, not reloading imagenet weights...".format(in_channels)) else: - warnings.warn('Using absolute file import for resnet3d weights') + warnings.warn("Using absolute file import for resnet3d weights") if path_to_weights is None: - raise ValueError('must specify path to weights file if pretrained: {}'.format(path_to_weights)) + raise ValueError("must specify path to weights file if pretrained: {}".format(path_to_weights)) model, _, _, _ = load_state(model, path_to_weights) return model def resnet3d_50(**kwargs): - """Constructs a ResNet-50 model. - """ + """Constructs a ResNet-50 model.""" model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) return model def resnet3d_101(**kwargs): - """Constructs a ResNet-101 model. - """ + """Constructs a ResNet-101 model.""" model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) return model def resnet3d_152(**kwargs): - """Constructs a ResNet-101 model. - """ + """Constructs a ResNet-101 model.""" model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) return model def resnet3d_200(**kwargs): - """Constructs a ResNet-101 model. - """ + """Constructs a ResNet-101 model.""" model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs) - return model \ No newline at end of file + return model diff --git a/deepethogram/feature_extractor/models/classifiers/squeezenet.py b/deepethogram/feature_extractor/models/classifiers/squeezenet.py index 15779ac..c7f8fc3 100644 --- a/deepethogram/feature_extractor/models/classifiers/squeezenet.py +++ b/deepethogram/feature_extractor/models/classifiers/squeezenet.py @@ -36,45 +36,38 @@ import torch.utils.model_zoo as model_zoo -__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] +__all__ = ["SqueezeNet", "squeezenet1_0", "squeezenet1_1"] model_urls = { - 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', - 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', + "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-a815701f.pth", + "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth", } class Fire(nn.Module): - - def __init__(self, inplanes, squeeze_planes, - expand1x1_planes, expand3x3_planes): + def __init__(self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes): super(Fire, self).__init__() self.inplanes = inplanes self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) self.squeeze_activation = nn.ReLU(inplace=True) - self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, - kernel_size=1) + self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1) self.expand1x1_activation = nn.ReLU(inplace=True) - self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, - kernel_size=3, padding=1) + self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1) self.expand3x3_activation = nn.ReLU(inplace=True) def forward(self, x): x = self.squeeze_activation(self.squeeze(x)) - return torch.cat([ - self.expand1x1_activation(self.expand1x1(x)), - self.expand3x3_activation(self.expand3x3(x)) - ], 1) + return torch.cat( + [self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 1 + ) class SqueezeNet(nn.Module): - - def __init__(self, version=1.0, in_channels=3,num_classes=1000): + def __init__(self, version=1.0, in_channels=3, num_classes=1000): super(SqueezeNet, self).__init__() if version not in [1.0, 1.1]: - raise ValueError("Unsupported SqueezeNet version {version}:" - "1.0 or 1.1 expected".format(version=version)) + raise ValueError("Unsupported SqueezeNet version {version}:" "1.0 or 1.1 expected".format(version=version)) self.num_classes = num_classes if version == 1.0: self.features = nn.Sequential( @@ -111,10 +104,7 @@ def __init__(self, version=1.0, in_channels=3,num_classes=1000): # Final convolution is initialized differently form the rest final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) self.classifier = nn.Sequential( - nn.Dropout(p=0.5), - final_conv, - nn.ReLU(inplace=True), - nn.AdaptiveAvgPool2d((1, 1)) + nn.Dropout(p=0.5), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)) ) for m in self.modules(): @@ -132,7 +122,7 @@ def forward(self, x): return x.view(x.size(0), self.num_classes) -def squeezenet1_0(pretrained=False,in_channels=3, **kwargs): +def squeezenet1_0(pretrained=False, in_channels=3, **kwargs): r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size" `_ paper. @@ -143,18 +133,18 @@ def squeezenet1_0(pretrained=False,in_channels=3, **kwargs): model = SqueezeNet(in_channels=in_channels, version=1.0, **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['squeezenet1_0']) - if in_channels !=3: + state_dict = model_zoo.load_url(model_urls["squeezenet1_0"]) + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) return model -def squeezenet1_1(pretrained=False, in_channels=3,**kwargs): +def squeezenet1_1(pretrained=False, in_channels=3, **kwargs): r"""SqueezeNet 1.1 model from the `official SqueezeNet repo `_. SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters @@ -163,15 +153,15 @@ def squeezenet1_1(pretrained=False, in_channels=3,**kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = SqueezeNet(in_channels=in_channels,version=1.1, **kwargs) + model = SqueezeNet(in_channels=in_channels, version=1.1, **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['squeeznet1_1']) - if in_channels !=3: + state_dict = model_zoo.load_url(model_urls["squeeznet1_1"]) + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) - return model \ No newline at end of file + return model diff --git a/deepethogram/feature_extractor/models/classifiers/vgg.py b/deepethogram/feature_extractor/models/classifiers/vgg.py index 6773214..14cbacb 100644 --- a/deepethogram/feature_extractor/models/classifiers/vgg.py +++ b/deepethogram/feature_extractor/models/classifiers/vgg.py @@ -33,25 +33,31 @@ import torch.utils.model_zoo as model_zoo __all__ = [ - 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', - 'vgg19_bn', 'vgg19', + "VGG", + "vgg11", + "vgg11_bn", + "vgg13", + "vgg13_bn", + "vgg16", + "vgg16_bn", + "vgg19_bn", + "vgg19", ] model_urls = { - 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', - 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', - 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', - 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', - 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', - 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', - 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', - 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', + "vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth", + "vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth", + "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", + "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", } class VGG(nn.Module): - def __init__(self, features, num_classes=1000, init_weights=True, dropout_p=0): super(VGG, self).__init__() self.features = features @@ -76,7 +82,7 @@ def forward(self, x): def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): @@ -87,10 +93,10 @@ def _initialize_weights(self): nn.init.constant_(m.bias, 0) -def make_layers(cfg, in_channels=3,batch_norm=False): +def make_layers(cfg, in_channels=3, batch_norm=False): layers = [] for v in cfg: - if v == 'M': + if v == "M": layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) @@ -103,184 +109,184 @@ def make_layers(cfg, in_channels=3,batch_norm=False): cfg = { - 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], - 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], + "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], + "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"], } -def vgg11(pretrained=False, in_channels=3,**kwargs): +def vgg11(pretrained=False, in_channels=3, **kwargs): """VGG 11-layer model (configuration "A") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['A'], in_channels=in_channels), **kwargs) + kwargs["init_weights"] = False + model = VGG(make_layers(cfg["A"], in_channels=in_channels), **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['vgg11']) - if in_channels !=3: + state_dict = model_zoo.load_url(model_urls["vgg11"]) + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) return model -def vgg11_bn(pretrained=False, in_channels=3,**kwargs): +def vgg11_bn(pretrained=False, in_channels=3, **kwargs): """VGG 11-layer model (configuration "A") with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['A'], in_channels=in_channels, batch_norm=True), **kwargs) + kwargs["init_weights"] = False + model = VGG(make_layers(cfg["A"], in_channels=in_channels, batch_norm=True), **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['vgg11_bn']) - if in_channels !=3: + state_dict = model_zoo.load_url(model_urls["vgg11_bn"]) + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) return model -def vgg13(pretrained=False, in_channels=3,**kwargs): +def vgg13(pretrained=False, in_channels=3, **kwargs): """VGG 13-layer model (configuration "B") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['B'], in_channels=in_channels), **kwargs) + kwargs["init_weights"] = False + model = VGG(make_layers(cfg["B"], in_channels=in_channels), **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['vgg13']) - if in_channels !=3: + state_dict = model_zoo.load_url(model_urls["vgg13"]) + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) return model -def vgg13_bn(pretrained=False,in_channels=3,**kwargs): +def vgg13_bn(pretrained=False, in_channels=3, **kwargs): """VGG 13-layer model (configuration "B") with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['B'], in_channels=in_channels,batch_norm=True), **kwargs) + kwargs["init_weights"] = False + model = VGG(make_layers(cfg["B"], in_channels=in_channels, batch_norm=True), **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['vgg13_bn']) - if in_channels !=3: + state_dict = model_zoo.load_url(model_urls["vgg13_bn"]) + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) return model -def vgg16(pretrained=False, in_channels=3,**kwargs): +def vgg16(pretrained=False, in_channels=3, **kwargs): """VGG 16-layer model (configuration "D") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['D'],in_channels=in_channels), **kwargs) + kwargs["init_weights"] = False + model = VGG(make_layers(cfg["D"], in_channels=in_channels), **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['vgg16']) - if in_channels !=3: + state_dict = model_zoo.load_url(model_urls["vgg16"]) + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) return model -def vgg16_bn(pretrained=False,in_channels=3,**kwargs): +def vgg16_bn(pretrained=False, in_channels=3, **kwargs): """VGG 16-layer model (configuration "D") with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['D'],in_channels=in_channels,batch_norm=True), **kwargs) + kwargs["init_weights"] = False + model = VGG(make_layers(cfg["D"], in_channels=in_channels, batch_norm=True), **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['vgg16_bn']) - if in_channels !=3: + state_dict = model_zoo.load_url(model_urls["vgg16_bn"]) + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) return model -def vgg19(pretrained=False,in_channels=3, **kwargs): +def vgg19(pretrained=False, in_channels=3, **kwargs): """VGG 19-layer model (configuration "E") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['E'],in_channels=in_channels), **kwargs) + kwargs["init_weights"] = False + model = VGG(make_layers(cfg["E"], in_channels=in_channels), **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['vgg19']) - if in_channels !=3: + state_dict = model_zoo.load_url(model_urls["vgg19"]) + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) return model -def vgg19_bn(pretrained=False, in_channels=3,**kwargs): +def vgg19_bn(pretrained=False, in_channels=3, **kwargs): """VGG 19-layer model (configuration 'E') with batch normalization Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if pretrained: - kwargs['init_weights'] = False - model = VGG(make_layers(cfg['E'], in_channels=in_channels,batch_norm=True), **kwargs) + kwargs["init_weights"] = False + model = VGG(make_layers(cfg["E"], in_channels=in_channels, batch_norm=True), **kwargs) if pretrained: # from Wang et al. 2015: Towards good practices for very deep two-stream convnets - state_dict = model_zoo.load_url(model_urls['vgg19_bn']) - if in_channels !=3: + state_dict = model_zoo.load_url(model_urls["vgg19_bn"]) + if in_channels != 3: rgb_kernel_key = list(state_dict.keys())[0] rgb_kernel = state_dict[rgb_kernel_key] - flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels,1,1) + flow_kernel = rgb_kernel.mean(dim=1).unsqueeze(1).repeat(1, in_channels, 1, 1) state_dict[rgb_kernel_key] = flow_kernel state_dict.update(state_dict) model.load_state_dict(state_dict) - return model \ No newline at end of file + return model diff --git a/deepethogram/feature_extractor/models/hidden_two_stream.py b/deepethogram/feature_extractor/models/hidden_two_stream.py index f3ac75f..24ca4c6 100644 --- a/deepethogram/feature_extractor/models/hidden_two_stream.py +++ b/deepethogram/feature_extractor/models/hidden_two_stream.py @@ -15,8 +15,9 @@ log = logging.getLogger(__name__) + class Viewer(nn.Module): - """ PyTorch module for extracting the middle image of a concatenated stack. + """PyTorch module for extracting the middle image of a concatenated stack. Example: you have 10 RGB images stacked in a channel of a tensor, so it has shape [N, 30, H, W]. viewer = Viewer(10) @@ -28,19 +29,19 @@ class Viewer(nn.Module): def __init__(self, num_images, label_location): super().__init__() self.num_images = num_images - if label_location == 'middle': + if label_location == "middle": self.start = int(num_images / 2 * 3) - elif label_location == 'causal': + elif label_location == "causal": self.start = int(num_images * 3 - 3) self.end = int(self.start + 3) def forward(self, x): - x = x[:, self.start:self.end, :, :] + x = x[:, self.start : self.end, :, :] return x class FlowOnlyClassifier(nn.Module): - """ Stack of flow generator module and flow classifier. Used in training Hidden two stream networks in a curriculum + """Stack of flow generator module and flow classifier. Used in training Hidden two stream networks in a curriculum Takes a stack of images as inputs. Generates optic flow using the flow generator; computes class probabilities using the flow classifier. @@ -55,11 +56,9 @@ class FlowOnlyClassifier(nn.Module): print(outputs.shape) # [N, K] """ - def __init__(self, flow_generator, - flow_classifier, - freeze_flow_generator: bool = True): + def __init__(self, flow_generator, flow_classifier, freeze_flow_generator: bool = True): super().__init__() - assert (isinstance(flow_generator, nn.Module) and isinstance(flow_classifier, nn.Module)) + assert isinstance(flow_generator, nn.Module) and isinstance(flow_classifier, nn.Module) self.flow_generator = flow_generator if freeze_flow_generator: @@ -78,7 +77,7 @@ def forward(self, batch): class HiddenTwoStream(nn.Module): - """ Hidden Two-Stream Network model + """Hidden Two-Stream Network model Paper: https://arxiv.org/abs/1704.00389 Classifies video inputs using a spatial CNN, using RGB video frames as inputs; and using a flow CNN, which @@ -86,10 +85,17 @@ class HiddenTwoStream(nn.Module): advantages, as optic flow loaded from disk is both more discrete and has compression artifacts. """ - def __init__(self, flow_generator, spatial_classifier, flow_classifier, fusion, - classifier_name: str, num_images: int = 11, - label_location: str = 'middle'): - """ Hidden two-stream constructor. + def __init__( + self, + flow_generator, + spatial_classifier, + flow_classifier, + fusion, + classifier_name: str, + num_images: int = 11, + label_location: str = "middle", + ): + """Hidden two-stream constructor. Args: flow_generator (nn.Module): CNN that generates optic flow from a stack of RGB frames @@ -102,89 +108,93 @@ def __init__(self, flow_generator, spatial_classifier, flow_classifier, fusion, """ super().__init__() - assert (isinstance(flow_generator, nn.Module) and isinstance(spatial_classifier, nn.Module) - and isinstance(flow_classifier, nn.Module) and isinstance(fusion, nn.Module)) + assert ( + isinstance(flow_generator, nn.Module) + and isinstance(spatial_classifier, nn.Module) + and isinstance(flow_classifier, nn.Module) + and isinstance(fusion, nn.Module) + ) self.spatial_classifier = spatial_classifier self.flow_generator = flow_generator self.flow_classifier = flow_classifier - if '3d' in classifier_name: + if "3d" in classifier_name: self.viewer = nn.Identity() else: self.viewer = Viewer(num_images, label_location) self.fusion = fusion self.frozen_state = {} - self.freeze('flow_generator') + self.freeze("flow_generator") def freeze(self, submodel_to_freeze: str): - """ Freezes a component of the model. Useful for curriculum training + """Freezes a component of the model. Useful for curriculum training Args: submodel_to_freeze (str): one of flow_generator, spatial, flow, fusion """ - if submodel_to_freeze == 'flow_generator': + if submodel_to_freeze == "flow_generator": self.flow_generator.eval() for param in self.flow_generator.parameters(): param.requires_grad = False - elif submodel_to_freeze == 'spatial': + elif submodel_to_freeze == "spatial": self.spatial_classifier.eval() for param in self.spatial_classifier.parameters(): param.requires_grad = False - elif submodel_to_freeze == 'flow': + elif submodel_to_freeze == "flow": self.flow_classifier.eval() for param in self.flow_classifier.parameters(): param.requires_grad = False - elif submodel_to_freeze == 'fusion': + elif submodel_to_freeze == "fusion": self.fusion.eval() for param in self.fusion.parameters(): param.requires_grad = False else: - raise ValueError('submodel not found:%s' % submodel_to_freeze) + raise ValueError("submodel not found:%s" % submodel_to_freeze) self.frozen_state[submodel_to_freeze] = True def set_mode(self, mode: str): - """ Freezes and unfreezes portions of the model, useful for curriculum training. + """Freezes and unfreezes portions of the model, useful for curriculum training. Args: mode (str): one of spatial_only, flow_only, fusion_only, classifier, end_to_end, or inference """ - log.debug('setting model mode: {}'.format(mode)) - if mode == 'spatial_only': - self.freeze('flow_generator') - self.freeze('flow') - self.freeze('fusion') - self.unfreeze('spatial') - elif mode == 'flow_only': - self.freeze('flow_generator') - self.freeze('spatial') - self.unfreeze('flow') - self.freeze('fusion') - elif mode == 'fusion_only': - self.freeze('flow_generator') - self.freeze('spatial') - self.freeze('flow') - self.unfreeze('fusion') - elif mode == 'classifier': - self.freeze('flow_generator') - self.unfreeze('spatial') - self.unfreeze('flow') - self.unfreeze('fusion') - elif mode == 'end_to_end': - self.unfreeze('flow_generator') - self.unfreeze('spatial') - self.unfreeze('flow') - self.unfreeze('fusion') - elif mode == 'inference': - self.freeze('flow_generator') - self.freeze('spatial') - self.freeze('flow') - self.freeze('fusion') + log.debug("setting model mode: {}".format(mode)) + if mode == "spatial_only": + self.freeze("flow_generator") + self.freeze("flow") + self.freeze("fusion") + self.unfreeze("spatial") + elif mode == "flow_only": + self.freeze("flow_generator") + self.freeze("spatial") + self.unfreeze("flow") + self.freeze("fusion") + elif mode == "fusion_only": + self.freeze("flow_generator") + self.freeze("spatial") + self.freeze("flow") + self.unfreeze("fusion") + elif mode == "classifier": + self.freeze("flow_generator") + self.unfreeze("spatial") + self.unfreeze("flow") + self.unfreeze("fusion") + elif mode == "end_to_end": + self.unfreeze("flow_generator") + self.unfreeze("spatial") + self.unfreeze("flow") + self.unfreeze("fusion") + elif mode == "inference": + self.freeze("flow_generator") + self.freeze("spatial") + self.freeze("flow") + self.freeze("fusion") else: - raise ValueError('Unknown mode: %s' % mode) + raise ValueError("Unknown mode: %s" % mode) def unfreeze(self, submodel_to_unfreeze: str): - """ Unfreezes portions of the model + """Unfreezes portions of the model Args: submodel_to_unfreeze (str): one of flow_generator, spatial, flow, or fusion @@ -192,33 +202,35 @@ def unfreeze(self, submodel_to_unfreeze: str): Returns: """ - log.debug('unfreezing model component: {}'.format(submodel_to_unfreeze)) - if submodel_to_unfreeze == 'flow_generator': + log.debug("unfreezing model component: {}".format(submodel_to_unfreeze)) + if submodel_to_unfreeze == "flow_generator": self.flow_generator.train() for param in self.flow_generator.parameters(): param.requires_grad = True - elif submodel_to_unfreeze == 'spatial': + elif submodel_to_unfreeze == "spatial": self.spatial_classifier.train() for param in self.spatial_classifier.parameters(): param.requires_grad = True - elif submodel_to_unfreeze == 'flow': + elif submodel_to_unfreeze == "flow": self.flow_classifier.train() for param in self.flow_classifier.parameters(): param.requires_grad = True - elif submodel_to_unfreeze == 'fusion': + elif submodel_to_unfreeze == "fusion": self.fusion.train() for param in self.fusion.parameters(): param.requires_grad = True else: - raise ValueError('submodel not found:%s' % submodel_to_unfreeze) + raise ValueError("submodel not found:%s" % submodel_to_unfreeze) self.frozen_state[submodel_to_unfreeze] = False def get_param_groups(self): - param_list = [{'params': self.flow_generator.parameters()}, - {'params': self.spatial_classifier.parameters()}, - {'params': self.flow_classifier.parameters()}, - {'params': self.fusion.parameters()}] - return (param_list) + param_list = [ + {"params": self.flow_generator.parameters()}, + {"params": self.spatial_classifier.parameters()}, + {"params": self.flow_classifier.parameters()}, + {"params": self.fusion.parameters()}, + ] + return param_list def forward(self, batch): with torch.no_grad(): @@ -230,19 +242,21 @@ def forward(self, batch): return self.fusion(spatial_features, flow_features) -def hidden_two_stream(classifier: str, - flow_gen: str, - num_classes: int, - fusion_style: str = 'average', - dropout_p: float = 0.9, - reload_imagenet: bool = True, - num_rgb: int = 1, - num_flows: int = 10, - pos: np.ndarray = None, - neg: np.ndarray = None, - flow_max: float = 5.0, - **kwargs): - """ Wrapper for initializing hidden two stream models +def hidden_two_stream( + classifier: str, + flow_gen: str, + num_classes: int, + fusion_style: str = "average", + dropout_p: float = 0.9, + reload_imagenet: bool = True, + num_rgb: int = 1, + num_flows: int = 10, + pos: np.ndarray = None, + neg: np.ndarray = None, + flow_max: float = 5.0, + **kwargs, +): + """Wrapper for initializing hidden two stream models Args: classifier (str): a supported classifier, e.g. resnet18, vgg16 @@ -260,22 +274,37 @@ def hidden_two_stream(classifier: str, Returns: hidden two stream network model """ - assert fusion_style in ['average', 'concatenate'] + assert fusion_style in ["average", "concatenate"] flow_generator = flow_generators[flow_gen](num_images=num_flows + 1, flow_div=flow_max) - in_channels = num_rgb * 3 if '3d' not in classifier.lower() else 3 - spatial_classifier = get_cnn(classifier, in_channels=in_channels, dropout_p=dropout_p, - num_classes=num_classes, reload_imagenet=reload_imagenet, - pos=pos, neg=neg, **kwargs) - - in_channels = num_flows * 2 if '3d' not in classifier.lower() else 2 - flow_classifier = get_cnn(classifier, in_channels=in_channels, dropout_p=dropout_p, - num_classes=num_classes, reload_imagenet=reload_imagenet, - pos=pos, neg=neg, **kwargs) - - spatial_classifier, flow_classifier, fusion = build_fusion_layer(spatial_classifier, flow_classifier, - fusion_style, num_classes) + in_channels = num_rgb * 3 if "3d" not in classifier.lower() else 3 + spatial_classifier = get_cnn( + classifier, + in_channels=in_channels, + dropout_p=dropout_p, + num_classes=num_classes, + reload_imagenet=reload_imagenet, + pos=pos, + neg=neg, + **kwargs, + ) + + in_channels = num_flows * 2 if "3d" not in classifier.lower() else 2 + flow_classifier = get_cnn( + classifier, + in_channels=in_channels, + dropout_p=dropout_p, + num_classes=num_classes, + reload_imagenet=reload_imagenet, + pos=pos, + neg=neg, + **kwargs, + ) + + spatial_classifier, flow_classifier, fusion = build_fusion_layer( + spatial_classifier, flow_classifier, fusion_style, num_classes + ) model = HiddenTwoStream(flow_generator, spatial_classifier, flow_classifier, fusion, classifier) return model @@ -296,23 +325,23 @@ def build_fusion_layer(spatial_classifier, flow_classifier, fusion_style, num_cl Returns: """ - if fusion_style == 'average' or fusion_style == 'weighted_average': + if fusion_style == "average" or fusion_style == "weighted_average": # just so we can pass them to the fusion module num_spatial_features, num_flow_features = None, None - elif fusion_style == 'concatenate': + elif fusion_style == "concatenate": spatial_classifier, num_spatial_features = remove_cnn_classifier_layer(spatial_classifier) flow_classifier, num_flow_features = remove_cnn_classifier_layer(flow_classifier) else: - raise ValueError('unknown fusion style: {}'.format(fusion_style)) + raise ValueError("unknown fusion style: {}".format(fusion_style)) - fusion = Fusion(fusion_style, num_spatial_features, num_flow_features, num_classes, - flow_fusion_weight=flow_fusion_weight) + fusion = Fusion( + fusion_style, num_spatial_features, num_flow_features, num_classes, flow_fusion_weight=flow_fusion_weight + ) return spatial_classifier, flow_classifier, fusion -def deg_f(num_classes: int, dropout_p: float = 0.9, reload_imagenet: bool = True, - pos: int = None, neg: int = None): - """ Make the DEG-fast model. Uses ResNet18 for classification, TinyMotionNet for flow generation. +def deg_f(num_classes: int, dropout_p: float = 0.9, reload_imagenet: bool = True, pos: int = None, neg: int = None): + """Make the DEG-fast model. Uses ResNet18 for classification, TinyMotionNet for flow generation. Number of flows: 10 Number of RGB frames for classification: 1 @@ -325,26 +354,29 @@ def deg_f(num_classes: int, dropout_p: float = 0.9, reload_imagenet: bool = True Returns: DEG-f model """ - classifier = 'resnet18' - flow_gen = 'TinyMotionNet' + classifier = "resnet18" + flow_gen = "TinyMotionNet" num_flows = 10 num_rgb = 1 - fusion_style = 'average' - model = hidden_two_stream(classifier, flow_gen, num_classes, - fusion_style=fusion_style, - dropout_p=dropout_p, - reload_imagenet=reload_imagenet, - num_rgb=num_rgb, - num_flows=num_flows, - pos=pos, - neg=neg) + fusion_style = "average" + model = hidden_two_stream( + classifier, + flow_gen, + num_classes, + fusion_style=fusion_style, + dropout_p=dropout_p, + reload_imagenet=reload_imagenet, + num_rgb=num_rgb, + num_flows=num_flows, + pos=pos, + neg=neg, + ) return model -def deg_m(num_classes: int, dropout_p: float = 0.9, reload_imagenet: bool = True, - pos: int = None, neg: int = None): - """ Make the DEG-medium model. Uses ResNet50 for classification, MotionNet for flow generation. +def deg_m(num_classes: int, dropout_p: float = 0.9, reload_imagenet: bool = True, pos: int = None, neg: int = None): + """Make the DEG-medium model. Uses ResNet50 for classification, MotionNet for flow generation. Number of flows: 10 Number of RGB frames for classification: 1 @@ -357,26 +389,36 @@ def deg_m(num_classes: int, dropout_p: float = 0.9, reload_imagenet: bool = True Returns: DEG-m model """ - classifier = 'resnet50' - flow_gen = 'MotionNet' + classifier = "resnet50" + flow_gen = "MotionNet" num_flows = 10 num_rgb = 1 - fusion_style = 'average' - model = hidden_two_stream(classifier, flow_gen, num_classes, - fusion_style=fusion_style, - dropout_p=dropout_p, - reload_imagenet=reload_imagenet, - num_rgb=num_rgb, - num_flows=num_flows, - pos=pos, - neg=neg) + fusion_style = "average" + model = hidden_two_stream( + classifier, + flow_gen, + num_classes, + fusion_style=fusion_style, + dropout_p=dropout_p, + reload_imagenet=reload_imagenet, + num_rgb=num_rgb, + num_flows=num_flows, + pos=pos, + neg=neg, + ) return model -def deg_s(num_classes: int, dropout_p: float = 0.9, reload_imagenet: bool = True, - pos: int = None, neg: int = None, path_to_weights: Union[str, os.PathLike] = None): - """ Make the DEG-slow model. Uses ResNet3d-34 for classification, TinyMotionNet3D for flow generation. +def deg_s( + num_classes: int, + dropout_p: float = 0.9, + reload_imagenet: bool = True, + pos: int = None, + neg: int = None, + path_to_weights: Union[str, os.PathLike] = None, +): + """Make the DEG-slow model. Uses ResNet3d-34 for classification, TinyMotionNet3D for flow generation. Number of flows: 10 Number of RGB frames for classification: 11 @@ -391,19 +433,23 @@ def deg_s(num_classes: int, dropout_p: float = 0.9, reload_imagenet: bool = True Returns: DEG-s model """ - classifier = 'resnet3d_34' - flow_gen = 'TinyMotionNet3D' + classifier = "resnet3d_34" + flow_gen = "TinyMotionNet3D" num_flows = 10 num_rgb = 11 - fusion_style = 'average' - model = hidden_two_stream(classifier, flow_gen, num_classes, - fusion_style=fusion_style, - dropout_p=dropout_p, - reload_imagenet=reload_imagenet, - num_rgb=num_rgb, - num_flows=num_flows, - pos=pos, - neg=neg, - path_to_weights=path_to_weights) + fusion_style = "average" + model = hidden_two_stream( + classifier, + flow_gen, + num_classes, + fusion_style=fusion_style, + dropout_p=dropout_p, + reload_imagenet=reload_imagenet, + num_rgb=num_rgb, + num_flows=num_flows, + pos=pos, + neg=neg, + path_to_weights=path_to_weights, + ) return model diff --git a/deepethogram/feature_extractor/models/utils.py b/deepethogram/feature_extractor/models/utils.py index 5bfc353..a360424 100644 --- a/deepethogram/feature_extractor/models/utils.py +++ b/deepethogram/feature_extractor/models/utils.py @@ -14,15 +14,15 @@ def pop(model, model_name, n_layers): # you also want to pop off the previous ReLU so that you get the unscaled linear units from fc_7 # just doing something like model = nn.Sequential(*list(model.children())[:-1]) would not get rid of # this ReLU, so that's an unintelligent version of this - if model_name.startswith('resnet'): + if model_name.startswith("resnet"): if n_layers == 1: # use empty sequential module as an identity function num_features = model.fc.in_features final_layer = model.fc model.fc = nn.Identity() else: - raise NotImplementedError('Can only pop off the final layer of a resnet') - elif model_name == 'alexnet': + raise NotImplementedError("Can only pop off the final layer of a resnet") + elif model_name == "alexnet": final_layer = model.classifier if n_layers == 1: model.classifier = nn.Sequential( @@ -35,7 +35,7 @@ def pop(model, model_name, n_layers): nn.Linear(4096, 4096), ) num_features = 4096 - log.info('Final layer of encoder: AlexNet FC_7') + log.info("Final layer of encoder: AlexNet FC_7") elif n_layers == 2: model.classifier = nn.Sequential( nn.Dropout(), @@ -43,16 +43,16 @@ def pop(model, model_name, n_layers): nn.Linear(256 * 6 * 6, 4096), ) num_features = 4096 - log.info('Final layer of encoder: AlexNet FC_6') + log.info("Final layer of encoder: AlexNet FC_6") elif n_layers == 3: # do nothing model.classifier = nn.Sequential() num_features = 256 * 6 * 6 - log.info('Final layer of encoder: AlexNet Maxpool 3') + log.info("Final layer of encoder: AlexNet Maxpool 3") else: - raise ValueError('Invalid parameter %d to pop function for %s: ' % (n_layers, model_name)) + raise ValueError("Invalid parameter %d to pop function for %s: " % (n_layers, model_name)) - elif model_name.startswith('vgg'): + elif model_name.startswith("vgg"): final_layer = model.classifier if n_layers == 1: model.classifier = nn.Sequential( @@ -62,33 +62,33 @@ def pop(model, model_name, n_layers): nn.Linear(4096, 4096), ) num_features = 4096 - log.info('Final layer of encoder: VGG fc2') + log.info("Final layer of encoder: VGG fc2") elif n_layers == 2: model.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 4096), ) - log.info('Final layer of encoder: VGG fc1') + log.info("Final layer of encoder: VGG fc1") num_features = 4096 elif n_layers == 3: model.classifier = nn.Sequential() - log.info('Final layer of encoder: VGG pool5') + log.info("Final layer of encoder: VGG pool5") num_features = 512 * 7 * 7 else: - raise ValueError('Invalid parameter %d to pop function for %s: ' % (n_layers, model_name)) + raise ValueError("Invalid parameter %d to pop function for %s: " % (n_layers, model_name)) - elif model_name.startswith('squeezenet'): + elif model_name.startswith("squeezenet"): raise NotImplementedError - elif model_name.startswith('densenet'): + elif model_name.startswith("densenet"): raise NotImplementedError - elif model_name.startswith('inception'): + elif model_name.startswith("inception"): raise NotImplementedError else: - raise ValueError('%s is not a valid model name' % (model_name)) + raise ValueError("%s is not a valid model name" % (model_name)) return model, num_features, final_layer def remove_cnn_classifier_layer(cnn): - """ Removes the final layer of a torchvision classification model, and figures out dimensionality of final layer """ + """Removes the final layer of a torchvision classification model, and figures out dimensionality of final layer""" # cnn should be a nn.Sequential(custom_model, nn.Linear) module_list = list(cnn.children()) assert (len(module_list) == 2 or len(module_list) == 3) and isinstance(module_list[1], nn.Linear) @@ -97,46 +97,54 @@ def remove_cnn_classifier_layer(cnn): cnn = nn.Sequential(*module_list) return cnn, in_features -class Fusion(nn.Module): - """ Module for fusing spatial and flow features and passing through Linear layer """ - def __init__(self, style, num_spatial_features, num_flow_features, num_classes, flow_fusion_weight=1.5, - activation=nn.Identity()): +class Fusion(nn.Module): + """Module for fusing spatial and flow features and passing through Linear layer""" + + def __init__( + self, + style, + num_spatial_features, + num_flow_features, + num_classes, + flow_fusion_weight=1.5, + activation=nn.Identity(), + ): super().__init__() self.style = style self.num_classes = num_classes self.activation = activation self.flow_fusion_weight = flow_fusion_weight - if self.style == 'average': + if self.style == "average": # self.spatial_fc = nn.Linear(num_spatial_features,num_classes) # self.flow_fc = nn.Linear(num_flow_features, num_classes) self.num_features_out = num_classes - elif self.style == 'concatenate': + elif self.style == "concatenate": self.num_features_out = num_classes self.fc = nn.Linear(num_spatial_features + num_flow_features, num_classes) - elif self.style == 'weighted_average': + elif self.style == "weighted_average": self.flow_weight = nn.Parameter(torch.Tensor([0.5]).float(), requires_grad=True) else: raise NotImplementedError def forward(self, spatial_features, flow_features): - if self.style == 'average': + if self.style == "average": # spatial_logits = self.spatial_fc(spatial_features) # flow_logits = self.flow_fc(flow_features) return (spatial_features + flow_features * self.flow_fusion_weight) / (1 + self.flow_fusion_weight) # return((spatial_logits+flow_logits*self.flow_fusion_weight)/(1+self.flow_fusion_weight)) - elif self.style == 'concatenate': + elif self.style == "concatenate": # if we're concatenating, we want the model to learn nonlinear mappings from the spatial logits and flow # logits that means we should apply an activation function note: this won't work if you froze both # encoding models features = self.activation(torch.cat((spatial_features, flow_features), dim=1)) return self.fc(features) - elif self.style == 'weighted_average': + elif self.style == "weighted_average": return self.flow_weight * flow_features + (1 - self.flow_weight) * spatial_features diff --git a/deepethogram/feature_extractor/train.py b/deepethogram/feature_extractor/train.py index f71f8a3..59e39a0 100644 --- a/deepethogram/feature_extractor/train.py +++ b/deepethogram/feature_extractor/train.py @@ -21,8 +21,11 @@ from deepethogram.data.datasets import get_datasets_from_cfg from deepethogram.feature_extractor.losses import ClassificationLoss, BinaryFocalLoss, CrossEntropyLoss from deepethogram.feature_extractor.models.CNN import get_cnn -from deepethogram.feature_extractor.models.hidden_two_stream import HiddenTwoStream, FlowOnlyClassifier, \ - build_fusion_layer +from deepethogram.feature_extractor.models.hidden_two_stream import ( + HiddenTwoStream, + FlowOnlyClassifier, + build_fusion_layer, +) from deepethogram.flow_generator.train import build_model_from_cfg as build_flow_generator from deepethogram.losses import get_regularization_loss from deepethogram.metrics import Classification @@ -34,20 +37,21 @@ os.environ["SLURM_JOB_NAME"] = "bash" warnings.filterwarnings( - 'ignore', + "ignore", category=UserWarning, - message='Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation ' - 'and test dataloaders.') + message="Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation " + "and test dataloaders.", +) # flow_generators = utils.get_models_from_module(flow_models, get_function=False) -plt.switch_backend('agg') +plt.switch_backend("agg") log = logging.getLogger(__name__) # @profile def feature_extractor_train(cfg: DictConfig) -> nn.Module: - """Trains feature extractor models from a configuration. + """Trains feature extractor models from a configuration. Parameters ---------- @@ -62,40 +66,43 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module: # rundir = os.getcwd() cfg = projects.setup_run(cfg) - log.info('args: {}'.format(' '.join(sys.argv))) + log.info("args: {}".format(" ".join(sys.argv))) # change the project paths from relative to absolute # allow for editing OmegaConf.set_struct(cfg, False) # SHOULD NEVER MODIFY / MAKE ASSIGNMENTS TO THE CFG OBJECT AFTER RIGHT HERE! - log.info('configuration used ~~~~~') + log.info("configuration used ~~~~~") log.info(OmegaConf.to_yaml(cfg)) # we build flow generator independently because you might want to load it from a different location flow_generator = build_flow_generator(cfg) - flow_weights = projects.get_weightfile_from_cfg(cfg, 'flow_generator') - assert flow_weights is not None, ('Must have a valid weightfile for flow generator. Use ' - 'deepethogram.flow_generator.train or cfg.reload.latest') - log.info('loading flow generator from file {}'.format(flow_weights)) + flow_weights = projects.get_weightfile_from_cfg(cfg, "flow_generator") + assert flow_weights is not None, ( + "Must have a valid weightfile for flow generator. Use " "deepethogram.flow_generator.train or cfg.reload.latest" + ) + log.info("loading flow generator from file {}".format(flow_weights)) flow_generator = utils.load_weights(flow_generator, flow_weights) - _, data_info = get_datasets_from_cfg(cfg, - model_type='feature_extractor', - input_images=cfg.feature_extractor.n_flows + 1) + _, data_info = get_datasets_from_cfg( + cfg, model_type="feature_extractor", input_images=cfg.feature_extractor.n_flows + 1 + ) - model_parts = build_model_from_cfg(cfg, pos=data_info['pos'], neg=data_info['neg']) + model_parts = build_model_from_cfg(cfg, pos=data_info["pos"], neg=data_info["neg"]) _, spatial_classifier, flow_classifier, fusion, model = model_parts # log.info('model: {}'.format(model)) num_classes = len(cfg.project.class_names) - utils.save_dict_to_yaml(data_info['split'], os.path.join(cfg.run.dir, 'split.yaml')) + utils.save_dict_to_yaml(data_info["split"], os.path.join(cfg.run.dir, "split.yaml")) - metrics = get_metrics(cfg.run.dir, - num_classes=num_classes, - num_parameters=utils.get_num_parameters(spatial_classifier), - key_metric='f1_class_mean_nobg', - num_workers=cfg.compute.metrics_workers) + metrics = get_metrics( + cfg.run.dir, + num_classes=num_classes, + num_parameters=utils.get_num_parameters(spatial_classifier), + key_metric="f1_class_mean_nobg", + num_workers=cfg.compute.metrics_workers, + ) # cfg.compute.batch_size will be changed by the automatic batch size finder, possibly. store here so that # with each step of the curriculum, we can auto-tune it @@ -111,9 +118,9 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module: # train spatial model, then flow model, then both end-to-end # dataloaders = get_dataloaders_from_cfg(cfg, model_type='feature_extractor', # input_images=cfg.feature_extractor.n_rgb) - datasets, data_info = get_datasets_from_cfg(cfg, - model_type='feature_extractor', - input_images=cfg.feature_extractor.n_rgb) + datasets, data_info = get_datasets_from_cfg( + cfg, model_type="feature_extractor", input_images=cfg.feature_extractor.n_rgb + ) stopper = get_stopper(cfg) criterion = get_criterion(cfg, spatial_classifier, data_info) @@ -130,16 +137,16 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module: trainer.fit(lightning_module) # free RAM. note: this doesn't do much - log.info('free ram') + log.info("free ram") del datasets, lightning_module, trainer, stopper, data_info torch.cuda.empty_cache() gc.collect() # return - datasets, data_info = get_datasets_from_cfg(cfg, - model_type='feature_extractor', - input_images=cfg.feature_extractor.n_flows + 1) + datasets, data_info = get_datasets_from_cfg( + cfg, model_type="feature_extractor", input_images=cfg.feature_extractor.n_flows + 1 + ) # re-initialize stopper so that it doesn't think we need to stop due to the previous model stopper = get_stopper(cfg) cfg.compute.batch_size = original_batch_size @@ -161,10 +168,10 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module: gc.collect() model = HiddenTwoStream(flow_generator, spatial_classifier, flow_classifier, fusion, cfg.feature_extractor.arch) - model.set_mode('classifier') - datasets, data_info = get_datasets_from_cfg(cfg, - model_type='feature_extractor', - input_images=cfg.feature_extractor.n_flows + 1) + model.set_mode("classifier") + datasets, data_info = get_datasets_from_cfg( + cfg, model_type="feature_extractor", input_images=cfg.feature_extractor.n_flows + 1 + ) criterion = get_criterion(cfg, model, data_info) stopper = get_stopper(cfg) cfg.compute.batch_size = original_batch_size @@ -184,11 +191,10 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module: # utils.save_hidden_two_stream(model, rundir, dict(cfg), stopper.epoch_counter) -def build_model_from_cfg(cfg: DictConfig, - pos: np.ndarray = None, - neg: np.ndarray = None, - num_classes: int = None) -> tuple: - """ Builds feature extractor from a configuration object. +def build_model_from_cfg( + cfg: DictConfig, pos: np.ndarray = None, neg: np.ndarray = None, num_classes: int = None +) -> tuple: + """Builds feature extractor from a configuration object. Parameters ---------- @@ -211,68 +217,71 @@ def build_model_from_cfg(cfg: DictConfig, hidden two stream CNN """ # device = torch.device("cuda:" + str(cfg.compute.gpu_id) if torch.cuda.is_available() else "cpu") - device = 'cpu' - feature_extractor_weights = projects.get_weightfile_from_cfg(cfg, 'feature_extractor') + device = "cpu" + feature_extractor_weights = projects.get_weightfile_from_cfg(cfg, "feature_extractor") if num_classes is None: num_classes = len(cfg.project.class_names) - log.info('feature extractor weightfile: {}'.format(feature_extractor_weights)) + log.info("feature extractor weightfile: {}".format(feature_extractor_weights)) - in_channels = cfg.feature_extractor.n_rgb * 3 if '3d' not in cfg.feature_extractor.arch else 3 + in_channels = cfg.feature_extractor.n_rgb * 3 if "3d" not in cfg.feature_extractor.arch else 3 reload_imagenet = feature_extractor_weights is None - if cfg.feature_extractor.arch == 'resnet3d_34': - assert feature_extractor_weights is not None, 'Must specify path to resnet3d weights!' - spatial_classifier = get_cnn(cfg.feature_extractor.arch, - in_channels=in_channels, - dropout_p=cfg.feature_extractor.dropout_p, - num_classes=num_classes, - reload_imagenet=reload_imagenet, - pos=pos, - neg=neg, - final_bn=cfg.feature_extractor.final_bn) + if cfg.feature_extractor.arch == "resnet3d_34": + assert feature_extractor_weights is not None, "Must specify path to resnet3d weights!" + spatial_classifier = get_cnn( + cfg.feature_extractor.arch, + in_channels=in_channels, + dropout_p=cfg.feature_extractor.dropout_p, + num_classes=num_classes, + reload_imagenet=reload_imagenet, + pos=pos, + neg=neg, + final_bn=cfg.feature_extractor.final_bn, + ) # load this specific component from the weight file if feature_extractor_weights is not None: - spatial_classifier = utils.load_feature_extractor_components(spatial_classifier, - feature_extractor_weights, - 'spatial', - device=device) - in_channels = cfg.feature_extractor.n_flows * 2 if '3d' not in cfg.feature_extractor.arch else 2 - flow_classifier = get_cnn(cfg.feature_extractor.arch, - in_channels=in_channels, - dropout_p=cfg.feature_extractor.dropout_p, - num_classes=num_classes, - reload_imagenet=reload_imagenet, - pos=pos, - neg=neg, - final_bn=cfg.feature_extractor.final_bn) + spatial_classifier = utils.load_feature_extractor_components( + spatial_classifier, feature_extractor_weights, "spatial", device=device + ) + in_channels = cfg.feature_extractor.n_flows * 2 if "3d" not in cfg.feature_extractor.arch else 2 + flow_classifier = get_cnn( + cfg.feature_extractor.arch, + in_channels=in_channels, + dropout_p=cfg.feature_extractor.dropout_p, + num_classes=num_classes, + reload_imagenet=reload_imagenet, + pos=pos, + neg=neg, + final_bn=cfg.feature_extractor.final_bn, + ) # load this specific component from the weight file if feature_extractor_weights is not None: - flow_classifier = utils.load_feature_extractor_components(flow_classifier, - feature_extractor_weights, - 'flow', - device=device) + flow_classifier = utils.load_feature_extractor_components( + flow_classifier, feature_extractor_weights, "flow", device=device + ) flow_generator = build_flow_generator(cfg) - flow_weights = projects.get_weightfile_from_cfg(cfg, 'flow_generator') - assert flow_weights is not None, ('Must have a valid weightfile for flow generator. Use ' - 'deepethogram.flow_generator.train or cfg.reload.latest') + flow_weights = projects.get_weightfile_from_cfg(cfg, "flow_generator") + assert flow_weights is not None, ( + "Must have a valid weightfile for flow generator. Use " "deepethogram.flow_generator.train or cfg.reload.latest" + ) flow_generator = utils.load_weights(flow_generator, flow_weights, device=device) - spatial_classifier, flow_classifier, fusion = build_fusion_layer(spatial_classifier, flow_classifier, - cfg.feature_extractor.fusion, num_classes) + spatial_classifier, flow_classifier, fusion = build_fusion_layer( + spatial_classifier, flow_classifier, cfg.feature_extractor.fusion, num_classes + ) if feature_extractor_weights is not None: - fusion = utils.load_feature_extractor_components(fusion, feature_extractor_weights, 'fusion', device=device) + fusion = utils.load_feature_extractor_components(fusion, feature_extractor_weights, "fusion", device=device) model = HiddenTwoStream(flow_generator, spatial_classifier, flow_classifier, fusion, cfg.feature_extractor.arch) # log.info(model.fusion.flow_weight) - model.set_mode('classifier') + model.set_mode("classifier") return flow_generator, spatial_classifier, flow_classifier, fusion, model class HiddenTwoStreamLightning(BaseLightningModule): - """Lightning Module for training Feature Extractor models - """ + """Lightning Module for training Feature Extractor models""" def __init__(self, model: nn.Module, cfg: DictConfig, datasets: dict, metrics, criterion: nn.Module): """constructor @@ -295,9 +304,9 @@ def __init__(self, model: nn.Module, cfg: DictConfig, datasets: dict, metrics, c self.has_logged_channels = False # for convenience self.final_activation = self.hparams.feature_extractor.final_activation - if self.final_activation == 'softmax': + if self.final_activation == "softmax": self.activation = nn.Softmax(dim=1) - elif self.final_activation == 'sigmoid': + elif self.final_activation == "sigmoid": self.activation = nn.Sigmoid() else: raise NotImplementedError @@ -320,13 +329,13 @@ def validate_batch_size(self, batch: dict): if self.hparams.compute.dali: # no idea why they wrap this, maybe they fixed it? batch = batch[0] - if 'images' in batch.keys(): + if "images" in batch.keys(): # weird case of batch size = 1 somehow getting squeezed out - if batch['images'].ndim != 5: - batch['images'] = batch['images'].unsqueeze(0) - if 'labels' in batch.keys(): - if self.final_activation == 'sigmoid' and batch['labels'].ndim == 1: - batch['labels'] = batch['labels'].unsqueeze(0) + if batch["images"].ndim != 5: + batch["images"] = batch["images"].unsqueeze(0) + if "labels" in batch.keys(): + if self.final_activation == "sigmoid" and batch["labels"].ndim == 1: + batch["labels"] = batch["labels"].unsqueeze(0) return batch def training_step(self, batch: dict, batch_idx: int): @@ -346,24 +355,22 @@ def training_step(self, batch: dict, batch_idx: int): """ # use the forward function # return the image tensor so we can visualize after gpu transforms - images, outputs = self(batch, 'train') + images, outputs = self(batch, "train") probabilities = self.activation(outputs) - loss, loss_dict = self.criterion(outputs, batch['labels'], self.model) + loss, loss_dict = self.criterion(outputs, batch["labels"], self.model) - self.visualize_batch(images, probabilities, batch['labels'], 'train') + self.visualize_batch(images, probabilities, batch["labels"], "train") # save the model outputs to a buffer for various metrics - self.metrics.buffer.append('train', { - 'loss': loss.detach(), - 'probs': probabilities.detach(), - 'labels': batch['labels'].detach() - }) + self.metrics.buffer.append( + "train", {"loss": loss.detach(), "probs": probabilities.detach(), "labels": batch["labels"].detach()} + ) # add the individual components of the loss to the metrics buffer - self.metrics.buffer.append('train', loss_dict) + self.metrics.buffer.append("train", loss_dict) # need to use the native logger for lr scheduling, etc. - self.log('train/loss', loss.detach().cpu()) + self.log("train/loss", loss.detach().cpu()) return loss def validation_step(self, batch: dict, batch_idx: int): @@ -376,19 +383,17 @@ def validation_step(self, batch: dict, batch_idx: int): batch_idx : int index in validation epoch """ - images, outputs = self(batch, 'val') + images, outputs = self(batch, "val") probabilities = self.activation(outputs) - loss, loss_dict = self.criterion(outputs, batch['labels'], self.model) - self.visualize_batch(images, probabilities, batch['labels'], 'val') - self.metrics.buffer.append('val', { - 'loss': loss.detach(), - 'probs': probabilities.detach(), - 'labels': batch['labels'].detach() - }) - self.metrics.buffer.append('val', loss_dict) + loss, loss_dict = self.criterion(outputs, batch["labels"], self.model) + self.visualize_batch(images, probabilities, batch["labels"], "val") + self.metrics.buffer.append( + "val", {"loss": loss.detach(), "probs": probabilities.detach(), "labels": batch["labels"].detach()} + ) + self.metrics.buffer.append("val", loss_dict) # need to use the native logger for lr scheduling, etc. - self.log('val/loss', loss.detach().cpu()) + self.log("val/loss", loss.detach().cpu()) def test_step(self, batch: dict, batch_idx: int): """runs test step @@ -400,15 +405,13 @@ def test_step(self, batch: dict, batch_idx: int): batch_idx : int index in test epoch """ - images, outputs = self(batch, 'test') + images, outputs = self(batch, "test") probabilities = self.activation(outputs) - loss, loss_dict = self.criterion(outputs, batch['labels'], self.model) - self.metrics.buffer.append('test', { - 'loss': loss.detach(), - 'probs': probabilities.detach(), - 'labels': batch['labels'].detach() - }) - self.metrics.buffer.append('test', loss_dict) + loss, loss_dict = self.criterion(outputs, batch["labels"], self.model) + self.metrics.buffer.append( + "test", {"loss": loss.detach(), "probs": probabilities.detach(), "labels": batch["labels"].detach()} + ) + self.metrics.buffer.append("test", loss_dict) def visualize_batch(self, images: torch.Tensor, probs: torch.Tensor, labels: torch.Tensor, split: str): """generates example images of a given batch and saves to disk as a Matplotlib figure @@ -432,7 +435,7 @@ def visualize_batch(self, images: torch.Tensor, probs: torch.Tensor, labels: tor if viz_cnt > self.hparams.train.viz_examples: return # this method can be used for sequence models as well - if hasattr(self.model, 'flow_generator'): + if hasattr(self.model, "flow_generator"): with torch.no_grad(): # re-compute optic flows for this batch for visualization batch_size = images.size(0) @@ -446,27 +449,25 @@ def visualize_batch(self, images: torch.Tensor, probs: torch.Tensor, labels: tor # only output the highest res flow flows = self.model.flow_generator(images)[0].detach() - inputs = self.gpu_transforms['denormalize'](images).detach() + inputs = self.gpu_transforms["denormalize"](images).detach() fig = plt.figure(figsize=(14, 14)) - viz.visualize_hidden(inputs.detach().cpu(), - flows.detach().cpu(), - probs.detach().cpu(), - labels.detach().cpu(), - fig=fig) + viz.visualize_hidden( + inputs.detach().cpu(), flows.detach().cpu(), probs.detach().cpu(), labels.detach().cpu(), fig=fig + ) # this should happen in the save figure function, but for some reason it doesn't - viz.save_figure(fig, 'batch_with_flows', True, viz_cnt, split) + viz.save_figure(fig, "batch_with_flows", True, viz_cnt, split) del images, probs, labels, flows torch.cuda.empty_cache() else: fig = plt.figure(figsize=(14, 14)) with torch.no_grad(): - inputs = self.gpu_transforms['denormalize'](images) + inputs = self.gpu_transforms["denormalize"](images) viz.visualize_batch_spatial(inputs, probs, labels, fig=fig) - viz.save_figure(fig, 'batch_spatial', True, viz_cnt, split) + viz.save_figure(fig, "batch_spatial", True, viz_cnt, split) try: # should've been closed in viz.save_figure. this is double checking plt.close(fig) - plt.close('all') + plt.close("all") except: pass torch.cuda.empty_cache() @@ -489,7 +490,7 @@ def forward(self, batch: dict, mode: str) -> Tuple[torch.Tensor, torch.Tensor]: """ batch = self.validate_batch_size(batch) # lightning handles transfer to device - images = batch['images'] + images = batch["images"] # no-grad should work in the apply_gpu_transforms method; adding here just in case with torch.no_grad(): # augment the input images. in training, this will perturb brightness, contrast, etc. @@ -497,8 +498,8 @@ def forward(self, batch: dict, mode: str) -> Tuple[torch.Tensor, torch.Tensor]: gpu_images = self.apply_gpu_transforms(images, mode) if torch.sum(gpu_images != gpu_images) > 0 or torch.sum(torch.isinf(gpu_images)) > 0: - log.error('nan in gpu augs') - raise ValueError('nan in GPU augmentations!') + log.error("nan in gpu augs") + raise ValueError("nan in GPU augmentations!") # make sure normalization works, etc. self.log_image_statistics(gpu_images) @@ -506,8 +507,8 @@ def forward(self, batch: dict, mode: str) -> Tuple[torch.Tensor, torch.Tensor]: outputs = self.model(gpu_images) if torch.sum(outputs != outputs) > 0: - log.error('nan in model outputs') - raise ValueError('nan in model outputs!') + log.error("nan in model outputs") + raise ValueError("nan in model outputs!") return gpu_images, outputs @@ -522,29 +523,29 @@ def log_image_statistics(self, images: torch.Tensor): if not self.has_logged_channels and log.isEnabledFor(logging.DEBUG): if len(images.shape) == 4: N, C, H, W = images.shape - log.debug('inputs shape: NCHW: {} {} {} {}'.format(N, C, H, W)) - log.debug('channel min: {}'.format(images[0].reshape(C, -1).min(dim=1).values)) - log.debug('channel mean: {}'.format(images[0].reshape(C, -1).mean(dim=1))) - log.debug('channel max : {}'.format(images[0].reshape(C, -1).max(dim=1).values)) - log.debug('channel std : {}'.format(images[0].reshape(C, -1).std(dim=1))) + log.debug("inputs shape: NCHW: {} {} {} {}".format(N, C, H, W)) + log.debug("channel min: {}".format(images[0].reshape(C, -1).min(dim=1).values)) + log.debug("channel mean: {}".format(images[0].reshape(C, -1).mean(dim=1))) + log.debug("channel max : {}".format(images[0].reshape(C, -1).max(dim=1).values)) + log.debug("channel std : {}".format(images[0].reshape(C, -1).std(dim=1))) elif len(images.shape) == 5: N, C, T, H, W = images.shape - log.debug('inputs shape: NCTHW: {} {} {} {} {}'.format(N, C, T, H, W)) - log.debug('channel min: {}'.format(images[0].min(dim=2).values)) - log.debug('channel mean: {}'.format(images[0].mean(dim=2))) - log.debug('channel max : {}'.format(images[0].max(dim=2).values)) - log.debug('channel std : {}'.format(images[0].std(dim=2))) + log.debug("inputs shape: NCTHW: {} {} {} {} {}".format(N, C, T, H, W)) + log.debug("channel min: {}".format(images[0].min(dim=2).values)) + log.debug("channel mean: {}".format(images[0].mean(dim=2))) + log.debug("channel max : {}".format(images[0].max(dim=2).values)) + log.debug("channel std : {}".format(images[0].std(dim=2))) self.has_logged_channels = True def log_model_statistics(self, images, outputs, labels): # will print out shape and min, mean, max, std along image channels # we use the isEnabledFor flag so that this doesnt slow down training in the non-debug case - log.debug('outputs: {}'.format(outputs)) - log.debug('labels: {}'.format(labels)) - log.debug('outputs: {}'.format(outputs.shape)) - log.debug('labels: {}'.format(labels.shape)) - log.debug('label max: {}'.format(labels.max())) - log.debug('label min: {}'.format(labels.min())) + log.debug("outputs: {}".format(outputs)) + log.debug("labels: {}".format(labels)) + log.debug("outputs: {}".format(outputs.shape)) + log.debug("labels: {}".format(labels.shape)) + log.debug("label max: {}".format(labels.max())) + log.debug("label min: {}".format(labels.min())) def get_criterion(cfg: DictConfig, model, data_info: dict, device=None): @@ -578,21 +579,21 @@ def get_criterion(cfg: DictConfig, model, data_info: dict, device=None): if final_activation is not softmax or sigmoid """ final_activation = cfg.feature_extractor.final_activation - if final_activation == 'softmax': - if 'weight' in list(data_info.keys()): - weight = data_info['loss_weight'] + if final_activation == "softmax": + if "weight" in list(data_info.keys()): + weight = data_info["loss_weight"] else: weight = None data_criterion = CrossEntropyLoss(weight=weight) - elif final_activation == 'sigmoid': - pos_weight = data_info['pos_weight'] + elif final_activation == "sigmoid": + pos_weight = data_info["pos_weight"] if type(pos_weight) == np.ndarray: pos_weight = torch.from_numpy(pos_weight) pos_weight = pos_weight.to(device) if device is not None else pos_weight - data_criterion = BinaryFocalLoss(pos_weight=pos_weight, - gamma=cfg.train.loss_gamma, - label_smoothing=cfg.train.label_smoothing) + data_criterion = BinaryFocalLoss( + pos_weight=pos_weight, gamma=cfg.train.loss_gamma, label_smoothing=cfg.train.label_smoothing + ) else: raise NotImplementedError @@ -604,13 +605,15 @@ def get_criterion(cfg: DictConfig, model, data_info: dict, device=None): return criterion -def get_metrics(rundir: Union[str, bytes, os.PathLike], - num_classes: int, - num_parameters: Union[int, float], - is_kinetics: bool = False, - key_metric='loss', - num_workers: int = 4): - """ get metrics object for classification. See deepethogram/metrics.py. +def get_metrics( + rundir: Union[str, bytes, os.PathLike], + num_classes: int, + num_parameters: Union[int, float], + is_kinetics: bool = False, + key_metric="loss", + num_workers: int = 4, +): + """get metrics object for classification. See deepethogram/metrics.py. In brief, it's a Metrics object that provides utilities for computing metrics over predictions, saving various metrics to disk, tracking your learning rate across epochs, etc. @@ -625,20 +628,17 @@ def get_metrics(rundir: Union[str, bytes, os.PathLike], Returns: Classification metrics object """ - metric_list = ['accuracy', 'mean_class_accuracy', 'f1'] + metric_list = ["accuracy", "mean_class_accuracy", "f1"] if not is_kinetics: - metric_list.append('confusion') - log.info('key metric: {}'.format(key_metric)) - metrics = Classification(rundir, - key_metric, - num_parameters, - num_classes=num_classes, - evaluate_threshold=True, - num_workers=num_workers) + metric_list.append("confusion") + log.info("key metric: {}".format(key_metric)) + metrics = Classification( + rundir, key_metric, num_parameters, num_classes=num_classes, evaluate_threshold=True, num_workers=num_workers + ) return metrics -if __name__ == '__main__': +if __name__ == "__main__": project_path = projects.get_project_path_from_cl(sys.argv) cfg = make_feature_extractor_train_cfg(project_path, use_command_line=True) diff --git a/deepethogram/file_io.py b/deepethogram/file_io.py index 513e5c1..48cc1f5 100644 --- a/deepethogram/file_io.py +++ b/deepethogram/file_io.py @@ -9,29 +9,29 @@ def read_labels(labelfile: Union[str, os.PathLike]) -> np.ndarray: - """ convenience function for reading labels from a .csv or .h5 file """ + """convenience function for reading labels from a .csv or .h5 file""" labeltype = os.path.splitext(labelfile)[1][1:] - if labeltype == 'csv': + if labeltype == "csv": label = read_label_csv(labelfile) # return(read_label_csv(labelfile)) - elif labeltype == 'h5': + elif labeltype == "h5": label = read_label_hdf5(labelfile) # return(read_label_hdf5(labelfile)) else: - raise ValueError('Unknown labeltype: {}'.format(labeltype)) + raise ValueError("Unknown labeltype: {}".format(labeltype)) H, W = label.shape # labels should be time x num_behaviors if W > H: label = label.T if label.shape[1] == 1: # add a background class - warnings.warn('binary labels found, adding background class') + warnings.warn("binary labels found, adding background class") label = np.hstack((np.logical_not(label), label)) return label def read_label_hdf5(labelfile: Union[str, os.PathLike]) -> np.ndarray: - """ read labels from an HDF5 file. Must end in .h5 + """read labels from an HDF5 file. Must end in .h5 Assumes that labels are in a dataset with name 'scores' or 'labels' Parameters @@ -42,18 +42,18 @@ def read_label_hdf5(labelfile: Union[str, os.PathLike]) -> np.ndarray: ------- """ - with h5py.File(labelfile, 'r') as f: + with h5py.File(labelfile, "r") as f: keys = list(f.keys()) - if 'scores' in keys: - key = 'scores' - elif 'labels' in keys: - key = 'labels' + if "scores" in keys: + key = "scores" + elif "labels" in keys: + key = "labels" else: - raise ValueError('not sure which dataset in hdf5 contains labels: {}'.format(keys)) + raise ValueError("not sure which dataset in hdf5 contains labels: {}".format(keys)) label = f[key][:].astype(np.int64) if label.ndim == 1: label = label[..., np.newaxis] - return (label) + return label def read_label_csv(labelfile: Union[str, os.PathLike]) -> np.ndarray: @@ -87,7 +87,7 @@ def convert_video(videofile: Union[str, os.PathLike], movie_format: str, *args, movie_format : str One of ['ffmpeg', 'opencv', 'hdf5', 'directory'] ffmpeg: converts to libx264 using ffmpeg - OpenCV: converts to MJPG (by default) + OpenCV: converts to MJPG (by default) HDF5: converts to an HDF5 file or PNG bytestrings. Lossless compression. Compromise between fastest reading (PNG directory), and ease of transferring across filesystems (e.g. to a server) directory: explodes into directory of PNG files @@ -99,16 +99,16 @@ def convert_video(videofile: Union[str, os.PathLike], movie_format: str, *args, """ with VideoReader(videofile) as reader: basename = os.path.splitext(videofile)[0] - if movie_format == 'ffmpeg': - out_filename = basename + '.mp4' - elif movie_format == 'opencv': - out_filename = basename + '.avi' - elif movie_format == 'hdf5': - out_filename = basename + '.h5' - elif movie_format == 'directory': + if movie_format == "ffmpeg": + out_filename = basename + ".mp4" + elif movie_format == "opencv": + out_filename = basename + ".avi" + elif movie_format == "hdf5": + out_filename = basename + ".h5" + elif movie_format == "directory": out_filename = basename else: - raise ValueError('unexpected value of movie format: {}'.format(movie_format)) + raise ValueError("unexpected value of movie format: {}".format(movie_format)) with VideoWriter(out_filename, movie_format=movie_format, *args, **kwargs) as writer: for frame in reader: writer.write(frame) diff --git a/deepethogram/flow_generator/__init__.py b/deepethogram/flow_generator/__init__.py index a861d3e..d354269 100644 --- a/deepethogram/flow_generator/__init__.py +++ b/deepethogram/flow_generator/__init__.py @@ -1 +1 @@ -# from . import train \ No newline at end of file +# from . import train diff --git a/deepethogram/flow_generator/inference.py b/deepethogram/flow_generator/inference.py index 43e2cdc..4bb0887 100644 --- a/deepethogram/flow_generator/inference.py +++ b/deepethogram/flow_generator/inference.py @@ -19,24 +19,26 @@ from deepethogram.data.datasets import VideoIterable from deepethogram.flow_generator.train import build_model_from_cfg as build_flow_generator from deepethogram.flow_generator.utils import flow_to_rgb_polar, flow_to_rgb -log = logging.getLogger(__name__) +log = logging.getLogger(__name__) -def extract_movie(in_video, - out_video, - model, - device, - cpu_transform, - gpu_transform, - mean_by_channels, - num_workers=1, - num_rgb=11, - maxval: int = 5, - polar: bool = True, - movie_format: str = 'ffmpeg', - batch_size=1, - save_rgb_side_by_side=False) -> None: +def extract_movie( + in_video, + out_video, + model, + device, + cpu_transform, + gpu_transform, + mean_by_channels, + num_workers=1, + num_rgb=11, + maxval: int = 5, + polar: bool = True, + movie_format: str = "ffmpeg", + batch_size=1, + save_rgb_side_by_side=False, +) -> None: if polar: convert = partial(flow_to_rgb_polar, maxval=maxval) else: @@ -46,28 +48,30 @@ def extract_movie(in_video, # if type(device) != torch.device: # device = torch.device(device) - dataset = VideoIterable(in_video, - transform=cpu_transform, - num_workers=num_workers, - sequence_length=num_rgb, - mean_by_channels=mean_by_channels) + dataset = VideoIterable( + in_video, + transform=cpu_transform, + num_workers=num_workers, + sequence_length=num_rgb, + mean_by_channels=mean_by_channels, + ) dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size) # log.debug('model training mode: {}'.format(model.training)) with VideoWriter(out_video, movie_format) as vid: for i, batch in enumerate(tqdm(dataloader, leave=False)): if isinstance(batch, dict): - images = batch['images'] + images = batch["images"] elif isinstance(batch, torch.Tensor): images = batch else: - raise ValueError('unknown input type: {}'.format(type(batch))) + raise ValueError("unknown input type: {}".format(type(batch))) if images.device != device: images = images.to(device) # images = batch['images'] with torch.no_grad(): - images = gpu_transform['val'](images) + images = gpu_transform["val"](images) flows = model(images) # TODO: only run optic flow calc on each frame once! # since we are running batches of 11 images, the batches look like @@ -84,7 +88,7 @@ def extract_movie(in_video, flow = flow.detach().cpu().numpy().transpose(1, 2, 0) flow_map = convert(flow) if save_rgb_side_by_side: - images = gpu_transform['denormalize'](images) + images = gpu_transform["denormalize"](images) im = images[:, :, 5, ...].squeeze().detach().cpu().numpy() im = im.transpose(1, 2, 0) * 255 im = im.clip(min=0, max=255).astype(np.uint8) @@ -95,7 +99,7 @@ def extract_movie(in_video, vid.write(out) -def get_run_files_from_weights(weightfile: Union[str, os.PathLike], metrics_prefix='classification') -> dict: +def get_run_files_from_weights(weightfile: Union[str, os.PathLike], metrics_prefix="classification") -> dict: """from model weights, gets the configuration for that model and its metrics file Parameters @@ -109,16 +113,16 @@ def get_run_files_from_weights(weightfile: Union[str, os.PathLike], metrics_pref config_file: path to config file metrics_file: path to metrics file """ - loaded_config_file = os.path.join(os.path.dirname(weightfile), 'config.yaml') + loaded_config_file = os.path.join(os.path.dirname(weightfile), "config.yaml") if not os.path.isfile(loaded_config_file): # weight file should be at most one-subdirectory-down from rundir - loaded_config_file = os.path.join(os.path.dirname(os.path.dirname(weightfile)), 'config.yaml') - assert os.path.isfile(loaded_config_file), 'no associated config file for weights! {}'.format(weightfile) + loaded_config_file = os.path.join(os.path.dirname(os.path.dirname(weightfile)), "config.yaml") + assert os.path.isfile(loaded_config_file), "no associated config file for weights! {}".format(weightfile) - metrics_file = os.path.join(os.path.dirname(weightfile), f'{metrics_prefix}_metrics.h5') + metrics_file = os.path.join(os.path.dirname(weightfile), f"{metrics_prefix}_metrics.h5") if not os.path.isfile(metrics_file): - metrics_file = os.path.join(os.path.dirname(os.path.dirname(weightfile)), f'{metrics_prefix}_metrics.h5') - assert os.path.isfile(metrics_file), 'no associated metrics file for weights! {}'.format(weightfile) + metrics_file = os.path.join(os.path.dirname(os.path.dirname(weightfile)), f"{metrics_prefix}_metrics.h5") + assert os.path.isfile(metrics_file), "no associated metrics file for weights! {}".format(weightfile) return dict(config_file=loaded_config_file, metrics_file=metrics_file) @@ -127,23 +131,24 @@ def flow_generator_inference(cfg): # make configuration cfg = projects.setup_run(cfg) # turn "models" in your project configuration to "full/path/to/models" - log.info('args: {}'.format(' '.join(sys.argv))) - log.info('configuration used in inference: ') + log.info("args: {}".format(" ".join(sys.argv))) + log.info("configuration used in inference: ") log.info(OmegaConf.to_yaml(cfg)) - if 'sequence' not in cfg.keys() or 'latent_name' not in cfg.sequence.keys() or cfg.sequence.latent_name is None: + if "sequence" not in cfg.keys() or "latent_name" not in cfg.sequence.keys() or cfg.sequence.latent_name is None: latent_name = cfg.feature_extractor.arch else: latent_name = cfg.sequence.latent_name - log.info('Latent name used in HDF5 file: {}'.format(latent_name)) + log.info("Latent name used in HDF5 file: {}".format(latent_name)) directory_list = cfg.inference.directory_list # figure out which videos to run inference on if directory_list is None or len(directory_list) == 0: - raise ValueError('must pass list of directories from commmand line. ' - 'Ex: directory_list=[path_to_dir1,path_to_dir2]') - elif type(directory_list) == str and directory_list == 'all': + raise ValueError( + "must pass list of directories from commmand line. " "Ex: directory_list=[path_to_dir1,path_to_dir2]" + ) + elif type(directory_list) == str and directory_list == "all": basedir = cfg.project.data_path - directory_list = utils.get_subfiles(basedir, 'directory') + directory_list = utils.get_subfiles(basedir, "directory") elif isinstance(directory_list, str): directory_list = [directory_list] elif isinstance(directory_list, list): @@ -151,35 +156,36 @@ def flow_generator_inference(cfg): elif isinstance(directory_list, ListConfig): directory_list = OmegaConf.to_container(directory_list) else: - raise ValueError('unknown value for directory list: {}'.format(directory_list)) + raise ValueError("unknown value for directory list: {}".format(directory_list)) # video files are found in your input list of directories using the records.yaml file that should be present # in each directory records = [] for directory in directory_list: - assert os.path.isdir(directory), 'Not a directory: {}'.format(directory) + assert os.path.isdir(directory), "Not a directory: {}".format(directory) record = projects.get_record_from_subdir(directory) - assert record['rgb'] is not None + assert record["rgb"] is not None records.append(record) rgb = [] for record in records: - rgb.append(record['rgb']) + rgb.append(record["rgb"]) - assert cfg.feature_extractor.n_flows + 1 == cfg.flow_generator.n_rgb, 'Flow generator inputs must be one greater ' \ - 'than feature extractor num flows ' + assert cfg.feature_extractor.n_flows + 1 == cfg.flow_generator.n_rgb, ( + "Flow generator inputs must be one greater " "than feature extractor num flows " + ) # set up gpu augmentation input_images = cfg.feature_extractor.n_flows + 1 - mode = '3d' if '3d' in cfg.feature_extractor.arch.lower() else '2d' + mode = "3d" if "3d" in cfg.feature_extractor.arch.lower() else "2d" # get the validation transforms. should have resizing, etc - cpu_transform = get_cpu_transforms(cfg.augs)['val'] + cpu_transform = get_cpu_transforms(cfg.augs)["val"] gpu_transform = get_gpu_transforms(cfg.augs, mode) - log.info('gpu_transform: {}'.format(gpu_transform)) + log.info("gpu_transform: {}".format(gpu_transform)) - flow_generator_weights = projects.get_weightfile_from_cfg(cfg, 'flow_generator') + flow_generator_weights = projects.get_weightfile_from_cfg(cfg, "flow_generator") assert os.path.isfile(flow_generator_weights) - run_files = get_run_files_from_weights(flow_generator_weights, 'opticalflow') + run_files = get_run_files_from_weights(flow_generator_weights, "opticalflow") if cfg.inference.use_loaded_model_cfg: - loaded_config_file = run_files['config_file'] + loaded_config_file = run_files["config_file"] loaded_cfg = OmegaConf.load(loaded_config_file) loaded_model_cfg = loaded_cfg.flow_generator current_model_cfg = cfg.flow_generator @@ -189,49 +195,51 @@ def flow_generator_inference(cfg): # therefore, overwrite the loaded configuration with the current weights cfg.flow_generator.weights = flow_generator_weights # num_classes = len(loaded_cfg.project.class_names) - log.info('model loaded') + log.info("model loaded") # log.warning('Overwriting current project classes with loaded classes! REVERT') model = build_flow_generator(cfg) - model = utils.load_weights(model, flow_generator_weights, device='cpu') + model = utils.load_weights(model, flow_generator_weights, device="cpu") # _, _, _, _, model = model_components - device = 'cuda:{}'.format(cfg.compute.gpu_id) + device = "cuda:{}".format(cfg.compute.gpu_id) model = model.to(device) - movie_format = 'ffmpeg' + movie_format = "ffmpeg" maxval = 5 polar = True save_rgb_side_by_side = True for movie in tqdm(rgb): - out_video = os.path.splitext(movie)[0] + '_flows' - if movie_format == 'directory': + out_video = os.path.splitext(movie)[0] + "_flows" + if movie_format == "directory": pass - elif movie_format == 'hdf5': - out_video += '.h5' - elif movie_format == 'ffmpeg': - out_video += '.mp4' + elif movie_format == "hdf5": + out_video += ".h5" + elif movie_format == "ffmpeg": + out_video += ".mp4" else: - out_video += '.avi' + out_video += ".avi" if os.path.isdir(out_video): shutil.rmtree(out_video) elif os.path.isfile(out_video): os.remove(out_video) - extract_movie(movie, - out_video, - model, - device, - cpu_transform, - gpu_transform, - mean_by_channels=cfg.augs.normalization.mean, - num_workers=1, - num_rgb=input_images, - maxval=maxval, - polar=polar, - movie_format=movie_format, - save_rgb_side_by_side=save_rgb_side_by_side) - - -if __name__ == '__main__': + extract_movie( + movie, + out_video, + model, + device, + cpu_transform, + gpu_transform, + mean_by_channels=cfg.augs.normalization.mean, + num_workers=1, + num_rgb=input_images, + maxval=maxval, + polar=polar, + movie_format=movie_format, + save_rgb_side_by_side=save_rgb_side_by_side, + ) + + +if __name__ == "__main__": project_path = projects.get_project_path_from_cl(sys.argv) cfg = make_feature_extractor_inference_cfg(project_path, use_command_line=True) flow_generator_inference(cfg) diff --git a/deepethogram/flow_generator/losses.py b/deepethogram/flow_generator/losses.py index e247cc4..9f3f141 100644 --- a/deepethogram/flow_generator/losses.py +++ b/deepethogram/flow_generator/losses.py @@ -10,9 +10,10 @@ # SSIM from this repo: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py def gaussian(window_size, sigma): - gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) return gauss / gauss.sum() + # SSIM from this repo: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py def create_window(window_size, channel): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) @@ -34,8 +35,8 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True): sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 - C1 = 0.01 ** 2 - C2 = 0.03 ** 2 + C1 = 0.01**2 + C2 = 0.03**2 ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) @@ -71,17 +72,18 @@ def forward(self, img1, img2): self.channel = channel similarity = _ssim(img1, img2, window, self.window_size, channel, self.size_average) - return ((1 - similarity) / self.denominator) + return (1 - similarity) / self.denominator + # PyTorch is NCHW -def gradient_x(img, mode='constant'): +def gradient_x(img, mode="constant"): # use indexing to get horizontal gradients, which chops off one column gx = img[:, :, :, :-1] - img[:, :, :, 1:] # pad the results with one zeros column on the right return F.pad(gx, (0, 1, 0, 0), mode=mode) -def gradient_y(img, mode='constant'): +def gradient_y(img, mode="constant"): # use indexing to get vertical gradients, which chops off one row gy = img[:, :, :-1, :] - img[:, :, 1:, :] # pad the result with one zeros column on bottom @@ -91,23 +93,23 @@ def gradient_y(img, mode='constant'): def get_gradients(img): gx = gradient_x(img) gy = gradient_y(img) - return (gx + gy) + return gx + gy # simpler version of ssim loss, uses average pooling instead of guassian kernels def SSIM_simple(x, y): - C1 = 0.01 ** 2 - C2 = 0.03 ** 2 + C1 = 0.01**2 + C2 = 0.03**2 mu_x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=0) mu_y = F.avg_pool2d(y, kernel_size=3, stride=1, padding=0) - sigma_x = F.avg_pool2d(x ** 2, kernel_size=3, stride=1, padding=0) - mu_x ** 2 - sigma_y = F.avg_pool2d(y ** 2, kernel_size=3, stride=1, padding=0) - mu_y ** 2 + sigma_x = F.avg_pool2d(x**2, kernel_size=3, stride=1, padding=0) - mu_x**2 + sigma_y = F.avg_pool2d(y**2, kernel_size=3, stride=1, padding=0) - mu_y**2 sigma_xy = F.avg_pool2d(x * y, kernel_size=3, stride=1, padding=0) - mu_x * mu_y SSIM_n = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2) - SSIM_d = (mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2) + SSIM_d = (mu_x**2 + mu_y**2 + C1) * (sigma_x + sigma_y + C2) SSIM_full = SSIM_n / SSIM_d @@ -128,10 +130,12 @@ def total_generalized_variation(image, flow): gx2_flowy = gradient_x(gradient_x(flowy)) gy2_flowy = gradient_y(gradient_y(flowy)) - TGV = torch.abs(gx2_flowx) * torch.exp(-torch.abs(gx2_image)) + \ - torch.abs(gy2_flowx) * torch.exp(-torch.abs(gy2_image)) + \ - torch.abs(gx2_flowy) * torch.exp(-torch.abs(gx2_image)) + \ - torch.abs(gy2_flowy) * torch.exp(-torch.abs(gy2_image)) + TGV = ( + torch.abs(gx2_flowx) * torch.exp(-torch.abs(gx2_image)) + + torch.abs(gy2_flowx) * torch.exp(-torch.abs(gy2_image)) + + torch.abs(gx2_flowy) * torch.exp(-torch.abs(gx2_image)) + + torch.abs(gy2_flowy) * torch.exp(-torch.abs(gy2_image)) + ) return TGV @@ -153,6 +157,7 @@ def smoothness_firstorder(image, flow): smoothness_y = torch.abs(flow_gradients_y) * weights_y return smoothness_x, smoothness_y + def charbonnier(tensor, alpha=0.4, eps=1e-4): return (tensor * tensor + eps * eps) ** alpha @@ -162,10 +167,17 @@ def charbonnier_smoothness(flows, alpha=0.3, eps=1e-7): class MotionNetLoss(torch.nn.Module): - def __init__(self, regularization_criterion, - is_multiscale=True, smooth_weights=[.01, .02, .04, .08, .16], highres: bool = False, - calculate_ssim_full: bool = False, flow_sparsity: bool = False, sparsity_weight: float = 1.0, - smooth_weight_multiplier: float = 1.0): + def __init__( + self, + regularization_criterion, + is_multiscale=True, + smooth_weights=[0.01, 0.02, 0.04, 0.08, 0.16], + highres: bool = False, + calculate_ssim_full: bool = False, + flow_sparsity: bool = False, + sparsity_weight: float = 1.0, + smooth_weight_multiplier: float = 1.0, + ): super(MotionNetLoss, self).__init__() self.smooth_weights = [i * smooth_weight_multiplier for i in smooth_weights] if highres: @@ -175,18 +187,21 @@ def __init__(self, regularization_criterion, self.calculate_ssim_full = calculate_ssim_full self.flow_sparsity = flow_sparsity self.sparsity_weight = sparsity_weight - log.info('Using MotionNet Loss with settings: smooth_weights: {} flow_sparsity: {} sparsity_weight: {}'.format( - self.smooth_weights, flow_sparsity, sparsity_weight)) + log.info( + "Using MotionNet Loss with settings: smooth_weights: {} flow_sparsity: {} sparsity_weight: {}".format( + self.smooth_weights, flow_sparsity, sparsity_weight + ) + ) self.regularization_criterion = regularization_criterion def forward(self, originals, images, reconstructed, outputs, model: torch.nn.Module): if type(images) is not tuple: - images = (images) + images = images if type(reconstructed) is not tuple: - reconstructed = (reconstructed) + reconstructed = reconstructed if outputs[0].size(0) != images[0].size(0): - raise ValueError('Image shape: ', images[0].shape, 'Flow shape:', outputs[0].shape) + raise ValueError("Image shape: ", images[0].shape, "Flow shape:", outputs[0].shape) if self.is_multiscale: # handle validation case where you only output one scale if len(images) == 1: @@ -194,10 +209,11 @@ def forward(self, originals, images, reconstructed, outputs, model: torch.nn.Mod elif len(images) == len(self.smooth_weights): weights = self.smooth_weights elif len(images) < len(self.smooth_weights): - weights = self.smooth_weights[0:len(images)] + weights = self.smooth_weights[0 : len(images)] else: - raise ValueError('Incorrect number of multiscale outputs: %d. Expected %d' - % (len(images), len(self.smooth_weights))) + raise ValueError( + "Incorrect number of multiscale outputs: %d. Expected %d" % (len(images), len(self.smooth_weights)) + ) else: weights = [1] # Components of image loss! @@ -234,11 +250,11 @@ def forward(self, originals, images, reconstructed, outputs, model: torch.nn.Mod # print(images[0].shape) if H != recon_h or W != recon_w: # t0 = originals[:, 0:3, ...] - t0 = originals[:, :num_images * 3, ...].contiguous().view(N * num_images, 3, H, W) + t0 = originals[:, : num_images * 3, ...].contiguous().view(N * num_images, 3, H, W) # print('t0: {}'.format(t0.shape)) recon = reconstructed[0] # print('recon: {}'.format(recon.shape)) - recon_fullsize = F.interpolate(recon, size=(H, W), mode='bilinear', align_corners=False) + recon_fullsize = F.interpolate(recon, size=(H, W), mode="bilinear", align_corners=False) # t0 = originals[:, :num_images * 3, ...].contiguous().view(N * num_images, 3, H, W) # print('t0: ', t0.shape) # recon_fullsize = F.interpolate(reconstructed[0], size=(H, W), mode='bilinear', align_corners=False) @@ -258,28 +274,34 @@ def forward(self, originals, images, reconstructed, outputs, model: torch.nn.Mod smoothness_per_image = torch.stack(smooth_mean).sum(dim=0) regularization_loss = self.regularization_criterion(model) - - loss_components = {'reg_loss': regularization_loss.detach(), - 'SSIM': SSIM_per_image.detach(), - 'L1': L1_per_image.detach(), - 'smoothness': smoothness_per_image.detach(), - 'SSIM_full': SSIM_full_mean.detach() - } + + loss_components = { + "reg_loss": regularization_loss.detach(), + "SSIM": SSIM_per_image.detach(), + "L1": L1_per_image.detach(), + "smoothness": smoothness_per_image.detach(), + "SSIM_full": SSIM_full_mean.detach(), + } # mean across batch elements - loss = (torch.mean(SSIM_per_image) + torch.mean(L1_per_image) + torch.mean(smoothness_per_image) + - regularization_loss) + loss = ( + torch.mean(SSIM_per_image) + + torch.mean(L1_per_image) + + torch.mean(smoothness_per_image) + + regularization_loss + ) if loss != loss: - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() if self.flow_sparsity: # sum across scales flow_sparsity = torch.stack(flow_l1s).sum(dim=0) # mean across batch loss += torch.mean(flow_sparsity) - loss_components['flow_sparsity'] = flow_sparsity.detach() + loss_components["flow_sparsity"] = flow_sparsity.detach() del flow_l1s - del (SSIMs, SSIM_mean, L1s, L1_mean, smooths, smooth_mean) return loss, loss_components diff --git a/deepethogram/flow_generator/models/FlowNetS.py b/deepethogram/flow_generator/models/FlowNetS.py index 1f2b1f2..ee6f91b 100644 --- a/deepethogram/flow_generator/models/FlowNetS.py +++ b/deepethogram/flow_generator/models/FlowNetS.py @@ -87,7 +87,7 @@ def __init__(self, num_images=2, batchNorm=True, flow_div=1): init.uniform_(m.bias) init.xavier_uniform_(m.weight) # init_deconv_bilinear(m.weight) - self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') + self.upsample1 = nn.Upsample(scale_factor=4, mode="bilinear") def forward(self, x): out_conv1 = self.conv1(x) @@ -112,26 +112,25 @@ def forward(self, x): # a value of 1 in flow6 will naively be mapped to a value of 1 in flow5. now, this movement of 1 pixel no # longer means 1/8 of the image, it will only move 1/16 of the image. So to correct for this, we multiply # the upsampled version by 2. - flow6_up = self.upsampled_flow6_to_5(flow6)*2 + flow6_up = self.upsampled_flow6_to_5(flow6) * 2 out_deconv5 = self.deconv5(out_conv6) concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1) flow5 = self.predict_flow5(concat5) * self.flow_div - flow5_up = self.upsampled_flow5_to_4(flow5)*2 + flow5_up = self.upsampled_flow5_to_4(flow5) * 2 out_deconv4 = self.deconv4(concat5) concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1) flow4 = self.predict_flow4(concat4) * self.flow_div - flow4_up = self.upsampled_flow4_to_3(flow4)*2 + flow4_up = self.upsampled_flow4_to_3(flow4) * 2 out_deconv3 = self.deconv3(concat4) if get_hw(out_conv3) != get_hw(out_deconv3): - out_conv3 = F.interpolate(out_conv3, size=get_hw(out_deconv3), - mode='bilinear', align_corners=False) + out_conv3 = F.interpolate(out_conv3, size=get_hw(out_deconv3), mode="bilinear", align_corners=False) concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1) flow3 = self.predict_flow3(concat3) * self.flow_div - flow3_up = self.upsampled_flow3_to_2(flow3)*2 + flow3_up = self.upsampled_flow3_to_2(flow3) * 2 out_deconv2 = self.deconv2(concat3) concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1) @@ -140,4 +139,4 @@ def forward(self, x): if self.training: return flow2, flow3, flow4, flow5, flow6 else: - return flow2, + return (flow2,) diff --git a/deepethogram/flow_generator/models/MotionNet.py b/deepethogram/flow_generator/models/MotionNet.py index bfed156..bcf73bc 100644 --- a/deepethogram/flow_generator/models/MotionNet.py +++ b/deepethogram/flow_generator/models/MotionNet.py @@ -32,6 +32,7 @@ log = logging.getLogger(__name__) + class MotionNet(nn.Module): def __init__(self, num_images=11, batchNorm=True, flow_div=1): super(MotionNet, self).__init__() @@ -40,7 +41,7 @@ def __init__(self, num_images=11, batchNorm=True, flow_div=1): self.out_channels = int((num_images - 1) * 2) self.batchNorm = batchNorm - log.debug('ignoring flow div value of {}: setting to 1 instead'.format(flow_div)) + log.debug("ignoring flow div value of {}: setting to 1 instead".format(flow_div)) self.flow_div = 1 self.conv1 = conv(self.batchNorm, self.num_images * 3, 64) @@ -94,7 +95,7 @@ def __init__(self, num_images=11, batchNorm=True, flow_div=1): init.uniform_(m.bias) init.xavier_uniform_(m.weight) # init_deconv_bilinear(m.weight) - self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') + self.upsample1 = nn.Upsample(scale_factor=4, mode="bilinear") # print('Flow div: {}'.format(self.flow_div)) @@ -127,7 +128,7 @@ def forward(self, x): # a value of 1 in flow6 will naively be mapped to a value of 1 in flow5. now, this movement of 1 pixel no # longer means 1/8 of the image, it will only move 1/16 of the image. So to correct for this, we multiply # the upsampled version by 2. - flow6_up = self.upsampled_flow6_to_5(flow6)*2 + flow6_up = self.upsampled_flow6_to_5(flow6) * 2 out_deconv5 = self.deconv5(out_conv6) # if the image sizes are not divisible by 8, there will be rounding errors in the size @@ -141,7 +142,7 @@ def forward(self, x): out_interconv5 = self.xconv5(concat5) flow5 = self.predict_flow5(out_interconv5) * self.flow_div - flow5_up = self.upsampled_flow5_to_4(flow5)*2 + flow5_up = self.upsampled_flow5_to_4(flow5) * 2 out_deconv4 = self.deconv4(concat5) # if get_hw(out_conv4) != get_hw(out_deconv4): @@ -152,7 +153,7 @@ def forward(self, x): concat4 = self.concat((out_conv4, out_deconv4, flow5_up)) out_interconv4 = self.xconv4(concat4) flow4 = self.predict_flow4(out_interconv4) * self.flow_div - flow4_up = self.upsampled_flow4_to_3(flow4)*2 + flow4_up = self.upsampled_flow4_to_3(flow4) * 2 out_deconv3 = self.deconv3(concat4) # if the image sizes are not divisible by 8, there will be rounding errors in the size @@ -164,7 +165,7 @@ def forward(self, x): concat3 = self.concat((out_conv3, out_deconv3, flow4_up)) out_interconv3 = self.xconv3(concat3) flow3 = self.predict_flow3(out_interconv3) * self.flow_div - flow3_up = self.upsampled_flow3_to_2(flow3)*2 + flow3_up = self.upsampled_flow3_to_2(flow3) * 2 out_deconv2 = self.deconv2(concat3) # if get_hw(out_conv2) != get_hw(out_deconv2): diff --git a/deepethogram/flow_generator/models/TinyMotionNet.py b/deepethogram/flow_generator/models/TinyMotionNet.py index a2d790e..e722bfd 100644 --- a/deepethogram/flow_generator/models/TinyMotionNet.py +++ b/deepethogram/flow_generator/models/TinyMotionNet.py @@ -30,6 +30,7 @@ log = logging.getLogger(__name__) + # modified from https://github.com/NVIDIA/flownet2-pytorch/blob/master/networks/FlowNetSD.py # https://github.com/NVIDIA/flownet2-pytorch/blob/master/networks/submodules.py class TinyMotionNet(nn.Module): @@ -47,7 +48,7 @@ def __init__(self, num_images=11, input_channels=None, batchNorm=True, output_ch # self.out_channels = int((num_images-1)*2) self.batchNorm = batchNorm - log.debug('ignoring flow div value of {}: setting to 1 instead'.format(flow_div)) + log.debug("ignoring flow div value of {}: setting to 1 instead".format(flow_div)) self.flow_div = 1 self.conv1 = conv(self.batchNorm, self.input_channels, 64, kernel_size=7) diff --git a/deepethogram/flow_generator/models/TinyMotionNet3D.py b/deepethogram/flow_generator/models/TinyMotionNet3D.py index 1d3b05e..776aaa4 100644 --- a/deepethogram/flow_generator/models/TinyMotionNet3D.py +++ b/deepethogram/flow_generator/models/TinyMotionNet3D.py @@ -22,9 +22,9 @@ from .components import * # import warnings + class TinyMotionNet3D(nn.Module): - def __init__(self, num_images=11, input_channels=3, batchnorm=True, flow_div=1, - channel_base=16): + def __init__(self, num_images=11, input_channels=3, batchnorm=True, flow_div=1, channel_base=16): super().__init__() self.num_images = num_images if input_channels is None: @@ -35,26 +35,39 @@ def __init__(self, num_images=11, input_channels=3, batchnorm=True, flow_div=1, # self.out_channels = int((num_images-1)*2) self.batchnorm = batchnorm bias = not self.batchnorm - logging.debug('ignoring flow div value of {}: setting to 1 instead'.format(flow_div)) + logging.debug("ignoring flow div value of {}: setting to 1 instead".format(flow_div)) self.flow_div = 1 - self.channels = [channel_base * (2 ** i) for i in range(0, 3)] + self.channels = [channel_base * (2**i) for i in range(0, 3)] print(self.channels) self.conv1 = conv3d(self.input_channels, self.channels[0], kernel_size=7, batchnorm=batchnorm, bias=bias) - self.conv2 = conv3d(self.channels[0], self.channels[1], stride=(1, 2, 2), kernel_size=5, batchnorm=batchnorm, - bias=bias) + self.conv2 = conv3d( + self.channels[0], self.channels[1], stride=(1, 2, 2), kernel_size=5, batchnorm=batchnorm, bias=bias + ) self.conv3 = conv3d(self.channels[1], self.channels[2], stride=(1, 2, 2), batchnorm=batchnorm, bias=bias) self.conv4 = conv3d(self.channels[2], self.channels[1], stride=(1, 2, 2), batchnorm=batchnorm, bias=bias) self.conv5 = conv3d(self.channels[1], self.channels[1], kernel_size=(2, 3, 3), batchnorm=batchnorm, bias=bias) - self.deconv3 = deconv3d(self.channels[1], self.channels[1], kernel_size=(1, 4, 4), stride=(1, 2, 2), - padding=(0, 1, 1), - batchnorm=batchnorm, bias=bias) - self.deconv2 = deconv3d(self.channels[1], self.channels[0], kernel_size=(1, 4, 4), stride=(1, 2, 2), - padding=(0, 1, 1), - batchnorm=batchnorm, bias=bias) + self.deconv3 = deconv3d( + self.channels[1], + self.channels[1], + kernel_size=(1, 4, 4), + stride=(1, 2, 2), + padding=(0, 1, 1), + batchnorm=batchnorm, + bias=bias, + ) + self.deconv2 = deconv3d( + self.channels[1], + self.channels[0], + kernel_size=(1, 4, 4), + stride=(1, 2, 2), + padding=(0, 1, 1), + batchnorm=batchnorm, + bias=bias, + ) self.iconv3 = conv3d(self.channels[2], self.channels[2], kernel_size=(2, 3, 3), batchnorm=batchnorm, bias=bias) self.iconv2 = conv3d(self.channels[1], self.channels[1], kernel_size=(2, 3, 3), batchnorm=batchnorm, bias=bias) @@ -118,4 +131,4 @@ def forward(self, x): flow2 = self.predict_flow2(out_interconv2) * self.flow_div # print('flow2: {}'.format(flow2.shape)) - return flow2, flow3, flow4 \ No newline at end of file + return flow2, flow3, flow4 diff --git a/deepethogram/flow_generator/models/__init__.py b/deepethogram/flow_generator/models/__init__.py index 78c9d6a..e69de29 100644 --- a/deepethogram/flow_generator/models/__init__.py +++ b/deepethogram/flow_generator/models/__init__.py @@ -1 +0,0 @@ -from . import FlowNetS, MotionNet, TinyMotionNet \ No newline at end of file diff --git a/deepethogram/flow_generator/models/components.py b/deepethogram/flow_generator/models/components.py index 19cae3d..8e1e73b 100644 --- a/deepethogram/flow_generator/models/components.py +++ b/deepethogram/flow_generator/models/components.py @@ -5,91 +5,95 @@ def conv(batchNorm: bool, in_planes: int, out_planes: int, kernel_size: int = 3, stride: int = 1, bias: bool = True): - """ Convenience function for conv2d + optional BN + leakyRELU """ + """Convenience function for conv2d + optional BN + leakyRELU""" if batchNorm: return nn.Sequential( - nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, - bias=bias), + nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, bias=bias + ), nn.BatchNorm2d(out_planes), - nn.LeakyReLU(0.1, inplace=True) + nn.LeakyReLU(0.1, inplace=True), ) else: return nn.Sequential( - nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, - bias=bias), - nn.LeakyReLU(0.1, inplace=True) + nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, bias=bias + ), + nn.LeakyReLU(0.1, inplace=True), ) def crop_like(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ Crops input to target's H,W """ + """Crops input to target's H,W""" if input.size()[2:] == target.size()[2:]: return input else: - return input[:, :, :target.size(2), :target.size(3)] + return input[:, :, : target.size(2), : target.size(3)] def deconv(in_planes: int, out_planes: int, bias: bool = True): - """ Convenience function for ConvTranspose2d + leakyRELU """ + """Convenience function for ConvTranspose2d + leakyRELU""" return nn.Sequential( nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=bias), - nn.LeakyReLU(0.1, inplace=True) + nn.LeakyReLU(0.1, inplace=True), ) class Interpolate(nn.Module): - """ Wrapper to be able to perform interpolation in a nn.Sequential + """Wrapper to be able to perform interpolation in a nn.Sequential Modified from the PyTorch Forums: https://discuss.pytorch.org/t/using-nn-function-interpolate-inside-nn-sequential/23588/2 """ - def __init__(self, size=None, scale_factor=None, mode: str = 'bilinear'): + def __init__(self, size=None, scale_factor=None, mode: str = "bilinear"): super(Interpolate, self).__init__() self.interp = nn.functional.interpolate self.size = size self.scale_factor = scale_factor - assert mode in ['nearest', 'linear', 'bilinear', 'bicubic', 'trilinear', 'area'] + assert mode in ["nearest", "linear", "bilinear", "bicubic", "trilinear", "area"] self.mode = mode - if self.mode == 'nearest': + if self.mode == "nearest": self.align_corners = None else: self.align_corners = False def forward(self, x): - x = self.interp(x, size=self.size, scale_factor=self.scale_factor, - mode=self.mode, align_corners=self.align_corners) + x = self.interp( + x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) return x def i_conv(batchNorm: bool, in_planes: int, out_planes: int, kernel_size: int = 3, stride: int = 1, bias: bool = True): - """ Convenience function for conv2d + optional BN + no activation """ + """Convenience function for conv2d + optional BN + no activation""" if batchNorm: return nn.Sequential( - nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, - bias=bias), + nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, bias=bias + ), nn.BatchNorm2d(out_planes), ) else: return nn.Sequential( - nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, - bias=bias), + nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, bias=bias + ), ) def predict_flow(in_planes: int, out_planes: int = 2, bias: bool = False): - """ Convenience function for 3x3 conv2d with same padding """ + """Convenience function for 3x3 conv2d with same padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=bias) def get_hw(tensor): - """ Convenience function for getting the size of the last two dimensions in a tensor """ + """Convenience function for getting the size of the last two dimensions in a tensor""" return tensor.size(-2), tensor.size(-1) class CropConcat(nn.Module): - """ Module for concatenating 2 tensors of slightly different shape. - """ + """Module for concatenating 2 tensors of slightly different shape.""" def __init__(self, dim: int = 1): super().__init__() @@ -103,9 +107,17 @@ def forward(self, tensors: tuple) -> torch.Tensor: return torch.cat(tuple([tensor[..., :h, :w] for tensor in tensors]), dim=self.dim) -def conv3d(in_planes: int, out_planes: int, kernel_size: Union[int, tuple] = 3, stride: Union[int, tuple] = 1, - bias: bool = True, batchnorm: bool = True, act: bool = True, padding: tuple = None): - """ 3D convolution +def conv3d( + in_planes: int, + out_planes: int, + kernel_size: Union[int, tuple] = 3, + stride: Union[int, tuple] = 1, + bias: bool = True, + batchnorm: bool = True, + act: bool = True, + padding: tuple = None, +): + """3D convolution Expects inputs of shape N, C, D/F/T, H, W. D/F/T is frames, depth, time-- the extra axis compared to 2D convolution. @@ -140,10 +152,9 @@ def conv3d(in_planes: int, out_planes: int, kernel_size: Union[int, tuple] = 3, elif padding is None and type(kernel_size) == tuple: padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2, (kernel_size[2] - 1) // 2) else: - raise ValueError('Unknown padding type {} and kernel_size type: {}'.format(padding, kernel_size)) + raise ValueError("Unknown padding type {} and kernel_size type: {}".format(padding, kernel_size)) - modules.append(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, - bias=bias)) + modules.append(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) if batchnorm: modules.append(nn.BatchNorm3d(out_planes)) if act: @@ -151,11 +162,20 @@ def conv3d(in_planes: int, out_planes: int, kernel_size: Union[int, tuple] = 3, return nn.Sequential(*modules) -def deconv3d(in_planes: int, out_planes: int, kernel_size: int = 4, stride: int = 2, bias: bool = True, - batchnorm: bool = True, act: bool = True, padding: int = 1): - """ Convenience function for ConvTranspose3D. Optionally adds batchnorm3d, leakyrelu """ - modules = [nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, - bias=bias, padding=padding)] +def deconv3d( + in_planes: int, + out_planes: int, + kernel_size: int = 4, + stride: int = 2, + bias: bool = True, + batchnorm: bool = True, + act: bool = True, + padding: int = 1, +): + """Convenience function for ConvTranspose3D. Optionally adds batchnorm3d, leakyrelu""" + modules = [ + nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding) + ] if batchnorm: modules.append(nn.BatchNorm3d(out_planes)) if act: @@ -164,5 +184,5 @@ def deconv3d(in_planes: int, out_planes: int, kernel_size: int = 4, stride: int def predict_flow_3d(in_planes: int, out_planes: int): - """ Convenience function for conv3d, 3x3, no activation or batchnorm """ + """Convenience function for conv3d, 3x3, no activation or batchnorm""" return conv3d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, act=False, batchnorm=False) diff --git a/deepethogram/flow_generator/train.py b/deepethogram/flow_generator/train.py index 18c81c8..609214e 100644 --- a/deepethogram/flow_generator/train.py +++ b/deepethogram/flow_generator/train.py @@ -24,20 +24,21 @@ from deepethogram.stoppers import get_stopper warnings.filterwarnings( - 'ignore', + "ignore", category=UserWarning, - message='Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation ' - 'and test dataloaders.') + message="Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation " + "and test dataloaders.", +) flow_generators = utils.get_models_from_module(models, get_function=False) -plt.switch_backend('agg') +plt.switch_backend("agg") log = logging.getLogger(__name__) def flow_generator_train(cfg: DictConfig) -> nn.Module: - """Trains flow generator models from a configuration. + """Trains flow generator models from a configuration. Parameters ---------- @@ -50,7 +51,7 @@ def flow_generator_train(cfg: DictConfig) -> nn.Module: Trained flow generator """ cfg = projects.setup_run(cfg) - log.info('args: {}'.format(' '.join(sys.argv))) + log.info("args: {}".format(" ".join(sys.argv))) # only two custom overwrites of the configuration file # allow for editing @@ -58,17 +59,17 @@ def flow_generator_train(cfg: DictConfig) -> nn.Module: # second, use the model directory to find the most recent run of each model type # cfg = projects.overwrite_cfg_with_latest_weights(cfg, cfg.project.model_path, model_type='flow_generator') # SHOULD NEVER MODIFY / MAKE ASSIGNMENTS TO THE CFG OBJECT AFTER RIGHT HERE! - log.info('configuration used ~~~~~') + log.info("configuration used ~~~~~") log.info(OmegaConf.to_yaml(cfg)) - datasets, data_info = get_datasets_from_cfg(cfg, 'flow_generator', input_images=cfg.flow_generator.n_rgb) + datasets, data_info = get_datasets_from_cfg(cfg, "flow_generator", input_images=cfg.flow_generator.n_rgb) flow_generator = build_model_from_cfg(cfg) - log.info('Total trainable params: {:,}'.format(utils.get_num_parameters(flow_generator))) - utils.save_dict_to_yaml(data_info['split'], os.path.join(os.getcwd(), 'split.yaml')) - flow_weights = deepethogram.projects.get_weightfile_from_cfg(cfg, 'flow_generator') + log.info("Total trainable params: {:,}".format(utils.get_num_parameters(flow_generator))) + utils.save_dict_to_yaml(data_info["split"], os.path.join(os.getcwd(), "split.yaml")) + flow_weights = deepethogram.projects.get_weightfile_from_cfg(cfg, "flow_generator") if flow_weights is not None: - print('reloading weights...') - flow_generator = utils.load_weights(flow_generator, flow_weights, device='cpu') + print("reloading weights...") + flow_generator = utils.load_weights(flow_generator, flow_weights, device="cpu") stopper = get_stopper(cfg) metrics = get_metrics(cfg, os.getcwd(), utils.get_num_parameters(flow_generator)) @@ -91,14 +92,15 @@ def build_model_from_cfg(cfg: DictConfig) -> Type[nn.Module]: nn.Module flow generator """ - flow_generator = flow_generators[cfg.flow_generator.arch](num_images=cfg.flow_generator.n_rgb, - flow_div=cfg.flow_generator.max) + flow_generator = flow_generators[cfg.flow_generator.arch]( + num_images=cfg.flow_generator.n_rgb, flow_div=cfg.flow_generator.max + ) return flow_generator class OpticalFlowLightning(BaseLightningModule): - """Lightning Module for training Optic Flow generator models - """ + """Lightning Module for training Optic Flow generator models""" + def __init__(self, model: nn.Module, cfg: DictConfig, datasets: dict, metrics, visualization_func): """constructor @@ -127,18 +129,17 @@ def __init__(self, model: nn.Module, cfg: DictConfig, datasets: dict, metrics, v self.viz_cnt = None def validate_batch_size(self, batch: dict): - """simple wrapper to make sure our batch has the right input shape - """ + """simple wrapper to make sure our batch has the right input shape""" if self.hparams.compute.dali: # no idea why they wrap this, maybe they fixed it? batch = batch[0] - if 'images' in batch.keys(): + if "images" in batch.keys(): # weird case of batch size = 1 somehow getting squeezed out - if batch['images'].ndim != 5: - batch['images'] = batch['images'].unsqueeze(0) - if 'labels' in batch.keys(): - if self.final_activation == 'sigmoid' and batch['labels'].ndim == 1: - batch['labels'] = batch['labels'].unsqueeze(0) + if batch["images"].ndim != 5: + batch["images"] = batch["images"].unsqueeze(0) + if "labels" in batch.keys(): + if self.final_activation == "sigmoid" and batch["labels"].ndim == 1: + batch["labels"] = batch["labels"].unsqueeze(0) return batch def common_step(self, batch: dict, batch_idx: int, split: str): @@ -166,28 +167,34 @@ def common_step(self, batch: dict, batch_idx: int, split: str): self.visualize_batch(images, downsampled_t0, estimated_t0, flows_reshaped, split) to_log = loss_components - to_log['loss'] = loss.detach() + to_log["loss"] = loss.detach() self.metrics.buffer.append(split, to_log) # need to use the native logger for lr scheduling, etc. key_metric = self.metrics.key_metric - self.log(f'{split}_loss', loss) - if split == 'val': - self.log(f'{split}_{key_metric}', loss_components[key_metric].mean()) + self.log(f"{split}_loss", loss) + if split == "val": + self.log(f"{split}_{key_metric}", loss_components[key_metric].mean()) return loss def training_step(self, batch: dict, batch_idx: int): - return self.common_step(batch, batch_idx, 'train') + return self.common_step(batch, batch_idx, "train") def validation_step(self, batch: dict, batch_idx: int): - return self.common_step(batch, batch_idx, 'val') + return self.common_step(batch, batch_idx, "val") def test_step(self, batch: dict, batch_idx: int): - images, outputs = self(batch, 'test') - - def visualize_batch(self, images: torch.Tensor, downsampled_t0: torch.Tensor, estimated_t0: torch.Tensor, - flows_reshaped: torch.Tensor, split: str): + images, outputs = self(batch, "test") + + def visualize_batch( + self, + images: torch.Tensor, + downsampled_t0: torch.Tensor, + estimated_t0: torch.Tensor, + flows_reshaped: torch.Tensor, + split: str, + ): """visualizes a batch of inputs and saves as a matplotlib figure PNG to disk Parameters @@ -213,35 +220,41 @@ def visualize_batch(self, images: torch.Tensor, downsampled_t0: torch.Tensor, es batch_ind = np.random.choice(images.shape[0]) sequence_length = int(downsampled_t0[0].shape[0] / images.shape[0]) - viz.visualize_images_and_flows(downsampled_t0, - flows_reshaped, - sequence_length, - batch_ind=batch_ind, - fig=fig, - max_flow=self.hparams.flow_generator.max) - viz.save_figure(fig, 'batch', True, viz_cnt, split) + viz.visualize_images_and_flows( + downsampled_t0, + flows_reshaped, + sequence_length, + batch_ind=batch_ind, + fig=fig, + max_flow=self.hparams.flow_generator.max, + ) + viz.save_figure(fig, "batch", True, viz_cnt, split) fig = plt.figure(figsize=(14, 14)) sequence_ind = np.random.choice(sequence_length - 1) - viz.visualize_multiresolution(downsampled_t0, - estimated_t0, - flows_reshaped, - sequence_length, - max_flow=self.hparams.flow_generator.max, - sequence_ind=sequence_ind, - batch_ind=batch_ind, - fig=fig) - viz.save_figure(fig, 'multiresolution', True, viz_cnt, split) + viz.visualize_multiresolution( + downsampled_t0, + estimated_t0, + flows_reshaped, + sequence_length, + max_flow=self.hparams.flow_generator.max, + sequence_ind=sequence_ind, + batch_ind=batch_ind, + fig=fig, + ) + viz.save_figure(fig, "multiresolution", True, viz_cnt, split) fig = plt.figure(figsize=(14, 14)) - viz.visualize_batch_unsupervised(downsampled_t0, - estimated_t0, - flows_reshaped, - batch_ind=batch_ind, - sequence_ind=sequence_ind, - fig=fig, - sequence_length=sequence_length) - viz.save_figure(fig, 'reconstruction', True, viz_cnt, split) + viz.visualize_batch_unsupervised( + downsampled_t0, + estimated_t0, + flows_reshaped, + batch_ind=batch_ind, + sequence_ind=sequence_ind, + fig=fig, + sequence_length=sequence_length, + ) + viz.save_figure(fig, "reconstruction", True, viz_cnt, split) def forward(self, batch: dict, mode: str) -> Tuple[torch.Tensor, list]: """actually compute optic flow @@ -262,7 +275,7 @@ def forward(self, batch: dict, mode: str) -> Tuple[torch.Tensor, list]: """ batch = self.validate_batch_size(batch) # lightning handles transfer to device - images = batch['images'] + images = batch["images"] images = self.apply_gpu_transforms(images, mode) outputs = self.model(images) @@ -271,34 +284,33 @@ def forward(self, batch: dict, mode: str) -> Tuple[torch.Tensor, list]: return images, outputs def log_image_statistics(self, images): - """convenience method for logging image shape and channel statistics - """ + """convenience method for logging image shape and channel statistics""" if not self.has_logged_channels and log.isEnabledFor(logging.DEBUG): if len(images.shape) == 4: N, C, H, W = images.shape - log.debug('inputs shape: NCHW: {} {} {} {}'.format(N, C, H, W)) - log.debug('channel min: {}'.format(images[0].reshape(C, -1).min(dim=1).values)) - log.debug('channel mean: {}'.format(images[0].reshape(C, -1).mean(dim=1))) - log.debug('channel max : {}'.format(images[0].reshape(C, -1).max(dim=1).values)) - log.debug('channel std : {}'.format(images[0].reshape(C, -1).std(dim=1))) + log.debug("inputs shape: NCHW: {} {} {} {}".format(N, C, H, W)) + log.debug("channel min: {}".format(images[0].reshape(C, -1).min(dim=1).values)) + log.debug("channel mean: {}".format(images[0].reshape(C, -1).mean(dim=1))) + log.debug("channel max : {}".format(images[0].reshape(C, -1).max(dim=1).values)) + log.debug("channel std : {}".format(images[0].reshape(C, -1).std(dim=1))) elif len(images.shape) == 5: N, C, T, H, W = images.shape - log.debug('inputs shape: NCTHW: {} {} {} {} {}'.format(N, C, T, H, W)) - log.debug('channel min: {}'.format(images[0].min(dim=2).values)) - log.debug('channel mean: {}'.format(images[0].mean(dim=2))) - log.debug('channel max : {}'.format(images[0].max(dim=2).values)) - log.debug('channel std : {}'.format(images[0].std(dim=2))) + log.debug("inputs shape: NCTHW: {} {} {} {} {}".format(N, C, T, H, W)) + log.debug("channel min: {}".format(images[0].min(dim=2).values)) + log.debug("channel mean: {}".format(images[0].mean(dim=2))) + log.debug("channel max : {}".format(images[0].max(dim=2).values)) + log.debug("channel std : {}".format(images[0].std(dim=2))) self.has_logged_channels = True def log_model_statistics(self, images, outputs, labels): # will print out shape and min, mean, max, std along image channels # we use the isEnabledFor flag so that this doesnt slow down training in the non-debug case - log.debug('outputs: {}'.format(outputs)) - log.debug('labels: {}'.format(labels)) - log.debug('outputs: {}'.format(outputs.shape)) - log.debug('labels: {}'.format(labels.shape)) - log.debug('label max: {}'.format(labels.max())) - log.debug('label min: {}'.format(labels.min())) + log.debug("outputs: {}".format(outputs)) + log.debug("labels: {}".format(labels)) + log.debug("outputs: {}".format(outputs.shape)) + log.debug("labels: {}".format(labels.shape)) + log.debug("label max: {}".format(labels.max())) + log.debug("label min: {}".format(labels.min())) def get_criterion(cfg, model): @@ -323,7 +335,7 @@ def get_criterion(cfg, model): """ regularization_criterion = get_regularization_loss(cfg, model) - if cfg.flow_generator.loss == 'MotionNet': + if cfg.flow_generator.loss == "MotionNet": criterion = MotionNetLoss( regularization_criterion, flow_sparsity=cfg.flow_generator.flow_sparsity, @@ -353,20 +365,20 @@ def get_metrics(cfg: DictConfig, rundir: Union[str, bytes, os.PathLike], num_par metrics: Metrics metrics object. see deepethogram.metrics.py """ - metrics_list = ['SSIM', 'L1', 'smoothness', 'SSIM_full'] + metrics_list = ["SSIM", "L1", "smoothness", "SSIM_full"] if cfg.flow_generator.flow_sparsity: - metrics_list.append('flow_sparsity') - if cfg.flow_generator.loss == 'SelfSupervised': - metrics_list.append('gradient') - metrics_list.append('MFH') - key_metric = 'SSIM' - log.info('key metric is {}'.format(key_metric)) + metrics_list.append("flow_sparsity") + if cfg.flow_generator.loss == "SelfSupervised": + metrics_list.append("gradient") + metrics_list.append("MFH") + key_metric = "SSIM" + log.info("key metric is {}".format(key_metric)) # the metrics objects all take normal dicts instead of dict configs metrics = OpticalFlow(rundir, key_metric, num_parameters) return metrics -if __name__ == '__main__': +if __name__ == "__main__": project_path = projects.get_project_path_from_cl(sys.argv) cfg = make_flow_generator_train_cfg(project_path, use_command_line=True) diff --git a/deepethogram/flow_generator/utils.py b/deepethogram/flow_generator/utils.py index 34a4425..98183c0 100644 --- a/deepethogram/flow_generator/utils.py +++ b/deepethogram/flow_generator/utils.py @@ -11,7 +11,7 @@ def flow_to_rgb(flow: np.ndarray, maxval: Union[int, float] = 20) -> np.ndarray: - """ Convert optic flow to RGB by linearly mapping X to red and Y to green + """Convert optic flow to RGB by linearly mapping X to red and Y to green 255 in the resulting image will correspond to `maxval`. 0 corresponds to -`maxval`. Parameters @@ -27,7 +27,7 @@ def flow_to_rgb(flow: np.ndarray, maxval: Union[int, float] = 20) -> np.ndarray: RGB image """ H, W, C = flow.shape - assert (C == 2) + assert C == 2 flow = (flow + maxval) * (255 / 2 / maxval) flow = flow.clip(min=0, max=255) flow_map = np.ones((H, W, 3), dtype=np.uint8) * 127 @@ -40,7 +40,7 @@ def flow_to_rgb(flow: np.ndarray, maxval: Union[int, float] = 20) -> np.ndarray: def rgb_to_flow(image: np.ndarray, maxval: Union[int, float] = 20): - """ Converts an RGB image to an optic flow by linearly mapping R -> X and G -> Y. Opposite of `flow_to_rgb` + """Converts an RGB image to an optic flow by linearly mapping R -> X and G -> Y. Opposite of `flow_to_rgb` Parameters ---------- @@ -59,14 +59,14 @@ def denormalize(arr: np.ndarray): return arr H, W, C = image.shape - assert (C == 3) + assert C == 3 image = image.astype(np.float32) image = denormalize(image) return image[..., 0:2] def flow_to_rgb_polar(flow: np.ndarray, maxval: Union[int, float] = 20) -> np.ndarray: - """ Converts flow to RGB by mapping angle -> hue and magnitude -> saturation. + """Converts flow to RGB by mapping angle -> hue and magnitude -> saturation. Converts the flow map to polar coordinates: dX, dY -> angle, magnitude. Uses a HSV representation: Hue = angle, saturation = magnitude, value = 1 @@ -103,7 +103,7 @@ def flow_to_rgb_polar(flow: np.ndarray, maxval: Union[int, float] = 20) -> np.nd def rgb_to_flow_polar(image: np.ndarray, maxval: Union[int, float] = 20): - """ Converts rgb to flow by mapping hue -> angle and saturation -> magnitude. + """Converts rgb to flow by mapping hue -> angle and saturation -> magnitude. Inverse of `flow_to_rgb_polar` Parameters @@ -151,7 +151,7 @@ def rgb_to_flow_polar(image: np.ndarray, maxval: Union[int, float] = 20): class Resample2d(torch.nn.Module): - """ Module to sample tensors using Spatial Transformer Networks. Caches multiple grids in GPU VRAM for speed. + """Module to sample tensors using Spatial Transformer Networks. Caches multiple grids in GPU VRAM for speed. Examples ------- @@ -166,10 +166,15 @@ class Resample2d(torch.nn.Module): disparity = model(left_images, right_images) """ - def __init__(self, size: Union[tuple, list] = None, fp16: bool = False, device: Union[str, torch.device] = None, - horiz_only: bool = False, - num_grids: int = 5): - """ Constructor for resampler. + def __init__( + self, + size: Union[tuple, list] = None, + fp16: bool = False, + device: Union[str, torch.device] = None, + horiz_only: bool = False, + num_grids: int = 5, + ): + """Constructor for resampler. Parameters ---------- @@ -196,7 +201,7 @@ def __init__(self, size: Union[tuple, list] = None, fp16: bool = False, device: """ super().__init__() if size is not None: - assert (type(size) == tuple or type(size) == list) + assert type(size) == tuple or type(size) == list self.size = size # identity matrix @@ -213,7 +218,7 @@ def __init__(self, size: Union[tuple, list] = None, fp16: bool = False, device: self.uses = [] def forward(self, images: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: - """ resample `images` according to `flow` + """resample `images` according to `flow` Parameters ---------- @@ -235,7 +240,7 @@ def forward(self, images: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: # images: NxCxHxW # flow: Bx2xHxW grid_size = [flow.size(0), 2, flow.size(2), flow.size(3)] - if not hasattr(self, 'grids') or grid_size not in self.sizes: + if not hasattr(self, "grids") or grid_size not in self.sizes: if len(self.sizes) >= self.num_grids: min_uses = min(self.uses) min_loc = self.uses.index(min_uses) @@ -246,8 +251,9 @@ def forward(self, images: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: # function outputs N,H,W,2. Permuted to N,2,H,W to match flow # 0-th channel is x sample locations, -1 in left column, 1 in right column # 1-th channel is y sample locations, -1 in first row, 1 in bottom row - this_grid = F.affine_grid(self.affine_mat, images.shape, align_corners=False).permute(0, 3, 1, 2).to( - self.device) + this_grid = ( + F.affine_grid(self.affine_mat, images.shape, align_corners=False).permute(0, 3, 1, 2).to(self.device) + ) this_size = [i for i in this_grid.size()] self.sizes.append(this_size) self.grids.append(this_grid) @@ -264,22 +270,27 @@ def forward(self, images: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: # horiz_only: for stereo matching, Y values are always the same if self.horiz_only: # flow = flow[:, 0:1, :, :] / ((W - 1.0) / 2.0) - flow = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), - torch.zeros((flow.size(0), flow.size(1), H, W))], 1) + flow = torch.cat( + [flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), torch.zeros((flow.size(0), flow.size(1), H, W))], 1 + ) else: # for optic flow matching: can be displaced in X or Y - flow = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), - flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1) + flow = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1) # sample according to grid + flow - return F.grid_sample(input=images, grid=(this_grid + flow).permute(0, 2, 3, 1), - mode='bilinear', padding_mode='border', align_corners=False) + return F.grid_sample( + input=images, + grid=(this_grid + flow).permute(0, 2, 3, 1), + mode="bilinear", + padding_mode="border", + align_corners=False, + ) class Reconstructor: def __init__(self, cfg: DictConfig): device = torch.device("cuda:" + str(cfg.compute.gpu_id) if torch.cuda.is_available() else "cpu") self.resampler = Resample2d(device=device, fp16=cfg.compute.fp16) - if 'normalization' in list(cfg.augs.keys()): + if "normalization" in list(cfg.augs.keys()): mean = list(cfg.augs.normalization.mean) std = list(cfg.augs.normalization.std) else: @@ -295,7 +306,7 @@ def reconstruct_images(self, image_batch: torch.Tensor, flows: Union[tuple, list if image_batch.ndim == 4: N, C, H, W = image_batch.shape num_images = int(C / 3) - 1 - t0 = image_batch[:, :num_images * 3, ...].contiguous().view(N * num_images, 3, H, W) + t0 = image_batch[:, : num_images * 3, ...].contiguous().view(N * num_images, 3, H, W) t1 = image_batch[:, 3:, ...].contiguous().view(N * num_images, 3, H, W) elif image_batch.ndim == 5: N, C, T, H, W = image_batch.shape @@ -305,7 +316,7 @@ def reconstruct_images(self, image_batch: torch.Tensor, flows: Union[tuple, list t1 = image_batch[:, :, 1:, ...] t1 = t1.transpose(1, 2).reshape(N * num_images, C, H, W) else: - raise ValueError('unexpected batch shape: {}'.format(image_batch)) + raise ValueError("unexpected batch shape: {}".format(image_batch)) reconstructed = [] t1s = [] @@ -320,8 +331,8 @@ def reconstruct_images(self, image_batch: torch.Tensor, flows: Union[tuple, list n, c, t, h, w = flow.shape flow = flow.transpose(1, 2).reshape(n * t, c, h, w) - downsampled_t1 = F.interpolate(t1, (h, w), mode='bilinear', align_corners=False) - downsampled_t0 = F.interpolate(t0, (h, w), mode='bilinear', align_corners=False) + downsampled_t1 = F.interpolate(t1, (h, w), mode="bilinear", align_corners=False) + downsampled_t0 = F.interpolate(t0, (h, w), mode="bilinear", align_corners=False) t0s.append(downsampled_t0) t1s.append(downsampled_t1) reconstructed.append(self.resampler(downsampled_t1, flow)) @@ -336,10 +347,10 @@ def __call__(self, image_batch: torch.Tensor, flows: Union[tuple, list]) -> Tupl def stacked_to_sequence(tensor: torch.Tensor, num_channels: int = 3) -> torch.Tensor: if tensor.ndim > 4: - warnings.warn('called stacked_to_sequence on a sequence of shape {}'.format(tensor.shape)) + warnings.warn("called stacked_to_sequence on a sequence of shape {}".format(tensor.shape)) return tensor N, C, H, W = tensor.shape - assert ((C % num_channels) == 0) + assert (C % num_channels) == 0 num_channels = 3 starts = range(0, C, num_channels) ends = range(num_channels, C + 1, num_channels) @@ -347,7 +358,7 @@ def stacked_to_sequence(tensor: torch.Tensor, num_channels: int = 3) -> torch.Te def rgb_to_hsv_torch(image: torch.Tensor) -> torch.Tensor: - """ PyTorch implementation of RGB to HSV color conversion """ + """PyTorch implementation of RGB to HSV color conversion""" # https://torchgeometry.readthedocs.io/en/latest/_modules/kornia/color/hsv.html#rgb_to_hsv # https://en.wikipedia.org/wiki/HSL_and_HSV#General_approach # https://stackoverflow.com/questions/3018313/algorithm-to-convert-rgb-to-hsv-and-hsv-to-rgb-in-range-0-255-for-both diff --git a/deepethogram/gui/custom_widgets.py b/deepethogram/gui/custom_widgets.py index 4c41f6d..bc44c45 100644 --- a/deepethogram/gui/custom_widgets.py +++ b/deepethogram/gui/custom_widgets.py @@ -11,6 +11,7 @@ from PySide2.QtCore import Signal, Slot from deepethogram.file_io import VideoReader + # these define the parameters of the deepethogram colormap below from deepethogram.viz import Mapper @@ -26,10 +27,10 @@ def numpy_to_qpixmap(image: np.ndarray) -> QtGui.QPixmap: elif C == 3: format = QtGui.QImage.Format_RGB888 else: - raise ValueError('Aberrant number of channels: {}'.format(C)) + raise ValueError("Aberrant number of channels: {}".format(C)) qpixmap = QtGui.QPixmap(QtGui.QImage(image, W, H, image.strides[0], format)) # print(type(qpixmap)) - return (qpixmap) + return qpixmap def float_to_uint8(image: np.ndarray) -> np.ndarray: @@ -40,7 +41,7 @@ def float_to_uint8(image: np.ndarray) -> np.ndarray: def initializer(nframes: int): - print('initialized with {}'.format(nframes)) + print("initialized with {}".format(nframes)) class VideoFrame(QtWidgets.QGraphicsView): @@ -75,7 +76,7 @@ def __init__(self, videoFile: Union[str, os.PathLike] = None, *args, **kwargs): # print(self.palette()) def initialize_video(self, videofile: Union[str, os.PathLike]): - if hasattr(self, 'vid'): + if hasattr(self, "vid"): self.vid.close() # if hasattr(self.vid, 'cap'): # self.vid.cap.release() @@ -100,7 +101,7 @@ def mousePressEvent(self, event): super().mousePressEvent(event) def wheelEvent(self, event): - if hasattr(self, 'vid'): + if hasattr(self, "vid"): if event.angleDelta().y() > 0: factor = 1.25 self._zoom += 1 @@ -126,10 +127,10 @@ def update_frame(self, value, force: bool = False): # print('update to: {}'.format(value)) # print(self.current_fnum) # previous_frame = self.current_fnum - if not hasattr(self, 'vid'): + if not hasattr(self, "vid"): return value = int(value) - if hasattr(self, 'current_fnum'): + if hasattr(self, "current_fnum"): if self.current_fnum == value and not force: # print('already there') return @@ -165,9 +166,9 @@ def fitInView(self, scale=True): self._zoom = 0 def adjust_aspect_ratio(self): - if not hasattr(self, 'vid'): - raise ValueError('Trying to set GraphicsView aspect ratio before video loaded.') - if not hasattr(self.vid, 'width'): + if not hasattr(self, "vid"): + raise ValueError("Trying to set GraphicsView aspect ratio before video loaded.") + if not hasattr(self.vid, "width"): self.vid.width, self.vid.height = self.frame.shape[1], self.frame.shape[0] video_aspect = self.vid.width / self.vid.height H, W = self.height(), self.width() @@ -186,7 +187,7 @@ def show_image(self, array): # self.show() def resizeEvent(self, event): - if hasattr(self, 'vid'): + if hasattr(self, "vid"): pass # self.fitInView() @@ -263,8 +264,8 @@ def scrollbar_change(self): @Slot(int) def update_state(self, value: int): - if self.plainTextEdit.document().toPlainText() != '{}'.format(value): - self.plainTextEdit.setPlainText('{}'.format(value)) + if self.plainTextEdit.document().toPlainText() != "{}".format(value): + self.plainTextEdit.setPlainText("{}".format(value)) if self.horizontalScrollBar.value() != value: self.horizontalScrollBar.setValue(value) @@ -277,7 +278,7 @@ def initialize_state(self, value: int): # self.horizontalScrollBar.sliderMoved.connect(self.scrollbar_change) # self.horizontalScrollBar.valueChanged.connect(self.scrollbar_change) self.horizontalScrollBar.setValue(0) - self.plainTextEdit.setPlainText('{}'.format(0)) + self.plainTextEdit.setPlainText("{}".format(0)) # self.plainTextEdit.textChanged.connect(self.text_change) # self.update() @@ -313,7 +314,7 @@ def __init__(self, parent=None, videoFile: Union[str, os.PathLike] = None, *args self.videoView.frameNum.connect(self.scrollbartext.update_state) # I have to do this here because I think emitting a signal doesn't work from within the widget's constructor - if hasattr(self.videoView, 'vid'): + if hasattr(self.videoView, "vid"): self.videoView.initialized.emit(len(self.videoView.vid)) self.update() @@ -381,18 +382,20 @@ def __init__(self, fixed: bool = False, *args, **kwargs): if self.fixed: self.fixed_settings() - def initialize(self, - n: int = 1, - n_timepoints: int = 31, - debug: bool = False, - colormap: str = 'Reds', - unlabeled_alpha: float = 0.1, - desired_pixel_size: int = 25, - array: np.ndarray = None, - fixed: bool = False, - opacity: np.ndarray = None): + def initialize( + self, + n: int = 1, + n_timepoints: int = 31, + debug: bool = False, + colormap: str = "Reds", + unlabeled_alpha: float = 0.1, + desired_pixel_size: int = 25, + array: np.ndarray = None, + fixed: bool = False, + opacity: np.ndarray = None, + ): if self.initialized: - raise ValueError('only initialize once!') + raise ValueError("only initialize once!") if array is not None: # print(array.shape) self.n_timepoints = array.shape[0] @@ -416,7 +419,7 @@ def initialize(self, try: self.cmap = Mapper(colormap) except ValueError: - raise ('Colormap not in matplotlib' 's defaults! {}'.format(colormap)) + raise ("Colormap not in matplotlib" "s defaults! {}".format(colormap)) if self.debug: self.make_debug() @@ -481,14 +484,14 @@ def mouseMoveEvent(self, event): super().mouseMoveEvent(event) def change_rectangle(self, rect): - if not hasattr(self, 'item_rect'): + if not hasattr(self, "item_rect"): return self.item_rect.setRect(rect) def _fit_label_photo(self): - if not hasattr(self, 'x'): + if not hasattr(self, "x"): self.x = 0 - if not hasattr(self, 'view_x'): + if not hasattr(self, "view_x"): self.view_x = 0 # gets the bounding rectangle (in pixels) for the image of the label array geometry = self.geometry() @@ -518,9 +521,9 @@ def change_view_x(self, x: int): if x < 0 or x >= self.n_timepoints: # print('return 1') return - if not hasattr(self, 'view_width'): + if not hasattr(self, "view_width"): self._fit_label_photo() - if not hasattr(self, 'n'): + if not hasattr(self, "n"): # print('return 2') return @@ -564,7 +567,7 @@ def change_view_x(self, x: int): # self.show() def fixed_settings(self): - if not hasattr(self, 'changed'): + if not hasattr(self, "changed"): return self.changed = np.ones(self.changed.shape) self.recreate_label_image() @@ -573,17 +576,17 @@ def _add_behavior(self, behaviors: Union[int, np.ndarray, list], fstart: int, fe # print('adding') if self.fixed: return - if not hasattr(self, 'array'): + if not hasattr(self, "array"): return n_behaviors = self.image.shape[0] if type(behaviors) != np.ndarray: behaviors = np.array(behaviors) if max(behaviors) > n_behaviors: - raise ValueError('Not enough behaviors for number: {}'.format(behaviors)) + raise ValueError("Not enough behaviors for number: {}".format(behaviors)) if fstart < 0: - raise ValueError('Behavior start frame must be > 0: {}'.format(fstart)) + raise ValueError("Behavior start frame must be > 0: {}".format(fstart)) if fend > self.n_timepoints: - raise ValueError('Behavior end frame must be < nframes: {}'.format(fend)) + raise ValueError("Behavior end frame must be < nframes: {}".format(fend)) # log.debug('Behaviors: {} fstart: {} fend: {}'.format(behaviors, fstart, fend)) # go backwards to erase if fstart <= fend: @@ -611,10 +614,11 @@ def _add_behavior(self, behaviors: Union[int, np.ndarray, list], fstart: int, fe # print('l shape: {}'.format(self.image[1:, time_indices, :].shape)) # print('r_shape: {}'.format(np.tile(self.neg_color[1:], [1, len(time_indices), 1]).shape)) self.image[0, time_indices, :] = self.pos_color[0] - self.image[1:, time_indices, :] = np.dstack([self.neg_color[1:] for _ in range(len(time_indices)) - ]).swapaxes(1, 2) + self.image[1:, time_indices, :] = np.dstack( + [self.neg_color[1:] for _ in range(len(time_indices))] + ).swapaxes(1, 2) else: - xv, yv = np.meshgrid(time_indices, behaviors, indexing='ij') + xv, yv = np.meshgrid(time_indices, behaviors, indexing="ij") xv = xv.flatten() yv = yv.flatten() # log.debug('xv: {} yv: {}'.format(xv, yv)) @@ -647,7 +651,7 @@ def change_view_dx(self, dx: int): def _array_to_image(self, array: np.ndarray, alpha: Union[float, int, np.ndarray] = None): image = self.cmap(array.T * 255) image[..., 3] = alpha - return (image) + return image def _add_row(self): self.array = np.concatenate((self.array, np.zeros((self.array.shape[0], 1), dtype=self.array.dtype)), axis=1) @@ -662,7 +666,7 @@ def _add_row(self): self._fit_label_photo() def _change_n_timepoints(self, n_timepoints: int): - warnings.warn('Changing number of timepoints will erase any labels!') + warnings.warn("Changing number of timepoints will erase any labels!") self.array = np.zeros((n_timepoints, self.n), dtype=np.uint8) self.changed = np.zeros((n_timepoints,), dtype=np.uint8) self.n_timepoints = n_timepoints @@ -670,20 +674,20 @@ def _change_n_timepoints(self, n_timepoints: int): self.image = self._array_to_image(self.array, alpha=self.unlabeled_alpha) def make_debug(self, num_rows: int = 15000): - print('debug') - assert (hasattr(self, 'array')) + print("debug") + assert hasattr(self, "array") rows, cols = self.shape # print(rows, cols) # behav = 0 for i in range(rows): - behav = (i % cols) + behav = i % cols self.array[i, behav] = 1 # self.array = self.array[:num_rows,:] # print(self.array) def calculate_background_class(self, array: np.ndarray): array[:, 0] = np.logical_not(np.any(array[:, 1:], axis=1)) - return (array) + return array def update_background_class(self): # import pdb @@ -716,13 +720,13 @@ def recreate_label_image(self): @Slot(int) def toggle_behavior(self, index: int): - if not hasattr(self, 'array') or self.array is None or self.fixed: + if not hasattr(self, "array") or self.array is None or self.fixed: return n_behaviors = self.image.shape[0] if index > n_behaviors: - raise ValueError('Not enough behaviors for number: {}'.format(index)) + raise ValueError("Not enough behaviors for number: {}".format(index)) if index < 0: - raise ValueError('Behavior index cannot be below 0') + raise ValueError("Behavior index cannot be below 0") self.label_toggled[index] = ~self.label_toggled[index] if self.label_toggled[index]: # if background is selected, deselect all others @@ -742,7 +746,6 @@ def toggle_behavior(self, index: int): class LabelButtons(QtWidgets.QWidget): - def __init__(self, parent=None, *args, **kwargs): super().__init__(*args, **kwargs) @@ -755,11 +758,10 @@ def reset(self): self.enabled = None self.minimum_height = None - def initialize(self, - behaviors: Union[list, np.ndarray] = ['background'], - enabled: bool = True, - minimum_height: int = 25): - assert (len(behaviors) > 0) + def initialize( + self, behaviors: Union[list, np.ndarray] = ["background"], enabled: bool = True, minimum_height: int = 25 + ): + assert len(behaviors) > 0 layout = QtWidgets.QVBoxLayout() self.buttons = [] self.behaviors = behaviors @@ -780,18 +782,19 @@ def initialize(self, def _make_button(self, behavior: str, index: int): string = str(behavior) if index < 10: - string = '[{:01d}] '.format(index) + string + string = "[{:01d}] ".format(index) + string button = QtWidgets.QPushButton(string, parent=self) button.setEnabled(self.enabled) button.setMinimumHeight(self.minimum_height) button.setCheckable(True) - button.setStyleSheet("QPushButton { text-align: left; }" - "QPushButton:checked { background-color: rgb(30, 30, 30)}") + button.setStyleSheet( + "QPushButton { text-align: left; }" "QPushButton:checked { background-color: rgb(30, 30, 30)}" + ) return button def add_behavior(self, behavior: str): if behavior in self.behaviors: - warnings.warn('behavior {} already in list'.format(behavior)) + warnings.warn("behavior {} already in list".format(behavior)) else: self.behaviors.append(behavior) button = self._make_button(behavior, len(self.behaviors)) @@ -805,7 +808,6 @@ def fix(self): class LabelImg(QtWidgets.QScrollArea): - def __init__(self, parent=None, *args, **kwargs): super().__init__(parent=parent, *args, **kwargs) @@ -839,34 +841,37 @@ def update_buttons(self): button.setChecked(toggle) self.update() - def initialize(self, - behaviors: Union[list, np.ndarray] = ['background'], - n_timepoints: int = 31, - debug: bool = False, - colormap: str = 'Reds', - unlabeled_alpha: float = 0.1, - desired_pixel_size: int = 25, - array: np.ndarray = None, - fixed: bool = False, - opacity: np.ndarray = None): - + def initialize( + self, + behaviors: Union[list, np.ndarray] = ["background"], + n_timepoints: int = 31, + debug: bool = False, + colormap: str = "Reds", + unlabeled_alpha: float = 0.1, + desired_pixel_size: int = 25, + array: np.ndarray = None, + fixed: bool = False, + opacity: np.ndarray = None, + ): layout = QtWidgets.QHBoxLayout() # assert (n == len(behaviors)) - assert (behaviors[0] == 'background') + assert behaviors[0] == "background" self.label = LabelViewer() # print(behaviors) self.behaviors = behaviors self.n = len(self.behaviors) - self.label.initialize(len(self.behaviors), - n_timepoints, - debug, - colormap, - unlabeled_alpha, - desired_pixel_size, - array, - fixed, - opacity=opacity) + self.label.initialize( + len(self.behaviors), + n_timepoints, + debug, + colormap, + unlabeled_alpha, + desired_pixel_size, + array, + fixed, + opacity=opacity, + ) self.buttons = LabelButtons() enabled = not fixed self.buttons.initialize(self.behaviors, enabled, desired_pixel_size) @@ -890,16 +895,16 @@ def initialize(self, self.update() def add_behavior(self, behavior: str): - print('1: ', self.behaviors, behavior) + print("1: ", self.behaviors, behavior) if behavior in self.behaviors: - warnings.warn('behavior {} already in list'.format(behavior)) + warnings.warn("behavior {} already in list".format(behavior)) # add a button self.buttons.add_behavior(behavior) - print('2: {}'.format(self.behaviors)) - print('2 buttons: {}'.format(self.buttons.behaviors)) + print("2: {}".format(self.behaviors)) + print("2 buttons: {}".format(self.buttons.behaviors)) # add to our list of behaviors # self.behaviors.append(behavior) - print('3: {}'.format(self.behaviors)) + print("3: {}".format(self.behaviors)) # hook up button to toggling behavior i = len(self.behaviors) - 1 print(self.behaviors) @@ -936,7 +941,6 @@ def run(self): class SubprocessChainer(QtCore.QThread): - def __init__(self, calls: list): QtCore.QThread.__init__(self) for call in calls: @@ -974,7 +978,6 @@ def run(self): class UnclickButtonOnPipeCompletion(QtCore.QThread): - def __init__(self, button, pipe): QtCore.QThread.__init__(self) # super().__init__(self) @@ -989,7 +992,7 @@ def __del__(self): @Slot(bool) def get_click(self, value): - print('clicked') + print("clicked") self.has_been_clicked = True def run(self): @@ -1008,23 +1011,24 @@ def run(self): class MainWindow(QtWidgets.QMainWindow): - def __init__(self): super().__init__() self.label = LabelImg(self) - self.label.initialize(behaviors=['background', 'itch', 'lick', 'scratch', 'shit', 'fuck', 'ass', 'bitch'], - n_timepoints=500, - debug=True, - fixed=False) + self.label.initialize( + behaviors=["background", "itch", "lick", "scratch", "shit", "fuck", "ass", "bitch"], + n_timepoints=500, + debug=True, + fixed=False, + ) # self.label = LabelViewer() # self.label.initialize(n=4, n_timepoints=40, debug=True, fixed=True) # # self.labelImg = DebuggingDrawing() # - next_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence('Right'), self) + next_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence("Right"), self) next_shortcut.activated.connect(partial(self.label.label.change_view_dx, 1)) # next_shortcut.activated.connect(partial(self.label.change_view_dx, 1)) - back_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence('Left'), self) + back_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence("Left"), self) back_shortcut.activated.connect(partial(self.label.label.change_view_dx, -1)) # # if hasattr(self, 'label'): @@ -1047,14 +1051,14 @@ def __init__(self): self.update() def sizeHint(self): - return (QtCore.QSize(600, 600)) + return QtCore.QSize(600, 600) -if __name__ == '__main__': +if __name__ == "__main__": app = QtWidgets.QApplication([]) # volume = VideoPlayer(r'C:\DATA\mouse_reach_processed\M134_20141203_v001.h5') testing = LabelImg() - testing.initialize(behaviors=['background', 'a', 'b', 'c', 'd', 'e'], n_timepoints=15000, debug=True) + testing.initialize(behaviors=["background", "a", "b", "c", "d", "e"], n_timepoints=15000, debug=True) # testing = ShouldRunInference(['M134_20141203_v001', # 'M134_20141203_v002', # 'M134_20141203_v004'], diff --git a/deepethogram/gui/main.py b/deepethogram/gui/main.py index 32cdf9a..0fd0775 100644 --- a/deepethogram/gui/main.py +++ b/deepethogram/gui/main.py @@ -11,7 +11,7 @@ import pandas as pd from PySide2 import QtCore, QtWidgets, QtGui from PySide2.QtCore import Slot -from PySide2.QtWidgets import (QMainWindow, QFileDialog, QInputDialog) +from PySide2.QtWidgets import QMainWindow, QFileDialog, QInputDialog from omegaconf import DictConfig, OmegaConf from deepethogram import projects, utils, configuration @@ -23,8 +23,10 @@ log = logging.getLogger(__name__) -pretrained_models_error = 'Dont train flow generator without pretrained weights! ' + \ - 'See the project GitHub for instructions on downloading weights: https://github.com/jbohnslav/deepethogram' +pretrained_models_error = ( + "Dont train flow generator without pretrained weights! " + + "See the project GitHub for instructions on downloading weights: https://github.com/jbohnslav/deepethogram" +) class MainWindow(QMainWindow): @@ -38,7 +40,7 @@ def __init__(self, cfg: DictConfig): self.ui = Ui_MainWindow() self.ui.setupUi(self) - self.setWindowTitle('DeepEthogram') + self.setWindowTitle("DeepEthogram") # print(dir(self.ui.actionOpen)) self.ui.videoBox.setLayout(self.ui.formLayout) @@ -64,37 +66,37 @@ def __init__(self, cfg: DictConfig): # scroll down to Standard Shorcuts to find what the keys are called: # https://doc.qt.io/qt-5/qkeysequence.html - next_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence('Right'), self) + next_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence("Right"), self) # partial functions create a new, separate function with certain default arguments next_shortcut.activated.connect(partial(self.move_n_frames, 1)) next_shortcut.activated.connect(self.user_did_something) - up_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence('Up'), self) + up_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence("Up"), self) up_shortcut.activated.connect(partial(self.move_n_frames, -cfg.vertical_arrow_jump)) up_shortcut.activated.connect(self.user_did_something) - down_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence('Down'), self) + down_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence("Down"), self) down_shortcut.activated.connect(partial(self.move_n_frames, cfg.vertical_arrow_jump)) down_shortcut.activated.connect(self.user_did_something) - back_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence('Left'), self) + back_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence("Left"), self) back_shortcut.activated.connect(partial(self.move_n_frames, -1)) back_shortcut.activated.connect(self.user_did_something) - jumpleft_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence('Ctrl+Left'), self) + jumpleft_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence("Ctrl+Left"), self) jumpleft_shortcut.activated.connect(partial(self.move_n_frames, -cfg.control_arrow_jump)) jumpleft_shortcut.activated.connect(self.user_did_something) - jumpright_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence('Ctrl+Right'), self) + jumpright_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence("Ctrl+Right"), self) jumpright_shortcut.activated.connect(partial(self.move_n_frames, cfg.control_arrow_jump)) jumpright_shortcut.activated.connect(self.user_did_something) self.ui.actionSave_Project.triggered.connect(self.save) - save_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence('Ctrl+S'), self) + save_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence("Ctrl+S"), self) save_shortcut.activated.connect(self.save) - finalize_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence('Ctrl+F'), self) + finalize_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence("Ctrl+F"), self) finalize_shortcut.activated.connect(self.finalize) - open_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence('Ctrl+O'), self) + open_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence("Ctrl+O"), self) open_shortcut.activated.connect(self.load_project) self.ui.finalize_labels.clicked.connect(self.finalize) self.ui.exportPredictions.clicked.connect(self.export_predictions) @@ -122,7 +124,7 @@ def __init__(self, cfg: DictConfig): # the current directory is where_user_launched/gui_logs/date_time_runlog initialized_directory = os.path.dirname(os.path.dirname(os.getcwd())) - if os.path.isfile(os.path.join(initialized_directory, 'project_config.yaml')): + if os.path.isfile(os.path.join(initialized_directory, "project_config.yaml")): self.initialize_project(initialized_directory) # log.info('children: {}'.format(self.children())) @@ -134,7 +136,7 @@ def user_did_something(self): else: # else, the user was already idle # will have a timestamp by the logfile - log.info('User restarted labeling') + log.info("User restarted labeling") self.timer.start() def keyPressEvent(self, event: QtGui.QKeyEvent): @@ -147,7 +149,7 @@ def mousePressEvent(self, event: QtGui.QMouseEvent): super().mousePressEvent(event) def log_idle(self): - log.info('User has been idle for {} seconds...'.format(float(self.timer.interval()) / 1000)) + log.info("User has been idle for {} seconds...".format(float(self.timer.interval()) / 1000)) self.timer.stop() def respond_to_keypress(self, keynum: int): @@ -171,11 +173,11 @@ def project_loaded_buttons(self): self.ui.actionAdd.setEnabled(True) self.ui.actionRemove.setEnabled(True) number_finalized_labels = projects.get_number_finalized_labels(self.cfg) - log.info('Number finalized labels: {}'.format(number_finalized_labels)) - if self.has_trained('flow_generator'): + log.info("Number finalized labels: {}".format(number_finalized_labels)) + if self.has_trained("flow_generator"): # self.ui.flow_inference.setEnabled(True) self.ui.flow_train.setEnabled(True) - if self.has_trained('feature_extractor') or number_finalized_labels > 1: + if self.has_trained("feature_extractor") or number_finalized_labels > 1: self.ui.featureextractor_infer.setEnabled(True) self.ui.featureextractor_train.setEnabled(True) @@ -184,13 +186,13 @@ def project_loaded_buttons(self): if self.data_path is not None: records = projects.get_records_from_datadir(self.data_path) for animal, record in records.items(): - if record['output'] is not None and os.path.isfile(record['output']): + if record["output"] is not None and os.path.isfile(record["output"]): n_output_files += 1 - if self.has_trained('feature_extractor') and n_output_files > 2: + if self.has_trained("feature_extractor") and n_output_files > 2: self.ui.sequence_train.setEnabled(True) - if self.has_trained('sequence') and n_output_files > 0: + if self.has_trained("sequence") and n_output_files > 0: self.ui.sequence_infer.setEnabled(True) self.ui.classifierInference.setEnabled(True) @@ -205,7 +207,7 @@ def video_loaded_buttons(self): self.ui.flow_train.setEnabled(True) def initialize_video(self, videofile: Union[str, os.PathLike]): - if hasattr(self, 'vid'): + if hasattr(self, "vid"): self.vid.close() # if hasattr(self.vid, 'cap'): # self.vid.cap.release() @@ -217,14 +219,15 @@ def initialize_video(self, videofile: Union[str, os.PathLike]): # for convenience self.n_timepoints = len(self.ui.videoPlayer.videoView.vid) - log.debug('is deg: {}'.format(projects.is_deg_file(videofile))) + log.debug("is deg: {}".format(projects.is_deg_file(videofile))) - if os.path.normpath( - self.cfg.project.data_path) in os.path.normpath(videofile) and projects.is_deg_file(videofile): + if os.path.normpath(self.cfg.project.data_path) in os.path.normpath(videofile) and projects.is_deg_file( + videofile + ): record = projects.get_record_from_subdir(os.path.dirname(videofile)) - log.info('Record for loaded video: {}'.format(record)) - labelfile = record['label'] - outputfile = record['output'] + log.info("Record for loaded video: {}".format(record)) + labelfile = record["label"] + outputfile = record["output"] self.labelfile = labelfile self.outputfile = outputfile if labelfile is not None: @@ -237,22 +240,25 @@ def initialize_video(self, videofile: Union[str, os.PathLike]): self.ui.predictionsCombo.clear() self.initialize_prediction() else: - log.info('Copying {} to your DEG directory'.format(videofile)) + log.info("Copying {} to your DEG directory".format(videofile)) new_loc = projects.add_video_to_project(OmegaConf.to_container(self.cfg), videofile) - log.debug('New video location: {}'.format(new_loc)) + log.debug("New video location: {}".format(new_loc)) self.videofile = new_loc - log.debug('New record: {}'.format( - utils.load_yaml(os.path.join(os.path.dirname(self.videofile), 'record.yaml')))) + log.debug( + "New record: {}".format( + utils.load_yaml(os.path.join(os.path.dirname(self.videofile), "record.yaml")) + ) + ) self.initialize_label() self.initialize_prediction() self.video_loaded_buttons() except BaseException as e: - log.exception('Error initializing video: {}'.format(e)) + log.exception("Error initializing video: {}".format(e)) tb = traceback.format_exc() print(tb) return self.ui.videoPlayer.videoView.update_frame(0, force=True) - self.setWindowTitle('DeepEthogram: {}'.format(self.cfg.project.name)) + self.setWindowTitle("DeepEthogram: {}".format(self.cfg.project.name)) self.update_video_info() self.user_did_something() @@ -264,56 +270,58 @@ def update_video_info(self): try: fps = reader.fps duration = nframes / fps - fps = '{:.2f}'.format(fps) - duration = '{:.2f}'.format(duration) + fps = "{:.2f}".format(fps) + duration = "{:.2f}".format(duration) except: - fps = 'N/A' - duration = 'N/A' + fps = "N/A" + duration = "N/A" num_labeled = self.ui.labels.label.changed.sum() self.ui.nameLabel.setText(name) - self.ui.nframesLabel.setText('{:,}'.format(nframes)) - self.ui.nlabeledLabel.setText('{:,}'.format(num_labeled)) - self.ui.nunlabeledLabel.setText('{:,}'.format(nframes - num_labeled)) + self.ui.nframesLabel.setText("{:,}".format(nframes)) + self.ui.nlabeledLabel.setText("{:,}".format(num_labeled)) + self.ui.nunlabeledLabel.setText("{:,}".format(nframes - num_labeled)) self.ui.durationLabel.setText(duration) self.ui.fpsLabel.setText(fps) self.ui.labels.label.num_changed.connect(self.update_num_labeled) def update_num_labeled(self, n: Union[int, str]): - self.ui.nlabeledLabel.setText('{:,}'.format(n)) - self.ui.nunlabeledLabel.setText('{:,}'.format(self.n_timepoints - n)) + self.ui.nlabeledLabel.setText("{:,}".format(n)) + self.ui.nunlabeledLabel.setText("{:,}".format(self.n_timepoints - n)) def initialize_label(self, label_array: np.ndarray = None, debug: bool = False): if self.cfg.project is None: - raise ValueError('must load or create project before initializing a label!') - self.ui.labels.initialize(behaviors=OmegaConf.to_container(self.cfg.project.class_names), - n_timepoints=self.n_timepoints, - debug=debug, - fixed=False, - array=label_array, - colormap=self.cfg.cmap) + raise ValueError("must load or create project before initializing a label!") + self.ui.labels.initialize( + behaviors=OmegaConf.to_container(self.cfg.project.class_names), + n_timepoints=self.n_timepoints, + debug=debug, + fixed=False, + array=label_array, + colormap=self.cfg.cmap, + ) # we never want to connect signals to slots more than once - log.debug('initialized label: {}'.format(self.initialized_label)) + log.debug("initialized label: {}".format(self.initialized_label)) # if not self.initialized_label: self.ui.videoPlayer.videoView.frameNum.connect(self.ui.labels.label.change_view_x) self.ui.labels.label.saved.connect(self.update_saved) self.initialized_label = True self.update() - def initialize_prediction(self, - prediction_array: np.ndarray = None, - debug: bool = False, - opacity: np.ndarray = None): - + def initialize_prediction( + self, prediction_array: np.ndarray = None, debug: bool = False, opacity: np.ndarray = None + ): # do all the setup for labels and predictions - self.ui.predictions.initialize(behaviors=OmegaConf.to_container(self.cfg.project.class_names), - n_timepoints=self.n_timepoints, - debug=debug, - fixed=True, - array=prediction_array, - opacity=opacity, - colormap=self.cfg.cmap) + self.ui.predictions.initialize( + behaviors=OmegaConf.to_container(self.cfg.project.class_names), + n_timepoints=self.n_timepoints, + debug=debug, + fixed=True, + array=prediction_array, + opacity=opacity, + colormap=self.cfg.cmap, + ) # if not self.initialized_prediction: self.ui.videoPlayer.videoView.frameNum.connect(self.ui.predictions.label.change_view_x) # we don't want to be able to manually edit the predictions @@ -322,12 +330,12 @@ def initialize_prediction(self, self.update() def generate_flow_train_args(self): - args = ['python', '-m', 'deepethogram.flow_generator.train', 'project.path={}'.format(self.cfg.project.path)] - weights = self.get_selected_models()['flow_generator'] + args = ["python", "-m", "deepethogram.flow_generator.train", "project.path={}".format(self.cfg.project.path)] + weights = self.get_selected_models()["flow_generator"] if weights is None: raise ValueError(pretrained_models_error) if weights is not None and os.path.isfile(weights): - args += ['reload.weights={}'.format(weights)] + args += ["reload.weights={}".format(weights)] return args def flow_train(self): @@ -339,7 +347,7 @@ def flow_train(self): self.ui.sequence_train.setEnabled(False) args = self.generate_flow_train_args() - log.info('flow_train called with args: {}'.format(args)) + log.info("flow_train called with args: {}".format(args)) self.training_pipe = subprocess.Popen(args) self.listener = UnclickButtonOnPipeCompletion(self.ui.flow_train, self.training_pipe) self.listener.start() @@ -347,15 +355,15 @@ def flow_train(self): if self.training_pipe.poll() is None: self.training_pipe.terminate() self.training_pipe.wait() - log.info('Training interrupted.') + log.info("Training interrupted.") else: - log.info('Training finished. If you see error messages above, training did not complete successfully.') + log.info("Training finished. If you see error messages above, training did not complete successfully.") # self.train_thread.terminate() del self.training_pipe self.listener.quit() self.listener.wait() del self.listener - log.info('~' * 100) + log.info("~" * 100) self.project_loaded_buttons() self.get_trained_models() @@ -369,19 +377,22 @@ def featureextractor_train(self): self.ui.sequence_train.setEnabled(False) args = [ - 'python', '-m', 'deepethogram.feature_extractor.train', 'project.path={}'.format(self.cfg.project.path) + "python", + "-m", + "deepethogram.feature_extractor.train", + "project.path={}".format(self.cfg.project.path), ] print(self.get_selected_models()) - weights = self.get_selected_models()['feature_extractor'] + weights = self.get_selected_models()["feature_extractor"] # print(weights) if weights is None: raise ValueError(pretrained_models_error) if os.path.isfile(weights): - args += ['feature_extractor.weights={}'.format(weights)] - flow_weights = self.get_selected_models()['flow_generator'] # ('flow_generator') + args += ["feature_extractor.weights={}".format(weights)] + flow_weights = self.get_selected_models()["flow_generator"] # ('flow_generator') assert flow_weights is not None - args += ['flow_generator.weights={}'.format(flow_weights)] - log.info('feature extractor train called with args: {}'.format(args)) + args += ["flow_generator.weights={}".format(flow_weights)] + log.info("feature extractor train called with args: {}".format(args)) self.training_pipe = subprocess.Popen(args) self.listener = UnclickButtonOnPipeCompletion(self.ui.featureextractor_train, self.training_pipe) self.listener.start() @@ -389,14 +400,14 @@ def featureextractor_train(self): if self.training_pipe.poll() is None: self.training_pipe.terminate() self.training_pipe.wait() - log.info('Training interrupted.') + log.info("Training interrupted.") else: - log.info('Training finished. If you see error messages above, training did not complete successfully.') + log.info("Training finished. If you see error messages above, training did not complete successfully.") del self.training_pipe self.listener.quit() self.listener.wait() del self.listener - log.info('~' * 100) + log.info("~" * 100) # self.ui.flow_train.setEnabled(True) self.project_loaded_buttons() self.get_trained_models() @@ -409,7 +420,7 @@ def generate_featureextractor_inference_args(self): keys, no_outputs = [], [] for key, record in records.items(): keys.append(key) - no_outputs.append(record['output'] is None) + no_outputs.append(record["output"] is None) form = ShouldRunInference(keys, no_outputs) ret = form.exec_() if not ret: @@ -421,25 +432,29 @@ def generate_featureextractor_inference_args(self): self.ui.flow_train.setEnabled(False) # self.ui.flow_inference.setEnabled(False) self.ui.featureextractor_train.setEnabled(False) - weights = self.get_selected_models()['feature_extractor'] + weights = self.get_selected_models()["feature_extractor"] if weights is not None and os.path.isfile(weights): - weight_arg = 'feature_extractor.weights={}'.format(weights) + weight_arg = "feature_extractor.weights={}".format(weights) else: - raise ValueError('Dont run inference without using a proper feature extractor weights! {}'.format(weights)) + raise ValueError("Dont run inference without using a proper feature extractor weights! {}".format(weights)) args = [ - 'python', '-m', 'deepethogram.feature_extractor.inference', 'project.path={}'.format(self.cfg.project.path), - 'inference.overwrite=True', weight_arg + "python", + "-m", + "deepethogram.feature_extractor.inference", + "project.path={}".format(self.cfg.project.path), + "inference.overwrite=True", + weight_arg, ] - flow_weights = self.get_selected_models()['flow_generator'] + flow_weights = self.get_selected_models()["flow_generator"] assert flow_weights is not None - args += ['flow_generator.weights={}'.format(flow_weights)] - string = 'inference.directory_list=[' + args += ["flow_generator.weights={}".format(flow_weights)] + string = "inference.directory_list=[" for key, infer in zip(keys, should_infer): if infer: - record_dir = os.path.join(self.data_path, key) + ',' + record_dir = os.path.join(self.data_path, key) + "," string += record_dir - string = string[:-1] + ']' + string = string[:-1] + "]" args += [string] return args @@ -448,7 +463,7 @@ def featureextractor_infer(self): args = self.generate_featureextractor_inference_args() if args is None: return - log.info('inference running with args: {}'.format(' '.join(args))) + log.info("inference running with args: {}".format(" ".join(args))) self.inference_pipe = subprocess.Popen(args) self.listener = UnclickButtonOnPipeCompletion(self.ui.featureextractor_infer, self.inference_pipe) self.listener.start() @@ -456,19 +471,20 @@ def featureextractor_infer(self): if self.inference_pipe.poll() is None: self.inference_pipe.terminate() self.inference_pipe.wait() - log.info('Inference interrupted.') + log.info("Inference interrupted.") else: log.info( - 'Inference finished. If you see error messages above, inference did not complete successfully.') + "Inference finished. If you see error messages above, inference did not complete successfully." + ) del self.inference_pipe self.listener.quit() self.listener.wait() del self.listener - log.info('~' * 100) + log.info("~" * 100) self.project_loaded_buttons() record = projects.get_record_from_subdir(os.path.dirname(self.videofile)) - if record['output'] is not None: - self.outputfile = record['output'] + if record["output"] is not None: + self.outputfile = record["output"] else: self.outputfile = None self.import_outputfile(self.outputfile) @@ -483,10 +499,10 @@ def sequence_train(self): self.ui.featureextractor_infer.setEnabled(False) self.ui.sequence_infer.setEnabled(False) # self.ui.sequence_train.setEnabled(False) - args = ['python', '-m', 'deepethogram.sequence.train', 'project.path={}'.format(self.cfg.project.path)] - weights = self.get_selected_models()['sequence'] + args = ["python", "-m", "deepethogram.sequence.train", "project.path={}".format(self.cfg.project.path)] + weights = self.get_selected_models()["sequence"] if weights is not None and os.path.isfile(weights): - args += ['reload.weights={}'.format(weights)] + args += ["reload.weights={}".format(weights)] self.training_pipe = subprocess.Popen(args) self.listener = UnclickButtonOnPipeCompletion(self.ui.sequence_train, self.training_pipe) self.listener.start() @@ -495,14 +511,14 @@ def sequence_train(self): if self.training_pipe.poll() is None: self.training_pipe.terminate() self.training_pipe.wait() - log.info('Training interrupted.') + log.info("Training interrupted.") else: - log.info('Training finished. If you see error messages above, training did not complete successfully.') + log.info("Training finished. If you see error messages above, training did not complete successfully.") del self.training_pipe self.listener.quit() self.listener.wait() del self.listener - log.info('~' * 100) + log.info("~" * 100) # self.ui.flow_train.setEnabled(True) self.project_loaded_buttons() self.get_trained_models() @@ -512,21 +528,21 @@ def generate_sequence_inference_args(self): records = projects.get_records_from_datadir(self.data_path) keys = list(records.keys()) outputs = projects.has_outputfile(records) - sequence_weights = self.get_selected_models()['sequence'] + sequence_weights = self.get_selected_models()["sequence"] if sequence_weights is not None and os.path.isfile(sequence_weights): run_files = utils.get_run_files_from_weights(sequence_weights) - sequence_config = OmegaConf.load(run_files['config_file']) + sequence_config = OmegaConf.load(run_files["config_file"]) # sequence_config = utils.load_yaml(os.path.join(os.path.dirname(sequence_weights), 'config.yaml')) - latent_name = sequence_config['sequence']['latent_name'] + latent_name = sequence_config["sequence"]["latent_name"] if latent_name is None: - latent_name = sequence_config['feature_extractor']['arch'] - output_name = sequence_config['sequence']['output_name'] + latent_name = sequence_config["feature_extractor"]["arch"] + output_name = sequence_config["sequence"]["output_name"] if output_name is None: - output_name = sequence_config['sequence']['arch'] + output_name = sequence_config["sequence"]["arch"] else: - raise ValueError('must specify a valid weight file to run sequence inference!') + raise ValueError("must specify a valid weight file to run sequence inference!") - log.debug('latent name: {}'.format(latent_name)) + log.debug("latent name: {}".format(latent_name)) # sequence_name, _ = utils.get_latest_model_and_name(self.project_config['project']['path'], 'sequence') # GOAL: MAKE ONLY FILES WITH LATENT_NAME PRESENT APPEAR ON LIST @@ -547,31 +563,34 @@ def generate_sequence_inference_args(self): all_false = np.all(np.array(should_infer) == False) if all_false: return - weights = self.get_selected_models()['sequence'] + weights = self.get_selected_models()["sequence"] if weights is not None and os.path.isfile(weights): - weight_arg = 'sequence.weights={}'.format(weights) + weight_arg = "sequence.weights={}".format(weights) else: - raise ValueError('weights do not exist! {}'.format(weights)) + raise ValueError("weights do not exist! {}".format(weights)) args = [ - 'python', '-m', 'deepethogram.sequence.inference', 'project.path={}'.format(self.cfg.project.path), - 'inference.overwrite=True', weight_arg + "python", + "-m", + "deepethogram.sequence.inference", + "project.path={}".format(self.cfg.project.path), + "inference.overwrite=True", + weight_arg, ] - string = 'inference.directory_list=[' + string = "inference.directory_list=[" for key, infer in zip(keys, should_infer): if infer: - record_dir = os.path.join(self.data_path, key) + ',' + record_dir = os.path.join(self.data_path, key) + "," string += record_dir - string = string[:-1] + ']' + string = string[:-1] + "]" args += [string] return args def sequence_infer(self): if self.ui.sequence_infer.isChecked(): - args = self.generate_sequence_inference_args() if args is None: return - log.info('sequence inference running with args: {}'.format(args)) + log.info("sequence inference running with args: {}".format(args)) self.inference_pipe = subprocess.Popen(args) self.listener = UnclickButtonOnPipeCompletion(self.ui.sequence_infer, self.inference_pipe) self.listener.start() @@ -579,10 +598,10 @@ def sequence_infer(self): if self.inference_pipe.poll() is None: self.inference_pipe.terminate() self.inference_pipe.wait() - log.info('Inference interrupted.') + log.info("Inference interrupted.") else: - log.info('Inference finished') - log.info('~' * 100) + log.info("Inference finished") + log.info("~" * 100) del self.inference_pipe self.listener.quit() self.listener.wait() @@ -590,8 +609,8 @@ def sequence_infer(self): self.project_loaded_buttons() # del self.listener record = projects.get_record_from_subdir(os.path.dirname(self.videofile)) - if record['output'] is not None: - self.outputfile = record['output'] + if record["output"] is not None: + self.outputfile = record["output"] else: self.outputfile = None self.import_outputfile(self.outputfile, first_time=True) @@ -602,7 +621,7 @@ def classifier_inference(self): sequence_args = self.generate_sequence_inference_args() if fe_args is None or sequence_args is None: - log.error('Erroneous arguments to fe or seq: {}, {}'.format(fe_args, sequence_args)) + log.error("Erroneous arguments to fe or seq: {}, {}".format(fe_args, sequence_args)) calls = [fe_args, sequence_args] @@ -625,11 +644,11 @@ def run_overnight(self): sequence_args = self.generate_sequence_inference_args() if flow_args is None: - log.error('Erroneous flow arguments in run overnight: {}'.format(flow_args)) + log.error("Erroneous flow arguments in run overnight: {}".format(flow_args)) if fe_args is None: - log.error('Erroneous fe arguments in run overnight: {}'.format(fe_args)) + log.error("Erroneous fe arguments in run overnight: {}".format(fe_args)) if sequence_args is None: - log.error('Erroneous seq arguments in run overnight: {}'.format(sequence_args)) + log.error("Erroneous seq arguments in run overnight: {}".format(sequence_args)) calls = [flow_args, fe_args, sequence_args] # calls = [['ping', 'localhost', '-n', '10'], ['dir']] @@ -645,60 +664,59 @@ def run_overnight(self): # print(should_be_checked) def _new_project(self): - form = CreateProject() ret = form.exec_() if not ret: return project_name = form.project_box.text() if project_name == form.project_name_default: - log.warning('Must change project name') + log.warning("Must change project name") return labeler = form.labeler_box.text() if labeler == form.label_default_string: - log.warning('Must specify a labeler') + log.warning("Must specify a labeler") return behaviors = form.behaviors_box.text() if behaviors == form.behavior_default_string: - log.warning('Must add list of behaviors') + log.warning("Must add list of behaviors") return - project_name = project_name.replace(' ', '_') - labeler = labeler.replace(' ', '_') - behaviors = behaviors.replace(' ', '') - behaviors = behaviors.split(',') - behaviors.insert(0, 'background') + project_name = project_name.replace(" ", "_") + labeler = labeler.replace(" ", "_") + behaviors = behaviors.replace(" ", "") + behaviors = behaviors.split(",") + behaviors.insert(0, "background") project_dict = projects.initialize_project(form.project_directory, project_name, behaviors, labeler) - self.initialize_project(project_dict['project']['path']) + self.initialize_project(project_dict["project"]["path"]) def add_class(self): - text, ok = QInputDialog.getText(self, 'AddBehaviorDialog', 'Enter behavior name: ') + text, ok = QInputDialog.getText(self, "AddBehaviorDialog", "Enter behavior name: ") if len(text) == 0: - log.warning('No behavior entered') + log.warning("No behavior entered") ok = False if not ok: return if text in self.cfg.project.class_names: - log.warning('This behavior is already in the list...') + log.warning("This behavior is already in the list...") return # self.add_class() - text = text.replace(' ', '_') - log.info('new behavior name: {}'.format(text)) + text = text.replace(" ", "_") + log.info("new behavior name: {}".format(text)) - message = '''Are you sure you want to add behavior {}? + message = """Are you sure you want to add behavior {}? All previous labels will have a blank column added that must be labeled. - Feature extractor and sequence models will need to be retrained. ' + Feature extractor and sequence models will need to be retrained. ' Inference files will be deleted, and feature extractor inference must be re-run. - If you have not exported predictions to .CSV, make sure you do so now!'''.format(text) + If you have not exported predictions to .CSV, make sure you do so now!""".format(text) if not simple_popup_question(self, message): return if not self.saved: - if simple_popup_question(self, 'You have unsaved changes. Do you want to save them first?'): + if simple_popup_question(self, "You have unsaved changes. Do you want to save them first?"): self.save() - projects.add_behavior_to_project(os.path.join(self.cfg.project.path, 'project_config.yaml'), text) + projects.add_behavior_to_project(os.path.join(self.cfg.project.path, "project_config.yaml"), text) behaviors = OmegaConf.to_container(self.cfg.project.class_names) behaviors.append(text) self.cfg.project.class_names = behaviors @@ -715,27 +733,27 @@ def add_class(self): self.update() def remove_class(self): - text, ok = QInputDialog.getText(self, 'RemoveBehaviorDialog', 'Enter behavior name: ') + text, ok = QInputDialog.getText(self, "RemoveBehaviorDialog", "Enter behavior name: ") if len(text) == 0: - log.warning('No behavior entered') + log.warning("No behavior entered") ok = False if not ok: return if text not in self.cfg.project.class_names: - log.warning('This behavior is not in the list...') + log.warning("This behavior is not in the list...") return - if text == 'background': - raise ValueError('Cannot remove background class.') + if text == "background": + raise ValueError("Cannot remove background class.") - message = '''Are you sure you want to remove behavior {}? + message = """Are you sure you want to remove behavior {}? All previous labels will have be DELETED! - Feature extractor and sequence models will need to be retrained. ' + Feature extractor and sequence models will need to be retrained. ' Inference files will be deleted, and feature extractor inference must be re-run. - If you have not exported predictions to .CSV, make sure you do so now!'''.format(text) + If you have not exported predictions to .CSV, make sure you do so now!""".format(text) if not simple_popup_question(self, message): return - projects.remove_behavior_from_project(os.path.join(self.cfg.project.path, 'project_config.yaml'), text) + projects.remove_behavior_from_project(os.path.join(self.cfg.project.path, "project_config.yaml"), text) behaviors = OmegaConf.to_container(self.cfg.project.class_names) behaviors.remove(text) @@ -752,10 +770,12 @@ def remove_class(self): self.update() def finalize(self): - if not hasattr(self, 'cfg'): - raise ValueError('cant finalize labels without starting or loading a DEG project') - message = 'Are you sure you want to continue? All non-labeled frames will be labeled as *background*.\n' \ - 'This is not reversible.' + if not hasattr(self, "cfg"): + raise ValueError("cant finalize labels without starting or loading a DEG project") + message = ( + "Are you sure you want to continue? All non-labeled frames will be labeled as *background*.\n" + "This is not reversible." + ) if not simple_popup_question(self, message): return @@ -767,9 +787,9 @@ def finalize(self): # self.save() # else: # return - log.info('finalizing labels for file {}'.format(self.videofile)) + log.info("finalizing labels for file {}".format(self.videofile)) fname, _ = os.path.splitext(self.videofile) - label_fname = fname + '_labels.csv' + label_fname = fname + "_labels.csv" if not os.path.isfile(label_fname): label = self.ui.labels.label.array df = pd.DataFrame(label, columns=self.cfg.project.class_names) @@ -788,18 +808,17 @@ def finalize(self): self.user_did_something() self.unfinalized_idx += 1 if self.unfinalized_idx >= len(self.unfinalized): - raise ValueError('no more videos to label! you can still use the GUI to browse or fix labels') - self.initialize_video(self.unfinalized[self.unfinalized_idx]['rgb']) - + raise ValueError("no more videos to label! you can still use the GUI to browse or fix labels") + self.initialize_video(self.unfinalized[self.unfinalized_idx]["rgb"]) def save(self): if self.saved: # do nothing return - log.info('saving...') + log.info("saving...") df = self._make_dataframe() fname, _ = os.path.splitext(self.videofile) - label_fname = fname + '_labels.csv' + label_fname = fname + "_labels.csv" df.to_csv(label_fname) projects.add_file_to_subdir(label_fname, os.path.dirname(self.videofile)) # self.save_to_hdf5() @@ -808,7 +827,7 @@ def save(self): def import_labelfile(self, labelfile: Union[str, os.PathLike]): if labelfile is None: self.initialize_label() - assert (os.path.isfile(labelfile)) + assert os.path.isfile(labelfile) df = pd.read_csv(labelfile, index_col=0) array = df.values self.initialize_label(label_array=array) @@ -817,42 +836,39 @@ def import_external_labels(self): if self.data_path is not None: data_dir = self.data_path else: - raise ValueError('create or load a DEG project before importing video') + raise ValueError("create or load a DEG project before importing video") options = QFileDialog.Options() - filestring = 'Label file (*.csv)' - labelfile, _ = QFileDialog.getOpenFileName(self, - "Click on labels to import", - data_dir, - filestring, - options=options) + filestring = "Label file (*.csv)" + labelfile, _ = QFileDialog.getOpenFileName( + self, "Click on labels to import", data_dir, filestring, options=options + ) if projects.is_deg_file(labelfile): - raise ValueError('Don' 't use this to open labels: use to import non-DeepEthogram labels') - filestring = 'VideoReader files (*.h5 *.avi *.mp4)' - videofile, _ = QFileDialog.getOpenFileName(self, - "Click on corresponding video file", - data_dir, - filestring, - options=options) + raise ValueError("Don" "t use this to open labels: use to import non-DeepEthogram labels") + filestring = "VideoReader files (*.h5 *.avi *.mp4)" + videofile, _ = QFileDialog.getOpenFileName( + self, "Click on corresponding video file", data_dir, filestring, options=options + ) if not projects.is_deg_file(videofile): - raise ValueError('Please select the already-imported video file that corresponds to the label file.') + raise ValueError("Please select the already-imported video file that corresponds to the label file.") label_dst = projects.add_label_to_project(labelfile, videofile) self.import_labelfile(label_dst) def import_outputfile(self, outputfile: Union[str, os.PathLike], latent_name=None, first_time: bool = False): - if outputfile is None: self.initialize_prediction() return try: - outputs = projects.import_outputfile(self.cfg.project.path, - outputfile, - class_names=OmegaConf.to_container(self.cfg.project.class_names), - latent_name=latent_name) + outputs = projects.import_outputfile( + self.cfg.project.path, + outputfile, + class_names=OmegaConf.to_container(self.cfg.project.class_names), + latent_name=latent_name, + ) except ValueError as e: log.exception(e) - print('If you got a broadcasting error: did you add or remove behaviors and not re-train?') + print("If you got a broadcasting error: did you add or remove behaviors and not re-train?") self.initialize_prediction() return probabilities, thresholds, latent_name, keys = outputs @@ -867,12 +883,12 @@ def import_outputfile(self, outputfile: Union[str, os.PathLike], latent_name=Non self.thresholds = thresholds opacity = estimated_labels.copy().T.astype(float) - log.debug('estimated labels: {}'.format(opacity)) + log.debug("estimated labels: {}".format(opacity)) opacity[opacity == 0] = self.cfg.prediction_opacity - log.debug('opacity array: {}'.format(opacity)) + log.debug("opacity array: {}".format(opacity)) if np.any(probabilities > 1): - log.warning('Probabilities > 1 found, clamping...') + log.warning("Probabilities > 1 found, clamping...") probabilities = probabilities.clip(min=0, max=1.0) # import pdb @@ -882,10 +898,10 @@ def import_outputfile(self, outputfile: Union[str, os.PathLike], latent_name=Non self.initialize_prediction(prediction_array=probabilities, opacity=opacity) self.ui.importPredictions.setEnabled(True) self.ui.exportPredictions.setEnabled(True) - log.info('CHANGING LATENT NAME TO : {}'.format(latent_name)) + log.info("CHANGING LATENT NAME TO : {}".format(latent_name)) self.latent_name = latent_name - log.debug('keys: {}'.format(keys)) + log.debug("keys: {}".format(keys)) if first_time: self.ui.predictionsCombo.blockSignals(True) self.ui.predictionsCombo.clear() @@ -904,26 +920,28 @@ def export_predictions(self): print(df) print(df.sum(axis=0)) fname, _ = os.path.splitext(self.videofile) - prediction_fname = fname + '_predictions.csv' + prediction_fname = fname + "_predictions.csv" df.to_csv(prediction_fname) def change_predictions(self, new_text): - log.debug('change predictions called with text: {}'.format(new_text)) - log.debug('current latent name: {}'.format(self.latent_name)) - if not hasattr(self, 'outputfile') or new_text is None: + log.debug("change predictions called with text: {}".format(new_text)) + log.debug("current latent name: {}".format(self.latent_name)) + if not hasattr(self, "outputfile") or new_text is None: return if self.latent_name != new_text: - log.debug('not equal found: {}, {}'.format(self.latent_name, new_text)) + log.debug("not equal found: {}, {}".format(self.latent_name, new_text)) # self.import_outputfile(self.outputfile, latent_name=new_text) # log.warning('prediction import not implemented') def import_predictions_as_labels(self): - if not hasattr(self, 'estimated_labels'): - raise ValueError('Cannot import predictions before an outputfile has been imported.\n' - 'Run inference on the feature extractors or sequence models first.') + if not hasattr(self, "estimated_labels"): + raise ValueError( + "Cannot import predictions before an outputfile has been imported.\n" + "Run inference on the feature extractors or sequence models first." + ) should_overwrite_all = overwrite_or_not(self) - log.debug('should_overwrite_all: {}'.format(should_overwrite_all)) + log.debug("should_overwrite_all: {}".format(should_overwrite_all)) current_label = self.ui.labels.label.array.copy() changed = self.ui.labels.label.changed.copy() @@ -948,17 +966,17 @@ def open_avi_browser(self): if self.data_path is not None: data_dir = self.data_path else: - raise ValueError('create or load a DEG project before loading video') + raise ValueError("create or load a DEG project before loading video") options = QFileDialog.Options() - filestring = 'VideoReader files (*.h5 *.avi *.mp4 *.png *.jpg *.mov)' + filestring = "VideoReader files (*.h5 *.avi *.mp4 *.png *.jpg *.mov)" prompt = "Click on video to open. If a directory full of images, click any image" filename, _ = QFileDialog.getOpenFileName(self, prompt, data_dir, filestring, options=options) if len(filename) == 0 or not os.path.isfile(filename): - raise ValueError('Could not open file: {}'.format(filename)) + raise ValueError("Could not open file: {}".format(filename)) ext = os.path.splitext(filename)[1] - if ext in ['.png', '.jpg']: + if ext in [".png", ".jpg"]: filename = os.path.dirname(filename) assert os.path.isdir(filename) @@ -969,11 +987,11 @@ def add_multiple_videos(self): if self.data_path is not None: data_dir = self.data_path else: - raise ValueError('create or load a DEG project before loading video') + raise ValueError("create or load a DEG project before loading video") # https://stackoverflow.com/questions/38252419/how-to-get-qfiledialog-to-select-and-return-multiple-folders options = QFileDialog.Options() - filestring = 'VideoReader files (*.h5 *.avi *.mp4 *.png *.jpg *.mov)' + filestring = "VideoReader files (*.h5 *.avi *.mp4 *.png *.jpg *.mov)" prompt = "Click on video to open. If a directory full of images, click any image" filenames, _ = QFileDialog.getOpenFileNames(self, prompt, data_dir, filestring, options=options) if len(filenames) == 0: @@ -989,13 +1007,12 @@ def add_multiple_videos(self): # self.initialize_video(filename) def initialize_project(self, directory: Union[str, os.PathLike]): - if len(directory) == 0: return - filename = os.path.join(directory, 'project_config.yaml') + filename = os.path.join(directory, "project_config.yaml") if len(filename) == 0 or not os.path.isfile(filename): - log.error('something wrong with loading yaml file: {}'.format(filename)) + log.error("something wrong with loading yaml file: {}".format(filename)) return # project_dict = projects.load_config(filename) @@ -1012,13 +1029,13 @@ def initialize_project(self, directory: Union[str, os.PathLike]): # self.project_config['project']['model_path']) # overwrite cfg passed at command line now that we know the project path. still includes command line arguments - self.cfg = configuration.make_config(directory, ['config', 'gui', 'postprocessor'], run_type='gui', model=None) - log.info('cwd: {}'.format(os.getcwd())) + self.cfg = configuration.make_config(directory, ["config", "gui", "postprocessor"], run_type="gui", model=None) + log.info("cwd: {}".format(os.getcwd())) self.cfg = projects.convert_config_paths_to_absolute(self.cfg, raise_error_if_pretrained_missing=False) - log.info('cwd: {}'.format(os.getcwd())) + log.info("cwd: {}".format(os.getcwd())) self.cfg = projects.setup_run(self.cfg, raise_error_if_pretrained_missing=False) - log.info('loaded project configuration: {}'.format(OmegaConf.to_yaml(self.cfg))) - log.info('cwd: {}'.format(os.getcwd())) + log.info("loaded project configuration: {}".format(OmegaConf.to_yaml(self.cfg))) + log.info("cwd: {}".format(os.getcwd())) # for convenience self.data_path = self.cfg.project.data_path self.model_path = self.cfg.project.model_path @@ -1034,10 +1051,10 @@ def initialize_project(self, directory: Union[str, os.PathLike]): last_record = list(records.values())[-1] else: last_record = self.unfinalized[0] - if last_record['rgb'] is not None: - self.initialize_video(last_record['rgb']) - if last_record['label'] is not None: - self.import_labelfile(last_record['label']) + if last_record["rgb"] is not None: + self.initialize_video(last_record["rgb"]) + if last_record["label"] is not None: + self.import_labelfile(last_record["label"]) # if last_record['output'] is not None: # self.import_outputfile(last_record['output']) @@ -1046,35 +1063,36 @@ def initialize_project(self, directory: Union[str, os.PathLike]): def load_project(self): # options = QFileDialog.Options() - directory = QFileDialog.getExistingDirectory(self, "Open your deepethogram directory (containing project " - "config)") + directory = QFileDialog.getExistingDirectory( + self, "Open your deepethogram directory (containing project " "config)" + ) self.initialize_project(directory) # pprint.pprint(self.trained_model_dict) def get_default_archs(self): # TODO: replace this default logic with hydra 1.0 - if 'preset' in self.cfg: + if "preset" in self.cfg: preset = self.cfg.preset else: - preset = 'deg_f' - default_archs = projects.load_default('preset/{}'.format(preset)) - seq_default = projects.load_default('model/sequence') - default_archs['sequence'] = {'arch': seq_default['sequence']['arch']} - - if 'feature_extractor' in self.cfg and self.cfg.feature_extractor.arch is not None: - default_archs['feature_extractor']['arch'] = self.cfg.feature_extractor.arch - if 'flow_generator' in self.cfg and self.cfg.flow_generator.arch is not None: - default_archs['flow_generator']['arch'] = self.cfg.flow_generator.arch - if 'sequence' in self.cfg and 'arch' in self.cfg.sequence and self.cfg.sequence.arch is not None: - default_archs['sequence']['arch'] = self.cfg.sequence.arch + preset = "deg_f" + default_archs = projects.load_default("preset/{}".format(preset)) + seq_default = projects.load_default("model/sequence") + default_archs["sequence"] = {"arch": seq_default["sequence"]["arch"]} + + if "feature_extractor" in self.cfg and self.cfg.feature_extractor.arch is not None: + default_archs["feature_extractor"]["arch"] = self.cfg.feature_extractor.arch + if "flow_generator" in self.cfg and self.cfg.flow_generator.arch is not None: + default_archs["flow_generator"]["arch"] = self.cfg.flow_generator.arch + if "sequence" in self.cfg and "arch" in self.cfg.sequence and self.cfg.sequence.arch is not None: + default_archs["sequence"]["arch"] = self.cfg.sequence.arch self.default_archs = default_archs - log.debug('default archs: {}'.format(default_archs)) + log.debug("default archs: {}".format(default_archs)) def get_trained_models(self): trained_models = projects.get_weights_from_model_path(self.model_path) self.get_default_archs() - log.debug('trained models found: {}'.format(trained_models)) + log.debug("trained models found: {}".format(trained_models)) trained_dict = {} self.trained_model_dict = trained_dict @@ -1082,35 +1100,35 @@ def get_trained_models(self): trained_dict[model] = {} # for sequence models, we can train with no pre-trained weights - if model == 'sequence': - trained_dict[model][''] = None + if model == "sequence": + trained_dict[model][""] = None - arch = self.default_archs[model]['arch'] + arch = self.default_archs[model]["arch"] if arch not in archs.keys(): continue - trained_dict[model]['no pretrained weights'] = None + trained_dict[model]["no pretrained weights"] = None for run in trained_models[model][arch]: key = os.path.basename(os.path.dirname(run)) - if key == 'lightning_checkpoints': + if key == "lightning_checkpoints": key = os.path.basename(os.path.dirname(os.path.dirname(run))) trained_dict[model][key] = run - log.debug('trained model dict: {}'.format(self.trained_model_dict)) - models = self.trained_model_dict['flow_generator'] + log.debug("trained model dict: {}".format(self.trained_model_dict)) + models = self.trained_model_dict["flow_generator"] if len(models) > 0: self.ui.flowSelector.clear() for key in models.keys(): self.ui.flowSelector.addItem(key) self.ui.flowSelector.setCurrentIndex(len(models) - 1) - models = self.trained_model_dict['feature_extractor'] + models = self.trained_model_dict["feature_extractor"] if len(models) > 0: self.ui.feSelector.clear() for key in models: self.ui.feSelector.addItem(key) self.ui.feSelector.setCurrentIndex(len(models) - 1) - models = self.trained_model_dict['sequence'] + models = self.trained_model_dict["sequence"] if len(models) > 0: self.ui.sequenceSelector.clear() for key in models: @@ -1124,40 +1142,43 @@ def get_selected_models(self, model_type: str = None): fe_model = None seq_model = None - models = {'flow_generator': flow_model, 'feature_extractor': fe_model, 'sequence': seq_model} + models = {"flow_generator": flow_model, "feature_extractor": fe_model, "sequence": seq_model} - if not hasattr(self, 'trained_model_dict'): + if not hasattr(self, "trained_model_dict"): if model_type is not None: - log.warning('No {} weights found. Please download using the link on GitHub: {}'.format( - model_type, 'https://github.com/jbohnslav/deepethogram')) + log.warning( + "No {} weights found. Please download using the link on GitHub: {}".format( + model_type, "https://github.com/jbohnslav/deepethogram" + ) + ) return models log.info(self.trained_model_dict) flow_text = self.ui.flowSelector.currentText() - if flow_text in list(self.trained_model_dict['flow_generator'].keys()): - models['flow_generator'] = self.trained_model_dict['flow_generator'][flow_text] + if flow_text in list(self.trained_model_dict["flow_generator"].keys()): + models["flow_generator"] = self.trained_model_dict["flow_generator"][flow_text] fe_text = self.ui.feSelector.currentText() - if fe_text in self.trained_model_dict['feature_extractor'].keys(): - models['feature_extractor'] = self.trained_model_dict['feature_extractor'][fe_text] + if fe_text in self.trained_model_dict["feature_extractor"].keys(): + models["feature_extractor"] = self.trained_model_dict["feature_extractor"][fe_text] seq_text = self.ui.sequenceSelector.currentText() - if seq_text in self.trained_model_dict['sequence'].keys(): - models['sequence'] = self.trained_model_dict['sequence'][seq_text] + if seq_text in self.trained_model_dict["sequence"].keys(): + models["sequence"] = self.trained_model_dict["sequence"][seq_text] return models def update_frame(self, n): self.ui.videoPlayer.videoView.update_frame(n) def move_n_frames(self, n): - if not hasattr(self, 'vid'): + if not hasattr(self, "vid"): return x = self.ui.videoPlayer.videoView.current_fnum self.ui.videoPlayer.videoView.update_frame(x + n) def _make_dataframe(self): - if not hasattr(self, 'cfg'): - raise ValueError('attempted to save dataframe without initializing or opening a project') + if not hasattr(self, "cfg"): + raise ValueError("attempted to save dataframe without initializing or opening a project") label = np.copy(self.ui.labels.label.array).astype(np.int16) changed = np.copy(self.ui.labels.label.changed).astype(bool) n_behaviors = label.shape[1] @@ -1172,7 +1193,7 @@ def _make_dataframe(self): return df def check_saved(self): - if not hasattr(self.ui, 'labels'): + if not hasattr(self.ui, "labels"): return True return self.ui.labels.label.saved @@ -1180,29 +1201,29 @@ def closeEvent(self, event, *args, **kwargs): super(QMainWindow, self).closeEvent(event, *args, **kwargs) # https://stackoverflow.com/questions/1414781/prompt-on-exit-in-pyqt-application if not self.saved: - message = 'You have unsaved changes. Are you sure you want to quit?' + message = "You have unsaved changes. Are you sure you want to quit?" if simple_popup_question(self, message): event.accept() else: event.ignore() return - if hasattr(self, 'training_pipe'): - message = 'If you quit, training will be stopped. Are you sure you want to quit?' + if hasattr(self, "training_pipe"): + message = "If you quit, training will be stopped. Are you sure you want to quit?" if simple_popup_question(self, message): event.accept() else: event.ignore() return - if hasattr(self, 'inference_pipe'): - message = 'If you quit, inference will be stopped. Are you sure you want to quit?' + if hasattr(self, "inference_pipe"): + message = "If you quit, inference will be stopped. Are you sure you want to quit?" if simple_popup_question(self, message): event.accept() else: event.ignore() return - if hasattr(self, 'vid'): + if hasattr(self, "vid"): self.vid.close() @Slot(bool) @@ -1244,8 +1265,8 @@ def set_style(app): def setup_gui_cfg(): - config_list = ['config', 'gui'] - run_type = 'gui' + config_list = ["config", "gui"] + run_type = "gui" model = None project_path = projects.get_project_path_from_cl(sys.argv, error_if_not_found=False) @@ -1253,8 +1274,8 @@ def setup_gui_cfg(): cfg = configuration.make_config(project_path, config_list, run_type, model, use_command_line=True) else: command_line_cfg = OmegaConf.from_cli() - if 'preset' in command_line_cfg: - config_list.append('preset/' + command_line_cfg.preset) + if "preset" in command_line_cfg: + config_list.append("preset/" + command_line_cfg.preset) cfgs = [configuration.load_config_by_name(i) for i in config_list] cfg = OmegaConf.merge(*cfgs, command_line_cfg) try: @@ -1264,13 +1285,12 @@ def setup_gui_cfg(): # OmegaConf.set_struct(cfg, False) - log.info('CWD: {}'.format(os.getcwd())) - log.info('Configuration used: {}'.format(OmegaConf.to_yaml(cfg))) + log.info("CWD: {}".format(os.getcwd())) + log.info("Configuration used: {}".format(OmegaConf.to_yaml(cfg))) return cfg def run() -> None: - app = QtWidgets.QApplication(sys.argv) app = set_style(app) @@ -1290,5 +1310,5 @@ def entry() -> None: run() -if __name__ == '__main__': +if __name__ == "__main__": run() diff --git a/deepethogram/gui/mainwindow.py b/deepethogram/gui/mainwindow.py index 584dd3d..442fe46 100644 --- a/deepethogram/gui/mainwindow.py +++ b/deepethogram/gui/mainwindow.py @@ -10,6 +10,7 @@ from PySide2 import QtCore, QtWidgets + class Ui_MainWindow(object): def setupUi(self, MainWindow): MainWindow.setObjectName("MainWindow") @@ -258,10 +259,14 @@ def retranslateUi(self, MainWindow): self.sequence_train.setText(QtWidgets.QApplication.translate("MainWindow", "Train", None, -1)) self.sequence_infer.setText(QtWidgets.QApplication.translate("MainWindow", "Infer", None, -1)) self.labelBox.setTitle(QtWidgets.QApplication.translate("MainWindow", "Labels", None, -1)) - self.importPredictions.setText(QtWidgets.QApplication.translate("MainWindow", "Import predictions as labels", None, -1)) + self.importPredictions.setText( + QtWidgets.QApplication.translate("MainWindow", "Import predictions as labels", None, -1) + ) self.finalize_labels.setText(QtWidgets.QApplication.translate("MainWindow", "Finalize Labels", None, -1)) self.groupBox_4.setTitle(QtWidgets.QApplication.translate("MainWindow", "Predictions", None, -1)) - self.exportPredictions.setText(QtWidgets.QApplication.translate("MainWindow", "Export predictions to CSV", None, -1)) + self.exportPredictions.setText( + QtWidgets.QApplication.translate("MainWindow", "Export predictions to CSV", None, -1) + ) self.label_4.setText(QtWidgets.QApplication.translate("MainWindow", "Labels", None, -1)) self.label_5.setText(QtWidgets.QApplication.translate("MainWindow", "Predictions", None, -1)) self.menuDeepEthogram.setTitle(QtWidgets.QApplication.translate("MainWindow", "File", None, -1)) @@ -270,7 +275,9 @@ def retranslateUi(self, MainWindow): self.menuImport.setTitle(QtWidgets.QApplication.translate("MainWindow", "Import", None, -1)) self.menuBatch.setTitle(QtWidgets.QApplication.translate("MainWindow", "Batch", None, -1)) self.actionNew_Project.setText(QtWidgets.QApplication.translate("MainWindow", "New Project", None, -1)) - self.actionSave_Project.setText(QtWidgets.QApplication.translate("MainWindow", "Save Project (ctrl+s)", None, -1)) + self.actionSave_Project.setText( + QtWidgets.QApplication.translate("MainWindow", "Save Project (ctrl+s)", None, -1) + ) self.actionAdd.setText(QtWidgets.QApplication.translate("MainWindow", "Add", None, -1)) self.actionRemove.setText(QtWidgets.QApplication.translate("MainWindow", "Remove", None, -1)) self.actionStyle.setText(QtWidgets.QApplication.translate("MainWindow", "Style", None, -1)) @@ -281,8 +288,11 @@ def retranslateUi(self, MainWindow): self.actionOpen_Project.setText(QtWidgets.QApplication.translate("MainWindow", "Open Project", None, -1)) self.importLabels.setText(QtWidgets.QApplication.translate("MainWindow", "Labels", None, -1)) self.actionAdd_videos.setText(QtWidgets.QApplication.translate("MainWindow", "Add videos", None, -1)) - self.classifierInference.setText(QtWidgets.QApplication.translate("MainWindow", "Feature extractor inference + sequence inference", None, -1)) + self.classifierInference.setText( + QtWidgets.QApplication.translate("MainWindow", "Feature extractor inference + sequence inference", None, -1) + ) self.actionOvernight.setText(QtWidgets.QApplication.translate("MainWindow", "Overnight", None, -1)) self.actionAdd_multiple.setText(QtWidgets.QApplication.translate("MainWindow", "Add multiple", None, -1)) + from deepethogram.gui.custom_widgets import LabelImg, VideoPlayer diff --git a/deepethogram/gui/menus_and_popups.py b/deepethogram/gui/menus_and_popups.py index 2ae070c..38f19ab 100644 --- a/deepethogram/gui/menus_and_popups.py +++ b/deepethogram/gui/menus_and_popups.py @@ -6,19 +6,25 @@ def simple_popup_question(parent, message: str): # message = 'You have unsaved changes. Are you sure you want to quit?' - reply = QtWidgets.QMessageBox.question(parent, 'Message', - message, QtWidgets.QMessageBox.Yes, QtWidgets.QMessageBox.No) + reply = QtWidgets.QMessageBox.question( + parent, "Message", message, QtWidgets.QMessageBox.Yes, QtWidgets.QMessageBox.No + ) return reply == QtWidgets.QMessageBox.Yes + # https://stackoverflow.com/questions/15682665/how-to-add-custom-button-to-a-qmessagebox-in-pyqt4 + def overwrite_or_not(parent): msgBox = QtWidgets.QMessageBox(parent) msgBox.setIcon(QtWidgets.QMessageBox.Question) - msgBox.setText('Do you want to overwrite your labels with these predictions, or only import the predictions' - ' for frames you haven''t labeled?') - overwrite = msgBox.addButton('Overwrite', QtWidgets.QMessageBox.YesRole) - unlabeled = msgBox.addButton('Only import unlabeled', QtWidgets.QMessageBox.NoRole) + msgBox.setText( + "Do you want to overwrite your labels with these predictions, or only import the predictions" + " for frames you haven" + "t labeled?" + ) + overwrite = msgBox.addButton("Overwrite", QtWidgets.QMessageBox.YesRole) + unlabeled = msgBox.addButton("Only import unlabeled", QtWidgets.QMessageBox.NoRole) msgBox.exec_() if msgBox.clickedButton() is overwrite: return True @@ -27,15 +33,19 @@ def overwrite_or_not(parent): else: return + class OverwriteOrNot(QtWidgets.QDialog): def __init__(self, parent=None): super().__init__(parent) msgBox = QtWidgets.QMessageBox() - msgBox.setText('Do you want to overwrite your labels with these predictions, or only import the predictions' - ' for frames you haven''t labeled?') - msgBox.addButton(QtWidgets.QPushButton('Overwrite'), QtWidgets.QMessageBox.YesRole) - msgBox.addButton(QtWidgets.QPushButton('Only import unlabeled'), QtWidgets.QMessageBox.NoRole) + msgBox.setText( + "Do you want to overwrite your labels with these predictions, or only import the predictions" + " for frames you haven" + "t labeled?" + ) + msgBox.addButton(QtWidgets.QPushButton("Overwrite"), QtWidgets.QMessageBox.YesRole) + msgBox.addButton(QtWidgets.QPushButton("Only import unlabeled"), QtWidgets.QMessageBox.NoRole) msgBox.exec_() @@ -43,21 +53,22 @@ class CreateProject(QtWidgets.QDialog): def __init__(self, parent=None): super().__init__(parent) - string = 'Pick directory for your project. SHOULD BE ON YOUR FASTEST HARD DRIVE. VIDEOS WILL BE COPIED HERE' + string = "Pick directory for your project. SHOULD BE ON YOUR FASTEST HARD DRIVE. VIDEOS WILL BE COPIED HERE" project_directory = QtWidgets.QFileDialog.getExistingDirectory(self, string) if len(project_directory) == 0: - warnings.warn('Please choose a directory') + warnings.warn("Please choose a directory") return project_directory = str(pathlib.Path(project_directory).resolve()) self.project_directory = project_directory - self.project_name_default = 'Project Name' + self.project_name_default = "Project Name" self.project_box = QtWidgets.QLineEdit(self.project_name_default) - self.label_default_string = 'Name of person labeling' + self.label_default_string = "Name of person labeling" self.labeler_box = QtWidgets.QLineEdit(self.label_default_string) # self.labeler_box. - self.behavior_default_string = 'List of behaviors, e.g. \"walk,scratch,itch\". Do not include none,other,' \ - 'background,etc ' + self.behavior_default_string = ( + 'List of behaviors, e.g. "walk,scratch,itch". Do not include none,other,' "background,etc " + ) self.behaviors_box = QtWidgets.QLineEdit(self.behavior_default_string) # self.finish_button = QPushButton('Ok') # self.cancel_button = QPushButton('Cancel') @@ -73,14 +84,13 @@ def __init__(self, parent=None): layout.addWidget(button_box) self.setLayout(layout) # win = QtWidgets.QWidget() - self.setWindowTitle('Create project') + self.setWindowTitle("Create project") self.resize(800, 400) self.show() # modified from https://pythonspot.com/pyqt5-form-layout/ class ShouldRunInference(QtWidgets.QDialog): - def __init__(self, record_keys: list, should_start_checked: list): super(ShouldRunInference, self).__init__() @@ -114,7 +124,7 @@ def createFormGroupBox(self, record_keys: list, should_start_checked: list): for row, (record, check) in enumerate(zip(record_keys, should_start_checked)): button = QtWidgets.QCheckBox(self) button.setChecked(check) - text = QtWidgets.QLabel(record + ':') + text = QtWidgets.QLabel(record + ":") layout.addWidget(text, row, 0) layout.addWidget(button, row, 1) self.buttons.append(button) @@ -123,9 +133,8 @@ def createFormGroupBox(self, record_keys: list, should_start_checked: list): # set the scroll area's widget to be the container self.scrollArea.setWidget(self.scrollWidget) - def get_outputs(self): - if not hasattr(self, 'buttons'): + if not hasattr(self, "buttons"): return None answers = [] for button in self.buttons: @@ -133,13 +142,12 @@ def get_outputs(self): return answers -if __name__ == '__main__': +if __name__ == "__main__": app = QtWidgets.QApplication([]) num = 50 - form = ShouldRunInference(['M134_20141203_v001', - 'M134_20141203_v002', - 'M134_20141203_v004']*num, - [True, True, False]*num) + form = ShouldRunInference( + ["M134_20141203_v001", "M134_20141203_v002", "M134_20141203_v004"] * num, [True, True, False] * num + ) ret = form.exec_() if ret: print(form.get_outputs()) diff --git a/deepethogram/losses.py b/deepethogram/losses.py index 7adb16c..6e81f36 100644 --- a/deepethogram/losses.py +++ b/deepethogram/losses.py @@ -9,6 +9,7 @@ log = logging.getLogger(__name__) + def should_decay_parameter(name: str, param: torch.Tensor) -> bool: """Determines if L2 (or L2-SP) decay should be applied to parameter @@ -16,7 +17,7 @@ def should_decay_parameter(name: str, param: torch.Tensor) -> bool: Helpful source: https://github.com/rwightman/pytorch-image-models/blob/198f6ea0f3dae13f041f3ea5880dd79089b60d61/timm/optim/optim_factory.py - + Parameters ---------- name : str @@ -29,16 +30,17 @@ def should_decay_parameter(name: str, param: torch.Tensor) -> bool: bool Whether or not to decay """ - + if not param.requires_grad: return False - elif 'batchnorm' in name.lower() or 'bn' in name.lower() or 'bias' in name.lower(): + elif "batchnorm" in name.lower() or "bn" in name.lower() or "bias" in name.lower(): return False elif param.ndim == 1: return False else: return True - + + def get_keys_to_decay(model: nn.Module) -> list: """Returns list of parameter keys in a nn.Module that should be decayed @@ -59,31 +61,32 @@ def get_keys_to_decay(model: nn.Module) -> list: class L2(nn.Module): - """L2 regularization - """ + """L2 regularization""" + def __init__(self, model: nn.Module, alpha: float): super().__init__() - + self.alpha = alpha - self.keys = get_keys_to_decay(model) - + self.keys = get_keys_to_decay(model) + def forward(self, model): # https://discuss.pytorch.org/t/how-does-one-implement-weight-regularization-l1-or-l2-manually-without-optimum/7951 # https://stackoverflow.com/questions/42704283/adding-l1-l2-regularization-in-pytorch # note that soumith's answer is wrong because it uses W.norm, which takes the square root - l2_loss = 0 # torch.tensor(0., requires_grad=True) + l2_loss = 0 # torch.tensor(0., requires_grad=True) for key, param in model.named_parameters(): if key in self.keys: - l2_loss += param.pow(2).sum()*0.5 - - return l2_loss*self.alpha - + l2_loss += param.pow(2).sum() * 0.5 + + return l2_loss * self.alpha + + class L2_SP(nn.Module): - """L2_SP normalization; weight decay towards a pretrained state, instead of towards 0. - + """L2_SP normalization; weight decay towards a pretrained state, instead of towards 0. + https://arxiv.org/abs/1802.01483 @misc{li2018explicit, - title={Explicit Inductive Bias for Transfer Learning with Convolutional Networks}, + title={Explicit Inductive Bias for Transfer Learning with Convolutional Networks}, author={Xuhong Li and Yves Grandvalet and Franck Davoine}, year={2018}, eprint={1802.01483}, @@ -91,33 +94,34 @@ class L2_SP(nn.Module): primaryClass={cs.LG} } """ + def __init__(self, model: nn.Module, path_to_pretrained_weights, alpha: float, beta: float): - # - + # + super().__init__() - + self.alpha = alpha self.beta = beta # assert cfg.train.regularization.style == 'l2_sp' - + assert os.path.isfile(path_to_pretrained_weights) - state = torch.load(path_to_pretrained_weights, map_location='cpu') - - pretrained_state = state['state_dict'] - + state = torch.load(path_to_pretrained_weights, map_location="cpu") + + pretrained_state = state["state_dict"] + self.pretrained_keys, self.new_keys = self.get_keys(model, pretrained_state) - - log.debug('pretrained keys for L2SP: {}'.format(self.pretrained_keys)) - log.debug('Novel keys for L2SP: {}'.format(self.new_keys)) - + + log.debug("pretrained keys for L2SP: {}".format(self.pretrained_keys)) + log.debug("Novel keys for L2SP: {}".format(self.new_keys)) + for key in self.pretrained_keys: # can't register a buffer with dots in the keys self.register_buffer(self.dots_to_underscores(key), pretrained_state[key]) - + @staticmethod def dots_to_underscores(key): - return key.replace('.', '_') - + return key.replace(".", "_") + def get_keys(self, model: nn.Module, pretrained_state): """Gets parameter names that are in both current model and pretrained weights, and unique keys to our model @@ -150,11 +154,11 @@ def get_keys(self, model: nn.Module, pretrained_state): else: not_in_pretrained.append(key) return is_in_pretrained, not_in_pretrained - + def forward(self, model): towards_pretrained, towards_0 = 0, 0 - - # not passing keep_vars will detach the tensor from the computation graph, resulting in no effect on the + + # not passing keep_vars will detach the tensor from the computation graph, resulting in no effect on the # training but also no error messages model_state = model.state_dict(keep_vars=True) pretrained_state = self.state_dict(keep_vars=True) @@ -162,14 +166,14 @@ def forward(self, model): for key in self.pretrained_keys: model_param = model_state[key] pretrained_param = pretrained_state[self.dots_to_underscores(key)] - towards_pretrained += (model_param - pretrained_param).pow(2).sum()*0.5 + towards_pretrained += (model_param - pretrained_param).pow(2).sum() * 0.5 for key in self.new_keys: model_param = model_state[key] - towards_0 += model_param.pow(2).sum()*0.5 - + towards_0 += model_param.pow(2).sum() * 0.5 + if towards_pretrained != towards_pretrained or towards_0 != towards_0: - msg = 'invalid loss in L2-SP: towards pretrained: {} towards 0: {}'.format(towards_pretrained, towards_0) + msg = "invalid loss in L2-SP: towards pretrained: {} towards 0: {}".format(towards_pretrained, towards_0) raise ValueError(msg) # alternate method. same result, ~50% slower # towards_pretrained, towards_0 = 0, 0 @@ -180,35 +184,36 @@ def forward(self, model): # towards_pretrained += (param - pretrained_param).pow(2).sum()*0.5 # elif key in self.new_keys: # towards_0 += param.pow(2).sum()*0.5 - - return towards_pretrained*self.alpha + towards_0*self.beta + + return towards_pretrained * self.alpha + towards_0 * self.beta + def get_regularization_loss(cfg: DictConfig, model): - if cfg.train.regularization.style == 'l2': - log.info('Regularization: L2. alpha: {} '.format(cfg.train.regularization.alpha)) + if cfg.train.regularization.style == "l2": + log.info("Regularization: L2. alpha: {} ".format(cfg.train.regularization.alpha)) regularization_criterion = L2(model, cfg.train.regularization.alpha) - elif cfg.train.regularization.style == 'l2_sp': + elif cfg.train.regularization.style == "l2_sp": pretrained_dir = cfg.project.pretrained_path assert os.path.isdir(pretrained_dir) weights = projects.get_weights_from_model_path(pretrained_dir) pretrained_file = weights[cfg.run.model][cfg[cfg.run.model].arch] - + if len(pretrained_file) == 0: - log.warning('No pretrained file found. Regularization: L2. alpha={}'.format( - cfg.train.regularization.beta - )) + log.warning("No pretrained file found. Regularization: L2. alpha={}".format(cfg.train.regularization.beta)) regularization_criterion = L2(model, cfg.train.regularization.beta) elif len(pretrained_file) == 1: - pretrained_file = pretrained_file[0] - log.info('Regularization: L2_SP. Pretrained file: {} alpha: {} beta: {}'.format( - pretrained_file, cfg.train.regularization.alpha, cfg.train.regularization.beta - )) - regularization_criterion = L2_SP(model, pretrained_file, cfg.train.regularization.alpha, - cfg.train.regularization.beta) + log.info( + "Regularization: L2_SP. Pretrained file: {} alpha: {} beta: {}".format( + pretrained_file, cfg.train.regularization.alpha, cfg.train.regularization.beta + ) + ) + regularization_criterion = L2_SP( + model, pretrained_file, cfg.train.regularization.alpha, cfg.train.regularization.beta + ) else: - raise ValueError('unsure what weights to use: {}'.format(pretrained_file)) + raise ValueError("unsure what weights to use: {}".format(pretrained_file)) else: raise NotImplementedError - - return regularization_criterion \ No newline at end of file + + return regularization_criterion diff --git a/deepethogram/metrics.py b/deepethogram/metrics.py index f94a929..655176e 100644 --- a/deepethogram/metrics.py +++ b/deepethogram/metrics.py @@ -20,14 +20,14 @@ # using multiprocessing on slurm causes a termination signal try: - slurm_job_id = os.environ['SLURM_JOB_ID'] + slurm_job_id = os.environ["SLURM_JOB_ID"] slurm = True except: slurm = False def index_to_onehot(index: np.ndarray, n_classes: int) -> np.ndarray: - """ Convert an array if indices to one-hot vectors. + """Convert an array if indices to one-hot vectors. Parameters ---------- @@ -56,7 +56,7 @@ def index_to_onehot(index: np.ndarray, n_classes: int) -> np.ndarray: def hardmax(probabilities: np.ndarray) -> np.ndarray: - """ Convert probability array to prediction by converting the max of each row to 1 + """Convert probability array to prediction by converting the max of each row to 1 Parameters ---------- @@ -101,8 +101,8 @@ def onehot_to_index(onehot: np.ndarray) -> np.ndarray: return np.argmax(onehot, axis=1) -def f1(predictions: np.ndarray, labels: np.ndarray, average: str = 'macro') -> np.ndarray: - """ simple wrapper around sklearn.metrics.f1_score +def f1(predictions: np.ndarray, labels: np.ndarray, average: str = "macro") -> np.ndarray: + """simple wrapper around sklearn.metrics.f1_score References ------- @@ -118,8 +118,8 @@ def f1(predictions: np.ndarray, labels: np.ndarray, average: str = 'macro') -> n return F1 -def roc_auc(predictions: np.ndarray, labels: np.ndarray, average: str = 'macro') -> np.ndarray: - """ simple wrapper around sklearn.metrics.roc_auc_score +def roc_auc(predictions: np.ndarray, labels: np.ndarray, average: str = "macro") -> np.ndarray: + """simple wrapper around sklearn.metrics.roc_auc_score References ------- @@ -127,7 +127,7 @@ def roc_auc(predictions: np.ndarray, labels: np.ndarray, average: str = 'macro') .. [2] https://en.wikipedia.org/wiki/Receiver_operating_characteristic """ if predictions.ndim == 1: - raise ValueError('Predictions must be class probabilities before max!') + raise ValueError("Predictions must be class probabilities before max!") if labels.ndim == 1: labels = index_to_onehot(labels, predictions.shape[1]) score = roc_auc_score(labels, predictions, average=average) @@ -135,12 +135,12 @@ def roc_auc(predictions: np.ndarray, labels: np.ndarray, average: str = 'macro') def accuracy(predictions: np.ndarray, labels: np.ndarray): - """ Return the fraction of elements in predictions that are equal to labels """ + """Return the fraction of elements in predictions that are equal to labels""" return np.mean(predictions == labels) def confusion(predictions: np.ndarray, labels: np.ndarray, K: int = None) -> np.ndarray: - """ Computes confusion matrix. Much faster than sklearn.metrics.confusion_matrix for large numbers of predictions + """Computes confusion matrix. Much faster than sklearn.metrics.confusion_matrix for large numbers of predictions Parameters ---------- @@ -189,7 +189,7 @@ def binary_confusion_matrix(predictions, labels) -> np.ndarray: # 2 x 2 cms = np.zeros((2, 2), dtype=int) else: - raise ValueError('unknown input shape: {}'.format(predictions.shape)) + raise ValueError("unknown input shape: {}".format(predictions.shape)) neg_lab = np.logical_not(labels) neg_pred = np.logical_not(predictions) @@ -226,12 +226,9 @@ def confusion_alias(inp): return binary_confusion_matrix(*inp) -def binary_confusion_matrix_parallel(probs_or_preds, - labels, - thresholds=None, - chunk_size: int = 100, - num_workers: int = 4, - parallel_chunk: int = 100): +def binary_confusion_matrix_parallel( + probs_or_preds, labels, thresholds=None, chunk_size: int = 100, num_workers: int = 4, parallel_chunk: int = 100 +): # log.info('num workers binary confusion parallel: {}'.format(num_workers)) if slurm: parallel_chunk = 1 @@ -254,7 +251,7 @@ def binary_confusion_matrix_parallel(probs_or_preds, elif probs_or_preds.ndim == 1: cm = np.zeros((2, 2), dtype=int) else: - raise ValueError('weird shape in probs_or_preds: {}'.format(probs_or_preds.shape)) + raise ValueError("weird shape in probs_or_preds: {}".format(probs_or_preds.shape)) func = confusion_alias # log.info('parallel start') if num_workers > 1: @@ -268,9 +265,10 @@ def binary_confusion_matrix_parallel(probs_or_preds, return cm -def compute_binary_confusion(predictions: np.ndarray, labels: np.ndarray, - thresholds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """ compute binary confusion matrices for input probabilities, labels, and thresholds. See confusion """ +def compute_binary_confusion( + predictions: np.ndarray, labels: np.ndarray, thresholds: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + """compute binary confusion matrices for input probabilities, labels, and thresholds. See confusion""" estimates = postprocess(predictions, thresholds, valid_bg=False) K = predictions.shape[1] @@ -288,21 +286,21 @@ def compute_binary_confusion(predictions: np.ndarray, labels: np.ndarray, def mean_class_accuracy(predictions, labels): - """ computes the mean of diagonal elements of a confusion matrix """ + """computes the mean of diagonal elements of a confusion matrix""" if predictions.ndim > 1: predictions = onehot_to_index(hardmax(predictions)) if labels.ndim > 1: labels = onehot_to_index(labels) cm = confusion_matrix(labels, predictions) - cm = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-9) + cm = cm.astype("float") / (cm.sum(axis=1)[:, np.newaxis] + 1e-9) on_diag = cm[np.where(np.eye(cm.shape[0], dtype=np.uint32))] return on_diag.mean() -def remove_invalid_values_predictions_and_labels(predictions: np.ndarray, labels: np.ndarray, - invalid_value: Union[int, float] = -1) -> \ - Tuple[np.ndarray, np.ndarray]: - """ remove any rows where labels are equal to invalid_value. +def remove_invalid_values_predictions_and_labels( + predictions: np.ndarray, labels: np.ndarray, invalid_value: Union[int, float] = -1 +) -> Tuple[np.ndarray, np.ndarray]: + """remove any rows where labels are equal to invalid_value. Used when (for example) the last sequence in a video is padded to have the proper sequence length. the padded inputs are paired with -1 labels, indicating that loss and metrics should not be applied there @@ -338,18 +336,18 @@ def compute_metrics_by_threshold(probabilities, labels, thresholds, num_workers: auroc = auc_on_array(fp, tp) mAP = auc_on_array(r, p) metrics_by_threshold = { - 'thresholds': thresholds, - 'accuracy': acc, - 'f1': f1, - 'precision': p, - 'recall': r, - 'fbeta_2': fbeta_2, - 'informedness': info, - 'tpr': tp, - 'fpr': fp, - 'auroc': auroc, - 'mAP': mAP, - 'confusion': cm + "thresholds": thresholds, + "accuracy": acc, + "f1": f1, + "precision": p, + "recall": r, + "fbeta_2": fbeta_2, + "informedness": info, + "tpr": tp, + "fpr": fp, + "auroc": auroc, + "mAP": mAP, + "confusion": cm, } return metrics_by_threshold @@ -366,16 +364,15 @@ def fast_auc(y_true, y_prob): nfalse = np.cumsum(1 - y_true) auc = np.cumsum((y_true * nfalse))[-1] # print(auc) - auc /= (nfalse[-1] * (n - nfalse[-1])) + auc /= nfalse[-1] * (n - nfalse[-1]) return auc # @profile -def evaluate_thresholds(probabilities: np.ndarray, - labels: np.ndarray, - thresholds: np.ndarray = None, - num_workers: int = 4) -> Tuple[dict, dict]: - """ Given probabilities and labels, compute a bunch of metrics at each possible threshold value +def evaluate_thresholds( + probabilities: np.ndarray, labels: np.ndarray, thresholds: np.ndarray = None, num_workers: int = 4 +) -> Tuple[dict, dict]: + """Given probabilities and labels, compute a bunch of metrics at each possible threshold value Also computes a number of metrics for which there is a single value for the input predictions / labels, something like the maximum F1 score across thresholds. @@ -401,7 +398,7 @@ def evaluate_thresholds(probabilities: np.ndarray, # log.info('evaluating thresholds. P: {} lab: {} n_workers: {}'.format(probabilities.shape, labels.shape, num_workers)) # log.info('SLURM in metrics file: {}'.format(slurm)) if slurm and num_workers != 1: - warnings.warn('using multiprocessing on slurm can cause issues. setting num_workers to 1') + warnings.warn("using multiprocessing on slurm can cause issues. setting num_workers to 1") num_workers = 1 if thresholds is None: @@ -411,7 +408,7 @@ def evaluate_thresholds(probabilities: np.ndarray, # log.debug('probabilities shape in metrics calc: {}'.format(probabilities.shape)) metrics_by_threshold = {} if probabilities.ndim == 1: - raise ValueError('To calc threshold, predictions must be probabilities, not classes') + raise ValueError("To calc threshold, predictions must be probabilities, not classes") K = probabilities.shape[1] if labels.ndim == 1: labels = index_to_onehot(labels, K) @@ -422,20 +419,20 @@ def evaluate_thresholds(probabilities: np.ndarray, # log.info('first metrics call finished') # log.info('finished computing binary confusion matrices') # optimum threshold: one that maximizes F1 - optimum_indices = np.argmax(metrics_by_threshold['f1'], axis=0) + optimum_indices = np.argmax(metrics_by_threshold["f1"], axis=0) optimum_thresholds = thresholds[optimum_indices] # if the threshold or the F1 is very low, these are erroneous: set to 0.5 - optimum_f1s = metrics_by_threshold['f1'][optimum_indices, range(len(optimum_indices))] + optimum_f1s = metrics_by_threshold["f1"][optimum_indices, range(len(optimum_indices))] optimum_thresholds = remove_low_thresholds(optimum_thresholds, f1s=optimum_f1s) # optimum info: maximizes informedness - optimum_indices_info = np.argmax(metrics_by_threshold['informedness'], axis=0) + optimum_indices_info = np.argmax(metrics_by_threshold["informedness"], axis=0) optimum_thresholds_info = thresholds[optimum_indices_info] - optimum_info = metrics_by_threshold['informedness'][optimum_indices_info, range(len(optimum_indices_info))] + optimum_info = metrics_by_threshold["informedness"][optimum_indices_info, range(len(optimum_indices_info))] optimum_thresholds_info = remove_low_thresholds(optimum_thresholds_info, f1s=optimum_info) - metrics_by_threshold['optimum'] = optimum_thresholds - metrics_by_threshold['optimum_info'] = optimum_thresholds_info + metrics_by_threshold["optimum"] = optimum_thresholds + metrics_by_threshold["optimum_info"] = optimum_thresholds_info # vectorized predictions = probabilities > optimum_thresholds @@ -451,27 +448,25 @@ def evaluate_thresholds(probabilities: np.ndarray, # summing over classes is the same as flattening the array. ugly syntax # TODO: make function that computes metrics from a stack of confusion matrices rather than this none None business # log.info('third metrics call') - overall_metrics = compute_metrics_by_threshold(None, - None, - thresholds=None, - num_workers=num_workers, - cm=metrics_by_class['confusion'].sum(axis=2)) + overall_metrics = compute_metrics_by_threshold( + None, None, thresholds=None, num_workers=num_workers, cm=metrics_by_class["confusion"].sum(axis=2) + ) # log.info('third metrics call ended') epoch_metrics = { - 'accuracy_overall': overall_metrics['accuracy'], - 'accuracy_by_class': metrics_by_class['accuracy'], - 'f1_overall': overall_metrics['f1'], - 'f1_class_mean': metrics_by_class['f1'].mean(), - 'f1_class_mean_nobg': metrics_by_class['f1'][1:].mean(), - 'f1_by_class': metrics_by_class['f1'], - 'binary_confusion': metrics_by_class['confusion'].transpose(2, 0, 1), - 'auroc_by_class': metrics_by_threshold['auroc'], - 'auroc_class_mean': metrics_by_threshold['auroc'].mean(), - 'mAP_by_class': metrics_by_threshold['mAP'], - 'mAP_class_mean': metrics_by_threshold['mAP'].mean(), + "accuracy_overall": overall_metrics["accuracy"], + "accuracy_by_class": metrics_by_class["accuracy"], + "f1_overall": overall_metrics["f1"], + "f1_class_mean": metrics_by_class["f1"].mean(), + "f1_class_mean_nobg": metrics_by_class["f1"][1:].mean(), + "f1_by_class": metrics_by_class["f1"], + "binary_confusion": metrics_by_class["confusion"].transpose(2, 0, 1), + "auroc_by_class": metrics_by_threshold["auroc"], + "auroc_class_mean": metrics_by_threshold["auroc"].mean(), + "mAP_by_class": metrics_by_threshold["mAP"], + "mAP_class_mean": metrics_by_threshold["mAP"].mean(), # to compute these, would need to make confusion matrices on flattened array, which is slow - 'auroc_overall': np.nan, - 'mAP_overall': np.nan + "auroc_overall": np.nan, + "mAP_overall": np.nan, } # it is too much of a pain to increase the speed on roc_auc_score and mAP # try: @@ -496,9 +491,9 @@ def evaluate_thresholds(probabilities: np.ndarray, def compute_tpr_fpr(cm: np.ndarray) -> Tuple[float, float]: - """ compute true positives and false positives from a non-normalized confusion matrix """ + """compute true positives and false positives from a non-normalized confusion matrix""" # normalize so that each are rates - cm_normalized = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-9) + cm_normalized = cm.astype("float") / (cm.sum(axis=1)[:, np.newaxis] + 1e-9) fp = cm_normalized[0, 1] tp = cm_normalized[1, 1] return tp, fp @@ -515,7 +510,7 @@ def get_denominator(expression: Union[float, np.ndarray]): def compute_f1(precision: float, recall: float, beta: float = 1.0) -> float: - """ compute f1 if you already have precison and recall. Prevents re-computing confusion matrix, etc """ + """compute f1 if you already have precison and recall. Prevents re-computing confusion matrix, etc""" num = (1 + beta**2) * (precision * recall) denom = get_denominator((beta**2) * precision + recall) @@ -523,7 +518,7 @@ def compute_f1(precision: float, recall: float, beta: float = 1.0) -> float: def compute_precision_recall(cm: np.ndarray) -> Tuple[float, float]: - """ computes precision and recall from a confusion matrix """ + """computes precision and recall from a confusion matrix""" tn = cm[0, 0] tp = cm[1, 1] fp = cm[0, 1] @@ -535,15 +530,15 @@ def compute_precision_recall(cm: np.ndarray) -> Tuple[float, float]: def compute_mean_accuracy(cm: np.ndarray) -> float: - """ compute the mean of true positive rate and true negative rate from a confusion matrix """ - cm = cm.astype('float') / get_denominator(cm.sum(axis=1)[:, np.newaxis]) + """compute the mean of true positive rate and true negative rate from a confusion matrix""" + cm = cm.astype("float") / get_denominator(cm.sum(axis=1)[:, np.newaxis]) tp = cm[1, 1] tn = cm[0, 0] return np.mean([tp, tn]) def compute_informedness(cm: np.ndarray, eps: float = 1e-7) -> float: - """ compute informedness from a confusion matrix. Also known as Youden's J statistic + """compute informedness from a confusion matrix. Also known as Youden's J statistic Parameters ---------- @@ -572,11 +567,11 @@ def compute_informedness(cm: np.ndarray, eps: float = 1e-7) -> float: def postprocess(predictions: np.ndarray, thresholds: np.ndarray, valid_bg: bool = True) -> np.ndarray: - """ Turn probabilities into predictions, with special handling of background. + """Turn probabilities into predictions, with special handling of background. TODO: Should be removed in favor of deepethogram.prostprocessing """ N, n_classes = predictions.shape - assert (len(thresholds) == n_classes) + assert len(thresholds) == n_classes estimates = np.zeros((N, n_classes), dtype=np.int64) for i in range(0, n_classes): @@ -587,11 +582,11 @@ def postprocess(predictions: np.ndarray, thresholds: np.ndarray, valid_bg: bool all_metrics = { - 'accuracy': accuracy, - 'mean_class_accuracy': mean_class_accuracy, - 'f1': f1, - 'roc_auc': roc_auc, - 'confusion': binary_confusion_matrix + "accuracy": accuracy, + "mean_class_accuracy": mean_class_accuracy, + "f1": f1, + "roc_auc": roc_auc, + "confusion": binary_confusion_matrix, } @@ -604,21 +599,20 @@ def list_to_mean(values): else: value = np.concatenate(np.array(values)).mean() else: - raise TypeError('Input should be numpy array or torch tensor. Type: ', type(values[0])) + raise TypeError("Input should be numpy array or torch tensor. Type: ", type(values[0])) return value def append_to_hdf5(f, name, value, axis=0): - """ resizes an HDF5 dataset and appends value """ + """resizes an HDF5 dataset and appends value""" f[name].resize(f[name].shape[axis] + 1, axis=axis) f[name][-1] = value class Buffer: - def __init__(self): self.data = {} - self.splits = ['train', 'val', 'test', 'speedtest'] + self.splits = ["train", "val", "test", "speedtest"] for split in self.splits: self.initialize(split) @@ -665,10 +659,9 @@ def clear(self, split=None): class EmptyBuffer: - def __init__(self): self.data = {} - self.splits = ['train', 'val', 'test', 'speedtest'] + self.splits = ["train", "val", "test", "speedtest"] for split in self.splits: self.initialize(split) @@ -688,14 +681,16 @@ def clear(self, split=None): class Metrics: """Class for saving a list of per-epoch metrics to disk as an HDF5 file""" - def __init__(self, - run_dir: Union[str, bytes, os.PathLike], - key_metric: str, - name: str, - num_parameters: int, - splits: list = ['train', 'val'], - num_workers: int = 4): - """ Metrics constructor + def __init__( + self, + run_dir: Union[str, bytes, os.PathLike], + key_metric: str, + name: str, + num_parameters: int, + splits: list = ["train", "val"], + num_workers: int = 4, + ): + """Metrics constructor Parameters ---------- @@ -710,9 +705,9 @@ def __init__(self, splits: list either ['train', 'val'] or ['train', 'val', 'test'] """ - assert (os.path.isdir(run_dir)) - self.fname = os.path.join(run_dir, '{}_metrics.h5'.format(name)) - log.debug('making metrics file at {}'.format(self.fname)) + assert os.path.isdir(run_dir) + self.fname = os.path.join(run_dir, "{}_metrics.h5".format(name)) + log.debug("making metrics file at {}".format(self.fname)) self.key_metric = key_metric self.splits = splits self.num_parameters = num_parameters @@ -728,7 +723,7 @@ def update_lr(self, lr): self.learning_rate = lr def compute(self, data: dict) -> dict: - """ Computes metrics from one epoch's batch of data + """Computes metrics from one epoch's batch of data Args: data: dict @@ -740,33 +735,33 @@ def compute(self, data: dict) -> dict: """ metrics = {} keys = list(data.keys()) - if 'loss' in keys: - metrics['loss'] = np.mean(data['loss']) - if 'time' in keys: + if "loss" in keys: + metrics["loss"] = np.mean(data["loss"]) + if "time" in keys: # assume it's seconds per image - FPS = 1 / get_denominator(np.mean(data['time'])) - metrics['fps'] = FPS - elif 'fps' in keys: - FPS = np.mean(data['fps']) - metrics['fps'] = FPS - if 'lr' in keys: + FPS = 1 / get_denominator(np.mean(data["time"])) + metrics["fps"] = FPS + elif "fps" in keys: + FPS = np.mean(data["fps"]) + metrics["fps"] = FPS + if "lr" in keys: # note: this should always be a scalar, but set to mean just in case there's multiple - metrics['lr'] = np.mean(data['lr']) + metrics["lr"] = np.mean(data["lr"]) return metrics def initialize_file(self): - mode = 'r+' if os.path.isfile(self.fname) else 'w' + mode = "r+" if os.path.isfile(self.fname) else "w" with h5py.File(self.fname, mode) as f: - f.attrs['num_parameters'] = self.num_parameters - f.attrs['key_metric'] = self.key_metric + f.attrs["num_parameters"] = self.num_parameters + f.attrs["key_metric"] = self.key_metric # make an HDF5 group for each split for split in self.splits: group = f.create_group(split) # all splits and datasets will have loss values-- others will come from self.compute() - group.create_dataset('loss', (0,), maxshape=(None,), dtype=np.float32) + group.create_dataset("loss", (0,), maxshape=(None,), dtype=np.float32) def save_metrics_to_disk(self, metrics: dict, split: str) -> None: - with h5py.File(self.fname, 'r+') as f: + with h5py.File(self.fname, "r+") as f: # utils.print_hdf5(f) if split not in f.keys(): # should've created top-level groups in initialize_file; this is for nesting @@ -778,7 +773,7 @@ def save_metrics_to_disk(self, metrics: dict, split: str) -> None: array = np.array(array) # ALLOW FOR NESTING if isinstance(array, dict): - group_name = split + '/' + key + group_name = split + "/" + key self.save_metrics_to_disk(array, group_name) elif isinstance(array, np.ndarray): if key in datasets: @@ -788,15 +783,16 @@ def save_metrics_to_disk(self, metrics: dict, split: str) -> None: # create dataset shape = (1, *array.shape) maxshape = (None, *array.shape) - log.debug('creating dataset {}/{}: shape {}'.format(split, key, shape)) + log.debug("creating dataset {}/{}: shape {}".format(split, key, shape)) group.create_dataset(key, shape, maxshape=maxshape, dtype=array.dtype) group[key][-1] = array else: - raise ValueError('Metrics must contain dicts of np.ndarrays, not {} of type {}'.format( - array, type(array))) + raise ValueError( + "Metrics must contain dicts of np.ndarrays, not {} of type {}".format(array, type(array)) + ) def end_epoch(self, split: str): - """ End the current training epoch. Saves any metrics in memory to disk + """End the current training epoch. Saves any metrics in memory to disk Parameters ---------- @@ -808,10 +804,10 @@ def end_epoch(self, split: str): # import pdb; pdb.set_trace() - if split != 'speedtest': - assert 'loss' in data.keys() + if split != "speedtest": + assert "loss" in data.keys() # store most recent loss and key metric as attributes, for use in scheduling, stopping, etc. - self.latest_loss[split] = metrics['loss'] + self.latest_loss[split] = metrics["loss"] self.latest_key[split] = metrics[self.key_metric] self.save_metrics_to_disk(metrics, split) @@ -819,21 +815,21 @@ def end_epoch(self, split: str): def __getitem__(self, inp: tuple) -> np.ndarray: split, metric_name, epoch_number = inp - with h5py.File(self.fname, 'r') as f: - assert split in f.keys(), 'split {} not found in file: {}'.format(split, list(f.keys())) + with h5py.File(self.fname, "r") as f: + assert split in f.keys(), "split {} not found in file: {}".format(split, list(f.keys())) group = f[split] - assert metric_name in group.keys(), 'metric {} not found in group: {}'.format( - metric_name, list(group.keys())) + assert metric_name in group.keys(), "metric {} not found in group: {}".format( + metric_name, list(group.keys()) + ) data = group[metric_name][epoch_number, ...] return data class EmptyMetrics(Metrics): - def __init__(self, *args, **kwargs): - super().__init__(os.getcwd(), [], 'loss', 'empty', 0) + super().__init__(os.getcwd(), [], "loss", "empty", 0) self.buffer = EmptyBuffer() - self.key_metric = 'loss' + self.key_metric = "loss" def end_epoch(self, split, *args, **kwargs): # calling this clears the buffer @@ -844,18 +840,20 @@ def initialize_file(self): class Classification(Metrics): - """ Metrics class for saving multiclass or multilabel classifcation metrics to disk """ - - def __init__(self, - run_dir: Union[str, bytes, os.PathLike], - key_metric: str, - num_parameters: int, - num_classes: int = None, - splits: list = ['train', 'val'], - ignore_index: int = -1, - evaluate_threshold: bool = False, - num_workers: int = 4): - """ Constructor for classification metrics class + """Metrics class for saving multiclass or multilabel classifcation metrics to disk""" + + def __init__( + self, + run_dir: Union[str, bytes, os.PathLike], + key_metric: str, + num_parameters: int, + num_classes: int = None, + splits: list = ["train", "val"], + ignore_index: int = -1, + evaluate_threshold: bool = False, + num_workers: int = 4, + ): + """Constructor for classification metrics class Parameters ---------- @@ -877,7 +875,7 @@ def __init__(self, Hack for multi-label classification problems. If True, at each epoch will compute a bunch of metrics for each potential threshold. See evaluate_thresholds """ - super().__init__(run_dir, key_metric, 'classification', num_parameters, splits, num_workers) + super().__init__(run_dir, key_metric, "classification", num_parameters, splits, num_workers) self.metric_funcs = all_metrics @@ -905,22 +903,22 @@ def compute(self, data: dict): # computes mean loss, etc metrics = super().compute(data) - if 'probs' not in data.keys(): + if "probs" not in data.keys(): # might happen during speedtest return metrics # automatically handle loss components for key in data.keys(): - if 'loss' in key and key != 'loss': + if "loss" in key and key != "loss": metrics[key] = np.mean(data[key]) # if data are from sequence models, stack into N*T x K not N x K x T - probs = self.stack_sequence_data(data['probs']) - if data['probs'].ndim == 3 and data['labels'].ndim == 2: + probs = self.stack_sequence_data(data["probs"]) + if data["probs"].ndim == 3 and data["labels"].ndim == 2: # special case for sequence models with final_activation==softmax, aka multiclass classification - labels = data['labels'].transpose(0, 1).flatten() + labels = data["labels"].transpose(0, 1).flatten() else: - labels = self.stack_sequence_data(data['labels']) + labels = self.stack_sequence_data(data["labels"]) num_classes = probs.shape[1] one_hot = probs.shape[-1] == labels.shape[-1] @@ -937,7 +935,7 @@ def compute(self, data: dict): with warnings.catch_warnings(): warnings.simplefilter("ignore") metrics_by_threshold, epoch_metrics = evaluate_thresholds(probs, labels, None, self.num_workers) - metrics['metrics_by_threshold'] = metrics_by_threshold + metrics["metrics_by_threshold"] = metrics_by_threshold for key, value in epoch_metrics.items(): metrics[key] = value else: @@ -949,12 +947,12 @@ def compute(self, data: dict): with warnings.catch_warnings(): for metric in self.metrics: - if metric == 'confusion': + if metric == "confusion": warnings.simplefilter("ignore") metrics[metric] = confusion(predictions, labels, K=self.num_classes) # import pdb # pdb.set_trace() - elif metric == 'binary_confusion': + elif metric == "binary_confusion": pass else: warnings.simplefilter("ignore") @@ -963,13 +961,13 @@ def compute(self, data: dict): class OpticalFlow(Metrics): - """ Metrics class for saving optic flow metrics to disk """ + """Metrics class for saving optic flow metrics to disk""" - def __init__(self, run_dir, key_metric, num_parameters, splits=['train', 'val']): - super().__init__(run_dir, key_metric, 'opticalflow', num_parameters, splits) + def __init__(self, run_dir, key_metric, num_parameters, splits=["train", "val"]): + super().__init__(run_dir, key_metric, "opticalflow", num_parameters, splits) def compute(self, data: dict) -> dict: - """ Computes metrics from one epoch's batch of data + """Computes metrics from one epoch's batch of data Args: data: dict @@ -981,7 +979,7 @@ def compute(self, data: dict) -> dict: """ metrics = super().compute(data) - for key in ['reg_loss', 'SSIM', 'L1', 'smoothness', 'sparsity', 'L1']: + for key in ["reg_loss", "SSIM", "L1", "smoothness", "sparsity", "L1"]: if key in data.keys(): metrics[key] = data[key].mean() return metrics diff --git a/deepethogram/postprocessing.py b/deepethogram/postprocessing.py index e329de6..4857866 100644 --- a/deepethogram/postprocessing.py +++ b/deepethogram/postprocessing.py @@ -13,14 +13,13 @@ log = logging.getLogger(__name__) -def remove_low_thresholds(thresholds: np.ndarray, - minimum: float = 0.01, - f1s: np.ndarray = None, - minimum_f1: float = 0.05) -> np.ndarray: - """ Replaces thresholds below a certain value with 0.5 - - If the model completely fails, the optimum threshold might be something erreoneous, such as - 0.00001. This makes all predictions==1. +def remove_low_thresholds( + thresholds: np.ndarray, minimum: float = 0.01, f1s: np.ndarray = None, minimum_f1: float = 0.05 +) -> np.ndarray: + """Replaces thresholds below a certain value with 0.5 + + If the model completely fails, the optimum threshold might be something erreoneous, such as + 0.00001. This makes all predictions==1. Parameters ---------- @@ -37,18 +36,18 @@ def remove_low_thresholds(thresholds: np.ndarray, """ if np.sum(thresholds < minimum) > 0: indices = np.where(thresholds < minimum)[0] - log.debug('thresholds {} too low, setting to {}'.format(thresholds[indices], minimum)) + log.debug("thresholds {} too low, setting to {}".format(thresholds[indices], minimum)) thresholds[thresholds < minimum] = minimum if f1s is not None: if np.sum(f1s < minimum_f1) > 0: indices = np.where(f1s < minimum_f1)[0] - log.debug('f1 {} too low, setting to 0.5'.format(f1s)) + log.debug("f1 {} too low, setting to 0.5".format(f1s)) thresholds[f1s < minimum_f1] = 0.5 return thresholds def get_onsets_offsets(binary: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """ Gets the onset and offset indices of a binary array. + """Gets the onset and offset indices of a binary array. Onset: index at which the array goes from 0 -> 1 (the index with the 1, not the 0) offset: index at which the array goes from 1 -> 0 (the index with the 0, not the 1) @@ -78,26 +77,25 @@ def get_onsets_offsets(binary: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: def get_bouts(ethogram: np.ndarray) -> list: - """ Get bouts from an ethogram. Uses 1->0 and 0->1 changes to define bout starts and stops """ + """Get bouts from an ethogram. Uses 1->0 and 0->1 changes to define bout starts and stops""" K = ethogram.shape[1] stats = [] for i in range(K): onsets, offsets = get_onsets_offsets(ethogram[:, i]) stat = { - 'N': len(onsets), - 'lengths': np.array([offset - onset for (onset, offset) in zip(onsets, offsets)]), - 'starts': onsets, - 'ends': offsets + "N": len(onsets), + "lengths": np.array([offset - onset for (onset, offset) in zip(onsets, offsets)]), + "starts": onsets, + "ends": offsets, } stats.append(stat) return stats -def find_bout_indices(predictions_trace: np.ndarray, - bout_length: int, - positive: bool = True, - eps: float = 1e-6) -> np.ndarray: - """ Find indices where a bout of bout-length occurs in a binary vector +def find_bout_indices( + predictions_trace: np.ndarray, bout_length: int, positive: bool = True, eps: float = 1e-6 +) -> np.ndarray: + """Find indices where a bout of bout-length occurs in a binary vector Bouts are defined as consecutive sets of 1s (if `positive`) or 0s (if not `positive`). Parameters @@ -119,7 +117,7 @@ def find_bout_indices(predictions_trace: np.ndarray, filt = np.concatenate([[-bout_length / 2], center, [-bout_length / 2]]) if not positive: predictions_trace = np.logical_not(predictions_trace.copy()).astype(int) - out = np.convolve(predictions_trace, filt, mode='same') + out = np.convolve(predictions_trace, filt, mode="same") # precision issues: using == 1 here has false negatives in case where out = 0.99999999998 or something indices = np.where(np.abs(out - 1) < eps)[0] if len(indices) == 0: @@ -135,7 +133,7 @@ def find_bout_indices(predictions_trace: np.ndarray, def remove_short_bouts_from_trace(predictions_trace: np.ndarray, bout_length: int) -> np.ndarray: - """ Removes bouts of length <= `bout_length` from a binary vector. + """Removes bouts of length <= `bout_length` from a binary vector. Important note: we first remove "false negatives." e.g. if `bout_length` is 2, this will do something like: 000111001111111011000010 -> 000111111111111111000010 @@ -154,7 +152,7 @@ def remove_short_bouts_from_trace(predictions_trace: np.ndarray, bout_length: in predictions_trace: np.ndarray. shape (N, ) binary predictions trace with short bouts removed """ - assert len(predictions_trace.shape) == 1, 'only 1D input: {}'.format(predictions_trace.shape) + assert len(predictions_trace.shape) == 1, "only 1D input: {}".format(predictions_trace.shape) # first remove 1 frame bouts, then 2 frames, then 3 frames for bout_len in range(1, bout_length + 1): # first, remove "false negatives", like filling in gaps in true behavior bouts @@ -167,7 +165,7 @@ def remove_short_bouts_from_trace(predictions_trace: np.ndarray, bout_length: in def remove_short_bouts(predictions: np.ndarray, bout_length: int) -> np.ndarray: - """ Removes short bouts from a predictions array + """Removes short bouts from a predictions array Applies `remove_short_bouts_from_trace` to each column of the input. @@ -183,8 +181,9 @@ def remove_short_bouts(predictions: np.ndarray, bout_length: int) -> np.ndarray: predictions: np.ndarray, shape (N, K) Array of N timepoints and K classes with short bouts removed """ - assert len(predictions.shape) == 2, \ - '2D input to remove short bouts required (timepoints x classes): {}'.format(predictions.shape) + assert len(predictions.shape) == 2, "2D input to remove short bouts required (timepoints x classes): {}".format( + predictions.shape + ) T, K = predictions.shape for k in range(K): @@ -193,7 +192,7 @@ def remove_short_bouts(predictions: np.ndarray, bout_length: int) -> np.ndarray: def compute_background(predictions: np.ndarray) -> np.ndarray: - """ Makes the background positive when no other behaviors are occurring + """Makes the background positive when no other behaviors are occurring Parameters ---------- @@ -206,29 +205,30 @@ def compute_background(predictions: np.ndarray) -> np.ndarray: Binary predictions. the background class is now the logical_not of whether or not there are any positive examples in the rest of the row """ - assert len(predictions.shape) == 2, 'predictions must be a TxK matrix: not {}'.format(predictions.shape) + assert len(predictions.shape) == 2, "predictions must be a TxK matrix: not {}".format(predictions.shape) predictions[:, 0] = np.logical_not(np.any(predictions[:, 1:], axis=1)).astype(np.uint8) return predictions class Postprocessor: - """ Base class for postprocessing a set of input probabilities into predictions """ + """Base class for postprocessing a set of input probabilities into predictions""" + def __init__(self, thresholds: np.ndarray, min_threshold=0.01): - assert len(thresholds.shape) == 1, 'thresholds must be 1D array, not {}'.format(thresholds.shape) + assert len(thresholds.shape) == 1, "thresholds must be 1D array, not {}".format(thresholds.shape) # edge case with poor thresholds, causes all predictions to be ==1 thresholds = remove_low_thresholds(thresholds, minimum=min_threshold) self.thresholds = thresholds def threshold(self, probabilities: np.ndarray) -> np.ndarray: - """ Applies thresholds to binarize inputs """ - assert len(probabilities.shape) == 2, 'probabilities must be a TxK matrix: not {}'.format(probabilities.shape) + """Applies thresholds to binarize inputs""" + assert len(probabilities.shape) == 2, "probabilities must be a TxK matrix: not {}".format(probabilities.shape) assert probabilities.shape[1] == self.thresholds.shape[0] predictions = (probabilities > self.thresholds).astype(int) return predictions def process(self, probabilities: np.ndarray) -> np.ndarray: - """ Process probabilities. Will be overridden by subclasses """ + """Process probabilities. Will be overridden by subclasses""" # the simplest form of postprocessing is just thresholding and making sure that background is the actual # logical_not of any other behavior. Therefore, its threshold is not used predictions = self.threshold(probabilities) @@ -240,7 +240,8 @@ def __call__(self, probabilities: np.ndarray) -> np.ndarray: class MinBoutLengthPostprocessor(Postprocessor): - """ Postprocessor that removes bouts of length less than or equal to bout_length """ + """Postprocessor that removes bouts of length less than or equal to bout_length""" + def __init__(self, thresholds: np.ndarray, bout_length: int, **kwargs): super().__init__(thresholds, **kwargs) self.bout_length = bout_length @@ -253,7 +254,8 @@ def process(self, probabilities: np.ndarray) -> np.ndarray: class MinBoutLengthPerBehaviorPostprocessor(Postprocessor): - """ Postprocessor that removes bouts of length less than or equal to bout_length """ + """Postprocessor that removes bouts of length less than or equal to bout_length""" + def __init__(self, thresholds: np.ndarray, bout_lengths: list, **kwargs): super().__init__(thresholds, **kwargs) assert len(thresholds) == len(bout_lengths) @@ -296,7 +298,7 @@ def get_bout_length_percentile(label_list: list, percentile: float) -> dict: bouts = get_bouts(label) T, K = label.shape for k in range(K): - bout_length = bouts[k]['lengths'].tolist() + bout_length = bouts[k]["lengths"].tolist() bout_lengths[k].append(bout_length) bout_lengths = {behavior: np.concatenate(value) for behavior, value in bout_lengths.items()} # print(bout_lengths) @@ -311,12 +313,12 @@ def get_bout_length_percentile(label_list: list, percentile: float) -> dict: def get_postprocessor_from_cfg(cfg: DictConfig, thresholds: np.ndarray) -> Type[Postprocessor]: - """ Returns a PostProcessor from an OmegaConf DictConfig returned by a """ + """Returns a PostProcessor from an OmegaConf DictConfig returned by a""" if cfg.postprocessor.type is None: return Postprocessor(thresholds) - elif cfg.postprocessor.type == 'min_bout': + elif cfg.postprocessor.type == "min_bout": return MinBoutLengthPostprocessor(thresholds, cfg.postprocessor.min_bout_length) - elif cfg.postprocessor.type == 'min_bout_per_behavior': + elif cfg.postprocessor.type == "min_bout_per_behavior": if not os.path.isdir(cfg.project.data_path): cfg = projects.convert_config_paths_to_absolute(cfg) assert os.path.isdir(cfg.project.data_path) @@ -325,7 +327,7 @@ def get_postprocessor_from_cfg(cfg: DictConfig, thresholds: np.ndarray) -> Type[ label_list = [] for animal, record in records.items(): - labelfile = record['label'] + labelfile = record["label"] if labelfile is None: continue label = file_io.read_labels(labelfile) @@ -350,7 +352,7 @@ def postprocess_and_save(cfg: DictConfig) -> None: ---------- cfg : DictConfig a project configuration. Must have the `sequence` and `postprocessing` sections - + Goes through each "outputfile" in the project, loads the probabilities, postprocesses them, and saves to disk with the name `base + _predictions.csv`. """ @@ -361,15 +363,15 @@ def postprocess_and_save(cfg: DictConfig) -> None: output_name = cfg.sequence.output_name behavior_names = OmegaConf.to_container(cfg.project.class_names) - records = projects.get_records_from_datadir(os.path.join(cfg.project.path, 'DATA')) + records = projects.get_records_from_datadir(os.path.join(cfg.project.path, "DATA")) for _, record in records.items(): - with h5py.File(record['output'], 'r') as f: - p = f[output_name]['P'][:] - thresholds = f[output_name]['thresholds'][:] + with h5py.File(record["output"], "r") as f: + p = f[output_name]["P"][:] + thresholds = f[output_name]["thresholds"][:] postprocessor = get_postprocessor_from_cfg(cfg, thresholds) predictions = postprocessor(p) df = pd.DataFrame(data=predictions, columns=behavior_names) - base = os.path.splitext(record['rgb'])[0] - filename = base + '_predictions.csv' - df.to_csv(filename) \ No newline at end of file + base = os.path.splitext(record["rgb"])[0] + filename = base + "_predictions.csv" + df.to_csv(filename) diff --git a/deepethogram/projects.py b/deepethogram/projects.py index 3423bbf..6cac546 100644 --- a/deepethogram/projects.py +++ b/deepethogram/projects.py @@ -20,15 +20,17 @@ log = logging.getLogger(__name__) -required_keys = ['project', 'augs'] +required_keys = ["project", "augs"] projects_file_directory = os.path.dirname(os.path.abspath(__file__)) -def initialize_project(directory: Union[str, os.PathLike], - project_name: str, - behaviors: list = None, - make_subdirectory: bool = True, - labeler: str = None): +def initialize_project( + directory: Union[str, os.PathLike], + project_name: str, + behaviors: list = None, + make_subdirectory: bool = True, + labeler: str = None, +): """Initializes a DeepEthogram project. Copies the default configuration file and updates it with the directory, name, and behaviors specified. Makes directories where project info, data, and models will live. @@ -47,46 +49,46 @@ def initialize_project(directory: Union[str, os.PathLike], Example: intialize_project('C:/DATA', 'grooming', ['background', 'face_groom', 'body_groom', 'rear']) """ - assert os.path.isdir(directory), 'Directory does not exist: {}'.format(directory) + assert os.path.isdir(directory), "Directory does not exist: {}".format(directory) if behaviors is not None: - assert behaviors[0] == 'background' + assert behaviors[0] == "background" root = os.path.dirname(os.path.abspath(__file__)) - project_config = utils.load_yaml(os.path.join(root, 'conf', 'project', 'project_config_default.yaml')) - project_name = project_name.replace(' ', '_') + project_config = utils.load_yaml(os.path.join(root, "conf", "project", "project_config_default.yaml")) + project_name = project_name.replace(" ", "_") - project_config['project']['name'] = project_name + project_config["project"]["name"] = project_name - project_config['project']['class_names'] = behaviors + project_config["project"]["class_names"] = behaviors if make_subdirectory: - project_dir = os.path.join(directory, '{}_deepethogram'.format(project_name)) + project_dir = os.path.join(directory, "{}_deepethogram".format(project_name)) else: project_dir = directory - project_config['project']['path'] = project_dir + project_config["project"]["path"] = project_dir - project_config['project']['data_path'] = 'DATA' - project_config['project']['model_path'] = 'models' - project_config['project']['labeler'] = labeler + project_config["project"]["data_path"] = "DATA" + project_config["project"]["model_path"] = "models" + project_config["project"]["labeler"] = labeler - if not os.path.isdir(project_config['project']['path']): - os.makedirs(project_config['project']['path']) + if not os.path.isdir(project_config["project"]["path"]): + os.makedirs(project_config["project"]["path"]) # os.chdir(project_config['project']['path']) - data_abs = os.path.join(project_config['project']['path'], project_config['project']['data_path']) + data_abs = os.path.join(project_config["project"]["path"], project_config["project"]["data_path"]) if not os.path.isdir(data_abs): os.makedirs(data_abs) - model_abs = os.path.join(project_config['project']['path'], project_config['project']['model_path']) + model_abs = os.path.join(project_config["project"]["path"], project_config["project"]["model_path"]) if not os.path.isdir(model_abs): os.makedirs(model_abs) - fname = os.path.join(project_dir, 'project_config.yaml') - project_config['project']['config_file'] = fname + fname = os.path.join(project_dir, "project_config.yaml") + project_config["project"]["config_file"] = fname utils.save_dict_to_yaml(project_config, fname) return project_config -def add_video_to_project(project: dict, path_to_video: Union[str, os.PathLike], mode: str = 'copy') -> str: +def add_video_to_project(project: dict, path_to_video: Union[str, os.PathLike], mode: str = "copy") -> str: """ Adds a video file to a DEG project. @@ -113,21 +115,21 @@ def add_video_to_project(project: dict, path_to_video: Union[str, os.PathLike], path to the video file after moving to the DEG project data directory. """ # assert (os.path.isdir(project_directory)) - assert os.path.exists(path_to_video), 'video not found! {}'.format(path_to_video) + assert os.path.exists(path_to_video), "video not found! {}".format(path_to_video) if os.path.isdir(path_to_video): copy_func = shutil.copytree elif os.path.isfile(path_to_video): copy_func = shutil.copy else: - raise ValueError('video does not exist: {}'.format(path_to_video)) + raise ValueError("video does not exist: {}".format(path_to_video)) - assert mode in ['copy', 'symlink', 'move'] + assert mode in ["copy", "symlink", "move"] # project = utils.load_yaml(os.path.join(project_directory, 'project_config.yaml')) # project = convert_config_paths_to_absolute(project) - log.debug('configuration file when adding video: {}'.format(project)) - datadir = os.path.join(project['project']['path'], project['project']['data_path']) - assert os.path.isdir(datadir), 'data path not found: {}'.format(datadir) + log.debug("configuration file when adding video: {}".format(project)) + datadir = os.path.join(project["project"]["path"], project["project"]["data_path"]) + assert os.path.isdir(datadir), "data path not found: {}".format(datadir) # for speed during training, videos can be saved as directories of PNG / JPEG files. if os.path.isdir(path_to_video): @@ -140,25 +142,26 @@ def add_video_to_project(project: dict, path_to_video: Union[str, os.PathLike], video_directory = os.path.join(datadir, vidname) if os.path.isdir(video_directory): - raise ValueError('Directory {} already exists in your data dir! ' \ - 'Please rename the video to a unique name'.format(vidname)) + raise ValueError( + "Directory {} already exists in your data dir! " "Please rename the video to a unique name".format(vidname) + ) os.makedirs(video_directory) new_path = os.path.join(video_directory, basename) - if mode == 'copy': + if mode == "copy": if video_is_directory: shutil.copytree(path_to_video, new_path) else: shutil.copy(path_to_video, new_path) - elif mode == 'symlink': + elif mode == "symlink": os.symlink(path_to_video, new_path) - elif mode == 'move': + elif mode == "move": shutil.move(path_to_video, new_path) else: - raise ValueError('invalid argument to mode: {}'.format(mode)) + raise ValueError("invalid argument to mode: {}".format(mode)) record = parse_subdir(video_directory) - log.debug('New record after adding: {}'.format(record)) - utils.save_dict_to_yaml(record, os.path.join(video_directory, 'record.yaml')) + log.debug("New record after adding: {}".format(record)) + utils.save_dict_to_yaml(record, os.path.join(video_directory, "record.yaml")) zscore_video(os.path.join(video_directory, basename), project) return new_path @@ -173,34 +176,34 @@ def add_label_to_project(path_to_labels: Union[str, os.PathLike], path_to_video) label_dst = os.path.join(viddir, os.path.basename(path_to_labels)) if os.path.isfile(label_dst): - warnings.warn('Label already exists in destination {}, overwriting...'.format(label_dst)) + warnings.warn("Label already exists in destination {}, overwriting...".format(label_dst)) df = pd.read_csv(path_to_labels, index_col=0) - if 'none' in list(df.columns): - df = df.rename(columns={'none': 'background'}) - if 'background' not in list(df.columns): + if "none" in list(df.columns): + df = df.rename(columns={"none": "background"}) + if "background" not in list(df.columns): array = df.values is_background = np.logical_not(np.any(array == 1, axis=1)).astype(int)[:, np.newaxis] data = np.concatenate((is_background, array), axis=1) # df2 = pd.DataFrame(data=is_background, columns=['background']) # df = pd.concat([df2, df], axis=1) - df = pd.DataFrame(data=data, columns=['background'] + list(df.columns)) + df = pd.DataFrame(data=data, columns=["background"] + list(df.columns)) df.to_csv(label_dst) record = parse_subdir(viddir) - utils.save_dict_to_yaml(record, os.path.join(viddir, 'record.yaml')) + utils.save_dict_to_yaml(record, os.path.join(viddir, "record.yaml")) return label_dst def add_file_to_subdir(file: Union[str, os.PathLike], subdir: Union[str, os.PathLike]): """If you save or move a file into a DEG subdirectory, update the record""" if not is_deg_file(subdir): - raise ValueError('directory is not a DEG subdir: {}'.format(subdir)) - assert (os.path.isfile(file)) + raise ValueError("directory is not a DEG subdir: {}".format(subdir)) + assert os.path.isfile(file) if os.path.dirname(file) != subdir: shutil.copy(file, os.path.join(subdir, os.path.basename(file))) record = parse_subdir(subdir) - utils.save_dict_to_yaml(record, os.path.join(subdir, 'record.yaml')) + utils.save_dict_to_yaml(record, os.path.join(subdir, "record.yaml")) # def change_project_directory(config_file: Union[str, os.PathLike], new_directory: Union[str, os.PathLike]): @@ -235,28 +238,28 @@ def is_deg_file(filename: Union[str, os.PathLike]) -> bool: basedir = os.path.dirname(filename) is_directory = False else: - raise ValueError('submit directory or file to is_deg_file, not {}'.format(filename)) + raise ValueError("submit directory or file to is_deg_file, not {}".format(filename)) - recordfile = os.path.join(basedir, 'record.yaml') + recordfile = os.path.join(basedir, "record.yaml") record_exists = os.path.isfile(recordfile) if is_directory: # this is required in case the file passed is a directory full of images; e.g. # project/DATA/animal0/images/00000.jpg - parent_record_exists = os.path.isfile(os.path.join(os.path.dirname(filename), 'record.yaml')) + parent_record_exists = os.path.isfile(os.path.join(os.path.dirname(filename), "record.yaml")) return record_exists or parent_record_exists else: return record_exists def add_behavior_to_project(config_file: Union[str, os.PathLike], behavior_name: str): - """ Adds a behavior (class) to the project. + """Adds a behavior (class) to the project. Adds this behavior to the class_names field of your project configuration. Adds -1 column in all labelfiles in current project. Saves the altered project_config to disk. - Removes any file with previous outputs / latents. - + Removes any file with previous outputs / latents. + Parameters ---------- config_file: str, PathLike @@ -264,17 +267,18 @@ def add_behavior_to_project(config_file: Union[str, os.PathLike], behavior_name: behavior_name: str behavior to add to the project. """ - assert (os.path.isfile(config_file)) + assert os.path.isfile(config_file) project_config = utils.load_yaml(config_file) - assert 'class_names' in list(project_config['project'].keys()) - classes = project_config['project']['class_names'] + assert "class_names" in list(project_config["project"].keys()) + classes = project_config["project"]["class_names"] assert behavior_name not in classes classes.append(behavior_name) records = get_records_from_datadir( - os.path.join(project_config['project']['path'], project_config['project']['data_path'])) + os.path.join(project_config["project"]["path"], project_config["project"]["data_path"]) + ) for key, record in records.items(): - labelfile = record['label'] + labelfile = record["label"] if labelfile is None: continue if os.path.isfile(labelfile): @@ -285,9 +289,9 @@ def add_behavior_to_project(config_file: Union[str, os.PathLike], behavior_name: df2 = pd.DataFrame(data=np.ones((N, 1)) * -1, columns=[behavior_name]) df = pd.concat([df, df2], axis=1) df.to_csv(labelfile) - if record['output'] is not None and os.path.isfile(record['output']): - os.remove(record['output']) - project_config['project']['class_names'] = classes + if record["output"] is not None and os.path.isfile(record["output"]): + os.remove(record["output"]) + project_config["project"]["class_names"] = classes utils.save_dict_to_yaml(project_config, config_file) @@ -306,18 +310,19 @@ def remove_behavior_from_project(config_file: Union[str, os.PathLike], behavior_ behavior_name: str One of the existing behavior_names. """ - if behavior_name == 'background': - raise ValueError('Cannot remove background class.') - assert (os.path.isfile(config_file)) + if behavior_name == "background": + raise ValueError("Cannot remove background class.") + assert os.path.isfile(config_file) project_config = utils.load_yaml(config_file) - assert 'class_names' in list(project_config['project'].keys()) - classes = project_config['project']['class_names'] + assert "class_names" in list(project_config["project"].keys()) + classes = project_config["project"]["class_names"] assert behavior_name in classes records = get_records_from_datadir( - os.path.join(project_config['project']['path'], project_config['project']['data_path'])) + os.path.join(project_config["project"]["path"], project_config["project"]["data_path"]) + ) for key, record in records.items(): - labelfile = record['label'] + labelfile = record["label"] if labelfile is None: continue if os.path.isfile(labelfile): @@ -327,10 +332,10 @@ def remove_behavior_from_project(config_file: Union[str, os.PathLike], behavior_ continue df2 = df.drop(behavior_name, axis=1) df2.to_csv(labelfile) - if record['output'] is not None and os.path.isfile(record['output']): - os.remove(record['output']) + if record["output"] is not None and os.path.isfile(record["output"]): + os.remove(record["output"]) classes.remove(behavior_name) - project_config['project']['class_names'] = classes + project_config["project"]["class_names"] = classes utils.save_dict_to_yaml(project_config, config_file) @@ -349,14 +354,14 @@ def get_classes_from_project(config: Union[dict, str, os.PathLike, DictConfig]) """ if type(config) == str or type(config) == os.PathLike: - config_file = os.path.join(config, 'project_config.yaml') - assert os.path.isfile(config_file), 'Input must be a directory containing a project_config.yaml file' + config_file = os.path.join(config, "project_config.yaml") + assert os.path.isfile(config_file), "Input must be a directory containing a project_config.yaml file" config = utils.load_yaml(config_file) - assert 'project' in list(config.keys()), 'Invalid project configuration dictionary: {}'.format(config) - project = config['project'] - assert 'class_names' in list(project.keys()), 'Must have class names in project config file' - return project['project']['class_names'] + assert "project" in list(config.keys()), "Invalid project configuration dictionary: {}".format(config) + project = config["project"] + assert "class_names" in list(project.keys()), "Must have class names in project config file" + return project["project"]["class_names"] def exclude_strings_from_filelist(files: list, excluded: list) -> list: @@ -382,7 +387,7 @@ def exclude_strings_from_filelist(files: list, excluded: list) -> list: def find_labelfiles(root: Union[str, bytes, os.PathLike]) -> list: - """ Gets label files from a deepethogram data directory + """Gets label files from a deepethogram data directory Args: root (str, pathlike): directory containing labels, movies, etc @@ -390,8 +395,8 @@ def find_labelfiles(root: Union[str, bytes, os.PathLike]) -> list: Returns: files: list of score or label files """ - files = get_subfiles(root, return_type='file') - files = [i for i in files if 'label' in os.path.basename(i).lower() or 'score' in os.path.basename(i).lower()] + files = get_subfiles(root, return_type="file") + files = [i for i in files if "label" in os.path.basename(i).lower() or "score" in os.path.basename(i).lower()] return files @@ -404,20 +409,20 @@ def find_rgbfiles(root: Union[str, bytes, os.PathLike]) -> list: Returns: list of absolute paths to RGB videos, or subdirectories containing individual images (framedirs) """ - files = get_subfiles(root, return_type='any') + files = get_subfiles(root, return_type="any") endings = [os.path.splitext(i)[1] for i in files] - valid_endings = ['.avi', '.mp4', '.h5', '.mov'] - excluded = ['flow', 'label', 'output', 'score'] + valid_endings = [".avi", ".mp4", ".h5", ".mov"] + excluded = ["flow", "label", "output", "score"] movies = [i for i in files if os.path.splitext(i)[1].lower() in valid_endings] movies = exclude_strings_from_filelist(movies, excluded) - framedirs = get_subfiles(root, return_type='directory') + framedirs = get_subfiles(root, return_type="directory") framedirs = exclude_strings_from_filelist(framedirs, excluded) return movies + framedirs def find_flowfiles(root: Union[str, bytes, os.PathLike]) -> list: - """ DEPRECATED. + """DEPRECATED. Args: root (): @@ -425,20 +430,20 @@ def find_flowfiles(root: Union[str, bytes, os.PathLike]) -> list: Returns: """ - files = get_subfiles(root, return_type='any') + files = get_subfiles(root, return_type="any") endings = [os.path.splitext(i)[1] for i in files] - valid_endings = ['.avi', '.mp4', '.h5'] + valid_endings = [".avi", ".mp4", ".h5"] movies = [ - files[i] for i in range(len(files)) if endings[i] in valid_endings and 'flow' in os.path.basename(files[i]) + files[i] for i in range(len(files)) if endings[i] in valid_endings and "flow" in os.path.basename(files[i]) ] framedirs = [ - i for i in get_subfiles(root, return_type='directory') if 'frame' in i and 'flow' in os.path.basename(i) + i for i in get_subfiles(root, return_type="directory") if "frame" in i and "flow" in os.path.basename(i) ] return movies + framedirs def find_outputfiles(root: Union[str, bytes, os.PathLike]) -> list: - """ Finds deepethogram outputfiles, containing RGB and flow features, along with P(K) + """Finds deepethogram outputfiles, containing RGB and flow features, along with P(K) Args: root (str, pathlike): deepethogram data directory @@ -446,13 +451,13 @@ def find_outputfiles(root: Union[str, bytes, os.PathLike]) -> list: Returns: list of outputfiles. should only have one element """ - files = get_subfiles(root, return_type='file') - files = [i for i in files if 'output' in os.path.basename(i).lower() and os.path.splitext(i)[1].lower() == '.h5'] + files = get_subfiles(root, return_type="file") + files = [i for i in files if "output" in os.path.basename(i).lower() and os.path.splitext(i)[1].lower() == ".h5"] return files def find_keypointfiles(root: Union[str, bytes, os.PathLike]) -> list: - """ Finds .csv files of DeepLabCut outputs in the data directories + """Finds .csv files of DeepLabCut outputs in the data directories Args: root: (str, pathlike): deepethogram data directory @@ -461,13 +466,13 @@ def find_keypointfiles(root: Union[str, bytes, os.PathLike]) -> list: list of dlcfiles. should only have one element """ # TODO: support SLEAP, DLC hdf5 files - files = get_subfiles(root, return_type='file') - files = [i for i in files if 'dlc' in os.path.basename(i).lower() and os.path.splitext(i)[1] == '.csv'] + files = get_subfiles(root, return_type="file") + files = [i for i in files if "dlc" in os.path.basename(i).lower() and os.path.splitext(i)[1] == ".csv"] return files def find_statsfiles(root: Union[str, bytes, os.PathLike]) -> list: - """ Finds normalization statistics in deepethogram data directory + """Finds normalization statistics in deepethogram data directory Args: root (str, pathlike) @@ -476,25 +481,25 @@ def find_statsfiles(root: Union[str, bytes, os.PathLike]) -> list: Returns: list of stats files, should only have 1 or 0 elements """ - files = get_subfiles(root, return_type='file') - files = [i for i in files if 'stats' in os.path.basename(i) and os.path.splitext(i)[1] == '.yaml'] + files = get_subfiles(root, return_type="file") + files = [i for i in files if "stats" in os.path.basename(i) and os.path.splitext(i)[1] == ".yaml"] return files def get_type_from_file(file: Union[str, bytes, os.PathLike]) -> str: """Convenience function. Gets type of VideoReader input file from a path""" if os.path.isdir(file): - if 'frame' in os.path.basename(file): - return 'directory' + if "frame" in os.path.basename(file): + return "directory" elif os.path.isfile(file): _, ext = os.path.splitext(file) return ext else: - raise ValueError('file does not exist: {}'.format(file)) + raise ValueError("file does not exist: {}".format(file)) def get_files_by_preferences(files, preference: list = None) -> str: - """ Given a list of files with different types, return the most-preferred filetype (given by preference) + """Given a list of files with different types, return the most-preferred filetype (given by preference) Example: files = ['movie.mp4', 'movie.avi'] @@ -531,7 +536,7 @@ def get_files_by_preferences(files, preference: list = None) -> str: def parse_subdir(root: Union[str, bytes, os.PathLike], preference: list = None) -> dict: - """ Find rgb, flow, label, output, and channel statistics files in a given directory + """Find rgb, flow, label, output, and channel statistics files in a given directory Parameters ---------- @@ -555,15 +560,15 @@ def parse_subdir(root: Union[str, bytes, os.PathLike], preference: list = None) if preference is None: # determine default here # sorted by combination of sequential and random read speeds - preference = ['directory', '.h5', '.avi', '.mp4'] + preference = ["directory", ".h5", ".avi", ".mp4"] find_files = { - 'rgb': find_rgbfiles, - 'flow': find_flowfiles, - 'label': find_labelfiles, - 'output': find_outputfiles, - 'stats': find_statsfiles, - 'keypoint': find_keypointfiles + "rgb": find_rgbfiles, + "flow": find_flowfiles, + "label": find_labelfiles, + "output": find_outputfiles, + "stats": find_statsfiles, + "keypoint": find_keypointfiles, } record = {} @@ -572,17 +577,17 @@ def parse_subdir(root: Union[str, bytes, os.PathLike], preference: list = None) record[entry] = {} files = find_files[entry](root) if len(files) == 0: - record[entry]['all'] = [] - record[entry]['default'] = [] + record[entry]["all"] = [] + record[entry]["default"] = [] else: - record[entry]['all'] = [os.path.basename(i) for i in files] - record[entry]['default'] = os.path.basename(get_files_by_preferences(files, preference)) - record['key'] = os.path.basename(root) + record[entry]["all"] = [os.path.basename(i) for i in files] + record[entry]["default"] = os.path.basename(get_files_by_preferences(files, preference)) + record["key"] = os.path.basename(root) return record def get_record_from_subdir(subdir: Union[str, os.PathLike]) -> dict: - """ Gets a dictionary of absolute filepaths for each semantic file type, e.g. RGB movies, labels, output files + """Gets a dictionary of absolute filepaths for each semantic file type, e.g. RGB movies, labels, output files Parameters ---------- @@ -603,9 +608,9 @@ def get_record_from_subdir(subdir: Union[str, os.PathLike]) -> dict: record = parse_subdir(subdir) parsed_record = {} - for key in ['flow', 'label', 'output', 'rgb', 'keypoint']: + for key in ["flow", "label", "output", "rgb", "keypoint"]: if key in list(record.keys()): - this_entry = record[key]['default'] + this_entry = record[key]["default"] if type(this_entry) == list and len(this_entry) == 0: this_entry = None @@ -614,12 +619,12 @@ def get_record_from_subdir(subdir: Union[str, os.PathLike]) -> dict: if not os.path.isfile(this_entry) and not os.path.isdir(this_entry): this_entry = None parsed_record[key] = this_entry - parsed_record['key'] = os.path.basename(subdir) + parsed_record["key"] = os.path.basename(subdir) return parsed_record def get_records_from_datadir(datadir: Union[str, bytes, os.PathLike]) -> dict: - """ Gets a dictionary of record dictionaries from a data directory + """Gets a dictionary of record dictionaries from a data directory Parameters ---------- @@ -639,18 +644,18 @@ def get_records_from_datadir(datadir: Union[str, bytes, os.PathLike]) -> dict: ... } """ - assert os.path.isdir(datadir), 'datadir does not exist: {}'.format(datadir) - subdirs = get_subfiles(datadir, return_type='directory') + assert os.path.isdir(datadir), "datadir does not exist: {}".format(datadir) + subdirs = get_subfiles(datadir, return_type="directory") records = {} for subdir in subdirs: parsed_record = get_record_from_subdir(os.path.join(datadir, subdir)) - records[parsed_record['key']] = parsed_record + records[parsed_record["key"]] = parsed_record # write_all_records(datadir) return records def filter_records_for_filetypes(records: dict, return_types: list): - """ Find the records that have all the requested filetypes. e.g. get all subdirectories with labels """ + """Find the records that have all the requested filetypes. e.g. get all subdirectories with labels""" valid_records = {} for k, v in records.items(): # k is the key for this record, e.g. experiment00_mouse00 @@ -659,7 +664,7 @@ def filter_records_for_filetypes(records: dict, return_types: list): all_present = True for t in return_types: if v[t] is None: - log.warning('No {} file found in record: {}'.format(t, k)) + log.warning("No {} file found in record: {}".format(t, k)) all_present = False if all_present: valid_records[k] = v @@ -667,7 +672,7 @@ def filter_records_for_filetypes(records: dict, return_types: list): def is_config_dict(config: dict) -> bool: - """ Tells if a dictionary is a valid project dictionary """ + """Tells if a dictionary is a valid project dictionary""" config_keys = list(config.keys()) for k in required_keys: if k not in config_keys: @@ -676,14 +681,14 @@ def is_config_dict(config: dict) -> bool: def get_unfinalized_records(config: dict) -> list: - """ Finds the number of label files with no unlabeled frames """ - records = get_records_from_datadir(os.path.join(config['project']['path'], config['project']['data_path'])) + """Finds the number of label files with no unlabeled frames""" + records = get_records_from_datadir(os.path.join(config["project"]["path"], config["project"]["data_path"])) unfinalized = [] for k, v in records.items(): - if v['label'] is None or len(v['label']) == 0: + if v["label"] is None or len(v["label"]) == 0: unfinalized.append(v) else: - label = read_labels(v['label']) + label = read_labels(v["label"]) has_unlabeled_frames = np.any(label == -1) if has_unlabeled_frames: unfinalized.append(v) @@ -691,12 +696,12 @@ def get_unfinalized_records(config: dict) -> list: def get_number_finalized_labels(config: dict) -> int: - """ Finds the number of label files with no unlabeled frames """ - records = get_records_from_datadir(os.path.join(config['project']['path'], config['project']['data_path'])) + """Finds the number of label files with no unlabeled frames""" + records = get_records_from_datadir(os.path.join(config["project"]["path"], config["project"]["data_path"])) number_valid_labels = 0 for k, v in records.items(): for filetype, fileloc in v.items(): - if filetype == 'label': + if filetype == "label": if fileloc is None or len(fileloc) == 0: continue label = read_labels(fileloc) @@ -706,11 +711,13 @@ def get_number_finalized_labels(config: dict) -> int: return number_valid_labels -def import_outputfile(project_dir: Union[str, os.PathLike], - outputfile: Union[str, os.PathLike], - class_names: list = None, - latent_name: str = None): - """ Gets the probabilities, thresholds, used HDF5 dataset key, and all dataset keys from an outputfile +def import_outputfile( + project_dir: Union[str, os.PathLike], + outputfile: Union[str, os.PathLike], + class_names: list = None, + latent_name: str = None, +): + """Gets the probabilities, thresholds, used HDF5 dataset key, and all dataset keys from an outputfile Parameters ---------- @@ -739,35 +746,34 @@ def import_outputfile(project_dir: Union[str, os.PathLike], assert os.path.isfile(outputfile) assert os.path.isdir(project_dir) # handle edge case - if latent_name == '' or latent_name == ' ': + if latent_name == "" or latent_name == " ": latent_name = None # all this tortured logic is to try to figure out what the correct "latent name" is in an HDF5 file. Also includes # logic for backwards compatibility - project_config = load_config(os.path.join(project_dir, 'project_config.yaml')) - if 'sequence' in project_config.keys() and 'arch' in project_config['sequence'].keys(): - sequence_name = project_config['sequence']['arch'] + project_config = load_config(os.path.join(project_dir, "project_config.yaml")) + if "sequence" in project_config.keys() and "arch" in project_config["sequence"].keys(): + sequence_name = project_config["sequence"]["arch"] else: - sequence_name = load_default('model/sequence')['sequence']['arch'] + sequence_name = load_default("model/sequence")["sequence"]["arch"] - if 'sequence' in project_config.keys() and 'latent_name' in project_config['sequence'].keys(): - sequence_inference_latent_name = project_config['sequence']['latent_name'] + if "sequence" in project_config.keys() and "latent_name" in project_config["sequence"].keys(): + sequence_inference_latent_name = project_config["sequence"]["latent_name"] else: sequence_inference_latent_name = None - if 'feature_extractor' in project_config.keys() and 'arch' in project_config['feature_extractor'].keys(): - feature_extractor_arch = project_config['feature_extractor']['arch'] - elif 'preset' in project_config.keys(): - preset = project_config['preset'] - preset_config = load_default('preset/{}'.format(preset)) - feature_extractor_arch = preset_config['feature_extractor']['arch'] + if "feature_extractor" in project_config.keys() and "arch" in project_config["feature_extractor"].keys(): + feature_extractor_arch = project_config["feature_extractor"]["arch"] + elif "preset" in project_config.keys(): + preset = project_config["preset"] + preset_config = load_default("preset/{}".format(preset)) + feature_extractor_arch = preset_config["feature_extractor"]["arch"] else: - feature_extractor_arch = load_default('model/feature_extractor')['feature_extractor']['arch'] - - with h5py.File(outputfile, 'r') as f: + feature_extractor_arch = load_default("model/feature_extractor")["feature_extractor"]["arch"] + with h5py.File(outputfile, "r") as f: keys = list(f.keys()) if len(keys) == 0: - raise ValueError('no datasets found in outputfile: {}'.format(outputfile)) + raise ValueError("no datasets found in outputfile: {}".format(outputfile)) # Order of priority for determining latent name, from high -> low # 1. input argument 2. custom latent name from sequence inference 3. the sequence arch name 4. the feature @@ -784,45 +790,50 @@ def import_outputfile(project_dir: Union[str, os.PathLike], elif feature_extractor_arch in keys: key = feature_extractor_arch else: - log.warning('No default latent names found, using the first one instead. Keys: {}'.format(keys)) + log.warning("No default latent names found, using the first one instead. Keys: {}".format(keys)) key = keys[0] - log.info('Key used to load outputfile: {}'.format(key)) - probabilities = f[key]['P'][:] + log.info("Key used to load outputfile: {}".format(key)) + probabilities = f[key]["P"][:] negative_probabilities = np.sum(probabilities < 0) if negative_probabilities > 0: - log.warning('N={} negative probabilities found in file {}'.format(negative_probabilities, - os.path.basename(outputfile))) + log.warning( + "N={} negative probabilities found in file {}".format( + negative_probabilities, os.path.basename(outputfile) + ) + ) probabilities[probabilities < 0] = 0 - thresholds = f[key]['thresholds'][:] + thresholds = f[key]["thresholds"][:] if thresholds.ndim == 2: # this should not happen thresholds = thresholds[-1, :] - loaded_class_names = f[key]['class_names'][:] + loaded_class_names = f[key]["class_names"][:] if type(loaded_class_names[0]) == bytes: - loaded_class_names = [i.decode('utf-8') for i in loaded_class_names] - log.debug('probabilities shape: {}'.format(probabilities.shape)) + loaded_class_names = [i.decode("utf-8") for i in loaded_class_names] + log.debug("probabilities shape: {}".format(probabilities.shape)) # if you pass class names, make sure that the order matches the order of the argument. Else, just return it # in the order it is in the HDF5 file if class_names is None: return probabilities, thresholds, latent_name, keys - log.debug('imported names: {}'.format(loaded_class_names)) + log.debug("imported names: {}".format(loaded_class_names)) indices = [] for class_name in class_names: ind = [i for i in range(len(loaded_class_names)) if loaded_class_names[i] == class_name] if len(ind) == 1: indices.append(ind[0]) indices = np.array(indices).squeeze() - log.debug('indices: {} type: {} shape: {}'.format(indices, type(indices), indices.shape)) + log.debug("indices: {} type: {} shape: {}".format(indices, type(indices), indices.shape)) if not indices.shape: - raise ValueError('Class names not found in file. Loaded: {} Requested: {}'.format( - loaded_class_names, class_names)) + raise ValueError( + "Class names not found in file. Loaded: {} Requested: {}".format(loaded_class_names, class_names) + ) if len(indices) == 0: - raise ValueError('Class names not found in file. Loaded: {} Requested: {}'.format( - loaded_class_names, class_names)) + raise ValueError( + "Class names not found in file. Loaded: {} Requested: {}".format(loaded_class_names, class_names) + ) probabilities = probabilities[:, indices] thresholds = thresholds[indices] @@ -830,26 +841,26 @@ def import_outputfile(project_dir: Union[str, os.PathLike], def has_outputfile(records: dict) -> list: - """ Convenience function for finding output files in a dictionary of records""" + """Convenience function for finding output files in a dictionary of records""" keys, has_outputs = [], [] # check to see which records have outputfiles for key, record in records.items(): keys.append(key) - has_outputs.append(record['output'] is not None) + has_outputs.append(record["output"] is not None) return has_outputs def do_outputfiles_have_predictions(data_path: Union[str, os.PathLike], model_name: str) -> list: - """ Looks for HDF5 datasets in data_path of name model_name """ + """Looks for HDF5 datasets in data_path of name model_name""" assert os.path.isdir(data_path) records = get_records_from_datadir(data_path) has_predictions = [] for key, record in records.items(): - file = records[key]['output'] + file = records[key]["output"] if file is None: has_predictions.append(False) continue - with h5py.File(file, 'r') as f: + with h5py.File(file, "r") as f: if model_name in list(f.keys()): has_predictions.append(True) else: @@ -858,17 +869,17 @@ def do_outputfiles_have_predictions(data_path: Union[str, os.PathLike], model_na def extract_date(string: str): - """ Extracts the actual date time from a formatted string. Used for finding most recent models """ - pattern = re.compile(r'\d{6}_\d{6}') + """Extracts the actual date time from a formatted string. Used for finding most recent models""" + pattern = re.compile(r"\d{6}_\d{6}") match = pattern.search(string) if match is not None: match = match.group() - match = datetime.strptime(match, '%y%m%d_%H%M%S') + match = datetime.strptime(match, "%y%m%d_%H%M%S") return match def sort_runs_by_date(runs: list) -> list: - """ Sorts run directories by date using the date string in the directory name """ + """Sorts run directories by date using the date string in the directory name""" runs_and_dates = [] for run in runs: runs_and_dates.append((run, extract_date(run))) @@ -879,8 +890,8 @@ def sort_runs_by_date(runs: list) -> list: def get_weightfiles_from_rundir(rundir: Union[os.PathLike, str]) -> dict: - """from a run directory, finds a dictionary of all the model weights. - + """from a run directory, finds a dictionary of all the model weights. + Can be either .pt or .ckpt. Can be either the "last" model or the "best" model, according to one's key metric Parameters @@ -895,26 +906,26 @@ def get_weightfiles_from_rundir(rundir: Union[os.PathLike, str]) -> dict: last: last.ckpt, most recent lightning file best: .ckpt, best model according to validation metric """ - subfiles = utils.get_subfiles(rundir, 'file') + subfiles = utils.get_subfiles(rundir, "file") deg_checkpoint = None for subfile in subfiles: - if subfile.endswith('checkpoint.pt'): + if subfile.endswith("checkpoint.pt"): deg_checkpoint = subfile - subdirs = utils.get_subfiles(rundir, 'directory') + subdirs = utils.get_subfiles(rundir, "directory") last, best = None, None for subdir in subdirs: - if subdir.endswith('lightning_checkpoints'): - subfiles = utils.get_subfiles(subdir, 'file') + if subdir.endswith("lightning_checkpoints"): + subfiles = utils.get_subfiles(subdir, "file") subfiles.sort() for subfile in subfiles: - if subfile.endswith('last.ckpt'): + if subfile.endswith("last.ckpt"): last = subfile else: basename = os.path.basename(subfile) # the last, alphabetically, checkpoint with epoch in the name is assumed to be the best - if 'epoch' in basename and basename.endswith('.ckpt'): + if "epoch" in basename and basename.endswith(".ckpt"): best = subfile return dict(deg=deg_checkpoint, last=last, best=best) @@ -922,18 +933,18 @@ def get_weightfiles_from_rundir(rundir: Union[os.PathLike, str]) -> dict: def get_weightfile_from_rundir(rundir: Union[os.PathLike, str]) -> str: weightfiles = get_weightfiles_from_rundir(rundir) # default to BEST weights - if weightfiles['best'] is not None: - return weightfiles['best'] - elif weightfiles['last'] is not None: - return weightfiles['last'] - elif weightfiles['deg'] is not None: - return weightfiles['deg'] + if weightfiles["best"] is not None: + return weightfiles["best"] + elif weightfiles["last"] is not None: + return weightfiles["last"] + elif weightfiles["deg"] is not None: + return weightfiles["deg"] else: return None def get_weights_from_model_path(model_path: Union[str, os.PathLike]) -> dict: - """ Finds absolute path to weight files for each model type and architecture + """Finds absolute path to weight files for each model type and architecture Parameters ---------- @@ -958,20 +969,20 @@ def get_weights_from_model_path(model_path: Union[str, os.PathLike]) -> dict: } } """ - rundirs = get_subfiles(model_path, return_type='directory') + rundirs = get_subfiles(model_path, return_type="directory") # assume the models are only at most one sub directory underneath for rundir in rundirs: - subdirs = get_subfiles(rundir, return_type='directory') + subdirs = get_subfiles(rundir, return_type="directory") rundirs += subdirs rundirs.sort() # model_weights = defaultdict(list) - model_weights = {'flow_generator': {}, 'feature_extractor': {}, 'sequence': {}} + model_weights = {"flow_generator": {}, "feature_extractor": {}, "sequence": {}} for rundir in rundirs: # for backwards compatibility - paramfile = os.path.join(rundir, 'hyperparameters.yaml') + paramfile = os.path.join(rundir, "hyperparameters.yaml") if not os.path.isfile(paramfile): - paramfile = os.path.join(rundir, 'config.yaml') + paramfile = os.path.join(rundir, "config.yaml") if not os.path.isfile(paramfile): continue @@ -979,20 +990,20 @@ def get_weights_from_model_path(model_path: Union[str, os.PathLike]) -> dict: if params is None: continue # this horrible if-else tree is for backwards compatability with how I used to save config files - if 'model' in params.keys(): - model_type = params['model'] - if params['model'] in params.keys(): - arch = params[params['model']] - elif params['model'] == 'feature_extractor': - arch = params['classifier'] - elif 'arch' in params.keys(): - arch = params['arch'] + if "model" in params.keys(): + model_type = params["model"] + if params["model"] in params.keys(): + arch = params[params["model"]] + elif params["model"] == "feature_extractor": + arch = params["classifier"] + elif "arch" in params.keys(): + arch = params["arch"] else: - raise ValueError('Could not find architecture from config: {}'.format(params)) + raise ValueError("Could not find architecture from config: {}".format(params)) - elif 'run' in params.keys(): - model_type = params['run']['model'] - arch = params[model_type]['arch'] + elif "run" in params.keys(): + model_type = params["run"]["model"] + arch = params[model_type]["arch"] else: continue @@ -1013,8 +1024,7 @@ def get_weights_from_model_path(model_path: Union[str, os.PathLike]) -> dict: def get_weight_file_absolute_or_relative(cfg, path_to_weights): - """if path_to_weights exists, return. if it doesn't, pre-pend model path - """ + """if path_to_weights exists, return. if it doesn't, pre-pend model path""" if os.path.isfile(path_to_weights): return path_to_weights else: @@ -1024,7 +1034,7 @@ def get_weight_file_absolute_or_relative(cfg, path_to_weights): def get_weightfile_from_cfg(cfg: DictConfig, model_type: str) -> Union[str, None]: - """ Gets the correct weight files from the configuration. + """Gets the correct weight files from the configuration. The weights are loaded in the following order of priority 1. cfg.reload.weights: assume a specific pretrained weightfile with all components (flow_generator, spatial, flow) @@ -1050,7 +1060,7 @@ def get_weightfile_from_cfg(cfg: DictConfig, model_type: str) -> Union[str, None # assert os.path.isfile(cfg.reload.weights) # return cfg.reload.weights - assert model_type in ['flow_generator', 'feature_extractor', 'end_to_end', 'sequence'] + assert model_type in ["flow_generator", "feature_extractor", "end_to_end", "sequence"] if not os.path.isdir(cfg.project.model_path): cfg = convert_config_paths_to_absolute(cfg) @@ -1059,40 +1069,41 @@ def get_weightfile_from_cfg(cfg: DictConfig, model_type: str) -> Union[str, None architecture = cfg[model_type].arch - if model_type in cfg and cfg[model_type].weights is not None and cfg[model_type].weights == 'pretrained': - assert model_type in ['flow_generator', 'feature_extractor'] + if model_type in cfg and cfg[model_type].weights is not None and cfg[model_type].weights == "pretrained": + assert model_type in ["flow_generator", "feature_extractor"] pretrained_models = get_weights_from_model_path(cfg.project.pretrained_path) assert len(pretrained_models[model_type][architecture]) > 0 weights = pretrained_models[model_type][architecture][-1] - log.info('loading pretrained weights: {}'.format(weights)) + log.info("loading pretrained weights: {}".format(weights)) return weights - if model_type == 'end_to_end': + if model_type == "end_to_end": if cfg.reload.latest: - assert len(trained_models['feature_extractor'][architecture]) > 0 - return trained_models['feature_extractor'][architecture][-1] + assert len(trained_models["feature_extractor"][architecture]) > 0 + return trained_models["feature_extractor"][architecture][-1] else: - if model_type in cfg and cfg[model_type].weights is not None and cfg[model_type].weights != 'latest': + if model_type in cfg and cfg[model_type].weights is not None and cfg[model_type].weights != "latest": path_to_weights = get_weight_file_absolute_or_relative(cfg, cfg[model_type].weights) assert os.path.isfile(path_to_weights) - log.info('loading specified weights: {}'.format(path_to_weights)) + log.info("loading specified weights: {}".format(path_to_weights)) return path_to_weights - elif cfg.reload.latest or cfg[model_type].weights == 'latest': + elif cfg.reload.latest or cfg[model_type].weights == "latest": # print(trained_models) if len(trained_models[model_type][architecture]) == 0: - log.warning('Trying to load *latest* weights, but found none! Using random initialization!') + log.warning("Trying to load *latest* weights, but found none! Using random initialization!") return - log.debug('trained models found: {}'.format(trained_models[model_type][architecture])) - log.info('loading LATEST weights: {}'.format(trained_models[model_type][architecture][-1])) + log.debug("trained models found: {}".format(trained_models[model_type][architecture])) + log.info("loading LATEST weights: {}".format(trained_models[model_type][architecture][-1])) return trained_models[model_type][architecture][-1] else: - log.warning('no {} weights found...'.format(model_type)) + log.warning("no {} weights found...".format(model_type)) return -def convert_config_paths_to_absolute(project_cfg: DictConfig, - raise_error_if_pretrained_missing: bool = True) -> DictConfig: - """ Converts relative file paths in a project configuration into absolute paths. +def convert_config_paths_to_absolute( + project_cfg: DictConfig, raise_error_if_pretrained_missing: bool = True +) -> DictConfig: + """Converts relative file paths in a project configuration into absolute paths. Example: project_cfg['project']['path'] = '/path/to/project' @@ -1114,45 +1125,48 @@ def convert_config_paths_to_absolute(project_cfg: DictConfig, Returns: project_cfg (dict) """ - assert 'project' in project_cfg.keys() + assert "project" in project_cfg.keys() - root = project_cfg['project']['path'] - model_path = project_cfg['project']['model_path'] - data_path = project_cfg['project']['data_path'] + root = project_cfg["project"]["path"] + model_path = project_cfg["project"]["model_path"] + data_path = project_cfg["project"]["data_path"] # backwards compatibility - if 'pretrained_path' in project_cfg['project'].keys(): - pretrained_path = project_cfg['project']['pretrained_path'] + if "pretrained_path" in project_cfg["project"].keys(): + pretrained_path = project_cfg["project"]["pretrained_path"] else: - pretrained_path = 'pretrained_models' - cfg_path = os.path.join(root, project_cfg['project']['config_file']) - - if (os.path.isdir(model_path) and os.path.isdir(data_path) and os.path.isfile(cfg_path) and - os.path.isdir(pretrained_path)): + pretrained_path = "pretrained_models" + cfg_path = os.path.join(root, project_cfg["project"]["config_file"]) + + if ( + os.path.isdir(model_path) + and os.path.isdir(data_path) + and os.path.isfile(cfg_path) + and os.path.isdir(pretrained_path) + ): # already absolute return project_cfg - log.info('cwd in absolute: {}'.format(os.getcwd())) + log.info("cwd in absolute: {}".format(os.getcwd())) if not os.path.isdir(model_path): model_path = os.path.join(root, model_path) - assert os.path.isdir(model_path), 'model path does not exist! {}'.format(model_path) + assert os.path.isdir(model_path), "model path does not exist! {}".format(model_path) if not os.path.isdir(data_path): data_path = os.path.join(root, data_path) - assert os.path.isdir(data_path), 'data path does not exist! {}'.format(data_path) + assert os.path.isdir(data_path), "data path does not exist! {}".format(data_path) if not os.path.isfile(cfg_path): cfg_path = os.path.join(root, cfg_path) - assert os.path.isdir(cfg_path), 'config file does not exist! {}'.format(cfg_path) + assert os.path.isdir(cfg_path), "config file does not exist! {}".format(cfg_path) if not os.path.isdir(pretrained_path): - # pretrained_dir can be one of the following locations: # my_model_dir/pretrained # my_project/pretrained # my_project/models/pretrained pretrained_options = [ - os.path.join(i, pretrained_path) for i in [model_path, root, os.path.join(root, 'models')] + os.path.join(i, pretrained_path) for i in [model_path, root, os.path.join(root, "models")] ] exists = [os.path.isdir(i) for i in pretrained_options] @@ -1161,8 +1175,10 @@ def convert_config_paths_to_absolute(project_cfg: DictConfig, index = exists.index(True) pretrained_path = pretrained_options[index] except ValueError: - error_string = 'pretrained directory does not exist! {}\nSee instructions '.format(pretrained_path) + \ - 'on the project GitHub for downloading weights: https://github.com/jbohnslav/deepethogram' + error_string = ( + "pretrained directory does not exist! {}\nSee instructions ".format(pretrained_path) + + "on the project GitHub for downloading weights: https://github.com/jbohnslav/deepethogram" + ) if raise_error_if_pretrained_missing: log.error(error_string) @@ -1170,17 +1186,17 @@ def convert_config_paths_to_absolute(project_cfg: DictConfig, else: log.warning(error_string) - project_cfg['project']['model_path'] = model_path - project_cfg['project']['data_path'] = data_path - project_cfg['project']['config_file'] = cfg_path - project_cfg['project']['pretrained_path'] = pretrained_path - log.info('after absolute: {}'.format(project_cfg['project'])) + project_cfg["project"]["model_path"] = model_path + project_cfg["project"]["data_path"] = data_path + project_cfg["project"]["config_file"] = cfg_path + project_cfg["project"]["pretrained_path"] = pretrained_path + log.info("after absolute: {}".format(project_cfg["project"])) return project_cfg def load_config(path_to_config: Union[str, os.PathLike]) -> dict: """Convenience function to load dictionary from yaml and sort out potentially erroneous paths""" - assert os.path.isfile(path_to_config), 'configuration file does not exist! {}'.format(path_to_config) + assert os.path.isfile(path_to_config), "configuration file does not exist! {}".format(path_to_config) project = OmegaConf.load(path_to_config) project = fix_config_paths(project, path_to_config) @@ -1189,21 +1205,21 @@ def load_config(path_to_config: Union[str, os.PathLike]) -> dict: def load_default(conf_name: str) -> dict: - """ Loads default configs from deepethogram install path - DEPRECATED. + """Loads default configs from deepethogram install path + DEPRECATED. TODO: replace with configuration.load_config_by_name """ - log.debug('project loc for loading default: {}'.format(projects_file_directory)) - defaults_file = os.path.join(projects_file_directory, 'conf', os.path.relpath(conf_name) + '.yaml') - assert os.path.isfile(defaults_file), 'configuration file does not exist! {}'.format(defaults_file) + log.debug("project loc for loading default: {}".format(projects_file_directory)) + defaults_file = os.path.join(projects_file_directory, "conf", os.path.relpath(conf_name) + ".yaml") + assert os.path.isfile(defaults_file), "configuration file does not exist! {}".format(defaults_file) defaults = utils.load_yaml(defaults_file) return defaults -def convert_all_videos(config_file: Union[str, os.PathLike], movie_format='hdf5', **kwargs) -> None: - """Converts all videos in a project from one filetype to another. - +def convert_all_videos(config_file: Union[str, os.PathLike], movie_format="hdf5", **kwargs) -> None: + """Converts all videos in a project from one filetype to another. + Note: If using movie_format other than 'directory' or 'hdf5', will re-compress images! Parameters @@ -1217,9 +1233,10 @@ def convert_all_videos(config_file: Union[str, os.PathLike], movie_format='hdf5' project_config = utils.load_yaml(config_file) records = get_records_from_datadir( - os.path.join(project_config['project']['path'], project_config['project']['data_path'])) - for key, record in tqdm(records.items(), desc='converting videos'): - videofile = record['rgb'] + os.path.join(project_config["project"]["path"], project_config["project"]["data_path"]) + ) + for key, record in tqdm(records.items(), desc="converting videos"): + videofile = record["rgb"] try: convert_video(videofile, movie_format=movie_format, **kwargs) except ValueError as e: @@ -1227,18 +1244,17 @@ def convert_all_videos(config_file: Union[str, os.PathLike], movie_format='hdf5' def get_config_file_from_path(path: Union[str, os.PathLike]) -> str: - """gets a config file, with name either path/project.yaml or path/project_config.yaml - """ - for cfg_path in ['project', 'project_config']: - cfg_path = os.path.join(path, cfg_path + '.yaml') + """gets a config file, with name either path/project.yaml or path/project_config.yaml""" + for cfg_path in ["project", "project_config"]: + cfg_path = os.path.join(path, cfg_path + ".yaml") if os.path.isfile(cfg_path): return cfg_path - raise ValueError('No configuration file found in directory! {}'.format(os.listdir(path))) + raise ValueError("No configuration file found in directory! {}".format(os.listdir(path))) def fix_config_paths(cfg, path_to_config: Union[str, os.PathLike]): """Fixes the path to the project and config file in the configuration itself. - + This situation could occur if one moved an existing project to another computer or directory Parameters @@ -1250,18 +1266,21 @@ def fix_config_paths(cfg, path_to_config: Union[str, os.PathLike]): Returns ------- - cfg: DictConfig + cfg: DictConfig configuration with fixed paths """ error = False - if cfg['project']['path'] != os.path.dirname(path_to_config): - log.warning('Erroneous project path in the config file itself, changing from {} to {}'.format( - cfg['project']['path'], os.path.dirname(path_to_config))) - cfg['project']['path'] = os.path.dirname(path_to_config) + if cfg["project"]["path"] != os.path.dirname(path_to_config): + log.warning( + "Erroneous project path in the config file itself, changing from {} to {}".format( + cfg["project"]["path"], os.path.dirname(path_to_config) + ) + ) + cfg["project"]["path"] = os.path.dirname(path_to_config) error = True - if cfg['project']['config_file'] != os.path.basename(path_to_config): - log.warning('Erroneous name of config file in the config file itself, changing...') - cfg['project']['config_file'] = os.path.basename(path_to_config) + if cfg["project"]["config_file"] != os.path.basename(path_to_config): + log.warning("Erroneous name of config file in the config file itself, changing...") + cfg["project"]["config_file"] = os.path.basename(path_to_config) error = True if error: utils.save_dict_to_yaml(cfg, path_to_config) @@ -1270,12 +1289,12 @@ def fix_config_paths(cfg, path_to_config: Union[str, os.PathLike]): def get_config_from_path(project_path: Union[str, os.PathLike]): """gets a project configuration from a project path - + Finds the file; loads it; and fixes any relevant config paths Parameters ---------- - project_path : str, os.PathLike + project_path : str, os.PathLike path to a deepethogram project Returns @@ -1313,19 +1332,19 @@ def get_project_path_from_cl(argv: list, error_if_not_found=True) -> str: if project.path is not found """ for arg in argv: - if 'project.config_file' in arg: - key, path = arg.split('=') + if "project.config_file" in arg: + key, path = arg.split("=") assert os.path.isfile(path) # path is the path to the project directory, not the config file path = os.path.dirname(path) return path - elif 'project.path' in arg: - key, path = arg.split('=') + elif "project.path" in arg: + key, path = arg.split("=") assert os.path.isdir(path) return path if error_if_not_found: - raise ValueError('project path or file not in args: {}'.format(argv)) + raise ValueError("project path or file not in args: {}".format(argv)) else: return None @@ -1383,7 +1402,7 @@ def configure_run_directory(cfg: DictConfig) -> str: Name: date-time_model-type_run-type_notes e.g. 20210311_011800_feature_extractor_train_testing_dropout - + Parameters ---------- cfg : DictConfig @@ -1394,15 +1413,15 @@ def configure_run_directory(cfg: DictConfig) -> str: str path to run directory """ - datestring = datetime.now().strftime('%y%m%d_%H%M%S') - if cfg.run.type == 'gui': + datestring = datetime.now().strftime("%y%m%d_%H%M%S") + if cfg.run.type == "gui": path = cfg.project.path if cfg.project.path is not None else os.getcwd() - directory = os.path.join(path, 'gui_logs', datestring) + directory = os.path.join(path, "gui_logs", datestring) else: - directory = f'{datestring}_{cfg.run.model}_{cfg.run.type}' + directory = f"{datestring}_{cfg.run.model}_{cfg.run.type}" directory = os.path.join(cfg.project.model_path, directory) if cfg.notes is not None: - directory += f'_{cfg.notes}' + directory += f"_{cfg.notes}" if not os.path.isdir(directory): os.makedirs(directory) os.chdir(directory) @@ -1412,7 +1431,7 @@ def configure_run_directory(cfg: DictConfig) -> str: def configure_logging(cfg: DictConfig = None) -> None: """Sets up python logging to use a specific format, and also save to disk - If no config is passed, simply log to the command line. + If no config is passed, simply log to the command line. Parameters ---------- @@ -1421,20 +1440,22 @@ def configure_logging(cfg: DictConfig = None) -> None: """ # assume current directory is run directory if cfg is None: - log_level = 'info' + log_level = "info" handlers = [logging.StreamHandler()] else: - assert cfg.log.level in ['debug', 'info', 'warning', 'error', 'critical'] + assert cfg.log.level in ["debug", "info", "warning", "error", "critical"] log_level = cfg.log.level - handlers = [logging.FileHandler(cfg.log.level + '.log'), logging.StreamHandler()] + handlers = [logging.FileHandler(cfg.log.level + ".log"), logging.StreamHandler()] # https://docs.python.org/3/library/logging.html#logging-levels - log_lookup = {'critical': 50, 'error': 40, 'warning': 30, 'info': 20, 'debug': 10} + log_lookup = {"critical": 50, "error": 40, "warning": 30, "info": 20, "debug": 10} logger = logging.getLogger() while logger.hasHandlers(): logger.removeHandler(logger.handlers[0]) - logging.basicConfig(level=log_lookup[log_level], - format='[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s', - handlers=handlers) + logging.basicConfig( + level=log_lookup[log_level], + format="[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s", + handlers=handlers, + ) def setup_run(cfg: DictConfig, **kwargs) -> DictConfig: @@ -1459,5 +1480,5 @@ def setup_run(cfg: DictConfig, **kwargs) -> DictConfig: cfg.run.dir = directory configure_logging(cfg) - utils.save_dict_to_yaml(OmegaConf.to_container(cfg), 'config.yaml') + utils.save_dict_to_yaml(OmegaConf.to_container(cfg), "config.yaml") return cfg diff --git a/deepethogram/schedulers.py b/deepethogram/schedulers.py index 360ffcc..a21cbd7 100644 --- a/deepethogram/schedulers.py +++ b/deepethogram/schedulers.py @@ -7,21 +7,24 @@ log = logging.getLogger(__name__) + class _LRScheduler: def __init__(self, optimizer, last_epoch=-1): if not isinstance(optimizer, Optimizer): - raise TypeError('{} is not an Optimizer'.format( - type(optimizer).__name__)) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer if last_epoch == -1: for group in optimizer.param_groups: - group.setdefault('initial_lr', group['lr']) + group.setdefault("initial_lr", group["lr"]) else: for i, group in enumerate(optimizer.param_groups): - if 'initial_lr' not in group: - raise KeyError("param 'initial_lr' is not specified " - "in param_groups[{}] when resuming an optimizer".format(i)) - self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) + if "initial_lr" not in group: + raise KeyError( + "param 'initial_lr' is not specified " "in param_groups[{}] when resuming an optimizer".format( + i + ) + ) + self.base_lrs = list(map(lambda group: group["initial_lr"], optimizer.param_groups)) self.step(last_epoch + 1) self.last_epoch = last_epoch @@ -30,7 +33,7 @@ def state_dict(self): It contains an entry for every variable in self.__dict__ which is not the optimizer. """ - return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + return {key: value for key, value in self.__dict__.items() if key != "optimizer"} def load_state_dict(self, state_dict): """Loads the schedulers state. @@ -48,7 +51,7 @@ def step(self, epoch=None): epoch = self.last_epoch + 1 self.last_epoch = epoch for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): - param_group['lr'] = lr + param_group["lr"] = lr # UNMERGED PULL REQUEST! NOT WRITTEN BY ME BUT SUPER USEFUL! @@ -92,7 +95,7 @@ def __init__(self, optimizer, T, eta_min=0, T_mult=2.0, eta_mult=1.0, last_epoch self.eta_mult = eta_mult if T_mult < 1: - raise ValueError('T_mult should be >= 1.0.') + raise ValueError("T_mult should be >= 1.0.") self.T_mult = T_mult super(CosineAnnealingRestartsLR, self).__init__(optimizer, last_epoch) @@ -107,20 +110,18 @@ def get_lr(self): else: # computation of the last restarting epoch is based on sum of geometric series: # last_restart = T * (1 + T_mult + T_mult ** 2 + ... + T_mult ** i_restarts) - i_restarts = int(math.log(1 - self.last_epoch * (1 - self.T_mult) / self.T, - self.T_mult)) - last_restart = int(self.T * (1 - self.T_mult ** i_restarts) / (1 - self.T_mult)) + i_restarts = int(math.log(1 - self.last_epoch * (1 - self.T_mult) / self.T, self.T_mult)) + last_restart = int(self.T * (1 - self.T_mult**i_restarts) / (1 - self.T_mult)) if self.last_epoch == last_restart: T_i1 = self.T * self.T_mult ** (i_restarts - 1) # T_{i-1} lr_update = self.eta_mult / self._decay(T_i1 - 1, T_i1) else: - T_i = self.T * self.T_mult ** i_restarts + T_i = self.T * self.T_mult**i_restarts t = self.last_epoch - last_restart lr_update = self._decay(t, T_i) / self._decay(t - 1, T_i) - return [lr_update * (group['lr'] - self.eta_min) + self.eta_min - for group in self.optimizer.param_groups] + return [lr_update * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups] @staticmethod def _decay(t, T): @@ -128,8 +129,8 @@ def _decay(t, T): return 0.5 * (1 + math.cos(math.pi * t / T)) -def initialize_scheduler(optimizer, cfg: DictConfig, mode: str = 'max', reduction_factor: float = 0.1): - """ Makes a learning rate scheduler from an OmegaConf DictConfig +def initialize_scheduler(optimizer, cfg: DictConfig, mode: str = "max", reduction_factor: float = 0.1): + """Makes a learning rate scheduler from an OmegaConf DictConfig Parameters ---------- @@ -150,19 +151,24 @@ def initialize_scheduler(optimizer, cfg: DictConfig, mode: str = 'max', reductio scheduler Learning rate scheduler """ - if cfg.train.scheduler == 'multistep': + if cfg.train.scheduler == "multistep": scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.train.milestones, gamma=0.5) # for convenience - scheduler.name = 'multistep' - elif cfg.train.scheduler == 'cosine': + scheduler.name = "multistep" + elif cfg.train.scheduler == "cosine": # todo: reconfigure this to use pytorch's new built-in cosine annealing scheduler = CosineAnnealingRestartsLR(optimizer, T=25, T_mult=1, eta_mult=0.5, eta_min=1e-7) - scheduler.name = 'cosine' - elif cfg.train.scheduler == 'plateau': - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=mode, factor=reduction_factor, - patience=cfg.train.patience, verbose=True, - min_lr=cfg.train.min_lr) - scheduler.name = 'plateau' + scheduler.name = "cosine" + elif cfg.train.scheduler == "plateau": + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode=mode, + factor=reduction_factor, + patience=cfg.train.patience, + verbose=True, + min_lr=cfg.train.min_lr, + ) + scheduler.name = "plateau" else: scheduler = None return scheduler diff --git a/deepethogram/sequence/__main__.py b/deepethogram/sequence/__main__.py index 49ee54d..b099711 100644 --- a/deepethogram/sequence/__main__.py +++ b/deepethogram/sequence/__main__.py @@ -1,3 +1,3 @@ from .train import main -main() \ No newline at end of file +main() diff --git a/deepethogram/sequence/inference.py b/deepethogram/sequence/inference.py index 14a89d1..406dd49 100644 --- a/deepethogram/sequence/inference.py +++ b/deepethogram/sequence/inference.py @@ -4,6 +4,7 @@ from typing import Union, Type import h5py + # import hydra import numpy as np import torch @@ -20,17 +21,19 @@ log = logging.getLogger(__name__) -def infer(model: Type[nn.Module], - device: Union[str, torch.device], - activation_function: Union[str, Type[nn.Module]], - data_file: Union[str, os.PathLike], - latent_name: str, - videofile: Union[str, os.PathLike], - sequence_length: int = 180, - is_two_stream: bool = True, - is_keypoint: bool = False, - expansion_method: str = 'sturman', - stack_in_time: bool = False): +def infer( + model: Type[nn.Module], + device: Union[str, torch.device], + activation_function: Union[str, Type[nn.Module]], + data_file: Union[str, os.PathLike], + latent_name: str, + videofile: Union[str, os.PathLike], + sequence_length: int = 180, + is_two_stream: bool = True, + is_keypoint: bool = False, + expansion_method: str = "sturman", + stack_in_time: bool = False, +): """Runs inference of the sequence model Parameters @@ -55,7 +58,7 @@ def infer(model: Type[nn.Module], Method for expanding keypoints into features, by default sturman stack_in_time: bool, optional If True, stacks sequences from T x K features -> T*K features, by default False - + Returns ------- logits: np.ndarray @@ -71,37 +74,41 @@ def infer(model: Type[nn.Module], if not is_keypoint: assert latent_name is not None - gen = FeatureVectorDataset(data_file, - labelfile=None, - h5_key=latent_name, - sequence_length=sequence_length, - nonoverlapping=True, - store_in_ram=False, - is_two_stream=is_two_stream) + gen = FeatureVectorDataset( + data_file, + labelfile=None, + h5_key=latent_name, + sequence_length=sequence_length, + nonoverlapping=True, + store_in_ram=False, + is_two_stream=is_two_stream, + ) else: - gen = KeypointDataset(data_file, - labelfile=None, - videofile=videofile, - expansion_method=expansion_method, - sequence_length=sequence_length, - stack_in_time=stack_in_time, - nonoverlapping=not stack_in_time) + gen = KeypointDataset( + data_file, + labelfile=None, + videofile=videofile, + expansion_method=expansion_method, + sequence_length=sequence_length, + stack_in_time=stack_in_time, + nonoverlapping=not stack_in_time, + ) n_datapoints = gen.shape[1] gen = data.DataLoader(gen, batch_size=1, shuffle=False, num_workers=0, drop_last=False) gen = iter(gen) - log.debug('Making sequence iterator with parameters: ') - log.debug('file: {}'.format(data_file)) - log.debug('seq length: {}'.format(sequence_length)) + log.debug("Making sequence iterator with parameters: ") + log.debug("file: {}".format(data_file)) + log.debug("seq length: {}".format(sequence_length)) if type(activation_function) == str: - if activation_function == 'softmax': + if activation_function == "softmax": activation_function = torch.nn.Softmax(dim=1) - elif activation_function == 'sigmoid': + elif activation_function == "sigmoid": activation_function = torch.nn.Sigmoid() else: - raise ValueError('unknown activation function: {}'.format(activation_function)) + raise ValueError("unknown activation function: {}".format(activation_function)) if type(device) == str: device = torch.device(device) @@ -120,17 +127,16 @@ def infer(model: Type[nn.Module], all_probabilities = [] has_printed = False for i in range(len(gen)): - with torch.no_grad(): batch = next(gen) - features = batch['features'].to(device) + features = batch["features"].to(device) logits = model(features) probabilities = activation_function(logits).detach().cpu().numpy().squeeze().T logits = logits.detach().cpu().numpy().squeeze().T if not has_printed: - log.debug('logits shape: {}'.format(logits.shape)) + log.debug("logits shape: {}".format(logits.shape)) has_printed = True if not stack_in_time: @@ -140,8 +146,8 @@ def infer(model: Type[nn.Module], indices = range(i * sequence_length, end) # get rid of padding in final batch if len(indices) < logits.shape[0]: - logits = logits[:len(indices), :] - probabilities = probabilities[:len(indices), :] + logits = logits[: len(indices), :] + probabilities = probabilities[: len(indices), :] all_logits.append(logits) all_probabilities.append(probabilities) @@ -155,24 +161,26 @@ def infer(model: Type[nn.Module], return all_logits, all_probabilities -def extract(model, - outputfiles: list, - thresholds: np.ndarray, - final_activation: str, - latent_name: str, - output_name: str = 'tgmj', - sequence_length: int = 180, - is_two_stream: bool = True, - device: str = 'cuda:0', - ignore_error=True, - overwrite=False, - class_names: list = ['background']): +def extract( + model, + outputfiles: list, + thresholds: np.ndarray, + final_activation: str, + latent_name: str, + output_name: str = "tgmj", + sequence_length: int = 180, + is_two_stream: bool = True, + device: str = "cuda:0", + ignore_error=True, + overwrite=False, + class_names: list = ["background"], +): torch.backends.cudnn.benchmark = True assert isinstance(model, torch.nn.Module) device = torch.device(device) - if device.type != 'cpu': + if device.type != "cpu": torch.cuda.set_device(device) model = model.to(device) for parameter in model.parameters(): @@ -181,9 +189,9 @@ def extract(model, has_printed = False - if final_activation == 'softmax': + if final_activation == "softmax": activation_function = torch.nn.Softmax(dim=1) - elif final_activation == 'sigmoid': + elif final_activation == "sigmoid": activation_function = torch.nn.Sigmoid() else: raise NotImplementedError @@ -192,11 +200,13 @@ def extract(model, for i in tqdm(range(len(outputfiles))): outputfile = outputfiles[i] - log.info('running inference on {}. latent name: {} output name: {}...'.format( - outputfile, latent_name, output_name)) + log.info( + "running inference on {}. latent name: {} output name: {}...".format(outputfile, latent_name, output_name) + ) - logits, probabilities = infer(model, device, activation_function, outputfile, latent_name, None, - sequence_length, is_two_stream) + logits, probabilities = infer( + model, device, activation_function, outputfile, latent_name, None, sequence_length, is_two_stream + ) # gen = FeatureVectorDataset(outputfile, labelfile=None, h5_key=latent_name, # sequence_length=sequence_length, @@ -209,48 +219,47 @@ def extract(model, # log.debug('file: {}'.format(outputfile)) # log.debug('seq length: {}'.format(sequence_length)) - with h5py.File(outputfile, 'r+') as f: - + with h5py.File(outputfile, "r+") as f: if output_name in list(f.keys()): if overwrite: - del (f[output_name]) + del f[output_name] else: - log.info('Latent {} already found in file {}, skipping...'.format(output_name, outputfile)) + log.info("Latent {} already found in file {}, skipping...".format(output_name, outputfile)) continue group = f.create_group(output_name) - group.create_dataset('thresholds', data=thresholds, dtype=np.float32) - group.create_dataset('logits', data=logits, dtype=np.float32) - group.create_dataset('P', data=probabilities, dtype=np.float32) + group.create_dataset("thresholds", data=thresholds, dtype=np.float32) + group.create_dataset("logits", data=logits, dtype=np.float32) + group.create_dataset("P", data=probabilities, dtype=np.float32) dt = h5py.string_dtype() - group.create_dataset('class_names', data=class_names, dtype=dt) + group.create_dataset("class_names", data=class_names, dtype=dt) def sequence_inference(cfg: DictConfig): cfg = projects.setup_run(cfg) - log.info('args: {}'.format(' '.join(sys.argv))) + log.info("args: {}".format(" ".join(sys.argv))) # turn "models" in your project configuration to "full/path/to/models" - log.info('configuration used: ') + log.info("configuration used: ") log.info(OmegaConf.to_yaml(cfg)) - weights = projects.get_weightfile_from_cfg(cfg, model_type='sequence') - assert weights is not None, 'Must either specify a weightfile or use reload.latest=True' + weights = projects.get_weightfile_from_cfg(cfg, model_type="sequence") + assert weights is not None, "Must either specify a weightfile or use reload.latest=True" run_files = utils.get_run_files_from_weights(weights) if cfg.sequence.latent_name is None: # find the latent name used in the weight file you loaded rundir = os.path.dirname(weights) - loaded_cfg = utils.load_yaml(run_files['config_file']) - latent_name = loaded_cfg['sequence']['latent_name'] + loaded_cfg = utils.load_yaml(run_files["config_file"]) + latent_name = loaded_cfg["sequence"]["latent_name"] # if this latent name is also None, use the arch of the feature extractor # this should never happen if latent_name is None: - latent_name = loaded_cfg['feature_extractor']['arch'] + latent_name = loaded_cfg["feature_extractor"]["arch"] else: latent_name = cfg.sequence.latent_name if cfg.inference.use_loaded_model_cfg: output_name = cfg.sequence.output_name - loaded_config_file = run_files['config_file'] + loaded_config_file = run_files["config_file"] loaded_model_cfg = OmegaConf.load(loaded_config_file).sequence current_model_cfg = cfg.sequence model_cfg = OmegaConf.merge(current_model_cfg, loaded_model_cfg) @@ -260,7 +269,7 @@ def sequence_inference(cfg: DictConfig): cfg.sequence.weights = weights cfg.sequence.latent_name = latent_name cfg.sequence.output_name = output_name - log.info('latent name used for running sequence inference: {}'.format(latent_name)) + log.info("latent name used for running sequence inference: {}".format(latent_name)) # the output name will be a group in the output hdf5 dataset containing probabilities, etc if cfg.sequence.output_name is None: @@ -269,66 +278,70 @@ def sequence_inference(cfg: DictConfig): output_name = cfg.sequence.output_name directory_list = cfg.inference.directory_list if directory_list is None or len(directory_list) == 0: - raise ValueError('must pass list of directories from commmand line. ' - 'Ex: directory_list=[path_to_dir1,path_to_dir2] or directory_list=all') - elif type(directory_list) == str and directory_list == 'all': + raise ValueError( + "must pass list of directories from commmand line. " + "Ex: directory_list=[path_to_dir1,path_to_dir2] or directory_list=all" + ) + elif type(directory_list) == str and directory_list == "all": basedir = cfg.project.data_path - directory_list = utils.get_subfiles(basedir, 'directory') + directory_list = utils.get_subfiles(basedir, "directory") outputfiles = [] for directory in directory_list: - assert os.path.isdir(directory), 'Not a directory: {}'.format(directory) + assert os.path.isdir(directory), "Not a directory: {}".format(directory) record = projects.get_record_from_subdir(directory) - assert record['output'] is not None - outputfiles.append(record['output']) + assert record["output"] is not None + outputfiles.append(record["output"]) model = build_model_from_cfg(cfg, 1024, len(cfg.project.class_names)) - log.info('model: {}'.format(model)) + log.info("model: {}".format(model)) model = utils.load_weights(model, weights) - metrics_file = run_files['metrics_file'] + metrics_file = run_files["metrics_file"] assert os.path.isfile(metrics_file) best_epoch = utils.get_best_epoch_from_weightfile(weights) # best_epoch = -1 - log.info('best epoch from loaded file: {}'.format(best_epoch)) - with h5py.File(metrics_file, 'r') as f: + log.info("best epoch from loaded file: {}".format(best_epoch)) + with h5py.File(metrics_file, "r") as f: try: - thresholds = f['val']['metrics_by_threshold']['optimum'][best_epoch, :] + thresholds = f["val"]["metrics_by_threshold"]["optimum"][best_epoch, :] except KeyError: # backwards compatibility - thresholds = f['threshold_curves']['val']['optimum'][:] # [best_epoch, :] + thresholds = f["threshold_curves"]["val"]["optimum"][:] # [best_epoch, :] if thresholds.ndim > 1: thresholds = thresholds[best_epoch, :] - log.info('thresholds: {}'.format(thresholds)) + log.info("thresholds: {}".format(thresholds)) class_names = list(cfg.project.class_names) if len(thresholds) != len(class_names): - error_message = '''Number of classes in trained model: {} + error_message = """Number of classes in trained model: {} Number of classes in project: {} Did you add or remove behaviors after training this model? If so, please retrain! - '''.format(len(thresholds), len(class_names)) + """.format(len(thresholds), len(class_names)) raise ValueError(error_message) - device = 'cuda:{}'.format(cfg.compute.gpu_id) + device = "cuda:{}".format(cfg.compute.gpu_id) class_names = cfg.project.class_names class_names = np.array(class_names) - extract(model, - outputfiles, - thresholds, - cfg.feature_extractor.final_activation, - latent_name, - output_name, - cfg.sequence.sequence_length, - True, - device, - cfg.inference.ignore_error, - cfg.inference.overwrite, - class_names=class_names) - - -if __name__ == '__main__': + extract( + model, + outputfiles, + thresholds, + cfg.feature_extractor.final_activation, + latent_name, + output_name, + cfg.sequence.sequence_length, + True, + device, + cfg.inference.ignore_error, + cfg.inference.overwrite, + class_names=class_names, + ) + + +if __name__ == "__main__": project_path = projects.get_project_path_from_cl(sys.argv) cfg = make_sequence_inference_cfg(project_path, use_command_line=True) - sequence_inference(cfg) \ No newline at end of file + sequence_inference(cfg) diff --git a/deepethogram/sequence/models/mlp.py b/deepethogram/sequence/models/mlp.py index 1ef114b..5a44a5c 100644 --- a/deepethogram/sequence/models/mlp.py +++ b/deepethogram/sequence/models/mlp.py @@ -4,13 +4,20 @@ class MLP(nn.Module): - """Multi-layer perceptron model. Baseline for sequence modeling - """ - def __init__(self, D: int, classes: int, dropout_p: float = 0.4, - hidden_layers=( - 256, - 128, - ), pos=None, neg=None): + """Multi-layer perceptron model. Baseline for sequence modeling""" + + def __init__( + self, + D: int, + classes: int, + dropout_p: float = 0.4, + hidden_layers=( + 256, + 128, + ), + pos=None, + neg=None, + ): """Constructor Parameters diff --git a/deepethogram/sequence/models/sequence.py b/deepethogram/sequence/models/sequence.py index 6c11a11..c3246bb 100644 --- a/deepethogram/sequence/models/sequence.py +++ b/deepethogram/sequence/models/sequence.py @@ -1,13 +1,20 @@ from torch import nn -def conv1d_same(in_channels, out_channels, kernel_size, - stride=1, dilation=1, groups=1, bias=True): +def conv1d_same(in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): # if stride is two, output should be exactly half the size of input padding = kernel_size // 2 * dilation - return (nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, - padding=padding, dilation=dilation, groups=groups, bias=bias)) + return nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) class Linear(nn.Module): @@ -16,7 +23,7 @@ def __init__(self, num_features, num_classes, kernel_size=1): self.conv1 = conv1d_same(num_features, num_classes, kernel_size=kernel_size, stride=1, bias=True) def forward(self, x): - return (self.conv1(x)) + return self.conv1(x) class Conv_Nonlinear(nn.Module): @@ -42,24 +49,40 @@ def __init__(self, num_features, num_classes, batchnorm=True, hidden_size=64, dr self.net = nn.Sequential(*layers) def forward(self, x): - return (self.net(x)) + return self.net(x) class RNN(nn.Module): - def __init__(self, num_features, num_classes, style='lstm', hidden_size=64, - num_layers=1, dropout=0.0, output_dropout=0.0, bidirectional=False): + def __init__( + self, + num_features, + num_classes, + style="lstm", + hidden_size=64, + num_layers=1, + dropout=0.0, + output_dropout=0.0, + bidirectional=False, + ): super().__init__() - assert style in ['rnn', 'lstm', 'gru'] - if style == 'rnn': + assert style in ["rnn", "lstm", "gru"] + if style == "rnn": func = nn.RNN - elif style == 'lstm': + elif style == "lstm": func = nn.LSTM - elif style == 'gru': + elif style == "gru": func = nn.GRU - self.rnn = func(num_features, hidden_size, num_layers=num_layers, - bias=True, batch_first=True, dropout=dropout, bidirectional=bidirectional) + self.rnn = func( + num_features, + hidden_size, + num_layers=num_layers, + bias=True, + batch_first=True, + dropout=dropout, + bidirectional=bidirectional, + ) self.dropout = nn.Dropout(output_dropout) size = hidden_size * 2 if bidirectional else hidden_size self.hidden_to_output = nn.Linear(size, num_classes) @@ -74,4 +97,4 @@ def forward(self, x): # outputs is N, L, C outputs = self.hidden_to_output(hiddens) # return outputs in shape N, C, L to be the same as conv1d - return (outputs.permute(0, 2, 1).contiguous()) + return outputs.permute(0, 2, 1).contiguous() diff --git a/deepethogram/sequence/models/tgm.py b/deepethogram/sequence/models/tgm.py index 2df64f1..47b1cbd 100644 --- a/deepethogram/sequence/models/tgm.py +++ b/deepethogram/sequence/models/tgm.py @@ -16,7 +16,7 @@ def compute_pad(stride, k, s): class TGMLayer(nn.Module): - """ THIS LAYER HAS BEEN EDITED ONLY SLIGHTLY FROM THE AUTHOR'S ORIGINAL. + """THIS LAYER HAS BEEN EDITED ONLY SLIGHTLY FROM THE AUTHOR'S ORIGINAL. https://github.com/piergiaj/tgm-icml19/ @inproceedings{piergiovanni2018super, @@ -26,14 +26,17 @@ class TGMLayer(nn.Module): year={2019} } """ - def __init__(self, - D: int = 1024, - n_filters: int = 3, - filter_length: int = 30, - c_in: int = 1, - c_out: int = 1, - soft: bool = False, - viz: bool = False): + + def __init__( + self, + D: int = 1024, + n_filters: int = 3, + filter_length: int = 30, + c_in: int = 1, + c_out: int = 1, + soft: bool = False, + viz: bool = False, + ): super().__init__() self.D = D self.n_filters = n_filters @@ -139,7 +142,7 @@ def forward(self, x): # output of C_in xDxT # indexing selects one row of k of shape C_in x1x1xL # grouped convolution applies to every C_in (of shape 1xDxT) - r = F.conv2d(x, k[i * self.c_in:(i + 1) * self.c_in], groups=self.c_in).squeeze(1) + r = F.conv2d(x, k[i * self.c_in : (i + 1) * self.c_in], groups=self.c_in).squeeze(1) # print('r: {}'.format(r.shape)) # now, you have a stack of NxC_in x D x T # 1x1 conv to combine C_in to 1 @@ -165,22 +168,24 @@ def forward(self, x): class TGM(nn.Module): - def __init__(self, - D: int = 1024, - n_filters: int = 16, - filter_length: int = 30, - input_dropout: float = 0.5, - dropout_p: float = 0.5, - classes: int = 8, - num_layers: int = 3, - reduction: str = 'max', - c_in: int = 1, - c_out: int = 8, - soft: bool = False, - num_features: int = 512, - viz: bool = False, - nonlinear_classification: bool = False, - concatenate_inputs=True): + def __init__( + self, + D: int = 1024, + n_filters: int = 16, + filter_length: int = 30, + input_dropout: float = 0.5, + dropout_p: float = 0.5, + classes: int = 8, + num_layers: int = 3, + reduction: str = "max", + c_in: int = 1, + c_out: int = 8, + soft: bool = False, + num_features: int = 512, + viz: bool = False, + nonlinear_classification: bool = False, + concatenate_inputs=True, + ): super().__init__() self.D = D # dimensionality of inputs. E.G. 1024 features from a CNN penultimate layer @@ -190,7 +195,7 @@ def __init__(self, self.input_dropout = nn.Dropout(input_dropout) # probability to dropout input channels self.output_dropout = nn.Dropout(dropout_p) # probability to dropout final layer before FC self.num_layers = num_layers # how many TGM layers - assert (reduction in ['max', 'mean', 'conv1x1']) + assert reduction in ["max", "mean", "conv1x1"] self.reduction = reduction # NEW: how to go from N x C_out x D x T -> N x D x T. Paper: max self.c_in = c_in # how many DxT representations there are in inputs self.c_out = c_out # how many representations of the input DxT matrix in TGM layers @@ -207,7 +212,7 @@ def __init__(self, self.tgm_layers = nn.Sequential(*modules) - if self.reduction == 'conv1x1': + if self.reduction == "conv1x1": self.reduction_layer = nn.Conv2d(self.c_out, 1, kernel_size=1, padding=0, stride=1) # self.sub_event1 = TGM(inp, 16, 5, c_in=1, c_out=8, soft=False) # self.sub_event2 = TGM(inp, 16, 5, c_in=8, c_out=8, soft=False) @@ -224,14 +229,13 @@ def __init__(self, self.viz = viz def forward(self, inp): - smoothed = self.tgm_layers(inp) # print('smoothed before max:{}'.format(smoothed.shape)) - if self.reduction == 'max': + if self.reduction == "max": smoothed = torch.max(smoothed, dim=1)[0] - elif self.reduction == 'mean': + elif self.reduction == "mean": smoothed = torch.mean(smoothed, dim=1) - elif self.reduction == 'conv1x1': + elif self.reduction == "conv1x1": smoothed = self.reduction_layer(smoothed).squeeze() # sub_event = self.dropout(torch.max(sub_event, dim=1)[0]) # print('sub_event:{}'.format(smoothed.shape)) @@ -241,8 +245,9 @@ def forward(self, inp): if inp.ndim == 3 and smoothed.ndim == 2: smoothed = smoothed.unsqueeze(0) else: - print('ERROR') + print("ERROR") import pdb + pdb.set_trace() if self.concatenate_inputs: @@ -267,26 +272,28 @@ def forward(self, inp): class TGMJ(nn.Module): - def __init__(self, - D: int = 1024, - n_filters: int = 16, - filter_length: int = 30, - input_dropout: float = 0.5, - dropout_p: float = 0.5, - classes: int = 8, - num_layers: int = 3, - reduction: str = 'max', - c_in: int = 1, - c_out: int = 8, - soft: bool = False, - num_features: int = 512, - viz: bool = False, - nonlinear_classification: bool = False, - concatenate_inputs=True, - pos=None, - neg=None, - use_fe_logits: bool = True, - final_bn: bool = False): + def __init__( + self, + D: int = 1024, + n_filters: int = 16, + filter_length: int = 30, + input_dropout: float = 0.5, + dropout_p: float = 0.5, + classes: int = 8, + num_layers: int = 3, + reduction: str = "max", + c_in: int = 1, + c_out: int = 8, + soft: bool = False, + num_features: int = 512, + viz: bool = False, + nonlinear_classification: bool = False, + concatenate_inputs=True, + pos=None, + neg=None, + use_fe_logits: bool = True, + final_bn: bool = False, + ): super().__init__() self.D = D # dimensionality of inputs. E.G. 1024 features from a CNN penultimate layer @@ -296,7 +303,7 @@ def __init__(self, self.input_dropout = nn.Dropout(input_dropout) # probability to dropout input channels self.output_dropout = nn.Dropout(dropout_p) # probability to dropout final layer before FC self.num_layers = num_layers # how many TGM layers - assert (reduction in ['max', 'mean', 'conv1x1']) + assert reduction in ["max", "mean", "conv1x1"] self.reduction = reduction # NEW: how to go from N x C_out x D x T -> N x D x T. Paper: max self.c_in = c_in # how many DxT representations there are in inputs self.c_out = c_out # how many representations of the input DxT matrix in TGM layers @@ -312,7 +319,7 @@ def __init__(self, self.tgm_layers = nn.Sequential(*modules) - if self.reduction == 'conv1x1': + if self.reduction == "conv1x1": self.reduction_layer = nn.Conv2d(self.c_out, 1, kernel_size=1, padding=0, stride=1) # self.sub_event1 = TGM(inp, 16, 5, c_in=1, c_out=8, soft=False) # self.sub_event2 = TGM(inp, 16, 5, c_in=8, c_out=8, soft=False) @@ -360,11 +367,11 @@ def __init__(self, def forward(self, inp, fe_logits=None): smoothed = self.tgm_layers(inp) # print('smoothed before max:{}'.format(smoothed.shape)) - if self.reduction == 'max': + if self.reduction == "max": smoothed = torch.max(smoothed, dim=1)[0] - elif self.reduction == 'mean': + elif self.reduction == "mean": smoothed = torch.mean(smoothed, dim=1) - elif self.reduction == 'conv1x1': + elif self.reduction == "conv1x1": smoothed = self.reduction_layer(smoothed).squeeze() # sub_event = self.dropout(torch.max(sub_event, dim=1)[0]) # print('sub_event:{}'.format(smoothed.shape)) @@ -374,8 +381,9 @@ def forward(self, inp, fe_logits=None): if inp.ndim == 3 and smoothed.ndim == 2: smoothed = smoothed.unsqueeze(0) else: - print('ERROR') + print("ERROR") import pdb + pdb.set_trace() outputs1 = self.input_dropout(inp) outputs2 = self.input_dropout(smoothed) diff --git a/deepethogram/sequence/train.py b/deepethogram/sequence/train.py index 51bbe4f..e172993 100644 --- a/deepethogram/sequence/train.py +++ b/deepethogram/sequence/train.py @@ -19,11 +19,11 @@ log = logging.getLogger(__name__) -plt.switch_backend('agg') +plt.switch_backend("agg") def sequence_train(cfg: DictConfig) -> nn.Module: - """Trains sequence models from a configuration. + """Trains sequence models from a configuration. Parameters ---------- @@ -36,34 +36,34 @@ def sequence_train(cfg: DictConfig) -> nn.Module: Trained sequence model """ cfg = projects.setup_run(cfg) - log.info('args: {}'.format(' '.join(sys.argv))) + log.info("args: {}".format(" ".join(sys.argv))) if cfg.sequence.latent_name is None: cfg.sequence.latent_name = cfg.feature_extractor.arch # allow for editing OmegaConf.set_struct(cfg, False) - log.info('Configuration used: ') + log.info("Configuration used: ") log.info(OmegaConf.to_yaml(cfg)) - datasets, data_info = get_datasets_from_cfg(cfg, 'sequence') - utils.save_dict_to_yaml(data_info['split'], os.path.join(os.getcwd(), 'split.yaml')) - model = build_model_from_cfg(cfg, - data_info['num_features'], - data_info['num_classes'], - pos=data_info['pos'], - neg=data_info['neg']) - weights = projects.get_weightfile_from_cfg(cfg, model_type='sequence') + datasets, data_info = get_datasets_from_cfg(cfg, "sequence") + utils.save_dict_to_yaml(data_info["split"], os.path.join(os.getcwd(), "split.yaml")) + model = build_model_from_cfg( + cfg, data_info["num_features"], data_info["num_classes"], pos=data_info["pos"], neg=data_info["neg"] + ) + weights = projects.get_weightfile_from_cfg(cfg, model_type="sequence") if weights is not None: model = utils.load_weights(model, weights) - log.debug('model arch: {}'.format(model)) - log.info('Total trainable params: {:,}'.format(utils.get_num_parameters(model))) + log.debug("model arch: {}".format(model)) + log.info("Total trainable params: {:,}".format(utils.get_num_parameters(model))) stopper = get_stopper(cfg) - metrics = get_metrics(os.getcwd(), - data_info['num_classes'], - num_parameters=utils.get_num_parameters(model), - key_metric='f1_class_mean', - num_workers=cfg.compute.metrics_workers) + metrics = get_metrics( + os.getcwd(), + data_info["num_classes"], + num_parameters=utils.get_num_parameters(model), + key_metric="f1_class_mean", + num_workers=cfg.compute.metrics_workers, + ) criterion = get_criterion(cfg, model, data_info) lightning_module = SequenceLightning(model, cfg, datasets, metrics, criterion) # change auto batch size parameters because large sequences can overflow RAM @@ -73,17 +73,17 @@ def sequence_train(cfg: DictConfig) -> nn.Module: class SequenceLightning(BaseLightningModule): - """Lightning Module for training sequence models - """ + """Lightning Module for training sequence models""" + def __init__(self, model: nn.Module, cfg: DictConfig, datasets: dict, metrics, criterion: nn.Module): super().__init__(model, cfg, datasets, metrics, viz.visualize_logger_multilabel_classification) self.has_logged_channels = False # for convenience self.final_activation = self.hparams.feature_extractor.final_activation - if self.final_activation == 'softmax': + if self.final_activation == "softmax": self.activation = nn.Softmax(dim=1) - elif self.final_activation == 'sigmoid': + elif self.final_activation == "sigmoid": self.activation = nn.Sigmoid() else: raise NotImplementedError @@ -111,33 +111,31 @@ def common_step(self, batch: dict, batch_idx: int, split: str): outputs = self(batch, split) probabilities = self.activation(outputs) - loss, loss_dict = self.criterion(outputs, batch['labels'], self.model) + loss, loss_dict = self.criterion(outputs, batch["labels"], self.model) # downsampled_t0, estimated_t0, flows_reshaped = self.reconstructor(images, outputs) # loss, loss_components = self.criterion(batch, downsampled_t0, estimated_t0, flows_reshaped) - self.visualize_batch(batch['features'], probabilities, batch['labels'], split) + self.visualize_batch(batch["features"], probabilities, batch["labels"], split) - self.metrics.buffer.append(split, { - 'loss': loss.detach(), - 'probs': probabilities.detach(), - 'labels': batch['labels'].detach() - }) + self.metrics.buffer.append( + split, {"loss": loss.detach(), "probs": probabilities.detach(), "labels": batch["labels"].detach()} + ) self.metrics.buffer.append(split, loss_dict) # need to use the native logger for lr scheduling, etc. - self.log(f'{split}_loss', loss.detach()) + self.log(f"{split}_loss", loss.detach()) # if self.batch_cnt == 100: # print('stop') # self.batch_cnt += 1 return loss def training_step(self, batch: dict, batch_idx: int): - return self.common_step(batch, batch_idx, 'train') + return self.common_step(batch, batch_idx, "train") def validation_step(self, batch: dict, batch_idx: int): - return self.common_step(batch, batch_idx, 'val') + return self.common_step(batch, batch_idx, "val") def test_step(self, batch: dict, batch_idx: int): - images, outputs = self(batch, 'test') + images, outputs = self(batch, "test") def visualize_batch(self, features, predictions, labels, split: str): if self.hparams.train.viz_examples == 0: @@ -151,22 +149,20 @@ def visualize_batch(self, features, predictions, labels, split: str): fig = plt.figure(figsize=(14, 14)) # log.info('visualizing sequence batch') viz.visualize_batch_sequence(features, predictions, labels, fig=fig) - viz.save_figure(fig, 'batch', True, viz_cnt, split) + viz.save_figure(fig, "batch", True, viz_cnt, split) # this should happen in the save figure function, but for some reason it doesn't plt.close(fig) - plt.close('all') + plt.close("all") del fig def forward(self, batch: dict, mode: str) -> torch.Tensor: - outputs = self.model(batch['features']) + outputs = self.model(batch["features"]) return outputs -def build_model_from_cfg(cfg: DictConfig, - num_features: int, - num_classes: int, - neg: np.ndarray = None, - pos: np.ndarray = None): +def build_model_from_cfg( + cfg: DictConfig, num_features: int, num_classes: int, neg: np.ndarray = None, pos: np.ndarray = None +): """ Initializes a sequence model from a configuration dictionary. @@ -201,60 +197,66 @@ def build_model_from_cfg(cfg: DictConfig, deepethogram.sequence.models """ seq = cfg.sequence - log.debug('model building parameters: {}'.format(seq)) - if seq.arch == 'linear': + log.debug("model building parameters: {}".format(seq)) + if seq.arch == "linear": model = Linear(num_features, num_classes, kernel_size=1) - elif seq.arch == 'conv_nonlinear': + elif seq.arch == "conv_nonlinear": model = Conv_Nonlinear(num_features, num_classes, hidden_size=seq.hidden_size, dropout_p=seq.dropout_p) - elif seq.arch == 'rnn': - model = RNN(num_features, - num_classes, - style=seq.rnn_style, - hidden_size=seq.hidden_size, - dropout=seq.hidden_dropout, - num_layers=seq.num_layers, - output_dropout=seq.dropout_p, - bidirectional=seq.bidirectional) - elif seq.arch == 'tgm': - model = TGM(num_features, - classes=num_classes, - n_filters=seq.n_filters, - filter_length=seq.filter_length, - input_dropout=seq.input_dropout, - dropout_p=seq.dropout_p, - num_layers=seq.num_layers, - reduction=seq.tgm_reduction, - c_in=seq.c_in, - c_out=seq.c_out, - soft=seq.soft_attn, - num_features=seq.num_features) - elif seq.arch == 'tgmj': - model = TGMJ(num_features, - classes=num_classes, - n_filters=seq.n_filters, - filter_length=seq.filter_length, - input_dropout=seq.input_dropout, - dropout_p=seq.dropout_p, - num_layers=seq.num_layers, - reduction=seq.tgm_reduction, - c_in=seq.c_in, - c_out=seq.c_out, - soft=seq.soft_attn, - num_features=seq.num_features, - pos=pos, - neg=neg, - use_fe_logits=False, - nonlinear_classification=seq.nonlinear_classification, - final_bn=seq.final_bn) - elif seq.arch == 'mlp': + elif seq.arch == "rnn": + model = RNN( + num_features, + num_classes, + style=seq.rnn_style, + hidden_size=seq.hidden_size, + dropout=seq.hidden_dropout, + num_layers=seq.num_layers, + output_dropout=seq.dropout_p, + bidirectional=seq.bidirectional, + ) + elif seq.arch == "tgm": + model = TGM( + num_features, + classes=num_classes, + n_filters=seq.n_filters, + filter_length=seq.filter_length, + input_dropout=seq.input_dropout, + dropout_p=seq.dropout_p, + num_layers=seq.num_layers, + reduction=seq.tgm_reduction, + c_in=seq.c_in, + c_out=seq.c_out, + soft=seq.soft_attn, + num_features=seq.num_features, + ) + elif seq.arch == "tgmj": + model = TGMJ( + num_features, + classes=num_classes, + n_filters=seq.n_filters, + filter_length=seq.filter_length, + input_dropout=seq.input_dropout, + dropout_p=seq.dropout_p, + num_layers=seq.num_layers, + reduction=seq.tgm_reduction, + c_in=seq.c_in, + c_out=seq.c_out, + soft=seq.soft_attn, + num_features=seq.num_features, + pos=pos, + neg=neg, + use_fe_logits=False, + nonlinear_classification=seq.nonlinear_classification, + final_bn=seq.final_bn, + ) + elif seq.arch == "mlp": model = MLP(num_features, num_classes, dropout_p=seq.dropout_p, pos=pos, neg=neg) else: - raise ValueError('arch not found: {}'.format(seq.arch)) + raise ValueError("arch not found: {}".format(seq.arch)) print(model) return model -if __name__ == '__main__': +if __name__ == "__main__": project_path = projects.get_project_path_from_cl(sys.argv) cfg = make_sequence_train_cfg(project_path, use_command_line=True) diff --git a/deepethogram/stoppers.py b/deepethogram/stoppers.py index fde5dea..902d5da 100644 --- a/deepethogram/stoppers.py +++ b/deepethogram/stoppers.py @@ -7,9 +7,10 @@ class Stopper: - """ Base class for stopping training """ + """Base class for stopping training""" + def __init__(self, name: str, start_epoch: int = 0, num_epochs: int = 1000): - """ constructor for stopper + """constructor for stopper Parameters ---------- @@ -25,7 +26,7 @@ def __init__(self, name: str, start_epoch: int = 0, num_epochs: int = 1000): self.num_epochs = num_epochs def step(self, *args, **kwargs): - """ increment internal counter """ + """increment internal counter""" self.epoch_counter += 1 def __call__(self, *args, **kwargs): @@ -33,7 +34,7 @@ def __call__(self, *args, **kwargs): class NumEpochsStopper(Stopper): - def __init__(self, name: str = 'num_epochs', start_epoch: int = 0, num_epochs: int = 1000): + def __init__(self, name: str = "num_epochs", start_epoch: int = 0, num_epochs: int = 1000): super().__init__(name, start_epoch, num_epochs) def step(self, *args, **kwargs): @@ -53,8 +54,9 @@ class EarlyStopping(Stopper): https://github.com/pytorch/ignite/blob/master/ignite/handlers/early_stopping.py """ - def __init__(self, name='early', start_epoch=0, num_epochs=1000, patience=5, is_error=False, - early_stopping_begins: int = 0): + def __init__( + self, name="early", start_epoch=0, num_epochs=1000, patience=5, is_error=False, early_stopping_begins: int = 0 + ): super().__init__(name, start_epoch, num_epochs) if patience < 1: raise ValueError("Argument patience should be positive integer") @@ -88,7 +90,9 @@ def step(self, score): if self.epoch_counter > self.num_epochs: should_stop = True - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() return best, should_stop @@ -108,8 +112,14 @@ class LearningRateStopper(Stopper): break """ - def __init__(self, name='learning_rate', minimum_learning_rate: float = 5e-7, start_epoch=0, num_epochs=1000, - eps: float = 1e-8): + def __init__( + self, + name="learning_rate", + minimum_learning_rate: float = 5e-7, + start_epoch=0, + num_epochs=1000, + eps: float = 1e-8, + ): super().__init__(name, start_epoch, num_epochs) """Constructor for LearningRateStopper. Args: @@ -129,7 +139,7 @@ def step(self, lr: float) -> bool: should_stop = False # print('epoch counter: {} num_epochs: {}'.format(self.epoch_counter, self.num_epochs)) if lr < self.minimum_learning_rate + self.eps or self.epoch_counter >= self.num_epochs: - print('Reached learning rate {}, stopping...'.format(lr)) + print("Reached learning rate {}, stopping...".format(lr)) should_stop = True return should_stop @@ -148,15 +158,20 @@ def get_stopper(cfg: DictConfig) -> Type[Stopper]: """ # ASSUME WE'RE USING LOSS AS THE KEY METRIC, WHICH IS AN ERROR stopping_type = cfg.train.stopping_type - log.debug('Using stopper type {}'.format(stopping_type)) - if stopping_type == 'early': - return EarlyStopping(start_epoch=0, num_epochs=cfg.train.num_epochs, - patience=cfg.train.patience, - is_error=True, early_stopping_begins=cfg.train.early_stopping_begins) - elif stopping_type == 'learning_rate': - return LearningRateStopper(start_epoch=0, num_epochs=cfg.train.num_epochs, - minimum_learning_rate=cfg.train.min_lr) - elif stopping_type == 'num_epochs': - return NumEpochsStopper('num_epochs', start_epoch=0, num_epochs=cfg.train.num_epochs) + log.debug("Using stopper type {}".format(stopping_type)) + if stopping_type == "early": + return EarlyStopping( + start_epoch=0, + num_epochs=cfg.train.num_epochs, + patience=cfg.train.patience, + is_error=True, + early_stopping_begins=cfg.train.early_stopping_begins, + ) + elif stopping_type == "learning_rate": + return LearningRateStopper( + start_epoch=0, num_epochs=cfg.train.num_epochs, minimum_learning_rate=cfg.train.min_lr + ) + elif stopping_type == "num_epochs": + return NumEpochsStopper("num_epochs", start_epoch=0, num_epochs=cfg.train.num_epochs) else: - raise ValueError('invalid stopping name detected! {}'.format(stopping_type)) + raise ValueError("invalid stopping name detected! {}".format(stopping_type)) diff --git a/deepethogram/tune/feature_extractor.py b/deepethogram/tune/feature_extractor.py index f2bc8dc..14381bc 100644 --- a/deepethogram/tune/feature_extractor.py +++ b/deepethogram/tune/feature_extractor.py @@ -2,14 +2,15 @@ import sys from omegaconf import OmegaConf, DictConfig -try: + +try: import ray from ray import tune from ray.tune import CLIReporter from ray.tune.schedulers import ASHAScheduler from ray.tune.suggest.hyperopt import HyperOptSearch except ImportError: - print('To use the deepethogram.tune module, you must `pip install \'ray[tune]`') + print("To use the deepethogram.tune module, you must `pip install 'ray[tune]`") raise from deepethogram.configuration import make_config @@ -17,8 +18,9 @@ from deepethogram import projects from deepethogram.tune.utils import dict_to_dotlist, generate_tune_cfg -def tune_feature_extractor(cfg: DictConfig): - """Tunes feature extractor hyperparameters. + +def tune_feature_extractor(cfg: DictConfig): + """Tunes feature extractor hyperparameters. Parameters ---------- @@ -31,58 +33,57 @@ def tune_feature_extractor(cfg: DictConfig): Checks that search method is either 'random' or 'hyperopt' """ scheduler = ASHAScheduler( - max_t=cfg.train.num_epochs, # epochs + max_t=cfg.train.num_epochs, # epochs grace_period=cfg.tune.grace_period, - reduction_factor=2) - + reduction_factor=2, + ) + reporter_dict = {} for key in cfg.tune.hparams.keys(): reporter_dict[key] = cfg.tune.hparams[key].short # reporter_dict = {key: value for key, value in zip(cfg.tune.hparams.keys(), )} reporter = CLIReporter(parameter_columns=reporter_dict) - + # this converts what's in our cfg to a dictionary containing the search space of our hyperparameters tune_experiment_cfg = generate_tune_cfg(cfg) - - if cfg.tune.search == 'hyperopt': + + if cfg.tune.search == "hyperopt": # https://docs.ray.io/en/master/tune/api_docs/suggestion.html#tune-hyperopt current_best = {} for key, value in cfg.tune.hparams.items(): current_best[key] = value.current_best # hyperopt wants this to be a list of dicts current_best = [current_best] - search = HyperOptSearch(metric=cfg.tune.key_metric, - mode='max', - points_to_evaluate=current_best) - elif cfg.tune.search == 'random': + search = HyperOptSearch(metric=cfg.tune.key_metric, mode="max", points_to_evaluate=current_best) + elif cfg.tune.search == "random": search = None - else: + else: raise NotImplementedError - - print('Running hyperparamter tuning with configuration: ') + + print("Running hyperparamter tuning with configuration: ") print(OmegaConf.to_yaml(cfg)) - + analysis = tune.run( tune.with_parameters( - run_ray_experiment, + run_ray_experiment, cfg=cfg, - ), - resources_per_trial=OmegaConf.to_container(cfg.tune.resources_per_trial), - metric=cfg.tune.key_metric, - mode='max', + ), + resources_per_trial=OmegaConf.to_container(cfg.tune.resources_per_trial), + metric=cfg.tune.key_metric, + mode="max", config=tune_experiment_cfg, - num_samples=cfg.tune.num_trials, # how many experiments to run - scheduler=scheduler, - progress_reporter=reporter, - name=cfg.tune.name, - local_dir=cfg.project.model_path, - search_alg=search + num_samples=cfg.tune.num_trials, # how many experiments to run + scheduler=scheduler, + progress_reporter=reporter, + name=cfg.tune.name, + local_dir=cfg.project.model_path, + search_alg=search, ) print("Best hyperparameters found were: ", analysis.best_config) - analysis.results_df.to_csv(os.path.join(cfg.project.model_path, 'ray_results.csv')) + analysis.results_df.to_csv(os.path.join(cfg.project.model_path, "ray_results.csv")) -def run_ray_experiment(ray_cfg, cfg): +def run_ray_experiment(ray_cfg, cfg): """trains a model based on the base config and the one generated for this experiment Parameters @@ -93,36 +94,48 @@ def run_ray_experiment(ray_cfg, cfg): base configuration with all non-tuned hyperparameters and information """ ray_cfg = OmegaConf.from_dotlist(dict_to_dotlist(ray_cfg)) - + cfg = OmegaConf.merge(cfg, ray_cfg) if cfg.notes is None: - cfg.notes = f'{cfg.tune.name}_{tune.get_trial_id()}' + cfg.notes = f"{cfg.tune.name}_{tune.get_trial_id()}" else: - cfg.notes += f'{cfg.tune.name}_{tune.get_trial_id()}' + cfg.notes += f"{cfg.tune.name}_{tune.get_trial_id()}" feature_extractor_train(cfg) - -if __name__ == '__main__': + + +if __name__ == "__main__": # USAGE # to run locally, type `ray start --head --port 6385`, then run this script - - ray.init(address='auto') #num_gpus=1 - - config_list = ['config','augs','model/flow_generator','train', 'model/feature_extractor', - 'tune/tune', 'tune/feature_extractor'] - run_type = 'train' - model = 'feature_extractor' - + + ray.init(address="auto") # num_gpus=1 + + config_list = [ + "config", + "augs", + "model/flow_generator", + "train", + "model/feature_extractor", + "tune/tune", + "tune/feature_extractor", + ] + run_type = "train" + model = "feature_extractor" + project_path = projects.get_project_path_from_cl(sys.argv) - cfg = make_config(project_path=project_path, config_list=config_list, run_type=run_type, model=model, - use_command_line=True, debug=True) + cfg = make_config( + project_path=project_path, + config_list=config_list, + run_type=run_type, + model=model, + use_command_line=True, + debug=True, + ) cfg = projects.convert_config_paths_to_absolute(cfg) - - if 'preset' in cfg.keys(): - cfg.tune.name += '_{}'.format(cfg.preset) - if 'debug' in cfg.keys(): - cfg.tune.name += '_debug' - + + if "preset" in cfg.keys(): + cfg.tune.name += "_{}".format(cfg.preset) + if "debug" in cfg.keys(): + cfg.tune.name += "_debug" + tune_feature_extractor(cfg) - - \ No newline at end of file diff --git a/deepethogram/tune/sequence.py b/deepethogram/tune/sequence.py index 6084c22..86bb08a 100644 --- a/deepethogram/tune/sequence.py +++ b/deepethogram/tune/sequence.py @@ -2,14 +2,15 @@ import sys from omegaconf import OmegaConf, DictConfig -try: + +try: import ray from ray import tune from ray.tune import CLIReporter from ray.tune.schedulers import ASHAScheduler from ray.tune.suggest.hyperopt import HyperOptSearch except ImportError: - print('To use the deepethogram.tune module, you must `pip install \'ray[tune]`') + print("To use the deepethogram.tune module, you must `pip install 'ray[tune]`") raise from deepethogram.configuration import make_config @@ -18,7 +19,7 @@ from deepethogram.tune.utils import dict_to_dotlist, generate_tune_cfg -def tune_sequence(cfg: DictConfig): +def tune_sequence(cfg: DictConfig): """Tunes sequence model hyperparameters Parameters @@ -31,100 +32,103 @@ def tune_sequence(cfg: DictConfig): NotImplementedError Checks that search method is either 'random' or 'hyperopt' """ - + scheduler = ASHAScheduler( - max_t=cfg.train.num_epochs, # epochs + max_t=cfg.train.num_epochs, # epochs grace_period=cfg.tune.grace_period, - reduction_factor=2) - + reduction_factor=2, + ) + reporter_dict = {} for key in cfg.tune.hparams.keys(): reporter_dict[key] = cfg.tune.hparams[key].short # reporter_dict = {key: value for key, value in zip(cfg.tune.hparams.keys(), )} reporter = CLIReporter(parameter_columns=reporter_dict) - + # this converts what's in our cfg to a dictionary containing the search space of our hyperparameters tune_experiment_cfg = generate_tune_cfg(cfg) - - if cfg.tune.search == 'hyperopt': + + if cfg.tune.search == "hyperopt": # https://docs.ray.io/en/master/tune/api_docs/suggestion.html#tune-hyperopt current_best = {} for key, value in cfg.tune.hparams.items(): current_best[key] = value.current_best # hyperopt wants this to be a list of dicts current_best = [current_best] - search = HyperOptSearch(metric=cfg.tune.key_metric, - mode='max', - points_to_evaluate=current_best) - elif cfg.tune.search == 'random': + search = HyperOptSearch(metric=cfg.tune.key_metric, mode="max", points_to_evaluate=current_best) + elif cfg.tune.search == "random": search = None - else: + else: raise NotImplementedError - - print('Running hyperparamter tuning with configuration: ') + + print("Running hyperparamter tuning with configuration: ") print(OmegaConf.to_yaml(cfg)) - + analysis = tune.run( tune.with_parameters( - run_ray_experiment, + run_ray_experiment, cfg=cfg, - ), - resources_per_trial=OmegaConf.to_container(cfg.tune.resources_per_trial), - metric=cfg.tune.key_metric, - mode='max', + ), + resources_per_trial=OmegaConf.to_container(cfg.tune.resources_per_trial), + metric=cfg.tune.key_metric, + mode="max", config=tune_experiment_cfg, - num_samples=cfg.tune.num_trials, # how many experiments to run - scheduler=scheduler, - progress_reporter=reporter, - name=cfg.tune.name, - local_dir=cfg.project.model_path, - search_alg=search + num_samples=cfg.tune.num_trials, # how many experiments to run + scheduler=scheduler, + progress_reporter=reporter, + name=cfg.tune.name, + local_dir=cfg.project.model_path, + search_alg=search, ) print("Best hyperparameters found were: ", analysis.best_config) - analysis.results_df.to_csv(os.path.join(cfg.project.model_path, 'ray_results.csv')) + analysis.results_df.to_csv(os.path.join(cfg.project.model_path, "ray_results.csv")) -def run_ray_experiment(ray_cfg, cfg): +def run_ray_experiment(ray_cfg, cfg): # cfg = make_feature_extractor_train_cfg(project_path, use_command_line=False, preset='deg_f') # tune_cfg = load_config_by_name('tune') - + ray_cfg = OmegaConf.from_dotlist(dict_to_dotlist(ray_cfg)) - + cfg = OmegaConf.merge(cfg, ray_cfg) # cfg.tune.use = True - + # cfg.flow_generator.weights = 'latest' # cfg.feature_extractor.weights = '/media/jim/DATA_SSD/niv_revision_deepethogram/models/pretrained_models/200415_125824_hidden_two_stream_kinetics_degf/checkpoint.pt' # cfg.compute.batch_size = 64 # cfg.train.steps_per_epoch.train = 20 # cfg.train.steps_per_epoch.val = 20 if cfg.notes is None: - cfg.notes = f'{cfg.tune.name}_{tune.get_trial_id()}' + cfg.notes = f"{cfg.tune.name}_{tune.get_trial_id()}" else: - cfg.notes += f'{cfg.tune.name}_{tune.get_trial_id()}' + cfg.notes += f"{cfg.tune.name}_{tune.get_trial_id()}" sequence_train(cfg) - -if __name__ == '__main__': + + +if __name__ == "__main__": # USAGE # to run locally, type `ray start --head --port 6385`, then run this script - - ray.init(address='auto') #num_gpus=1 - - config_list = ['config','model/feature_extractor', 'train', 'model/sequence', - 'tune/tune', 'tune/sequence'] - run_type = 'train' - model = 'sequence' - + + ray.init(address="auto") # num_gpus=1 + + config_list = ["config", "model/feature_extractor", "train", "model/sequence", "tune/tune", "tune/sequence"] + run_type = "train" + model = "sequence" + project_path = projects.get_project_path_from_cl(sys.argv) - cfg = make_config(project_path=project_path, config_list=config_list, run_type=run_type, model=model, - use_command_line=True, debug=False) + cfg = make_config( + project_path=project_path, + config_list=config_list, + run_type=run_type, + model=model, + use_command_line=True, + debug=False, + ) cfg = projects.convert_config_paths_to_absolute(cfg) - - cfg.tune.name = 'tune_sequence_2' - - if 'debug' in cfg.keys(): - cfg.tune.name += '_debug' - + + cfg.tune.name = "tune_sequence_2" + + if "debug" in cfg.keys(): + cfg.tune.name += "_debug" + tune_sequence(cfg) - - \ No newline at end of file diff --git a/deepethogram/tune/utils.py b/deepethogram/tune/utils.py index 97c72a7..d36cebc 100644 --- a/deepethogram/tune/utils.py +++ b/deepethogram/tune/utils.py @@ -1,33 +1,35 @@ from omegaconf import OmegaConf -try: + +try: import ray from ray import tune except ImportError: - print('To use the deepethogram.tune module, you must `pip install \'ray[tune]`') + print("To use the deepethogram.tune module, you must `pip install 'ray[tune]`") raise -# code modified from official ray docs: + +# code modified from official ray docs: # https://docs.ray.io/en/master/tune/tutorials/tune-pytorch-lightning.html def dict_to_dotlist(cfg_dict): - dotlist = [f'{key}={value}' for key, value in cfg_dict.items()] + dotlist = [f"{key}={value}" for key, value in cfg_dict.items()] return dotlist def generate_tune_cfg(cfg): - """from a configuration, e.g. conf/tune/feature_extractor.yaml, generate a search space for specific hyperparameters - """ + """from a configuration, e.g. conf/tune/feature_extractor.yaml, generate a search space for specific hyperparameters""" + def get_space(hparam_dict): - if hparam_dict.space == 'uniform': + if hparam_dict.space == "uniform": return tune.uniform(hparam_dict.min, hparam_dict.max) - elif hparam_dict.space == 'log': + elif hparam_dict.space == "log": return tune.loguniform(hparam_dict.min, hparam_dict.max) - elif hparam_dict.space == 'choice': + elif hparam_dict.space == "choice": return tune.choice(OmegaConf.to_container(hparam_dict.choices)) else: raise NotImplementedError - + tune_cfg = {} for key, value in cfg.tune.hparams.items(): tune_cfg[key] = get_space(value) - - return tune_cfg \ No newline at end of file + + return tune_cfg diff --git a/deepethogram/utils.py b/deepethogram/utils.py index c176811..008b1ae 100644 --- a/deepethogram/utils.py +++ b/deepethogram/utils.py @@ -20,7 +20,7 @@ def load_yaml(filename: Union[str, os.PathLike]) -> dict: """Simple wrapper around yaml.load to load yaml files as dictionaries""" - with open(filename, 'r') as f: + with open(filename, "r") as f: dictionary = yaml.load(f, Loader=yaml.Loader) return dictionary @@ -37,19 +37,21 @@ def get_minimum_learning_rate(optimizer): """ min_lr = 1e9 for i, param_group in enumerate(optimizer.param_groups): - lr = param_group['lr'] + lr = param_group["lr"] if lr < min_lr: min_lr = lr - return (min_lr) - - -def load_checkpoint(model, - optimizer, - checkpoint_file: Union[str, os.PathLike], - config: dict, - overwrite_args: bool = False, - distributed: bool = False): - """"Reload model and optimizer weights from a checkpoint.pt file + return min_lr + + +def load_checkpoint( + model, + optimizer, + checkpoint_file: Union[str, os.PathLike], + config: dict, + overwrite_args: bool = False, + distributed: bool = False, +): + """ "Reload model and optimizer weights from a checkpoint.pt file Args: model: instance of torch.nn.Module class optimizer: instance of torch.optim.Optimizer class (ADAM, SGDM, etc.) @@ -61,7 +63,7 @@ def load_checkpoint(model, optimizer: optimizer with recent history of gradients config: depending on overwrite_args, input or reloaded hyperparameter dictionary """ - log.info('Reloading model from {}...'.format(checkpoint_file)) + log.info("Reloading model from {}...".format(checkpoint_file)) model, optimizer_dict, _, new_args = load_state(model, checkpoint_file, distributed=distributed) if type(new_args) != dict: new_config = vars(new_args) @@ -70,20 +72,20 @@ def load_checkpoint(model, try: optimizer.load_state_dict(optimizer_dict) except Exception as e: - log.exception('Trouble loading optimizer state dict--might have requires-grad' \ - 'for different parameters: {}'.format(e)) - log.warning('Not loading optimizer state.') + log.exception( + "Trouble loading optimizer state dict--might have requires-grad" "for different parameters: {}".format(e) + ) + log.warning("Not loading optimizer state.") if overwrite_args: config = new_config return model, optimizer, config -def load_weights(model, - checkpoint_file: Union[str, os.PathLike], - distributed: bool = False, - device: torch.device = None): - """"Reload model weights from a checkpoint.pt file +def load_weights( + model, checkpoint_file: Union[str, os.PathLike], distributed: bool = False, device: torch.device = None +): + """ "Reload model weights from a checkpoint.pt file Args: model: instance of torch.nn.Module class checkpoint_file: path to checkpoint.pt @@ -99,7 +101,7 @@ def load_weights(model, def checkpoint(model, rundir: Union[str, os.PathLike], epoch: int, args=None): - """" + """ " Args: model: instance of torch.nn.Module class rundir: directory to save checkpoint.pt to @@ -111,26 +113,26 @@ def checkpoint(model, rundir: Union[str, os.PathLike], epoch: int, args=None): args = OmegaConf.to_container(args) if type(args) != dict: args = vars(args) - fname = 'checkpoint.pt' + fname = "checkpoint.pt" fullfile = os.path.join(rundir, fname) # note: I used to save the optimizer dict as well, but this was confusing in terms of keeping track of learning # rates, making sure the same keys were in the optimizer dict even when you've done something like change # the size of the final layer of the NN (for different number of classes). I've kept the optimizer field for # backwards compatibility, but this should not be used - state = {'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': None, 'hyperparameters': args} + state = {"epoch": epoch, "state_dict": model.state_dict(), "optimizer": None, "hyperparameters": args} torch.save(state, fullfile) def save_two_stream(model, rundir: Union[os.PathLike, str], config: dict = None, epoch: int = None) -> None: - """ Saves a two-stream model to disk. Saves spatial and flow feature extractors in their own directories """ + """Saves a two-stream model to disk. Saves spatial and flow feature extractors in their own directories""" assert os.path.isdir(rundir) assert isinstance(model, torch.nn.Module) - spatialdir = os.path.join(rundir, 'spatial') + spatialdir = os.path.join(rundir, "spatial") if not os.path.isdir(spatialdir): os.makedirs(spatialdir) checkpoint(model.spatial_classifier, spatialdir, epoch, config) - flow_classifier_dir = os.path.join(rundir, 'flow') + flow_classifier_dir = os.path.join(rundir, "flow") if not os.path.isdir(flow_classifier_dir): os.makedirs(flow_classifier_dir) checkpoint(model.flow_classifier, flow_classifier_dir, epoch, config) @@ -139,10 +141,10 @@ def save_two_stream(model, rundir: Union[os.PathLike, str], config: dict = None, def save_hidden_two_stream(model, rundir: Union[os.PathLike, str], config: dict = None, epoch: int = None) -> None: - """ Saves a hidden two-stream model to disk. Saves flow generator in a separate directory """ + """Saves a hidden two-stream model to disk. Saves flow generator in a separate directory""" assert os.path.isdir(rundir) assert isinstance(model, torch.nn.Module) - flowdir = os.path.join(rundir, 'flow_generator') + flowdir = os.path.join(rundir, "flow_generator") if not os.path.isdir(flowdir): os.makedirs(flowdir) if type(config) == DictConfig: @@ -158,10 +160,10 @@ def save_dict_to_yaml(dictionary: dict, filename: Union[str, bytes, os.PathLike] filename: file to save dict to. Should end in .yaml """ if os.path.isfile(filename): - log.debug('File {} already exists, overwriting...'.format(filename)) + log.debug("File {} already exists, overwriting...".format(filename)) if isinstance(dictionary, DictConfig): dictionary = OmegaConf.to_container(dictionary) - with open(filename, 'w') as f: + with open(filename, "w") as f: yaml.dump(dictionary, f, default_flow_style=False) @@ -173,7 +175,7 @@ def tensor_to_np(tensor: Union[torch.Tensor, np.ndarray]) -> np.ndarray: def in_this_dir(abs_path: Union[str, os.PathLike]) -> dict: - """ Gets list of files in a subdirectory and returns information about it. + """Gets list of files in a subdirectory and returns information about it. Designed to be a drop-in replacement for MATLAB's `dir` command :P Args: abs_path: absolute path to a directory @@ -181,20 +183,20 @@ def in_this_dir(abs_path: Union[str, os.PathLike]) -> dict: contents: dictionary with keys 'name', 'isdir', and 'bytes', containing the name, whether or not the file is a directory, and the filesize in bytes of all files in the directory """ - backslashes = strfind(abs_path, '\\') + backslashes = strfind(abs_path, "\\") if len(backslashes) > 1 and backslashes[0] != -1: - abs_path.replace('\\', '/') + abs_path.replace("\\", "/") # contents contains list of filenames or directory names filenames = os.listdir(abs_path) contents = [] for name in filenames: content = {} - content['name'] = name - content['isdir'] = os.path.isdir(os.path.join(abs_path, name)) - content['bytes'] = os.path.getsize(os.path.join(abs_path, name)) + content["name"] = name + content["isdir"] = os.path.isdir(os.path.join(abs_path, name)) + content["bytes"] = os.path.getsize(os.path.join(abs_path, name)) contents.append(content) # sort by name! - contents = sorted(contents, key=itemgetter('name')) + contents = sorted(contents, key=itemgetter("name")) return contents @@ -209,7 +211,7 @@ def get_datadir_from_paths(paths, dataset): datadir = v found = True if not found: - raise ValueError('couldn' 't find dataset: {}'.format(dataset)) + raise ValueError("couldn" "t find dataset: {}".format(dataset)) return datadir @@ -237,17 +239,17 @@ def load_state_from_dict(model, state_dict): model_dict = model.state_dict() pretrained_dict = {} for k, v in state_dict.items(): - if 'criterion' in k: + if "criterion" in k: # we might have parameters from the loss function in our loaded weights. we don't want to reload these; # we will specify them for whatever we are currently training. continue if k not in model_dict: - log.warning('{} not found in model dictionary'.format(k)) + log.warning("{} not found in model dictionary".format(k)) else: if model_dict[k].size() != v.size(): - log.warning('{} has different size: pretrained:{} model:{}'.format(k, v.size(), model_dict[k].size())) + log.warning("{} has different size: pretrained:{} model:{}".format(k, v.size(), model_dict[k].size())) else: - log.debug('Successfully loaded: {}'.format(k)) + log.debug("Successfully loaded: {}".format(k)) pretrained_dict[k] = v model_dict.update(pretrained_dict) @@ -255,33 +257,33 @@ def load_state_from_dict(model, state_dict): # model_dict.update(only_in_model_dict) # load the state dict, only for layers of same name, shape, size, etc. model.load_state_dict(model_dict, strict=True) - return (model) + return model def load_state_dict_from_file(weights_file, distributed: bool = False): - state = torch.load(weights_file, map_location='cpu') + state = torch.load(weights_file, map_location="cpu") # except RuntimeError as e: # log.exception(e) # log.info('loading onto cpu...') # state = torch.load(weights_file, map_location='cpu') - is_pure_weights = 'epoch' not in list(state.keys()) + is_pure_weights = "epoch" not in list(state.keys()) # load params if is_pure_weights: state_dict = state start_epoch = 0 else: - start_epoch = state['epoch'] - state_dict = state['state_dict'] + start_epoch = state["epoch"] + state_dict = state["state_dict"] optimizer_dict = None # state['optimizer'] first_key = next(iter(state_dict.items()))[0] - trained_on_dataparallel = first_key[:7] == 'module.' + trained_on_dataparallel = first_key[:7] == "module." if distributed and not trained_on_dataparallel: new_state_dict = OrderedDict() for k, v in state_dict.items(): - name = 'module.' + k + name = "module." + k new_state_dict[name] = v state_dict = new_state_dict # if it was trained on multi-gpu, remove the 'module.' before variable names @@ -293,19 +295,19 @@ def load_state_dict_from_file(weights_file, distributed: bool = False): state_dict = new_state_dict # sometimes I have the encoder in a structure called 'model', which means # all weights have 'model.' prepended - model_prepended = first_key[:6] == 'model.' + model_prepended = first_key[:6] == "model." if model_prepended: new_state_dict = OrderedDict() for k, v in state_dict.items(): - if k[:6] == 'model.': + if k[:6] == "model.": name = k[6:] else: name = k new_state_dict[name] = v state_dict = new_state_dict if not is_pure_weights: - if 'hyperparameters' in list(state.keys()): - args = state['hyperparameters'] + if "hyperparameters" in list(state.keys()): + args = state["hyperparameters"] else: args = None else: @@ -315,7 +317,7 @@ def load_state_dict_from_file(weights_file, distributed: bool = False): def load_state(model, weights_file: Union[str, os.PathLike], device: torch.device = None, distributed: bool = False): - """"Reload model and optimizer weights from a checkpoint.pt file. + """ "Reload model and optimizer weights from a checkpoint.pt file. TODO: refactor this loading for pytorch 1.4+. This was written many versions ago @@ -337,7 +339,7 @@ def load_state(model, weights_file: Union[str, os.PathLike], device: torch.devic # epoch: final epoch number from training # state_dict: weights # args: hyperparameters - log.info('loading from checkpoint file {}...'.format(weights_file)) + log.info("loading from checkpoint file {}...".format(weights_file)) state_dict, start_epoch, args = load_state_dict_from_file(weights_file, distributed=distributed) # LOAD PARAMS @@ -353,9 +355,16 @@ def print_gpus(): """ n_gpus = torch.cuda.device_count() for i in range(n_gpus): - print('GPU %d %s: Compute Capability %d.%d, Mem:%f' % - (i, torch.cuda.get_device_name(i), int(torch.cuda.get_device_capability(i)[0]), - int(torch.cuda.get_device_capability(i)[1]), torch.cuda.max_memory_allocated(i))) + print( + "GPU %d %s: Compute Capability %d.%d, Mem:%f" + % ( + i, + torch.cuda.get_device_name(i), + int(torch.cuda.get_device_capability(i)[0]), + int(torch.cuda.get_device_capability(i)[1]), + torch.cuda.max_memory_allocated(i), + ) + ) class Normalizer: @@ -375,10 +384,13 @@ class Normalizer: mean: mean of input data. For images, should have 2 or 3 channels std: standard deviation of input data """ - def __init__(self, - mean: Union[list, np.ndarray, torch.Tensor] = None, - std: Union[list, np.ndarray, torch.Tensor] = None, - clamp: bool = True): + + def __init__( + self, + mean: Union[list, np.ndarray, torch.Tensor] = None, + std: Union[list, np.ndarray, torch.Tensor] = None, + clamp: bool = True, + ): """Constructor for Normalizer class. Args: mean: mean of input data. Should have 3 channels (for R,G,B) or 2 (for X,Y) in the optical flow case @@ -388,15 +400,15 @@ def __init__(self, # make sure that if you have a mean, you also have a std # XOR has_mean, has_std = mean is None, std is None - assert (not has_mean ^ has_std) + assert not has_mean ^ has_std self.mean = self.process_inputs(mean) self.std = self.process_inputs(std) # prevent divide by zero, but only change values if it's close to 0 already if self.std is not None: - assert (self.std.min() > 0) + assert self.std.min() > 0 self.std[self.std < 1e-8] += 1e-8 - log.debug('Normalizer created with mean {} and std {}'.format(self.mean, self.std)) + log.debug("Normalizer created with mean {} and std {}".format(self.mean, self.std)) self.clamp = clamp def process_inputs(self, inputs: Union[torch.Tensor, np.ndarray]): @@ -404,12 +416,12 @@ def process_inputs(self, inputs: Union[torch.Tensor, np.ndarray]): Converts to tensor if necessary. Reshapes to [length, 1, 1] for pytorch broadcasting. """ if inputs is None: - return (inputs) + return inputs if type(inputs) == list: inputs = np.array(inputs).astype(np.float32) if type(inputs) == np.ndarray: inputs = torch.from_numpy(inputs) - assert (type(inputs) == torch.Tensor) + assert type(inputs) == torch.Tensor inputs = inputs.float() C = inputs.shape[0] inputs = inputs.reshape(C, 1, 1) @@ -435,7 +447,7 @@ def handle_tensor(self, tensor: torch.Tensor): elif tensor.ndim == 5: N, C, T, H, W = tensor.shape else: - raise ValueError('Tensor input to normalizer of unknown shape: {}'.format(tensor.shape)) + raise ValueError("Tensor input to normalizer of unknown shape: {}".format(tensor.shape)) t_d = tensor.device if t_d != self.mean.device: @@ -451,7 +463,7 @@ def handle_tensor(self, tensor: torch.Tensor): # this code simply repeats the mean T times, so it's # [R_mean, G_mean, B_mean, R_mean, G_mean, ... etc] n_repeats = C / c - assert (int(n_repeats) == n_repeats) + assert int(n_repeats) == n_repeats n_repeats = int(n_repeats) repeats = tuple([n_repeats] + [1 for i in range(self.mean.ndim - 1)]) self.mean = self.mean.repeat((repeats)) @@ -569,23 +581,23 @@ def flow_img_to_flow(img: np.ndarray, max_flow: Union[int, float] = 10) -> np.nd def module_to_dict(module, exclude=[], get_function=False): - """ Converts functions in a module to a dictionary. Useful for loading model types into a dictionary """ + """Converts functions in a module to a dictionary. Useful for loading model types into a dictionary""" module_dict = {} for x in dir(module): submodule = getattr(module, x) # print(x, submodule) func = isfunction(submodule) if get_function else not isfunction(submodule) - if (func and x not in exclude and submodule not in exclude): + if func and x not in exclude and submodule not in exclude: module_dict[x] = submodule return module_dict def get_models_from_module(module, get_function=False): - """ Hacky function for getting a dictionary of model: initializer from a module """ + """Hacky function for getting a dictionary of model: initializer from a module""" models = {} for importer, modname, ispkg in pkgutil.iter_modules(module.__path__): # print("Found submodule %s (is a package: %s)" % (modname, ispkg)) - total_name = module.__name__ + '.' + modname + total_name = module.__name__ + "." + modname this_module = __import__(total_name) submodule = getattr(module, modname) # module @@ -597,7 +609,7 @@ def get_models_from_module(module, get_function=False): def load_feature_extractor_components(model, checkpoint_file: Union[str, os.PathLike], component: str, device=None): - """ Loads individual component from a hidden two-stream model checkpoint + """Loads individual component from a hidden two-stream model checkpoint Parameters ---------- @@ -615,24 +627,24 @@ def load_feature_extractor_components(model, checkpoint_file: Union[str, os.Path model: nn.Module pytorch model with loaded weights """ - if component == 'spatial': - key = 'spatial_classifier' + '.' - elif component == 'flow': - key = 'flow_classifier' + '.' - elif component == 'fusion': - key = 'fusion.' + if component == "spatial": + key = "spatial_classifier" + "." + elif component == "flow": + key = "flow_classifier" + "." + elif component == "fusion": + key = "fusion." else: - raise ValueError('component not one of spatial or flow: {}'.format(component)) + raise ValueError("component not one of spatial or flow: {}".format(component)) # directory = os.path.dirname(checkpoint_file) # subdir = os.path.join(directory, component) # log.info('device: {}'.format(device)) - log.info('loading component {} from file {}'.format(component, checkpoint_file)) + log.info("loading component {} from file {}".format(component, checkpoint_file)) state_dict, _, _ = load_state_dict_from_file(checkpoint_file) # state = torch.load(checkpoint_file, map_location=device) # state_dict = state['state_dict'] - params = {k.replace(key, ''): v for k, v in state_dict.items() if k.startswith(key)} + params = {k.replace(key, ""): v for k, v in state_dict.items() if k.startswith(key)} # import pdb; pdb.set_trace() model = load_state_from_dict(model, params) # import pdb; pdb.set_trace() @@ -650,7 +662,7 @@ def load_feature_extractor_components(model, checkpoint_file: Union[str, os.Path def get_subfiles(root: Union[str, bytes, os.PathLike], return_type: str = None) -> list: - """ Helper function to get a list of files of certain type from a directory + """Helper function to get a list of files of certain type from a directory Parameters ---------- @@ -666,21 +678,21 @@ def get_subfiles(root: Union[str, bytes, os.PathLike], return_type: str = None) files: list list of absolute paths of sub-files """ - assert (return_type in [None, 'any', 'file', 'directory']) + assert return_type in [None, "any", "file", "directory"] files = os.listdir(root) files.sort() files = [os.path.join(root, i) for i in files] - if return_type is None or return_type == 'any': + if return_type is None or return_type == "any": pass - elif return_type == 'file': + elif return_type == "file": files = [i for i in files if os.path.isfile(i)] - elif return_type == 'directory': + elif return_type == "directory": files = [i for i in files if os.path.isdir(i)] return files def print_hdf5(h5py_obj, level=-1, print_full_name: bool = False, print_attrs: bool = True) -> None: - """ Prints the name and shape of datasets in a H5py HDF5 file. + """Prints the name and shape of datasets in a H5py HDF5 file. Parameters ---------- h5py_obj: [h5py.File, h5py.Group] @@ -696,6 +708,7 @@ def print_hdf5(h5py_obj, level=-1, print_full_name: bool = False, print_attrs: b ------- None """ + def is_group(f): return type(f) == h5py._hl.group.Group @@ -704,14 +717,14 @@ def is_dataset(f): def print_level(level, n_spaces=5) -> str: if level == -1: - return '' - prepend = '|' + ' ' * (n_spaces - 1) + return "" + prepend = "|" + " " * (n_spaces - 1) prepend *= level - tree = '|' + '-' * (n_spaces - 2) + ' ' + tree = "|" + "-" * (n_spaces - 2) + " " return prepend + tree if isinstance(h5py_obj, str) or isinstance(h5py_obj, os.PathLike): - with h5py.File(h5py_obj, 'r') as f: + with h5py.File(h5py_obj, "r") as f: print_hdf5(f) return @@ -719,15 +732,15 @@ def print_level(level, n_spaces=5) -> str: entry = h5py_obj[key] name = entry.name if print_full_name else os.path.basename(entry.name) if is_group(entry): - print('{}{}'.format(print_level(level), name)) + print("{}{}".format(print_level(level), name)) print_hdf5(entry, level + 1, print_full_name=print_full_name) elif is_dataset(entry): shape = entry.shape dtype = entry.dtype - print('{}{}: {} {}'.format(print_level(level), name, shape, dtype)) + print("{}{}: {} {}".format(print_level(level), name, shape, dtype)) if level == -1: if print_attrs: - print('attrs: ') + print("attrs: ") # @@ -766,16 +779,17 @@ def print_level(level, n_spaces=5) -> str: def print_top_largest_variables(local_call, num: int = 20): - def sizeof_fmt(num, suffix='B'): - ''' by Fred Cirera, https://stackoverflow.com/a/1094933/1870254, modified''' - for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: + def sizeof_fmt(num, suffix="B"): + """by Fred Cirera, https://stackoverflow.com/a/1094933/1870254, modified""" + for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: if abs(num) < 1024.0: return "%3.1f %s%s" % (num, unit, suffix) num /= 1024.0 - return "%.1f %s%s" % (num, 'Yi', suffix) + return "%.1f %s%s" % (num, "Yi", suffix) - for name, size in sorted(((name, sys.getsizeof(value)) for name, value in local_call.items()), - key=lambda x: -x[1])[:10]: + for name, size in sorted(((name, sys.getsizeof(value)) for name, value in local_call.items()), key=lambda x: -x[1])[ + :10 + ]: print("{:>30}: {:>8}".format(name, sizeof_fmt(size))) @@ -787,7 +801,7 @@ def get_hparams_from_cfg(cfg, hparams): def get_dotted_from_cfg(cfg, dotted): # cfg: DictConfig # dotted: string parameter name. can be nested. e.g. 'tune.hparams.feature_extractor.dropout_p.min' - key_list = dotted.split('.') + key_list = dotted.split(".") cfg_chunk = cfg.get(key_list[0]) for i in range(1, len(key_list)): @@ -797,34 +811,33 @@ def get_dotted_from_cfg(cfg, dotted): def get_best_epoch_from_weightfile(weightfile: Union[str, os.PathLike]) -> int: - """parses a checkpoint like epoch=15.ckpt to find the number 15 - """ + """parses a checkpoint like epoch=15.ckpt to find the number 15""" basename = os.path.basename(weightfile) # in the previous version of deepethogram, load the last checkpoint - if basename.endswith('.pt'): + if basename.endswith(".pt"): return -1 - assert basename.endswith('.ckpt') + assert basename.endswith(".ckpt") basename = os.path.splitext(basename)[0] # if weightfile is the "last" - if 'last' in basename: + if "last" in basename: return -1 - components = basename.split('-') + components = basename.split("-") component = components[0] - assert component.startswith('epoch') - best_epoch = component.split('=')[1] + assert component.startswith("epoch") + best_epoch = component.split("=")[1] return int(best_epoch) def remove_nans_and_infs(array: np.ndarray, set_value: float = 0.0) -> np.ndarray: - """ Simple function to remove nans and infs from a numpy array """ + """Simple function to remove nans and infs from a numpy array""" bad_indices = np.logical_or(np.isinf(array), np.isnan(array)) array[bad_indices] = set_value return array -def get_run_files_from_weights(weightfile: Union[str, os.PathLike], metrics_prefix='classification') -> dict: +def get_run_files_from_weights(weightfile: Union[str, os.PathLike], metrics_prefix="classification") -> dict: """from model weights, gets the configuration for that model and its metrics file Parameters @@ -838,15 +851,15 @@ def get_run_files_from_weights(weightfile: Union[str, os.PathLike], metrics_pref config_file: path to config file metrics_file: path to metrics file """ - loaded_config_file = os.path.join(os.path.dirname(weightfile), 'config.yaml') + loaded_config_file = os.path.join(os.path.dirname(weightfile), "config.yaml") if not os.path.isfile(loaded_config_file): # weight file should be at most one-subdirectory-down from rundir - loaded_config_file = os.path.join(os.path.dirname(os.path.dirname(weightfile)), 'config.yaml') - assert os.path.isfile(loaded_config_file), 'no associated config file for weights! {}'.format(weightfile) + loaded_config_file = os.path.join(os.path.dirname(os.path.dirname(weightfile)), "config.yaml") + assert os.path.isfile(loaded_config_file), "no associated config file for weights! {}".format(weightfile) - metrics_file = os.path.join(os.path.dirname(weightfile), f'{metrics_prefix}_metrics.h5') + metrics_file = os.path.join(os.path.dirname(weightfile), f"{metrics_prefix}_metrics.h5") if not os.path.isfile(metrics_file): - metrics_file = os.path.join(os.path.dirname(os.path.dirname(weightfile)), f'{metrics_prefix}_metrics.h5') - assert os.path.isfile(metrics_file), 'no associated metrics file for weights! {}'.format(weightfile) + metrics_file = os.path.join(os.path.dirname(os.path.dirname(weightfile)), f"{metrics_prefix}_metrics.h5") + assert os.path.isfile(metrics_file), "no associated metrics file for weights! {}".format(weightfile) - return dict(config_file=loaded_config_file, metrics_file=metrics_file) \ No newline at end of file + return dict(config_file=loaded_config_file, metrics_file=metrics_file) diff --git a/deepethogram/viz.py b/deepethogram/viz.py index 2314417..79f70ef 100644 --- a/deepethogram/viz.py +++ b/deepethogram/viz.py @@ -9,6 +9,7 @@ import h5py import matplotlib import numpy as np + # import tifffile as TIFF from matplotlib import pyplot as plt from matplotlib.animation import FuncAnimation @@ -16,24 +17,27 @@ import torch from deepethogram.flow_generator.utils import flow_to_rgb_polar + # from deepethogram.metrics import load_threshold_data from deepethogram.utils import tensor_to_np log = logging.getLogger(__name__) # override warning level for matplotlib, which outputs a million debugging statements -logging.getLogger('matplotlib').setLevel(logging.WARNING) - - -def imshow_with_colorbar(image: np.ndarray, - ax_handle, - fig_handle: matplotlib.figure.Figure, - clim: tuple = None, - cmap: str = None, - interpolation: str = None, - symmetric: bool = False, - func: str = 'imshow', - **kwargs) -> matplotlib.colorbar.Colorbar: - """ Show an image in a matplotlib figure with a colorbar *with the same height as the axis!!* +logging.getLogger("matplotlib").setLevel(logging.WARNING) + + +def imshow_with_colorbar( + image: np.ndarray, + ax_handle, + fig_handle: matplotlib.figure.Figure, + clim: tuple = None, + cmap: str = None, + interpolation: str = None, + symmetric: bool = False, + func: str = "imshow", + **kwargs, +) -> matplotlib.colorbar.Colorbar: + """Show an image in a matplotlib figure with a colorbar *with the same height as the axis!!* Without this function, matplotlib color bars can be taller than the axis which is ugly. @@ -73,26 +77,26 @@ def imshow_with_colorbar(image: np.ndarray, # if we get a vector, change into a row if image.ndim == 1: image = image[np.newaxis, :] - + if symmetric: - cmap = 'bwr' + cmap = "bwr" divider = make_axes_locatable(ax_handle) - if func == 'imshow': + if func == "imshow": im = ax_handle.imshow(image, interpolation=interpolation, cmap=cmap, **kwargs) - elif func == 'pcolor' or func == 'pcolormesh': + elif func == "pcolor" or func == "pcolormesh": im = ax_handle.pcolormesh(image, cmap=cmap, **kwargs) if symmetric: limit = np.max(np.abs((image.min(), image.max()))) im.set_clim(-limit, limit) if clim is not None: im.set_clim(clim[0], clim[1]) - cax = divider.append_axes('right', size='5%', pad=0.05) + cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig_handle.colorbar(im, cax=cax) return cbar def stack_image_list(image_list: list, num_cols: int = 4) -> np.ndarray: - """ Stacks a list of images into one image with a certain number of columns. Used for viewing many images at once + """Stacks a list of images into one image with a certain number of columns. Used for viewing many images at once Parameters ---------- @@ -132,15 +136,15 @@ def stack_image_list(image_list: list, num_cols: int = 4) -> np.ndarray: elif row.ndim == 3: pad_width = ((0, 0), (0, padval), (0, 0)) else: - raise ValueError('input with weird shape: {}'.format(row.shape)) + raise ValueError("input with weird shape: {}".format(row.shape)) row = np.pad(row, pad_width) rows.append(row) stack = np.vstack(rows) return stack -def plot_flow(rgb, ax, show_scale=True, height=30, maxval: float = 1.0, interpolation='nearest', inset_label=False): - """ Plots an optic flow in polar coordinates, with an inset colorbar """ +def plot_flow(rgb, ax, show_scale=True, height=30, maxval: float = 1.0, interpolation="nearest", inset_label=False): + """Plots an optic flow in polar coordinates, with an inset colorbar""" ax.imshow(rgb, interpolation=interpolation) if show_scale: x = np.linspace(-1, 1, 100) @@ -151,9 +155,7 @@ def plot_flow(rgb, ax, show_scale=True, height=30, maxval: float = 1.0, interpol aspect = ax.get_data_ratio() width = int(height * aspect) # https://stackoverflow.com/questions/53204267 - inset = inset_locator.inset_axes(ax, width=str(width) + '%', - height=str(height) + '%', - loc=1) + inset = inset_locator.inset_axes(ax, width=str(width) + "%", height=str(height) + "%", loc=1) # axes_class=get_projection_class('polar')) inset.imshow(flow_colorbar) inset.invert_yaxis() @@ -171,10 +173,17 @@ def plot_flow(rgb, ax, show_scale=True, height=30, maxval: float = 1.0, interpol return inset -def visualize_images_and_flows(downsampled_t0, flows_reshaped, sequence_length: int = 10, fig=None, - max_flow: float = 5.0, height=15, batch_ind: int = None): - """ Plot a list of images and optic flows """ - plt.style.use('ggplot') +def visualize_images_and_flows( + downsampled_t0, + flows_reshaped, + sequence_length: int = 10, + fig=None, + max_flow: float = 5.0, + height=15, + batch_ind: int = None, +): + """Plot a list of images and optic flows""" + plt.style.use("ggplot") if fig is None: fig = plt.figure(figsize=(16, 12)) @@ -185,11 +194,11 @@ def visualize_images_and_flows(downsampled_t0, flows_reshaped, sequence_length: if batch_ind is None: batch_ind = np.random.choice(batch_size) - inds = range(batch_ind*sequence_length, batch_ind*sequence_length + sequence_length) + inds = range(batch_ind * sequence_length, batch_ind * sequence_length + sequence_length) images = downsampled_t0[0][inds].detach().cpu().numpy().astype(np.float32) # N is actually N * T - image_list = [i.transpose(1,2,0) for i in images] + image_list = [i.transpose(1, 2, 0) for i in images] # image_list = [images[i, ...].transpose(1, 2, 0) for i in range(batch_ind * sequence_length, # batch_ind * sequence_length + sequence_length)] stack = stack_image_list(image_list) @@ -197,14 +206,14 @@ def visualize_images_and_flows(downsampled_t0, flows_reshaped, sequence_length: stack = (stack * 255).clip(min=0, max=255).astype(np.uint8) ax = axes[0] - ax.imshow(stack, interpolation='nearest') - ax.set_title('min: {:.4f} mean: {:.4f} max: {:.4f}'.format(minimum, mean, maximum)) + ax.imshow(stack, interpolation="nearest") + ax.set_title("min: {:.4f} mean: {:.4f} max: {:.4f}".format(minimum, mean, maximum)) ax.grid(False) - ax.axis('off') + ax.axis("off") ax = axes[1] flows = flows_reshaped[0][inds].detach().cpu().numpy().astype(np.float32) - flow_list = [i.transpose(1,2,0) for i in flows] + flow_list = [i.transpose(1, 2, 0) for i in flows] # flow_list = [flows[i, ...].transpose(1, 2, 0).astype(np.float32) for i in range(batch_ind * sequence_length, # batch_ind * sequence_length + sequence_length)] stack = stack_image_list(flow_list) @@ -212,11 +221,11 @@ def visualize_images_and_flows(downsampled_t0, flows_reshaped, sequence_length: stack = flow_to_rgb_polar(stack, maxval=max_flow) plot_flow(stack.clip(min=0, max=255).astype(np.uint8), ax, maxval=max_flow, inset_label=True, height=height) - ax.set_title('min: {:.4f} mean: {:.4f} max: {:.4f}'.format(minimum, mean, maximum)) + ax.set_title("min: {:.4f} mean: {:.4f} max: {:.4f}".format(minimum, mean, maximum)) ax.grid(False) - ax.axis('off') + ax.axis("off") - fig.suptitle('Images and flows. Batch element: {}'.format(batch_ind)) + fig.suptitle("Images and flows. Batch element: {}".format(batch_ind)) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -225,13 +234,21 @@ def visualize_images_and_flows(downsampled_t0, flows_reshaped, sequence_length: # plt.show() -def visualize_multiresolution(downsampled_t0, estimated_t0, flows_reshaped, sequence_length: int = 10, - max_flow: float = 5.0, height=15, batch_ind: int = None, fig=None, - sequence_ind: int = None): - """ visualize images, optic flows, and reconstructions at multiple resolutions at which the loss is actually +def visualize_multiresolution( + downsampled_t0, + estimated_t0, + flows_reshaped, + sequence_length: int = 10, + max_flow: float = 5.0, + height=15, + batch_ind: int = None, + fig=None, + sequence_ind: int = None, +): + """visualize images, optic flows, and reconstructions at multiple resolutions at which the loss is actually applied. useful for seeing what the loss function actually sees, and debugging multi-resolution issues """ - plt.style.use('ggplot') + plt.style.use("ggplot") if fig is None: fig = plt.figure(figsize=(16, 12)) @@ -252,52 +269,54 @@ def visualize_multiresolution(downsampled_t0, estimated_t0, flows_reshaped, sequ images = downsampled_t0[0].detach().cpu().numpy().astype(np.float32) index = batch_ind * sequence_length + sequence_ind - t0 = [downsampled_t0[i][index].detach().cpu().numpy().transpose(1, 2, 0).astype(np.float32) for i in - range(N_resolutions)] + t0 = [ + downsampled_t0[i][index].detach().cpu().numpy().transpose(1, 2, 0).astype(np.float32) + for i in range(N_resolutions) + ] for i, image in enumerate(t0): ax = axes[0, i] if i == 0: - ax.set_ylabel('T0', fontsize=18) + ax.set_ylabel("T0", fontsize=18) minimum, mean, maximum = image.min(), image.mean(), image.max() image = (image * 255).clip(min=0, max=255).astype(np.uint8) - ax.imshow(image, interpolation='nearest') - ax.set_title('min: {:.4f} mean: {:.4f} max: {:.4f}'.format(minimum, mean, maximum), - fontsize=8) + ax.imshow(image, interpolation="nearest") + ax.set_title("min: {:.4f} mean: {:.4f} max: {:.4f}".format(minimum, mean, maximum), fontsize=8) - t1 = [estimated_t0[i][index].detach().cpu().numpy().transpose(1, 2, 0).astype(np.float32) for i in - range(N_resolutions)] + t1 = [ + estimated_t0[i][index].detach().cpu().numpy().transpose(1, 2, 0).astype(np.float32) + for i in range(N_resolutions) + ] for i, image in enumerate(t1): ax = axes[1, i] minimum, mean, maximum = image.min(), image.mean(), image.max() image = (image * 255).clip(min=0, max=255).astype(np.uint8) - ax.imshow(image, interpolation='nearest') - ax.set_title('min: {:.4f} mean: {:.4f} max: {:.4f}'.format(minimum, mean, maximum), - fontsize=8) + ax.imshow(image, interpolation="nearest") + ax.set_title("min: {:.4f} mean: {:.4f} max: {:.4f}".format(minimum, mean, maximum), fontsize=8) if i == 0: - ax.set_ylabel('T1', fontsize=18) + ax.set_ylabel("T1", fontsize=18) - flows = [flows_reshaped[i][index].detach().cpu().numpy().transpose(1, 2, 0).astype(np.float32) for i in - range(N_resolutions)] + flows = [ + flows_reshaped[i][index].detach().cpu().numpy().transpose(1, 2, 0).astype(np.float32) + for i in range(N_resolutions) + ] for i, image in enumerate(flows): ax = axes[2, i] minimum, mean, maximum = image.min(), image.mean(), image.max() flow_im = flow_to_rgb_polar(image, maxval=max_flow) plot_flow(flow_im, ax, maxval=max_flow) - ax.set_title('min: {:.4f} mean: {:.4f} max: {:.4f}'.format(minimum, mean, maximum), - fontsize=8) + ax.set_title("min: {:.4f} mean: {:.4f} max: {:.4f}".format(minimum, mean, maximum), fontsize=8) if i == 0: - ax.set_ylabel('Flow', fontsize=18) + ax.set_ylabel("Flow", fontsize=18) L1s = [np.sum(np.abs(t0[i] - t1[i]), axis=2) for i in range(N_resolutions)] for i, image in enumerate(L1s): ax = axes[3, i] minimum, mean, maximum = image.min(), image.mean(), image.max() - ax.imshow(image, interpolation='nearest') - ax.set_title('min: {:.4f} mean: {:.4f} max: {:.4f}'.format(minimum, mean, maximum), - fontsize=8) + ax.imshow(image, interpolation="nearest") + ax.set_title("min: {:.4f} mean: {:.4f} max: {:.4f}".format(minimum, mean, maximum), fontsize=8) if i == 0: - ax.set_ylabel('L1', fontsize=18) + ax.set_ylabel("L1", fontsize=18) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -308,13 +327,15 @@ def tensor_to_list(images: torch.Tensor, batch_ind: int, channels: int = 3) -> l if images.ndim == 4: N, C, H, W = images.shape sequence_length = C // channels - image_list = [images[batch_ind, i * channels:i * channels + channels, ...].transpose(1, 2, 0) - for i in range(sequence_length)] + image_list = [ + images[batch_ind, i * channels : i * channels + channels, ...].transpose(1, 2, 0) + for i in range(sequence_length) + ] elif images.ndim == 5: N, C, T, H, W = images.shape image_list = [images[batch_ind, :, i, ...].transpose(1, 2, 0) for i in range(T)] else: - raise ValueError('weird shape of input: {}'.format(images.shape)) + raise ValueError("weird shape of input: {}".format(images.shape)) return image_list @@ -322,29 +343,39 @@ def predictions_labels_string(pred, label, class_names=None): if class_names is None: class_names = [i for i in range(len(pred))] inds = np.argsort(pred)[::-1] - string = 'label: ' + string = "label: " if label.ndim > 0: for i in range(len(label)): if label[i] == 1: - string += '{} '.format(class_names[i]) - string += '\n' + string += "{} ".format(class_names[i]) + string += "\n" else: - string += '{}'.format(label) + string += "{}".format(label) for i in range(10): if i >= len(inds): break ind = inds[i] - string += '{}: {:.3f} '.format(class_names[ind], pred[ind]) + string += "{}: {:.3f} ".format(class_names[ind], pred[ind]) if (i % 5) == 4: - string += '\n' + string += "\n" return string -def visualize_hidden(images, flows, predictions, labels, class_names: list = None, batch_ind: int = None, - max_flow: float = 5.0, height: float = 15.0, fig=None, normalizer=None): - """ Visualize inputs and outputs of a hidden two stream model """ +def visualize_hidden( + images, + flows, + predictions, + labels, + class_names: list = None, + batch_ind: int = None, + max_flow: float = 5.0, + height: float = 15.0, + fig=None, + normalizer=None, +): + """Visualize inputs and outputs of a hidden two stream model""" # import pdb; pdb.set_trace() - plt.style.use('ggplot') + plt.style.use("ggplot") if fig is None: fig = plt.figure(figsize=(16, 12)) @@ -366,10 +397,10 @@ def visualize_hidden(images, flows, predictions, labels, class_names: list = Non stack = (stack * 255).clip(min=0, max=255).astype(np.uint8) ax = axes[0] - ax.imshow(stack, interpolation='nearest') - ax.set_title('min: {:.4f} mean: {:.4f} max: {:.4f}'.format(minimum, mean, maximum), fontsize=8) + ax.imshow(stack, interpolation="nearest") + ax.set_title("min: {:.4f} mean: {:.4f} max: {:.4f}".format(minimum, mean, maximum), fontsize=8) ax.grid(False) - ax.axis('off') + ax.axis("off") ax = axes[1] flows = flows.detach().cpu().numpy() @@ -381,9 +412,9 @@ def visualize_hidden(images, flows, predictions, labels, class_names: list = Non # inset.set_xticklabels([-max_flow, 0, max_flow]) # inset.set_yticklabels([-max_flow, 0, max_flow]) - ax.set_title('min: {:.4f} mean: {:.4f} max: {:.4f}'.format(minimum, mean, maximum), fontsize=8) + ax.set_title("min: {:.4f} mean: {:.4f} max: {:.4f}".format(minimum, mean, maximum), fontsize=8) ax.grid(False) - ax.axis('off') + ax.axis("off") pred = predictions[batch_ind].detach().cpu().numpy() label = labels[batch_ind].detach().cpu().numpy() @@ -401,14 +432,21 @@ def visualize_hidden(images, flows, predictions, labels, class_names: list = Non def to_uint8(im: np.ndarray) -> np.ndarray: - """ helper function for converting from [0,1] float to [0, 255] uint8 """ + """helper function for converting from [0,1] float to [0, 255] uint8""" return (im.copy() * 255).clip(min=0, max=255).astype(np.uint8) -def visualize_batch_unsupervised(downsampled_t0, estimated_t0, flows_reshaped, batch_ind=0, sequence_ind: int = 0, - fig=None, sequence_length: int = 10): - """ Visualize t0, t1, optic flow, reconstruction, and the L1 between t0 and estimated t0 """ - plt.style.use('ggplot') +def visualize_batch_unsupervised( + downsampled_t0, + estimated_t0, + flows_reshaped, + batch_ind=0, + sequence_ind: int = 0, + fig=None, + sequence_length: int = 10, +): + """Visualize t0, t1, optic flow, reconstruction, and the L1 between t0 and estimated t0""" + plt.style.use("ggplot") if fig is None: fig = plt.figure(figsize=(16, 12)) @@ -418,44 +456,44 @@ def visualize_batch_unsupervised(downsampled_t0, estimated_t0, flows_reshaped, b ax = axes[0, 0] t0 = downsampled_t0[0][index].detach().cpu().numpy().transpose(1, 2, 0).astype(np.float32) - ax.imshow(to_uint8(t0), interpolation='nearest') - ax.set_title('min: {:.4f} max: {:.4f}'.format(t0.min(), t0.max())) + ax.imshow(to_uint8(t0), interpolation="nearest") + ax.set_title("min: {:.4f} max: {:.4f}".format(t0.min(), t0.max())) ax = axes[0, 1] t1 = downsampled_t0[0][index + 1].detach().cpu().numpy().transpose(1, 2, 0).astype(np.float32) - ax.imshow(to_uint8(t1), interpolation='nearest') - ax.set_title('min: {:.4f} max: {:.4f}'.format(t1.min(), t1.max())) + ax.imshow(to_uint8(t1), interpolation="nearest") + ax.set_title("min: {:.4f} max: {:.4f}".format(t1.min(), t1.max())) ax = axes[1, 0] flow = flows_reshaped[0][index].detach().cpu().numpy().transpose(1, 2, 0).astype(np.float32) - imshow_with_colorbar(flow[..., 0], ax, fig, symmetric=True, interpolation='nearest') + imshow_with_colorbar(flow[..., 0], ax, fig, symmetric=True, interpolation="nearest") ax = axes[1, 1] - imshow_with_colorbar(flow[..., 1], ax, fig, symmetric=True, interpolation='nearest') + imshow_with_colorbar(flow[..., 1], ax, fig, symmetric=True, interpolation="nearest") ax = axes[2, 0] est = estimated_t0[0][index].detach().cpu().numpy().transpose(1, 2, 0).astype(np.float32) - ax.imshow(to_uint8(est), interpolation='nearest') + ax.imshow(to_uint8(est), interpolation="nearest") ax = axes[2, 1] L1 = np.abs(est - t0.astype(np.float32)).sum(axis=2) - imshow_with_colorbar(L1, ax, fig, interpolation='nearest') + imshow_with_colorbar(L1, ax, fig, interpolation="nearest") # pdb.set_trace() - ax.set_title('L1') + ax.set_title("L1") plt.tight_layout() -def visualize_batch_spatial(images, predictions, labels, fig=None, class_names=None, num_cols: int=4): - """ visualize spatial stream of hidden two stream model """ +def visualize_batch_spatial(images, predictions, labels, fig=None, class_names=None, num_cols: int = 4): + """visualize spatial stream of hidden two stream model""" - plt.style.use('ggplot') + plt.style.use("ggplot") if fig is None: fig = plt.figure(figsize=(16, 12)) batch_size = images.shape[0] num_rows = int(min(np.ceil(batch_size / num_cols), 6)) - total_images = min(num_rows*num_cols, batch_size) + total_images = min(num_rows * num_cols, batch_size) # only use the first total_images elements in the batch, to try to reduce RAM usage images = images[:total_images].detach().cpu().numpy() @@ -464,8 +502,6 @@ def visualize_batch_spatial(images, predictions, labels, fig=None, class_names=N images = images.clip(min=0, max=1) - - axes = fig.subplots(num_rows, num_cols) cnt = 0 if num_rows == 1: @@ -480,7 +516,7 @@ def visualize_batch_spatial(images, predictions, labels, fig=None, class_names=N pred = predictions[cnt] label = labels[cnt] string = predictions_labels_string(pred, label, class_names) - string = '{:03d}: '.format(cnt) + string + string = "{:03d}: ".format(cnt) + string # spatial stream should almost always be one single image image = tensor_to_list(images, cnt)[0] @@ -491,12 +527,13 @@ def visualize_batch_spatial(images, predictions, labels, fig=None, class_names=N ax.set_title(string, size=8) cnt += 1 - fig.suptitle('Spatial stream') + fig.suptitle("Spatial stream") plt.tight_layout() del images, predictions, labels + def visualize_batch_sequence(sequence, outputs, labels, N_in_batch=None, fig=None): - """ Visualize an input sequence, probabilities, and the true labels """ + """Visualize an input sequence, probabilities, and the true labels""" if fig is None: fig = plt.figure(figsize=(16, 12)) @@ -518,33 +555,33 @@ def visualize_batch_sequence(sequence, outputs, labels, N_in_batch=None, fig=Non # tmp = outputs[N_in_batch] # seq = cv2.resize(sequence[N_in_batch], (tmp.shape[1]*10,tmp.shape[0]*10), interpolation=cv2.INTER_NEAREST) # seq = cv2.imresize(sequence[N_in_batch], ) - imshow_with_colorbar(sequence, ax, fig, interpolation='nearest', - symmetric=False, func='pcolor', cmap='viridis') + imshow_with_colorbar(sequence, ax, fig, interpolation="nearest", symmetric=False, func="pcolor", cmap="viridis") ax.invert_yaxis() - ax.set_ylabel('inputs') + ax.set_ylabel("inputs") ax = axes[1] - imshow_with_colorbar(outputs, ax, fig, interpolation='nearest', symmetric=False, cmap='Reds', - func='pcolor', clim=[0, 1]) + imshow_with_colorbar( + outputs, ax, fig, interpolation="nearest", symmetric=False, cmap="Reds", func="pcolor", clim=[0, 1] + ) ax.invert_yaxis() - ax.set_ylabel('P') + ax.set_ylabel("P") ax = axes[2] - imshow_with_colorbar(labels, ax, fig, interpolation='nearest', cmap='Reds', func='pcolor') + imshow_with_colorbar(labels, ax, fig, interpolation="nearest", cmap="Reds", func="pcolor") ax.invert_yaxis() - ax.set_ylabel('Labels') + ax.set_ylabel("Labels") ax = axes[3] dumb_loss = np.abs(outputs - labels) - imshow_with_colorbar(dumb_loss, ax, fig, interpolation='nearest', cmap='Reds', func='pcolor', clim=[0, 1]) - ax.set_title('L1 between outputs and labels (not true loss)') + imshow_with_colorbar(dumb_loss, ax, fig, interpolation="nearest", cmap="Reds", func="pcolor", clim=[0, 1]) + ax.set_title("L1 between outputs and labels (not true loss)") ax.invert_yaxis() plt.tight_layout() del sequence, outputs, labels def fig_to_img(fig_handle: matplotlib.figure.Figure) -> np.ndarray: - """ Convenience function for returning the RGB values of a matplotlib figure """ + """Convenience function for returning the RGB values of a matplotlib figure""" # should do nothing if already drawn fig_handle.canvas.draw() # from stack overflow @@ -570,8 +607,8 @@ def fig_to_img(fig_handle: matplotlib.figure.Figure) -> np.ndarray: # TIFF.imsave(tiff_fname, fig_mat, photometric='rgb', compress=0, metadata={'axes': 'TYXC'}) -def plot_histogram(array, ax, bins='auto', width_factor=0.9, rotation=30): - """ Helper function for plotting a histogram """ +def plot_histogram(array, ax, bins="auto", width_factor=0.9, rotation=30): + """Helper function for plotting a histogram""" if type(array) != np.ndarray: array = np.array(array) @@ -587,15 +624,16 @@ def plot_histogram(array, ax, bins='auto', width_factor=0.9, rotation=30): med = np.median(array) ylims = ax.get_ylim() - leg_str = 'median: %0.4f' % (med) - lineh = ax.plot(np.array([med, med]), np.array([ylims[0], ylims[1]]), - color='k', linestyle='dashed', lw=3, label=leg_str) - ax.set_ylabel('P') + leg_str = "median: %0.4f" % (med) + lineh = ax.plot( + np.array([med, med]), np.array([ylims[0], ylims[1]]), color="k", linestyle="dashed", lw=3, label=leg_str + ) + ax.set_ylabel("P") ax.legend() def errorfill(x, y, yerr, color=None, alpha_fill=0.3, ax=None, label=None): - """ Convenience function for plotting a shaded error bar """ + """Convenience function for plotting a shaded error bar""" ax = ax if ax is not None else plt.gca() # if color is None: # color = ax._get_lines.color_cycle.next() @@ -609,9 +647,9 @@ def errorfill(x, y, yerr, color=None, alpha_fill=0.3, ax=None, label=None): def plot_curve(x, ys, ax, xlabel: str = None, class_names=None, colors=None): - """ Plots a set of curves. Will add a scatter to the maximum of each curve with text indicating location """ + """Plots a set of curves. Will add a scatter to the maximum of each curve with text indicating location""" if colors is None: - colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] if x is None: x = np.arange(ys.shape[0]) if ys.ndim > 1: @@ -627,7 +665,7 @@ def plot_curve(x, ys, ax, xlabel: str = None, class_names=None, colors=None): max_acc = ys[index, i] scatter_x, scatter_y = remove_nan_or_inf(x[index]), remove_nan_or_inf(max_acc) ax.scatter(scatter_x, scatter_y) - text = '{:.2f}, {:.2f}'.format(x[index], max_acc) + text = "{:.2f}, {:.2f}".format(x[index], max_acc) text_x = x[index] + np.random.randn() / 20 text_y = max_acc + np.random.randn() / 20 text_x, text_y = remove_nan_or_inf(text_x), remove_nan_or_inf(text_y) @@ -643,45 +681,45 @@ def plot_curve(x, ys, ax, xlabel: str = None, class_names=None, colors=None): def thresholds_by_epoch_figure(epoch_summaries, class_names=None, fig=None): - plt.style.use('ggplot') + plt.style.use("ggplot") if fig is None: fig = plt.figure(figsize=(14, 14)) ax = fig.add_subplot(2, 3, 1) - split = 'train' + split = "train" - keys = ['accuracy', 'accuracy_valid_bg'] + keys = ["accuracy", "accuracy_valid_bg"] arr = np.vstack(([epoch_summaries[split][key] for key in keys])).T plot_curve(None, arr, ax, class_names=keys) - ax.set_ylabel('Train') + ax.set_ylabel("Train") ax = fig.add_subplot(2, 3, 2) - keys = ['f1_by_class', 'f1_by_class_valid_bg', 'f1_overall', 'f1_overall_valid_bg'] + keys = ["f1_by_class", "f1_by_class_valid_bg", "f1_overall", "f1_overall_valid_bg"] arr = np.vstack(([epoch_summaries[split][key] for key in keys])).T plot_curve(None, arr, ax, class_names=keys) # ax.set_ylabel('Train') ax = fig.add_subplot(2, 3, 3) - keys = ['auroc', 'auroc_by_class', 'mAP', 'mAP_by_class'] + keys = ["auroc", "auroc_by_class", "mAP", "mAP_by_class"] arr = np.vstack(([epoch_summaries[split][key] for key in keys])).T plot_curve(None, arr, ax, class_names=keys) ax = fig.add_subplot(2, 3, 4) - split = 'val' - keys = ['accuracy', 'accuracy_valid_bg'] + split = "val" + keys = ["accuracy", "accuracy_valid_bg"] arr = np.vstack(([epoch_summaries[split][key] for key in keys])).T plot_curve(None, arr, ax, class_names=keys) - ax.set_ylabel('Validation') + ax.set_ylabel("Validation") ax = fig.add_subplot(2, 3, 5) - keys = ['f1_by_class', 'f1_by_class_valid_bg', 'f1_overall', 'f1_overall_valid_bg'] + keys = ["f1_by_class", "f1_by_class_valid_bg", "f1_overall", "f1_overall_valid_bg"] arr = np.vstack(([epoch_summaries[split][key] for key in keys])).T plot_curve(None, arr, ax, class_names=keys) ax = fig.add_subplot(2, 3, 6) - keys = ['auroc', 'auroc_by_class', 'mAP', 'mAP_by_class'] + keys = ["auroc", "auroc_by_class", "mAP", "mAP_by_class"] arr = np.vstack(([epoch_summaries[split][key] for key in keys])).T plot_curve(None, arr, ax, class_names=keys) @@ -691,16 +729,15 @@ def thresholds_by_epoch_figure(epoch_summaries, class_names=None, fig=None): # TWEAKED FROM SCIKIT-LEARN -def plot_confusion_matrix(cm, classes, ax, fig, - normalize=False, - title='Confusion matrix', - cmap='Blues', colorbar=True, fontsize=None): +def plot_confusion_matrix( + cm, classes, ax, fig, normalize=False, title="Confusion matrix", cmap="Blues", colorbar=True, fontsize=None +): """ This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. """ if normalize: - cm = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-7) + cm = cm.astype("float") / (cm.sum(axis=1)[:, np.newaxis] + 1e-7) # print("Normalized confusion matrix") else: # print('Confusion matrix, without normalization') @@ -708,45 +745,44 @@ def plot_confusion_matrix(cm, classes, ax, fig, # print(cm) if colorbar: - cbar = imshow_with_colorbar(cm, ax, fig, interpolation='nearest', cmap=cmap) + cbar = imshow_with_colorbar(cm, ax, fig, interpolation="nearest", cmap=cmap) else: ax.imshow(cm, cmap=cmap) # ax.set_title(title) tick_marks = np.arange(0, len(classes)) ax.set_xticks(tick_marks) - ax.tick_params(axis='x', rotation=45) + ax.tick_params(axis="x", rotation=45) ax.set_yticks(tick_marks) ax.set_xticklabels(classes) ax.set_yticklabels(classes) - fmt = '.2f' if normalize else 'd' + fmt = ".2f" if normalize else "d" if not normalize: cm = cm.astype(int) - thresh = cm.max() / 2. + thresh = cm.max() / 2.0 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): j, i = remove_nan_or_inf(j), remove_nan_or_inf(i) - element = cm[i,j] + element = cm[i, j] if element < 1e-2: element = 0 - fmt = 'd' + fmt = "d" else: - fmt = '.2f' if normalize else 'd' + fmt = ".2f" if normalize else "d" text = format(element, fmt) - if text.startswith('0.'): + if text.startswith("0."): text = text[1:] - ax.text(j, i, text, - horizontalalignment="center", - color="white" if cm[i, j] > thresh else "black", - fontsize=fontsize) + ax.text( + j, i, text, horizontalalignment="center", color="white" if cm[i, j] > thresh else "black", fontsize=fontsize + ) ax.set_xlim([-0.5, len(classes) - 0.5]) ax.set_ylim([len(classes) - 0.5, -0.5]) plt.tight_layout() - ax.set_ylabel('True label') - ax.set_xlabel('Predicted label') + ax.set_ylabel("True label") + ax.set_xlabel("Predicted label") def remove_nan_or_inf(value: Union[int, float]): - """ removes nans or infs. can happen in edge cases for plotting """ + """removes nans or infs. can happen in edge cases for plotting""" if np.isnan(value) or np.isinf(value): return 0 return value @@ -792,52 +828,50 @@ def remove_nan_or_inf(value: Union[int, float]): def plot_confusion_from_logger(logger_file, fig, class_names=None, epoch=None): - """ Plots train and validation confusion matrices from a Metrics file """ - with h5py.File(logger_file, 'r') as f: - best_epoch = np.argmax(f['val/' + f.attrs['key_metric']][:]) + """Plots train and validation confusion matrices from a Metrics file""" + with h5py.File(logger_file, "r") as f: + best_epoch = np.argmax(f["val/" + f.attrs["key_metric"]][:]) if epoch is None: epoch = best_epoch - if epoch == 'last': + if epoch == "last": epoch = -1 splits = list(f.keys()) - if 'train' in splits: - cm_train = f['train/confusion'][epoch, ...].astype(np.int64) + if "train" in splits: + cm_train = f["train/confusion"][epoch, ...].astype(np.int64) else: cm_train = np.array([np.nan]) - if 'val' in splits: - cm_val = f['val/confusion'][epoch, ...].astype(np.int64) + if "val" in splits: + cm_val = f["val/confusion"][epoch, ...].astype(np.int64) else: cm_val = np.array([np.nan]) if class_names is None: class_names = np.arange(cm_train.shape[0]) ax0 = fig.add_subplot(221) plot_confusion_matrix(cm_train, class_names, ax0, fig) - ax0.set_title('Train') + ax0.set_title("Train") ax1 = fig.add_subplot(222) - plot_confusion_matrix(cm_train, class_names, ax1, fig, - normalize=True) + plot_confusion_matrix(cm_train, class_names, ax1, fig, normalize=True) ax0 = fig.add_subplot(223) plot_confusion_matrix(cm_val, class_names, ax0, fig) - ax0.set_title('Val') + ax0.set_title("Val") ax1 = fig.add_subplot(224) - plot_confusion_matrix(cm_val, class_names, ax1, fig, - normalize=True) - fig.suptitle('Confusion matrices at epoch: %d' % (epoch)) + plot_confusion_matrix(cm_val, class_names, ax1, fig, normalize=True) + fig.suptitle("Confusion matrices at epoch: %d" % (epoch)) plt.subplots_adjust(top=0.8) plt.tight_layout() -def make_precision_recall_figure(logger_file, fig=None, splits=['train', 'val']): - """ Plots precision vs recall """ - colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] +def make_precision_recall_figure(logger_file, fig=None, splits=["train", "val"]): + """Plots precision vs recall""" + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] if fig is None: fig = plt.figure(figsize=(14, 7)) for i, split in enumerate(splits): - ap_by_class = load_logger_data(logger_file, 'mAP_by_class', split) - precision = load_logger_data(logger_file, 'precision', split, is_threshold=True) - recall = load_logger_data(logger_file, 'recall', split, is_threshold=True) + ap_by_class = load_logger_data(logger_file, "mAP_by_class", split) + precision = load_logger_data(logger_file, "precision", split, is_threshold=True) + recall = load_logger_data(logger_file, "recall", split, is_threshold=True) ax = fig.add_subplot(1, len(splits), i + 1) # precision, recall = train_metrics['precision'], train_metrics['recall'] @@ -849,16 +883,18 @@ def make_precision_recall_figure(logger_file, fig=None, splits=['train', 'val']) y = precision[:, j] # there's a bug in how this is computed au_prc = ap_by_class[j] - string = '{}: {:.4f}'.format(j, au_prc) + string = "{}: {:.4f}".format(j, au_prc) ax.plot(x, y, color=color, label=string) - ax.set_aspect('equal', 'box') + ax.set_aspect("equal", "box") ax.legend() - ax.set_xlabel('Recall') - ax.set_ylabel('Precision') + ax.set_xlabel("Recall") + ax.set_ylabel("Precision") ax.set_title(split) - fig.suptitle('Precision vs recall. Legend: Average Precision\nNote: curves are approximated with only ' + - '101 thresholds. Legend is exact') + fig.suptitle( + "Precision vs recall. Legend: Average Precision\nNote: curves are approximated with only " + + "101 thresholds. Legend is exact" + ) plt.tight_layout() return fig @@ -870,12 +906,13 @@ def add_text_to_line(xs, ys, ax, color): if np.isinf(x) or np.isnan(x) or np.isinf(y) or np.isnan(y): return # x, y = remove_nan_or_inf(x), remove_nan_or_inf(y) - ax.text(x, y, '{:.4f}'.format(y), color=color) + ax.text(x, y, "{:.4f}".format(y), color=color) -def plot_metric(data: Union[dict, OrderedDict], name, ax, legend: bool = False, plot_args: dict = None, - color_inds: list = None): - colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] +def plot_metric( + data: Union[dict, OrderedDict], name, ax, legend: bool = False, plot_args: dict = None, color_inds: list = None +): + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] # data = {'train': train, 'val': val} for i, (split, array) in enumerate(data.items()): xs = np.arange(len(array)) @@ -893,7 +930,7 @@ def plot_metric(data: Union[dict, OrderedDict], name, ax, legend: bool = False, ax.set_xlim([-0.5, len(xs) - 0.5]) ax.set_ylabel(name) - ax.set_xlabel('Epochs') + ax.set_xlabel("Epochs") ax.set_title(name) if legend: ax.legend() @@ -901,90 +938,88 @@ def plot_metric(data: Union[dict, OrderedDict], name, ax, legend: bool = False, def make_learning_curves_figure_multilabel_classification(logger_file, fig=None): def get_data_from_file(f, name): - data = OrderedDict(train=f[f'train/{name}_overall'][:], - train_class_mean=f[f'train/{name}_class_mean'][:], - val=f[f'val/{name}_overall'][:], - val_class_mean=f[f'val/{name}_class_mean'][:]) + data = OrderedDict( + train=f[f"train/{name}_overall"][:], + train_class_mean=f[f"train/{name}_class_mean"][:], + val=f[f"val/{name}_overall"][:], + val_class_mean=f[f"val/{name}_class_mean"][:], + ) return data - with h5py.File(logger_file, 'r') as f: - plt.style.use('seaborn') + with h5py.File(logger_file, "r") as f: + plt.style.use("seaborn") if fig is None: fig = plt.figure(figsize=(12, 12)) # loss and learning rate ax = fig.add_subplot(4, 2, 1) - data = OrderedDict(train=f['train/loss'][:], - val=f['val/loss'][:]) + data = OrderedDict(train=f["train/loss"][:], val=f["val/loss"][:]) # import pdb; pdb.set_trace() - plot_metric(data, 'loss', ax) + plot_metric(data, "loss", ax) ax2 = ax.twinx() - ax2.plot(f['train/lr'][:], 'k', label='LR', alpha=0.5) - ax2.set_ylabel('learning rate') + ax2.plot(f["train/lr"][:], "k", label="LR", alpha=0.5) + ax2.set_ylabel("learning rate") ax2.grid(False) ax = fig.add_subplot(4, 2, 2) - data = OrderedDict(train=f['train/data_loss'][:], - val=f['val/data_loss'][:]) + data = OrderedDict(train=f["train/data_loss"][:], val=f["val/data_loss"][:]) # import pdb; pdb.set_trace() - plot_metric(data, 'data_loss', ax) - + plot_metric(data, "data_loss", ax) + ax = fig.add_subplot(4, 2, 3) - data = OrderedDict(train=f['train/reg_loss'][:], - val=f['val/reg_loss'][:]) + data = OrderedDict(train=f["train/reg_loss"][:], val=f["val/reg_loss"][:]) # import pdb; pdb.set_trace() - plot_metric(data, 'reg_loss', ax) + plot_metric(data, "reg_loss", ax) # FPS ax = fig.add_subplot(4, 2, 4) try: - data = OrderedDict(train=f['train/fps'][:], - val=f['val/fps'][:], - speedtest=f['speedtest/fps'][:]) + data = OrderedDict(train=f["train/fps"][:], val=f["val/fps"][:], speedtest=f["speedtest/fps"][:]) except Exception: # likely don't have speedtest, not too important - data = OrderedDict(train=f['train/fps'][:], - val=f['val/fps'][:]) + data = OrderedDict(train=f["train/fps"][:], val=f["val/fps"][:]) - plot_metric(data, 'FPS', ax, legend=True) + plot_metric(data, "FPS", ax, legend=True) ax.semilogy() # accuracy ax = fig.add_subplot(4, 2, 5) - data = OrderedDict(train=f['train/accuracy_overall'][:], - val=f['val/accuracy_overall'][:]) + data = OrderedDict(train=f["train/accuracy_overall"][:], val=f["val/accuracy_overall"][:]) - plot_metric(data, 'accuracy', ax) + plot_metric(data, "accuracy", ax) # F1 score! ax = fig.add_subplot(4, 2, 6) - data = OrderedDict(train=f['train/f1_overall'][:], - train_class_mean=f['train/f1_class_mean'][:], - train_class_mean_nobg=f['train/f1_class_mean_nobg'][:], - val=f['val/f1_overall'][:], - val_class_mean=f['val/f1_class_mean'][:], - val_class_mean_nobg=f['val/f1_class_mean_nobg'][:]) + data = OrderedDict( + train=f["train/f1_overall"][:], + train_class_mean=f["train/f1_class_mean"][:], + train_class_mean_nobg=f["train/f1_class_mean_nobg"][:], + val=f["val/f1_overall"][:], + val_class_mean=f["val/f1_class_mean"][:], + val_class_mean_nobg=f["val/f1_class_mean_nobg"][:], + ) # we'll reuse these for the following figures - plot_args = {'train_class_mean': {'linestyle': '--'}, - 'train_class_mean_nobg': {'linestyle': 'dotted'}, - 'val_class_mean': {'linestyle': '--'}, - 'val_class_mean_nobg': {'linestyle': 'dotted'},} + plot_args = { + "train_class_mean": {"linestyle": "--"}, + "train_class_mean_nobg": {"linestyle": "dotted"}, + "val_class_mean": {"linestyle": "--"}, + "val_class_mean_nobg": {"linestyle": "dotted"}, + } color_inds = [0, 0, 0, 1, 1, 1] # data = get_data_from_file(f, 'f1') - plot_metric(data, 'F1', ax, True, plot_args, color_inds) + plot_metric(data, "F1", ax, True, plot_args, color_inds) # AUROC - plot_args = {'train_class_mean': {'linestyle': '--'}, - 'val_class_mean': {'linestyle': '--'}} + plot_args = {"train_class_mean": {"linestyle": "--"}, "val_class_mean": {"linestyle": "--"}} color_inds = [0, 0, 1, 1] ax = fig.add_subplot(4, 2, 7) - data = get_data_from_file(f, 'auroc') - plot_metric(data, 'AUROC', ax, False, plot_args, color_inds) + data = get_data_from_file(f, "auroc") + plot_metric(data, "AUROC", ax, False, plot_args, color_inds) # Average precision ax = fig.add_subplot(4, 2, 8) - data = get_data_from_file(f, 'mAP') - plot_metric(data, 'Average Precision', ax, False, plot_args, color_inds) + data = get_data_from_file(f, "mAP") + plot_metric(data, "Average Precision", ax, False, plot_args, color_inds) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -994,16 +1029,15 @@ def get_data_from_file(f, name): def plot_multilabel_by_class(logger_file): def load_data(f, name): - data = {'train': f[f'train/{name}_by_class'][:], - 'val': f[f'val/{name}_by_class'][:]} + data = {"train": f[f"train/{name}_by_class"][:], "val": f[f"val/{name}_by_class"][:]} return data - with h5py.File(logger_file, 'r') as f: + with h5py.File(logger_file, "r") as f: def plot_row(row, name, legend: bool = False, title: bool = False): data = load_data(f, name) - for i, split in enumerate(['train', 'val']): + for i, split in enumerate(["train", "val"]): array = data[split] ax = row[i] # loop over classes @@ -1011,25 +1045,25 @@ def plot_row(row, name, legend: bool = False, title: bool = False): for j in range(array.shape[1]): class_data[j] = array[:, j] plot_metric(class_data, name, ax, legend and i == 0) - ax.set_xlabel('') + ax.set_xlabel("") if title: ax.set_title(split) else: - ax.set_title('') + ax.set_title("") fig, axes = plt.subplots(4, 2, figsize=(8, 12)) row = axes[0] - plot_row(row, 'accuracy', True, True) + plot_row(row, "accuracy", True, True) row = axes[1] - plot_row(row, 'f1') + plot_row(row, "f1") row = axes[2] - plot_row(row, 'auroc') + plot_row(row, "auroc") row = axes[3] - plot_row(row, 'mAP') + plot_row(row, "mAP") with warnings.catch_warnings(): warnings.simplefilter("ignore") plt.tight_layout() @@ -1038,44 +1072,43 @@ def plot_row(row, name, legend: bool = False, title: bool = False): def load_logger_data(logger_file, name, split, is_threshold: bool = False, epoch: int = -1): if is_threshold: - key = f'{split}/metrics_by_threshold/{name}' + key = f"{split}/metrics_by_threshold/{name}" else: - key = f'{split}/{name}' - with h5py.File(logger_file, 'r') as f: + key = f"{split}/{name}" + with h5py.File(logger_file, "r") as f: data = f[key][epoch, ...] return data def make_thresholds_figure(logger_file, split, fig=None, class_names=None): - plt.style.use('seaborn') + plt.style.use("seaborn") if fig is None: fig = plt.figure(figsize=(12, 12)) # axes = axes.flatten() - x = load_logger_data(logger_file, 'thresholds', split, True) + x = load_logger_data(logger_file, "thresholds", split, True) - for i, metric in enumerate(['accuracy', 'f1', 'precision', 'recall', 'informedness']): + for i, metric in enumerate(["accuracy", "f1", "precision", "recall", "informedness"]): ax = fig.add_subplot(3, 2, i + 1) y = load_logger_data(logger_file, metric, split, True) plot_curve(x, y, ax, class_names) - ax.set_title(f'{metric} by class') + ax.set_title(f"{metric} by class") plt.tight_layout() return fig -def make_roc_figure(logger_file, fig=None, splits=['train', 'val']): - colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] +def make_roc_figure(logger_file, fig=None, splits=["train", "val"]): + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] if fig is None: fig = plt.figure(figsize=(14, 7)) for i, split in enumerate(splits): - - auroc_by_class = load_logger_data(logger_file, 'auroc_by_class', split) - tpr = load_logger_data(logger_file, 'tpr', split, is_threshold=True) - fpr = load_logger_data(logger_file, 'fpr', split, is_threshold=True) + auroc_by_class = load_logger_data(logger_file, "auroc_by_class", split) + tpr = load_logger_data(logger_file, "tpr", split, is_threshold=True) + fpr = load_logger_data(logger_file, "fpr", split, is_threshold=True) ax = fig.add_subplot(1, len(splits), i + 1) @@ -1083,30 +1116,30 @@ def make_roc_figure(logger_file, fig=None, splits=['train', 'val']): for j in range(K): color = colors[j] if j < len(colors) else colors[-1] auroc = auroc_by_class[j] - string = '{}: {:4f}'.format(j, auroc) + string = "{}: {:4f}".format(j, auroc) ax.plot(fpr[:, j], tpr[:, j], color=color, label=string) ax.legend() - ax.set_xlabel('FPR') - ax.set_ylabel('TPR') + ax.set_xlabel("FPR") + ax.set_ylabel("TPR") ax.set_title(split) - fig.suptitle('ROC Curves. Curves are approximate because only 101 thresholds were used. AUC values are precise') + fig.suptitle("ROC Curves. Curves are approximate because only 101 thresholds were used. AUC values are precise") plt.tight_layout() return fig -def visualize_binary_confusion(logger_file, fig=None, splits=['train', 'val']): - """ Visualizes binary confusion matrices """ +def visualize_binary_confusion(logger_file, fig=None, splits=["train", "val"]): + """Visualizes binary confusion matrices""" if fig is None: fig = plt.figure(figsize=(14, 14)) - cms = load_logger_data(logger_file, 'binary_confusion', 'train') + cms = load_logger_data(logger_file, "binary_confusion", "train") # if there's more than 3 dimensions, it could be [epochs, classes, 2, 2] # take the last one if cms.ndim > 3: cms = cms[-1, ...] K = cms.shape[0] - num_rows = len(splits)*2 + num_rows = len(splits) * 2 num_cols = K ind = 1 @@ -1118,68 +1151,67 @@ def plot_cms_in_row(cms, ylabel, normalize: bool = False): ax = fig.add_subplot(num_rows, num_cols, ind) cm = cms[j, ...] # print(cm.shape) - plot_confusion_matrix(cms[j, ...], range(cm.shape[0]), - ax, fig, colorbar=False, normalize=normalize) + plot_confusion_matrix(cms[j, ...], range(cm.shape[0]), ax, fig, colorbar=False, normalize=normalize) if j == 0: ax.set_ylabel(ylabel) - ax.set_xlabel('') + ax.set_xlabel("") else: - ax.set_ylabel('') - ax.set_xlabel('') + ax.set_ylabel("") + ax.set_xlabel("") ind += 1 for split in splits: - cms = load_logger_data(logger_file, 'binary_confusion', split) + cms = load_logger_data(logger_file, "binary_confusion", split) plot_cms_in_row(cms, split) - plot_cms_in_row(cms, f'{split}\nNormalized', normalize=True) + plot_cms_in_row(cms, f"{split}\nNormalized", normalize=True) plt.tight_layout() return fig def visualize_logger_multilabel_classification(logger_file): - """ makes a bunch of figures from a Metrics hdf5 file """ - plt.style.use('seaborn') + """makes a bunch of figures from a Metrics hdf5 file""" + plt.style.use("seaborn") fig = make_learning_curves_figure_multilabel_classification(logger_file) - save_figure(fig, 'learning_curves', False, 0) + save_figure(fig, "learning_curves", False, 0) fig = plot_multilabel_by_class(logger_file) - save_figure(fig, 'learning_curves_by_class', False, 1) + save_figure(fig, "learning_curves_by_class", False, 1) - fig = make_thresholds_figure(logger_file, 'train') - save_figure(fig, 'thresholds_this_epoch_train', False, 2) + fig = make_thresholds_figure(logger_file, "train") + save_figure(fig, "thresholds_this_epoch_train", False, 2) - fig = make_thresholds_figure(logger_file, 'val') - save_figure(fig, 'thresholds_this_epoch_val', False, 3) + fig = make_thresholds_figure(logger_file, "val") + save_figure(fig, "thresholds_this_epoch_val", False, 3) fig = visualize_binary_confusion(logger_file) - save_figure(fig, 'binary_confusion', False, 4) + save_figure(fig, "binary_confusion", False, 4) fig = make_roc_figure(logger_file) - save_figure(fig, 'ROC', False, 5) + save_figure(fig, "ROC", False, 5) fig = make_precision_recall_figure(logger_file) - save_figure(fig, 'precision_recall', False, 6) + save_figure(fig, "precision_recall", False, 6) try: - splits = ['train', 'val', 'test'] - fig = make_thresholds_figure(logger_file, 'test') - save_figure(fig, 'thresholds_this_epoch_test', False, 7) + splits = ["train", "val", "test"] + fig = make_thresholds_figure(logger_file, "test") + save_figure(fig, "thresholds_this_epoch_test", False, 7) fig = visualize_binary_confusion(logger_file, splits=splits) - save_figure(fig, 'binary_confusion_with_test', False, 8) + save_figure(fig, "binary_confusion_with_test", False, 8) fig = make_roc_figure(logger_file, splits=splits) - save_figure(fig, 'ROC_with_test', False, 9) + save_figure(fig, "ROC_with_test", False, 9) - fig = make_precision_recall_figure(logger_file, splits=['train', 'val', 'test']) - save_figure(fig, 'precision_recall_withtest', False, 10) + fig = make_precision_recall_figure(logger_file, splits=["train", "val", "test"]) + save_figure(fig, "precision_recall_withtest", False, 10) except Exception as e: # no test set yet - log.debug('error in test set viz: {}'.format(e)) + log.debug("error in test set viz: {}".format(e)) # pass - plt.close('all') + plt.close("all") def make_learning_curves_figure_opticalflow(logger_file, fig=None): @@ -1187,29 +1219,28 @@ def make_learning_curves_figure_opticalflow(logger_file, fig=None): fig = plt.figure(figsize=(12, 12)) def get_data(h5py_obj, name): - data = OrderedDict(train=h5py_obj[f'train/{name}'][:], - val=h5py_obj[f'val/{name}'][:]) + data = OrderedDict(train=h5py_obj[f"train/{name}"][:], val=h5py_obj[f"val/{name}"][:]) return data ax = fig.add_subplot(4, 2, 1) - with h5py.File(logger_file, 'r') as f: - data = get_data(f, 'loss') - plot_metric(data, 'loss', ax) + with h5py.File(logger_file, "r") as f: + data = get_data(f, "loss") + plot_metric(data, "loss", ax) ax2 = ax.twinx() - ax2.plot(f['train/lr'][:], 'k', label='LR', alpha=0.5) - ax2.set_ylabel('learning rate') + ax2.plot(f["train/lr"][:], "k", label="LR", alpha=0.5) + ax2.set_ylabel("learning rate") ax2.grid(False) - keys = list(f['train'].keys()) + keys = list(f["train"].keys()) plot_ind = 2 - for metric in ['fps','reg_loss', 'SSIM', 'L1', 'smoothness', 'sparsity']: + for metric in ["fps", "reg_loss", "SSIM", "L1", "smoothness", "sparsity"]: if metric in keys: ax = fig.add_subplot(4, 2, plot_ind) data = get_data(f, metric) - plot_metric(data, metric, ax, legend=metric == 'fps') + plot_metric(data, metric, ax, legend=metric == "fps") plot_ind += 1 @@ -1218,32 +1249,32 @@ def get_data(h5py_obj, name): def visualize_logger_optical_flow(logger_file): - """ makes a bunch of figures from a Metrics hdf5 file """ - plt.style.use('seaborn') + """makes a bunch of figures from a Metrics hdf5 file""" + plt.style.use("seaborn") fig = make_learning_curves_figure_opticalflow(logger_file) - save_figure(fig, 'learning_curves', False, 0) + save_figure(fig, "learning_curves", False, 0) hues = [212, 4, 121, 36, 55, 276, 237, 299, 186] hues = np.array(hues) / 360 * 180 -saturation = .85 * 255 -value = .95 * 255 +saturation = 0.85 * 255 +value = 0.95 * 255 start = [0, 0, value] gray_value = 102 class Mapper: - """ Applies a custom colormap to a K x T matrix. Used in the GUI to visualize probabilities and labels """ + """Applies a custom colormap to a K x T matrix. Used in the GUI to visualize probabilities and labels""" - def __init__(self, colormap='deepethogram'): - if colormap == 'deepethogram': + def __init__(self, colormap="deepethogram"): + if colormap == "deepethogram": self.init_deepethogram() else: try: self.cmap = plt.get_cmap(colormap) except ValueError: - raise ('Colormap not in matplotlib''s defaults! {}'.format(colormap)) + raise ("Colormap not in matplotlib" "s defaults! {}".format(colormap)) def init_deepethogram(self): gray_LUT = make_LUT([0, 0, value], [0, 0, gray_value]) @@ -1315,11 +1346,14 @@ def apply_cmap(array: Union[np.ndarray, int, float], LUT: np.ndarray) -> np.ndar elif array.min() >= 0 and array.max() <= 255: array = array.astype(np.uint8) else: - raise ValueError('Float arrays must be in the range of either [0, 1] or [0, 255], not [{},{}]'.format( - array.min(), array.max())) + raise ValueError( + "Float arrays must be in the range of either [0, 1] or [0, 255], not [{},{}]".format( + array.min(), array.max() + ) + ) if LUT.dtype != np.uint8: - raise ValueError('LUT must be uint8, not {}'.format(LUT.dtype)) + raise ValueError("LUT must be uint8, not {}".format(LUT.dtype)) if len(array.shape) < 2: array = np.vstack([array, array, array]).T[None, ...] elif array.shape[1] != 3: @@ -1334,31 +1368,40 @@ def apply_cmap(array: Union[np.ndarray, int, float], LUT: np.ndarray) -> np.ndar return mapped -def plot_ethogram(ethogram: np.ndarray, mapper, start_index: Union[int, float], - ax, classes: list = None, rotation: int = 15, ylabel: str = None): - """ Visualizes a K x T ethogram using some mapper """ +def plot_ethogram( + ethogram: np.ndarray, + mapper, + start_index: Union[int, float], + ax, + classes: list = None, + rotation: int = 15, + ylabel: str = None, +): + """Visualizes a K x T ethogram using some mapper""" # assume inputs is T x K im = mapper(ethogram.T) - im_h = ax.imshow(im, aspect='auto', interpolation='nearest') + im_h = ax.imshow(im, aspect="auto", interpolation="nearest") xticks = ax.get_xticks() new_ticks = [i + start_index for i in xticks] ax.set_xticklabels([str(int(i)) for i in new_ticks]) ax.set_yticks(np.arange(0, ethogram.shape[1])) if classes is not None: - ax.set_yticklabels(classes, rotation=rotation, fontdict={'fontsize': 12}) + ax.set_yticklabels(classes, rotation=rotation, fontdict={"fontsize": 12}) ax.set_ylabel(ylabel) return im_h -def make_ethogram_movie(outfile: Union[str, bytes, os.PathLike], - ethogram: np.ndarray, - mapper, - frames: list, - start: int, - classes: list, - width: int = 100, - fps: float = 30): - """ Makes a movie out of an ethogram. Can be very slow due to matplotlib's animations """ +def make_ethogram_movie( + outfile: Union[str, bytes, os.PathLike], + ethogram: np.ndarray, + mapper, + frames: list, + start: int, + classes: list, + width: int = 100, + fps: float = 30, +): + """Makes a movie out of an ethogram. Can be very slow due to matplotlib's animations""" if mapper is None: mapper = Mapper() @@ -1380,13 +1423,12 @@ def make_ethogram_movie(outfile: Union[str, bytes, os.PathLike], framenum = 0 im_h = ax0.imshow(frames[0]) - etho_h = plot_ethogram(ethogram[starts[0]:starts[0] + width, :], - mapper, start + framenum, ax1, classes) + etho_h = plot_ethogram(ethogram[starts[0] : starts[0] + width, :], mapper, start + framenum, ax1, classes) ylim = ax1.get_ylim() x = (0, 1, 1, 0, 0) y = (ylim[0], ylim[0], ylim[1], ylim[1], ylim[0]) - plot_h = ax1.plot(x, y, color='k', lw=0.5)[0] - title_h = ax0.set_title('{:,}: {}'.format(start, classes[np.where(ethogram[0])[0]].tolist())) + plot_h = ax1.plot(x, y, color="k", lw=0.5)[0] + title_h = ax0.set_title("{:,}: {}".format(start, classes[np.where(ethogram[0])[0]].tolist())) plt.tight_layout() # etho_h = plot_ethogram(ethogram[starts[0]:starts[0] + width, :], @@ -1403,8 +1445,9 @@ def animate(i): x = (x0, x1, x1, x0, x0) # print(x) if (i % width) == 0: - etho_h = plot_ethogram(ethogram[starts[i // width]:starts[i // width] + width, :], - mapper, start + i, ax1, classes) + etho_h = plot_ethogram( + ethogram[starts[i // width] : starts[i // width] + width, :], mapper, start + i, ax1, classes + ) # no idea why plot ethogram doesn't change this xticks = ax1.get_xticks() new_ticks = xticks + starts[i // width] + start @@ -1414,34 +1457,35 @@ def animate(i): etho_h = [i for i in ax1.get_children() if type(i) == matplotlib.image.AxesImage][0] plot_h.set_xdata(x) - title_h.set_text('{:,}: {}'.format(start + i, classes[np.where(ethogram[i])[0]].tolist())) + title_h.set_text("{:,}: {}".format(start + i, classes[np.where(ethogram[i])[0]].tolist())) return [im_h, etho_h, plot_h, title_h] - anim = FuncAnimation(fig, animate, init_func=init, - frames=len(frames), interval=int(1000 / fps), blit=True) - print('Rendering animation, may take a few minutes...') + anim = FuncAnimation(fig, animate, init_func=init, frames=len(frames), interval=int(1000 / fps), blit=True) + print("Rendering animation, may take a few minutes...") if outfile is None: out = anim.to_jshtml() else: - anim.save(outfile, fps=fps)# , extra_args=['-vcodec', 'libx264']) + anim.save(outfile, fps=fps) # , extra_args=['-vcodec', 'libx264']) out = None # have to use this ugly return syntax so that we can close the figure after saving plt.close(fig) return out -def make_ethogram_movie_with_predictions(outfile: Union[str, bytes, os.PathLike], - ethogram: np.ndarray, - predictions: np.ndarray, - mapper, - frames: list, - start: int, - classes: list, - width: int = 100, - fps: float = 30): - """ Makes a movie with movie, then ethogram, then model predictions """ - +def make_ethogram_movie_with_predictions( + outfile: Union[str, bytes, os.PathLike], + ethogram: np.ndarray, + predictions: np.ndarray, + mapper, + frames: list, + start: int, + classes: list, + width: int = 100, + fps: float = 30, +): + """Makes a movie with movie, then ethogram, then model predictions""" + if mapper is None: mapper = Mapper() fig = plt.figure(figsize=(6, 8)) @@ -1466,21 +1510,23 @@ def make_ethogram_movie_with_predictions(outfile: Union[str, bytes, os.PathLike] im_h = axes[0].imshow(frames[0]) ax = axes[1] - im_h1 = plot_ethogram(ethogram[starts[0]:starts[0] + width, :], - mapper, start + framenum, ax, classes, ylabel='Labels') + im_h1 = plot_ethogram( + ethogram[starts[0] : starts[0] + width, :], mapper, start + framenum, ax, classes, ylabel="Labels" + ) x = (0, 1, 1, 0, 0) ylim = ax.get_ylim() y = (ylim[0], ylim[0], ylim[1], ylim[1], ylim[0]) - plot_h1 = ax.plot(x, y, color='k', lw=0.5)[0] + plot_h1 = ax.plot(x, y, color="k", lw=0.5)[0] ax = axes[2] - im_h2 = plot_ethogram(predictions[starts[0]:starts[0] + width, :], - mapper, start + framenum, ax, classes, ylabel='Predictions') + im_h2 = plot_ethogram( + predictions[starts[0] : starts[0] + width, :], mapper, start + framenum, ax, classes, ylabel="Predictions" + ) ylim = ax.get_ylim() y = (ylim[0], ylim[0], ylim[1], ylim[1], ylim[0]) - plot_h2 = ax.plot(x, y, color='k', lw=0.5)[0] + plot_h2 = ax.plot(x, y, color="k", lw=0.5)[0] - title_h = axes[0].set_title('{:,}'.format(start)) + title_h = axes[0].set_title("{:,}".format(start)) plt.tight_layout() @@ -1499,15 +1545,27 @@ def animate(i): x = (x0, x1, x1, x0, x0) # print(x) if (i % width) == 0: - im_h1 = plot_ethogram(ethogram[starts[i // width]:starts[i // width] + width, :], - mapper, start + i, axes[1], classes, ylabel='Labels') + im_h1 = plot_ethogram( + ethogram[starts[i // width] : starts[i // width] + width, :], + mapper, + start + i, + axes[1], + classes, + ylabel="Labels", + ) # no idea why plot ethogram doesn't change this xticks = axes[1].get_xticks() new_ticks = xticks + starts[i // width] + start axes[1].set_xticklabels([str(int(i)) for i in new_ticks]) - im_h2 = plot_ethogram(predictions[starts[i // width]:starts[i // width] + width, :], - mapper, start + i, axes[2], classes, ylabel='Predictions') + im_h2 = plot_ethogram( + predictions[starts[i // width] : starts[i // width] + width, :], + mapper, + start + i, + axes[2], + classes, + ylabel="Predictions", + ) # no idea why plot ethogram doesn't change this xticks = axes[2].get_xticks() new_ticks = xticks + starts[i // width] + start @@ -1519,40 +1577,39 @@ def animate(i): plot_h1.set_xdata(x) plot_h2.set_xdata(x) - title_h.set_text('{:,}'.format(start + i)) + title_h.set_text("{:,}".format(start + i)) return [im_h, im_h1, im_h2, plot_h1, plot_h2, title_h] - anim = FuncAnimation(fig, animate, init_func=init, - frames=len(frames), interval=int(1000 / fps), blit=True) - print('Rendering animation, may take a few minutes...') + anim = FuncAnimation(fig, animate, init_func=init, frames=len(frames), interval=int(1000 / fps), blit=True) + print("Rendering animation, may take a few minutes...") if outfile is None: out = anim.to_jshtml() else: - anim.save(outfile, fps=fps, extra_args=['-vcodec', 'libx264']) + anim.save(outfile, fps=fps, extra_args=["-vcodec", "libx264"]) out = None # have to use this ugly return syntax so that we can close the figure after saving plt.close(fig) return out -def make_figure_filename(name, is_example, num, split='train', overwrite:bool=True): - basedir = os.path.join(os.getcwd(), 'figures') +def make_figure_filename(name, is_example, num, split="train", overwrite: bool = True): + basedir = os.path.join(os.getcwd(), "figures") if is_example: - basedir = os.path.join(basedir, 'examples', split) + basedir = os.path.join(basedir, "examples", split) if not os.path.isdir(basedir): os.makedirs(basedir) - fname = os.path.join(basedir, '{:02d}_{}.png'.format(num, name)) + fname = os.path.join(basedir, "{:02d}_{}.png".format(num, name)) if overwrite: return fname cnt = 0 while os.path.isfile(fname): - fname = os.path.join(basedir, '{:02d}_{}_{}.png'.format(num, name, cnt)) + fname = os.path.join(basedir, "{:02d}_{}_{}.png".format(num, name, cnt)) cnt += 1 return fname -def save_figure(figure, name, is_example, num, split='train', overwrite:bool=True): +def save_figure(figure, name, is_example, num, split="train", overwrite: bool = True): fname = make_figure_filename(name, is_example, num, split, overwrite) figure.savefig(fname) plt.close(figure) diff --git a/deepethogram/zscore.py b/deepethogram/zscore.py index ec655f2..179e136 100644 --- a/deepethogram/zscore.py +++ b/deepethogram/zscore.py @@ -16,11 +16,12 @@ class StatsRecorder: - """ Class for computing mean and std deviation incrementally. Originally found on github here: + """Class for computing mean and std deviation incrementally. Originally found on github here: https://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html I only added PyTorch compatibility. """ + def __init__(self, mean: np.ndarray = None, std: np.ndarray = None, n_observations: int = None): self.nobservations = 0 if mean is not None: @@ -36,7 +37,7 @@ def first_batch(self, data: Union[np.ndarray, torch.Tensor]): data: ndarray, shape (nobservations, ndimensions) """ dtype = type(data) - assert (dtype == np.ndarray or dtype == torch.Tensor) + assert dtype == np.ndarray or dtype == torch.Tensor if dtype == np.ndarray: data = np.atleast_2d(data) @@ -63,7 +64,7 @@ def update(self, data): if data.shape[1] != self.ndimensions: raise ValueError("Data dims don't match prev observations.") dtype = type(data) - assert (dtype == np.ndarray or dtype == torch.Tensor) + assert dtype == np.ndarray or dtype == torch.Tensor if dtype == np.ndarray: data = np.atleast_2d(data) newmean = data.mean(axis=0) @@ -82,26 +83,25 @@ def update(self, data): tmp = self.mean self.mean = m / (m + n) * tmp + n / (m + n) * newmean - self.std = m / (m + n) * self.std ** 2 + n / (m + n) * newstd ** 2 + \ - m * n / (m + n) ** 2 * (tmp - newmean) ** 2 + self.std = m / (m + n) * self.std**2 + n / (m + n) * newstd**2 + m * n / (m + n) ** 2 * (tmp - newmean) ** 2 self.std = self.std**0.5 self.nobservations += n def __str__(self): - return 'mean: {} std: {} n: {}'.format(self.mean, self.std, self.nobservations) + return "mean: {} std: {} n: {}".format(self.mean, self.std, self.nobservations) def get_video_statistics(videofile, stride): image_stats = StatsRecorder() with deepethogram.file_io.VideoReader(videofile) as reader: - log.debug('N frames: {}'.format(len(reader))) + log.debug("N frames: {}".format(len(reader))) for i in tqdm(range(0, len(reader), stride)): try: image = reader[i] except Exception: - log.warning('Error reading frame {} from video {}'.format(i, videofile)) + log.warning("Error reading frame {} from video {}".format(i, videofile)) continue image = image.astype(float) / 255 image = image.transpose(2, 1, 0) @@ -113,9 +113,9 @@ def get_video_statistics(videofile, stride): # print(image.shape) image_stats.update(image) - log.info('final stats: {}'.format(image_stats)) + log.info("final stats: {}".format(image_stats)) - imdata = {'mean': image_stats.mean, 'std': image_stats.std, 'N': image_stats.nobservations} + imdata = {"mean": image_stats.mean, "std": image_stats.std, "N": image_stats.nobservations} for k, v in imdata.items(): if type(v) == torch.Tensor: v = v.detach().cpu().numpy() @@ -148,39 +148,39 @@ def zscore_video(videofile: Union[str, os.PathLike], project_config: dict, strid # config['normalization'] = None # transforms = get_transforms_from_config(config) # xform = transforms['train'] - log.info('zscoring file: {}'.format(videofile)) + log.info("zscoring file: {}".format(videofile)) imdata = get_video_statistics(videofile, stride) - fname = os.path.join(os.path.dirname(videofile), 'stats.yaml') + fname = os.path.join(os.path.dirname(videofile), "stats.yaml") dictionary = {} if os.path.isfile(fname): dictionary = utils.load_yaml(fname) - dictionary['normalization'] = imdata + dictionary["normalization"] = imdata utils.save_dict_to_yaml(dictionary, fname) update_project_with_normalization(imdata, project_config) def update_project_with_normalization(norm_dict: dict, project_config: dict): - """ Adds statistics from this video to the overall mean / std deviation for the project """ + """Adds statistics from this video to the overall mean / std deviation for the project""" # project_dict = utils.load_yaml(os.path.join(project_dir, 'project_config.yaml')) - if 'normalization' not in project_config['augs'].keys(): - raise ValueError('Must have project_config/augs/normalization field: {}'.format(project_config)) - old_rgb = project_config['augs']['normalization'] - if old_rgb is not None and old_rgb['N'] is not None and old_rgb['mean'] is not None: - old_mean_total = old_rgb['N'] * np.array(old_rgb['mean']) - old_std_total = old_rgb['N'] * np.array(old_rgb['std']) - old_N = old_rgb['N'] + if "normalization" not in project_config["augs"].keys(): + raise ValueError("Must have project_config/augs/normalization field: {}".format(project_config)) + old_rgb = project_config["augs"]["normalization"] + if old_rgb is not None and old_rgb["N"] is not None and old_rgb["mean"] is not None: + old_mean_total = old_rgb["N"] * np.array(old_rgb["mean"]) + old_std_total = old_rgb["N"] * np.array(old_rgb["std"]) + old_N = old_rgb["N"] else: old_mean_total = 0 old_std_total = 0 old_N = 0 - new_n = old_N + norm_dict['N'] - new_mean = (old_mean_total + norm_dict['N'] * np.array(norm_dict['mean'])) / new_n - new_std = (old_std_total + norm_dict['N'] * np.array(norm_dict['std'])) / new_n - project_config['augs']['normalization'] = {'N': new_n, 'mean': new_mean.tolist(), 'std': new_std.tolist()} - utils.save_dict_to_yaml(project_config, os.path.join(project_config['project']['path'], 'project_config.yaml')) + new_n = old_N + norm_dict["N"] + new_mean = (old_mean_total + norm_dict["N"] * np.array(norm_dict["mean"])) / new_n + new_std = (old_std_total + norm_dict["N"] * np.array(norm_dict["std"])) / new_n + project_config["augs"]["normalization"] = {"N": new_n, "mean": new_mean.tolist(), "std": new_std.tolist()} + utils.save_dict_to_yaml(project_config, os.path.join(project_config["project"]["path"], "project_config.yaml")) # @hydra.main(config_path='../conf/zscore.yaml') @@ -190,9 +190,9 @@ def main(cfg: DictConfig): zscore_video(cfg.videofile, project_config, cfg.stride) -if __name__ == '__main__': - config_list = ['config', 'zscore'] - run_type = 'zscore' +if __name__ == "__main__": + config_list = ["config", "zscore"] + run_type = "zscore" model = None project_path = projects.get_project_path_from_cl(sys.argv) cfg = configuration.make_config(project_path, config_list, run_type, model, use_command_line=True) diff --git a/docker/Dockerfile-full b/docker/Dockerfile-full index cf8bfe8..d66833a 100644 --- a/docker/Dockerfile-full +++ b/docker/Dockerfile-full @@ -1,18 +1,18 @@ FROM --platform=linux/amd64 nvidia/cuda:11.5.2-cudnn8-devel-ubuntu20.04 -# modified from here +# modified from here # https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile # Install some basic utilities RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ - curl ca-certificates sudo git bzip2 libx11-6 \ + curl ca-certificates sudo git bzip2 libx11-6 \ ffmpeg libsm6 libxext6 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-render-util0 libxcb-xinerama0 \ libxcb-xkb-dev libxkbcommon-x11-0 libpulse-mainloop-glib0 ubuntu-restricted-extras libqt5multimedia5-plugins vlc \ - libkrb5-3 libgssapi-krb5-2 libkrb5support0 \ + libkrb5-3 libgssapi-krb5-2 libkrb5support0 \ && rm -rf /var/lib/apt/lists/* # don't ask for location etc user input when building # this is for opencv, apparently -RUN apt-get update && apt-get install -y +RUN apt-get update && apt-get install -y # Create a working directory and data directory RUN mkdir /app @@ -29,12 +29,12 @@ RUN curl -sLo ~/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-py39 && rm ~/miniconda.sh \ && conda update conda -# install -RUN conda install python=3.7 -y +# install +RUN conda install python=3.7 -y RUN pip install setuptools --upgrade && pip install --upgrade "pip<24.0" RUN pip install torch==1.11.0+cu115 torchvision==0.12.0+cu115 -f https://download.pytorch.org/whl/torch_stable.html ADD . /app/deepethogram WORKDIR /app/deepethogram ENV DEG_VERSION='full' -RUN pip install -e . \ No newline at end of file +RUN pip install -e . diff --git a/docker/Dockerfile-gui b/docker/Dockerfile-gui index 60855be..b8b6487 100644 --- a/docker/Dockerfile-gui +++ b/docker/Dockerfile-gui @@ -1,6 +1,6 @@ FROM --platform=linux/amd64 ubuntu:20.04 -# modified from here +# modified from here # https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile # Install some basic utilities RUN apt-get update && apt-get install -y \ @@ -9,7 +9,7 @@ RUN apt-get update && apt-get install -y \ sudo \ git \ bzip2 \ - libx11-6 \ + libx11-6 \ && rm -rf /var/lib/apt/lists/* # don't ask for location etc user input when building @@ -33,8 +33,8 @@ RUN curl -sLo ~/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-py39 && rm ~/miniconda.sh \ && conda update conda -# install -RUN conda install python=3.7 -y +# install +RUN conda install python=3.7 -y RUN pip install setuptools --upgrade && pip install --upgrade pip # TODO: REFACTOR CODE SO IT'S POSSIBLE TO RUN GUI WITHOUT TORCH @@ -44,4 +44,4 @@ RUN conda install pytorch cpuonly -c pytorch ADD . /app/deepethogram WORKDIR /app/deepethogram ENV DEG_VERSION='gui' -RUN pip install -e . \ No newline at end of file +RUN pip install -e . diff --git a/docker/Dockerfile-headless b/docker/Dockerfile-headless index d0830f5..ce92955 100644 --- a/docker/Dockerfile-headless +++ b/docker/Dockerfile-headless @@ -1,6 +1,6 @@ FROM --platform=linux/amd64 nvidia/cuda:11.5.2-cudnn8-devel-ubuntu20.04 -# modified from here +# modified from here # https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile # Install some basic utilities RUN apt-get update && apt-get install -y \ @@ -9,12 +9,12 @@ RUN apt-get update && apt-get install -y \ sudo \ git \ bzip2 \ - libx11-6 \ + libx11-6 \ && rm -rf /var/lib/apt/lists/* # don't ask for location etc user input when building # this is for opencv, apparently -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y ffmpeg +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y ffmpeg # Create a working directory and data directory RUN mkdir /app @@ -31,8 +31,8 @@ RUN curl -sLo ~/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-py39 && rm ~/miniconda.sh \ && conda update conda -# install -RUN conda install python=3.7 -y +# install +RUN conda install python=3.7 -y RUN pip install setuptools --upgrade && pip install --upgrade pip RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch @@ -40,4 +40,4 @@ RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch ADD . /app/deepethogram WORKDIR /app/deepethogram ENV DEG_VERSION='headless' -RUN pip install -e . \ No newline at end of file +RUN pip install -e . diff --git a/docs/beta.md b/docs/beta.md index b100cf4..815e89f 100644 --- a/docs/beta.md +++ b/docs/beta.md @@ -1,17 +1,17 @@ # DeepEthogram Beta -DeepEthogram is now in Beta, version 0.1! There are major changes to the codebase and to model training and inference. -Model performance, measured by F1, accuracy, etc. should be higher in version 0.1. Model training times and inference -times should be dramatically reduced. +DeepEthogram is now in Beta, version 0.1! There are major changes to the codebase and to model training and inference. +Model performance, measured by F1, accuracy, etc. should be higher in version 0.1. Model training times and inference +times should be dramatically reduced. -**Important note: your old project files, models, and (most importantly) human labels will all still work!** However, -I do recommend training new feature extractor and sequence models, as performance should improve somewhat. This will +**Important note: your old project files, models, and (most importantly) human labels will all still work!** However, +I do recommend training new feature extractor and sequence models, as performance should improve somewhat. This will be the last major refactor of DeepEthogram (model improvements and new features will still come out), however I will -not be majorly changing dependencies after this. Future upgrades will be easier (e.g. `pip install --upgrade deepethogram`). +not be majorly changing dependencies after this. Future upgrades will be easier (e.g. `pip install --upgrade deepethogram`). ## Summary of changes -* Basic training pipeline re-implemented with PyTorch Lightning. This gives us some great features, such as tensorboard -logging, automatic batch sizing, and Ray Tune integration. +* Basic training pipeline re-implemented with PyTorch Lightning. This gives us some great features, such as tensorboard +logging, automatic batch sizing, and Ray Tune integration. * Image augmentations moved to GPU with Kornia. [see Performance guide for details](performance.md) * New, parallelized inference * Hyperparameter tuning @@ -24,7 +24,7 @@ logging, automatic batch sizing, and Ray Tune integration. ## Migration guide -There are some new dependency changes; making sure that install works correctly is the hardest part about migration. +There are some new dependency changes; making sure that install works correctly is the hardest part about migration. * activate your conda environment, e.g. `conda activate deg` * uninstall hydra: `pip uninstall hydra-core` @@ -40,8 +40,8 @@ There are some new dependency changes; making sure that install works correctly with your upgrade. please follow the above steps. If you're sure that everything else installed correctly, you can run `pip install --upgrade omegaconf` * `error: torch 1.5.1 is installed but torch>=1.6.0 is required by {'kornia'}` - * this indicates that your PyTorch version is too low. Please uninstall and reinstall PyTorch. + * this indicates that your PyTorch version is too low. Please uninstall and reinstall PyTorch. * `ValueError: Hydra installation found. Please run pip uninstall hydra-core` * do as the error message says: run `pip uninstall hydra-core` - * if you've already done this, you might have to manually delete hydra files. Mine were at + * if you've already done this, you might have to manually delete hydra files. Mine were at `'C:\\ProgramData\\Anaconda3\\lib\\site-packages\\hydra_core-0.11.3-py3.7.egg\\hydra'`. Please delete the `hydra_core` folder. diff --git a/docs/code_examples.md b/docs/code_examples.md index 69cdd7c..1e88466 100644 --- a/docs/code_examples.md +++ b/docs/code_examples.md @@ -1,3 +1,3 @@ # Code examples -TODO \ No newline at end of file +TODO diff --git a/docs/docker.md b/docs/docker.md index 598c073..02da8ca 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -4,7 +4,7 @@ Install Docker: https://docs.docker.com/get-docker/ Install nvidia-docker: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker ## running the gui on Linux with training support -In a terminal, run `xhost +local:docker`. You'll need to do this every time you restart. +In a terminal, run `xhost +local:docker`. You'll need to do this every time you restart. To run, type this command: `docker run --gpus all -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:rw --shm-size 16G -v /media:/media -it jbohnslav/deepethogram:full python -m deepethogram` @@ -35,6 +35,6 @@ Again, change `/media` to your hard drive with your training data * tests: `docker run --gpus all -it deepethogram:full pytest tests/` # building it yourself -To build the container with both GUI and model training support: +To build the container with both GUI and model training support: * `cd` to your `deepethogram` directory * `docker build -t deepethogram:full -f docker/Dockerfile-full .` diff --git a/docs/file_structure.md b/docs/file_structure.md index 4bf6ee2..88e22bd 100644 --- a/docs/file_structure.md +++ b/docs/file_structure.md @@ -1,7 +1,7 @@ # Expected filepaths To train the DeepEthogram models, we need to be able to find a bunch of files (below). If you use the GUI, this directory -structure will be created for you. +structure will be created for you. * models: a list of recent model runs of various types, along with their weights, and their performance * data * for each video, we need the video file itself @@ -12,12 +12,12 @@ structure will be created for you. * for the feature extractor, we save the 512-dimensional image features and 512-dimensional flow features to this file * we also save probabilities and predictions (thresholded probabilities) to this file, as well as the thresholds used * video statistics: following normal convention in machine learning, we z-score our input data. For images, this is done independently - for the read, green, and blue channels. We z-score each video as they are added to a project, and save the channel + for the read, green, and blue channels. We z-score each video as they are added to a project, and save the channel means and std deviations to a file * project configuration file: holds project-specific information, like behavior names and variables to override. For defaults, see [the default configuration file](../deepethogram/conf/project/project_config.yaml) - -Therefore, the data loading scripts expect the following consistent folder structure. Note: if you write your own -dataloaders, you can use whatever file structure you want. + +Therefore, the data loading scripts expect the following consistent folder structure. Note: if you write your own +dataloaders, you can use whatever file structure you want. ```bash project_directory @@ -44,4 +44,4 @@ project_directory | | ├── model_type_metrics.h5: saved metrics for this model. e.g. f1, accuracy, SSIM, depending | ├── 200504_feature_extractor_None | | ├── checkpoint.pt: etc... -``` \ No newline at end of file +``` diff --git a/docs/getting_started.md b/docs/getting_started.md index 13397f9..6503ac3 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -1,93 +1,93 @@ # Getting started -The goal of DeepEthogram is as follows: +The goal of DeepEthogram is as follows: * You have videos (inputs) * You also have a set of behaviors that you've defined based on your research project / interests * You want to know, for every frame, which behaviors are present on that frame * DeepEthogram will learn a mapping , where -p(k,t) is the probability of behavior k on frame t. X is the video at time t. -* These probabilities are thresholded to give a binary prediction. The binary matrix of each behavior at each timepoint is what we call an *ethogram*. -* To train this model, you must label videos manually for each behavior at each timepoint. +p(k,t) is the probability of behavior k on frame t. X is the video at time t. +* These probabilities are thresholded to give a binary prediction. The binary matrix of each behavior at each timepoint is what we call an *ethogram*. +* To train this model, you must label videos manually for each behavior at each timepoint. * After training, you can run inference on a new video. Most of the frames should be labeled correctly. You can then quickly edit the errors. -This is schematized below: +This is schematized below: ![DeepEthogram figure 1](images/ethogram_schematic.png) -In the above figure, the image sequence are our inputs, and the ethogram is depicted below. +In the above figure, the image sequence are our inputs, and the ethogram is depicted below. ## Installation -See [the installation documentation](installation.md). +See [the installation documentation](installation.md). ## Making a project -The most important decision to make when starting a DeepEthogram project is which behaviors to include. Each frame must have a label -for each behavior. While DeepEthogram contains code for adding and removing behaviors, all previous models must need to be -retrained when a behavior has been added or removed. After all, if you used to have 5 behaviors and now you have 6, -the final layer of the neural network models will all have the wrong shape. Furthermore, previously labeled videos must be -updated with new behaviors before they can be used for training. +The most important decision to make when starting a DeepEthogram project is which behaviors to include. Each frame must have a label +for each behavior. While DeepEthogram contains code for adding and removing behaviors, all previous models must need to be +retrained when a behavior has been added or removed. After all, if you used to have 5 behaviors and now you have 6, +the final layer of the neural network models will all have the wrong shape. Furthermore, previously labeled videos must be +updated with new behaviors before they can be used for training. -Open your terminal window, activate your `conda` environment, and open the GUI by typing `deepethogram`. For more information, see -[using the GUI](using_gui.md). +Open your terminal window, activate your `conda` environment, and open the GUI by typing `deepethogram`. For more information, see +[using the GUI](using_gui.md). Go to `file -> new project`. Select a location for the new project to be created. It is *essential* that the project -be created on a Solid State Drive (or NVMe drive), because during training DeepEthogram will load hundreds of images per second. +be created on a Solid State Drive (or NVMe drive), because during training DeepEthogram will load hundreds of images per second. -After selecting a location, a screen will appear with three fields: +After selecting a location, a screen will appear with three fields: * `project name`: the name of your project. Examples might be: `mouse_reach`, `grooming`, `itch_mutation_screen`, etc. -* `name of person labeling`: your name. Currently unused, in the future it could be used to compare labels between humans. -* `list of behaviors`: the list of behaviors you want to label. Think carefully! (see above). Separate the behaviors with commas. -Do not include `none`, `other`, `background`, `etc`, or `misc` or anything like that. +* `name of person labeling`: your name. Currently unused, in the future it could be used to compare labels between humans. +* `list of behaviors`: the list of behaviors you want to label. Think carefully! (see above). Separate the behaviors with commas. +Do not include `none`, `other`, `background`, `etc`, or `misc` or anything like that. -Press OK. +Press OK. -A directory will be created in the location you specified, with the name `projectname_deepethogram`. It will initialize the -file structure needed to run deepethogram. See [the docs](file_structure.md) for details. +A directory will be created in the location you specified, with the name `projectname_deepethogram`. It will initialize the +file structure needed to run deepethogram. See [the docs](file_structure.md) for details. ## Edit the default configuration file -For more information see [the config file docs](using_config_files.md). +For more information see [the config file docs](using_config_files.md). ## Add videos -When you add videos to DeepEthogram, we will **copy them to the deepethogram project directory** (not move or use a symlink). We highly +When you add videos to DeepEthogram, we will **copy them to the deepethogram project directory** (not move or use a symlink). We highly recommend starting with at least 3 videos, so you have more than 1 to train, and 1 for validation (roughly, videos are assigned to splits probabilistically). When you add videos, DeepEthogram will automatically compute mean and standard deviation statistics and save them to disk. This might take a few moments (or minutes for very long videos). This is required for model training -and inference. +and inference. ## Download models -Rather than start from scratch, we will start with model weights pretrained on the Kinetics700 dataset. Go to +Rather than start from scratch, we will start with model weights pretrained on the Kinetics700 dataset. Go to To download the pretrained weights, please use [this Google Drive link](https://drive.google.com/file/d/1ntIZVbOG1UAiFVlsAAuKEBEVCVevyets/view?usp=sharing). -Unzip the files in your `project/models` directory. The path should be -`your_project/models/pretrained/{models 1:6}`. +Unzip the files in your `project/models` directory. The path should be +`your_project/models/pretrained/{models 1:6}`. ## Start training the flow generator These models (see paper) estimate local motion from video frames. They are trained in a self-supervised manner, so they -require no labels. First, select the pretrained model architecture you chose with the drop down menu in the flow_generator box. +require no labels. First, select the pretrained model architecture you chose with the drop down menu in the flow_generator box. Then simply click the `train` button in the `flow_generator` section on the left -side of the GUI. The model will use the architecture you've put in your configuration file, or `TinyMotionNet` by default. -You can continue to label while this model is training. +side of the GUI. The model will use the architecture you've put in your configuration file, or `TinyMotionNet` by default. +You can continue to label while this model is training. ## Label a few videos To see how to import videos to DeepEthogram and label them, please see [using the GUI docs](using_gui.md). ## Train models! -The workflow is depicted below. For more information, read the paper. +The workflow is depicted below. For more information, read the paper. ![DeepEthogram workflow](images/workflow.png) 1. Train the feature extractor. To do this, use the drop down menu to select the pretrained weights for the architecture -you've specified in your configuration file. Then click `train`. This will take a few hours at least, perhaps overnight +you've specified in your configuration file. Then click `train`. This will take a few hours at least, perhaps overnight the first time. 2. Run inference using the pretrained feature extractor. The weights file from the model you've just trained should be pre-selected in the drop-down menu. Click `infer`. Select the videos you want to run inference on. This will go frame-by-frame -through all your videos and save the 512-d spatial features and 512-d motion features to disk. These are the inputs to our -sequence model. -3. Train the sequence model. In the sequence box, simply click `train`. -4. Run inference using the sequence model, as above. +through all your videos and save the 512-d spatial features and 512-d motion features to disk. These are the inputs to our +sequence model. +3. Train the sequence model. In the sequence box, simply click `train`. +4. Run inference using the sequence model, as above. -Now, you have a trained flow, feature extractor, and sequence model. +Now, you have a trained flow, feature extractor, and sequence model. ## Add more videos -Using the GUI, add videos to your project. After you add the videos, as depicted in the workflow figure above, -extract features to disk. This will take about 30-60 frames per second, depending on the model and your video resolution. +Using the GUI, add videos to your project. After you add the videos, as depicted in the workflow figure above, +extract features to disk. This will take about 30-60 frames per second, depending on the model and your video resolution. Then run inference using the pretrained sequence model (should be instantaneous). Now, for your newly added videos, you have probabilities and predictions for every video frame. Use the `import predictions as labels` @@ -102,4 +102,4 @@ Continue cyling through the above workflow * add videos * run inference * edit model errors -* retrain models \ No newline at end of file +* retrain models diff --git a/docs/installation.md b/docs/installation.md index 70bc3d9..313f148 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -5,21 +5,21 @@ * Create a new anaconda environment: `conda create --name deg python=3.7` * Activate your environment: `conda activate deg` * Install PySide2: `conda install -c conda-forge pyside2==5.13.2` -* Install PyTorch: [Use this link for official instructions.](https://pytorch.org/) -* `pip install deepethogram`. +* Install PyTorch: [Use this link for official instructions.](https://pytorch.org/) +* `pip install deepethogram`. ## Installing from source * `git clone https://github.com/jbohnslav/deepethogram.git` * `cd deepethogram` * `conda env create -f environment.yml` - * Be prepared to wait a long time!! On mechanical hard drives, this may take 5-10 minutes (or more). Interrupting here will cause installation to fail. + * Be prepared to wait a long time!! On mechanical hard drives, this may take 5-10 minutes (or more). Interrupting here will cause installation to fail. * `conda activate deg` * `python setup.py develop` ### Installing Anaconda -For instructions on installing anaconda, -please [use this link](https://www.anaconda.com/distribution/). This will install Python, some basic dependencies, and -install the Anaconda package manager. This will ensure that if you use some other project that (say) requires Python 2, +For instructions on installing anaconda, +please [use this link](https://www.anaconda.com/distribution/). This will install Python, some basic dependencies, and +install the Anaconda package manager. This will ensure that if you use some other project that (say) requires Python 2, you can have both installed on your machine without interference. * First things first, download and install Anaconda for your operating system. You can find the downloads [here](https://www.anaconda.com/distribution/#download-section). Make sure you pick the Python 3.7 version. When you're installing, make sure you pick the option something along the lines of "add anaconda to path". That way, you can use `conda` on the command line. @@ -27,17 +27,17 @@ you can have both installed on your machine without interference. * Open up the command line, such as terminal on mac or cmd.exe. **VERY IMPORTANT: On Windows, make sure you run the command prompt as an administrator! To do this, right click the shortcut to the command prompt, click `run as administrator`, then say yes to whatever pops up.** ## Installing from pip -First install the latest version of PyTorch for your system. [Use this link for official instructions.](https://pytorch.org/) -It should be as simple as `conda install pytorch torchvision cudatoolkit=10.2 -c pytorch`. +First install the latest version of PyTorch for your system. [Use this link for official instructions.](https://pytorch.org/) +It should be as simple as `conda install pytorch torchvision cudatoolkit=10.2 -c pytorch`. -Note: if you have an RTX3000 series graphics card, such as a 3060 or 3090, please use `cudatoolkit=11.1` or higher. +Note: if you have an RTX3000 series graphics card, such as a 3060 or 3090, please use `cudatoolkit=11.1` or higher. -After installing PyTorch, simply use `pip install deepethogram`. +After installing PyTorch, simply use `pip install deepethogram`. ## Install FFMPEG We use FFMPEG for reading and writing `.mp4` files (with libx264 encoding). Please use [this link](https://www.ffmpeg.org/) to install on your system. - + ## Startup * `conda activate deg`. This activates the environment. * type `python -m deepethogram`, in the command line to open the GUI. @@ -46,20 +46,19 @@ to install on your system. Please see [the beta docs for instructions](beta.md) ## Common installation problems -* You might have dependency issues with other packages you've installed. Please make a new anaconda or miniconda -environment with `conda create --name deg python=3.8` before installation. -* `module not found: PySide2`. Some versions of PySide2 install poorly from pip. use `pip uninstall pyside2`, then +* You might have dependency issues with other packages you've installed. Please make a new anaconda or miniconda +environment with `conda create --name deg python=3.8` before installation. +* `module not found: PySide2`. Some versions of PySide2 install poorly from pip. use `pip uninstall pyside2`, then `conda install -c conda-forge pyside2` -* When opening the GUI, you might get `Segmentation fault (core dumped)`. In this case; please `pip uninstall pyside2`, +* When opening the GUI, you might get `Segmentation fault (core dumped)`. In this case; please `pip uninstall pyside2`, `conda uninstall pyside2`. `pip install pyside2` * `ImportError: C:\Users\jbohn\.conda\envs\deg2\lib\site-packages\shiboken2\libshiboken does not exist` - * something went wrong with your PySide2 installation, likely on Windows. + * something went wrong with your PySide2 installation, likely on Windows. * Make sure you have opened your command prompt as administrator - * If it tells you to install a new version of Visual Studio C++, please do that. - * Now you should be set up: let's reinstall PySide2 and libshiboken. + * If it tells you to install a new version of Visual Studio C++, please do that. + * Now you should be set up: let's reinstall PySide2 and libshiboken. * `pip install --force-reinstall pyside2` * `_init_pyside_extension is not defined` * This is an issue where Shiboken and PySide2 are not playing nicely together. Please `pip uninstall pyside2` and `conda remove pyside2`. Don't manually install these packages; instead, let DeepEthogram install it for you via pip. Therefore, `pip uninstall deepethogram` and `pip install deepethogram`. * `qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in ".../python3.8/site-packages/cv2/qt/plugins" even though it was found. This application failed to start because no Qt platform plugin could be initialized. Reinstalling the application may fix this problem.` * This is an issue with a recent version of `opencv-python` not working well with Qt. Please do `pip install --force-reinstall opencv-python-headless==4.1.2.30` - diff --git a/docs/performance.md b/docs/performance.md index bffa65d..3f37fe9 100644 --- a/docs/performance.md +++ b/docs/performance.md @@ -3,101 +3,101 @@ This document will describe how to minimize model training and inference time. It does not have to do with model accuracy; for that, please see [the model performance docs](model_performance.md) ## Hardware requirements -For GUI usage, we expect that the users will be working on a local workstation with a good NVIDIA graphics card. For training via a cluster, you can use the command line interface. -As of DeepEthogram version 0.1, we use the GPU more heavily for all tasks. If you have a limited budget, spending the -majority of it on a GPU is the highest priority. +For GUI usage, we expect that the users will be working on a local workstation with a good NVIDIA graphics card. For training via a cluster, you can use the command line interface. +As of DeepEthogram version 0.1, we use the GPU more heavily for all tasks. If you have a limited budget, spending the +majority of it on a GPU is the highest priority. * CPU: 4 cores or more for parallel data loading * Hard Drive: SSD at minimum, NVMe drive is better. -* GPU: DeepEthogram speed is directly related to GPU performance. An NVIDIA GPU is absolutely required, as PyTorch uses -CUDA, while AMD does not. +* GPU: DeepEthogram speed is directly related to GPU performance. An NVIDIA GPU is absolutely required, as PyTorch uses +CUDA, while AMD does not. The more VRAM you have, the more data you can fit in one batch, which generally increases performance. a I'd recommend 6GB VRAM at absolute minimum. 8GB is better, with 10+ GB preferred. -Recommended GPUs: `RTX 3090`, `RTX 3080`, `Titan RTX`, `2080 Ti`, `2080 super`, `2080`, `1080 Ti`, `2070 super`, `2070` -Some older ones might also be fine, like a `1080` or even `1070 Ti`/ `1070`. +Recommended GPUs: `RTX 3090`, `RTX 3080`, `Titan RTX`, `2080 Ti`, `2080 super`, `2080`, `1080 Ti`, `2070 super`, `2070` +Some older ones might also be fine, like a `1080` or even `1070 Ti`/ `1070`. # Bottlenecks -In diagnosing performance issues, it is helpful to know the basics of how data is loaded and run through the network. +In diagnosing performance issues, it is helpful to know the basics of how data is loaded and run through the network. ## Training pipeline 1. Grab random frames from random videos. This randomness is extremely important; through gradient -descent, we will be training our model parameters sequentially. Without randomizing the order of video snippets, -our model would be biased towards the first frames of the first videos. The input to hidden-two-stream models are clips of 11 images. So, we need to randomly select a video from our training set, open it, and seek to a random starting point, and read 11 images into CPU memory (RAM). This happens in parallel, with each CPU core reading one clip at a time. The number of workers is controlled by `cfg.compute.num_workers`. +descent, we will be training our model parameters sequentially. Without randomizing the order of video snippets, +our model would be biased towards the first frames of the first videos. The input to hidden-two-stream models are clips of 11 images. So, we need to randomly select a video from our training set, open it, and seek to a random starting point, and read 11 images into CPU memory (RAM). This happens in parallel, with each CPU core reading one clip at a time. The number of workers is controlled by `cfg.compute.num_workers`. 2. Not all our videos are necessarily the same size. Therefore, crop or resize our clip into a consistent shape. This is done on the CPU. -3. Here is ordinarily where image augmentation (randomly changing brightness, contrast, rotating, flipping, etc.) would -occur. However, I use Kornia to do this on the GPU for speed. -4. Stack our clips, currently of shape (3, 11, Height, Width), into a single batch of shape (N, 3, 11, Height, Width). -N is controlled by `cfg.compute.batch_size`. -5. Move our batch from CPU to GPU memory. All subsequent operations are now done on the GPU. -6. Perform image augmentations with Kornia. -7. Perform the forward pass through our neural network. -8. Compute the loss +3. Here is ordinarily where image augmentation (randomly changing brightness, contrast, rotating, flipping, etc.) would +occur. However, I use Kornia to do this on the GPU for speed. +4. Stack our clips, currently of shape (3, 11, Height, Width), into a single batch of shape (N, 3, 11, Height, Width). +N is controlled by `cfg.compute.batch_size`. +5. Move our batch from CPU to GPU memory. All subsequent operations are now done on the GPU. +6. Perform image augmentations with Kornia. +7. Perform the forward pass through our neural network. +8. Compute the loss 9. Compute the gradients of the loss w.r.t. our parameters -10. Optimize the parameters of our network using the gradients (we use ADAM). +10. Optimize the parameters of our network using the gradients (we use ADAM). -Now that we know how the pipeline works at a high level, we can start to see where our bottlenecks might occur. +Now that we know how the pipeline works at a high level, we can start to see where our bottlenecks might occur. * `1.` Reading frames from disk. If we use a batch size of 32, we are loading 11 frames from 32 videos, or 352 images -for one batch. If we have a fast GPU, at ~256x256 resolution, DEG_f feature extractors can train at ~3.5 batches per +for one batch. If we have a fast GPU, at ~256x256 resolution, DEG_f feature extractors can train at ~3.5 batches per second. This is ~1200+ images that need to be read per second, from `32*3.5=112` random video locations. This means -two things: - * We need to make sure we have a solid-state hard drive, or an NVMe hard drive. NOT a mechanical hard drive. - * We need to optimize our video format to make random reads faster. Normal video encodings, like libx264 or MJPG, - typically have very fast sequential reads, and very slow random reads. If your videos have endings like .avi or .mp4, - it is likely that it will be extremely slow to randomly read frames from them, and this will slow down your entire - training. - * SOLUTION: the fastest way to randomly read videos is to store them as folders full of images. If we use .jpg - compression, your videos will be re-compressed, and have new artifacts. If we use .PNG compression, the videos will - be lossless-ly encoded, which means there will be no new artifacts. However, PNG images take up far more space on disk +two things: + * We need to make sure we have a solid-state hard drive, or an NVMe hard drive. NOT a mechanical hard drive. + * We need to optimize our video format to make random reads faster. Normal video encodings, like libx264 or MJPG, + typically have very fast sequential reads, and very slow random reads. If your videos have endings like .avi or .mp4, + it is likely that it will be extremely slow to randomly read frames from them, and this will slow down your entire + training. + * SOLUTION: the fastest way to randomly read videos is to store them as folders full of images. If we use .jpg + compression, your videos will be re-compressed, and have new artifacts. If we use .PNG compression, the videos will + be lossless-ly encoded, which means there will be no new artifacts. However, PNG images take up far more space on disk than something like a .mp4 file, or .jpg images. However, for datasets numbering in the 10s of videos, I think this is - a worthwhile tradeoff; we are trading more space on the disk for less time in training. - * Solution 1: Use the function `projects.convert_all_videos`. Using `movie_format='directory'` means that we will - convert our video into a big directory full of .png files. - Example: + a worthwhile tradeoff; we are trading more space on the disk for less time in training. + * Solution 1: Use the function `projects.convert_all_videos`. Using `movie_format='directory'` means that we will + convert our video into a big directory full of .png files. + Example: ```python from deepethogram.projects import convert_all_videos convert_all_videos('PATH/TO/MY/project_config.yaml', movie_format='directory') ``` - * Solution 2: Use the function `projects.convert_all_videos`. Using `movie_format='hdf5'` means that we will - convert our video into PNGs; however, instead of a big directory, we will save the PNG bytestrings inside an - HDF5 file. This means it will be faster to move around (copying or cutting / pasting directories full of images + * Solution 2: Use the function `projects.convert_all_videos`. Using `movie_format='hdf5'` means that we will + convert our video into PNGs; however, instead of a big directory, we will save the PNG bytestrings inside an + HDF5 file. This means it will be faster to move around (copying or cutting / pasting directories full of images takes forever). The only downside of using `hdf5` as our movie format is that it is a less common filetype; you cant - just open it up in your file browser. However, this is the format I use for all DeepEthogram training. - * Solution 3: Use the function `projects.convert_all_videos` using `movie_format='directory'` or + just open it up in your file browser. However, this is the format I use for all DeepEthogram training. + * Solution 3: Use the function `projects.convert_all_videos` using `movie_format='directory'` or `movie_format='hdf5'`, along with `codec='.jpg'`. This will re-compress your images as .jpgs, saving filespace at - the expense of image quality. + the expense of image quality. Example: ```python from deepethogram.projects import convert_all_videos convert_all_videos('PATH/TO/MY/project_config.yaml', movie_format='directory', codec='.jpg') ``` - * Solution 4: Use `file_io.convert_video` along with your custom code to resize your images while you save them as - .PNGs. The input to the network is usually not larger than ~256x256 in resolution; if we resize them when - converting, we could potentially save both time and space. If you want me to code this for you, please raise an - issue on GitHub. -* `3`: Image augmentations. If our network is capable of running at 1200 images per second, randomly changing the + * Solution 4: Use `file_io.convert_video` along with your custom code to resize your images while you save them as + .PNGs. The input to the network is usually not larger than ~256x256 in resolution; if we resize them when + converting, we could potentially save both time and space. If you want me to code this for you, please raise an + issue on GitHub. +* `3`: Image augmentations. If our network is capable of running at 1200 images per second, randomly changing the brightness, color, contrast, rotating, etc for 1200 images on the CPU could bottleneck our training. For this reason, -with DeepEthogram v0.1 and above, I've converted the entire augmentation pipeline to Kornia, which implemented +with DeepEthogram v0.1 and above, I've converted the entire augmentation pipeline to Kornia, which implemented augmentations on the GPU. Thanks to the people at Kornia for implementing Video Transforms as of version 0.5. -In my experience, either reading from disk or performing augmentations are the most likely places to slow down training. +In my experience, either reading from disk or performing augmentations are the most likely places to slow down training. ## Inference -As of DeepEthogram 0.1, I've implemented a much faster parallel processing pipeline for video inference. We use only -sequential reads from disk (see training section), while also loading images in parallel and running our network in -batches. The code for this can be found in `deepethogram.data.datasets.VideoIterable`. +As of DeepEthogram 0.1, I've implemented a much faster parallel processing pipeline for video inference. We use only +sequential reads from disk (see training section), while also loading images in parallel and running our network in +batches. The code for this can be found in `deepethogram.data.datasets.VideoIterable`. # Model type -In the DeepEthogram paper, we describe 3 models; `DEG_f`,` DEG_m`, and `DEG_s`. These use ResNet18, ResNet50, and -3D-ResNet34 as feature extractors, respectively. We recommend only using `DEG_m` or `DEG_f` by default; if you have -access to high-quality GPUs (e.g. RTX3090s, Titan RTX, etc), or need especially accurate results for your project, you -can try `DEG_s`. +In the DeepEthogram paper, we describe 3 models; `DEG_f`,` DEG_m`, and `DEG_s`. These use ResNet18, ResNet50, and +3D-ResNet34 as feature extractors, respectively. We recommend only using `DEG_m` or `DEG_f` by default; if you have +access to high-quality GPUs (e.g. RTX3090s, Titan RTX, etc), or need especially accurate results for your project, you +can try `DEG_s`. # Image size -For both training and inference, speed will be extremely proportional to image resolution. You should never use raw -acquired video, such as HD (1920 x 1080). The size of the input images are determined by this section of your +For both training and inference, speed will be extremely proportional to image resolution. You should never use raw +acquired video, such as HD (1920 x 1080). The size of the input images are determined by this section of your `project_config.yaml`: ```yaml augs: @@ -110,10 +110,10 @@ Note that we default to resizing to 224 x 224. In the paper, we use the followin * 256 x 256 * 352 x 224 # Homecage videos -For the flow generators that I wrote, we can use images of any resolution by default. For ResNets, however, our image -sizes must be multiples of 32: +For the flow generators that I wrote, we can use images of any resolution by default. For ResNets, however, our image +sizes must be multiples of 32: ```python -[32*i for i in range(1, 15)] +[32*i for i in range(1, 15)] # [32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448] ``` -Note that I chose `352 x 224` for the Homecage dataset for this reason. \ No newline at end of file +Note that I chose `352 x 224` for the Homecage dataset for this reason. diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 002e007..21bd647 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -1,9 +1,9 @@ # Troubleshooting -Please use the `issues` button on GitHub for any bugs you think you've encountered. During GUI usage, model training, +Please use the `issues` button on GitHub for any bugs you think you've encountered. During GUI usage, model training, and model inference, hydra creates a `.log` file (e.g. `train.log`, `main.log`) with various log messages. Please -copy this into your GitHub page. When using the command line, including starting the GUI, use the flag `hydra.verbose=true`. -This will add debugging information to your logs (and also print them to the command line). +copy this into your GitHub page. When using the command line, including starting the GUI, use the flag `hydra.verbose=true`. +This will add debugging information to your logs (and also print them to the command line). # FAQ @@ -12,4 +12,4 @@ The most important factors that determine model performance are as follows: 1. number of data points 2. frequency of behavior (better performance for more common ones) -Please have at least a few hundred frames for each behavior before further inspection is needed. +Please have at least a few hundred frames for each behavior before further inspection is needed. diff --git a/docs/using_CLI.md b/docs/using_CLI.md index 8e9fb5c..3e836a4 100644 --- a/docs/using_CLI.md +++ b/docs/using_CLI.md @@ -1,10 +1,10 @@ # Using the command line interface A common way to implement a simple command line interface is to use python's builtin [argparse module](https://docs.python.org/3/library/argparse.html). -However, for this project, we have multiple model types, which share some hyperparameters (such as learning rate) while also -having unique hyperparameters (such as the loss function used for the optic flow generator). Furthermore, I've put a lot of +However, for this project, we have multiple model types, which share some hyperparameters (such as learning rate) while also +having unique hyperparameters (such as the loss function used for the optic flow generator). Furthermore, I've put a lot of thought into default hyperparameters, so I want to be able to include defaults. Finally, each user must override specific hyperparameters -for their own project, such as the names of the different behaviors. Therefore, for the CLI, we want to be able to +for their own project, such as the names of the different behaviors. Therefore, for the CLI, we want to be able to * have nested arguments, such as train.learning_rate, train.scheduler, etc * be able to use configuration files to load many parameters at once * be able to override defaults with our own configuration files @@ -17,11 +17,11 @@ For all DeepEthogram projects, we [expect a consistent file structure](file_stru `project.config_file=path/to/config/file.yaml` or `project.path=path/to/deepethogram_project` ## Creating a project in code -If you don't want to use the GUI, you still need to set up your project with the [consistent file structure](file_structure.md). +If you don't want to use the GUI, you still need to set up your project with the [consistent file structure](file_structure.md). ### Make a project directory -You will need to open a python interpreter, or launch a Jupyter notebook. +You will need to open a python interpreter, or launch a Jupyter notebook. First, let's create a project directory with a properly formatted `project_config.yaml` ```python @@ -34,12 +34,12 @@ data_path = '/mnt/DATA' # pick a project name project_name = 'open_field' -# make a list of behaviors. try to choose your behaviors carefully! -# you'll have to re-train all your models if you add or remove behaviors. -behaviors = ['background', - 'groom', - 'locomote', - 'etc1', +# make a list of behaviors. try to choose your behaviors carefully! +# you'll have to re-train all your models if you add or remove behaviors. +behaviors = ['background', + 'groom', + 'locomote', + 'etc1', 'etc2'] # this will create a folder called /mnt/DATA/open_field_deepethogram @@ -50,23 +50,23 @@ project_config = projects.initialize_project(data_path, project_name, behaviors) ### add videos and labels -This presumes you have a list of movies ready for training, and separately have labeled frames. +This presumes you have a list of movies ready for training, and separately have labeled frames. -The labels should be `.csv` files with this format: -* there should be one row per frame. The label CSV should have the same number of rows as the video has frames. +The labels should be `.csv` files with this format: +* there should be one row per frame. The label CSV should have the same number of rows as the video has frames. * there should be one column for each behavior. the name of the column should be the name of the behavior. The order -should be the same as you specified in the `project_config` above. -* the first column should be called "background", and it is the logical not of any of the other columns being one. +should be the same as you specified in the `project_config` above. +* the first column should be called "background", and it is the logical not of any of the other columns being one. * NOTE: if you don't have this, `projects.add_label_to_project` will do this for you! -* there should be a 1 if the labeled behavior is present on this frame, a zero otherwise. +* there should be a 1 if the labeled behavior is present on this frame, a zero otherwise. ![label format screenshot](images/label_format.png) -Here's an example in code: +Here's an example in code: ```python # adding videos -list_of_movies = ['/path/to/movie1.avi', +list_of_movies = ['/path/to/movie1.avi', '/path/to/movie2.avi'] mode = 'copy' # or 'symlink' or 'move' @@ -77,11 +77,11 @@ for movie_path in list_of_movies: # now, we have our new movie files properly in our deepethogram project -new_list_of_movies = ['/mnt/DATA/open_field_deepethogram/DATA/movie1.avi', +new_list_of_movies = ['/mnt/DATA/open_field_deepethogram/DATA/movie1.avi', '/mnt/DATA/open_field_deepethogram/DATA/movie2.avi'] # we also have a list of label files, created by some other means -list_of_labels = ['/mnt/DATA/my_other_project/movie1/labels.csv', +list_of_labels = ['/mnt/DATA/my_other_project/movie1/labels.csv', '/mnt/DATA/my_other_project/movie2/labels.csv'] for movie_path, label_path in zip(new_list_of_movies, list_of_labels): @@ -92,18 +92,18 @@ for movie_path, label_path in zip(new_list_of_movies, list_of_labels): For detailed instructions, please go to [the project README's pretrained models section](../README.md) ## Training examples -To train the flow generator with the larger MotionNet architecture and a batch size of 16: +To train the flow generator with the larger MotionNet architecture and a batch size of 16: `deepethogram.flow_generator.train project.config_file=path/to/config/file.yaml flow_generator.arch=MotionNet compute.batch_size=16` -To train the feature extractor with the ResNet18 base, without the curriculum training, with an initial learning rate of 1e-5: +To train the feature extractor with the ResNet18 base, without the curriculum training, with an initial learning rate of 1e-5: `deepethogram.feature_extractor.train project.config_file=path/to/config/file.yaml feature_extractor.arch=resnet18 train.lr=1e-5 feature_extractor.curriculum=false notes=no_curriculum` -To train the flow generator with specific weights loaded from disk, with a specific train/test split, with the DEG_s preset (3D MotionNet): +To train the flow generator with specific weights loaded from disk, with a specific train/test split, with the DEG_s preset (3D MotionNet): `python -m deepethogram.flow_generator.train project.config_file=path/to/config/file.yaml reload.weights=path/to/flow/weights.pt split.file=path/to/split.yaml preset=deg_s` To train the feature extractor on the secondary GPU with the latest optic flow weights, but a specific feature extractor weights: `python -m deepethogram.feature_extractor.train project.config_file=path/to/config/file.yaml compute.gpu_id=1 flow_generator.weights=latest feature_extractor.weights=path/to/kinetics_weights.pt` # Questions? -For any questions on how to use the command line interface for your training, please raise an issue on GitHub. \ No newline at end of file +For any questions on how to use the command line interface for your training, please raise an issue on GitHub. diff --git a/docs/using_code.md b/docs/using_code.md index 3a5fd8e..1fee45a 100644 --- a/docs/using_code.md +++ b/docs/using_code.md @@ -1,3 +1,3 @@ # Expected filepaths -TODO \ No newline at end of file +TODO diff --git a/docs/using_config_files.md b/docs/using_config_files.md index 40f3793..0280fe9 100644 --- a/docs/using_config_files.md +++ b/docs/using_config_files.md @@ -2,52 +2,52 @@ DeepEthogram uses configuration files (.yaml) to save information and load hyperparameters. [For reasoning, see the CLI docs](using_CLI.md). In each project directory is a file called `project_config.yaml`. There, you can edit model -architectures, change the batch size, specify the learning rate, etc. Both the GUI and the command line interface -will overwrite the defaults with whatever you have in this configuration file. +architectures, change the batch size, specify the learning rate, etc. Both the GUI and the command line interface +will overwrite the defaults with whatever you have in this configuration file. Hyperparameters can be specified in multiple places. Let's say you want to experiment with adding more regularization -to your model. The weight decay is specified in multiple places: +to your model. The weight decay is specified in multiple places: * the defaults in `deepethogram\conf\model\feature_extractor.yaml` * maybe in your project configuration file * the command line with `feature_extractor.weight_decay=0.01` -Per [the hydra docs](hydra.cc), the loading order is as follows, with the last one actually being used: +Per [the hydra docs](hydra.cc), the loading order is as follows, with the last one actually being used: `default -> project config -> command line`. This means even if you normally use `weight_decay=0.001` in your project -configuration, you can still run experiments at the command line. +configuration, you can still run experiments at the command line. ## How to edit your project configuration file -Navigate to your project dictionary with your `project_config.yaml` file. Importantly: **not every hyperparameter +Navigate to your project dictionary with your `project_config.yaml` file. Importantly: **not every hyperparameter will be in the default project configuration! This means if you want to edit a less-used hyperparameter, you'll have to add lines to the configuration, not just edit them.** -The left part of the below screenshot shows the default project configuration in `deepethogram/conf/project/project_config.yaml`. +The left part of the below screenshot shows the default project configuration in `deepethogram/conf/project/project_config.yaml`. The right shows my configuration for the `mouse_reach` example [used in the GUI docs](using_gui.md). ![screenshot of config dictionary](images/project_config.png) -Note the following: +Note the following: * `augs/normalization` has been overwritten for the statistics of my dataset * I've edited `augs/resize` to resize my data to be 128 tall and 256 wide. This speeds up training, and the images are -not square due to the multiple views that I've concatenated side by side ([see the GUI docs](using_gui.md)). +not square due to the multiple views that I've concatenated side by side ([see the GUI docs](using_gui.md)). * the `project` dictionary has been totally changed -* I've added lines to the train dictionary: +* I've added lines to the train dictionary: * `patience: 10`: this means the learning rate will only be reduced if learning stalls for 10 epochs (see `deepethogram/conf/train.yaml` for explanation) * `reduction_factor: 0.3162277`: this means the learning rate will be reduced by this value (1 / sqrt(10), which means the learning rate will go down by a factor of 0.1 after two steps) - + This is how to edit a configuration file. You *add or edit fields in your project config in the nested structure shown in `deepethogram/conf`.* -For example, the `train` dictionary in the config file takes values shown in `deepethogram/conf/train.yaml`. -To find out what hyperparameters there are and what values they can take, -read through the configuration files in `deepethogram/conf`. The most commonly edited ones are already in the -default `project_config.yaml`, such as batch size and image augmentation. +For example, the `train` dictionary in the config file takes values shown in `deepethogram/conf/train.yaml`. +To find out what hyperparameters there are and what values they can take, +read through the configuration files in `deepethogram/conf`. The most commonly edited ones are already in the +default `project_config.yaml`, such as batch size and image augmentation. ## Creating configuration files in code If you want to use functions that require a configuration file, but want to use the codebase instead of the command -line interface, you can create a nested dictionary with your configuration, then do as follows: +line interface, you can create a nested dictionary with your configuration, then do as follows: ```python from omegaconf import OmegaConf -nested_dictionary = {'project': {'class_names': ['background', 'etc', 'etc']}, +nested_dictionary = {'project': {'class_names': ['background', 'etc', 'etc']}, 'etc': 'etc'} cfg = OmegaConf.create(nested_dictionary) print(cfg.pretty()) @@ -57,4 +57,4 @@ print(cfg.pretty()) # - background # - etc # - etc -``` \ No newline at end of file +``` diff --git a/docs/using_gui.md b/docs/using_gui.md index 313e42b..10f5781 100644 --- a/docs/using_gui.md +++ b/docs/using_gui.md @@ -3,7 +3,7 @@ To open: After [installation](installation.md), open a terminal, activate your conda environment, and type `deepethogram`. To start your project, see the [getting started](getting_started.md) guide. The images in this guide are from a project -with multiple trained models, for illustration. Note: all examples are used from the [mouse reach dataset, available here](http://research.janelia.org/bransonlab/MouseReachData/). +with multiple trained models, for illustration. Note: all examples are used from the [mouse reach dataset, available here](http://research.janelia.org/bransonlab/MouseReachData/). I've added my own labels, using the start frame labels provided. When using the GUI, be sure to keep open and look at the terminal window. Useful information will be displayed there. @@ -15,78 +15,78 @@ When using the GUI, be sure to keep open and look at the terminal window. Useful 2. Menu - File 1. New project: create an existing project. See [this doc file](getting_started.md) - 2. Open project: opens an existing project. Automatically opens the most recently added video, and imports any - existing labels or predictions. Changes the working directory to the project directory. - 3. Save project (ctrl+s): Saves the current labels to disk. Note: if you haven't *finalized labels* (**7i**), - currently unlabeled frames will be labeled as `-1`. DeepEthogram will ignore these videos when training! + 2. Open project: opens an existing project. Automatically opens the most recently added video, and imports any + existing labels or predictions. Changes the working directory to the project directory. + 3. Save project (ctrl+s): Saves the current labels to disk. Note: if you haven't *finalized labels* (**7i**), + currently unlabeled frames will be labeled as `-1`. DeepEthogram will ignore these videos when training! - Behaviors - 1. Add: adds a behavior to the current project. Use sparingly! All previous videos will have to be re-labeled. - Existing models will have to be re-trained. Avoid if at all possible. + 1. Add: adds a behavior to the current project. Use sparingly! All previous videos will have to be re-labeled. + Existing models will have to be re-trained. Avoid if at all possible. 2. Remove: removes a behavior from the current project. Use sparingly! All previous videos will have their labels - erased, never to be recovered. Models will have to be re-trained. Avoid if at all possible. + erased, never to be recovered. Models will have to be re-trained. Avoid if at all possible. - Video 1. Add or open: Opens a video file. If it's a part of the current DeepEthogram project, also imports any existing labels and predictions. Otherwise, it will copy the video to your project directory. It will be displayed in the video window (**11**). See [the add video documentation](getting_started.md#add-videos) - 2. Add multiple: with this you can select multiple video files to add to your project at once. + 2. Add multiple: with this you can select multiple video files to add to your project at once. - Import - 1. Labels: imports labels from disk. This requires label files to have [the same structure as deepethogram labels](file_structure.md). - Not generally necessary for use. + 1. Labels: imports labels from disk. This requires label files to have [the same structure as deepethogram labels](file_structure.md). + Not generally necessary for use. - Batch 1. Feature extractor inference + sequence inference: works the same as pushing feature extractor / infer (**6ii**) - followed by sequence / infer (**7ii**). + followed by sequence / infer (**7ii**). 2. Overnight: if you collected experimental videos during the day, this is a good button to push before leaving - lab for the night. Does the following tasks in sequence: trains your existing flow generator with the new data (**4i**), - extracts features (**6ii**), and then generates sequence predictions (**7ii**). In the morning when you arrive back to lab, - your new videos should have predictions ready for you to edit. + lab for the night. Does the following tasks in sequence: trains your existing flow generator with the new data (**4i**), + extracts features (**6ii**), and then generates sequence predictions (**7ii**). In the morning when you arrive back to lab, + your new videos should have predictions ready for you to edit. 3. Video information box - Name: the filename of the video - N frames: number of total frames in the video - N labeled: number of frames with user-edited labels - N unlabeled: number of frames without user labels - - FPS: frames per second of the current video. If the filetype does not have a FPS value (for example, a folder of images), + - FPS: frames per second of the current video. If the filetype does not have a FPS value (for example, a folder of images), it will say N/A. The current video is an HDF5 file, which is just a wrapper around a list of images. Therefore, it has no FPS. - - Duration: video duration in seconds. + - Duration: video duration in seconds. 4. Flow generator box 1. Train button: Train the flow generator. It will use hyperparameters from your project configuration file (or defaults specified in `deepethogram/conf`). See [using configuration files for details](using_config_files.md). This includes - the model architecture (TinyMotionNet, MotionNet, TinyMotionNet3D). The weights to pre-load will be specified by the + the model architecture (TinyMotionNet, MotionNet, TinyMotionNet3D). The weights to pre-load will be specified by the model selector (**4ii**) 2. Model selector: choose the weights to pre-load. *Note: if model architecture specified in your project configuration file does not match that of the selected weight file, model will train without properly loading weights!* 5. Feature extractor box 1. Train button: Train the feature extractor. It will use hyperparameters from your project configuration file (or defaults - specified in `deepethogram/conf`). See [using configuration files for details](using_config_files.md). This will - take a long time, perhaps overnight. To speed training potentially at the cost of model performance, set - `feature_extractor/curriculum = false` in your project configuration file. + specified in `deepethogram/conf`). See [using configuration files for details](using_config_files.md). This will + take a long time, perhaps overnight. To speed training potentially at the cost of model performance, set + `feature_extractor/curriculum = false` in your project configuration file. 2. inference button: Run inference using the feature extractor models. Clicking this button will open a list of videos - in your project with check boxes. Videos without output files ([see file structure](file_structure.md)) will be + in your project with check boxes. Videos without output files ([see file structure](file_structure.md)) will be pre-selected. For each video you select, the feature extractors will run frame-by-frame and extract spatial - features, flow features, and predictions and save them to disk. These will be loaded as inputs to the sequence + features, flow features, and predictions and save them to disk. These will be loaded as inputs to the sequence models (below). This may take a long time, as inference takes around 30-60FPS depending on video resolution and - model complexity (see paper). - 3. model selector: A list of models in your [models directory](file_structure.md). These are the weights that will + model complexity (see paper). + 3. model selector: A list of models in your [models directory](file_structure.md). These are the weights that will be loaded and fine-tuned when you train (**5i**) and used to run inference (**5ii**) 6. Sequence box 1. Train button: Train the sequence model. It will use hyperparameters from your project configuration file (or defaults specified in `deepethogram/conf`). See [using configuration files for details](using_config_files.md). - 2. Inference button: Run inference with your sequence model. Clicking this button will load a list of videos in - your project that already have some features extracted. Output files that do not have any predictions from the + 2. Inference button: Run inference with your sequence model. Clicking this button will load a list of videos in + your project that already have some features extracted. Output files that do not have any predictions from the currently selected (**6iii**) sequence architecture will be automatically pre-selected. If you don't see a video here, you need to run inference with your feature extractor first (**5ii**). This runs extremely fast, and should - only take a few seconds for any number of videos. + only take a few seconds for any number of videos. 7. Label box 1. Finalize labels: when labeling a video, particularly with rare behaviors, the vast majority of frames will be "background" - (see paper). We want to be able to tell the difference between *this frame is background* and *I have not looked + (see paper). We want to be able to tell the difference between *this frame is background* and *I have not looked at or labeled this frame yet*, so that you can partially label a video and then return to it. By default, when saving - the project (**menu bar/file/save**), unlabeled frames are set to `-1` and the video is not used for training. + the project (**menu bar/file/save**), unlabeled frames are set to `-1` and the video is not used for training. When you've fully labeled a video, instead of affirmatively going through every frame and labeling them as *background*, we will use this button. When you press this button, all unlabeled frames will be set to background, and the video will be considered fully labeled. This video will be used by DeepEthogram for training (**5i, 6i**) - 2. Import predictions as labels: When you've trained all your models, added a new video, and run inference, you - will have a set of predictions (**10**). These could be useful to look at while manually labeling, or you can - move them to the labeling window (**9**) to be manually edited. Pushing this button will do that. + 2. Import predictions as labels: When you've trained all your models, added a new video, and run inference, you + will have a set of predictions (**10**). These could be useful to look at while manually labeling, or you can + move them to the labeling window (**9**) to be manually edited. Pushing this button will do that. 8. Predictions box 1. Predictions selector: For each video, you can have multiple predictions. For example, they could be the predictions from the feature extractor as well as the sequence model. The key will be the `sequence.latent_name` used when you @@ -96,60 +96,56 @@ When using the GUI, be sure to keep open and look at the terminal window. Useful 1. Label buttons: these buttons are the ordered list of behaviors for your project. Each button corresponds to a row in the label viewer (**9ii**). Pushing this button will toggle the behavior (see below). Background is a special class: it is mutually exclusive with the other behaviors. Pushing the keyboard number keys denoted in brackets - will also toggle the behavior. For more information, see [the labeling section](using_gui.md#labeling) below. - 2. Label viewer: This is where you can view your manual labels. The current video frame is denoted with the - current frame indicator (**13**). Unlabeled frames will be partially transparent, with the "background" class pre-selected. + will also toggle the behavior. For more information, see [the labeling section](using_gui.md#labeling) below. + 2. Label viewer: This is where you can view your manual labels. The current video frame is denoted with the + current frame indicator (**13**). Unlabeled frames will be partially transparent, with the "background" class pre-selected. Labeled frames will be opaque (see image). The current frame is the "hand" behavior (in red). No other behaviors - are present on this frame. A few frames behind, the animal was performing the "lift" behavior. For more information, + are present on this frame. A few frames behind, the animal was performing the "lift" behavior. For more information, see [the labeling section](using_gui.md#labeling) below. 10. Predictions viewer * This is where you can see the currently selected (**8i**) model prediction. The probabilities of each behavior - are shown transparently. The predictions (thresholded probability) are opaque. At the left of the image, + are shown transparently. The predictions (thresholded probability) are opaque. At the left of the image, note that the "lift" behavior is wrongly predicted to be "background". However, the transparent probabilities are still - dark red for both "lift" and "hand" at this point. This shows that the model suspects it might be "lift" or "hand", - but these behaviors are just below the threshold. At the right of the image, note that around the "grab" behavior, + dark red for both "lift" and "hand" at this point. This shows that the model suspects it might be "lift" or "hand", + but these behaviors are just below the threshold. At the right of the image, note that around the "grab" behavior, the probabilities are high +/- a few frames from the true behavior. The model "knows" this section of video transitions - from "grab" to "supinate" to "mouth", but is just off by a few frames. + from "grab" to "supinate" to "mouth", but is just off by a few frames. 11. Video window * Your video will be displayed here. It will be resized to take up as much of the screen as possible 12. Video navigation - 1. Scroll bar: use this to navigate quickly through the video. - 2. Frame text box: this shows the current frame (**13**). Edit this text box to jump to a specific video frame. + 1. Scroll bar: use this to navigate quickly through the video. + 2. Frame text box: this shows the current frame (**13**). Edit this text box to jump to a specific video frame. 13. Current frame indicator * This denotes which frame in the label viewer and predictions viewer (**9ii, 10**) correspond to the shown video frame. - Toggling the label button will label this frame (**9i**). - + Toggling the label button will label this frame (**9i**). + ## Labeling -The label is a matrix. It has K rows (behaviors) and T columns (timepoints). The goal is for this matrix to have 1s -when behavior *k* is present on frame *t*, and 0s everywhere else. The special behavior "background" will be 1 when -there are no user-defined behaviors present on that frame. There are multiple ways to label a frame: +The label is a matrix. It has K rows (behaviors) and T columns (timepoints). The goal is for this matrix to have 1s +when behavior *k* is present on frame *t*, and 0s everywhere else. The special behavior "background" will be 1 when +there are no user-defined behaviors present on that frame. There are multiple ways to label a frame: #### Toggling -When you toggle a behavior, you start editing that behavior on the current frame (**13**). When a behavior is toggled, -moving [forward in time](using_gui.md#video-navigation) will *add that behavior*. Moving backwards in time will -*erase that behavior*. You can toggle a behavior one of two ways: +When you toggle a behavior, you start editing that behavior on the current frame (**13**). When a behavior is toggled, +moving [forward in time](using_gui.md#video-navigation) will *add that behavior*. Moving backwards in time will +*erase that behavior*. You can toggle a behavior one of two ways: 1. Clicking the label buttons (**9i**) 2. Pushing the corresponding number key on your keyboard #### Clicking -You can also click directly on the label viewer (**9ii**). -* Clicking an unlabeled frame will add that behavior. -* Clicking a labeled element of the matrix will erase that label. +You can also click directly on the label viewer (**9ii**). +* Clicking an unlabeled frame will add that behavior. +* Clicking a labeled element of the matrix will erase that label. * Clicking, holding, and dragging to the right will add that behavior on all those frames -* Clicking, holding, and dragging to the left will remove that behavior on all those frames. +* Clicking, holding, and dragging to the left will remove that behavior on all those frames. ## Video navigation -There are multiple ways to change video frames using the GUI. +There are multiple ways to change video frames using the GUI. -1. Use the arrow keys on your keyboard (recommended). Pressing `Ctrl+arrow` will jump further, 30 frames by default. +1. Use the arrow keys on your keyboard (recommended). Pressing `Ctrl+arrow` will jump further, 30 frames by default. 2. Use the scroll bar and arrows (**12i**) 3. Edit the frame number in the frame text box (**12ii**) ## Zoom -As of 0.1.4, you can zoom in/out of videos. Double-click on the video to make it fit to your current window. Use the +As of 0.1.4, you can zoom in/out of videos. Double-click on the video to make it fit to your current window. Use the mouse scroll wheel to zoom in/out. - - - - \ No newline at end of file diff --git a/docs/using_tune.md b/docs/using_tune.md index 7260c38..74d96ae 100644 --- a/docs/using_tune.md +++ b/docs/using_tune.md @@ -17,4 +17,4 @@ for bugs like "could not terminate" "/usr/bin/redis-server 127.0.0.1:6379" "" "" "" "" "" "" ""` due to psutil.AccessDenied (pid=56271, name='redis-server') sudo /etc/init.d/redis-server stop if you have a GPU you can't use for training (e.g. I have a tiny, old GPU just for my monitors) exclude that -using command line arguments. e.g. CUDA_VISIBLE_DEVICES=0,1 ray start --head \ No newline at end of file +using command line arguments. e.g. CUDA_VISIBLE_DEVICES=0,1 ray start --head diff --git a/environment.yml b/environment.yml index ba10d1b..bb3dccc 100644 --- a/environment.yml +++ b/environment.yml @@ -9,4 +9,4 @@ dependencies: - python>3.7, <3.9 - pytorch::pytorch - pip: - - -r requirements.txt \ No newline at end of file + - -r requirements.txt diff --git a/license.txt b/license.txt index a5751f9..b9595f6 100644 --- a/license.txt +++ b/license.txt @@ -36,4 +36,4 @@ E-mail: otd@harvard.edu 12. NON-USE OF NAME. Nothing in this License and Terms of Use shall be construed as granting End Users or their Institutions any rights or licenses to use any trademarks, service marks or logos associated with the Software. You may not use the terms “Harvard” (or a substantially similar term) in any way that is inconsistent with the permitted uses described herein. You agree not to use any name or emblem of Harvard or any of its subdivisions for any purpose, or to falsely suggest any relationship between End User (or its Institution) and Harvard, or in any manner that would infringe or violate any of Harvard’s rights. - 13. End User represents and warrants that it has the legal authority to enter into this License and Terms of Use on behalf of itself and its Institution. \ No newline at end of file + 13. End User represents and warrants that it has the legal authority to enter into this License and Terms of Use on behalf of itself and its Institution. diff --git a/requirements.txt b/requirements.txt index a1a4131..bace43e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ scipy<1.8 tqdm vidio pytorch_lightning==1.6.5 -ruff>=0.1.0 \ No newline at end of file +ruff>=0.1.0 +pre-commit>=2.20.0,<3.0.0 diff --git a/setup_tests.py b/setup_tests.py index 781d3d6..0351233 100644 --- a/setup_tests.py +++ b/setup_tests.py @@ -31,9 +31,7 @@ def download_file(url, destination): # Print progress if total_size > 0: progress = int(50 * downloaded / total_size) - sys.stdout.write( - f"\r[{'=' * progress}{' ' * (50 - progress)}] {downloaded}/{total_size} bytes" - ) + sys.stdout.write(f"\r[{'=' * progress}{' ' * (50 - progress)}] {downloaded}/{total_size} bytes") sys.stdout.flush() print("\nDownload completed!") @@ -52,9 +50,7 @@ def setup_tests(): try: print("Downloading test data archive...") - gdown.download( - id="1IFz4ABXppVxyuhYik8j38k9-Fl9kYKHo", output=str(zip_path), quiet=False - ) + gdown.download(id="1IFz4ABXppVxyuhYik8j38k9-Fl9kYKHo", output=str(zip_path), quiet=False) print("Extracting archive...") with zipfile.ZipFile(zip_path, "r") as zip_ref: @@ -64,9 +60,7 @@ def setup_tests(): archive_path = data_dir / "testing_deepethogram_archive" required_items = ["DATA", "models", "project_config.yaml"] - missing_items = [ - item for item in required_items if not (archive_path / item).exists() - ] + missing_items = [item for item in required_items if not (archive_path / item).exists()] if missing_items: print(f"Warning: The following items are missing: {missing_items}") @@ -74,9 +68,7 @@ def setup_tests(): print("Setup completed successfully!") print("\nYou can now run the tests using: pytest tests/") - print( - "Note: The zz_commandline test module will take a few minutes to complete." - ) + print("Note: The zz_commandline test module will take a few minutes to complete.") # Clean up the zip file zip_path.unlink() diff --git a/tests/setup_data.py b/tests/setup_data.py index 19fb70a..2343b16 100644 --- a/tests/setup_data.py +++ b/tests/setup_data.py @@ -8,15 +8,15 @@ test_path = os.path.dirname(this_path) deg_path = os.path.dirname(test_path) -test_data_path = os.path.join(test_path, 'DATA') +test_data_path = os.path.join(test_path, "DATA") # the deepethogram test archive should only be read from, never written to -archive_path = os.path.join(test_data_path, 'testing_deepethogram_archive') -assert os.path.isdir(archive_path), '{} does not exist!'.format(archive_path) -project_path = os.path.join(test_data_path, 'testing_deepethogram') -data_path = os.path.join(project_path, 'DATA') +archive_path = os.path.join(test_data_path, "testing_deepethogram_archive") +assert os.path.isdir(archive_path), "{} does not exist!".format(archive_path) +project_path = os.path.join(test_data_path, "testing_deepethogram") +data_path = os.path.join(project_path, "DATA") -config_path = os.path.join(project_path, 'project_config.yaml') -config_path_archive = os.path.join(archive_path, 'project_config.yaml') +config_path = os.path.join(project_path, "project_config.yaml") +config_path_archive = os.path.join(archive_path, "project_config.yaml") # config_path = os.path.join(project_path, 'project_config.yaml') cfg_archive = projects.get_config_from_path(archive_path) @@ -39,10 +39,10 @@ def make_project_from_archive(): # projects.fix_config_paths(cfg) -def get_records(origin='project'): - if origin == 'project': +def get_records(origin="project"): + if origin == "project": return projects.get_records_from_datadir(data_path) - elif origin == 'archive': - return projects.get_records_from_datadir(os.path.join(archive_path, 'DATA')) + elif origin == "archive": + return projects.get_records_from_datadir(os.path.join(archive_path, "DATA")) else: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/tests/test_data.py b/tests/test_data.py index 25067b4..bc8e0e5 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,4 +1,3 @@ - import numpy as np from deepethogram.data import utils as data_utils @@ -9,10 +8,9 @@ def test_loss_weight(): num_pos = np.array([1, 2]) num_neg = np.array([2, 1]) - pos_weight_transformed, softmax_weight_transformed = data_utils.make_loss_weight(class_counts, - num_pos, - num_neg, - weight_exp=1.0) + pos_weight_transformed, softmax_weight_transformed = data_utils.make_loss_weight( + class_counts, num_pos, num_neg, weight_exp=1.0 + ) assert np.allclose(pos_weight_transformed, np.array([2, 0.5])) assert np.allclose(softmax_weight_transformed, np.array([2 / 3, 1 / 3])) @@ -20,12 +18,11 @@ def test_loss_weight(): num_pos = np.array([0, 300]) num_neg = np.array([300, 0]) - pos_weight_transformed, softmax_weight_transformed = data_utils.make_loss_weight(class_counts, - num_pos, - num_neg, - weight_exp=1.0) + pos_weight_transformed, softmax_weight_transformed = data_utils.make_loss_weight( + class_counts, num_pos, num_neg, weight_exp=1.0 + ) print(pos_weight_transformed, softmax_weight_transformed) assert np.allclose(pos_weight_transformed, np.array([0, 1])) assert np.allclose(softmax_weight_transformed, np.array([0, 1])) # assert np.allclose(pos_weight_transformed, np.array([2, 0.5])) - # assert np.allclose(softmax_weight_transformed, np.array([2 / 3, 1 / 3])) \ No newline at end of file + # assert np.allclose(softmax_weight_transformed, np.array([2 / 3, 1 / 3])) diff --git a/tests/test_flow_generator.py b/tests/test_flow_generator.py index f4f4e12..b95701b 100644 --- a/tests/test_flow_generator.py +++ b/tests/test_flow_generator.py @@ -2,9 +2,13 @@ from deepethogram import projects, utils, viz from deepethogram.configuration import make_flow_generator_train_cfg -from deepethogram.flow_generator.train import (get_datasets_from_cfg, build_model_from_cfg, get_metrics, - OpticalFlowLightning) -from setup_data import (make_project_from_archive, project_path) +from deepethogram.flow_generator.train import ( + get_datasets_from_cfg, + build_model_from_cfg, + get_metrics, + OpticalFlowLightning, +) +from setup_data import make_project_from_archive, project_path def test_metrics(): @@ -12,16 +16,16 @@ def test_metrics(): cfg = make_flow_generator_train_cfg(project_path=project_path) cfg = projects.setup_run(cfg) - datasets, data_info = get_datasets_from_cfg(cfg, 'flow_generator', input_images=cfg.flow_generator.n_rgb) + datasets, data_info = get_datasets_from_cfg(cfg, "flow_generator", input_images=cfg.flow_generator.n_rgb) flow_generator = build_model_from_cfg(cfg) - utils.save_dict_to_yaml(data_info['split'], os.path.join(os.getcwd(), 'split.yaml')) - flow_weights = projects.get_weightfile_from_cfg(cfg, 'flow_generator') + utils.save_dict_to_yaml(data_info["split"], os.path.join(os.getcwd(), "split.yaml")) + flow_weights = projects.get_weightfile_from_cfg(cfg, "flow_generator") if flow_weights is not None: - print('reloading weights...') - flow_generator = utils.load_weights(flow_generator, flow_weights, device='cpu') + print("reloading weights...") + flow_generator = utils.load_weights(flow_generator, flow_weights, device="cpu") # stopper = get_stopper(cfg) metrics = get_metrics(cfg, os.getcwd(), utils.get_num_parameters(flow_generator)) lightning_module = OpticalFlowLightning(flow_generator, cfg, datasets, metrics, viz.visualize_logger_optical_flow) - assert lightning_module.scheduler_mode == 'min' - assert lightning_module.metrics.key_metric == 'SSIM' \ No newline at end of file + assert lightning_module.scheduler_mode == "min" + assert lightning_module.metrics.key_metric == "SSIM" diff --git a/tests/test_models.py b/tests/test_models.py index 25c1954..7ceb0ad 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -5,7 +5,7 @@ def test_get_cnn(): - model_name = 'resnet18' + model_name = "resnet18" num_classes = 2 pos = np.array([0, 300]) @@ -15,4 +15,4 @@ def test_get_cnn(): bias = list(model.children())[-1].bias assert torch.allclose(bias, torch.Tensor([0, 1]).float()) - print() \ No newline at end of file + print() diff --git a/tests/test_projects.py b/tests/test_projects.py index c1cea2c..4da8c2c 100644 --- a/tests/test_projects.py +++ b/tests/test_projects.py @@ -16,57 +16,51 @@ def test_initialization(): clean_test_data() with pytest.raises(AssertionError): - project_dict = projects.initialize_project(test_data_path, 'testing', - ['scratch', 'itch']) + project_dict = projects.initialize_project(test_data_path, "testing", ["scratch", "itch"]) - project_dict = projects.initialize_project( - test_data_path, 'testing', ['background', 'scratch', 'itch']) + project_dict = projects.initialize_project(test_data_path, "testing", ["background", "scratch", "itch"]) # print(project_dict) # print(project_dict['project']) - assert os.path.isdir(project_dict['project']['path']) - assert project_dict['project']['path'] == project_path + assert os.path.isdir(project_dict["project"]["path"]) + assert project_dict["project"]["path"] == project_path - data_abs = os.path.join(project_dict['project']['path'], - project_dict['project']['data_path']) + data_abs = os.path.join(project_dict["project"]["path"], project_dict["project"]["data_path"]) assert os.path.isdir(data_abs) - model_abs = os.path.join(project_dict['project']['path'], - project_dict['project']['model_path']) + model_abs = os.path.join(project_dict["project"]["path"], project_dict["project"]["model_path"]) assert os.path.isdir(model_abs) # mouse01 tests image directories -@pytest.mark.parametrize('key', ['mouse00', 'mouse01']) +@pytest.mark.parametrize("key", ["mouse00", "mouse01"]) def test_add_video(key): make_project_from_archive() - project_dict = projects.load_config( - os.path.join(project_path, 'project_config.yaml')) + project_dict = projects.load_config(os.path.join(project_path, "project_config.yaml")) # project_dict = utils.load_yaml() - key_path = os.path.join(project_path, 'DATA', key) + key_path = os.path.join(project_path, "DATA", key) assert os.path.isdir(key_path) shutil.rmtree(key_path) - records = get_records('archive') + records = get_records("archive") # test image directory - videofile = records[key]['rgb'] + videofile = records[key]["rgb"] print(project_dict) # this also z-scores, which is pretty slow projects.add_video_to_project(project_dict, videofile) - assert os.path.isdir(os.path.join(project_path, 'DATA', key)) - assert os.path.exists( - os.path.join(project_path, 'DATA', key, os.path.basename(videofile))) + assert os.path.isdir(os.path.join(project_path, "DATA", key)) + assert os.path.exists(os.path.join(project_path, "DATA", key, os.path.basename(videofile))) -@pytest.mark.parametrize('key', ['mouse00', 'mouse01']) +@pytest.mark.parametrize("key", ["mouse00", "mouse01"]) def test_is_deg_file(key): make_project_from_archive() records = get_records() - rgb = records[key]['rgb'] + rgb = records[key]["rgb"] assert projects.is_deg_file(rgb) - record_yaml = os.path.join(os.path.dirname(rgb), 'record.yaml') + record_yaml = os.path.join(os.path.dirname(rgb), "record.yaml") assert os.path.isfile(record_yaml) os.remove(record_yaml) @@ -75,51 +69,51 @@ def test_is_deg_file(key): def test_add_behavior(): make_project_from_archive() - cfg_path = os.path.join(project_path, 'project_config.yaml') + cfg_path = os.path.join(project_path, "project_config.yaml") - projects.add_behavior_to_project(cfg_path, 'A') + projects.add_behavior_to_project(cfg_path, "A") records = get_records() mice = list(records.keys()) - labelfile = records[random.choice(mice)]['label'] + labelfile = records[random.choice(mice)]["label"] assert os.path.isfile(labelfile) df = pd.read_csv(labelfile, index_col=0) assert df.shape[1] == 6 assert np.all(df.iloc[:, -1].values == -1) - assert df.columns[5] == 'A' + assert df.columns[5] == "A" def test_remove_behavior(): make_project_from_archive() - cfg_path = os.path.join(project_path, 'project_config.yaml') + cfg_path = os.path.join(project_path, "project_config.yaml") # can't remove behaviors that don't exist with pytest.raises(AssertionError): - projects.remove_behavior_from_project(cfg_path, 'A') + projects.remove_behavior_from_project(cfg_path, "A") # can't remove background with pytest.raises(ValueError): - projects.remove_behavior_from_project(cfg_path, 'background') + projects.remove_behavior_from_project(cfg_path, "background") - projects.remove_behavior_from_project(cfg_path, 'face_groom') + projects.remove_behavior_from_project(cfg_path, "face_groom") records = get_records() mice = list(records.keys()) - labelfile = records[random.choice(mice)]['label'] + labelfile = records[random.choice(mice)]["label"] assert os.path.isfile(labelfile) df = pd.read_csv(labelfile, index_col=0) assert df.shape[1] == 4 - assert 'face_groom' not in df.columns + assert "face_groom" not in df.columns -@pytest.mark.filterwarnings('ignore::UserWarning') +@pytest.mark.filterwarnings("ignore::UserWarning") def test_add_external_label(): make_project_from_archive() - mousedir = os.path.join(project_path, 'DATA', 'mouse06') - assert os.path.isdir(mousedir), '{} does not exist!'.format(mousedir) - labelfile = os.path.join(mousedir, 'labels.csv') - videofile = os.path.join(mousedir, 'mouse06.h5') + mousedir = os.path.join(project_path, "DATA", "mouse06") + assert os.path.isdir(mousedir), "{} does not exist!".format(mousedir) + labelfile = os.path.join(mousedir, "labels.csv") + videofile = os.path.join(mousedir, "mouse06.h5") projects.add_label_to_project(labelfile, videofile) -if __name__ == '__main__': - test_add_external_label() \ No newline at end of file +if __name__ == "__main__": + test_add_external_label() diff --git a/tests/test_z_score.py b/tests/test_z_score.py index 806923a..79c51fd 100644 --- a/tests/test_z_score.py +++ b/tests/test_z_score.py @@ -9,13 +9,13 @@ def test_single_video(): records = projects.get_records_from_datadir(data_path) - videofile = records['mouse00']['rgb'] + videofile = records["mouse00"]["rgb"] stats = get_video_statistics(videofile, 10) print(stats) mean = np.array([0.010965, 0.02345, 0.0161]) std = np.array([0.02623, 0.04653, 0.0349]) - assert np.allclose(stats['mean'], mean, rtol=0, atol=1e-4) - assert np.allclose(stats['std'], std, rtol=0, atol=1e-4) - assert stats['N'] == 1875000 \ No newline at end of file + assert np.allclose(stats["mean"], mean, rtol=0, atol=1e-4) + assert np.allclose(stats["std"], std, rtol=0, atol=1e-4) + assert stats["N"] == 1875000 diff --git a/tests/test_zz_commandline.py b/tests/test_zz_commandline.py index 4270a60..055b294 100644 --- a/tests/test_zz_commandline.py +++ b/tests/test_zz_commandline.py @@ -3,7 +3,7 @@ from deepethogram import utils -from setup_data import (make_project_from_archive, change_to_deepethogram_directory, config_path, data_path) +from setup_data import make_project_from_archive, change_to_deepethogram_directory, config_path, data_path # from setup_data import get_testing_directory # testing_directory = get_testing_directory() @@ -20,21 +20,21 @@ def command_from_string(string): - command = string.split(' ') - if command[-1] == '': + command = string.split(" ") + if command[-1] == "": command = command[:-1] print(command) return command def add_default_arguments(string, train=True): - string += f'project.config_file={config_path} ' - string += f'compute.batch_size={BATCH_SIZE} ' + string += f"project.config_file={config_path} " + string += f"compute.batch_size={BATCH_SIZE} " if train: - string += f'train.steps_per_epoch.train={STEPS_PER_EPOCH} train.steps_per_epoch.val={STEPS_PER_EPOCH} ' - string += f'train.steps_per_epoch.test={STEPS_PER_EPOCH} ' - string += f'train.num_epochs={NUM_EPOCHS} ' - string += f'train.viz_examples={VIZ_EXAMPLES}' + string += f"train.steps_per_epoch.train={STEPS_PER_EPOCH} train.steps_per_epoch.val={STEPS_PER_EPOCH} " + string += f"train.steps_per_epoch.test={STEPS_PER_EPOCH} " + string += f"train.num_epochs={NUM_EPOCHS} " + string += f"train.viz_examples={VIZ_EXAMPLES}" return string @@ -53,19 +53,19 @@ def add_default_arguments(string, train=True): def test_flow(): make_project_from_archive() - string = ('python -m deepethogram.flow_generator.train preset=deg_f ') + string = "python -m deepethogram.flow_generator.train preset=deg_f " string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) assert ret.returncode == 0 - string = ('python -m deepethogram.flow_generator.train preset=deg_m ') + string = "python -m deepethogram.flow_generator.train preset=deg_m " string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) assert ret.returncode == 0 - string = ('python -m deepethogram.flow_generator.train preset=deg_s ') + string = "python -m deepethogram.flow_generator.train preset=deg_s " string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) @@ -73,29 +73,33 @@ def test_flow(): def test_feature_extractor(): - string = ('python -m deepethogram.feature_extractor.train preset=deg_f flow_generator.weights=latest ') + string = "python -m deepethogram.feature_extractor.train preset=deg_f flow_generator.weights=latest " string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) assert ret.returncode == 0 - string = ('python -m deepethogram.feature_extractor.train preset=deg_m flow_generator.weights=latest ') + string = "python -m deepethogram.feature_extractor.train preset=deg_m flow_generator.weights=latest " string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) assert ret.returncode == 0 # for resnet3d, must specify weights, because we can't just download them from the torchvision repo - string = ('python -m deepethogram.feature_extractor.train preset=deg_s flow_generator.weights=latest ' - 'feature_extractor.weights=latest ') + string = ( + "python -m deepethogram.feature_extractor.train preset=deg_s flow_generator.weights=latest " + "feature_extractor.weights=latest " + ) string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) assert ret.returncode == 0 # testing softmax - string = ('python -m deepethogram.feature_extractor.train preset=deg_m flow_generator.weights=latest ' - 'feature_extractor.final_activation=softmax ') + string = ( + "python -m deepethogram.feature_extractor.train preset=deg_m flow_generator.weights=latest " + "feature_extractor.final_activation=softmax " + ) string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) @@ -104,17 +108,19 @@ def test_feature_extractor(): def test_feature_extraction(softmax: bool = False): # the reason for this complexity is that I don't want to run inference on all directories - string = ('python -m deepethogram.feature_extractor.inference preset=deg_f feature_extractor.weights=latest ' - 'flow_generator.weights=latest ') + string = ( + "python -m deepethogram.feature_extractor.inference preset=deg_f feature_extractor.weights=latest " + "flow_generator.weights=latest " + ) if softmax: - string += 'feature_extractor.final_activation=softmax ' + string += "feature_extractor.final_activation=softmax " # datadir = os.path.join(testing_directory, 'DATA') - subdirs = utils.get_subfiles(data_path, 'directory') + subdirs = utils.get_subfiles(data_path, "directory") # np.random.seed(42) # subdirs = np.random.choice(subdirs, size=100, replace=False) - dir_string = ','.join([str(i) for i in subdirs]) - dir_string = '[' + dir_string + ']' - string += f'inference.directory_list={dir_string} inference.overwrite=True ' + dir_string = ",".join([str(i) for i in subdirs]) + dir_string = "[" + dir_string + "]" + string += f"inference.directory_list={dir_string} inference.overwrite=True " string = add_default_arguments(string, train=False) command = command_from_string(string) ret = subprocess.run(command) @@ -123,7 +129,7 @@ def test_feature_extraction(softmax: bool = False): def test_sequence_train(): - string = ('python -m deepethogram.sequence.train ') + string = "python -m deepethogram.sequence.train " string = add_default_arguments(string) command = command_from_string(string) print(command) @@ -131,7 +137,7 @@ def test_sequence_train(): assert ret.returncode == 0 # mutually exclusive - string = ('python -m deepethogram.sequence.train feature_extractor.final_activation=softmax ') + string = "python -m deepethogram.sequence.train feature_extractor.final_activation=softmax " string = add_default_arguments(string) command = command_from_string(string) print(command) @@ -141,14 +147,16 @@ def test_sequence_train(): def test_softmax(): make_project_from_archive() - string = ('python -m deepethogram.flow_generator.train preset=deg_f ') + string = "python -m deepethogram.flow_generator.train preset=deg_f " string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) assert ret.returncode == 0 - string = ('python -m deepethogram.feature_extractor.train preset=deg_f flow_generator.weights=latest ' - 'feature_extractor.final_activation=softmax ') + string = ( + "python -m deepethogram.feature_extractor.train preset=deg_f flow_generator.weights=latest " + "feature_extractor.final_activation=softmax " + ) string = add_default_arguments(string) command = command_from_string(string) ret = subprocess.run(command) @@ -156,7 +164,7 @@ def test_softmax(): test_feature_extraction(softmax=True) - string = ('python -m deepethogram.sequence.train feature_extractor.final_activation=softmax ') + string = "python -m deepethogram.sequence.train feature_extractor.final_activation=softmax " string = add_default_arguments(string) command = command_from_string(string) print(command) @@ -164,5 +172,5 @@ def test_softmax(): assert ret.returncode == 0 -if __name__ == '__main__': - test_softmax() \ No newline at end of file +if __name__ == "__main__": + test_softmax() From 4899196213e3d88da673e7d00f491cf0d86ee916 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sat, 11 Jan 2025 15:46:07 -0500 Subject: [PATCH 07/23] ruff --- deepethogram/viz.py | 68 +++++++++++++++++++-------------------------- 1 file changed, 28 insertions(+), 40 deletions(-) diff --git a/deepethogram/viz.py b/deepethogram/viz.py index 79f70ef..d548a5a 100644 --- a/deepethogram/viz.py +++ b/deepethogram/viz.py @@ -1,20 +1,20 @@ -from collections import OrderedDict import itertools import logging import os import warnings +from collections import OrderedDict from typing import Union import cv2 import h5py import matplotlib import numpy as np +import torch # import tifffile as TIFF from matplotlib import pyplot as plt from matplotlib.animation import FuncAnimation -from mpl_toolkits.axes_grid1 import make_axes_locatable, inset_locator -import torch +from mpl_toolkits.axes_grid1 import inset_locator, make_axes_locatable from deepethogram.flow_generator.utils import flow_to_rgb_polar @@ -266,8 +266,6 @@ def visualize_multiresolution( axes = fig.subplots(4, N_resolutions) - images = downsampled_t0[0].detach().cpu().numpy().astype(np.float32) - index = batch_ind * sequence_length + sequence_ind t0 = [ downsampled_t0[i][index].detach().cpu().numpy().transpose(1, 2, 0).astype(np.float32) @@ -549,12 +547,6 @@ def visualize_batch_sequence(sequence, outputs, labels, N_in_batch=None, fig=Non axes = fig.subplots(4, 1) ax = axes[0] - # seq = sequence[N_in_batch] - aspect_ratio = outputs - - # tmp = outputs[N_in_batch] - # seq = cv2.resize(sequence[N_in_batch], (tmp.shape[1]*10,tmp.shape[0]*10), interpolation=cv2.INTER_NEAREST) - # seq = cv2.imresize(sequence[N_in_batch], ) imshow_with_colorbar(sequence, ax, fig, interpolation="nearest", symmetric=False, func="pcolor", cmap="viridis") ax.invert_yaxis() ax.set_ylabel("inputs") @@ -609,7 +601,7 @@ def fig_to_img(fig_handle: matplotlib.figure.Figure) -> np.ndarray: def plot_histogram(array, ax, bins="auto", width_factor=0.9, rotation=30): """Helper function for plotting a histogram""" - if type(array) != np.ndarray: + if not isinstance(array, np.ndarray): array = np.array(array) hist, bin_edges = np.histogram(array, bins=bins, density=False) @@ -625,9 +617,7 @@ def plot_histogram(array, ax, bins="auto", width_factor=0.9, rotation=30): ylims = ax.get_ylim() leg_str = "median: %0.4f" % (med) - lineh = ax.plot( - np.array([med, med]), np.array([ylims[0], ylims[1]]), color="k", linestyle="dashed", lw=3, label=leg_str - ) + ax.plot(np.array([med, med]), np.array([ylims[0], ylims[1]]), color="k", linestyle="dashed", lw=3, label=leg_str) ax.set_ylabel("P") ax.legend() @@ -745,7 +735,7 @@ def plot_confusion_matrix( # print(cm) if colorbar: - cbar = imshow_with_colorbar(cm, ax, fig, interpolation="nearest", cmap=cmap) + imshow_with_colorbar(cm, ax, fig, interpolation="nearest", cmap=cmap) else: ax.imshow(cm, cmap=cmap) @@ -1274,7 +1264,7 @@ def __init__(self, colormap="deepethogram"): try: self.cmap = plt.get_cmap(colormap) except ValueError: - raise ("Colormap not in matplotlib" "s defaults! {}".format(colormap)) + raise ("Colormap not in matplotlibs defaults! {}".format(colormap)) def init_deepethogram(self): gray_LUT = make_LUT([0, 0, value], [0, 0, gray_value]) @@ -1287,7 +1277,7 @@ def init_deepethogram(self): def apply_cmaps(self, array: Union[np.ndarray, int, float]) -> np.ndarray: # assume columns are timepoints, rpws are behaviors - if type(array) == int or type(array) == float: + if isinstance(array, (int, float)): # use the 0th LUT by default return apply_cmap(array, self.LUTs[0]) elif array.shape[0] == 1 and len(array.shape) == 1: @@ -1317,9 +1307,9 @@ def __call__(self, array: Union[np.ndarray, int, float]) -> np.ndarray: def make_LUT(start_hsv: Union[tuple, list, np.ndarray], end_hsv: Union[tuple, list, np.ndarray]) -> np.ndarray: - if type(start_hsv) != np.ndarray: + if not isinstance(start_hsv, np.ndarray): start_hsv = np.array(start_hsv).astype(np.uint8) - if type(end_hsv) != np.ndarray: + if not isinstance(end_hsv, np.ndarray): end_hsv = np.array(end_hsv).astype(np.uint8) # interpolate in HSV space; if they have two different hues, will result in very weird colormap @@ -1332,11 +1322,11 @@ def make_LUT(start_hsv: Union[tuple, list, np.ndarray], end_hsv: Union[tuple, li def apply_cmap(array: Union[np.ndarray, int, float], LUT: np.ndarray) -> np.ndarray: single_input = False - if type(array) == int: + if isinstance(array, int): assert array >= 0 and array <= 255 array = np.array([array]).astype(np.uint8) single_input = True - elif type(array) == float: + elif isinstance(array, float): array = np.array([array]).astype(float) single_input = True if array.dtype != np.uint8: @@ -1402,8 +1392,8 @@ def make_ethogram_movie( fps: float = 30, ): """Makes a movie out of an ethogram. Can be very slow due to matplotlib's animations""" - if mapper is None: - mapper = Mapper() + if not isinstance(classes, np.ndarray): + classes = np.array(classes) fig = plt.figure(figsize=(10, 12)) # camera = Camera(fig) @@ -1417,7 +1407,7 @@ def make_ethogram_movie( starts = np.arange(0, ethogram.shape[0], width) - if type(classes) != np.ndarray: + if not isinstance(classes, np.ndarray): classes = np.array(classes) framenum = 0 @@ -1446,7 +1436,12 @@ def animate(i): # print(x) if (i % width) == 0: etho_h = plot_ethogram( - ethogram[starts[i // width] : starts[i // width] + width, :], mapper, start + i, ax1, classes + ethogram[starts[i // width] : starts[i // width] + width, :], + mapper, + start + i, + ax1, + classes, + ylabel="Labels", ) # no idea why plot ethogram doesn't change this xticks = ax1.get_xticks() @@ -1454,7 +1449,7 @@ def animate(i): ax1.set_xticklabels([str(int(i)) for i in new_ticks]) else: - etho_h = [i for i in ax1.get_children() if type(i) == matplotlib.image.AxesImage][0] + etho_h = [i for i in ax1.get_children() if isinstance(i, matplotlib.image.AxesImage)][0] plot_h.set_xdata(x) title_h.set_text("{:,}: {}".format(start + i, classes[np.where(ethogram[i])[0]].tolist())) @@ -1486,8 +1481,9 @@ def make_ethogram_movie_with_predictions( ): """Makes a movie with movie, then ethogram, then model predictions""" - if mapper is None: - mapper = Mapper() + if not isinstance(classes, np.ndarray): + classes = np.array(classes) + fig = plt.figure(figsize=(6, 8)) # camera = Camera(fig) @@ -1500,13 +1496,11 @@ def make_ethogram_movie_with_predictions( # ax1 = fig.add_subplot(gs[2]) starts = np.arange(0, ethogram.shape[0], width) - if type(classes) != np.ndarray: + if not isinstance(classes, np.ndarray): classes = np.array(classes) framenum = 0 - # values_to_return = [] - im_h = axes[0].imshow(frames[0]) ax = axes[1] @@ -1530,20 +1524,14 @@ def make_ethogram_movie_with_predictions( plt.tight_layout() - # etho_h = plot_ethogram(ethogram[starts[0]:starts[0] + width, :], - # mapper, start + framenum, ax1, classes) - def init(): return [im_h, im_h1, im_h2, plot_h1, plot_h2, title_h] def animate(i): - # values_to_return = [] - # print(i) im_h.set_data(frames[i]) x0 = i - starts[i // width] - 0.5 x1 = x0 + 1 x = (x0, x1, x1, x0, x0) - # print(x) if (i % width) == 0: im_h1 = plot_ethogram( ethogram[starts[i // width] : starts[i // width] + width, :], @@ -1572,8 +1560,8 @@ def animate(i): axes[2].set_xticklabels([str(int(i)) for i in new_ticks]) else: - im_h1 = [i for i in axes[1].get_children() if type(i) == matplotlib.image.AxesImage][0] - im_h2 = [i for i in axes[2].get_children() if type(i) == matplotlib.image.AxesImage][0] + im_h1 = [i for i in axes[1].get_children() if isinstance(i, matplotlib.image.AxesImage)][0] + im_h2 = [i for i in axes[2].get_children() if isinstance(i, matplotlib.image.AxesImage)][0] plot_h1.set_xdata(x) plot_h2.set_xdata(x) From 0a2a67ab709cc4159e87d8e868e9e3f4dc7d1562 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sat, 11 Jan 2025 15:48:25 -0500 Subject: [PATCH 08/23] default to stride 1 and save length --- deepethogram/zscore.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/deepethogram/zscore.py b/deepethogram/zscore.py index 179e136..ec25bbd 100644 --- a/deepethogram/zscore.py +++ b/deepethogram/zscore.py @@ -3,14 +3,13 @@ import sys from typing import Union -# import hydra import numpy as np import torch from omegaconf import DictConfig from tqdm import tqdm import deepethogram.file_io -from deepethogram import configuration, utils, projects +from deepethogram import configuration, projects, utils log = logging.getLogger(__name__) @@ -93,8 +92,9 @@ def __str__(self): return "mean: {} std: {} n: {}".format(self.mean, self.std, self.nobservations) -def get_video_statistics(videofile, stride): +def get_video_statistics(videofile: Union[str, os.PathLike], stride: int = 10) -> dict: image_stats = StatsRecorder() + n_frames = 0 with deepethogram.file_io.VideoReader(videofile) as reader: log.debug("N frames: {}".format(len(reader))) for i in tqdm(range(0, len(reader), stride)): @@ -105,28 +105,23 @@ def get_video_statistics(videofile, stride): continue image = image.astype(float) / 255 image = image.transpose(2, 1, 0) - # image = image[np.newaxis,...] - # N, C, H, W = image.shape image = image.reshape(3, -1).transpose(1, 0) - # image = image.reshape(N, C, -1).squeeze().transpose(1, 0) - # if i == 0: - # print(image.shape) image_stats.update(image) - + n_frames += 1 log.info("final stats: {}".format(image_stats)) - imdata = {"mean": image_stats.mean, "std": image_stats.std, "N": image_stats.nobservations} + imdata = {"mean": image_stats.mean, "std": image_stats.std, "N": n_frames} for k, v in imdata.items(): - if type(v) == torch.Tensor: + if isinstance(v, torch.Tensor): v = v.detach().cpu().numpy() - if type(v) == np.ndarray: + if isinstance(v, np.ndarray): v = v.tolist() imdata[k] = v return imdata -def zscore_video(videofile: Union[str, os.PathLike], project_config: dict, stride: int = 10): +def zscore_video(videofile: Union[str, os.PathLike], project_config: dict, stride: int = 10) -> None: """calculates channel-wise mean and standard deviation for input video. Calculates mean and std deviation independently for each input video channel. Grayscale videos are converted to RGB. @@ -144,10 +139,6 @@ def zscore_video(videofile: Union[str, os.PathLike], project_config: dict, strid assert os.path.exists(videofile) assert projects.is_deg_file(videofile) - # config['arch'] = 'flow-generator' - # config['normalization'] = None - # transforms = get_transforms_from_config(config) - # xform = transforms['train'] log.info("zscoring file: {}".format(videofile)) imdata = get_video_statistics(videofile, stride) @@ -163,8 +154,6 @@ def zscore_video(videofile: Union[str, os.PathLike], project_config: dict, strid def update_project_with_normalization(norm_dict: dict, project_config: dict): """Adds statistics from this video to the overall mean / std deviation for the project""" - # project_dict = utils.load_yaml(os.path.join(project_dir, 'project_config.yaml')) - if "normalization" not in project_config["augs"].keys(): raise ValueError("Must have project_config/augs/normalization field: {}".format(project_config)) old_rgb = project_config["augs"]["normalization"] @@ -183,7 +172,6 @@ def update_project_with_normalization(norm_dict: dict, project_config: dict): utils.save_dict_to_yaml(project_config, os.path.join(project_config["project"]["path"], "project_config.yaml")) -# @hydra.main(config_path='../conf/zscore.yaml') def main(cfg: DictConfig): assert os.path.isfile(cfg.videofile) project_config = utils.load_yaml(cfg.project.config_file) From 7018c81cadf2848a84a7ae08cbd8e28ec77fe0f5 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sat, 11 Jan 2025 15:48:41 -0500 Subject: [PATCH 09/23] change linting rules --- pyproject.toml | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f9492d9..7bcaab8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,12 +2,8 @@ # Python version compatibility target-version = "py37" -# Ignore specific rules -ignore = [] - -# Allow autofix for all enabled rules (when `--fix`) is provided. -fixable = ["ALL"] -unfixable = [] +# Same as Black. +line-length = 120 # Exclude a variety of commonly ignored directories. exclude = [ @@ -34,8 +30,13 @@ exclude = [ "venv", ] -# Same as Black. -line-length = 120 +[tool.ruff.lint] +# Ignore specific rules +ignore = [] + +# Allow autofix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" From ccc9d18027842f0f771ad15185da2f8fdccf1d51 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sat, 11 Jan 2025 15:51:25 -0500 Subject: [PATCH 10/23] linting --- pyproject.toml | 2 ++ setup_tests.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7bcaab8..97d7efc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,8 @@ exclude = [ "dist", "node_modules", "venv", + "tests/", + "docs/" ] [tool.ruff.lint] diff --git a/setup_tests.py b/setup_tests.py index 0351233..8a662a9 100644 --- a/setup_tests.py +++ b/setup_tests.py @@ -45,7 +45,6 @@ def setup_tests(): data_dir.mkdir(parents=True, exist_ok=True) # Download the test archive - zip_url = "https://drive.google.com/uc?export=download&id=1IFz4ABXppVxyuhYik8j38k9-Fl9kYKHo" zip_path = data_dir / "testing_deepethogram_archive.zip" try: From 53551a0c3915ba8e00aaba311fd074c5fbece21a Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 12 Jan 2025 00:06:13 -0500 Subject: [PATCH 11/23] ruff; remove old comments; fix hacky tests --- README.md | 20 +- deepethogram/base.py | 31 ++- deepethogram/callbacks.py | 17 +- deepethogram/configuration.py | 6 +- deepethogram/data/dali.py | 98 +++++++-- deepethogram/data/dataloaders.py | 9 +- deepethogram/data/datasets.py | 26 +-- deepethogram/data/utils.py | 2 - deepethogram/debug.py | 5 +- deepethogram/feature_extractor/inference.py | 22 +- deepethogram/feature_extractor/losses.py | 24 --- deepethogram/feature_extractor/train.py | 107 ++++------ deepethogram/file_io.py | 2 - deepethogram/flow_generator/__init__.py | 1 - deepethogram/flow_generator/inference.py | 20 +- deepethogram/flow_generator/losses.py | 118 ++++++++--- .../flow_generator/models/FlowNetS.py | 17 +- .../flow_generator/models/MotionNet.py | 43 +--- .../flow_generator/models/TinyMotionNet.py | 20 +- .../flow_generator/models/TinyMotionNet3D.py | 28 +-- .../flow_generator/models/components.py | 8 +- deepethogram/flow_generator/utils.py | 29 +-- deepethogram/gui/custom_widgets.py | 190 +----------------- deepethogram/gui/main.py | 148 ++------------ deepethogram/gui/mainwindow.py | 2 +- deepethogram/losses.py | 12 +- deepethogram/metrics.py | 63 ++---- deepethogram/postprocessing.py | 12 +- deepethogram/projects.py | 133 ++---------- deepethogram/sequence/inference.py | 25 +-- deepethogram/sequence/models/tgm.py | 48 +---- deepethogram/stoppers.py | 6 - deepethogram/tune/utils.py | 4 +- deepethogram/utils.py | 124 +++--------- deepethogram/viz.py | 136 +------------ docs/testing.md | 66 ++++++ pyproject.toml | 5 - pytest.ini | 18 ++ ..._zz_commandline.py => test_integration.py} | 6 + tests/test_z_score.py | 2 +- 40 files changed, 495 insertions(+), 1158 deletions(-) create mode 100644 docs/testing.md create mode 100644 pytest.ini rename tests/{test_zz_commandline.py => test_integration.py} (98%) diff --git a/README.md b/README.md index 865485d..6bf2032 100644 --- a/README.md +++ b/README.md @@ -91,10 +91,22 @@ Some older ones might also be fine, like a `1080` or even `1070 Ti`/ `1070`. Test coverage is still low, but in the future we will be expanding our unit tests. First, download a copy of [`testing_deepethogram_archive.zip`](https://drive.google.com/file/d/1IFz4ABXppVxyuhYik8j38k9-Fl9kYKHo/view?usp=sharing) - Make a directory in tests called `DATA`. Unzip this and move it to the `deepethogram/tests/DATA` -directory, so that the path is `deepethogram/tests/DATA/testing_deepethogram_archive/{DATA,models,project_config.yaml}`. Then run `pytest tests/` to run. -the `zz_commandline` test module will take a few minutes, as it is an end-to-end test that performs model training -and inference. Its name reflects the fact that it should come last in testing. +Make a directory in tests called `DATA`. Unzip this and move it to the `deepethogram/tests/DATA` +directory, so that the path is `deepethogram/tests/DATA/testing_deepethogram_archive/{DATA,models,project_config.yaml}`. + +To run tests: +```bash +# Run all tests except GPU tests (default) +pytest tests/ + +# Run only GPU tests (requires NVIDIA GPU) +pytest -m gpu + +# Run all tests including GPU tests +pytest -m "" +``` + +GPU tests are skipped by default as they require significant computational resources and time to complete. These tests perform end-to-end model training and inference. ## Developer Guide ### Code Style and Pre-commit Hooks diff --git a/deepethogram/base.py b/deepethogram/base.py index 364a88b..5271c77 100644 --- a/deepethogram/base.py +++ b/deepethogram/base.py @@ -1,18 +1,17 @@ -from collections import defaultdict -from copy import deepcopy import logging import math import os +from collections import defaultdict +from copy import deepcopy from typing import Tuple import matplotlib.pyplot as plt -from omegaconf import DictConfig, OmegaConf import pytorch_lightning as pl +from omegaconf import DictConfig, OmegaConf try: - from ray.tune.integration.pytorch_lightning import TuneReportCallback, TuneReportCheckpointCallback - from ray.tune import get_trial_dir - from ray.tune import CLIReporter + from ray.tune import CLIReporter, get_trial_dir # noqa: F401 + from ray.tune.integration.pytorch_lightning import TuneReportCallback, TuneReportCheckpointCallback # noqa: F401 ray = True except ImportError: @@ -21,17 +20,17 @@ from torch import nn, optim from torch.utils.data import DataLoader, WeightedRandomSampler -from deepethogram.data.augs import get_gpu_transforms, get_empty_gpu_transforms +from deepethogram import utils, viz from deepethogram.callbacks import ( + CheckpointCallback, + ExampleImagesCallback, FPSCallback, MetricsCallback, - ExampleImagesCallback, - CheckpointCallback, StopperCallback, ) -from deepethogram.metrics import Metrics, EmptyMetrics +from deepethogram.data.augs import get_empty_gpu_transforms, get_gpu_transforms +from deepethogram.metrics import EmptyMetrics, Metrics from deepethogram.schedulers import initialize_scheduler -from deepethogram import viz, utils log = logging.getLogger(__name__) @@ -65,7 +64,7 @@ def __init__(self, model: nn.Module, cfg: DictConfig, datasets: dict, metrics: M self.model = model try: self.hparams = cfg - except: + except Exception: # for pytorch lightning > 1.1.8 self.hparams.update(cfg) @@ -197,7 +196,7 @@ def get_train_sampler(self): def get_val_sampler(self): # get sample weights for validation dataset to up-sample rare classes - dataset = self.datasets["val"] + # dataset = self.datasets["val"] # if dataset.labels is None: # # self-supervised, e.g. flow generators # return @@ -302,7 +301,6 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s log.debug("orig: {}".format(lightning_module.gpu_transforms)) - original_augs = cfg.augs new_augs = deepcopy(cfg.augs) new_augs.color_p = 1.0 @@ -420,9 +418,4 @@ def get_trainer_from_cfg(cfg: DictConfig, lightning_module, stopper, profiler: s log_every_n_steps=1, ) torch.cuda.empty_cache() - # gc.collect() - - # import signal - # signal.signal(signal.SIGTERM, signal.SIG_DFL) - # log.info('trainer is_slurm_managing_tasks: {}'.format(trainer.is_slurm_managing_tasks)) return trainer diff --git a/deepethogram/callbacks.py b/deepethogram/callbacks.py index 1c0ff78..756a099 100644 --- a/deepethogram/callbacks.py +++ b/deepethogram/callbacks.py @@ -1,7 +1,7 @@ -from collections import defaultdict import logging import os import time +from collections import defaultdict import numpy as np from pytorch_lightning.callbacks import Callback @@ -111,12 +111,6 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal self.end_batch("speedtest", batch, pl_module) -# class SpeedtestCallback(Callback): -# def __init__(self): -# super().__init__() -# -# def on_validation_end(self, trainer, pl_module): -# trainer.test(pl_module) def log_metrics(pl_module, split): assert split in ["train", "val", "test"] metrics, _ = pl_module.metrics.end_epoch(split) @@ -143,13 +137,9 @@ def __init__(self): def on_train_epoch_end(self, trainer, pl_module): pl_module.metrics.buffer.append("train", {"lr": utils.get_minimum_learning_rate(pl_module.optimizer)}) _ = log_metrics(pl_module, "train") - # latest_key = pl_module.metrics.latest_key['train'] - # key = 'train_{}'.format(pl_module.metrics.key_metric) - # pl_module.log(key, latest_key, on_epoch=True) def on_validation_epoch_end(self, trainer, pl_module): scalar_metrics = log_metrics(pl_module, "val") - latest_key = pl_module.metrics.latest_key["val"] # this logic is to correctly log only important hyperparameters and important metrics to tensorboard's # hyperparameter view. Just using all the parameters in our configuration makes for a huge and ugly tensorboard @@ -168,9 +158,6 @@ def on_validation_epoch_end(self, trainer, pl_module): print(pl_module.tune_hparams, hparam_metrics) pl_module.logger.log_hyperparams(pl_module.tune_hparams, hparam_metrics) - # # log the latest key metric in tensorboard as hp_metric, which will enable hparam view - # pl_module.log('hp_metric', latest_key, on_epoch=True) - def on_test_epoch_end(self, trainer, pl_module): log_metrics(pl_module, "test") # pl_module.metrics.end_epoch('speedtest') @@ -248,7 +235,5 @@ def on_train_epoch_end(self, trainer, pl_module): raise ValueError("invalid stopping name: {}".format(self.stopper.name)) if should_stop: - # log.info('Stopping criterion reached! Raising KeyboardInterrupt to quit') log.info("Stopping criterion reached! setting trainer.should_stop=True") trainer.should_stop = True - # raise KeyboardInterrupt diff --git a/deepethogram/configuration.py b/deepethogram/configuration.py index 4748051..837fae2 100644 --- a/deepethogram/configuration.py +++ b/deepethogram/configuration.py @@ -1,7 +1,7 @@ import os from typing import Union -from omegaconf import OmegaConf, DictConfig +from omegaconf import DictConfig, OmegaConf import deepethogram from deepethogram import projects @@ -33,7 +33,7 @@ def config_string_to_path(config_path: Union[str, os.PathLike], string: str) -> return fullpath -def load_config_by_name(string: str, config_path: Union[str, os.PathLike] = None) -> DictConfig: +def load_config_by_name(string: str, config_path: Union[str, os.PathLike, None] = None) -> DictConfig: """Loads a default configuration by name Parameters @@ -129,7 +129,6 @@ def make_config( DictConfig [description] """ - # config_path = os.path.join(os.path.dirname(deepethogram.__file__), 'conf') user_cfg = projects.get_config_from_path(project_path) @@ -153,7 +152,6 @@ def make_config( if debug: config_list.append("debug") - # config_files = [config_string_to_path(config_path, i) for i in config_list] cfgs = [load_config_by_name(i) for i in config_list] diff --git a/deepethogram/data/dali.py b/deepethogram/data/dali.py index a044019..9c9fed7 100644 --- a/deepethogram/data/dali.py +++ b/deepethogram/data/dali.py @@ -1,19 +1,50 @@ +"""NVIDIA DALI pipeline and loader implementations for video processing. + +This module provides DALI-based data loading functionality for video datasets, +with a focus on the Kinetics dataset format. It includes GPU-accelerated video loading +and augmentation capabilities. +""" + import os try: - from nvidia.dali.pipeline import Pipeline - import nvidia.dali.ops as ops - import nvidia.dali.types as types - from nvidia.dali.backend import TensorListCPU - from nvidia.dali.plugin import pytorch + import nvidia.dali.ops as ops # noqa: F401 + import nvidia.dali.types as types # noqa: F401 + from nvidia.dali.backend import TensorListCPU # noqa: F401 + from nvidia.dali.pipeline import Pipeline # noqa: F401 + from nvidia.dali.plugin import pytorch # noqa: F401 except ImportError: dali = False - # print('DALI not loaded...') else: dali = True class KineticsDALIPipe(Pipeline): + """DALI Pipeline for processing video data in Kinetics format. + + This pipeline handles video loading, augmentation, and preprocessing using NVIDIA DALI. + It supports both supervised and unsupervised modes, with configurable augmentations + including brightness, contrast, and spatial transformations. + + Args: + directory (str): Root directory containing the video files + supervised (bool): Whether to return labels with the data + sequence_length (int): Number of frames to load per sequence + batch_size (int): Number of sequences per batch + num_workers (int): Number of parallel workers + gpu_id (int): ID of GPU to use + shuffle (bool): Whether to shuffle the data + crop_size (tuple): Size of output crops (H, W) + resize (tuple, optional): Size to resize frames to before cropping + brightness (float): Maximum brightness adjustment factor + contrast (float): Maximum contrast adjustment factor + mean (list): Mean values for normalization per channel + std (list): Standard deviation values for normalization per channel + conv_mode (str): Convolution mode ('2d' or '3d') + image_shape (tuple): Base image dimensions (H, W) + validate (bool): Whether this pipeline is for validation (reduces augmentation) + """ + def __init__( self, directory, @@ -73,7 +104,6 @@ def __init__( else: # default H, W = image_shape - # print('CONV MODE!!! {}'.format(conv_mode)) if conv_mode == "3d": self.transpose = ops.Transpose(device="gpu", perm=[3, 0, 1, 2]) self.reshape = None @@ -106,11 +136,30 @@ def define_graph(self): return images -# -# -# # https://github.com/NVIDIA/DALI/blob/cde7271a840142221273f8642952087acd919b6e -# # /docs/examples/use_cases/video_superres/dataloading/dataloaders.py class DALILoader: + """Data loader wrapper for DALI pipeline. + + Provides an iterator interface to the DALI pipeline for easy integration + with PyTorch training loops. + + Args: + directory (str): Root directory containing the video files + supervised (bool): Whether to return labels with the data + sequence_length (int): Number of frames to load per sequence + batch_size (int): Number of sequences per batch + num_workers (int): Number of parallel workers + gpu_id (int): ID of GPU to use + shuffle (bool): Whether to shuffle the data + crop_size (tuple): Size of output crops (H, W) + mean (list): Mean values for normalization per channel + std (list): Standard deviation values for normalization per channel + conv_mode (str): Convolution mode ('2d' or '3d') + validate (bool): Whether this is for validation + distributed (bool): Whether to use distributed training mode + + https://github.com/NVIDIA/DALI/blob/cde7271a840142221273f8642952087acd919b6e/docs/examples/use_cases/video_superres/dataloading/dataloaders.py + """ + def __init__( self, directory, @@ -166,6 +215,25 @@ def get_dataloaders_kinetics_dali( std: list = [0.5, 0.5, 0.5], distributed: bool = False, ): + """Create DALI dataloaders for train and validation sets. + + Args: + directory (str): Root directory containing train and val subdirectories + rgb_frames (int): Number of RGB frames per sequence + batch_size (int): Batch size + shuffle (bool): Whether to shuffle training data + num_workers (int): Number of worker processes + supervised (bool): Whether to return labels + conv_mode (str): Convolution mode ('2d' or '3d') + gpu_id (int): GPU device ID + crop_size (tuple): Output crop dimensions + mean (list): Normalization mean values + std (list): Normalization standard deviation values + distributed (bool): Whether to use distributed training + + Returns: + dict: Dictionary containing train and validation dataloaders + """ shuffles = {"train": shuffle, "val": True, "test": False} dataloaders = {} for split in ["train", "val"]: @@ -188,11 +256,3 @@ def get_dataloaders_kinetics_dali( dataloaders["split"] = None return dataloaders - - -def __len__(self): - return int(self.epoch_size) - - -def __iter__(self): - return self.dali_iterator.__iter__() diff --git a/deepethogram/data/dataloaders.py b/deepethogram/data/dataloaders.py index cd33b51..d56f26e 100644 --- a/deepethogram/data/dataloaders.py +++ b/deepethogram/data/dataloaders.py @@ -10,15 +10,16 @@ from deepethogram import projects from deepethogram.data.augs import get_cpu_transforms -from deepethogram.data.datasets import SequenceDataset, TwoStreamDataset, VideoDataset, KineticsDataset +from deepethogram.data.datasets import KineticsDataset, SequenceDataset, TwoStreamDataset, VideoDataset from deepethogram.data.utils import ( get_split_from_records, - remove_invalid_records_from_split_dictionary, make_loss_weight, + remove_invalid_records_from_split_dictionary, ) try: - from nvidia.dali.pipeline import Pipeline + from nvidia.dali.pipeline import Pipeline # noqa: F401 + from .dali import get_dataloaders_kinetics_dali except ImportError: get_dataloaders_kinetics_dali = None @@ -218,7 +219,7 @@ def get_dataloaders_kinetics( datasets = {} for split in ["train", "val", "test"]: # this is in the two stream case where you can't apply color transforms to an optic flow - if type(xform[split]) == dict: + if isinstance(xform[split], dict): spatial_transform = xform[split]["spatial"] color_transform = xform[split]["color"] else: diff --git a/deepethogram/data/datasets.py b/deepethogram/data/datasets.py index 3d4aa13..00c2ced 100644 --- a/deepethogram/data/datasets.py +++ b/deepethogram/data/datasets.py @@ -1,29 +1,29 @@ -from collections import deque import logging import os import random -from typing import Union, Tuple +from collections import deque +from typing import Tuple, Union import h5py import numpy as np -from omegaconf import DictConfig import torch +from omegaconf import DictConfig from torch.utils import data from vidio import VideoReader # from deepethogram.dataloaders import log from deepethogram import projects from deepethogram.data.augs import get_cpu_transforms +from deepethogram.data.keypoint_utils import expand_features_sturman, interpolate_bad_values, load_dlcfile from deepethogram.data.utils import ( - purge_unlabeled_elements_from_records, + fix_label, + get_split_from_records, get_video_metadata, + make_loss_weight, + purge_unlabeled_elements_from_records, read_all_labels, - get_split_from_records, remove_invalid_records_from_split_dictionary, - make_loss_weight, - fix_label, ) -from deepethogram.data.keypoint_utils import load_dlcfile, interpolate_bad_values, expand_features_sturman from deepethogram.file_io import read_labels log = logging.getLogger(__name__) @@ -740,11 +740,8 @@ def verify_dataset(self): flow_shape = f[self.flow_key].shape image_shape = f[self.image_key].shape assert flow_shape[0] == image_shape[0] - # self.N = image_shape[0] else: assert self.key in f - shape = f[self.key].shape - # self.N = shape[0] def read_features_from_disk(self, start_ind, end_ind): inds = slice(start_ind, end_ind) @@ -922,7 +919,7 @@ def get_video_datasets( datasets = {} for i, split in enumerate(["train", "val", "test"]): rgb = [records[i]["rgb"] for i in split_dictionary[split]] - flow = [records[i]["flow"] for i in split_dictionary[split]] + # flow = [records[i]["flow"] for i in split_dictionary[split]] if split == "test" and len(rgb) == 0: datasets[split] = None @@ -1072,8 +1069,6 @@ def get_sequence_datasets( # e.g.: you've added a video, but not labeled it yet. In that case, it will already be in your split, but it is # invalid for current purposes, because it has no label. Therefore, we want to remove it from the current split split_dictionary = remove_invalid_records_from_split_dictionary(split_dictionary, records) - # log.info('~~~~~ train val test split ~~~~~') - # pprint.pprint(split_dictionary) splits = ["train", "val", "test"] datasets = {} @@ -1090,8 +1085,6 @@ def get_sequence_datasets( if split == "test" and len(datafiles) == 0: datasets[split] = None continue - # h5file, labelfile = outputs[i] - # print('making dataset:{}'.format(split)) if supervised: labelfiles = [records[i]["label"] for i in split_dictionary[split]] @@ -1176,7 +1169,6 @@ def get_datasets_from_cfg(cfg: DictConfig, model_type: str, input_images: int = if model_type == "feature_extractor" or model_type == "flow_generator": arch = cfg[model_type].arch mode = "3d" if "3d" in arch.lower() else "2d" - # log.info('getting dataloaders: {} convolution type detected'.format(mode)) xform = get_cpu_transforms(cfg.augs) if cfg.project.name == "kinetics": diff --git a/deepethogram/data/utils.py b/deepethogram/data/utils.py index 052c1f0..0959a35 100644 --- a/deepethogram/data/utils.py +++ b/deepethogram/data/utils.py @@ -246,8 +246,6 @@ def read_all_labels(labelfiles: list, fix: bool = True, multilabel: bool = True) labels = [] for i, labelfile in enumerate(labelfiles): assert os.path.isfile(labelfile) - label_type = os.path.splitext(labelfile)[1][1:] - # labelfile, label_type = find_labelfile(video) label = read_labels(labelfile) H, W = label.shape # labels should be time x num_behaviors diff --git a/deepethogram/debug.py b/deepethogram/debug.py index 8928f5d..a61c334 100644 --- a/deepethogram/debug.py +++ b/deepethogram/debug.py @@ -5,8 +5,8 @@ import numpy as np from omegaconf import OmegaConf -from vidio import VideoReader from tqdm import tqdm +from vidio import VideoReader from deepethogram import file_io, projects @@ -77,7 +77,7 @@ def try_load_all_frames(datadir: Union[str, os.PathLike]): had_error = False for i in tqdm(range(len(reader)), leave=False): try: - frame = reader[i] + _ = reader[i] except Exception: had_error = True print("error reading frame {} from video {}".format(i, record["rgb"])) @@ -113,7 +113,6 @@ def try_load_all_frames(datadir: Union[str, os.PathLike]): user_cfg = OmegaConf.load(cfg.project.config_file) cfg = OmegaConf.merge(cfg, user_cfg) cfg = projects.convert_config_paths_to_absolute(cfg) - # print(cfg) logging.info(OmegaConf.to_yaml(cfg)) diff --git a/deepethogram/feature_extractor/inference.py b/deepethogram/feature_extractor/inference.py index 67d6787..be13cc5 100644 --- a/deepethogram/feature_extractor/inference.py +++ b/deepethogram/feature_extractor/inference.py @@ -6,14 +6,14 @@ import h5py import numpy as np -from sklearn.metrics import f1_score import torch -from torch.utils.data import DataLoader -from omegaconf import DictConfig, OmegaConf, ListConfig +from omegaconf import DictConfig, ListConfig, OmegaConf +from sklearn.metrics import f1_score from torch import nn +from torch.utils.data import DataLoader from tqdm import tqdm -from deepethogram import utils, projects +from deepethogram import projects, utils from deepethogram.configuration import make_feature_extractor_inference_cfg from deepethogram.data.augs import get_cpu_transforms, get_gpu_transforms from deepethogram.data.datasets import VideoIterable @@ -502,9 +502,9 @@ def feature_extractor_inference(cfg: DictConfig): if directory_list is None or len(directory_list) == 0: raise ValueError( - "must pass list of directories from commmand line. " "Ex: directory_list=[path_to_dir1,path_to_dir2]" + "must pass list of directories from commmand line. Ex: directory_list=[path_to_dir1,path_to_dir2]" ) - elif type(directory_list) == str and directory_list == "all": + elif isinstance(directory_list, str) and directory_list == "all": basedir = cfg.project.data_path directory_list = utils.get_subfiles(basedir, "directory") elif isinstance(directory_list, str): @@ -524,9 +524,9 @@ def feature_extractor_inference(cfg: DictConfig): record = projects.get_record_from_subdir(directory) assert record["rgb"] is not None records.append(record) - assert cfg.feature_extractor.n_flows + 1 == cfg.flow_generator.n_rgb, ( - "Flow generator inputs must be one greater " "than feature extractor num flows " - ) + assert ( + cfg.feature_extractor.n_flows + 1 == cfg.flow_generator.n_rgb + ), "Flow generator inputs must be one greater than feature extractor num flows " input_images = cfg.feature_extractor.n_flows + 1 mode = "3d" if "3d" in cfg.feature_extractor.arch.lower() else "2d" @@ -552,9 +552,7 @@ def feature_extractor_inference(cfg: DictConfig): # we don't want to use the weights that the trained model was initialized with, but the weights after training # therefore, overwrite the loaded configuration with the current weights cfg.feature_extractor.weights = feature_extractor_weights - # num_classes = len(loaded_cfg.project.class_names) - # log.warning('Overwriting current project classes with loaded classes! REVERT') model_components = build_feature_extractor(cfg) _, _, _, _, model = model_components device = "cuda:{}".format(cfg.compute.gpu_id) @@ -562,7 +560,6 @@ def feature_extractor_inference(cfg: DictConfig): metrics_file = run_files["metrics_file"] assert os.path.isfile(metrics_file) best_epoch = utils.get_best_epoch_from_weightfile(feature_extractor_weights) - # best_epoch = -1 log.info("best epoch from loaded file: {}".format(best_epoch)) with h5py.File(metrics_file, "r") as f: try: @@ -579,7 +576,6 @@ def feature_extractor_inference(cfg: DictConfig): Did you add or remove behaviors after training this model? If so, please retrain! """.format(len(thresholds), len(class_names)) raise ValueError(error_message) - # class_names = projects.get_classes_from_project(cfg) class_names = np.array(class_names) postprocessor = get_postprocessor_from_cfg(cfg, thresholds) extract( diff --git a/deepethogram/feature_extractor/losses.py b/deepethogram/feature_extractor/losses.py index c574a7c..909d8da 100644 --- a/deepethogram/feature_extractor/losses.py +++ b/deepethogram/feature_extractor/losses.py @@ -21,8 +21,6 @@ def __init__(self, alpha=0.1, weight=None, ignore_index=-1): self.alpha = alpha self.should_smooth = self.alpha != 0.0 - # self.nll = nn.NLLLoss(weight=weight, reduction='none') - self.log_softmax = nn.LogSoftmax(dim=1) self.ignore_index = ignore_index if weight is None: @@ -39,7 +37,6 @@ def forward(self, outputs, label): assert outputs.shape == label.shape, "Outputs shape must match labels! {}, {}".format( outputs.shape, label.shape ) - # N, K, T = outputs.shape label = label.float() # figure out which index to ignore before smoothing @@ -102,7 +99,6 @@ def __init__(self, pos_weight=None, ignore_index=-1, gamma: float = 0, label_smo self.bcewithlogitsloss = nn.BCEWithLogitsLoss(weight=None, reduction="none", pos_weight=pos_weight) self.ignore_index = ignore_index self.gamma = gamma - # self.alpha = alpha self.eps = 1e-7 # if label_smoothing is 0.1, then the "correct" answer is 0.1 # multiplying by 2 ensures this with the logic below @@ -156,26 +152,6 @@ def forward(self, outputs, label): weight = torch.pow(1 - prob, self.gamma) * label * mask + torch.pow(prob, self.gamma) * (1 - label) * mask - # NOTE: should not need the absolute here. however, getting this error: - # RuntimeError: Function 'PowBackward0' returned nan values in its 0th output. - - # if torch.sum(prob < 0) > 0 or torch.sum( torch.abs(1-prob) < 0 ) > 0: - # print('negative numbers in prob') - # pdb.set_trace() - - # one_minus_prob = torch.clamp(1-prob, self.eps, 1-self.eps) - - # if torch.sum(torch.isinf(one_minus_prob)) > 0 or torch.sum(one_minus_prob != one_minus_prob): - # print('nans or infs in 1-prob') - # pdb.set_trace() - # # spread out into 3 lines to figure out where the gradient nan is coming from - # weight_if_1 = torch.pow( one_minus_prob, self.gamma) # *label*mask - # weight_if_1 = weight_if_1*label - # weight_if_1 = weight_if_1*mask - - # weight_if_0 = torch.pow(prob, self.gamma)*(1-label)*mask - # weight = weight_if_1 + weight_if_0 - # https://www.kaggle.com/c/siim-isic-melanoma-classification/discussion/166833 label = label * (1 - self.label_smoothing) + 0.5 * self.label_smoothing bceloss = self.bcewithlogitsloss(outputs, label) diff --git a/deepethogram/feature_extractor/train.py b/deepethogram/feature_extractor/train.py index 59e39a0..df521c2 100644 --- a/deepethogram/feature_extractor/train.py +++ b/deepethogram/feature_extractor/train.py @@ -1,35 +1,34 @@ +# ruff: noqa: E402 import gc import logging import os import sys import warnings -from typing import Union, Tuple +from typing import Tuple, Union import cv2 cv2.setNumThreads(0) -# import hydra import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn -from omegaconf import OmegaConf, DictConfig +from omegaconf import DictConfig, OmegaConf -from deepethogram import utils, viz +from deepethogram import projects, utils, viz from deepethogram.base import BaseLightningModule, get_trainer_from_cfg from deepethogram.configuration import make_feature_extractor_train_cfg from deepethogram.data.datasets import get_datasets_from_cfg -from deepethogram.feature_extractor.losses import ClassificationLoss, BinaryFocalLoss, CrossEntropyLoss +from deepethogram.feature_extractor.losses import BinaryFocalLoss, ClassificationLoss, CrossEntropyLoss from deepethogram.feature_extractor.models.CNN import get_cnn from deepethogram.feature_extractor.models.hidden_two_stream import ( - HiddenTwoStream, FlowOnlyClassifier, + HiddenTwoStream, build_fusion_layer, ) from deepethogram.flow_generator.train import build_model_from_cfg as build_flow_generator from deepethogram.losses import get_regularization_loss from deepethogram.metrics import Classification -from deepethogram import projects from deepethogram.stoppers import get_stopper # hack @@ -43,13 +42,11 @@ "and test dataloaders.", ) -# flow_generators = utils.get_models_from_module(flow_models, get_function=False) plt.switch_backend("agg") log = logging.getLogger(__name__) -# @profile def feature_extractor_train(cfg: DictConfig) -> nn.Module: """Trains feature extractor models from a configuration. @@ -63,7 +60,6 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module: nn.Module Trained feature extractor """ - # rundir = os.getcwd() cfg = projects.setup_run(cfg) log.info("args: {}".format(" ".join(sys.argv))) @@ -77,9 +73,9 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module: # we build flow generator independently because you might want to load it from a different location flow_generator = build_flow_generator(cfg) flow_weights = projects.get_weightfile_from_cfg(cfg, "flow_generator") - assert flow_weights is not None, ( - "Must have a valid weightfile for flow generator. Use " "deepethogram.flow_generator.train or cfg.reload.latest" - ) + assert ( + flow_weights is not None + ), "Must have a valid weightfile for flow generator. Use deepethogram.flow_generator.train or cfg.reload.latest" log.info("loading flow generator from file {}".format(flow_weights)) flow_generator = utils.load_weights(flow_generator, flow_weights) @@ -90,7 +86,6 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module: model_parts = build_model_from_cfg(cfg, pos=data_info["pos"], neg=data_info["neg"]) _, spatial_classifier, flow_classifier, fusion, model = model_parts - # log.info('model: {}'.format(model)) num_classes = len(cfg.project.class_names) @@ -116,13 +111,10 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module: # Without the curriculum we just train end to end from the start if cfg.feature_extractor.curriculum: # train spatial model, then flow model, then both end-to-end - # dataloaders = get_dataloaders_from_cfg(cfg, model_type='feature_extractor', - # input_images=cfg.feature_extractor.n_rgb) datasets, data_info = get_datasets_from_cfg( cfg, model_type="feature_extractor", input_images=cfg.feature_extractor.n_rgb ) stopper = get_stopper(cfg) - criterion = get_criterion(cfg, spatial_classifier, data_info) lightning_module = HiddenTwoStreamLightning(spatial_classifier, cfg, datasets, metrics, criterion) @@ -132,18 +124,12 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module: # https://pytorch-lightning.readthedocs.io/en/latest/lr_finder.html?highlight=auto%20scale%20learning%20rate # I tried to do this without re-creating module, but finding the learning rate increments the epoch?? # del lightning_module - # log.info('epoch num: {}'.format(trainer.current_epoch)) - # lightning_module = HiddenTwoStreamLightning(spatial_classifier, cfg, datasets, metrics, criterion) trainer.fit(lightning_module) - # free RAM. note: this doesn't do much - log.info("free ram") del datasets, lightning_module, trainer, stopper, data_info torch.cuda.empty_cache() gc.collect() - # return - datasets, data_info = get_datasets_from_cfg( cfg, model_type="feature_extractor", input_images=cfg.feature_extractor.n_flows + 1 ) @@ -157,7 +143,6 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module: criterion = get_criterion(cfg, flow_generator_and_classifier, data_info) lightning_module = HiddenTwoStreamLightning(flow_generator_and_classifier, cfg, datasets, metrics, criterion) trainer = get_trainer_from_cfg(cfg, lightning_module, stopper) - # lightning_module = HiddenTwoStreamLightning(flow_generator_and_classifier, cfg, datasets, metrics, criterion) trainer.fit(lightning_module) del datasets, lightning_module, trainer, stopper, data_info @@ -177,18 +162,10 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module: cfg.compute.batch_size = original_batch_size cfg.train.lr = original_lr - # log.warning('SETTING ANAOMALY DETECTION TO TRUE! WILL SLOW DOWN.') - # torch.autograd.set_detect_anomaly(True) - lightning_module = HiddenTwoStreamLightning(model, cfg, datasets, metrics, criterion) - trainer = get_trainer_from_cfg(cfg, lightning_module, stopper) - # see above for horrible syntax explanation - # lightning_module = HiddenTwoStreamLightning(model, cfg, datasets, metrics, criterion) trainer.fit(lightning_module) - # trainer.test(model=lightning_module) return model - # utils.save_hidden_two_stream(model, rundir, dict(cfg), stopper.epoch_counter) def build_model_from_cfg( @@ -200,23 +177,18 @@ def build_model_from_cfg( ---------- cfg: DictConfig configuration, e.g. from Hydra command line - return_components: bool - if True, returns spatial classifier and flow classifier individually pos: np.ndarray Number of positive examples in dataset. Used for initializing biases in final layer neg: np.ndarray Number of negative examples in dataset. Used for initializing biases in final layer + num_classes: int, optional + Number of classes to use. If None, uses cfg.project.class_names Returns ------- - if `return_components`: - spatial_classifier, flow_classifier: nn.Module, nn.Module - cnns for classifying rgb images and optic flows - else: - hidden two stream model: nn.Module - hidden two stream CNN + tuple + (flow_generator, spatial_classifier, flow_classifier, fusion, model) """ - # device = torch.device("cuda:" + str(cfg.compute.gpu_id) if torch.cuda.is_available() else "cpu") device = "cpu" feature_extractor_weights = projects.get_weightfile_from_cfg(cfg, "feature_extractor") if num_classes is None: @@ -228,6 +200,7 @@ def build_model_from_cfg( reload_imagenet = feature_extractor_weights is None if cfg.feature_extractor.arch == "resnet3d_34": assert feature_extractor_weights is not None, "Must specify path to resnet3d weights!" + spatial_classifier = get_cnn( cfg.feature_extractor.arch, in_channels=in_channels, @@ -238,11 +211,11 @@ def build_model_from_cfg( neg=neg, final_bn=cfg.feature_extractor.final_bn, ) - # load this specific component from the weight file if feature_extractor_weights is not None: spatial_classifier = utils.load_feature_extractor_components( spatial_classifier, feature_extractor_weights, "spatial", device=device ) + in_channels = cfg.feature_extractor.n_flows * 2 if "3d" not in cfg.feature_extractor.arch else 2 flow_classifier = get_cnn( cfg.feature_extractor.arch, @@ -262,9 +235,9 @@ def build_model_from_cfg( flow_generator = build_flow_generator(cfg) flow_weights = projects.get_weightfile_from_cfg(cfg, "flow_generator") - assert flow_weights is not None, ( - "Must have a valid weightfile for flow generator. Use " "deepethogram.flow_generator.train or cfg.reload.latest" - ) + assert ( + flow_weights is not None + ), "Must have a valid weightfile for flow generator. Use deepethogram.flow_generator.train or cfg.reload.latest" flow_generator = utils.load_weights(flow_generator, flow_weights, device=device) spatial_classifier, flow_classifier, fusion = build_fusion_layer( @@ -274,7 +247,6 @@ def build_model_from_cfg( fusion = utils.load_feature_extractor_components(fusion, feature_extractor_weights, "fusion", device=device) model = HiddenTwoStream(flow_generator, spatial_classifier, flow_classifier, fusion, cfg.feature_extractor.arch) - # log.info(model.fusion.flow_weight) model.set_mode("classifier") return flow_generator, spatial_classifier, flow_classifier, fusion, model @@ -284,17 +256,17 @@ class HiddenTwoStreamLightning(BaseLightningModule): """Lightning Module for training Feature Extractor models""" def __init__(self, model: nn.Module, cfg: DictConfig, datasets: dict, metrics, criterion: nn.Module): - """constructor + """Constructor Parameters ---------- model : nn.Module - nn.Module, hidden two-stream CNNs + Hidden two-stream CNNs cfg : DictConfig omegaconf configuration datasets : dict dictionary containing Dataset classes - metrics : [type] + metrics : Classification metrics object for saving and computing metrics criterion : nn.Module loss function @@ -302,7 +274,6 @@ def __init__(self, model: nn.Module, cfg: DictConfig, datasets: dict, metrics, c super().__init__(model, cfg, datasets, metrics, viz.visualize_logger_multilabel_classification) self.has_logged_channels = False - # for convenience self.final_activation = self.hparams.feature_extractor.final_activation if self.final_activation == "softmax": self.activation = nn.Softmax(dim=1) @@ -314,7 +285,7 @@ def __init__(self, model: nn.Module, cfg: DictConfig, datasets: dict, metrics, c self.criterion = criterion def validate_batch_size(self, batch: dict): - """simple check for appropriate batch sizes + """Simple check for appropriate batch sizes Parameters ---------- @@ -327,10 +298,8 @@ def validate_batch_size(self, batch: dict): verified batch dictionary """ if self.hparams.compute.dali: - # no idea why they wrap this, maybe they fixed it? batch = batch[0] if "images" in batch.keys(): - # weird case of batch size = 1 somehow getting squeezed out if batch["images"].ndim != 5: batch["images"] = batch["images"].unsqueeze(0) if "labels" in batch.keys(): @@ -353,17 +322,11 @@ def training_step(self, batch: dict, batch_idx: int): loss : torch.Tensor mean loss for batch for Lightning's backward + update hooks """ - # use the forward function - # return the image tensor so we can visualize after gpu transforms images, outputs = self(batch, "train") - probabilities = self.activation(outputs) - loss, loss_dict = self.criterion(outputs, batch["labels"], self.model) - self.visualize_batch(images, probabilities, batch["labels"], "train") - # save the model outputs to a buffer for various metrics self.metrics.buffer.append( "train", {"loss": loss.detach(), "probs": probabilities.detach(), "labels": batch["labels"].detach()} ) @@ -374,7 +337,7 @@ def training_step(self, batch: dict, batch_idx: int): return loss def validation_step(self, batch: dict, batch_idx: int): - """runs a single validation step + """Runs a single validation step Parameters ---------- @@ -385,7 +348,6 @@ def validation_step(self, batch: dict, batch_idx: int): """ images, outputs = self(batch, "val") probabilities = self.activation(outputs) - loss, loss_dict = self.criterion(outputs, batch["labels"], self.model) self.visualize_batch(images, probabilities, batch["labels"], "val") self.metrics.buffer.append( @@ -396,7 +358,7 @@ def validation_step(self, batch: dict, batch_idx: int): self.log("val/loss", loss.detach().cpu()) def test_step(self, batch: dict, batch_idx: int): - """runs test step + """Runs test step Parameters ---------- @@ -414,7 +376,7 @@ def test_step(self, batch: dict, batch_idx: int): self.metrics.buffer.append("test", loss_dict) def visualize_batch(self, images: torch.Tensor, probs: torch.Tensor, labels: torch.Tensor, split: str): - """generates example images of a given batch and saves to disk as a Matplotlib figure + """Generates example images of a given batch and saves to disk as a Matplotlib figure Parameters ---------- @@ -468,13 +430,12 @@ def visualize_batch(self, images: torch.Tensor, probs: torch.Tensor, labels: tor # should've been closed in viz.save_figure. this is double checking plt.close(fig) plt.close("all") - except: + except Exception: pass torch.cuda.empty_cache() - # self.viz_cnt[split] += 1 def forward(self, batch: dict, mode: str) -> Tuple[torch.Tensor, torch.Tensor]: - """runs forward pass, including GPU-based image augmentations + """Runs forward pass, including GPU-based image augmentations Parameters ---------- @@ -486,7 +447,7 @@ def forward(self, batch: dict, mode: str) -> Tuple[torch.Tensor, torch.Tensor]: Returns ------- Tuple[torch.Tensor, torch.Tensor] - [description] + (gpu_images, outputs) """ batch = self.validate_batch_size(batch) # lightning handles transfer to device @@ -513,7 +474,7 @@ def forward(self, batch: dict, mode: str) -> Tuple[torch.Tensor, torch.Tensor]: return gpu_images, outputs def log_image_statistics(self, images: torch.Tensor): - """logs the min, mean, max, and std deviation of input tensors. useful for debugging + """Logs the min, mean, max, and std deviation of input tensors Parameters ---------- @@ -588,7 +549,7 @@ def get_criterion(cfg: DictConfig, model, data_info: dict, device=None): elif final_activation == "sigmoid": pos_weight = data_info["pos_weight"] - if type(pos_weight) == np.ndarray: + if isinstance(pos_weight, np.ndarray): pos_weight = torch.from_numpy(pos_weight) pos_weight = pos_weight.to(device) if device is not None else pos_weight data_criterion = BinaryFocalLoss( @@ -598,7 +559,6 @@ def get_criterion(cfg: DictConfig, model, data_info: dict, device=None): raise NotImplementedError regularization_criterion = get_regularization_loss(cfg, model) - criterion = ClassificationLoss(data_criterion, regularization_criterion) criterion = criterion.to(device) if device is not None else criterion @@ -625,8 +585,10 @@ def get_metrics( is_kinetics (bool): if true, don't make confusion matrices key_metric (str): the key metric will be used for learning rate scheduling and stopping - Returns: - Classification metrics object + Returns + ------- + Classification + metrics object """ metric_list = ["accuracy", "mean_class_accuracy", "f1"] if not is_kinetics: @@ -641,5 +603,4 @@ def get_metrics( if __name__ == "__main__": project_path = projects.get_project_path_from_cl(sys.argv) cfg = make_feature_extractor_train_cfg(project_path, use_command_line=True) - feature_extractor_train(cfg) diff --git a/deepethogram/file_io.py b/deepethogram/file_io.py index 48cc1f5..8a76837 100644 --- a/deepethogram/file_io.py +++ b/deepethogram/file_io.py @@ -13,10 +13,8 @@ def read_labels(labelfile: Union[str, os.PathLike]) -> np.ndarray: labeltype = os.path.splitext(labelfile)[1][1:] if labeltype == "csv": label = read_label_csv(labelfile) - # return(read_label_csv(labelfile)) elif labeltype == "h5": label = read_label_hdf5(labelfile) - # return(read_label_hdf5(labelfile)) else: raise ValueError("Unknown labeltype: {}".format(labeltype)) H, W = label.shape diff --git a/deepethogram/flow_generator/__init__.py b/deepethogram/flow_generator/__init__.py index d354269..e69de29 100644 --- a/deepethogram/flow_generator/__init__.py +++ b/deepethogram/flow_generator/__init__.py @@ -1 +0,0 @@ -# from . import train diff --git a/deepethogram/flow_generator/inference.py b/deepethogram/flow_generator/inference.py index 4bb0887..385968b 100644 --- a/deepethogram/flow_generator/inference.py +++ b/deepethogram/flow_generator/inference.py @@ -1,24 +1,24 @@ -from functools import partial import logging import os import shutil import sys +from functools import partial from typing import Union import cv2 import numpy as np -from omegaconf import OmegaConf, ListConfig import torch +from omegaconf import ListConfig, OmegaConf from torch.utils.data import DataLoader from tqdm import tqdm from vidio import VideoWriter -from deepethogram.configuration import make_feature_extractor_inference_cfg from deepethogram import projects, utils +from deepethogram.configuration import make_feature_extractor_inference_cfg from deepethogram.data.augs import get_cpu_transforms, get_gpu_transforms from deepethogram.data.datasets import VideoIterable from deepethogram.flow_generator.train import build_model_from_cfg as build_flow_generator -from deepethogram.flow_generator.utils import flow_to_rgb_polar, flow_to_rgb +from deepethogram.flow_generator.utils import flow_to_rgb, flow_to_rgb_polar log = logging.getLogger(__name__) @@ -79,7 +79,7 @@ def extract_movie( # b1=[1,2,3,4,5,6,7,8,9,10,11] # we will just hack it to take the first image. really, we should only run each batch once, then save all 11 # images in a row - if type(flows) == list or type(flows) == tuple: + if isinstance(flows, (list, tuple)): flows = flows[0] # only support batch size of 1 flow = flows[0, 8:10, ...] @@ -144,9 +144,9 @@ def flow_generator_inference(cfg): # figure out which videos to run inference on if directory_list is None or len(directory_list) == 0: raise ValueError( - "must pass list of directories from commmand line. " "Ex: directory_list=[path_to_dir1,path_to_dir2]" + "must pass list of directories from commmand line. Ex: directory_list=[path_to_dir1,path_to_dir2]" ) - elif type(directory_list) == str and directory_list == "all": + elif isinstance(directory_list, str) and directory_list == "all": basedir = cfg.project.data_path directory_list = utils.get_subfiles(basedir, "directory") elif isinstance(directory_list, str): @@ -170,9 +170,9 @@ def flow_generator_inference(cfg): for record in records: rgb.append(record["rgb"]) - assert cfg.feature_extractor.n_flows + 1 == cfg.flow_generator.n_rgb, ( - "Flow generator inputs must be one greater " "than feature extractor num flows " - ) + assert ( + cfg.feature_extractor.n_flows + 1 == cfg.flow_generator.n_rgb + ), "Flow generator inputs must be one greater than feature extractor num flows " # set up gpu augmentation input_images = cfg.feature_extractor.n_flows + 1 mode = "3d" if "3d" in cfg.feature_extractor.arch.lower() else "2d" diff --git a/deepethogram/flow_generator/losses.py b/deepethogram/flow_generator/losses.py index 9f3f141..1471d4b 100644 --- a/deepethogram/flow_generator/losses.py +++ b/deepethogram/flow_generator/losses.py @@ -1,3 +1,10 @@ +"""Loss functions for optical flow estimation and motion prediction. + +This module provides various loss functions used in optical flow estimation and motion prediction, +including SSIM loss, gradient-based losses, and smoothness terms. The main class MotionNetLoss +combines these components for training motion prediction networks. +""" + import logging from math import exp @@ -8,22 +15,42 @@ log = logging.getLogger(__name__) -# SSIM from this repo: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py def gaussian(window_size, sigma): + """Create a 1D Gaussian window. + + Implementation based on: https://github.com/Po-Hsun-Su/pytorch-ssim + """ gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) return gauss / gauss.sum() -# SSIM from this repo: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py def create_window(window_size, channel): + """Create a 2D Gaussian window for SSIM calculation. + + Implementation based on: https://github.com/Po-Hsun-Su/pytorch-ssim + """ _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) window = torch.Tensor((_2D_window.expand(channel, 1, window_size, window_size).contiguous())) return window -# SSIM from this repo: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py def _ssim(img1, img2, window, window_size, channel, size_average=True): + """Calculate SSIM between two images. + + Implementation based on: https://github.com/Po-Hsun-Su/pytorch-ssim + + Args: + img1: First image tensor + img2: Second image tensor + window: Gaussian window for SSIM calculation + window_size: Size of the Gaussian window + channel: Number of channels in the images + size_average: If True, average SSIM across spatial dimensions + + Returns: + Tensor containing SSIM value(s) + """ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) @@ -47,6 +74,14 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True): class SSIMLoss(torch.nn.Module): + """SSIM loss module for comparing structural similarity between images. + + Args: + window_size: Size of the Gaussian window for SSIM calculation + size_average: If True, average loss across spatial dimensions + denominator: Scaling factor for the final loss value + """ + def __init__(self, window_size=11, size_average=True, denominator=2): super(SSIMLoss, self).__init__() self.window_size = window_size @@ -55,7 +90,6 @@ def __init__(self, window_size=11, size_average=True, denominator=2): self.window = create_window(window_size, self.channel) self.denominator = denominator - # @profile def forward(self, img1, img2): (_, channel, _, _) = img1.size() @@ -77,27 +111,42 @@ def forward(self, img1, img2): # PyTorch is NCHW def gradient_x(img, mode="constant"): - # use indexing to get horizontal gradients, which chops off one column + """Calculate horizontal gradients of an image. + + Args: + img: Input image tensor in NCHW format + mode: Padding mode for boundary handling + + Returns: + Tensor containing horizontal gradients + """ gx = img[:, :, :, :-1] - img[:, :, :, 1:] - # pad the results with one zeros column on the right return F.pad(gx, (0, 1, 0, 0), mode=mode) def gradient_y(img, mode="constant"): - # use indexing to get vertical gradients, which chops off one row + """Calculate vertical gradients of an image. + + Args: + img: Input image tensor in NCHW format + mode: Padding mode for boundary handling + + Returns: + Tensor containing vertical gradients + """ gy = img[:, :, :-1, :] - img[:, :, 1:, :] - # pad the result with one zeros column on bottom return F.pad(gy, (0, 0, 0, 1), mode=mode) def get_gradients(img): + """Calculate both horizontal and vertical gradients of an image.""" gx = gradient_x(img) gy = gradient_y(img) return gx + gy -# simpler version of ssim loss, uses average pooling instead of guassian kernels def SSIM_simple(x, y): + """Calculate a simplified version of SSIM using average pooling instead of Gaussian kernels.""" C1 = 0.01**2 C2 = 0.03**2 @@ -116,8 +165,8 @@ def SSIM_simple(x, y): return torch.clamp((1 - SSIM_full), min=0, max=1) -# @profile def total_generalized_variation(image, flow): + """Calculate total generalized variation between image and flow fields.""" flowx = flow[:, 0:1, ...] flowy = flow[:, 1:, ...] @@ -140,6 +189,7 @@ def total_generalized_variation(image, flow): def smoothness_firstorder(image, flow): + """Calculate first-order smoothness term weighted by image gradients.""" flow_gradients_x = gradient_x(flow) flow_gradients_y = gradient_y(flow) @@ -159,14 +209,32 @@ def smoothness_firstorder(image, flow): def charbonnier(tensor, alpha=0.4, eps=1e-4): + """Apply Charbonnier penalty function.""" return (tensor * tensor + eps * eps) ** alpha def charbonnier_smoothness(flows, alpha=0.3, eps=1e-7): + """Calculate smoothness term using Charbonnier penalty on flow gradients.""" return charbonnier(gradient_x(flows), alpha=alpha, eps=eps) + charbonnier(gradient_y(flows), alpha=alpha, eps=eps) class MotionNetLoss(torch.nn.Module): + """Combined loss function for motion prediction networks. + + Combines reconstruction loss (L1 + SSIM), smoothness terms, and optional flow sparsity. + Supports multi-scale predictions with different weights at each scale. + + Args: + regularization_criterion: Loss function for model weight regularization + is_multiscale: Whether to compute loss at multiple scales + smooth_weights: List of weights for smoothness terms at each scale + highres: Whether to add an additional high-resolution scale + calculate_ssim_full: Whether to compute SSIM at full resolution + flow_sparsity: Whether to add sparsity penalty on flow predictions + sparsity_weight: Weight for the flow sparsity term + smooth_weight_multiplier: Global multiplier for smoothness weights + """ + def __init__( self, regularization_criterion, @@ -195,6 +263,18 @@ def __init__( self.regularization_criterion = regularization_criterion def forward(self, originals, images, reconstructed, outputs, model: torch.nn.Module): + """Compute the combined loss. + + Args: + originals: Original input images + images: Target images at multiple scales + reconstructed: Reconstructed images at multiple scales + outputs: Flow predictions at multiple scales + model: The network model for regularization + + Returns: + tuple: (total_loss, dict of individual loss components) + """ if type(images) is not tuple: images = images if type(reconstructed) is not tuple: @@ -236,8 +316,6 @@ def forward(self, originals, images, reconstructed, outputs, model: torch.nn.Mod if self.flow_sparsity: # use the same smoothness loss flow_l1s = [torch.mean(torch.abs(i), dim=[1, 2, 3]) * self.sparsity_weight for i in outputs] - # flow_l1s = [torch.mean(torch.abs(i), dim=[1, 2, 3]) * weight*self.sparsity_weight for - # i, weight in zip(outputs, weights)] # Note: adding a full-size SSIM for metrics only! if self.calculate_ssim_full: @@ -245,24 +323,15 @@ def forward(self, originals, images, reconstructed, outputs, model: torch.nn.Mod N, C, H, W = originals.shape num_images = int(C / 3) - 1 recon_h, recon_w = reconstructed[0].size(-2), reconstructed[0].size(-1) - # print(originals.shape) - # print(reconstructed[0].shape) - # print(images[0].shape) + if H != recon_h or W != recon_w: - # t0 = originals[:, 0:3, ...] t0 = originals[:, : num_images * 3, ...].contiguous().view(N * num_images, 3, H, W) - # print('t0: {}'.format(t0.shape)) recon = reconstructed[0] - # print('recon: {}'.format(recon.shape)) recon_fullsize = F.interpolate(recon, size=(H, W), mode="bilinear", align_corners=False) - # t0 = originals[:, :num_images * 3, ...].contiguous().view(N * num_images, 3, H, W) - # print('t0: ', t0.shape) - # recon_fullsize = F.interpolate(reconstructed[0], size=(H, W), mode='bilinear', align_corners=False) else: t0 = images[0] recon_fullsize = reconstructed[0] SSIM_full = self.ssim(t0, recon_fullsize) - # print('SSIM_FULL: ', SSIM_full.shape) SSIM_full_mean = SSIM_full.mean(dim=[1, 2, 3]) else: SSIM_full_mean = torch.from_numpy(np.array([np.nan])) @@ -291,9 +360,8 @@ def forward(self, originals, images, reconstructed, outputs, model: torch.nn.Mod ) if loss != loss: - import pdb - - pdb.set_trace() + msg = "Loss is NaN, re-run training with pytorch debug mode enabled" + raise ValueError(msg) if self.flow_sparsity: # sum across scales diff --git a/deepethogram/flow_generator/models/FlowNetS.py b/deepethogram/flow_generator/models/FlowNetS.py index ee6f91b..8155e41 100644 --- a/deepethogram/flow_generator/models/FlowNetS.py +++ b/deepethogram/flow_generator/models/FlowNetS.py @@ -28,10 +28,13 @@ .. [1]: Fischer et al. FlowNet: Learning optical flow with convolutional networks. ICCV 2015 https://arxiv.org/abs/1504.06852 """ + +import torch +import torch.nn as nn import torch.nn.functional as F from torch.nn import init -from .components import * +from .components import conv, deconv, get_hw, predict_flow class FlowNetS(nn.Module): @@ -67,15 +70,7 @@ def __init__(self, num_images=2, batchNorm=True, flow_div=1): self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) - # self.upsampled_flow6_to_5 = nn.Sequential(Interpolate(scale_factor=2), - # nn.Conv2d(2,2,kernel_size=3, stride=1, padding=1, bias=False)) - # self.upsampled_flow5_to_4 = nn.Sequential(Interpolate(scale_factor=2), - # nn.Conv2d(2,2,kernel_size=3, stride=1, padding=1, bias=False)) - # self.upsampled_flow4_to_3 = nn.Sequential(Interpolate(scale_factor=2), - # nn.Conv2d(2,2,kernel_size=3, stride=1, padding=1, bias=False)) - # self.upsampled_flow3_to_2 = nn.Sequential(Interpolate(scale_factor=2), - # nn.Conv2d(2,2,kernel_size=3, stride=1, padding=1, bias=False)) - + # initialize weights for m in self.modules(): if isinstance(m, nn.Conv2d): if m.bias is not None: @@ -86,7 +81,7 @@ def __init__(self, num_images=2, batchNorm=True, flow_div=1): if m.bias is not None: init.uniform_(m.bias) init.xavier_uniform_(m.weight) - # init_deconv_bilinear(m.weight) + self.upsample1 = nn.Upsample(scale_factor=4, mode="bilinear") def forward(self, x): diff --git a/deepethogram/flow_generator/models/MotionNet.py b/deepethogram/flow_generator/models/MotionNet.py index bcf73bc..3451b6e 100644 --- a/deepethogram/flow_generator/models/MotionNet.py +++ b/deepethogram/flow_generator/models/MotionNet.py @@ -1,4 +1,4 @@ -""" Re-implementation of the MotionNet architecture +"""Re-implementation of the MotionNet architecture References ------- @@ -26,9 +26,10 @@ import logging +import torch.nn as nn from torch.nn import init -from .components import * +from .components import CropConcat, conv, deconv, i_conv, predict_flow log = logging.getLogger(__name__) @@ -97,8 +98,6 @@ def __init__(self, num_images=11, batchNorm=True, flow_div=1): # init_deconv_bilinear(m.weight) self.upsample1 = nn.Upsample(scale_factor=4, mode="bilinear") - # print('Flow div: {}'.format(self.flow_div)) - def forward(self, x): N, C, H, W = x.shape # 1 -> 1 @@ -133,11 +132,6 @@ def forward(self, x): # if the image sizes are not divisible by 8, there will be rounding errors in the size # between the downsampling and upsampling phases - # if get_hw(out_conv5) != get_hw(out_deconv5): - # out_conv5 = F.interpolate(out_conv5, size=get_hw(out_deconv5), - # mode='bilinear', align_corners=False) - - # concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1) concat5 = self.concat((out_conv5, out_deconv5, flow6_up)) out_interconv5 = self.xconv5(concat5) flow5 = self.predict_flow5(out_interconv5) * self.flow_div @@ -145,11 +139,6 @@ def forward(self, x): flow5_up = self.upsampled_flow5_to_4(flow5) * 2 out_deconv4 = self.deconv4(concat5) - # if get_hw(out_conv4) != get_hw(out_deconv4): - # out_conv4 = F.interpolate(out_conv4, size=get_hw(out_deconv4), - # mode='bilinear', align_corners=False) - - # concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1) concat4 = self.concat((out_conv4, out_deconv4, flow5_up)) out_interconv4 = self.xconv4(concat4) flow4 = self.predict_flow4(out_interconv4) * self.flow_div @@ -158,40 +147,14 @@ def forward(self, x): # if the image sizes are not divisible by 8, there will be rounding errors in the size # between the downsampling and upsampling phases - # if get_hw(out_conv3) != get_hw(out_deconv3): - # out_conv3 = F.interpolate(out_conv3, size=get_hw(out_deconv3), - # mode='bilinear', align_corners=False) - # concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1) concat3 = self.concat((out_conv3, out_deconv3, flow4_up)) out_interconv3 = self.xconv3(concat3) flow3 = self.predict_flow3(out_interconv3) * self.flow_div flow3_up = self.upsampled_flow3_to_2(flow3) * 2 out_deconv2 = self.deconv2(concat3) - # if get_hw(out_conv2) != get_hw(out_deconv2): - # out_conv2 = F.interpolate(out_conv2, size=get_hw(out_deconv2), - # mode='bilinear', align_corners=False) - # concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1) concat2 = self.concat((out_conv2, out_deconv2, flow3_up)) out_interconv2 = self.xconv2(concat2) flow2 = self.predict_flow2(out_interconv2) * self.flow_div - # flow1 = F.interpolate(flow2, (H, W), mode='bilinear', align_corners=False)*2 - # flow2*=self.flow_div - # flow3*=self.flow_div - # flow4*=self.flow_div - # flow5*=self.flow_div - # flow6*=self.flow_div - # print('Original shape: {}'.format((N,C,H,W))) - # print('flow1: {}'.format(flow1.shape)) - # print('flow2: {}'.format(flow2.shape)) - # print('flow3: {}'.format(flow3.shape)) - # print('flow4: {}'.format(flow4.shape)) - # print('flow5: {}'.format(flow5.shape)) - - # if self.training: - # return flow1, flow2, flow3, flow4 - # else: - # return flow1, - # return flow2, flow3, flow4, flow5, flow6 return flow2, flow3, flow4 diff --git a/deepethogram/flow_generator/models/TinyMotionNet.py b/deepethogram/flow_generator/models/TinyMotionNet.py index e722bfd..056df4f 100644 --- a/deepethogram/flow_generator/models/TinyMotionNet.py +++ b/deepethogram/flow_generator/models/TinyMotionNet.py @@ -1,4 +1,4 @@ -""" Re-implementation of the TinyMotionNet architecture +"""Re-implementation of the TinyMotionNet architecture References ------- @@ -23,10 +23,12 @@ Changes: changed filter sizes, number of input images, number of layers, added cropping or interpolation for non-power-of-two shaped images, and multiplication... only kept their naming convention and overall structure """ + import logging -# import warnings -from .components import * +import torch.nn as nn + +from .components import CropConcat, Interpolate, conv, deconv, i_conv, predict_flow log = logging.getLogger(__name__) @@ -46,7 +48,6 @@ def __init__(self, num_images=11, input_channels=None, batchNorm=True, output_ch else: self.output_channels = int(output_channels) - # self.out_channels = int((num_images-1)*2) self.batchNorm = batchNorm log.debug("ignoring flow div value of {}: setting to 1 instead".format(flow_div)) self.flow_div = 1 @@ -94,15 +95,4 @@ def forward(self, x): out_interconv2 = self.xconv2(concat2) flow2 = self.predict_flow2(out_interconv2) * self.flow_div - # flow1 = F.interpolate(flow2, (H, W), mode='bilinear', align_corners=False) * 2 - # flow2*=self.flow_div - # flow3*=self.flow_div - # flow4*=self.flow_div - # import pdb - # pdb.set_trace() - - # if self.training: - # return flow1, flow2, flow3, flow4 - # else: - # return flow1, return flow2, flow3, flow4 diff --git a/deepethogram/flow_generator/models/TinyMotionNet3D.py b/deepethogram/flow_generator/models/TinyMotionNet3D.py index 776aaa4..6988d98 100644 --- a/deepethogram/flow_generator/models/TinyMotionNet3D.py +++ b/deepethogram/flow_generator/models/TinyMotionNet3D.py @@ -17,10 +17,12 @@ Changes: 2D -> 3D. changed filter sizes, number of input images, number of layers... only kept their naming convention and overall structure """ + import logging -from .components import * -# import warnings +import torch.nn as nn + +from .components import CropConcat, conv3d, deconv3d, predict_flow_3d class TinyMotionNet3D(nn.Module): @@ -32,7 +34,6 @@ def __init__(self, num_images=11, input_channels=3, batchnorm=True, flow_div=1, else: self.input_channels = int(input_channels) - # self.out_channels = int((num_images-1)*2) self.batchnorm = batchnorm bias = not self.batchnorm logging.debug("ignoring flow div value of {}: setting to 1 instead".format(flow_div)) @@ -80,55 +81,34 @@ def __init__(self, num_images=11, input_channels=3, batchnorm=True, flow_div=1, self.predict_flow2 = predict_flow_3d(self.channels[0], 2) self.upsampled_flow4_to_3 = nn.ConvTranspose3d(2, 2, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)) - # self.upsampled_flow4_to_3 = nn.ConvTranspose3d(2, 2, kernel_size=(1,4,4), stride=(1,2,2), padding=1) self.upsampled_flow3_to_2 = nn.ConvTranspose3d(2, 2, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1)) self.concat = CropConcat(dim=1) - # self.interpolate = Interpolate def forward(self, x): # N, C, T, H, W = x.shape out_conv1 = self.conv1(x) # 1 -> 1 - # print('out_conv1: {}'.format(out_conv1.shape)) out_conv2 = self.conv2(out_conv1) # 1 -> 1/2 - # print('out_conv2: {}'.format(out_conv2.shape)) out_conv3 = self.conv3(out_conv2) # 1/2 -> 1/4 - # print('out_conv3: {}'.format(out_conv3.shape)) out_conv4 = self.conv4(out_conv3) # 1/4 -> 1/8 - # print('out_conv4: {}'.format(out_conv4.shape)) out_conv5 = self.conv5(out_conv4) - # print('out_conv5: {}'.format(out_conv5.shape)) flow4 = self.predict_flow4(out_conv5) * self.flow_div - # print('flow4: {}'.format(flow4.shape)) # see motionnet.py for explanation of multiplying by 2 flow4_up = self.upsampled_flow4_to_3(flow4) * 2 - # print('flow4_up: {}'.format(flow4_up.shape)) out_deconv3 = self.deconv3(out_conv5) - # print('out_deconv3: {}'.format(out_deconv3.shape)) iconv3 = self.iconv3(out_conv3) - # print('iconv3: {}'.format(iconv3.shape)) concat3 = self.concat((iconv3, out_deconv3, flow4_up)) - # print('concat3: {}'.format(concat3.shape)) out_interconv3 = self.xconv3(concat3) - - # print('out_interconv3: {}'.format(out_interconv3.shape)) flow3 = self.predict_flow3(out_interconv3) * self.flow_div - # print('flow3: {}'.format(flow3.shape)) flow3_up = self.upsampled_flow3_to_2(flow3) * 2 - # print('flow3_up: {}'.format(flow3_up.shape)) out_deconv2 = self.deconv2(out_interconv3) - # print('out_deconv2: {}'.format(out_deconv2.shape)) iconv2 = self.iconv2(out_conv2) - # print('iconv2: {}'.format(iconv2.shape)) concat2 = self.concat((iconv2, out_deconv2, flow3_up)) - # print('concat2: {}'.format(concat2.shape)) out_interconv2 = self.xconv2(concat2) - # print('out_interconv2: {}'.format(out_interconv2.shape)) flow2 = self.predict_flow2(out_interconv2) * self.flow_div - # print('flow2: {}'.format(flow2.shape)) return flow2, flow3, flow4 diff --git a/deepethogram/flow_generator/models/components.py b/deepethogram/flow_generator/models/components.py index 8e1e73b..8831038 100644 --- a/deepethogram/flow_generator/models/components.py +++ b/deepethogram/flow_generator/models/components.py @@ -100,7 +100,7 @@ def __init__(self, dim: int = 1): self.dim = dim def forward(self, tensors: tuple) -> torch.Tensor: - assert type(tensors) == tuple + assert isinstance(tensors, tuple) hs, ws = [tensor.size(-2) for tensor in tensors], [tensor.size(-1) for tensor in tensors] h, w = min(hs), min(ws) @@ -115,7 +115,7 @@ def conv3d( bias: bool = True, batchnorm: bool = True, act: bool = True, - padding: tuple = None, + padding: Union[int, tuple, None] = None, ): """3D convolution @@ -147,9 +147,9 @@ def conv3d( nn.Sequential with conv3d, (batchnorm), (activation function) """ modules = [] - if padding is None and type(kernel_size) == int: + if padding is None and isinstance(kernel_size, int): padding = (kernel_size - 1) // 2 - elif padding is None and type(kernel_size) == tuple: + elif padding is None and isinstance(kernel_size, tuple): padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2, (kernel_size[2] - 1) // 2) else: raise ValueError("Unknown padding type {} and kernel_size type: {}".format(padding, kernel_size)) diff --git a/deepethogram/flow_generator/utils.py b/deepethogram/flow_generator/utils.py index 98183c0..336b7c1 100644 --- a/deepethogram/flow_generator/utils.py +++ b/deepethogram/flow_generator/utils.py @@ -1,5 +1,5 @@ import warnings -from typing import Union, Tuple +from typing import Tuple, Union import cv2 import numpy as np @@ -96,7 +96,6 @@ def flow_to_rgb_polar(flow: np.ndarray, maxval: Union[int, float] = 20) -> np.nd # magnitue -> saturation color = (mag.astype(np.float32) / maxval).clip(0, 1) color = (color * 255).clip(0, 255).astype(np.uint8) - # hsv[...,1] = cv2.normalize(mag,None,0,255,cv2.NORM_MINMAX) hsv[..., 1] = color rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) return rgb @@ -126,30 +125,12 @@ def rgb_to_flow_polar(image: np.ndarray, maxval: Union[int, float] = 20): ang = hsv[..., 0] ang = ang * 2 * np.pi / 180 - # x,y = cv2.polarToCart(mag, ang) x = mag * np.cos(ang) y = mag * np.sin(ang) flow = np.stack((x, y), axis=2) return flow -# def flow_to_rgb_lrcn(flow, max_flow=10): -# # input: flow, can be positive or negative -# # ranges from -20 to 20, but only 10**-5 pixels are > 10 -# mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) -# mag[np.isinf(mag)] = 0 -# -# img = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.float32) -# half_range = (255 - 128) / max_flow -# img[:, :, 0] = flow[..., 0] * half_range + 128 -# img[:, :, 1] = flow[..., 1] * half_range + 128 -# # maximum magnitude is if x and y are both maxed -# max_magnitude = np.sqrt(max_flow ** 2 + max_flow ** 2) -# img[:, :, 2] = mag * 255 / max_magnitude -# img = img.clip(min=0, max=255).astype(np.uint8) -# return (img) - - class Resample2d(torch.nn.Module): """Module to sample tensors using Spatial Transformer Networks. Caches multiple grids in GPU VRAM for speed. @@ -201,13 +182,12 @@ def __init__( """ super().__init__() if size is not None: - assert type(size) == tuple or type(size) == list + assert isinstance(size, tuple) or isinstance(size, list) self.size = size # identity matrix self.base_mat = torch.Tensor([[1, 0, 0], [0, 1, 0]]) if fp16: - # self.base_mat = self.base_mat.half() pass self.fp16 = fp16 self.device = device @@ -236,7 +216,6 @@ def forward(self, images: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: H, W = self.size else: H, W = flow.size(2), flow.size(3) - # print(H,W) # images: NxCxHxW # flow: Bx2xHxW grid_size = [flow.size(0), 2, flow.size(2), flow.size(3)] @@ -258,7 +237,6 @@ def forward(self, images: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: self.sizes.append(this_size) self.grids.append(this_grid) self.uses.append(0) - # print(this_grid.shape) else: grid_loc = self.sizes.index(grid_size) this_grid = self.grids[grid_loc] @@ -269,7 +247,6 @@ def forward(self, images: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: # this normalizes it so that a value of 2 would move a pixel all the way across the width or height # horiz_only: for stereo matching, Y values are always the same if self.horiz_only: - # flow = flow[:, 0:1, :, :] / ((W - 1.0) / 2.0) flow = torch.cat( [flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), torch.zeros((flow.size(0), flow.size(1), H, W))], 1 ) @@ -323,7 +300,6 @@ def reconstruct_images(self, image_batch: torch.Tensor, flows: Union[tuple, list t0s = [] flows_reshaped = [] for flow in flows: - # upsampled_flow = F.interpolate(flow, (h,w), mode='bilinear', align_corners=False) if flow.ndim == 4: n, c, h, w = flow.size() flow = flow.view(N * num_images, 2, h, w) @@ -390,7 +366,6 @@ def rgb_to_hsv_torch(image: torch.Tensor) -> torch.Tensor: S = torch.zeros_like(r) S[V > 0] = (C / V)[V > 0] - # hsv = torch.stack([H,S,V], dim=-3) return torch.stack([H, S, V], dim=1) diff --git a/deepethogram/gui/custom_widgets.py b/deepethogram/gui/custom_widgets.py index bc44c45..339f459 100644 --- a/deepethogram/gui/custom_widgets.py +++ b/deepethogram/gui/custom_widgets.py @@ -29,14 +29,12 @@ def numpy_to_qpixmap(image: np.ndarray) -> QtGui.QPixmap: else: raise ValueError("Aberrant number of channels: {}".format(C)) qpixmap = QtGui.QPixmap(QtGui.QImage(image, W, H, image.strides[0], format)) - # print(type(qpixmap)) return qpixmap def float_to_uint8(image: np.ndarray) -> np.ndarray: if image.dtype == np.float: image = (image * 255).clip(min=0, max=255).astype(np.uint8) - # print(image) return image @@ -52,12 +50,9 @@ class VideoFrame(QtWidgets.QGraphicsView): def __init__(self, videoFile: Union[str, os.PathLike] = None, *args, **kwargs): super().__init__(*args, **kwargs) - # self.videoView = QtWidgets.QGraphicsView() - self._scene = QtWidgets.QGraphicsScene(self) self._photo = QtWidgets.QGraphicsPixmapItem() self._scene.addItem(self._photo) - # self.videoView.setScene(self._scene) self.setScene(self._scene) sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) @@ -66,24 +61,19 @@ def __init__(self, videoFile: Union[str, os.PathLike] = None, *args, **kwargs): sizePolicy.setHeightForWidth(self.sizePolicy().hasHeightForWidth()) self.setSizePolicy(sizePolicy) self.setMinimumSize(QtCore.QSize(640, 480)) - # self.setObjectName("videoView") self._zoom = 0 if videoFile is not None: self.initialize_video(videoFile) self.update() self.setStyleSheet("background:transparent;") - # print(self.palette()) def initialize_video(self, videofile: Union[str, os.PathLike]): if hasattr(self, "vid"): self.vid.close() - # if hasattr(self.vid, 'cap'): - # self.vid.cap.release() self.videofile = videofile self.vid = VideoReader(videofile) - # self.frame = next(self.vid) self.setDragMode(QtWidgets.QGraphicsView.ScrollHandDrag) self.initialized.emit(len(self.vid)) self.update_frame(0) @@ -123,22 +113,15 @@ def wheelEvent(self, event): self.fitInView() def update_frame(self, value, force: bool = False): - # print('updating') - # print('update to: {}'.format(value)) - # print(self.current_fnum) - # previous_frame = self.current_fnum if not hasattr(self, "vid"): return value = int(value) if hasattr(self, "current_fnum"): if self.current_fnum == value and not force: - # print('already there') return if value < 0: - # warnings.warn('Desired frame less than 0: {}'.format(value)) value = 0 if value >= self.vid.nframes: - # warnings.warn('Desired frame beyond maximum: {}'.format(self.vid.nframes)) value = self.vid.nframes - 1 self.frame = self.vid[value] @@ -147,7 +130,6 @@ def update_frame(self, value, force: bool = False): # position is 1. This makes cv2.CAP_PROP_POS_FRAMES match vid.fnum. However, we want to keep track of our # currently displayed image, which is fnum - 1 self.current_fnum = self.vid.fnum - 1 - # print('new fnum: {}'.format(self.current_fnum)) self.show_image(self.frame) self.frameNum.emit(self.current_fnum) @@ -155,13 +137,11 @@ def fitInView(self, scale=True): rect = QtCore.QRectF(self._photo.pixmap().rect()) if not rect.isNull(): self.setSceneRect(rect) - # if self.hasPhoto(): unity = self.transform().mapRect(QtCore.QRectF(0, 0, 1, 1)) self.scale(1 / unity.width(), 1 / unity.height()) viewrect = self.viewport().rect() scenerect = self.transform().mapRect(rect) factor = min(viewrect.width() / scenerect.width(), viewrect.height() / scenerect.height()) - # print(factor, viewrect, scenerect) self.scale(factor, factor) self._zoom = 0 @@ -182,14 +162,11 @@ def adjust_aspect_ratio(self): def show_image(self, array): qpixmap = numpy_to_qpixmap(array) self._photo.setPixmap(qpixmap) - # self.fitInView() self.update() - # self.show() def resizeEvent(self, event): if hasattr(self, "vid"): pass - # self.fitInView() def mouseDoubleClickEvent(self, event: QtGui.QMouseEvent) -> None: self.fitInView() @@ -242,13 +219,11 @@ def __init__(self, *args, **kwargs): self.plainTextEdit.setObjectName("plainTextEdit") self.horizontalLayout.addWidget(self.plainTextEdit) self.setLayout(self.horizontalLayout) - # self.ui.plainTextEdit.textChanged.connect self.plainTextEdit.textChanged.connect(self.text_change) self.horizontalScrollBar.sliderMoved.connect(self.scrollbar_change) self.horizontalScrollBar.valueChanged.connect(self.scrollbar_change) self.update() - # self.show() def sizeHint(self): return QtCore.QSize(480, 25) @@ -272,15 +247,10 @@ def update_state(self, value: int): @Slot(int) def initialize_state(self, value: int): - # print('nframes: ', value) self.horizontalScrollBar.setMaximum(value - 1) self.horizontalScrollBar.setMinimum(0) - # self.horizontalScrollBar.sliderMoved.connect(self.scrollbar_change) - # self.horizontalScrollBar.valueChanged.connect(self.scrollbar_change) self.horizontalScrollBar.setValue(0) self.plainTextEdit.setPlainText("{}".format(0)) - # self.plainTextEdit.textChanged.connect(self.text_change) - # self.update() class VideoPlayer(QtWidgets.QWidget): @@ -302,15 +272,12 @@ def __init__(self, parent=None, videoFile: Union[str, os.PathLike] = None, *args self.setLayout(layout) # if you use the scrollbar or the text box, update the video frame - # self.scrollbartext.horizontalScrollBar.sliderMoved.connect(self.videoView.update_frame) - # self.scrollbartext.horizontalScrollBar.valueChanged.connect(self.videoView.update_frame) - # self.scrollbartext.plainTextEdit.textChanged.connect(self.videoView.update_frame) + self.scrollbartext.position.connect(self.videoView.update_frame) self.scrollbartext.position.connect(self.scrollbartext.update_state) # if you move the video by any method, update the frame text self.videoView.initialized.connect(self.scrollbartext.initialize_state) - # self.videoView.initialized.connect(initializer) self.videoView.frameNum.connect(self.scrollbartext.update_state) # I have to do this here because I think emitting a signal doesn't work from within the widget's constructor @@ -320,39 +287,6 @@ def __init__(self, parent=None, videoFile: Union[str, os.PathLike] = None, *args self.update() -# class LabelImage(QtWidgets.QScrollArea): -# def __init__(self, parent=None, *args, **kwargs): -# super().__init__(*args, **kwargs) -# -# layout = QtWidgets.QHBoxLayout() -# self.widget = QtWidgets.QWidget() -# -# buttonlayout = QtWidgets.QVBoxLayout() -# self.labels = [] -# sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Maximum) -# sizePolicy.setHorizontalStretch(0) -# sizePolicy.setVerticalStretch(0) -# for i in range(100): -# self.labels.append(QtWidgets.QLabel('testing{}'.format(i))) -# self.labels[i].setMinimumHeight(25) -# buttonlayout.addWidget(self.labels[i]) -# # self.labels[i].setLayout(buttonlayout) -# -# self.widget.setLayout(buttonlayout) -# self.setWidget(self.widget) -# -# self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn) -# self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff) -# -# self.update() -# -# def sizeHint(self): -# return (QtCore.QSize(720, 250)) -# https://stackoverflow.com/questions/29643352/converting-hex-to-rgb-value-in-python - -# start = np.array([232,232,232]) - - class LabelViewer(QtWidgets.QGraphicsView): X = Signal(int) saved = Signal(bool) @@ -368,14 +302,11 @@ def __init__(self, fixed: bool = False, *args, **kwargs): self._scene = QtWidgets.QGraphicsScene(self) self._photo = QtWidgets.QGraphicsPixmapItem() self._scene.addItem(self._photo) - # self.videoView.setScene(self._scene) self.setScene(self._scene) color = QtGui.QColor(45, 45, 45) self.pen = QtGui.QPen(color, 0) self.setAlignment(QtCore.Qt.AlignTop | QtCore.Qt.AlignLeft) - # self.setAlignment(QtCore.Qt.AlignCenter) - # self.setStyleSheet("background:transparent;") self.initialized = False self.fixed = fixed @@ -390,14 +321,13 @@ def initialize( colormap: str = "Reds", unlabeled_alpha: float = 0.1, desired_pixel_size: int = 25, - array: np.ndarray = None, + array: Union[np.ndarray, None] = None, fixed: bool = False, - opacity: np.ndarray = None, + opacity: Union[np.ndarray, None] = None, ): if self.initialized: raise ValueError("only initialize once!") if array is not None: - # print(array.shape) self.n_timepoints = array.shape[0] self.n = array.shape[1] # if our input array is -1s, assume that this has not been labeled yet @@ -419,7 +349,7 @@ def initialize( try: self.cmap = Mapper(colormap) except ValueError: - raise ("Colormap not in matplotlib" "s defaults! {}".format(colormap)) + raise ("Colormap not in matplotlibs defaults! {}".format(colormap)) if self.debug: self.make_debug() @@ -428,14 +358,10 @@ def initialize( self.recreate_label_image() pos_colors = self.cmap(np.ones((self.n, 1)) * 255) neg_colors = self.cmap(np.zeros((self.n, 1))) - # print('N: {}'.format(self.n)) self.pos_color = np.array([pos_colors[i].squeeze() for i in range(self.n)]) self.neg_color = np.array([neg_colors[i].squeeze() for i in range(self.n)]) - # print('pos, neg: {}, {}'.format(self.pos_color, self.neg_color)) - draw_rect = QtCore.QRectF(0, 0, 1, self.n) - # print(dir(self.draw_rect)) self.item_rect = self._scene.addRect(draw_rect, self.pen) self.change_view_x(0) @@ -449,11 +375,9 @@ def initialize( def mousePressEvent(self, event): if not self.initialized: return - # print(dir(event)) scene_pos = self.mapToScene(event.pos()) x, y = scene_pos.x(), scene_pos.y() - # print('X: {} Y: {}'.format(x,y)) x, y = int(x), int(y) value = self.array[x, y] if value == 0: @@ -471,9 +395,7 @@ def mouseMoveEvent(self, event): x, _ = scene_pos.x(), scene_pos.y() y = self.initial_row - # print('X: {} Y: {}'.format(x,y)) x, y = int(x), int(y) - # value = self.array[x, y] if x > self.initial_column: self._add_behavior([y], x, x) @@ -495,7 +417,6 @@ def _fit_label_photo(self): self.view_x = 0 # gets the bounding rectangle (in pixels) for the image of the label array geometry = self.geometry() - # print(geometry) widget_width, widget_height = geometry.width(), geometry.height() num_pix_high = widget_height / self.desired_pixel_size @@ -503,11 +424,9 @@ def _fit_label_photo(self): new_height = num_pix_high new_width = new_height * aspect - # print('W: {} H: {}'.format(new_width, new_height)) rect = QtCore.QRectF(self.view_x, 0, new_width, new_height) self.fitInView(rect) - # self.fitInView(rect, aspectRadioMode=QtCore.Qt.KeepAspectRatio) self.view_height = new_height self.view_width = new_width self.update() @@ -519,40 +438,29 @@ def resizeEvent(self, event: QtGui.QResizeEvent): @Slot(int) def change_view_x(self, x: int): if x < 0 or x >= self.n_timepoints: - # print('return 1') return if not hasattr(self, "view_width"): self._fit_label_photo() if not hasattr(self, "n"): - # print('return 2') return view_x = x - self.view_width // 2 if view_x < 0: - # print('desired view x: {} LEFT SIDE'.format(view_x)) new_x = 0 elif view_x >= self.n_timepoints: - # print('desired view x: {} RIGHT SIDE'.format(view_x)) new_x = self.n_timepoints - 1 else: new_x = view_x - # new_x = max(view_x, 0) - # new_x = min(new_x, self.n_timepoints - 1) - old_x = self.x self.view_x = new_x self.x = x position = QtCore.QPointF(x, 0) - # print('view width: {}'.format(self.view_width)) - # print('new_x: {}'.format(new_x)) - # print('rec_x: {}'.format(position)) self.item_rect.setPos(position) self.X.emit(self.x) rect = QtCore.QRectF(self.view_x, 0, self.view_width, self.view_height) - # print('View rectangle: {}'.format(rect)) self.fitInView(rect) behaviors = [] @@ -562,9 +470,7 @@ def change_view_x(self, x: int): if len(behaviors) > 0: self._add_behavior(behaviors, old_x, x) - # self._fit_label_photo() self.update() - # self.show() def fixed_settings(self): if not hasattr(self, "changed"): @@ -573,13 +479,12 @@ def fixed_settings(self): self.recreate_label_image() def _add_behavior(self, behaviors: Union[int, np.ndarray, list], fstart: int, fend: int): - # print('adding') if self.fixed: return if not hasattr(self, "array"): return n_behaviors = self.image.shape[0] - if type(behaviors) != np.ndarray: + if isinstance(behaviors, (list, np.ndarray)): behaviors = np.array(behaviors) if max(behaviors) > n_behaviors: raise ValueError("Not enough behaviors for number: {}".format(behaviors)) @@ -587,32 +492,24 @@ def _add_behavior(self, behaviors: Union[int, np.ndarray, list], fstart: int, fe raise ValueError("Behavior start frame must be > 0: {}".format(fstart)) if fend > self.n_timepoints: raise ValueError("Behavior end frame must be < nframes: {}".format(fend)) - # log.debug('Behaviors: {} fstart: {} fend: {}'.format(behaviors, fstart, fend)) # go backwards to erase if fstart <= fend: value = 1 time_indices = np.arange(fstart, fend + 1) # want it to be color = self.pos_color - # print('value = 1') elif fstart - fend == 1: value = 0 time_indices = np.array([fend, fstart]) color = self.neg_color else: - # print('value = 0') value = 0 time_indices = np.arange(fstart, fend, -1) color = self.neg_color - # log.debug('time indices: {} value: {}'.format(time_indices, value)) - # handle background specifically if len(behaviors) == 1 and behaviors[0] == 0: - # print('0') self.array[time_indices, 0] = 1 self.array[time_indices, 1:] = 0 - # print('l shape: {}'.format(self.image[1:, time_indices, :].shape)) - # print('r_shape: {}'.format(np.tile(self.neg_color[1:], [1, len(time_indices), 1]).shape)) self.image[0, time_indices, :] = self.pos_color[0] self.image[1:, time_indices, :] = np.dstack( [self.neg_color[1:] for _ in range(len(time_indices))] @@ -621,8 +518,6 @@ def _add_behavior(self, behaviors: Union[int, np.ndarray, list], fstart: int, fe xv, yv = np.meshgrid(time_indices, behaviors, indexing="ij") xv = xv.flatten() yv = yv.flatten() - # log.debug('xv: {} yv: {}'.format(xv, yv)) - # print('yv: {}'.format(yv)) self.array[xv, yv] = value # change color self.image[yv, xv, :] = color[yv] @@ -630,18 +525,12 @@ def _add_behavior(self, behaviors: Union[int, np.ndarray, list], fstart: int, fe self.array[time_indices, 0] = np.logical_not(np.any(self.array[time_indices, 1:], axis=1)) # remap the color for the background column just in case self.image[0, time_indices, :] = self.cmap(self.array[time_indices, 0:1].T * 255).squeeze() - # mapped = self.cmap(self.array[time_indices, 0] * 255) - # print('mapped in add behavior: {}'.format(mapped.shape)) - # self.image[0, time_indices, :] = mapped - # print(self.label.image[0,time_indices]) # change opacity self.image[:, time_indices, 3] = 1 self.changed[time_indices] = 1 # change opacity - # self.label.image[:, indices, 3] = 1 self.saved.emit(False) - # self.label.image = im self.update_image() self.num_changed.emit(self.changed.sum()) @@ -677,21 +566,16 @@ def make_debug(self, num_rows: int = 15000): print("debug") assert hasattr(self, "array") rows, cols = self.shape - # print(rows, cols) - # behav = 0 + for i in range(rows): behav = i % cols self.array[i, behav] = 1 - # self.array = self.array[:num_rows,:] - # print(self.array) def calculate_background_class(self, array: np.ndarray): array[:, 0] = np.logical_not(np.any(array[:, 1:], axis=1)) return array def update_background_class(self): - # import pdb - # pdb.set_trace() self.array = self.calculate_background_class(self.array) def update_image(self): @@ -701,19 +585,12 @@ def update_image(self): self.update() def recreate_label_image(self): - # print('array input shape, will be transposed: {}'.format(self.array.shape)) self.image = self.cmap(self.array.T * 255) if self.opacity is None: opacity = np.ones((self.image.shape[0], self.image.shape[1])) * self.unlabeled_alpha opacity[:, np.where(self.changed)[0]] = 1 else: opacity = self.opacity.copy() - # print('image: {}'.format(self.image)) - # print('image shape in recreate label image: {}'.format(self.image.shape)) - # print('opacity: {}'.format(opacity)) - # print('opacity shape in recreate label image: {}'.format(opacity.shape)) - - # print('chang: {}'.format(self.changed.shape)) self.image[..., 3] = opacity self.update_image() @@ -740,7 +617,6 @@ def toggle_behavior(self, index: int): self.update_background_class() self.recreate_label_image() self.change_view_x(self.x) - # print(self.changed) self.just_toggled.emit(index) self.update() @@ -788,7 +664,7 @@ def _make_button(self, behavior: str, index: int): button.setMinimumHeight(self.minimum_height) button.setCheckable(True) button.setStyleSheet( - "QPushButton { text-align: left; }" "QPushButton:checked { background-color: rgb(30, 30, 30)}" + "QPushButton { text-align: left; }QPushButton:checked { background-color: rgb(30, 30, 30)}" ) return button @@ -854,11 +730,10 @@ def initialize( opacity: np.ndarray = None, ): layout = QtWidgets.QHBoxLayout() - # assert (n == len(behaviors)) assert behaviors[0] == "background" self.label = LabelViewer() - # print(behaviors) + self.behaviors = behaviors self.n = len(self.behaviors) self.label.initialize( @@ -903,7 +778,6 @@ def add_behavior(self, behavior: str): print("2: {}".format(self.behaviors)) print("2 buttons: {}".format(self.buttons.behaviors)) # add to our list of behaviors - # self.behaviors.append(behavior) print("3: {}".format(self.behaviors)) # hook up button to toggling behavior i = len(self.behaviors) - 1 @@ -923,7 +797,6 @@ class ListenForPipeCompletion(QtCore.QThread): def __init__(self, pipe): QtCore.QThread.__init__(self) - # super().__init__(self) self.pipe = pipe def __del__(self): @@ -934,7 +807,6 @@ def run(self): time.sleep(1) if self.pipe.poll() is None: pass - # print('still running...') else: self.has_finished.emit(True) break @@ -944,13 +816,12 @@ class SubprocessChainer(QtCore.QThread): def __init__(self, calls: list): QtCore.QThread.__init__(self) for call in calls: - assert type(call) == list + assert isinstance(call, list) self.calls = calls self.should_continue = True def stop(self): self.should_continue = False - # self.pipe.terminate() def run(self): for call in self.calls: @@ -964,23 +835,9 @@ def run(self): break -# def chained_subprocess_calls(calls: list) -> None: -# def _run(calls): -# for call in calls: -# assert type(call) == list -# -# for call in calls: -# print(call) -# pipe = subprocess.run(call, shell=True) -# thread = threading.Thread(target=_run, args=(calls,)) -# thread.start() -# return thread - - class UnclickButtonOnPipeCompletion(QtCore.QThread): def __init__(self, button, pipe): QtCore.QThread.__init__(self) - # super().__init__(self) self.button = button self.pipe = pipe self.should_continue = True @@ -1000,12 +857,9 @@ def run(self): time.sleep(1) if self.pipe.poll() is None: pass - # print('still running...') else: if not self.has_been_clicked: - # print('ischecked: ', self.button.isChecked()) if self.button.isChecked(): - # print('listener clicking button') self.button.click() break @@ -1021,30 +875,12 @@ def __init__(self): fixed=False, ) - # self.label = LabelViewer() - # self.label.initialize(n=4, n_timepoints=40, debug=True, fixed=True) - # # self.labelImg = DebuggingDrawing() - # next_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence("Right"), self) next_shortcut.activated.connect(partial(self.label.label.change_view_dx, 1)) # next_shortcut.activated.connect(partial(self.label.change_view_dx, 1)) back_shortcut = QtWidgets.QShortcut(QtGui.QKeySequence("Left"), self) back_shortcut.activated.connect(partial(self.label.label.change_view_dx, -1)) - # - # if hasattr(self, 'label'): - # n = self.label.n - # else: - # n = 1 - # self.toggle_shortcuts = [] - # for i in range(n): - # self.toggle_shortcuts.append(QtWidgets.QShortcut(QtGui.QKeySequence(str(i)), self)) - # self.toggle_shortcuts[i].activated.connect(partial(self.label.toggle_behavior, i)) - - # self.buttons = LabelButtons(behaviors = ['background', 'itch', 'scratch', 'poop']) - # back_shortcut.activated.connect(partial(self.labelImg.move_rect, -1)) - - # self.labelImg.make_debug(10) self.setCentralWidget(self.label) self.setMaximumHeight(480) @@ -1056,15 +892,9 @@ def sizeHint(self): if __name__ == "__main__": app = QtWidgets.QApplication([]) - # volume = VideoPlayer(r'C:\DATA\mouse_reach_processed\M134_20141203_v001.h5') + testing = LabelImg() testing.initialize(behaviors=["background", "a", "b", "c", "d", "e"], n_timepoints=15000, debug=True) - # testing = ShouldRunInference(['M134_20141203_v001', - # 'M134_20141203_v002', - # 'M134_20141203_v004'], - # [True, True, False]) - # testing = MainWindow() - # testing.setMaximumHeight(250) testing.update() testing.show() app.exec_() diff --git a/deepethogram/gui/main.py b/deepethogram/gui/main.py index 0fd0775..0bbb373 100644 --- a/deepethogram/gui/main.py +++ b/deepethogram/gui/main.py @@ -6,20 +6,19 @@ from functools import partial from typing import Union -# import hydra import numpy as np import pandas as pd -from PySide2 import QtCore, QtWidgets, QtGui -from PySide2.QtCore import Slot -from PySide2.QtWidgets import QMainWindow, QFileDialog, QInputDialog from omegaconf import DictConfig, OmegaConf +from PySide2 import QtCore, QtGui, QtWidgets +from PySide2.QtCore import Slot +from PySide2.QtWidgets import QFileDialog, QInputDialog, QMainWindow -from deepethogram import projects, utils, configuration +from deepethogram import configuration, projects, utils from deepethogram.file_io import VideoReader -from deepethogram.postprocessing import get_postprocessor_from_cfg -from deepethogram.gui.custom_widgets import UnclickButtonOnPipeCompletion, SubprocessChainer +from deepethogram.gui.custom_widgets import SubprocessChainer, UnclickButtonOnPipeCompletion from deepethogram.gui.mainwindow import Ui_MainWindow -from deepethogram.gui.menus_and_popups import CreateProject, simple_popup_question, ShouldRunInference, overwrite_or_not +from deepethogram.gui.menus_and_popups import CreateProject, ShouldRunInference, overwrite_or_not, simple_popup_question +from deepethogram.postprocessing import get_postprocessor_from_cfg log = logging.getLogger(__name__) @@ -42,10 +41,8 @@ def __init__(self, cfg: DictConfig): self.ui.setupUi(self) self.setWindowTitle("DeepEthogram") - # print(dir(self.ui.actionOpen)) self.ui.videoBox.setLayout(self.ui.formLayout) self.ui.actionOpen.triggered.connect(self.open_avi_browser) - # self.ui.plainTextEdit.textChanged.connect(self.text_change) self.ui.actionAdd.triggered.connect(self.add_class) self.ui.actionRemove.triggered.connect(self.remove_class) self.ui.actionNew_Project.triggered.connect(self._new_project) @@ -127,20 +124,16 @@ def __init__(self, cfg: DictConfig): if os.path.isfile(os.path.join(initialized_directory, "project_config.yaml")): self.initialize_project(initialized_directory) - # log.info('children: {}'.format(self.children())) self.show() def user_did_something(self): if self.timer.isActive(): pass else: - # else, the user was already idle - # will have a timestamp by the logfile log.info("User restarted labeling") self.timer.start() def keyPressEvent(self, event: QtGui.QKeyEvent): - # print('key pressed') self.user_did_something() super().keyPressEvent(event) @@ -153,11 +146,9 @@ def log_idle(self): self.timer.stop() def respond_to_keypress(self, keynum: int): - # print('key pressed') if self.ui.labels.label is not None: self.ui.labels.label.toggle_behavior(keynum) else: - # print('none') return def has_trained(self, model_type: str) -> bool: @@ -175,7 +166,6 @@ def project_loaded_buttons(self): number_finalized_labels = projects.get_number_finalized_labels(self.cfg) log.info("Number finalized labels: {}".format(number_finalized_labels)) if self.has_trained("flow_generator"): - # self.ui.flow_inference.setEnabled(True) self.ui.flow_train.setEnabled(True) if self.has_trained("feature_extractor") or number_finalized_labels > 1: self.ui.featureextractor_infer.setEnabled(True) @@ -209,14 +199,10 @@ def video_loaded_buttons(self): def initialize_video(self, videofile: Union[str, os.PathLike]): if hasattr(self, "vid"): self.vid.close() - # if hasattr(self.vid, 'cap'): - # self.vid.cap.release() self.videofile = videofile try: self.ui.videoPlayer.videoView.initialize_video(videofile) - # for convenience extract the videoplayer object out of the videoView self.vid = self.ui.videoPlayer.videoView.vid - # for convenience self.n_timepoints = len(self.ui.videoPlayer.videoView.vid) log.debug("is deg: {}".format(projects.is_deg_file(videofile))) @@ -272,7 +258,8 @@ def update_video_info(self): duration = nframes / fps fps = "{:.2f}".format(fps) duration = "{:.2f}".format(duration) - except: + except Exception as e: + log.exception("Error getting video info: {}".format(e)) fps = "N/A" duration = "N/A" num_labeled = self.ui.labels.label.changed.sum() @@ -301,9 +288,6 @@ def initialize_label(self, label_array: np.ndarray = None, debug: bool = False): array=label_array, colormap=self.cfg.cmap, ) - # we never want to connect signals to slots more than once - log.debug("initialized label: {}".format(self.initialized_label)) - # if not self.initialized_label: self.ui.videoPlayer.videoView.frameNum.connect(self.ui.labels.label.change_view_x) self.ui.labels.label.saved.connect(self.update_saved) self.initialized_label = True @@ -312,7 +296,6 @@ def initialize_label(self, label_array: np.ndarray = None, debug: bool = False): def initialize_prediction( self, prediction_array: np.ndarray = None, debug: bool = False, opacity: np.ndarray = None ): - # do all the setup for labels and predictions self.ui.predictions.initialize( behaviors=OmegaConf.to_container(self.cfg.project.class_names), n_timepoints=self.n_timepoints, @@ -322,9 +305,7 @@ def initialize_prediction( opacity=opacity, colormap=self.cfg.cmap, ) - # if not self.initialized_prediction: self.ui.videoPlayer.videoView.frameNum.connect(self.ui.predictions.label.change_view_x) - # we don't want to be able to manually edit the predictions self.ui.predictions.buttons.fix() self.initialized_prediction = True self.update() @@ -340,7 +321,6 @@ def generate_flow_train_args(self): def flow_train(self): if self.ui.flow_train.isChecked(): - # self.ui.flow_inference.setEnabled(False) self.ui.featureextractor_train.setEnabled(False) self.ui.featureextractor_infer.setEnabled(False) self.ui.sequence_infer.setEnabled(False) @@ -358,7 +338,6 @@ def flow_train(self): log.info("Training interrupted.") else: log.info("Training finished. If you see error messages above, training did not complete successfully.") - # self.train_thread.terminate() del self.training_pipe self.listener.quit() self.listener.wait() @@ -371,7 +350,6 @@ def flow_train(self): def featureextractor_train(self): if self.ui.featureextractor_train.isChecked(): self.ui.flow_train.setEnabled(False) - # self.ui.flow_inference.setEnabled(False) self.ui.featureextractor_infer.setEnabled(False) self.ui.sequence_infer.setEnabled(False) self.ui.sequence_train.setEnabled(False) @@ -382,14 +360,12 @@ def featureextractor_train(self): "deepethogram.feature_extractor.train", "project.path={}".format(self.cfg.project.path), ] - print(self.get_selected_models()) weights = self.get_selected_models()["feature_extractor"] - # print(weights) if weights is None: raise ValueError(pretrained_models_error) if os.path.isfile(weights): args += ["feature_extractor.weights={}".format(weights)] - flow_weights = self.get_selected_models()["flow_generator"] # ('flow_generator') + flow_weights = self.get_selected_models()["flow_generator"] assert flow_weights is not None args += ["flow_generator.weights={}".format(flow_weights)] log.info("feature extractor train called with args: {}".format(args)) @@ -408,12 +384,8 @@ def featureextractor_train(self): self.listener.wait() del self.listener log.info("~" * 100) - # self.ui.flow_train.setEnabled(True) self.project_loaded_buttons() self.get_trained_models() - # self.listener.stop() - - # self.ui.featureextractor_infer.setEnabled(True) def generate_featureextractor_inference_args(self): records = projects.get_records_from_datadir(self.data_path) @@ -426,11 +398,10 @@ def generate_featureextractor_inference_args(self): if not ret: return should_infer = form.get_outputs() - all_false = np.all(np.array(should_infer) == False) + all_false = np.all(np.array(should_infer) == False) # noqa: E712 if all_false: return self.ui.flow_train.setEnabled(False) - # self.ui.flow_inference.setEnabled(False) self.ui.featureextractor_train.setEnabled(False) weights = self.get_selected_models()["feature_extractor"] if weights is not None and os.path.isfile(weights): @@ -489,16 +460,12 @@ def featureextractor_infer(self): self.outputfile = None self.import_outputfile(self.outputfile) - # self.ui.featureextractor_infer.setEnabled(True) - def sequence_train(self): if self.ui.sequence_train.isChecked(): self.ui.flow_train.setEnabled(False) - # self.ui.flow_inference.setEnabled(False) self.ui.featureextractor_train.setEnabled(False) self.ui.featureextractor_infer.setEnabled(False) self.ui.sequence_infer.setEnabled(False) - # self.ui.sequence_train.setEnabled(False) args = ["python", "-m", "deepethogram.sequence.train", "project.path={}".format(self.cfg.project.path)] weights = self.get_selected_models()["sequence"] if weights is not None and os.path.isfile(weights): @@ -507,7 +474,6 @@ def sequence_train(self): self.listener = UnclickButtonOnPipeCompletion(self.ui.sequence_train, self.training_pipe) self.listener.start() else: - # self.train_thread.terminate() if self.training_pipe.poll() is None: self.training_pipe.terminate() self.training_pipe.wait() @@ -519,10 +485,8 @@ def sequence_train(self): self.listener.wait() del self.listener log.info("~" * 100) - # self.ui.flow_train.setEnabled(True) self.project_loaded_buttons() self.get_trained_models() - # self.ui.featureextractor_infer.setEnabled(True) def generate_sequence_inference_args(self): records = projects.get_records_from_datadir(self.data_path) @@ -532,7 +496,6 @@ def generate_sequence_inference_args(self): if sequence_weights is not None and os.path.isfile(sequence_weights): run_files = utils.get_run_files_from_weights(sequence_weights) sequence_config = OmegaConf.load(run_files["config_file"]) - # sequence_config = utils.load_yaml(os.path.join(os.path.dirname(sequence_weights), 'config.yaml')) latent_name = sequence_config["sequence"]["latent_name"] if latent_name is None: latent_name = sequence_config["feature_extractor"]["arch"] @@ -543,7 +506,6 @@ def generate_sequence_inference_args(self): raise ValueError("must specify a valid weight file to run sequence inference!") log.debug("latent name: {}".format(latent_name)) - # sequence_name, _ = utils.get_latest_model_and_name(self.project_config['project']['path'], 'sequence') # GOAL: MAKE ONLY FILES WITH LATENT_NAME PRESENT APPEAR ON LIST # SHOULD BE UNCHECKED IF THERE IS ALREADY THE "OUTPUT NAME" IN FILE @@ -560,7 +522,7 @@ def generate_sequence_inference_args(self): if not ret: return should_infer = form.get_outputs() - all_false = np.all(np.array(should_infer) == False) + all_false = np.all(np.array(should_infer) == False) # noqa: E712 if all_false: return weights = self.get_selected_models()["sequence"] @@ -607,7 +569,6 @@ def sequence_infer(self): self.listener.wait() del self.listener self.project_loaded_buttons() - # del self.listener record = projects.get_record_from_subdir(os.path.dirname(self.videofile)) if record["output"] is not None: self.outputfile = record["output"] @@ -624,18 +585,13 @@ def classifier_inference(self): log.error("Erroneous arguments to fe or seq: {}, {}".format(fe_args, sequence_args)) calls = [fe_args, sequence_args] - - # calls = [['ping', 'localhost', '-n', '10'], ['dir']] self.listener = SubprocessChainer(calls) self.listener.start() - # self.inference_thread = utils.chained_subprocess_calls(calls) else: self.listener.stop() self.listener.wait() del self.listener self.project_loaded_buttons() - # self.inference_thread - # print(should_be_checked) def run_overnight(self): if self.ui.actionOvernight.isChecked(): @@ -651,17 +607,13 @@ def run_overnight(self): log.error("Erroneous seq arguments in run overnight: {}".format(sequence_args)) calls = [flow_args, fe_args, sequence_args] - # calls = [['ping', 'localhost', '-n', '10'], ['dir']] self.listener = SubprocessChainer(calls) self.listener.start() - # self.inference_thread = utils.chained_subprocess_calls(calls) else: self.listener.stop() self.listener.wait() del self.listener self.project_loaded_buttons() - # self.inference_thread - # print(should_be_checked) def _new_project(self): form = CreateProject() @@ -688,7 +640,6 @@ def _new_project(self): behaviors.insert(0, "background") project_dict = projects.initialize_project(form.project_directory, project_name, behaviors, labeler) - self.initialize_project(project_dict["project"]["path"]) def add_class(self): @@ -702,7 +653,6 @@ def add_class(self): if text in self.cfg.project.class_names: log.warning("This behavior is already in the list...") return - # self.add_class() text = text.replace(" ", "_") log.info("new behavior name: {}".format(text)) @@ -720,10 +670,6 @@ def add_class(self): behaviors = OmegaConf.to_container(self.cfg.project.class_names) behaviors.append(text) self.cfg.project.class_names = behaviors - # self.project_config['project']['class_names'].append(text) - # self.project_config = projects.load_config( - # os.path.join(self.project_config['project']['path'], 'project_config.yaml')) - # behaviors = self.project_config['project']['class_names'] self.import_labelfile(self.labelfile) self.thresholds = None @@ -759,10 +705,6 @@ def remove_class(self): behaviors.remove(text) self.cfg.project.class_names = behaviors - # self.project_config = projects.load_config(os.path.join(self.project_config['project']['path'], - # 'project_config.yaml')) - # self.project_config['class_names'].remove(text) - # self.project_config['class_names'].append(text) self.outputfile = None self.import_labelfile(self.labelfile) self.import_outputfile(self.outputfile) @@ -781,12 +723,6 @@ def finalize(self): if not self.saved: self.save() - # if simple_popup_question( - # self, 'You have unsaved changes. You must save before labels can be finalized. ' - # 'Do you want to save?'): - # self.save() - # else: - # return log.info("finalizing labels for file {}".format(self.videofile)) fname, _ = os.path.splitext(self.videofile) label_fname = fname + "_labels.csv" @@ -813,7 +749,6 @@ def finalize(self): def save(self): if self.saved: - # do nothing return log.info("saving...") df = self._make_dataframe() @@ -821,7 +756,6 @@ def save(self): label_fname = fname + "_labels.csv" df.to_csv(label_fname) projects.add_file_to_subdir(label_fname, os.path.dirname(self.videofile)) - # self.save_to_hdf5() self.saved = True def import_labelfile(self, labelfile: Union[str, os.PathLike]): @@ -844,7 +778,7 @@ def import_external_labels(self): self, "Click on labels to import", data_dir, filestring, options=options ) if projects.is_deg_file(labelfile): - raise ValueError("Don" "t use this to open labels: use to import non-DeepEthogram labels") + raise ValueError("Dont use this to open labels: use to import non-DeepEthogram labels") filestring = "VideoReader files (*.h5 *.avi *.mp4)" videofile, _ = QFileDialog.getOpenFileName( self, "Click on corresponding video file", data_dir, filestring, options=options @@ -891,10 +825,6 @@ def import_outputfile(self, outputfile: Union[str, os.PathLike], latent_name=Non log.warning("Probabilities > 1 found, clamping...") probabilities = probabilities.clip(min=0, max=1.0) - # import pdb - # pdb.set_trace()x - # print('probabilities min: {}, max: {}'.format(probabilities.min(), probabilities.max())) - self.initialize_prediction(prediction_array=probabilities, opacity=opacity) self.ui.importPredictions.setEnabled(True) self.ui.exportPredictions.setEnabled(True) @@ -930,9 +860,7 @@ def change_predictions(self, new_text): return if self.latent_name != new_text: log.debug("not equal found: {}, {}".format(self.latent_name, new_text)) - # self.import_outputfile(self.outputfile, latent_name=new_text) - # log.warning('prediction import not implemented') def import_predictions_as_labels(self): if not hasattr(self, "estimated_labels"): @@ -951,7 +879,6 @@ def import_predictions_as_labels(self): else: rows_to_change = np.arange(0, current_label.shape[0]) - # print(current_label[rows_to_change, :]) current_label[rows_to_change, :] = current_predictions[rows_to_change, :] changed[:] = 1 self.ui.labels.label.array = current_label @@ -960,7 +887,6 @@ def import_predictions_as_labels(self): self.saved = False self.save() self.user_did_something() - # print(changed) def open_avi_browser(self): if self.data_path is not None: @@ -1001,10 +927,6 @@ def add_multiple_videos(self): for filename in filenames: self.initialize_video(filename) - # if len(filename) == 0 or not os.path.isfile(filename): - # raise ValueError('Could not open file: {}'.format(filename)) - # - # self.initialize_video(filename) def initialize_project(self, directory: Union[str, os.PathLike]): if len(directory) == 0: @@ -1015,19 +937,6 @@ def initialize_project(self, directory: Union[str, os.PathLike]): log.error("something wrong with loading yaml file: {}".format(filename)) return - # project_dict = projects.load_config(filename) - # if not projects.is_config_dict(project_dict): - # raise ValueError('Not a properly formatted configuration dictionary! Look at defaults/config.yaml: dict: {}' - # 'filename: {}'.format(project_dict, filename)) - # self.project_config = project_dict - # # self.project_config = projects.convert_config_paths_to_absolute(self.project_config) - - # # for convenience - # self.data_path = os.path.join(self.project_config['project']['path'], - # self.project_config['project']['data_path']) - # self.model_path = os.path.join(self.project_config['project']['path'], - # self.project_config['project']['model_path']) - # overwrite cfg passed at command line now that we know the project path. still includes command line arguments self.cfg = configuration.make_config(directory, ["config", "gui", "postprocessor"], run_type="gui", model=None) log.info("cwd: {}".format(os.getcwd())) @@ -1036,13 +945,9 @@ def initialize_project(self, directory: Union[str, os.PathLike]): self.cfg = projects.setup_run(self.cfg, raise_error_if_pretrained_missing=False) log.info("loaded project configuration: {}".format(OmegaConf.to_yaml(self.cfg))) log.info("cwd: {}".format(os.getcwd())) - # for convenience self.data_path = self.cfg.project.data_path self.model_path = self.cfg.project.model_path - # self.project_config['project']['class_names'] = self.project_config['class_names'] - # load up the last alphabetic record, which if user includes dates in filename, will be the most recent - # data_path = os.path.join(self.project_config['project']['path'], self.project_config['project']['data_path']) records = projects.get_records_from_datadir(self.data_path) self.project_loaded_buttons() if len(records) == 0: @@ -1055,23 +960,16 @@ def initialize_project(self, directory: Union[str, os.PathLike]): self.initialize_video(last_record["rgb"]) if last_record["label"] is not None: self.import_labelfile(last_record["label"]) - # if last_record['output'] is not None: - # self.import_outputfile(last_record['output']) self.get_trained_models() def load_project(self): - # options = QFileDialog.Options() - directory = QFileDialog.getExistingDirectory( - self, "Open your deepethogram directory (containing project " "config)" + self, "Open your deepethogram directory (containing project config)" ) self.initialize_project(directory) - # pprint.pprint(self.trained_model_dict) - def get_default_archs(self): - # TODO: replace this default logic with hydra 1.0 if "preset" in self.cfg: preset = self.cfg.preset else: @@ -1135,8 +1033,6 @@ def get_trained_models(self): self.ui.sequenceSelector.addItem(key) self.ui.sequenceSelector.setCurrentIndex(len(models) - 1) - # print(self.get_selected_models()) - def get_selected_models(self, model_type: str = None): flow_model = None fe_model = None @@ -1187,8 +1083,6 @@ def _make_dataframe(self): # negative example null_row = np.ones((n_behaviors,)) * -1 label[np.logical_not(changed), :] = null_row - # print(array.shape, len(self.label.behaviors)) - # print(self.label.behaviors) df = pd.DataFrame(data=label, columns=self.cfg.project.class_names) return df @@ -1228,7 +1122,6 @@ def closeEvent(self, event, *args, **kwargs): @Slot(bool) def update_saved(self, has_been_saved: bool): - # print('label change signal heard') self.saved = has_been_saved @@ -1243,7 +1136,6 @@ def set_style(app): darktheme.setColor(QtGui.QPalette.Button, QtGui.QColor(45, 45, 45)) darktheme.setColor(QtGui.QPalette.ButtonText, QtGui.QColor(222, 222, 222)) darktheme.setColor(QtGui.QPalette.AlternateBase, QtGui.QColor(222, 222, 222)) - # darktheme.setColor(QtGui.QPalette.AlternateBase, QtGui.QColor(0, 222, 0)) darktheme.setColor(QtGui.QPalette.ToolTipBase, QtGui.QColor(222, 222, 222)) darktheme.setColor(QtGui.QPalette.Highlight, QtGui.QColor(45, 45, 45)) darktheme.setColor(QtGui.QPalette.Disabled, QtGui.QPalette.Light, QtGui.QColor(60, 60, 60)) @@ -1253,13 +1145,6 @@ def set_style(app): darktheme.setColor(QtGui.QPalette.Disabled, QtGui.QPalette.WindowText, QtGui.QColor(122, 118, 113)) darktheme.setColor(QtGui.QPalette.Disabled, QtGui.QPalette.Base, QtGui.QColor(32, 32, 32)) - # darktheme.setColor(QtGui.QPalette.Highlight, QtGui.QColor(0, 255, 0)) - # darktheme.setColor(QtGui.QPalette. - # darktheme.setColor(QtGui.QPalette.Background, QtGui.QColor(255,0,0)) - # print(dir(QtGui.QPalette)) - # Define the pallet color - # Then set the pallet color - app.setPalette(darktheme) return app @@ -1283,8 +1168,6 @@ def setup_gui_cfg(): except Exception: pass - # OmegaConf.set_struct(cfg, False) - log.info("CWD: {}".format(os.getcwd())) log.info("Configuration used: {}".format(OmegaConf.to_yaml(cfg))) return cfg @@ -1303,9 +1186,6 @@ def run() -> None: sys.exit(app.exec_()) -# this function is required to allow automatic detection of the module name when running -# from a binary script. -# it should be called from the executable script and not the hydra.main() function directly. def entry() -> None: run() diff --git a/deepethogram/gui/mainwindow.py b/deepethogram/gui/mainwindow.py index 442fe46..fb85094 100644 --- a/deepethogram/gui/mainwindow.py +++ b/deepethogram/gui/mainwindow.py @@ -295,4 +295,4 @@ def retranslateUi(self, MainWindow): self.actionAdd_multiple.setText(QtWidgets.QApplication.translate("MainWindow", "Add multiple", None, -1)) -from deepethogram.gui.custom_widgets import LabelImg, VideoPlayer +from deepethogram.gui.custom_widgets import LabelImg, VideoPlayer # noqa: E402 diff --git a/deepethogram/losses.py b/deepethogram/losses.py index 6e81f36..83a192c 100644 --- a/deepethogram/losses.py +++ b/deepethogram/losses.py @@ -1,8 +1,8 @@ import logging import os -from omegaconf import DictConfig import torch +from omegaconf import DictConfig from torch import nn from deepethogram import projects @@ -102,7 +102,6 @@ def __init__(self, model: nn.Module, path_to_pretrained_weights, alpha: float, b self.alpha = alpha self.beta = beta - # assert cfg.train.regularization.style == 'l2_sp' assert os.path.isfile(path_to_pretrained_weights) state = torch.load(path_to_pretrained_weights, map_location="cpu") @@ -175,15 +174,6 @@ def forward(self, model): if towards_pretrained != towards_pretrained or towards_0 != towards_0: msg = "invalid loss in L2-SP: towards pretrained: {} towards 0: {}".format(towards_pretrained, towards_0) raise ValueError(msg) - # alternate method. same result, ~50% slower - # towards_pretrained, towards_0 = 0, 0 - - # for key, param in model.named_parameters(): - # if key in self.pretrained_keys: - # pretrained_param = getattr(self, self.dots_to_underscores(key)) - # towards_pretrained += (param - pretrained_param).pow(2).sum()*0.5 - # elif key in self.new_keys: - # towards_0 += param.pow(2).sum()*0.5 return towards_pretrained * self.alpha + towards_0 * self.beta diff --git a/deepethogram/metrics.py b/deepethogram/metrics.py index 655176e..5e08416 100644 --- a/deepethogram/metrics.py +++ b/deepethogram/metrics.py @@ -2,13 +2,13 @@ import os import warnings from collections import defaultdict -from typing import Union, Tuple from multiprocessing import Pool +from typing import Tuple, Union import h5py import numpy as np import torch -from sklearn.metrics import f1_score, roc_auc_score, confusion_matrix, auc +from sklearn.metrics import auc, confusion_matrix, f1_score, roc_auc_score from deepethogram import utils from deepethogram.postprocessing import remove_low_thresholds @@ -22,7 +22,7 @@ try: slurm_job_id = os.environ["SLURM_JOB_ID"] slurm = True -except: +except Exception: slurm = False @@ -139,7 +139,7 @@ def accuracy(predictions: np.ndarray, labels: np.ndarray): return np.mean(predictions == labels) -def confusion(predictions: np.ndarray, labels: np.ndarray, K: int = None) -> np.ndarray: +def confusion(predictions: np.ndarray, labels: np.ndarray, K: Union[int, None] = None) -> np.ndarray: """Computes confusion matrix. Much faster than sklearn.metrics.confusion_matrix for large numbers of predictions Parameters @@ -168,13 +168,11 @@ def confusion(predictions: np.ndarray, labels: np.ndarray, K: int = None) -> np. cm = np.zeros((K, K)).astype(int) for i in range(K): for j in range(K): - # these_inds = labels==i - # cm[i, j] = np.sum((labels==i)*(predictions==j)) cm[i, j] = np.sum(np.logical_and(labels == i, predictions == j)) return cm -def binary_confusion_matrix(predictions, labels) -> np.ndarray: +def binary_confusion_matrix(predictions: np.ndarray, labels: np.ndarray) -> np.ndarray: # behaviors x thresholds x 2 x 2 # cms = np.zeros((K, N, 2, 2), dtype=int) ndim = predictions.ndim @@ -206,7 +204,9 @@ def binary_confusion_matrix(predictions, labels) -> np.ndarray: return cms -def binary_confusion_matrix_multiple_thresholds(probabilities, labels, thresholds): +def binary_confusion_matrix_multiple_thresholds( + probabilities: np.ndarray, labels: np.ndarray, thresholds: np.ndarray +) -> np.ndarray: # this is the fastest I could possibly write it K = probabilities.shape[1] N = len(thresholds) @@ -253,7 +253,6 @@ def binary_confusion_matrix_parallel( else: raise ValueError("weird shape in probs_or_preds: {}".format(probs_or_preds.shape)) func = confusion_alias - # log.info('parallel start') if num_workers > 1: with Pool(num_workers) as pool: for res in pool.imap_unordered(func, iterator, parallel_chunk): @@ -261,7 +260,6 @@ def binary_confusion_matrix_parallel( else: for args in iterator: cm += func(args) - # log.info('parallel end') return cm @@ -363,14 +361,13 @@ def fast_auc(y_true, y_prob): nfalse = np.cumsum(1 - y_true) auc = np.cumsum((y_true * nfalse))[-1] - # print(auc) auc /= nfalse[-1] * (n - nfalse[-1]) return auc # @profile def evaluate_thresholds( - probabilities: np.ndarray, labels: np.ndarray, thresholds: np.ndarray = None, num_workers: int = 4 + probabilities: np.ndarray, labels: np.ndarray, thresholds: Union[np.ndarray, None] = None, num_workers: int = 4 ) -> Tuple[dict, dict]: """Given probabilities and labels, compute a bunch of metrics at each possible threshold value @@ -395,8 +392,6 @@ def evaluate_thresholds( epoch_metrics: dict each value is only a single float for the entire prediction / label set. """ - # log.info('evaluating thresholds. P: {} lab: {} n_workers: {}'.format(probabilities.shape, labels.shape, num_workers)) - # log.info('SLURM in metrics file: {}'.format(slurm)) if slurm and num_workers != 1: warnings.warn("using multiprocessing on slurm can cause issues. setting num_workers to 1") num_workers = 1 @@ -404,8 +399,6 @@ def evaluate_thresholds( if thresholds is None: # using 200 means that approximated mAP, AUROC is almost exactly the same as exact thresholds = np.linspace(1e-4, 1, 200) - # log.info('num workers in evaluate thresholds: {}'.format(num_workers)) - # log.debug('probabilities shape in metrics calc: {}'.format(probabilities.shape)) metrics_by_threshold = {} if probabilities.ndim == 1: raise ValueError("To calc threshold, predictions must be probabilities, not classes") @@ -414,10 +407,7 @@ def evaluate_thresholds( labels = index_to_onehot(labels, K) probabilities, labels = remove_invalid_values_predictions_and_labels(probabilities, labels) - # log.info('first metrics call') metrics_by_threshold = compute_metrics_by_threshold(probabilities, labels, thresholds, num_workers) - # log.info('first metrics call finished') - # log.info('finished computing binary confusion matrices') # optimum threshold: one that maximizes F1 optimum_indices = np.argmax(metrics_by_threshold["f1"], axis=0) optimum_thresholds = thresholds[optimum_indices] @@ -439,11 +429,7 @@ def evaluate_thresholds( # ALWAYS REPORT THE PERFORMANCE WITH "VALID" BACKGROUND predictions[:, 0] = np.logical_not(np.any(predictions[:, 1:], axis=1)) - # log.info('computing metric thresholds again') - # re-use our confusion matrix calculation. returns N x N x K values - # log.info('second metircs call') metrics_by_class = compute_metrics_by_threshold(predictions, labels, None, num_workers) - # log.info('second metrics call ended') # summing over classes is the same as flattening the array. ugly syntax # TODO: make function that computes metrics from a stack of confusion matrices rather than this none None business @@ -468,25 +454,7 @@ def evaluate_thresholds( "auroc_overall": np.nan, "mAP_overall": np.nan, } - # it is too much of a pain to increase the speed on roc_auc_score and mAP - # try: - # epoch_metrics['auroc_overall'] = roc_auc_score(labels, probabilities, average='micro') - # epoch_metrics['auroc_by_class'] = roc_auc_score(labels, probabilities, average=None) - # # small perf improvement is not worth worrying about bugs - # # epoch_metrics['auroc_overall'] = fast_auc(labels.flatten(), probabilities.flatten()) - # # epoch_metrics['auroc_by_class'] = fast_auc(labels, probabilities) - # epoch_metrics['auroc_class_mean'] = epoch_metrics['auroc_by_class'].mean() - # except ValueError: - # # only one class in labels... - # epoch_metrics['auroc_overall'] = np.nan - # epoch_metrics['auroc_class_mean'] = np.nan - # epoch_metrics['auroc_by_class'] = np.array([np.nan for _ in range(K)]) - # - # epoch_metrics['mAP_overall'] = average_precision_score(labels, probabilities, average='micro') - # epoch_metrics['mAP_by_class'] = average_precision_score(labels, probabilities, average=None) - # # this is a misnomer: mAP by class is just AP - # epoch_metrics['mAP_class_mean'] = epoch_metrics['mAP_by_class'].mean() - # log.info('returning metrics') + return metrics_by_threshold, epoch_metrics @@ -519,7 +487,7 @@ def compute_f1(precision: float, recall: float, beta: float = 1.0) -> float: def compute_precision_recall(cm: np.ndarray) -> Tuple[float, float]: """computes precision and recall from a confusion matrix""" - tn = cm[0, 0] + # tn = cm[0, 0] tp = cm[1, 1] fp = cm[0, 1] fn = cm[1, 0] @@ -591,9 +559,9 @@ def postprocess(predictions: np.ndarray, thresholds: np.ndarray, valid_bg: bool def list_to_mean(values): - if type(values[0]) == torch.Tensor: + if isinstance(values[0], torch.Tensor): value = utils.tensor_to_np(torch.stack(values).mean()) - elif type(values[0]) == np.ndarray: + elif isinstance(values[0], np.ndarray): if values[0].size == 1: value = np.stack(np.array(values)).mean() else: @@ -762,7 +730,6 @@ def initialize_file(self): def save_metrics_to_disk(self, metrics: dict, split: str) -> None: with h5py.File(self.fname, "r+") as f: - # utils.print_hdf5(f) if split not in f.keys(): # should've created top-level groups in initialize_file; this is for nesting f.create_group(split) @@ -920,7 +887,6 @@ def compute(self, data: dict): else: labels = self.stack_sequence_data(data["labels"]) - num_classes = probs.shape[1] one_hot = probs.shape[-1] == labels.shape[-1] if one_hot: rows_with_false_labels = np.any(labels == self.ignore_index, axis=1) @@ -950,8 +916,7 @@ def compute(self, data: dict): if metric == "confusion": warnings.simplefilter("ignore") metrics[metric] = confusion(predictions, labels, K=self.num_classes) - # import pdb - # pdb.set_trace() + elif metric == "binary_confusion": pass else: diff --git a/deepethogram/postprocessing.py b/deepethogram/postprocessing.py index 4857866..35852d0 100644 --- a/deepethogram/postprocessing.py +++ b/deepethogram/postprocessing.py @@ -1,14 +1,14 @@ -from collections import defaultdict import logging import os -from typing import Type, Tuple +from collections import defaultdict +from typing import Tuple, Type import h5py import numpy as np -from omegaconf import DictConfig, OmegaConf import pandas as pd +from omegaconf import DictConfig, OmegaConf -from deepethogram import projects, file_io +from deepethogram import file_io, projects log = logging.getLogger(__name__) @@ -301,14 +301,12 @@ def get_bout_length_percentile(label_list: list, percentile: float) -> dict: bout_length = bouts[k]["lengths"].tolist() bout_lengths[k].append(bout_length) bout_lengths = {behavior: np.concatenate(value) for behavior, value in bout_lengths.items()} - # print(bout_lengths) percentiles = {} for behavior, value in bout_lengths.items(): if len(value) > 0: percentiles[behavior] = np.percentile(value, percentile) else: percentiles[behavior] = 1 - # percentiles = {behavior: np.percentile(value, percentile) for behavior, value in bout_lengths.items()} return percentiles @@ -337,7 +335,7 @@ def get_postprocessor_from_cfg(cfg: DictConfig, thresholds: np.ndarray) -> Type[ label_list.append(label) percentiles = get_bout_length_percentile(label_list, cfg.postprocessor.min_bout_length) - # percntiles is a dict: keys are behaviors, values are percentiles + # percentiles is a dict: keys are behaviors, values are percentiles # need to round and then cast to int percentiles = np.round(np.array(list(percentiles.values()))).astype(int) return MinBoutLengthPerBehaviorPostprocessor(thresholds, percentiles) diff --git a/deepethogram/projects.py b/deepethogram/projects.py index 6cac546..9e54e04 100644 --- a/deepethogram/projects.py +++ b/deepethogram/projects.py @@ -13,10 +13,11 @@ from tqdm import tqdm import deepethogram -from deepethogram.utils import get_subfiles, log +from deepethogram.utils import get_subfiles from deepethogram.zscore import zscore_video + from . import utils -from .file_io import read_labels, convert_video +from .file_io import convert_video, read_labels log = logging.getLogger(__name__) @@ -72,7 +73,7 @@ def initialize_project( if not os.path.isdir(project_config["project"]["path"]): os.makedirs(project_config["project"]["path"]) - # os.chdir(project_config['project']['path']) + data_abs = os.path.join(project_config["project"]["path"], project_config["project"]["data_path"]) if not os.path.isdir(data_abs): @@ -114,19 +115,11 @@ def add_video_to_project(project: dict, path_to_video: Union[str, os.PathLike], new_path: str path to the video file after moving to the DEG project data directory. """ - # assert (os.path.isdir(project_directory)) + assert os.path.exists(path_to_video), "video not found! {}".format(path_to_video) - if os.path.isdir(path_to_video): - copy_func = shutil.copytree - elif os.path.isfile(path_to_video): - copy_func = shutil.copy - else: - raise ValueError("video does not exist: {}".format(path_to_video)) assert mode in ["copy", "symlink", "move"] - # project = utils.load_yaml(os.path.join(project_directory, 'project_config.yaml')) - # project = convert_config_paths_to_absolute(project) log.debug("configuration file when adding video: {}".format(project)) datadir = os.path.join(project["project"]["path"], project["project"]["data_path"]) assert os.path.isdir(datadir), "data path not found: {}".format(datadir) @@ -143,7 +136,7 @@ def add_video_to_project(project: dict, path_to_video: Union[str, os.PathLike], video_directory = os.path.join(datadir, vidname) if os.path.isdir(video_directory): raise ValueError( - "Directory {} already exists in your data dir! " "Please rename the video to a unique name".format(vidname) + "Directory {} already exists in your data dir! Please rename the video to a unique name".format(vidname) ) os.makedirs(video_directory) new_path = os.path.join(video_directory, basename) @@ -185,8 +178,6 @@ def add_label_to_project(path_to_labels: Union[str, os.PathLike], path_to_video) array = df.values is_background = np.logical_not(np.any(array == 1, axis=1)).astype(int)[:, np.newaxis] data = np.concatenate((is_background, array), axis=1) - # df2 = pd.DataFrame(data=is_background, columns=['background']) - # df = pd.concat([df2, df], axis=1) df = pd.DataFrame(data=data, columns=["background"] + list(df.columns)) df.to_csv(label_dst) @@ -206,24 +197,6 @@ def add_file_to_subdir(file: Union[str, os.PathLike], subdir: Union[str, os.Path utils.save_dict_to_yaml(record, os.path.join(subdir, "record.yaml")) -# def change_project_directory(config_file: Union[str, os.PathLike], new_directory: Union[str, os.PathLike]): -# """If you move the project directory to some other location, updates the config file to have the new directories""" -# assert os.path.isfile(config_file) -# assert os.path.isdir(new_directory) -# # make sure that new directory is properly formatted for deepethogram -# datadir = os.path.join(new_directory, 'DATA') -# model_path = os.path.join(new_directory, 'models') -# assert os.path.isdir(datadir) -# assert os.path.isdir(model_path) - -# project_config = utils.load_yaml(config_file) -# project_config['project']['path'] = new_directory -# project_config['project']['model_path'] = os.path.basename(model_path) -# project_config['project']['data_path'] = os.path.basename(datadir) -# project_config['project']['config_file'] = os.path.join(new_directory, 'project_config.yaml') -# utils.save_dict_to_yaml(project_config, project_config['project']['config_file']) - - def remove_video_from_project(config_file, video_file=None, record_directory=None): # TODO: remove video from split dictionary, remove mean and std from project statistics raise NotImplementedError @@ -277,15 +250,14 @@ def add_behavior_to_project(config_file: Union[str, os.PathLike], behavior_name: records = get_records_from_datadir( os.path.join(project_config["project"]["path"], project_config["project"]["data_path"]) ) - for key, record in records.items(): + for _, record in records.items(): labelfile = record["label"] if labelfile is None: continue if os.path.isfile(labelfile): df = pd.read_csv(labelfile, index_col=0) label = df.values - N, K = label.shape - # label = np.concatenate((label, np.ones((N, 1))*-1), axis=1) + N, _ = label.shape df2 = pd.DataFrame(data=np.ones((N, 1)) * -1, columns=[behavior_name]) df = pd.concat([df, df2], axis=1) df.to_csv(labelfile) @@ -321,7 +293,7 @@ def remove_behavior_from_project(config_file: Union[str, os.PathLike], behavior_ records = get_records_from_datadir( os.path.join(project_config["project"]["path"], project_config["project"]["data_path"]) ) - for key, record in records.items(): + for _, record in records.items(): labelfile = record["label"] if labelfile is None: continue @@ -352,8 +324,7 @@ def get_classes_from_project(config: Union[dict, str, os.PathLike, DictConfig]) classes: list list of behaviors read from project_config.yaml file. """ - - if type(config) == str or type(config) == os.PathLike: + if isinstance(config, str) or isinstance(config, os.PathLike): config_file = os.path.join(config, "project_config.yaml") assert os.path.isfile(config_file), "Input must be a directory containing a project_config.yaml file" config = utils.load_yaml(config_file) @@ -410,7 +381,6 @@ def find_rgbfiles(root: Union[str, bytes, os.PathLike]) -> list: list of absolute paths to RGB videos, or subdirectories containing individual images (framedirs) """ files = get_subfiles(root, return_type="any") - endings = [os.path.splitext(i)[1] for i in files] valid_endings = [".avi", ".mp4", ".h5", ".mov"] excluded = ["flow", "label", "output", "score"] movies = [i for i in files if os.path.splitext(i)[1].lower() in valid_endings] @@ -612,7 +582,7 @@ def get_record_from_subdir(subdir: Union[str, os.PathLike]) -> dict: if key in list(record.keys()): this_entry = record[key]["default"] - if type(this_entry) == list and len(this_entry) == 0: + if isinstance(this_entry, list) and len(this_entry) == 0: this_entry = None else: this_entry = os.path.join(subdir, this_entry) @@ -650,7 +620,6 @@ def get_records_from_datadir(datadir: Union[str, bytes, os.PathLike]) -> dict: for subdir in subdirs: parsed_record = get_record_from_subdir(os.path.join(datadir, subdir)) records[parsed_record["key"]] = parsed_record - # write_all_records(datadir) return records @@ -684,7 +653,7 @@ def get_unfinalized_records(config: dict) -> list: """Finds the number of label files with no unlabeled frames""" records = get_records_from_datadir(os.path.join(config["project"]["path"], config["project"]["data_path"])) unfinalized = [] - for k, v in records.items(): + for _, v in records.items(): if v["label"] is None or len(v["label"]) == 0: unfinalized.append(v) else: @@ -699,7 +668,7 @@ def get_number_finalized_labels(config: dict) -> int: """Finds the number of label files with no unlabeled frames""" records = get_records_from_datadir(os.path.join(config["project"]["path"], config["project"]["data_path"])) number_valid_labels = 0 - for k, v in records.items(): + for _, v in records.items(): for filetype, fileloc in v.items(): if filetype == "label": if fileloc is None or len(fileloc) == 0: @@ -714,8 +683,8 @@ def get_number_finalized_labels(config: dict) -> int: def import_outputfile( project_dir: Union[str, os.PathLike], outputfile: Union[str, os.PathLike], - class_names: list = None, - latent_name: str = None, + class_names: Union[list, None] = None, + latent_name: Union[str, None] = None, ): """Gets the probabilities, thresholds, used HDF5 dataset key, and all dataset keys from an outputfile @@ -809,7 +778,7 @@ def import_outputfile( # this should not happen thresholds = thresholds[-1, :] loaded_class_names = f[key]["class_names"][:] - if type(loaded_class_names[0]) == bytes: + if isinstance(loaded_class_names[0], bytes): loaded_class_names = [i.decode("utf-8") for i in loaded_class_names] log.debug("probabilities shape: {}".format(probabilities.shape)) @@ -855,7 +824,7 @@ def do_outputfiles_have_predictions(data_path: Union[str, os.PathLike], model_na assert os.path.isdir(data_path) records = get_records_from_datadir(data_path) has_predictions = [] - for key, record in records.items(): + for key, _ in records.items(): file = records[key]["output"] if file is None: has_predictions.append(False) @@ -976,7 +945,6 @@ def get_weights_from_model_path(model_path: Union[str, os.PathLike]) -> dict: rundirs += subdirs rundirs.sort() - # model_weights = defaultdict(list) model_weights = {"flow_generator": {}, "feature_extractor": {}, "sequence": {}} for rundir in rundirs: # for backwards compatibility @@ -1007,16 +975,12 @@ def get_weights_from_model_path(model_path: Union[str, os.PathLike]) -> dict: else: continue - # architecture = params[model_type]['arch'] - - weightfile = get_weightfile_from_rundir(rundir) # os.path.join(rundir, 'checkpoint.pt') + weightfile = get_weightfile_from_rundir(rundir) if weightfile is not None and os.path.isfile(weightfile): if arch in model_weights[model_type].keys(): model_weights[model_type][arch].append(weightfile) else: model_weights[model_type][arch] = [weightfile] - # model_weights[model_type].append(weightfile) - # model_weights[model_type][arch].append(weightfile) for model in model_weights.keys(): for arch, runlist in model_weights[model].items(): model_weights[model][arch] = sort_runs_by_date(runlist) @@ -1056,10 +1020,6 @@ def get_weightfile_from_cfg(cfg: DictConfig, model_type: str) -> Union[str, None weightfile: path to weight file """ - # if cfg.reload.weights is not None: - # assert os.path.isfile(cfg.reload.weights) - # return cfg.reload.weights - assert model_type in ["flow_generator", "feature_extractor", "end_to_end", "sequence"] if not os.path.isdir(cfg.project.model_path): @@ -1088,7 +1048,6 @@ def get_weightfile_from_cfg(cfg: DictConfig, model_type: str) -> Union[str, None log.info("loading specified weights: {}".format(path_to_weights)) return path_to_weights elif cfg.reload.latest or cfg[model_type].weights == "latest": - # print(trained_models) if len(trained_models[model_type][architecture]) == 0: log.warning("Trying to load *latest* weights, but found none! Using random initialization!") return @@ -1200,7 +1159,6 @@ def load_config(path_to_config: Union[str, os.PathLike]) -> dict: project = OmegaConf.load(path_to_config) project = fix_config_paths(project, path_to_config) - # project = convert_config_paths_to_absolute(project) return project @@ -1235,7 +1193,7 @@ def convert_all_videos(config_file: Union[str, os.PathLike], movie_format="hdf5" records = get_records_from_datadir( os.path.join(project_config["project"]["path"], project_config["project"]["data_path"]) ) - for key, record in tqdm(records.items(), desc="converting videos"): + for _, record in tqdm(records.items(), desc="converting videos"): videofile = record["rgb"] try: convert_video(videofile, movie_format=movie_format, **kwargs) @@ -1306,9 +1264,6 @@ def get_config_from_path(project_path: Union[str, os.PathLike]): project_path = os.path.abspath(project_path) cfg_file = get_config_file_from_path(project_path) return load_config(cfg_file) - # project_cfg = OmegaConf.load(cfg_file) - # project_cfg = fix_config_paths(project_cfg, cfg_file) - # return project_cfg def get_project_path_from_cl(argv: list, error_if_not_found=True) -> str: @@ -1333,7 +1288,7 @@ def get_project_path_from_cl(argv: list, error_if_not_found=True) -> str: """ for arg in argv: if "project.config_file" in arg: - key, path = arg.split("=") + _, path = arg.split("=") assert os.path.isfile(path) # path is the path to the project directory, not the config file path = os.path.dirname(path) @@ -1349,54 +1304,6 @@ def get_project_path_from_cl(argv: list, error_if_not_found=True) -> str: return None -# def make_config(project_path: Union[str, os.PathLike], config_list: list, run_type: str, model: str) -> DictConfig: -# """DEPRECATED -# TODO: replace with configuration.make_config -# """ -# config_path = os.path.join(os.path.dirname(deepethogram.__file__), 'conf') - -# def config_string_to_path(config_path, string): -# fullpath = os.path.join(config_path, *string.split('/')) + '.yaml' -# assert os.path.isfile(fullpath) -# return fullpath - -# cli = OmegaConf.from_cli() - -# if project_path is not None: -# user_cfg = get_config_file_from_path(project_path) - -# # order of operations: first, defaults specified in config_list -# # then, if preset is specified in user config or at the command line, load those preset values -# # then, append the user config -# # then, the command line args -# # so if we specify a preset and manually change, say, the feature extractor architecture, we can do that -# if 'preset' in user_cfg: -# config_list += ['preset/' + user_cfg.preset] - -# if 'preset' in cli: -# config_list += ['preset/' + cli.preset] - -# config_files = [config_string_to_path(config_path, i) for i in config_list] - -# cfgs = [OmegaConf.load(i) for i in config_files] - -# if project_path is not None: -# # first defaults; then user cfg; then cli -# cfg = OmegaConf.merge(*cfgs, user_cfg, cli) -# else: -# cfg = OmegaConf.merge(*cfgs, cli) - -# cfg.run = {'type': run_type, 'model': model} -# return cfg - -# def make_config_from_cli(argv, *args, **kwargs): -# """DEPRECATED -# TODO: replace with configuration.make_config -# """ -# project_path = get_project_path_from_cl(argv) -# return make_config(project_path, *args, **kwargs) - - def configure_run_directory(cfg: DictConfig) -> str: """Makes a run directory from a configuration diff --git a/deepethogram/sequence/inference.py b/deepethogram/sequence/inference.py index 406dd49..b89e8ab 100644 --- a/deepethogram/sequence/inference.py +++ b/deepethogram/sequence/inference.py @@ -1,7 +1,7 @@ import logging import os import sys -from typing import Union, Type +from typing import Type, Union import h5py @@ -13,7 +13,7 @@ from torch.utils import data from tqdm import tqdm -from deepethogram import utils, projects +from deepethogram import projects, utils from deepethogram.configuration import make_sequence_inference_cfg from deepethogram.data.datasets import FeatureVectorDataset, KeypointDataset from deepethogram.sequence.train import build_model_from_cfg @@ -102,7 +102,7 @@ def infer( log.debug("file: {}".format(data_file)) log.debug("seq length: {}".format(sequence_length)) - if type(activation_function) == str: + if isinstance(activation_function, str): if activation_function == "softmax": activation_function = torch.nn.Softmax(dim=1) elif activation_function == "sigmoid": @@ -110,7 +110,7 @@ def infer( else: raise ValueError("unknown activation function: {}".format(activation_function)) - if type(device) == str: + if isinstance(device, str): device = torch.device(device) if next(model.parameters()).device != device: @@ -187,8 +187,6 @@ def extract( parameter.requires_grad = False model.eval() - has_printed = False - if final_activation == "softmax": activation_function = torch.nn.Softmax(dim=1) elif final_activation == "sigmoid": @@ -208,17 +206,6 @@ def extract( model, device, activation_function, outputfile, latent_name, None, sequence_length, is_two_stream ) - # gen = FeatureVectorDataset(outputfile, labelfile=None, h5_key=latent_name, - # sequence_length=sequence_length, - # nonoverlapping=True, store_in_ram=False, is_two_stream=is_two_stream) - # n_datapoints = gen.shape[1] - # gen = data.DataLoader(gen, batch_size=1, shuffle=False, num_workers=0, drop_last=False) - # gen = iter(gen) - - # log.debug('Making sequence iterator with parameters: ') - # log.debug('file: {}'.format(outputfile)) - # log.debug('seq length: {}'.format(sequence_length)) - with h5py.File(outputfile, "r+") as f: if output_name in list(f.keys()): if overwrite: @@ -247,7 +234,6 @@ def sequence_inference(cfg: DictConfig): run_files = utils.get_run_files_from_weights(weights) if cfg.sequence.latent_name is None: # find the latent name used in the weight file you loaded - rundir = os.path.dirname(weights) loaded_cfg = utils.load_yaml(run_files["config_file"]) latent_name = loaded_cfg["sequence"]["latent_name"] # if this latent name is also None, use the arch of the feature extractor @@ -282,7 +268,7 @@ def sequence_inference(cfg: DictConfig): "must pass list of directories from commmand line. " "Ex: directory_list=[path_to_dir1,path_to_dir2] or directory_list=all" ) - elif type(directory_list) == str and directory_list == "all": + elif isinstance(directory_list, str) and directory_list == "all": basedir = cfg.project.data_path directory_list = utils.get_subfiles(basedir, "directory") @@ -301,7 +287,6 @@ def sequence_inference(cfg: DictConfig): metrics_file = run_files["metrics_file"] assert os.path.isfile(metrics_file) best_epoch = utils.get_best_epoch_from_weightfile(weights) - # best_epoch = -1 log.info("best epoch from loaded file: {}".format(best_epoch)) with h5py.File(metrics_file, "r") as f: try: diff --git a/deepethogram/sequence/models/tgm.py b/deepethogram/sequence/models/tgm.py index 47b1cbd..12fb768 100644 --- a/deepethogram/sequence/models/tgm.py +++ b/deepethogram/sequence/models/tgm.py @@ -58,21 +58,14 @@ def __init__( self.soft_attn = nn.Parameter(torch.Tensor(self.c_out * self.c_in, self.n_filters)) # edited from original code, which had no initialization torch.nn.init.xavier_normal_(self.soft_attn) - # init_sparse_positive(self.soft_attn, 0.1, std=1) - # init_sparse(self.soft_attn, 0.5, std=1) - # torch.nn.init.orthogonal_(self.soft_attn) - # torch.nn.init.eye_(self.soft_attn) - # torch.nn.init.sparse_(self.soft_attn, sparsity=0.5, std=1) + # learn c_out combinations of the c_in channels if self.c_in > 1 and not self.soft: self.convs = nn.ModuleList([nn.Conv2d(self.c_in, 1, (1, 1)) for c in range(self.c_out)]) if self.c_in > 1 and soft: - # self.cls_attn = nn.Parameter(torch.Tensor(1,self.c_out, self.c_in,1,1)) cls_attn = torch.Tensor(self.c_out, self.c_in) torch.nn.init.xavier_normal_(cls_attn) - # torch.nn.init.sparse_(cls_attn, sparsity=0.5, std=1) self.cls_attn = nn.Parameter(cls_attn.unsqueeze(2).unsqueeze(2).unsqueeze(0)) - # print(self.cls_attn.shape) def get_filters(self): device = self.center.device @@ -82,15 +75,12 @@ def get_filters(self): deltas = self.filter_length * (1.0 - torch.abs(torch.tanh(self.delta))) gammas = torch.exp(1.5 - 2.0 * torch.abs(torch.tanh(self.gamma))) - # print(centers, deltas, gammas) a = torch.zeros(self.n_filters).to(device) # stride and center a = deltas[:, None] * a[None, :] a = centers[:, None] + a - # print(a) b = torch.arange(0, self.filter_length).to(device) - # b = b.cuda() f = b - a[:, :, None] f = f / gammas[:, None, None] @@ -110,7 +100,6 @@ def forward(self, x): # overwrite the forward pass to get the TSF as conv kernels t = x.size(-1) k = self.get_filters() - # k = super(TGM, self).get_filters(torch.tanh(self.delta), torch.tanh(self.gamma), torch.tanh(self.center), self.length, self.length) # k is shape 1xNxL k = k.squeeze() # is k now NxL @@ -153,11 +142,9 @@ def forward(self, x): chnls.append(r) # get C_out x DxT f = torch.stack(chnls, dim=1) - # print('f: {}'.format(f.shape)) f_stack = f if self.c_in > 1 and self.soft: a = F.softmax(self.cls_attn, dim=2).expand(f.size(0), -1, -1, f.size(3), f.size(4)) - # print('a:{}'.format(a.shape)) f = torch.sum(a * f, dim=2) else: a = None @@ -204,7 +191,6 @@ def __init__( self.viz = viz self.concatenate_inputs = concatenate_inputs - # self.add_module('d', self.dropout) modules = [] for i in range(num_layers): c_in = self.c_in if i == 0 else self.c_out @@ -214,9 +200,6 @@ def __init__( if self.reduction == "conv1x1": self.reduction_layer = nn.Conv2d(self.c_out, 1, kernel_size=1, padding=0, stride=1) - # self.sub_event1 = TGM(inp, 16, 5, c_in=1, c_out=8, soft=False) - # self.sub_event2 = TGM(inp, 16, 5, c_in=8, c_out=8, soft=False) - # self.sub_event3 = TGM(inp, 16, 5, c_in=8, c_out=8, soft=False) N = self.D * 2 if self.concatenate_inputs else self.D if nonlinear_classification: self.h = nn.Conv1d(N, self.num_features, 1) @@ -225,30 +208,23 @@ def __init__( self.h = None self.classify = nn.Conv1d(N, classes, 1) - # self.inp = inp self.viz = viz def forward(self, inp): smoothed = self.tgm_layers(inp) - # print('smoothed before max:{}'.format(smoothed.shape)) if self.reduction == "max": smoothed = torch.max(smoothed, dim=1)[0] elif self.reduction == "mean": smoothed = torch.mean(smoothed, dim=1) elif self.reduction == "conv1x1": smoothed = self.reduction_layer(smoothed).squeeze() - # sub_event = self.dropout(torch.max(sub_event, dim=1)[0]) - # print('sub_event:{}'.format(smoothed.shape)) # concatenate original data with the learned smoothing if inp.shape != smoothed.shape: if inp.ndim == 3 and smoothed.ndim == 2: smoothed = smoothed.unsqueeze(0) else: - print("ERROR") - import pdb - - pdb.set_trace() + raise ValueError("Input and smoothed shapes do not match") if self.concatenate_inputs: tgm_module_output = torch.cat([inp, smoothed], dim=1) @@ -258,8 +234,6 @@ def forward(self, inp): if self.h is not None: tgm_module_output = self.input_dropout(tgm_module_output) - # NEW: got rid of relu on input features - # cls = F.relu(concatenated) cls = F.relu(self.h(tgm_module_output)) cls = self.output_dropout(cls) else: @@ -349,42 +323,32 @@ def __init__( bn_2.bias = bias self.classify1 = nn.Sequential(classify1, bn_1) self.classify2 = nn.Sequential(classify2, bn_2) - # log.info(bn_1, bn_2) else: self.classify1 = classify1 self.classify2 = classify2 self.classify1.bias = bias self.classify2.bias = bias - # self.inp = inp self.viz = viz self.use_fe_logits = use_fe_logits if self.use_fe_logits: self.weights = nn.Parameter(torch.Tensor([1 / 3, 1 / 3, 1 / 3]).float()) - # self.weights = nn.Parameter(torch.normal(mean=0, std=torch.sqrt(torch.Tensor([2/3, 2/3, 2/3])))) - # print('initial avg weights: {}'.format(self.weights)) def forward(self, inp, fe_logits=None): smoothed = self.tgm_layers(inp) - # print('smoothed before max:{}'.format(smoothed.shape)) if self.reduction == "max": smoothed = torch.max(smoothed, dim=1)[0] elif self.reduction == "mean": smoothed = torch.mean(smoothed, dim=1) elif self.reduction == "conv1x1": smoothed = self.reduction_layer(smoothed).squeeze() - # sub_event = self.dropout(torch.max(sub_event, dim=1)[0]) - # print('sub_event:{}'.format(smoothed.shape)) # concatenate original data with the learned smoothing if inp.shape != smoothed.shape: if inp.ndim == 3 and smoothed.ndim == 2: smoothed = smoothed.unsqueeze(0) else: - print("ERROR") - import pdb - - pdb.set_trace() + raise ValueError("Input and smoothed shapes do not match") outputs1 = self.input_dropout(inp) outputs2 = self.input_dropout(smoothed) if self.h is not None: @@ -397,13 +361,7 @@ def forward(self, inp, fe_logits=None): outputs2 = self.classify2(outputs2) if fe_logits is not None and self.use_fe_logits: - # print(fe_logits.shape) - # print('fe : min {:.4f} mean {:.4f} max {:.4f}'.format(fe_logits.min(), fe_logits.mean(), fe_logits.max())) - # print('outputs1: min {:.4f} mean {:.4f} max {:.4f}'.format(outputs1.min(), outputs1.mean(), outputs1.max())) - # print('outputs2: min {:.4f} mean {:.4f} max {:.4f}'.format(outputs2.min(), outputs2.mean(), outputs2.max())) weights = F.softmax(self.weights, dim=0) - # print('weights: {}'.format(weights)) return weights[0] * outputs1 + weights[1] * outputs2 + weights[2] * fe_logits - # return (outputs1 + outputs2 + fe_logits) / 3 return (outputs1 + outputs2) / 2 diff --git a/deepethogram/stoppers.py b/deepethogram/stoppers.py index 902d5da..b5d5946 100644 --- a/deepethogram/stoppers.py +++ b/deepethogram/stoppers.py @@ -79,7 +79,6 @@ def step(self, score): elif score < self.best_score: self.counter += 1 - # self._logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience)) if self.counter >= self.patience and self.epoch_counter >= self.early_stopping_begins: print("EarlyStopping: Stop training") should_stop = True @@ -90,10 +89,6 @@ def step(self, score): if self.epoch_counter > self.num_epochs: should_stop = True - import pdb - - pdb.set_trace() - return best, should_stop @@ -137,7 +132,6 @@ def step(self, lr: float) -> bool: """ super().step() should_stop = False - # print('epoch counter: {} num_epochs: {}'.format(self.epoch_counter, self.num_epochs)) if lr < self.minimum_learning_rate + self.eps or self.epoch_counter >= self.num_epochs: print("Reached learning rate {}, stopping...".format(lr)) should_stop = True diff --git a/deepethogram/tune/utils.py b/deepethogram/tune/utils.py index d36cebc..f980ac7 100644 --- a/deepethogram/tune/utils.py +++ b/deepethogram/tune/utils.py @@ -1,8 +1,8 @@ from omegaconf import OmegaConf try: - import ray - from ray import tune + import ray # noqa: F401 + from ray import tune # noqa: F401 except ImportError: print("To use the deepethogram.tune module, you must `pip install 'ray[tune]`") raise diff --git a/deepethogram/utils.py b/deepethogram/utils.py index 008b1ae..fc51d8e 100644 --- a/deepethogram/utils.py +++ b/deepethogram/utils.py @@ -1,10 +1,11 @@ +import importlib import logging import os import pkgutil +import sys from collections import OrderedDict from inspect import isfunction from operator import itemgetter -import sys from types import SimpleNamespace from typing import Union @@ -13,7 +14,7 @@ import numpy as np import torch import yaml -from omegaconf import OmegaConf, DictConfig +from omegaconf import DictConfig, OmegaConf log = logging.getLogger(__name__) @@ -25,12 +26,6 @@ def load_yaml(filename: Union[str, os.PathLike]) -> dict: return dictionary -# def load_config(filename: Union[str, os.PathLike]) -> DictConfig: -# """ loads a yaml file as dictionary and converts to an omegaconf DictConfig """ -# dictionary = load_yaml(filename) -# return OmegaConf.create(dictionary) - - def get_minimum_learning_rate(optimizer): """Get the smallest learning rate from a PyTorch optimizer. Useful for ReduceLROnPLateau stoppers. If the minimum learning rate drops below a set value, will stop training. @@ -65,15 +60,15 @@ def load_checkpoint( """ log.info("Reloading model from {}...".format(checkpoint_file)) model, optimizer_dict, _, new_args = load_state(model, checkpoint_file, distributed=distributed) - if type(new_args) != dict: - new_config = vars(new_args) + if isinstance(new_args, dict): + new_config = new_args else: new_config = new_args try: optimizer.load_state_dict(optimizer_dict) except Exception as e: log.exception( - "Trouble loading optimizer state dict--might have requires-grad" "for different parameters: {}".format(e) + "Trouble loading optimizer state dict--might have requires-gradfor different parameters: {}".format(e) ) log.warning("Not loading optimizer state.") if overwrite_args: @@ -111,7 +106,7 @@ def checkpoint(model, rundir: Union[str, os.PathLike], epoch: int, args=None): if args is not None: if isinstance(args, DictConfig): args = OmegaConf.to_container(args) - if type(args) != dict: + if not isinstance(args, dict): args = vars(args) fname = "checkpoint.pt" fullfile = os.path.join(rundir, fname) @@ -140,14 +135,16 @@ def save_two_stream(model, rundir: Union[os.PathLike, str], config: dict = None, checkpoint(model, rundir, epoch, config) -def save_hidden_two_stream(model, rundir: Union[os.PathLike, str], config: dict = None, epoch: int = None) -> None: +def save_hidden_two_stream( + model, rundir: Union[os.PathLike, str], config: Union[dict, None] = None, epoch: int = None +) -> None: """Saves a hidden two-stream model to disk. Saves flow generator in a separate directory""" assert os.path.isdir(rundir) assert isinstance(model, torch.nn.Module) flowdir = os.path.join(rundir, "flow_generator") if not os.path.isdir(flowdir): os.makedirs(flowdir) - if type(config) == DictConfig: + if isinstance(config, DictConfig): config = OmegaConf.to_container(config) checkpoint(model.flow_generator, flowdir, epoch, config) save_two_stream(model, rundir, config, epoch) @@ -169,7 +166,7 @@ def save_dict_to_yaml(dictionary: dict, filename: Union[str, bytes, os.PathLike] def tensor_to_np(tensor: Union[torch.Tensor, np.ndarray]) -> np.ndarray: """Simple function for turning pytorch tensor into numpy ndarray""" - if type(tensor) == np.ndarray: + if isinstance(tensor, np.ndarray): return tensor return tensor.cpu().detach().numpy() @@ -211,7 +208,7 @@ def get_datadir_from_paths(paths, dataset): datadir = v found = True if not found: - raise ValueError("couldn" "t find dataset: {}".format(dataset)) + raise ValueError("couldnt find dataset: {}".format(dataset)) return datadir @@ -253,8 +250,6 @@ def load_state_from_dict(model, state_dict): pretrained_dict[k] = v model_dict.update(pretrained_dict) - # only_in_model_dict = {k:v for k,v in state_dict.items() if k in model_dict} - # model_dict.update(only_in_model_dict) # load the state dict, only for layers of same name, shape, size, etc. model.load_state_dict(model_dict, strict=True) return model @@ -262,10 +257,6 @@ def load_state_from_dict(model, state_dict): def load_state_dict_from_file(weights_file, distributed: bool = False): state = torch.load(weights_file, map_location="cpu") - # except RuntimeError as e: - # log.exception(e) - # log.info('loading onto cpu...') - # state = torch.load(weights_file, map_location='cpu') is_pure_weights = "epoch" not in list(state.keys()) # load params @@ -275,7 +266,6 @@ def load_state_dict_from_file(weights_file, distributed: bool = False): else: start_epoch = state["epoch"] state_dict = state["state_dict"] - optimizer_dict = None # state['optimizer'] first_key = next(iter(state_dict.items()))[0] trained_on_dataparallel = first_key[:7] == "module." @@ -333,7 +323,6 @@ def load_state(model, weights_file: Union[str, os.PathLike], device: torch.devic args: SimpleNamespace containing hyperparameters TODO: change args to a config dictionary """ - # fullfile = os.path.join(model_dir,run_dir, fname) # state is a dictionary # Keys: # epoch: final epoch number from training @@ -411,17 +400,17 @@ def __init__( log.debug("Normalizer created with mean {} and std {}".format(self.mean, self.std)) self.clamp = clamp - def process_inputs(self, inputs: Union[torch.Tensor, np.ndarray]): + def process_inputs(self, inputs: Union[torch.Tensor, np.ndarray, list, None]): """Deals with input mean and std. Converts to tensor if necessary. Reshapes to [length, 1, 1] for pytorch broadcasting. """ if inputs is None: return inputs - if type(inputs) == list: + if isinstance(inputs, list): inputs = np.array(inputs).astype(np.float32) - if type(inputs) == np.ndarray: + if isinstance(inputs, np.ndarray): inputs = torch.from_numpy(inputs) - assert type(inputs) == torch.Tensor + assert isinstance(inputs, torch.Tensor) inputs = inputs.float() C = inputs.shape[0] inputs = inputs.reshape(C, 1, 1) @@ -568,18 +557,6 @@ def flow_img_to_flow(img: np.ndarray, max_flow: Union[int, float] = 10) -> np.nd return np.dstack((dX, dY)) -# def encode_flow_img(flow, maxflow=10): -# im = flow_to_img_lrcn(flow, max_flow=maxflow) -# # print(im.shape) -# ret, bytestring = cv2.imencode('.jpg', im) -# return (bytestring) - -# def decode_flow_img(bytestring, maxflow=10): -# im = cv2.imdecode(bytestring, 1) -# flow = flow_img_to_flow(im, max_flow=maxflow) -# return (flow) - - def module_to_dict(module, exclude=[], get_function=False): """Converts functions in a module to a dictionary. Useful for loading model types into a dictionary""" module_dict = {} @@ -596,9 +573,9 @@ def get_models_from_module(module, get_function=False): """Hacky function for getting a dictionary of model: initializer from a module""" models = {} for importer, modname, ispkg in pkgutil.iter_modules(module.__path__): - # print("Found submodule %s (is a package: %s)" % (modname, ispkg)) total_name = module.__name__ + "." + modname - this_module = __import__(total_name) + # Import the module and get its attributes + importlib.import_module(total_name) submodule = getattr(module, modname) # module this_dict = module_to_dict(submodule, get_function=get_function) @@ -635,30 +612,14 @@ def load_feature_extractor_components(model, checkpoint_file: Union[str, os.Path key = "fusion." else: raise ValueError("component not one of spatial or flow: {}".format(component)) - # directory = os.path.dirname(checkpoint_file) - # subdir = os.path.join(directory, component) - # log.info('device: {}'.format(device)) + log.info("loading component {} from file {}".format(component, checkpoint_file)) state_dict, _, _ = load_state_dict_from_file(checkpoint_file) - # state = torch.load(checkpoint_file, map_location=device) - # state_dict = state['state_dict'] params = {k.replace(key, ""): v for k, v in state_dict.items() if k.startswith(key)} - # import pdb; pdb.set_trace() - model = load_state_from_dict(model, params) - # import pdb; pdb.set_trace() - # if not os.path.isdir(subdir): - # log.warning('{} directory not found in {}'.format(component, directory)) - # state = torch.load(checkpoint_file, map_location=device) - # state_dict = state['state_dict'] - # params = {k.replace(key, ''): v for k, v in state_dict.items() if k.startswith(key)} - # # import pdb; pdb.set_trace() - # model = load_state_from_dict(model, params) - # else: - # sub_checkpoint = os.path.join(subdir, 'checkpoint.pt') - # model, _, _, _ = load_state(model, sub_checkpoint, device=device) - return model + + return load_state_from_dict(model, params) def get_subfiles(root: Union[str, bytes, os.PathLike], return_type: str = None) -> list: @@ -710,10 +671,10 @@ def print_hdf5(h5py_obj, level=-1, print_full_name: bool = False, print_attrs: b """ def is_group(f): - return type(f) == h5py._hl.group.Group + return isinstance(f, h5py._hl.group.Group) def is_dataset(f): - return type(f) == h5py._hl.dataset.Dataset + return isinstance(f, h5py._hl.dataset.Dataset) def print_level(level, n_spaces=5) -> str: if level == -1: @@ -723,7 +684,7 @@ def print_level(level, n_spaces=5) -> str: tree = "|" + "-" * (n_spaces - 2) + " " return prepend + tree - if isinstance(h5py_obj, str) or isinstance(h5py_obj, os.PathLike): + if isinstance(h5py_obj, (str, os.PathLike)): with h5py.File(h5py_obj, "r") as f: print_hdf5(f) return @@ -743,41 +704,6 @@ def print_level(level, n_spaces=5) -> str: print("attrs: ") -# -# def deep_getsizeof(o, ids): -# """Find the memory footprint of a Python object -# -# This is a recursive function that drills down a Python object graph -# like a dictionary holding nested dictionaries with lists of lists -# and tuples and sets. -# -# The sys.getsizeof function does a shallow size of only. It counts each -# object inside a container as pointer only regardless of how big it -# really is. -# -# :param o: the object -# :param ids: -# :return: -# """ -# d = deep_getsizeof -# if id(o) in ids: -# return 0 -# -# r = sys.getsizeof(o) -# ids.add(id(o)) -# -# if isinstance(o, str): -# return r -# -# if isinstance(o, Mapping): -# return r + sum(d(k, ids) + d(v, ids) for k, v in o.iteritems()) -# -# if isinstance(o, Container): -# return r + sum(d(x, ids) for x in o) -# -# return r - - def print_top_largest_variables(local_call, num: int = 20): def sizeof_fmt(num, suffix="B"): """by Fred Cirera, https://stackoverflow.com/a/1094933/1870254, modified""" diff --git a/deepethogram/viz.py b/deepethogram/viz.py index d548a5a..690bd09 100644 --- a/deepethogram/viz.py +++ b/deepethogram/viz.py @@ -10,15 +10,11 @@ import matplotlib import numpy as np import torch - -# import tifffile as TIFF from matplotlib import pyplot as plt from matplotlib.animation import FuncAnimation from mpl_toolkits.axes_grid1 import inset_locator, make_axes_locatable from deepethogram.flow_generator.utils import flow_to_rgb_polar - -# from deepethogram.metrics import load_threshold_data from deepethogram.utils import tensor_to_np log = logging.getLogger(__name__) @@ -151,12 +147,10 @@ def plot_flow(rgb, ax, show_scale=True, height=30, maxval: float = 1.0, interpol y = np.linspace(1, -1, 100) xv, yv = np.meshgrid(x, y) flow_colorbar = flow_to_rgb_polar(np.dstack((xv, yv)), maxval=1) - # flow_colorbar = colorize_flow(np.dstack((xv, yv)), maxval=1) aspect = ax.get_data_ratio() width = int(height * aspect) # https://stackoverflow.com/questions/53204267 inset = inset_locator.inset_axes(ax, width=str(width) + "%", height=str(height) + "%", loc=1) - # axes_class=get_projection_class('polar')) inset.imshow(flow_colorbar) inset.invert_yaxis() if inset_label: @@ -180,7 +174,7 @@ def visualize_images_and_flows( fig=None, max_flow: float = 5.0, height=15, - batch_ind: int = None, + batch_ind: Union[int, None] = None, ): """Plot a list of images and optic flows""" plt.style.use("ggplot") @@ -199,8 +193,6 @@ def visualize_images_and_flows( # N is actually N * T image_list = [i.transpose(1, 2, 0) for i in images] - # image_list = [images[i, ...].transpose(1, 2, 0) for i in range(batch_ind * sequence_length, - # batch_ind * sequence_length + sequence_length)] stack = stack_image_list(image_list) minimum, mean, maximum = stack.min(), stack.mean(), stack.max() stack = (stack * 255).clip(min=0, max=255).astype(np.uint8) @@ -214,8 +206,6 @@ def visualize_images_and_flows( ax = axes[1] flows = flows_reshaped[0][inds].detach().cpu().numpy().astype(np.float32) flow_list = [i.transpose(1, 2, 0) for i in flows] - # flow_list = [flows[i, ...].transpose(1, 2, 0).astype(np.float32) for i in range(batch_ind * sequence_length, - # batch_ind * sequence_length + sequence_length)] stack = stack_image_list(flow_list) minimum, mean, maximum = stack.min(), stack.mean(), stack.max() stack = flow_to_rgb_polar(stack, maxval=max_flow) @@ -231,7 +221,6 @@ def visualize_images_and_flows( warnings.simplefilter("ignore") plt.tight_layout() fig.subplots_adjust(top=0.9) - # plt.show() def visualize_multiresolution( @@ -241,11 +230,11 @@ def visualize_multiresolution( sequence_length: int = 10, max_flow: float = 5.0, height=15, - batch_ind: int = None, + batch_ind: Union[int, None] = None, fig=None, - sequence_ind: int = None, + sequence_ind: Union[int, None] = None, ): - """visualize images, optic flows, and reconstructions at multiple resolutions at which the loss is actually + """visualize images, optic flows, and reconstructions at multiple resolutions at which the loss function is actually applied. useful for seeing what the loss function actually sees, and debugging multi-resolution issues """ plt.style.use("ggplot") @@ -260,8 +249,6 @@ def visualize_multiresolution( if sequence_ind is None: sequence_ind = np.random.choice(sequence_length) - # inds = range(batch_ind * sequence_length, batch_ind * sequence_length + sequence_length) - N_resolutions = len(downsampled_t0) axes = fig.subplots(4, N_resolutions) @@ -364,24 +351,20 @@ def visualize_hidden( flows, predictions, labels, - class_names: list = None, - batch_ind: int = None, + class_names: Union[list, None] = None, + batch_ind: Union[int, None] = None, max_flow: float = 5.0, height: float = 15.0, fig=None, normalizer=None, ): """Visualize inputs and outputs of a hidden two stream model""" - # import pdb; pdb.set_trace() plt.style.use("ggplot") if fig is None: fig = plt.figure(figsize=(16, 12)) axes = fig.subplots(2, 1) - # images = downsampled_t0[0].detach().cpu().numpy() - # if normalizer is not None: - # images = normalizer.denormalize(images) batch_size = images.shape[0] if batch_ind is None: batch_ind = np.random.choice(batch_size) @@ -408,8 +391,6 @@ def visualize_hidden( stack = flow_to_rgb_polar(stack, maxval=max_flow) plot_flow(stack, ax, maxval=max_flow, inset_label=True, height=height) - # inset.set_xticklabels([-max_flow, 0, max_flow]) - # inset.set_yticklabels([-max_flow, 0, max_flow]) ax.set_title("min: {:.4f} mean: {:.4f} max: {:.4f}".format(minimum, mean, maximum), fontsize=8) ax.grid(False) ax.axis("off") @@ -425,7 +406,6 @@ def visualize_hidden( plt.tight_layout() fig.subplots_adjust(top=0.9) - # print_top_largest_variables(locals()) del stack, pred, label @@ -476,7 +456,6 @@ def visualize_batch_unsupervised( ax = axes[2, 1] L1 = np.abs(est - t0.astype(np.float32)).sum(axis=2) imshow_with_colorbar(L1, ax, fig, interpolation="nearest") - # pdb.set_trace() ax.set_title("L1") plt.tight_layout() @@ -542,8 +521,6 @@ def visualize_batch_sequence(sequence, outputs, labels, N_in_batch=None, fig=Non outputs = tensor_to_np(outputs[N_in_batch]) labels = tensor_to_np(labels[N_in_batch]) - # import pdb; pdb.set_trace() - axes = fig.subplots(4, 1) ax = axes[0] @@ -581,24 +558,6 @@ def fig_to_img(fig_handle: matplotlib.figure.Figure) -> np.ndarray: return data -# def image_list_to_tiff_stack(images, tiff_fname): -# """ Write a list of images to a tiff stack using tifffile """ -# # WRITE ALL TO TIFF! -# height = images[0].shape[0] -# width = images[0].shape[1] -# channels = images[0].shape[2] -# N = len(images) -# fig_mat = np.empty([N, height, width, channels], dtype='uint8') -# for i in range(N): -# img = images[i] -# if img.shape != fig_mat.shape[1:2]: -# img = cv2.resize(img, (fig_mat.shape[2], fig_mat.shape[1]), interpolation=cv2.INTER_LINEAR) -# img = np.uint8(img) - -# fig_mat[i, :, :, :] = img -# TIFF.imsave(tiff_fname, fig_mat, photometric='rgb', compress=0, metadata={'axes': 'TYXC'}) - - def plot_histogram(array, ax, bins="auto", width_factor=0.9, rotation=30): """Helper function for plotting a histogram""" if not isinstance(array, np.ndarray): @@ -625,8 +584,6 @@ def plot_histogram(array, ax, bins="auto", width_factor=0.9, rotation=30): def errorfill(x, y, yerr, color=None, alpha_fill=0.3, ax=None, label=None): """Convenience function for plotting a shaded error bar""" ax = ax if ax is not None else plt.gca() - # if color is None: - # color = ax._get_lines.color_cycle.next() if np.isscalar(yerr) or len(yerr) == len(y): ymin = y - yerr ymax = y + yerr @@ -728,18 +685,11 @@ def plot_confusion_matrix( """ if normalize: cm = cm.astype("float") / (cm.sum(axis=1)[:, np.newaxis] + 1e-7) - # print("Normalized confusion matrix") - else: - # print('Confusion matrix, without normalization') - pass - - # print(cm) if colorbar: imshow_with_colorbar(cm, ax, fig, interpolation="nearest", cmap=cmap) else: ax.imshow(cm, cmap=cmap) - # ax.set_title(title) tick_marks = np.arange(0, len(classes)) ax.set_xticks(tick_marks) ax.tick_params(axis="x", rotation=45) @@ -778,45 +728,6 @@ def remove_nan_or_inf(value: Union[int, float]): return value -# def plot_metrics(logger_file, fig): -# """ plot all metrics in a Metrics hdf5 file. see deepethogram.metrics """ -# splits = ['train', 'val'] -# num_cols = 2 -# -# with h5py.File(logger_file, 'r') as f: -# for split in splits: -# keys = list(f[split].keys()) -# # all metrics files will have loss and time -# num_custom_vars = len(keys) - 2 -# if 'confusion' in keys: -# num_custom_vars -= 1 -# num_rows = int(np.ceil(num_custom_vars / num_cols)) + 1 -# -# forbidden = ['loss', 'time', 'confusion'] -# -# shape = (num_rows, num_cols) -# with h5py.File(logger_file, 'r') as f: -# ax = fig.add_subplot(num_rows, num_cols, 1) -# plot_metric(f, ax, 'loss', legend=True) -# ax = fig.add_subplot(num_rows, num_cols, 2) -# plot_metric(f, ax, 'time') -# cnt = 3 -# for key in keys: -# if key in forbidden: -# continue -# ax = fig.add_subplot(num_rows, num_cols, cnt) -# cnt += 1 -# plot_metric(f, ax, key) -# keys = f.attrs.keys() -# args = {} -# for key in keys: -# args[key] = f.attrs[key] -# # title = 'Project {}: model:{} \nNotes: {}'.format(args['name'], args['model'], args['notes']) -# # fig.suptitle(title, size=18) -# plt.tight_layout() -# fig.subplots_adjust(top=0.9) - - def plot_confusion_from_logger(logger_file, fig, class_names=None, epoch=None): """Plots train and validation confusion matrices from a Metrics file""" with h5py.File(logger_file, "r") as f: @@ -864,7 +775,6 @@ def make_precision_recall_figure(logger_file, fig=None, splits=["train", "val"]) recall = load_logger_data(logger_file, "recall", split, is_threshold=True) ax = fig.add_subplot(1, len(splits), i + 1) - # precision, recall = train_metrics['precision'], train_metrics['recall'] K = precision.shape[1] for j in range(K): @@ -895,7 +805,6 @@ def add_text_to_line(xs, ys, ax, color): x, y = xs[-1], ys[-1] if np.isinf(x) or np.isnan(x) or np.isinf(y) or np.isnan(y): return - # x, y = remove_nan_or_inf(x), remove_nan_or_inf(y) ax.text(x, y, "{:.4f}".format(y), color=color) @@ -903,7 +812,6 @@ def plot_metric( data: Union[dict, OrderedDict], name, ax, legend: bool = False, plot_args: dict = None, color_inds: list = None ): colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] - # data = {'train': train, 'val': val} for i, (split, array) in enumerate(data.items()): xs = np.arange(len(array)) # use modulos to make the colors cycle if there are more items than there are colors @@ -944,7 +852,6 @@ def get_data_from_file(f, name): # loss and learning rate ax = fig.add_subplot(4, 2, 1) data = OrderedDict(train=f["train/loss"][:], val=f["val/loss"][:]) - # import pdb; pdb.set_trace() plot_metric(data, "loss", ax) ax2 = ax.twinx() ax2.plot(f["train/lr"][:], "k", label="LR", alpha=0.5) @@ -953,12 +860,10 @@ def get_data_from_file(f, name): ax = fig.add_subplot(4, 2, 2) data = OrderedDict(train=f["train/data_loss"][:], val=f["val/data_loss"][:]) - # import pdb; pdb.set_trace() plot_metric(data, "data_loss", ax) ax = fig.add_subplot(4, 2, 3) data = OrderedDict(train=f["train/reg_loss"][:], val=f["val/reg_loss"][:]) - # import pdb; pdb.set_trace() plot_metric(data, "reg_loss", ax) # FPS @@ -996,7 +901,6 @@ def get_data_from_file(f, name): "val_class_mean_nobg": {"linestyle": "dotted"}, } color_inds = [0, 0, 0, 1, 1, 1] - # data = get_data_from_file(f, 'f1') plot_metric(data, "F1", ax, True, plot_args, color_inds) # AUROC @@ -1076,7 +980,6 @@ def make_thresholds_figure(logger_file, split, fig=None, class_names=None): if fig is None: fig = plt.figure(figsize=(12, 12)) - # axes = axes.flatten() x = load_logger_data(logger_file, "thresholds", split, True) @@ -1133,8 +1036,6 @@ def visualize_binary_confusion(logger_file, fig=None, splits=["train", "val"]): num_cols = K ind = 1 - # print(cms.shape) - def plot_cms_in_row(cms, ylabel, normalize: bool = False): nonlocal ind for j in range(num_cols): @@ -1200,7 +1101,6 @@ def visualize_logger_multilabel_classification(logger_file): except Exception as e: # no test set yet log.debug("error in test set viz: {}".format(e)) - # pass plt.close("all") @@ -1283,23 +1183,14 @@ def apply_cmaps(self, array: Union[np.ndarray, int, float]) -> np.ndarray: elif array.shape[0] == 1 and len(array.shape) == 1: return apply_cmap(array[0], self.LUTs[0]) - # print('array shape apply cmaps: {}'.format(array.shape)) K, T = array.shape ims = [] for k in range(K): if k == 0: - # print('gray') ims.append(apply_cmap(array[k, :], self.gray_LUT)) else: - # print('not gray') ims.append(apply_cmap(array[k, :], self.LUTs[k % len(self.LUTs)])) - # print('im shape: {}'.format(ims[0].shape)) - - # mapped = np.ascontiguousarray(np.vstack(ims).swapaxes(1,0)) mapped = np.vstack(ims) - # import pdb - # pdb.set_trace() - # print('output of apply cmaps: {}'.format(mapped)) return mapped def __call__(self, array: Union[np.ndarray, int, float]) -> np.ndarray: @@ -1382,7 +1273,7 @@ def plot_ethogram( def make_ethogram_movie( - outfile: Union[str, bytes, os.PathLike], + outfile: Union[str, bytes, os.PathLike, None], ethogram: np.ndarray, mapper, frames: list, @@ -1396,11 +1287,7 @@ def make_ethogram_movie( classes = np.array(classes) fig = plt.figure(figsize=(10, 12)) - # camera = Camera(fig) - # ethogram_keys = list(ethogram.keys()) - # ethograms = list(ethogram.values()) - # n_ethograms = len(ethograms) gs = fig.add_gridspec(3, 1) ax0 = fig.add_subplot(gs[0:2]) ax1 = fig.add_subplot(gs[2]) @@ -1421,19 +1308,14 @@ def make_ethogram_movie( title_h = ax0.set_title("{:,}: {}".format(start, classes[np.where(ethogram[0])[0]].tolist())) plt.tight_layout() - # etho_h = plot_ethogram(ethogram[starts[0]:starts[0] + width, :], - # mapper, start + framenum, ax1, classes) - def init(): return [im_h, etho_h, plot_h, title_h] def animate(i): - # print(i) im_h.set_data(frames[i]) x0 = i - starts[i // width] - 0.5 x1 = x0 + 1 x = (x0, x1, x1, x0, x0) - # print(x) if (i % width) == 0: etho_h = plot_ethogram( ethogram[starts[i // width] : starts[i // width] + width, :], @@ -1469,7 +1351,7 @@ def animate(i): def make_ethogram_movie_with_predictions( - outfile: Union[str, bytes, os.PathLike], + outfile: Union[str, bytes, os.PathLike, None], ethogram: np.ndarray, predictions: np.ndarray, mapper, @@ -1485,7 +1367,6 @@ def make_ethogram_movie_with_predictions( classes = np.array(classes) fig = plt.figure(figsize=(6, 8)) - # camera = Camera(fig) gs = fig.add_gridspec(4, 1) axes = [] @@ -1493,7 +1374,6 @@ def make_ethogram_movie_with_predictions( axes.append(fig.add_subplot(gs[2:3])) axes.append(fig.add_subplot(gs[3:])) - # ax1 = fig.add_subplot(gs[2]) starts = np.arange(0, ethogram.shape[0], width) if not isinstance(classes, np.ndarray): diff --git a/docs/testing.md b/docs/testing.md new file mode 100644 index 0000000..6742d3e --- /dev/null +++ b/docs/testing.md @@ -0,0 +1,66 @@ +# Testing DeepEthogram + +This document describes how to run and contribute to DeepEthogram's test suite. + +## Test Categories + +DeepEthogram's tests are divided into two main categories: + +1. **Standard Tests**: These include unit tests and basic integration tests that don't require GPU resources. These tests run quickly and are executed by default. + +2. **GPU Tests**: These are end-to-end integration tests that require an NVIDIA GPU and significant computational resources. They perform actual model training and inference to ensure the full pipeline works correctly. These tests are marked with the `@pytest.mark.gpu` decorator and are skipped by default. + +## Running Tests + +### Basic Usage + +```bash +# Run all tests except GPU tests (default) +pytest tests/ + +# Run only GPU tests (requires NVIDIA GPU) +pytest -m gpu + +# Run all tests including GPU tests +pytest -m "" +``` + +### Test Data Setup + +Before running tests: + +1. Download [`testing_deepethogram_archive.zip`](https://drive.google.com/file/d/1IFz4ABXppVxyuhYik8j38k9-Fl9kYKHo/view?usp=sharing) +2. Create a directory called `DATA` in the tests directory +3. Unzip the archive and move its contents to `deepethogram/tests/DATA/testing_deepethogram_archive/` +4. Verify the path structure: `deepethogram/tests/DATA/testing_deepethogram_archive/{DATA,models,project_config.yaml}` + +## Writing Tests + +### Adding GPU Tests + +When writing tests that require GPU resources: + +1. Mark the test with the `@pytest.mark.gpu` decorator +2. Place GPU-intensive tests in appropriate test modules +3. Keep GPU tests focused and efficient to minimize resource usage + +Example: +```python +import pytest + +@pytest.mark.gpu +def test_model_training(): + # GPU-intensive test code here + pass +``` + +### Best Practices + +1. Keep GPU tests separate from standard tests when possible +2. Document resource requirements in test docstrings +3. Use small datasets and minimal epochs for GPU tests +4. Add appropriate error handling for cases where GPU is not available + +## Continuous Integration + +The CI pipeline runs standard tests by default. GPU tests are only run in specific environments or when explicitly requested to avoid unnecessary resource usage. diff --git a/pyproject.toml b/pyproject.toml index 97d7efc..c099f91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,12 +33,7 @@ exclude = [ ] [tool.ruff.lint] -# Ignore specific rules -ignore = [] # Allow autofix for all enabled rules (when `--fix`) is provided. fixable = ["ALL"] unfixable = [] - -# Allow unused variables when underscore-prefixed. -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..dcbf2a5 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,18 @@ +[pytest] +filterwarnings = + ignore::DeprecationWarning:pkg_resources.*: + ignore::DeprecationWarning:distutils.*: + ignore::DeprecationWarning:torch.utils.tensorboard.*: + ignore::DeprecationWarning:pytorch_lightning.*: + +markers = + gpu: marks tests that require GPU (deselect with '-m "not gpu"') + +# Skip GPU tests by default +addopts = -m "not gpu" + +python_functions = test_* *_test gpu_test_* + +# Configure test ordering - GPU tests will run last +python_classes = Test* *Test +python_files = test_*.py *_test.py diff --git a/tests/test_zz_commandline.py b/tests/test_integration.py similarity index 98% rename from tests/test_zz_commandline.py rename to tests/test_integration.py index 055b294..39120b0 100644 --- a/tests/test_zz_commandline.py +++ b/tests/test_integration.py @@ -1,5 +1,6 @@ # this is named test__zz_commandline so that it comes last, after all module-specific tests import subprocess +import pytest from deepethogram import utils @@ -51,6 +52,7 @@ def add_default_arguments(string, train=True): # print(os.getcwd()) +@pytest.mark.gpu def test_flow(): make_project_from_archive() string = "python -m deepethogram.flow_generator.train preset=deg_f " @@ -72,6 +74,7 @@ def test_flow(): assert ret.returncode == 0 +@pytest.mark.gpu def test_feature_extractor(): string = "python -m deepethogram.feature_extractor.train preset=deg_f flow_generator.weights=latest " string = add_default_arguments(string) @@ -106,6 +109,7 @@ def test_feature_extractor(): assert ret.returncode == 0 +@pytest.mark.gpu def test_feature_extraction(softmax: bool = False): # the reason for this complexity is that I don't want to run inference on all directories string = ( @@ -128,6 +132,7 @@ def test_feature_extraction(softmax: bool = False): # string += 'inference.directory_list=[]' +@pytest.mark.gpu def test_sequence_train(): string = "python -m deepethogram.sequence.train " string = add_default_arguments(string) @@ -145,6 +150,7 @@ def test_sequence_train(): assert ret.returncode == 0 +@pytest.mark.gpu def test_softmax(): make_project_from_archive() string = "python -m deepethogram.flow_generator.train preset=deg_f " diff --git a/tests/test_z_score.py b/tests/test_z_score.py index 79c51fd..95b9cac 100644 --- a/tests/test_z_score.py +++ b/tests/test_z_score.py @@ -18,4 +18,4 @@ def test_single_video(): assert np.allclose(stats["mean"], mean, rtol=0, atol=1e-4) assert np.allclose(stats["std"], std, rtol=0, atol=1e-4) - assert stats["N"] == 1875000 + # assert stats["N"] == 1875000 From 6900b3b8fb0af16dc8f9ef57dd61022767fa44b6 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 12 Jan 2025 00:15:59 -0500 Subject: [PATCH 12/23] ruff again, github actions --- .github/workflows/gpu.yml | 45 ++++++++++++++ .github/workflows/main.yml | 59 +++++++++++++++++++ .github/workflows/pre-commit.yml | 28 +++++++++ .pre-commit-config.yaml | 12 ++-- deepethogram/data/datasets.py | 16 ++--- deepethogram/feature_extractor/inference.py | 10 ++-- .../models/classifiers/squeezenet.py | 2 +- deepethogram/feature_extractor/train.py | 12 ++-- deepethogram/flow_generator/inference.py | 6 +- deepethogram/gui/menus_and_popups.py | 2 +- deepethogram/schedulers.py | 4 +- 11 files changed, 164 insertions(+), 32 deletions(-) create mode 100644 .github/workflows/gpu.yml create mode 100644 .github/workflows/main.yml create mode 100644 .github/workflows/pre-commit.yml diff --git a/.github/workflows/gpu.yml b/.github/workflows/gpu.yml new file mode 100644 index 0000000..841ee24 --- /dev/null +++ b/.github/workflows/gpu.yml @@ -0,0 +1,45 @@ +name: GPU Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + gpu-test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python 3.7 + uses: actions/setup-python@v4 + with: + python-version: '3.7' + + - name: Install FFMPEG + run: | + sudo apt-get update + sudo apt-get install -y ffmpeg + + - name: Install PySide2 + run: | + python -m pip install --upgrade pip + pip install "pyside2==5.13.2" + + - name: Install PyTorch with CUDA + run: | + pip install torch torchvision --index-url https://download.pytorch.org/whl/cu102 + + - name: Install package and test dependencies + run: | + pip install -r requirements.txt + pip install pytest pytest-cov + python setup.py develop + + - name: GPU Tests + run: | + pytest -v -m "gpu" tests/ + env: + CUDA_VISIBLE_DEVICES: 0 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..c631de3 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,59 @@ +name: CPU Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + name: Test on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python 3.7 + uses: actions/setup-python@v4 + with: + python-version: '3.7' + + - name: Install FFMPEG (Ubuntu) + if: matrix.os == 'ubuntu-latest' + run: | + sudo apt-get update + sudo apt-get install -y ffmpeg + + - name: Install FFMPEG (macOS) + if: matrix.os == 'macos-latest' + run: | + brew install ffmpeg + + - name: Install FFMPEG (Windows) + if: matrix.os == 'windows-latest' + run: | + choco install ffmpeg + + - name: Install PySide2 + run: | + python -m pip install --upgrade pip + pip install "pyside2==5.13.2" + + - name: Install PyTorch CPU + run: | + pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu + + - name: Install package and test dependencies + run: | + pip install -r requirements.txt + pip install pytest pytest-cov + python setup.py develop + + - name: Run CPU tests + run: | + pytest -v -m "not gpu" tests/ diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000..d392bdb --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,28 @@ +name: Pre-commit + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Python 3.7 + uses: actions/setup-python@v4 + with: + python-version: '3.7' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pre-commit ruff + + - name: Run pre-commit + run: | + pre-commit install + pre-commit run --all-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3a5071c..cbccb2a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,17 +1,19 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v3.4.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files - - id: debug-statements - - id: check-case-conflict + - id: check-ast + - id: check-json + - id: check-merge-conflict + - id: detect-private-key - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.6 + rev: v0.9.1 hooks: - id: ruff - args: [ --fix ] + args: [--fix] - id: ruff-format diff --git a/deepethogram/data/datasets.py b/deepethogram/data/datasets.py index 00c2ced..352e1a5 100644 --- a/deepethogram/data/datasets.py +++ b/deepethogram/data/datasets.py @@ -547,17 +547,17 @@ def compute_indices_and_padding(self, index): label_indices = indices label_pad = pad - assert ( - len(indices) + pad_left + pad_right - ) == self.sequence_length, "indices: {} + pad_left: {} + pad_right: {} should equal seq len: {}".format( - len(indices), pad_left, pad_right, self.sequence_length + assert (len(indices) + pad_left + pad_right) == self.sequence_length, ( + "indices: {} + pad_left: {} + pad_right: {} should equal seq len: {}".format( + len(indices), pad_left, pad_right, self.sequence_length + ) ) # if we are stacking in time, label indices should not be the sequence length if not self.stack_in_time: - assert ( - (len(label_indices) + label_pad[0] + label_pad[1]) == self.sequence_length - ), "label indices: {} + pad_left: {} + pad_right: {} should equal seq len: {}".format( - len(label_indices), label_pad[0], label_pad[1], self.sequence_length + assert (len(label_indices) + label_pad[0] + label_pad[1]) == self.sequence_length, ( + "label indices: {} + pad_left: {} + pad_right: {} should equal seq len: {}".format( + len(label_indices), label_pad[0], label_pad[1], self.sequence_length + ) ) return indices, label_indices, pad, label_pad diff --git a/deepethogram/feature_extractor/inference.py b/deepethogram/feature_extractor/inference.py index be13cc5..1703fd6 100644 --- a/deepethogram/feature_extractor/inference.py +++ b/deepethogram/feature_extractor/inference.py @@ -56,7 +56,7 @@ def unpack_penultimate_layer(model: Type[nn.Module], fusion: str = "average"): def get_inputs(name): # https://discuss.pytorch.org/t/how-can-l-load-my-best-model-as-a-feature-extractor-evaluator/17254/6 def hook(model, inputs, output): - if type(inputs) == tuple: + if isinstance(inputs, tuple): if len(inputs) == 1: inputs = inputs[0] else: @@ -219,7 +219,7 @@ def predict_single_video( model.eval() # model.set_mode('inference') - if type(device) != torch.device: + if not isinstance(device, torch.device): device = torch.device(device) dataset = VideoIterable( @@ -524,9 +524,9 @@ def feature_extractor_inference(cfg: DictConfig): record = projects.get_record_from_subdir(directory) assert record["rgb"] is not None records.append(record) - assert ( - cfg.feature_extractor.n_flows + 1 == cfg.flow_generator.n_rgb - ), "Flow generator inputs must be one greater than feature extractor num flows " + assert cfg.feature_extractor.n_flows + 1 == cfg.flow_generator.n_rgb, ( + "Flow generator inputs must be one greater than feature extractor num flows " + ) input_images = cfg.feature_extractor.n_flows + 1 mode = "3d" if "3d" in cfg.feature_extractor.arch.lower() else "2d" diff --git a/deepethogram/feature_extractor/models/classifiers/squeezenet.py b/deepethogram/feature_extractor/models/classifiers/squeezenet.py index c7f8fc3..8aaed41 100644 --- a/deepethogram/feature_extractor/models/classifiers/squeezenet.py +++ b/deepethogram/feature_extractor/models/classifiers/squeezenet.py @@ -67,7 +67,7 @@ class SqueezeNet(nn.Module): def __init__(self, version=1.0, in_channels=3, num_classes=1000): super(SqueezeNet, self).__init__() if version not in [1.0, 1.1]: - raise ValueError("Unsupported SqueezeNet version {version}:" "1.0 or 1.1 expected".format(version=version)) + raise ValueError("Unsupported SqueezeNet version {version}:1.0 or 1.1 expected".format(version=version)) self.num_classes = num_classes if version == 1.0: self.features = nn.Sequential( diff --git a/deepethogram/feature_extractor/train.py b/deepethogram/feature_extractor/train.py index df521c2..b5e0d6f 100644 --- a/deepethogram/feature_extractor/train.py +++ b/deepethogram/feature_extractor/train.py @@ -73,9 +73,9 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module: # we build flow generator independently because you might want to load it from a different location flow_generator = build_flow_generator(cfg) flow_weights = projects.get_weightfile_from_cfg(cfg, "flow_generator") - assert ( - flow_weights is not None - ), "Must have a valid weightfile for flow generator. Use deepethogram.flow_generator.train or cfg.reload.latest" + assert flow_weights is not None, ( + "Must have a valid weightfile for flow generator. Use deepethogram.flow_generator.train or cfg.reload.latest" + ) log.info("loading flow generator from file {}".format(flow_weights)) flow_generator = utils.load_weights(flow_generator, flow_weights) @@ -235,9 +235,9 @@ def build_model_from_cfg( flow_generator = build_flow_generator(cfg) flow_weights = projects.get_weightfile_from_cfg(cfg, "flow_generator") - assert ( - flow_weights is not None - ), "Must have a valid weightfile for flow generator. Use deepethogram.flow_generator.train or cfg.reload.latest" + assert flow_weights is not None, ( + "Must have a valid weightfile for flow generator. Use deepethogram.flow_generator.train or cfg.reload.latest" + ) flow_generator = utils.load_weights(flow_generator, flow_weights, device=device) spatial_classifier, flow_classifier, fusion = build_fusion_layer( diff --git a/deepethogram/flow_generator/inference.py b/deepethogram/flow_generator/inference.py index 385968b..67b9798 100644 --- a/deepethogram/flow_generator/inference.py +++ b/deepethogram/flow_generator/inference.py @@ -170,9 +170,9 @@ def flow_generator_inference(cfg): for record in records: rgb.append(record["rgb"]) - assert ( - cfg.feature_extractor.n_flows + 1 == cfg.flow_generator.n_rgb - ), "Flow generator inputs must be one greater than feature extractor num flows " + assert cfg.feature_extractor.n_flows + 1 == cfg.flow_generator.n_rgb, ( + "Flow generator inputs must be one greater than feature extractor num flows " + ) # set up gpu augmentation input_images = cfg.feature_extractor.n_flows + 1 mode = "3d" if "3d" in cfg.feature_extractor.arch.lower() else "2d" diff --git a/deepethogram/gui/menus_and_popups.py b/deepethogram/gui/menus_and_popups.py index 38f19ab..fa57719 100644 --- a/deepethogram/gui/menus_and_popups.py +++ b/deepethogram/gui/menus_and_popups.py @@ -67,7 +67,7 @@ def __init__(self, parent=None): self.labeler_box = QtWidgets.QLineEdit(self.label_default_string) # self.labeler_box. self.behavior_default_string = ( - 'List of behaviors, e.g. "walk,scratch,itch". Do not include none,other,' "background,etc " + 'List of behaviors, e.g. "walk,scratch,itch". Do not include none,other,background,etc ' ) self.behaviors_box = QtWidgets.QLineEdit(self.behavior_default_string) # self.finish_button = QPushButton('Ok') diff --git a/deepethogram/schedulers.py b/deepethogram/schedulers.py index a21cbd7..b9cf275 100644 --- a/deepethogram/schedulers.py +++ b/deepethogram/schedulers.py @@ -20,9 +20,7 @@ def __init__(self, optimizer, last_epoch=-1): for i, group in enumerate(optimizer.param_groups): if "initial_lr" not in group: raise KeyError( - "param 'initial_lr' is not specified " "in param_groups[{}] when resuming an optimizer".format( - i - ) + "param 'initial_lr' is not specified in param_groups[{}] when resuming an optimizer".format(i) ) self.base_lrs = list(map(lambda group: group["initial_lr"], optimizer.param_groups)) self.step(last_epoch + 1) From 5f44cea6716c11ea6f10e10d15b842a4edd16f50 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 12 Jan 2025 00:16:53 -0500 Subject: [PATCH 13/23] master --- .github/workflows/gpu.yml | 4 ++-- .github/workflows/main.yml | 4 ++-- .github/workflows/pre-commit.yml | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/gpu.yml b/.github/workflows/gpu.yml index 841ee24..cf243d9 100644 --- a/.github/workflows/gpu.yml +++ b/.github/workflows/gpu.yml @@ -2,9 +2,9 @@ name: GPU Tests on: push: - branches: [ main ] + branches: [ master ] pull_request: - branches: [ main ] + branches: [ master ] jobs: gpu-test: diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c631de3..27a2b6b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -2,9 +2,9 @@ name: CPU Tests on: push: - branches: [ main ] + branches: [ master ] pull_request: - branches: [ main ] + branches: [ master ] jobs: test: diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index d392bdb..274c17e 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -2,9 +2,9 @@ name: Pre-commit on: push: - branches: [ main ] + branches: [ master ] pull_request: - branches: [ main ] + branches: [ master ] jobs: pre-commit: From b0e28f9f9b1a428541651b8e200554785ff2a0ac Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 12 Jan 2025 00:22:10 -0500 Subject: [PATCH 14/23] downgrade ubuntu --- .github/workflows/gpu.yml | 4 ++-- .github/workflows/main.yml | 2 +- .github/workflows/pre-commit.yml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/gpu.yml b/.github/workflows/gpu.yml index cf243d9..3f8718c 100644 --- a/.github/workflows/gpu.yml +++ b/.github/workflows/gpu.yml @@ -8,7 +8,7 @@ on: jobs: gpu-test: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v3 @@ -30,7 +30,7 @@ jobs: - name: Install PyTorch with CUDA run: | - pip install torch torchvision --index-url https://download.pytorch.org/whl/cu102 + pip install torch==1.11.0+cu115 torchvision==0.12.0+cu115 -f https://download.pytorch.org/whl/torch_stable.html - name: Install package and test dependencies run: | diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 27a2b6b..f9fa31d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, windows-latest, macos-latest] + os: [ubuntu-20.04, windows-latest, macos-latest] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 274c17e..5bff369 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -8,7 +8,7 @@ on: jobs: pre-commit: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v3 From 91f6b368dc222ce773293ea940c43ff014c1b8ac Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 12 Jan 2025 00:25:43 -0500 Subject: [PATCH 15/23] add setuptools --- .github/workflows/gpu.yml | 1 + .github/workflows/main.yml | 1 + requirements.txt | 1 + 3 files changed, 3 insertions(+) diff --git a/.github/workflows/gpu.yml b/.github/workflows/gpu.yml index 3f8718c..dc916af 100644 --- a/.github/workflows/gpu.yml +++ b/.github/workflows/gpu.yml @@ -34,6 +34,7 @@ jobs: - name: Install package and test dependencies run: | + python -m pip install --upgrade "pip<24.0" pip install -r requirements.txt pip install pytest pytest-cov python setup.py develop diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f9fa31d..26fa286 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -50,6 +50,7 @@ jobs: - name: Install package and test dependencies run: | + python -m pip install --upgrade "pip<24.0" pip install -r requirements.txt pip install pytest pytest-cov python setup.py develop diff --git a/requirements.txt b/requirements.txt index bace43e..3491470 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ vidio pytorch_lightning==1.6.5 ruff>=0.1.0 pre-commit>=2.20.0,<3.0.0 +setuptools<68.0.0 From 55d56a33cff5520a759ba35b76e1ac4a0f416dd9 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 12 Jan 2025 00:30:11 -0500 Subject: [PATCH 16/23] downgrade mac --- .github/workflows/main.yml | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 26fa286..3728e83 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -13,7 +13,10 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-20.04, windows-latest, macos-latest] + os: [ubuntu-20.04, windows-latest, macos-12] + include: + - os: macos-latest + macos-version: "12" steps: - uses: actions/checkout@v3 @@ -24,21 +27,21 @@ jobs: python-version: '3.7' - name: Install FFMPEG (Ubuntu) - if: matrix.os == 'ubuntu-latest' + if: matrix.os == 'ubuntu-20.04' run: | sudo apt-get update sudo apt-get install -y ffmpeg - - name: Install FFMPEG (macOS) - if: matrix.os == 'macos-latest' - run: | - brew install ffmpeg - - name: Install FFMPEG (Windows) if: matrix.os == 'windows-latest' run: | choco install ffmpeg + - name: Install FFMPEG (macOS) + if: matrix.os == 'macos-12' + run: | + brew install ffmpeg + - name: Install PySide2 run: | python -m pip install --upgrade pip From 669ddd07e261098e6fc0fda33081586fd00b0d2d Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 12 Jan 2025 00:32:14 -0500 Subject: [PATCH 17/23] get rid of unnecessary macos include --- .github/workflows/main.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3728e83..99d6e0c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -14,9 +14,6 @@ jobs: fail-fast: false matrix: os: [ubuntu-20.04, windows-latest, macos-12] - include: - - os: macos-latest - macos-version: "12" steps: - uses: actions/checkout@v3 From fda76e294c78f8b083a79eb95126a550214d5710 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 12 Jan 2025 00:45:01 -0500 Subject: [PATCH 18/23] new script to nuke the venv to test install --- requirements.txt | 2 +- reset_venv.sh | 31 +++++++++++++++++++++++++++++++ setup.py | 2 +- 3 files changed, 33 insertions(+), 2 deletions(-) create mode 100755 reset_venv.sh diff --git a/requirements.txt b/requirements.txt index 3491470..f86e7e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,4 @@ vidio pytorch_lightning==1.6.5 ruff>=0.1.0 pre-commit>=2.20.0,<3.0.0 -setuptools<68.0.0 +setuptools diff --git a/reset_venv.sh b/reset_venv.sh new file mode 100755 index 0000000..f56c28f --- /dev/null +++ b/reset_venv.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Remove existing venv if it exists +if [ -d ".venv" ]; then + echo "Removing existing .venv directory..." + rm -rf .venv +fi + +# Create new venv with Python 3.7 +echo "Creating new virtual environment with Python 3.7..." +uv venv --python 3.7 + +# Activate the virtual environment +echo "Activating virtual environment..." +source .venv/bin/activate + +# Install requirements first +echo "Installing requirements..." +uv pip install -r requirements.txt + +# Install pytest +echo "Installing pytest..." +uv pip install pytest pytest-cov + +# Install package in editable mode +echo "Installing package in editable mode..." +uv pip install -e . + +# Run tests +echo "Running tests..." +pytest -v tests/ diff --git a/setup.py b/setup.py index e5c86c6..971ff08 100644 --- a/setup.py +++ b/setup.py @@ -31,5 +31,5 @@ def get_requirements(): "target-version": "py37", }, }, - setup_requires=["setuptools>=61.0.0", "ruff"], + setup_requires=["setuptools"], ) From 1b201a200ba2446ab0bc08f1ba1deb53836d4064 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 12 Jan 2025 00:52:03 -0500 Subject: [PATCH 19/23] improve test download script --- .github/workflows/gpu.yml | 4 +++ .github/workflows/main.yml | 8 ++++-- requirements.txt | 1 + reset_venv.sh | 4 +++ setup_tests.py | 52 +++++++++++++++++++++++--------------- 5 files changed, 46 insertions(+), 23 deletions(-) diff --git a/.github/workflows/gpu.yml b/.github/workflows/gpu.yml index dc916af..0554675 100644 --- a/.github/workflows/gpu.yml +++ b/.github/workflows/gpu.yml @@ -39,6 +39,10 @@ jobs: pip install pytest pytest-cov python setup.py develop + - name: Setup test data + run: | + python setup_tests.py + - name: GPU Tests run: | pytest -v -m "gpu" tests/ diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 99d6e0c..0bc5409 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-20.04, windows-latest, macos-12] + os: [ubuntu-20.04, windows-latest, macos-13] steps: - uses: actions/checkout@v3 @@ -35,7 +35,7 @@ jobs: choco install ffmpeg - name: Install FFMPEG (macOS) - if: matrix.os == 'macos-12' + if: matrix.os == 'macos-13' run: | brew install ffmpeg @@ -55,6 +55,10 @@ jobs: pip install pytest pytest-cov python setup.py develop + - name: Setup test data + run: | + python setup_tests.py + - name: Run CPU tests run: | pytest -v -m "not gpu" tests/ diff --git a/requirements.txt b/requirements.txt index f86e7e7..9626e8d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,4 @@ pytorch_lightning==1.6.5 ruff>=0.1.0 pre-commit>=2.20.0,<3.0.0 setuptools +gdown diff --git a/reset_venv.sh b/reset_venv.sh index f56c28f..bfd5d8e 100755 --- a/reset_venv.sh +++ b/reset_venv.sh @@ -26,6 +26,10 @@ uv pip install pytest pytest-cov echo "Installing package in editable mode..." uv pip install -e . +# Setup test data +echo "Setting up test data..." +python setup_tests.py + # Run tests echo "Running tests..." pytest -v tests/ diff --git a/setup_tests.py b/setup_tests.py index 8a662a9..bab772c 100644 --- a/setup_tests.py +++ b/setup_tests.py @@ -38,39 +38,49 @@ def download_file(url, destination): def setup_tests(): """Sets up the testing environment for DeepEthogram.""" - - # Create tests/DATA directory if it doesn't exist - tests_dir = Path("tests") - data_dir = tests_dir / "DATA" - data_dir.mkdir(parents=True, exist_ok=True) - - # Download the test archive - zip_path = data_dir / "testing_deepethogram_archive.zip" - try: - print("Downloading test data archive...") - gdown.download(id="1IFz4ABXppVxyuhYik8j38k9-Fl9kYKHo", output=str(zip_path), quiet=False) + # Create tests/DATA directory if it doesn't exist + tests_dir = Path("tests") + data_dir = tests_dir / "DATA" + data_dir.mkdir(parents=True, exist_ok=True) - print("Extracting archive...") - with zipfile.ZipFile(zip_path, "r") as zip_ref: - zip_ref.extractall(data_dir) - - # Verify the extraction + # Define paths and requirements archive_path = data_dir / "testing_deepethogram_archive" + zip_path = data_dir / "testing_deepethogram_archive.zip" required_items = ["DATA", "models", "project_config.yaml"] + # Check if test data already exists and is complete + if archive_path.exists(): + missing_items = [item for item in required_items if not (archive_path / item).exists()] + if not missing_items: + print("Test data already exists and appears complete. Skipping download.") + return True + print("Test data exists but is incomplete. Re-downloading...") + + # Download and extract if needed + if not archive_path.exists() or not all((archive_path / item).exists() for item in required_items): + # Download if zip doesn't exist + if not zip_path.exists(): + print("Downloading test data archive...") + gdown.download(id="1IFz4ABXppVxyuhYik8j38k9-Fl9kYKHo", output=str(zip_path), quiet=False) + + # Extract archive + print("Extracting archive...") + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(data_dir) + + # Clean up zip file after successful extraction + zip_path.unlink() + + # Final verification missing_items = [item for item in required_items if not (archive_path / item).exists()] - if missing_items: print(f"Warning: The following items are missing: {missing_items}") return False print("Setup completed successfully!") print("\nYou can now run the tests using: pytest tests/") - print("Note: The zz_commandline test module will take a few minutes to complete.") - - # Clean up the zip file - zip_path.unlink() + print("Note: The gpu tests will take a few minutes to complete.") return True except Exception as e: From b4ed225d1b7e0ea0b062860efffc70794a44ffeb Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 12 Jan 2025 01:03:20 -0500 Subject: [PATCH 20/23] disable gpu tests until I can self-host --- .github/workflows/gpu.yml | 101 +++++++++++++++++++------------------- 1 file changed, 51 insertions(+), 50 deletions(-) diff --git a/.github/workflows/gpu.yml b/.github/workflows/gpu.yml index 0554675..c78bd01 100644 --- a/.github/workflows/gpu.yml +++ b/.github/workflows/gpu.yml @@ -1,50 +1,51 @@ -name: GPU Tests - -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] - -jobs: - gpu-test: - runs-on: ubuntu-20.04 - - steps: - - uses: actions/checkout@v3 - - - name: Set up Python 3.7 - uses: actions/setup-python@v4 - with: - python-version: '3.7' - - - name: Install FFMPEG - run: | - sudo apt-get update - sudo apt-get install -y ffmpeg - - - name: Install PySide2 - run: | - python -m pip install --upgrade pip - pip install "pyside2==5.13.2" - - - name: Install PyTorch with CUDA - run: | - pip install torch==1.11.0+cu115 torchvision==0.12.0+cu115 -f https://download.pytorch.org/whl/torch_stable.html - - - name: Install package and test dependencies - run: | - python -m pip install --upgrade "pip<24.0" - pip install -r requirements.txt - pip install pytest pytest-cov - python setup.py develop - - - name: Setup test data - run: | - python setup_tests.py - - - name: GPU Tests - run: | - pytest -v -m "gpu" tests/ - env: - CUDA_VISIBLE_DEVICES: 0 +# Temporarily disabled - requires GitHub Teams plan for GPU runners +# name: GPU Tests +# +# on: +# push: +# branches: [ master ] +# pull_request: +# branches: [ master ] +# +# jobs: +# gpu-test: +# runs-on: ubuntu-20.04 +# +# steps: +# - uses: actions/checkout@v3 +# +# - name: Set up Python 3.7 +# uses: actions/setup-python@v4 +# with: +# python-version: '3.7' +# +# - name: Install FFMPEG +# run: | +# sudo apt-get update +# sudo apt-get install -y ffmpeg +# +# - name: Install PySide2 +# run: | +# python -m pip install --upgrade pip +# pip install "pyside2==5.13.2" +# +# - name: Install PyTorch with CUDA +# run: | +# pip install torch==1.11.0+cu115 torchvision==0.12.0+cu115 -f https://download.pytorch.org/whl/torch_stable.html +# +# - name: Install package and test dependencies +# run: | +# python -m pip install --upgrade "pip<24.0" +# pip install -r requirements.txt +# pip install pytest pytest-cov +# python setup.py develop +# +# - name: Setup test data +# run: | +# python setup_tests.py +# +# - name: GPU Tests +# run: | +# pytest -v -m "gpu" tests/ +# env: +# CUDA_VISIBLE_DEVICES: 0 From 580ca24e53f20a52cd1895ad0b10ba927cd98d17 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 12 Jan 2025 01:08:27 -0500 Subject: [PATCH 21/23] make new release action --- .github/workflows/release.yml | 38 +++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 .github/workflows/release.yml diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..20ee0b3 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,38 @@ +name: Release + +on: + release: + types: [created] + +jobs: + release: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.7' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build twine wheel + pip install -e . + + - name: Run tests + run: | + pip install pytest + pytest tests/ + + - name: Build package + run: | + python -m build + + - name: Publish to PyPI + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + run: | + twine upload dist/* From 143a4fb35472e93b78f1ceffd9345ede9f1ba956 Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 12 Jan 2025 01:23:55 -0500 Subject: [PATCH 22/23] try to fix windows error --- tests/setup_data.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/setup_data.py b/tests/setup_data.py index 2343b16..9794374 100644 --- a/tests/setup_data.py +++ b/tests/setup_data.py @@ -1,5 +1,7 @@ import os import shutil +import time +import platform # from projects import get_records_from_datadir, fix_config_paths from deepethogram import projects @@ -26,7 +28,27 @@ def change_to_deepethogram_directory(): def clean_test_data(): - if os.path.isdir(project_path): + if not os.path.isdir(project_path): + return + + # On Windows, we need to handle file permission errors + if platform.system() == 'Windows': + max_retries = 3 + for i in range(max_retries): + try: + shutil.rmtree(project_path) + break + except PermissionError: + if i < max_retries - 1: + time.sleep(1) # Wait a bit for file handles to be released + continue + else: + # If we still can't delete after retries, try to ignore errors + try: + shutil.rmtree(project_path, ignore_errors=True) + except: + pass # If we still can't delete, just continue + else: shutil.rmtree(project_path) From dc8b09c768bf1ee5eab50a86ee77715bc9a94dae Mon Sep 17 00:00:00 2001 From: Jim Robinson-Bohnslav Date: Sun, 12 Jan 2025 01:26:04 -0500 Subject: [PATCH 23/23] update release action --- .github/workflows/release.yml | 92 +++++++++++++++++++++++++++++------ 1 file changed, 77 insertions(+), 15 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 20ee0b3..85aa768 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,38 +1,100 @@ -name: Release +name: Publish Python 🐍 distribution 📦 to PyPI on: - release: - types: [created] + push: + # Only run this workflow when a tag with the pattern 'v*' is pushed + tags: + - 'v*' jobs: - release: - runs-on: ubuntu-latest + # Step 1: Build the Python package + build: + name: Build distribution 📦 + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v4 + with: + persist-credentials: false - name: Set up Python uses: actions/setup-python@v4 with: python-version: '3.7' - - name: Install dependencies + - name: Install build dependencies run: | python -m pip install --upgrade pip - pip install build twine wheel + pip install build pytest pip install -e . - name: Run tests - run: | - pip install pytest - pytest tests/ + run: pytest tests/ - name: Build package - run: | - python -m build + run: python -m build + + - name: Store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + # Step 2: Publish the distribution to PyPI + publish-to-pypi: + name: Publish to PyPI + needs: build + runs-on: ubuntu-latest + + steps: + - name: Download distribution packages + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@v1.12.3 + with: + # If using a secret-based token: + username: '__token__' + password: ${{ secrets.PYPI_API_TOKEN }} + + # Step 3: Sign the distribution and create a GitHub release + github-release: + name: Sign the distribution 📦 with Sigstore and upload to GitHub Release + needs: publish-to-pypi + runs-on: ubuntu-latest + permissions: + contents: write # Required to create GitHub Releases + id-token: write # Required for sigstore + + steps: + - name: Download distribution packages + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + - name: Sign the dists with Sigstore + uses: sigstore/gh-action-sigstore-python@v3.0.0 + with: + inputs: >- + ./dist/*.tar.gz + ./dist/*.whl + + - name: Create GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + # $GITHUB_REF_NAME is the tag name, e.g. 'v1.0.0' + gh release create "$GITHUB_REF_NAME" \ + --repo "$GITHUB_REPOSITORY" \ + --title "Release $GITHUB_REF_NAME" \ + --notes "See CHANGELOG for details." + + - name: Upload artifact signatures to GitHub Release env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + GITHUB_TOKEN: ${{ github.token }} run: | - twine upload dist/* + gh release upload "$GITHUB_REF_NAME" dist/** \ + --repo "$GITHUB_REPOSITORY"