Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix PPVCube #187

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 24 additions & 59 deletions yt_astro_analysis/ppv_cube/ppv_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@
# The full license is in the file COPYING.txt, distributed with this software.
# -----------------------------------------------------------------------------

import re

import numpy as np

import yt.units.dimensions as ytdims
from yt.funcs import get_pbar, is_root, is_sequence
from yt.units.yt_array import YTQuantity
from yt.utilities.on_demand_imports import _astropy
from yt.utilities.orientation import Orientation
from yt.utilities.parallel_tools.parallel_analysis_interface import (
parallel_objects,
parallel_root_only,
Expand All @@ -29,33 +27,6 @@

from . import ppv_utils


def create_vlos(normal, no_shifting):
if no_shifting:

def _v_los(field, data):
return data.ds.arr(data["index", "zeros"], "cm/s")

elif isinstance(normal, str):

def _v_los(field, data):
return -data["gas", "velocity_%s" % normal]

else:
orient = Orientation(normal)
los_vec = orient.unit_vectors[2]

def _v_los(field, data):
vz = (
data["gas", "velocity_x"] * los_vec[0]
+ data["gas", "velocity_y"] * los_vec[1]
+ data["gas", "velocity_z"] * los_vec[2]
)
return -vz

return _v_los


fits_info = {
"velocity": ("m/s", "VOPT", "v"),
"frequency": ("Hz", "FREQ", "f"),
Expand All @@ -77,7 +48,6 @@ def __init__(
thermal_broad=False,
atomic_weight=56.0,
depth=(1.0, "unitary"),
depth_res=256,
method="integrate",
weight_field=None,
no_shifting=False,
Expand Down Expand Up @@ -127,9 +97,6 @@ def __init__(
A tuple containing the depth to project through and the string
key of the unit: (width, 'unit'). If set to a float, code units
are assumed. Only for off-axis cubes.
depth_res : integer, optional
Deprecated, this is still in the function signature for API
compatibility
method : string, optional
Set the projection method to be used.
"integrate" : line of sight integration over the line element.
Expand Down Expand Up @@ -162,7 +129,6 @@ def __init__(
"""

self.ds = ds
self.field = field
self.width = width
self.particle_mass = atomic_weight * mh
self.thermal_broad = thermal_broad
Expand Down Expand Up @@ -196,6 +162,8 @@ def __init__(

dd = ds.all_data()
fd = dd._determine_fields(field)[0]
self.field = fd
ftype = fd[0]
self.field_units = ds._get_field_info(fd).units

self.vbins = ds.arr(
Expand All @@ -211,37 +179,37 @@ def __init__(

self.current_v = 0.0

_vlos = create_vlos(normal, self.no_shifting)
self.ds.add_field(
("gas", "v_los"), function=_vlos, units="cm/s", sampling_type="cell"
)
if self.no_shifting:
self.velocity_field = (ftype, "zeros")
else:
self.velocity_field = (ftype, "velocity_los")

_intensity = self._create_intensity()
_intensity = self._create_intensity(ftype=ftype)
self.ds.add_field(
("gas", "intensity"),
(ftype, "intensity"),
function=_intensity,
units=self.field_units,
sampling_type="cell",
sampling_type="local",
)

if method == "integrate" and weight_field is None:
self.proj_units = str(ds.quan(1.0, self.field_units + "*cm").units)
elif method == "sum":
else:
self.proj_units = self.field_units

storage = {}
pbar = get_pbar("Generating cube.", self.nv)
pbar = get_pbar("Generating cube", self.nv)
for sto, i in parallel_objects(range(self.nv), storage=storage):
self.current_v = self.vmid_cgs[i]
if isinstance(normal, str):
prj = ds.proj(
"intensity",
(ftype, "intensity"),
ds.coordinates.axis_id[normal],
method=method,
weight_field=weight_field,
data_source=data_source,
)
buf = prj.to_frb(width, self.nx, center=self.center)["intensity"]
buf = prj.to_frb(width, self.nx, center=self.center)[ftype, "intensity"]
else:
if data_source is None:
source = ds
Expand All @@ -253,7 +221,7 @@ def __init__(
normal,
width,
(self.nx, self.ny),
"intensity",
(ftype, "intensity"),
north_vector=north_vector,
no_ghost=no_ghost,
method=method,
Expand All @@ -277,8 +245,7 @@ def __init__(
elif not isinstance(self.width, YTQuantity):
self.width = ds.quan(self.width, "code_length")

self.ds.field_info.pop(("gas", "intensity"))
self.ds.field_info.pop(("gas", "v_los"))
self.ds.field_info.pop((ftype, "intensity"))

def transform_spectral_axis(self, rest_value, units):
"""
Expand Down Expand Up @@ -344,8 +311,8 @@ def write_fits(

Notes
-----
Additional keyword arguments are passed to
:meth:`~astropy.io.fits.HDUList.writeto`.
All other optional arguments are passed to
:meth:`~yt.visualization.fits_image.FITSImageData.create_sky_wcs`.

Examples
--------
Expand Down Expand Up @@ -373,12 +340,10 @@ def write_fits(
w.wcs.cunit = [units, units, vunit]
w.wcs.ctype = ["LINEAR", "LINEAR", vtype]

fib = FITSImageData(self.data.transpose(), fields=self.field, wcs=w)
fib.update_all_headers("bunit", re.sub("()", "", str(self.proj_units)))
fib.update_all_headers("btype", self.field)
fib = FITSImageData(self.data, fields=self.field, wcs=w)
if sky_scale is not None and sky_center is not None:
fib.create_sky_wcs(sky_center, sky_scale)
fib.writeto(filename, overwrite=overwrite, **kwargs)
fib.create_sky_wcs(sky_center, sky_scale, **kwargs)
fib.writeto(filename, overwrite=overwrite)

def __repr__(self):
return "PPVCube [%d %d %d] (%s < %s < %s)" % (
Expand All @@ -393,12 +358,12 @@ def __repr__(self):
def __getitem__(self, item):
return self.data[item]

def _create_intensity(self):
def _create_intensity(self, ftype="gas"):
if self.thermal_broad:

def _intensity(field, data):
v = self.current_v - data["gas", "v_los"].in_cgs().v
T = data["gas", "temperature"].in_cgs().v
v = self.current_v - data[self.velocity_field].in_cgs().v
T = data[ftype, "temperature"].in_cgs().v
w = ppv_utils.compute_weight(
self.thermal_broad,
self.dv_cgs,
Expand All @@ -414,7 +379,7 @@ def _intensity(field, data):
def _intensity(field, data):
w = (
1.0
- np.fabs(self.current_v - data["gas", "v_los"].in_cgs().v)
- np.fabs(self.current_v - data[self.velocity_field].in_cgs().v)
/ self.dv_cgs
)
w[w < 0.0] = 0.0
Expand Down