diff --git a/src/kbmod/mocking/__init__.py b/src/kbmod/mocking/__init__.py
index d678d95d4..b79086419 100644
--- a/src/kbmod/mocking/__init__.py
+++ b/src/kbmod/mocking/__init__.py
@@ -3,4 +3,3 @@
from .headers import *
from .fits_data import *
from .fits import *
-#from . import test_mocking
diff --git a/src/kbmod/mocking/catalogs.py b/src/kbmod/mocking/catalogs.py
index 06fa5dadd..6e4fc853f 100644
--- a/src/kbmod/mocking/catalogs.py
+++ b/src/kbmod/mocking/catalogs.py
@@ -1,15 +1,18 @@
import abc
import numpy as np
-from astropy.time import Time
-from astropy.table import QTable, vstack
+from astropy.table import QTable
+from .config import Config
__all__ = [
"gen_catalog",
"CatalogFactory",
- "SimpleSourceCatalog",
- "SimpleObjectCatalog",
+ "SimpleCatalog",
+ "SourceCatalogConfig",
+ "SourceCatalog",
+ "ObjectCatalogConfig",
+ "ObjectCatalog",
]
@@ -26,84 +29,161 @@ def gen_catalog(n, param_ranges, seed=None):
# conversion assumes a gaussian
if "flux" in param_ranges and "amplitude" not in param_ranges:
- xstd = cat["x_stddev"] if "x_stddev" in cat.colnames else 1
- ystd = cat["y_stddev"] if "y_stddev" in cat.colnames else 1
+ xstd = cat["x_stddev"] if "x_stddev" in cat.colnames else 1.0
+ ystd = cat["y_stddev"] if "y_stddev" in cat.colnames else 1.0
cat["amplitude"] = cat["flux"] / (2.0 * np.pi * xstd * ystd)
return cat
-
class CatalogFactory(abc.ABC):
@abc.abstractmethod
- def gen_realization(self, *args, t=None, dt=None, **kwargs):
+ def mock(self, *args, **kwargs):
raise NotImplementedError()
- def mock(self, *args, **kwargs):
- return self.gen_realization(self, *args, **kwargs)
+class SimpleCatalogConfig(Config):
+ return_copy = False
+ seed = None
+ n = 100
+ param_ranges = {}
+
+
+class SimpleCatalog(CatalogFactory):
+ default_config = SimpleCatalogConfig
+
+ def __init_from_table(self, table, config=None, **kwargs):
+ config = self.default_config(config=config, **kwargs)
+ config.n = len(table)
+ params = {}
+ for col in table.keys():
+ params[col] = (table[col].min(), table[col].max())
+ config.param_ranges.update(params)
+ return config, table
+
+ def __init_from_config(self, config, **kwargs):
+ config = self.default_config(config=config, method="subset", **kwargs)
+ table = gen_catalog(config.n, config.param_ranges, config.seed)
+ return config, table
+
+ def __init_from_ranges(self, **kwargs):
+ param_ranges = kwargs.pop("param_ranges", None)
+ if param_ranges is None:
+ param_ranges = {k: v for k, v in kwargs.items() if k in self.default_config.param_ranges}
+ kwargs = {k: v for k, v in kwargs.items() if k not in self.default_config.param_ranges}
+
+ config = self.default_config(**kwargs, method="subset")
+ config.param_ranges.update(param_ranges)
+ return self.__init_from_config(config=config)
+
+ def __init__(self, table=None, config=None, **kwargs):
+ if table is not None:
+ config, table = self.__init_from_table(table, config=config, **kwargs)
+ elif isinstance(config, Config):
+ config, table = self.__init_from_config(config=config, **kwargs)
+ elif isinstance(config, dict) or kwargs:
+ config = {} if config is None else config
+ config, table = self.__init_from_ranges(**{**config, **kwargs})
+ else:
+ raise ValueError(
+ "Expected table or config, or keyword arguments of expected "
+ f"catalog value ranges, got:\n table={table}\n config={config} "
+ f"\n kwargs={kwargs}"
+ )
+
+ self.config = config
+ self.table = table
+ self.current = 0
-class SimpleSourceCatalog(CatalogFactory):
- base_param_ranges = {
- "amplitude": [500, 2000],
- "x_mean": [0, 4096],
- "y_mean": [0, 2048],
- "x_stddev": [1, 7],
- "y_stddev": [1, 7],
- "theta": [0, np.pi],
- }
+ @classmethod
+ def from_config(cls, config, **kwargs):
+ config = cls.default_config(config=config, method="subset", **kwargs)
+ return cls(gen_catalog(config.n, config.param_ranges, config.seed), config=config)
- def __init__(self, table, return_copy=False):
- self.table = table
- self.return_copy = return_copy
+ @classmethod
+ def from_ranges(cls, n=None, config=None, **kwargs):
+ config = cls.default_config(n=n, config=config, method="subset")
+ config.param_ranges.update(**kwargs)
+ return cls.from_config(config)
@classmethod
- def from_params(cls, n=100, param_ranges=None):
- param_ranges = {} if param_ranges is None else param_ranges
- tmp = cls.base_param_ranges.copy()
- tmp.update(param_ranges)
- return cls(gen_catalog(n, tmp))
-
- def gen_realization(self, *args, t=None, dt=None, **kwargs):
- if self.return_copy:
+ def from_table(cls, table):
+ config = cls.default_config()
+ config.n = len(table)
+ params = {}
+ for col in table.keys():
+ params[col] = (table[col].min(), table[col].max())
+ config["param_ranges"] = params
+ return cls(table, config=config)
+
+ def mock(self):
+ self.current += 1
+ if self.config.return_copy:
return self.table.copy()
return self.table
-class SimpleObjectCatalog(CatalogFactory):
- base_param_ranges = {
- "amplitude": [1, 100],
- "x_mean": [0, 4096],
- "y_mean": [0, 2048],
- "vx": [500, 1000],
- "vy": [500, 1000],
- "stddev": [1, 1.8],
- "theta": [0, np.pi],
+class SourceCatalogConfig(SimpleCatalogConfig):
+ param_ranges = {
+ "amplitude": [1., 10.],
+ "x_mean": [0., 4096.],
+ "y_mean": [0., 2048.],
+ "x_stddev": [1., 3.],
+ "y_stddev": [1., 3.],
+ "theta": [0., np.pi],
}
- def __init__(self, table, obstime=None):
- self.table = table
- self._realization = table.copy()
+
+class SourceCatalog(SimpleCatalog):
+ default_config = SourceCatalogConfig
+
+
+class ObjectCatalogConfig(SimpleCatalogConfig):
+ param_ranges = {
+ "amplitude": [0.1, 3.0],
+ "x_mean": [0., 4096.],
+ "y_mean": [0., 2048.],
+ "vx": [500., 1000.],
+ "vy": [500., 1000.],
+ "stddev": [0.25, 1.5],
+ "theta": [0., np.pi],
+ }
+
+
+class ObjectCatalog(SimpleCatalog):
+ default_config = ObjectCatalogConfig
+
+ def __init__(self, table=None, obstime=None, config=None, **kwargs):
+ # put return_copy into kwargs to override whatever user might have
+ # supplied, and to guarantee the default is overriden
+ kwargs["return_copy"] = True
+ super().__init__(table=table, config=config, **kwargs)
+ self._realization = self.table.copy()
self.obstime = 0 if obstime is None else obstime
- @classmethod
- def from_params(cls, n=100, param_ranges=None):
- param_ranges = {} if param_ranges is None else param_ranges
- tmp = cls.base_param_ranges.copy()
- tmp.update(param_ranges)
- return cls(gen_catalog(n, tmp))
+ def reset(self):
+ self.current = 0
+ self._realization = self.table.copy()
def gen_realization(self, t=None, dt=None, **kwargs):
if t is None and dt is None:
return self._realization
dt = dt if t is None else t - self.obstime
- self._realization["x_mean"] += self._realization["vx"] * dt
- self._realization["y_mean"] += self._realization["vy"] * dt
+ self._realization["x_mean"] += self.table["vx"] * dt
+ self._realization["y_mean"] += self.table["vy"] * dt
return self._realization
def mock(self, n=1, **kwargs):
+ breakpoint()
if n == 1:
- return self.gen_realization(**kwargs)
- return [self.gen_realization(**kwargs).copy() for i in range(n)]
+ data = self.gen_realization(**kwargs)
+ self.current += 1
+ else:
+ data = []
+ for i in range(n):
+ data.append(self.gen_realization(**kwargs).copy())
+ self.current += 1
+
+ return data
diff --git a/src/kbmod/mocking/config.py b/src/kbmod/mocking/config.py
index 0bb94a8ac..ed34aa649 100644
--- a/src/kbmod/mocking/config.py
+++ b/src/kbmod/mocking/config.py
@@ -1,3 +1,5 @@
+import copy
+
__all__ = ["Config", "ConfigurationError"]
@@ -22,37 +24,66 @@ class attributes. Particular attributes can be overriden on an per-instance
Keyword arguments, assigned as configuration key-values.
"""
- def __init__(self, config=None, **kwargs):
+ def __init__(self, config=None, method="default", **kwargs):
# This is a bit hacky, but it makes life a lot easier because it
# enables automatic loading of the default configuration and separation
# of default config from instance bound config
keys = list(set(dir(self.__class__)) - set(dir(Config)))
# First fill out all the defaults by copying cls attrs
- self._conf = {k: getattr(self, k) for k in keys}
+ self._conf = {k: copy.copy(getattr(self, k)) for k in keys}
# Then override with any user-specified values
- conf = config
- if isinstance(config, Config):
- conf = config._conf
+ self.update(config=config, method=method, **kwargs)
- if conf is not None:
- self._conf.update(config)
- self._conf.update(kwargs)
+ @classmethod
+ def from_configs(cls, *args):
+ config = cls()
+ for conf in args:
+ config.update(config=conf, method="extend")
+ return config
- # now just shortcut the most common dict operations
def __getitem__(self, key):
return self._conf[key]
+ # now just shortcut the most common dict operations
+ def __getattribute__(self, key):
+ hasconf = "_conf" in object.__getattribute__(self, "__dict__")
+ if hasconf:
+ conf = object.__getattribute__(self, "_conf")
+ if key in conf:
+ return conf[key]
+ return object.__getattribute__(self, key)
+
def __setitem__(self, key, value):
self._conf[key] = value
+ def __repr__(self):
+ res = f"{self.__class__.__name__}("
+ for k, v in self.items():
+ res += f"{k}: {v}, "
+ return res[:-2] + ")"
+
def __str__(self):
res = f"{self.__class__.__name__}("
for k, v in self.items():
res += f"{k}: {v}, "
return res[:-2] + ")"
+ def _repr_html_(self):
+ repr = f"""
+
+ {self.__class__.__name__}
+
+ Key |
+ Value |
+
+ """
+ for k, v in self.items():
+ repr += f"{k} | {v}\n"
+ repr += " |
"
+ return repr
+
def __len__(self):
return len(self._conf)
@@ -76,7 +107,7 @@ def __or__(self, other):
elif isinstance(other, dict):
return self.__class__(config=self._conf | other)
else:
- raise TypeError("unsupported operand type(s) for |: {type(self)} " "and {type(other)}")
+ raise TypeError("unsupported operand type(s) for |: {type(self)}and {type(other)}")
def keys(self):
"""A set-like object providing a view on config's keys."""
@@ -90,7 +121,10 @@ def items(self):
"""A set-like object providing a view on config's items."""
return self._conf.items()
- def update(self, conf=None, **kwargs):
+ def copy(self):
+ return self.__class__(config=self._conf.copy())
+
+ def update(self, config=None, method="default", **kwargs):
"""Update this config from dict/other config/iterable and
apply any explicit keyword overrides.
@@ -107,9 +141,46 @@ def update(self, conf=None, **kwargs):
for k in kwargs: this[k] = kwargs[k]
"""
- if conf is not None:
- self._conf.update(conf)
- self._conf.update(kwargs)
+ # Python < 3.9 does not support set operations for dicts
+ # [fixme]: Update this to: other = conf | kwargs
+ # and remove current implementation when 3.9 gets too old. Order of
+ # conf and kwargs matter to correctly apply explicit overrides
+
+ # Check if both conf and kwargs are given, just conf or just
+ # kwargs. If none are given do nothing to comply with default
+ # dict behavior
+ if config is not None and kwargs:
+ other = {**config, **kwargs}
+ elif config is not None:
+ other = config
+ elif kwargs is not None:
+ other = kwargs
+ else:
+ return
+
+ # then, see if we the given config and overrides are a subset of this
+ # config or it's superset. Depending on the selected method then raise
+ # errors, ignore or extend the current config if the given config is a
+ # superset (or disjoint) from the current one.
+ subset = {k: v for k, v in other.items() if k in self._conf}
+ superset = {k: v for k, v in other.items() if k not in subset}
+
+ if method.lower() == "default":
+ if superset:
+ raise ConfigurationError(
+ "Tried setting the following fields, not a part of "
+ f"this configuration options: {superset}"
+ )
+ conf = other # == subset
+ elif method.lower() == "subset":
+ conf = subset
+ elif method.lower() == "extend":
+ conf = other
+ else:
+ raise ValueError("Method expected to be one of 'default', "
+ f"'subset' or 'extend'. Got {method} instead.")
+
+ self._conf.update(conf)
def toDict(self):
"""Return this config as a dict."""
diff --git a/src/kbmod/mocking/fits.py b/src/kbmod/mocking/fits.py
index f47579461..9eb46a62c 100644
--- a/src/kbmod/mocking/fits.py
+++ b/src/kbmod/mocking/fits.py
@@ -4,6 +4,7 @@
import numpy as np
from astropy.wcs import WCS
+from astropy.table import Table
from astropy.modeling import models
from astropy.io.fits import (
HDUList,
@@ -13,68 +14,73 @@
Header
)
-from .headers import HeaderFactory, ArchivedHeader
-from .catalogs import gen_catalog, SimpleSourceCatalog, SimpleObjectCatalog
+from .config import Config
+from .headers import HeaderFactory, ArchivedHeader, HeaderFactoryConfig
+from .catalogs import gen_catalog, SimpleCatalog, SourceCatalog, ObjectCatalog
from .fits_data import (
+ DataFactoryConfig,
DataFactory,
- ZeroedData,
+ SimpleImageConfig,
SimpleImage,
+ SimulatedImageConfig,
+ SimulatedImage,
+ SimpleVarianceConfig,
SimpleVariance,
+ SimpleMaskConfig,
SimpleMask,
add_model_objects
)
__all__ = [
- "HDUFactory",
- "HDUListFactory",
"callback",
+ "EmptyFitsConfig",
"EmptyFits",
+ "SimpleFitsConfig",
"SimpleFits",
- "DECamImdiffs",
+ "DECamImdiffConfig",
+ "DECamImdiff",
]
-class HDUFactory:
- def __init__(self, hdu_cls, header_factory, data_factory=None):
- self.hdu_cls = hdu_cls
- self.header_factory = header_factory
- self.data_factory = data_factory
- self.update_data = False if data_factory is None else True
+class HDUListFactoryConfig(Config):
+ validate_header = False
+ """Call ``Header.update`` instead of assigning header values. This enforces
+ FITS standards and may strip non-standard header keywords."""
- def mock(self, **kwargs):
- hdu = self.hdu_cls()
+ update_header = False
+ """After mocking and assigning data call ``Header.update_header``. This
+ enforces header to be consistent to the data type and shape. May alter or
+ remove keywords from the mocked header."""
- header = self.header_factory.mock(hdu=hdu, **kwargs)
- # hdu.header.update(header) costs more but maybe better?
- hdu.header = header
- if self.update_data:
- data = self.data_factory.mock(hdu=hdu, **kwargs)
- hdu.data = data
- # not sure if we want this tbh
- # hdu.update_header()
+class HDUListFactory(abc.ABC):
+ default_config = HDUListFactoryConfig
- return hdu
+ def __init__(self, config=None, **kwargs):
+ self.config = self.default_config(config=config, **kwargs)
+ self.current = 0
+ def hdu_cast(self, hdu_cls, hdr, data=None, config=None, **kwargs):
+ hdu = hdu_cls()
-class HDUListFactory(abc.ABC):
- def __init__(self, layout, base_primary={}, base_ext={}, base_wcs={}):
- self.layout = layout
- self.base_primary = base_primary
- self.base_ext = base_ext
- self.base_wcs = base_wcs
-
- def generate(self, **kwargs):
- hdul = HDUList()
- for layer in self.layout:
- hdul.append(layer.mock(**kwargs))
- return hdul
+ if self.config.validate_header:
+ hdu.header.update(hdr)
+ else:
+ hdu.header = hdr
+
+ if data is not None:
+ hdu.data = data
+ if self.config.update_header:
+ hdu.update_header()
+
+ return hdu
@abc.abstractmethod
def mock(self, n=1):
raise NotImplementedError()
+
# I am sure a decorator like this must exist somewhere in functools, but can't
# find it and I'm doing something wrong with functools.partial because that's
# strictly right-side binding?
@@ -90,9 +96,7 @@ def wrapper(*args, **kwargs):
def f(*fargs, **fkwargs):
kwargs.update(fkwargs)
return func(*(args + fargs), **kwargs)
-
return f
-
else:
# functions, static methods
def wrapper(*args, **kwargs):
@@ -100,170 +104,375 @@ def wrapper(*args, **kwargs):
def f(*fargs, **fkwargs):
kwargs.update(fkwargs)
return func(*fargs, **kwargs)
-
return f
-
return wrapper
+class EmptyFitsConfig(HDUListFactoryConfig):
+ editable_images = False
+ separate_masks = False
+ writeable_mask = False
+ dt = 0.001
+
+ shape = (100, 100)
+ """Dimensions of the generated images."""
+
+
class EmptyFits(HDUListFactory):
+ default_config = EmptyFitsConfig
+
@callback
@staticmethod
def increment_obstime(old, dt):
return old + dt
- def __init__(self, primary_hdr=None, dt=0.001):
- # header and data factories that go into creating HDUs
- self.primary = HeaderFactory.from_base_primary(
- metadata=primary_hdr,
+ def __init__(self, metadata=None, config=None, **kwargs):
+ super().__init__(config=config, method="extend", **kwargs)
+
+ # 1. Update all the default configs. Use class configs that simplify
+ # and unite the many smaller config settings of each underlying factory.
+ breakpoint()
+ hdr_conf = HeaderFactoryConfig(config=self.config, method="subset", shape=self.config.shape)
+
+ # 2. Set up Header and Data factories that go into creating HDUs
+ # 2.1) First headers, since that metadata specified data formats
+ self.primary_hdr = HeaderFactory.from_primary_template(
+ overrides=metadata,
mutables=["OBS-MJD"],
- callbacks=[self.increment_obstime(dt=dt)],
+ callbacks=[self.increment_obstime(dt=self.config.dt)],
+ config=hdr_conf
)
- self.image = HeaderFactory.from_base_ext({"EXTNAME": "IMAGE"})
- self.variance = HeaderFactory.from_base_ext({"EXTNAME": "VARIANCE"})
- self.mask = HeaderFactory.from_base_ext({"EXTNAME": "MASK"})
- data = ZeroedData()
-
- # a map of HDU class and their respective header and data generators
- layout = [
- HDUFactory(PrimaryHDU, self.primary),
- HDUFactory(CompImageHDU, self.image, data),
- HDUFactory(CompImageHDU, self.variance, data),
- HDUFactory(CompImageHDU, self.mask, data),
- ]
- super().__init__(layout)
+ self.img_hdr = HeaderFactory.from_ext_template({"EXTNAME": "IMAGE"}, config=hdr_conf)
+ self.var_hdr = HeaderFactory.from_ext_template({"EXTNAME": "VARIANCE"}, config=hdr_conf)
+ self.mask_hdr = HeaderFactory.from_ext_template({"EXTNAME": "MASK"}, config=hdr_conf)
+
+ # 2.2) Then data factories, attempt to save performance and memory
+ # where possible by really only allocating 1 array whenever the data
+ # is read-only and content-static between created HDUs.
+ writeable, return_copy = False, False
+ if self.config.editable_images:
+ writeable, return_copy = True, True
+
+ self.data = DataFactory.from_header(
+ kind="image",
+ header=self.img_hdr.header,
+ writeable=writeable,
+ return_copy=return_copy
+ )
+ self.shared_data = DataFactory.from_header("image", self.mask_hdr.header,
+ writeable=self.config.writeable_mask)
def mock(self, n=1):
- if n == 1:
- return self.step()
+ # 3) Finally when mocking, vectorize as many operations as possible.
+ # The amount of data scales fast with number of images generated so
+ # that even modest iterative generation of data costs significantly.
+ # (F.e. for each HDU - 3 images are allocated, if we use DECam, for
+ # each HDU, 62 to 70 images can be generated). Many times generated values
+ # are not trivial zeros, but randomly drawn, modified, and then
+ # calculated.
+ var_hdr = self.var_hdr.mock()
+ mask_hdr = self.mask_hdr.mock()
+
+ imgs = self.data.mock(n)
+ variances = self.data.mock(n)
+ masks = self.shared_data.mock(n)
- shape = (n*3, self.image.header["NAXIS1"], self.image.header["NAXIS1"])
- images = np.zeros(shape, dtype=np.float32)
- imghdr, varhdr, maskhdr = self.image.mock(), self.variance.mock(), self.mask.mock()
hduls = []
for i in range(n):
- hduls.append(HDUList(
- hdus=[
- PrimaryHDU(header=self.primary.mock()),
- CompImageHDU(header=imghdr, data=images[i]),
- CompImageHDU(header=varhdr, data=images[n+i]),
- CompImageHDU(header=maskhdr, data=images[2*n+i])
- ]
- ))
-
+ hduls.append(HDUList(hdus=[
+ self.hdu_cast(PrimaryHDU, self.primary_hdr.mock()),
+ self.hdu_cast(CompImageHDU, self.img_hdr.mock(), imgs[i]),
+ self.hdu_cast(CompImageHDU, var_hdr, variances[i]),
+ self.hdu_cast(CompImageHDU, mask_hdr, masks[i])
+ ]))
+
+ self.current += n
return hduls
+class SimpleFitsConfig(HDUListFactoryConfig):
+ editable_images = False
+ separate_masks = False
+ writeable_mask = False
+ noise_generation = "simplistic"
+ dt = 0.001
+
+ shape = (100, 100)
+ """Dimensions of the generated images."""
+
class SimpleFits(HDUListFactory):
+ default_config = SimpleFitsConfig
+
@callback
@staticmethod
def increment_obstime(old, dt):
return old + dt
- def __init__(self, shape, noise=0, noise_std=1, dt=0.001, source_catalog=None, object_catalog=None):
- # internal counter of n images created so far
- self._idx = 0
- self.dt = dt
+ def __init__(self, metadata=None, source_catalog=None, object_catalog=None,
+ config=None, **kwargs):
+ super().__init__(config=config, method="extend", **kwargs)
+
self.src_cat = source_catalog
self.obj_cat = object_catalog
- self.shape = shape
- self.noise = noise
- self.noise_std = noise_std
- dims = {"NAXIS1": shape[0], "NAXIS2": shape[1]}
-
- # Then we can generate the header factories primary header contains no
- # data, but does contain update-able metadata fields (f.e. timestamps),
- # others contain static headers, but dynamic, mutually connected data
- self.primary = HeaderFactory.from_base_primary(
+
+ # 1. Update all the default configs using more user-friendly kwargs
+ hdr_conf = HeaderFactoryConfig(config=self.config, shape=self.config.shape, method="subset")
+
+ if self.config.noise_generation == "realistic":
+ img_cfg = SimulatedImageConfig(config=self.config, method="subset")
+ else:
+ img_cfg = SimpleImageConfig(config=self.config, method="subset")
+ var_cfg = SimpleVarianceConfig(config=self.config, method="subset")
+ mask_cfg = SimpleMaskConfig(config=self.config, method="subset")
+
+ # 2. Set up Header and Data factories that go into creating HDUs
+ # 2.1) First headers, since that metadata specified data formats
+ self.primary_hdr = HeaderFactory.from_primary_template(
+ overrides=metadata,
mutables=["OBS-MJD"],
- callbacks=[self.increment_obstime(dt=0.001)],
- )
- self.image = HeaderFactory.from_base_ext({"EXTNAME": "IMAGE", **dims})
- self.variance = HeaderFactory.from_base_ext({"EXTNAME": "VARIANCE", **dims})
- self.mask = HeaderFactory.from_base_ext({"EXTNAME": "MASK", **dims})
-
- self.img_data = SimpleImage(
- shape=shape,
- noise=noise,
- noise_std=noise_std,
- src_cat=self.src_cat,
+ callbacks=[self.increment_obstime(dt=self.config.dt)],
+ config=hdr_conf
)
- self.var_data = SimpleVariance(self.img_data.base, read_noise=noise, gain=1.0)
- self.mask_data = DataFactory(np.zeros(shape))
-
- # Now we can build the HDU map and the HDUList layout
- layout = [
- HDUFactory(PrimaryHDU, self.primary),
- HDUFactory(CompImageHDU, self.image, self.img_data),
- HDUFactory(CompImageHDU, self.variance, self.var_data),
- HDUFactory(CompImageHDU, self.mask, self.mask_data),
- ]
-
- super().__init__(layout)
+ self.img_hdr = HeaderFactory.from_ext_template({"EXTNAME": "IMAGE"}, config=hdr_conf)
+ self.var_hdr = HeaderFactory.from_ext_template({"EXTNAME": "VARIANCE"}, config=hdr_conf)
+ self.mask_hdr = HeaderFactory.from_ext_template({"EXTNAME": "MASK"}, config=hdr_conf)
+
+ # 2.2) Then data factories
+ if self.config.noise_generation == "realistic":
+ self.img_data = SimulatedImage(src_cat=self.src_cat, config=img_cfg)
+ else:
+ self.img_data = SimpleImage(src_cat=self.src_cat, config=img_cfg)
+ self.var_data = SimpleVariance(self.img_data.base, config=var_cfg)
+ self.mask_data = SimpleMask.from_image(self.img_data.base, config=mask_cfg)
@classmethod
- def from_defaults(cls, shape=(2048, 4096), add_static_sources=False, add_moving_objects=True, **kwargs):
- source_catalog, object_catalog = None, None
- cat_lims = {"x_mean": [0, shape[1]], "y_mean": [0, shape[0]]}
- if add_static_sources:
- source_catalog = SimpleSourceCatalog.from_params(param_ranges=cat_lims)
- if add_moving_objects:
- object_catalog = SimpleObjectCatalog.from_params(param_ranges=cat_lims)
-
- return cls(shape=shape, source_catalog=source_catalog, object_catalog=object_catalog, **kwargs)
+ def from_defaults(cls, config=None, sources=None, objects=None, **kwargs):
+ config = cls.default_config(config, **kwargs, method="extend")
+ cat_lims = {"x_mean": [0, config.shape[1]], "y_mean": [0, config.shape[0]]}
+ src_cat, obj_cat = None, None
+ if sources is not None:
+ if isinstance(sources, SimpleCatalog):
+ src_cat = sources
+ elif isinstance(sources, Table):
+ src_cat = SourceCatalog(table=sources)
+ elif isinstance(sources, Config) or sources == True:
+ sources.param_ranges.update(cat_lims)
+ src_cat = SourceCatalog(**sources)
+ elif isinstance(sources, dict):
+ sources.update(cat_lims)
+ src_cat = SourceCatalog(**sources)
+ elif isinstance(sources, int):
+ src_cat = SourceCatalog(n=sources, **cat_lims)
+ elif sources is True:
+ obj_cat = SourceCatalog(**cat_lims)
+ else:
+ raise ValueError("Sources are expected to be a catalog, table, config or config overrides.")
+ if objects is not None:
+ if isinstance(objects, SimpleCatalog):
+ obj_cat = objects
+ elif isinstance(objects, Table):
+ obj_cat = ObjectCatalog(table=objects)
+ elif isinstance(objects, Config):
+ objects.param_ranges.update(cat_lims)
+ obj_cat = ObjectCatalog(**objects)
+ elif isinstance(objects, dict):
+ objects.update(cat_lims)
+ obj_cat = ObjectCatalog(**objects)
+ elif isinstance(objects, int):
+ obj_cat = ObjectCatalog(n=objects, **cat_lims)
+ elif objects is True:
+ obj_cat = ObjectCatalog(**cat_lims)
+ else:
+ raise ValueError("Objects are expected to be a catalog, table, config or config overrides.")
+ return cls(source_catalog=src_cat, object_catalog=obj_cat, config=config, **kwargs)
def mock(self, n=1):
obj_cats = None
-
if self.obj_cat is not None:
- obj_cats = self.obj_cat.mock(n, dt=self.dt)
+ obj_cats = self.obj_cat.mock(n, dt=self.config.dt)
+
+ var_hdr = self.var_hdr.mock()
+ mask_hdr = self.mask_hdr.mock()
images = self.img_data.mock(n, obj_cats=obj_cats)
variances = self.var_data.mock(images=images)
- mask = self.mask_data.mock()
- imghdr, varhdr, maskhdr = self.image.mock(), self.variance.mock(), self.mask.mock()
+ masks = self.mask_data.mock(n)
hduls = []
for i in range(n):
- hduls.append(HDUList(
- hdus=[
- PrimaryHDU(header=self.primary.mock()),
- CompImageHDU(header=imghdr, data=images[i]),
- CompImageHDU(header=varhdr, data=variances[i]),
- CompImageHDU(header=maskhdr, data=mask),
- ]
- ))
+ hduls.append(HDUList(hdus=[
+ self.hdu_cast(PrimaryHDU, self.primary_hdr.mock()),
+ self.hdu_cast(CompImageHDU, self.img_hdr.mock(), images[i]),
+ self.hdu_cast(CompImageHDU, var_hdr, variances[i]),
+ self.hdu_cast(CompImageHDU, mask_hdr, masks[i])
+ ]))
+
+ self.current += n
return hduls
-class DECamImdiffs(HDUListFactory):
- def __init__(self):
- headers = ArchivedHeader("headers_archive.tar.bz2", "decam_imdiff_headers.ecsv")
- data = ZeroedData()
- image = HDUFactory(CompImageHDU, headers, data)
-
- # DECam imdiff data products consist of 16 headers, first 4 are:
- # primary, science variance and mask and the rest are PSF, SkyWCS,
- # catalog metadata, chebyshev higher order corrections etc. We don't
- # use these, so leave them with no data generators. The first 4 we fill
- # with all-zeros.
- # don't let Black have its way with these lines because it's a massacre
- # fmt: off
- layout = [
- HDUFactory(PrimaryHDU, headers),
- image,
- image,
- image,
- ]
- layout.extend([HDUFactory(BinTableHDU, headers) ] * 12)
- # fmt: on
+class DECamImdiffConfig2(HDUListFactoryConfig):
+ archive_name = "headers_archive.tar.bz2"
+ file_name = "decam_imdiff_headers.ecsv"
+ n_hdrs_per_hdu = 16
+
+
+class DECamImdiff2(HDUListFactory):
+ default_config = DECamImdiffConfig2
+
+ def __init__(self, config=None, **kwargs):
+ super().__init__(config=config, **kwargs)
+ self.hdr_factory = ArchivedHeader(self.config.archive_name, self.config.file_name)
+
+ def mock(self, n=1):
+ hdrs = self.hdr_factory.mock(n)
+
+ hduls = []
+ for hdul_idx in range(n):
+ hduls.append(HDUList(hdus=[
+ self.hdu_cast(PrimaryHDU, hdrs[hdul_idx][0]),
+ self.hdu_cast(CompImageHDU, hdrs[hdul_idx][1]),
+ self.hdu_cast(CompImageHDU, hdrs[hdul_idx][2]),
+ self.hdu_cast(CompImageHDU, hdrs[hdul_idx][3]),
+ self.hdu_cast(BinTableHDU, hdrs[hdul_idx][4]),
+ self.hdu_cast(BinTableHDU, hdrs[hdul_idx][5]),
+ self.hdu_cast(BinTableHDU, hdrs[hdul_idx][6]),
+ self.hdu_cast(BinTableHDU, hdrs[hdul_idx][7]),
+ self.hdu_cast(BinTableHDU, hdrs[hdul_idx][8]),
+ self.hdu_cast(BinTableHDU, hdrs[hdul_idx][9]),
+ self.hdu_cast(BinTableHDU, hdrs[hdul_idx][10]),
+ self.hdu_cast(BinTableHDU, hdrs[hdul_idx][11]),
+ self.hdu_cast(BinTableHDU, hdrs[hdul_idx][12]),
+ self.hdu_cast(BinTableHDU, hdrs[hdul_idx][13]),
+ self.hdu_cast(BinTableHDU, hdrs[hdul_idx][14]),
+ self.hdu_cast(BinTableHDU, hdrs[hdul_idx][15]),
+ ]))
+
+ self.current += 1
+ return hduls
+
+
+
+
+class DECamImdiffConfig(HDUListFactoryConfig):
+ archive_name = "headers_archive.tar.bz2"
+ file_name = "decam_imdiff_headers.ecsv"
+
+ with_data=False
+
+ editable_images = False
+ separate_masks = False
+ writeable_mask = False
+ noise_generation = "simplistic"
+ dt = 0.001
+
+ shape = (100, 100)
+ """Dimensions of the generated images."""
- super().__init__(layout)
+
+class NoneFactory:
+ "Kinda makes some code later prettier. Kinda"
+ def mock(self, n):
+ return [None, ]*n
+
+
+class DECamImdiff(HDUListFactory):
+ default_config = DECamImdiffConfig
+
+ def __init__(self, source_catalog=None, object_catalog=None, config=None, **kwargs):
+ super().__init__(config=config, **kwargs)
+
+ # 1. Get the header factory - this is different than before. In a
+ # header preserving factory, it's the headers that dictate the shape
+ # and format of data. Since these are also optional for this class, we
+ # just create empty placeholders for data and only fill it with
+ # factories if we need to.
+ self.hdr_factory = ArchivedHeader(self.config.archive_name,
+ self.config.file_name)
+
+ self.hdu_types = [PrimaryHDU, CompImageHDU, CompImageHDU, CompImageHDU]
+ self.hdu_types.extend([BinTableHDU, ]*12)
+ self.data = [NoneFactory(), ]*16
+
+ self.src_cat = source_catalog
+ self.obj_cat = object_catalog
+
+ if self.config.with_data:
+ self.__init_data_factories()
+
+ def __init_data_factories(self):
+ # 2. To fill in data placehodlers, we get an example set of headers in
+ # a file. Since they dictate data, we must derive whatever configs we
+ # need to and use them as overrides to trump defaults and user-given
+ # values. This is the point at which we rely on user telling this
+ # factory what the important header keys are since these could be
+ # non-standard. And since these change for each HDU, it is rather
+ # difficult to generalize without resorting to iterative and fully
+ # dynamical data generation - which will be slow. The only mercy is
+ # that we only have to do it for the HDUs we care about and use factory
+ # methods for the rest.
+ headers = self.hdr_factory.get(0)
+
+ img_shape = (headers[1]["NAXIS1"], headers[1]["NAXIS2"])
+ img_overrides = {
+ "shape": img_shape,
+ "dtype": DataFactoryConfig.bitpix_type_map[headers[1]["BITPIX"]],
+ }
+
+ if self.config.noise_generation == "realistic":
+ img_cfg = SimulatedImageConfig(config=self.config, **img_overrides, method="subset")
+ else:
+ img_cfg = SimpleImageConfig(config=self.config, **img_overrides, method="subset")
+ var_cfg = SimpleVarianceConfig(config=self.config, **img_overrides, method="subset")
+ mask_cfg = SimpleMaskConfig(config=self.config, **img_overrides, method="subset")
+
+ # 2.1) Now we can instantiate data factories with correct configs
+ # and fill in the data placeholder
+ if self.config.noise_generation == "realistic":
+ self.img_data = SimulatedImage(src_cat=self.src_cat, config=img_cfg)
+ else:
+ self.img_data = SimpleImage(src_cat=self.src_cat, config=img_cfg)
+ self.var_data = SimpleVariance(self.img_data.base, config=var_cfg)
+ self.mask_data = SimpleMask.from_image(self.img_data.base, config=mask_cfg)
+
+ self.hdu_types = [PrimaryHDU, CompImageHDU, CompImageHDU, CompImageHDU]
+ self.hdu_types.extend([BinTableHDU, ]*12)
+
+ self.data = [
+ NoneFactory(),
+ self.img_data,
+ self.var_data,
+ self.mask_data
+ ]
+ self.data.extend([
+ DataFactory.from_header("table", h) for h in headers[4:]
+ ])
def mock(self, n=1):
- if n==1:
- return self.generate()
- hduls = [self.generate() for i in range(n)]
+ obj_cats = None
+ if self.obj_cat is not None:
+ obj_cats = self.obj_cat.mock(n, dt=self.config.dt)
+
+ hdrs = self.hdr_factory.mock(n)
+
+ if self.config.with_data:
+ images = self.img_data.mock(n, obj_cats=obj_cats)
+ variances = self.var_data.mock(images=images)
+ data = [self.data[0].mock(n), images, variances]
+ for factory in self.data[3:]:
+ data.append(factory.mock(n=n))
+ else:
+ data = [f.mock(n=n) for f in self.data]
+
+ hduls = []
+ for hdul_idx in range(n):
+ hdus = []
+ for hdu_idx, hdu_cls in enumerate(self.hdu_types):
+ hdus.append(self.hdu_cast(
+ hdu_cls, hdrs[hdul_idx][hdu_idx], data[hdu_idx][hdul_idx]
+ ))
+ hduls.append(HDUList(hdus=hdus))
+ self.current += n
return hduls
diff --git a/src/kbmod/mocking/fits_data.py b/src/kbmod/mocking/fits_data.py
index ffd8d79fb..79b6da1e3 100644
--- a/src/kbmod/mocking/fits_data.py
+++ b/src/kbmod/mocking/fits_data.py
@@ -9,17 +9,25 @@
from astropy.modeling import models
from .config import Config, ConfigurationError
+from kbmod import Logging
__all__ = [
"add_model_objects",
+ "DataFactoryConfig",
"DataFactory",
- "ZeroedData",
+ "SimpleImageConfig",
"SimpleImage",
- "SimpleMask"
+ "SimpleMaskConfig",
+ "SimpleMask",
+ "SimulatedImageConfig",
+ "SimulatedImage"
]
+logger = Logging.getLogger(__name__)
+
+
def add_model_objects(img, catalog, model):
"""Adds a catalog of model objects to the image.
@@ -50,6 +58,9 @@ def add_model_objects(img, catalog, model):
# Save the initial model parameters so we can set them back
init_params = {param: getattr(model, param) for param in params_to_set}
+ # model could throw a value error if drawn amplitude was too large, we must
+ # restore the model back to its starting values to cover for a general
+ # use-case because Astropy modelling is a bit awkward.
try:
for i, source in enumerate(catalog):
for param in params_to_set:
@@ -62,7 +73,6 @@ def add_model_objects(img, catalog, model):
model.y_mean < img.shape[0]
]):
model.render(img)
-
finally:
for param, value in init_params.items():
setattr(model, param, value)
@@ -74,6 +84,13 @@ class DataFactoryConfig(Config):
"""Data factory configuration primarily controls mutability of the given
and returned mocked datasets.
"""
+ default_img_shape = (5, 5)
+
+ default_img_bit_width = 32
+
+ default_tbl_length = 5
+
+ default_tbl_dtype = np.dtype([("a", int), ("b", int)])
writeable = False
"""Sets the base array ``writeable`` flag. Default `False`."""
@@ -84,6 +101,21 @@ class DataFactoryConfig(Config):
otherwise the original (possibly mutable!) object is returned. Default `False`.
"""
+ # https://archive.stsci.edu/fits/fits_standard/node39.html#s:man
+ bitpix_type_map = {
+ # or char
+ 8: int,
+ # actually no idea what dtype, or C type for that matter,
+ # are used to represent these values. But default Headers return them
+ 16: np.float16,
+ 32: np.float32,
+ 64: np.float64,
+ # classic IEEE float and double
+ -32: np.float32,
+ -64: np.float64,
+ }
+ """Map between FITS header BITPIX keyword value and NumPy return type."""
+
class DataFactory:
"""Generic data factory.
@@ -132,32 +164,103 @@ class DataFactory:
default_config = DataFactoryConfig
"""Default configuration."""
- def __init__(self, base=None, config=None, **kwargs):
- self.config = self.default_config()
- self.config.update(config, **kwargs)
+ def __init__(self, base, config=None, **kwargs):
+ self.config = self.default_config(config, **kwargs)
self.base = base
- if base is not None:
+ if base is None:
+ self.shape = None
+ self.dtype = None
+ else:
+ self.shape = base.shape
+ self.dtype = base.dtype
self.base.flags.writeable = self.config.writeable
self.counter = 0
- def mock(self, **kwargs):
- """Return a mocked fits data and increase counter.
+ @classmethod
+ def gen_image(cls, metadata=None, config=None, **kwargs):
+ conf = cls.default_config(config, method="subset", **kwargs)
+ cols = metadata.get("NAXIS1", conf.default_img_shape[0])
+ rows = metadata.get("NAXIS2", conf.default_img_shape[1])
+ bitwidth = metadata.get("BITPIX", conf.default_img_bit_width)
+ dtype = conf.bitpix_type_map[bitwidth]
+ shape = (cols, rows)
- Parameters
- ----------
- **kwargs
- Any additional keyword arguments are ignored.
+ return np.zeros(shape, dtype)
+
+ @classmethod
+ def gen_table(cls, metadata=None, config=None, **kwargs):
+ conf = cls.default_config(config, **kwargs, method="subset")
+
+ # FITS format standards prescribe FORTRAN-77-like input format strings
+ # for different types, but the base set has been extended significantly
+ # since and Rubin uses completely non-standard keys with support for
+ # their own internal abstractions like 'Angle' objects:
+ # https://archive.stsci.edu/fits/fits_standard/node58.html
+ # https://docs.astropy.org/en/stable/io/fits/usage/table.html#column-creation
+ # https://github.com/lsst/afw/blob/main/src/fits.cc#L207
+ # So we really don't have much of a choice but to force a default
+ # AstroPy HDU and then call the update. This might not preserve the
+ # header or the data formats exactly and if metadata isn't given
+ # could even assume a wrong class all together. The TableHDU is
+ # almost never used however - so hopefully this keeps on working.
+ table_cls = BinTableHDU
+ data = None
+ if metadata is not None:
+ if metadata["XTENSION"] == "BINTABLE":
+ table_cls = BinTableHDU
+ elif metadata["XTENSION"] == "TABLE":
+ table_cls = TableHDU
+
+ hdu = table_cls()
+ hdu.header.update(metadata)
+
+ rows = metadata.get("NAXIS2", conf.default_tbl_length)
+ shape = (rows, )
+ data = np.zeros(shape, dtype=hdu.data.dtype)
+ else:
+ hdu = table_cls()
+ shape = (conf.default_tbl_length, )
+ data = np.zeros(shape, dtype=conf.default_tbl_dtype)
+
+ return data
+
+ @classmethod
+ def from_hdu(cls, hdu, config=None, **kwargs):
+ if isinstance(hdu, (PrimaryHDU, CompImageHDU, ImageHDU)):
+ return cls(base=cls.gen_image(hdu), config=config, **kwargs)
+ elif isinstance(hdu, (TableHDU, BinTableHDU)):
+ return cls(base=cls.gen_table(hdu), config=config, **kwargs)
+ else:
+ raise TypeError(f"Expected an HDU, got {type(hdu)} instead.")
+
+ @classmethod
+ def from_header(cls, kind, header, config=None, **kwargs):
+ if kind.lower() == "image":
+ return cls(base=cls.gen_image(header), config=config, **kwargs)
+ elif kind.lower() == "table":
+ return cls(base=cls.gen_table(header), config=config, **kwargs)
+ else:
+ raise TypeError(f"Expected an 'image' or 'table', got {kind} instead.")
+
+ @classmethod
+ def zeros(cls, shape, dtype, config=None, **kwargs):
+ return cls(np.zeros(shape, dtype), config, **kwargs)
+
+ def mock(self, n=1, **kwargs):
+ if self.base is None:
+ raise ValueError(
+ "Expected a DataFactory that has a base, but none was set. "
+ "Use `zeros` or `from_hdu` to construct this object correctly."
+ )
- Returns
- -------
- data : `no.array`
- Mocked data array
- """
- self.counter += 1
if self.config.return_copy:
- return self.base.copy()
- return self.base
+ base = np.repeat(self.base[np.newaxis, ], (n,), axis=0)
+ else:
+ base = np.broadcast_to(self.base, (n, *self.shape))
+ base.flags.writeable = self.config.writeable
+
+ return base
class SimpleVarianceConfig(DataFactoryConfig):
@@ -207,7 +310,15 @@ def mock(self, images=None):
class SimpleMaskConfig(DataFactoryConfig):
"""Simple mask configuration."""
- pass
+ dtype = np.float32
+
+ threshold = 1e-05
+
+ shape = (5, 5)
+ padding = 0
+ bad_columns = []
+ patches = []
+
class SimpleMask(DataFactory):
"""Simple mask factory.
@@ -223,11 +334,19 @@ class SimpleMask(DataFactory):
"""
default_config = SimpleMaskConfig
- def __init__(self, mask):
- super().__init__(base=mask)
+
+ def __init__(self, mask, config=None, **kwargs):
+ super().__init__(base=mask, config=config, **kwargs)
@classmethod
- def from_params(cls, shape, padding=0, bad_columns=[], patches=[]):
+ def from_image(cls, image, config=None, **kwargs):
+ config = cls.default_config(config=config, **kwargs, method="subset")
+ mask = image.copy()
+ mask[image > config.threshold] = 1
+ return cls(mask)
+
+ @classmethod
+ def from_params(cls, config=None, **kwargs):
"""Create a mask by adding a padding around the edges of the array with
the given dimensions and mask out bad columns.
@@ -273,19 +392,22 @@ def from_params(cls, shape, padding=0, bad_columns=[], patches=[]):
[1., 0., 1., 1., 0., 0., 0., 0., 0., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
"""
- mask = np.zeros(shape)
+ config = cls.default_config(config=config, **kwargs, method="subset")
+ mask = np.zeros(config.shape, dtype=config.dtype)
+
+ shape, padding = config.shape, config.padding
# padding
mask[:padding] = 1
- mask[shape[0] - padding :] = 1
+ mask[shape[0] - padding:] = 1
mask[:, :padding] = 1
- mask[: shape[1] - padding :] = 1
+ mask[: shape[1] - padding:] = 1
# bad columns
- for col in bad_columns:
+ for col in config.bad_columns:
mask[:, col] = 1
- for patch, value in patches:
+ for patch, value in config.patches:
if isinstance(patch, tuple):
mask[*patch] = 1
elif isinstance(slice):
@@ -293,83 +415,7 @@ def from_params(cls, shape, padding=0, bad_columns=[], patches=[]):
else:
raise ValueError(f"Expected a tuple (x, y), (slice, slice) or slice, got {patch} instead.")
- return cls(mask)
-
-
-class ZeroedDataConfig(DataFactoryConfig):
- """Zeroed data config."""
- return_copy = True
-
- shape = (5, 5)
- """Default image size."""
-
- # https://archive.stsci.edu/fits/fits_standard/node39.html#s:man
- bitpix_type_map = {
- # or char
- 8: int,
- # actually no idea what dtype, or C type for that matter,
- # are used to represent these values. But default Headers return them
- 16: np.float16,
- 32: np.float32,
- 64: np.float64,
- # classic IEEE float and double
- -32: np.float32,
- -64: np.float64,
- }
- """Map between FITS header BITPIX keyword value and NumPy return type."""
-
-
-class ZeroedData(DataFactory):
- """Data factory that creates zeroed data arrays from header definitions.
-
- A convenience factory able to generate the correct size image and bintable
- data from a header.
-
- Parameters
- ----------
- base : `np.array`
- Static data shared by all mocked instances.
- config : `DataFactoryConfig`
- Configuration of the data factory.
- **kwargs :
- Additional keyword arguments are applied as config
- overrides.
- """
- default_config = ZeroedDataConfig
-
- def __init__(self, base=None, config=None, **kwargs):
- super().__init__(base, config, **kwargs)
-
- def mock_image_data(self, hdu):
- cols = hdu.header.get("NAXIS1", False)
- rows = hdu.header.get("NAXIS2", False)
- shape = (cols, rows) if all((cols, rows)) else self.config.shape
-
- data = np.zeros(
- shape,
- dtype=self.config.bitpix_type_map[hdu.header["BITPIX"]]
- )
- return data
-
- def mock_table_data(self, hdu):
- # interestingly, table HDUs create their own empty
- # tables from headers, but image HDUs do not, this
- # is why hdu.data exists and has a valid dtype
- nrows = hdu.header["TFIELDS"]
- return np.zeros((nrows,), dtype=hdu.data.dtype)
-
- def mock(self, hdu=None, **kwargs):
- # two cases: a) static data shared by all instances created by the
- # factory set at init time or b) dynamic data generated from a header
- # specification at call-time.
- if self.base is not None:
- return super().mock(hdu=hdu, **kwargs)
- if isinstance(hdu, (PrimaryHDU, CompImageHDU, ImageHDU)):
- return self.mock_image_data(hdu)
- elif isinstance(hdu, (TableHDU, BinTableHDU)):
- return self.mock_table_data(hdu)
- else:
- raise TypeError(f"Expected an HDU, got {type(hdu)} instead.")
+ return cls(mask, config=config)
class SimpleImageConfig(DataFactoryConfig):
@@ -386,7 +432,7 @@ class SimpleImageConfig(DataFactoryConfig):
seed = None
"""Seed of the random number generator used to generate noise."""
- noise = 0
+ noise = 10
"""Mean of the standard Gaussian distribution of the noise."""
noise_std = 1.0
@@ -422,8 +468,10 @@ class SimpleImage(DataFactory):
"""
default_config = SimpleImageConfig
- def __init__(self, image=None, config=None, src_cat=None, obj_cat=None, **kwargs):
- super().__init__(image, config, **kwargs)
+ def __init__(self, image=None, src_cat=None, obj_cat=None,
+ config=None, **kwargs):
+ self.config = self.default_config(config=config, **kwargs)
+ super().__init__(image, self.config, **kwargs)
if image is None:
image = np.zeros(self.config.shape, dtype=np.float32)
@@ -431,9 +479,15 @@ def __init__(self, image=None, config=None, src_cat=None, obj_cat=None, **kwargs
image = image
self.config.shape = image.shape
+ # Astropy throws a strange ValueError instead of reporting a non-writeable
+ # array, This must be a bug TODO: report. It's not safe to edit a
+ # non-writeable array and then revert writeability so make a copy.
self.src_cat = src_cat
if self.src_cat is not None:
- add_model_objects(image, src_cat.table, self.config.model)
+ image = image if image.flags.writeable else image.copy()
+ add_model_objects(image, src_cat.table,
+ self.config.model(x_stddev=1, y_stddev=1))
+ image.flags.writeable = self.config.writeable
self.base = image
self._base_contains_data = image.sum() != 0
@@ -456,7 +510,7 @@ def add_noise(cls, images, config):
rng.standard_normal(size=shape, dtype=images.dtype, out=images)
# There's a lot of multiplications that happen, skip if possible
- if self.config.noise_std != 1.0:
+ if config.noise_std != 1.0:
images *= config.noise_std
images += config.noise
@@ -476,16 +530,21 @@ def mock(self, n=1, obj_cats=None, **kwargs):
shape = (n, *self.config.shape)
images = np.zeros(shape, dtype=np.float32)
- if self.config.noise != 0:
- images = self.add_noise(n, images, self.config)
+ if self.config.add_noise:
+ images = self.add_noise(images=images, config=self.config)
- # but if base has no data (no sources, bad cols etc) skip
+ # if base has no data (no sources, bad cols etc) skip
if self._base_contains_data:
images += self.base
- # same with moving objects
+ # same with moving objects, each image has to have a new realization of
+ # a catalog in which moving objects have different coordinate. This way
+ # any trajectory can be mocked. When we have only 1 mocked image though
+ # zip will attempt to iterate over the next available dimension, and
+ # that's rows of the image and the table - we don't want that.
if obj_cats is not None:
- for i, (img, cat) in enumerate(zip(images, obj_cats)):
+ pairs = [(images[0], obj_cats),] if n == 1 else zip(images, obj_cats)
+ for i, (img, cat) in enumerate(pairs):
add_model_objects(
img,
cat,
@@ -526,6 +585,9 @@ class SimulatedImageConfig(DataFactoryConfig):
# not sure this is a smart idea to put here
rng = np.random.default_rng()
+ seed = None
+ """Random number generator seed shared by all number generators."""
+
# image properties
shape = (100, 100)
"""Dimensions of the created images."""
@@ -546,7 +608,7 @@ class SimulatedImageConfig(DataFactoryConfig):
bias = 0.0
"""Bias in counts."""
- add_bad_cols = False
+ add_bad_columns = False
"""Add bad columns to the image."""
bad_cols_method = "random"
@@ -559,12 +621,15 @@ class SimulatedImageConfig(DataFactoryConfig):
n_bad_cols = 5
"""When bad columns method is random, sets the number of bad columns."""
- bad_cols_seed = 123
+ bad_cols_seed = seed
"""Seed for the bad columns random number generator."""
- bad_col_pattern_offset = 0.1
+ bad_col_offset = 100
+ """Bad column signal offset (in counts) with regards to the baseline noise."""
+
+ bad_col_pattern_offset = 10
"""Random-looking noise variation along the length of the bad columns is
- offset from the base column counts by this amount multiplied by bias."""
+ offset from the mean bad column counts by this amount."""
dark_current_gen = rng.poisson
"""Dark current follows a Poisson distribution."""
@@ -582,7 +647,7 @@ class SimulatedImageConfig(DataFactoryConfig):
hot_pix_locs = []
"""A `list[tuple]` of hot pixel indices."""
- hot_pix_seed = 321
+ hot_pix_seed = seed
"""Seed for hot pixel random number generator."""
n_hot_pix = 10
@@ -644,25 +709,29 @@ def add_bad_cols(cls, image, config):
image : `np.array`
Image.
"""
- if not config.add_bad_cols:
+ if not config.add_bad_columns:
return image
+ shape = image.shape
+ rng = np.random.RandomState(seed=config.bad_cols_seed)
if config.bad_cols_method == "random":
- rng = np.random.RandomState(seed=self.bad_cols_seed)
bad_cols = rng.randint(0, shape[1], size=config.n_bad_cols)
elif config.bad_col_locs:
bad_cols = config.bad_col_locs
else:
- raise ConfigurationError("Bad columns method is not 'random', but `bad_col_locs` contains no bad column indices.")
+ raise ConfigurationError(
+ "Bad columns method is not 'random', but `bad_col_locs` "
+ "contains no column indices."
+ )
- self.col_pattern = rng.randint(
+ col_pattern = rng.randint(
low=0,
- high=int(config.bad_col_pattern_offset * config.bias),
+ high=int(config.bad_col_pattern_offset),
size=shape[0]
)
- for col in columns:
- image[:, col] = config.bias + col_pattern
+ for col in bad_cols:
+ image[:, col] += col_pattern + config.bad_col_offset
return image
@@ -688,23 +757,27 @@ def add_hot_pixels(cls, image, config):
if not config.add_hot_pixels:
return image
+ shape = image.shape
if config.hot_pix_method == "random":
rng = np.random.RandomState(seed=config.hot_pix_seed)
x = rng.randint(0, shape[1], size=config.n_hot_pix)
y = rng.randint(0, shape[0], size=config.n_hot_pix)
- hot_pixels = np.column_stack(x, y)
+ hot_pixels = np.column_stack([x, y])
elif config.hot_pix_locs:
hot_pixels = pixels
else:
- raise ConfigurationError("Hot pixels method is not 'random', but `hot_pix_locs` contains no (col, row) location indices of hot pixels.")
+ raise ConfigurationError(
+ "Hot pixels method is not 'random', but `hot_pix_locs` contains "
+ "no (col, row) location indices of hot pixels."
+ )
for pix in hot_pixels:
- image[*pix] += offset
+ image[*pix] += config.hot_pix_offset
return image
@classmethod
- def add_noise(cls, n, images, config):
+ def add_noise(cls, images, config):
"""Adds read noise (gaussian), dark noise (poissonian) and sky
background (poissonian) noise to the image.
@@ -723,7 +796,7 @@ def add_noise(cls, n, images, config):
shape = images.shape
# add read noise
- images += self.read_noise_gen(
+ images += config.read_noise_gen(
scale=config.read_noise / config.gain,
size=shape
)
@@ -733,7 +806,7 @@ def add_noise(cls, n, images, config):
images += config.dark_current_gen(current, size=shape)
# add sky counts
- images += self.sky_count_gen(
+ images += config.sky_count_gen(
lam=config.sky_level * config.gain,
size=shape
) / config.gain
@@ -761,14 +834,20 @@ def gen_base_image(cls, config=None, src_cat=None):
# empty image
base = np.zeros(config.shape, dtype=np.float32)
base += config.bias
- base = cls.add_bad_cols(base, config)
base = cls.add_hot_pixels(base, config)
+ base = cls.add_bad_cols(base, config)
if src_cat is not None:
- add_model_objects(base, src_cat.table, config.model)
+ add_model_objects(base, src_cat.table,
+ config.model(x_stddev=1, y_stddev=1))
return base
- def __init__(self, image=None, config=None, src_cat=None, obj_cat=None):
- conf = self.default_config(config)
- base = self.gen_base_image(conf)
- super().__init__(image=base, config=conf, src_cat=src_cat, obj_cat=obj_cat)
+ def __init__(self, image=None, config=None, src_cat=None, obj_cat=None, **kwargs):
+ conf = self.default_config(config=config, **kwargs)
+ # static objects are added in SimpleImage init
+ super().__init__(
+ image=self.gen_base_image(conf),
+ config=conf,
+ src_cat=src_cat,
+ obj_cat=obj_cat
+ )
diff --git a/src/kbmod/mocking/headers.py b/src/kbmod/mocking/headers.py
index 033ecee1f..3947d978c 100644
--- a/src/kbmod/mocking/headers.py
+++ b/src/kbmod/mocking/headers.py
@@ -5,16 +5,18 @@
from astropy.io.fits import Header
from .utils import header_archive_to_table
+from .config import Config
__all__ = [
"HeaderFactory",
+ "HeaderFactoryConfig",
"ArchivedHeader",
]
-class HeaderFactory:
- base_primary = {
+class HeaderFactoryConfig(Config):
+ primary_template = {
"EXTNAME": "PRIMARY",
"NAXIS": 0,
"BITPIX": 8,
@@ -26,9 +28,16 @@ class HeaderFactory:
"OBSERVAT": "CTIO"
}
- base_ext = {"NAXIS": 2, "NAXIS1": 2048, "NAXIS2": 4096, "BITPIX": 32}
+ ext_template = {
+ "NAXIS": 2,
+ "NAXIS1": 2048,
+ "NAXIS2": 4096,
+ "CRPIX1": 1024,
+ "CPRIX2": 2048,
+ "BITPIX": 32
+ }
- base_wcs = {
+ wcs_template = {
"ctype": ["RA---TAN", "DEC--TAN"],
"crval": [351, -5],
"cunit": ["deg", "deg"],
@@ -36,6 +45,19 @@ class HeaderFactory:
"cd": [[-1.44e-07, 7.32e-05], [7.32e-05, 1.44e-05]],
}
+ def __init__(self, config=None, method="default", shape=None, **kwargs):
+ super().__init__(config=config, method=method, **kwargs)
+
+ if shape is not None:
+ self.ext_template["NAXIS1"] = shape[1]
+ self.ext_template["NAXIS2"] = shape[0]
+ self.ext_template["CRPIX1"] = shape[1]//2
+ self.ext_template["CRPIX2"] = shape[0]//2
+
+
+class HeaderFactory:
+ default_config = HeaderFactoryConfig
+
def __validate_mutables(self):
# !xor
if bool(self.mutables) != bool(self.callbacks):
@@ -59,7 +81,10 @@ def __validate_mutables(self):
"provide the required metadata keys."
)
- def __init__(self, metadata=None, mutables=None, callbacks=None):
+ def __init__(self, metadata=None, mutables=None, callbacks=None,
+ config=None, **kwargs):
+ self.config = self.default_config(config=config, **kwargs)
+
cards = [] if metadata is None else metadata
self.header = Header(cards=cards)
@@ -67,50 +92,61 @@ def __init__(self, metadata=None, mutables=None, callbacks=None):
self.callbacks = callbacks
self.__validate_mutables()
- def mock(self, hdu=None, **kwargs):
+ self.is_dynamic = self.mutables is None
+ self.counter = 0
+
+ def get(self):
if self.mutables is not None:
for i, mutable in enumerate(self.mutables):
self.header[mutable] = self.callbacks[i](self.header[mutable])
-
return self.header
+ def mock(self, **kwargs):
+ self.counter += 1
+ return self.get(**kwargs)
+
@classmethod
- def gen_wcs(cls, crval, metadata=None):
- metadata = cls.base_wcs if metadata is None else metadata
+ def gen_wcs(cls, metadata):
wcs = WCS(naxis=2)
for k, v in metadata.items():
setattr(wcs.wcs, k, v)
return wcs.to_header()
@classmethod
- def gen_header(cls, base, metadata, extend, add_wcs):
- header = Header(base) if extend else Header()
- header.update(metadata)
+ def gen_header(cls, base, overrides, wcs_base=None):
+ header = Header() if base is None else Header(base)
+ header.update(overrides)
- if add_wcs:
+ if wcs_base is not None:
naxis1 = header.get("NAXIS1", False)
naxis2 = header.get("NAXIS2", False)
if not all((naxis1, naxis2)):
- raise ValueError("Adding a WCS to the header requires NAXIS1 and NAXIS2 keys.")
- crpix = [naxis1 / 2.0, naxis2 / 2.0]
- header.update(cls.gen_wcs(crpix))
+ raise ValueError("Adding a WCS to the header requires "
+ "NAXIS1 and NAXIS2 keys.")
+ header.update(cls.gen_wcs(wcs_base))
return header
@classmethod
- def from_base_primary(cls, metadata=None, mutables=None, callbacks=None, extend_base=True, add_wcs=False):
- hdr = cls.gen_header(cls.base_primary, metadata, extend_base, add_wcs)
- return cls(hdr, mutables, callbacks)
+ def from_primary_template(cls, overrides=None, mutables=None, callbacks=None,
+ config=None):
+ config = cls.default_config(config=config)
+ hdr = cls.gen_header(base=config.primary_template, overrides=overrides)
+ return cls(hdr, mutables, callbacks, config=config)
@classmethod
- def from_base_ext(
- cls, metadata=None, mutables=None, callbacks=None, extend_base=True, add_wcs=True, dims=None
- ):
- hdr = cls.gen_header(cls.base_ext, metadata, extend_base, add_wcs)
- return cls(hdr, mutables, callbacks)
-
-
-class ArchivedHeader(HeaderFactory):
+ def from_ext_template(cls, overrides=None, mutables=None, callbacks=None,
+ wcs=None, config=None):
+ config = cls.default_config(config=config)
+ hdr = cls.gen_header(
+ base=config.ext_template,
+ overrides=overrides,
+ wcs_base=config.wcs_template if wcs is None else wcs
+ )
+ return cls(hdr, mutables, callbacks, config=config)
+
+
+class ArchivedHeaderConfig(HeaderFactoryConfig):
# will almost never be anything else. Rather, it would be a miracle if it
# were something else, since FITS standard shouldn't allow it. Further
# casting by some packages will always be casting implemented in terms of
@@ -123,31 +159,61 @@ class ArchivedHeader(HeaderFactory):
}
"""Map between type names and types themselves."""
- def __init__(self, archive_name, fname, compression="bz2", format="ascii.ecsv"):
- self.table = header_archive_to_table(archive_name, fname, compression, format)
+ compression = "bz2"
+
+ format = "ascii.ecsv"
+
+ n_hdrs_per_hdu = 1
+
+class ArchivedHeader(HeaderFactory):
+ default_config = ArchivedHeaderConfig
+
+ def __init__(self, archive_name, fname, config=None, **kwargs):
+ super().__init__(config, **kwargs)
+ self.table = header_archive_to_table(
+ archive_name, fname, self.config.compression, self.config.format
+ )
# Create HDU groups for easier iteration
- self.table = self.table.group_by(["filename", "hdu"])
+ self.table = self.table.group_by("filename")
self.n_hdus = len(self.table)
- # Internal counter for the current fits index,
- # so that we may auto-increment it and avoid returning
- # the same data all the time.
- self._current = 0
-
def lexical_cast(self, value, format):
"""Cast str literal of a type to the type itself. Supports just
the builtin Python types.
"""
- if format in self.lexical_type_map:
- return self.lexical_type_map[format](value)
+ if format in self.config.lexical_type_map:
+ return self.config.lexical_type_map[format](value)
return value
- def mock(self, hdu=None):
+ def get_item(self, group_idx, hdr_idx):
header = Header()
- warnings.filterwarnings("ignore", category=AstropyUserWarning)
- for k, v, f in self.table.groups[self._current]["keyword", "value", "format"]:
+ # this is equivalent to one hdulist worth of headers
+ group = self.table.groups[group_idx]
+ # this is equivalent to one HDU's header
+ subgroup = group.group_by("hdu")
+ for k, v, f in subgroup.groups[hdr_idx]["keyword", "value", "format"]:
header[k] = self.lexical_cast(v, f)
warnings.resetwarnings()
- self._current += 1
return header
+
+ def get(self, group_idx):
+ headers = []
+ # this is a bit repetitive but it saves recreating
+ # groups for one HDUL-equivalent many times
+ group = self.table.groups[group_idx]
+ subgroup = group.group_by("hdu")
+ headers = []
+ for subgroup in subgroup.groups:
+ header = Header()
+ for k, v, f in subgroup["keyword", "value", "format"]:
+ header[k] = self.lexical_cast(v, f)
+ headers.append(header)
+ return headers
+
+ def mock(self, n=1):
+ res = []
+ for _ in range(n):
+ res.append(self.get(self.counter))
+ self.counter += 1
+ return res