Skip to content

Commit

Permalink
Merge pull request #296 from smash-transport/constantin/fix_multiplic…
Browse files Browse the repository at this point in the history
…ity_cut

fix multiplicity_cut() function
  • Loading branch information
NGoetz authored Nov 18, 2024
2 parents ac35cc0 + 95c9bd4 commit cb5a3b8
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 33 deletions.
24 changes: 15 additions & 9 deletions src/sparkx/BaseStorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,16 +585,19 @@ def spacetime_rapidity_cut(

return self

def multiplicity_cut(self, min_multiplicity: int) -> "BaseStorer":
def multiplicity_cut(
self, cut_value_tuple: Tuple[Union[float, None], Union[float, None]]
) -> "BaseStorer":
"""
Apply multiplicity cut. Remove all events with a multiplicity lower
than :code:`min_multiplicity`.
Apply multiplicity cut. Remove all events with a multiplicity not
complying with cut_value.
Parameters
----------
min_multiplicity : int
Lower bound for multiplicity. If the multiplicity of an event is
lower than :code:`min_multiplicity`, this event is discarded.
cut_value_tuple : tuple
Upper and lower bound for multiplicity. If the multiplicity of an event is
not in this range, the event is discarded. The range is inclusive on the
lower bound and exclusive on the upper bound.
Returns
-------
Expand All @@ -603,7 +606,7 @@ def multiplicity_cut(self, min_multiplicity: int) -> "BaseStorer":
"""

self.particle_list_ = multiplicity_cut(
self.particle_list_, min_multiplicity
self.particle_list_, cut_value_tuple
)
self._update_num_output_per_event_after_filter()

Expand Down Expand Up @@ -840,16 +843,19 @@ def _update_num_output_per_event_after_filter(self) -> None:
raise ValueError("num_output_per_event_ is not set")
if self.particle_list_ is None:
raise ValueError("particle_list_ is not set")

if self.num_output_per_event_.ndim == 1:
# Handle the case where num_output_per_event_ is a one-dimensional array
self.num_output_per_event_[1] = len(self.particle_list_[0])
elif self.num_output_per_event_.ndim == 2:
# Handle the case where num_output_per_event_ is a two-dimensional array
updated_num_output_per_event : np.ndarray = np.ndarray((len(self.particle_list_),2), dtype=int)
for event in range(len(self.particle_list_)):
self.num_output_per_event_[event][1] = len(
updated_num_output_per_event[event][0] = event + self.num_output_per_event_[0][0]
updated_num_output_per_event[event][1] = len(
self.particle_list_[event]
)
self.num_output_per_event_ = updated_num_output_per_event
self.num_events_ = len(self.particle_list_)
else:
raise ValueError(
"num_output_per_event_ has an unexpected number of dimensions"
Expand Down
46 changes: 35 additions & 11 deletions src/sparkx/Filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,35 +809,59 @@ def spacetime_rapidity_cut(


def multiplicity_cut(
particle_list: List[List[Particle]], min_multiplicity: int
particle_list: List[List[Particle]], cut_value_tuple: tuple
) -> List[List[Particle]]:
"""
Apply multiplicity cut. Remove all events with a multiplicity lower
than :code:`min_multiplicity`.
Apply multiplicity cut. Remove all events with a multiplicity not complying
with cut_value_tuple.
Parameters
----------
particle_list:
List with lists containing particle objects for the events
min_multiplicity : int
Lower bound for multiplicity. If the multiplicity of an event is
lower than :code:`min_multiplicity`, this event is discarded.
cut_value_tuple : tuple
Upper and lower bound for multiplicity. If the multiplicity of an event is
not in this range, the event is discarded. The range is inclusive on the
lower bound and exclusive on the upper bound.
Returns
-------
list of lists
Filtered list of lists containing particle objects for each event
"""
if not isinstance(min_multiplicity, int):
raise TypeError("Input value for multiplicity cut must be an int")
if min_multiplicity < 0:
raise ValueError("Minimum multiplicity must >= 0")
if not isinstance(cut_value_tuple, tuple):
raise TypeError(
"Input value must be a tuple containing either "
+ "positive numbers or None of length two"
)

__ensure_tuple_is_valid_else_raise_error(cut_value_tuple, allow_none=True)

# Check if the cut limits are positive if they are not None
if (cut_value_tuple[0] is not None and cut_value_tuple[0] < 0) or (
cut_value_tuple[1] is not None and cut_value_tuple[1] < 0
):
raise ValueError("The cut limits must be positive or None")

if cut_value_tuple[0] is None:
lower_cut = float("-inf")
else:
lower_cut = cut_value_tuple[0]

if cut_value_tuple[1] is None:
upper_cut = float("inf")
else:
upper_cut = cut_value_tuple[1]

# Ensure cut values are in the correct order
lim_max = max(upper_cut, lower_cut)
lim_min = min(upper_cut, lower_cut)

idx_keep_event = []
for idx, event_particles in enumerate(particle_list):
multiplicity = len(event_particles)
if multiplicity >= min_multiplicity:
if multiplicity >= lower_cut and multiplicity < upper_cut:
idx_keep_event.append(idx)

particle_list = [particle_list[idx] for idx in idx_keep_event]
Expand Down
10 changes: 6 additions & 4 deletions src/sparkx/Jetscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class Jetscape(BaseStorer):
>>> jetscape = Jetscape(JETSCAPE_FILE_PATH)
>>>
>>> pions = jetscape.multiplicity_cut(500).participants().particle_species((211, -211, 111))
>>> pions = jetscape.multiplicity_cut(500, None).participants().particle_species((211, -211, 111))
>>>
>>> # save the pions of all events as nested list
>>> pions_list = pions.particle_list()
Expand All @@ -151,7 +151,7 @@ class Jetscape(BaseStorer):
Let's assume we only want to keep pions in events with a
multiplicity > 500:
>>> jetscape = Jetscape(JETSCAPE_FILE_PATH, filters={'multiplicity_cut':500, 'particle_species':(211, -211, 111)}})
>>> jetscape = Jetscape(JETSCAPE_FILE_PATH, filters={'multiplicity_cut':(500,None), 'particle_species':(211, -211, 111)}})
>>>
>>> # print the pions to a jetscape file
>>> jetscape.print_particle_lists_to_file('./particle_lists.dat')
Expand Down Expand Up @@ -363,13 +363,15 @@ def print_particle_lists_to_file(self, output_file: str) -> None:
raise ValueError("The number of output per event is empty.")
if self.num_events_ is None:
raise ValueError("The number of events is empty.")

# Open the output file with buffered writing (25 MB)
with open(output_file, "w", buffering=25 * 1024 * 1024) as f_out:
f_out.write(header_file)

list_of_particles = self.particle_list()
if self.num_events_ > 1:
if self.num_events_ == 0:
warnings.warn("The number of events is zero.")
elif self.num_events_ > 1:
for i in range(self.num_events_):
event = self.num_output_per_event_[i, 0]
num_out = self.num_output_per_event_[i, 1]
Expand Down
9 changes: 6 additions & 3 deletions src/sparkx/Oscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sparkx.Filter import *
import numpy as np
import warnings
from sparkx.loader.OscarLoader import OscarLoader
from sparkx.BaseStorer import BaseStorer
from typing import Any, List, Optional, Union, Dict
Expand Down Expand Up @@ -133,7 +134,7 @@ class Oscar(BaseStorer):
>>> oscar = Oscar(OSCAR_FILE_PATH)
>>>
>>> pions = oscar.multiplicity_cut(500).participants().particle_species((211, -211, 111))
>>> pions = oscar.multiplicity_cut(500, None).participants().particle_species((211, -211, 111))
>>>
>>> # save the pions of all events as nested list
>>> pions_list = pions.particle_list()
Expand All @@ -153,7 +154,7 @@ class Oscar(BaseStorer):
Let's assume we only want to keep participant pions in events with a
multiplicity > 500:
>>> oscar = Oscar(OSCAR_FILE_PATH, filters={'multiplicity_cut':500, 'participants':True, 'particle_species':(211, -211, 111)})
>>> oscar = Oscar(OSCAR_FILE_PATH, filters={'multiplicity_cut':(500,None), 'participants':True, 'particle_species':(211, -211, 111)})
>>>
>>> # print the pions to an oscar file
>>> oscar.print_particle_lists_to_file('./particle_lists.oscar')
Expand Down Expand Up @@ -387,7 +388,9 @@ def print_particle_lists_to_file(self, output_file: str) -> None:
raise ValueError("The number of output per event is empty.")
if self.num_events_ is None:
raise ValueError("The number of events is empty.")
if self.num_events_ > 1:
if self.num_events_ == 0:
warnings.warn("The number of events is zero.")
elif self.num_events_ > 1:
for i in range(self.num_events_):
event = self.num_output_per_event_[i, 0]
num_out = self.num_output_per_event_[i, 1]
Expand Down
14 changes: 8 additions & 6 deletions tests/test_Filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,18 +770,20 @@ def particle_list_multiplicity():

def test_multiplicity_cut(particle_list_multiplicity):
# Test cases for valid input
assert multiplicity_cut(particle_list_multiplicity, 7) == [
assert multiplicity_cut(particle_list_multiplicity, (7,None)) == [
particle_list_multiplicity[1]
]
assert multiplicity_cut(particle_list_multiplicity, 11) == []
assert multiplicity_cut(particle_list_multiplicity, (5,10)) == [
particle_list_multiplicity[0]
]
assert multiplicity_cut(particle_list_multiplicity, (11, None)) == []

# Test cases for invalid input
with pytest.raises(TypeError):
multiplicity_cut(particle_list_multiplicity, min_multiplicity=3.5)

with pytest.raises(ValueError):
multiplicity_cut(particle_list_multiplicity, min_multiplicity=-1)
multiplicity_cut(particle_list_multiplicity, cut_value=(-3.5,4))

with pytest.raises(TypeError):
multiplicity_cut(particle_list_multiplicity, cut_value=(0,'a'))

@pytest.fixture
def particle_list_status():
Expand Down
9 changes: 9 additions & 0 deletions tests/test_Jetscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,15 @@ def test_Jetscape_print_with_empty_events(
assert filecmp.cmp(jetscape_file_no_hadrons, output_path)
os.remove(output_path)

def test_Jetscape_print_with_no_events(
jetscape_file_path, output_path
):
jetscape = Jetscape(jetscape_file_path)
jetscape.particle_list_ = [[], [], [], [], []]
jetscape.multiplicity_cut((100000000,None))
with pytest.warns(UserWarning):
jetscape.print_particle_lists_to_file(output_path)
os.remove(output_path)

def test_Jetscape_get_sigmaGen(jetscape_file_path):
jetscape = Jetscape(jetscape_file_path)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_Oscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,16 @@ def test_standard_oscar_print(tmp_path, output_path):
os.remove(output_path)


def test_empty_oscar_print(tmp_path, output_path):
tmp_oscar_file = create_temporary_oscar_file(
tmp_path, 5, "Oscar2013", [1, 7, 0, 36, 5]
)
oscar = Oscar(tmp_oscar_file).multiplicity_cut((100000000,None))
with pytest.warns(UserWarning):
oscar.print_particle_lists_to_file(output_path)
os.remove(output_path)


def test_extended_oscar_print(tmp_path, output_path):
tmp_oscar_file = create_temporary_oscar_file(
tmp_path, 5, "Oscar2013Extended", [4, 1, 42, 0, 3]
Expand Down

0 comments on commit cb5a3b8

Please sign in to comment.