From 72e976157619c008d2c1ec46e26841381e001bee Mon Sep 17 00:00:00 2001 From: rai-harshit Date: Fri, 26 Jan 2024 20:16:03 -0800 Subject: [PATCH] Add test cases --- .../ap/association/filterDiaSourceCatalog.py | 9 +++-- tests/test_filterDiaSourceCatalog.py | 36 +++++++++++++++++-- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/python/lsst/ap/association/filterDiaSourceCatalog.py b/python/lsst/ap/association/filterDiaSourceCatalog.py index 9d3832d9..202f616b 100644 --- a/python/lsst/ap/association/filterDiaSourceCatalog.py +++ b/python/lsst/ap/association/filterDiaSourceCatalog.py @@ -30,6 +30,7 @@ import lsst.pipe.base.connectionTypes as connTypes from lsst.meas.base import DetectorVisitIdGeneratorConfig from lsst.utils.timer import timeMethod +import numpy as np class FilterDiaSourceCatalogConnections( @@ -98,5 +99,9 @@ def run(self, diaSourceCat, ccdVisitId): results : `lsst.pipe.base.Struct` Results struct with components. """ - filteredDiaSourceCat = diaSourceCat[~diaSourceCat["sky_source"]] - return pipeBase.Struct(filteredDiaSourceCat=filteredDiaSourceCat) + if self.config.doRemoveSkySources: + sky_source_column = diaSourceCat["sky_source"] + num_sky_sources = np.sum(sky_source_column) + diaSourceCat = diaSourceCat[~sky_source_column] + self.log.info(f"Filtered {num_sky_sources} sky sources.") + return pipeBase.Struct(filteredDiaSourceCat=diaSourceCat) diff --git a/tests/test_filterDiaSourceCatalog.py b/tests/test_filterDiaSourceCatalog.py index 11202b58..91ac3710 100644 --- a/tests/test_filterDiaSourceCatalog.py +++ b/tests/test_filterDiaSourceCatalog.py @@ -21,10 +21,42 @@ import unittest import lsst.utils.tests +import lsst.meas.base.tests as measTests +import lsst.geom as geom +from lsst.ap.association.filterDiaSourceCatalog import (FilterDiaSourceCatalogConfig, + FilterDiaSourceCatalogTask) + class TestFilterDiaSourceCatalogTask(unittest.TestCase): - pass - + + def setUp(self): + self.nSources = 10 + self.yLoc = 100 + self.expId = 4321 + self.bbox = geom.Box2I(geom.Point2I(0, 0), + geom.Extent2I(1024, 1153)) + dataset = measTests.TestDataset(self.bbox) + for srcIdx in range(self.nSources): + dataset.addSource(10000.0, geom.Point2D(srcIdx, self.yLoc)) + schema = dataset.makeMinimalSchema() + schema.addField("sky_source", type="Flag", doc="Sky objects.") + _, self.diaSourceCat = dataset.realize(10.0, schema, randomSeed=1234) + self.diaSourceCat[0:5]["sky_source"] = True + self.config = FilterDiaSourceCatalogConfig() + + def test_run(self): + self.config.doRemoveSkySources = False + filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) + result = filterDiaSourceCatalogTask.run(self.diaSourceCat, ccdVisitId=self.expId) + self.assertEqual(len(result.filteredDiaSourceCat), len(self.diaSourceCat)) + + def test_run_no_filter(self): + self.config.doRemoveSkySources = True + filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) + result = filterDiaSourceCatalogTask.run(self.diaSourceCat, ccdVisitId=self.expId) + self.assertEqual(len(result.filteredDiaSourceCat), + len(self.diaSourceCat[~self.diaSourceCat['sky_source']])) + class MemoryTester(lsst.utils.tests.MemoryTestCase): pass