Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for Histogram and CentralityClasses #185

Merged
merged 4 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/sparkx/CentralityClasses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import warnings


class CentralityClasses:
Expand Down Expand Up @@ -58,8 +59,32 @@ def __init__(self,events_multiplicity,centrality_bins):
if not isinstance(centrality_bins, (list,np.ndarray)):
raise TypeError("'centrality_bins' is not list or numpy.ndarray")

# Check if centrality_bins is sorted
if not all(centrality_bins[i] <= centrality_bins[i+1] for i in range(len(centrality_bins)-1)):
warnings.warn("'centrality_bins' is not sorted. Sorting automatically.")
centrality_bins.sort()

# Check for uniqueness of values
# Remove duplicates from the list
unique_bins = []
seen = set()
multiple_same_entries = False
for item in centrality_bins:
if item not in seen:
unique_bins.append(item)
seen.add(item)
else:
multiple_same_entries = True

if multiple_same_entries:
warnings.warn("'centrality_bins' contains duplicate values. They are removed automatically.")

# Check for negative values and values greater than 100
if any(value < 0.0 or value > 100.0 for value in centrality_bins):
raise ValueError("'centrality_bins' contains values less than 0 or greater than 100.")

self.events_multiplicity_ = events_multiplicity
self.centrality_bins_ = centrality_bins
self.centrality_bins_ = unique_bins

self.dNchdetaMin_ = []
self.dNchdetaMax_ = []
Expand Down
6 changes: 4 additions & 2 deletions src/sparkx/Histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,8 @@ def scale_histogram(self,value):
raise ValueError("The scaling factor of the histogram cannot be negative")
elif isinstance(value, (list, np.ndarray)) and sum(1 for number in value if number < 0) > 0:
raise ValueError("The scaling factor of the histogram cannot be negative")
elif isinstance(value, (list, np.ndarray)) and len(value) != self.number_of_bins_:
raise ValueError("The length of list/array not compatible with number_of_bins_ of the histogram")

if self.histograms_.ndim == 1:
if isinstance(value, (int, float, np.number)):
Expand Down Expand Up @@ -479,7 +481,7 @@ def set_error(self,own_error):
value: list, numpy.ndarray
Values for the uncertainties of the individual bins.
"""
if len(own_error) != self.number_of_bins_ and\
if len(own_error) != self.number_of_bins_ or\
not isinstance(own_error, (list,np.ndarray)):
error_message = "The input error has a different length than the"\
+ " number of histogram bins or it is not a list/numpy.ndarray"
Expand All @@ -498,7 +500,7 @@ def set_systematic_error(self,own_error):
value: list, numpy.ndarray
Values for the systematic uncertainties of the individual bins.
"""
if len(own_error) != self.number_of_bins_ and\
if len(own_error) != self.number_of_bins_ or\
not isinstance(own_error, (list,np.ndarray)):
error_message = "The input error has a different length than the"\
+ " number of histogram bins or it is not a list/numpy.ndarray"
Expand Down
79 changes: 79 additions & 0 deletions tests/test_CentralityClasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest
import numpy as np
from sparkx.CentralityClasses import CentralityClasses


Hendrik1704 marked this conversation as resolved.
Show resolved Hide resolved
@pytest.fixture
def centrality_obj():
bins = [0,10,20,30,40,50,60,70,80,90,100]
numbers_sequence = range(1,101)
multiplicities = [num for num in numbers_sequence for _ in range(100)]
return CentralityClasses(events_multiplicity=multiplicities,
centrality_bins=bins)

def test_init_with_invalid_input():
with pytest.raises(TypeError):
CentralityClasses(events_multiplicity=10, centrality_bins=[0, 25, 50, 75, 100])
with pytest.raises(TypeError):
CentralityClasses(events_multiplicity=[0,10,20,30,40,50,60,70,80,90,100],
centrality_bins=0)

