Skip to content

Commit

Permalink
Merge pull request #287 from lsst/tickets/DM-41955
Browse files Browse the repository at this point in the history
DM-41955: Remove unphysical diaSources from the output of detectAndMeasure
  • Loading branch information
isullivan authored Jan 4, 2024
2 parents 7dcc1e6 + 3c9f4bd commit 8096f4f
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 23 deletions.
51 changes: 44 additions & 7 deletions python/lsst/ip/diffim/detectAndMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig,
target=SkyObjectsTask,
doc="Generate sky sources",
)
badSourceFlags = lsst.pex.config.ListField(
dtype=str,
doc="Sources with any of these flags set are removed before writing the output catalog.",
default=("base_PixelFlags_flag_offimage",
),
)
idGenerator = DetectorVisitIdGeneratorConfig.make_field()

def setDefaults(self):
Expand Down Expand Up @@ -222,6 +228,10 @@ def __init__(self, **kwargs):
self.makeSubtask("skySources")
self.skySourceKey = self.schema.addField("sky_source", type="Flag", doc="Sky objects.")

# Check that the schema and config are consistent
for flag in self.config.badSourceFlags:
if flag not in self.schema:
raise pipeBase.InvalidQuantumError("Field %s not in schema" % flag)
# initialize InitOutputs
self.outputSchema = afwTable.SourceCatalog(self.schema)
self.outputSchema.getTable().setMetadata(self.algMetadata)
Expand Down Expand Up @@ -353,17 +363,18 @@ def processResults(self, science, matchedTemplate, difference, sources, table,
fpSet = positiveFootprints
fpSet.merge(negativeFootprints, self.config.growFootprint,
self.config.growFootprint, False)
diaSources = afwTable.SourceCatalog(table)
fpSet.makeSources(diaSources)
self.log.info("Merging detections into %d sources", len(diaSources))
initialDiaSources = afwTable.SourceCatalog(table)
fpSet.makeSources(initialDiaSources)
self.log.info("Merging detections into %d sources", len(initialDiaSources))
else:
diaSources = sources
self.metadata.add("nMergedDiaSources", len(diaSources))
initialDiaSources = sources
self.metadata.add("nMergedDiaSources", len(initialDiaSources))

if self.config.doSkySources:
self.addSkySources(diaSources, difference.mask, difference.info.id)
self.addSkySources(initialDiaSources, difference.mask, difference.info.id)

self.measureDiaSources(diaSources, science, difference, matchedTemplate)
self.measureDiaSources(initialDiaSources, science, difference, matchedTemplate)
diaSources = self._removeBadSources(initialDiaSources)

if self.config.doForcedMeasurement:
self.measureForcedSources(diaSources, science, difference.getWcs())
Expand All @@ -376,6 +387,32 @@ def processResults(self, science, matchedTemplate, difference, sources, table,

return measurementResults

def _removeBadSources(self, diaSources):
"""Remove bad diaSources from the catalog.
Parameters
----------
diaSources : `lsst.afw.table.SourceCatalog`
The catalog of detected sources.
Returns
-------
diaSources : `lsst.afw.table.SourceCatalog`
The updated catalog of detected sources, with any source that has a
flag in ``config.badSourceFlags`` set removed.
"""
nBadTotal = 0
selector = np.ones(len(diaSources), dtype=bool)
for flag in self.config.badSourceFlags:
flags = diaSources[flag]
nBad = np.count_nonzero(flags)
if nBad > 0:
self.log.info("Found and removed %d unphysical sources with flag %s.", nBad, flag)
selector &= ~flags
nBadTotal += nBad
self.metadata.add("nRemovedBadFlaggedSources", nBadTotal)
return diaSources[selector].copy(deep=True)

def addSkySources(self, diaSources, mask, seed):
"""Add sources in empty regions of the difference image
for measuring the background.
Expand Down
90 changes: 74 additions & 16 deletions tests/test_detectAndMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import lsst.geom
from lsst.ip.diffim import detectAndMeasure, subtractImages
from lsst.ip.diffim.utils import makeTestImage
from lsst.pipe.base import InvalidQuantumError
import lsst.utils.tests


Expand Down Expand Up @@ -92,33 +93,28 @@ def _check_values(self, values, minValue=None, maxValue=None):
if maxValue is not None:
self.assertTrue(np.all(values <= maxValue))

def _setup_detection(self, doApCorr=False, doMerge=False,
doSkySources=False, doForcedMeasurement=False):
def _setup_detection(self, doSkySources=False, nSkySources=5, **kwargs):
"""Setup and configure the detection and measurement PipelineTask.
Parameters
----------
doApCorr : `bool`, optional
Run subtask to apply aperture corrections.
doMerge : `bool`, optional
Merge positive and negative diaSources.
doSkySources : `bool`, optional
Generate sky sources.
doForcedMeasurement : `bool`, optional
Force photometer diaSource locations on PVI.
nSkySources : `int`, optional
The number of sky sources to add in isolated background regions.
**kwargs
Any additional config parameters to set.
Returns
-------
`lsst.pipe.base.PipelineTask`
The configured Task to use for detection and measurement.
"""
config = self.detectionTask.ConfigClass()
config.doApCorr = doApCorr
config.doMerge = doMerge
config.doSkySources = doSkySources
config.doForcedMeasurement = doForcedMeasurement
if doSkySources:
config.skySources.nSources = 5
config.skySources.nSources = nSkySources
config.update(**kwargs)
return self.detectionTask(config=config)


Expand Down Expand Up @@ -189,6 +185,68 @@ def test_measurements_finite(self):
self._check_values(output.diaSources.getY(), minValue=0, maxValue=ySize)
self._check_values(output.diaSources.getPsfInstFlux())

def test_raise_config_schema_mismatch(self):
"""Check that sources with specified flags are removed from the catalog.
"""
# Configure the detection Task, and and set a config that is not in the schema
with self.assertRaises(InvalidQuantumError):
self._setup_detection(badSourceFlags=["Bogus_flag_42"])

def test_remove_unphysical(self):
"""Check that sources with specified flags are removed from the catalog.
"""
# Set up the simulated images
noiseLevel = 1.
staticSeed = 1
xSize = 256
ySize = 256
kwargs = {"psfSize": 2.4, "xSize": xSize, "ySize": ySize}
science, sources = makeTestImage(seed=staticSeed, noiseLevel=noiseLevel, noiseSeed=6,
nSrc=1, **kwargs)
matchedTemplate, _ = makeTestImage(seed=staticSeed, noiseLevel=noiseLevel/4, noiseSeed=7,
nSrc=1, **kwargs)
difference = science.clone()
bbox = difference.getBBox()
difference.maskedImage -= matchedTemplate.maskedImage

# Configure the detection Task, and do not remove unphysical sources
detectionTask = self._setup_detection(doForcedMeasurement=False, doSkySources=True, nSkySources=20,
badSourceFlags=[])

# Run detection and check the results
diaSources = detectionTask.run(science, matchedTemplate, difference).diaSources
badDiaSrcNoRemove = ~bbox.contains(diaSources.getX(), diaSources.getY())
nBadNoRemove = np.count_nonzero(badDiaSrcNoRemove)
# Verify that unphysical sources exist
self.assertGreater(nBadNoRemove, 0)

# Configure the detection Task, and remove unphysical sources
detectionTask = self._setup_detection(doForcedMeasurement=False, doSkySources=True, nSkySources=20,
badSourceFlags=["base_PixelFlags_flag_offimage", ])

# Run detection and check the results
diaSources = detectionTask.run(science, matchedTemplate, difference).diaSources
badDiaSrcDoRemove = ~bbox.contains(diaSources.getX(), diaSources.getY())
nBadDoRemove = np.count_nonzero(badDiaSrcDoRemove)
# Verify that all sources are physical
self.assertEqual(nBadDoRemove, 0)
# Set a few centroids outside the image bounding box
nSetBad = 5
for src in diaSources[0: nSetBad]:
src["slot_Centroid_x"] += xSize
src["slot_Centroid_y"] += ySize
src["base_PixelFlags_flag_offimage"] = True
# Verify that these sources are outside the image
badDiaSrc = ~bbox.contains(diaSources.getX(), diaSources.getY())
nBad = np.count_nonzero(badDiaSrc)
self.assertEqual(nBad, nSetBad)
diaSourcesNoBad = detectionTask._removeBadSources(diaSources)
badDiaSrcNoBad = ~bbox.contains(diaSourcesNoBad.getX(), diaSourcesNoBad.getY())

# Verify that no sources outside the image bounding box remain
self.assertEqual(np.count_nonzero(badDiaSrcNoBad), 0)
self.assertEqual(len(diaSourcesNoBad), len(diaSources) - nSetBad)

def test_detect_transients(self):
"""Run detection on a difference image containing transients.
"""
Expand All @@ -202,7 +260,7 @@ def test_detect_transients(self):
matchedTemplate, _ = makeTestImage(noiseLevel=noiseLevel/4, noiseSeed=7, **kwargs)

# Configure the detection Task
detectionTask = self._setup_detection()
detectionTask = self._setup_detection(doMerge=False)
kwargs["seed"] = transientSeed
kwargs["nSrc"] = 10
kwargs["fluxLevel"] = 1000
Expand Down Expand Up @@ -254,7 +312,7 @@ def test_detect_dipoles(self):
difference.maskedImage -= matchedTemplate.maskedImage[science.getBBox()]

# Configure the detection Task
detectionTask = self._setup_detection()
detectionTask = self._setup_detection(doMerge=False)

# Run detection and check the results
output = detectionTask.run(science, matchedTemplate, difference)
Expand Down Expand Up @@ -462,7 +520,7 @@ def test_detect_transients(self):
subtractTask = subtractImages.AlardLuptonPreconvolveSubtractTask()

# Configure the detection Task
detectionTask = self._setup_detection()
detectionTask = self._setup_detection(doMerge=False)
kwargs["seed"] = transientSeed
kwargs["nSrc"] = 10
kwargs["fluxLevel"] = 1000
Expand Down Expand Up @@ -532,7 +590,7 @@ def test_detect_dipoles(self):
score = subtractTask._convolveExposure(difference, scienceKernel, subtractTask.convolutionControl)

# Configure the detection Task
detectionTask = self._setup_detection()
detectionTask = self._setup_detection(doMerge=False)

# Run detection and check the results
output = detectionTask.run(science, matchedTemplate, difference, score)
Expand Down

0 comments on commit 8096f4f

Please sign in to comment.