Skip to content

Commit

Permalink
Add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
rai-harshit committed Jan 27, 2024
1 parent 694a14f commit 72e9761
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
9 changes: 7 additions & 2 deletions python/lsst/ap/association/filterDiaSourceCatalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import lsst.pipe.base.connectionTypes as connTypes
from lsst.meas.base import DetectorVisitIdGeneratorConfig
from lsst.utils.timer import timeMethod
import numpy as np


class FilterDiaSourceCatalogConnections(
Expand Down Expand Up @@ -98,5 +99,9 @@ def run(self, diaSourceCat, ccdVisitId):
results : `lsst.pipe.base.Struct`
Results struct with components.
"""
filteredDiaSourceCat = diaSourceCat[~diaSourceCat["sky_source"]]
return pipeBase.Struct(filteredDiaSourceCat=filteredDiaSourceCat)
if self.config.doRemoveSkySources:
sky_source_column = diaSourceCat["sky_source"]
num_sky_sources = np.sum(sky_source_column)
diaSourceCat = diaSourceCat[~sky_source_column]
self.log.info(f"Filtered {num_sky_sources} sky sources.")
return pipeBase.Struct(filteredDiaSourceCat=diaSourceCat)
36 changes: 34 additions & 2 deletions tests/test_filterDiaSourceCatalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,42 @@

import unittest
import lsst.utils.tests
import lsst.meas.base.tests as measTests
import lsst.geom as geom
from lsst.ap.association.filterDiaSourceCatalog import (FilterDiaSourceCatalogConfig,
FilterDiaSourceCatalogTask)


class TestFilterDiaSourceCatalogTask(unittest.TestCase):
pass


def setUp(self):
self.nSources = 10
self.yLoc = 100
self.expId = 4321
self.bbox = geom.Box2I(geom.Point2I(0, 0),
geom.Extent2I(1024, 1153))
dataset = measTests.TestDataset(self.bbox)
for srcIdx in range(self.nSources):
dataset.addSource(10000.0, geom.Point2D(srcIdx, self.yLoc))
schema = dataset.makeMinimalSchema()
schema.addField("sky_source", type="Flag", doc="Sky objects.")
_, self.diaSourceCat = dataset.realize(10.0, schema, randomSeed=1234)
self.diaSourceCat[0:5]["sky_source"] = True
self.config = FilterDiaSourceCatalogConfig()

def test_run(self):
self.config.doRemoveSkySources = False
filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
result = filterDiaSourceCatalogTask.run(self.diaSourceCat, ccdVisitId=self.expId)
self.assertEqual(len(result.filteredDiaSourceCat), len(self.diaSourceCat))

def test_run_no_filter(self):
self.config.doRemoveSkySources = True
filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
result = filterDiaSourceCatalogTask.run(self.diaSourceCat, ccdVisitId=self.expId)
self.assertEqual(len(result.filteredDiaSourceCat),
len(self.diaSourceCat[~self.diaSourceCat['sky_source']]))


class MemoryTester(lsst.utils.tests.MemoryTestCase):
pass
Expand Down

0 comments on commit 72e9761

Please sign in to comment.