From 15b4f49f6249a1b6a84e0512d82b535afffc98c3 Mon Sep 17 00:00:00 2001 From: "J. Bryce Kalmbach" Date: Sat, 23 Nov 2024 14:27:04 -0800 Subject: [PATCH] Change tasks to use run methods inside runQuantum to make things easier to test and enable running tasks interactively. --- doc/versionHistory.rst | 8 ++ python/lsst/donut/viz/aggregate_visit.py | 167 +++++++++++++++-------- 2 files changed, 115 insertions(+), 60 deletions(-) diff --git a/doc/versionHistory.rst b/doc/versionHistory.rst index e8ae878..45d0ae5 100644 --- a/doc/versionHistory.rst +++ b/doc/versionHistory.rst @@ -4,6 +4,14 @@ Version History ################## +.._lsst.ts.donut.viz-1.2.2 + +------------- +1.2.2 +------------- + +* Change tasks to use run methods inside runQuantum to make things easier to test and enable running tasks interactively. + .._lsst.ts.donut.viz-1.2.1 ------------- diff --git a/python/lsst/donut/viz/aggregate_visit.py b/python/lsst/donut/viz/aggregate_visit.py index 33a52b8..3ce7df7 100644 --- a/python/lsst/donut/viz/aggregate_visit.py +++ b/python/lsst/donut/viz/aggregate_visit.py @@ -1,9 +1,12 @@ +import typing + import galsim +import lsst.daf.base as dafBase import lsst.pex.config as pexConfig import lsst.pipe.base as pipeBase import numpy as np from astropy import units as u -from astropy.table import Table, vstack +from astropy.table import QTable, Table, vstack from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS from lsst.geom import Point2D, radians from lsst.pipe.base import connectionTypes as ct @@ -40,13 +43,6 @@ class AggregateZernikeTablesTaskConnections( multiple=True, deferGraphConstraint=True, ) - camera = ct.PrerequisiteInput( - name="camera", - storageClass="Camera", - doc="Input camera to construct complete exposures.", - dimensions=["instrument"], - isCalibration=True, - ) aggregateZernikesRaw = ct.Output( doc="Visit-level table of donuts and Zernikes", dimensions=("visit", "instrument"), @@ -79,12 +75,21 @@ def runQuantum( inputRefs: pipeBase.InputQuantizedConnection, outputRefs: pipeBase.OutputQuantizedConnection, ): - camera = butlerQC.get(inputRefs.camera) + + zernike_tables = butlerQC.get(inputRefs.zernikeTable) + out_raw, out_avg = self.run(zernike_tables) + + # Find the right output references + butlerQC.put(out_raw, outputRefs.aggregateZernikesRaw) + butlerQC.put(out_avg, outputRefs.aggregateZernikesAvg) + + @timeMethod + def run(self, zernike_tables: typing.List[QTable]) -> tuple[Table, Table]: raw_tables = [] avg_tables = [] - for zernikesRef in inputRefs.zernikeTable: - zernike_table = butlerQC.get(zernikesRef) + + for zernike_table in zernike_tables: raw_table = Table() zernikes_merged = [] noll_indices = [] @@ -96,11 +101,11 @@ def runQuantum( zernikes_merged = np.array(zernikes_merged).T noll_indices = np.array(noll_indices) raw_table["zk_CCS"] = np.atleast_2d(zernikes_merged[1:]) - raw_table["detector"] = camera[zernikesRef.dataId["detector"]].getName() + raw_table["detector"] = zernike_table.meta["extra"]["det_name"] raw_tables.append(raw_table) avg_table = Table() avg_table["zk_CCS"] = np.atleast_2d(zernikes_merged[0]) - avg_table["detector"] = camera[zernikesRef.dataId["detector"]].getName() + avg_table["detector"] = zernike_table.meta["extra"]["det_name"] avg_tables.append(avg_table) out_raw = vstack(raw_tables) out_avg = vstack(avg_tables) @@ -144,9 +149,7 @@ def runQuantum( cat["zk_OCS"] = cat["zk_OCS"][:, noll_indices - 4] cat["zk_NW"] = cat["zk_NW"][:, noll_indices - 4] - # Find the right output references - butlerQC.put(out_raw, outputRefs.aggregateZernikesRaw) - butlerQC.put(out_avg, outputRefs.aggregateZernikesAvg) + return out_raw, out_avg # Note: cannot make visit a dimension because we have not yet paired visits. @@ -255,18 +258,53 @@ def runQuantum( pairs = self.pairer.run(visitInfoDict) # Make dictionaries to match visits and detectors - donutRefDict = { - (ref.dataId["visit"], ref.dataId["detector"]): ref + donutTables = { + (ref.dataId["visit"], ref.dataId["detector"]): butlerQC.get(ref) for ref in inputRefs.donutTables } - qualityRefDict = { - (ref.dataId["visit"], ref.dataId["detector"]): ref + qualityTables = { + (ref.dataId["visit"], ref.dataId["detector"]): butlerQC.get(ref) for ref in inputRefs.qualityTables } + pairTables = self.run(camera, visitInfoDict, pairs, donutTables, qualityTables) + + # Put pairTables in butler + for pairTable, pairTableRef in zip(pairTables, outputRefs.aggregateDonutTable): + butlerQC.put(pairTable, pairTableRef) + + @timeMethod + def run( + self, + camera, + visitInfoDict: dict, + pairs: list, + donutTables: dict, + qualityTables: dict, + ) -> typing.List[QTable]: + """Aggregate donut tables for a set of visits. + + Parameters + ---------- + camera : lsst.afw.cameraGeom.Camera + The camera object. + visitInfoDict : dict + Dictionary of visit info objects keyed by visit ID. + pairs : list + List of visit pairs. + donutTables : dict + Dictionary of donut tables keyed by (visit, detector). + qualityTables : dict + Dictionary of quality tables keyed by (visit, detector). + + Returns + ------- + list of astropy.table.QTable + List of aggregated donut tables, one per visit pair. + """ # Find common (visit, detector) extra-focal pairs # DonutQualityTables only saved under extra-focal ids - extra_keys = set(donutRefDict) & set(qualityRefDict) + extra_keys = set(donutTables) & set(qualityTables) # Raise error if there's no matches if len(extra_keys) == 0: @@ -275,6 +313,7 @@ def runQuantum( "the donut and quality tables" ) + pairTables = [] for pair in pairs: intraVisitInfo = visitInfoDict[pair.intra] extraVisitInfo = visitInfoDict[pair.extra] @@ -293,9 +332,9 @@ def runQuantum( tform = det.getTransform(PIXELS, FIELD_ANGLE) # Load the donut catalog table, and the donut quality table - intraDonutTable = butlerQC.get(donutRefDict[(pair.intra, detector)]) - extraDonutTable = butlerQC.get(donutRefDict[(pair.extra, detector)]) - qualityTable = butlerQC.get(qualityRefDict[(pair.extra, detector)]) + intraDonutTable = donutTables[(pair.intra, detector)] + extraDonutTable = donutTables[(pair.extra, detector)] + qualityTable = qualityTables[(pair.extra, detector)] # Get rows of quality table for this exposure intraQualityTable = qualityTable[ @@ -336,9 +375,7 @@ def runQuantum( out = vstack(tables) - # TODO: Swap parallactic angle for pseudo parallactic angle. - # See SMTN-019 for details. - + # Add metadata for extra and intra focal exposures out.meta["extra"] = { "visit": pair.extra, "focusZ": extraVisitInfo.focusZ, @@ -389,6 +426,7 @@ def runQuantum( out.meta["extra"]["mjd"] + out.meta["intra"]["mjd"] ) + # Calculate coordinates in different reference frames q = out.meta["average"]["parallacticAngle"] rtp = out.meta["average"]["rotTelPos"] out["thx_OCS"] = np.cos(rtp) * out["thx_CCS"] - np.sin(rtp) * out["thy_CCS"] @@ -396,15 +434,9 @@ def runQuantum( out["th_N"] = np.cos(q) * out["thx_CCS"] - np.sin(q) * out["thy_CCS"] out["th_W"] = np.sin(q) * out["thx_CCS"] + np.cos(q) * out["thy_CCS"] - # Find the right output references - for outRef in outputRefs.aggregateDonutTable: - if outRef.dataId["visit"] == pair.extra: - butlerQC.put(out, outRef) - break - else: - raise ValueError( - f"Expected to find an output reference with visit {pair.extra}" - ) + pairTables.append(out) + + return pairTables class AggregateAOSVisitTableTaskConnections( @@ -471,6 +503,15 @@ def runQuantum( azr = butlerQC.get(inputRefs.aggregateZernikesRaw) aza = butlerQC.get(inputRefs.aggregateZernikesAvg) + avg_table, raw_table = self.run(adc, azr, aza) + + butlerQC.put(avg_table, outputRefs.aggregateAOSAvg) + butlerQC.put(raw_table, outputRefs.aggregateAOSRaw) + + @timeMethod + def run( + self, adc: typing.List[Table], azr: typing.List[Table], aza: typing.List[Table] + ) -> tuple[Table, Table]: dets = np.unique(adc["detector"]) avg_table = aza.copy() avg_keys = [ @@ -528,8 +569,7 @@ def runQuantum( raw_table[k + "_intra"][w] = adc[k][wadc][wintra] raw_table[k + "_extra"][w] = adc[k][wadc][wextra] - butlerQC.put(avg_table, outputRefs.aggregateAOSAvg) - butlerQC.put(raw_table, outputRefs.aggregateAOSRaw) + return avg_table, raw_table class AggregateDonutStampsTaskConnections( @@ -601,19 +641,37 @@ def runQuantum( inputRefs: pipeBase.InputQuantizedConnection, outputRefs: pipeBase.OutputQuantizedConnection, ) -> None: + + intraStampsList, extraStampsList, intraMeta, extraMeta = self.run( + butlerQC.get(inputRefs.donutStampsIntra), + butlerQC.get(inputRefs.donutStampsExtra), + butlerQC.get(inputRefs.qualityTables), + ) + + intraStampsListRavel = np.ravel(intraStampsList) + extraStampsListRavel = np.ravel(extraStampsList) + + butlerQC.put( + DonutStamps(intraStampsListRavel, metadata=intraMeta), + outputRefs.donutStampsIntraVisit, + ) + + butlerQC.put( + DonutStamps(extraStampsListRavel, metadata=extraMeta), + outputRefs.donutStampsExtraVisit, + ) + + @timeMethod + def run( + self, + intraStamps: typing.List, + extraStamps: typing.List, + qualityTables: typing.List, + ) -> tuple[typing.List, typing.List, dafBase.PropertyList, dafBase.PropertyList]: intraStampsList = [] extraStampsList = [] - for intraRef, extraRef, qualityRef in zip( - inputRefs.donutStampsIntra, - inputRefs.donutStampsExtra, - inputRefs.qualityTables, - ): - # Load the donuts - intra = butlerQC.get(intraRef) - extra = butlerQC.get(extraRef) - + for intra, extra, quality in zip(intraStamps, extraStamps, qualityTables): # Load the quality table and determine which donuts were selected - quality = butlerQC.get(qualityRef) intraSelect = quality[quality["DEFOCAL_TYPE"] == "intra"]["FINAL_SELECT"] extraSelect = quality[quality["DEFOCAL_TYPE"] == "extra"]["FINAL_SELECT"] @@ -649,15 +707,4 @@ def runQuantum( intraStampsList.append(intra[: self.config.maxDonutsPerDetector]) extraStampsList.append(extra[: self.config.maxDonutsPerDetector]) - intraStampsListRavel = np.ravel(intraStampsList) - extraStampsListRavel = np.ravel(extraStampsList) - - butlerQC.put( - DonutStamps(intraStampsListRavel, metadata=intra.metadata), - outputRefs.donutStampsIntraVisit, - ) - - butlerQC.put( - DonutStamps(extraStampsListRavel, metadata=extra.metadata), - outputRefs.donutStampsExtraVisit, - ) + return intraStampsList, extraStampsList, extra.metadata, intra.metadata