Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow gpu support #742

Draft
wants to merge 36 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d5d16af
added args to select main lobe.
aewallwi Sep 4, 2021
94a9a0f
add mainlobe radius filtering mode.
aewallwi Sep 4, 2021
d2c6285
add mainlobe radius arg.
aewallwi Sep 4, 2021
bb2502d
style fix.
aewallwi Sep 4, 2021
f48c203
pass mainlobe arg
aewallwi Sep 4, 2021
e02cb97
Merge branch 'master' into select_main_lobe
aewallwi Sep 4, 2021
c1dc372
Merge branch 'master' into select_main_lobe
aewallwi Sep 8, 2021
d18aebb
beam based frfilter infrastructure.
aewallwi Sep 13, 2021
05ca60e
fix up things breaking existing unittests.
aewallwi Sep 13, 2021
219ff6d
make sure to install tophat_frfilter_run.py.
aewallwi Sep 13, 2021
3390185
get rid of blt read in chunker.
aewallwi Sep 14, 2021
f0bdda3
import uvbeam!
aewallwi Sep 14, 2021
ae1cf5a
us to_filter instead of skip_autos.
aewallwi Sep 14, 2021
e134fdd
ad frfilter test data.
aewallwi Sep 14, 2021
f115958
replace with beam written by more up to date astropy version.
aewallwi Sep 14, 2021
207c54b
unittests for mainlobe fr filter.
aewallwi Sep 15, 2021
9d69bf9
style fixes.
aewallwi Sep 15, 2021
c9310d9
fix binning bug.
aewallwi Sep 15, 2021
26899b1
trim functionality.
aewallwi Sep 16, 2021
50b4cd3
udate script.
aewallwi Sep 16, 2021
660e3ed
fix unittests.
aewallwi Sep 16, 2021
0005713
find and replace mainlobe radius.
aewallwi Sep 16, 2021
09a4c19
add verbose args
aewallwi Sep 17, 2021
eb1ea2a
standardize units and make sure to phasor whatever input_data is spec…
aewallwi Sep 17, 2021
223df1a
broadcast phasor.
aewallwi Sep 17, 2021
bad5dc1
make sure to get values.
aewallwi Sep 17, 2021
8f8681f
hard code sidereal day.
aewallwi Sep 17, 2021
17a1341
use const keyword
aewallwi Sep 17, 2021
1080e8c
use pre-computed fr-profiles if you wish.
aewallwi Sep 17, 2021
becb48c
wrong argparser!
aewallwi Sep 18, 2021
1176e12
fix freq_skip arg.
aewallwi Sep 18, 2021
c599be2
fix kwarg.
aewallwi Sep 20, 2021
c6219da
tensorflow arg and gpu support
aewallwi Sep 21, 2021
f9edc39
add to ci deps.
aewallwi Sep 21, 2021
65fdbe5
add imports.
aewallwi Sep 28, 2021
3cbcc3e
Merge branch 'main' into tensorflow_gpu_support
aewallwi Sep 28, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: Run Tests

on:
pull_request:
branches: [ master ]
branches: [ main ]
push:
branches: [ master ]
branches: [ main ]

jobs:
tests:
Expand All @@ -25,7 +25,7 @@ jobs:
fail-fast: false

steps:
- uses: actions/checkout@master
- uses: actions/checkout@main
with:
fetch-depth: 0

Expand Down
2 changes: 1 addition & 1 deletion ci/hera_cal_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies:

