Skip to content

Commit

Permalink
Merge pull request #319 from HERA-Team/use-uvdata-new
Browse files Browse the repository at this point in the history
Update to use numpy=2 and pyuvdata=3
  • Loading branch information
steven-murray authored Aug 5, 2024
2 parents 57cde9b + 2be06c3 commit 4e0f1e7
Show file tree
Hide file tree
Showing 21 changed files with 87 additions and 72 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_suite.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
os: [ubuntu-latest, macos-13]
python-version: ["3.10", "3.11", "3.12"]

steps:
Expand Down
40 changes: 23 additions & 17 deletions hera_sim/adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,12 @@ def match_antennas(
for target_ant in target_copy.ant_2_array
]
)
for i, bl in enumerate(target_copy.baseline_array):
ant1, ant2 = target.baseline_to_antnums(bl)
ant1 = target_to_reference_map[ant1]
ant2 = target_to_reference_map[ant2]
newbl = target.antnums_to_baseline(ant1, ant2)
target_copy.baseline_array[i] = newbl

attrs_to_update = tuple()
if relabel_antennas:
Expand Down Expand Up @@ -359,6 +365,9 @@ def match_antennas(
]
)

target_copy._clear_key2ind_cache(target_copy)
target_copy._clear_antpair2ind_cache(target_copy)

