Skip to content

Commit

Permalink
tests: last of the rng fixes?
Browse files Browse the repository at this point in the history
  • Loading branch information
r-pascua committed Apr 24, 2024
1 parent 7e0beac commit f21b3a1
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 26 deletions.
3 changes: 2 additions & 1 deletion hera_sim/tests/test_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def test_match_subarray():

def test_match_translated_array():
# A simple translation should just be undone
translation = np.random.uniform(-1, 1, 3)
rng = np.random.default_rng(0)
translation = rng.uniform(-1, 1, 3)
array_1 = {0: [0, 0, 0], 1: [1, 0, 0], 2: [1, 1, 0]}
array_2 = {ant: np.array(pos) - translation for ant, pos in array_1.items()}
# Won't be an exact match to machine precision, so need some small tolerance.
Expand Down
7 changes: 4 additions & 3 deletions hera_sim/tests/test_compare_pyuvsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@ def get_sky_model(uvdata, nsource):
sources = [
[125.7, -30.72, 2, 0], # Fix a single source near zenith
]
rng = np.random.default_rng(0)
if nsource > 1: # Add random other sources
ra = np.random.uniform(low=0.0, high=360.0, size=nsource - 1)
dec = -30.72 + np.random.random(nsource - 1) * 10.0
flux = np.random.random(nsource - 1) * 4
ra = rng.uniform(low=0.0, high=360.0, size=nsource - 1)
dec = -30.72 + rng.random(nsource - 1) * 10.0
flux = rng.random(nsource - 1) * 4
for i in range(nsource - 1):
sources.append([ra[i], dec[i], flux[i], 0])
sources = np.array(sources)
Expand Down
4 changes: 2 additions & 2 deletions hera_sim/tests/test_foregrounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def test_foreground_conjugation(freqs, lsts, Tsky_mdl, omega_p, model):

conj_kwargs = kwargs.copy()
conj_kwargs["bl_vec"] = -bl_vec
vis = model(**kwargs)
conj_vis = model(**conj_kwargs)
vis = model(**kwargs, rng=np.random.default_rng(0))
conj_vis = model(**conj_kwargs, rng=np.random.default_rng(0))
assert np.allclose(vis, conj_vis.conj()) # Assert V_ij = V*_ji


Expand Down
2 changes: 2 additions & 0 deletions hera_sim/tests/test_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,15 @@ def test_thermal_noise_with_phase_wrap(freqs, omega_p, autovis, expectation):
channel_width = np.mean(np.diff(freqs)) * units.GHz.to("Hz")
expected_SNR = np.sqrt(integration_time * channel_width)
Trx = 0
rng = np.random.default_rng(0)
if autovis is not None:
autovis = np.ones((wrapped_lsts.size, freqs.size), dtype=complex)
noise_sim = noise.ThermalNoise(
Tsky_mdl=noise.HERA_Tsky_mdl["xx"],
omega_p=omega_p,
Trx=Trx,
autovis=autovis,
rng=rng,
)
with expectation:
vis = noise_sim(lsts=wrapped_lsts, freqs=freqs)
Expand Down
6 changes: 4 additions & 2 deletions hera_sim/tests/test_rfi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def lsts():
@pytest.mark.parametrize("station_freq", [0.150, 0.1505])
def test_rfi_station_strength(freqs, lsts, station_freq):
# Generate RFI for a single station.
station = rfi.RfiStation(station_freq, std=0.0)
station = rfi.RfiStation(station_freq, std=0.0, rng=np.random.default_rng(0))
rfi_vis = station(lsts, freqs)

# Check that the RFI is inserted where it should be at the correct level.
Expand Down Expand Up @@ -48,7 +48,9 @@ def test_rfi_station_from_file(freqs, lsts):
filename = DATA_PATH / "HERA_H1C_RFI_STATIONS.npy"
station_params = np.load(filename)
Nstations = station_params.shape[0]
rfi_vis = rfi.rfi_stations(lsts, freqs, stations=filename)
rfi_vis = rfi.rfi_stations(
lsts, freqs, stations=filename, rng=np.random.default_rng(0)
)
assert np.sum(np.sum(np.abs(rfi_vis), axis=0).astype(bool)) >= Nstations


Expand Down
16 changes: 12 additions & 4 deletions hera_sim/tests/test_sigchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_gen_bandpass():
assert 2 in g
assert g[1].size == fqs.size
assert np.all(g[1] == g[2])
g = sigchain.gen_bandpass(fqs, list(range(10)), 0.2)
g = sigchain.gen_bandpass(fqs, list(range(10)), 0.2, rng=np.random.default_rng(0))
assert not np.all(g[1] == g[2])


