diff --git a/CHANGELOG.md b/CHANGELOG.md index f14ce7c8d..bf2134756 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ - Updated `TermSetWrapper` to support validating a single field within a compound array. @mavaylon1 [#1061](https://github.com/hdmf-dev/hdmf/pull/1061) - Updated testing to not install in editable mode and not run `coverage` by default. @rly [#1107](https://github.com/hdmf-dev/hdmf/pull/1107) - Add `post_init_method` parameter when generating classes to perform post-init functionality, i.e., validation. @mavaylon1 [#1089](https://github.com/hdmf-dev/hdmf/pull/1089) +- Exposed `progress_bar_class` to the `GenericDataChunkIterator` for more custom control over display of progress while iterating. @codycbakerphd [#1110](https://github.com/hdmf-dev/hdmf/pull/1110) - Updated loading, unloading, and getting the `TypeConfigurator` to support a `TypeMap` parameter. @mavaylon1 [#1117](https://github.com/hdmf-dev/hdmf/pull/1117) ### Bug Fixes diff --git a/src/hdmf/data_utils.py b/src/hdmf/data_utils.py index 23f0b4019..0e83bde2d 100644 --- a/src/hdmf/data_utils.py +++ b/src/hdmf/data_utils.py @@ -1,9 +1,9 @@ import copy import math from abc import ABCMeta, abstractmethod -from collections.abc import Iterable +from collections.abc import Iterable, Callable from warnings import warn -from typing import Tuple, Callable +from typing import Tuple from itertools import product, chain import h5py @@ -179,9 +179,15 @@ class GenericDataChunkIterator(AbstractDataChunkIterator): doc="Display a progress bar with iteration rate and estimated completion time.", default=False, ), + dict( + name="progress_bar_class", + type=Callable, + doc="The progress bar class to use. Defaults to tqdm.tqdm if the TQDM package is installed.", + default=None, + ), dict( name="progress_bar_options", - type=None, + type=dict, doc="Dictionary of keyword arguments to be passed directly to tqdm.", default=None, ), @@ -199,8 +205,23 @@ def __init__(self, **kwargs): HDF5 recommends chunk size in the range of 2 to 16 MB for optimal cloud performance. https://youtu.be/rcS5vt-mKok?t=621 """ - buffer_gb, buffer_shape, chunk_mb, chunk_shape, self.display_progress, progress_bar_options = getargs( - "buffer_gb", "buffer_shape", "chunk_mb", "chunk_shape", "display_progress", "progress_bar_options", kwargs + ( + buffer_gb, + buffer_shape, + chunk_mb, + chunk_shape, + self.display_progress, + progress_bar_class, + progress_bar_options, + ) = getargs( + "buffer_gb", + "buffer_shape", + "chunk_mb", + "chunk_shape", + "display_progress", + "progress_bar_class", + "progress_bar_options", + kwargs, ) self.progress_bar_options = progress_bar_options or dict() @@ -277,11 +298,13 @@ def __init__(self, **kwargs): try: from tqdm import tqdm + progress_bar_class = progress_bar_class or tqdm + if "total" in self.progress_bar_options: warn("Option 'total' in 'progress_bar_options' is not allowed to be over-written! Ignoring.") self.progress_bar_options.pop("total") - self.progress_bar = tqdm(total=self.num_buffers, **self.progress_bar_options) + self.progress_bar = progress_bar_class(total=self.num_buffers, **self.progress_bar_options) except ImportError: warn( "You must install tqdm to use the progress bar feature (pip install tqdm)! " diff --git a/tests/unit/utils_test/test_core_GenericDataChunkIterator.py b/tests/unit/utils_test/test_core_GenericDataChunkIterator.py index debac9cab..2117eb6d0 100644 --- a/tests/unit/utils_test/test_core_GenericDataChunkIterator.py +++ b/tests/unit/utils_test/test_core_GenericDataChunkIterator.py @@ -4,7 +4,7 @@ from pathlib import Path from tempfile import mkdtemp from shutil import rmtree -from typing import Tuple, Iterable, Callable +from typing import Tuple, Iterable, Callable, Union from sys import version_info import h5py @@ -408,6 +408,33 @@ def test_progress_bar(self): first_line = file.read() self.assertIn(member=desc, container=first_line) + @unittest.skipIf(not TQDM_INSTALLED, "optional tqdm module is not installed") + def test_progress_bar_class(self): + import tqdm + + class MyCustomProgressBar(tqdm.tqdm): + def update(self, n: int = 1) -> Union[bool, None]: + displayed = super().update(n) + print(f"Custom injection on step {n}") # noqa: T201 + + return displayed + + out_text_file = self.test_dir / "test_progress_bar_class.txt" + desc = "Testing progress bar..." + with open(file=out_text_file, mode="w") as file: + iterator = self.TestNumpyArrayDataChunkIterator( + array=self.test_array, + display_progress=True, + progress_bar_class=MyCustomProgressBar, + progress_bar_options=dict(file=file, desc=desc), + ) + j = 0 + for buffer in iterator: + j += 1 # dummy operation; must be silent for proper updating of bar + with open(file=out_text_file, mode="r") as file: + first_line = file.read() + self.assertIn(member=desc, container=first_line) + @unittest.skipIf(not TQDM_INSTALLED, "optional tqdm module is installed") def test_progress_bar_no_options(self): dci = self.TestNumpyArrayDataChunkIterator(array=self.test_array, display_progress=True)