diff --git a/src/sparkx/CentralityClasses.py b/src/sparkx/CentralityClasses.py index 6d11b7ad..878bb67b 100644 --- a/src/sparkx/CentralityClasses.py +++ b/src/sparkx/CentralityClasses.py @@ -1,4 +1,5 @@ import numpy as np +import warnings class CentralityClasses: @@ -58,8 +59,32 @@ def __init__(self,events_multiplicity,centrality_bins): if not isinstance(centrality_bins, (list,np.ndarray)): raise TypeError("'centrality_bins' is not list or numpy.ndarray") + # Check if centrality_bins is sorted + if not all(centrality_bins[i] <= centrality_bins[i+1] for i in range(len(centrality_bins)-1)): + warnings.warn("'centrality_bins' is not sorted. Sorting automatically.") + centrality_bins.sort() + + # Check for uniqueness of values + # Remove duplicates from the list + unique_bins = [] + seen = set() + multiple_same_entries = False + for item in centrality_bins: + if item not in seen: + unique_bins.append(item) + seen.add(item) + else: + multiple_same_entries = True + + if multiple_same_entries: + warnings.warn("'centrality_bins' contains duplicate values. They are removed automatically.") + + # Check for negative values and values greater than 100 + if any(value < 0.0 or value > 100.0 for value in centrality_bins): + raise ValueError("'centrality_bins' contains values less than 0 or greater than 100.") + self.events_multiplicity_ = events_multiplicity - self.centrality_bins_ = centrality_bins + self.centrality_bins_ = unique_bins self.dNchdetaMin_ = [] self.dNchdetaMax_ = [] diff --git a/src/sparkx/Histogram.py b/src/sparkx/Histogram.py index b5bf450e..2fb68d5c 100755 --- a/src/sparkx/Histogram.py +++ b/src/sparkx/Histogram.py @@ -444,6 +444,8 @@ def scale_histogram(self,value): raise ValueError("The scaling factor of the histogram cannot be negative") elif isinstance(value, (list, np.ndarray)) and sum(1 for number in value if number < 0) > 0: raise ValueError("The scaling factor of the histogram cannot be negative") + elif isinstance(value, (list, np.ndarray)) and len(value) != self.number_of_bins_: + raise ValueError("The length of list/array not compatible with number_of_bins_ of the histogram") if self.histograms_.ndim == 1: if isinstance(value, (int, float, np.number)): @@ -479,7 +481,7 @@ def set_error(self,own_error): value: list, numpy.ndarray Values for the uncertainties of the individual bins. """ - if len(own_error) != self.number_of_bins_ and\ + if len(own_error) != self.number_of_bins_ or\ not isinstance(own_error, (list,np.ndarray)): error_message = "The input error has a different length than the"\ + " number of histogram bins or it is not a list/numpy.ndarray" @@ -498,7 +500,7 @@ def set_systematic_error(self,own_error): value: list, numpy.ndarray Values for the systematic uncertainties of the individual bins. """ - if len(own_error) != self.number_of_bins_ and\ + if len(own_error) != self.number_of_bins_ or\ not isinstance(own_error, (list,np.ndarray)): error_message = "The input error has a different length than the"\ + " number of histogram bins or it is not a list/numpy.ndarray" diff --git a/tests/test_CentralityClasses.py b/tests/test_CentralityClasses.py new file mode 100644 index 00000000..d3b4e95a --- /dev/null +++ b/tests/test_CentralityClasses.py @@ -0,0 +1,79 @@ +import pytest +import numpy as np +from sparkx.CentralityClasses import CentralityClasses + + +@pytest.fixture +def centrality_obj(): + bins = [0,10,20,30,40,50,60,70,80,90,100] + numbers_sequence = range(1,101) + multiplicities = [num for num in numbers_sequence for _ in range(100)] + return CentralityClasses(events_multiplicity=multiplicities, + centrality_bins=bins) + +def test_init_with_invalid_input(): + with pytest.raises(TypeError): + CentralityClasses(events_multiplicity=10, centrality_bins=[0, 25, 50, 75, 100]) + with pytest.raises(TypeError): + CentralityClasses(events_multiplicity=[0,10,20,30,40,50,60,70,80,90,100], + centrality_bins=0) + + numbers_sequence = range(1,101) + multiplicities = [num for num in numbers_sequence for _ in range(100)] + with pytest.raises(ValueError): + CentralityClasses(events_multiplicity=multiplicities,centrality_bins=[0,10,20,30,40,50,60,70,80,90,100,110]) + with pytest.raises(ValueError): + CentralityClasses(events_multiplicity=multiplicities,centrality_bins=[-10,0,10,20,30,40,50,60,70,80,90,100]) + with pytest.warns(UserWarning, match=r"'centrality_bins' contains duplicate values. They are removed automatically."): + a = CentralityClasses(events_multiplicity=multiplicities,centrality_bins=[0,10,20,30,40,40,50,60,70,80,90,100]) + assert a.centrality_bins_ == [0,10,20,30,40,50,60,70,80,90,100] + +def test_create_centrality_classes(centrality_obj): + # Assuming there are 10 bins, so there should be 10 minimum values + assert len(centrality_obj.dNchdetaMin_) == 10 + assert len(centrality_obj.dNchdetaMax_) == 10 + assert len(centrality_obj.dNchdetaAvg_) == 10 + assert len(centrality_obj.dNchdetaAvgErr_) == 10 + +def test_get_centrality_class(centrality_obj): + assert centrality_obj.get_centrality_class(99) == 0 + assert centrality_obj.get_centrality_class(105) == 0 + assert centrality_obj.get_centrality_class(0) == 9 + + assert centrality_obj.get_centrality_class(1) == 9 + assert centrality_obj.get_centrality_class(2) == 9 + assert centrality_obj.get_centrality_class(3) == 9 + assert centrality_obj.get_centrality_class(4) == 9 + assert centrality_obj.get_centrality_class(5) == 9 + assert centrality_obj.get_centrality_class(6) == 9 + assert centrality_obj.get_centrality_class(7) == 9 + assert centrality_obj.get_centrality_class(8) == 9 + assert centrality_obj.get_centrality_class(9) == 9 + assert centrality_obj.get_centrality_class(10) == 9 + assert centrality_obj.get_centrality_class(11) == 8 + assert centrality_obj.get_centrality_class(19) == 8 + assert centrality_obj.get_centrality_class(20) == 8 + assert centrality_obj.get_centrality_class(21) == 7 + assert centrality_obj.get_centrality_class(29) == 7 + +def test_output_centrality_classes(centrality_obj, tmp_path): + output_file = tmp_path / "centrality_output.txt" + centrality_obj.output_centrality_classes(str(output_file)) + assert output_file.is_file() + + # Check content of the output file + with open(output_file, 'r') as f: + lines = f.readlines() + assert lines[0].startswith("# CentralityMin CentralityMax") + assert len(lines) == 10 + +def test_create_centrality_classes_error(): + with pytest.raises(ValueError): + CentralityClasses(events_multiplicity=[1, 2, 3], centrality_bins=[0, 25, 50, 75, 100]) + + with pytest.raises(ValueError): + CentralityClasses(events_multiplicity=[10, 15, -20, 25], centrality_bins=[0, 25, 50, 75, 100]) + +def test_output_centrality_classes_with_invalid_fname(centrality_obj): + with pytest.raises(TypeError): + centrality_obj.output_centrality_classes(123) \ No newline at end of file diff --git a/tests/test_Histogram.py b/tests/test_Histogram.py new file mode 100644 index 00000000..73aa82e1 --- /dev/null +++ b/tests/test_Histogram.py @@ -0,0 +1,130 @@ +import numpy as np +import pytest +import csv +from sparkx.Histogram import Histogram + +def test_histogram_creation_with_tuple(): + # Test histogram creation with a tuple + hist = Histogram((0, 10, 10)) + assert hist.number_of_bins_ == 10 + assert np.allclose(hist.bin_edges_, np.linspace(0, 10, num=11)) + +def test_histogram_creation_with_list(): + # Test histogram creation with a list + hist = Histogram([0, 2, 4, 6, 8, 10]) + assert hist.number_of_bins_ == 5 + assert np.allclose(hist.bin_edges_, np.array([0, 2, 4, 6, 8, 10])) + +def test_histogram_creation_with_invalid_input(): + # Test histogram creation with invalid input + with pytest.raises(TypeError): + # Passing incorrect input type + hist = Histogram("invalid_input") + + with pytest.raises(ValueError): + # Passing tuple with invalid values + hist = Histogram((10, 0, 10)) # hist_min > hist_max + + with pytest.raises(ValueError): + # Passing tuple with non-integer number of bins + hist = Histogram((0, 10, 1.4)) + + with pytest.raises(ValueError): + hist = Histogram((0, 10, -1)) # num_bins <= 0 + +def test_set_error_with_invalid_input(): + # Test setting error with invalid input + hist = Histogram((0, 10, 10)) + with pytest.raises(ValueError): + # Passing incorrect length of error list + hist.set_error([1, 2, 3]) + +def test_set_systematic_error_with_invalid_input(): + # Test setting systematic error with invalid input + hist = Histogram((0, 10, 10)) + with pytest.raises(ValueError): + # Passing incorrect length of systematic error list + hist.set_systematic_error([1, 2, 3]) + +def test_add_value_single_number(): + # Test adding a single number to the histogram + hist = Histogram((0, 10, 10)) + hist.add_value(4.5) + assert np.allclose(hist.histogram(), np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0])) + +def test_add_value_list(): + # Test adding a list of numbers to the histogram + hist = Histogram((0, 10, 10)) + hist.add_value([1, 3, 5, 7, 9]) + assert np.allclose(hist.histogram(), np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1])) + +def test_average(): + # Test averaging histograms + hist = Histogram((0, 10, 10)) + hist.add_value([1, 3, 5, 7, 9]) + hist.add_histogram() + hist.add_value([2, 4, 6, 8, 9]) + hist.average() + assert np.allclose(hist.histogram(), np.array([0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.0])) + +def test_write_to_file(tmp_path): + # Test writing histograms to a file + hist = Histogram((0, 4, 4)) + hist.add_value([1, 2, 3]) + hist.statistical_error() + hist_labels = [{'bin_center': 'Bin Center', 'bin_low': 'Bin Low', 'bin_high': 'Bin High', + 'distribution': 'Distribution', 'stat_err+': 'Stat Error+', 'stat_err-': 'Stat Error-', + 'sys_err+': 'Sys Error+', 'sys_err-': 'Sys Error-'}] + filename = tmp_path / "test_histograms.csv" + hist.write_to_file(filename, hist_labels) + + # Check if the file exists + assert filename.is_file() + + # Read the file and verify its content + with open(filename, 'r') as f: + reader = csv.reader(f) + headers = next(reader) + expected_headers = ['Bin Center', 'Bin Low', 'Bin High', 'Distribution', + 'Stat Error+', 'Stat Error-', 'Sys Error+', 'Sys Error-'] + assert headers == expected_headers + + # Check the content of the file + rows = [row for row in reader] + rows = rows[:-1] # neglect the last empty line + expected_rows = [ + ['0.5', '0.0', '1.0', '0.0', '0.0', '0.0', '0.0', '0.0'], + ['1.5', '1.0', '2.0', '1.0', '1.0', '1.0', '0.0', '0.0'], + ['2.5', '2.0', '3.0', '1.0', '1.0', '1.0', '0.0', '0.0'], + ['3.5', '3.0', '4.0', '1.0', '1.0', '1.0', '0.0', '0.0'] + ] + assert rows == expected_rows + +def test_average_weighted(): + # Test weighted averaging histograms + hist = Histogram((0, 10, 10)) + hist.add_value([1, 3, 5, 7, 9]) + hist.add_histogram() + hist.add_value([2, 4, 6, 8, 9]) + weights = np.array([0.5, 0.5]) + hist.average_weighted(weights) + assert np.allclose(hist.histogram(), np.array([0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.0])) + +def test_scale_histogram_single_factor(): + # Test scaling histogram with a single factor + hist = Histogram((0, 10, 10)) + hist.add_value([1, 3, 5, 7, 9]) + hist.scale_histogram(2) + assert np.allclose(hist.histogram(), np.array([0, 2, 0, 2, 0, 2, 0, 2, 0, 2])) + +def test_scale_histogram_multiple_factors(): + # Test scaling histogram with multiple factors + hist = Histogram((0, 10, 10)) + hist.add_value([1, 3, 5, 7, 9]) + hist.add_histogram() + hist.add_value([2, 4, 6, 8, 9]) + with pytest.raises(ValueError): + # Passing incorrect length of scaling list + hist.scale_histogram([2, 0.5]) + hist.scale_histogram([2,2,2,2,2,2,2,2,2,2]) + assert np.allclose(hist.histogram(), np.array([[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 0, 2, 0, 2, 0, 2, 0, 2, 2]]))