From 97ef13cc3681740813984307263453451cd7ac77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zo=C3=AB=20Bilodeau?= Date: Tue, 2 Jul 2024 17:33:57 +0200 Subject: [PATCH] changed hadd progress bar --- src/hepconvert/histogram_adding.py | 19 +++---------------- tests/test_add_histograms.py | 8 ++++++++ 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/hepconvert/histogram_adding.py b/src/hepconvert/histogram_adding.py index eddcc7e..66c5f05 100644 --- a/src/hepconvert/histogram_adding.py +++ b/src/hepconvert/histogram_adding.py @@ -453,15 +453,11 @@ def add_histograms( keys = file.keys(filter_classname="TH[1|2|3][I|S|F|D|C]", cycle=False) if progress_bar is not False: tqdm = _utils.check_tqdm() - file_bar = progress_bar - hist_bar = tqdm.tqdm(desc="Histograms added") number_of_items = len(files) if progress_bar is True: - file_bar = tqdm.tqdm(desc="Files added") - hist_bar = tqdm.tqdm(desc="Histograms added") - else: - hist_bar = None - file_bar.reset(number_of_items) + file_bar = tqdm.tqdm(desc="Files summed") + file_bar.reset(number_of_items) + if same_names: if union: for i, _value in enumerate(files[1:]): @@ -491,8 +487,6 @@ def add_histograms( msg = f"File: {input_file} does not exist or is corrupt." raise FileNotFoundError(msg) from None if same_names: - if progress_bar and hist_bar: - hist_bar.reset(len(keys)) for key in keys: try: in_file[key] @@ -510,15 +504,10 @@ def add_histograms( else: h_sum = _hadd_3d(hists, in_file, key, first) - if progress_bar: - hist_bar.update(n=1) - if h_sum is not None: hists[key] = h_sum else: n_keys = in_file.keys(filter_classname="TH[1|2|3][I|S|F|D|C]", cycle=False) - if progress_bar: - hist_bar.reset(len(n_keys)) for i, _value in enumerate(keys): if len(in_file[n_keys[i]].axes) == 1: h_sum = _hadd_1d(out_file, in_file, keys[i], first, n_key=n_keys[i]) @@ -531,8 +520,6 @@ def add_histograms( if h_sum is not None: out_file[keys[i]] = h_sum - if progress_bar: - hist_bar.update(n=1) if progress_bar: file_bar.update(n=1) diff --git a/tests/test_add_histograms.py b/tests/test_add_histograms.py index b6888ab..5301adb 100644 --- a/tests/test_add_histograms.py +++ b/tests/test_add_histograms.py @@ -71,6 +71,14 @@ def test_simple(tmp_path): ).all +def test_glob(tmp_path): + hepconvert.add_histograms( + os.path.join(tmp_path, "dest.root"), + "tests/samples/hists", progress_bar=True + ) + +test_glob("tests/samples/") + def mult_1D(tmp_path, file_paths): gauss_1 = ROOT.TH1I("name1", "title", 5, -4, 4) gauss_1.FillRandom("gaus")