Skip to content

Commit

Permalink
Merge pull request #319 from smash-transport/sass/Histogram_bug_fix
Browse files Browse the repository at this point in the history
Sass/histogram bug fix
  • Loading branch information
nilssass authored Nov 22, 2024
2 parents e2cbd16 + 00ae01f commit d0263a2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/sparkx/Histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def average_weighted(self, weights: np.ndarray) -> "Histogram":
average = average.reshape(1, -1)

self.histograms_ = average
self.error_ = np.sqrt(variance)
self.error_ = np.sqrt(variance).reshape(1, -1)
self.systematic_error_ = np.sqrt(
np.average(self.systematic_error_**2.0, axis=0, weights=weights)
)
Expand Down
41 changes: 32 additions & 9 deletions tests/test_Histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,26 +103,28 @@ def test_add_value_single_number_to_multiple_histograms():
),
)


def test_add_value_out_of_range():
# Test adding values inside and outside the histograms range
# Test adding values inside and outside the histograms range
hist = Histogram((0, 20, 20))
valid_values = [1, 2, 10, 18]
for value in valid_values:
hist.add_value(value)

hist.add_histogram()

# Testing the second layer
outlier_values = [-1, 21, 40, 20]
with pytest.warns(UserWarning, match="Exceeding values are ignored"):
for value in outlier_values:
hist.add_value(value)
hist.add_value(value)

# Testing the first layer
hist1 = Histogram((0, 20, 20))
with pytest.warns(UserWarning, match="Exceeding values are ignored"):
for value in outlier_values:
hist1.add_value(value)
hist1.add_value(value)


def test_remove_bin_out_of_range():
# Test removing a bin at an index out of range
Expand Down Expand Up @@ -542,13 +544,34 @@ def test_scale_histogram_multiple_factors():
hist.add_value([1, 3, 5, 7, 9])
hist.add_histogram()
hist.add_value([2, 4, 6, 8, 9])

# Passing incorrect length of scaling list
with pytest.raises(ValueError):
# Passing incorrect length of scaling list
hist.scale_histogram([2, 0.5])
hist.scale_histogram([2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

# Scale histogram with valid scaling list
hist.scale_histogram([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
assert np.allclose(
hist.histogram(),
np.array(
[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 0, 2, 0, 2, 0, 2, 0, 2, 2]]
[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 0, 3, 0, 5, 0, 7, 0, 9, 10]]
),
)


def test_scale_histogram_after_averaging():
hist = Histogram((0, 10, 10))
hist.add_value([1, 1, 1, 1, 3, 5, 7, 9])
hist.add_histogram()
hist.add_value([2, 4, 4.5, 6, 8, 8, 9])

# Scale the averaged histogram
hist.average()
hist.scale_histogram(2.0)

expected_histogram = np.array(
[[0.0, 4.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 2.0, 2.0]]
)

assert hist.histogram().shape == expected_histogram.shape
assert np.allclose(hist.histogram(), expected_histogram)

0 comments on commit d0263a2

Please sign in to comment.