Expand Down Expand Up @@ -221,7 +221,7 @@ def test_reflection_spectrum():
dlys = np.arange(-1000, 1001, 5)
fqs = uvtools.utils.fourier_freqs(dlys)
fqs += 0.1 - fqs.min() # Range from 100 MHz to whatever the upper bound is
reflections = reflections(fqs, range(100))
reflections = reflections(fqs, range(100), rng=np.random.default_rng(0))
reflections = np.vstack(list(reflections.values()))
spectra = np.abs(uvtools.utils.FFT(reflections, axis=1))
spectra = spectra / spectra.max(axis=1).reshape(-1, 1)
Expand Down Expand Up @@ -270,8 +270,9 @@ def test_amp_jitter():
ants = range(10000)
amp = 5
amp_jitter = 0.1
rng = np.random.default_rng(0)
jittered_amps = sigchain.Reflections._complete_params(
ants, amp=amp, amp_jitter=amp_jitter
ants, amp=amp, amp_jitter=amp_jitter, rng=rng
)[0]
assert np.isclose(jittered_amps.mean(), amp, rtol=0.05)
assert np.isclose(jittered_amps.std(), amp * amp_jitter, rtol=0.05)
Expand All @@ -281,8 +282,9 @@ def test_dly_jitter():
ants = range(10000)
dly = 500
dly_jitter = 20
rng = np.random.default_rng(0)
jittered_dlys = sigchain.Reflections._complete_params(
ants, dly=dly, dly_jitter=dly_jitter
ants, dly=dly, dly_jitter=dly_jitter, rng=rng
)[1]
assert np.isclose(jittered_dlys.mean(), dly, rtol=0.05)
assert np.isclose(jittered_dlys.std(), dly_jitter, rtol=0.05)
Expand All @@ -297,6 +299,7 @@ def test_cross_coupling_spectrum(fqs, dlys, Tsky):
amp_range=amp_range,
dly_range=dly_range,
symmetrize=True,
rng=np.random.default_rng(0),
)
amplitudes = np.logspace(*amp_range, n_copies)
delays = np.linspace(*dly_range, n_copies)
Expand Down Expand Up @@ -352,6 +355,7 @@ def test_over_air_cross_coupling(Tsky_mdl, lsts):
cable_delays=cable_delays,
max_delay=max_delay,
amp_decay_fac=amp_decay_fac,
rng=np.random.default_rng(0),
)
xtalk = gen_xtalk(fqs, (0, 1), antpos, Tsky, Tsky)
xt_fft = uvtools.utils.FFT(xtalk, axis=1, taper="bh7")
Expand Down Expand Up @@ -740,12 +744,14 @@ def test_vary_gain_amp_sinusoidal(gains, times, fringe_rates, fringe_keys):

def test_vary_gain_amp_noiselike(gains, times):
vary_amp = 0.1
rng = np.random.default_rng(0)
varied_gain = sigchain.vary_gains_in_time(
gains=gains,
times=times,
parameter="amp",
variation_mode="noiselike",
variation_amp=vary_amp,
rng=rng,
)[0]

# Check that the mean gain amplitude is the original gain amplitude.
Expand Down Expand Up @@ -817,6 +823,7 @@ def test_vary_gain_phase_noiselike(gains, times, delay_phases, phase_offsets):
parameter="phs",
variation_mode="noiselike",
variation_amp=vary_amp,
rng=np.random.default_rng(0),
)[0]

varied_phases = np.angle(varied_gain)
Expand Down Expand Up @@ -894,6 +901,7 @@ def test_vary_gain_delay_noiselike(gains, times, freqs, delays):
parameter="dly",
variation_amp=vary_amp,
variation_mode="noiselike",
rng=np.random.default_rng(0),
)[0]

# Determine the bandpass delay at each time.
Expand Down
34 changes: 24 additions & 10 deletions hera_sim/tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,27 @@ def test_nondefault_blt_order_lsts():


def test_add_with_str(base_sim):
base_sim.add("noiselike_eor")
base_sim.add("noiselike_eor", rng=np.random.default_rng(0))
assert not np.all(base_sim.data.data_array == 0)


def test_add_with_builtin_class(base_sim):
base_sim.add(DiffuseForeground, Tsky_mdl=Tsky_mdl, omega_p=omega_p)
base_sim.add(
DiffuseForeground,
Tsky_mdl=Tsky_mdl,
omega_p=omega_p,
rng=np.random.default_rng(0),
)
assert not np.all(np.isclose(base_sim.data.data_array, 0))


def test_add_with_class_instance(base_sim):
base_sim.add(diffuse_foreground, Tsky_mdl=Tsky_mdl, omega_p=omega_p)
base_sim.add(
diffuse_foreground,
Tsky_mdl=Tsky_mdl,
omega_p=omega_p,
rng=np.random.default_rng(0),
)
assert not np.all(np.isclose(base_sim.data.data_array, 0))


Expand Down Expand Up @@ -301,7 +311,9 @@ def test_get_multiplicative_effect(base_sim, pol, ant1):


def test_not_add_vis(base_sim):
vis = base_sim.add("noiselike_eor", add_vis=False, ret_vis=True)
vis = base_sim.add(
"noiselike_eor", add_vis=False, ret_vis=True, rng=np.random.default_rng(0)
)

