Skip to content

Commit

Permalink
Merge pull request #283 from smash-transport/ngoetz/bugfix_histogram
Browse files Browse the repository at this point in the history
Bugfix np.arrays in Histogram
  • Loading branch information
nilssass authored Oct 9, 2024
2 parents aa11ca7 + 950197f commit 526df69
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
24 changes: 21 additions & 3 deletions src/sparkx/Histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,9 @@ def average_weighted(self, weights):
variance = np.average(
(self.histograms_ - average) ** 2.0, axis=0, weights=weights
)
# Ensure the result is a 2D array
if average.ndim == 1:
average = average.reshape(1, -1)

self.histograms_ = average
self.error_ = np.sqrt(variance)
Expand All @@ -574,6 +577,9 @@ def average_weighted(self, weights):
self.histogram_raw_count_ = np.sum(self.histograms_raw_count_, axis=0)
self.scaling_ = self.scaling_[0]

if self.scaling_.ndim == 1:
self.scaling_ = self.scaling_.reshape(1, -1)

self.number_of_histograms_ = 1

return self
Expand Down Expand Up @@ -607,6 +613,10 @@ def average_weighted_by_error(self):
weights = 1 / self.error_**2
average = np.average(self.histograms_, axis=0, weights=weights)

# Ensure the result is a 2D array
if average.ndim == 1:
average = average.reshape(1, -1)

self.histograms_ = average
self.error_ = np.sqrt(
1.0 / np.sum(1.0 / np.square(self.error_), axis=0)
Expand All @@ -617,6 +627,9 @@ def average_weighted_by_error(self):
self.histogram_raw_count_ = np.sum(self.histograms_raw_count_, axis=0)
self.scaling_ = self.scaling_[0]

if self.scaling_.ndim == 1:
self.scaling_ = self.scaling_.reshape(1, -1)

self.number_of_histograms_ = 1

return self
Expand Down Expand Up @@ -688,9 +701,14 @@ def scale_histogram(self, value):
self.error_[-1] *= value

elif isinstance(value, (list, np.ndarray)):
self.histograms_[-1] *= np.asarray(value)
self.scaling_[-1] *= np.asarray(value)
self.scaling_[-1] *= np.asarray(value)
if np.asarray(value).shape != self.histograms_[-1].shape:
raise ValueError(
"The shape of the scaling factor array is not compatible with the histogram shape"
)

value_array=np.asarray(value)
self.histograms_[-1] *= value_array
self.scaling_[-1] *= value_array

def set_error(self, own_error):
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_Histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def test_average():
hist.error_, np.array([0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.0])
)
assert isinstance(hist.scaling_, np.ndarray)
assert not any(isinstance(i, np.ndarray) for i in hist.scaling_)
assert all(isinstance(i, np.ndarray) for i in hist.scaling_)
assert np.allclose(counts_summed, hist.histogram_raw_count_)


Expand Down Expand Up @@ -361,7 +361,7 @@ def test_average_weighted_by_error():
hist.error_, np.array([0.89442719, 1.41421356, 2.12132034]), atol=0.01
)
assert isinstance(hist.scaling_, np.ndarray)
assert not any(isinstance(i, np.ndarray) for i in hist.scaling_)
assert all(isinstance(i, np.ndarray) for i in hist.scaling_)
assert np.allclose(counts_summed, hist.histogram_raw_count_)


Expand Down

0 comments on commit 526df69

Please sign in to comment.