diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 87c66399c..aebccd6de 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -33,5 +33,6 @@ Table.alter() # Comment regarding the change - [ ] If release, I have updated the `CITATION.cff` - [ ] This PR makes edits to table definitions: (yes/no) - [ ] If table edits, I have included an `alter` snippet for release notes. +- [ ] If this PR makes changes to positon, I ran the relevant tests locally. - [ ] I have updated the `CHANGELOG.md` with PR number and description. - [ ] I have added/edited docs/notebooks to reflect the changes diff --git a/.github/workflows/publish-docs.yml b/.github/workflows/publish-docs.yml index 3b39b877c..db1cf6224 100644 --- a/.github/workflows/publish-docs.yml +++ b/.github/workflows/publish-docs.yml @@ -5,6 +5,7 @@ on: - "*.*.*" # For docs bump, use X.X.XaX branches: - test_branch + workflow_dispatch: # Manually trigger with 'Run workflow' button permissions: contents: write diff --git a/.github/workflows/test-conda.yml b/.github/workflows/test-conda.yml index 6432b366e..fd9245c8e 100644 --- a/.github/workflows/test-conda.yml +++ b/.github/workflows/test-conda.yml @@ -1,4 +1,4 @@ -name: Test conda env and run tests +name: Tests on: push: @@ -7,52 +7,74 @@ on: - '!documentation' schedule: # once a day at midnight UTC - cron: '0 0 * * *' + workflow_dispatch: # Manually trigger with 'Run workflow' button + +concurrency: # Replace Cancel Workflow Action + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true jobs: run-tests: - runs-on: ${{ matrix.os }} + runs-on: ubuntu-latest defaults: run: shell: bash -l {0} - strategy: - matrix: - os: [ubuntu-latest] #, macos-latest, windows-latest] env: - OS: ${{ matrix.os }} - PYTHON: '3.8' + OS: ubuntu-latest + PYTHON: '3.9' + UCSF_BOX_TOKEN: ${{ secrets.UCSF_BOX_TOKEN }} # for download and testing + UCSF_BOX_USER: ${{ secrets.UCSF_BOX_USER }} + services: + mysql: + image: datajoint/mysql:8.0 + env: # args: mysql -h 127.0.0.1 -P 3308 -uroot -ptutorial -e "CMD;" + MYSQL_DATABASE: localhost + MYSQL_ROOT_PASSWORD: tutorial + ports: + - 3308:3306 + options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 steps: - - name: Cancel Workflow Action - uses: styfle/cancel-workflow-action@0.11.0 - with: - access_token: ${{ github.token }} - all_but_latest: true - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ env.PYTHON }} - - name: Set up conda environment - uses: conda-incubator/setup-miniconda@v2 + - name: Set up conda + uses: conda-incubator/setup-miniconda@v3 with: activate-environment: spyglass environment-file: environment.yml miniforge-variant: Mambaforge miniforge-version: latest - - name: Install spyglass + use-mamba: true + - name: Install apt dependencies run: | - pip install -e .[test] + sudo apt-get update # First mysql options + sudo apt-get install mysql-client libmysqlclient-dev libgirepository1.0-dev -y + sudo apt-get install ffmpeg libsm6 libxext6 -y # non-dlc position deps + - name: Run pip install for test deps + run: | + pip install --quiet .[test] - name: Download data env: - UCSF_BOX_TOKEN: ${{ secrets.UCSF_BOX_TOKEN }} - UCSF_BOX_USER: ${{ secrets.UCSF_BOX_USER }} - WEBSITE: ftps://ftp.box.com/trodes_to_nwb_test_data/minirec20230622.nwb + BASEURL: ftps://ftp.box.com/trodes_to_nwb_test_data/ + NWBFILE: minirec20230622.nwb # Relative to Base URL + VID_ONE: 20230622_sample_01_a1/20230622_sample_01_a1.1.h264 + VID_TWO: 20230622_sample_02_a1/20230622_sample_02_a1.1.h264 RAW_DIR: /home/runner/work/spyglass/spyglass/tests/_data/raw/ + VID_DIR: /home/runner/work/spyglass/spyglass/tests/_data/video/ run: | - mkdir -p $RAW_DIR - wget --recursive --no-verbose --no-host-directories --no-directories \ - --user $UCSF_BOX_USER --password $UCSF_BOX_TOKEN \ - -P $RAW_DIR $WEBSITE + mkdir -p $RAW_DIR $VID_DIR + wget_opts() { # Declare func with download options + wget \ + --recursive --no-verbose --no-host-directories --no-directories \ + --user "$UCSF_BOX_USER" --password "$UCSF_BOX_TOKEN" \ + -P "$1" "$BASEURL""$2" + } + wget_opts $RAW_DIR $NWBFILE + wget_opts $VID_DIR $VID_ONE + wget_opts $VID_DIR $VID_TWO - name: Run tests run: | - pytest -rP # env vars are set within certain tests + pytest --no-docker --no-dlc diff --git a/.github/workflows/test-package-build.yml b/.github/workflows/test-package-build.yml index 0098982cb..41aace719 100644 --- a/.github/workflows/test-package-build.yml +++ b/.github/workflows/test-package-build.yml @@ -13,6 +13,7 @@ on: branches: - master - maint/* + workflow_dispatch: # Manually trigger with 'Run workflow' button defaults: run: shell: bash @@ -20,10 +21,10 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: 3.9 - run: pip install --upgrade build twine @@ -31,14 +32,14 @@ jobs: run: python -m build - run: twine check dist/* - name: Upload sdist and wheel artifacts - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: dist path: dist/ - name: Build git archive run: mkdir archive && git archive -v -o archive/archive.tgz HEAD - name: Upload git archive artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: archive path: archive/ @@ -51,13 +52,13 @@ jobs: steps: - name: Download sdist and wheel artifacts if: matrix.package != 'archive' - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: dist path: dist/ - name: Download git archive artifact if: matrix.package == 'archive' - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: archive path: archive/ @@ -74,13 +75,9 @@ jobs: - name: Install sdist if: matrix.package == 'sdist' run: pip install dist/*.tar.gz - - name: Install archive - if: matrix.package == 'archive' + - name: Install archive # requires tag + if: matrix.package == 'archive' && startsWith(github.ref, 'refs/tags/') run: pip install archive/archive.tgz - # - name: Install test extras - # run: pip install spyglass[test] - # - name: Run tests - # run: pytest --doctest-modules -v --pyargs spyglass publish: name: Upload release to PyPI runs-on: ubuntu-latest @@ -92,7 +89,7 @@ jobs: id-token: write # IMPORTANT: this permission is mandatory for trusted publishing if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') steps: - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: dist path: dist/ diff --git a/.gitignore b/.gitignore index 052080023..6319e5f1c 100644 --- a/.gitignore +++ b/.gitignore @@ -59,6 +59,7 @@ coverage.xml *.cover .hypothesis/ .pytest_cache/ +tests/_data/* # Translations *.mo diff --git a/CHANGELOG.md b/CHANGELOG.md index 54d6f087b..565b0c301 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,11 +14,18 @@ - Add long-distance restrictions via `<<` and `>>` operators. #943, #969 - Fix relative pathing for `mkdocstring-python=>1.9.1`. #967, #968 - Clean up old `TableChain.join` call in mixin delete. #982 +- Add pytests for position pipeline, various `test_mode` exceptions #966 +- Migrate `pip` dependencies from `environment.yml`s to `pyproject.toml` #966 ### Pipelines +- Common + - `PositionVideo` table now inserts into self after `make` #966 +- Decoding: Default values for classes on `ImportError` #966 - DLC - Allow dlc without pre-existing tracking data #973, #975 + - Raise `KeyError` for missing input parameters across helper funcs #966 + - `DLCPosVideo` table now inserts into self after `make` #966 ## [0.5.2] (April 22, 2024) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index d146e54a8..df319d828 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,4 +1,3 @@ - # Contributor Covenant Code of Conduct ## Our Pledge @@ -6,7 +5,7 @@ We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender -identity and expression, level of experience, education, socio-economic status, +identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation. @@ -18,24 +17,24 @@ diverse, inclusive, and healthy community. Examples of behavior that contributes to a positive environment for our community include: -* Demonstrating empathy and kindness toward other people -* Being respectful of differing opinions, viewpoints, and experiences -* Giving and gracefully accepting constructive feedback -* Accepting responsibility and apologizing to those affected by our mistakes, - and learning from the experience -* Focusing on what is best not just for us as individuals, but for the overall - community +- Demonstrating empathy and kindness toward other people +- Being respectful of differing opinions, viewpoints, and experiences +- Giving and gracefully accepting constructive feedback +- Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +- Focusing on what is best not just for us as individuals, but for the overall + community Examples of unacceptable behavior include: -* The use of sexualized language or imagery, and sexual attention or advances of - any kind -* Trolling, insulting or derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or email address, - without their explicit permission -* Other conduct which could reasonably be considered inappropriate in a - professional setting +- The use of sexualized language or imagery, and sexual attention or advances of + any kind +- Trolling, insulting or derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or email address, + without their explicit permission +- Other conduct which could reasonably be considered inappropriate in a + professional setting ## Enforcement Responsibilities @@ -61,8 +60,8 @@ representative at an online or offline event. Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at -eric.denovellis@ucsf.edu. -All complaints will be reviewed and investigated promptly and fairly. +eric.denovellis@ucsf.edu. All complaints will be reviewed and investigated +promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. @@ -120,14 +119,14 @@ version 2.1, available at [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. Community Impact Guidelines were inspired by -[Mozilla's code of conduct enforcement ladder][Mozilla CoC]. +[Mozilla's code of conduct enforcement ladder][mozilla coc]. For answers to common questions about this code of conduct, see the FAQ at -[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at +[https://www.contributor-covenant.org/faq][faq]. Translations are available at [https://www.contributor-covenant.org/translations][translations]. +[faq]: https://www.contributor-covenant.org/faq [homepage]: https://www.contributor-covenant.org -[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html -[Mozilla CoC]: https://github.com/mozilla/diversity -[FAQ]: https://www.contributor-covenant.org/faq +[mozilla coc]: https://github.com/mozilla/diversity [translations]: https://www.contributor-covenant.org/translations +[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html diff --git a/docs/src/misc/mixin.md b/docs/src/misc/mixin.md index 229747402..ab49b0c49 100644 --- a/docs/src/misc/mixin.md +++ b/docs/src/misc/mixin.md @@ -21,6 +21,11 @@ schema = dj.schema("my_schema") @schema class MyOldTable(dj.Manual): pass + + +@schema +class MyNewTable(SpyglassMixin, dj.Manual): + pass ``` **NOTE**: The mixin must be the first class inherited from in order to override @@ -60,10 +65,10 @@ key and `>>` as a shorthand for `restrict_by` a downstream key. from spyglass.example import AnyTable AnyTable() << 'upstream_attribute="value"' -AnyTable() >> 'downsteam_attribute="value"' +AnyTable() >> 'downstream_attribute="value"' # Equivalent to -AnyTable().restrict_by('downsteam_attribute="value"', direction="down") +AnyTable().restrict_by('downstream_attribute="value"', direction="down") AnyTable().restrict_by('upstream_attribute="value"', direction="up") ``` @@ -136,7 +141,7 @@ function, `delete_downstream_merge`, to handle this, which is run by default when calling `delete`. `delete_downstream_merge`, also aliased as `ddm`, identifies all Merge tables -downsteam of where it is called. If `dry_run=True`, it will return a list of +downstream of where it is called. If `dry_run=True`, it will return a list of entries that would be deleted, otherwise it will delete them. Importantly, `delete_downstream_merge` cannot properly interact with tables that @@ -156,7 +161,7 @@ from spyglass.example import MyMerge restricted_nwbfile.delete_downstream_merge(reload_cache=True, dry_run=False) ``` -Because each table keeps a cache of downsteam merge tables, it is important to +Because each table keeps a cache of downstream merge tables, it is important to reload the cache if the table has been imported after the cache was created. Speed gains can also be achieved by avoiding re-instancing the table each time. diff --git a/environment.yml b/environment.yml index 0f5e19187..7fa1b51ea 100644 --- a/environment.yml +++ b/environment.yml @@ -10,9 +10,10 @@ name: spyglass channels: - conda-forge - defaults - # - pytorch # dlc-only - franklab - edeno + # - pytorch # dlc-only + # - anaconda # dlc-only, for cudatoolkit dependencies: - bottleneck # - cudatoolkit=11.3 # dlc-only @@ -36,15 +37,6 @@ dependencies: # - torchvision # dlc-only - track_linearization>=2.3 - pip: - - "black[jupyter]" - - datajoint>=0.13.6 - # - deeplabcut<2.3.0 # dlc-only - ghostipy # for common_filter - - ndx-franklab-novela>=0.1.0 - mountainsort4 - - panel<=1.3.5 # See panel #6325 - - pubnub<=6.4.0 - - pynwb>=2.2.0,<3 - - sortingview>=0.11 - - spikeinterface>=0.99.1,<0.100 - . diff --git a/environment_dlc.yml b/environment_dlc.yml index 45fd107c8..9870a0424 100644 --- a/environment_dlc.yml +++ b/environment_dlc.yml @@ -10,9 +10,10 @@ name: spyglass-dlc channels: - conda-forge - defaults - - pytorch # dlc-only - franklab - edeno + - pytorch # dlc-only + - anaconda # dlc-only, for cudatoolkit dependencies: - bottleneck - cudatoolkit=11.3 # dlc-only @@ -22,6 +23,7 @@ dependencies: - libgcc # dlc-only - matplotlib - non_local_detector + - numpy<1.24 - pip>=20.2.* - position_tools - pybind11 # req by mountainsort4 -> isosplit5 @@ -35,16 +37,6 @@ dependencies: - torchvision # dlc-only - track_linearization>=2.3 - pip: - - "black[jupyter]" - - datajoint>=0.13.6 - - deeplabcut<2.3.0 # dlc-only - ghostipy # for common_filter - - ndx-franklab-novela>=0.1.0 - mountainsort4 - - panel<=1.3.5 # See panel #6325 - - pubnub<=6.4.0 - - pynwb>=2.2.0,<3 - - sortingview>=0.11 - - spikeinterface>=0.98.2,<0.99 - - tensorflow<=2.12 # dlc-only - .[dlc] diff --git a/pyproject.toml b/pyproject.toml index ffb8d0df6..78d189b73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ keywords = [ "kachery", "sortingview", ] +dynamic = ["version"] dependencies = [ "black[jupyter]", "bottleneck", @@ -47,8 +48,8 @@ dependencies = [ "non_local_detector", "numpy<1.24", "opencv-python", - "panel<=1.3.4", # See panel #6325 - "position_tools", + "panel>=1.4.0", # panel #6325 resolved + "position_tools>=0.1.0", "pubnub<6.4.0", # TODO: remove this when sortingview is updated "pydotplus", "pynwb>=2.2.0,<3", @@ -58,31 +59,27 @@ dependencies = [ "spikeinterface>=0.99.1,<0.100", "track_linearization>=2.3", ] -dynamic = ["version"] - -[project.scripts] -spyglass_cli = "spyglass.cli:cli" - -[project.urls] -"Homepage" = "https://github.com/LorenFrankLab/spyglass" -"Bug Tracker" = "https://github.com/LorenFrankLab/spyglass/issues" [project.optional-dependencies] -dlc = ["ffmpeg", "numba>=0.54", "deeplabcut<2.3.0"] +dlc = [ + "ffmpeg", + "deeplabcut[tf]", # removing dlc pin removes need to pin tf/numba +] test = [ - "click", # for CLI subpackage only - "docker", # for tests in a container + "click", # for CLI subpackage only + "docker", # for tests in a container "ghostipy", - "kachery", # database access + "kachery", # database access "kachery-client", "kachery-cloud>=0.4.0", - "pre-commit", # linting - "pytest", # unit testing - "pytest-cov", # code coverage + "pre-commit", # linting + "pytest", # unit testing + "pytest-cov", # code coverage + "pytest-xvfb", # for headless testing of Qt ] docs = [ "hatch", # Get version from env - "jupytext==1.16.0", # Convert notebooks to .py + "jupytext", # Convert notebooks to .py "mike", # Docs versioning "mkdocs", # Docs core "mkdocs-exclude", # Docs exclude files @@ -94,6 +91,13 @@ docs = [ "mkdocstrings[python]", # Docs API docstrings ] +[project.scripts] +spyglass_cli = "spyglass.cli:cli" + +[project.urls] +"Homepage" = "https://github.com/LorenFrankLab/spyglass" +"Bug Tracker" = "https://github.com/LorenFrankLab/spyglass/issues" + [tool.hatch.version] source = "vcs" @@ -120,20 +124,28 @@ ignore-words-list = 'nevers' [tool.pytest.ini_options] minversion = "7.0" addopts = [ - "-sv", + # "-sv", # no capture, verbose output # "--sw", # stepwise: resume with next test after failure # "--pdb", # drop into debugger on failure "-p no:warnings", # "--no-teardown", # don't teardown the database after tests - # "--quiet-spy", # don't show logging from spyglass + # "--quiet-spy", # don't show logging from spyglass + # "--no-dlc", # don't run DLC tests "--show-capture=no", "--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger + "--doctest-modules", # run doctests in all modules "--cov=spyglass", "--cov-report=term-missing", "--no-cov-on-fail", ] testpaths = ["tests"] log_level = "INFO" +env = [ + "QT_QPA_PLATFORM = offscreen", # QT fails headless without this + "DISPLAY = :0", # QT fails headless without this + "TF_ENABLE_ONEDNN_OPTS = 0", # TF disable approx calcs + "TF_CPP_MIN_LOG_LEVEL = 2", # Disable TF warnings +] [tool.coverage.run] source = ["*/src/spyglass/*"] diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index 722b32f74..f9abff647 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -625,7 +625,7 @@ def set_lfp_band_electrodes( if lfp_sampling_rate // decimation != lfp_band_sampling_rate: raise ValueError( f"lfp_band_sampling rate {lfp_band_sampling_rate} is not an integer divisor of lfp " - f"samping rate {lfp_sampling_rate}" + f"sampling rate {lfp_sampling_rate}" ) # filter query = FirFilterParameters() & { diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index 382e39069..ed91aa463 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -24,7 +24,7 @@ from spyglass.common.common_behav import RawPosition, VideoFile from spyglass.common.common_interval import IntervalList # noqa F401 from spyglass.common.common_nwbfile import AnalysisNwbfile -from spyglass.settings import raw_dir, video_dir +from spyglass.settings import raw_dir, test_mode, video_dir from spyglass.utils import SpyglassMixin, logger from spyglass.utils.dj_helper_fn import deprecated_factory @@ -601,6 +601,7 @@ def make(self, key): cm_to_pixels=cm_per_pixel, disable_progressbar=False, ) + self.insert1(key) @staticmethod def convert_to_pixels(data, frame_size, cm_to_pixels=1.0): @@ -644,6 +645,7 @@ def make_video( disable_progressbar=False, arrow_radius=15, circle_radius=8, + truncate_data=False, # reduce data to min length across all variables ): import cv2 # noqa: F401 @@ -657,6 +659,25 @@ def make_video( frame_rate = video.get(5) n_frames = int(head_orientation_mean.shape[0]) + if test_mode or truncate_data: + # pytest video data has mismatched shapes in some cases + # centroid (267, 2), video_time (270, 2), position_time (5193,) + min_len = min( + n_frames, + len(video_time), + len(position_time), + len(head_position_mean), + len(head_orientation_mean), + min(len(v) for v in centroids.values()), + ) + n_frames = min_len + video_time = video_time[:min_len] + position_time = position_time[:min_len] + head_position_mean = head_position_mean[:min_len] + head_orientation_mean = head_orientation_mean[:min_len] + for color, data in centroids.items(): + centroids[color] = data[:min_len] + out = cv2.VideoWriter( output_video_filename, fourcc, frame_rate, frame_size, True ) @@ -749,7 +770,10 @@ def make_video( video.release() out.release() - cv2.destroyAllWindows() + try: + cv2.destroyAllWindows() + except cv2.error: # if cv is already closed or does not have func + pass # ----------------------------- Migrated Tables ----------------------------- diff --git a/src/spyglass/decoding/v0/dj_decoder_conversion.py b/src/spyglass/decoding/v0/dj_decoder_conversion.py index edcb0d637..1cf6d30c4 100644 --- a/src/spyglass/decoding/v0/dj_decoder_conversion.py +++ b/src/spyglass/decoding/v0/dj_decoder_conversion.py @@ -26,6 +26,21 @@ ObservationModel, ) except ImportError as e: + ( + Identity, + RandomWalk, + RandomWalkDirection1, + RandomWalkDirection2, + Uniform, + DiagonalDiscrete, + RandomDiscrete, + UniformDiscrete, + UserDefinedDiscrete, + Environment, + UniformInitialConditions, + UniformOneEnvironmentInitialConditions, + ObservationModel, + ) = [None] * 13 logger.warning(e) from track_linearization import make_track_graph diff --git a/src/spyglass/position/v1/dlc_reader.py b/src/spyglass/position/v1/dlc_reader.py index c2e56063f..caa3c2e5c 100644 --- a/src/spyglass/position/v1/dlc_reader.py +++ b/src/spyglass/position/v1/dlc_reader.py @@ -8,6 +8,8 @@ import pandas as pd import ruamel.yaml as yaml +from spyglass.settings import test_mode + class PoseEstimation: def __init__( @@ -32,10 +34,11 @@ def __init__( pkl_paths = list( self.dlc_dir.rglob(f"{filename_prefix}*meta.pickle") ) - assert len(pkl_paths) == 1, ( - "Unable to find one unique .pickle file in: " - + f"{dlc_dir} - Found: {len(pkl_paths)}" - ) + if not test_mode: + assert len(pkl_paths) == 1, ( + "Unable to find one unique .pickle file in: " + + f"{dlc_dir} - Found: {len(pkl_paths)}" + ) self.pkl_path = pkl_paths[0] else: self.pkl_path = Path(pkl_path) @@ -44,18 +47,20 @@ def __init__( # data file: h5 - body part outputs from the DLC post estimation step if h5_path is None: h5_paths = list(self.dlc_dir.rglob(f"{filename_prefix}*.h5")) - assert len(h5_paths) == 1, ( - "Unable to find one unique .h5 file in: " - + f"{dlc_dir} - Found: {len(h5_paths)}" - ) + if not test_mode: + assert len(h5_paths) == 1, ( + "Unable to find one unique .h5 file in: " + + f"{dlc_dir} - Found: {len(h5_paths)}" + ) self.h5_path = h5_paths[0] else: self.h5_path = Path(h5_path) assert self.h5_path.exists() - assert ( - self.pkl_path.stem == self.h5_path.stem + "_meta" - ), f"Mismatching h5 ({self.h5_path.stem}) and pickle {self.pkl_path.stem}" + if not test_mode: + assert ( + self.pkl_path.stem == self.h5_path.stem + "_meta" + ), f"Mismatching h5 ({self.h5_path.stem}) and pickle {self.pkl_path.stem}" # config file: yaml - configuration for invoking the DLC post estimation step if yml_path is None: @@ -65,10 +70,11 @@ def __init__( yml_paths = [ val for val in yml_paths if val.stem == "dj_dlc_config" ] - assert len(yml_paths) == 1, ( - "Unable to find one unique .yaml file in: " - + f"{dlc_dir} - Found: {len(yml_paths)}" - ) + if not test_mode: + assert len(yml_paths) == 1, ( + "Unable to find one unique .yaml file in: " + + f"{dlc_dir} - Found: {len(yml_paths)}" + ) self.yml_path = yml_paths[0] else: self.yml_path = Path(yml_path) diff --git a/src/spyglass/position/v1/dlc_utils.py b/src/spyglass/position/v1/dlc_utils.py index 369207886..6d27615e4 100644 --- a/src/spyglass/position/v1/dlc_utils.py +++ b/src/spyglass/position/v1/dlc_utils.py @@ -11,7 +11,7 @@ from contextlib import redirect_stdout from itertools import groupby from operator import itemgetter -from typing import Union +from typing import Iterable, Union import datajoint as dj import matplotlib.pyplot as plt @@ -20,8 +20,8 @@ from tqdm import tqdm as tqdm from spyglass.common.common_behav import VideoFile +from spyglass.settings import dlc_output_dir, dlc_video_dir, raw_dir, test_mode from spyglass.utils import logger -from spyglass.settings import dlc_output_dir, dlc_video_dir, raw_dir def validate_option( @@ -62,7 +62,10 @@ def validate_option( f"Unknown {name}: {option} " f"Available options: {options}" ) - if types and not isinstance(option, tuple(types)): + if types is not None and not isinstance(types, Iterable): + types = (types,) + + if types is not None and not isinstance(option, types): raise TypeError(f"{name} is {type(option)}. Available types {types}") if val_range and not (val_range[0] <= option <= val_range[1]): @@ -108,7 +111,6 @@ def validate_smooth_params(params): if not params.get("smooth"): return smoothing_params = params.get("smoothing_params") - validate_option(smoother=smoothing_params, name="smoothing_params") validate_option( option=smoothing_params.get("smooth_method"), name="smooth_method", @@ -194,7 +196,7 @@ class OutputLogger: # TODO: migrate to spyglass.utils.logger def __init__(self, name, path, level="INFO", **kwargs): self.logger = self.setup_logger(name, path, **kwargs) self.name = self.logger.name - self.level = getattr(logging, level) + self.level = 30 if test_mode else getattr(logging, level) def setup_logger( self, name_logfile, path_logfile, print_console=False @@ -383,7 +385,17 @@ def infer_output_dir(key, makedir=True): """ # TODO: add check to make sure interval_list_name refers to a single epoch # Or make key include epoch in and of itself instead of interval_list_name - nwb_file_name = key["nwb_file_name"].split("_.")[0] + + file_name = key.get("nwb_file_name") + dlc_model_name = key.get("dlc_model_name") + epoch = key.get("epoch") + + if not all([file_name, dlc_model_name, epoch]): + raise ValueError( + "Key must contain 'nwb_file_name', 'dlc_model_name', and 'epoch'" + ) + + nwb_file_name = file_name.split("_.")[0] output_dir = pathlib.Path(dlc_output_dir) / pathlib.Path( f"{nwb_file_name}/{nwb_file_name}_{key['epoch']:02}" f"_model_" + key["dlc_model_name"].replace(" ", "-") @@ -1019,7 +1031,10 @@ def make_video( video.release() out.release() print("destroying cv2 windows") - cv2.destroyAllWindows() + try: + cv2.destroyAllWindows() + except cv2.error: # if cv is already closed or does not have func + pass print("finished making video with opencv") return diff --git a/src/spyglass/position/v1/position_dlc_centroid.py b/src/spyglass/position/v1/position_dlc_centroid.py index f1f077d6a..70a1c1252 100644 --- a/src/spyglass/position/v1/position_dlc_centroid.py +++ b/src/spyglass/position/v1/position_dlc_centroid.py @@ -170,7 +170,7 @@ def make(self, key): for point in required_points: bodypart = points[point] if bodypart not in bodyparts_avail: - raise ValueError( + raise ValueError( # TODO: migrate to input validation "Bodypart in points not in model." f"\tBodypart {bodypart}" f"\tIn Model {bodyparts_avail}" @@ -222,6 +222,7 @@ def make(self, key): "smoothing_duration" ) if not smoothing_duration: + # TODO: remove - validated with `validate_smooth_params` raise KeyError( "smoothing_duration needs to be passed within smoothing_params" ) @@ -368,6 +369,7 @@ def four_led_centroid(pos_df: pd.DataFrame, **params): """Determines the centroid of 4 LEDS on an implant LED ring. Assumed to be the Green LED, and 3 red LEDs called: redLED_C, redLED_L, redLED_R By default, uses (greenled + redLED_C) / 2 to calculate centroid + If Green LED is NaN, but red center LED is not, then the red center LED is called the centroid If green and red center LEDs are NaN, but red left and red right LEDs are not, @@ -397,6 +399,9 @@ def four_led_centroid(pos_df: pd.DataFrame, **params): numpy array with shape (n_time, 2) centroid[0] is the x coord and centroid[1] is the y coord """ + if not (params.get("max_LED_separation") and params.get("points")): + raise KeyError("max_LED_separation/points need to be passed in params") + centroid = np.zeros(shape=(len(pos_df), 2)) idx = pd.IndexSlice # TODO: this feels messy, clean-up @@ -722,6 +727,8 @@ def two_pt_centroid(pos_df: pd.DataFrame, **params): numpy array with shape (n_time, 2) centroid[0] is the x coord and centroid[1] is the y coord """ + if not (params.get("max_LED_separation") and params.get("points")): + raise KeyError("max_LED_separation/points need to be passed in params") idx = pd.IndexSlice centroid = np.zeros(shape=(len(pos_df), 2)) @@ -797,6 +804,8 @@ def one_pt_centroid(pos_df: pd.DataFrame, **params): numpy array with shape (n_time, 2) centroid[0] is the x coord and centroid[1] is the y coord """ + if not params.get("points"): + raise KeyError("points need to be passed in params") idx = pd.IndexSlice PT1 = params["points"].pop("point1", None) centroid = pos_df.loc[:, idx[PT1, ("x", "y")]].to_numpy() diff --git a/src/spyglass/position/v1/position_dlc_cohort.py b/src/spyglass/position/v1/position_dlc_cohort.py index b265a1ce5..6cf1f0eee 100644 --- a/src/spyglass/position/v1/position_dlc_cohort.py +++ b/src/spyglass/position/v1/position_dlc_cohort.py @@ -113,6 +113,12 @@ def make(self, key): bodyparts_params_dict ), "more entries found in DLCSmoothInterp than specified in bodyparts_params_dict" table_column_names = list(table_entries[0].dtype.fields.keys()) + + if len(table_entries) == 0: + raise ValueError( + f"No entries found in DLCSmoothInterp for {temp_key}" + ) + for table_entry in table_entries: entry_key = { **{ diff --git a/src/spyglass/position/v1/position_dlc_pose_estimation.py b/src/spyglass/position/v1/position_dlc_pose_estimation.py index 6a670fc31..6ae7669bf 100644 --- a/src/spyglass/position/v1/position_dlc_pose_estimation.py +++ b/src/spyglass/position/v1/position_dlc_pose_estimation.py @@ -35,7 +35,7 @@ class DLCPoseEstimationSelection(SpyglassMixin, dj.Manual): """ @classmethod - def get_video_crop(cls, video_path): + def get_video_crop(cls, video_path, crop_input=None): """ Queries the user to determine the cropping parameters for a given video @@ -61,9 +61,13 @@ def get_video_crop(cls, video_path): ax.set_yticks(np.arange(ylims[0], ylims[-1], -50)) ax.grid(visible=True, color="white", lw=0.5, alpha=0.5) display(fig) - crop_input = input( - "Please enter the crop parameters for your video in format xmin, xmax, ymin, ymax, or 'none'\n" - ) + + if crop_input is None: + crop_input = input( + "Please enter the crop parameters for your video in format " + + "xmin, xmax, ymin, ymax, or 'none'\n" + ) + plt.close() if crop_input.lower() == "none": return None @@ -98,6 +102,10 @@ def insert_estimation_task( video_path, video_filename, _, _ = get_video_path(key) output_dir = infer_output_dir(key) + + if not video_path: + raise FileNotFoundError(f"Video file not found for {key}") + with OutputLogger( name=f"{key['nwb_file_name']}_{key['epoch']}_{key['dlc_model_name']}_log", path=f"{output_dir.as_posix()}/log.log", diff --git a/src/spyglass/position/v1/position_dlc_position.py b/src/spyglass/position/v1/position_dlc_position.py index 11c7019f3..c18eafd62 100644 --- a/src/spyglass/position/v1/position_dlc_position.py +++ b/src/spyglass/position/v1/position_dlc_position.py @@ -13,6 +13,7 @@ validate_option, validate_smooth_params, ) +from spyglass.settings import test_mode from spyglass.utils.dj_mixin import SpyglassMixin from .position_dlc_pose_estimation import DLCPoseEstimation @@ -176,7 +177,12 @@ def make(self, key): params = (DLCSmoothInterpParams() & key).fetch1("params") # Get DLC output dataframe logger.logger.info("fetching Pose Estimation Dataframe") - dlc_df = (DLCPoseEstimation.BodyPart() & key).fetch1_dataframe() + + bp_key = key.copy() + if test_mode: # during testing, analysis_file not in BodyPart table + bp_key.pop("analysis_file_name", None) + + dlc_df = (DLCPoseEstimation.BodyPart() & bp_key).fetch1_dataframe() dt = np.median(np.diff(dlc_df.index.to_numpy())) sampling_rate = 1 / dt logger.logger.info("Identifying indices to NaN") @@ -223,7 +229,7 @@ def make(self, key): final_df = smooth_df.drop(["likelihood"], axis=1) final_df = final_df.rename_axis("time").reset_index() position_nwb_data = ( - (DLCPoseEstimation.BodyPart() & key) + (DLCPoseEstimation.BodyPart() & bp_key) .fetch_nwb()[0]["dlc_pose_estimation_position"] .get_spatial_series() ) @@ -338,6 +344,11 @@ def nan_inds( subthresh_inds_mask, inds_to_span=inds_to_span ) + if len(good_spans) == 0: + # Prevents ref before assignment error of mask on return + # TODO: instead of raise, insert empty dataframe + raise ValueError("No good spans found in the data") + for span in good_spans[::-1]: if np.sum(np.isnan(dlc_df.iloc[span[0] : span[-1]].x)) > 0: nan_mask = np.isnan(dlc_df.iloc[span[0] : span[-1]].x) diff --git a/src/spyglass/position/v1/position_dlc_selection.py b/src/spyglass/position/v1/position_dlc_selection.py index b140111e1..02692ce14 100644 --- a/src/spyglass/position/v1/position_dlc_selection.py +++ b/src/spyglass/position/v1/position_dlc_selection.py @@ -276,7 +276,7 @@ def insert_default(cls): def get_default(cls): query = cls & {"dlc_pos_video_params_name": "default"} if not len(query) > 0: - cls().insert_default(skip_duplicates=True) + cls().insert_default() default = (cls & {"dlc_pos_video_params_name": "default"}).fetch1() else: default = query.fetch1() @@ -304,6 +304,8 @@ class DLCPosVideo(SpyglassMixin, dj.Computed): --- """ + # TODO: Shoultn't this keep track of the video file it creates? + def make(self, key): from tqdm import tqdm as tqdm @@ -432,3 +434,4 @@ def make(self, key): crop=crop, **params["video_params"], ) + self.insert1(key) diff --git a/src/spyglass/position/v1/position_dlc_training.py b/src/spyglass/position/v1/position_dlc_training.py index ec40d43e0..393eb6af9 100644 --- a/src/spyglass/position/v1/position_dlc_training.py +++ b/src/spyglass/position/v1/position_dlc_training.py @@ -6,6 +6,7 @@ from spyglass.position.v1.dlc_utils import OutputLogger from spyglass.position.v1.position_dlc_project import DLCProject +from spyglass.settings import test_mode from spyglass.utils.dj_mixin import SpyglassMixin schema = dj.schema("position_v1_dlc_training") @@ -107,7 +108,7 @@ class DLCModelTrainingSelection(SpyglassMixin, dj.Manual): """ def insert1(self, key, **kwargs): - training_id = key["training_id"] + training_id = key.get("training_id") if training_id is None: training_id = ( dj.U().aggr(self & key, n="max(training_id)").fetch1("n") or 0 @@ -185,6 +186,7 @@ def make(self, key): if k in training_dataset_input_args } logger.logger.info("creating training dataset") + # err here create_training_dataset(dlc_cfg_filepath, **training_dataset_kwargs) # ---- Trigger DLC model training job ---- train_network_input_args = list( @@ -198,6 +200,8 @@ def make(self, key): for k in ["shuffle", "trainingsetindex", "maxiters"]: if k in train_network_kwargs: train_network_kwargs[k] = int(train_network_kwargs[k]) + if test_mode: + train_network_kwargs["maxiters"] = 2 try: train_network(dlc_cfg_filepath, **train_network_kwargs) except ( diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index 1a422b86f..86487ad23 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -11,6 +11,7 @@ from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.common.common_position import IntervalPositionInfo from spyglass.position.v1.dlc_utils import check_videofile, get_video_path +from spyglass.settings import test_mode from spyglass.utils import SpyglassMixin, logger schema = dj.schema("position_v1_trodes_position") @@ -337,11 +338,23 @@ def convert_to_pixels(data, frame_size, cm_to_pixels=1.0): return data / cm_to_pixels @staticmethod - def fill_nan(variable, video_time, variable_time): + def fill_nan(variable, video_time, variable_time, truncate_data=False): + """Fill in missing values in variable with nans at video_time. + + Parameters + ---------- + variable : ndarray, shape (n_time,) or (n_time, n_dims) + The variable to fill in. + video_time : ndarray, shape (n_video_time,) + The time points of the video. + variable_time : ndarray, shape (n_variable_time,) + The time points of the variable. + """ # TODO: Reduce duplication across dlc_utils and common_position - video_ind = np.digitize(variable_time, video_time[1:]) + video_ind = np.digitize(variable_time, video_time[1:]) n_video_time = len(video_time) + try: n_variable_dims = variable.shape[1] filled_variable = np.full((n_video_time, n_variable_dims), np.nan) @@ -365,6 +378,7 @@ def make_video( disable_progressbar=False, arrow_radius=15, circle_radius=8, + truncate_data=False, # reduce data to min length across all variables ): import cv2 @@ -382,8 +396,31 @@ def make_video( output_video_filename, fourcc, frame_rate, frame_size, True ) + if test_mode or truncate_data: + # pytest video data has mismatched shapes in some cases + # centroid (267, 2), video_time (270, 2), position_time (5193,) + min_len = min( + n_frames, + len(video_time), + len(position_time), + len(position_mean), + len(orientation_mean), + min(len(v) for v in centroids.values()), + ) + n_frames = min_len + video_time = video_time[:min_len] + position_time = position_time[:min_len] + position_mean = position_mean[:min_len] + orientation_mean = orientation_mean[:min_len] + for color, data in centroids.items(): + centroids[color] = data[:min_len] + centroids = { - color: self.fill_nan(data, video_time, position_time) + color: self.fill_nan( + variable=data, + video_time=video_time, + variable_time=position_time, + ) for color, data in centroids.items() } position_mean = self.fill_nan(position_mean, video_time, position_time) diff --git a/src/spyglass/utils/database_settings.py b/src/spyglass/utils/database_settings.py index 7e1834313..1ad6efaa4 100755 --- a/src/spyglass/utils/database_settings.py +++ b/src/spyglass/utils/database_settings.py @@ -1,7 +1,9 @@ #!/usr/bin/env python + import os import sys import tempfile +from functools import cached_property from pathlib import Path import datajoint as dj @@ -37,6 +39,7 @@ def __init__( target_database=None, exec_user=None, exec_pass=None, + test_mode=False, ): """Class to manage common database settings @@ -66,6 +69,9 @@ def __init__( User for executing commands. If None, use dj.config exec_pass : str, optional Password for executing commands. If None, use dj.config + test_mode : bool, optional + Default False. If True, prepend sudo to commands for use in CI/CD + Only true in github actions, not true in local testing. """ self.shared_modules = [f"{m}{ESC}" for m in SHARED_MODULES] self.user = user_name or dj.config["database.user"] @@ -76,6 +82,7 @@ def __init__( self.target_database = target_database or "mysql" self.exec_user = exec_user or dj.config["database.user"] self.exec_pass = exec_pass or dj.config["database.password"] + self.test_mode = test_mode @property def _create_roles_dict(self): @@ -102,7 +109,7 @@ def _create_roles_dict(self): ], ) - @property + @cached_property def _create_roles_sql(self): return sum(self._create_roles_dict.values(), []) @@ -214,10 +221,16 @@ def exec(self, file): if self.debug: return + if self.test_mode: + prefix = "sudo mysql -h 127.0.0.1 -P 3308 -uroot -ptutorial" + else: + prefix = f"mysql -h {self.host} -u {self.exec_user} -p" + cmd = ( - f"mysql -p -h {self.host} < {file.name}" + f"{prefix} < {file.name}" if self.target_database == "mysql" else f"docker exec -i {self.target_database} mysql -u " + f"{self.exec_user} --password={self.exec_pass} < {file.name}" ) + os.system(cmd) diff --git a/tests/README.md b/tests/README.md index 476dbb4c8..36b6ab71f 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,5 +1,25 @@ # PyTests +## Environment + +To allow pytest helpers to automatically dowlnoad requisite data, you'll need to +set credentials for Box. Consider adding these to a private `.env` file. + +- `UCSF_BOX_USER`: UCSF email address +- `UCSF_BOX_TOKEN`: Token generated from UCSF Box account + +To facilitate headless testing of various Qt-based tools as well as Tensorflow, +`pyproject.toml` includes some environment variables associated with the +display. These are... + +- `QT_QPA_PLATFORM`: Set to `offscreen` to prevent the need for a display. +- `TF_ENABLE_ONEDNN_OPTS`: Set to `1` to enable Tensorflow optimizations. +- `TF_CPP_MIN_LOG_LEVEL`: Set to `2` to suppress Tensorflow warnings. + + + +## Options + This directory is contains files for testing the code. Simply by running `pytest` from the root directory, all tests will be run with default parameters specified in `pyproject.toml`. Notable optional parameters include... @@ -7,7 +27,7 @@ specified in `pyproject.toml`. Notable optional parameters include... - Coverage items. The coverage report indicates what percentage of the code was included in tests. - - `--cov=spyglatss`: Which package should be described in the coverage report + - `--cov=spyglass`: Which package should be described in the coverage report - `--cov-report term-missing`: Include lines of items missing in coverage - Verbosity. @@ -18,23 +38,24 @@ specified in `pyproject.toml`. Notable optional parameters include... - Data and database. - - `--no-server`: Default False, launch Docker container from python. When - True, no server is started and tests attempt to connect to existing - container. + - `--base_dir`: Default `./tests/test_data/`. Where to store downloaded and + created files. - `--no-teardown`: Default False. When True, docker database tables are preserved on exit. Set to false to inspect output items after testing. - - `--my-datadir ./rel-path/`: Default `./tests/test_data/`. Where to store - created files. + - `--no-docker`: Default False, launch Docker container from python. When + True, no server is started and tests attempt to connect to existing + container. For github actions, `--no-docker` is set to configure the + container class as null. + - `--no-dlc`: Default False. When True, skip data downloads for and tests of + features that require DeepLabCut. - Incremental running. - - `-m`: Run tests with the - [given marker](https://docs.pytest.org/en/6.2.x/usage.html#specifying-tests-selecting-tests) - (e.g., `pytest -m current`). - - `--sw`: Stepwise. Continue from previously failed test when starting again. - `-s`: No capture. By including `from IPython import embed; embed()` in a test, and using this flag, you can open an IPython environment from within a test + - `-v`: Verbose. List individual tests, report pass/fail. + - `--sw`: Stepwise. Continue from previously failed test when starting again. - `--pdb`: Enter debug mode if a test fails. - `tests/test_file.py -k test_name`: To run just a set of tests, specify the file name at the end of the command. To run a single test, further specify diff --git a/tests/common/test_behav.py b/tests/common/test_behav.py index 1f4767dfb..6f2daa690 100644 --- a/tests/common/test_behav.py +++ b/tests/common/test_behav.py @@ -79,22 +79,18 @@ def test_populate_state_script(common, pop_state_script): ), "StateScript populate unexpected effect" -@pytest.mark.skip(reason="No video files in mini") -def test_videofile_no_transaction(common, mini_restr): - """Test no transaction""" - common.VideoFile()._no_transaction_make(mini_restr) - - -@pytest.mark.skip(reason="No video files in mini") -def test_videofile_update_entries(common): +def test_videofile_update_entries(common, video_keys): """Test update entries""" - common.VideoFile().update_entries() + key = common.VideoFile().fetch(as_dict=True)[0] + common.VideoFile().update_entries(key) -@pytest.mark.skip(reason="No video files in mini") -def test_videofile_getabspath(common, mini_restr): +def test_videofile_getabspath(common, video_keys): """Test get absolute path""" - common.VideoFile().getabspath(mini_restr) + key = video_keys[0] + path = common.VideoFile().get_abs_path(key) + file_part = key["nwb_file_name"].split("2")[0] + "_0" + str(key["epoch"]) + assert file_part in path, "VideoFile get_abs_path failed" @pytest.mark.skipif(not TEARDOWN, reason="No teardown: expect no change.") diff --git a/tests/common/test_position.py b/tests/common/test_position.py index b10c0654b..bb74e213c 100644 --- a/tests/common/test_position.py +++ b/tests/common/test_position.py @@ -1,29 +1,31 @@ +import numpy as np +import pandas as pd import pytest -@pytest.fixture +@pytest.fixture(scope="session") def common_position(common): yield common.common_position -@pytest.fixture +@pytest.fixture(scope="session") def interval_position_info(common_position): yield common_position.IntervalPositionInfo -@pytest.fixture +@pytest.fixture(scope="session") def default_param_key(): yield {"position_info_param_name": "default"} -@pytest.fixture +@pytest.fixture(scope="session") def interval_key(common): yield (common.IntervalList & "interval_list_name LIKE 'pos 0%'").fetch1( "KEY" ) -@pytest.fixture +@pytest.fixture(scope="session") def param_table(common_position, default_param_key, teardown): param_table = common_position.PositionInfoParameters() param_table.insert1(default_param_key, skip_duplicates=True) @@ -32,7 +34,7 @@ def param_table(common_position, default_param_key, teardown): param_table.delete(safemode=False) -@pytest.fixture +@pytest.fixture(scope="session") def upsample_position( common, common_position, @@ -63,7 +65,7 @@ def upsample_position( (param_table & upsample_param_key).delete(safemode=False) -@pytest.fixture +@pytest.fixture(scope="session") def interval_pos_key(upsample_position): yield upsample_position @@ -72,7 +74,7 @@ def test_interval_position_info_insert(common_position, interval_pos_key): assert common_position.IntervalPositionInfo & interval_pos_key -@pytest.fixture +@pytest.fixture(scope="session") def upsample_position_error( upsample_position, default_param_key, @@ -147,6 +149,76 @@ def test_interval_position_info_kwarg_alias(interval_position_info): ), "IntervalPositionInfo._fix_kwargs() should alias old arg names." -@pytest.mark.skip(reason="Not testing with video data yet.") -def test_position_video(common_position): - pass +@pytest.fixture(scope="session") +def position_video(common_position): + yield common_position.PositionVideo() + + +def test_position_video(position_video, upsample_position): + _ = position_video.populate() + assert len(position_video) == 1, "Failed to populate PositionVideo table." + + +def test_convert_to_pixels(position_video): + + data = np.array([[2, 4], [6, 8]]) + expect = np.array([[1, 2], [3, 4]]) + output = position_video.convert_to_pixels(data, "junk", 2) + + assert np.array_equal(output, expect), "Failed to convert to pixels." + + +@pytest.fixture(scope="session") +def rename_default_cols(common_position): + yield common_position._fix_col_names, ["xloc", "yloc", "xloc2", "yloc2"] + + +@pytest.mark.parametrize( + "col_type, cols", + [ + ("DEFAULT_COLS", ["xloc", "yloc", "xloc2", "yloc2"]), + ("ONE_IDX_COLS", ["xloc1", "yloc1", "xloc2", "yloc2"]), + ("ZERO_IDX_COLS", ["xloc0", "yloc0", "xloc1", "yloc1"]), + ], +) +def test_rename_columns(rename_default_cols, col_type, cols): + + _fix_col_names, defaults = rename_default_cols + df = pd.DataFrame([range(len(cols) + 1)], columns=["junk"] + cols) + result = _fix_col_names(df).columns.tolist() + + assert result == defaults, f"_fix_col_names failed to rename {col_type}." + + +def test_rename_three_d(rename_default_cols): + _fix_col_names, _ = rename_default_cols + three_d = ["junk", "x", "y", "z"] + df = pd.DataFrame([range(4)], columns=three_d) + result = _fix_col_names(df).columns.tolist() + + assert ( + result == three_d[1:] + ), "_fix_col_names failed to rename THREE_D_COLS." + + +def test_rename_non_default_columns(monkeypatch, rename_default_cols): + _fix_col_names, defaults = rename_default_cols + df = pd.DataFrame([range(4)], columns=["a", "b", "c", "d"]) + + # Monkeypatch the input function + monkeypatch.setattr("builtins.input", lambda _: "yes") + result = _fix_col_names(df).columns.tolist() + + assert ( + result == defaults + ), "_fix_col_names failed to rename non-default cols." + + +def test_rename_non_default_columns_err(monkeypatch, rename_default_cols): + _fix_col_names, defaults = rename_default_cols + df = pd.DataFrame([range(4)], columns=["a", "b", "c", "d"]) + + monkeypatch.setattr("builtins.input", lambda _: "no") + + with pytest.raises(ValueError): + _fix_col_names(df) diff --git a/tests/conftest.py b/tests/conftest.py index cd9350ff1..fe8ce1a5b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,7 @@ import warnings from contextlib import nullcontext from pathlib import Path -from subprocess import Popen +from shutil import rmtree as shutil_rmtree from time import sleep as tsleep import datajoint as dj @@ -18,15 +18,22 @@ import pynwb import pytest from datajoint.logging import logger as dj_logger +from numba import NumbaWarning +from pandas.errors import PerformanceWarning from .container import DockerMySQLManager +from .data_downloader import DataDownloader warnings.filterwarnings("ignore", category=UserWarning, module="hdmf") +warnings.filterwarnings("ignore", module="tensorflow") +warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn") +warnings.filterwarnings("ignore", category=PerformanceWarning, module="pandas") +warnings.filterwarnings("ignore", category=NumbaWarning, module="numba") # ------------------------------- TESTS CONFIG ------------------------------- # globals in pytest_configure: -# BASE_DIR, RAW_DIR, SERVER, TEARDOWN, VERBOSE, TEST_FILE, DOWNLOAD +# BASE_DIR, RAW_DIR, SERVER, TEARDOWN, VERBOSE, TEST_FILE, DOWNLOAD, NO_DLC def pytest_addoption(parser): @@ -39,10 +46,10 @@ def pytest_addoption(parser): Parameters ---------- --quiet-spy (bool): Default False. Allow print statements from Spyglass. + --base-dir (str): Default './tests/test_data/'. Dir for local input file. --no-teardown (bool): Default False. Delete pipeline on close. - --no-server (bool): Default False. Run datajoint server in Docker. - --datadir (str): Default './tests/test_data/'. Dir for local input file. - WARNING: not yet implemented. + --no-docker (bool): Default False. Run datajoint mysql server in Docker. + --no-dlc (bool): Default False. Skip DLC tests. Also skip video downloads. """ parser.addoption( "--quiet-spy", @@ -52,11 +59,11 @@ def pytest_addoption(parser): help="Quiet logging from Spyglass.", ) parser.addoption( - "--no-server", - action="store_true", - dest="no_server", - default=False, - help="Do not launch datajoint server in Docker.", + "--base-dir", + action="store", + default="./tests/_data/", + dest="base_dir", + help="Directory for local input file.", ) parser.addoption( "--no-teardown", @@ -66,20 +73,28 @@ def pytest_addoption(parser): help="Tear down tables after tests.", ) parser.addoption( - "--base-dir", - action="store", - default="./tests/_data/", - dest="base_dir", - help="Directory for local input file.", + "--no-docker", + action="store_true", + dest="no_docker", + default=False, + help="Do not launch datajoint server in Docker.", + ) + parser.addoption( + "--no-dlc", + action="store_true", + dest="no_dlc", + default=False, + help="Skip downloads for and tests of DLC-dependent features.", ) def pytest_configure(config): - global BASE_DIR, RAW_DIR, SERVER, TEARDOWN, VERBOSE, TEST_FILE, DOWNLOAD + global BASE_DIR, RAW_DIR, SERVER, TEARDOWN, VERBOSE, TEST_FILE, DOWNLOADS, NO_DLC TEST_FILE = "minirec20230622.nwb" TEARDOWN = not config.option.no_teardown VERBOSE = not config.option.quiet_spy + NO_DLC = config.option.no_dlc BASE_DIR = Path(config.option.base_dir).absolute() BASE_DIR.mkdir(parents=True, exist_ok=True) @@ -89,50 +104,16 @@ def pytest_configure(config): SERVER = DockerMySQLManager( restart=TEARDOWN, shutdown=TEARDOWN, - null_server=config.option.no_server, + null_server=config.option.no_docker, verbose=VERBOSE, ) - DOWNLOAD = download_data(verbose=VERBOSE) - -def data_is_downloaded(): - """Check if data is downloaded.""" - return os.path.exists(RAW_DIR / TEST_FILE) - - -def download_data(verbose=False): - """Download data from BOX using environment variable credentials. - - Note: In gh-actions, this is handled by the test-conda workflow. - """ - if data_is_downloaded(): - return None - UCSF_BOX_USER = os.environ.get("UCSF_BOX_USER") - UCSF_BOX_TOKEN = os.environ.get("UCSF_BOX_TOKEN") - if not all([UCSF_BOX_USER, UCSF_BOX_TOKEN]): - raise ValueError( - "Missing data, no credentials: UCSF_BOX_USER or UCSF_BOX_TOKEN." - ) - data_url = f"ftps://ftp.box.com/trodes_to_nwb_test_data/{TEST_FILE}" - - cmd = [ - "wget", - "--recursive", - "--no-host-directories", - "--no-directories", - "--user", - UCSF_BOX_USER, - "--password", - UCSF_BOX_TOKEN, - "-P", - RAW_DIR, - data_url, - ] - if not verbose: - cmd.insert(cmd.index("--recursive") + 1, "--no-verbose") - cmd_kwargs = dict(stdout=sys.stdout, stderr=sys.stderr) if verbose else {} - - return Popen(cmd, **cmd_kwargs) + DOWNLOADS = DataDownloader( + nwb_file_name=TEST_FILE, + base_dir=BASE_DIR, + verbose=VERBOSE, + download_dlc=not NO_DLC, + ) def pytest_unconfigure(config): @@ -231,10 +212,10 @@ def mini_path(raw_dir): path = raw_dir / TEST_FILE # wait for wget download to finish - if DOWNLOAD is not None: - DOWNLOAD.wait() + if (nwb_download := DOWNLOADS.file_downloads.get(TEST_FILE)) is not None: + nwb_download.wait() - # wait for gh-actions download to finish + # wait for download to finish timeout, wait, found = 60, 5, False for _ in range(timeout // wait): if path.exists(): @@ -248,6 +229,17 @@ def mini_path(raw_dir): yield path +@pytest.fixture(scope="session") +def nodlc(request): + yield NO_DLC + + +@pytest.fixture(scope="session") +def skipif_nodlc(request): + if NO_DLC: + yield pytest.mark.skip(reason="Skipping DLC-dependent tests.") + + @pytest.fixture(scope="session") def mini_copy_name(mini_path): from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename # noqa: E402 @@ -324,7 +316,7 @@ def mini_insert( yield close_nwb_files() - # Note: no need to run deletes in teardown, bc removing the container + # Note: teardown will remove the container, deleting all data @pytest.fixture(scope="session") @@ -403,6 +395,19 @@ def populate_exception(): yield PopulateException +# -------------------------- FIXTURES, COMMON TABLES -------------------------- + + +@pytest.fixture(scope="session") +def video_keys(common, base_dir): + for file, download in DOWNLOADS.file_downloads.items(): + if file.endswith(".h264") and download is not None: + download.wait() # wait for videos to finish downloading + DOWNLOADS.rename_files() + + return common.VideoFile().fetch(as_dict=True) + + # ------------------------- FIXTURES, POSITION TABLES ------------------------- @@ -439,11 +444,11 @@ def trodes_params(trodes_params_table, teardown): "params": { **params, "is_upsampled": 1, - "upsampling_sampling_rate": 500, + "upsampling_sampling_rate": 500, # TODO - lower this to speed up }, }, } - trodes_params_table.get_default() + _ = trodes_params_table.get_default() trodes_params_table.insert( [v for k, v in paramsets.items()], skip_duplicates=True ) @@ -771,3 +776,453 @@ def lfp_merge_key(populate_lfp): @pytest.fixture(scope="session") def lfp_v1_key(lfp, lfp_s_key): yield (lfp.v1.LFPV1 & lfp_s_key).fetch1("KEY") + + +# --------------------------- FIXTURES, DLC TABLES ---------------------------- +# ---------------- Note: DLCOutput is used to test RestrGraph ----------------- + + +@pytest.fixture(scope="session") +def bodyparts(sgp): + bps = ["whiteLED", "tailBase", "tailMid", "tailTip"] + sgp.v1.BodyPart.insert( + [{"bodypart": bp, "bodypart_description": "none"} for bp in bps], + skip_duplicates=True, + ) + + yield bps + + +@pytest.fixture(scope="session") +def dlc_project_tbl(sgp): + yield sgp.v1.DLCProject() + + +@pytest.fixture(scope="session") +def dlc_project_name(): + yield "pytest_proj" + + +@pytest.fixture(scope="session") +def insert_project( + verbose_context, + teardown, + dlc_project_name, + dlc_project_tbl, + common, + bodyparts, + mini_copy_name, +): + if NO_DLC: + pytest.skip("Skipping DLC-dependent tests.") + + from deeplabcut.utils.auxiliaryfunctions import read_config, write_config + + team_name = "sc_eb" + common.LabTeam.insert1({"team_name": team_name}, skip_duplicates=True) + with verbose_context: + project_key = dlc_project_tbl.insert_new_project( + project_name=dlc_project_name, + bodyparts=bodyparts, + lab_team=team_name, + frames_per_video=100, + video_list=[ + {"nwb_file_name": mini_copy_name, "epoch": 0}, + {"nwb_file_name": mini_copy_name, "epoch": 1}, + ], + skip_duplicates=True, + ) + config_path = (dlc_project_tbl & project_key).fetch1("config_path") + cfg = read_config(config_path) + cfg.update( + { + "numframes2pick": 2, + "maxiters": 2, + "scorer": team_name, + "skeleton": [ + ["whiteLED"], + [ + ["tailMid", "tailMid"], + ["tailBase", "tailBase"], + ["tailTip", "tailTip"], + ], + ], # eb's has video_sets: {1: {'crop': [0, 1260, 0, 728]}} + } + ) + + write_config(config_path, cfg) + + yield project_key, cfg, config_path + + if teardown: + (dlc_project_tbl & project_key).delete(safemode=False) + shutil_rmtree(str(Path(config_path).parent)) + + +@pytest.fixture(scope="session") +def project_key(insert_project): + yield insert_project[0] + + +@pytest.fixture(scope="session") +def dlc_config(insert_project): + yield insert_project[1] + + +@pytest.fixture(scope="session") +def config_path(insert_project): + yield insert_project[2] + + +@pytest.fixture(scope="session") +def project_dir(config_path): + yield Path(config_path).parent + + +@pytest.fixture(scope="session") +def extract_frames( + verbose_context, dlc_project_tbl, project_key, dlc_config, project_dir +): + with verbose_context: + dlc_project_tbl.run_extract_frames( + project_key, userfeedback=False, mode="automatic" + ) + vid_name = list(dlc_config["video_sets"].keys())[0].split("/")[-1] + label_dir = project_dir / "labeled-data" / vid_name.split(".")[0] + + yield label_dir + + for file in label_dir.glob("*png"): + if file.stem in ["img000", "img001"]: + continue + file.unlink() + + +@pytest.fixture(scope="session") +def labeled_vid_dir(extract_frames): + yield extract_frames + + +@pytest.fixture(scope="session") +def fix_downloaded(labeled_vid_dir, project_dir): + """Grabs CollectedData and img files from project_dir, moves to labeled""" + for file in project_dir.parent.parent.glob("*"): + if file.is_dir(): + continue + dest = labeled_vid_dir / file.name + if dest.exists(): + dest.unlink() + dest.write_bytes(file.read_bytes()) + # TODO: revert to rename before merge + # file.rename(labeled_vid_dir / file.name) + + yield + + +@pytest.fixture(scope="session") +def add_training_files(dlc_project_tbl, project_key, fix_downloaded): + dlc_project_tbl.add_training_files(project_key, skip_duplicates=True) + yield + + +@pytest.fixture(scope="session") +def dlc_training_params(sgp): + params_tbl = sgp.v1.DLCModelTrainingParams() + params_name = "pytest" + yield params_tbl, params_name + + +@pytest.fixture(scope="session") +def training_params_key(verbose_context, sgp, project_key, dlc_training_params): + params_tbl, params_name = dlc_training_params + with verbose_context: + params_tbl.insert_new_params( + paramset_name=params_name, + params={ + "trainingsetindex": 0, + "shuffle": 1, + "gputouse": None, + "TFGPUinference": False, + "net_type": "resnet_50", + "augmenter_type": "imgaug", + "video_sets": "test skipping param", + }, + skip_duplicates=True, + ) + yield {"dlc_training_params_name": params_name} + + +@pytest.fixture(scope="session") +def model_train_key(sgp, project_key, training_params_key): + _ = project_key.pop("config_path", None) + model_train_key = { + **project_key, + **training_params_key, + } + sgp.v1.DLCModelTrainingSelection().insert1( + { + **model_train_key, + "model_prefix": "", + }, + skip_duplicates=True, + ) + yield model_train_key + + +@pytest.fixture(scope="session") +def populate_training(sgp, fix_downloaded, model_train_key, add_training_files): + train_tbl = sgp.v1.DLCModelTraining + if len(train_tbl & model_train_key) == 0: + _ = add_training_files + _ = fix_downloaded + sgp.v1.DLCModelTraining.populate(model_train_key) + yield model_train_key + + +@pytest.fixture(scope="session") +def model_source_key(sgp, model_train_key, populate_training): + yield (sgp.v1.DLCModelSource & model_train_key).fetch1("KEY") + + +@pytest.fixture(scope="session") +def model_key(sgp, model_source_key): + model_key = {**model_source_key, "dlc_model_params_name": "default"} + _ = sgp.v1.DLCModelParams.get_default() + sgp.v1.DLCModelSelection().insert1(model_key, skip_duplicates=True) + yield model_key + + +@pytest.fixture(scope="session") +def populate_model(sgp, model_key): + model_tbl = sgp.v1.DLCModel + if model_tbl & model_key: + yield + else: + sgp.v1.DLCModel.populate(model_key) + yield + + +@pytest.fixture(scope="session") +def pose_estimation_key(sgp, mini_copy_name, populate_model, model_key): + yield sgp.v1.DLCPoseEstimationSelection.insert_estimation_task( + { + "nwb_file_name": mini_copy_name, + "epoch": 1, + "video_file_num": 0, + **model_key, + }, + task_mode="trigger", # trigger or load + params={"gputouse": None, "videotype": "mp4", "TFGPUinference": False}, + ) + + +@pytest.fixture(scope="session") +def populate_pose_estimation(sgp, pose_estimation_key): + pose_est_tbl = sgp.v1.DLCPoseEstimation() + if len(pose_est_tbl & pose_estimation_key) < 1: + pose_est_tbl.populate(pose_estimation_key) + yield pose_est_tbl + + +@pytest.fixture(scope="session") +def si_params_name(sgp, populate_pose_estimation): + params_name = "low_bar" + params_tbl = sgp.v1.DLCSmoothInterpParams + # if len(params_tbl & {"dlc_si_params_name": params_name}) < 1: + if True: # TODO: remove before merge + nan_params = params_tbl.get_nan_params() + nan_params["dlc_si_params_name"] = params_name + nan_params["params"].update( + { + "likelihood_thresh": 0.4, + "max_cm_between_pts": 100, + "num_inds_to_span": 50, + # Smoothing and Interpolation added later - must check + "smoothing_params": {"smoothing_duration": 0.05}, + "interp_params": {"max_cm_to_interp": 100}, + } + ) + params_tbl.insert1(nan_params, skip_duplicates=True) + + yield params_name + + +@pytest.fixture(scope="session") +def si_key(sgp, bodyparts, si_params_name, pose_estimation_key): + key = { + key: val + for key, val in pose_estimation_key.items() + if key in sgp.v1.DLCSmoothInterpSelection.primary_key + } + sgp.v1.DLCSmoothInterpSelection.insert( + [ + { + **key, + "bodypart": bodypart, + "dlc_si_params_name": si_params_name, + } + for bodypart in bodyparts[:1] + ], + skip_duplicates=True, + ) + yield key + + +@pytest.fixture(scope="session") +def populate_si(sgp, si_key, populate_pose_estimation): + sgp.v1.DLCSmoothInterp.populate() + yield + + +@pytest.fixture(scope="session") +def cohort_selection(sgp, si_key, si_params_name): + cohort_key = { + k: v + for k, v in { + **si_key, + "dlc_si_cohort_selection_name": "whiteLED", + "bodyparts_params_dict": { + "whiteLED": si_params_name, + }, + }.items() + if k not in ["bodypart", "dlc_si_params_name"] + } + sgp.v1.DLCSmoothInterpCohortSelection().insert1( + cohort_key, skip_duplicates=True + ) + yield cohort_key + + +@pytest.fixture(scope="session") +def cohort_key(sgp, cohort_selection): + yield cohort_selection.copy() + + +@pytest.fixture(scope="session") +def populate_cohort(sgp, cohort_selection, populate_si): + sgp.v1.DLCSmoothInterpCohort.populate(cohort_selection) + + +@pytest.fixture(scope="session") +def centroid_params(sgp): + params_tbl = sgp.v1.DLCCentroidParams + params_key = {"dlc_centroid_params_name": "one_test"} + if len(params_tbl & params_key) == 0: + params_tbl.insert1( + { + **params_key, + "params": { + "centroid_method": "one_pt_centroid", + "points": {"point1": "whiteLED"}, + "interpolate": True, + "interp_params": {"max_cm_to_interp": 100}, + "smooth": True, + "smoothing_params": { + "smoothing_duration": 0.05, + "smooth_method": "moving_avg", + }, + "max_LED_separation": 50, + "speed_smoothing_std_dev": 0.100, + }, + } + ) + yield params_key + + +@pytest.fixture(scope="session") +def centroid_selection(sgp, cohort_key, populate_cohort, centroid_params): + centroid_key = cohort_key.copy() + centroid_key = { + key: val + for key, val in cohort_key.items() + if key in sgp.v1.DLCCentroidSelection.primary_key + } + centroid_key.update(centroid_params) + sgp.v1.DLCCentroidSelection.insert1(centroid_key, skip_duplicates=True) + yield centroid_key + + +@pytest.fixture(scope="session") +def centroid_key(sgp, centroid_selection): + yield centroid_selection.copy() + + +@pytest.fixture(scope="session") +def populate_centroid(sgp, centroid_selection): + sgp.v1.DLCCentroid.populate(centroid_selection) + + +@pytest.fixture(scope="session") +def orient_params(sgp): + params_tbl = sgp.v1.DLCOrientationParams + params_key = {"dlc_orientation_params_name": "none"} + if len(params_tbl & params_key) == 0: + params_tbl.insert1( + { + **params_key, + "params": { + "orient_method": "none", + "bodypart1": "whiteLED", + "orientation_smoothing_std_dev": 0.001, + }, + } + ) + return params_key + + +@pytest.fixture(scope="session") +def orient_selection(sgp, cohort_key, orient_params): + orient_key = { + key: val + for key, val in cohort_key.items() + if key in sgp.v1.DLCOrientationSelection.primary_key + } + orient_key.update(orient_params) + sgp.v1.DLCOrientationSelection().insert1(orient_key, skip_duplicates=True) + yield orient_key + + +@pytest.fixture(scope="session") +def orient_key(sgp, orient_selection): + yield orient_selection.copy() + + +@pytest.fixture(scope="session") +def populate_orient(sgp, orient_selection): + sgp.v1.DLCOrientation().populate(orient_selection) + yield + + +@pytest.fixture(scope="session") +def dlc_selection(sgp, centroid_key, orient_key, populate_orient): + dlc_key = { + key: val + for key, val in centroid_key.items() + if key in sgp.v1.DLCPosV1.primary_key + } + dlc_key.update( + { + "dlc_si_cohort_centroid": centroid_key[ + "dlc_si_cohort_selection_name" + ], + "dlc_si_cohort_orientation": orient_key[ + "dlc_si_cohort_selection_name" + ], + "dlc_orientation_params_name": orient_key[ + "dlc_orientation_params_name" + ], + } + ) + sgp.v1.DLCPosSelection().insert1(dlc_key, skip_duplicates=True) + yield dlc_key + + +@pytest.fixture(scope="session") +def dlc_key(sgp, dlc_selection): + yield dlc_selection.copy() + + +@pytest.fixture(scope="session") +def populate_dlc(sgp, dlc_key): + sgp.v1.DLCPosV1().populate(dlc_key) + yield diff --git a/tests/container.py b/tests/container.py index fa26f1c46..b9d77263e 100644 --- a/tests/container.py +++ b/tests/container.py @@ -46,7 +46,7 @@ def __init__( self.mysql_version = mysql_version self.container_name = container_name self.port = port or "330" + self.mysql_version[0] - self.client = docker.from_env() + self.client = None if null_server else docker.from_env() self.null_server = null_server self.password = "tutorial" self.user = "root" @@ -64,10 +64,14 @@ def __init__( @property def container(self) -> docker.models.containers.Container: + if self.null_server: + return self.container_name return self.client.containers.get(self.container_name) @property def container_status(self) -> str: + if self.null_server: + return None try: self.container.reload() return self.container.status @@ -76,6 +80,8 @@ def container_status(self) -> str: @property def container_health(self) -> str: + if self.null_server: + return None try: self.container.reload() return self.container.health @@ -125,7 +131,6 @@ def wait(self, timeout=120, wait=3) -> None: wait : int Time to wait between checks in seconds. Default 5. """ - if self.null_server: return None if not self.container_status or self.container_status == "exited": @@ -209,9 +214,10 @@ def stop(self, remove=True) -> None: if not self.container_status or self.container_status == "exited": return + container_name = self.container_name self.container.stop() - self.logger.info(f"Container {self.container_name} stopped.") + self.logger.info(f"Container {container_name} stopped.") if remove: self.container.remove() - self.logger.info(f"Container {self.container_name} removed.") + self.logger.info(f"Container {container_name} removed.") diff --git a/tests/data_downloader.py b/tests/data_downloader.py new file mode 100644 index 000000000..98a254eda --- /dev/null +++ b/tests/data_downloader.py @@ -0,0 +1,139 @@ +from functools import cached_property +from os import environ as os_environ +from pathlib import Path +from subprocess import DEVNULL, Popen +from sys import stderr, stdout +from typing import Dict, Union + +UCSF_BOX_USER = os_environ.get("UCSF_BOX_USER") +UCSF_BOX_TOKEN = os_environ.get("UCSF_BOX_TOKEN") +BASE_URL = "ftps://ftp.box.com/trodes_to_nwb_test_data/" + +NON_DLC = 3 # First N items below are not for DeepLabCut +FILE_PATHS = [ + { + "relative_dir": "raw", + "target_name": "minirec20230622.nwb", + "url": BASE_URL + "minirec20230622.nwb", + }, + { + "relative_dir": "video", + "target_name": "20230622_minirec_01_s1.1.h264", + "url": BASE_URL + "20230622_sample_01_a1/20230622_sample_01_a1.1.h264", + }, + { + "relative_dir": "video", + "target_name": "20230622_minirec_02_s2.1.h264", + "url": BASE_URL + "20230622_sample_02_a1/20230622_sample_02_a1.1.h264", + }, + { + "relative_dir": "deeplabcut", + "target_name": "CollectedData_sc_eb.csv", + "url": BASE_URL + "minirec_dlc_items/CollectedData_sc_eb.csv", + }, + { + "relative_dir": "deeplabcut", + "target_name": "CollectedData_sc_eb.h5", + "url": BASE_URL + "minirec_dlc_items/CollectedData_sc_eb.h5", + }, + { + "relative_dir": "deeplabcut", + "target_name": "img000.png", + "url": BASE_URL + "minirec_dlc_items/img000.png", + }, + { + "relative_dir": "deeplabcut", + "target_name": "img001.png", + "url": BASE_URL + "minirec_dlc_items/img001.png", + }, +] + + +class DataDownloader: + def __init__( + self, + nwb_file_name, + file_paths=FILE_PATHS, + base_dir=".", + download_dlc=True, + verbose=True, + ): + if not all([UCSF_BOX_USER, UCSF_BOX_TOKEN]): + raise ValueError( + "Missing os.environ credentials: UCSF_BOX_USER, UCSF_BOX_TOKEN." + ) + if nwb_file_name != file_paths[0]["target_name"]: + raise ValueError( + f"Please adjust data_downloader.py to match: {nwb_file_name}" + ) + + self.cmd = [ + "wget", + "--recursive", + "--no-host-directories", + "--no-directories", + "--user", + UCSF_BOX_USER, + "--password", + UCSF_BOX_TOKEN, + "-P", # Then need relative path, then url + ] + + self.verbose = verbose + if not verbose: + self.cmd.insert(self.cmd.index("--recursive") + 1, "--no-verbose") + self.cmd_kwargs = dict(stdout=DEVNULL, stderr=DEVNULL) + else: + self.cmd_kwargs = dict(stdout=stdout, stderr=stderr) + + self.base_dir = Path(base_dir).resolve() + self.file_paths = file_paths if download_dlc else file_paths[:NON_DLC] + self.base_dir.mkdir(exist_ok=True) + + # Start downloads + _ = self.file_downloads + + def rename_files(self): + """Redund, but allows rerun later in startup process of conftest.""" + for path in self.file_paths: + target, url = path["target_name"], path["url"] + target_dir = self.base_dir / path["relative_dir"] + orig = target_dir / url.split("/")[-1] + dest = target_dir / target + + if orig.exists(): + orig.rename(dest) + + @cached_property # Only make list of processes once + def file_downloads(self) -> Dict[str, Union[Popen, None]]: + """{File: POpen/None} for each file. If exists/finished, None.""" + ret = dict() + self.rename_files() + for path in self.file_paths: + target, url = path["target_name"], path["url"] + target_dir = self.base_dir / path["relative_dir"] + dest = target_dir / target + + if dest.exists(): + ret[target] = None + continue + + target_dir.mkdir(exist_ok=True, parents=True) + ret[target] = Popen(self.cmd + [target_dir, url], **self.cmd_kwargs) + return ret + + def check_download(self, download, info): + if download is not None: + download.wait() + if download.returncode: + return download + return None + + @property + def download_errors(self): + ret = [] + for download, item in zip(self.file_downloads, self.file_paths): + if d_status := self.check_download(download, item): + ret.append(d_status) + continue + return ret diff --git a/tests/position/__init__.py b/tests/position/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/position/conftest.py b/tests/position/conftest.py new file mode 100644 index 000000000..c6c58d199 --- /dev/null +++ b/tests/position/conftest.py @@ -0,0 +1,92 @@ +""" +The following lines are not used in the course of regular pose processing and +can be removed so long as other functionality is not impacted. + +position_merge.py: 106-107, 110-123, 139-262 +dlc_decorators.py: 11, 16-18, 22 +dlc_reader.py : + 24, 38, 44-45, 51, 57-58, 61, 70, 74, 80-81, 135-137, 146, 149-162, 214, + 218 +dlc_utils.py : + 58, 61, 69, 72, 97-100, 104, 149-161, 232-235, 239-241, 246, 259, 280, + 293-305, 310-316, 328-341, 356-373, 395, 404, 480, 487-488, 530, 548-561, + 594-601, 611-612, 641-657, 682-736, 762-772, 787, 809-1286 +""" + +from itertools import product as iter_product + +import numpy as np +import pandas as pd +import pytest + + +@pytest.fixture(scope="session") +def dlc_video_params(sgp): + sgp.v1.DLCPosVideoParams.insert_default() + params_key = {"dlc_pos_video_params_name": "five_percent"} + sgp.v1.DLCPosVideoParams.insert1( + { + **params_key, + "params": { + "percent_frames": 0.05, + "incl_likelihood": True, + }, + }, + skip_duplicates=True, + ) + yield params_key + + +@pytest.fixture(scope="session") +def dlc_video_selection(sgp, dlc_key, dlc_video_params, populate_dlc): + s_key = {**dlc_key, **dlc_video_params} + sgp.v1.DLCPosVideoSelection.insert1(s_key, skip_duplicates=True) + yield dlc_key + + +@pytest.fixture(scope="session") +def populate_dlc_video(sgp, dlc_video_selection): + sgp.v1.DLCPosVideo.populate(dlc_video_selection) + yield sgp.v1.DLCPosVideo() + + +@pytest.fixture(scope="session") +def populate_evaluation(sgp, populate_model): + sgp.v1.DLCEvaluation.populate() + yield + + +def generate_led_df(leds, inc_vals=False): + """Returns df with all combinations of 1 and np.nan for each led. + + If inc_vals is True, the values will be incremented by 1 for each non-nan""" + all_vals = list(zip(*iter_product([1, np.nan], repeat=len(leds)))) + n_rows = len(all_vals[0]) + indices = np.random.uniform(1.6223e09, 1.6224e09, n_rows) + + data = dict() + for led, values in zip(leds, all_vals): + data.update( + { + (led, "video_frame_id"): { + i: f for i, f in zip(indices, range(n_rows + 1)) + }, + (led, "x"): {i: v for i, v in zip(indices, values)}, + (led, "y"): {i: v for i, v in zip(indices, values)}, + } + ) + df = pd.DataFrame(data) + + if not inc_vals: + return df + + count = [0] + + def increment_count(): + count[0] += 1 + return count[0] + + def process_value(x): + return increment_count() if x == 1 else x + + return df.applymap(process_value) diff --git a/tests/position/test_dlc_cent.py b/tests/position/test_dlc_cent.py new file mode 100644 index 000000000..a3675b2ae --- /dev/null +++ b/tests/position/test_dlc_cent.py @@ -0,0 +1,63 @@ +import numpy as np +import pytest + +from .conftest import generate_led_df + + +@pytest.fixture(scope="session") +def centroid_df(sgp, centroid_key, populate_centroid): + yield (sgp.v1.DLCCentroid & centroid_key).fetch1_dataframe() + + +def test_centroid_fetch1_dataframe(centroid_df): + df_cols = centroid_df.columns + exp_cols = [ + "video_frame_ind", + "position_x", + "position_y", + "velocity_x", + "velocity_y", + "speed", + ] + + assert all( + e in df_cols for e in exp_cols + ), f"Unexpected cols in position merge dataframe: {df_cols}" + + +@pytest.fixture(scope="session") +def params_tbl(sgp): + yield sgp.v1.DLCCentroidParams() + + +def test_insert_default_params(params_tbl): + ret = params_tbl.get_default() + assert "default" in params_tbl.fetch( + "dlc_centroid_params_name" + ), "Default params not inserted" + assert ( + ret["dlc_centroid_params_name"] == "default" + ), "Default params not inserted" + + +def test_validate_params(params_tbl): + params = params_tbl.get_default() + params["dlc_centroid_params_name"] = "other test" + params_tbl.insert1(params, skip_duplicates=True) + + +@pytest.mark.parametrize( + "key", ["four_led_centroid", "two_pt_centroid", "one_pt_centroid"] +) +def test_centroid_calcs(key, sgp): + points = sgp.v1.position_dlc_centroid._key_to_points[key] + func = sgp.v1.position_dlc_centroid._key_to_func_dict[key] + + df = generate_led_df(points) + ret = func(df, max_LED_separation=100, points={p: p for p in points}) + + assert np.all(ret[:-1] == 1), f"Centroid calculation failed for {key}" + assert np.all(np.isnan(ret[-1])), f"Centroid calculation failed for {key}" + + with pytest.raises(KeyError): + func(df) # Missing led separation/point names diff --git a/tests/position/test_dlc_model.py b/tests/position/test_dlc_model.py new file mode 100644 index 000000000..6f1ccf89d --- /dev/null +++ b/tests/position/test_dlc_model.py @@ -0,0 +1,18 @@ +import pytest + + +def test_model_params_default(sgp): + assert sgp.v1.DLCModelParams.get_default() == { + "dlc_model_params_name": "default", + "params": { + "params": {}, + "shuffle": 1, + "trainingsetindex": 0, + "model_prefix": "", + }, + } + + +def test_model_input_assert(sgp): + with pytest.raises(AssertionError): + sgp.v1.DLCModelInput().insert1({"config_path": "/fake/path/"}) diff --git a/tests/position/test_dlc_orient.py b/tests/position/test_dlc_orient.py new file mode 100644 index 000000000..826df4cf9 --- /dev/null +++ b/tests/position/test_dlc_orient.py @@ -0,0 +1,45 @@ +import numpy as np +import pandas as pd +import pytest + +from .conftest import generate_led_df + + +def test_insert_params(sgp): + params_name = "test_params" + params_key = {"dlc_orientation_params_name": params_name} + params_tbl = sgp.v1.DLCOrientationParams() + params_tbl.insert_params( + params_name=params_name, params={}, skip_duplicates=True + ) + assert params_tbl & params_key, "Failed to insert params" + + defaults = params_tbl.get_default() + assert ( + defaults.get("params", {}).get("bodypart1") == "greenLED" + ), "Failed to insert default params" + + +def test_orient_fetch1_dataframe(sgp, orient_key, populate_orient): + """Fetches dataframe, but example data has one led, no orientation""" + fetched_df = (sgp.v1.DLCOrientation & orient_key).fetch1_dataframe() + assert isinstance(fetched_df, pd.DataFrame) + + +@pytest.mark.parametrize( + "key, points, exp_sum", + [ + ("none", ["none"], 0.0), + ("red_green_orientation", ["bodypart1", "bodypart2"], -2.356), + ("red_led_bisector", ["led1", "led2", "led3"], -1.571), + ], +) +def test_orient_calcs(sgp, key, points, exp_sum): + func = sgp.v1.position_dlc_orient._key_to_func_dict[key] + + df = generate_led_df(points, inc_vals=True) + df_sum = np.nansum(func(df, **{p: p for p in points})) + + assert np.isclose( + df_sum, exp_sum, atol=0.001 + ), f"Failed to calculate orient via {key}" diff --git a/tests/position/test_dlc_pos_est.py b/tests/position/test_dlc_pos_est.py new file mode 100644 index 000000000..fdf055843 --- /dev/null +++ b/tests/position/test_dlc_pos_est.py @@ -0,0 +1,36 @@ +import pytest + + +@pytest.fixture(scope="session") +def pos_est_sel(sgp): + yield sgp.v1.position_dlc_pose_estimation.DLCPoseEstimationSelection() + + +@pytest.mark.usefixtures("skipif_nodlc") +def test_rename_non_default_columns(sgp, common, pos_est_sel, video_keys): + vid_path, vid_name, _, _ = sgp.v1.dlc_utils.get_video_path(video_keys[0]) + + input = "0, 10, 0, 1000" + output = pos_est_sel.get_video_crop(vid_path + vid_name, input) + expected = [0, 10, 0, 1000] + + assert ( + output == expected + ), f"{pos_est_sel.table_name}.get_video_crop did not return expected output" + + +def test_invalid_video(pos_est_sel, pose_estimation_key): + _ = pose_estimation_key # Ensure populated + example_key = pos_est_sel.fetch("KEY", as_dict=True)[0] + example_key["nwb_file_name"] = "invalid.nwb" + with pytest.raises(FileNotFoundError): + pos_est_sel.insert_estimation_task(example_key) + + +def test_pose_est_dataframe(populate_pose_estimation): + pose_cols = populate_pose_estimation.fetch_dataframe().columns + + for bp in ["tailBase", "tailMid", "tailTip"]: + for val in ["video_frame_ind", "x", "y"]: + col = (bp, val) + assert col in pose_cols, f"PoseEstimation df missing column {col}." diff --git a/tests/position/test_dlc_position.py b/tests/position/test_dlc_position.py new file mode 100644 index 000000000..94646f315 --- /dev/null +++ b/tests/position/test_dlc_position.py @@ -0,0 +1,64 @@ +import pytest + + +@pytest.fixture(scope="session") +def si_params_tbl(sgp): + yield sgp.v1.DLCSmoothInterpParams() + + +def test_si_params_default(si_params_tbl): + assert si_params_tbl.get_default() == { + "dlc_si_params_name": "default", + "params": { + "interp_params": {"max_cm_to_interp": 15}, + "interpolate": True, + "likelihood_thresh": 0.95, + "max_cm_between_pts": 20, + "num_inds_to_span": 20, + "smooth": True, + "smoothing_params": { + "smooth_method": "moving_avg", + "smoothing_duration": 0.05, + }, + }, + } + assert si_params_tbl.get_nan_params() == { + "dlc_si_params_name": "just_nan", + "params": { + "interpolate": False, + "likelihood_thresh": 0.95, + "max_cm_between_pts": 20, + "num_inds_to_span": 20, + "smooth": False, + }, + } + assert list(si_params_tbl.get_available_methods()) == [ + "moving_avg" + ], f"{si_params_tbl.table_name}: unexpected available methods" + + +def test_invalid_params_insert(si_params_tbl): + with pytest.raises(KeyError): + si_params_tbl.insert1({"params": "invalid"}) + + +@pytest.fixture(scope="session") +def si_df(sgp, si_key, populate_si, bodyparts): + yield ( + sgp.v1.DLCSmoothInterp() & {**si_key, "bodypart": bodyparts[0]} + ).fetch1_dataframe() + + +def test_cohort_fetch1_dataframe(si_df): + df_cols = si_df.columns + exp_cols = ["video_frame_ind", "x", "y"] + assert all( + e in df_cols for e in exp_cols + ), f"Unexpected cols in DLCSmoothInterp dataframe: {df_cols}" + + +def test_all_nans(populate_pose_estimation, sgp): + pose_est_tbl = populate_pose_estimation + df = pose_est_tbl.BodyPart().fetch1_dataframe() + with pytest.raises(ValueError): + sgp.v1.position_dlc_position.nan_inds(df, 10, 0.99, 10) diff --git a/tests/position/test_dlc_proj.py b/tests/position/test_dlc_proj.py new file mode 100644 index 000000000..7eaba196d --- /dev/null +++ b/tests/position/test_dlc_proj.py @@ -0,0 +1,68 @@ +import pytest + + +def test_bp_insert(sgp): + bp_tbl = sgp.v1.position_dlc_project.BodyPart() + + bp_w_desc, desc = "test_bp", "test_desc" + bp_no_desc = "test_bp_no_desc" + + bp_tbl.add_from_config([bp_w_desc], [desc]) + bp_tbl.add_from_config([bp_no_desc]) + + assert bp_tbl & { + "bodypart": bp_w_desc, + "description": desc, + }, "Bodypart with description not inserted correctly" + assert bp_tbl & { + "bodypart": bp_no_desc, + "description": bp_no_desc, + }, "Bodypart without description not inserted correctly" + + +def test_project_insert(dlc_project_tbl, project_key): + assert dlc_project_tbl & project_key, "Project not inserted correctly" + + +@pytest.fixture +def new_project_key(): + return { + "project_name": "test_project_name", + "bodyparts": ["bp1"], + "lab_team": "any", + "frames_per_video": 1, + "video_list": ["any"], + "groupname": "fake group", + } + + +def test_failed_name_insert( + dlc_project_tbl, dlc_project_name, config_path, new_project_key +): + new_project_key.update({"project_name": dlc_project_name}) + existing_key = dlc_project_tbl.insert_new_project( + project_name=dlc_project_name, + bodyparts=["bp1"], + lab_team="any", + frames_per_video=1, + video_list=["any"], + groupname="any", + ) + expected_key = { + "project_name": dlc_project_name, + "config_path": config_path, + } + assert ( + existing_key == expected_key + ), "Project re-insert did not return expected key" + + +def test_failed_group_insert(dlc_project_tbl, new_project_key): + with pytest.raises(ValueError): + dlc_project_tbl.insert_new_project(**new_project_key) + + +def test_extract_frames(extract_frames, labeled_vid_dir): + extracted_files = list(labeled_vid_dir.glob("*.png")) + stems = set([f.stem for f in extracted_files]) - {"img000", "img001"} + assert len(stems) == 2, "Incorrect number of frames extracted" diff --git a/tests/position/test_dlc_sel.py b/tests/position/test_dlc_sel.py new file mode 100644 index 000000000..35b33fe06 --- /dev/null +++ b/tests/position/test_dlc_sel.py @@ -0,0 +1,17 @@ +def test_dlcvideo_default(sgp): + expected_default = { + "dlc_pos_video_params_name": "default", + "params": { + "incl_likelihood": True, + "percent_frames": 1, + "video_params": {"arrow_radius": 20, "circle_radius": 6}, + }, + } + + # run twice to trigger fetch existing + assert sgp.v1.DLCPosVideoParams.get_default() == expected_default + assert sgp.v1.DLCPosVideoParams.get_default() == expected_default + + +def test_dlc_video_populate(populate_dlc_video): + assert len(populate_dlc_video) > 0, "DLCPosVideo table is empty" diff --git a/tests/position/test_dlc_train.py b/tests/position/test_dlc_train.py new file mode 100644 index 000000000..eefa26f66 --- /dev/null +++ b/tests/position/test_dlc_train.py @@ -0,0 +1,37 @@ +import pytest + + +def test_existing_params( + verbose_context, dlc_training_params, training_params_key +): + params_tbl, params_name = dlc_training_params + + _ = training_params_key # Ensure populated + params_query = params_tbl & {"dlc_training_params_name": params_name} + assert params_query, "Existing params not found" + + with verbose_context: + params_tbl.insert_new_params( + paramset_name=params_name, + params={ + "shuffle": 1, + "trainingsetindex": 0, + "net_type": "any", + "gputouse": None, + }, + skip_duplicates=False, + ) + + assert len(params_query) == 1, "Existing params duplicated" + + +@pytest.mark.usefixtures("skipif_nodlc") +def test_get_params(nodlc, verbose_context, dlc_training_params): + if nodlc: # Decorator wasn't working here, so duplicate skipif + pytest.skip(reason="Skipping DLC-dependent tests.") + + params_tbl, _ = dlc_training_params + with verbose_context: + accepted_params = params_tbl.get_accepted_params() + + assert accepted_params is not None, "Failed to get accepted params" diff --git a/tests/position/test_pos_merge.py b/tests/position/test_pos_merge.py new file mode 100644 index 000000000..047129cd5 --- /dev/null +++ b/tests/position/test_pos_merge.py @@ -0,0 +1,24 @@ +import pytest + + +@pytest.fixture(scope="session") +def merge_df(sgp, pos_merge, dlc_key, populate_dlc): + merge_key = (pos_merge.DLCPosV1 & dlc_key).fetch1("KEY") + yield (pos_merge & merge_key).fetch1_dataframe() + + +def test_merge_dlc_fetch1_dataframe(merge_df): + df_cols = merge_df.columns + exp_cols = [ + "video_frame_ind", + "position_x", + "position_y", + "orientation", + "velocity_x", + "velocity_y", + "speed", + ] + + assert all( + e in df_cols for e in exp_cols + ), f"Unexpected cols in position merge dataframe: {df_cols}" diff --git a/tests/position/test_trodes.py b/tests/position/test_trodes.py index d4bc617f6..92fdfeeb1 100644 --- a/tests/position/test_trodes.py +++ b/tests/position/test_trodes.py @@ -59,3 +59,9 @@ def test_fetch_df(trodes_pos_v1, trodes_params): ) hash_exp = "5296e74dea2e5e68d39f81bc81723a12" assert hash_df == hash_exp, "Dataframe differs from expected" + + +def test_trodes_video(sgp): + vid_tbl = sgp.v1.TrodesPosVideo() + _ = vid_tbl.populate() + assert len(vid_tbl) == 2, "Failed to populate TrodesPosVideo" diff --git a/tests/utils/test_db_settings.py b/tests/utils/test_db_settings.py index 1c3efbead..3b72ec885 100644 --- a/tests/utils/test_db_settings.py +++ b/tests/utils/test_db_settings.py @@ -7,12 +7,16 @@ def db_settings(user_name): from spyglass.utils.database_settings import DatabaseSettings + id = getattr(docker_server.container, "id", None) + no_docker = id is None # If 'None', we're --no-docker in gh actions + return DatabaseSettings( user_name=user_name, host_name=docker_server.creds["database.host"], - target_database=docker_server.container.id, + target_database=id, exec_user=docker_server.creds["database.user"], exec_pass=docker_server.creds["database.password"], + test_mode=no_docker, ) diff --git a/tests/utils/test_mixin.py b/tests/utils/test_mixin.py index 010abf03c..5b6beb4d0 100644 --- a/tests/utils/test_mixin.py +++ b/tests/utils/test_mixin.py @@ -41,15 +41,19 @@ def test_merge_detect(Nwbfile, pos_merge_tables): ), "Merges not detected by mixin." -def test_merge_chain_join(Nwbfile, pos_merge_tables, lin_v1, lfp_merge_key): - """Test that the mixin can join merge chains.""" - _ = lin_v1, lfp_merge_key # merge tables populated +def test_merge_chain_join( + Nwbfile, pos_merge_tables, lin_v1, lfp_merge_key, populate_dlc +): + """Test that the mixin can join merge chains. + + NOTE: This will change if more data is added to merge tables.""" + _ = lin_v1, lfp_merge_key, populate_dlc # merge tables populated all_chains = [ chains.cascade(True, direction="down") for chains in Nwbfile._merge_chains.values() ] - end_len = [len(chain[0]) for chain in all_chains if chain] + end_len = [len(chain) for chain in all_chains] assert sum(end_len) == 4, "Merge chains not joined correctly."