Skip to content

Commit

Permalink
Merge pull request #767 from svank/faster-axis_world_coords
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadair authored Oct 16, 2024
2 parents 81d3dcb + 0e1d41f commit 84f6f62
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 21 deletions.
3 changes: 3 additions & 0 deletions changelog/767.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
When calling :meth:`ndcube.NDCubeBase.axis_world_coords` or :meth:`ndcube.NDCubeBase.axis_world_coords_values` with a
specific axis or axes specified, the methods now avoid doing calculations for any other uncorrelated axes, offering
significant speedups when those other axes are large.
45 changes: 24 additions & 21 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,16 +443,9 @@ def quantity(self):
"""Unitful representation of the NDCube data."""
return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED)

def _generate_world_coords(self, pixel_corners, wcs):
# TODO: We can improve this by not always generating all coordinates
# To make our lives easier here we generate all the coordinates for all
# pixels and then choose the ones we want to return to the user based
# on the axes argument. We could be smarter by integrating this logic
# into the main loop, this would potentially reduce the number of calls
# to pixel_to_world_values

def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None):
# Create meshgrid of all pixel coordinates.
# If user, wants pixel_corners, set pixel values to pixel pixel_corners.
# If user wants pixel_corners, set pixel values to pixel corners.
# Else make pixel centers.
pixel_shape = self.data.shape[::-1]
if pixel_corners:
Expand All @@ -468,8 +461,16 @@ def _generate_world_coords(self, pixel_corners, wcs):
if wcs is None:
return []

world_coords = [None] * wcs.world_n_dim
# This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is
# required so values_to_high_level_objects in the calling function doesn't crash or warn
world_coords = [0] * wcs.world_n_dim
for (pixel_axes_indices, world_axes_indices) in _split_matrix(wcs.axis_correlation_matrix):
if (needed_axes is not None
and len(needed_axes)
and not any(world_axis in needed_axes for world_axis in world_axes_indices)):
# needed_axes indicates which values in world_coords will be used by the calling
# function, so skip this iteration if we won't be producing any of those values
continue
# First construct a range of pixel indices for this set of coupled dimensions
sub_range = [ranges[idx] for idx in pixel_axes_indices]
# Then get a set of non correlated dimensions
Expand Down Expand Up @@ -499,23 +500,16 @@ def _generate_world_coords(self, pixel_corners, wcs):

@utils.cube.sanitize_wcs
def axis_world_coords(self, *axes, pixel_corners=False, wcs=None):

# Docstring in NDCubeABC.
if isinstance(wcs, BaseHighLevelWCS):
wcs = wcs.low_level_wcs

axes_coords = self._generate_world_coords(pixel_corners, wcs)

orig_wcs = wcs
if isinstance(wcs, ExtraCoords):
wcs = wcs.wcs
if not wcs:
return tuple()

axes_coords = values_to_high_level_objects(*axes_coords, low_level_wcs=wcs)

if not axes:
return tuple(axes_coords)

object_names = np.array([wao_comp[0] for wao_comp in wcs.world_axis_object_components])
unique_obj_names = utils.misc.unique_sorted(object_names)
world_axes_for_obj = [np.where(object_names == name)[0] for name in unique_obj_names]
Expand All @@ -531,6 +525,13 @@ def axis_world_coords(self, *axes, pixel_corners=False, wcs=None):
[world_index_to_object_index[world_index] for world_index in world_indices]
)

axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices)

axes_coords = values_to_high_level_objects(*axes_coords, low_level_wcs=wcs)

if not axes:
return tuple(axes_coords)

return tuple(axes_coords[i] for i in object_indices)

@utils.cube.sanitize_wcs
Expand All @@ -539,17 +540,19 @@ def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None):
if isinstance(wcs, BaseHighLevelWCS):
wcs = wcs.low_level_wcs

axes_coords = self._generate_world_coords(pixel_corners, wcs)

orig_wcs = wcs
if isinstance(wcs, ExtraCoords):
wcs = wcs.wcs

world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)

axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices)

world_axis_physical_types = wcs.world_axis_physical_types

# If user has supplied axes, extract only the
# world coords that correspond to those axes.
if axes:
world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)
axes_coords = [axes_coords[i] for i in world_indices]
world_axis_physical_types = tuple(np.array(world_axis_physical_types)[world_indices])

Expand Down

0 comments on commit 84f6f62

Please sign in to comment.