From 3c287e8fe9609419bc73992d84dd43efb5c25b58 Mon Sep 17 00:00:00 2001 From: Varun Iyer Date: Wed, 6 Apr 2022 10:25:23 -0700 Subject: [PATCH] Black-formatted all files pylint should be cleaner, functionality unchanged --- sdi/__init__.py | 4 +- sdi/_cli.py | 8 ++ sdi/align.py | 54 ++++++++----- sdi/collate.py | 177 ++++++++++++++++++++++++++++++++++--------- sdi/combine.py | 24 ++++-- sdi/display.py | 34 ++++++--- sdi/extract.py | 75 ++++++++++++------ sdi/fitsio.py | 42 ++++++++-- sdi/ref.py | 83 ++++++++++++++------ sdi/secid.py | 96 +++++++++++++++-------- sdi/snr_function.py | 15 ++-- sdi/subtract.py | 41 +++++++--- sdi/test/align.py | 37 ++++++--- sdi/test/collate.py | 15 ++-- sdi/test/combine.py | 31 +++++--- sdi/test/extract.py | 31 ++++++-- sdi/test/fitsio.py | 34 +++++++-- sdi/test/subtract.py | 36 ++++++--- sdi/test_cmd.py | 11 ++- setup.py | 21 ++++- 20 files changed, 639 insertions(+), 230 deletions(-) diff --git a/sdi/__init__.py b/sdi/__init__.py index adbdc40..4db6ac7 100644 --- a/sdi/__init__.py +++ b/sdi/__init__.py @@ -9,7 +9,7 @@ from .ref import ref from .secid import secid from .collate import collate -from .snr_function import snr #New +from .snr_function import snr # New from . import test # in order to test click commands @@ -20,7 +20,7 @@ from .extract import extract_cmd as _extract_cmd from .subtract import subtract_cmd as _subtract_cmd from .collate import collate_cmd as _collate_cmd -from .snr_function import snr_cmd as _snr_cmd #New +from .snr_function import snr_cmd as _snr_cmd # New from . import test_cmd _scidir = os.path.join(os.path.dirname(__file__), "test/fixtures/science") diff --git a/sdi/_cli.py b/sdi/_cli.py index 3b1e115..8f14433 100644 --- a/sdi/_cli.py +++ b/sdi/_cli.py @@ -5,12 +5,14 @@ from functools import update_wrapper import click + @click.group(chain=True) def cli(): """ Chains together pipeline commands. """ + @cli.resultcallback() def run_pipeline(operators): """ @@ -26,6 +28,7 @@ def run_pipeline(operators): # do necessary things on overall outputs pass + def operator(func): """ Decorator which wraps commands so that they return functions after being @@ -33,22 +36,27 @@ def operator(func): All of the returned functions are passed as an iterable into run_pipeline. """ + def new_func(*args, **kwargs): def operator(hduls): # args and kwargs are subcommand-specific return func(hduls, *args, **kwargs) + return operator # basically return new_func, but better return update_wrapper(new_func, func) + def generator(func): """ Does what operator does, but for the first thing in the series, like opening the hduls. Works with sub-funcs that do not have 'hduls' as the first argument. """ + def new_func(hduls, *args, **kwargs): yield from hduls yield from func(*args, **kwargs) + return operator(update_wrapper(new_func, func)) diff --git a/sdi/align.py b/sdi/align.py index c069172..8485b63 100644 --- a/sdi/align.py +++ b/sdi/align.py @@ -10,72 +10,84 @@ import click import numpy as np from astropy.io.fits import CompImageHDU -#from sources import Source + +# from sources import Source import astroalign -from .snr_function import snr -#from _scripts import snr +from .snr_function import snr + +# from _scripts import snr from . import _cli as cli + def align(hduls, name="SCI", ref=None): """ Aligns the source astronomical image(s) to the reference astronomical image \b :param hduls: list of fitsfiles :param name: header containing image data. - :param reference: index of desired + :param reference: index of desired :return: list of FITS files with HDU aligned """ hduls_list = [hdul for hdul in hduls] outputs = [] - hduls_list = snr(hduls_list,name) #SNR function appends snr to each hdul's "CAT" table. - hduls_list = list(hduls_list) #convert generator to list. + hduls_list = snr( + hduls_list, name + ) # SNR function appends snr to each hdul's "CAT" table. + hduls_list = list(hduls_list) # convert generator to list. - if ref is None: #No reference index given. we establish reference based on best signal to noise ratio + if ( + ref is None + ): # No reference index given. we establish reference based on best signal to noise ratio - reference = hduls_list[0][name] #0th index reference is used by default + reference = hduls_list[0][name] # 0th index reference is used by default ref_SNR = hduls_list[0]["CAT"].header["SNR"] - for hdul in hduls_list: #loops though hdul lists finding the hdul with greatest snr - if hdul["CAT"].header["SNR"] > ref_SNR: #compares SNR value of current hdul to ref + for ( + hdul + ) in hduls_list: # loops though hdul lists finding the hdul with greatest snr + if ( + hdul["CAT"].header["SNR"] > ref_SNR + ): # compares SNR value of current hdul to ref ref_SNR = hdul["CAT"].header["SNR"] reference = hdul[name] - else: #ref index is provided + else: # ref index is provided reference = hduls_list[ref][name] try: ref_data = reference.data except AttributeError: - print("The reference file have doesn't have Attribute: Data") + print("The reference file have doesn't have Attribute: Data") pass - for i,hdul in enumerate(hduls_list): + for i, hdul in enumerate(hduls_list): np_src = hdul[name] - + # possibly unneccessary but unsure about scoping output = np.array([]) - + try: - output = astroalign.register(np_src,ref_data)[0] + output = astroalign.register(np_src, ref_data)[0] except ValueError: np_src = hdul[name].data.byteswap().newbyteorder() output = astroalign.register(np_src, ref_data)[0] pass - if hasattr(hdul[name], "data"): + if hasattr(hdul[name], "data"): idx = hdul.index_of(name) hdul[idx].data = output - hdul[idx].header['EXTNAME'] = ("ALGN ") - hdul[idx].header = reference.header #hdul['sci'].header = reference.header - + hdul[idx].header["EXTNAME"] = "ALGN " + hdul[idx].header = reference.header # hdul['sci'].header = reference.header + return (hdul for hdul in hduls_list) + @cli.cli.command("align") @click.option("-n", "--name", default="SCI", help="The HDU to be aligned.") @cli.operator -#TODO use CAT sources if they exist +# TODO use CAT sources if they exist ## align function wrapper def align_cmd(hduls, name="SCI", ref=None): diff --git a/sdi/collate.py b/sdi/collate.py index 5ed035b..b67a811 100644 --- a/sdi/collate.py +++ b/sdi/collate.py @@ -6,29 +6,32 @@ from sklearn.cluster import dbscan, OPTICS from astropy.io import fits -#Subfunctions utilized in clustering function +# Subfunctions utilized in clustering function def _norm(array): array -= min(array) - array *= 1/max(array) + array *= 1 / max(array) return array + def dist_func(a1, a2): tot = 0 for i in range(len(a1)): - tot += (a1[i]-a2[i])**2 + tot += (a1[i] - a2[i]) ** 2 return tot**0.5 + def obj_mean(cluster): means = None for source in cluster: if source != None: if not means: - means = [0 for i in range(len(source))] + means = [0 for i in range(len(source))] for idx, coord in enumerate(source): means[idx] += coord - return [mean/(len(cluster)-cluster.count(None)) for mean in means] + return [mean / (len(cluster) - cluster.count(None)) for mean in means] + -#Exact same function as in _scripts/collateutils, called at the end of running collate +# Exact same function as in _scripts/collateutils, called at the end of running collate def cluster_ratio(hduls, name="CAT", tablename="OBJ"): """ Prints number of clustered sources out of total sources as well as the percent of clustered sources. @@ -46,13 +49,28 @@ def cluster_ratio(hduls, name="CAT", tablename="OBJ"): if point[0] == -1: noise += 1 clustered_points += len(hdul[tablename].data) - noise - print("{} out of {} points clustered".format(clustered_points, points), clustered_points/points) + print( + "{} out of {} points clustered".format(clustered_points, points), + clustered_points / points, + ) -#Collate function itself -def collate(hduls, name="CAT", tablename="OBJ", coords = ["x", "y", "flux"], algorithm="DBSCAN", minpts=4, eps=0.001, maxeps=np.inf, xi=0.85, collision_fix=False, show_collisions=False): +# Collate function itself +def collate( + hduls, + name="CAT", + tablename="OBJ", + coords=["x", "y", "flux"], + algorithm="DBSCAN", + minpts=4, + eps=0.001, + maxeps=np.inf, + xi=0.85, + collision_fix=False, + show_collisions=False, +): """ - Clusters source data from HDULs to predict sky objects and then appends a TableHDU to each HDUL containing object/cluster data. + Clusters source data from HDULs to predict sky objects and then appends a TableHDU to each HDUL containing object/cluster data. The nth index of each TableHDU for each HDUL refers to the same object with the value of the index referring to the (value)th source in the HDUL's source catalog. Collisions refer to when two sources in the same HDUL are clustered to the same object. In the event of a collision, @@ -82,43 +100,63 @@ def collate(hduls, name="CAT", tablename="OBJ", coords = ["x", "y", "flux"], alg normirdf = [_norm(array) for array in irdf] if algorithm == "DBSCAN": - print('\n Collating using {} with {} coordinates, minpts={} and eps={}'.format(algorithm, coords, minpts, eps)) + print( + "\n Collating using {} with {} coordinates, minpts={} and eps={}".format( + algorithm, coords, minpts, eps + ) + ) cores, labels = dbscan(np.vstack(normirdf).T, eps=eps, min_samples=minpts) elif algorithm == "OPTICS": - print('\n Collating using {} with {} coordinates, minpts={}, xi={}, and maxeps={}'.format(algorithm, coords, minpts, xi, maxeps)) - labels = OPTICS(min_samples=minpts, max_eps=maxeps, xi=xi).fit(np.vstack(normirdf).T).labels_ + print( + "\n Collating using {} with {} coordinates, minpts={}, xi={}, and maxeps={}".format( + algorithm, coords, minpts, xi, maxeps + ) + ) + labels = ( + OPTICS(min_samples=minpts, max_eps=maxeps, xi=xi) + .fit(np.vstack(normirdf).T) + .labels_ + ) else: print('Invalid clustering method, please use "DBSCAN" or "OPTICS"') return (hdul for hdul in hduls) - + obj_count = max(labels) if obj_count == -1: - print('No clusters could be formed. Please try using more images, increasing eps, or changing algorithms.') + print( + "No clusters could be formed. Please try using more images, increasing eps, or changing algorithms." + ) return (hdul for hdul in hduls) for i, hdul in enumerate(hduls): source_count = len(hdul[name].data) - cols = [-1 for j in range(obj_count+1)] + cols = [-1 for j in range(obj_count + 1)] for j in range(source_count): if labels[j] != -1: if cols[labels[j]] == -1: cols[labels[j]] = j else: collision_count += 1 - #Store labels for collision image, cluster containing collision, and the two conflicting sources + # Store labels for collision image, cluster containing collision, and the two conflicting sources collisions.append([i, labels[j], j, cols[labels[j]]]) if show_collisions: - print("collision at hdul {}, object '{}' at sources {} {}".format(i, labels[j], j, cols[labels[j]])) + print( + "collision at hdul {}, object '{}' at sources {} {}".format( + i, labels[j], j, cols[labels[j]] + ) + ) if len(labels) != source_count: labels = labels[source_count:] - table = [fits.Column(name='objects', format="I", array=np.array(cols), ascii=True)] + table = [ + fits.Column(name="objects", format="I", array=np.array(cols), ascii=True) + ] table = fits.TableHDU.from_columns(table) table.name = tablename hdul.append(table) print("{} collisions".format(collision_count)) - #Dealing with collisions, takes colliding sources and keeps the source closest to the mean coordinate of their corresponding cluster. + # Dealing with collisions, takes colliding sources and keeps the source closest to the mean coordinate of their corresponding cluster. if collision_fix == True: if len(collisions) > 0: print("Resolving collisions") @@ -138,7 +176,7 @@ def collate(hduls, name="CAT", tablename="OBJ", coords = ["x", "y", "flux"], alg for collision in collisions: sources = all_sources[collision[0]] mean = means[collision[1]] - #Gets coords of the two colliding sources before calculating their distance from the cluster mean to determine which source is kept. + # Gets coords of the two colliding sources before calculating their distance from the cluster mean to determine which source is kept. s1 = [sources[collision[2]][coord] for coord in coords] s2 = [sources[collision[3]][coord] for coord in coords] if dist_func(s1, mean) < dist_func(s2, mean): @@ -148,24 +186,81 @@ def collate(hduls, name="CAT", tablename="OBJ", coords = ["x", "y", "flux"], alg cluster_ratio(hduls, name, tablename) return (hdul for hdul in hduls) + @cli.cli.command("collate") -@click.option("-n", "--name", default="CAT", help="name of HDU for each HDUL containing catalog of sources") -@click.option("-t", "--tablename", default="OBJ", help="name of TableHDU to store cluster data") -#@click.option("-c", "--coords", default="xy", type=click.Choice(["xy", "radec"], case_sensitive=False), help="choice of either 'xy' or 'radec' coordinate system used to calculate distance between sources for clustering") -@click.option("-c", "--coords", default=["x", "y", "flux"], multiple=True, help="list of dimensions to cluster sources over, multiple custom coordinates may be specified using: -c dim1 -c dim2 etc.") -@click.option("-a", "--algorithm", default="DBSCAN", type=click.Choice(["DBSCAN", "OPTICS"], case_sensitive=False), help="choice of 'DBSCAN' or 'OPTICS' clustering algorithm") +@click.option( + "-n", + "--name", + default="CAT", + help="name of HDU for each HDUL containing catalog of sources", +) +@click.option( + "-t", "--tablename", default="OBJ", help="name of TableHDU to store cluster data" +) +# @click.option("-c", "--coords", default="xy", type=click.Choice(["xy", "radec"], case_sensitive=False), help="choice of either 'xy' or 'radec' coordinate system used to calculate distance between sources for clustering") +@click.option( + "-c", + "--coords", + default=["x", "y", "flux"], + multiple=True, + help="list of dimensions to cluster sources over, multiple custom coordinates may be specified using: -c dim1 -c dim2 etc.", +) +@click.option( + "-a", + "--algorithm", + default="DBSCAN", + type=click.Choice(["DBSCAN", "OPTICS"], case_sensitive=False), + help="choice of 'DBSCAN' or 'OPTICS' clustering algorithm", +) @click.option("-p", "--minpts", default=4, help="minimum cluster size") -@click.option("-e", "--eps", default=0.001, help="point neighborhood radius parameter, used only in DBSCAN") -@click.option("-m", "--maxeps", default=np.inf, help="maximum neighborhood radius parameter, used only in OPTICS") -@click.option("-x", "--xi", default=0.85, help="minimum steepness of cluster boundary, used only in OPTICS") -@click.option("-f", "--collision_fix", default=False, help="allows for optional resolving of collisions which may result in increased runtime") -@click.option("-s", "--show_collisions", default=False, help="Display colliding sources and their corresponding images and clusters") +@click.option( + "-e", + "--eps", + default=0.001, + help="point neighborhood radius parameter, used only in DBSCAN", +) +@click.option( + "-m", + "--maxeps", + default=np.inf, + help="maximum neighborhood radius parameter, used only in OPTICS", +) +@click.option( + "-x", + "--xi", + default=0.85, + help="minimum steepness of cluster boundary, used only in OPTICS", +) +@click.option( + "-f", + "--collision_fix", + default=False, + help="allows for optional resolving of collisions which may result in increased runtime", +) +@click.option( + "-s", + "--show_collisions", + default=False, + help="Display colliding sources and their corresponding images and clusters", +) @cli.operator -#collate function wrapper -def collate_cmd(hduls, name="CAT", tablename="OBJ", coords=["x", "y", "flux"], algorithm="DBSCAN", minpts=4, eps=0.001, maxeps=np.inf, xi=0.85, collision_fix=False, show_collisions=False): +# collate function wrapper +def collate_cmd( + hduls, + name="CAT", + tablename="OBJ", + coords=["x", "y", "flux"], + algorithm="DBSCAN", + minpts=4, + eps=0.001, + maxeps=np.inf, + xi=0.85, + collision_fix=False, + show_collisions=False, +): """ - Clusters source data from HDULs to predict sky objects and then appends a TableHDU to each HDUL containing object/cluster data. + Clusters source data from HDULs to predict sky objects and then appends a TableHDU to each HDUL containing object/cluster data. The nth index of each TableHDU for each HDUL refers to the same object with the value of the index referring to the (value)th source in the HDUL's source catalog. Collisions refer to when two sources in the same HDUL are clustered to the same object. In the event of a collision, @@ -183,4 +278,16 @@ def collate_cmd(hduls, name="CAT", tablename="OBJ", coords=["x", "y", "flux"], a collision_fix -- allows for optional resolving of collisions show_collisions -- Display colliding sources and their corresponding images and clusters """ - return collate(hduls, name, tablename, coords, algorithm, minpts, eps, maxeps, xi, collision_fix, show_collisions) + return collate( + hduls, + name, + tablename, + coords, + algorithm, + minpts, + eps, + maxeps, + xi, + collision_fix, + show_collisions, + ) diff --git a/sdi/combine.py b/sdi/combine.py index 5382172..82735ac 100644 --- a/sdi/combine.py +++ b/sdi/combine.py @@ -9,6 +9,7 @@ from astropy.io import fits from . import _cli as cli + def combine(hduls, name="ALGN"): """ Combine takes a pixel-by-pixel median of a set of astronomical data to @@ -20,22 +21,29 @@ def combine(hduls, name="ALGN"): :param hduls: list of fits hdul's :param name: the name of the HDU to sum among the HDULS :returns: a list with a single hdul representing the median image. - """ + """ hduls_list = [hdul for hdul in hduls] try: - data = [hdul[name].data for hdul in hduls_list] #creates list of all data arrays from all the hdul's in the list. + data = [ + hdul[name].data for hdul in hduls_list + ] # creates list of all data arrays from all the hdul's in the list. except KeyError: hduls_list[0].info() - raise KeyError(str(f"Name {name} not found in HDUList! Try running again with `combine -n [name]` from above")) from None + raise KeyError( + str( + f"Name {name} not found in HDUList! Try running again with `combine -n [name]` from above" + ) + ) from None - comb = np.median(data, axis=0) + comb = np.median(data, axis=0) hdu = fits.PrimaryHDU(comb) - #hduls_list += [fits.HDUList([hdu])] - #^^We do not need to create a list of HDUL's here. We return a single median image. + # hduls_list += [fits.HDUList([hdu])] + # ^^We do not need to create a list of HDUL's here. We return a single median image. return fits.HDUList([hdu]) + @cli.cli.command("combine") @click.option("-n", "--name", default="ALGN", help="The HDU to be aligned.") @cli.operator @@ -54,4 +62,6 @@ def combine_cmd(hduls, name="ALGN"): :param name: the name of the HDU to sum among the HDULS :returns: a HDUList with a single HDU representing the median image. """ - return [combine(hduls, name),] + return [ + combine(hduls, name), + ] diff --git a/sdi/display.py b/sdi/display.py index 94f2399..0fa23e3 100644 --- a/sdi/display.py +++ b/sdi/display.py @@ -6,7 +6,8 @@ from pyds9 import DS9 from . import _cli as cli -def display(hduls, cats=None, color='green', size=40): + +def display(hduls, cats=None, color="green", size=40): """ Opens a series of HDU Images on ds9, optionally with circled catalogs. :param hduls: list of HDUL to display @@ -32,22 +33,35 @@ def display(hduls, cats=None, color='green', size=40): ds9.set_pyfits(hdul) ds9.set("zoom to fit") for source in cat: - ds9.set('regions command {{circle {} {} {} #color={}}}' \ - .format(source['x'], source['y'], size, color)) + ds9.set( + "regions command {{circle {} {} {} #color={}}}".format( + source["x"], source["y"], size, color + ) + ) ds9.set("frame first") ds9.set("frame delete") return (hdul for hdul in hduls) + @cli.cli.command("display") -@click.option("-e", "--cat_ext", type=str, \ - help="EXTNAME for a TableHDU in each HDUL which " - "contains a record array of sources to circle in ds9.") -@click.option("-c", "--color", default="green", help="A string color to draw" - "circles in, e.g. 'green'") -@click.option("-s", "--size", default=40, help="The radius of the circles" - "drawn", type=int) +@click.option( + "-e", + "--cat_ext", + type=str, + help="EXTNAME for a TableHDU in each HDUL which " + "contains a record array of sources to circle in ds9.", +) +@click.option( + "-c", + "--color", + default="green", + help="A string color to draw" "circles in, e.g. 'green'", +) +@click.option( + "-s", "--size", default=40, help="The radius of the circles" "drawn", type=int +) @cli.operator def display_cmd(hduls, cat_ext=None, color="green", size=40): """ diff --git a/sdi/extract.py b/sdi/extract.py index 1719ea1..fa5e6b0 100644 --- a/sdi/extract.py +++ b/sdi/extract.py @@ -11,6 +11,7 @@ import sep from . import _cli as cli + def extract(hduls, stddev_thresh=3.0, read_ext="SUB", write_ext="XRT"): """ Uses sep to find sources on a residual image(s) @@ -27,18 +28,22 @@ def extract(hduls, stddev_thresh=3.0, read_ext="SUB", write_ext="XRT"): for hdul in hduls: data = hdul[read_ext].data bkg = None - try: + try: bkg = sep.Background(data) except: data = data.byteswap().newbyteorder() bkg = sep.Background(data) - sources = sep.extract(data - bkg.back(), bkg.globalrms * stddev_thresh, - segmentation_map=False) - #Convert xy to RA/DEC in ICRS + sources = sep.extract( + data - bkg.back(), bkg.globalrms * stddev_thresh, segmentation_map=False + ) + # Convert xy to RA/DEC in ICRS from astropy import wcs from astropy.coordinates import SkyCoord - coords = wcs.utils.pixel_to_skycoord(sources['x'],sources['y'],wcs.WCS(hdul[read_ext].header),origin = 0) - coords_icrs = SkyCoord(coords,frame = 'icrs') + + coords = wcs.utils.pixel_to_skycoord( + sources["x"], sources["y"], wcs.WCS(hdul[read_ext].header), origin=0 + ) + coords_icrs = SkyCoord(coords, frame="icrs") ra = coords_icrs.ra.degree dec = coords_icrs.dec.degree extname = write_ext @@ -49,29 +54,51 @@ def extract(hduls, stddev_thresh=3.0, read_ext="SUB", write_ext="XRT"): extname = write_ext[0] except (ValueError, TypeError): pass - cat = fits.BinTableHDU(data=sources, header=header, - name=extname, ver=extver) - print('name = ',extname) + cat = fits.BinTableHDU(data=sources, header=header, name=extname, ver=extver) + print("name = ", extname) out_cols = cat.data.columns - new_cols = fits.ColDefs([fits.Column(name = 'ra',format = 'D',array=ra),fits.Column(name = 'dec', format = 'D', array=dec)]) - new_table = fits.BinTableHDU(header = header, name = extname, ver=extver).from_columns(out_cols + new_cols) - new_hdu = fits.BinTableHDU(new_table.data, header = header, name = extname, ver = extver) + new_cols = fits.ColDefs( + [ + fits.Column(name="ra", format="D", array=ra), + fits.Column(name="dec", format="D", array=dec), + ] + ) + new_table = fits.BinTableHDU( + header=header, name=extname, ver=extver + ).from_columns(out_cols + new_cols) + new_hdu = fits.BinTableHDU( + new_table.data, header=header, name=extname, ver=extver + ) hdul.append(new_hdu) yield hdul -@cli.cli.command("extract") -@click.option("-t", "--threshold", default=3.0, - help="A threshold value to use for source extraction in terms of" - "the number of stddevs above the background noise.", type=float) -@click.option("-r", "--read_ext", default="SUB", - help="An index number or ext name that identifies the data in" - "input hduls that you want source extraction for. For LCO, this " - "is 1 or SUB.") ##We might want to re-word this now. -@click.option("-w", "--write_ext", default=("XRT", 1), - help="An extension name and extension version that will identify" - "the HDUL that the resulting BinTable gets written to. Default" - "is `XRT 1`", type=(str, int)) +@cli.cli.command("extract") +@click.option( + "-t", + "--threshold", + default=3.0, + help="A threshold value to use for source extraction in terms of" + "the number of stddevs above the background noise.", + type=float, +) +@click.option( + "-r", + "--read_ext", + default="SUB", + help="An index number or ext name that identifies the data in" + "input hduls that you want source extraction for. For LCO, this " + "is 1 or SUB.", +) ##We might want to re-word this now. +@click.option( + "-w", + "--write_ext", + default=("XRT", 1), + help="An extension name and extension version that will identify" + "the HDUL that the resulting BinTable gets written to. Default" + "is `XRT 1`", + type=(str, int), +) @cli.operator def extract_cmd(hduls, threshold, read_ext, write_ext): """ diff --git a/sdi/fitsio.py b/sdi/fitsio.py index d8a9387..9641d52 100644 --- a/sdi/fitsio.py +++ b/sdi/fitsio.py @@ -9,6 +9,7 @@ from astropy.io import fits from . import _cli as cli + def read(directory): """ Takes a directory containing fits files and returns them as a list @@ -16,24 +17,27 @@ def read(directory): paths = glob.glob("{}/*.fits*".format(directory)) try: # if file name is numeric, paths will be sorted in ascending order - paths = sorted(paths, key = lambda item: int(item[len(directory):len(item)-5])) + paths = sorted( + paths, key=lambda item: int(item[len(directory) : len(item) - 5]) + ) except: pass hduls = [fits.open(p) for p in paths] return hduls + def write(hduls, directory, format_): """ Writes all given hduls into the directory specified """ import os - #check if directory exists + # check if directory exists is_dir = os.path.isdir(directory) if is_dir == False: os.mkdir(directory) - print('directory was not present, now made at' + directory) + print("directory was not present, now made at" + directory) for i, h in enumerate(hduls): path = os.path.join(directory, format_.format(number=i)) @@ -41,9 +45,23 @@ def write(hduls, directory, format_): h.writeto(path) return hduls + @cli.cli.command("write") -@click.option('-d', '--directory', type=str, help="Specify path to directory to save fitsfiles.", default="./") -@click.option('-f', '--format', "format_", type=str, help="Specify string format for filename.", default="{number}.fits") +@click.option( + "-d", + "--directory", + type=str, + help="Specify path to directory to save fitsfiles.", + default="./", +) +@click.option( + "-f", + "--format", + "format_", + type=str, + help="Specify string format for filename.", + default="{number}.fits", +) @cli.operator ## write function wrapper def write_cmd(hduls, directory, format_): @@ -52,8 +70,15 @@ def write_cmd(hduls, directory, format_): """ return write(hduls, directory, format_) + @cli.cli.command("read") -@click.option('-d', '--directory', type=str, help="Specify path to directory of fitsfiles.", required=False) +@click.option( + "-d", + "--directory", + type=str, + help="Specify path to directory of fitsfiles.", + required=False, +) @cli.generator ## read function wrapper def read_cmd(directory): @@ -65,7 +90,10 @@ def read_cmd(directory): try: directory = askdirectory() except: - click.echo("Visual file dialog does not exist, please use option -d and specify path to directory to read fitsfiles.", err=True) + click.echo( + "Visual file dialog does not exist, please use option -d and specify path to directory to read fitsfiles.", + err=True, + ) sys.exit() hduls = read(directory) if hduls: diff --git a/sdi/ref.py b/sdi/ref.py index 4ac23a6..9aec7c0 100644 --- a/sdi/ref.py +++ b/sdi/ref.py @@ -16,10 +16,23 @@ import warnings # define specific columns so we don't get dtype issues from the chaff -COLUMNS = ["source_id", "ra", "ra_error", "dec", "dec_error", - "phot_g_mean_flux", "phot_g_mean_flux_error", "phot_g_mean_mag", - "phot_rp_mean_flux", "phot_rp_mean_flux_error", "phot_rp_mean_mag", - "phot_bp_mean_flux","phot_bp_mean_flux_error","phot_bp_mean_mag"] +COLUMNS = [ + "source_id", + "ra", + "ra_error", + "dec", + "dec_error", + "phot_g_mean_flux", + "phot_g_mean_flux_error", + "phot_g_mean_mag", + "phot_rp_mean_flux", + "phot_rp_mean_flux_error", + "phot_rp_mean_mag", + "phot_bp_mean_flux", + "phot_bp_mean_flux_error", + "phot_bp_mean_mag", +] + def _in_cone(coord: SkyCoord, cone_center: SkyCoord, cone_radius: u.degree): """ @@ -27,7 +40,8 @@ def _in_cone(coord: SkyCoord, cone_center: SkyCoord, cone_radius: u.degree): cone_radius """ d = (coord.ra - cone_center.ra) ** 2 + (coord.dec - cone_center.dec) ** 2 - return d < (cone_radius ** 2) + return d < (cone_radius**2) + def ref(hduls, read_ext=-1, write_ext="REF", threshold=0.001): """ @@ -47,6 +61,7 @@ def ref(hduls, read_ext=-1, write_ext="REF", threshold=0.001): # This import is here because the GAIA import slows down `import sdi` # substantially; we don't want to import it unless we need it from astroquery.gaia import Gaia + Gaia.ROW_LIMIT = -1 queried_coords = [] @@ -54,27 +69,34 @@ def ref(hduls, read_ext=-1, write_ext="REF", threshold=0.001): cached_table = np.array([]) # an arbitrary cone search radius, separate from the threshold value - cone_radius = u.Quantity(0.05 , u.deg) + cone_radius = u.Quantity(0.05, u.deg) # we need this to track blanks till we know the dtype initial_empty = 0 # An adaptive method of obtaining the threshold value for hdul in hduls: - threshold = max(hdul[read_ext].data["a"])*hdul['ALGN'].header['PIXSCALE']/3600 + threshold = ( + max(hdul[read_ext].data["a"]) * hdul["ALGN"].header["PIXSCALE"] / 3600 + ) threshold = u.Quantity(threshold, u.deg) ra = hdul[read_ext].data["RA"] dec = hdul[read_ext].data["DEC"] sources = hdul[read_ext].data output_table = np.array([]) - coordinates = SkyCoord(ra,dec,unit="deg") + coordinates = SkyCoord(ra, dec, unit="deg") for coord in coordinates: ########### Query an area if we have not done so already ########### # Check to see if we've queried the area - if not any((_in_cone(coord, query, cone_radius - 2 * threshold) \ - for query in queried_coords)): + if not any( + ( + _in_cone(coord, query, cone_radius - 2 * threshold) + for query in queried_coords + ) + ): # we have never queried the area. Do a GAIA cone search - data = Gaia.cone_search_async(coord, cone_radius, columns=COLUMNS, - output_format="csv").get_results() + data = Gaia.cone_search_async( + coord, cone_radius, columns=COLUMNS, output_format="csv" + ).get_results() data = data.as_array() # add the cache table to the data if len(cached_table): @@ -83,8 +105,9 @@ def ref(hduls, read_ext=-1, write_ext="REF", threshold=0.001): cached_table = data.data for d in data: # construct Coord objects for the new data - cached_coords.append(SkyCoord(d["ra"], d["dec"], - unit=(u.deg, u.deg))) + cached_coords.append( + SkyCoord(d["ra"], d["dec"], unit=(u.deg, u.deg)) + ) # note that we have now queried this area queried_coords.append(coord) @@ -98,8 +121,12 @@ def ref(hduls, read_ext=-1, write_ext="REF", threshold=0.001): output_table = np.hstack((output_table, np.copy(ct))) else: output_table = np.copy(ct) - output_table = np.hstack((np.empty(shape=initial_empty, - dtype=output_table.dtype), output_table)) + output_table = np.hstack( + ( + np.empty(shape=initial_empty, dtype=output_table.dtype), + output_table, + ) + ) appended = True break else: @@ -118,23 +145,33 @@ def ref(hduls, read_ext=-1, write_ext="REF", threshold=0.001): extname = write_ext header = fits.Header([fits.Card("HISTORY", "From the GAIA remote db")]) # replace nan values with 0.0 - for i,elm in enumerate(output_table): - for j,val in enumerate(elm): + for i, elm in enumerate(output_table): + for j, val in enumerate(elm): if np.isnan(val): elm[j] = 0.0 # only append the hdul if output_table is not empty if len(output_table): - hdul.append(fits.BinTableHDU(data=output_table, header=header, name=extname)) + hdul.append( + fits.BinTableHDU(data=output_table, header=header, name=extname) + ) else: - warnings.warn(f"empty reference table created, no stars found in the database corresponding to {hdul}") - yield hdul # not sure if I should be yielding hdul even if output_table is empty + warnings.warn( + f"empty reference table created, no stars found in the database corresponding to {hdul}" + ) + yield hdul # not sure if I should be yielding hdul even if output_table is empty return + @cli.cli.command("ref") @click.option("-r", "--read-ext", default="CAT", help="The HDU to match") @click.option("-w", "--write-ext", default="REF", help="The HDU to load ref into") -@click.option("-t", "--threshold", default=0.001, type=float, - help="The threshold in degrees for a cone search") +@click.option( + "-t", + "--threshold", + default=0.001, + type=float, + help="The threshold in degrees for a cone search", +) @cli.operator def ref_cmd(hduls, read_ext=-1, write_ext="REF", threshold=0.001): """ diff --git a/sdi/secid.py b/sdi/secid.py index 3ca3fcc..5c3139a 100644 --- a/sdi/secid.py +++ b/sdi/secid.py @@ -9,11 +9,29 @@ from astropy import units as u from . import _cli as cli -RABOUND = [12.44,11.97,11.48,11,10.52,10.04,9.56] #Bounds used to define sections -_RASCALE = [-1,0,9,18,27,36,45] #Values to sum to find section ID according to a grid -DECBOUND = [42.77,42.45,42.13,41.81,41.49,41.17,40.85,40.53,40.21,39.89] -_DECSCALE = [-1,9,8,7,6,5,4,3,2,1] -def secid(hduls, read_ext='SCI', write_ext=-1): +RABOUND = [ + 12.44, + 11.97, + 11.48, + 11, + 10.52, + 10.04, + 9.56, +] # Bounds used to define sections +_RASCALE = [ + -1, + 0, + 9, + 18, + 27, + 36, + 45, +] # Values to sum to find section ID according to a grid +DECBOUND = [42.77, 42.45, 42.13, 41.81, 41.49, 41.17, 40.85, 40.53, 40.21, 39.89] +_DECSCALE = [-1, 9, 8, 7, 6, 5, 4, 3, 2, 1] + + +def secid(hduls, read_ext="SCI", write_ext=-1): """ secid adds M31 section ID information to the header of the HDU specified in write_ext. If an HDUL is not in M31, the SECID is set to -1. @@ -21,47 +39,63 @@ def secid(hduls, read_ext='SCI', write_ext=-1): :param read_ext: an extname of an HDU to read RA and DEC information from. :param write_ext: an extname or index number of an HDU to write secid information into. """ - if write_ext == -1: #Defaults the write_ext to the read_ext + if write_ext == -1: # Defaults the write_ext to the read_ext write_ext = read_ext for hdul in hduls: - RA = hdul[read_ext].header['RA'] #Reads in RA and DEC - DEC = hdul[read_ext].header['DEC'] - sc = SkyCoord(RA + ' ' + DEC, unit = (u.hourangle,u.deg)) #Combines RA and Dec into a skycoord with correct units - SCRA = sc.ra*u.deg - SCDEC = sc.dec*u.deg - RAid = -1 #starts the values at -1 in case they are too small + RA = hdul[read_ext].header["RA"] # Reads in RA and DEC + DEC = hdul[read_ext].header["DEC"] + sc = SkyCoord( + RA + " " + DEC, unit=(u.hourangle, u.deg) + ) # Combines RA and Dec into a skycoord with correct units + SCRA = sc.ra * u.deg + SCDEC = sc.dec * u.deg + RAid = -1 # starts the values at -1 in case they are too small DECid = -1 - #Constraints the SECID to one 'row' based on the RA - for i,j in enumerate(RABOUND): - if SCRA.value > RABOUND[i]: #Checks the values in a loop until it fails to be larger than one. + # Constraints the SECID to one 'row' based on the RA + for i, j in enumerate(RABOUND): + if ( + SCRA.value > RABOUND[i] + ): # Checks the values in a loop until it fails to be larger than one. RAid = _RASCALE[i] break - - #Figures out what SECID you are precisely in using DEC - for i,j in enumerate(DECBOUND): + + # Figures out what SECID you are precisely in using DEC + for i, j in enumerate(DECBOUND): if SCDEC.value > DECBOUND[i]: DECid = _DECSCALE[i] break - if DECid != -1 and RAid != -1: #Creates the SECid if both the RA and DEC were found within bounds. + if ( + DECid != -1 and RAid != -1 + ): # Creates the SECid if both the RA and DEC were found within bounds. SECid = DECid + RAid else: SECid = -1 - - hdul[write_ext].header.set('SECID', SECid, comment='Setion ID in M31 as according to Lubin mosaic') #Writes the secid into the hdul - yield hdul #Passes the hdul down the pipeline + + hdul[write_ext].header.set( + "SECID", SECid, comment="Setion ID in M31 as according to Lubin mosaic" + ) # Writes the secid into the hdul + yield hdul # Passes the hdul down the pipeline + @cli.cli.command("secid") -@click.option("-r", "--read_ext", default='SCI', - help="An index number or ext name that identifies the data in" - "input hduls that you want source extraction for. For LCO, this " - "is 0 or SCI.") -@click.option("-w", "--write_ext", default=-1, - help="An extension name and extension version that will identify" - "the HDUL that the resulting secid gets written to. Default" - "is -1") +@click.option( + "-r", + "--read_ext", + default="SCI", + help="An index number or ext name that identifies the data in" + "input hduls that you want source extraction for. For LCO, this " + "is 0 or SCI.", +) +@click.option( + "-w", + "--write_ext", + default=-1, + help="An extension name and extension version that will identify" + "the HDUL that the resulting secid gets written to. Default" + "is -1", +) @cli.operator - def secid_cmd(hduls, read_ext, write_ext): """ Use secid to find the section of an image diff --git a/sdi/snr_function.py b/sdi/snr_function.py index 26b87a1..6927c8a 100644 --- a/sdi/snr_function.py +++ b/sdi/snr_function.py @@ -9,6 +9,7 @@ import click from . import _cli as cli + def snr(hduls, name="SCI"): """ calculates the signal-to-noise ratio of a fits file @@ -18,7 +19,7 @@ def snr(hduls, name="SCI"): data = hdul[name].data # identify background rms - boxsize = (data.shape) + boxsize = data.shape bkg = Background2D(data, boxsize) bkg_mean_rms = np.mean(bkg.background_rms) @@ -27,10 +28,12 @@ def snr(hduls, name="SCI"): # set threshold and detect sources, threshold 5*std above background threshold = detect_threshold(data=new_data, nsigma=5.0, background=0.0) - SegmentationImage = detect_sources(data=new_data, threshold=threshold, npixels=10) + SegmentationImage = detect_sources( + data=new_data, threshold=threshold, npixels=10 + ) SourceCatalog = source_properties(new_data, SegmentationImage) - columns = ['id', 'xcentroid', 'ycentroid', 'source_sum'] + columns = ["id", "xcentroid", "ycentroid", "source_sum"] source_max_values = SourceCatalog.max_value avg_source_max_values = np.mean(source_max_values) @@ -38,15 +41,15 @@ def snr(hduls, name="SCI"): # calculate signal to noise ratio signal = avg_source_max_values noise = bkg_mean_rms - SNR = (signal)/(noise) - hdul["CAT"].header.append(('SNR', SNR, "signal to noise ratio")) + SNR = (signal) / (noise) + hdul["CAT"].header.append(("SNR", SNR, "signal to noise ratio")) return (hdul for hdul in hduls) + @cli.cli.command("snr") @click.option("-n", "--name", default="SCI", help="The HDU to calculate for") @cli.operator - def snr_cmd(hduls, name="SCI"): """ snr function wrapper diff --git a/sdi/subtract.py b/sdi/subtract.py index 21d961c..56a6f29 100644 --- a/sdi/subtract.py +++ b/sdi/subtract.py @@ -4,42 +4,61 @@ from . import _cli as cli from .combine import combine -def subtract(hduls, name="ALGN", method: ("ois", "numpy")="ois"): + +def subtract(hduls, name="ALGN", method: ("ois", "numpy") = "ois"): """ Returns differences of a set of images from a template image Arguments: - hduls -- list of fits hdul's where the last image is the template + hduls -- list of fits hdul's where the last image is the template name -- name of the HDU to use image that the other images will be subtracted from method -- method of subtraction. OIS is default (best/slow). Numpy is alternative (worse/quick) """ hduls = [h for h in hduls] outputs = [] - template = combine(hduls, name)["PRIMARY"].data # this is a temporary HDUL containing 1 PrimaryHDU with combined data + template = combine(hduls, name)[ + "PRIMARY" + ].data # this is a temporary HDUL containing 1 PrimaryHDU with combined data if method == "ois": - for i,hdu in enumerate(hduls): + for i, hdu in enumerate(hduls): try: - diff = ois.optimal_system(image=hdu[name].data, refimage=template, method='Bramich')[0] + diff = ois.optimal_system( + image=hdu[name].data, refimage=template, method="Bramich" + )[0] except ValueError: - diff = ois.optimal_system(image=hdu[name].data.byteswap().newbyteorder(), refimage=template.byteswap().newbyteorder(), method='Bramich')[0] - hdu.insert(1,CompImageHDU(data = diff, header = hduls[i]['ALGN'].header, name = "SUB")) + diff = ois.optimal_system( + image=hdu[name].data.byteswap().newbyteorder(), + refimage=template.byteswap().newbyteorder(), + method="Bramich", + )[0] + hdu.insert( + 1, CompImageHDU(data=diff, header=hduls[i]["ALGN"].header, name="SUB") + ) outputs.append(hdu) elif method == "numpy": - for i,hdu in enumerate(hduls): + for i, hdu in enumerate(hduls): diff = template - hdu[name].data - hdu.insert(1,CompImageHDU(data = diff, header = hduls[i]['ALGN'].header, name = "SUB")) + hdu.insert( + 1, CompImageHDU(data=diff, header=hduls[i]["ALGN"].header, name="SUB") + ) outputs.append(hdu) else: raise ValueError(f"method {method} unknown!") return (hdul for hdul in outputs) + @cli.cli.command("subtract") @click.option("-n", "--name", default="ALGN", help="The HDU to be aligned.") -@click.option("-m", "--method", default="ois", help="The subtraction method to use; ois or numpy (straight subtraction).") +@click.option( + "-m", + "--method", + default="ois", + help="The subtraction method to use; ois or numpy (straight subtraction).", +) @cli.operator ## subtract function wrapper -def subtract_cmd(hduls,name="ALGN", method="ois"): +def subtract_cmd(hduls, name="ALGN", method="ois"): """ Returns differences of a set of images from a template image\n Arguments:\n diff --git a/sdi/test/align.py b/sdi/test/align.py index 6aaaf72..dcdba27 100644 --- a/sdi/test/align.py +++ b/sdi/test/align.py @@ -13,19 +13,36 @@ from skimage.util.dtype import img_as_float import glob -class TestAlign(unittest.TestCase): +class TestAlign(unittest.TestCase): @classmethod def setUpClass(cls): # uses two fits files as align param to decrease runtime - cls.paths = glob.glob(os.path.join(os.path.dirname(__file__), "fixtures/science/tfn0m414-kb99-20180831-029[0-1]-e91.fits.fz")) + cls.paths = glob.glob( + os.path.join( + os.path.dirname(__file__), + "fixtures/science/tfn0m414-kb99-20180831-029[0-1]-e91.fits.fz", + ) + ) cls.read = [fits.open(p) for p in cls.paths] cls.output = list(sdi.align(cls.read)) # sorts the known true output to be congruent with the current output - cls.path_len = len(os.path.join(os.path.dirname(__file__), "fixtures/comparativeData/alignData")) - cls.paths_true = glob.glob("{}/*.fits*".format(os.path.join(os.path.dirname(__file__), "fixtures/comparativeData/alignData"))) - cls.paths_true = sorted(cls.paths_true, key = lambda item: int(item[cls.path_len+1:len(item)-5])) + cls.path_len = len( + os.path.join( + os.path.dirname(__file__), "fixtures/comparativeData/alignData" + ) + ) + cls.paths_true = glob.glob( + "{}/*.fits*".format( + os.path.join( + os.path.dirname(__file__), "fixtures/comparativeData/alignData" + ) + ) + ) + cls.paths_true = sorted( + cls.paths_true, key=lambda item: int(item[cls.path_len + 1 : len(item) - 5]) + ) cls.true_output = [fits.open(p) for p in cls.paths_true] @classmethod @@ -39,18 +56,20 @@ def test_type(self): self.assertIsInstance(o, fits.HDUList, "Did not align type fits.HDUList") def test_length(self): - self.assertEqual(len(TestAlign.output), 2, "Did not align the correct number of fits") + self.assertEqual( + len(TestAlign.output), 2, "Did not align the correct number of fits" + ) def test_output(self): for t, o in zip(TestAlign.true_output, TestAlign.output): ## converts the true_output SNR value to type float32 to be congruent with output - t[0].header['SNR'] = np.float32(t[0].header['SNR']) + t[0].header["SNR"] = np.float32(t[0].header["SNR"]) compare = fits.FITSDiff(t, o) - self.assertEqual(compare.identical, True , compare.report(fileobj = None)) + self.assertEqual(compare.identical, True, compare.report(fileobj=None)) def test_click(self): runner = CliRunner() - result = runner.invoke(sdi._align_cmd, ['-n', 'name']) + result = runner.invoke(sdi._align_cmd, ["-n", "name"]) assert result.exit_code == 0 diff --git a/sdi/test/collate.py b/sdi/test/collate.py index 0d698d6..ecaee32 100644 --- a/sdi/test/collate.py +++ b/sdi/test/collate.py @@ -11,8 +11,9 @@ import sdi import glob + class TestCollate(unittest.TestCase): - + path = os.path.join(os.path.dirname(__file__), "fixtures/extractedsims") @classmethod @@ -33,18 +34,22 @@ def test_type(self): self.assertIsInstance(o, fits.HDUList, "Output is not type fits.HDUList") def test_length(self): - self.assertEqual(len(TestCollate.output), 10, "Did not collate correct number of HDULs from fixtures/extractedsims") + self.assertEqual( + len(TestCollate.output), + 10, + "Did not collate correct number of HDULs from fixtures/extractedsims", + ) def test_output(self): for t, o in zip(TestCollate.true_output, TestCollate.output): compare = fits.FITSDiff(fits.HDUList(t), fits.HDUList(o)) - self.assertEqual(compare.identical, True, compare.report(fileobj = None)) + self.assertEqual(compare.identical, True, compare.report(fileobj=None)) def test_click(self): runner = CliRunner() - result = runner.invoke(sdi._collate_cmd, ['-c', 'xy']) + result = runner.invoke(sdi._collate_cmd, ["-c", "xy"]) assert result.exit_code == 0 + if __name__ == "__main__": unittest.main() - diff --git a/sdi/test/combine.py b/sdi/test/combine.py index 63fe1b6..71fc885 100644 --- a/sdi/test/combine.py +++ b/sdi/test/combine.py @@ -10,6 +10,7 @@ from astropy.io import fits import sdi + class TestCombine(unittest.TestCase): path = os.path.join(os.path.dirname(__file__), "fixtures/science") @@ -18,10 +19,11 @@ class TestCombine(unittest.TestCase): def setUpClass(cls): cls.read = [s for s in sdi.read(cls.path)] cls.output = sdi.combine(cls.read) - true_path = os.path.join(os.path.dirname(__file__), "fixtures/comparativeData/combineData/0.fits") + true_path = os.path.join( + os.path.dirname(__file__), "fixtures/comparativeData/combineData/0.fits" + ) cls.true_output = [fits.open(true_path)] - @classmethod def tearDownClass(cls): for o, t in zip(cls.read, cls.true_output): @@ -29,20 +31,29 @@ def tearDownClass(cls): t.close() def test_length(self): - self.assertEqual(len(TestCombine.output), 1, "Combine did not return one PrimaryHDU") + self.assertEqual( + len(TestCombine.output), 1, "Combine did not return one PrimaryHDU" + ) def test_type(self): - self.assertIsInstance(TestCombine.output[0], fits.PrimaryHDU, "Output is not of type PrimaryHDU") + self.assertIsInstance( + TestCombine.output[0], fits.PrimaryHDU, "Output is not of type PrimaryHDU" + ) def test_output(self): - self.compare = fits.FITSDiff(fits.HDUList(TestCombine.output[0]), fits.HDUList(TestCombine.true_output[0])) - self.assertEqual(self.compare.identical, True, self.compare.report(fileobj = None)) - + self.compare = fits.FITSDiff( + fits.HDUList(TestCombine.output[0]), + fits.HDUList(TestCombine.true_output[0]), + ) + self.assertEqual( + self.compare.identical, True, self.compare.report(fileobj=None) + ) + def test_click(self): runner = CliRunner() - result = runner.invoke(sdi._combine_cmd, ['-n', 'name']) + result = runner.invoke(sdi._combine_cmd, ["-n", "name"]) assert result.exit_code == 0 - + + if __name__ == "__main__": unittest.main() - diff --git a/sdi/test/extract.py b/sdi/test/extract.py index 36e38f7..9cc8a75 100644 --- a/sdi/test/extract.py +++ b/sdi/test/extract.py @@ -11,6 +11,7 @@ import sdi import glob + class TestExtract(unittest.TestCase): path = os.path.join(os.path.dirname(__file__), "fixtures/residuals") @@ -21,9 +22,21 @@ def setUpClass(cls): cls.output = list(sdi.extract(cls.read)) # sorts the known true output to be congruent with the current output - cls.path_len = len(os.path.join(os.path.dirname(__file__), "fixtures/comparativeData/extractData")) - cls.paths = glob.glob("{}/*.fits*".format(os.path.join(os.path.dirname(__file__), "fixtures/comparativeData/extractData"))) - cls.paths = sorted(cls.paths, key = lambda item: int(item[cls.path_len+1:len(item)-5])) + cls.path_len = len( + os.path.join( + os.path.dirname(__file__), "fixtures/comparativeData/extractData" + ) + ) + cls.paths = glob.glob( + "{}/*.fits*".format( + os.path.join( + os.path.dirname(__file__), "fixtures/comparativeData/extractData" + ) + ) + ) + cls.paths = sorted( + cls.paths, key=lambda item: int(item[cls.path_len + 1 : len(item) - 5]) + ) cls.true_output = [fits.open(p) for p in cls.paths] @classmethod @@ -37,18 +50,22 @@ def test_type(self): self.assertIsInstance(o, fits.HDUList, "Output is not of type fits.HDUList") def test_length(self): - self.assertEqual(len(TestExtract.output), 10, "Did not extract the correct number of HDULs from fixtures/residuals") + self.assertEqual( + len(TestExtract.output), + 10, + "Did not extract the correct number of HDULs from fixtures/residuals", + ) def test_output(self): for t, o in zip(TestExtract.true_output, TestExtract.output): compare = fits.FITSDiff(fits.HDUList(t), fits.HDUList(o)) - self.assertEqual(compare.identical, True, compare.report(fileobj = None)) + self.assertEqual(compare.identical, True, compare.report(fileobj=None)) def test_click(self): runner = CliRunner() - result = runner.invoke(sdi._extract_cmd, ['-t', 3.0, '-r', 0, '-w', 'XRT', 1]) + result = runner.invoke(sdi._extract_cmd, ["-t", 3.0, "-r", 0, "-w", "XRT", 1]) assert result.exit_code == 0 + if __name__ == "__main__": unittest.main() - diff --git a/sdi/test/fitsio.py b/sdi/test/fitsio.py index b8cb88b..42e1df1 100644 --- a/sdi/test/fitsio.py +++ b/sdi/test/fitsio.py @@ -13,6 +13,7 @@ import getpass from datetime import datetime + class TestRead(unittest.TestCase): path = os.path.join(os.path.dirname(__file__), "fixtures/science") @@ -27,8 +28,14 @@ def tearDown(self): hdul.close() def test_length(self): - self.assertEqual(len(self.output), 10, "Did not read the correct number of HDULs in directroy, fixtures/science") - self.assertEqual(len(self.noutput), 0, "Did not read zero HDULs in from empty directory.") + self.assertEqual( + len(self.output), + 10, + "Did not read the correct number of HDULs in directroy, fixtures/science", + ) + self.assertEqual( + len(self.noutput), 0, "Did not read zero HDULs in from empty directory." + ) def test_type(self): for o in self.output: @@ -36,11 +43,11 @@ def test_type(self): def test_click(self): runner = CliRunner() - result = runner.invoke(sdi._read_cmd, ['-d', 'filePath']) + result = runner.invoke(sdi._read_cmd, ["-d", "filePath"]) assert result.exit_code == 0 -class TestWrite(unittest.TestCase): +class TestWrite(unittest.TestCase): @classmethod def setUpClass(self): ## Creates a unique dir to write to @@ -49,8 +56,14 @@ def setUpClass(self): self.dirName = os.path.join(self.tmpdir.name, self.testDir) os.mkdir(self.dirName) self.h = sdi.read(os.path.join(os.path.dirname(__file__), "fixtures/science")) - self.output = sdi.write(self.h, os.path.join(os.path.dirname(__file__), self.dirName), "{number}.fits") - self.paths = glob.glob("{}/*.fits*".format(os.path.join(os.path.dirname(__file__), self.dirName))) + self.output = sdi.write( + self.h, + os.path.join(os.path.dirname(__file__), self.dirName), + "{number}.fits", + ) + self.paths = glob.glob( + "{}/*.fits*".format(os.path.join(os.path.dirname(__file__), self.dirName)) + ) @classmethod def tearDownClass(self): @@ -63,12 +76,17 @@ def test_type(self): self.assertIsInstance(o, fits.HDUList, "Did not write type fits.HDUList") def test_dir(self): - self.assertEqual(len(self.paths), 10,f"Did not write the correct number of HDULs to directory {self.paths}") + self.assertEqual( + len(self.paths), + 10, + f"Did not write the correct number of HDULs to directory {self.paths}", + ) def test_click(self): runner = CliRunner() - result = runner.invoke(sdi._write_cmd, ['-d', 'filePath', '-f', 'format']) + result = runner.invoke(sdi._write_cmd, ["-d", "filePath", "-f", "format"]) assert result.exit_code == 0 + if __name__ == "__main__": unittest.main() diff --git a/sdi/test/subtract.py b/sdi/test/subtract.py index a31a62f..b1c9df8 100644 --- a/sdi/test/subtract.py +++ b/sdi/test/subtract.py @@ -11,19 +11,36 @@ import sdi import glob + class TestSubtract(unittest.TestCase): - @classmethod def setUpClass(cls): # uses one fits file as subtract param to decrease runtime - cls.path = glob.glob(os.path.join(os.path.dirname(__file__), "fixtures/science/tfn0m414-kb99-20180831-029[0-1]-e91.fits.fz")) + cls.path = glob.glob( + os.path.join( + os.path.dirname(__file__), + "fixtures/science/tfn0m414-kb99-20180831-029[0-1]-e91.fits.fz", + ) + ) cls.read = [fits.open(p) for p in cls.path] cls.output = list(sdi.subtract(cls.read)) # sorts the known true output to be congruent with the current output - cls.path_len = len(os.path.join(os.path.dirname(__file__), "fixtures/comparativeData/subtractData")) - cls.paths = glob.glob("{}/*.fits*".format(os.path.join(os.path.dirname(__file__), "fixtures/comparativeData/subtractData"))) - cls.paths = sorted(cls.paths, key = lambda item: int(item[cls.path_len+1:len(item)-5])) + cls.path_len = len( + os.path.join( + os.path.dirname(__file__), "fixtures/comparativeData/subtractData" + ) + ) + cls.paths = glob.glob( + "{}/*.fits*".format( + os.path.join( + os.path.dirname(__file__), "fixtures/comparativeData/subtractData" + ) + ) + ) + cls.paths = sorted( + cls.paths, key=lambda item: int(item[cls.path_len + 1 : len(item) - 5]) + ) cls.true_output = [fits.open(p) for p in cls.paths] @classmethod @@ -37,19 +54,20 @@ def test_type(self): self.assertIsInstance(o, fits.HDUList, "Did not output type fits.HDUList") def test_length(self): - self.assertEqual(len(TestSubtract.output), 2, "Did not output the correct number of HDULs") + self.assertEqual( + len(TestSubtract.output), 2, "Did not output the correct number of HDULs" + ) def test_output(self): for t, o in zip(TestSubtract.true_output, TestSubtract.output): compare = fits.FITSDiff(fits.HDUList(t), fits.HDUList(o)) - self.assertEqual(compare.identical, True, compare.report(fileobj = None)) + self.assertEqual(compare.identical, True, compare.report(fileobj=None)) def test_click(self): runner = CliRunner() - result = runner.invoke(sdi._subtract_cmd, ['-n', 'name']) + result = runner.invoke(sdi._subtract_cmd, ["-n", "name"]) assert result.exit_code == 0 if __name__ == "__main__": unittest.main() - diff --git a/sdi/test_cmd.py b/sdi/test_cmd.py index 1521db1..724cf7e 100644 --- a/sdi/test_cmd.py +++ b/sdi/test_cmd.py @@ -4,8 +4,15 @@ from . import _cli as cli from . import test + @cli.cli.command("test") -@click.option('-n', '--name', type=str, help="Specify the name of the function you'd like to test", default="test") +@click.option( + "-n", + "--name", + type=str, + help="Specify the name of the function you'd like to test", + default="test", +) @cli.operator def test_cmd(hduls, name): """ @@ -15,7 +22,7 @@ def test_cmd(hduls, name): To run all tests, simply don't pass any parameters """ - if name == "test": + if name == "test": test_suite = unittest.TestLoader().loadTestsFromModule(test) else: test_suite = unittest.TestLoader().loadTestsFromModule(getattr(test, name)) diff --git a/setup.py b/setup.py index 49f1a91..dee893a 100644 --- a/setup.py +++ b/setup.py @@ -1,15 +1,22 @@ from setuptools import setup, find_packages + try: import numpy except ModuleNotFoundError: import sys - sys.exit("numpy not found, sdi requires numpy for installation.\n Please try '$pip3 install numpy'.") + + sys.exit( + "numpy not found, sdi requires numpy for installation.\n Please try '$pip3 install numpy'." + ) try: import setuptools_rust except ModuleNotFoundError: import sys - sys.exit("setuptools_rust not found, sdi requires setuptools_rust for installation.\n Please try '$pip3 install setuptools_rust'.") + + sys.exit( + "setuptools_rust not found, sdi requires setuptools_rust for installation.\n Please try '$pip3 install setuptools_rust'." + ) setup( @@ -18,7 +25,15 @@ py_modules=["sdi"], # packages=find_packages(include=["openfits"]), include_package_data=True, - install_requires=["click", "astropy", "photutils", "ois", "pyds9", "astroalign", "astroquery", "sklearn"], + install_requires=[ + "click", + "astropy", + "photutils", + "ois", + "astroalign", + "astroquery", + "sklearn", + ], entry_points=""" [console_scripts] sdi=sdi._cli:cli