Skip to content

Commit

Permalink
Merge pull request #340 from smash-transport/roch/fix_centrality_example
Browse files Browse the repository at this point in the history
Fix centrality class example, enhance documentation
  • Loading branch information
Hendrik1704 authored Nov 29, 2024
2 parents ec1e57f + fbd9b77 commit 885b942
Show file tree
Hide file tree
Showing 21 changed files with 447 additions and 315 deletions.
58 changes: 34 additions & 24 deletions src/sparkx/BaseStorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BaseStorer(ABC):
----------
num_output_per_event_ : numpy.array
Array containing the event number and the number of particles in this
event as :code:`num_output_per_event_[event i][num_output in event i]`
event as :code:`num_output_per_event_[event i][num_output in event i]`
(updated when filters are applied)
num_events_ : int
Number of events contained in the Oscar object (updated when filters
Expand Down Expand Up @@ -105,11 +105,11 @@ def __init__(
self.particle_list_,
self.num_events_,
self.num_output_per_event_,
self.custom_attr_list
self.custom_attr_list,
) = self.loader_.load(**kwargs)
else:
raise ValueError("Loader has not been created properly")

def __add__(self, other: "BaseStorer") -> "BaseStorer":
"""
Adds two BaseStorer objects by combining their particle lists and updating num_output_per_event accordingly.
Expand All @@ -134,12 +134,14 @@ def __add__(self, other: "BaseStorer") -> "BaseStorer":
"""
if not isinstance(other, BaseStorer):
raise TypeError("Can only add BaseStorer objects")

# Ensure that both instances are of the same class
if type(self) is not type(other):
raise TypeError("Can only add objects of the same class")

combined_particle_list: list = self.particle_list_ + other.particle_list_
combined_particle_list: list = (
self.particle_list_ + other.particle_list_
)

# Ensure num_output_per_event_ is not None
if self.num_output_per_event_ is None:
Expand All @@ -156,18 +158,22 @@ def __add__(self, other: "BaseStorer") -> "BaseStorer":
)

# Adjust event_number for the parts that originally belonged to other
combined_num_output_per_event[self.num_events_:, 0] += self.num_events_
combined_num_output_per_event[self.num_events_ :, 0] += self.num_events_

combined_storer: BaseStorer = self.__class__.__new__(self.__class__)
combined_storer.__dict__.update(self.__dict__) # Inherit all properties from self
combined_storer.__dict__.update(
self.__dict__
) # Inherit all properties from self
combined_storer._update_after_merge(other)
combined_storer.particle_list_ = combined_particle_list
combined_storer.num_output_per_event_ = combined_num_output_per_event
combined_storer.num_events_ = self.num_events_ + other.num_events_
combined_storer.loader_ = None # Loader is not applicable for combined object
combined_storer.loader_ = (
None # Loader is not applicable for combined object
)

return combined_storer

