diff --git a/src/kbmod/mocking/__init__.py b/src/kbmod/mocking/__init__.py index d678d95d4..b79086419 100644 --- a/src/kbmod/mocking/__init__.py +++ b/src/kbmod/mocking/__init__.py @@ -3,4 +3,3 @@ from .headers import * from .fits_data import * from .fits import * -#from . import test_mocking diff --git a/src/kbmod/mocking/catalogs.py b/src/kbmod/mocking/catalogs.py index 06fa5dadd..6e4fc853f 100644 --- a/src/kbmod/mocking/catalogs.py +++ b/src/kbmod/mocking/catalogs.py @@ -1,15 +1,18 @@ import abc import numpy as np -from astropy.time import Time -from astropy.table import QTable, vstack +from astropy.table import QTable +from .config import Config __all__ = [ "gen_catalog", "CatalogFactory", - "SimpleSourceCatalog", - "SimpleObjectCatalog", + "SimpleCatalog", + "SourceCatalogConfig", + "SourceCatalog", + "ObjectCatalogConfig", + "ObjectCatalog", ] @@ -26,84 +29,161 @@ def gen_catalog(n, param_ranges, seed=None): # conversion assumes a gaussian if "flux" in param_ranges and "amplitude" not in param_ranges: - xstd = cat["x_stddev"] if "x_stddev" in cat.colnames else 1 - ystd = cat["y_stddev"] if "y_stddev" in cat.colnames else 1 + xstd = cat["x_stddev"] if "x_stddev" in cat.colnames else 1.0 + ystd = cat["y_stddev"] if "y_stddev" in cat.colnames else 1.0 cat["amplitude"] = cat["flux"] / (2.0 * np.pi * xstd * ystd) return cat - class CatalogFactory(abc.ABC): @abc.abstractmethod - def gen_realization(self, *args, t=None, dt=None, **kwargs): + def mock(self, *args, **kwargs): raise NotImplementedError() - def mock(self, *args, **kwargs): - return self.gen_realization(self, *args, **kwargs) +class SimpleCatalogConfig(Config): + return_copy = False + seed = None + n = 100 + param_ranges = {} + + +class SimpleCatalog(CatalogFactory): + default_config = SimpleCatalogConfig + + def __init_from_table(self, table, config=None, **kwargs): + config = self.default_config(config=config, **kwargs) + config.n = len(table) + params = {} + for col in table.keys(): + params[col] = (table[col].min(), table[col].max()) + config.param_ranges.update(params) + return config, table + + def __init_from_config(self, config, **kwargs): + config = self.default_config(config=config, method="subset", **kwargs) + table = gen_catalog(config.n, config.param_ranges, config.seed) + return config, table + + def __init_from_ranges(self, **kwargs): + param_ranges = kwargs.pop("param_ranges", None) + if param_ranges is None: + param_ranges = {k: v for k, v in kwargs.items() if k in self.default_config.param_ranges} + kwargs = {k: v for k, v in kwargs.items() if k not in self.default_config.param_ranges} + + config = self.default_config(**kwargs, method="subset") + config.param_ranges.update(param_ranges) + return self.__init_from_config(config=config) + + def __init__(self, table=None, config=None, **kwargs): + if table is not None: + config, table = self.__init_from_table(table, config=config, **kwargs) + elif isinstance(config, Config): + config, table = self.__init_from_config(config=config, **kwargs) + elif isinstance(config, dict) or kwargs: + config = {} if config is None else config + config, table = self.__init_from_ranges(**{**config, **kwargs}) + else: + raise ValueError( + "Expected table or config, or keyword arguments of expected " + f"catalog value ranges, got:\n table={table}\n config={config} " + f"\n kwargs={kwargs}" + ) + + self.config = config + self.table = table + self.current = 0 -class SimpleSourceCatalog(CatalogFactory): - base_param_ranges = { - "amplitude": [500, 2000], - "x_mean": [0, 4096], - "y_mean": [0, 2048], - "x_stddev": [1, 7], - "y_stddev": [1, 7], - "theta": [0, np.pi], - } + @classmethod + def from_config(cls, config, **kwargs): + config = cls.default_config(config=config, method="subset", **kwargs) + return cls(gen_catalog(config.n, config.param_ranges, config.seed), config=config) - def __init__(self, table, return_copy=False): - self.table = table - self.return_copy = return_copy + @classmethod + def from_ranges(cls, n=None, config=None, **kwargs): + config = cls.default_config(n=n, config=config, method="subset") + config.param_ranges.update(**kwargs) + return cls.from_config(config) @classmethod - def from_params(cls, n=100, param_ranges=None): - param_ranges = {} if param_ranges is None else param_ranges - tmp = cls.base_param_ranges.copy() - tmp.update(param_ranges) - return cls(gen_catalog(n, tmp)) - - def gen_realization(self, *args, t=None, dt=None, **kwargs): - if self.return_copy: + def from_table(cls, table): + config = cls.default_config() + config.n = len(table) + params = {} + for col in table.keys(): + params[col] = (table[col].min(), table[col].max()) + config["param_ranges"] = params + return cls(table, config=config) + + def mock(self): + self.current += 1 + if self.config.return_copy: return self.table.copy() return self.table -class SimpleObjectCatalog(CatalogFactory): - base_param_ranges = { - "amplitude": [1, 100], - "x_mean": [0, 4096], - "y_mean": [0, 2048], - "vx": [500, 1000], - "vy": [500, 1000], - "stddev": [1, 1.8], - "theta": [0, np.pi], +class SourceCatalogConfig(SimpleCatalogConfig): + param_ranges = { + "amplitude": [1., 10.], + "x_mean": [0., 4096.], + "y_mean": [0., 2048.], + "x_stddev": [1., 3.], + "y_stddev": [1., 3.], + "theta": [0., np.pi], } - def __init__(self, table, obstime=None): - self.table = table - self._realization = table.copy() + +class SourceCatalog(SimpleCatalog): + default_config = SourceCatalogConfig + + +class ObjectCatalogConfig(SimpleCatalogConfig): + param_ranges = { + "amplitude": [0.1, 3.0], + "x_mean": [0., 4096.], + "y_mean": [0., 2048.], + "vx": [500., 1000.], + "vy": [500., 1000.], + "stddev": [0.25, 1.5], + "theta": [0., np.pi], + } + + +class ObjectCatalog(SimpleCatalog): + default_config = ObjectCatalogConfig + + def __init__(self, table=None, obstime=None, config=None, **kwargs): + # put return_copy into kwargs to override whatever user might have + # supplied, and to guarantee the default is overriden + kwargs["return_copy"] = True + super().__init__(table=table, config=config, **kwargs) + self._realization = self.table.copy() self.obstime = 0 if obstime is None else obstime - @classmethod - def from_params(cls, n=100, param_ranges=None): - param_ranges = {} if param_ranges is None else param_ranges - tmp = cls.base_param_ranges.copy() - tmp.update(param_ranges) - return cls(gen_catalog(n, tmp)) + def reset(self): + self.current = 0 + self._realization = self.table.copy() def gen_realization(self, t=None, dt=None, **kwargs): if t is None and dt is None: return self._realization dt = dt if t is None else t - self.obstime - self._realization["x_mean"] += self._realization["vx"] * dt - self._realization["y_mean"] += self._realization["vy"] * dt + self._realization["x_mean"] += self.table["vx"] * dt + self._realization["y_mean"] += self.table["vy"] * dt return self._realization def mock(self, n=1, **kwargs): + breakpoint() if n == 1: - return self.gen_realization(**kwargs) - return [self.gen_realization(**kwargs).copy() for i in range(n)] + data = self.gen_realization(**kwargs) + self.current += 1 + else: + data = [] + for i in range(n): + data.append(self.gen_realization(**kwargs).copy()) + self.current += 1 + + return data diff --git a/src/kbmod/mocking/config.py b/src/kbmod/mocking/config.py index 0bb94a8ac..ed34aa649 100644 --- a/src/kbmod/mocking/config.py +++ b/src/kbmod/mocking/config.py @@ -1,3 +1,5 @@ +import copy + __all__ = ["Config", "ConfigurationError"] @@ -22,37 +24,66 @@ class attributes. Particular attributes can be overriden on an per-instance Keyword arguments, assigned as configuration key-values. """ - def __init__(self, config=None, **kwargs): + def __init__(self, config=None, method="default", **kwargs): # This is a bit hacky, but it makes life a lot easier because it # enables automatic loading of the default configuration and separation # of default config from instance bound config keys = list(set(dir(self.__class__)) - set(dir(Config))) # First fill out all the defaults by copying cls attrs - self._conf = {k: getattr(self, k) for k in keys} + self._conf = {k: copy.copy(getattr(self, k)) for k in keys} # Then override with any user-specified values - conf = config - if isinstance(config, Config): - conf = config._conf + self.update(config=config, method=method, **kwargs) - if conf is not None: - self._conf.update(config) - self._conf.update(kwargs) + @classmethod + def from_configs(cls, *args): + config = cls() + for conf in args: + config.update(config=conf, method="extend") + return config - # now just shortcut the most common dict operations def __getitem__(self, key): return self._conf[key] + # now just shortcut the most common dict operations + def __getattribute__(self, key): + hasconf = "_conf" in object.__getattribute__(self, "__dict__") + if hasconf: + conf = object.__getattribute__(self, "_conf") + if key in conf: + return conf[key] + return object.__getattribute__(self, key) + def __setitem__(self, key, value): self._conf[key] = value + def __repr__(self): + res = f"{self.__class__.__name__}(" + for k, v in self.items(): + res += f"{k}: {v}, " + return res[:-2] + ")" + def __str__(self): res = f"{self.__class__.__name__}(" for k, v in self.items(): res += f"{k}: {v}, " return res[:-2] + ")" + def _repr_html_(self): + repr = f""" + + + + + + + """ + for k, v in self.items(): + repr += f"
{self.__class__.__name__}
KeyValue
{k}{v}\n" + repr += "
" + return repr + def __len__(self): return len(self._conf) @@ -76,7 +107,7 @@ def __or__(self, other): elif isinstance(other, dict): return self.__class__(config=self._conf | other) else: - raise TypeError("unsupported operand type(s) for |: {type(self)} " "and {type(other)}") + raise TypeError("unsupported operand type(s) for |: {type(self)}and {type(other)}") def keys(self): """A set-like object providing a view on config's keys.""" @@ -90,7 +121,10 @@ def items(self): """A set-like object providing a view on config's items.""" return self._conf.items() - def update(self, conf=None, **kwargs): + def copy(self): + return self.__class__(config=self._conf.copy()) + + def update(self, config=None, method="default", **kwargs): """Update this config from dict/other config/iterable and apply any explicit keyword overrides. @@ -107,9 +141,46 @@ def update(self, conf=None, **kwargs): for k in kwargs: this[k] = kwargs[k] """ - if conf is not None: - self._conf.update(conf) - self._conf.update(kwargs) + # Python < 3.9 does not support set operations for dicts + # [fixme]: Update this to: other = conf | kwargs + # and remove current implementation when 3.9 gets too old. Order of + # conf and kwargs matter to correctly apply explicit overrides + + # Check if both conf and kwargs are given, just conf or just + # kwargs. If none are given do nothing to comply with default + # dict behavior + if config is not None and kwargs: + other = {**config, **kwargs} + elif config is not None: + other = config + elif kwargs is not None: + other = kwargs + else: + return + + # then, see if we the given config and overrides are a subset of this + # config or it's superset. Depending on the selected method then raise + # errors, ignore or extend the current config if the given config is a + # superset (or disjoint) from the current one. + subset = {k: v for k, v in other.items() if k in self._conf} + superset = {k: v for k, v in other.items() if k not in subset} + + if method.lower() == "default": + if superset: + raise ConfigurationError( + "Tried setting the following fields, not a part of " + f"this configuration options: {superset}" + ) + conf = other # == subset + elif method.lower() == "subset": + conf = subset + elif method.lower() == "extend": + conf = other + else: + raise ValueError("Method expected to be one of 'default', " + f"'subset' or 'extend'. Got {method} instead.") + + self._conf.update(conf) def toDict(self): """Return this config as a dict.""" diff --git a/src/kbmod/mocking/fits.py b/src/kbmod/mocking/fits.py index f47579461..9eb46a62c 100644 --- a/src/kbmod/mocking/fits.py +++ b/src/kbmod/mocking/fits.py @@ -4,6 +4,7 @@ import numpy as np from astropy.wcs import WCS +from astropy.table import Table from astropy.modeling import models from astropy.io.fits import ( HDUList, @@ -13,68 +14,73 @@ Header ) -from .headers import HeaderFactory, ArchivedHeader -from .catalogs import gen_catalog, SimpleSourceCatalog, SimpleObjectCatalog +from .config import Config +from .headers import HeaderFactory, ArchivedHeader, HeaderFactoryConfig +from .catalogs import gen_catalog, SimpleCatalog, SourceCatalog, ObjectCatalog from .fits_data import ( + DataFactoryConfig, DataFactory, - ZeroedData, + SimpleImageConfig, SimpleImage, + SimulatedImageConfig, + SimulatedImage, + SimpleVarianceConfig, SimpleVariance, + SimpleMaskConfig, SimpleMask, add_model_objects ) __all__ = [ - "HDUFactory", - "HDUListFactory", "callback", + "EmptyFitsConfig", "EmptyFits", + "SimpleFitsConfig", "SimpleFits", - "DECamImdiffs", + "DECamImdiffConfig", + "DECamImdiff", ] -class HDUFactory: - def __init__(self, hdu_cls, header_factory, data_factory=None): - self.hdu_cls = hdu_cls - self.header_factory = header_factory - self.data_factory = data_factory - self.update_data = False if data_factory is None else True +class HDUListFactoryConfig(Config): + validate_header = False + """Call ``Header.update`` instead of assigning header values. This enforces + FITS standards and may strip non-standard header keywords.""" - def mock(self, **kwargs): - hdu = self.hdu_cls() + update_header = False + """After mocking and assigning data call ``Header.update_header``. This + enforces header to be consistent to the data type and shape. May alter or + remove keywords from the mocked header.""" - header = self.header_factory.mock(hdu=hdu, **kwargs) - # hdu.header.update(header) costs more but maybe better? - hdu.header = header - if self.update_data: - data = self.data_factory.mock(hdu=hdu, **kwargs) - hdu.data = data - # not sure if we want this tbh - # hdu.update_header() +class HDUListFactory(abc.ABC): + default_config = HDUListFactoryConfig - return hdu + def __init__(self, config=None, **kwargs): + self.config = self.default_config(config=config, **kwargs) + self.current = 0 + def hdu_cast(self, hdu_cls, hdr, data=None, config=None, **kwargs): + hdu = hdu_cls() -class HDUListFactory(abc.ABC): - def __init__(self, layout, base_primary={}, base_ext={}, base_wcs={}): - self.layout = layout - self.base_primary = base_primary - self.base_ext = base_ext - self.base_wcs = base_wcs - - def generate(self, **kwargs): - hdul = HDUList() - for layer in self.layout: - hdul.append(layer.mock(**kwargs)) - return hdul + if self.config.validate_header: + hdu.header.update(hdr) + else: + hdu.header = hdr + + if data is not None: + hdu.data = data + if self.config.update_header: + hdu.update_header() + + return hdu @abc.abstractmethod def mock(self, n=1): raise NotImplementedError() + # I am sure a decorator like this must exist somewhere in functools, but can't # find it and I'm doing something wrong with functools.partial because that's # strictly right-side binding? @@ -90,9 +96,7 @@ def wrapper(*args, **kwargs): def f(*fargs, **fkwargs): kwargs.update(fkwargs) return func(*(args + fargs), **kwargs) - return f - else: # functions, static methods def wrapper(*args, **kwargs): @@ -100,170 +104,375 @@ def wrapper(*args, **kwargs): def f(*fargs, **fkwargs): kwargs.update(fkwargs) return func(*fargs, **kwargs) - return f - return wrapper +class EmptyFitsConfig(HDUListFactoryConfig): + editable_images = False + separate_masks = False + writeable_mask = False + dt = 0.001 + + shape = (100, 100) + """Dimensions of the generated images.""" + + class EmptyFits(HDUListFactory): + default_config = EmptyFitsConfig + @callback @staticmethod def increment_obstime(old, dt): return old + dt - def __init__(self, primary_hdr=None, dt=0.001): - # header and data factories that go into creating HDUs - self.primary = HeaderFactory.from_base_primary( - metadata=primary_hdr, + def __init__(self, metadata=None, config=None, **kwargs): + super().__init__(config=config, method="extend", **kwargs) + + # 1. Update all the default configs. Use class configs that simplify + # and unite the many smaller config settings of each underlying factory. + breakpoint() + hdr_conf = HeaderFactoryConfig(config=self.config, method="subset", shape=self.config.shape) + + # 2. Set up Header and Data factories that go into creating HDUs + # 2.1) First headers, since that metadata specified data formats + self.primary_hdr = HeaderFactory.from_primary_template( + overrides=metadata, mutables=["OBS-MJD"], - callbacks=[self.increment_obstime(dt=dt)], + callbacks=[self.increment_obstime(dt=self.config.dt)], + config=hdr_conf ) - self.image = HeaderFactory.from_base_ext({"EXTNAME": "IMAGE"}) - self.variance = HeaderFactory.from_base_ext({"EXTNAME": "VARIANCE"}) - self.mask = HeaderFactory.from_base_ext({"EXTNAME": "MASK"}) - data = ZeroedData() - - # a map of HDU class and their respective header and data generators - layout = [ - HDUFactory(PrimaryHDU, self.primary), - HDUFactory(CompImageHDU, self.image, data), - HDUFactory(CompImageHDU, self.variance, data), - HDUFactory(CompImageHDU, self.mask, data), - ] - super().__init__(layout) + self.img_hdr = HeaderFactory.from_ext_template({"EXTNAME": "IMAGE"}, config=hdr_conf) + self.var_hdr = HeaderFactory.from_ext_template({"EXTNAME": "VARIANCE"}, config=hdr_conf) + self.mask_hdr = HeaderFactory.from_ext_template({"EXTNAME": "MASK"}, config=hdr_conf) + + # 2.2) Then data factories, attempt to save performance and memory + # where possible by really only allocating 1 array whenever the data + # is read-only and content-static between created HDUs. + writeable, return_copy = False, False + if self.config.editable_images: + writeable, return_copy = True, True + + self.data = DataFactory.from_header( + kind="image", + header=self.img_hdr.header, + writeable=writeable, + return_copy=return_copy + ) + self.shared_data = DataFactory.from_header("image", self.mask_hdr.header, + writeable=self.config.writeable_mask) def mock(self, n=1): - if n == 1: - return self.step() + # 3) Finally when mocking, vectorize as many operations as possible. + # The amount of data scales fast with number of images generated so + # that even modest iterative generation of data costs significantly. + # (F.e. for each HDU - 3 images are allocated, if we use DECam, for + # each HDU, 62 to 70 images can be generated). Many times generated values + # are not trivial zeros, but randomly drawn, modified, and then + # calculated. + var_hdr = self.var_hdr.mock() + mask_hdr = self.mask_hdr.mock() + + imgs = self.data.mock(n) + variances = self.data.mock(n) + masks = self.shared_data.mock(n) - shape = (n*3, self.image.header["NAXIS1"], self.image.header["NAXIS1"]) - images = np.zeros(shape, dtype=np.float32) - imghdr, varhdr, maskhdr = self.image.mock(), self.variance.mock(), self.mask.mock() hduls = [] for i in range(n): - hduls.append(HDUList( - hdus=[ - PrimaryHDU(header=self.primary.mock()), - CompImageHDU(header=imghdr, data=images[i]), - CompImageHDU(header=varhdr, data=images[n+i]), - CompImageHDU(header=maskhdr, data=images[2*n+i]) - ] - )) - + hduls.append(HDUList(hdus=[ + self.hdu_cast(PrimaryHDU, self.primary_hdr.mock()), + self.hdu_cast(CompImageHDU, self.img_hdr.mock(), imgs[i]), + self.hdu_cast(CompImageHDU, var_hdr, variances[i]), + self.hdu_cast(CompImageHDU, mask_hdr, masks[i]) + ])) + + self.current += n return hduls +class SimpleFitsConfig(HDUListFactoryConfig): + editable_images = False + separate_masks = False + writeable_mask = False + noise_generation = "simplistic" + dt = 0.001 + + shape = (100, 100) + """Dimensions of the generated images.""" + class SimpleFits(HDUListFactory): + default_config = SimpleFitsConfig + @callback @staticmethod def increment_obstime(old, dt): return old + dt - def __init__(self, shape, noise=0, noise_std=1, dt=0.001, source_catalog=None, object_catalog=None): - # internal counter of n images created so far - self._idx = 0 - self.dt = dt + def __init__(self, metadata=None, source_catalog=None, object_catalog=None, + config=None, **kwargs): + super().__init__(config=config, method="extend", **kwargs) + self.src_cat = source_catalog self.obj_cat = object_catalog - self.shape = shape - self.noise = noise - self.noise_std = noise_std - dims = {"NAXIS1": shape[0], "NAXIS2": shape[1]} - - # Then we can generate the header factories primary header contains no - # data, but does contain update-able metadata fields (f.e. timestamps), - # others contain static headers, but dynamic, mutually connected data - self.primary = HeaderFactory.from_base_primary( + + # 1. Update all the default configs using more user-friendly kwargs + hdr_conf = HeaderFactoryConfig(config=self.config, shape=self.config.shape, method="subset") + + if self.config.noise_generation == "realistic": + img_cfg = SimulatedImageConfig(config=self.config, method="subset") + else: + img_cfg = SimpleImageConfig(config=self.config, method="subset") + var_cfg = SimpleVarianceConfig(config=self.config, method="subset") + mask_cfg = SimpleMaskConfig(config=self.config, method="subset") + + # 2. Set up Header and Data factories that go into creating HDUs + # 2.1) First headers, since that metadata specified data formats + self.primary_hdr = HeaderFactory.from_primary_template( + overrides=metadata, mutables=["OBS-MJD"], - callbacks=[self.increment_obstime(dt=0.001)], - ) - self.image = HeaderFactory.from_base_ext({"EXTNAME": "IMAGE", **dims}) - self.variance = HeaderFactory.from_base_ext({"EXTNAME": "VARIANCE", **dims}) - self.mask = HeaderFactory.from_base_ext({"EXTNAME": "MASK", **dims}) - - self.img_data = SimpleImage( - shape=shape, - noise=noise, - noise_std=noise_std, - src_cat=self.src_cat, + callbacks=[self.increment_obstime(dt=self.config.dt)], + config=hdr_conf ) - self.var_data = SimpleVariance(self.img_data.base, read_noise=noise, gain=1.0) - self.mask_data = DataFactory(np.zeros(shape)) - - # Now we can build the HDU map and the HDUList layout - layout = [ - HDUFactory(PrimaryHDU, self.primary), - HDUFactory(CompImageHDU, self.image, self.img_data), - HDUFactory(CompImageHDU, self.variance, self.var_data), - HDUFactory(CompImageHDU, self.mask, self.mask_data), - ] - - super().__init__(layout) + self.img_hdr = HeaderFactory.from_ext_template({"EXTNAME": "IMAGE"}, config=hdr_conf) + self.var_hdr = HeaderFactory.from_ext_template({"EXTNAME": "VARIANCE"}, config=hdr_conf) + self.mask_hdr = HeaderFactory.from_ext_template({"EXTNAME": "MASK"}, config=hdr_conf) + + # 2.2) Then data factories + if self.config.noise_generation == "realistic": + self.img_data = SimulatedImage(src_cat=self.src_cat, config=img_cfg) + else: + self.img_data = SimpleImage(src_cat=self.src_cat, config=img_cfg) + self.var_data = SimpleVariance(self.img_data.base, config=var_cfg) + self.mask_data = SimpleMask.from_image(self.img_data.base, config=mask_cfg) @classmethod - def from_defaults(cls, shape=(2048, 4096), add_static_sources=False, add_moving_objects=True, **kwargs): - source_catalog, object_catalog = None, None - cat_lims = {"x_mean": [0, shape[1]], "y_mean": [0, shape[0]]} - if add_static_sources: - source_catalog = SimpleSourceCatalog.from_params(param_ranges=cat_lims) - if add_moving_objects: - object_catalog = SimpleObjectCatalog.from_params(param_ranges=cat_lims) - - return cls(shape=shape, source_catalog=source_catalog, object_catalog=object_catalog, **kwargs) + def from_defaults(cls, config=None, sources=None, objects=None, **kwargs): + config = cls.default_config(config, **kwargs, method="extend") + cat_lims = {"x_mean": [0, config.shape[1]], "y_mean": [0, config.shape[0]]} + src_cat, obj_cat = None, None + if sources is not None: + if isinstance(sources, SimpleCatalog): + src_cat = sources + elif isinstance(sources, Table): + src_cat = SourceCatalog(table=sources) + elif isinstance(sources, Config) or sources == True: + sources.param_ranges.update(cat_lims) + src_cat = SourceCatalog(**sources) + elif isinstance(sources, dict): + sources.update(cat_lims) + src_cat = SourceCatalog(**sources) + elif isinstance(sources, int): + src_cat = SourceCatalog(n=sources, **cat_lims) + elif sources is True: + obj_cat = SourceCatalog(**cat_lims) + else: + raise ValueError("Sources are expected to be a catalog, table, config or config overrides.") + if objects is not None: + if isinstance(objects, SimpleCatalog): + obj_cat = objects + elif isinstance(objects, Table): + obj_cat = ObjectCatalog(table=objects) + elif isinstance(objects, Config): + objects.param_ranges.update(cat_lims) + obj_cat = ObjectCatalog(**objects) + elif isinstance(objects, dict): + objects.update(cat_lims) + obj_cat = ObjectCatalog(**objects) + elif isinstance(objects, int): + obj_cat = ObjectCatalog(n=objects, **cat_lims) + elif objects is True: + obj_cat = ObjectCatalog(**cat_lims) + else: + raise ValueError("Objects are expected to be a catalog, table, config or config overrides.") + return cls(source_catalog=src_cat, object_catalog=obj_cat, config=config, **kwargs) def mock(self, n=1): obj_cats = None - if self.obj_cat is not None: - obj_cats = self.obj_cat.mock(n, dt=self.dt) + obj_cats = self.obj_cat.mock(n, dt=self.config.dt) + + var_hdr = self.var_hdr.mock() + mask_hdr = self.mask_hdr.mock() images = self.img_data.mock(n, obj_cats=obj_cats) variances = self.var_data.mock(images=images) - mask = self.mask_data.mock() - imghdr, varhdr, maskhdr = self.image.mock(), self.variance.mock(), self.mask.mock() + masks = self.mask_data.mock(n) hduls = [] for i in range(n): - hduls.append(HDUList( - hdus=[ - PrimaryHDU(header=self.primary.mock()), - CompImageHDU(header=imghdr, data=images[i]), - CompImageHDU(header=varhdr, data=variances[i]), - CompImageHDU(header=maskhdr, data=mask), - ] - )) + hduls.append(HDUList(hdus=[ + self.hdu_cast(PrimaryHDU, self.primary_hdr.mock()), + self.hdu_cast(CompImageHDU, self.img_hdr.mock(), images[i]), + self.hdu_cast(CompImageHDU, var_hdr, variances[i]), + self.hdu_cast(CompImageHDU, mask_hdr, masks[i]) + ])) + + self.current += n return hduls -class DECamImdiffs(HDUListFactory): - def __init__(self): - headers = ArchivedHeader("headers_archive.tar.bz2", "decam_imdiff_headers.ecsv") - data = ZeroedData() - image = HDUFactory(CompImageHDU, headers, data) - - # DECam imdiff data products consist of 16 headers, first 4 are: - # primary, science variance and mask and the rest are PSF, SkyWCS, - # catalog metadata, chebyshev higher order corrections etc. We don't - # use these, so leave them with no data generators. The first 4 we fill - # with all-zeros. - # don't let Black have its way with these lines because it's a massacre - # fmt: off - layout = [ - HDUFactory(PrimaryHDU, headers), - image, - image, - image, - ] - layout.extend([HDUFactory(BinTableHDU, headers) ] * 12) - # fmt: on +class DECamImdiffConfig2(HDUListFactoryConfig): + archive_name = "headers_archive.tar.bz2" + file_name = "decam_imdiff_headers.ecsv" + n_hdrs_per_hdu = 16 + + +class DECamImdiff2(HDUListFactory): + default_config = DECamImdiffConfig2 + + def __init__(self, config=None, **kwargs): + super().__init__(config=config, **kwargs) + self.hdr_factory = ArchivedHeader(self.config.archive_name, self.config.file_name) + + def mock(self, n=1): + hdrs = self.hdr_factory.mock(n) + + hduls = [] + for hdul_idx in range(n): + hduls.append(HDUList(hdus=[ + self.hdu_cast(PrimaryHDU, hdrs[hdul_idx][0]), + self.hdu_cast(CompImageHDU, hdrs[hdul_idx][1]), + self.hdu_cast(CompImageHDU, hdrs[hdul_idx][2]), + self.hdu_cast(CompImageHDU, hdrs[hdul_idx][3]), + self.hdu_cast(BinTableHDU, hdrs[hdul_idx][4]), + self.hdu_cast(BinTableHDU, hdrs[hdul_idx][5]), + self.hdu_cast(BinTableHDU, hdrs[hdul_idx][6]), + self.hdu_cast(BinTableHDU, hdrs[hdul_idx][7]), + self.hdu_cast(BinTableHDU, hdrs[hdul_idx][8]), + self.hdu_cast(BinTableHDU, hdrs[hdul_idx][9]), + self.hdu_cast(BinTableHDU, hdrs[hdul_idx][10]), + self.hdu_cast(BinTableHDU, hdrs[hdul_idx][11]), + self.hdu_cast(BinTableHDU, hdrs[hdul_idx][12]), + self.hdu_cast(BinTableHDU, hdrs[hdul_idx][13]), + self.hdu_cast(BinTableHDU, hdrs[hdul_idx][14]), + self.hdu_cast(BinTableHDU, hdrs[hdul_idx][15]), + ])) + + self.current += 1 + return hduls + + + + +class DECamImdiffConfig(HDUListFactoryConfig): + archive_name = "headers_archive.tar.bz2" + file_name = "decam_imdiff_headers.ecsv" + + with_data=False + + editable_images = False + separate_masks = False + writeable_mask = False + noise_generation = "simplistic" + dt = 0.001 + + shape = (100, 100) + """Dimensions of the generated images.""" - super().__init__(layout) + +class NoneFactory: + "Kinda makes some code later prettier. Kinda" + def mock(self, n): + return [None, ]*n + + +class DECamImdiff(HDUListFactory): + default_config = DECamImdiffConfig + + def __init__(self, source_catalog=None, object_catalog=None, config=None, **kwargs): + super().__init__(config=config, **kwargs) + + # 1. Get the header factory - this is different than before. In a + # header preserving factory, it's the headers that dictate the shape + # and format of data. Since these are also optional for this class, we + # just create empty placeholders for data and only fill it with + # factories if we need to. + self.hdr_factory = ArchivedHeader(self.config.archive_name, + self.config.file_name) + + self.hdu_types = [PrimaryHDU, CompImageHDU, CompImageHDU, CompImageHDU] + self.hdu_types.extend([BinTableHDU, ]*12) + self.data = [NoneFactory(), ]*16 + + self.src_cat = source_catalog + self.obj_cat = object_catalog + + if self.config.with_data: + self.__init_data_factories() + + def __init_data_factories(self): + # 2. To fill in data placehodlers, we get an example set of headers in + # a file. Since they dictate data, we must derive whatever configs we + # need to and use them as overrides to trump defaults and user-given + # values. This is the point at which we rely on user telling this + # factory what the important header keys are since these could be + # non-standard. And since these change for each HDU, it is rather + # difficult to generalize without resorting to iterative and fully + # dynamical data generation - which will be slow. The only mercy is + # that we only have to do it for the HDUs we care about and use factory + # methods for the rest. + headers = self.hdr_factory.get(0) + + img_shape = (headers[1]["NAXIS1"], headers[1]["NAXIS2"]) + img_overrides = { + "shape": img_shape, + "dtype": DataFactoryConfig.bitpix_type_map[headers[1]["BITPIX"]], + } + + if self.config.noise_generation == "realistic": + img_cfg = SimulatedImageConfig(config=self.config, **img_overrides, method="subset") + else: + img_cfg = SimpleImageConfig(config=self.config, **img_overrides, method="subset") + var_cfg = SimpleVarianceConfig(config=self.config, **img_overrides, method="subset") + mask_cfg = SimpleMaskConfig(config=self.config, **img_overrides, method="subset") + + # 2.1) Now we can instantiate data factories with correct configs + # and fill in the data placeholder + if self.config.noise_generation == "realistic": + self.img_data = SimulatedImage(src_cat=self.src_cat, config=img_cfg) + else: + self.img_data = SimpleImage(src_cat=self.src_cat, config=img_cfg) + self.var_data = SimpleVariance(self.img_data.base, config=var_cfg) + self.mask_data = SimpleMask.from_image(self.img_data.base, config=mask_cfg) + + self.hdu_types = [PrimaryHDU, CompImageHDU, CompImageHDU, CompImageHDU] + self.hdu_types.extend([BinTableHDU, ]*12) + + self.data = [ + NoneFactory(), + self.img_data, + self.var_data, + self.mask_data + ] + self.data.extend([ + DataFactory.from_header("table", h) for h in headers[4:] + ]) def mock(self, n=1): - if n==1: - return self.generate() - hduls = [self.generate() for i in range(n)] + obj_cats = None + if self.obj_cat is not None: + obj_cats = self.obj_cat.mock(n, dt=self.config.dt) + + hdrs = self.hdr_factory.mock(n) + + if self.config.with_data: + images = self.img_data.mock(n, obj_cats=obj_cats) + variances = self.var_data.mock(images=images) + data = [self.data[0].mock(n), images, variances] + for factory in self.data[3:]: + data.append(factory.mock(n=n)) + else: + data = [f.mock(n=n) for f in self.data] + + hduls = [] + for hdul_idx in range(n): + hdus = [] + for hdu_idx, hdu_cls in enumerate(self.hdu_types): + hdus.append(self.hdu_cast( + hdu_cls, hdrs[hdul_idx][hdu_idx], data[hdu_idx][hdul_idx] + )) + hduls.append(HDUList(hdus=hdus)) + self.current += n return hduls diff --git a/src/kbmod/mocking/fits_data.py b/src/kbmod/mocking/fits_data.py index ffd8d79fb..79b6da1e3 100644 --- a/src/kbmod/mocking/fits_data.py +++ b/src/kbmod/mocking/fits_data.py @@ -9,17 +9,25 @@ from astropy.modeling import models from .config import Config, ConfigurationError +from kbmod import Logging __all__ = [ "add_model_objects", + "DataFactoryConfig", "DataFactory", - "ZeroedData", + "SimpleImageConfig", "SimpleImage", - "SimpleMask" + "SimpleMaskConfig", + "SimpleMask", + "SimulatedImageConfig", + "SimulatedImage" ] +logger = Logging.getLogger(__name__) + + def add_model_objects(img, catalog, model): """Adds a catalog of model objects to the image. @@ -50,6 +58,9 @@ def add_model_objects(img, catalog, model): # Save the initial model parameters so we can set them back init_params = {param: getattr(model, param) for param in params_to_set} + # model could throw a value error if drawn amplitude was too large, we must + # restore the model back to its starting values to cover for a general + # use-case because Astropy modelling is a bit awkward. try: for i, source in enumerate(catalog): for param in params_to_set: @@ -62,7 +73,6 @@ def add_model_objects(img, catalog, model): model.y_mean < img.shape[0] ]): model.render(img) - finally: for param, value in init_params.items(): setattr(model, param, value) @@ -74,6 +84,13 @@ class DataFactoryConfig(Config): """Data factory configuration primarily controls mutability of the given and returned mocked datasets. """ + default_img_shape = (5, 5) + + default_img_bit_width = 32 + + default_tbl_length = 5 + + default_tbl_dtype = np.dtype([("a", int), ("b", int)]) writeable = False """Sets the base array ``writeable`` flag. Default `False`.""" @@ -84,6 +101,21 @@ class DataFactoryConfig(Config): otherwise the original (possibly mutable!) object is returned. Default `False`. """ + # https://archive.stsci.edu/fits/fits_standard/node39.html#s:man + bitpix_type_map = { + # or char + 8: int, + # actually no idea what dtype, or C type for that matter, + # are used to represent these values. But default Headers return them + 16: np.float16, + 32: np.float32, + 64: np.float64, + # classic IEEE float and double + -32: np.float32, + -64: np.float64, + } + """Map between FITS header BITPIX keyword value and NumPy return type.""" + class DataFactory: """Generic data factory. @@ -132,32 +164,103 @@ class DataFactory: default_config = DataFactoryConfig """Default configuration.""" - def __init__(self, base=None, config=None, **kwargs): - self.config = self.default_config() - self.config.update(config, **kwargs) + def __init__(self, base, config=None, **kwargs): + self.config = self.default_config(config, **kwargs) self.base = base - if base is not None: + if base is None: + self.shape = None + self.dtype = None + else: + self.shape = base.shape + self.dtype = base.dtype self.base.flags.writeable = self.config.writeable self.counter = 0 - def mock(self, **kwargs): - """Return a mocked fits data and increase counter. + @classmethod + def gen_image(cls, metadata=None, config=None, **kwargs): + conf = cls.default_config(config, method="subset", **kwargs) + cols = metadata.get("NAXIS1", conf.default_img_shape[0]) + rows = metadata.get("NAXIS2", conf.default_img_shape[1]) + bitwidth = metadata.get("BITPIX", conf.default_img_bit_width) + dtype = conf.bitpix_type_map[bitwidth] + shape = (cols, rows) - Parameters - ---------- - **kwargs - Any additional keyword arguments are ignored. + return np.zeros(shape, dtype) + + @classmethod + def gen_table(cls, metadata=None, config=None, **kwargs): + conf = cls.default_config(config, **kwargs, method="subset") + + # FITS format standards prescribe FORTRAN-77-like input format strings + # for different types, but the base set has been extended significantly + # since and Rubin uses completely non-standard keys with support for + # their own internal abstractions like 'Angle' objects: + # https://archive.stsci.edu/fits/fits_standard/node58.html + # https://docs.astropy.org/en/stable/io/fits/usage/table.html#column-creation + # https://github.com/lsst/afw/blob/main/src/fits.cc#L207 + # So we really don't have much of a choice but to force a default + # AstroPy HDU and then call the update. This might not preserve the + # header or the data formats exactly and if metadata isn't given + # could even assume a wrong class all together. The TableHDU is + # almost never used however - so hopefully this keeps on working. + table_cls = BinTableHDU + data = None + if metadata is not None: + if metadata["XTENSION"] == "BINTABLE": + table_cls = BinTableHDU + elif metadata["XTENSION"] == "TABLE": + table_cls = TableHDU + + hdu = table_cls() + hdu.header.update(metadata) + + rows = metadata.get("NAXIS2", conf.default_tbl_length) + shape = (rows, ) + data = np.zeros(shape, dtype=hdu.data.dtype) + else: + hdu = table_cls() + shape = (conf.default_tbl_length, ) + data = np.zeros(shape, dtype=conf.default_tbl_dtype) + + return data + + @classmethod + def from_hdu(cls, hdu, config=None, **kwargs): + if isinstance(hdu, (PrimaryHDU, CompImageHDU, ImageHDU)): + return cls(base=cls.gen_image(hdu), config=config, **kwargs) + elif isinstance(hdu, (TableHDU, BinTableHDU)): + return cls(base=cls.gen_table(hdu), config=config, **kwargs) + else: + raise TypeError(f"Expected an HDU, got {type(hdu)} instead.") + + @classmethod + def from_header(cls, kind, header, config=None, **kwargs): + if kind.lower() == "image": + return cls(base=cls.gen_image(header), config=config, **kwargs) + elif kind.lower() == "table": + return cls(base=cls.gen_table(header), config=config, **kwargs) + else: + raise TypeError(f"Expected an 'image' or 'table', got {kind} instead.") + + @classmethod + def zeros(cls, shape, dtype, config=None, **kwargs): + return cls(np.zeros(shape, dtype), config, **kwargs) + + def mock(self, n=1, **kwargs): + if self.base is None: + raise ValueError( + "Expected a DataFactory that has a base, but none was set. " + "Use `zeros` or `from_hdu` to construct this object correctly." + ) - Returns - ------- - data : `no.array` - Mocked data array - """ - self.counter += 1 if self.config.return_copy: - return self.base.copy() - return self.base + base = np.repeat(self.base[np.newaxis, ], (n,), axis=0) + else: + base = np.broadcast_to(self.base, (n, *self.shape)) + base.flags.writeable = self.config.writeable + + return base class SimpleVarianceConfig(DataFactoryConfig): @@ -207,7 +310,15 @@ def mock(self, images=None): class SimpleMaskConfig(DataFactoryConfig): """Simple mask configuration.""" - pass + dtype = np.float32 + + threshold = 1e-05 + + shape = (5, 5) + padding = 0 + bad_columns = [] + patches = [] + class SimpleMask(DataFactory): """Simple mask factory. @@ -223,11 +334,19 @@ class SimpleMask(DataFactory): """ default_config = SimpleMaskConfig - def __init__(self, mask): - super().__init__(base=mask) + + def __init__(self, mask, config=None, **kwargs): + super().__init__(base=mask, config=config, **kwargs) @classmethod - def from_params(cls, shape, padding=0, bad_columns=[], patches=[]): + def from_image(cls, image, config=None, **kwargs): + config = cls.default_config(config=config, **kwargs, method="subset") + mask = image.copy() + mask[image > config.threshold] = 1 + return cls(mask) + + @classmethod + def from_params(cls, config=None, **kwargs): """Create a mask by adding a padding around the edges of the array with the given dimensions and mask out bad columns. @@ -273,19 +392,22 @@ def from_params(cls, shape, padding=0, bad_columns=[], patches=[]): [1., 0., 1., 1., 0., 0., 0., 0., 0., 1.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]) """ - mask = np.zeros(shape) + config = cls.default_config(config=config, **kwargs, method="subset") + mask = np.zeros(config.shape, dtype=config.dtype) + + shape, padding = config.shape, config.padding # padding mask[:padding] = 1 - mask[shape[0] - padding :] = 1 + mask[shape[0] - padding:] = 1 mask[:, :padding] = 1 - mask[: shape[1] - padding :] = 1 + mask[: shape[1] - padding:] = 1 # bad columns - for col in bad_columns: + for col in config.bad_columns: mask[:, col] = 1 - for patch, value in patches: + for patch, value in config.patches: if isinstance(patch, tuple): mask[*patch] = 1 elif isinstance(slice): @@ -293,83 +415,7 @@ def from_params(cls, shape, padding=0, bad_columns=[], patches=[]): else: raise ValueError(f"Expected a tuple (x, y), (slice, slice) or slice, got {patch} instead.") - return cls(mask) - - -class ZeroedDataConfig(DataFactoryConfig): - """Zeroed data config.""" - return_copy = True - - shape = (5, 5) - """Default image size.""" - - # https://archive.stsci.edu/fits/fits_standard/node39.html#s:man - bitpix_type_map = { - # or char - 8: int, - # actually no idea what dtype, or C type for that matter, - # are used to represent these values. But default Headers return them - 16: np.float16, - 32: np.float32, - 64: np.float64, - # classic IEEE float and double - -32: np.float32, - -64: np.float64, - } - """Map between FITS header BITPIX keyword value and NumPy return type.""" - - -class ZeroedData(DataFactory): - """Data factory that creates zeroed data arrays from header definitions. - - A convenience factory able to generate the correct size image and bintable - data from a header. - - Parameters - ---------- - base : `np.array` - Static data shared by all mocked instances. - config : `DataFactoryConfig` - Configuration of the data factory. - **kwargs : - Additional keyword arguments are applied as config - overrides. - """ - default_config = ZeroedDataConfig - - def __init__(self, base=None, config=None, **kwargs): - super().__init__(base, config, **kwargs) - - def mock_image_data(self, hdu): - cols = hdu.header.get("NAXIS1", False) - rows = hdu.header.get("NAXIS2", False) - shape = (cols, rows) if all((cols, rows)) else self.config.shape - - data = np.zeros( - shape, - dtype=self.config.bitpix_type_map[hdu.header["BITPIX"]] - ) - return data - - def mock_table_data(self, hdu): - # interestingly, table HDUs create their own empty - # tables from headers, but image HDUs do not, this - # is why hdu.data exists and has a valid dtype - nrows = hdu.header["TFIELDS"] - return np.zeros((nrows,), dtype=hdu.data.dtype) - - def mock(self, hdu=None, **kwargs): - # two cases: a) static data shared by all instances created by the - # factory set at init time or b) dynamic data generated from a header - # specification at call-time. - if self.base is not None: - return super().mock(hdu=hdu, **kwargs) - if isinstance(hdu, (PrimaryHDU, CompImageHDU, ImageHDU)): - return self.mock_image_data(hdu) - elif isinstance(hdu, (TableHDU, BinTableHDU)): - return self.mock_table_data(hdu) - else: - raise TypeError(f"Expected an HDU, got {type(hdu)} instead.") + return cls(mask, config=config) class SimpleImageConfig(DataFactoryConfig): @@ -386,7 +432,7 @@ class SimpleImageConfig(DataFactoryConfig): seed = None """Seed of the random number generator used to generate noise.""" - noise = 0 + noise = 10 """Mean of the standard Gaussian distribution of the noise.""" noise_std = 1.0 @@ -422,8 +468,10 @@ class SimpleImage(DataFactory): """ default_config = SimpleImageConfig - def __init__(self, image=None, config=None, src_cat=None, obj_cat=None, **kwargs): - super().__init__(image, config, **kwargs) + def __init__(self, image=None, src_cat=None, obj_cat=None, + config=None, **kwargs): + self.config = self.default_config(config=config, **kwargs) + super().__init__(image, self.config, **kwargs) if image is None: image = np.zeros(self.config.shape, dtype=np.float32) @@ -431,9 +479,15 @@ def __init__(self, image=None, config=None, src_cat=None, obj_cat=None, **kwargs image = image self.config.shape = image.shape + # Astropy throws a strange ValueError instead of reporting a non-writeable + # array, This must be a bug TODO: report. It's not safe to edit a + # non-writeable array and then revert writeability so make a copy. self.src_cat = src_cat if self.src_cat is not None: - add_model_objects(image, src_cat.table, self.config.model) + image = image if image.flags.writeable else image.copy() + add_model_objects(image, src_cat.table, + self.config.model(x_stddev=1, y_stddev=1)) + image.flags.writeable = self.config.writeable self.base = image self._base_contains_data = image.sum() != 0 @@ -456,7 +510,7 @@ def add_noise(cls, images, config): rng.standard_normal(size=shape, dtype=images.dtype, out=images) # There's a lot of multiplications that happen, skip if possible - if self.config.noise_std != 1.0: + if config.noise_std != 1.0: images *= config.noise_std images += config.noise @@ -476,16 +530,21 @@ def mock(self, n=1, obj_cats=None, **kwargs): shape = (n, *self.config.shape) images = np.zeros(shape, dtype=np.float32) - if self.config.noise != 0: - images = self.add_noise(n, images, self.config) + if self.config.add_noise: + images = self.add_noise(images=images, config=self.config) - # but if base has no data (no sources, bad cols etc) skip + # if base has no data (no sources, bad cols etc) skip if self._base_contains_data: images += self.base - # same with moving objects + # same with moving objects, each image has to have a new realization of + # a catalog in which moving objects have different coordinate. This way + # any trajectory can be mocked. When we have only 1 mocked image though + # zip will attempt to iterate over the next available dimension, and + # that's rows of the image and the table - we don't want that. if obj_cats is not None: - for i, (img, cat) in enumerate(zip(images, obj_cats)): + pairs = [(images[0], obj_cats),] if n == 1 else zip(images, obj_cats) + for i, (img, cat) in enumerate(pairs): add_model_objects( img, cat, @@ -526,6 +585,9 @@ class SimulatedImageConfig(DataFactoryConfig): # not sure this is a smart idea to put here rng = np.random.default_rng() + seed = None + """Random number generator seed shared by all number generators.""" + # image properties shape = (100, 100) """Dimensions of the created images.""" @@ -546,7 +608,7 @@ class SimulatedImageConfig(DataFactoryConfig): bias = 0.0 """Bias in counts.""" - add_bad_cols = False + add_bad_columns = False """Add bad columns to the image.""" bad_cols_method = "random" @@ -559,12 +621,15 @@ class SimulatedImageConfig(DataFactoryConfig): n_bad_cols = 5 """When bad columns method is random, sets the number of bad columns.""" - bad_cols_seed = 123 + bad_cols_seed = seed """Seed for the bad columns random number generator.""" - bad_col_pattern_offset = 0.1 + bad_col_offset = 100 + """Bad column signal offset (in counts) with regards to the baseline noise.""" + + bad_col_pattern_offset = 10 """Random-looking noise variation along the length of the bad columns is - offset from the base column counts by this amount multiplied by bias.""" + offset from the mean bad column counts by this amount.""" dark_current_gen = rng.poisson """Dark current follows a Poisson distribution.""" @@ -582,7 +647,7 @@ class SimulatedImageConfig(DataFactoryConfig): hot_pix_locs = [] """A `list[tuple]` of hot pixel indices.""" - hot_pix_seed = 321 + hot_pix_seed = seed """Seed for hot pixel random number generator.""" n_hot_pix = 10 @@ -644,25 +709,29 @@ def add_bad_cols(cls, image, config): image : `np.array` Image. """ - if not config.add_bad_cols: + if not config.add_bad_columns: return image + shape = image.shape + rng = np.random.RandomState(seed=config.bad_cols_seed) if config.bad_cols_method == "random": - rng = np.random.RandomState(seed=self.bad_cols_seed) bad_cols = rng.randint(0, shape[1], size=config.n_bad_cols) elif config.bad_col_locs: bad_cols = config.bad_col_locs else: - raise ConfigurationError("Bad columns method is not 'random', but `bad_col_locs` contains no bad column indices.") + raise ConfigurationError( + "Bad columns method is not 'random', but `bad_col_locs` " + "contains no column indices." + ) - self.col_pattern = rng.randint( + col_pattern = rng.randint( low=0, - high=int(config.bad_col_pattern_offset * config.bias), + high=int(config.bad_col_pattern_offset), size=shape[0] ) - for col in columns: - image[:, col] = config.bias + col_pattern + for col in bad_cols: + image[:, col] += col_pattern + config.bad_col_offset return image @@ -688,23 +757,27 @@ def add_hot_pixels(cls, image, config): if not config.add_hot_pixels: return image + shape = image.shape if config.hot_pix_method == "random": rng = np.random.RandomState(seed=config.hot_pix_seed) x = rng.randint(0, shape[1], size=config.n_hot_pix) y = rng.randint(0, shape[0], size=config.n_hot_pix) - hot_pixels = np.column_stack(x, y) + hot_pixels = np.column_stack([x, y]) elif config.hot_pix_locs: hot_pixels = pixels else: - raise ConfigurationError("Hot pixels method is not 'random', but `hot_pix_locs` contains no (col, row) location indices of hot pixels.") + raise ConfigurationError( + "Hot pixels method is not 'random', but `hot_pix_locs` contains " + "no (col, row) location indices of hot pixels." + ) for pix in hot_pixels: - image[*pix] += offset + image[*pix] += config.hot_pix_offset return image @classmethod - def add_noise(cls, n, images, config): + def add_noise(cls, images, config): """Adds read noise (gaussian), dark noise (poissonian) and sky background (poissonian) noise to the image. @@ -723,7 +796,7 @@ def add_noise(cls, n, images, config): shape = images.shape # add read noise - images += self.read_noise_gen( + images += config.read_noise_gen( scale=config.read_noise / config.gain, size=shape ) @@ -733,7 +806,7 @@ def add_noise(cls, n, images, config): images += config.dark_current_gen(current, size=shape) # add sky counts - images += self.sky_count_gen( + images += config.sky_count_gen( lam=config.sky_level * config.gain, size=shape ) / config.gain @@ -761,14 +834,20 @@ def gen_base_image(cls, config=None, src_cat=None): # empty image base = np.zeros(config.shape, dtype=np.float32) base += config.bias - base = cls.add_bad_cols(base, config) base = cls.add_hot_pixels(base, config) + base = cls.add_bad_cols(base, config) if src_cat is not None: - add_model_objects(base, src_cat.table, config.model) + add_model_objects(base, src_cat.table, + config.model(x_stddev=1, y_stddev=1)) return base - def __init__(self, image=None, config=None, src_cat=None, obj_cat=None): - conf = self.default_config(config) - base = self.gen_base_image(conf) - super().__init__(image=base, config=conf, src_cat=src_cat, obj_cat=obj_cat) + def __init__(self, image=None, config=None, src_cat=None, obj_cat=None, **kwargs): + conf = self.default_config(config=config, **kwargs) + # static objects are added in SimpleImage init + super().__init__( + image=self.gen_base_image(conf), + config=conf, + src_cat=src_cat, + obj_cat=obj_cat + ) diff --git a/src/kbmod/mocking/headers.py b/src/kbmod/mocking/headers.py index 033ecee1f..3947d978c 100644 --- a/src/kbmod/mocking/headers.py +++ b/src/kbmod/mocking/headers.py @@ -5,16 +5,18 @@ from astropy.io.fits import Header from .utils import header_archive_to_table +from .config import Config __all__ = [ "HeaderFactory", + "HeaderFactoryConfig", "ArchivedHeader", ] -class HeaderFactory: - base_primary = { +class HeaderFactoryConfig(Config): + primary_template = { "EXTNAME": "PRIMARY", "NAXIS": 0, "BITPIX": 8, @@ -26,9 +28,16 @@ class HeaderFactory: "OBSERVAT": "CTIO" } - base_ext = {"NAXIS": 2, "NAXIS1": 2048, "NAXIS2": 4096, "BITPIX": 32} + ext_template = { + "NAXIS": 2, + "NAXIS1": 2048, + "NAXIS2": 4096, + "CRPIX1": 1024, + "CPRIX2": 2048, + "BITPIX": 32 + } - base_wcs = { + wcs_template = { "ctype": ["RA---TAN", "DEC--TAN"], "crval": [351, -5], "cunit": ["deg", "deg"], @@ -36,6 +45,19 @@ class HeaderFactory: "cd": [[-1.44e-07, 7.32e-05], [7.32e-05, 1.44e-05]], } + def __init__(self, config=None, method="default", shape=None, **kwargs): + super().__init__(config=config, method=method, **kwargs) + + if shape is not None: + self.ext_template["NAXIS1"] = shape[1] + self.ext_template["NAXIS2"] = shape[0] + self.ext_template["CRPIX1"] = shape[1]//2 + self.ext_template["CRPIX2"] = shape[0]//2 + + +class HeaderFactory: + default_config = HeaderFactoryConfig + def __validate_mutables(self): # !xor if bool(self.mutables) != bool(self.callbacks): @@ -59,7 +81,10 @@ def __validate_mutables(self): "provide the required metadata keys." ) - def __init__(self, metadata=None, mutables=None, callbacks=None): + def __init__(self, metadata=None, mutables=None, callbacks=None, + config=None, **kwargs): + self.config = self.default_config(config=config, **kwargs) + cards = [] if metadata is None else metadata self.header = Header(cards=cards) @@ -67,50 +92,61 @@ def __init__(self, metadata=None, mutables=None, callbacks=None): self.callbacks = callbacks self.__validate_mutables() - def mock(self, hdu=None, **kwargs): + self.is_dynamic = self.mutables is None + self.counter = 0 + + def get(self): if self.mutables is not None: for i, mutable in enumerate(self.mutables): self.header[mutable] = self.callbacks[i](self.header[mutable]) - return self.header + def mock(self, **kwargs): + self.counter += 1 + return self.get(**kwargs) + @classmethod - def gen_wcs(cls, crval, metadata=None): - metadata = cls.base_wcs if metadata is None else metadata + def gen_wcs(cls, metadata): wcs = WCS(naxis=2) for k, v in metadata.items(): setattr(wcs.wcs, k, v) return wcs.to_header() @classmethod - def gen_header(cls, base, metadata, extend, add_wcs): - header = Header(base) if extend else Header() - header.update(metadata) + def gen_header(cls, base, overrides, wcs_base=None): + header = Header() if base is None else Header(base) + header.update(overrides) - if add_wcs: + if wcs_base is not None: naxis1 = header.get("NAXIS1", False) naxis2 = header.get("NAXIS2", False) if not all((naxis1, naxis2)): - raise ValueError("Adding a WCS to the header requires NAXIS1 and NAXIS2 keys.") - crpix = [naxis1 / 2.0, naxis2 / 2.0] - header.update(cls.gen_wcs(crpix)) + raise ValueError("Adding a WCS to the header requires " + "NAXIS1 and NAXIS2 keys.") + header.update(cls.gen_wcs(wcs_base)) return header @classmethod - def from_base_primary(cls, metadata=None, mutables=None, callbacks=None, extend_base=True, add_wcs=False): - hdr = cls.gen_header(cls.base_primary, metadata, extend_base, add_wcs) - return cls(hdr, mutables, callbacks) + def from_primary_template(cls, overrides=None, mutables=None, callbacks=None, + config=None): + config = cls.default_config(config=config) + hdr = cls.gen_header(base=config.primary_template, overrides=overrides) + return cls(hdr, mutables, callbacks, config=config) @classmethod - def from_base_ext( - cls, metadata=None, mutables=None, callbacks=None, extend_base=True, add_wcs=True, dims=None - ): - hdr = cls.gen_header(cls.base_ext, metadata, extend_base, add_wcs) - return cls(hdr, mutables, callbacks) - - -class ArchivedHeader(HeaderFactory): + def from_ext_template(cls, overrides=None, mutables=None, callbacks=None, + wcs=None, config=None): + config = cls.default_config(config=config) + hdr = cls.gen_header( + base=config.ext_template, + overrides=overrides, + wcs_base=config.wcs_template if wcs is None else wcs + ) + return cls(hdr, mutables, callbacks, config=config) + + +class ArchivedHeaderConfig(HeaderFactoryConfig): # will almost never be anything else. Rather, it would be a miracle if it # were something else, since FITS standard shouldn't allow it. Further # casting by some packages will always be casting implemented in terms of @@ -123,31 +159,61 @@ class ArchivedHeader(HeaderFactory): } """Map between type names and types themselves.""" - def __init__(self, archive_name, fname, compression="bz2", format="ascii.ecsv"): - self.table = header_archive_to_table(archive_name, fname, compression, format) + compression = "bz2" + + format = "ascii.ecsv" + + n_hdrs_per_hdu = 1 + +class ArchivedHeader(HeaderFactory): + default_config = ArchivedHeaderConfig + + def __init__(self, archive_name, fname, config=None, **kwargs): + super().__init__(config, **kwargs) + self.table = header_archive_to_table( + archive_name, fname, self.config.compression, self.config.format + ) # Create HDU groups for easier iteration - self.table = self.table.group_by(["filename", "hdu"]) + self.table = self.table.group_by("filename") self.n_hdus = len(self.table) - # Internal counter for the current fits index, - # so that we may auto-increment it and avoid returning - # the same data all the time. - self._current = 0 - def lexical_cast(self, value, format): """Cast str literal of a type to the type itself. Supports just the builtin Python types. """ - if format in self.lexical_type_map: - return self.lexical_type_map[format](value) + if format in self.config.lexical_type_map: + return self.config.lexical_type_map[format](value) return value - def mock(self, hdu=None): + def get_item(self, group_idx, hdr_idx): header = Header() - warnings.filterwarnings("ignore", category=AstropyUserWarning) - for k, v, f in self.table.groups[self._current]["keyword", "value", "format"]: + # this is equivalent to one hdulist worth of headers + group = self.table.groups[group_idx] + # this is equivalent to one HDU's header + subgroup = group.group_by("hdu") + for k, v, f in subgroup.groups[hdr_idx]["keyword", "value", "format"]: header[k] = self.lexical_cast(v, f) warnings.resetwarnings() - self._current += 1 return header + + def get(self, group_idx): + headers = [] + # this is a bit repetitive but it saves recreating + # groups for one HDUL-equivalent many times + group = self.table.groups[group_idx] + subgroup = group.group_by("hdu") + headers = [] + for subgroup in subgroup.groups: + header = Header() + for k, v, f in subgroup["keyword", "value", "format"]: + header[k] = self.lexical_cast(v, f) + headers.append(header) + return headers + + def mock(self, n=1): + res = [] + for _ in range(n): + res.append(self.get(self.counter)) + self.counter += 1 + return res