Skip to content

Commit

Permalink
Feature branch that add PSF-realted plots.
Browse files Browse the repository at this point in the history
  • Loading branch information
LR-inaf committed Nov 15, 2024
1 parent 6674256 commit 2c9fd8f
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 0 deletions.
138 changes: 138 additions & 0 deletions python/lsst/donut/viz/plot_aos_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import lsst.pipe.base.connectionTypes as ct
import matplotlib.pyplot as plt
import numpy as np
from astropy import units as u
import yaml
from lsst.utils.timer import timeMethod
from lsst.ts.wep.utils import convertZernikesToPsfWidth

from .utilities import (
add_rotated_axis,
Expand All @@ -17,6 +19,7 @@
rose,
)
from .zernike_pyramid import zernikePyramid
from .psf_from_zern import psfPanel

try:
from lsst.rubintv.production.uploaders import MultiUploader
Expand All @@ -30,6 +33,9 @@
"PlotDonutTaskConnections",
"PlotDonutTaskConfig",
"PlotDonutTask",
"PlotPsfZernTaskConnections",
"PlotPsfZernTaskConfig",
"PlotPsfZernTask",
]


Expand Down Expand Up @@ -380,3 +386,135 @@ def runQuantum(
seqNum=seq_num,
filename=donut_gallery_fn,
)

class PlotPsfZernTaskConnections(
pipeBase.PipelineTaskConnections,
dimensions=("visit", "instrument"),
):
zernikes = ct.Input(
doc="Zernikes catalog",
dimensions=("visit", "instrument", "detector"),
storageClass="AstropyTable",
multiple=True,
name="zernikes",
)
psfFromZernPanel = ct.Output(
doc="PSF value retrieved from zernikes",
dimensions=("visit", "instrument"),
storageClass="Plot",
name="psfFromZernPanel",
)


class PlotPsfZernTaskConfig(
pipeBase.PipelineTaskConfig,
pipelineConnections=PlotPsfZernTaskConnections,
):
doRubinTVUpload = pexConfig.Field(
dtype=bool,
doc="Upload to RubinTV",
default=False,
)


class PlotPsfZernTask(pipeBase.PipelineTask):
ConfigClass = PlotPsfZernTaskConfig
_DefaultName = "plotPsfZernTask"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if self.config.doRubinTVUpload:
if not MultiUploader:
raise RuntimeError("MultiUploader is not available")
self.uploader = MultiUploader()

@timeMethod
def runQuantum(
self,
butlerQC: pipeBase.QuantumContext,
inputRefs: pipeBase.InputQuantizedConnection,
outputRefs: pipeBase.OutputQuantizedConnection,
) -> None:
zernikes = butlerQC.get(inputRefs.zernikes)

zkPanel = self.plotPsfFromZern(zernikes)
zkPanel.suptitle(
f"PSF from Zernikes\nvisit: {inputRefs.zernikes[-1].dataId['visit']}",
fontsize="xx-large",
fontweight="book"
)

butlerQC.put(zkPanel, outputRefs.psfFromZernPanel)

if self.config.doRubinTVUpload:
instrument = inputRefs.zernikes.dataId["instrument"]
visit = inputRefs.zernikes.dataId["visit"]
day_obs, seq_num = get_day_obs_seq_num_from_visitid(visit)
with tempfile.TemporaryDirectory() as tmpdir:
psf_zk_panel = Path(tmpdir) / "psf_zk_panel.png"
zkPanel.savefig(psf_zk_panel)

self.uploader.uploadPerSeqNumPlot(
instrument=get_instrument_channel_name(instrument),
plotName="psf_zk_panel",
dayObs=day_obs,
seqNum=seq_num,
filename=zk_meas_fn,
)

def get_psf_degr(self, zset):
return np.sqrt(np.sum(convertZernikesToPsfWidth(zset)**2))

def get_rtp_q(self, qtable):
q = qtable.meta["extra"]["boresight_par_angle_rad"]
rot = qtable.meta["extra"]["boresight_rot_angle_rad"]
rtp = q - rot - np.pi / 2
return rtp, q

