diff --git a/ci/hera_cal_tests.yml b/ci/hera_cal_tests.yml index 6ae7b9333..de9534b28 100644 --- a/ci/hera_cal_tests.yml +++ b/ci/hera_cal_tests.yml @@ -23,7 +23,7 @@ dependencies: - pip: #- git+https://github.com/RadioAstronomySoftwareGroup/pyuvdata - - git+https://github.com/HERA-Team/uvtools + - git+https://github.com/HERA-Team/uvtools@gpu_support - git+https://github.com/HERA-Team/linsolve - git+https://github.com/HERA-Team/hera_qm - git+https://github.com/RadioAstronomySoftwareGroup/pyuvsim diff --git a/hera_cal/chunker.py b/hera_cal/chunker.py index a312338a0..67e171d14 100644 --- a/hera_cal/chunker.py +++ b/hera_cal/chunker.py @@ -68,7 +68,7 @@ def chunk_files(filenames, inputfile, outputfile, chunk_size, type="data", polarizations = chunked_files.pols if spw_range is None: spw_range = (0, chunked_files.Nfreqs) - data, flags, nsamples = chunked_files.read(axis='blt', polarizations=polarizations, + data, flags, nsamples = chunked_files.read(polarizations=polarizations, freq_chans=range(spw_range[0], spw_range[1])) elif type == 'gains': chunked_files.read() diff --git a/hera_cal/data/fr_unittest_beam.beamfits b/hera_cal/data/fr_unittest_beam.beamfits new file mode 100644 index 000000000..701e07619 Binary files /dev/null and b/hera_cal/data/fr_unittest_beam.beamfits differ diff --git a/hera_cal/data/fr_unittest_data.uvh5 b/hera_cal/data/fr_unittest_data.uvh5 new file mode 100644 index 000000000..088fd9180 Binary files /dev/null and b/hera_cal/data/fr_unittest_data.uvh5 differ diff --git a/hera_cal/delay_filter.py b/hera_cal/delay_filter.py index 272b0e3a0..9a79e5739 100644 --- a/hera_cal/delay_filter.py +++ b/hera_cal/delay_filter.py @@ -212,5 +212,4 @@ def delay_filter_argparser(): filt_options.add_argument("--horizon", type=float, default=1.0, help='proportionality constant for bl_len where 1.0 (default) is the horizon\ (full light travel time)') filt_options.add_argument("--min_dly", type=float, default=0.0, help="A minimum delay threshold [ns] used for filtering.") - filt_options.add_argument("--skip_if_flag_within_edge_distance", type=int, default=0, help="skip integrations if there is a flag within this integer distance of edge.") return a diff --git a/hera_cal/frf.py b/hera_cal/frf.py index b8fcf609f..2d35eb342 100644 --- a/hera_cal/frf.py +++ b/hera_cal/frf.py @@ -13,14 +13,359 @@ from .datacontainer import DataContainer from .vis_clean import VisClean -from pyuvdata import UVData, UVFlag +from pyuvdata import UVData, UVFlag, UVBeam import argparse from . import io +from . import vis_clean import warnings +from astropy.coordinates import SkyCoord, EarthLocation, AltAz, ITRS +import astropy.units as units +import astropy.constants as const +import healpy as hp +from astropy.time import Time +from pyuvdata import utils as uvutils +from . import utils +import copy +import astropy.constants as const +from .utils import echo +import datetime + + +SPEED_OF_LIGHT = const.c.si.value +SDAY_KSEC = 86163.93 / 1000. + + +def build_fringe_rate_profiles(uvd, uvb, blkeys=None, normed=True, combine_pols=True, nfr=None, dfr=None, + taper='none', fr_freq_skip=1, verbose=False): + """ + Calculate fringe-rate profiles to either directly apply as an FIR filter or set a range to filter. + + Parameters + ---------- + uvd: UVData object + UVData holding baselines for which we will build fringe-rate profiles. + uvb: UVBeam object + UVBeam object holding beams that we will build uvbeam profiles for. + percentile_low: float, optional + Percent of beam-squared power below lower fringe rate. + percentile_high: float, optional + Percent of beam-squared power above upper fringe rate. + to_filter: list of antpairpol tuples + list of antpairpol tuples of baselines to calculate fringe-rate limits for. + dfr: float, optional. + spacing of fringe-rate grid used to perform binning and percentile calc + in units of mHz. + default is None -> set to 1 / (dtime * Ntime) of uvd. + nfr: float, optional. + number of points on fringe-rate grid to perform binning and percentile calc. + default is None -> set to uvd.Ntimes. + taper: str, optional + taper expected for power spectrum calculations. Fringe-rates from different frequencies + fr_freq_skip: int, optional + bin fringe rates from every freq_skip channels. + default is 1 -> takes a long time. We recommend setting this to be larger. + verbose: bool, optional + lots of text output. + Returns + ------- + fr_grid: np.ndarray + length nfr grid of fringe-rates spaced by dfr centered at zero. + + profiles: Dictionary object. Maps antpairpol tuples to numpy.ndarray object + with the sum of the beam squared in all directions falling into + a particular fringe-rate bin. + + """ + + uvb = copy.deepcopy(uvb) + # convert to power and healpix if necesssary + if uvb.beam_type == 'efield': + uvb.efield_to_power() + try: + uvb.to_healpix() + except ValueError as err: + warnings.warn("UVBeam object already in healpix format...") + + antpos_trf = uvd.antenna_positions # earth centered antenna positions + antnums = uvd.antenna_numbers # antenna numbers. + + + lat, lon, alt = uvd.telescope_location_lat_lon_alt_degrees + location = EarthLocation(lon=lon * units.deg, lat=lat * units.deg, height=alt * units.m) + + # get topocentricl AzEl Beam coordinates. + npix = uvb.data_array.shape[-1] + nside = hp.npix2nside(npix) + polar, az = hp.pix2ang(nside=nside, ipix=range(npix)) + # cancel out super-horizon beam + alt = np.pi / 2. - polar + + # Covert AltAz coordinates of UVBeam pixels to barycentric coordinates. + obstime = Time(np.median(np.unique(uvd.time_array)), format='jd') + altaz = AltAz(obstime=obstime, location=location) + itrs = ITRS() + # coordinates of beam pixels in topocentric frame. + altaz_coords = SkyCoord(alt=alt * units.rad, az=az * units.rad, frame=altaz) + # transform beam pixels from topocentric to ITRS + eq_coords = altaz_coords.transform_to(itrs) + # get cartesian xyz unit vectors in the direction of each beam pixle. + eq_xyz = np.vstack([eq_coords.x, eq_coords.y, eq_coords.z]) + + # generate fringe_rate grid in mHz. + dt = SDAY_KSEC * np.median(np.diff(np.unique(uvd.time_array))) # times in kSec + if nfr is None: + nfr = uvd.Ntimes + if dfr is None: + # if no dfr provided, set to 1 / (ntimes * dt) + dfr = 1. / (dt * nfr) + + # build grid. + fr_grid = np.arange(-nfr // 2, nfr // 2) * dfr + fr0 = fr_grid[uvd.Ntimes // 2] + + # frequency tapering function expected for power spectra. + # square b/c for power spectrum power preservation. + ftaper = dspec.gen_window(taper, uvd.Nfreqs) ** 2. + if blkeys is None: + blkeys = uvd.get_antpairpols() + profiles = {} + ap_blkeys = {} # keep track of different polarizations for each antenna pair if we are going to sum over polarizations. + for bl in blkeys: + echo("Generating FR-Profile of {} at {}".format(bl, str(datetime.datetime.now())), verbose=verbose) + # sum beams from all frequencies + # get polarization number + polnum = np.where(uvutils.polstr2num(bl[-1], x_orientation=uvb.x_orientation) == uvb.polarization_array)[0][0] + # get baseline vector in equitorial coordinates. + ind1 = np.where(antnums == bl[0])[0][0] + ind2 = np.where(antnums == bl[1])[0][0] + blvec = uvd.antenna_positions[ind2] - uvd.antenna_positions[ind1] + # initialize binned power. + # we will bin frate power together for all frequencies, weighted by taper. + binned_power = np.zeros_like(fr_grid) + # iterate over each frequency and ftaper weighting. + # use linspace to make sure we get first and last frequencies. + chans_to_use = np.linspace(0, uvd.Nfreqs - 1, int(uvd.Nfreqs / fr_freq_skip)).astype(int) + for f0, fw in zip(uvd.freq_array[0, chans_to_use], ftaper[chans_to_use]): + frates = np.dot(np.cross(np.array([0, 0, 1.]), blvec), eq_xyz) * 2 * np.pi * f0 / SPEED_OF_LIGHT / SDAY_KSEC + # square of power beam values in directions of sky pixels + bsq = np.abs(uvb.data_array[0, 0, polnum, np.argmin(np.abs(f0 - uvb.freq_array[0])), :].squeeze()) ** 2. + # set beam below horizon to be zero. + bsq[polar >= np.pi / 2.] = 0. + # get fringe-rate bin membership for each pixel. + fr_bins = np.round(frates / dfr + nfr / 2).astype(int) + # bin power. + for binnum in range(nfr): + # For each bin, find all pixels that fall in that fr bin and add the sum beam-square values in each pixel + # times the frequency weighing value set by taper. + binned_power[binnum] += np.sum(bsq[fr_bins == binnum]) * fw # add sum of beam squared times taper weight. + profiles[bl] = binned_power + profiles[utils.reverse_bl(bl)] = binned_power[::-1] + if bl[:2] not in ap_blkeys: + ap_blkeys[bl[:2]] = [bl] + ap_blkeys[utils.reverse_bl(bl)[:2]] = [utils.reverse_bl(bl)] + else: + ap_blkeys[bl[:2]].append(bl) # append baselines + ap_blkeys[utils.reverse_bl(bl)[:2]].append(utils.reverse_bl(bl)) + + # combine polarizations by summing over all profiles for each antenna-pair. + if combine_pols: + for ap in ap_blkeys: + profile_summed_over_pols = np.sum([profiles[bl] for bl in ap_blkeys[ap]], axis=0) + for bl in ap_blkeys[ap]: + profiles[bl] = profile_summed_over_pols + # normalize if desiered + if normed: + for bl in profiles: + profiles[bl] /= np.sum(profiles[bl]) + return fr_grid, profiles + + + +def get_fringe_rate_limits(uvd, uvb=None, fr_profiles=None, percentile_low=5., percentile_high=95., blkeys=None, + dfr=None, nfr=None, taper='none', frate_standoff=0.0, + frate_width_multiplier=1.0, min_frate_half_width=0.025, + fr_freq_skip=1, verbose=False): + """ + Get bounding fringe-rates for isotropic emission for a UVBeam object across all frequencies. + + Parameters + ---------- + uvd: UVData object + UVData holding baselines for which we will build fringe-rate profiles. + uvb: UVBeam object, optional + UVBeam object holding beams that we will build uvbeam profiles for. + default is None -> use pre-computed frate_profiles + fr_profiles: dict object, optional + Dictionary mapping antpairpol keys to fringe-rate profiles. + centered on a grid of length nfr spaced by dfr centered at 0. + percentile_low: float, optional + Percent of beam-squared power below lower fringe rate. + percentile_high: float, optional + Percent of beam-squared power above upper fringe rate. + to_filter: list of antpairpol tuples + list of antpairpol tuples of baselines to calculate fringe-rate limits for. + dfr: float, optional. + spacing of fringe-rate grid used to perform binning and percentile calc + in units of mHz. + default is None -> set to 1 / (dtime * Ntime) of uvd. + nfr: float, optional. + number of points on fringe-rate grid to perform binning and percentile calc. + default is None -> set to uvd.Ntimes. + taper: str, optional + taper expected for power spectrum calculations. Fringe-rates from different frequencies + frate_standoff: float, optional + Additional fringe-rate standoff in mHz to add to Omega_E b_{EW} nu/c for fringe-rate inpainting. + default = 0.0. + frate_width_multiplier: float, optional + fraction of horizon to fringe-rate filter. + default is 1.0 + min_frate_half_width: float, optional + minimum fringe-rate width to filter, regardless of baseline length in mHz. + Default is 0.025 + fr_freq_skip: int, optional + bin fringe rates from every freq_skip channels. + default is 1 -> takes a long time. We recommend setting this to be larger. + verbose: bool, optional + lots of text output, default is False. + + Returns + ------- + frate_centers: dict object, + Dictionary with the center fringe-rate of each baseline in to_filter in units of mHz. + frate_half_widths: dict object + Dictionary with the half widths of each fringe-rate window around the frate_centers in units of mHz. + """ + + + frate_centers = {} + frate_half_widths = {} + if blkeys is None: + blkeys = uvd.get_antpairpols() + + if fr_profiles is None: + if uvb is not None: + fr_grid, fr_profiles = build_fringe_rate_profiles(uvd=uvd, uvb=uvb, blkeys=blkeys, normed=True, nfr=nfr, dfr=dfr, + taper=taper, fr_freq_skip=fr_freq_skip, verbose=verbose) + else: + raise ValueError("Must either supply uvb or fr_profiles!") + else: + if nfr is None: + nfr = uvd.Ntimes + if dfr is None: + dfr = 1. / (nfr * np.mean(np.diff(np.unique(uvd.time_array))) * SDAY_KSEC) + fr_grid = np.arange(-nfr // 2, nfr // 2) * dfr + for bl in blkeys: + binned_power = fr_profiles[bl] + # normalize to sum to 100. + binned_power /= np.sum(binned_power) + binned_power *= 100. + # get CDF as function of fringe-rate bin. + cspower = np.cumsum(binned_power) + # find low and high bins containing mass between percentile_low and percentile_high. + frlow = np.argmin(np.abs(cspower - percentile_low)) + frhigh = np.argmin(np.abs(cspower - percentile_high)) + frlow = fr_grid[frlow] + frhigh = fr_grid[frhigh] + # save low and high fringe rates for bl and its conjugate + frate_centers[bl] = .5 * (frlow + frhigh) + frate_half_widths[bl] = .5 * np.abs(frlow - frhigh) * frate_width_multiplier + frate_standoff + frate_half_widths[bl] = np.max([frate_half_widths[bl], min_frate_half_width]) + + frate_centers[utils.reverse_bl(bl)] = - frate_centers[bl] + frate_half_widths[utils.reverse_bl(bl)] = frate_half_widths[bl] + + return frate_centers, frate_half_widths + + +def sky_frates(uvd, blkeys=None, frate_standoff=0.0, frate_width_multiplier=1.0, min_frate_half_width=0.025): + """Automatically compute sky fringe-rate ranges based on baselines and telescope location. + + Parameters + ---------- + uvd: UVData object + uvdata object of data to compute sky-frate limits for. + blkeys: list of antpairpol tuples, optional + list of antpairpols to generate sky fringe-rate centers and widths for. + Default is None -> use all keys in self.data. + frate_standoff: float, optional + Additional fringe-rate standoff in mHz to add to Omega_E b_{EW} nu/c for fringe-rate inpainting. + default = 0.0. + frate_width_multiplier: float, optional + fraction of horizon to fringe-rate filter. + default is 1.0 + min_frate_half_width: float, optional + minimum fringe-rate to filter, regardless of baseline length in mHz. + Default is 0.025 + + Returns + ------- + frate_centers: DataContainer object, + DataContainer with the center fringe-rate of each baseline in to_filter in units of mHz. + frate_half_widths: DataContainer object + DataContainer with the half widths of each fringe-rate window around the frate_centers in units of mHz. -def timeavg_waterfall(data, Navg, flags=None, nsamples=None, wgt_by_nsample=True, - wgt_by_favg_nsample=False, rephase=False, lsts=None, freqs=None, + """ + if blkeys is None: + blkeys = uvd.get_antpairpols() + antpos, antnums = uvd.get_ENU_antpos() + sinlat = np.sin(np.abs(uvd.telescope_location_lat_lon_alt[0])) + frate_centers = {} + frate_half_widths = {} + + # compute maximum fringe rate dict based on baseline lengths. + for k in blkeys: + ind1 = np.where(antnums == k[0])[0][0] + ind2 = np.where(antnums == k[1])[0][0] + blvec = antpos[ind1] - antpos[ind2] + blcos = blvec[0] / np.linalg.norm(blvec[:2]) + if np.isfinite(blcos): + frateamp_df = np.linalg.norm(blvec[:2]) / SDAY_KSEC / SPEED_OF_LIGHT * 2 * np.pi + # set autocorrs to have blcose of 0.0 + + if blcos >= 0: + max_frate_df = frateamp_df * np.sqrt(sinlat ** 2. + blcos ** 2. * (1 - sinlat ** 2.)) + min_frate_df = -frateamp_df * sinlat + else: + min_frate_df = -frateamp_df * np.sqrt(sinlat ** 2. + blcos ** 2. * (1 - sinlat ** 2.)) + max_frate_df = frateamp_df * sinlat + + min_frate = np.min([f0 * min_frate_df for f0 in uvd.freq_array[0]]) + max_frate = np.max([f0 * max_frate_df for f0 in uvd.freq_array[0]]) + else: + max_frate = 0.0 + min_frate = 0.0 + + frate_centers[k] = (max_frate + min_frate) / 2. + frate_centers[utils.reverse_bl(k)] = -frate_centers[k] + + frate_half_widths[k] = np.abs(max_frate - min_frate) / 2. * frate_width_multiplier + frate_standoff + frate_half_widths[k] = np.max([frate_half_widths[k], min_frate_half_width]) # Don't allow frates smaller then min_frate + frate_half_widths[utils.reverse_bl(k)] = frate_half_widths[k] + + return frate_centers, frate_half_widths + +def sky_mainlobe_complement(uvd, uvb, percentile_low=5., percentile_high=95., blkeys=None, + dfr=None, nfr=None, taper='none', frate_standoff=0.0, + frate_width_multiplier=1.0, min_frate_half_width=0.025, side='lower', extension=0.0): + """ + Method for finding the fringe-rates that fall on the sky but are outside of the main-lobe. + + + This method is meant to allow for subtracting overwhelmingly bright emission from outside of the main-lobe + whose side-libes might be picked up by a main-lobe filter if it is not first subtracted. + It computes main-lobe fringe rates and sky-fringe rates and then gives the fringe-rate bounds for emission + below the main-lobe but still on the sky. + + Parameters + ---------- + + """ + return + + +def timeavg_waterfall(data, Navg, flags=None, nsamples=None, wgt_by_nsample=True, + wgt_by_favg_nsample=False, rephase=False, lsts=None, freqs=None, bl_vec=None, lat=-30.72152, extra_arrays={}, verbose=True): """ Calculate the time average of a visibility waterfall. The average is optionally @@ -497,6 +842,143 @@ def filter_data(self, data, frps, flags=None, nsamples=None, filt_flags[k] = f filt_nsamples[k] = eff_nsamples + def run_tophat_frfilter(self, to_filter=None, weight_dict=None, mode='clean', + uvb=None, percentile_low=5., percentile_high=95., + frate_standoff=0.0, + frate_width_multiplier=1.0, min_frate_half_width=0.025, max_frate_coeffs=None, + skip_wgt=0.1, tol=1e-9, verbose=False, cache_dir=None, read_cache=False, + write_cache=False, taper='none', + data=None, flags=None, center_before_filtering=True, fr_freq_skip=1, + **filter_kwargs): + ''' + Interpolate / filter data in time using the physical fringe-rates of the sky. (or constant frate) + Arguments: + to_filter: list of visibilities to filter in the (i,j,pol) format. + If None (the default), all visibilities are filtered. + weight_dict: dictionary or DataContainer with all the same keys as self.data. + Linear multiplicative weights to use for the delay filter. Default, use np.logical_not + of self.flags. uvtools.dspec.fourier_filter will renormalize to compensate. + mode: string specifying filtering mode. See fourier_filter or uvtools.dspec.fourier_filter for supported modes. + uvb: UVBeam object, optional + UVBeam object with model of the primary beam. Supercedes max_frate_coeffs. + for determining fringe-rate limits. + frate_standoff: float, optional + Additional fringe-rate standoff in mHz to add to Omega_E b_{EW} nu/c for fringe-rate inpainting. + default = 0.0. + frate_width_multiplier: float, optional + fraction of horizon or mainlobe radius to fringe-rate filter. + default is 1.0 + min_frate_half_width: float, optional + minimum fringe-rate to filter, regardless of baseline length in mHz. + Default is 0.025 + max_frate_coeffs, 2-tuple float + Maximum fringe-rate coefficients for the model max_frate [mHz] = x1 * EW_bl_len [ m ] + x2." + Providing these overrides the sky-based fringe-rate determination! Default is None. + skip_wgt: skips filtering rows with very low total weight (unflagged fraction ~< skip_wgt). + Model is left as 0s, residual is left as data, and info is {'skipped': True} for that + time. Skipped channels are then flagged in self.flags. + Only works properly when all weights are all between 0 and 1. + tol : float, optional. To what level are foregrounds subtracted. + verbose: If True print feedback to stdout + cache_dir: string, optional, path to cache file that contains pre-computed dayenu matrices. + see uvtools.dspec.dayenu_filter for key formats. + read_cache: bool, If true, read existing cache files in cache_dir before running. + write_cache: bool. If true, create new cache file with precomputed matrices + that were not in previously loaded cache files. + taper: str, optional + taper expected for power spectrum calculations. Fringe-rates from different frequencies + cache: dictionary containing pre-computed filter products. + skip_flagged_edges : bool, if true do not include edge times in filtering region (filter over sub-region). + verbose: bool, optional, lots of outputs! + center_before_filtering: bool, optional + shift the data by multiplying by the center fringe-rate and filter a window centered at zero fringe rate. + This improves filter stability when the filtering window is highly offset from zero since we avoid interpolating + with fine-time-scale modes. + fr_freq_skip: int, optional + bin fringe rates from every freq_skip channels. + default is 1 -> takes a long time. We recommend setting this to be larger. + filter_kwargs: see fourier_filter for a full list of filter_specific arguments. + + Results are stored in: + self.clean_resid: DataContainer formatted like self.data with only high-fringe-rate components + self.clean_model: DataContainer formatted like self.data with only low-fringe-rate components + self.clean_info: Dictionary of info from uvtools.dspec.fourier_filter with the same keys as self.data + ''' + if to_filter is None: + to_filter = list(self.data.keys()) + # read in cache + if not mode == 'clean': + if read_cache: + filter_cache = io.read_filter_cache_scratch(cache_dir) + else: + filter_cache = {} + keys_before = list(filter_cache.keys()) + else: + filter_cache = None + if uvb is None and max_frate_coeffs is None: + # if max_frate_coeffs is none and uvb is none, fringe-rate filter all modes that could be occupied by sky emission. + frate_centers, frate_half_widths = sky_frates(self.hd, blkeys=to_filter, frate_standoff=frate_standoff, + frate_width_multiplier=frate_width_multiplier, min_frate_half_width=min_frate_half_width) + elif uvb is None and max_frate_coeffs is not None: + # if uvb is None and max_frate_coeffs is not None, use max_frate_coeffs. + frate_half_widths = io.DataContainer({k: np.max([max_frate_coeffs[0] * self.blvecs[k[:2]][0] + max_frate_coeffs[1], 0.0]) for k in to_filter}) + frate_centers = io.DataContainer({k: 0.0 for k in to_filter}) + + elif uvb is not None: + # if uvb is not None, get fringe-rates from binning. + frate_centers, frate_half_widths = get_fringe_rate_limits(self.hd, uvb, + percentile_low=percentile_low, + percentile_high=percentile_high, + blkeys=to_filter, verbose=verbose, + frate_standoff=frate_standoff, fr_freq_skip=fr_freq_skip, + frate_width_multiplier=frate_width_multiplier, + min_frate_half_width=min_frate_half_width) + + + wgts = io.DataContainer({k: (~self.flags[k]).astype(float) for k in self.flags}) + for k in to_filter: + if mode != 'clean': + filter_kwargs['suppression_factors'] = [tol] + else: + filter_kwargs['tol'] = tol + # center sky-modes at zero fringe-rate if we chose to center_before_filtering. + if center_before_filtering: + phasor = np.exp(2j * np.pi * self.times * SDAY_KSEC * frate_centers[k]) + if 'data' in filter_kwargs: + input_data = filter_kwargs['data'] + else: + input_data = self.data + input_data[k] /= phasor[:, None] + filter_center_to_use = 0.0 + else: + filter_center_to_use = frate_centers[k] + self.fourier_filter(keys=[k], filter_centers=[filter_center_to_use], filter_half_widths=[frate_half_widths[k]], + mode=mode, x=self.times * SDAY_KSEC, + flags=self.flags, wgts=wgts, + ax='time', cache=filter_cache, skip_wgt=skip_wgt, verbose=verbose, **filter_kwargs) + + # recenter data in fringe-rate by multiplying back the phaser if we chose to center_before_filtering. + if center_before_filtering: + if 'output_prefix' in filter_kwargs: + filtered_data = getattr(self, filter_kwargs['output_prefix'] + '_data') + filtered_model = getattr(self, filter_kwargs['output_prefix'] + '_model') + filtered_resid = getattr(self, filter_kwargs['output_prefix'] + '_resid') + else: + filtered_data = self.clean_data + filtered_model = self.clean_model + filtered_resid = self.clean_resid + filtered_data[k] *= phasor[:, None] + filtered_model[k] *= phasor[:, None] + filtered_resid[k] *= phasor[:, None] + if 'data' in filter_kwargs: + input_data = filter_kwargs['data'] + else: + input_data = self.data + input_data[k] *= phasor[:, None] + if not mode == 'clean': + if write_cache: + filter_cache = io.write_filter_cache_scratch(filter_cache, cache_dir, skip_keys=keys_before) + def time_avg_data_and_write(input_data_list, output_data, t_avg, baseline_list=None, wgt_by_nsample=True, wgt_by_favg_nsample=False, rephase=False, @@ -544,7 +1026,7 @@ def time_avg_data_and_write(input_data_list, output_data, t_avg, baseline_list=N "in your dataset. Exiting without writing any output.", RuntimeWarning) else: fr = FRFilter(input_data_list, filetype=filetype) - fr.read(bls=baseline_list, axis='blt') + fr.read(bls=baseline_list) fr.timeavg_data(fr.data, fr.times, fr.lsts, t_avg, flags=fr.flags, nsamples=fr.nsamples, wgt_by_nsample=wgt_by_nsample, wgt_by_favg_nsample=wgt_by_favg_nsample, rephase=rephase) @@ -558,6 +1040,173 @@ def time_avg_data_and_write(input_data_list, output_data, t_avg, baseline_list=N uvf.write(flag_output, clobber=clobber) +def tophat_frfilter_argparser(mode='clean'): + '''Arg parser for commandline operation of tophat fr-filters. + + Parameters + ---------- + mode : string, optional. + Determines sets of arguments to load. + Can be 'clean', 'dayenu', or 'dpss_leastsq'. + + Returns + ------- + argparser + argparser for tophat fringe-rate (time-domain) filtering for specified filtering mode + + ''' + ap = vis_clean._filter_argparser() + filt_options = ap.add_argument_group(title='Options for the fr-filter') + ap.add_argument("--frate_width_multiplier", type=float, default=1.0, help="Fraction of maximum sky-fringe-rate to interpolate / filter." + "Used if select_mainlobe is False and max_frate_coeffs not specified.") + ap.add_argument("--frate_standoff", type=float, default=0.0, help="Standoff in fringe-rate to filter [mHz]." + "Used of select_mainlobe is False and max_frate_coeffs not specified.") + ap.add_argument("--min_frate_half_width", type=float, default=0.025, help="Minimum fringe-rate to filter [mHz].") + ap.add_argument("--max_frate_coeffs", type=float, default=None, nargs=2, help="Maximum fringe-rate coefficients for the model max_frate [mHz] = x1 * EW_bl_len [ m ] + x2." + "Providing these overrides the sky-based fringe-rate determination! Default is None.") + ap.add_argument("--skip_autos", default=False, action="store_true", help="Exclude autos from filtering.") + ap.add_argument("--uvbeam", default=None, type=str, help="Path to UVBeam beamfits file to use for determining isotropic sky fringe-rates to filter.") + ap.add_argument("--percentile_low", default=5.0, type=float, help="Reject fringe-rates with beam power below this percentile if uvbeam is provided.") + ap.add_argument("--percentile_high", default=95.0, type=float, help="Reject fringe-rates with beam power above this percentile if uvbeam is provided.") + ap.add_argument("--taper", default='none', type=str, help="Weight fringe-rates at different frequencies by the square of this taper if uvbeam is provided.") + ap.add_argument("--fr_freq_skip", default=1, type=int, help="fr_freq_skip: int, optional" + "bin fringe rates from every freq_skip channels." + "default is 1 -> takes a long time. We recommend setting this to be larger.") + return ap + + +def load_tophat_frfilter_and_write(datafile_list, baseline_list=None, calfile_list=None, + Nbls_per_load=None, spw_range=None, cache_dir=None, + read_cache=False, write_cache=False, external_flags=None, + factorize_flags=False, time_thresh=0.05, + res_outfilename=None, CLEAN_outfilename=None, filled_outfilename=None, + clobber=False, add_to_history='', avg_red_bllens=False, polarizations=None, + skip_flagged_edges=False, overwrite_flags=False, + flag_yaml=None, skip_autos=False, uvbeam=None, verbose=False, + clean_flags_in_resid_flags=True, **filter_kwargs): + ''' + A tophat fr-filtering method that only simultaneously loads and writes user-provided + list of baselines. This is to support parallelization over baseline (rather then time) if baseline_list is specified. + + Arguments: + datafile_list: list of data files to perform cross-talk filtering on + baseline_list: list of antenna-pair-pol triplets to filter and write out from the datafile_list. + If None, load all baselines in files. Default is None. + calfile_list: optional list of calibration files to apply to data before fr filtering + Nbls_per_load: int, the number of baselines to load at once. + If None, load all baselines at once. default : None. + spw_range: 2-tuple or 2-list, spw_range of data to filter. + cache_dir: string, optional, path to cache file that contains pre-computed dayenu matrices. + see uvtools.dspec.dayenu_filter for key formats. + read_cache: bool, If true, read existing cache files in cache_dir before running. + write_cache: bool. If true, create new cache file with precomputed matrices + that were not in previously loaded cache files. + factorize_flags: bool, optional + If True, factorize flags before running delay filter. See vis_clean.factorize_flags. + time_thresh : float, optional + Fractional threshold of flagged pixels across time needed to flag all times + per freq channel. It is not recommend to set this greater than 0.5. + Fully flagged integrations do not count towards triggering time_thresh. + res_outfilename: path for writing the filtered visibilities with flags + CLEAN_outfilename: path for writing the CLEAN model visibilities (with the same flags) + filled_outfilename: path for writing the original data but with flags unflagged and replaced + with CLEAN models wherever possible + clobber: if True, overwrites existing file at the outfilename + add_to_history: string appended to the history of the output file + avg_red_bllens: bool, if True, round baseline lengths to redundant average. Default is False. + polarizations : list of polarizations to process (and write out). Default None operates on all polarizations in data. + skip_flagged_edges : bool, if true do not include edge times in filtering region (filter over sub-region). + overwrite_flags : bool, if true reset data flags to False except for flagged antennas. + flag_yaml: path to manual flagging text file. + skip_autos: bool, if true, exclude autocorrelations from filtering. Default is False. + autos will still be saved in the resides as zeros, as the models as the data (with original flags). + uvbeam: str, optional. + path to UVBeam object for calculating frate bounds. + default is None -> use other filter kwargs to determine fringe rate bounds. + verbose: bool, optional + Helpful text outputs. + clean_flags_in_resid_flags: bool, optional. If true, include clean flags in residual flags that get written. + default is True. + filter_kwargs: additional keyword arguments to be passed to FRFilter.run_tophat_frfilter() + ''' + if baseline_list is not None and Nbls_per_load is not None: + raise NotImplementedError("baseline loading and partial i/o not yet implemented.") + hd = io.HERAData(datafile_list, filetype='uvh5') + if baseline_list is not None and len(baseline_list) == 0: + warnings.warn("Length of baseline list is zero." + "This can happen under normal circumstances when there are more files in datafile_list then baselines." + "in your dataset. Exiting without writing any output.", RuntimeWarning) + else: + if baseline_list is None: + if len(hd.filepaths) > 1: + baseline_list = list(hd.bls.values())[0] + else: + baseline_list = hd.bls + if spw_range is None: + spw_range = [0, hd.Nfreqs] + freqs = hd.freq_array.flatten()[spw_range[0]:spw_range[1]] + baseline_antennas = [] + for blpolpair in baseline_list: + baseline_antennas += list(blpolpair[:2]) + baseline_antennas = np.unique(baseline_antennas).astype(int) + if calfile_list is not None: + cals = io.HERACal(calfile_list) + cals.read(antenna_nums=baseline_antennas, frequencies=freqs) + else: + cals = None + if polarizations is None: + if len(hd.filepaths) > 1: + polarizations = list(hd.pols.values())[0] + else: + polarizations = hd.pols + if Nbls_per_load is None: + Nbls_per_load = len(baseline_list) + for i in range(0, len(baseline_list), Nbls_per_load): + frfil = FRFilter(hd, input_cal=cals) + frfil.read(bls=baseline_list[i:i + Nbls_per_load], frequencies=freqs) + if avg_red_bllens: + frfil.avg_red_baseline_vectors() + if external_flags is not None: + frfil.apply_flags(external_flags, overwrite_flags=overwrite_flags) + if flag_yaml is not None: + frfil.apply_flags(flag_yaml, overwrite_flags=overwrite_flags, filetype='yaml') + if factorize_flags: + frfil.factorize_flags(time_thresh=time_thresh, inplace=True) + to_filter = frfil.data.keys() + if skip_autos: + to_filter = [bl for bl in to_filter if bl[0] != bl[1]] + if uvbeam is not None: + uvb = UVBeam() + uvb.read_beamfits(uvbeam) + else: + uvb = None + if len(to_filter) > 0: + frfil.run_tophat_frfilter(cache_dir=cache_dir, read_cache=read_cache, write_cache=write_cache, uvb=uvb, + skip_flagged_edges=skip_flagged_edges, to_filter=to_filter, verbose=verbose, **filter_kwargs) + else: + frfil.clean_data = DataContainer({}) + frfil.clean_flags = DataContainer({}) + frfil.clean_resid = DataContainer({}) + frfil.clean_resid_flags = DataContainer({}) + frfil.clean_model = DataContainer({}) + # put autocorr data into filtered data containers if skip_autos = True. + # so that it can be written out into the filtered files. + if skip_autos: + for bl in frfil.data.keys(): + if bl[0] == bl[1]: + frfil.clean_data[bl] = frfil.data[bl] + frfil.clean_flags[bl] = frfil.flags[bl] + frfil.clean_resid[bl] = frfil.data[bl] + frfil.clean_model[bl] = np.zeros_like(frfil.data[bl]) + frfil.clean_resid_flags[bl] = frfil.flags[bl] + + frfil.write_filtered_data(res_outfilename=res_outfilename, CLEAN_outfilename=CLEAN_outfilename, + filled_outfilename=filled_outfilename, partial_write=Nbls_per_load < len(baseline_list), + clobber=clobber, add_to_history=add_to_history, + extra_attrs={'Nfreqs': frfil.hd.Nfreqs, 'freq_array': frfil.hd.freq_array}) + frfil.hd.data_array = None # this forces a reload in the next loop + + def time_average_argparser(): """ Define an argument parser for time averaging data. @@ -584,4 +1233,5 @@ def time_average_argparser(): ap.add_argument("--verbose", default=False, action="store_true", help="verbose output.") ap.add_argument("--flag_output", default=None, type=str, help="optional filename to save a separate copy of the time-averaged flags as a uvflag object.") ap.add_argument("--filetype", default="uvh5", type=str, help="optional filetype specifier. Default is 'uvh5'. Set to 'miriad' if reading miriad files etc...") + return ap diff --git a/hera_cal/tests/test_frf.py b/hera_cal/tests/test_frf.py index b53c76100..b1368a3a6 100644 --- a/hera_cal/tests/test_frf.py +++ b/hera_cal/tests/test_frf.py @@ -14,7 +14,9 @@ from pyuvdata import utils as uvutils import unittest from scipy import stats - +from scipy import constants +from pyuvdata import UVFlag, UVBeam +from .. import utils from .. import datacontainer, io, frf from ..data import DATA_PATH @@ -87,11 +89,11 @@ def test_timeavg_waterfall(): n = np.ones((4, 10)) n[0, 0:5] *= 2 ad, _, _, _, _ = frf.timeavg_waterfall(d, 2, rephase=False, nsamples=n, wgt_by_nsample=True) - np.testing.assert_array_equal(ad[1, :], 1.0) - np.testing.assert_array_equal(ad[0, 0:5], 5. / 3) + np.testing.assert_array_equal(ad[1, :], 1.0) + np.testing.assert_array_equal(ad[0, 0:5], 5. / 3) np.testing.assert_array_equal(ad[0, 5:10], 1.5) ad, _, _, _, _ = frf.timeavg_waterfall(d, 2, rephase=False, nsamples=n, wgt_by_nsample=False, wgt_by_favg_nsample=True) - np.testing.assert_array_equal(ad[1, :], 1.0) + np.testing.assert_array_equal(ad[1, :], 1.0) np.testing.assert_array_equal(ad[0, :], 1.6) @@ -260,3 +262,475 @@ def test_time_average_argparser_multifile(self): assert not args.verbose assert args.flag_output is None assert args.filetype == "uvh5" + + def test_run_tophat_frfilter(self): + fname = os.path.join(DATA_PATH, "zen.2458043.12552.xx.HH.uvORA") + k = (24, 25, 'ee') + frfil = frf.FRFilter(fname, filetype='miriad') + frfil.read(bls=[k]) + bl = np.linalg.norm(frfil.antpos[24] - frfil.antpos[25]) / constants.c * 1e9 + sdf = (frfil.freqs[1] - frfil.freqs[0]) / 1e9 + + frfil.run_tophat_frfilter(tol=1e-2, output_prefix='frfiltered') + for k in frfil.data.keys(): + assert frfil.frfiltered_resid[k].shape == (60, 64) + assert frfil.frfiltered_model[k].shape == (60, 64) + assert k in frfil.frfiltered_info + + # test skip_wgt imposition of flags + fname = os.path.join(DATA_PATH, "zen.2458043.12552.xx.HH.uvORA") + k = (24, 25, 'ee') + # check successful run when avg_red_bllens is True and when False. + for avg_red_bllens in [True, False]: + frfil = frf.FRFilter(fname, filetype='miriad') + frfil.read(bls=[k]) + if avg_red_bllens: + frfil.avg_red_baseline_vectors() + wgts = {k: np.ones_like(frfil.flags[k], dtype=np.float)} + wgts[k][:, 0] = 0.0 + frfil.run_tophat_frfilter(to_filter=[k], weight_dict=wgts, tol=1e-5, window='blackman-harris', skip_wgt=0.1, maxiter=100) + assert frfil.clean_info[k][(0, frfil.Nfreqs)]['status']['axis_0'][0] == 'skipped' + np.testing.assert_array_equal(frfil.clean_flags[k][:, 0], np.ones_like(frfil.flags[k][:, 0])) + np.testing.assert_array_equal(frfil.clean_model[k][:, 0], np.zeros_like(frfil.clean_resid[k][:, 0])) + np.testing.assert_array_equal(frfil.clean_resid[k][:, 0], np.zeros_like(frfil.clean_resid[k][:, 0])) + + def test_load_tophat_frfilter_and_write_baseline_list(self, tmpdir): + tmp_path = tmpdir.strpath + uvh5 = [os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.first.uvh5"), + os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.second.uvh5")] + cals = [os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.uv.abs.calfits_54x_only.part1"), + os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.uv.abs.calfits_54x_only.part2")] + outfilename = os.path.join(tmp_path, 'temp.h5') + cdir = os.path.join(tmp_path, 'cache_temp') + # make a cache directory + if os.path.isdir(cdir): + shutil.rmtree(cdir) + os.mkdir(cdir) + # test graceful exit with baseline list length of zero. + with pytest.warns(RuntimeWarning): + frf.load_tophat_frfilter_and_write(datafile_list=uvh5, baseline_list=[], + calfile_list=cals, spw_range=[100, 200], cache_dir=cdir, + read_cache=True, write_cache=True, avg_red_bllens=True, + res_outfilename=outfilename, clobber=True, + mode='dayenu') + for avg_bl in [True, False]: + frf.load_tophat_frfilter_and_write(datafile_list=uvh5, baseline_list=[(53, 54)], polarizations=['ee'], + calfile_list=cals, spw_range=[100, 200], cache_dir=cdir, + read_cache=True, write_cache=True, avg_red_bllens=avg_bl, + res_outfilename=outfilename, clobber=True, + mode='dayenu') + hd = io.HERAData(outfilename) + d, f, n = hd.read() + assert len(list(d.keys())) == 1 + assert d[(53, 54, 'ee')].shape[1] == 100 + assert d[(53, 54, 'ee')].shape[0] == 60 + # now do no spw range and no cal files just to cover those lines. + frf.load_tophat_frfilter_and_write(datafile_list=uvh5, baseline_list=[(53, 54)], polarizations=['ee'], + cache_dir=cdir, + read_cache=True, write_cache=True, avg_red_bllens=avg_bl, + res_outfilename=outfilename, clobber=True, + mode='dayenu') + hd = io.HERAData(outfilename) + d, f, n = hd.read() + assert len(list(d.keys())) == 1 + assert d[(53, 54, 'ee')].shape[1] == 1024 + assert d[(53, 54, 'ee')].shape[0] == 60 + # now test flag factorization and time thresholding. + # prepare an input files for broadcasting flags + uvh5 = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5") + input_file = os.path.join(tmp_path, 'temp_special_flags.h5') + shutil.copy(uvh5, input_file) + hd = io.HERAData(input_file) + _, flags, _ = hd.read() + ntimes_before = hd.Ntimes + nfreqs_before = hd.Nfreqs + freqs_before = hd.freqs + times_before = hd.times + for bl in flags: + flags[bl][:] = False + flags[bl][0, :hd.Nfreqs // 2] = True # first time has 50% flagged + flags[bl][-3:, -1] = True # last channel has flags for three integrations + hd.update(flags=flags) + hd.write_uvh5(input_file, clobber=True) + # this time_threshold will result in + # entire first integration begin flagged + # and entire final channel being flagged + # when flags are broadcasted. + time_thresh = 2. / hd.Ntimes + for blnum, bl in enumerate(flags.keys()): + outfilename = os.path.join(tmp_path, 'bl_chunk_%d.h5' % blnum) + frf.load_tophat_frfilter_and_write(datafile_list=[input_file], res_outfilename=outfilename, + tol=1e-4, baseline_list=[bl[:2]], polarizations=[bl[-1]], + cache_dir=cdir, + factorize_flags=True, + time_thresh=time_thresh, clobber=True) + # now load all of the outputs in + output_files = glob.glob(tmp_path + '/bl_chunk_*.h5') + hd = io.HERAData(output_files) + d, f, n = hd.read() + hd_original = io.HERAData(uvh5) + for bl in hd_original.bls: + assert bl in d.keys() + + for bl in f: + assert np.all(f[bl][:, -1]) + assert np.all(f[bl][0, :]) + + # test apriori flags and flag_yaml + flag_yaml = os.path.join(DATA_PATH, 'test_input/a_priori_flags_sample.yaml') + uvf = UVFlag(hd, mode='flag', copy_flags=True) + uvf.to_waterfall(keep_pol=False, method='and') + uvf.flag_array[:] = False + flagfile = os.path.join(tmp_path, 'test_flag.h5') + uvf.write(flagfile, clobber=True) + frf.load_tophat_frfilter_and_write(datafile_list=[input_file], res_outfilename=outfilename, + tol=1e-4, baseline_list=[bl[:2]], polarizations=[bl[-1]], + clobber=True, mode='dayenu', + external_flags=flagfile, overwrite_flags=True) + # test that all flags are False + hd = io.HERAData(outfilename) + d, f, n = hd.read() + for k in f: + assert np.all(~f[k]) + # now do the external yaml + frf.load_tophat_frfilter_and_write(datafile_list=[input_file], res_outfilename=outfilename, + tol=1e-4, baseline_list=[bl[:2]], polarizations=[bl[-1]], + clobber=True, mode='dayenu', + external_flags=flagfile, overwrite_flags=True, + flag_yaml=flag_yaml) + # test that all flags are af yaml flags + hd = io.HERAData(outfilename) + d, f, n = hd.read() + for k in f: + assert np.all(f[k][:, 0]) + assert np.all(f[k][:, 1]) + assert np.all(f[k][:, 10:20]) + assert np.all(f[k][:, 60]) + os.remove(outfilename) + shutil.rmtree(cdir) + + def test_load_tophat_frfilter_and_write_multifile(self, tmpdir): + # cover line where baseline-list is None and multiple files are provided. + uvh5s = sorted(glob.glob(DATA_PATH + '/zen.2458045.*.uvh5')) + tmp_path = tmpdir.strpath + outfilename = os.path.join(tmp_path, 'temp_output.uvh5') + frf.load_tophat_frfilter_and_write(uvh5s, filled_outfilename=outfilename, tol=1e-4, clobber=True) + hd = io.HERAData(uvh5s) + d, f, n = hd.read() + hdoutput = io.HERAData(outfilename) + doutput, foutput, nouput = hdoutput.read() + for k in doutput: + assert doutput[k].shape == d[k].shape + + def test_get_fringe_rate_limits(self): + # simulations constructed with the notebook at https://drive.google.com/file/d/1jPPSmL3nqQbp7tTgP77j9KC0802iWyow/view?usp=sharing + test_beam = os.path.join(DATA_PATH, "fr_unittest_beam.beamfits") + test_data = os.path.join(DATA_PATH, "fr_unittest_data.uvh5") + uvd = UVData() + uvd.read_uvh5(test_data) + myfrf = frf.FRFilter(uvd) + myfrf.fft_data(data=myfrf.data, ax='time') + sim_c_frates = {} + sim_w_frates = {} + for bl in myfrf.dfft: + csum = np.cumsum(np.abs(myfrf.dfft[bl]) ** 2.) + csum /= csum.max() + frlow, frhigh = (myfrf.frates[np.argmin(np.abs(csum - 0.05))], myfrf.frates[np.argmin(np.abs(csum - 0.95))]) + sim_c_frates[bl] = .5 * (frlow + frhigh) + sim_w_frates[bl] = .5 * np.abs(frlow - frhigh) + sim_w_frates[utils.reverse_bl(bl)] = sim_w_frates[bl] + sim_c_frates[utils.reverse_bl(bl)] = -sim_c_frates[bl] + uvb = UVBeam() + uvb.read_beamfits(test_beam) + c_frs, w_frs = frf.get_fringe_rate_limits(uvd, uvb) + profile_frates = {} + for bl in sim_c_frates: + assert np.isclose(c_frs[bl], sim_c_frates[bl], atol=0.2, rtol=0.) + assert np.isclose(w_frs[bl], sim_w_frates[bl], atol=0.2, rtol=0.) + + def test_load_tophat_frfilter_and_write_beam_frates(self, tmpdir): + # simulations constructed with the notebook at https://drive.google.com/file/d/1jPPSmL3nqQbp7tTgP77j9KC0802iWyow/view?usp=sharing + # load in primary beam model and isotropic noise model of sky. + test_beam = os.path.join(DATA_PATH, "fr_unittest_beam.beamfits") + test_data = os.path.join(DATA_PATH, "fr_unittest_data.uvh5") + tmp_path = tmpdir.strpath + res_outfilename = os.path.join(tmp_path, 'resid.uvh5') + CLEAN_outfilename = os.path.join(tmp_path, 'model.uvh5') + filled_outfilename = os.path.join(tmp_path, 'filled.uvh5') + # perform cleaning. + to_filter = [] + frf.load_tophat_frfilter_and_write(datafile_list=[test_data], uvbeam=test_beam, mode='dpss_leastsq', filled_outfilename=filled_outfilename, + CLEAN_outfilename=CLEAN_outfilename, frate_standoff=0.075, + res_outfilename=res_outfilename, percentile_high=97.5, percentile_low=2.5) + hd_input = io.HERAData(test_data) + data, flags, nsamples = hd_input.read() + hd_resid = io.HERAData(res_outfilename) + data_r, flags_r, nsamples_r = hd_resid.read() + hd_filled = io.HERAData(filled_outfilename) + data_f, flags_f, nsamples_f = hd_filled.read() + + for bl in data: + assert np.mean(np.abs(data_r[bl]) ** 2.) <= .2 * np.mean(np.abs(data[bl]) ** 2.) + + def test_sky_frates_minfrate_and_to_filter(self): + # test edge frates + V = frf.FRFilter(os.path.join(DATA_PATH, "PyGSM_Jy_downselect.uvh5")) + V.read() + for to_filter in [None, list(V.data.keys())[:1]]: + cfrates, wfrates = frf.sky_frates(V.hd, min_frate_width=1000, blkeys=to_filter) + # to_filter set to None -> all keys should be present. + if to_filter is None: + for k in V.data: + assert k in cfrates + assert k in wfrates + # min_frate = 1000 should set all wfrates to 1000 + for k in cfrates: + assert wfrates[k] == 1000. + + def test_load_tophat_frfilter_and_write(self, tmpdir): + tmp_path = tmpdir.strpath + uvh5 = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5") + outfilename = os.path.join(tmp_path, 'temp.h5') + CLEAN_outfilename = os.path.join(tmp_path, 'temp_clean.h5') + filled_outfilename = os.path.join(tmp_path, 'temp_filled.h5') + frf.load_tophat_frfilter_and_write(uvh5, res_outfilename=outfilename, tol=1e-4, clobber=True, Nbls_per_load=1) + hd = io.HERAData(outfilename) + d, f, n = hd.read(bls=[(53, 54, 'ee')]) + for bl in d: + assert not np.all(np.isclose(d[bl], 0.)) + + frfil = frf.FRFilter(uvh5, filetype='uvh5') + frfil.read(bls=[(53, 54, 'ee')]) + frfil.run_tophat_frfilter(to_filter=[(53, 54, 'ee')], tol=1e-4, verbose=True) + np.testing.assert_almost_equal(d[(53, 54, 'ee')], frfil.clean_resid[(53, 54, 'ee')], decimal=5) + np.testing.assert_array_equal(f[(53, 54, 'ee')], frfil.flags[(53, 54, 'ee')]) + # test NotImplementedError + pytest.raises(NotImplementedError, frf.load_tophat_frfilter_and_write, uvh5, res_outfilename=outfilename, tol=1e-4, + clobber=True, Nbls_per_load=1, avg_red_bllens=True, baseline_list=[(54, 54)], polarizations=['ee']) + + # test loading and writing all baselines at once. + uvh5 = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5") + outfilename = os.path.join(tmp_path, 'temp.h5') + for avg_bl in [True, False]: + frf.load_tophat_frfilter_and_write(uvh5, res_outfilename=outfilename, tol=1e-4, clobber=True, + Nbls_per_load=None, avg_red_bllens=avg_bl) + hd = io.HERAData(outfilename) + d, f, n = hd.read(bls=[(53, 54, 'ee')]) + for bl in d: + assert not np.all(np.isclose(d[bl], 0.)) + + frfil = frf.FRFilter(uvh5, filetype='uvh5') + frfil.read(bls=[(53, 54, 'ee')]) + frfil.run_tophat_frfilter(to_filter=[(53, 54, 'ee')], tol=1e-4, verbose=True) + np.testing.assert_almost_equal(d[(53, 54, 'ee')], frfil.clean_resid[(53, 54, 'ee')], decimal=5) + np.testing.assert_array_equal(f[(53, 54, 'ee')], frfil.flags[(53, 54, 'ee')]) + + cal = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.uv.abs.calfits_54x_only") + outfilename = os.path.join(tmp_path, 'temp.h5') + os.remove(outfilename) + for avg_bl in [True, False]: + frf.load_tophat_frfilter_and_write(uvh5, calfile_list=cal, tol=1e-4, res_outfilename=outfilename, + Nbls_per_load=2, clobber=True, avg_red_bllens=avg_bl) + hd = io.HERAData(outfilename) + assert 'Thisfilewasproducedbythefunction' in hd.history.replace('\n', '').replace(' ', '') + d, f, n = hd.read() + for bl in d: + if not np.all(f[bl]): + assert not np.all(np.isclose(d[bl], 0.)) + np.testing.assert_array_equal(f[(53, 54, 'ee')], True) + os.remove(outfilename) + # test skip_autos + frf.load_tophat_frfilter_and_write(uvh5, calfile_list=None, tol=1e-4, res_outfilename=outfilename, + filled_outfilename=filled_outfilename, CLEAN_outfilename=CLEAN_outfilename, + Nbls_per_load=2, clobber=True, avg_red_bllens=avg_bl, skip_autos=True) + hd = io.HERAData(outfilename) + d, f, n = hd.read() + hd_original = io.HERAData(uvh5) + do, fo, no = hd_original.read() + chd = io.HERAData(CLEAN_outfilename) + cd, cf, cn = chd.read() + fhd = io.HERAData(filled_outfilename) + fd, ff, fn = fhd.read() + # test that the resids are are equal to original data. + for bl in do: + if bl[0] == bl[1]: + assert np.allclose(do[bl], d[bl]) # check that resid equals original data. + assert np.allclose(fo[bl], f[bl]) + assert np.allclose(no[bl], n[bl]) + assert np.allclose(cd[bl], np.zeros_like(cd[bl])) # check that all model values are zero. + assert np.allclose(fd[bl][~f[bl]], d[bl][~f[bl]]) # check that filled data equals original data. + else: + assert not np.allclose(do[bl], d[bl]) + assert np.allclose(no[bl], n[bl]) + + # prepare an input file for broadcasting flags + input_file = os.path.join(tmp_path, 'temp_special_flags.h5') + shutil.copy(uvh5, input_file) + hd = io.HERAData(input_file) + _, flags, _ = hd.read() + ntimes_before = hd.Ntimes + nfreqs_before = hd.Nfreqs + freqs_before = hd.freqs + times_before = hd.times + for bl in flags: + flags[bl][:] = False + flags[bl][0, :hd.Nfreqs // 2] = True # first time has 50% flagged + flags[bl][-3:, -1] = True # last channel has flags for three integrations + hd.update(flags=flags) + hd.write_uvh5(input_file, clobber=True) + # this time_threshold will result in + # entire first integration begin flagged + # and entire final channel being flagged + # when flags are broadcasted. + time_thresh = 2. / hd.Ntimes + frf.load_tophat_frfilter_and_write(input_file, res_outfilename=outfilename, tol=1e-4, + factorize_flags=True, time_thresh=time_thresh, clobber=True) + hd = io.HERAData(outfilename) + d, f, n = hd.read(bls=[(53, 54, 'ee')]) + for bl in f: + assert np.any(f[bl][:, :-1]) + assert np.all(f[bl][0, :]) + + # test delay filtering and writing with factorized flags and partial i/o + frf.load_tophat_frfilter_and_write(input_file, res_outfilename=outfilename, tol=1e-4, + factorize_flags=True, time_thresh=time_thresh, clobber=True) + hd = io.HERAData(outfilename) + d, f, n = hd.read(bls=[(53, 54, 'ee')]) + for bl in f: + # check that flags were broadcasted. + assert np.all(f[bl][0, :]) + assert np.all(f[bl][:, -1]) + assert not np.all(np.isclose(d[bl], 0.)) + + frf.load_tophat_frfilter_and_write(input_file, res_outfilename=outfilename, tol=1e-4, Nbls_per_load=1, + factorize_flags=True, time_thresh=time_thresh, clobber=True) + hd = io.HERAData(outfilename) + d, f, n = hd.read(bls=[(53, 54, 'ee')]) + for bl in f: + # check that flags were broadcasted. + assert np.all(f[bl][0, :]) + assert np.all(f[bl][:, -1]) + assert not np.all(np.isclose(d[bl], 0.)) + + # test apriori flags and flag_yaml + hd = io.HERAData(uvh5) + hd.read() + flag_yaml = os.path.join(DATA_PATH, 'test_input/a_priori_flags_sample.yaml') + uvf = UVFlag(hd, mode='flag', copy_flags=True) + uvf.to_waterfall(keep_pol=False, method='and') + uvf.flag_array[:] = False + flagfile = os.path.join(tmp_path, 'test_flag.h5') + uvf.write(flagfile, clobber=True) + frf.load_tophat_frfilter_and_write(uvh5, res_outfilename=outfilename, + Nbls_per_load=1, clobber=True, mode='dayenu', + external_flags=flagfile, + overwrite_flags=True) + # test that all flags are False + hd = io.HERAData(outfilename) + d, f, n = hd.read(bls=[(53, 54, 'ee')]) + for k in f: + assert np.all(~f[k]) + # now without parital io. + frf.load_tophat_frfilter_and_write(uvh5, res_outfilename=outfilename, + clobber=True, mode='dayenu', + external_flags=flagfile, + overwrite_flags=True) + # test that all flags are False + hd = io.HERAData(outfilename) + d, f, n = hd.read(bls=[(53, 54, 'ee')]) + for k in f: + assert np.all(~f[k]) + + def test_load_dayenu_filter_and_write(self, tmpdir): + tmp_path = tmpdir.strpath + uvh5 = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5") + cdir = os.path.join(tmp_path, 'cache_temp') + # make a cache directory + if os.path.isdir(cdir): + shutil.rmtree(cdir) + os.mkdir(cdir) + outfilename = os.path.join(tmp_path, 'temp.h5') + # run dayenu filter + avg_bl = True + frf.load_tophat_frfilter_and_write(uvh5, res_outfilename=outfilename, + cache_dir=cdir, mode='dayenu', + Nbls_per_load=1, clobber=True, avg_red_bllens=avg_bl, + spw_range=(0, 32), write_cache=True) + # generate duplicate cache files to test duplicate key handle for cache load. + frf.load_tophat_frfilter_and_write(uvh5, res_outfilename=outfilename, cache_dir=cdir, + mode='dayenu', avg_red_bllens=avg_bl, + Nbls_per_load=1, clobber=True, read_cache=False, + spw_range=(0, 32), write_cache=True) + # there should now be six cache files (one per i/o/filter). There are three baselines. + assert len(glob.glob(cdir + '/*')) == 6 + hd = io.HERAData(outfilename) + assert 'Thisfilewasproducedbythefunction' in hd.history.replace('\n', '').replace(' ', '') + d, f, n = hd.read(bls=[(53, 54, 'ee')]) + np.testing.assert_array_equal(f[(53, 54, 'ee')], True) + os.remove(outfilename) + shutil.rmtree(cdir) + os.mkdir(cdir) + # now do all the baselines at once. + for avg_bl in [True, False]: + frf.load_tophat_frfilter_and_write(uvh5, res_outfilename=outfilename, + cache_dir=cdir, mode='dayenu', avg_red_bllens=avg_bl, + Nbls_per_load=None, clobber=True, + spw_range=(0, 32), write_cache=True) + if avg_bl: + assert len(glob.glob(cdir + '/*')) == 1 + hd = io.HERAData(outfilename) + assert 'Thisfilewasproducedbythefunction' in hd.history.replace('\n', '').replace(' ', '') + d, f, n = hd.read(bls=[(53, 54, 'ee')]) + np.testing.assert_array_equal(f[(53, 54, 'ee')], True) + os.remove(outfilename) + shutil.rmtree(cdir) + os.mkdir(cdir) + # run again using computed cache. + calfile = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.uv.abs.calfits_54x_only") + frf.load_tophat_frfilter_and_write(uvh5, res_outfilename=outfilename, max_frate_coeffs=[0.0, 0.025], + cache_dir=cdir, calfile_list=calfile, read_cache=True, + Nbls_per_load=1, clobber=True, mode='dayenu', + spw_range=(0, 32), write_cache=True) + # no new cache files should be generated. + assert len(glob.glob(cdir + '/*')) == 1 + hd = io.HERAData(outfilename) + assert 'Thisfilewasproducedbythefunction' in hd.history.replace('\n', '').replace(' ', '') + d, f, n = hd.read(bls=[(53, 54, 'ee')]) + np.testing.assert_array_equal(f[(53, 54, 'ee')], True) + os.remove(outfilename) + shutil.rmtree(cdir) + + def test_tophat_clean_argparser(self): + sys.argv = [sys.argv[0], 'a', '--clobber', '--window', 'blackmanharris', '--max_frate_coeffs', '0.024', '-0.229'] + parser = frf.tophat_frfilter_argparser() + a = parser.parse_args() + assert a.datafilelist == ['a'] + assert a.clobber is True + assert a.window == 'blackmanharris' + assert a.max_frate_coeffs[0] == 0.024 + assert a.max_frate_coeffs[1] == -0.229 + assert a.time_thresh == 0.05 + assert not a.factorize_flags + + def test_tophat_linear_argparser(self): + sys.argv = [sys.argv[0], 'a', '--clobber', '--write_cache', '--cache_dir', '/blah/', '--max_frate_coeffs', '0.024', '-0.229', '--mode', 'dayenu'] + parser = frf.tophat_frfilter_argparser() + a = parser.parse_args() + assert a.datafilelist == ['a'] + assert a.clobber is True + assert a.write_cache is True + assert a.cache_dir == '/blah/' + assert a.max_frate_coeffs[0] == 0.024 + assert a.max_frate_coeffs[1] == -0.229 + assert a.time_thresh == 0.05 + assert not a.factorize_flags + parser = frf.tophat_frfilter_argparser() + a = parser.parse_args() + assert a.datafilelist == ['a'] + assert a.clobber is True + assert a.write_cache is True + assert a.cache_dir == '/blah/' + assert a.max_frate_coeffs[0] == 0.024 + assert a.max_frate_coeffs[1] == -0.229 + assert a.time_thresh == 0.05 + assert not a.factorize_flags diff --git a/hera_cal/tests/test_vis_clean.py b/hera_cal/tests/test_vis_clean.py index 713e6e5ab..7eeceaf05 100644 --- a/hera_cal/tests/test_vis_clean.py +++ b/hera_cal/tests/test_vis_clean.py @@ -19,7 +19,7 @@ from hera_cal import vis_clean from hera_cal.vis_clean import VisClean from hera_cal.data import DATA_PATH -from hera_cal import xtalk_filter as xf +from hera_cal import frf import glob import copy @@ -1134,9 +1134,9 @@ def test_time_chunk_from_baseline_chunks(self, tmp_path): baselines = io.baselines_from_filelist_position(file, datafiles) fname = 'temp.fragment.part.%d.h5' % filenum fragment_filename = tmp_path / fname - xf.load_xtalk_filter_and_write(datafiles, baseline_list=baselines, calfile_list=cals, - spw_range=[0, 20], cache_dir=cdir, read_cache=True, write_cache=True, - res_outfilename=fragment_filename, clobber=True) + frf.load_tophat_frfilter_and_write(datafiles, baseline_list=baselines, calfile_list=cals, + spw_range=[0, 20], cache_dir=cdir, read_cache=True, write_cache=True, + res_outfilename=fragment_filename, clobber=True) # load in fragment and make sure the number of baselines is equal to the length of the baseline list hd_fragment = io.HERAData(str(fragment_filename)) assert len(hd_fragment.bls) == len(baselines) @@ -1155,9 +1155,9 @@ def test_time_chunk_from_baseline_chunks(self, tmp_path): hd_reconstituted = io.HERAData(glob.glob(str(tmp_path / 'temp.reconstituted.part.*.h5'))) hd_reconstituted.read() # compare to xtalk filtering the whole file. - xf.load_xtalk_filter_and_write(datafile_list=os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5"), - calfile_list=os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.uv.abs.calfits_54x_only"), - res_outfilename=str(tmp_path / 'temp.h5'), clobber=True, spw_range=[0, 20]) + frf.load_tophat_frfilter_and_write(datafile_list=os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5"), + calfile_list=os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.uv.abs.calfits_54x_only"), + res_outfilename=str(tmp_path / 'temp.h5'), clobber=True, spw_range=[0, 20]) hd = io.HERAData(str(tmp_path / 'temp.h5')) hd.read() assert np.all(np.isclose(hd.data_array, hd_reconstituted.data_array)) @@ -1174,9 +1174,9 @@ def test_time_chunk_from_baseline_chunks(self, tmp_path): hd_reconstituted = io.HERAData(glob.glob(str(tmp_path / 'temp.reconstituted.part.*.h5'))) hd_reconstituted.read() # compare to xtalk filtering the whole file. - xf.load_xtalk_filter_and_write(datafile_list=os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5"), - calfile_list=os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.uv.abs.calfits_54x_only"), - res_outfilename=str(tmp_path / 'temp.h5'), clobber=True, spw_range=[0, 20]) + frf.load_tophat_frfilter_and_write(datafile_list=os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5"), + calfile_list=os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.uv.abs.calfits_54x_only"), + res_outfilename=str(tmp_path / 'temp.h5'), clobber=True, spw_range=[0, 20]) hd = io.HERAData(str(tmp_path / 'temp.h5')) hd.read() assert np.all(np.isclose(hd.data_array, hd_reconstituted.data_array)) diff --git a/hera_cal/tests/test_xtalk_filter.py b/hera_cal/tests/test_xtalk_filter.py deleted file mode 100644 index ec0e1ca23..000000000 --- a/hera_cal/tests/test_xtalk_filter.py +++ /dev/null @@ -1,407 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 the HERA Project -# Licensed under the MIT License - -import pytest -import numpy as np -from copy import deepcopy -import os -import sys -import shutil -from scipy import constants -from pyuvdata import UVCal, UVData - -from .. import io -from .. import xtalk_filter as xf -from ..data import DATA_PATH -import glob -from .. import vis_clean -from .. import utils as utils -from pyuvdata import UVFlag - - -class Test_XTalkFilter(object): - def test_run_xtalk_filter(self): - fname = os.path.join(DATA_PATH, "zen.2458043.12552.xx.HH.uvORA") - k = (24, 25, 'ee') - xfil = xf.XTalkFilter(fname, filetype='miriad') - xfil.read(bls=[k]) - bl = np.linalg.norm(xfil.antpos[24] - xfil.antpos[25]) / constants.c * 1e9 - sdf = (xfil.freqs[1] - xfil.freqs[0]) / 1e9 - - xfil.run_xtalk_filter(to_filter=xfil.data.keys(), tol=1e-2) - for k in xfil.data.keys(): - assert xfil.clean_resid[k].shape == (60, 64) - assert xfil.clean_model[k].shape == (60, 64) - assert k in xfil.clean_info - - # test skip_wgt imposition of flags - fname = os.path.join(DATA_PATH, "zen.2458043.12552.xx.HH.uvORA") - k = (24, 25, 'ee') - # check successful run when avg_red_bllens is True and when False. - for avg_red_bllens in [True, False]: - xfil = xf.XTalkFilter(fname, filetype='miriad') - xfil.read(bls=[k]) - if avg_red_bllens: - xfil.avg_red_baseline_vectors() - wgts = {k: np.ones_like(xfil.flags[k], dtype=np.float)} - wgts[k][:, 0] = 0.0 - xfil.run_xtalk_filter(to_filter=[k], weight_dict=wgts, tol=1e-5, window='blackman-harris', skip_wgt=0.1, maxiter=100) - assert xfil.clean_info[k][(0, xfil.Nfreqs)]['status']['axis_0'][0] == 'skipped' - np.testing.assert_array_equal(xfil.clean_flags[k][:, 0], np.ones_like(xfil.flags[k][:, 0])) - np.testing.assert_array_equal(xfil.clean_model[k][:, 0], np.zeros_like(xfil.clean_resid[k][:, 0])) - np.testing.assert_array_equal(xfil.clean_resid[k][:, 0], np.zeros_like(xfil.clean_resid[k][:, 0])) - - def test_load_xtalk_filter_and_write_baseline_list(self, tmpdir): - tmp_path = tmpdir.strpath - uvh5 = [os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.first.uvh5"), - os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.second.uvh5")] - cals = [os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.uv.abs.calfits_54x_only.part1"), - os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.uv.abs.calfits_54x_only.part2")] - outfilename = os.path.join(tmp_path, 'temp.h5') - cdir = os.path.join(tmp_path, 'cache_temp') - # make a cache directory - if os.path.isdir(cdir): - shutil.rmtree(cdir) - os.mkdir(cdir) - # check graceful exist with length zero baseline list. - with pytest.warns(RuntimeWarning): - xf.load_xtalk_filter_and_write(datafile_list=uvh5, baseline_list=[], polarizations=['ee'], - calfile_list=cals, spw_range=[100, 200], cache_dir=cdir, - read_cache=True, write_cache=True, avg_red_bllens=True, - res_outfilename=outfilename, clobber=True, - mode='dayenu') - for avg_bl in [True, False]: - xf.load_xtalk_filter_and_write(datafile_list=uvh5, baseline_list=[(53, 54)], polarizations=['ee'], - calfile_list=cals, spw_range=[100, 200], cache_dir=cdir, - read_cache=True, write_cache=True, avg_red_bllens=avg_bl, - res_outfilename=outfilename, clobber=True, - mode='dayenu') - hd = io.HERAData(outfilename) - d, f, n = hd.read() - assert len(list(d.keys())) == 1 - assert d[(53, 54, 'ee')].shape[1] == 100 - assert d[(53, 54, 'ee')].shape[0] == 60 - # now do no spw range and no cal files just to cover those lines. - xf.load_xtalk_filter_and_write(datafile_list=uvh5, baseline_list=[(53, 54)], polarizations=['ee'], - cache_dir=cdir, - read_cache=True, write_cache=True, avg_red_bllens=avg_bl, - res_outfilename=outfilename, clobber=True, - mode='dayenu') - hd = io.HERAData(outfilename) - d, f, n = hd.read() - assert len(list(d.keys())) == 1 - assert d[(53, 54, 'ee')].shape[1] == 1024 - assert d[(53, 54, 'ee')].shape[0] == 60 - - # test baseline_list = None - xf.load_xtalk_filter_and_write(datafile_list=uvh5, baseline_list=None, - calfile_list=cals, spw_range=[100, 200], cache_dir=cdir, - read_cache=True, write_cache=True, avg_red_bllens=True, - res_outfilename=outfilename, clobber=True, - mode='dayenu') - hd = io.HERAData(outfilename) - d, f, n = hd.read() - assert d[(53, 54, 'ee')].shape[1] == 100 - assert d[(53, 54, 'ee')].shape[0] == 60 - hdall = io.HERAData(uvh5) - hdall.read() - assert np.allclose(hd.baseline_array, hdall.baseline_array) - assert np.allclose(hd.time_array, hdall.time_array) - # now test flag factorization and time thresholding. - # prepare an input files for broadcasting flags - uvh5 = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5") - input_file = os.path.join(tmp_path, 'temp_special_flags.h5') - shutil.copy(uvh5, input_file) - hd = io.HERAData(input_file) - _, flags, _ = hd.read() - ntimes_before = hd.Ntimes - nfreqs_before = hd.Nfreqs - freqs_before = hd.freqs - times_before = hd.times - for bl in flags: - flags[bl][:] = False - flags[bl][0, :hd.Nfreqs // 2] = True # first time has 50% flagged - flags[bl][-3:, -1] = True # last channel has flags for three integrations - hd.update(flags=flags) - hd.write_uvh5(input_file, clobber=True) - # this time_threshold will result in - # entire first integration begin flagged - # and entire final channel being flagged - # when flags are broadcasted. - time_thresh = 2. / hd.Ntimes - for blnum, bl in enumerate(flags.keys()): - outfilename = os.path.join(tmp_path, 'bl_chunk_%d.h5' % blnum) - xf.load_xtalk_filter_and_write(datafile_list=[input_file], res_outfilename=outfilename, - tol=1e-4, baseline_list=[bl[:2]], polarizations=[bl[-1]], - cache_dir=cdir, - factorize_flags=True, - time_thresh=time_thresh, clobber=True) - # now load all of the outputs in - output_files = glob.glob(tmp_path + '/bl_chunk_*.h5') - hd = io.HERAData(output_files) - d, f, n = hd.read() - hd_original = io.HERAData(uvh5) - for bl in hd_original.bls: - assert bl in d.keys() - - for bl in f: - assert np.all(f[bl][:, -1]) - assert np.all(f[bl][0, :]) - - # test apriori flags and flag_yaml - flag_yaml = os.path.join(DATA_PATH, 'test_input/a_priori_flags_sample.yaml') - uvf = UVFlag(hd, mode='flag', copy_flags=True) - uvf.to_waterfall(keep_pol=False, method='and') - uvf.flag_array[:] = False - flagfile = os.path.join(tmp_path, 'test_flag.h5') - uvf.write(flagfile, clobber=True) - xf.load_xtalk_filter_and_write(datafile_list=[input_file], res_outfilename=outfilename, - tol=1e-4, baseline_list=[bl[:2]], - clobber=True, mode='dayenu', - external_flags=flagfile, overwrite_flags=True) - # test that all flags are False - hd = io.HERAData(outfilename) - d, f, n = hd.read() - for k in f: - assert np.all(~f[k]) - # now do the external yaml - xf.load_xtalk_filter_and_write(datafile_list=[input_file], res_outfilename=outfilename, - tol=1e-4, baseline_list=[bl[:2]], - clobber=True, mode='dayenu', - external_flags=flagfile, overwrite_flags=True, - flag_yaml=flag_yaml) - # test that all flags are af yaml flags - hd = io.HERAData(outfilename) - d, f, n = hd.read() - for k in f: - assert np.all(f[k][:, 0]) - assert np.all(f[k][:, 1]) - assert np.all(f[k][:, 10:20]) - assert np.all(f[k][:, 60]) - os.remove(outfilename) - shutil.rmtree(cdir) - - def test_load_xtalk_filter_and_write(self, tmpdir): - tmp_path = tmpdir.strpath - uvh5 = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5") - outfilename = os.path.join(tmp_path, 'temp.h5') - xf.load_xtalk_filter_and_write(uvh5, res_outfilename=outfilename, tol=1e-4, clobber=True, Nbls_per_load=1) - hd = io.HERAData(outfilename) - d, f, n = hd.read(bls=[(53, 54, 'ee')]) - for bl in d: - assert not np.all(np.isclose(d[bl], 0.)) - - xfil = xf.XTalkFilter(uvh5, filetype='uvh5') - xfil.read(bls=[(53, 54, 'ee')]) - xfil.run_xtalk_filter(to_filter=[(53, 54, 'ee')], tol=1e-4, verbose=True) - np.testing.assert_almost_equal(d[(53, 54, 'ee')], xfil.clean_resid[(53, 54, 'ee')], decimal=5) - np.testing.assert_array_equal(f[(53, 54, 'ee')], xfil.flags[(53, 54, 'ee')]) - # test NotImplementedError - pytest.raises(NotImplementedError, xf.load_xtalk_filter_and_write, uvh5, res_outfilename=outfilename, tol=1e-4, - clobber=True, Nbls_per_load=1, avg_red_bllens=True, baseline_list=[(54, 54)], polarizations=['ee']) - - # test loading and writing all baselines at once. - uvh5 = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5") - outfilename = os.path.join(tmp_path, 'temp.h5') - for avg_bl in [True, False]: - xf.load_xtalk_filter_and_write(uvh5, res_outfilename=outfilename, tol=1e-4, clobber=True, - Nbls_per_load=None, avg_red_bllens=avg_bl) - hd = io.HERAData(outfilename) - d, f, n = hd.read(bls=[(53, 54, 'ee')]) - for bl in d: - assert not np.all(np.isclose(d[bl], 0.)) - - xfil = xf.XTalkFilter(uvh5, filetype='uvh5') - xfil.read(bls=[(53, 54, 'ee')]) - xfil.run_xtalk_filter(to_filter=[(53, 54, 'ee')], tol=1e-4, verbose=True) - np.testing.assert_almost_equal(d[(53, 54, 'ee')], xfil.clean_resid[(53, 54, 'ee')], decimal=5) - np.testing.assert_array_equal(f[(53, 54, 'ee')], xfil.flags[(53, 54, 'ee')]) - - cal = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.uv.abs.calfits_54x_only") - outfilename = os.path.join(tmp_path, 'temp.h5') - os.remove(outfilename) - for avg_bl in [True, False]: - xf.load_xtalk_filter_and_write(uvh5, calfile_list=cal, tol=1e-4, res_outfilename=outfilename, - Nbls_per_load=2, clobber=True, avg_red_bllens=avg_bl) - hd = io.HERAData(outfilename) - assert 'Thisfilewasproducedbythefunction' in hd.history.replace('\n', '').replace(' ', '') - d, f, n = hd.read() - for bl in d: - if not np.all(f[bl]): - assert not np.all(np.isclose(d[bl], 0.)) - np.testing.assert_array_equal(f[(53, 54, 'ee')], True) - os.remove(outfilename) - - # prepare an input file for broadcasting flags - input_file = os.path.join(tmp_path, 'temp_special_flags.h5') - shutil.copy(uvh5, input_file) - hd = io.HERAData(input_file) - _, flags, _ = hd.read() - ntimes_before = hd.Ntimes - nfreqs_before = hd.Nfreqs - freqs_before = hd.freqs - times_before = hd.times - for bl in flags: - flags[bl][:] = False - flags[bl][0, :hd.Nfreqs // 2] = True # first time has 50% flagged - flags[bl][-3:, -1] = True # last channel has flags for three integrations - hd.update(flags=flags) - hd.write_uvh5(input_file, clobber=True) - # this time_threshold will result in - # entire first integration begin flagged - # and entire final channel being flagged - # when flags are broadcasted. - time_thresh = 2. / hd.Ntimes - xf.load_xtalk_filter_and_write(input_file, res_outfilename=outfilename, tol=1e-4, - factorize_flags=True, time_thresh=time_thresh, clobber=True) - hd = io.HERAData(outfilename) - d, f, n = hd.read(bls=[(53, 54, 'ee')]) - for bl in f: - assert np.any(f[bl][:, :-1]) - assert np.all(f[bl][0, :]) - - # test delay filtering and writing with factorized flags and partial i/o - xf.load_xtalk_filter_and_write(input_file, res_outfilename=outfilename, tol=1e-4, - factorize_flags=True, time_thresh=time_thresh, clobber=True) - hd = io.HERAData(outfilename) - d, f, n = hd.read(bls=[(53, 54, 'ee')]) - for bl in f: - # check that flags were broadcasted. - assert np.all(f[bl][0, :]) - assert np.all(f[bl][:, -1]) - assert not np.all(np.isclose(d[bl], 0.)) - - xf.load_xtalk_filter_and_write(input_file, res_outfilename=outfilename, tol=1e-4, Nbls_per_load=1, - factorize_flags=True, time_thresh=time_thresh, clobber=True) - hd = io.HERAData(outfilename) - d, f, n = hd.read(bls=[(53, 54, 'ee')]) - for bl in f: - # check that flags were broadcasted. - assert np.all(f[bl][0, :]) - assert np.all(f[bl][:, -1]) - assert not np.all(np.isclose(d[bl], 0.)) - - # test apriori flags and flag_yaml - hd = io.HERAData(uvh5) - hd.read() - flag_yaml = os.path.join(DATA_PATH, 'test_input/a_priori_flags_sample.yaml') - uvf = UVFlag(hd, mode='flag', copy_flags=True) - uvf.to_waterfall(keep_pol=False, method='and') - uvf.flag_array[:] = False - flagfile = os.path.join(tmp_path, 'test_flag.h5') - uvf.write(flagfile, clobber=True) - xf.load_xtalk_filter_and_write(uvh5, res_outfilename=outfilename, - Nbls_per_load=1, clobber=True, mode='dayenu', - external_flags=flagfile, - overwrite_flags=True) - # test that all flags are False - hd = io.HERAData(outfilename) - d, f, n = hd.read(bls=[(53, 54, 'ee')]) - for k in f: - assert np.all(~f[k]) - # now without parital io. - xf.load_xtalk_filter_and_write(uvh5, res_outfilename=outfilename, - clobber=True, mode='dayenu', - external_flags=flagfile, - overwrite_flags=True) - # test that all flags are False - hd = io.HERAData(outfilename) - d, f, n = hd.read(bls=[(53, 54, 'ee')]) - for k in f: - assert np.all(~f[k]) - - def test_load_dayenu_filter_and_write(self, tmpdir): - tmp_path = tmpdir.strpath - uvh5 = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.OCR_53x_54x_only.uvh5") - cdir = os.getcwd() - cdir = os.path.join(cdir, 'cache_temp') - # make a cache directory - if os.path.isdir(cdir): - shutil.rmtree(cdir) - os.mkdir(cdir) - outfilename = os.path.join(tmp_path, 'temp.h5') - # run dayenu filter - avg_bl = True - xf.load_xtalk_filter_and_write(uvh5, res_outfilename=outfilename, - cache_dir=cdir, mode='dayenu', - Nbls_per_load=1, clobber=True, avg_red_bllens=avg_bl, - spw_range=(0, 32), write_cache=True) - # generate duplicate cache files to test duplicate key handle for cache load. - xf.load_xtalk_filter_and_write(uvh5, res_outfilename=outfilename, cache_dir=cdir, - mode='dayenu', avg_red_bllens=avg_bl, - Nbls_per_load=1, clobber=True, read_cache=False, - spw_range=(0, 32), write_cache=True) - # there should now be six cache files (one per i/o/filter). There are three baselines. - assert len(glob.glob(cdir + '/*')) == 6 - hd = io.HERAData(outfilename) - assert 'Thisfilewasproducedbythefunction' in hd.history.replace('\n', '').replace(' ', '') - d, f, n = hd.read(bls=[(53, 54, 'ee')]) - np.testing.assert_array_equal(f[(53, 54, 'ee')], True) - os.remove(outfilename) - shutil.rmtree(cdir) - os.mkdir(cdir) - # now do all the baselines at once. - for avg_bl in [True, False]: - xf.load_xtalk_filter_and_write(uvh5, res_outfilename=outfilename, - cache_dir=cdir, mode='dayenu', avg_red_bllens=avg_bl, - Nbls_per_load=None, clobber=True, - spw_range=(0, 32), write_cache=True) - if avg_bl: - assert len(glob.glob(cdir + '/*')) == 1 - hd = io.HERAData(outfilename) - assert 'Thisfilewasproducedbythefunction' in hd.history.replace('\n', '').replace(' ', '') - d, f, n = hd.read(bls=[(53, 54, 'ee')]) - np.testing.assert_array_equal(f[(53, 54, 'ee')], True) - os.remove(outfilename) - shutil.rmtree(cdir) - os.mkdir(cdir) - # run again using computed cache. - calfile = os.path.join(DATA_PATH, "test_input/zen.2458101.46106.xx.HH.uv.abs.calfits_54x_only") - xf.load_xtalk_filter_and_write(uvh5, res_outfilename=outfilename, - cache_dir=cdir, calfile_list=calfile, read_cache=True, - Nbls_per_load=1, clobber=True, mode='dayenu', - spw_range=(0, 32), write_cache=True) - # now new cache files should be generated. - assert len(glob.glob(cdir + '/*')) == 1 - hd = io.HERAData(outfilename) - assert 'Thisfilewasproducedbythefunction' in hd.history.replace('\n', '').replace(' ', '') - d, f, n = hd.read(bls=[(53, 54, 'ee')]) - np.testing.assert_array_equal(f[(53, 54, 'ee')], True) - os.remove(outfilename) - shutil.rmtree(cdir) - - def test_xtalk_clean_argparser(self): - sys.argv = [sys.argv[0], 'a', '--clobber', '--window', 'blackmanharris', '--max_frate_coeffs', '0.024', '-0.229', '--mode', 'clean'] - parser = xf.xtalk_filter_argparser() - a = parser.parse_args() - assert a.datafilelist == ['a'] - assert a.clobber is True - assert a.window == 'blackmanharris' - assert a.max_frate_coeffs[0] == 0.024 - assert a.max_frate_coeffs[1] == -0.229 - assert a.time_thresh == 0.05 - assert not a.factorize_flags - - def test_xtalk_linear_argparser(self): - sys.argv = [sys.argv[0], 'a', '--clobber', '--write_cache', '--cache_dir', '/blah/', '--max_frate_coeffs', '0.024', '-0.229', '--mode', 'dayenu'] - parser = xf.xtalk_filter_argparser() - a = parser.parse_args() - assert a.datafilelist == ['a'] - assert a.clobber is True - assert a.write_cache is True - assert a.cache_dir == '/blah/' - assert a.max_frate_coeffs[0] == 0.024 - assert a.max_frate_coeffs[1] == -0.229 - assert a.time_thresh == 0.05 - assert not a.factorize_flags - parser = xf.xtalk_filter_argparser() - a = parser.parse_args() - assert a.datafilelist == ['a'] - assert a.clobber is True - assert a.write_cache is True - assert a.cache_dir == '/blah/' - assert a.max_frate_coeffs[0] == 0.024 - assert a.max_frate_coeffs[1] == -0.229 - assert a.time_thresh == 0.05 - assert not a.factorize_flags diff --git a/hera_cal/vis_clean.py b/hera_cal/vis_clean.py index e7f8a3ce6..42e53e70d 100644 --- a/hera_cal/vis_clean.py +++ b/hera_cal/vis_clean.py @@ -19,6 +19,7 @@ from .datacontainer import DataContainer from .utils import echo from .flag_utils import factorize_flags +import tensorflow as tf def find_discontinuity_edges(x, xtol=1e-3): @@ -793,6 +794,7 @@ def fourier_filter(self, filter_centers, filter_half_widths, mode='clean', keep_flags=False, clean_flags_in_resid_flags=False, skip_if_flag_within_edge_distance=0, flag_model_rms_outliers=False, model_rms_threshold=1.1, + gpu_index=None, gpu_memory_limit=None, **filter_kwargs): """ Generalized fourier filtering wrapper for uvtools.dspec.fourier_filter. @@ -981,6 +983,19 @@ def fourier_filter(self, filter_centers, filter_half_widths, mode='clean', unnecessarily long. If it is too high, clean does a poor job of deconvolving. 'alpha': float, if window is 'tukey', this is its alpha parameter. """ + gpus = tf.config.list_physical_devices("GPU") + if gpu_index is not None: + # See https://www.tensorflow.org/guide/gpu + if gpus: + if gpu_memory_limit is None: + tf.config.set_visible_devices(gpus[gpu_index], "GPU") + else: + tf.config.set_logical_device_configuration( + gpus[gpu_index], [tf.config.LogicalDeviceConfiguration(memory_limit=gpu_memory_limit * 1024)] + ) + + logical_gpus = tf.config.list_logical_devices("GPU") + echo(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU", verbose=verbose) # type checks if ax == 'both': if zeropad is None: @@ -1113,10 +1128,18 @@ def fourier_filter(self, filter_centers, filter_half_widths, mode='clean', win = flag_rows_with_flags_within_edge_distance(xp, win, skip_if_flag_within_edge_distance, ax=ax) mdl, res = np.zeros_like(d), np.zeros_like(d) - mdl, res, info = dspec.fourier_filter(x=xp, data=din, wgts=win, filter_centers=filter_centers, - filter_half_widths=filter_half_widths, - mode=mode, filter_dims=filterdim, skip_wgt=skip_wgt, - **filter_kwargs) + # limit computations to use specified GPU if they were provided. + if gpu_index is not None and gpus: + with tf.device(f"/device:GPU:{gpus[gpu_index].name[-1]}"): + mdl, res, info = dspec.fourier_filter(x=xp, data=din, wgts=win, filter_centers=filter_centers, + filter_half_widths=filter_half_widths, + mode=mode, filter_dims=filterdim, skip_wgt=skip_wgt, + **filter_kwargs) + else: + mdl, res, info = dspec.fourier_filter(x=xp, data=din, wgts=win, filter_centers=filter_centers, + filter_half_widths=filter_half_widths, + mode=mode, filter_dims=filterdim, skip_wgt=skip_wgt, + **filter_kwargs) # insert back the filtered model if we are skipping flagged edgs. if skip_flagged_edges: mdl = restore_flagged_edges(xp, mdl, edges, ax=ax) @@ -1920,10 +1943,12 @@ def list_of_int_tuples(v): ap.add_argument("--flag_yaml", default=None, type=str, help="path to a flagging yaml containing apriori antenna, freq, and time flags.") ap.add_argument("--polarizations", default=None, type=str, nargs="+", help="list of polarizations to filter.") ap.add_argument("--verbose", default=False, action="store_true", help="Lots of text.") + ap.add_argument("--skip_if_flag_within_edge_distance", type=int, default=0, help="skip integrations channels if there is a flag within this integer distance of edge.") ap.add_argument("--filter_spw_ranges", default=None, type=list_of_int_tuples, help="List of spw channel selections to filter independently. Two acceptable formats are " "Ex1: '200~300,500~650' --> [(200, 300), (500, 650), ...] and " "Ex2: '200 300, 500 650' --> [(200, 300), (500, 650), ...]") - # clean arguments. + ap.add_argument("--use_tensorflow", default=False, action="store_true", help="If provided, will use tensorflow GPU accelerated methods where possible.") + # Arguments for CLEAN. Not used in linear filtering methods. clean_options = ap.add_argument_group(title='Options for CLEAN (arguments only used if mode=="clean"!)') clean_options.add_argument("--window", type=str, default='blackman-harris', help='window function for frequency filtering (default "blackman-harris",\ see uvtools.dspec.gen_window for options') @@ -1932,10 +1957,12 @@ def list_of_int_tuples(v): clean_options.add_argument("--edgecut_hi", default=0, type=int, help="Number of channels to flag on upper band edge and exclude from window function.") clean_options.add_argument("--gain", type=float, default=0.1, help="Fraction of residual to use in each iteration.") clean_options.add_argument("--alpha", type=float, default=.5, help="If window='tukey', use this alpha parameter (default .5).") + # Options for caching for linear filtering. cache_options = ap.add_argument_group(title='Options for caching (arguments only used if mode!="clean")') cache_options.add_argument("--write_cache", default=False, action="store_true", help="if True, writes newly computed filter matrices to cache.") cache_options.add_argument("--cache_dir", type=str, default=None, help="directory to store cached filtering matrices in.") cache_options.add_argument("--read_cache", default=False, action="store_true", help="If true, read in cache files in directory specified by cache_dir.") + # Options that are only used for linear filters like dayenu and dpss_leastsq. linear_options = ap.add_argument_group(title="Options for linear filtering (dayenu and dpss_leastsq)") linear_options.add_argument("--max_contiguous_edge_flags", type=int, default=1, help="Skip integrations with at least this number of contiguous edge flags.") return ap diff --git a/hera_cal/xtalk_filter.py b/hera_cal/xtalk_filter.py deleted file mode 100644 index 8f43dbb4d..000000000 --- a/hera_cal/xtalk_filter.py +++ /dev/null @@ -1,218 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 the HERA Project -# Licensed under the MIT License - -"""Module for xtalk filtering data and related operations.""" - -import numpy as np - -from . import io -from . import version -from .vis_clean import VisClean -from . import vis_clean - -import pickle -import random -import glob -import os -import warnings -from copy import deepcopy -from pyuvdata import UVCal -import argparse - - -class XTalkFilter(VisClean): - """ - XTalkFilter object. - - Used for fringe-rate Xtalk CLEANing and filtering. - See vis_clean.VisClean for CLEAN functions. - """ - - def run_xtalk_filter(self, to_filter=None, weight_dict=None, max_frate_coeffs=[0.024, -0.229], mode='clean', - skip_wgt=0.1, tol=1e-9, cache_dir=None, read_cache=False, - write_cache=False, skip_flagged_edges=False, keep_flags=True, - data=None, flags=None, **filter_kwargs): - ''' - Run a cross-talk filter on data where the maximum fringe rate is set by the baseline length. - - Run a delay-filter on (a subset of) the data stored in the object. - Uses stored flags unless explicitly overridden with weight_dict. - - Arguments: - to_filter: list of visibilities to filter in the (i,j,pol) format. - If None (the default), all visibilities are filtered. - weight_dict: dictionary or DataContainer with all the same keys as self.data. - Linear multiplicative weights to use for the delay filter. Default, use np.logical_not - of self.flags. uvtools.dspec.xtalk_filter will renormalize to compensate. - max_frate_coeffs: All fringe-rates below this value are filtered (or interpolated) (in milliseconds). - max_frate [mHz] = x1 * EW_bl_len [ m ] + x2 - mode: string specifying filtering mode. See fourier_filter or uvtools.dspec.xtalk_filter for supported modes. - skip_wgt: skips filtering rows with very low total weight (unflagged fraction ~< skip_wgt). - Model is left as 0s, residual is left as data, and info is {'skipped': True} for that - time. Skipped channels are then flagged in self.flags. - Only works properly when all weights are all between 0 and 1. - tol : float, optional. To what level are foregrounds subtracted. - cache_dir: string, optional, path to cache file that contains pre-computed dayenu matrices. - see uvtools.dspec.dayenu_filter for key formats. - read_cache: bool, If true, read existing cache files in cache_dir before running. - write_cache: bool. If true, create new cache file with precomputed matrices - that were not in previously loaded cache files. - cache: dictionary containing pre-computed filter products. - skip_flagged_edges : bool, if true do not include edge times in filtering region (filter over sub-region). - keep_flags : bool, if true, retain data flags in filled data. - filter_kwargs: see fourier_filter for a full list of filter_specific arguments. - - Results are stored in: - self.clean_resid: DataContainer formatted like self.data with only high-delay components - self.clean_model: DataContainer formatted like self.data with only low-delay components - self.clean_info: Dictionary of info from uvtools.dspec.xtalk_filter with the same keys as self.data - ''' - # read in cache - if not mode == 'clean': - if read_cache: - filter_cache = io.read_filter_cache_scratch(cache_dir) - else: - filter_cache = {} - keys_before = list(filter_cache.keys()) - else: - filter_cache = None - # compute maximum fringe rate dict based on EW baseline lengths. - max_frate = io.DataContainer({k: np.max([max_frate_coeffs[0] * self.blvecs[k[:2]][0] + max_frate_coeffs[1], 0.0]) for k in self.data}) - # loop over all baselines in increments of Nbls - self.vis_clean(keys=to_filter, data=self.data, flags=self.flags, wgts=weight_dict, - ax='time', x=(self.times - np.mean(self.times)) * 24. * 3600., - cache=filter_cache, mode=mode, tol=tol, skip_wgt=skip_wgt, max_frate=max_frate, - overwrite=True, skip_flagged_edges=skip_flagged_edges, - keep_flags=keep_flags, **filter_kwargs) - if not mode == 'clean': - if write_cache: - filter_cache = io.write_filter_cache_scratch(filter_cache, cache_dir, skip_keys=keys_before) - - -def load_xtalk_filter_and_write(datafile_list, baseline_list=None, calfile_list=None, - Nbls_per_load=None, spw_range=None, cache_dir=None, - read_cache=False, write_cache=False, external_flags=None, - factorize_flags=False, time_thresh=0.05, - res_outfilename=None, CLEAN_outfilename=None, filled_outfilename=None, - clobber=False, add_to_history='', avg_red_bllens=False, polarizations=None, - skip_flagged_edges=False, overwrite_flags=False, - flag_yaml=None, - clean_flags_in_resid_flags=True, **filter_kwargs): - ''' - A xtalk filtering method that only simultaneously loads and writes user-provided - list of baselines. This is to support parallelization over baseline (rather then time). - While this function reads from multiple files (in datafile_list) - it always writes to a single file for the resid, filled, and model files. - - Arguments: - datafile_list: list of data files to perform cross-talk filtering on - baseline_list: list of antenna-pair 2-tuples. - to filter and write out from the datafile_list. - If None, load all baselines in files in datafile_list. Default is None. - calfile_list: optional list of calibration files to apply to data before xtalk filtering - Nbls_per_load: int, the number of baselines to load at once. - If None, load all baselines at once. default : None. - spw_range: 2-tuple or 2-list, spw_range of data to filter. - cache_dir: string, optional, path to cache file that contains pre-computed dayenu matrices. - see uvtools.dspec.dayenu_filter for key formats. - read_cache: bool, If true, read existing cache files in cache_dir before running. - write_cache: bool. If true, create new cache file with precomputed matrices - that were not in previously loaded cache files. - factorize_flags: bool, optional - If True, factorize flags before running delay filter. See vis_clean.factorize_flags. - time_thresh : float, optional - Fractional threshold of flagged pixels across time needed to flag all times - per freq channel. It is not recommend to set this greater than 0.5. - Fully flagged integrations do not count towards triggering time_thresh. - res_outfilename: path for writing the filtered visibilities with flags - CLEAN_outfilename: path for writing the CLEAN model visibilities (with the same flags) - filled_outfilename: path for writing the original data but with flags unflagged and replaced - with CLEAN models wherever possible - clobber: if True, overwrites existing file at the outfilename - add_to_history: string appended to the history of the output file - avg_red_bllens: bool, if True, round baseline lengths to redundant average. Default is False. - polarizations : list of polarizations to process (and write out). Default None operates on all polarizations in data. - skip_flagged_edges : bool, if true do not include edge times in filtering region (filter over sub-region). - overwrite_flags : bool, if true reset data flags to False except for flagged antennas. - flag_yaml: path to manual flagging text file. - clean_flags_in_resid_flags: bool, optional. If true, include clean flags in residual flags that get written. - default is True. - filter_kwargs: additional keyword arguments to be passed to XTalkFilter.run_xtalk_filter() - ''' - if baseline_list is not None and Nbls_per_load is not None: - raise NotImplementedError("baseline loading and partial i/o not yet implemented.") - hd = io.HERAData(datafile_list, filetype='uvh5', axis='blt') - if baseline_list is None: - if len(hd.filepaths) > 1: - baseline_list = list(hd.bls.values())[0] - else: - baseline_list = hd.bls - if len(baseline_list) == 0: - warnings.warn("Length of baseline list is zero." - "This can happen under normal circumstances when there are more files in datafile_list then baselines." - "in your dataset. Exiting without writing any output.", RuntimeWarning) - else: - if spw_range is None: - spw_range = [0, hd.Nfreqs] - freqs = hd.freq_array.flatten()[spw_range[0]:spw_range[1]] - baseline_antennas = [] - for blpolpair in baseline_list: - baseline_antennas += list(blpolpair[:2]) - baseline_antennas = np.unique(baseline_antennas).astype(int) - if calfile_list is not None: - cals = io.HERACal(calfile_list) - cals.read(antenna_nums=baseline_antennas, frequencies=freqs) - else: - cals = None - if polarizations is None: - if len(hd.filepaths) > 1: - polarizations = list(hd.pols.values())[0] - else: - polarizations = hd.pols - if Nbls_per_load is None: - Nbls_per_load = len(baseline_list) - for i in range(0, len(baseline_list), Nbls_per_load): - xf = XTalkFilter(hd, input_cal=cals, axis='blt') - xf.read(bls=baseline_list[i:i + Nbls_per_load], frequencies=freqs) - if avg_red_bllens: - xf.avg_red_baseline_vectors() - if external_flags is not None: - xf.apply_flags(external_flags, overwrite_flags=overwrite_flags) - if flag_yaml is not None: - xf.apply_flags(flag_yaml, overwrite_flags=overwrite_flags, filetype='yaml') - if factorize_flags: - xf.factorize_flags(time_thresh=time_thresh, inplace=True) - xf.run_xtalk_filter(cache_dir=cache_dir, read_cache=read_cache, write_cache=write_cache, - skip_flagged_edges=skip_flagged_edges, **filter_kwargs) - xf.write_filtered_data(res_outfilename=res_outfilename, CLEAN_outfilename=CLEAN_outfilename, - filled_outfilename=filled_outfilename, partial_write=Nbls_per_load < len(baseline_list), - clobber=clobber, add_to_history=add_to_history, - extra_attrs={'Nfreqs': xf.hd.Nfreqs, 'freq_array': xf.hd.freq_array}) - xf.hd.data_array = None # this forces a reload in the next loop - -# ------------------------------------------ -# Here are arg-parsers for xtalk-filtering. -# ------------------------------------------ - - -def xtalk_filter_argparser(): - '''Arg parser for commandline operation of xtalk filters. - - Parameters - ---------- - mode : string, optional. - Determines sets of arguments to load. - Can be 'clean', 'dayenu', or 'dpss_leastsq'. - - Returns - ------- - argparser - argparser for xtalk (time-domain) filtering for specified filtering mode - - ''' - a = vis_clean._filter_argparser() - filt_options = a.add_argument_group(title='Options for the cross-talk filter') - filt_options.add_argument("--max_frate_coeffs", type=float, nargs=2, help="Maximum fringe-rate coefficients for the model max_frate [mHz] = x1 * EW_bl_len [ m ] + x2.") - filt_options.add_argument("--skip_if_flag_within_edge_distance", type=int, default=0, help="skip integrations channels if there is a flag within this integer distance of edge.") - return a diff --git a/scripts/delay_filter_run.py b/scripts/delay_filter_run.py index 3e5d7a35d..1329ed869 100644 --- a/scripts/delay_filter_run.py +++ b/scripts/delay_filter_run.py @@ -60,5 +60,5 @@ CLEAN_outfilename=ap.CLEAN_outfilename, standoff=ap.standoff, horizon=ap.horizon, tol=ap.tol, skip_wgt=ap.skip_wgt, min_dly=ap.min_dly, zeropad=ap.zeropad, - filter_spw_ranges=ap.filter_spw_ranges, + filter_spw_ranges=ap.filter_spw_ranges, use_tensorflow=use_tensorflow, clean_flags_in_resid_flags=True, **filter_kwargs) diff --git a/scripts/tophat_frfilter_run.py b/scripts/tophat_frfilter_run.py new file mode 100644 index 000000000..1459e45ee --- /dev/null +++ b/scripts/tophat_frfilter_run.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2020 the HERA Project +# Licensed under the MIT License + +"""Command-line drive script for hera_cal.xtalk_filter with baseline parallelization. Only performs DAYENU Filtering""" + +from hera_cal import frf +import sys +import hera_cal.io as io + +parser = frf.tophat_frfilter_argparser() + +ap = parser.parse_args() + +# set kwargs +if ap.mode == 'clean': + filter_kwargs = {'window': ap.window, + 'maxiter': ap.maxiter, 'edgecut_hi': ap.edgecut_hi, + 'edgecut_low': ap.edgecut_low, 'gain': ap.gain} + if ap.window == 'tukey': + filter_kwargs['alpha'] = ap.alpha + avg_red_bllens = False +elif ap.mode == 'dayenu': + filter_kwargs = {} + avg_red_bllens = True + filter_kwargs['max_contiguous_edge_flags'] = 10000 + filter_kwargs['skip_contiguous_flags'] = False + filter_kwargs['skip_flagged_edges'] = False + filter_kwargs['flag_model_rms_outliers'] = False +elif ap.mode == 'dpss_leastsq' or ap.mode == 'dft_leastsq': + filter_kwargs = {} + avg_red_bllens = True + filter_kwargs['skip_contiguous_flags']=True + filter_kwargs['skip_flagged_edges'] = True + filter_kwargs['max_contiguous_edge_flags'] = 1 + filter_kwargs['flag_model_rms_outliers'] = True + +if ap.cornerturnfile is not None: + baseline_list = io.baselines_from_filelist_position(filename=ap.cornerturnfile, filelist=ap.datafilelist) +else: + baseline_list = None + +# modify output file name to include index. +spw_range = ap.spw_range +# allow none string to be passed through to ap.calfile +if isinstance(ap.calfilelist, str) and ap.calfile_list.lower() == 'none': + ap.calfile_list = None +# Run Xtalk Filter +frf.load_tophat_frfilter_and_write(ap.datafilelist, calfile_list=ap.calfilelist, avg_red_bllens=True, + baseline_list=baseline_list, spw_range=ap.spw_range, + cache_dir=ap.cache_dir, filled_outfilename=ap.filled_outfilename, + clobber=ap.clobber, write_cache=ap.write_cache, CLEAN_outfilename=ap.CLEAN_outfilename, + read_cache=ap.read_cache, mode=ap.mode, res_outfilename=ap.res_outfilename, + factorize_flags=ap.factorize_flags, time_thresh=ap.time_thresh, + add_to_history=' '.join(sys.argv), verbose=ap.verbose, + flag_yaml=ap.flag_yaml, Nbls_per_load=ap.Nbls_per_load, + external_flags=ap.external_flags, filter_spw_ranges=ap.filter_spw_ranges, + overwrite_flags=ap.overwrite_flags, skip_autos=ap.skip_autos, + skip_if_flag_within_edge_distance=ap.skip_if_flag_within_edge_distance, + zeropad=ap.zeropad, tol=ap.tol, skip_wgt=ap.skip_wgt, max_frate_coeffs=ap.max_frate_coeffs, + frate_standoff=ap.frate_standoff, min_frate_half_width=ap.min_frate_half_width, + frate_width_multiplier=ap.frate_width_multiplier, fr_freq_skip=ap.fr_freq_skip, + uvbeam=ap.uvbeam, percentile_low=ap.percentile_low, percentile_high=ap.percentile_high, + clean_flags_in_resid_flags=True, use_tensorflow=ap.use_tensorflow, **filter_kwargs) diff --git a/scripts/xtalk_filter_run.py b/scripts/xtalk_filter_run.py deleted file mode 100644 index 0004171d9..000000000 --- a/scripts/xtalk_filter_run.py +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright 2020 the HERA Project -# Licensed under the MIT License - -"""Command-line drive script for hera_cal.xtalk_filter with baseline parallelization. Only performs DAYENU Filtering""" - -from hera_cal import xtalk_filter -import sys -import hera_cal.io as io - -parser = xtalk_filter.xtalk_filter_argparser() - -a = parser.parse_args() - -# set kwargs -if ap.mode == 'clean': - filter_kwargs = {'window': ap.window, - 'skip_wgt': ap.skip_wgt, 'maxiter': ap.maxiter, 'edgecut_hi': ap.edgecut_hi, - 'edgecut_low': ap.edgecut_low, 'gain': ap.gain} - if ap.window == 'tukey': - filter_kwargs['alpha'] = ap.alpha - avg_red_bllens = False -elif ap.mode == 'dayenu': - filter_kwargs = {} - avg_red_bllens = True - filter_kwargs['max_contiguous_edge_flags'] = 10000 - filter_kwargs['skip_contiguous_flags'] = False - filter_kwargs['skip_flagged_edges'] = False - filter_kwargs['flag_model_rms_outliers'] = False -elif ap.mode == 'dpss_leastsq': - filter_kwargs = {} - avg_red_bllens = True - filter_kwargs['skip_contiguous_flags'] = True - filter_kwargs['skip_flagged_edges'] = True - filter_kwargs['max_contiguous_edge_flags'] = 1 - filter_kwargs['flag_model_rms_outliers'] = True -filter_kwargs['zeropad'] = a.zeropad - -if args.cornerturnfile is not None: - baseline_list = io.baselines_from_filelist_position(filename=ap.cornerturnfile, filelist=ap.datafilelist) -else: - baseline_list = None - -# modify output file name to include index. -spw_range = ap.spw_range -# allow none string to be passed through to ap.calfile -if isinstance(ap.calfilelist, str) and ap.calfilelist.lower() == 'none': - ap.calfilelist = None -# Run Xtalk Filter -xtalk_filter.load_xtalk_filter_and_write(ap.datafilelist, calfile_list=ap.calfilelist, avg_red_bllens=True, - baseline_list=baseline_list, spw_range=ap.spw_range, - cache_dir=ap.cache_dir, filled_outfilename=ap.filled_outfilename, - clobber=ap.clobber, write_cache=ap.write_cache, CLEAN_outfilename=ap.CLEAN_outfilename, - read_cache=ap.read_cache, mode=ap.mode, res_outfilename=ap.res_outfilename, - factorize_flags=ap.factorize_flags, time_thresh=ap.time_thresh, - add_to_history=' '.join(sys.argv), verbose=ap.verbose, - tol=ap.tol, max_frate_coeffs=ap.max_frate_coeffs, - flag_yaml=ap.flag_yaml, Nbls_per_load=ap.Nbls_per_load, - external_flags=ap.external_flags, frate_standoff=ap.frate_standoff, - overwrite_flags=ap.overwrite_flags, - clean_flags_in_resid_flags=True, **filter_kwargs) diff --git a/setup.py b/setup.py index 0cadf1d28..03f9df756 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def package_files(package_dir, subdirectory): 'scripts/smooth_cal_run.py', 'scripts/redcal_run.py', 'scripts/auto_reflection_run.py', 'scripts/noise_from_autos.py', 'scripts/query_ex_ants.py', 'scripts/red_average.py', - 'scripts/xtalk_filter_run.py', 'scripts/time_average.py', + 'scripts/time_average.py', 'scripts/tophat_frfilter_run.py', 'scripts/time_chunk_from_baseline_chunks_run.py', 'scripts/chunk_files.py', 'scripts/transfer_flags.py', 'scripts/flag_all.py', 'scripts/throw_away_flagged_antennas.py', 'scripts/select_spw_ranges.py'], 'version': version.version, @@ -58,7 +58,7 @@ def package_files(package_dir, subdirectory): 'extras_require': { "all": [ 'aipy>=3.0', - 'uvtools @ git+git://github.com/HERA-Team/uvtools', + 'uvtools @ git+git://github.com/HERA-Team/uvtools@gpu_support', ] }, 'zip_safe': False,