assert np.all(base_sim.data.data_array == 0)

Expand All @@ -315,14 +327,16 @@ def test_not_add_vis(base_sim):


def test_adding_vis_but_also_returning(base_sim):
vis = base_sim.add("noiselike_eor", ret_vis=True)
vis = base_sim.add("noiselike_eor", ret_vis=True, rng=np.random.default_rng(0))

assert not np.all(vis == 0)
assert np.all(np.isclose(vis, base_sim.data.data_array))

# use season defaults for simplicity
defaults.set("h1c")
vis += base_sim.add("diffuse_foreground", ret_vis=True)
vis += base_sim.add(
"diffuse_foreground", ret_vis=True, rng=np.random.default_rng(90)
)
# deactivate defaults for good measure
defaults.deactivate()
assert np.all(np.isclose(vis, base_sim.data.data_array))
Expand All @@ -334,7 +348,7 @@ def test_filter():
# only add visibilities for the (0,1) baseline
vis_filter = (0, 1, "xx")

sim.add("noiselike_eor", vis_filter=vis_filter)
sim.add("noiselike_eor", vis_filter=vis_filter, rng=np.random.default_rng(10))
assert np.all(sim.data.get_data(0, 0) == 0)
assert np.all(sim.data.get_data(1, 1) == 0)
assert np.all(sim.data.get_data(0, 1) != 0)
Expand Down Expand Up @@ -599,13 +613,13 @@ def test_legacy_funcs(component):

def test_vis_filter_single_pol():
sim = create_sim(polarization_array=["xx", "yy"])
sim.add("noiselike_eor", vis_filter=["xx"])
sim.add("noiselike_eor", vis_filter=["xx"], rng=np.random.default_rng(99))
assert np.all(sim.get_data("xx")) and not np.any(sim.get_data("yy"))


def test_vis_filter_two_pol():
sim = create_sim(polarization_array=["xx", "xy", "yx", "yy"])
sim.add("noiselike_eor", vis_filter=["xx", "yy"])
sim.add("noiselike_eor", vis_filter=["xx", "yy"], rng=np.random.default_rng(5))
assert all(
[
np.all(sim.get_data("xx")),
Expand All @@ -621,7 +635,7 @@ def test_vis_filter_arbitrary_key():
array_layout=hex_array(2, split_core=False, outriggers=0),
polarization_array=["xx", "yy"],
)
sim.add("noiselike_eor", vis_filter=[1, 3, 5, "xx"])
sim.add("noiselike_eor", vis_filter=[1, 3, 5, "xx"], rng=np.random.default_rng(7))
bls = sim.data.get_antpairs()
assert not np.any(sim.get_data("yy"))
assert all(
Expand Down
9 changes: 5 additions & 4 deletions hera_sim/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ def test_rough_filter_noisy_data(freqs, lsts, filter_type):
"fringe_filter_type": "gauss",
"fr_width": 1e-4,
}
rng = np.random.default_rng(0)
for i in range(Nrealizations):
data = utils.gen_white_noise((lsts.size, freqs.size))
data = utils.gen_white_noise((lsts.size, freqs.size), rng=rng)
filtered_data = filt(data, *args, **kwargs)
filtered_data_mean = np.mean(filtered_data)
mean_values[i] = filtered_data_mean.real, filtered_data_mean.imag
Expand Down Expand Up @@ -276,7 +277,7 @@ def test_fringe_filter_custom(freqs, lsts, fringe_rates):
@pytest.mark.parametrize("bl_len_ns", [50, 150])
@pytest.mark.parametrize("fr_width", [1e-4, 3e-4])
def test_rough_fringe_filter_noisy_data(freqs, lsts, fringe_rates, bl_len_ns, fr_width):
data = utils.gen_white_noise((lsts.size, freqs.size))
data = utils.gen_white_noise((lsts.size, freqs.size), rng=np.random.default_rng(0))
max_fringe_rates = utils.calc_max_fringe_rate(freqs, bl_len_ns)
filt_data = utils.rough_fringe_filter(
data, lsts, freqs, bl_len_ns, fringe_filter_type="gauss", fr_width=fr_width
Expand Down Expand Up @@ -333,15 +334,15 @@ def test_gen_white_noise_shape(shape):

@pytest.mark.parametrize("shape", [100, (100, 200)])
def test_gen_white_noise_mean(shape):
noise = utils.gen_white_noise(shape)
noise = utils.gen_white_noise(shape, rng=np.random.default_rng(0))
assert np.allclose(
[noise.mean().real, noise.mean().imag], 0, rtol=0, atol=5 / np.sqrt(noise.size)
)


@pytest.mark.parametrize("shape", [100, (100, 200)])
def test_gen_white_noise_variance(shape):
noise = utils.gen_white_noise(shape)
noise = utils.gen_white_noise(shape, rng=np.random.default_rng(0))
assert np.isclose(np.std(noise), 1, rtol=0, atol=0.1)


Expand Down

0 comments on commit f21b3a1

Please sign in to comment.