Skip to content

Commit

Permalink
Merge pull request #874 from StingraySoftware/covariance_spectrum_fixes
Browse files Browse the repository at this point in the history
Covariance spectrum fixes
  • Loading branch information
matteobachetti authored Feb 3, 2025
2 parents 4e7ceec + ef47943 commit c298e2c
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 118 deletions.
1 change: 1 addition & 0 deletions docs/changes/874.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Set a lower limit to the number of photons in a segment of data used for certain ``VarEnergySpectrum`` subclasses. This avoids, e.g., spurious high covariance measurements in low-count data sets.
8 changes: 5 additions & 3 deletions stingray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,15 +319,17 @@ def __eq__(self, other_ts):
if not getattr(self, attr, None) == getattr(other_ts, attr, None):
return False
else:
if not np.array_equal(getattr(self, attr, None), getattr(other_ts, attr, None)):
if not np.array_equal(
getattr(self, attr, None), getattr(other_ts, attr, None), equal_nan=True
):
return False

for attr in self.array_attrs():
if not np.array_equal(getattr(self, attr), getattr(other_ts, attr)):
if not np.array_equal(getattr(self, attr), getattr(other_ts, attr), equal_nan=True):
return False

for attr in self.internal_array_attrs():
if not np.array_equal(getattr(self, attr), getattr(other_ts, attr)):
if not np.array_equal(getattr(self, attr), getattr(other_ts, attr), equal_nan=True):
return False

return True
Expand Down
273 changes: 186 additions & 87 deletions stingray/tests/test_varenergyspectrum.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from multiprocessing import Event
import os
import numpy as np
Expand Down Expand Up @@ -89,6 +90,14 @@ def test_no_spectrum_func_raises(self):
with pytest.raises(TypeError):
ref_int = VarEnergySpectrum(self.events, [0.0, 10000], (0.5, 5, 10, "log"), [0.3, 10])

@pytest.mark.parametrize("energy_spec", [2, "a"])
def test_invalid_energy_spec(self, energy_spec):
with pytest.raises(
ValueError,
match=f"Energy specification must be a tuple or a list .input: {energy_spec}.",
):
DummyVarEnergy(self.events, [0.0, 10000], energy_spec=energy_spec)

