Skip to content

Commit

Permalink
Generalize MetricAnalysisTask
Browse files Browse the repository at this point in the history
Enable MetricAnalysisTask to work with any metric table. Inputs and dimensions specified at runtime via the pipeline yaml file.
  • Loading branch information
jrmullaney committed Nov 25, 2024
1 parent e46c45b commit 7b3eeaa
Showing 1 changed file with 46 additions and 8 deletions.
54 changes: 46 additions & 8 deletions python/lsst/analysis/tools/tasks/metricAnalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,71 @@
)


from lsst.pex.config import ListField
from lsst.pipe.base import connectionTypes as ct

from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask


class MetricAnalysisConnections(
AnalysisBaseConnections,
dimensions=("skymap",),
defaultTemplates={"metricBundleName": "objectTableCore_metrics"},
dimensions=(),
defaultTemplates={"metricBundleName": ""},
):

data = ct.Input(
doc="A summary table of all metrics by tract.",
doc="A table containing metrics.",
name="{metricBundleName}Table",
storageClass="ArrowAstropy",
dimensions=("skymap",),
deferLoad=True,
dimensions=(),
)

def __init__(self, *, config=None):
super().__init__(config=config)
self.dimensions.update(frozenset(sorted(config.outputDataDimensions)))
self.data = ct.Input(
doc=self.data.doc,
name=self.data.name,
storageClass=self.data.storageClass,
deferLoad=self.data.deferLoad,
dimensions=frozenset(sorted(config.inputDataDimensions)),
)


class MetricAnalysisConfig(AnalysisBaseConfig, pipelineConnections=MetricAnalysisConnections):
pass
inputDataDimensions = ListField(
doc="Dimensions of the input data.",
default=(),
dtype=str,
optional=False,
)
outputDataDimensions = ListField(
doc="Dimensions of the input data.",
default=(),
dtype=str,
optional=False,
)


class MetricAnalysisTask(AnalysisPipelineTask):
"""Turn metric bundles which are per tract into a
summary metric table.
"""
"""Perform an analysis of a metric table."""

ConfigClass = MetricAnalysisConfig
_DefaultName = "metricAnalysis"

def runQuantum(self, butlerQC, inputRefs, outputRefs):

inputs = butlerQC.get(inputRefs)
dataId = butlerQC.quantum.dataId
plotInfo = self.parsePlotInfo(inputs, dataId)

data = self.loadData(inputs.pop("data"))

# This check may not be necessary...
if "band" in data.columns:
outputs = self.run(data=data, plotInfo=plotInfo, band=dataId["band"], **inputs)
else:
outputs = self.run(data=data, plotInfo=plotInfo, **inputs)

butlerQC.put(outputs, outputRefs)

0 comments on commit 7b3eeaa

Please sign in to comment.