From a63266ac9c3d4c998fddc10d9bcef8ba16b59fa2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Jul 2023 01:16:06 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- hera_commissioning_tools/plots.py | 466 ++++++++++++++++++------------ hera_commissioning_tools/utils.py | 198 +++++++------ setup.py | 20 +- 3 files changed, 404 insertions(+), 280 deletions(-) diff --git a/hera_commissioning_tools/plots.py b/hera_commissioning_tools/plots.py index a277711..19414b6 100644 --- a/hera_commissioning_tools/plots.py +++ b/hera_commissioning_tools/plots.py @@ -35,15 +35,20 @@ githash = utils.get_git_revision_hash() curr_file = os.path.dirname(os.path.abspath(__file__)) -def testWriteArgs(x,y,z,bananas,h=5,p=[1,2,3],m=False,outfig='testWriteArgs.png'): + +def testWriteArgs( + x, y, z, bananas, h=5, p=[1, 2, 3], m=False, outfig="testWriteArgs.png" +): import inspect + fig = plt.figure() - plt.plot([1,2],[4,5]) + plt.plot([1, 2], [4, 5]) plt.savefig(outfig) if write_params: args = locals() curr_func = inspect.stack()[0][3] - utils.write_params_to_text(outfig,args,curr_func,curr_file,githash) + utils.write_params_to_text(outfig, args, curr_func, curr_file, githash) + def plot_autos( uvd, @@ -53,7 +58,7 @@ def plot_autos( savefig=False, write_params=True, outfig="", - plot_nodes='all', + plot_nodes="all", time_slice=False, slice_freq_inds=[], dtype="sky", @@ -96,27 +101,27 @@ def plot_autos( times = uvd.time_array maxants = 0 for node in nodes: - if plot_nodes is not 'all' and node not in plot_nodes: + if plot_nodes is not "all" and node not in plot_nodes: continue n = len(nodes[node]["ants"]) if n > maxants: maxants = n Nside = maxants - if plot_nodes == 'all': + if plot_nodes == "all": Yside = len(inclNodes) else: Yside = len(plot_nodes) - t_index = len(np.unique(times))//2 + t_index = len(np.unique(times)) // 2 jd = times[t_index] utc = Time(jd, format="jd").datetime h = cm_active.get_active(at_date=jd, float_format="jd") xlim = (np.min(freqs), np.max(freqs)) - colorsx = ['gold','darkorange','orangered','red','maroon'] - colorsy = ['powderblue','deepskyblue','dodgerblue','royalblue','blue'] + colorsx = ["gold", "darkorange", "orangered", "red", "maroon"] + colorsy = ["powderblue", "deepskyblue", "dodgerblue", "royalblue", "blue"] if ylim is None: if dtype == "sky": @@ -133,11 +138,11 @@ def plot_autos( fig.tight_layout(rect=(0, 0, 1, 0.95)) fig.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=1, wspace=0.05, hspace=0.3) k = 0 - i=-1 + i = -1 yrange_set = False for _, n in enumerate(inclNodes): slots_filled = [] - if plot_nodes is not 'all': + if plot_nodes is not "all": inclNodes = plot_nodes if n not in plot_nodes: continue @@ -148,63 +153,67 @@ def plot_autos( if a not in ants: continue status = utils.get_ant_status(h, a) - slot = utils.get_slot_number(uvd, a, sorted_ants, sortedSnapLocs, sortedSnapInputs) + slot = utils.get_slot_number( + uvd, a, sorted_ants, sortedSnapLocs, sortedSnapInputs + ) slots_filled.append(slot) ax = axes[i, slot] if time_slice is True: - colors = ['r','b'] + colors = ["r", "b"] lsts = uvd.lst_array * 3.819719 inds = np.unique(lsts, return_index=True)[1] lsts = np.asarray([lsts[ind] for ind in sorted(inds)]) -# if a == sorted_ants[0]: -# print(lsts) -# print(len(lsts)) - dx = np.log10(np.abs(uvd.get_data((a,a,'xx')))) - for s,ind in enumerate(slice_freq_inds): - dslice = dx[:,ind] - dslice = np.subtract(dslice,np.nanmean(dslice)) + # if a == sorted_ants[0]: + # print(lsts) + # print(len(lsts)) + dx = np.log10(np.abs(uvd.get_data((a, a, "xx")))) + for s, ind in enumerate(slice_freq_inds): + dslice = dx[:, ind] + dslice = np.subtract(dslice, np.nanmean(dslice)) (px,) = ax.plot( dslice, color=colorsx[s], alpha=1, linewidth=1.2, - label=f'XX - {int(freqs[ind])} MHz' + label=f"XX - {int(freqs[ind])} MHz", ) - dy = np.log10(np.abs(uvd.get_data((a,a,'yy')))) - for s,ind in enumerate(slice_freq_inds): - dslice = dy[:,ind] - dslice = np.subtract(dslice,np.nanmean(dslice)) -# print(dslice) + dy = np.log10(np.abs(uvd.get_data((a, a, "yy")))) + for s, ind in enumerate(slice_freq_inds): + dslice = dy[:, ind] + dslice = np.subtract(dslice, np.nanmean(dslice)) + # print(dslice) (py,) = ax.plot( dslice, color=colorsy[s], alpha=1, linewidth=1.2, - label=f'YY - {int(freqs[ind])} MHz' + label=f"YY - {int(freqs[ind])} MHz", ) if yrange_set is False: - xdiff = np.abs(np.subtract(np.nanmax(dx),np.nanmin(dx))) - ydiff = np.abs(np.subtract(np.nanmax(dy),np.nanmin(dy))) - yrange = np.nanmax([xdiff,ydiff]) + xdiff = np.abs(np.subtract(np.nanmax(dx), np.nanmin(dx))) + ydiff = np.abs(np.subtract(np.nanmax(dy), np.nanmin(dy))) + yrange = np.nanmax([xdiff, ydiff]) if math.isinf(yrange) or math.isnan(yrange): yrange = 10 else: yrange_set = True if ylim is None: - ymin = np.nanmin([np.nanmin(dx),np.nanmin(dy)]) + ymin = np.nanmin([np.nanmin(dx), np.nanmin(dy)]) if math.isinf(ymin): - ymin=0 + ymin = 0 if math.isnan(ymin): - ymin=0 + ymin = 0 ylim = (ymin, ymin + yrange) - if yrange > 3*np.abs(np.subtract(np.nanmax(dx),np.nanmin(dx))) or yrange > 3*np.abs(np.subtract(np.nanmax(dy),np.nanmin(dy))): - xdiff = np.abs(np.subtract(np.nanmax(dx),np.nanmin(dx))) - ydiff = np.abs(np.subtract(np.nanmax(dy),np.nanmin(dy))) - yrange_temp = np.nanmax([xdiff,ydiff]) + if yrange > 3 * np.abs( + np.subtract(np.nanmax(dx), np.nanmin(dx)) + ) or yrange > 3 * np.abs(np.subtract(np.nanmax(dy), np.nanmin(dy))): + xdiff = np.abs(np.subtract(np.nanmax(dx), np.nanmin(dx))) + ydiff = np.abs(np.subtract(np.nanmax(dy), np.nanmin(dy))) + yrange_temp = np.nanmax([xdiff, ydiff]) ylim = (ymin, ymin + yrange_temp) - ax.tick_params(color='red', labelcolor='red') + ax.tick_params(color="red", labelcolor="red") for spine in ax.spines.values(): - spine.set_edgecolor('red') + spine.set_edgecolor("red") elif logscale is True: (px,) = ax.plot( freqs, @@ -253,7 +262,7 @@ def plot_autos( ) if k == 0: if time_slice: - ax.legend(loc='upper left',bbox_to_anchor=(0, 1.7),ncol=2) + ax.legend(loc="upper left", bbox_to_anchor=(0, 1.7), ncol=2) else: ax.legend([px, py], ["NN", "EE"]) if i == len(inclNodes) - 1: @@ -284,11 +293,11 @@ def plot_autos( f"Node {n}", (1.1, 0.3), xycoords="axes fraction", rotation=270 ) if savefig is True: - plt.savefig(outfig,bbox_inches='tight') + plt.savefig(outfig, bbox_inches="tight") if write_params: args = locals() curr_func = inspect.stack()[0][3] - utils.write_params_to_text(outfig,args,curr_func,curr_file,githash) + utils.write_params_to_text(outfig, args, curr_func, curr_file, githash) plt.show() else: plt.show() @@ -304,7 +313,7 @@ def plot_wfs( vmin=None, vmax=None, wrongAnts=[], - plot_nodes='all', + plot_nodes="all", logscale=True, uvd_diff=None, metric=None, @@ -403,14 +412,14 @@ def plot_wfs( vmax = vmaxAuto for node in nodes: - if plot_nodes is not 'all' and node not in plot_nodes: + if plot_nodes is not "all" and node not in plot_nodes: continue n = len(nodes[node]["ants"]) if n > maxants: maxants = n Nside = maxants - if plot_nodes == 'all': + if plot_nodes == "all": Yside = len(inclNodes) else: Yside = len(plot_nodes) @@ -427,7 +436,7 @@ def plot_wfs( i = -1 for _, n in enumerate(inclNodes): slots_filled = [] - if plot_nodes is not 'all': + if plot_nodes is not "all": inclNodes = plot_nodes if n not in plot_nodes: continue @@ -438,7 +447,9 @@ def plot_wfs( if a not in ants: continue status = utils.get_ant_status(h, a) - slot = utils.get_slot_number(uvd, a, sorted_ants, sortedSnapLocs, sortedSnapInputs) + slot = utils.get_slot_number( + uvd, a, sorted_ants, sortedSnapLocs, sortedSnapInputs + ) slots_filled.append(slot) abb = status_abbreviations[status] ax = axes[i, slot] @@ -543,7 +554,7 @@ def plot_wfs( if write_params: args = locals() curr_func = inspect.stack()[0][3] - utils.write_params_to_text(outfig,args,curr_func,curr_file,githash) + utils.write_params_to_text(outfig, args, curr_func, curr_file, githash) plt.show() plt.close() @@ -687,8 +698,8 @@ def waterfall_lineplot( if sliceind == None: dat = np.abs(dat[len(dat) // 2, :]) else: - dat = np.abs(dat[sliceind,:]) - waterfall.axhline(sliceind,color='r') + dat = np.abs(dat[sliceind, :]) + waterfall.axhline(sliceind, color="r") plt.plot(freq, dat) line2.set_yscale("log") line2.set_xlabel("Frequency (MHz)") @@ -727,7 +738,7 @@ def waterfall_lineplot( if write_params: args = locals() curr_func = inspect.stack()[0][3] - utils.write_params_to_text(outfig,args,curr_func,curr_file,githash) + utils.write_params_to_text(outfig, args, curr_func, curr_file, githash) plt.show() plt.close() @@ -1032,13 +1043,25 @@ def plotVisibilitySpectra(uv, use_ants="all", badAnts=[], pols=["xx", "yy"]): fig.subplots_adjust(top=0.94, wspace=0.05) plt.show() plt.close() - -def plotCrossAndAutos(uv,crossVis,lsts,bls,ant1,ant2,plot_type = 'raw',savefig=False,write_params=True,outfig='', - norm_type='abs',pols=['EE','NN']): + +def plotCrossAndAutos( + uv, + crossVis, + lsts, + bls, + ant1, + ant2, + plot_type="raw", + savefig=False, + write_params=True, + outfig="", + norm_type="abs", + pols=["EE", "NN"], +): """ Plots cross visibility waterfall alongside waterfalls of constituent autos. Commonly used in conjunction with utils.loadCrossVis. - + Parameters: ---------- uv: UVData Object @@ -1063,102 +1086,130 @@ def plotCrossAndAutos(uv,crossVis,lsts,bls,ant1,ant2,plot_type = 'raw',savefig=F Can be any of 'abs','real','imag','phase' - sets how the complex cross visibilities are handled. pols: List Set of polarizations to include. Can be any set of ['EE','NN','EN','NE'] - + Returns: -------- None - + """ - + import matplotlib.colors as colors from matplotlib import pyplot as plt - - fig, ax = plt.subplots(len(pols),3,figsize=(16,10)) - freqs = uv.freq_array[0]*1e-6 + + fig, ax = plt.subplots(len(pols), 3, figsize=(16, 10)) + freqs = uv.freq_array[0] * 1e-6 yticks = [int(i) for i in np.linspace(0, len(lst_array) - 1, 6)] yticklabels = [np.around(lst_array[ytick], 1) for ytick in yticks] xticks = [int(i) for i in np.linspace(0, len(freqs) - 1, 5)] xticklabels = np.around(freqs[xticks], 0) - for i,p in enumerate(pols): - ind1 = bls.index((ant1,ant1)) - ind2 = bls.index((ant2,ant2)) - indcross = bls.index((ant1,ant2)) - auto1 = np.log10(abs(crossVis[:,:,i,ind1])) - auto2 = np.log10(abs(crossVis[:,:,i,ind2])) - cmap = 'viridis' - if norm_type == 'abs': - cross = np.log10(abs(crossVis[:,:,i,indcross])) + for i, p in enumerate(pols): + ind1 = bls.index((ant1, ant1)) + ind2 = bls.index((ant2, ant2)) + indcross = bls.index((ant1, ant2)) + auto1 = np.log10(abs(crossVis[:, :, i, ind1])) + auto2 = np.log10(abs(crossVis[:, :, i, ind2])) + cmap = "viridis" + if norm_type == "abs": + cross = np.log10(abs(crossVis[:, :, i, indcross])) cmin = 3 cmax = 6 - elif norm_type == 'real': - cross = np.real(crossVis[:,:,i,indcross]) + elif norm_type == "real": + cross = np.real(crossVis[:, :, i, indcross]) cmin = -1e5 cmax = 1e5 - elif norm_type == 'imag': - cross = np.imag(crossVis[:,:,i,indcross]) - cmin =-1e7 - cmax=1e7 - elif norm_type == 'phase': - cross = np.angle(crossVis[:,:,i,indcross]) - cmap = 'twilight' + elif norm_type == "imag": + cross = np.imag(crossVis[:, :, i, indcross]) + cmin = -1e7 + cmax = 1e7 + elif norm_type == "phase": + cross = np.angle(crossVis[:, :, i, indcross]) + cmap = "twilight" cmin = -np.pi cmax = np.pi - - if plot_type == 'raw': - im = ax[i][0].imshow(auto1,aspect='auto',interpolation='nearest',vmin=6,vmax=8) - plt.colorbar(im,ax=ax[i][0]) - im = ax[i][1].imshow(cross,aspect='auto',interpolation='nearest',vmin=cmin,vmax=cmax,cmap=cmap) - plt.colorbar(im,ax=ax[i][1]) - if norm_type in ['real','imag']: - ax[i][1].set_yscale('symlog') - im = ax[i][2].imshow(auto2,aspect='auto',interpolation='nearest',vmin=6,vmax=8) - plt.colorbar(im,ax=ax[i][2]) - elif plot_type == 'ms': - auto1 = np.subtract(auto1,np.nanmean(auto1,axis=0)) - auto2 = np.subtract(auto2,np.nanmean(auto2,axis=0)) - if i<2: - im = ax[i][0].imshow(auto1,aspect='auto',interpolation='nearest',vmin=-0.07,vmax=0.07) - plt.colorbar(im,ax=ax[i][0]) - im = ax[i][2].imshow(auto2,aspect='auto',interpolation='nearest',vmin=-0.07,vmax=0.07) - plt.colorbar(im,ax=ax[i][2]) + + if plot_type == "raw": + im = ax[i][0].imshow( + auto1, aspect="auto", interpolation="nearest", vmin=6, vmax=8 + ) + plt.colorbar(im, ax=ax[i][0]) + im = ax[i][1].imshow( + cross, + aspect="auto", + interpolation="nearest", + vmin=cmin, + vmax=cmax, + cmap=cmap, + ) + plt.colorbar(im, ax=ax[i][1]) + if norm_type in ["real", "imag"]: + ax[i][1].set_yscale("symlog") + im = ax[i][2].imshow( + auto2, aspect="auto", interpolation="nearest", vmin=6, vmax=8 + ) + plt.colorbar(im, ax=ax[i][2]) + elif plot_type == "ms": + auto1 = np.subtract(auto1, np.nanmean(auto1, axis=0)) + auto2 = np.subtract(auto2, np.nanmean(auto2, axis=0)) + if i < 2: + im = ax[i][0].imshow( + auto1, aspect="auto", interpolation="nearest", vmin=-0.07, vmax=0.07 + ) + plt.colorbar(im, ax=ax[i][0]) + im = ax[i][2].imshow( + auto2, aspect="auto", interpolation="nearest", vmin=-0.07, vmax=0.07 + ) + plt.colorbar(im, ax=ax[i][2]) else: - ax[i][0].axis('off') - ax[i][2].axis('off') - if norm_type in ['real','imag']: - im = ax[i][1].imshow(cross,aspect='auto',interpolation='nearest',cmap=cmap, - norm=colors.SymLogNorm(linthresh=0.03, linscale=0.03, - vmin=cmin, vmax=cmax)) + ax[i][0].axis("off") + ax[i][2].axis("off") + if norm_type in ["real", "imag"]: + im = ax[i][1].imshow( + cross, + aspect="auto", + interpolation="nearest", + cmap=cmap, + norm=colors.SymLogNorm( + linthresh=0.03, linscale=0.03, vmin=cmin, vmax=cmax + ), + ) else: - im = ax[i][1].imshow(cross,aspect='auto',interpolation='nearest',vmin=cmin,vmax=cmax,cmap=cmap) - plt.colorbar(im,ax=ax[i][1]) - - for n in [0,1,2]: - if i==3: + im = ax[i][1].imshow( + cross, + aspect="auto", + interpolation="nearest", + vmin=cmin, + vmax=cmax, + cmap=cmap, + ) + plt.colorbar(im, ax=ax[i][1]) + + for n in [0, 1, 2]: + if i == 3: ax[i][n].set_xticks(xticks) ax[i][n].set_xticklabels(xticklabels) else: ax[i][n].set_xticks([]) - if i==0: - ax[i][0].set_title(f'{ant1} auto') - ax[i][1].set_title(f'({ant1},{ant2})') - ax[i][2].set_title(f'{ant2} auto') - plt.suptitle(f'{plot_type} - {norm_type}') + if i == 0: + ax[i][0].set_title(f"{ant1} auto") + ax[i][1].set_title(f"({ant1},{ant2})") + ax[i][2].set_title(f"{ant2} auto") + plt.suptitle(f"{plot_type} - {norm_type}") for n in range(len(pols)): ax[n][0].set_yticks(yticks) ax[n][0].set_yticklabels(yticklabels) ax[n][0].set_ylabel(pols[n]) - for m in [1,2]: + for m in [1, 2]: ax[n][m].set_yticks([]) if savefig: - plt.savefig(outfig,bbox_inches='tight') + plt.savefig(outfig, bbox_inches="tight") if write_params: args = locals() curr_func = inspect.stack()[0][3] - utils.write_params_to_text(outfig,args,curr_func,curr_file,githash) + utils.write_params_to_text(outfig, args, curr_func, curr_file, githash) plt.close() -def plot_antenna_positions(uv, badAnts=[], flaggedAnts={}, use_ants="all",hexsize=35): +def plot_antenna_positions(uv, badAnts=[], flaggedAnts={}, use_ants="all", hexsize=35): """ Plots the positions of all antennas that have data, colored by node. @@ -1629,7 +1680,7 @@ def plotCrossWaterfallsByCorrValue( if write_params: args = locals() curr_func = inspect.stack()[0][3] - utils.write_params_to_text(outfig,args,curr_func,curr_file,githash) + utils.write_params_to_text(outfig, args, curr_func, curr_file, githash) def plotTimeDifferencedSumWaterfalls( @@ -1755,7 +1806,7 @@ def plotTimeDifferencedSumWaterfalls( if write_params: args = locals() curr_func = inspect.stack()[0][3] - utils.write_params_to_text(outfig,args,curr_func,curr_file,githash) + utils.write_params_to_text(outfig, args, curr_func, curr_file, githash) plt.show() plt.close() return uvd, sdiffs @@ -1774,7 +1825,7 @@ def plot_single_matrix( savefig=False, write_params=True, outfig="", - figtype='png', + figtype="png", cmap="plasma", title="Corr Matrix", incAntLines=False, @@ -1945,7 +1996,7 @@ def plot_single_matrix( if write_params: args = locals() curr_func = inspect.stack()[0][3] - utils.write_params_to_text(outfig,args,curr_func,curr_file,githash) + utils.write_params_to_text(outfig, args, curr_func, curr_file, githash) plt.close() else: plt.show() @@ -2106,13 +2157,13 @@ def makeCorrMatrices( nanDiffs=False, pols=["EE", "NN", "EN", "NE"], interleave="even_odd", - plot_nodes='all', + plot_nodes="all", savefig=False, write_params=True, outfig="", write_uvh5=False, plotMatrices=True, - figtype='png', + figtype="png", printStatusUpdates=False, ): """ @@ -2167,17 +2218,19 @@ def makeCorrMatrices( if nHH <= nfilesUse: use_files_sum = HHfiles else: - use_files_sum = HHfiles[nHH // 2 - nfilesUse // 2 : nHH // 2 + nfilesUse // 2] + use_files_sum = HHfiles[ + nHH // 2 - nfilesUse // 2 : nHH // 2 + nfilesUse // 2 + ] use_files_diff = [file.split("sum")[0] + "diff.uvh5" for file in use_files_sum] print(f"{len(use_files_sum)} times") if len(freq_inds) == 0: print("All frequency bins") - nfreqs = 'all' + nfreqs = "all" else: nfreqs = freq_inds[1] - freq_inds[0] print(f"{nfreqs} frequency bins") - - if plot_nodes is not 'all': + + if plot_nodes is not "all": if sm is None: uv = UVData() uv.read(HHfiles[0]) @@ -2187,9 +2240,9 @@ def makeCorrMatrices( plot_ants = [] for node in nodeDict: if node in plot_nodes: - for ant in nodeDict[node]['ants']: + for ant in nodeDict[node]["ants"]: plot_ants.append(ant) - if use_ants == 'all': + if use_ants == "all": use_ants = plot_ants else: use_ants = [a for a in use_ants if a in plot_ants] @@ -2212,28 +2265,28 @@ def makeCorrMatrices( print("Reading diff files") df.read(use_files_diff, antenna_nums=use_ants) elif df is None and type(dffile) is str: - write_diff=False + write_diff = False if printStatusUpdates: print("Reading diff files") df = UVData() - df.read(dffile,antenna_nums=use_ants) + df.read(dffile, antenna_nums=use_ants) times = np.unique(df.time_array) - if len(times) > 2*nfilesUse: - timesUse = times[0:nfilesUse*2] + if len(times) > 2 * nfilesUse: + timesUse = times[0 : nfilesUse * 2] df.select(times=timesUse) else: - write_diff=False - df = df.select(antenna_nums=use_ants,inplace=False) + write_diff = False + df = df.select(antenna_nums=use_ants, inplace=False) times = np.unique(df.time_array) - if len(times) > 2*nfilesUse: - timesUse = times[0:nfilesUse*2] + if len(times) > 2 * nfilesUse: + timesUse = times[0 : nfilesUse * 2] df.select(times=timesUse) if write_diff: JD = int(df.time_array[0]) if printStatusUpdates: - print('Writing uvh5 diff files') - df.write_uvh5(f'{JD}_{nfilesUse}files_{nfreqs}freqs_diff.uvh5',clobber=True) - + print("Writing uvh5 diff files") + df.write_uvh5(f"{JD}_{nfilesUse}files_{nfreqs}freqs_diff.uvh5", clobber=True) + if sm is None and smfile is None: if write_uvh5: write_sum = True @@ -2244,27 +2297,27 @@ def makeCorrMatrices( print("Reading sum files") sm.read(use_files_sum, antenna_nums=use_ants) elif sm is None and type(smfile) is str: - write_sum=False + write_sum = False if printStatusUpdates: print("Reading sum files") sm = UVData() - sm.read(smfile,antenna_nums=use_ants) + sm.read(smfile, antenna_nums=use_ants) times = np.unique(sm.time_array) - if len(times) > 2*nfilesUse: - timesUse = times[0:nfilesUse*2] + if len(times) > 2 * nfilesUse: + timesUse = times[0 : nfilesUse * 2] sm.select(times=timesUse) else: - write_sum=False - sm = sm.select(antenna_nums=use_ants,inplace=False) + write_sum = False + sm = sm.select(antenna_nums=use_ants, inplace=False) times = np.unique(sm.time_array) - if len(times) > 2*nfilesUse: - timesUse = times[0:nfilesUse*2] + if len(times) > 2 * nfilesUse: + timesUse = times[0 : nfilesUse * 2] sm.select(times=timesUse) if write_sum: if printStatusUpdates: - print('Writing uvh5 sum files') + print("Writing uvh5 sum files") JD = int(sm.time_array[0]) - sm.write_uvh5(f'{JD}_{nfilesUse}files_{nfreqs}freqs_sum.uvh5',clobber=True) + sm.write_uvh5(f"{JD}_{nfilesUse}files_{nfreqs}freqs_sum.uvh5", clobber=True) # Calculate real and imaginary correlation matrices if printStatusUpdates: @@ -2295,8 +2348,16 @@ def makeCorrMatrices( if printStatusUpdates: print("Plotting real matrix") plot_single_matrix( - sm, corr_real, logScale=True, vmin=0.01, title="|Real|", pols=pols, savefig=savefig, write_params=write_params, - figtype=figtype, outfig=f'{outfig}_{nfilesUse}files_{nfreqs}freqs_real.{figtype}', + sm, + corr_real, + logScale=True, + vmin=0.01, + title="|Real|", + pols=pols, + savefig=savefig, + write_params=write_params, + figtype=figtype, + outfig=f"{outfig}_{nfilesUse}files_{nfreqs}freqs_real.{figtype}", ) # Plot matrix of imaginary values if printStatusUpdates: @@ -2313,7 +2374,7 @@ def makeCorrMatrices( savefig=savefig, write_params=write_params, figtype=figtype, - outfig=f'{outfig}_{nfilesUse}files_{nfreqs}freqs_imag.{figtype}', + outfig=f"{outfig}_{nfilesUse}files_{nfreqs}freqs_imag.{figtype}", ) # Plot matrix of real values on linlog color scale. if printStatusUpdates: @@ -2330,7 +2391,7 @@ def makeCorrMatrices( savefig=savefig, write_params=write_params, figtype=figtype, - outfig=f'{outfig}_{nfilesUse}files_{nfreqs}freqs_linlog.{figtype}', + outfig=f"{outfig}_{nfilesUse}files_{nfreqs}freqs_linlog.{figtype}", ) return sm, df, corr_real, corr_imag, perBlSummary @@ -2348,7 +2409,7 @@ def plotPerNodeSpectraAndHists( freq_range=[132, 148], pol="allpols", percentage=10, - plot_nodes='all', + plot_nodes="all", avg="mean", printStatusUpdates=False, ): @@ -2385,7 +2446,7 @@ def plotPerNodeSpectraAndHists( """ from matplotlib.lines import Line2D - if plot_nodes == 'all': + if plot_nodes == "all": nodes, antDict, inclNodes = utils.generate_nodeDict(sm) plot_nodes = inclNodes if printStatusUpdates: @@ -2402,27 +2463,54 @@ def plotPerNodeSpectraAndHists( freqs_all = [freqs_all[int(i)] for i in inds] c_node = { - n: {'base': [], 'full': [], 'freqs': [], 'hist_vals': []} - for n in plot_nodes - } + n: {"base": [], "full": [], "freqs": [], "hist_vals": []} for n in plot_nodes + } for node in plot_nodes: - c_node[node]['full'] = np.asarray(perNodeSummary[pol][node]["intranode"]) - c_node[node]['freqs'] = np.tile(freqs, np.shape(c_node[node]['full'])[0] * int(np.shape(c)[1] / 1536)) - c_node[node]['base'] = np.asarray(c_node[node]['full']).flatten() - c_node[node]['base'], inds = utils.getRandPercentage(c_node[node]['base'], percentage) - c_node[node]['freqs'] = [c_node[node]['freqs'][int(i)] for i in inds] + c_node[node]["full"] = np.asarray(perNodeSummary[pol][node]["intranode"]) + c_node[node]["freqs"] = np.tile( + freqs, np.shape(c_node[node]["full"])[0] * int(np.shape(c)[1] / 1536) + ) + c_node[node]["base"] = np.asarray(c_node[node]["full"]).flatten() + c_node[node]["base"], inds = utils.getRandPercentage( + c_node[node]["base"], percentage + ) + c_node[node]["freqs"] = [c_node[node]["freqs"][int(i)] for i in inds] freq_min_ind = np.argmin(np.abs(np.subtract(freqs, freq_range[0]))) freq_max_ind = np.argmin(np.abs(np.subtract(freqs, freq_range[1]))) for node in plot_nodes: - c_node[node]['hist_vals'] = c_node[node]['full'][:, freq_min_ind:freq_max_ind].flatten() + c_node[node]["hist_vals"] = c_node[node]["full"][ + :, freq_min_ind:freq_max_ind + ].flatten() if printStatusUpdates: print("Plotting") fig, axes = plt.subplots(2, 2, figsize=(16, 16)) - colors = ['tab:blue', 'tab:red', 'tab:orange', 'tab:green', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan', 'blue', 'indigo', 'magenta', 'salmon', 'lawngreen', 'peachpuff', 'powderblue', 'darkseagreen', 'gold', 'teal', 'darkred'] + colors = [ + "tab:blue", + "tab:red", + "tab:orange", + "tab:green", + "tab:purple", + "tab:brown", + "tab:pink", + "tab:gray", + "tab:olive", + "tab:cyan", + "blue", + "indigo", + "magenta", + "salmon", + "lawngreen", + "peachpuff", + "powderblue", + "darkseagreen", + "gold", + "teal", + "darkred", + ] legend_elements = [ Line2D( @@ -2434,11 +2522,18 @@ def plotPerNodeSpectraAndHists( color=colors[i], alpha=1, ) - for i,node in enumerate(plot_nodes)] + for i, node in enumerate(plot_nodes) + ] - for i,node in enumerate(plot_nodes): + for i, node in enumerate(plot_nodes): axes[0][0].scatter( - c_node[node]['freqs'], np.real(c_node[node]['base']), s=1.5, alpha=0.2, color=colors[i], label=node) + c_node[node]["freqs"], + np.real(c_node[node]["base"]), + s=1.5, + alpha=0.2, + color=colors[i], + label=node, + ) axes[0][0].scatter(freqs, np.real(c_avg), s=1, color="k") axes[0][0].legend(handles=legend_elements, fontsize=12) @@ -2453,9 +2548,15 @@ def plotPerNodeSpectraAndHists( axes[0][0].axvline(freq_range[1], color="g") axes[0][0].set_title("Real Spectra") - for i,node in enumerate(plot_nodes): + for i, node in enumerate(plot_nodes): axes[0][1].scatter( - c_node[node]['freqs'], np.imag(c_node[node]['base']), s=1.5, alpha=0.2, color=colors[i], label=node) + c_node[node]["freqs"], + np.imag(c_node[node]["base"]), + s=1.5, + alpha=0.2, + color=colors[i], + label=node, + ) axes[0][1].scatter(freqs, np.imag(c_avg), s=1, color="k") axes[0][1].set_ylim(-1, 1) @@ -2474,7 +2575,7 @@ def plotPerNodeSpectraAndHists( ####### Connectivity Histograms ###### for i, node in enumerate(plot_nodes): axes[1][0].hist( - np.real(c_node[node]['hist_vals']), + np.real(c_node[node]["hist_vals"]), log=True, bins=bins, edgecolor=colors[i], @@ -2484,7 +2585,7 @@ def plotPerNodeSpectraAndHists( density=True, ) axes[1][1].hist( - np.imag(c_node[node]['hist_vals']), + np.imag(c_node[node]["hist_vals"]), log=True, bins=bins, edgecolor=colors[i], @@ -2502,16 +2603,17 @@ def plotPerNodeSpectraAndHists( axes[1][1].set_ylim(10e-4, 10e0) axes[1][1].set_title("Imag Histogram") axes[1][0].legend() - + if savefig: plt.savefig(outfig, bbox_inches="tight") if write_params: args = locals() curr_func = inspect.stack()[0][3] - utils.write_params_to_text(outfig,args,curr_func,curr_file,githash) + utils.write_params_to_text(outfig, args, curr_func, curr_file, githash) return perNodeSummary + def plotCorrSpectraAndHists( sm, df, @@ -2524,7 +2626,7 @@ def plotCorrSpectraAndHists( freq_range=[132, 148], pol="allpols", percentage=10, - plot_nodes='all', + plot_nodes="all", avg="mean", printStatusUpdates=True, ): @@ -2809,7 +2911,7 @@ def plotCorrSpectraAndHists( if write_params: args = locals() curr_func = inspect.stack()[0][3] - utils.write_params_to_text(outfig,args,curr_func,curr_file,githash) + utils.write_params_to_text(outfig, args, curr_func, curr_file, githash) plt.close() else: plt.show() @@ -3045,14 +3147,14 @@ def plotSmithChartByNode( if write_params: args = locals() curr_func = inspect.stack()[0][3] - utils.write_params_to_text(outfig,args,curr_func,curr_file,githash) + utils.write_params_to_text(outfig, args, curr_func, curr_file, githash) plt.close() else: plt.show() plt.close("all") -def plot_wfs_delay(uvd, _data_sq, pol,plot_nodes='all'): +def plot_wfs_delay(uvd, _data_sq, pol, plot_nodes="all"): """ Waterfall diagram for autocorrelation delay spectrum. @@ -3085,14 +3187,14 @@ def plot_wfs_delay(uvd, _data_sq, pol,plot_nodes='all'): maxants = 0 polnames = ["nn", "ee", "ne", "en"] for node in nodes: - if plot_nodes is not 'all' and node not in plot_nodes: + if plot_nodes is not "all" and node not in plot_nodes: continue n = len(nodes[node]["ants"]) if n > maxants: maxants = n Nside = maxants - if plot_nodes == 'all': + if plot_nodes == "all": Yside = len(inclNodes) else: Yside = len(plot_nodes) @@ -3121,7 +3223,7 @@ def plot_wfs_delay(uvd, _data_sq, pol,plot_nodes='all'): yticks = [int(i) for i in np.linspace(0, len(lsts) - 1, 6)] yticklabels = [np.around(lsts[ytick], 1) for ytick in yticks] for i, n in enumerate(inclNodes): - if plot_nodes is not 'all': + if plot_nodes is not "all": inclNodes = plot_nodes if n not in plot_nodes: continue diff --git a/hera_commissioning_tools/utils.py b/hera_commissioning_tools/utils.py index 40dfffa..631864a 100644 --- a/hera_commissioning_tools/utils.py +++ b/hera_commissioning_tools/utils.py @@ -6,78 +6,97 @@ import subprocess - def get_git_revision_hash(dirpath=None) -> str: """ Function to get current git hash of this repo. """ if dirpath is None: - return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() + return ( + subprocess.check_output(["git", "rev-parse", "HEAD"]) + .decode("ascii") + .strip() + ) else: - return subprocess.check_output(['git', '-C', str(dirpath), 'rev-parse', 'HEAD']).decode('ascii').strip() + return ( + subprocess.check_output(["git", "-C", str(dirpath), "rev-parse", "HEAD"]) + .decode("ascii") + .strip() + ) -def write_params_to_text(outfile,args,curr_func=None,curr_file=None,githash=None,**kwargs): + +def write_params_to_text( + outfile, args, curr_func=None, curr_file=None, githash=None, **kwargs +): from pyuvdata import UVCal import matplotlib - - with open(f'{outfile}.txt', 'w') as f: - if 'nb_path' in kwargs.keys(): - nb_path = kwargs['nb_path'] - f.write(f'Call came from within notebook: {nb_path}\n') + + with open(f"{outfile}.txt", "w") as f: + if "nb_path" in kwargs.keys(): + nb_path = kwargs["nb_path"] + f.write(f"Call came from within notebook: {nb_path}\n") if curr_func is not None: - f.write(f'Called from {curr_func}() \n') + f.write(f"Called from {curr_func}() \n") if curr_file is not None: - f.write(f'Within {curr_file} \n') + f.write(f"Within {curr_file} \n") if githash is not None: - f.write(f'Above file has githash: {githash} \n') - f.write(f'pyuvdata version: {pyuvdata.__version__} \n') - f.write(f'Numpy version: {np.__version__} \n') - f.write(f'Matplotlib version: {matplotlib.__version__} \n') - f.write('\n \n') - f.write('------------------ FUNCTION ARGS ------------------ \n') + f.write(f"Above file has githash: {githash} \n") + f.write(f"pyuvdata version: {pyuvdata.__version__} \n") + f.write(f"Numpy version: {np.__version__} \n") + f.write(f"Matplotlib version: {matplotlib.__version__} \n") + f.write("\n \n") + f.write("------------------ FUNCTION ARGS ------------------ \n") for arg in args.keys(): val = args[arg] if type(val) is list and len(val) > 150: - f.write(f'{arg}: [{val[0]} ... {val[-1]}]') - elif isinstance(val,UVData) or isinstance(val,UVCal): - if isinstance(val,UVData): - f.write(f'{arg}: UVData Object \n') - f.write(f' antenna numbers: {val.get_ants()} \n') - elif isinstance(val,UVCal): - f.write(f'{arg}: UVCal Object \n') - f.write(f' jd range: {val.time_array[0]} - {val.time_array[-1]} \n') - f.write(f' lst range: {val.lst_array[0]* 3.819719} - {val.lst_array[-1]* 3.819719} \n') - f.write(f' freq range: {val.freq_array[0][0]*1e-6} - {val.freq_array[0][-1]*1e-6} \n') - if hasattr(val,'file_name'): - f.write(f' File name(s): {val.file_name} \n') - elif type(val)==np.ufunc: - f.write(f'{arg}: {val.__name__}') + f.write(f"{arg}: [{val[0]} ... {val[-1]}]") + elif isinstance(val, UVData) or isinstance(val, UVCal): + if isinstance(val, UVData): + f.write(f"{arg}: UVData Object \n") + f.write(f" antenna numbers: {val.get_ants()} \n") + elif isinstance(val, UVCal): + f.write(f"{arg}: UVCal Object \n") + f.write(f" jd range: {val.time_array[0]} - {val.time_array[-1]} \n") + f.write( + f" lst range: {val.lst_array[0]* 3.819719} - {val.lst_array[-1]* 3.819719} \n" + ) + f.write( + f" freq range: {val.freq_array[0][0]*1e-6} - {val.freq_array[0][-1]*1e-6} \n" + ) + if hasattr(val, "file_name"): + f.write(f" File name(s): {val.file_name} \n") + elif type(val) == np.ufunc: + f.write(f"{arg}: {val.__name__}") else: - f.write(f'{arg}: {val}') - f.write('\n') - f.write(' \n') - f.write(' \n') - f.write('------------------ ADDITIONAL INFO ------------------ \n') + f.write(f"{arg}: {val}") + f.write("\n") + f.write(" \n") + f.write(" \n") + f.write("------------------ ADDITIONAL INFO ------------------ \n") for arg in kwargs.keys(): val = kwargs[arg] - if arg == 'nb_path': + if arg == "nb_path": continue - if isinstance(val,UVData) or isinstance(val,UVCal): - if isinstance(val,UVData): - f.write(f'{arg}: UVData Object \n') - f.write(f' antenna numbers: {val.get_ants()} \n') - elif isinstance(val,UVCal): - f.write(f'{arg}: UVCal Object \n') - f.write(f' jd range: {val.time_array[0]} - {val.time_array[-1]} \n') - f.write(f' lst range: {val.lst_array[0]* 3.819719} - {val.lst_array[-1]* 3.819719} \n') - f.write(f' freq range: {val.freq_array[0][0]*1e-6} - {val.freq_array[0][-1]*1e-6} \n') - if hasattr(val,'file_name'): - f.write(f' File name(s): {val.file_name} \n') - f.write(' \n') + if isinstance(val, UVData) or isinstance(val, UVCal): + if isinstance(val, UVData): + f.write(f"{arg}: UVData Object \n") + f.write(f" antenna numbers: {val.get_ants()} \n") + elif isinstance(val, UVCal): + f.write(f"{arg}: UVCal Object \n") + f.write(f" jd range: {val.time_array[0]} - {val.time_array[-1]} \n") + f.write( + f" lst range: {val.lst_array[0]* 3.819719} - {val.lst_array[-1]* 3.819719} \n" + ) + f.write( + f" freq range: {val.freq_array[0][0]*1e-6} - {val.freq_array[0][-1]*1e-6} \n" + ) + if hasattr(val, "file_name"): + f.write(f" File name(s): {val.file_name} \n") + f.write(" \n") else: - f.write(f'{arg}: {kwargs[arg]}') - f.write('\n') - f.write('------------------------------------------ \n') + f.write(f"{arg}: {kwargs[arg]}") + f.write("\n") + f.write("------------------------------------------ \n") + def get_files(data_path, JD): """ @@ -267,10 +286,11 @@ def detectWrongConnectionAnts(uvd, dtype="load"): wrongAnts.append(ant) return wrongAnts -def loadCrossVis(HHfiles, ants, savearray=False, outfile='', printStatusUpdates=False): + +def loadCrossVis(HHfiles, ants, savearray=False, outfile="", printStatusUpdates=False): """ Loads cross and auto visibilitis for all baselines constructed by combining antennas in ants. - + Parameters: ---------- HHfiles: List @@ -283,7 +303,7 @@ def loadCrossVis(HHfiles, ants, savearray=False, outfile='', printStatusUpdates= File prefix to save data to if savearray is True. printStatusUpdates: Boolean Option to print updates as files are being loaded. - + Returns: --------- vis_array: Numpy Array @@ -293,27 +313,30 @@ def loadCrossVis(HHfiles, ants, savearray=False, outfile='', printStatusUpdates= bls: List List of baselines included in the visibilitiy array. The order of these maps to the baseline axis of the visibility array. """ - bls = [(a1,a2) for a1 in ants for a2 in ants] - vis_array = np.zeros((2*len(HHfiles),len(uv_sum_sky.freq_array[0]),4,len(bls)),dtype=np.complex_) - lst_array = np.zeros((2*len(HHfiles))) - for i,f in enumerate(HHfiles): - if i%10==0 and printStatusUpdates: - print(f'reading file {i}') - JD = float(f.split('zen.')[1].split('.sum')[0]) + bls = [(a1, a2) for a1 in ants for a2 in ants] + vis_array = np.zeros( + (2 * len(HHfiles), len(uv_sum_sky.freq_array[0]), 4, len(bls)), + dtype=np.complex_, + ) + lst_array = np.zeros((2 * len(HHfiles))) + for i, f in enumerate(HHfiles): + if i % 10 == 0 and printStatusUpdates: + print(f"reading file {i}") + JD = float(f.split("zen.")[1].split(".sum")[0]) uv = UVData() - uv.read(f,antenna_nums=ants) - for j,bl in enumerate(bls): - vis = uv.get_data(bl[0],bl[1]) - vis_array[i*2:i*2+2,:,:,j] = vis + uv.read(f, antenna_nums=ants) + for j, bl in enumerate(bls): + vis = uv.get_data(bl[0], bl[1]) + vis_array[i * 2 : i * 2 + 2, :, :, j] = vis lsts = uv.lst_array * 3.819719 inds = np.unique(lsts, return_index=True)[1] lsts = [lsts[ind] for ind in sorted(inds)] - lst_array[i*2:i*2+2] = lsts + lst_array[i * 2 : i * 2 + 2] = lsts if savearray: - print(f'saving to {outfile}') - np.save(f'{outfile}_vis', vis_array) - np.save(f'{outfile}_lsts', lst_array) - np.save(f'{outfile}_bls', bls) + print(f"saving to {outfile}") + np.save(f"{outfile}_vis", vis_array) + np.save(f"{outfile}_lsts", lst_array) + np.save(f"{outfile}_bls", bls) return vis_array, lst_array, bls @@ -383,8 +406,9 @@ def generate_nodeDict(uv, pols=["E"]): inclNodes = np.unique(inclNodes) return nodes, antDict, inclNodes + def get_slot_number(uv, ant, sortedAntennas, sortedSnapLocs, sortedSnapInputs): - ind = np.argmin(abs(np.subtract(sortedAntennas,ant))) + ind = np.argmin(abs(np.subtract(sortedAntennas, ant))) loc = sortedSnapLocs[ind] inp = sortedSnapInputs[ind] slot = loc * 3 @@ -742,7 +766,7 @@ def calc_corr_metric( nanDiffs=False, interleave="even_odd", divideByAbs=True, - plot_nodes='all', + plot_nodes="all", perNodeSummary=False, printStatusUpdates=True, crossPolCheck=False, @@ -787,7 +811,7 @@ def calc_corr_metric( from hera_mc import cm_hookup if printStatusUpdates: - print('Getting metadata') + print("Getting metadata") if type(pols) == str: pols = [pols] if use_ants == "all": @@ -828,21 +852,15 @@ def calc_corr_metric( } if perNodeSummary: perNodeSummary = { - pol: { - node: { - 'intranode': [], - 'all': [] - } - for node in inclNodes - } + pol: {node: {"intranode": [], "all": []} for node in inclNodes} for pol in np.append(pols, "allpols") } x = cm_hookup.get_hookup("default") if printStatusUpdates: - print('Calculating') + print("Calculating") for i, a1 in enumerate(useAnts): - if i%10 == 0 and printStatusUpdates: - print(f'Calculating for antenna {i}') + if i % 10 == 0 and printStatusUpdates: + print(f"Calculating for antenna {i}") for j, a2 in enumerate(useAnts): if len(antpols) > 1: ant1 = int(a1[:-1]) @@ -964,7 +982,9 @@ def calc_corr_metric( ) perBlSummary["allpols"]["intranode_bls"].append((a1, a2)) if perNodeSummary: - perNodeSummary[pol][n1]['intranode'].append(np.nanmean(product, axis=0)) + perNodeSummary[pol][n1]["intranode"].append( + np.nanmean(product, axis=0) + ) else: perBlSummary[pol]["internode_vals"].append( np.nanmean(product, axis=0) @@ -975,8 +995,12 @@ def calc_corr_metric( ) perBlSummary["allpols"]["internode_bls"].append((a1, a2)) if perNodeSummary: - perNodeSummary[pol][n1]['all'].append(np.nanmean(product, axis=0)) - perNodeSummary[pol][n2]['all'].append(np.nanmean(product, axis=0)) + perNodeSummary[pol][n1]["all"].append( + np.nanmean(product, axis=0) + ) + perNodeSummary[pol][n2]["all"].append( + np.nanmean(product, axis=0) + ) for key in polInds.keys(): polInds[key][1] = 0 if perNodeSummary: diff --git a/setup.py b/setup.py index e438ad2..89c0083 100644 --- a/setup.py +++ b/setup.py @@ -4,17 +4,15 @@ long_description = fh.read() setuptools.setup( - name='hera_commissioning_tools', - version='0.0.0', - author='Dara Storer', - author_email='darajstorer@gmail.com', - description='Collection of HERA commissioning analysis tools.', + name="hera_commissioning_tools", + version="0.0.0", + author="Dara Storer", + author_email="darajstorer@gmail.com", + description="Collection of HERA commissioning analysis tools.", long_description=long_description, long_description_content_type="text/markdown", - url='https://github.com/HERA-Team/hera_commissioning_tools', - license='MIT', - packages=['hera_commissioning_tools'], - install_requires=['numpy>=1.18', - 'matplotlib', - 'pyuvdata'], + url="https://github.com/HERA-Team/hera_commissioning_tools", + license="MIT", + packages=["hera_commissioning_tools"], + install_requires=["numpy>=1.18", "matplotlib", "pyuvdata"], )