Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-41955: Remove unphysical diaSources from the output of detectAndMeasure #287

Merged
merged 6 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions python/lsst/ip/diffim/detectAndMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig,
target=SkyObjectsTask,
doc="Generate sky sources",
)
badSourceFlags = lsst.pex.config.ListField(
dtype=str,
doc="Sources with any of these flags set are removed before writing the output catalog.",
default=("base_PixelFlags_flag_offimage",
),
)
idGenerator = DetectorVisitIdGeneratorConfig.make_field()

def setDefaults(self):
Expand Down Expand Up @@ -222,6 +228,10 @@ def __init__(self, **kwargs):
self.makeSubtask("skySources")
self.skySourceKey = self.schema.addField("sky_source", type="Flag", doc="Sky objects.")

# Check that the schema and config are consistent
for flag in self.config.badSourceFlags:
if flag not in self.schema:
raise pipeBase.InvalidQuantumError("Field %s not in schema" % flag)
# initialize InitOutputs
self.outputSchema = afwTable.SourceCatalog(self.schema)
self.outputSchema.getTable().setMetadata(self.algMetadata)
Expand Down Expand Up @@ -353,17 +363,18 @@ def processResults(self, science, matchedTemplate, difference, sources, table,
fpSet = positiveFootprints
fpSet.merge(negativeFootprints, self.config.growFootprint,
self.config.growFootprint, False)
diaSources = afwTable.SourceCatalog(table)
fpSet.makeSources(diaSources)
self.log.info("Merging detections into %d sources", len(diaSources))
initialDiaSources = afwTable.SourceCatalog(table)
fpSet.makeSources(initialDiaSources)
self.log.info("Merging detections into %d sources", len(initialDiaSources))
else:
diaSources = sources
self.metadata.add("nMergedDiaSources", len(diaSources))
initialDiaSources = sources
self.metadata.add("nMergedDiaSources", len(initialDiaSources))

if self.config.doSkySources:
self.addSkySources(diaSources, difference.mask, difference.info.id)
self.addSkySources(initialDiaSources, difference.mask, difference.info.id)

self.measureDiaSources(diaSources, science, difference, matchedTemplate)
self.measureDiaSources(initialDiaSources, science, difference, matchedTemplate)
diaSources = self._removeBadSources(initialDiaSources)

if self.config.doForcedMeasurement:
self.measureForcedSources(diaSources, science, difference.getWcs())
Expand All @@ -376,6 +387,32 @@ def processResults(self, science, matchedTemplate, difference, sources, table,

return measurementResults

def _removeBadSources(self, diaSources):
"""Remove bad diaSources from the catalog.

Parameters
----------
diaSources : `lsst.afw.table.SourceCatalog`
The catalog of detected sources.

Returns
-------
diaSources : `lsst.afw.table.SourceCatalog`
The updated catalog of detected sources, with any source that has a
flag in ``config.badSourceFlags`` set removed.
"""
nBadTotal = 0
selector = np.ones(len(diaSources), dtype=bool)
for flag in self.config.badSourceFlags:
flags = diaSources[flag]
nBad = np.count_nonzero(flags)
if nBad > 0:
self.log.info("Found and removed %d unphysical sources with flag %s.", nBad, flag)
selector &= ~flags
nBadTotal += nBad
self.metadata.add("nRemovedBadFlaggedSources", nBadTotal)
return diaSources[selector].copy(deep=True)

def addSkySources(self, diaSources, mask, seed):
"""Add sources in empty regions of the difference image
for measuring the background.
Expand Down
90 changes: 74 additions & 16 deletions tests/test_detectAndMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import lsst.geom
from lsst.ip.diffim import detectAndMeasure, subtractImages
from lsst.ip.diffim.utils import makeTestImage
from lsst.pipe.base import InvalidQuantumError
import lsst.utils.tests


Expand Down Expand Up @@ -92,33 +93,28 @@ def _check_values(self, values, minValue=None, maxValue=None):
if maxValue is not None:
self.assertTrue(np.all(values <= maxValue))

def _setup_detection(self, doApCorr=False, doMerge=False,
doSkySources=False, doForcedMeasurement=False):
def _setup_detection(self, doSkySources=False, nSkySources=5, **kwargs):
"""Setup and configure the detection and measurement PipelineTask.

Parameters
----------
doApCorr : `bool`, optional
Run subtask to apply aperture corrections.
doMerge : `bool`, optional
Merge positive and negative diaSources.
doSkySources : `bool`, optional
Generate sky sources.
doForcedMeasurement : `bool`, optional
Force photometer diaSource locations on PVI.
nSkySources : `int`, optional
The number of sky sources to add in isolated background regions.
**kwargs
Any additional config parameters to set.

Returns
-------
`lsst.pipe.base.PipelineTask`
The configured Task to use for detection and measurement.
"""
config = self.detectionTask.ConfigClass()
config.doApCorr = doApCorr
config.doMerge = doMerge
config.doSkySources = doSkySources
config.doForcedMeasurement = doForcedMeasurement
if doSkySources:
config.skySources.nSources = 5
config.skySources.nSources = nSkySources
config.update(**kwargs)
return self.detectionTask(config=config)


Expand Down Expand Up @@ -189,6 +185,68 @@ def test_measurements_finite(self):
self._check_values(output.diaSources.getY(), minValue=0, maxValue=ySize)
self._check_values(output.diaSources.getPsfInstFlux())

def test_raise_config_schema_mismatch(self):
"""Check that sources with specified flags are removed from the catalog.
"""
parejkoj marked this conversation as resolved.
Show resolved Hide resolved
# Configure the detection Task, and and set a config that is not in the schema
with self.assertRaises(InvalidQuantumError):
self._setup_detection(badSourceFlags=["Bogus_flag_42"])

parejkoj marked this conversation as resolved.
Show resolved Hide resolved
def test_remove_unphysical(self):
"""Check that sources with specified flags are removed from the catalog.
"""
# Set up the simulated images
noiseLevel = 1.
staticSeed = 1
xSize = 256
ySize = 256
kwargs = {"psfSize": 2.4, "xSize": xSize, "ySize": ySize}
science, sources = makeTestImage(seed=staticSeed, noiseLevel=noiseLevel, noiseSeed=6,
nSrc=1, **kwargs)
matchedTemplate, _ = makeTestImage(seed=staticSeed, noiseLevel=noiseLevel/4, noiseSeed=7,
nSrc=1, **kwargs)
difference = science.clone()
bbox = difference.getBBox()
difference.maskedImage -= matchedTemplate.maskedImage

# Configure the detection Task, and do not remove unphysical sources
detectionTask = self._setup_detection(doForcedMeasurement=False, doSkySources=True, nSkySources=20,
badSourceFlags=[])

# Run detection and check the results
diaSources = detectionTask.run(science, matchedTemplate, difference).diaSources
badDiaSrcNoRemove = ~bbox.contains(diaSources.getX(), diaSources.getY())
nBadNoRemove = np.count_nonzero(badDiaSrcNoRemove)
# Verify that unphysical sources exist
self.assertGreater(nBadNoRemove, 0)

# Configure the detection Task, and remove unphysical sources
detectionTask = self._setup_detection(doForcedMeasurement=False, doSkySources=True, nSkySources=20,
badSourceFlags=["base_PixelFlags_flag_offimage", ])

# Run detection and check the results
diaSources = detectionTask.run(science, matchedTemplate, difference).diaSources
badDiaSrcDoRemove = ~bbox.contains(diaSources.getX(), diaSources.getY())
nBadDoRemove = np.count_nonzero(badDiaSrcDoRemove)
# Verify that all sources are physical
self.assertEqual(nBadDoRemove, 0)
# Set a few centroids outside the image bounding box
nSetBad = 5
for src in diaSources[0: nSetBad]:
src["slot_Centroid_x"] += xSize
src["slot_Centroid_y"] += ySize
src["base_PixelFlags_flag_offimage"] = True
# Verify that these sources are outside the image
badDiaSrc = ~bbox.contains(diaSources.getX(), diaSources.getY())
nBad = np.count_nonzero(badDiaSrc)
self.assertEqual(nBad, nSetBad)
diaSourcesNoBad = detectionTask._removeBadSources(diaSources)
badDiaSrcNoBad = ~bbox.contains(diaSourcesNoBad.getX(), diaSourcesNoBad.getY())

# Verify that no sources outside the image bounding box remain
self.assertEqual(np.count_nonzero(badDiaSrcNoBad), 0)
self.assertEqual(len(diaSourcesNoBad), len(diaSources) - nSetBad)

def test_detect_transients(self):
"""Run detection on a difference image containing transients.
"""
Expand All @@ -202,7 +260,7 @@ def test_detect_transients(self):
matchedTemplate, _ = makeTestImage(noiseLevel=noiseLevel/4, noiseSeed=7, **kwargs)

# Configure the detection Task
detectionTask = self._setup_detection()
detectionTask = self._setup_detection(doMerge=False)
kwargs["seed"] = transientSeed
kwargs["nSrc"] = 10
kwargs["fluxLevel"] = 1000
Expand Down Expand Up @@ -254,7 +312,7 @@ def test_detect_dipoles(self):
difference.maskedImage -= matchedTemplate.maskedImage[science.getBBox()]

# Configure the detection Task
detectionTask = self._setup_detection()
detectionTask = self._setup_detection(doMerge=False)

# Run detection and check the results
output = detectionTask.run(science, matchedTemplate, difference)
Expand Down Expand Up @@ -462,7 +520,7 @@ def test_detect_transients(self):
subtractTask = subtractImages.AlardLuptonPreconvolveSubtractTask()

# Configure the detection Task
detectionTask = self._setup_detection()
detectionTask = self._setup_detection(doMerge=False)
kwargs["seed"] = transientSeed
kwargs["nSrc"] = 10
kwargs["fluxLevel"] = 1000
Expand Down Expand Up @@ -532,7 +590,7 @@ def test_detect_dipoles(self):
score = subtractTask._convolveExposure(difference, scienceKernel, subtractTask.convolutionControl)

# Configure the detection Task
detectionTask = self._setup_detection()
detectionTask = self._setup_detection(doMerge=False)

# Run detection and check the results
output = detectionTask.run(science, matchedTemplate, difference, score)
Expand Down