def test_ref_band_none(self):
events = EventList(
[0.09, 0.21, 0.23, 0.32, 0.4, 0.54], energy=[0, 0, 0, 0, 1, 1], gti=[[0, 0.65]]
Expand Down Expand Up @@ -210,20 +219,16 @@ def test_create_complexcovariance(self):
assert np.all(np.iscomplex(spec.spectrum))

@pytest.mark.parametrize("cross", [True, False])
@pytest.mark.parametrize("kind", ["rms", "cov", "lag"])
def test_empty_subband_cov(self, cross, kind):
def test_empty_subband_lag(self, cross):
ev2 = None
if cross:
ev2 = self.test_ev2_small

if kind == "rms":
func = RmsSpectrum
elif kind == "lag":
func = LagSpectrum
elif kind == "cov":
func = ComplexCovarianceSpectrum

spec = func(
# Note: energy_spec is a list, so it's actually the edges
# of the energy bins. So, the covariance spectrum will be
# calculated in two bands: 0.3-12 keV and 12-15 keV. But
# the 12-15 keV band is empty (see definition of test_ev1_small)
spec = LagSpectrum(
self.test_ev1_small,
freq_interval=[0.00001, 0.1],
energy_spec=[0.3, 12, 15],
Expand All @@ -235,6 +240,51 @@ def test_empty_subband_cov(self, cross, kind):
good = ~np.isnan(spec.spectrum)
assert np.count_nonzero(good) == 1

@pytest.mark.parametrize("cross", [True, False])
@pytest.mark.parametrize("kind", ["rms", "cov"])
def test_empty_subband_cov(self, cross, kind):
ev2 = None
if cross:
ev2 = self.test_ev2_small

if kind == "rms":
func = RmsSpectrum
elif kind == "cov":
func = ComplexCovarianceSpectrum
# Note: energy_spec is a list, so it's actually the edges
# of the energy bins. So, the covariance spectrum will be
# calculated in two bands: 0.3-12 keV and 12-15 keV. But
# the 12-15 keV band is empty (see definition of test_ev1_small)
with pytest.warns(UserWarning, match="Low count rate in the 12-15 subject band"):
spec = func(
self.test_ev1_small,
freq_interval=[0.00001, 0.1],
energy_spec=[0.3, 12, 15],
ref_band=[[0.3, 12]],
bin_time=self.bin_time / 2,
segment_size=200,
events2=ev2,
)
good = ~np.isnan(spec.spectrum)
assert np.count_nonzero(good) == 1

def test_empty_subband_cov_ev2(self):
ev2 = copy.deepcopy(self.test_ev2_small)
# We empty out only the second event list above 5 keV
ev2.filter_energy_range([0.3, 5], inplace=True)

with pytest.warns(UserWarning, match="Low count rate in the 5-12 subject band"):
spec = RmsSpectrum(
self.test_ev1_small,
freq_interval=[0.00001, 0.1],
energy_spec=[0.3, 5, 12],
bin_time=self.bin_time / 2,
segment_size=200,
events2=ev2,
)
good = ~np.isnan(spec.spectrum)
assert np.count_nonzero(good) == 1

@pytest.mark.parametrize("norm", ["frac", "abs"])
def test_correct_rms_values_vs_cross(self, norm):
"""The rms calculated with independent event lists (from the cospectrum)
Expand Down Expand Up @@ -355,6 +405,132 @@ def test_rms_invalid_evlist_warns(self):
assert np.all(np.isnan(rms.spectrum_error))


import abc


class BaseTestIO(abc.ABC):
@property
@abc.abstractmethod
def variant(self):
pass

@classmethod
def setup_class(cls):
if cls.variant == "rms":
cls.func = RmsSpectrum
elif cls.variant == "complcov":
cls.func = ComplexCovarianceSpectrum
elif cls.variant == "cov":
cls.func = CovarianceSpectrum
elif cls.variant == "lag":
cls.func = LagSpectrum
spec = cls.func(energy_spec=[0.3, 12, 15])
spec.freq_interval = [0.1, 0.2]
spec.ref_band = [0.3, 12]
spec.bin_time = 0.01
spec.segment_size = 100
spec.cross = cls.variant == "complcov"
cls.sting_obj = spec

def test_astropy_roundtrip(self):
so = copy.deepcopy(self.sting_obj)
ts = so.to_astropy_table()
new_so = self.func.from_astropy_table(ts)
assert so == new_so

@pytest.mark.skipif("not _HAS_XARRAY")
def test_xarray_roundtrip(self):
so = copy.deepcopy(self.sting_obj)
ts = so.to_xarray()
new_so = self.func.from_xarray(ts)
assert so == new_so

@pytest.mark.skipif("not _HAS_PANDAS")
def test_pandas_roundtrip(self):
so = copy.deepcopy(self.sting_obj)
ts = so.to_pandas()
new_so = self.func.from_pandas(ts)
assert so == new_so

def test_astropy_roundtrip_empty(self):
# Set an attribute to a DummyStingrayObj. It will *not* be saved
so = self.func()
ts = so.to_astropy_table()
new_so = self.func.from_astropy_table(ts)
assert new_so.energy == []
assert so == new_so

@pytest.mark.skipif("not _HAS_XARRAY")
def test_xarray_roundtrip_empty(self):
so = self.func()
ts = so.to_xarray()
new_so = self.func.from_xarray(ts)
assert new_so.energy == []
assert so == new_so

@pytest.mark.skipif("not _HAS_PANDAS")
def test_pandas_roundtrip_empty(self):
so = self.func()
ts = so.to_pandas()
new_so = self.func.from_pandas(ts)
assert new_so.energy == []
assert so == new_so

@pytest.mark.skipif("not _HAS_H5PY")
def test_hdf_roundtrip(self):
so = copy.deepcopy(self.sting_obj)
so.write("dummy.hdf5")
new_so = so.read("dummy.hdf5")
os.unlink("dummy.hdf5")

assert so == new_so

def test_file_roundtrip_fits(self):
so = copy.deepcopy(self.sting_obj)
with pytest.warns(
UserWarning, match=".* output does not serialize the metadata at the moment"
):
so.write("dummy.fits")
new_so = self.func.read("dummy.fits")
os.unlink("dummy.fits")
assert so == new_so

@pytest.mark.parametrize("fmt", ["ascii", "ascii.ecsv"])
def test_file_roundtrip(self, fmt):
so = copy.deepcopy(self.sting_obj)
with pytest.warns(UserWarning, match=".* output does not serialize the metadata"):
so.write(f"dummy.{fmt}", fmt=fmt)
new_so = self.func.read(f"dummy.{fmt}", fmt=fmt)
os.unlink(f"dummy.{fmt}")

assert so == new_so

def test_file_roundtrip_pickle(self):
fmt = "pickle"
so = copy.deepcopy(self.sting_obj)
so.write(f"dummy.{fmt}", fmt=fmt)
new_so = self.func.read(f"dummy.{fmt}", fmt=fmt)
os.unlink(f"dummy.{fmt}")

assert so == new_so


class TestCovarianceIO(BaseTestIO):
variant = "cov"


class TestComplexCovarianceIO(BaseTestIO):
variant = "complcov"


class TestRmsIO(BaseTestIO):
variant = "rms"


class TestLagIO(BaseTestIO):
variant = "lag"


@pytest.mark.slow
class TestLagEnergySpectrum(object):
@classmethod
Expand Down Expand Up @@ -427,80 +603,3 @@ def test_lagspectrum_invalid_warns(self):

assert np.all(np.isnan(lag.spectrum))
assert np.all(np.isnan(lag.spectrum_error))


class TestRoundTrip:
@classmethod
def setup_class(cls):
tstart = 0.0
tend = 100.0
nphot = 1000
alltimes = np.random.uniform(tstart, tend, nphot)
alltimes.sort()
cls.events = EventList(
alltimes, energy=np.random.uniform(0.3, 12, nphot), gti=[[tstart, tend]]
)
cls.vespec = DummyVarEnergy(
cls.events, [0.0, 10000], (0.5, 5, 10, "lin"), [0.3, 10], bin_time=0.1
)
cls.vespec.spectrum = np.zeros_like(cls.vespec.energy)
cls.vespec.spectrum_error = np.zeros_like(cls.vespec.energy)

def _check_equal(self, so, table):
for attr in ["energy", "spectrum", "spectrum_error"]:
assert np.allclose(getattr(so, attr), table[attr])

if hasattr(table, "meta"):
for attr in ["freq_interval"]:
assert getattr(so, attr) == table.meta[attr]
if hasattr(table, "attrs"):
for attr in ["freq_interval"]:
assert getattr(so, attr) == table.attrs[attr]

def test_astropy_export(self):
so = self.vespec
ts = so.to_astropy_table()
self._check_equal(so, ts)
with pytest.raises(NotImplementedError):
so.from_astropy_table(ts)

@pytest.mark.skipif("not _HAS_XARRAY")
def test_xarray_export(self):
so = self.vespec
ts = so.to_xarray()
self._check_equal(so, ts)
with pytest.raises(NotImplementedError):
so.from_xarray(ts)

@pytest.mark.skipif("not _HAS_PANDAS")
def test_pandas_export(self):
so = self.vespec
ts = so.to_pandas()
self._check_equal(so, ts)
with pytest.raises(NotImplementedError):
so.from_pandas(ts)

@pytest.mark.skipif("not _HAS_H5PY")
def test_hdf_export(self):
so = self.vespec
so.write("dummy.hdf5")
new_so = Table.read("dummy.hdf5")
os.unlink("dummy.hdf5")
self._check_equal(so, new_so)

@pytest.mark.parametrize("fmt", ["ascii.ecsv", "ascii", "fits"])
def test_file_export(self, fmt):
so = self.vespec
with pytest.warns(UserWarning, match=".* output does not serialize the metadata"):
so.write("dummy", fmt=fmt)
new_so = Table.read("dummy", format=fmt)
os.unlink("dummy")
self._check_equal(so, new_so)

@pytest.mark.parametrize("fmt", ["pickle"])
def test_file_export_pickle(self, fmt):
so = self.vespec
so.write("dummy", fmt=fmt)
new_so = so.read("dummy", fmt=fmt)
os.unlink("dummy")
self._check_equal(so, new_so.to_astropy_table())
Loading

0 comments on commit c298e2c

Please sign in to comment.