# Now update the data... this will be a little messy.
for antpairpol, vis in target.antpairpol_iter():
ant1, ant2, pol = antpairpol
Expand All @@ -378,17 +387,21 @@ def match_antennas(

# Figure out how to slice through the new data array.
blts, conj_blts, pol_inds = target_copy._key2inds(new_antpairpol)
if len(blts) > 0:

if blts is not None:
# The new baseline has the same conjugation as the old one.
this_slice = (blts, slice(None), pol_inds[0])
this_slice = (
blts,
slice(None),
pol_inds[0].start,
)
else: # pragma: no cover
# The new baseline is conjugated relative to the old one.
# Given the handling of the antenna relabeling, this might not actually
# ever be called.
this_slice = (conj_blts, slice(None), pol_inds[1])
vis = vis.conj()
new_antpairpol = new_antpairpol[:2][::-1] + (pol,)

# If we needed to reflect the entire array to find the best match, then
# we need to make sure to conjugate the visibilities since the reflection
# is effectively undone by baseline conjugation.
Expand All @@ -400,13 +413,6 @@ def match_antennas(
target_copy.flag_array[this_slice] = target.get_flags(antpairpol)
target_copy.nsample_array[this_slice] = target.get_nsamples(antpairpol)

# Update the baseline array in case the antenna numbers got jumbled.
old_bl_int = target.antnums_to_baseline(ant1, ant2)
new_bl_int = target.antnums_to_baseline(*new_antpairpol[:2])
target_copy.baseline_array[target_copy.baseline_array == old_bl_int] = (
new_bl_int
)

# Update the uvw array just to be safe.
target_copy.set_uvws_from_antenna_positions()

Expand Down Expand Up @@ -584,9 +590,9 @@ def iswrapped(lsts):
ant1, ant2 = antpair
this_slice = slice(i, None, target.Nbls)
old_blts = target._key2inds(antpair)[0] # As a reference
this_uvw = target.uvw_array[old_blts[0]]
this_baseline = target.baseline_array[old_blts[0]]
this_integration_time = target.integration_time[old_blts[0]]
this_uvw = target.uvw_array[old_blts][0]
this_baseline = target.baseline_array[old_blts][0]
this_integration_time = target.integration_time[old_blts][0]

# Now actually update the metadata.
new_ant_1_array[this_slice] = ant1
Expand Down Expand Up @@ -791,10 +797,10 @@ def rephase_to_reference(
for i, antpair in enumerate(target.get_antpairs()):
ant1, ant2 = antpair
this_slice = slice(i, None, target.Nbls)
old_blt = target._key2inds(antpair)[0][0] # As a reference
this_uvw = target.uvw_array[old_blt]
this_baseline = target.baseline_array[old_blt]
this_integration_time = target.integration_time[old_blt]
old_blts = target._key2inds(antpair)[0] # As a reference
this_uvw = target.uvw_array[old_blts][0]
this_baseline = target.baseline_array[old_blts][0]
this_integration_time = target.integration_time[old_blts][0]

# Update the metadata.
new_ant_1_array[this_slice] = ant1
Expand Down
4 changes: 2 additions & 2 deletions hera_sim/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ def _check_config(self):
"""Check and warn if any keys in the configuration are repeated."""
# initialize dictionaries that enumerate the key, value pairs
# in the raw configuration dictionary
counts = {key: 0 for key in self().keys()}
values = {key: [] for key in self().keys()}
counts = dict.fromkeys(self().keys(), 0)
values = {k: [] for k in self().keys()}

# actually do the enumeration
self._recursive_enumerate(counts, values, self._raw_config)
Expand Down
5 changes: 4 additions & 1 deletion hera_sim/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ def empty_uvdata(
# only specify defaults this way for
# things that are *not* season-specific
polarization_array = kwargs.pop("polarization_array", ["xx"])
telescope_location = list(kwargs.pop("telescope_location", HERA_LAT_LON_ALT))
telescope_location = [
float(x) for x in kwargs.pop("telescope_location", HERA_LAT_LON_ALT)
]

telescope_name = kwargs.pop("telescope_name", "hera_sim")
write_files = kwargs.pop("write_files", False)

Expand Down
3 changes: 1 addition & 2 deletions hera_sim/sigchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,12 +1142,11 @@ def build_coupling_matrix(
else:
power_beam = uvbeam.copy()
power_beam.efield_to_power()
if power_beam.freq_interp_kind is None:
power_beam.freq_interp_kind = freq_interp
power_beam = power_beam.interp(
freq_array=freqs * units.GHz.to("Hz"),
new_object=True,
interpolation_function=pixel_interp,
freq_interp_kind=freq_interp,
) # Interpolate to the desired frequencies
power_beam.to_healpix()
power_beam.peak_normalize()
Expand Down
6 changes: 3 additions & 3 deletions hera_sim/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def get(
# First, find out if it needs to be conjugated.
try:
blt_inds = self.data.antpair2ind(ant1, ant2)
if blt_inds.size == 0:
if blt_inds is None:
raise ValueError
conj_data = False
except ValueError:
Expand Down Expand Up @@ -958,7 +958,7 @@ def _iterate_antpair_pols(self):
for ant1, ant2, pol in self.data.get_antpairpols():
blt_inds = self.data.antpair2ind((ant1, ant2))
pol_ind = self.data.get_pols().index(pol)
if blt_inds.size:
if blt_inds is not None:
yield ant1, ant2, pol, blt_inds, pol_ind

def _iteratively_apply(
Expand Down Expand Up @@ -1541,7 +1541,7 @@ def checkpol(pol):

if key is None:
ant1, ant2, pol = None, None, None
elif np.issubdtype(type(key), int):
elif np.issubdtype(type(key), np.integer):
# Figure out if it's an antenna or baseline integer
if key in self.antpos:
ant1, ant2, pol = key, None, None
Expand Down
4 changes: 3 additions & 1 deletion hera_sim/tests/test_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def get_all_baselines(antpairs):
ant_1_array = [antpair[0] for antpair in antpairs]
ant_2_array = [antpair[1] for antpair in antpairs]
return set(
antnums_to_baseline(ant_1_array + ant_2_array, ant_2_array + ant_1_array, 0)
antnums_to_baseline(
ant_1_array + ant_2_array, ant_2_array + ant_1_array, Nants_telescope=0
)
)


Expand Down
2 changes: 1 addition & 1 deletion hera_sim/tests/test_beams.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def run_sim(
polarization_array=pol_array,
x_orientation="east",
)
freqs = np.unique(uvdata.freq_array)
freqs = uvdata.freq_array
ra_dec, flux, spectral_index = sources

# calculate source fluxes for hera_sim
Expand Down
4 changes: 3 additions & 1 deletion hera_sim/tests/test_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,6 @@ def test_thermal_noise_with_phase_wrap(freqs, omega_p, autovis, expectation):
)
with expectation:
vis = noise_sim(lsts=wrapped_lsts, freqs=freqs)
assert np.isclose(np.std(vis), 1 / expected_SNR, rtol=1 / np.sqrt(vis.size))
np.testing.assert_allclose(
np.std(vis), 1 / expected_SNR, rtol=1 / np.sqrt(vis.size)
)
6 changes: 3 additions & 3 deletions hera_sim/tests/test_sigchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def test_mutual_coupling(use_numba):

# Now let's actually mock up the visibilities.
for ai, aj in uvdata.get_antpairs():
if uvdata.antpair2ind(ai, aj).size == 0:
if uvdata.antpair2ind(ai, aj) is None:
continue
ecef_bl = ecef_antpos[aj] - ecef_antpos[ai]
enu_bl = enu_antpos[aj] - enu_antpos[ai]
Expand Down Expand Up @@ -524,8 +524,8 @@ def uvbeam(tmp_path):
# Setup some things needed to mock up the UVBeam object
az = np.linspace(0, 2 * np.pi, 100)
za = np.linspace(np.pi / 2 - np.pi / 6, np.pi / 2 + np.pi / 6, 15)
freqs = np.linspace(99e6, 101e6, 20)[None, :]
data_shape = (2, 1, 2, freqs.size, za.size, az.size)
freqs = np.linspace(99e6, 101e6, 20)
data_shape = (2, 2, freqs.size, za.size, az.size)
basis_shape = (2, 2, za.size, az.size)
az_mesh, za_mesh = np.meshgrid(az, za)

Expand Down
2 changes: 1 addition & 1 deletion hera_sim/tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_nondefault_blt_order_lsts():
start_time=2458120.15,
array_layout=array_layout,
)
sim.data.reorder_blts("baseline", "time")
sim.data.reorder_blts("baseline", minor_order="time")
iswrapped = sim.lsts < sim.lsts[0]
lsts = sim.lsts + np.where(iswrapped, 2 * np.pi, 0)
assert np.all(lsts[1:] > lsts[:-1])
Expand Down
18 changes: 8 additions & 10 deletions hera_sim/tests/test_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def test_shapes(uvdata, simulator):
n_side=2**4,
)

assert sim.simulate().shape == (uvdata.Nblts, 1, NFREQ, uvdata.Npols)
assert sim.simulate().shape == (uvdata.Nblts, NFREQ, uvdata.Npols)


@pytest.mark.parametrize("precision, cdtype", [(1, np.complex64), (2, complex)])
Expand Down Expand Up @@ -447,14 +447,12 @@ def test_comparison(simulator, uvdata2, sky_model, beam_model):
.copy()
)

print(v0[0, 0, 0, 0])

v1 = VisibilitySimulation(
data_model=model_data, simulator=simulator(), n_side=2**4
).simulate()

assert v0.shape == v1.shape
print(v0[-9:, 0, 0, :], v1[-9:, 0, 0, :])
print(v0[-9:, 0, :], v1[-9:, 0, :])
np.testing.assert_allclose(v0, v1, rtol=0.05)


Expand Down Expand Up @@ -484,20 +482,20 @@ def test_ordering(uvdata_linear, simulator, order, conj):
sim.uvdata.reorder_blts(order="time", conj_convention="ant1<ant2")

assert np.allclose(
sim.uvdata.data_array[sim.uvdata.antpair2ind(0, 1), 0, 0, 0],
sim.uvdata.data_array[sim.uvdata.antpair2ind(1, 2), 0, 0, 0],
sim.uvdata.data_array[sim.uvdata.antpair2ind(0, 1), 0, 0],
sim.uvdata.data_array[sim.uvdata.antpair2ind(1, 2), 0, 0],
)

assert not np.allclose(sim.uvdata.get_data((0, 1)), sim.uvdata.get_data((0, 3)))

assert not np.allclose(
sim.uvdata.data_array[sim.uvdata.antpair2ind(0, 1), 0, 0, 0],
sim.uvdata.data_array[sim.uvdata.antpair2ind(0, 3), 0, 0, 0],
sim.uvdata.data_array[sim.uvdata.antpair2ind(0, 1), 0, 0],
sim.uvdata.data_array[sim.uvdata.antpair2ind(0, 3), 0, 0],
)

assert not np.allclose(
sim.uvdata.data_array[sim.uvdata.antpair2ind(0, 2), 0, 0, 0],
sim.uvdata.data_array[sim.uvdata.antpair2ind(0, 3), 0, 0, 0],
sim.uvdata.data_array[sim.uvdata.antpair2ind(0, 2), 0, 0],
sim.uvdata.data_array[sim.uvdata.antpair2ind(0, 3), 0, 0],
)


Expand Down
6 changes: 4 additions & 2 deletions hera_sim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,15 +675,17 @@ def find_baseline_orientations(
Dictionary mapping antenna pairs ``(ai,aj)`` to baseline orientations.
Orientations are defined on [0,2pi).
"""
groups, baselines = uvutils.get_antenna_redundancies(
groups, baselines = uvutils.redundancy.get_antenna_redundancies(
antenna_numbers, enu_antpos, include_autos=False
)[:2]
antpair2angle = {}
for group, (e, n, _u) in zip(groups, baselines):
angle = Longitude(np.arctan2(n, e) * units.rad).value
conj_angle = Longitude((angle + np.pi) * units.rad).value
for blnum in group:
ai, aj = uvutils.baseline_to_antnums(blnum, antenna_numbers.size)
ai, aj = uvutils.baseline_to_antnums(
blnum, Nants_telescope=antenna_numbers.size
)
antpair2angle[(ai, aj)] = angle
antpair2angle[(aj, ai)] = conj_angle
return antpair2angle
Expand Down
15 changes: 12 additions & 3 deletions hera_sim/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
DEFAULT_FQS = np.linspace(0.1, 0.2, 1024, endpoint=False)


def sim_red_data(reds, gains=None, shape=(10, 10), gain_scatter=0.1):
def sim_red_data(
reds,
gains=None,
shape=(10, 10),
gain_scatter=0.1,
rng: np.random.Generator | None = None,
):
"""
Simulate thermal-noise-free random but redundant (up to gains) visibilities.
Expand Down Expand Up @@ -37,15 +43,18 @@ def sim_red_data(reds, gains=None, shape=(10, 10), gain_scatter=0.1):
"""
from hera_cal.utils import split_bl

if rng is None:
rng = np.random.default_rng()

data, true_vis = {}, {}
ants = sorted({ant for bls in reds for bl in bls for ant in split_bl(bl)})
gains = {} if gains is None else deepcopy(gains)
for ant in ants:
gains[ant] = gains.get(
ant, 1 + gain_scatter * noise.white_noise((1,))
ant, 1 + gain_scatter * noise.white_noise((1,), rng=rng)
) * np.ones(shape, dtype=complex)
for bls in reds:
true_vis[bls[0]] = noise.white_noise(shape)
true_vis[bls[0]] = noise.white_noise(shape, rng=rng)
for bl in bls:
data[bl] = (
true_vis[bls[0]]
Expand Down
6 changes: 5 additions & 1 deletion hera_sim/visibilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,11 @@ def run_vis_sim(args):
blt_inds = np.load(args.compress)

data_model.uvdata._select_by_index(
blt_inds, None, None, "Compressed by redundancy", keep_all_metadata=True
blt_inds=blt_inds,
pol_inds=None,
freq_inds=None,
history_update_string="Compressed by redundancy",
keep_all_metadata=True,
)

logger.info("Done with compression.")
Expand Down
6 changes: 3 additions & 3 deletions hera_sim/visibilities/fftvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def validate(self, data_model: ModelData):
# TODO: the following is extremely slow. If possible, it would be good to
# find a better way to do it.
if any(
len(data_model.uvdata.antpair2ind(ai, aj)) > 0
and len(data_model.uvdata.antpair2ind(aj, ai)) > 0
data_model.uvdata.antpair2ind(ai, aj) is not None
and data_model.uvdata.antpair2ind(aj, ai) is not None
for ai, aj in data_model.uvdata.get_antpairs()
if ai != aj
):
Expand Down Expand Up @@ -313,7 +313,7 @@ def simulate(self, data_model):

logger.info("... re-ordering visibilities...")
self._reorder_vis(
req_pols, data_model.uvdata, visfull[:, 0, i], vis, antpairs, polarized
req_pols, data_model.uvdata, visfull[:, i], vis, antpairs, polarized
)

# Reduce visfull array if in MPI mode
Expand Down
8 changes: 4 additions & 4 deletions hera_sim/visibilities/matvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def validate(self, data_model: ModelData):
# TODO: the following is extremely slow. If possible, it would be good to
# find a better way to do it.
if any(
len(data_model.uvdata.antpair2ind(ai, aj)) > 0
and len(data_model.uvdata.antpair2ind(aj, ai)) > 0
data_model.uvdata.antpair2ind(ai, aj) is not None
and data_model.uvdata.antpair2ind(aj, ai) is not None
for ai, aj in data_model.uvdata.get_antpairs()
if ai != aj
):
Expand Down Expand Up @@ -395,7 +395,7 @@ def simulate(self, data_model):

logger.info("... re-ordering visibilities...")
self._reorder_vis(
req_pols, data_model.uvdata, visfull[:, 0, i], vis, ant_list, polarized
req_pols, data_model.uvdata, visfull[:, i], vis, ant_list, polarized
)

# Reduce visfull array if in MPI mode
Expand Down Expand Up @@ -448,7 +448,7 @@ def _reorder_vis(self, req_pols, uvdata, visfull, vis, ant_list, polarized):

# get all blt indices corresponding to this antpair
indx = uvdata.antpair2ind(antnum1, antnum2)
if len(indx) == 0:
if indx is None:
# maybe we chose the wrong ordering according to the data. Then
# we just conjugate.
indx = uvdata.antpair2ind(antnum2, antnum1)
Expand Down
2 changes: 0 additions & 2 deletions hera_sim/visibilities/pyuvsim_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,4 @@ def simulate(self, data_model: ModelData):
catalog=pyuvsim.simsetup.SkyModelData(data_model.sky_model),
quiet=self.quiet,
)
out_uv.use_current_array_shapes()
data_model.uvdata.use_current_array_shapes()
return out_uv.data_array
Loading

0 comments on commit 4e0f1e7

Please sign in to comment.