Skip to content

Commit

Permalink
Change tasks to use run methods inside runQuantum to make things easi…
Browse files Browse the repository at this point in the history
…er to test and enable running tasks interactively.
  • Loading branch information
jbkalmbach committed Nov 23, 2024
1 parent 06b31d4 commit 15b4f49
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 60 deletions.
8 changes: 8 additions & 0 deletions doc/versionHistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
Version History
##################

.._lsst.ts.donut.viz-1.2.2

-------------
1.2.2
-------------

* Change tasks to use run methods inside runQuantum to make things easier to test and enable running tasks interactively.

.._lsst.ts.donut.viz-1.2.1

-------------
Expand Down
167 changes: 107 additions & 60 deletions python/lsst/donut/viz/aggregate_visit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import typing

import galsim
import lsst.daf.base as dafBase
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
import numpy as np
from astropy import units as u
from astropy.table import Table, vstack
from astropy.table import QTable, Table, vstack
from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS
from lsst.geom import Point2D, radians
from lsst.pipe.base import connectionTypes as ct
Expand Down Expand Up @@ -40,13 +43,6 @@ class AggregateZernikeTablesTaskConnections(
multiple=True,
deferGraphConstraint=True,
)
camera = ct.PrerequisiteInput(
name="camera",
storageClass="Camera",
doc="Input camera to construct complete exposures.",
dimensions=["instrument"],
isCalibration=True,
)
aggregateZernikesRaw = ct.Output(
doc="Visit-level table of donuts and Zernikes",
dimensions=("visit", "instrument"),
Expand Down Expand Up @@ -79,12 +75,21 @@ def runQuantum(
inputRefs: pipeBase.InputQuantizedConnection,
outputRefs: pipeBase.OutputQuantizedConnection,
):
camera = butlerQC.get(inputRefs.camera)

zernike_tables = butlerQC.get(inputRefs.zernikeTable)
out_raw, out_avg = self.run(zernike_tables)

# Find the right output references
butlerQC.put(out_raw, outputRefs.aggregateZernikesRaw)
butlerQC.put(out_avg, outputRefs.aggregateZernikesAvg)

@timeMethod
def run(self, zernike_tables: typing.List[QTable]) -> tuple[Table, Table]:

raw_tables = []
avg_tables = []
for zernikesRef in inputRefs.zernikeTable:
zernike_table = butlerQC.get(zernikesRef)

for zernike_table in zernike_tables:
raw_table = Table()
zernikes_merged = []
noll_indices = []
Expand All @@ -96,11 +101,11 @@ def runQuantum(
zernikes_merged = np.array(zernikes_merged).T
noll_indices = np.array(noll_indices)
raw_table["zk_CCS"] = np.atleast_2d(zernikes_merged[1:])
raw_table["detector"] = camera[zernikesRef.dataId["detector"]].getName()
raw_table["detector"] = zernike_table.meta["extra"]["det_name"]
raw_tables.append(raw_table)
avg_table = Table()
avg_table["zk_CCS"] = np.atleast_2d(zernikes_merged[0])
avg_table["detector"] = camera[zernikesRef.dataId["detector"]].getName()
avg_table["detector"] = zernike_table.meta["extra"]["det_name"]
avg_tables.append(avg_table)
out_raw = vstack(raw_tables)
out_avg = vstack(avg_tables)
Expand Down Expand Up @@ -144,9 +149,7 @@ def runQuantum(
cat["zk_OCS"] = cat["zk_OCS"][:, noll_indices - 4]
cat["zk_NW"] = cat["zk_NW"][:, noll_indices - 4]

# Find the right output references
butlerQC.put(out_raw, outputRefs.aggregateZernikesRaw)
butlerQC.put(out_avg, outputRefs.aggregateZernikesAvg)
return out_raw, out_avg


# Note: cannot make visit a dimension because we have not yet paired visits.
Expand Down Expand Up @@ -255,18 +258,53 @@ def runQuantum(
pairs = self.pairer.run(visitInfoDict)

# Make dictionaries to match visits and detectors
donutRefDict = {
(ref.dataId["visit"], ref.dataId["detector"]): ref
donutTables = {
(ref.dataId["visit"], ref.dataId["detector"]): butlerQC.get(ref)
for ref in inputRefs.donutTables
}
qualityRefDict = {
(ref.dataId["visit"], ref.dataId["detector"]): ref
qualityTables = {
(ref.dataId["visit"], ref.dataId["detector"]): butlerQC.get(ref)
for ref in inputRefs.qualityTables
}

pairTables = self.run(camera, visitInfoDict, pairs, donutTables, qualityTables)

# Put pairTables in butler
for pairTable, pairTableRef in zip(pairTables, outputRefs.aggregateDonutTable):
butlerQC.put(pairTable, pairTableRef)

@timeMethod
def run(
self,
camera,
visitInfoDict: dict,
pairs: list,
donutTables: dict,
qualityTables: dict,
) -> typing.List[QTable]:
"""Aggregate donut tables for a set of visits.
Parameters
----------
camera : lsst.afw.cameraGeom.Camera
The camera object.
visitInfoDict : dict
Dictionary of visit info objects keyed by visit ID.
pairs : list
List of visit pairs.
donutTables : dict
Dictionary of donut tables keyed by (visit, detector).
qualityTables : dict
Dictionary of quality tables keyed by (visit, detector).
Returns
-------
list of astropy.table.QTable
List of aggregated donut tables, one per visit pair.
"""
# Find common (visit, detector) extra-focal pairs
# DonutQualityTables only saved under extra-focal ids
extra_keys = set(donutRefDict) & set(qualityRefDict)
extra_keys = set(donutTables) & set(qualityTables)

# Raise error if there's no matches
if len(extra_keys) == 0:
Expand All @@ -275,6 +313,7 @@ def runQuantum(
"the donut and quality tables"
)

pairTables = []
for pair in pairs:
intraVisitInfo = visitInfoDict[pair.intra]
extraVisitInfo = visitInfoDict[pair.extra]
Expand All @@ -293,9 +332,9 @@ def runQuantum(
tform = det.getTransform(PIXELS, FIELD_ANGLE)

# Load the donut catalog table, and the donut quality table
intraDonutTable = butlerQC.get(donutRefDict[(pair.intra, detector)])
extraDonutTable = butlerQC.get(donutRefDict[(pair.extra, detector)])
qualityTable = butlerQC.get(qualityRefDict[(pair.extra, detector)])
intraDonutTable = donutTables[(pair.intra, detector)]
extraDonutTable = donutTables[(pair.extra, detector)]
qualityTable = qualityTables[(pair.extra, detector)]

# Get rows of quality table for this exposure
intraQualityTable = qualityTable[
Expand Down Expand Up @@ -336,9 +375,7 @@ def runQuantum(

out = vstack(tables)

# TODO: Swap parallactic angle for pseudo parallactic angle.
# See SMTN-019 for details.

# Add metadata for extra and intra focal exposures
out.meta["extra"] = {
"visit": pair.extra,
"focusZ": extraVisitInfo.focusZ,
Expand Down Expand Up @@ -389,22 +426,17 @@ def runQuantum(
out.meta["extra"]["mjd"] + out.meta["intra"]["mjd"]
)

# Calculate coordinates in different reference frames
q = out.meta["average"]["parallacticAngle"]
rtp = out.meta["average"]["rotTelPos"]
out["thx_OCS"] = np.cos(rtp) * out["thx_CCS"] - np.sin(rtp) * out["thy_CCS"]
out["thy_OCS"] = np.sin(rtp) * out["thx_CCS"] + np.cos(rtp) * out["thy_CCS"]
out["th_N"] = np.cos(q) * out["thx_CCS"] - np.sin(q) * out["thy_CCS"]
out["th_W"] = np.sin(q) * out["thx_CCS"] + np.cos(q) * out["thy_CCS"]

# Find the right output references
for outRef in outputRefs.aggregateDonutTable:
if outRef.dataId["visit"] == pair.extra:
butlerQC.put(out, outRef)
break
else:
raise ValueError(
f"Expected to find an output reference with visit {pair.extra}"
)
pairTables.append(out)

return pairTables


class AggregateAOSVisitTableTaskConnections(
Expand Down Expand Up @@ -471,6 +503,15 @@ def runQuantum(
azr = butlerQC.get(inputRefs.aggregateZernikesRaw)
aza = butlerQC.get(inputRefs.aggregateZernikesAvg)

avg_table, raw_table = self.run(adc, azr, aza)

butlerQC.put(avg_table, outputRefs.aggregateAOSAvg)
butlerQC.put(raw_table, outputRefs.aggregateAOSRaw)

@timeMethod
def run(
self, adc: typing.List[Table], azr: typing.List[Table], aza: typing.List[Table]
) -> tuple[Table, Table]:
dets = np.unique(adc["detector"])
avg_table = aza.copy()
avg_keys = [
Expand Down Expand Up @@ -528,8 +569,7 @@ def runQuantum(
raw_table[k + "_intra"][w] = adc[k][wadc][wintra]
raw_table[k + "_extra"][w] = adc[k][wadc][wextra]

butlerQC.put(avg_table, outputRefs.aggregateAOSAvg)
butlerQC.put(raw_table, outputRefs.aggregateAOSRaw)
return avg_table, raw_table


class AggregateDonutStampsTaskConnections(
Expand Down Expand Up @@ -601,19 +641,37 @@ def runQuantum(
inputRefs: pipeBase.InputQuantizedConnection,
outputRefs: pipeBase.OutputQuantizedConnection,
) -> None:

intraStampsList, extraStampsList, intraMeta, extraMeta = self.run(
butlerQC.get(inputRefs.donutStampsIntra),
butlerQC.get(inputRefs.donutStampsExtra),
butlerQC.get(inputRefs.qualityTables),
)

intraStampsListRavel = np.ravel(intraStampsList)
extraStampsListRavel = np.ravel(extraStampsList)

butlerQC.put(
DonutStamps(intraStampsListRavel, metadata=intraMeta),
outputRefs.donutStampsIntraVisit,
)

butlerQC.put(
DonutStamps(extraStampsListRavel, metadata=extraMeta),
outputRefs.donutStampsExtraVisit,
)

@timeMethod
def run(
self,
intraStamps: typing.List,
extraStamps: typing.List,
qualityTables: typing.List,
) -> tuple[typing.List, typing.List, dafBase.PropertyList, dafBase.PropertyList]:
intraStampsList = []
extraStampsList = []
for intraRef, extraRef, qualityRef in zip(
inputRefs.donutStampsIntra,
inputRefs.donutStampsExtra,
inputRefs.qualityTables,
):
# Load the donuts
intra = butlerQC.get(intraRef)
extra = butlerQC.get(extraRef)

for intra, extra, quality in zip(intraStamps, extraStamps, qualityTables):
# Load the quality table and determine which donuts were selected
quality = butlerQC.get(qualityRef)
intraSelect = quality[quality["DEFOCAL_TYPE"] == "intra"]["FINAL_SELECT"]
extraSelect = quality[quality["DEFOCAL_TYPE"] == "extra"]["FINAL_SELECT"]

Expand Down Expand Up @@ -649,15 +707,4 @@ def runQuantum(
intraStampsList.append(intra[: self.config.maxDonutsPerDetector])
extraStampsList.append(extra[: self.config.maxDonutsPerDetector])

intraStampsListRavel = np.ravel(intraStampsList)
extraStampsListRavel = np.ravel(extraStampsList)

butlerQC.put(
DonutStamps(intraStampsListRavel, metadata=intra.metadata),
outputRefs.donutStampsIntraVisit,
)

butlerQC.put(
DonutStamps(extraStampsListRavel, metadata=extra.metadata),
outputRefs.donutStampsExtraVisit,
)
return intraStampsList, extraStampsList, extra.metadata, intra.metadata

0 comments on commit 15b4f49

Please sign in to comment.