- pip:
#- git+https://github.com/RadioAstronomySoftwareGroup/pyuvdata
- git+https://github.com/HERA-Team/uvtools
- git+https://github.com/HERA-Team/uvtools@gpu_support
- git+https://github.com/HERA-Team/linsolve
- git+https://github.com/HERA-Team/hera_qm
- git+https://github.com/RadioAstronomySoftwareGroup/pyuvsim
Expand Down
2 changes: 1 addition & 1 deletion hera_cal/chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def chunk_files(filenames, inputfile, outputfile, chunk_size, type="data",
polarizations = chunked_files.pols
if spw_range is None:
spw_range = (0, chunked_files.Nfreqs)
data, flags, nsamples = chunked_files.read(axis='blt', polarizations=polarizations,
data, flags, nsamples = chunked_files.read(polarizations=polarizations,
freq_chans=range(spw_range[0], spw_range[1]))
elif type == 'gains':
chunked_files.read()
Expand Down
Binary file added hera_cal/data/fr_unittest_beam.beamfits
Binary file not shown.
Binary file added hera_cal/data/fr_unittest_data.uvh5
Binary file not shown.
478 changes: 450 additions & 28 deletions hera_cal/frf.py

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions hera_cal/lstbin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,11 @@ def baselines_same_across_nights(data_list):
# check whether baselines are the same across all nights
# by checking that every baseline occurs in data_list the same number times.
same_across_nights = False
baseline_counts = {}
baseline_counts = DataContainer({})
for dlist in data_list:
for k in dlist:
if k in baseline_counts:
baseline_counts[k] += 1
elif utils.reverse_bl(k) in baseline_counts:
baseline_counts[utils.reverse_bl(k)] += 1
else:
baseline_counts[k] = 1
same_across_nights = np.all([baseline_counts[k] == baseline_counts[bl] for bl in baseline_counts])
Expand Down Expand Up @@ -788,7 +786,7 @@ def lst_bin_files(data_files, input_cals=None, dlst=None, verbose=True, ntimes_p
key_baselines.append(key_bl)
reds.append(bl_nightly_dict[j])
bls_to_load.extend(bl_nightly_dict[j])

data, flags, nsamps = hd.read(bls=bls_to_load, times=tarr[tinds])
# if we want to throw away data associated with flagged antennas, throw it away.
if ex_ant_yaml_files is not None:
Expand Down Expand Up @@ -1080,6 +1078,7 @@ def gen_bl_nightly_dicts(hds, bl_error_tol=1.0, include_autos=True, redundant=Fa
for bl in grp:
# store baseline vectors for all data.
blvecs[bl] = hd.antpos[bl[1]] - hd.antpos[bl[0]]
blvecs[bl[::-1]] = hd.antpos[bl[0]] - hd.antpos[bl[1]]
# otherwise, loop through baselines, for each bl_nightly_dict, see if the first
# entry matches (or conjugate matches). If yes, append to that bl_nightly_dict
for grp in reds:
Expand Down
70 changes: 67 additions & 3 deletions hera_cal/tests/test_frf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import unittest
from scipy import stats
from scipy import constants
from pyuvdata import UVFlag


from pyuvdata import UVFlag, UVBeam
from .. import utils
from .. import datacontainer, io, frf
from ..data import DATA_PATH

Expand Down Expand Up @@ -423,6 +422,71 @@ def test_load_tophat_frfilter_and_write_multifile(self, tmpdir):
for k in doutput:
assert doutput[k].shape == d[k].shape

def test_get_fringe_rate_limits(self):
# simulations constructed with the notebook at https://drive.google.com/file/d/1jPPSmL3nqQbp7tTgP77j9KC0802iWyow/view?usp=sharing
test_beam = os.path.join(DATA_PATH, "fr_unittest_beam.beamfits")
test_data = os.path.join(DATA_PATH, "fr_unittest_data.uvh5")
uvd = UVData()
uvd.read_uvh5(test_data)
myfrf = frf.FRFilter(uvd)
myfrf.fft_data(data=myfrf.data, ax='time')
sim_c_frates = {}
sim_w_frates = {}
for bl in myfrf.dfft:
csum = np.cumsum(np.abs(myfrf.dfft[bl]) ** 2.)
csum /= csum.max()
frlow, frhigh = (myfrf.frates[np.argmin(np.abs(csum - 0.05))], myfrf.frates[np.argmin(np.abs(csum - 0.95))])
sim_c_frates[bl] = .5 * (frlow + frhigh)
sim_w_frates[bl] = .5 * np.abs(frlow - frhigh)
sim_w_frates[utils.reverse_bl(bl)] = sim_w_frates[bl]
sim_c_frates[utils.reverse_bl(bl)] = -sim_c_frates[bl]
uvb = UVBeam()
uvb.read_beamfits(test_beam)
c_frs, w_frs = frf.get_fringe_rate_limits(uvd, uvb)
profile_frates = {}
for bl in sim_c_frates:
assert np.isclose(c_frs[bl], sim_c_frates[bl], atol=0.2, rtol=0.)
assert np.isclose(w_frs[bl], sim_w_frates[bl], atol=0.2, rtol=0.)

def test_load_tophat_frfilter_and_write_beam_frates(self, tmpdir):
# simulations constructed with the notebook at https://drive.google.com/file/d/1jPPSmL3nqQbp7tTgP77j9KC0802iWyow/view?usp=sharing
# load in primary beam model and isotropic noise model of sky.
test_beam = os.path.join(DATA_PATH, "fr_unittest_beam.beamfits")
test_data = os.path.join(DATA_PATH, "fr_unittest_data.uvh5")
tmp_path = tmpdir.strpath
res_outfilename = os.path.join(tmp_path, 'resid.uvh5')
CLEAN_outfilename = os.path.join(tmp_path, 'model.uvh5')
filled_outfilename = os.path.join(tmp_path, 'filled.uvh5')
# perform cleaning.
to_filter = []
frf.load_tophat_frfilter_and_write(datafile_list=[test_data], uvbeam=test_beam, mode='dpss_leastsq', filled_outfilename=filled_outfilename,
CLEAN_outfilename=CLEAN_outfilename, frate_standoff=0.075,
res_outfilename=res_outfilename, percentile_high=97.5, percentile_low=2.5)
hd_input = io.HERAData(test_data)
data, flags, nsamples = hd_input.read()
hd_resid = io.HERAData(res_outfilename)
data_r, flags_r, nsamples_r = hd_resid.read()
hd_filled = io.HERAData(filled_outfilename)
data_f, flags_f, nsamples_f = hd_filled.read()

for bl in data:
assert np.mean(np.abs(data_r[bl]) ** 2.) <= .2 * np.mean(np.abs(data[bl]) ** 2.)

def test_sky_frates_minfrate_and_to_filter(self):
# test edge frates
V = frf.FRFilter(os.path.join(DATA_PATH, "PyGSM_Jy_downselect.uvh5"))
V.read()
for to_filter in [None, list(V.data.keys())[:1]]:
cfrates, wfrates = frf.sky_frates(V.hd, min_frate_width=1000, blkeys=to_filter)
# to_filter set to None -> all keys should be present.
if to_filter is None:
for k in V.data:
assert k in cfrates
assert k in wfrates
# min_frate = 1000 should set all wfrates to 1000
for k in cfrates:
assert wfrates[k] == 1000.

def test_load_tophat_frfilter_and_write(self, tmpdir):
tmp_path = tmpdir.strpath
uvh5 = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5")
Expand Down
15 changes: 0 additions & 15 deletions hera_cal/tests/test_vis_clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,21 +1028,6 @@ def test_neb(self):
n = vis_clean.noise_eq_bandwidth(dspec.gen_window('blackmanharris', 10000))
assert np.isclose(n, 1.9689862471203075)

def test_sky_frates_minfrate_and_to_filter(self):
# test edge frates
V = VisClean(os.path.join(DATA_PATH, "PyGSM_Jy_downselect.uvh5"))
V.read()
for to_filter in [None, list(V.data.keys())[:1]]:
cfrates, wfrates = V.sky_frates(min_frate=1000, to_filter=to_filter)
# to_filter set to None -> all keys should be present.
if to_filter is None:
for k in V.data:
assert k in cfrates
assert k in wfrates
# min_frate = 1000 should set all wfrates to 1000
for k in cfrates:
assert wfrates[k] == 1000.

def test_zeropad(self):
fname = os.path.join(DATA_PATH, "zen.2458043.40141.xx.HH.XRAA.uvh5")
V = VisClean(fname, filetype='uvh5')
Expand Down
87 changes: 28 additions & 59 deletions hera_cal/vis_clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .datacontainer import DataContainer
from .utils import echo
from .flag_utils import factorize_flags
import tensorflow as tf


def find_discontinuity_edges(x, xtol=1e-3):
Expand Down Expand Up @@ -793,6 +794,7 @@ def fourier_filter(self, filter_centers, filter_half_widths, mode='clean',
keep_flags=False, clean_flags_in_resid_flags=False,
skip_if_flag_within_edge_distance=0,
flag_model_rms_outliers=False, model_rms_threshold=1.1,
gpu_index=None, gpu_memory_limit=None,
**filter_kwargs):
"""
Generalized fourier filtering wrapper for uvtools.dspec.fourier_filter.
Expand Down Expand Up @@ -981,6 +983,19 @@ def fourier_filter(self, filter_centers, filter_half_widths, mode='clean',
unnecessarily long. If it is too high, clean does a poor job of deconvolving.
'alpha': float, if window is 'tukey', this is its alpha parameter.
"""
gpus = tf.config.list_physical_devices("GPU")
if gpu_index is not None:
# See https://www.tensorflow.org/guide/gpu
if gpus:
if gpu_memory_limit is None:
tf.config.set_visible_devices(gpus[gpu_index], "GPU")
else:
tf.config.set_logical_device_configuration(
gpus[gpu_index], [tf.config.LogicalDeviceConfiguration(memory_limit=gpu_memory_limit * 1024)]
)

logical_gpus = tf.config.list_logical_devices("GPU")
echo(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU", verbose=verbose)
# type checks
if ax == 'both':
if zeropad is None:
Expand Down Expand Up @@ -1113,10 +1128,18 @@ def fourier_filter(self, filter_centers, filter_half_widths, mode='clean',
win = flag_rows_with_flags_within_edge_distance(xp, win, skip_if_flag_within_edge_distance, ax=ax)

mdl, res = np.zeros_like(d), np.zeros_like(d)
mdl, res, info = dspec.fourier_filter(x=xp, data=din, wgts=win, filter_centers=filter_centers,
filter_half_widths=filter_half_widths,
mode=mode, filter_dims=filterdim, skip_wgt=skip_wgt,
**filter_kwargs)
# limit computations to use specified GPU if they were provided.
if gpu_index is not None and gpus:
with tf.device(f"/device:GPU:{gpus[gpu_index].name[-1]}"):
mdl, res, info = dspec.fourier_filter(x=xp, data=din, wgts=win, filter_centers=filter_centers,
filter_half_widths=filter_half_widths,
mode=mode, filter_dims=filterdim, skip_wgt=skip_wgt,
**filter_kwargs)
else:
mdl, res, info = dspec.fourier_filter(x=xp, data=din, wgts=win, filter_centers=filter_centers,
filter_half_widths=filter_half_widths,
mode=mode, filter_dims=filterdim, skip_wgt=skip_wgt,
**filter_kwargs)
# insert back the filtered model if we are skipping flagged edgs.
if skip_flagged_edges:
mdl = restore_flagged_edges(xp, mdl, edges, ax=ax)
Expand Down Expand Up @@ -1412,61 +1435,6 @@ def write_filtered_data(self, res_outfilename=None, CLEAN_outfilename=None, fill
self.write_data(data_out, outfilename, filetype=filetype, overwrite=clobber, flags=flags_out,
add_to_history=add_to_history, extra_attrs=extra_attrs, **kwargs)

def sky_frates(self, to_filter=None, frate_standoff=0.0, frac_frate_sky_max=1.0, min_frate=0.025):
"""Automatically compute sky fringe-rate ranges based on baselines and telescope location.

Parameters
----------
to_filter: list of antpairpol tuples, optional
list of antpairpols to generate sky fringe-rate centers and widths for.
Default is None -> use all keys in self.data.
frate_standoff: float, optional
Additional fringe-rate standoff in mHz to add to Omega_E b_{EW} nu/c for fringe-rate inpainting.
default = 0.0.
frac_frate_sky_max: float, optional
fraction of horizon to fringe-rate filter.
default is 1.0
min_frate: float, optional
minimum fringe-rate to filter, regardless of baseline length in mHz.
Default is 0.025
TODO: Look into per-frequency fringe-rate centers and widths (currently uses max freq for broadest frate range).

Returns
-------
center_frates: DataContainer object,
DataContainer with the center fringe-rate of each baseline in to_filter in units of mHz.
width_frates: DataContainer object
DataContainer with the half widths of each fringe-rate window around the center_frates in units of mHz.

"""
if to_filter is None:
to_filter = self.data.keys()
# compute maximum fringe rate dict based on baseline lengths.
blcosines = {k: self.blvecs[k[:2]][0] / np.linalg.norm(self.blvecs[k[:2]]) for k in to_filter}
frateamps = {k: 1. / (24. * 3.6) * self.freqs.max() / 3e8 * 2 * np.pi * np.linalg.norm(self.blvecs[k[:2]]) for k in to_filter}
# set autocorrs to have blcose of 0.0
for k in blcosines:
if np.isnan(blcosines[k]):
blcosines[k] = 0.0
sinlat = np.sin(np.abs(self.hd.telescope_location_lat_lon_alt[0]))
max_frates = io.DataContainer({})
min_frates = io.DataContainer({})
center_frates = io.DataContainer({})
width_frates = io.DataContainer({})
# calculate min/max center fringerates.
# these depend on the sign of the blcosine.
for k in to_filter:
if blcosines[k] >= 0:
max_frates[k] = frateamps[k] * np.sqrt(sinlat ** 2. + blcosines[k] ** 2. * (1 - sinlat ** 2.))
min_frates[k] = -frateamps[k] * sinlat
else:
min_frates[k] = -frateamps[k] * np.sqrt(sinlat ** 2. + blcosines[k] ** 2. * (1 - sinlat ** 2.))
max_frates[k] = frateamps[k] * sinlat
center_frates[k] = (max_frates[k] + min_frates[k]) / 2.
width_frates[k] = np.abs(max_frates[k] - min_frates[k]) / 2. * frac_frate_sky_max + frate_standoff
width_frates[k] = np.max([width_frates[k], min_frate]) # Don't allow frates smaller then min_frate
return center_frates, width_frates

def zeropad_data(self, data, binvals=None, zeropad=0, axis=-1, undo=False):
"""
Iterate through DataContainer "data" and zeropad it inplace.
Expand Down Expand Up @@ -1979,6 +1947,7 @@ def list_of_int_tuples(v):
ap.add_argument("--filter_spw_ranges", default=None, type=list_of_int_tuples, help="List of spw channel selections to filter independently. Two acceptable formats are "
"Ex1: '200~300,500~650' --> [(200, 300), (500, 650), ...] and "
"Ex2: '200 300, 500 650' --> [(200, 300), (500, 650), ...]")
ap.add_argument("--use_tensorflow", default=False, action="store_true", help="If provided, will use tensorflow GPU accelerated methods where possible.")
# Arguments for CLEAN. Not used in linear filtering methods.
clean_options = ap.add_argument_group(title='Options for CLEAN (arguments only used if mode=="clean"!)')
clean_options.add_argument("--window", type=str, default='blackman-harris', help='window function for frequency filtering (default "blackman-harris",\
Expand Down
2 changes: 1 addition & 1 deletion scripts/delay_filter_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@
CLEAN_outfilename=ap.CLEAN_outfilename,
standoff=ap.standoff, horizon=ap.horizon, tol=ap.tol,
skip_wgt=ap.skip_wgt, min_dly=ap.min_dly, zeropad=ap.zeropad,
filter_spw_ranges=ap.filter_spw_ranges,
filter_spw_ranges=ap.filter_spw_ranges, use_tensorflow=use_tensorflow,
clean_flags_in_resid_flags=True, **filter_kwargs)
9 changes: 5 additions & 4 deletions scripts/tophat_frfilter_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@
filter_kwargs['skip_contiguous_flags'] = False
filter_kwargs['skip_flagged_edges'] = False
filter_kwargs['flag_model_rms_outliers'] = False
elif ap.mode == 'dpss_leastsq':
elif ap.mode == 'dpss_leastsq' or ap.mode == 'dft_leastsq':
filter_kwargs = {}
avg_red_bllens = True
filter_kwargs['skip_contiguous_flags']=True
filter_kwargs['skip_flagged_edges'] = True
filter_kwargs['max_contiguous_edge_flags'] = 1
filter_kwargs['flag_model_rms_outliers'] = True


if ap.cornerturnfile is not None:
baseline_list = io.baselines_from_filelist_position(filename=ap.cornerturnfile, filelist=ap.datafilelist)
else:
Expand All @@ -60,5 +59,7 @@
overwrite_flags=ap.overwrite_flags, skip_autos=ap.skip_autos,
skip_if_flag_within_edge_distance=ap.skip_if_flag_within_edge_distance,
zeropad=ap.zeropad, tol=ap.tol, skip_wgt=ap.skip_wgt, max_frate_coeffs=ap.max_frate_coeffs,
frac_frate_sky_max=ap.frac_frate_sky_max, frate_standoff=ap.frate_standoff, min_frate=ap.min_frate,
clean_flags_in_resid_flags=True, **filter_kwargs)
frate_standoff=ap.frate_standoff, min_frate_half_width=ap.min_frate_half_width,
frate_width_multiplier=ap.frate_width_multiplier, fr_freq_skip=ap.fr_freq_skip,
uvbeam=ap.uvbeam, percentile_low=ap.percentile_low, percentile_high=ap.percentile_high,
clean_flags_in_resid_flags=True, use_tensorflow=ap.use_tensorflow, **filter_kwargs)
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def package_files(package_dir, subdirectory):
'scripts/smooth_cal_run.py', 'scripts/redcal_run.py',
'scripts/auto_reflection_run.py', 'scripts/noise_from_autos.py',
'scripts/query_ex_ants.py', 'scripts/red_average.py',
'scripts/time_average.py',
'scripts/time_average.py', 'scripts/tophat_frfilter_run.py',
'scripts/time_chunk_from_baseline_chunks_run.py', 'scripts/chunk_files.py', 'scripts/transfer_flags.py',
'scripts/flag_all.py', 'scripts/throw_away_flagged_antennas.py', 'scripts/select_spw_ranges.py'],
'version': version.version,
Expand All @@ -58,7 +58,7 @@ def package_files(package_dir, subdirectory):
'extras_require': {
"all": [
'aipy>=3.0',
'uvtools @ git+git://github.com/HERA-Team/uvtools',
'uvtools @ git+git://github.com/HERA-Team/uvtools@gpu_support',
]
},
'zip_safe': False,
Expand Down