diff --git a/python/lsst/donut/viz/aggregate_visit.py b/python/lsst/donut/viz/aggregate_visit.py index 33a52b8..125d596 100644 --- a/python/lsst/donut/viz/aggregate_visit.py +++ b/python/lsst/donut/viz/aggregate_visit.py @@ -3,7 +3,8 @@ import lsst.pipe.base as pipeBase import numpy as np from astropy import units as u -from astropy.table import Table, vstack +from astropy.coordinates import Angle +from astropy.table import Table, QTable, vstack from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS from lsst.geom import Point2D, radians from lsst.pipe.base import connectionTypes as ct @@ -25,6 +26,9 @@ "AggregateDonutStampsTaskConnections", "AggregateDonutStampsTaskConfig", "AggregateDonutStampsTask", + "AggregateZernikesTaskConnections", + "AggregateZernikesTaskConfig", + "AggregateZernikesTask", ] @@ -661,3 +665,174 @@ def runQuantum( DonutStamps(extraStampsListRavel, metadata=extra.metadata), outputRefs.donutStampsExtraVisit, ) + + +def is_integer(s): + try: + int(s) + return True + except ValueError: + return False + + +def is_zk_column(s): + return s[0] == "Z" and is_integer(s[1:]) + + +class AggregateZernikesTaskConnections( + pipeBase.PipelineTaskConnections, + dimensions=("instrument", "visit"), +): + zernikes = ct.Input( + doc="Zernike Coefficients from all donuts", + dimensions=("visit", "detector", "instrument"), + storageClass="AstropyQTable", + name="zernikes", + multiple=True, + deferGraphConstraint=True, + ) + camera = ct.PrerequisiteInput( + name="camera", + storageClass="Camera", + doc="Input camera to construct complete exposures.", + dimensions=["instrument"], + isCalibration=True, + ) + aggregateZernikes = ct.Output( + doc="Visit-level table of Zernikes", + dimensions=("visit", "instrument"), + storageClass="AstropyQTable", + name="aggregateZernikes", + ) + + +class AggregateZernikesTaskConfig( + pipeBase.PipelineTaskConfig, + pipelineConnections=AggregateZernikesTaskConnections, +): + pass + + +class AggregateZernikesTask(pipeBase.PipelineTask): + ConfigClass = AggregateZernikesTaskConfig + _DefaultName = "AggregateZernikes" + + def runQuantum( + self, + butlerQC: pipeBase.QuantumContext, + inputRefs: pipeBase.InputQuantizedConnection, + outputRefs: pipeBase.OutputQuantizedConnection, + ): + camera = butlerQC.get(inputRefs.camera) + + zernike_tables = [] + for zernikesRef in inputRefs.zernikes: + det_num = zernikesRef.dataId["detector"] + zernike_table = butlerQC.get(zernikesRef) + zernike_table["det_num"] = det_num + zernike_table["det_name"] = camera[det_num].getName() + zernike_tables.append(zernike_table) + + # Get Noll indices used + noll_indices = sorted( + [ + int(name[1:]) for name in zernike_tables[0].columns if is_zk_column(name) + ] + ) + + # Assemble compound dtype + zk_dtype = np.dtype( + [ + (f"Z{j}", np.float32) + for j in noll_indices + ] + ) + + # assemble output table + out_tables = [] + for zernike_table in zernike_tables: + out = QTable() + # Temporary storage for the zernikes so we can place them in the last column + zk_CCS = np.zeros(len(zernike_table), dtype=zk_dtype) * u.micron + for col, val in zernike_table.columns.items(): + if is_zk_column(col): + j = int(col[1:]) + zk_CCS[f'Z{j}'] = val + else: + out[col] = val + out['zk_CCS'] = zk_CCS + out_tables.append(out) + out_table = vstack(out_tables) + + # Get meta from one of the input tables. They're all the same aside from the + # detector. + table_meta = zernike_tables[0].meta + + out.meta["extra"] = { + "visit": table_meta["extra"]["visit"], + "parallacticAngle": table_meta["extra"]["boresight_par_angle_rad"]*u.rad, + "rotAngle": table_meta["extra"]["boresight_rot_angle_rad"]*u.rad, + "rotTelPos": table_meta["extra"]["boresight_par_angle_rad"]*u.rad + - table_meta["extra"]["boresight_rot_angle_rad"]*u.rad + - (np.pi / 2)*u.rad, + "ra": table_meta["extra"]["boresight_ra_rad"]*u.rad, + "dec": table_meta["extra"]["boresight_dec_rad"]*u.rad, + "az": table_meta["extra"]["boresight_az_rad"]*u.rad, + "alt": table_meta["extra"]["boresight_alt_rad"]*u.rad, + "mjd": table_meta["extra"]["mjd"], + } + out.meta["intra"] = { + "visit": table_meta["intra"]["visit"], + "parallacticAngle": table_meta["intra"]["boresight_par_angle_rad"]*u.rad, + "rotAngle": table_meta["intra"]["boresight_rot_angle_rad"]*u.rad, + "rotTelPos": table_meta["intra"]["boresight_par_angle_rad"]*u.rad + - table_meta["intra"]["boresight_rot_angle_rad"]*u.rad + - (np.pi / 2)*u.rad, + "ra": table_meta["intra"]["boresight_ra_rad"]*u.rad, + "dec": table_meta["intra"]["boresight_dec_rad"]*u.rad, + "az": table_meta["intra"]["boresight_az_rad"]*u.rad, + "alt": table_meta["intra"]["boresight_alt_rad"]*u.rad, + "mjd": table_meta["intra"]["mjd"], + } + + # Carefully average angles in meta + out.meta["average"] = {} + for k in ( + "parallacticAngle", + "rotAngle", + "rotTelPos", + "ra", + "dec", + "az", + "alt", + ): + # Use the DM angle handling to wrap + a1 = out.meta["extra"][k].to_value(u.rad) * radians + a2 = out.meta["intra"][k].to_value(u.rad) * radians + a2 = a2.wrapNear(a1) + out.meta["average"][k] = ((a1 + a2) / 2).wrapCtr().asRadians() * u.rad + + # Easier to average the MJDs + out.meta["average"]["mjd"] = 0.5 * ( + out.meta["extra"]["mjd"] + out.meta["intra"]["mjd"] + ) + + q = out.meta["average"]["parallacticAngle"] + rtp = out.meta["average"]["rotTelPos"] + + jmax = np.max(noll_indices) + zk_CCS_full = np.zeros((len(out), jmax + 1)) + for j in noll_indices: + zk_CCS_full[:, j] = out["zk_CCS"]["Z" + str(j)].to_value(u.um) + rot_OCS = galsim.zernike.zernikeRotMatrix(jmax, -rtp) + rot_NW = galsim.zernike.zernikeRotMatrix(jmax, -q) + + zk_OCS_full = np.dot(zk_CCS_full, rot_OCS) + zk_NW_full = np.dot(zk_CCS_full, rot_NW) + out["zk_OCS"] = np.zeros(len(out), dtype=zk_dtype) * u.micron + out["zk_NW"] = np.zeros(len(out), dtype=zk_dtype) * u.micron + for j in noll_indices: + out["zk_OCS"]["Z" + str(j)] = zk_OCS_full[:, j] * u.micron + out["zk_NW"]["Z" + str(j)] = zk_NW_full[:, j] * u.micron + + butlerQC.put(out, outputRefs.aggregateZernikes)