numbers_sequence = range(1,101)
multiplicities = [num for num in numbers_sequence for _ in range(100)]
with pytest.raises(ValueError):
CentralityClasses(events_multiplicity=multiplicities,centrality_bins=[0,10,20,30,40,50,60,70,80,90,100,110])
with pytest.raises(ValueError):
CentralityClasses(events_multiplicity=multiplicities,centrality_bins=[-10,0,10,20,30,40,50,60,70,80,90,100])
with pytest.warns(UserWarning, match=r"'centrality_bins' contains duplicate values. They are removed automatically."):
a = CentralityClasses(events_multiplicity=multiplicities,centrality_bins=[0,10,20,30,40,40,50,60,70,80,90,100])
assert a.centrality_bins_ == [0,10,20,30,40,50,60,70,80,90,100]

def test_create_centrality_classes(centrality_obj):
# Assuming there are 10 bins, so there should be 10 minimum values
assert len(centrality_obj.dNchdetaMin_) == 10
assert len(centrality_obj.dNchdetaMax_) == 10
assert len(centrality_obj.dNchdetaAvg_) == 10
assert len(centrality_obj.dNchdetaAvgErr_) == 10

def test_get_centrality_class(centrality_obj):
assert centrality_obj.get_centrality_class(99) == 0
assert centrality_obj.get_centrality_class(105) == 0
assert centrality_obj.get_centrality_class(0) == 9

assert centrality_obj.get_centrality_class(1) == 9
assert centrality_obj.get_centrality_class(2) == 9
assert centrality_obj.get_centrality_class(3) == 9
assert centrality_obj.get_centrality_class(4) == 9
assert centrality_obj.get_centrality_class(5) == 9
assert centrality_obj.get_centrality_class(6) == 9
assert centrality_obj.get_centrality_class(7) == 9
assert centrality_obj.get_centrality_class(8) == 9
assert centrality_obj.get_centrality_class(9) == 9
assert centrality_obj.get_centrality_class(10) == 9
assert centrality_obj.get_centrality_class(11) == 8
assert centrality_obj.get_centrality_class(19) == 8
assert centrality_obj.get_centrality_class(20) == 8
assert centrality_obj.get_centrality_class(21) == 7
assert centrality_obj.get_centrality_class(29) == 7

def test_output_centrality_classes(centrality_obj, tmp_path):
output_file = tmp_path / "centrality_output.txt"
centrality_obj.output_centrality_classes(str(output_file))
assert output_file.is_file()

# Check content of the output file
with open(output_file, 'r') as f:
lines = f.readlines()
assert lines[0].startswith("# CentralityMin CentralityMax")
assert len(lines) == 10

def test_create_centrality_classes_error():
with pytest.raises(ValueError):
CentralityClasses(events_multiplicity=[1, 2, 3], centrality_bins=[0, 25, 50, 75, 100])

with pytest.raises(ValueError):
CentralityClasses(events_multiplicity=[10, 15, -20, 25], centrality_bins=[0, 25, 50, 75, 100])

def test_output_centrality_classes_with_invalid_fname(centrality_obj):
with pytest.raises(TypeError):
centrality_obj.output_centrality_classes(123)
130 changes: 130 additions & 0 deletions tests/test_Histogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import numpy as np
import pytest
import csv
from sparkx.Histogram import Histogram

def test_histogram_creation_with_tuple():
# Test histogram creation with a tuple
hist = Histogram((0, 10, 10))
assert hist.number_of_bins_ == 10
assert np.allclose(hist.bin_edges_, np.linspace(0, 10, num=11))

def test_histogram_creation_with_list():
# Test histogram creation with a list
hist = Histogram([0, 2, 4, 6, 8, 10])
assert hist.number_of_bins_ == 5
assert np.allclose(hist.bin_edges_, np.array([0, 2, 4, 6, 8, 10]))

def test_histogram_creation_with_invalid_input():
# Test histogram creation with invalid input
with pytest.raises(TypeError):
# Passing incorrect input type
hist = Histogram("invalid_input")

with pytest.raises(ValueError):
# Passing tuple with invalid values
hist = Histogram((10, 0, 10)) # hist_min > hist_max

with pytest.raises(ValueError):
# Passing tuple with non-integer number of bins
hist = Histogram((0, 10, 1.4))

with pytest.raises(ValueError):
hist = Histogram((0, 10, -1)) # num_bins <= 0