@abstractmethod
def _update_after_merge(self, other: "BaseStorer") -> None:
"""
Expand Down Expand Up @@ -203,7 +209,7 @@ def num_output_per_event(self) -> Optional[np.ndarray]:
:code:`num_output_per_event[event_n, number_of_particles_in_event_n]`
:code:`num_output_per_event` is updated with every manipulation e.g.
:code:`num_output_per_event` is updated with every manipulation e.g.
after applying cuts.
Returns
Expand Down Expand Up @@ -336,7 +342,7 @@ def particle_species(
Returns
-------
self : BaseStorer object
Containing only particle species specified by :code:`pdg_list` for
Containing only particle species specified by :code:`pdg_list` for
every event
"""
self.particle_list_ = particle_species(self.particle_list_, pdg_list)
Expand All @@ -348,7 +354,7 @@ def remove_particle_species(
self, pdg_list: Union[int, Union[Tuple[int], List[int], np.ndarray]]
) -> "BaseStorer":
"""
Remove particle species from :code:`particle_list` by their PDG ID in
Remove particle species from :code:`particle_list` by their PDG ID in
every event.
Parameters
Expand Down Expand Up @@ -409,7 +415,7 @@ def lower_event_energy_cut(
Parameters
----------
minimum_event_energy : int or float
The minimum event energy threshold. Should be a positive integer or
The minimum event energy threshold. Should be a positive integer or
float.
Returns
Expand All @@ -421,10 +427,10 @@ def lower_event_energy_cut(
Raises
------
TypeError
If the :code:`minimum_event_energy` parameter is not an integer or
If the :code:`minimum_event_energy` parameter is not an integer or
float.
ValueError
If the :code:`minimum_event_energy` parameter is less than or
If the :code:`minimum_event_energy` parameter is less than or
equal to 0.
"""
self.particle_list_ = lower_event_energy_cut(
Expand Down Expand Up @@ -502,7 +508,7 @@ def rapidity_cut(
cut_value : float
If a single value is passed, the cut is applied symmetrically
around 0.
For example, if :code:`cut_value = 1`, only particles with rapidity
For example, if :code:`cut_value = 1`, only particles with rapidity
in :code:`[-1.0, 1.0]` are kept.
cut_value : tuple
Expand Down Expand Up @@ -556,7 +562,7 @@ def spacetime_rapidity_cut(
) -> "BaseStorer":
"""
Apply spacetime rapidity (space-time rapidity) cut to all events and
remove all particles with spacetime rapidity not complying with
remove all particles with spacetime rapidity not complying with
cut_value.
Parameters
Expand Down Expand Up @@ -587,7 +593,7 @@ def spacetime_rapidity_cut(

def multiplicity_cut(
self, cut_value_tuple: Tuple[Union[float, None], Union[float, None]]
) -> "BaseStorer":
) -> "BaseStorer":
"""
Apply multiplicity cut. Remove all events with a multiplicity not
complying with cut_value.
Expand All @@ -596,7 +602,7 @@ def multiplicity_cut(
----------
cut_value_tuple : tuple
Upper and lower bound for multiplicity. If the multiplicity of an event is
not in this range, the event is discarded. The range is inclusive on the
not in this range, the event is discarded. The range is inclusive on the
lower bound and exclusive on the upper bound.
Returns
Expand Down Expand Up @@ -634,7 +640,7 @@ def spacetime_cut(
Returns
-------
self : BaseStorer object
Containing only particles complying with the spacetime cut for all
Containing only particles complying with the spacetime cut for all
events
"""
self.particle_list_ = spacetime_cut(
Expand Down Expand Up @@ -662,7 +668,7 @@ def particle_status(
Returns
-------
self : BaseStorer object
Containing only hadrons with status specified by
Containing only hadrons with status specified by
:code:`status_list` for every event
"""
self.particle_list_ = particle_status(self.particle_list_, status_list)
Expand Down Expand Up @@ -848,9 +854,13 @@ def _update_num_output_per_event_after_filter(self) -> None:
self.num_output_per_event_[1] = len(self.particle_list_[0])
elif self.num_output_per_event_.ndim == 2:
# Handle the case where num_output_per_event_ is a two-dimensional array
updated_num_output_per_event : np.ndarray = np.ndarray((len(self.particle_list_),2), dtype=int)
updated_num_output_per_event: np.ndarray = np.ndarray(
(len(self.particle_list_), 2), dtype=int
)
for event in range(len(self.particle_list_)):
updated_num_output_per_event[event][0] = event + self.num_output_per_event_[0][0]
updated_num_output_per_event[event][0] = (
event + self.num_output_per_event_[0][0]
)
updated_num_output_per_event[event][1] = len(
self.particle_list_[event]
)
Expand All @@ -869,7 +879,7 @@ def print_particle_lists_to_file(self, output_file: str) -> None:
Prints the particle lists to a specified file.
This method should be implemented by subclasses to print the particle
lists to the specified output file. The method raises a
lists to the specified output file. The method raises a
:code:`NotImplementedError` if it is not overridden by a subclass.
Parameters
Expand Down
49 changes: 23 additions & 26 deletions src/sparkx/CentralityClasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ class CentralityClasses:
"""
Class for defining centrality classes based on event multiplicity.
.. note::
It is the user's responsibility to ensure that the amount of events used
to determine the centrality classes is sufficient to provide reliable
results. The recommended minimum number of events is at least a few
hundred.
Parameters
----------
events_multiplicity : list or numpy.ndarray
Expand All @@ -27,7 +33,7 @@ class CentralityClasses:
Raises
------
TypeError
If :code:`events_multiplicity` or :code:`centrality_bins` is not a list
If :code:`events_multiplicity` or :code:`centrality_bins` is not a list
or :code:`numpy.ndarray`.
Attributes
Expand Down Expand Up @@ -58,10 +64,17 @@ class CentralityClasses:
.. code-block:: python
:linenos:
>>> centrality_obj = CentralityClasses(events_multiplicity=[10, 15, 20, 25],
... centrality_bins=[0, 25, 50, 75, 100])
>>> centrality_obj.get_centrality_class(18)
1
>>> from sparkx import *
>>> import random as rd
>>> rd.seed(0)
>>> # generate 300 random event multiplicities between 50 and 1500
>>> events_multiplicity = [rd.randint(50, 1500) for _ in range(300)]
>>> centrality_bins = [0, 10, 30, 50, 70, 90, 100]
>>> centrality_obj = CentralityClasses(events_multiplicity=events_multiplicity,
... centrality_bins=centrality_bins)
>>> centrality_obj.get_centrality_class(1490)
0
>>> centrality_obj.output_centrality_classes('centrality_output.txt')
"""

Expand Down Expand Up @@ -220,6 +233,11 @@ def get_centrality_class(self, dNchdEta: float) -> int:
This function determines the index of the centrality bin for a given
multiplicity value based on the predefined centrality classes.
In the case that the multiplicity input exceeds the largest or smallest
value of the multiplicity used to determine the centrality classes, the
function returns the index of the most central or most peripheral bin,
respectively.
Parameters
----------
dNchdEta : float
Expand All @@ -229,17 +247,6 @@ def get_centrality_class(self, dNchdEta: float) -> int:
-------
int
Index of the centrality bin.
Examples
--------
.. highlight:: python
.. code-block:: python
:linenos:
>>> centrality_obj = CentralityClasses(events_multiplicity=[10, 15, 20, 25],
... centrality_bins=[0, 25, 50, 75, 100])
>>> centrality_obj.get_centrality_class(18)
1
"""
# check if the multiplicity is in the most central bin
if dNchdEta >= self.dNchdetaMin_[0]:
Expand Down Expand Up @@ -271,16 +278,6 @@ def output_centrality_classes(self, fname: str) -> None:
TypeError
If :code:`fname` is not a string.
Examples
--------
.. highlight:: python
.. code-block:: python
:linenos:
>>> centrality_obj = CentralityClasses(events_multiplicity=[10, 15, 20, 25],
... centrality_bins=[0, 25, 50, 75, 100])
>>> centrality_obj.output_centrality_classes('centrality_output.txt')
Notes
-----
This function writes the centrality class information, including minimum,
Expand Down
28 changes: 19 additions & 9 deletions src/sparkx/EventCharacteristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,14 +351,14 @@ def generate_eBQS_densities_Milne_from_OSCAR_IC(
eta_range : list, tuple
A list containing the minimum and maximum values of spacetime
rapidity (eta) and the number of grid points.
output_filename : str
The name of the output file where the densities will be saved.
kernel : str
kernel : str
The type of kernel to use for smearing the particle data. Supported
values are 'gaussian' and 'covariant'. The default is "gaussian".
IC_info : str
A string containing info about the initial condition, e.g.,
collision energy or centrality.
Expand Down Expand Up @@ -405,8 +405,13 @@ def generate_eBQS_densities_Milne_from_OSCAR_IC(
)
if not isinstance(self.event_data_, list):
raise TypeError("The input is not a list.")
if not isinstance(kernel, str) or kernel not in {"gaussian", "covariant"}:
raise TypeError("Kernel must be a string and must be either 'covariant' or 'gaussian'.")
if not isinstance(kernel, str) or kernel not in {
"gaussian",
"covariant",
}:
raise TypeError(
"Kernel must be a string and must be either 'covariant' or 'gaussian'."
)

energy_density = Lattice3D(
x_min,
Expand Down Expand Up @@ -608,8 +613,8 @@ def generate_eBQS_densities_Minkowski_from_OSCAR_IC(
output_filename : str
The name of the output file where the densities will be saved.
kernel : str
kernel : str
The type of kernel to use for smearing the particle data. Supported
values are 'gaussian' and 'covariant'. The default is "gaussian".
Expand Down Expand Up @@ -651,8 +656,13 @@ def generate_eBQS_densities_Minkowski_from_OSCAR_IC(
)
if not isinstance(self.event_data_, list):
raise TypeError("The input is not a list.")
if not isinstance(kernel, str) or kernel not in {"gaussian", "covariant"}:
raise TypeError("Kernel must be a string and must be either 'covariant' or 'gaussian'.")
if not isinstance(kernel, str) or kernel not in {
"gaussian",
"covariant",
}:
raise TypeError(
"Kernel must be a string and must be either 'covariant' or 'gaussian'."
)

energy_density = Lattice3D(
x_min,
Expand Down
8 changes: 4 additions & 4 deletions src/sparkx/Filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def lower_event_energy_cut(
List with lists containing particle objects for the events
minimum_event_energy : int or float
The minimum event energy threshold. Should be a positive integer or
The minimum event energy threshold. Should be a positive integer or
float.
Returns
Expand Down Expand Up @@ -680,7 +680,7 @@ def pseudorapidity_cut(
cut_value : float
If a single value is passed, the cut is applied symmetrically
around 0.
For example, if :code:`cut_value = 1`, only particles with
For example, if :code:`cut_value = 1`, only particles with
pseudo-rapidity in [-1.0, 1.0] are kept.
cut_value : tuple
Expand Down Expand Up @@ -752,7 +752,7 @@ def spacetime_rapidity_cut(
cut_value : float
If a single value is passed, the cut is applied symmetrically
around 0.
For example, if :code:`cut_value = 1`, only particles with spacetime
For example, if :code:`cut_value = 1`, only particles with spacetime
rapidity in [-1.0, 1.0] are kept.
cut_value : tuple
Expand Down Expand Up @@ -822,7 +822,7 @@ def multiplicity_cut(
cut_value_tuple : tuple
Upper and lower bound for multiplicity. If the multiplicity of an event is
not in this range, the event is discarded. The range is inclusive on the
not in this range, the event is discarded. The range is inclusive on the
lower bound and exclusive on the upper bound.
Returns
Expand Down
2 changes: 1 addition & 1 deletion src/sparkx/Histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def add_value(
"""
Add value(s) to the latest histogram.
Different cases, if there is just one number added or a whole
Different cases, if there is just one number added or a whole
list/array of numbers.
Parameters
Expand Down
Loading

0 comments on commit 885b942

Please sign in to comment.