Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix multiplicity_cut() function #296

Merged
merged 3 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions src/sparkx/BaseStorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,16 +585,19 @@ def spacetime_rapidity_cut(

return self

def multiplicity_cut(self, min_multiplicity: int) -> "BaseStorer":
def multiplicity_cut(
self, cut_value_tuple: Tuple[Union[float, None], Union[float, None]]
) -> "BaseStorer":
"""
Apply multiplicity cut. Remove all events with a multiplicity lower
than :code:`min_multiplicity`.
Apply multiplicity cut. Remove all events with a multiplicity not
complying with cut_value.

Parameters
----------
min_multiplicity : int
Lower bound for multiplicity. If the multiplicity of an event is
lower than :code:`min_multiplicity`, this event is discarded.
cut_value_tuple : tuple
Upper and lower bound for multiplicity. If the multiplicity of an event is
not in this range, the event is discarded. The range is inclusive on the
lower bound and exclusive on the upper bound.

Returns
-------
Expand All @@ -603,7 +606,7 @@ def multiplicity_cut(self, min_multiplicity: int) -> "BaseStorer":
"""

self.particle_list_ = multiplicity_cut(
self.particle_list_, min_multiplicity
self.particle_list_, cut_value_tuple
)
self._update_num_output_per_event_after_filter()

Expand Down Expand Up @@ -840,16 +843,19 @@ def _update_num_output_per_event_after_filter(self) -> None:
raise ValueError("num_output_per_event_ is not set")
if self.particle_list_ is None:
raise ValueError("particle_list_ is not set")

if self.num_output_per_event_.ndim == 1:
# Handle the case where num_output_per_event_ is a one-dimensional array
self.num_output_per_event_[1] = len(self.particle_list_[0])
elif self.num_output_per_event_.ndim == 2:
# Handle the case where num_output_per_event_ is a two-dimensional array
updated_num_output_per_event : np.ndarray = np.ndarray((len(self.particle_list_),2), dtype=int)
for event in range(len(self.particle_list_)):
self.num_output_per_event_[event][1] = len(
updated_num_output_per_event[event][0] = event + self.num_output_per_event_[0][0]
updated_num_output_per_event[event][1] = len(
self.particle_list_[event]
)
self.num_output_per_event_ = updated_num_output_per_event
self.num_events_ = len(self.particle_list_)
else:
raise ValueError(
"num_output_per_event_ has an unexpected number of dimensions"
Expand Down
46 changes: 35 additions & 11 deletions src/sparkx/Filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,35 +809,59 @@ def spacetime_rapidity_cut(


def multiplicity_cut(
particle_list: List[List[Particle]], min_multiplicity: int
particle_list: List[List[Particle]], cut_value_tuple: tuple
) -> List[List[Particle]]:
"""
Apply multiplicity cut. Remove all events with a multiplicity lower
than :code:`min_multiplicity`.
Apply multiplicity cut. Remove all events with a multiplicity not complying
with cut_value_tuple.

Parameters
----------
particle_list:
List with lists containing particle objects for the events

min_multiplicity : int
Lower bound for multiplicity. If the multiplicity of an event is
lower than :code:`min_multiplicity`, this event is discarded.
cut_value_tuple : tuple
Upper and lower bound for multiplicity. If the multiplicity of an event is
not in this range, the event is discarded. The range is inclusive on the
lower bound and exclusive on the upper bound.

Returns
-------
list of lists
Filtered list of lists containing particle objects for each event
"""
if not isinstance(min_multiplicity, int):
raise TypeError("Input value for multiplicity cut must be an int")
if min_multiplicity < 0:
raise ValueError("Minimum multiplicity must >= 0")
if not isinstance(cut_value_tuple, tuple):
raise TypeError(
"Input value must be a tuple containing either "
+ "positive numbers or None of length two"
)

__ensure_tuple_is_valid_else_raise_error(cut_value_tuple, allow_none=True)

# Check if the cut limits are positive if they are not None
if (cut_value_tuple[0] is not None and cut_value_tuple[0] < 0) or (
cut_value_tuple[1] is not None and cut_value_tuple[1] < 0
):
raise ValueError("The cut limits must be positive or None")

if cut_value_tuple[0] is None:
lower_cut = float("-inf")
else:
lower_cut = cut_value_tuple[0]

if cut_value_tuple[1] is None:
upper_cut = float("inf")
else:
upper_cut = cut_value_tuple[1]

# Ensure cut values are in the correct order
lim_max = max(upper_cut, lower_cut)
lim_min = min(upper_cut, lower_cut)

idx_keep_event = []
for idx, event_particles in enumerate(particle_list):
multiplicity = len(event_particles)
if multiplicity >= min_multiplicity:
if multiplicity >= lower_cut and multiplicity < upper_cut:
idx_keep_event.append(idx)

particle_list = [particle_list[idx] for idx in idx_keep_event]
Expand Down
10 changes: 6 additions & 4 deletions src/sparkx/Jetscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class Jetscape(BaseStorer):

>>> jetscape = Jetscape(JETSCAPE_FILE_PATH)
>>>
>>> pions = jetscape.multiplicity_cut(500).participants().particle_species((211, -211, 111))
>>> pions = jetscape.multiplicity_cut(500, None).participants().particle_species((211, -211, 111))
>>>
>>> # save the pions of all events as nested list
>>> pions_list = pions.particle_list()
Expand All @@ -151,7 +151,7 @@ class Jetscape(BaseStorer):
Let's assume we only want to keep pions in events with a
multiplicity > 500:

>>> jetscape = Jetscape(JETSCAPE_FILE_PATH, filters={'multiplicity_cut':500, 'particle_species':(211, -211, 111)}})
>>> jetscape = Jetscape(JETSCAPE_FILE_PATH, filters={'multiplicity_cut':(500,None), 'particle_species':(211, -211, 111)}})
>>>
>>> # print the pions to a jetscape file
>>> jetscape.print_particle_lists_to_file('./particle_lists.dat')
Expand Down Expand Up @@ -363,13 +363,15 @@ def print_particle_lists_to_file(self, output_file: str) -> None:
raise ValueError("The number of output per event is empty.")
if self.num_events_ is None:
raise ValueError("The number of events is empty.")

# Open the output file with buffered writing (25 MB)
with open(output_file, "w", buffering=25 * 1024 * 1024) as f_out:
f_out.write(header_file)

list_of_particles = self.particle_list()
if self.num_events_ > 1:
if self.num_events_ == 0:
warnings.warn("The number of events is zero.")
elif self.num_events_ > 1:
for i in range(self.num_events_):
event = self.num_output_per_event_[i, 0]
num_out = self.num_output_per_event_[i, 1]
Expand Down
9 changes: 6 additions & 3 deletions src/sparkx/Oscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sparkx.Filter import *
import numpy as np
import warnings
from sparkx.loader.OscarLoader import OscarLoader
from sparkx.BaseStorer import BaseStorer
from typing import Any, List, Optional, Union, Dict
Expand Down Expand Up @@ -133,7 +134,7 @@ class Oscar(BaseStorer):

>>> oscar = Oscar(OSCAR_FILE_PATH)
>>>
>>> pions = oscar.multiplicity_cut(500).participants().particle_species((211, -211, 111))
>>> pions = oscar.multiplicity_cut(500, None).participants().particle_species((211, -211, 111))
>>>
>>> # save the pions of all events as nested list
>>> pions_list = pions.particle_list()
Expand All @@ -153,7 +154,7 @@ class Oscar(BaseStorer):
Let's assume we only want to keep participant pions in events with a
multiplicity > 500:

>>> oscar = Oscar(OSCAR_FILE_PATH, filters={'multiplicity_cut':500, 'participants':True, 'particle_species':(211, -211, 111)})
>>> oscar = Oscar(OSCAR_FILE_PATH, filters={'multiplicity_cut':(500,None), 'participants':True, 'particle_species':(211, -211, 111)})
>>>
>>> # print the pions to an oscar file
>>> oscar.print_particle_lists_to_file('./particle_lists.oscar')
Expand Down Expand Up @@ -387,7 +388,9 @@ def print_particle_lists_to_file(self, output_file: str) -> None:
raise ValueError("The number of output per event is empty.")
if self.num_events_ is None:
raise ValueError("The number of events is empty.")
if self.num_events_ > 1:
if self.num_events_ == 0:
NGoetz marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn("The number of events is zero.")
LucasConstantin marked this conversation as resolved.
Show resolved Hide resolved
elif self.num_events_ > 1:
for i in range(self.num_events_):
event = self.num_output_per_event_[i, 0]
num_out = self.num_output_per_event_[i, 1]
Expand Down
14 changes: 8 additions & 6 deletions tests/test_Filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,18 +770,20 @@ def particle_list_multiplicity():

def test_multiplicity_cut(particle_list_multiplicity):
# Test cases for valid input
assert multiplicity_cut(particle_list_multiplicity, 7) == [
assert multiplicity_cut(particle_list_multiplicity, (7,None)) == [
particle_list_multiplicity[1]
]
assert multiplicity_cut(particle_list_multiplicity, 11) == []
assert multiplicity_cut(particle_list_multiplicity, (5,10)) == [
particle_list_multiplicity[0]
]
assert multiplicity_cut(particle_list_multiplicity, (11, None)) == []

# Test cases for invalid input
with pytest.raises(TypeError):
multiplicity_cut(particle_list_multiplicity, min_multiplicity=3.5)

with pytest.raises(ValueError):
multiplicity_cut(particle_list_multiplicity, min_multiplicity=-1)
multiplicity_cut(particle_list_multiplicity, cut_value=(-3.5,4))

with pytest.raises(TypeError):
multiplicity_cut(particle_list_multiplicity, cut_value=(0,'a'))

NGoetz marked this conversation as resolved.
Show resolved Hide resolved
@pytest.fixture
def particle_list_status():
Expand Down
9 changes: 9 additions & 0 deletions tests/test_Jetscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,15 @@ def test_Jetscape_print_with_empty_events(
assert filecmp.cmp(jetscape_file_no_hadrons, output_path)
os.remove(output_path)

def test_Jetscape_print_with_no_events(
jetscape_file_path, output_path
):
jetscape = Jetscape(jetscape_file_path)
jetscape.particle_list_ = [[], [], [], [], []]
jetscape.multiplicity_cut((100000000,None))
with pytest.warns(UserWarning):
jetscape.print_particle_lists_to_file(output_path)
os.remove(output_path)

def test_Jetscape_get_sigmaGen(jetscape_file_path):
jetscape = Jetscape(jetscape_file_path)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_Oscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,16 @@ def test_standard_oscar_print(tmp_path, output_path):
os.remove(output_path)


def test_empty_oscar_print(tmp_path, output_path):
LucasConstantin marked this conversation as resolved.
Show resolved Hide resolved
tmp_oscar_file = create_temporary_oscar_file(
tmp_path, 5, "Oscar2013", [1, 7, 0, 36, 5]
)
oscar = Oscar(tmp_oscar_file).multiplicity_cut((100000000,None))
with pytest.warns(UserWarning):
oscar.print_particle_lists_to_file(output_path)
os.remove(output_path)


def test_extended_oscar_print(tmp_path, output_path):
tmp_oscar_file = create_temporary_oscar_file(
tmp_path, 5, "Oscar2013Extended", [4, 1, 42, 0, 3]
Expand Down
Loading