def test_set_error_with_invalid_input():
# Test setting error with invalid input
hist = Histogram((0, 10, 10))
with pytest.raises(ValueError):
# Passing incorrect length of error list
hist.set_error([1, 2, 3])

def test_set_systematic_error_with_invalid_input():
# Test setting systematic error with invalid input
hist = Histogram((0, 10, 10))
with pytest.raises(ValueError):
# Passing incorrect length of systematic error list
hist.set_systematic_error([1, 2, 3])

def test_add_value_single_number():
# Test adding a single number to the histogram
hist = Histogram((0, 10, 10))
hist.add_value(4.5)
assert np.allclose(hist.histogram(), np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0]))

def test_add_value_list():
# Test adding a list of numbers to the histogram
hist = Histogram((0, 10, 10))
hist.add_value([1, 3, 5, 7, 9])
assert np.allclose(hist.histogram(), np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1]))

def test_average():
# Test averaging histograms
hist = Histogram((0, 10, 10))
hist.add_value([1, 3, 5, 7, 9])
hist.add_histogram()
hist.add_value([2, 4, 6, 8, 9])
hist.average()
assert np.allclose(hist.histogram(), np.array([0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.0]))

def test_write_to_file(tmp_path):
# Test writing histograms to a file
hist = Histogram((0, 4, 4))
hist.add_value([1, 2, 3])
hist.statistical_error()
hist_labels = [{'bin_center': 'Bin Center', 'bin_low': 'Bin Low', 'bin_high': 'Bin High',
'distribution': 'Distribution', 'stat_err+': 'Stat Error+', 'stat_err-': 'Stat Error-',
'sys_err+': 'Sys Error+', 'sys_err-': 'Sys Error-'}]
filename = tmp_path / "test_histograms.csv"
hist.write_to_file(filename, hist_labels)

# Check if the file exists
assert filename.is_file()

# Read the file and verify its content
with open(filename, 'r') as f:
reader = csv.reader(f)
headers = next(reader)
expected_headers = ['Bin Center', 'Bin Low', 'Bin High', 'Distribution',
'Stat Error+', 'Stat Error-', 'Sys Error+', 'Sys Error-']
assert headers == expected_headers

# Check the content of the file
rows = [row for row in reader]
rows = rows[:-1] # neglect the last empty line
expected_rows = [
['0.5', '0.0', '1.0', '0.0', '0.0', '0.0', '0.0', '0.0'],
['1.5', '1.0', '2.0', '1.0', '1.0', '1.0', '0.0', '0.0'],
['2.5', '2.0', '3.0', '1.0', '1.0', '1.0', '0.0', '0.0'],
['3.5', '3.0', '4.0', '1.0', '1.0', '1.0', '0.0', '0.0']
]
assert rows == expected_rows

def test_average_weighted():
# Test weighted averaging histograms
hist = Histogram((0, 10, 10))
hist.add_value([1, 3, 5, 7, 9])
hist.add_histogram()
hist.add_value([2, 4, 6, 8, 9])
weights = np.array([0.5, 0.5])
hist.average_weighted(weights)
assert np.allclose(hist.histogram(), np.array([0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.0]))

def test_scale_histogram_single_factor():
# Test scaling histogram with a single factor
hist = Histogram((0, 10, 10))
hist.add_value([1, 3, 5, 7, 9])
hist.scale_histogram(2)
assert np.allclose(hist.histogram(), np.array([0, 2, 0, 2, 0, 2, 0, 2, 0, 2]))

def test_scale_histogram_multiple_factors():
# Test scaling histogram with multiple factors
hist = Histogram((0, 10, 10))
hist.add_value([1, 3, 5, 7, 9])
hist.add_histogram()
hist.add_value([2, 4, 6, 8, 9])
with pytest.raises(ValueError):
# Passing incorrect length of scaling list
hist.scale_histogram([2, 0.5])
hist.scale_histogram([2,2,2,2,2,2,2,2,2,2])
assert np.allclose(hist.histogram(), np.array([[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 0, 2, 0, 2, 0, 2, 0, 2, 2]]))
Loading