def get_rose_vecs(self, rtp, q):
vecs_xy = {
r"$x_\mathrm{Opt}$": (1, 0),
r"$y_\mathrm{Opt}$": (0, -1),
r"$x_\mathrm{Cam}$": (np.cos(rtp), -np.sin(rtp)),
r"$y_\mathrm{Cam}$": (-np.sin(rtp), -np.cos(rtp)),
}

vecs_NE = {
"az": (1, 0),
"alt": (0, +1),
"N": (np.sin(q), np.cos(q)),
"E": (np.sin(q - np.pi / 2), np.cos(q - np.pi / 2)),
}

return vecs_xy, vecs_NE

def plotPsfFromZern(self, zernikes):
xs = []
ys = []
zs = []
dname = []
for i, qt in enumerate(zernikes):
dname.append(qt.meta["extra"]["det_name"])
xs.append(qt["extra_centroid"]["x"][1:].value)
ys.append(qt["extra_centroid"]["y"][1:].value)
z = []
for row in qt[[col for col in qt.colnames if "Z" in col]][1:].iterrows():
z.append([el.to(u.micron).value for el in row])
zs.append(np.array(z))

xs = np.array(xs)
ys = np.array(ys)
zs = np.array(zs)
psf = np.array([[self.get_psf_degr(pair) for pair in det] for det in zs])

fig = psfPanel(xs, ys, psf, dname)

# draw rose
rtp, q = self.get_rtp_q(zernikes[-1])
vecs_xy, vecs_NE = self.get_rose_vecs(rtp, q)
rose(fig, vecs_xy, p0=(0.15, 0.94))
rose(fig, vecs_NE, p0=(0.85, 0.94))

return fig

71 changes: 71 additions & 0 deletions python/lsst/donut/viz/psf_from_zern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
from matplotlib.figure import Figure
from matplotlib.gridspec import GridSpec

def psfPanel(
xs,
ys,
psf,
detname,
fig=None,
figsize=(11,14),
cmap="cool",
**kwargs
) -> Figure:
"""Make a per-detector psf scatter plot
Subplots shows for each detector the psf retrieve from the zernike value
for each pair of intra-extra focal images. The points are placed using
pixel coordinates.
Parameters
----------
xs, ys: array of float, shape (ndet, npair)
Points coordinates in pixel.
psf: array of float, shape (ndet, npair)
PSF value for each point.
detname: list of strings, shape (ndet,)
Detector labels.
fig: matplotlib Figure, optional
If provided, use this figure. Default None.
figsize: tuple of float, optional
Figure size in inches. Default (11, 12).
cmap: str, optional
Colormap name. Default 'seismic'.
**kwargs:
Additional keyword arguments passed to matplotlib Figure constructor.
Returns
-------
fig: matplotlib Figure
The figure.
"""

# generating figure if None
if fig is None:
fig = Figure(figsize=figsize, **kwargs)

# creating the gridspec grid (3x3 equal axes and the bottom cbar ax)
gs = GridSpec(nrows=4, ncols=3, figure=fig, width_ratios=[1, 1, 1], height_ratios=[1, 1, 1, 0.1])
axs = [fig.add_subplot(gs[i, j]) for i in range(3) for j in range(3)]
ax_cbar = fig.add_subplot(gs[-1, :])

# setting the detector size (maybe there is a more wise way to retrieve it from the data metadata)
det_lim_y = (0., 4000.)
det_lim_x = (0., 4072.)

# setting the common colormap limits
pmax = np.nanmax(psf)
pmin = np.nanmin(psf)

# cycling through the axes.
for i, ax in enumerate(axs):
im = ax.scatter(xs[i], ys[i], c=psf[i], cmap=cmap, vmax=pmax, vmin=pmin)
ax.set_title(f"{detname[i]}: {np.nanmean(psf[i]):.3f} +/- {np.nanstd(psf[i]):.3f}")
ax.set(xlim=det_lim_x, ylim=det_lim_y, xticks=[], yticks=[], aspect="equal")

# setting the colorbar
cb = fig.colorbar(im, cax=ax_cbar, location="bottom")
cb.set_label(label="PSF width, arcsecond", fontsize="large")

return fig

0 comments on commit 2c9fd8f

Please sign in to comment.