Skip to content

Commit

Permalink
Expose progress bar class control (#1110)
Browse files Browse the repository at this point in the history
* expose progress bar class control

* update types

* grab progress bar class from kwargs

* fix

* swap back to callable but from typing

* swap from typing to collections

* Update CHANGELOG.md

* add test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
CodyCBakerPhD and pre-commit-ci[bot] authored May 20, 2024
1 parent 8a2658f commit e6e6c5b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Updated `TermSetWrapper` to support validating a single field within a compound array. @mavaylon1 [#1061](https://github.com/hdmf-dev/hdmf/pull/1061)
- Updated testing to not install in editable mode and not run `coverage` by default. @rly [#1107](https://github.com/hdmf-dev/hdmf/pull/1107)
- Add `post_init_method` parameter when generating classes to perform post-init functionality, i.e., validation. @mavaylon1 [#1089](https://github.com/hdmf-dev/hdmf/pull/1089)
- Exposed `progress_bar_class` to the `GenericDataChunkIterator` for more custom control over display of progress while iterating. @codycbakerphd [#1110](https://github.com/hdmf-dev/hdmf/pull/1110)
- Updated loading, unloading, and getting the `TypeConfigurator` to support a `TypeMap` parameter. @mavaylon1 [#1117](https://github.com/hdmf-dev/hdmf/pull/1117)

### Bug Fixes
Expand Down
35 changes: 29 additions & 6 deletions src/hdmf/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import copy
import math
from abc import ABCMeta, abstractmethod
from collections.abc import Iterable
from collections.abc import Iterable, Callable
from warnings import warn
from typing import Tuple, Callable
from typing import Tuple
from itertools import product, chain

import h5py
Expand Down Expand Up @@ -179,9 +179,15 @@ class GenericDataChunkIterator(AbstractDataChunkIterator):
doc="Display a progress bar with iteration rate and estimated completion time.",
default=False,
),
dict(
name="progress_bar_class",
type=Callable,
doc="The progress bar class to use. Defaults to tqdm.tqdm if the TQDM package is installed.",
default=None,
),
dict(
name="progress_bar_options",
type=None,
type=dict,
doc="Dictionary of keyword arguments to be passed directly to tqdm.",
default=None,
),
Expand All @@ -199,8 +205,23 @@ def __init__(self, **kwargs):
HDF5 recommends chunk size in the range of 2 to 16 MB for optimal cloud performance.
https://youtu.be/rcS5vt-mKok?t=621
"""
buffer_gb, buffer_shape, chunk_mb, chunk_shape, self.display_progress, progress_bar_options = getargs(
"buffer_gb", "buffer_shape", "chunk_mb", "chunk_shape", "display_progress", "progress_bar_options", kwargs
(
buffer_gb,
buffer_shape,
chunk_mb,
chunk_shape,
self.display_progress,
progress_bar_class,
progress_bar_options,
) = getargs(
"buffer_gb",
"buffer_shape",
"chunk_mb",
"chunk_shape",
"display_progress",
"progress_bar_class",
"progress_bar_options",
kwargs,
)
self.progress_bar_options = progress_bar_options or dict()

Expand Down Expand Up @@ -277,11 +298,13 @@ def __init__(self, **kwargs):
try:
from tqdm import tqdm

progress_bar_class = progress_bar_class or tqdm

if "total" in self.progress_bar_options:
warn("Option 'total' in 'progress_bar_options' is not allowed to be over-written! Ignoring.")
self.progress_bar_options.pop("total")

self.progress_bar = tqdm(total=self.num_buffers, **self.progress_bar_options)
self.progress_bar = progress_bar_class(total=self.num_buffers, **self.progress_bar_options)
except ImportError:
warn(
"You must install tqdm to use the progress bar feature (pip install tqdm)! "
Expand Down
29 changes: 28 additions & 1 deletion tests/unit/utils_test/test_core_GenericDataChunkIterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
from tempfile import mkdtemp
from shutil import rmtree
from typing import Tuple, Iterable, Callable
from typing import Tuple, Iterable, Callable, Union
from sys import version_info

import h5py
Expand Down Expand Up @@ -408,6 +408,33 @@ def test_progress_bar(self):
first_line = file.read()
self.assertIn(member=desc, container=first_line)

@unittest.skipIf(not TQDM_INSTALLED, "optional tqdm module is not installed")
def test_progress_bar_class(self):
import tqdm

class MyCustomProgressBar(tqdm.tqdm):
def update(self, n: int = 1) -> Union[bool, None]:
displayed = super().update(n)
print(f"Custom injection on step {n}") # noqa: T201

return displayed

out_text_file = self.test_dir / "test_progress_bar_class.txt"
desc = "Testing progress bar..."
with open(file=out_text_file, mode="w") as file:
iterator = self.TestNumpyArrayDataChunkIterator(
array=self.test_array,
display_progress=True,
progress_bar_class=MyCustomProgressBar,
progress_bar_options=dict(file=file, desc=desc),
)
j = 0
for buffer in iterator:
j += 1 # dummy operation; must be silent for proper updating of bar
with open(file=out_text_file, mode="r") as file:
first_line = file.read()
self.assertIn(member=desc, container=first_line)

@unittest.skipIf(not TQDM_INSTALLED, "optional tqdm module is installed")
def test_progress_bar_no_options(self):
dci = self.TestNumpyArrayDataChunkIterator(array=self.test_array, display_progress=True)
Expand Down

0 comments on commit e6e6c5b

Please sign in to comment.