diff --git a/python/lsst/analysis/tools/tasks/metricAnalysis.py b/python/lsst/analysis/tools/tasks/metricAnalysis.py index 30f14117a..2e1a1b26f 100644 --- a/python/lsst/analysis/tools/tasks/metricAnalysis.py +++ b/python/lsst/analysis/tools/tasks/metricAnalysis.py @@ -26,6 +26,7 @@ ) +from lsst.pex.config import ListField from lsst.pipe.base import connectionTypes as ct from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask @@ -33,26 +34,63 @@ class MetricAnalysisConnections( AnalysisBaseConnections, - dimensions=("skymap",), - defaultTemplates={"metricBundleName": "objectTableCore_metrics"}, + dimensions=(), + defaultTemplates={"metricBundleName": ""}, ): + data = ct.Input( - doc="A summary table of all metrics by tract.", + doc="A table containing metrics.", name="{metricBundleName}Table", storageClass="ArrowAstropy", - dimensions=("skymap",), deferLoad=True, + dimensions=(), ) + def __init__(self, *, config=None): + super().__init__(config=config) + self.dimensions.update(frozenset(sorted(config.outputDataDimensions))) + self.data = ct.Input( + doc=self.data.doc, + name=self.data.name, + storageClass=self.data.storageClass, + deferLoad=self.data.deferLoad, + dimensions=frozenset(sorted(config.inputDataDimensions)), + ) + class MetricAnalysisConfig(AnalysisBaseConfig, pipelineConnections=MetricAnalysisConnections): - pass + inputDataDimensions = ListField( + doc="Dimensions of the input data.", + default=(), + dtype=str, + optional=False, + ) + outputDataDimensions = ListField( + doc="Dimensions of the input data.", + default=(), + dtype=str, + optional=False, + ) class MetricAnalysisTask(AnalysisPipelineTask): - """Turn metric bundles which are per tract into a - summary metric table. - """ + """Perform an analysis of a metric table.""" ConfigClass = MetricAnalysisConfig _DefaultName = "metricAnalysis" + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + + inputs = butlerQC.get(inputRefs) + dataId = butlerQC.quantum.dataId + plotInfo = self.parsePlotInfo(inputs, dataId) + + data = self.loadData(inputs.pop("data")) + + # This check may not be necessary... + if "band" in data.columns: + outputs = self.run(data=data, plotInfo=plotInfo, band=dataId["band"], **inputs) + else: + outputs = self.run(data=data, plotInfo=plotInfo, **inputs) + + butlerQC.put(outputs, outputRefs)