diff --git a/changelog/780.bugfix.rst b/changelog/780.bugfix.rst new file mode 100644 index 000000000..f6c872dd6 --- /dev/null +++ b/changelog/780.bugfix.rst @@ -0,0 +1 @@ +Added an internal method to shortcut non-correlated axes avoiding the creation of a full coordinate grid, reducing memory use in specific circumstances. diff --git a/ndcube/conftest.py b/ndcube/conftest.py index ad34271d0..8960e8879 100644 --- a/ndcube/conftest.py +++ b/ndcube/conftest.py @@ -197,6 +197,30 @@ def wcs_3d_lt_ln_l(): return WCS(header=header) +@pytest.fixture +def wcs_3d_wave_lt_ln(): + header = { + 'CTYPE1': 'WAVE ', + 'CUNIT1': 'Angstrom', + 'CDELT1': 0.2, + 'CRPIX1': 0, + 'CRVAL1': 10, + + 'CTYPE2': 'HPLT-TAN', + 'CUNIT2': 'deg', + 'CDELT2': 0.5, + 'CRPIX2': 2, + 'CRVAL2': 0.5, + + 'CTYPE3': 'HPLN-TAN ', + 'CUNIT3': 'deg', + 'CDELT3': 0.4, + 'CRPIX3': 2, + 'CRVAL3': 1, + } + return WCS(header=header) + + @pytest.fixture def wcs_2d_lt_ln(): spatial = { @@ -445,6 +469,24 @@ def ndcube_3d_ln_lt_l_ec_time(wcs_3d_l_lt_ln, time_and_simple_extra_coords_2d): return cube +@pytest.fixture +def ndcube_3d_wave_lt_ln_ec_time(wcs_3d_wave_lt_ln): + shape = (3, 4, 5) + wcs_3d_wave_lt_ln.array_shape = shape + data = data_nd(shape) + mask = data > 0 + cube = NDCube( + data, + wcs_3d_wave_lt_ln, + mask=mask, + uncertainty=data, + ) + base_time = Time('2000-01-01', format='fits', scale='utc') + timestamps = Time([base_time + TimeDelta(60 * i, format='sec') for i in range(data.shape[0])]) + cube.extra_coords.add('time', 0, timestamps) + return cube + + @pytest.fixture def ndcube_3d_rotated(wcs_3d_ln_lt_t_rotated, simple_extra_coords_3d): data_rotated = np.array([[[1, 2, 3, 4, 6], [2, 4, 5, 3, 1], [0, -1, 2, 4, 2], [3, 5, 1, 2, 0]], diff --git a/ndcube/ndcube.py b/ndcube/ndcube.py index 29d258977..6e61cbc6f 100644 --- a/ndcube/ndcube.py +++ b/ndcube/ndcube.py @@ -12,6 +12,9 @@ import astropy.nddata import astropy.units as u from astropy.units import UnitsError +from astropy.wcs.utils import _split_matrix + +from ndcube.utils.wcs import world_axis_to_pixel_axes try: # Import sunpy coordinates if available to register the frames and WCS functions with astropy @@ -20,7 +23,6 @@ pass from astropy.wcs import WCS -from astropy.wcs.utils import _split_matrix from astropy.wcs.wcsapi import BaseHighLevelWCS, HighLevelWCSWrapper from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects @@ -443,9 +445,33 @@ def quantity(self): """Unitful representation of the NDCube data.""" return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED) - def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units): - # Create meshgrid of all pixel coordinates. - # If user wants pixel_corners, set pixel values to pixel corners. + + def _generate_independent_world_coords(self, pixel_corners, wcs, pixel_axes, units): + naxes = len(self.data.shape) + pixel_indices = [np.array([0], dtype=int).reshape([1] * naxes).squeeze()] * naxes + for pixel_axis in pixel_axes: + len_axis = self.data.shape[::-1][pixel_axis] + # Define limits of desired pixel range based on whether corners or centers are desired + lims = (-0.5, len_axis + 1) if pixel_corners else (0, len_axis) + pix_ind = np.arange(lims[0], lims[1]) + shape = [1] * naxes + shape[pixel_axis] = len(pix_ind) + pixel_indices[pixel_axis] = pix_ind.reshape(shape) + world_coords = wcs.pixel_to_world_values(*pixel_indices) + # TODO: Remove NaNs??? These should not be here + if np.isnan(world_coords).any(): + if isinstance(world_coords, tuple| list): + world_coords = [world_coord[~np.isnan(world_coord)] for world_coord in world_coords] + else: + world_coords = world_coords[~np.isnan(world_coords)] + if units: + mod = abs(wcs.world_n_dim - naxes) if wcs.world_n_dim > naxes else 0 + world_coords = world_coords << u.Unit(wcs.world_axis_units[np.squeeze(pixel_axes)+mod]) + return world_coords + + def _generate_dependent_world_coords(self, pixel_corners, wcs, pixel_axes, units): + # Create a meshgrid of all pixel coordinates. + # If the user wants pixel corners, set pixel values to pixel corners. # Else make pixel centers. pixel_shape = self.data.shape[::-1] if pixel_corners: @@ -453,21 +479,13 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units) ranges = [np.arange(i) - 0.5 for i in pixel_shape] else: ranges = [np.arange(i) for i in pixel_shape] - - # Limit the pixel dimensions to the ones present in the ExtraCoords - if isinstance(wcs, ExtraCoords): - ranges = [ranges[i] for i in wcs.mapping] - wcs = wcs.wcs - if wcs is None: - return [] - # This value of zero will be returned as a throwaway for unneeded axes, and a numerical value is # required so values_to_high_level_objects in the calling function doesn't crash or warn world_coords = [0] * wcs.world_n_dim for (pixel_axes_indices, world_axes_indices) in _split_matrix(wcs.axis_correlation_matrix): - if (needed_axes is not None - and len(needed_axes) - and not any(world_axis in needed_axes for world_axis in world_axes_indices)): + if (pixel_axes is not None + and len(pixel_axes) + and not any(world_axis in pixel_axes for world_axis in world_axes_indices)): # needed_axes indicates which values in world_coords will be used by the calling # function, so skip this iteration if we won't be producing any of those values continue @@ -492,47 +510,67 @@ def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units) array_slice[wcs.axis_correlation_matrix[idx]] = slice(None) tmp_world = world[idx][tuple(array_slice)].T world_coords[idx] = tmp_world - if units: for i, (coord, unit) in enumerate(zip(world_coords, wcs.world_axis_units)): world_coords[i] = coord << u.Unit(unit) + return world_coords + + def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes=None, units=None): + # TODO: Workout why I need this twice now. + if isinstance(wcs, ExtraCoords): + wcs = wcs.wcs + if not wcs: + return () + if needed_axes is None or len(needed_axes) == 0: + needed_axes = np.array(list(range(wcs.world_n_dim)),dtype=int) + axes_are_independent = [] + pixel_axes = set() + for world_axis in needed_axes: + pix_ax = world_axis_to_pixel_axes(world_axis, wcs.axis_correlation_matrix) + axes_are_independent.append(len(pix_ax) == 1) + pixel_axes = pixel_axes.union(set(pix_ax)) + if len(pixel_axes) == 1: + pixel_axes = list(pixel_axes) + if all(axes_are_independent) and len(pixel_axes) == len(needed_axes): + world_coords = self._generate_independent_world_coords(pixel_corners, wcs, pixel_axes, units) + else: + world_coords = self._generate_dependent_world_coords(pixel_corners, wcs, pixel_axes, units) + if len(world_coords) > 1 and isinstance(world_coords, tuple | list): + world_coords = [np.squeeze(world_coord) for world_coord in world_coords] + else: + world_coords = np.squeeze(world_coords) return world_coords + @utils.cube.sanitize_wcs def axis_world_coords(self, *axes, pixel_corners=False, wcs=None): # Docstring in NDCubeABC. if isinstance(wcs, BaseHighLevelWCS): wcs = wcs.low_level_wcs - orig_wcs = wcs if isinstance(wcs, ExtraCoords): wcs = wcs.wcs if not wcs: return () - object_names = np.array([wao_comp[0] for wao_comp in wcs.world_axis_object_components]) unique_obj_names = utils.misc.unique_sorted(object_names) world_axes_for_obj = [np.where(object_names == name)[0] for name in unique_obj_names] - # Create a mapping from world index in the WCS to object index in axes_coords world_index_to_object_index = {} for object_index, world_axes in enumerate(world_axes_for_obj): for world_index in world_axes: world_index_to_object_index[world_index] = object_index - world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes) object_indices = utils.misc.unique_sorted( [world_index_to_object_index[world_index] for world_index in world_indices] ) - - axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=False) - + axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=False) + if not isinstance(axes_coords, list): + axes_coords = [axes_coords] axes_coords = values_to_high_level_objects(*axes_coords, low_level_wcs=wcs) - if not axes: return tuple(axes_coords) - return tuple(axes_coords[i] for i in object_indices) @utils.cube.sanitize_wcs @@ -540,23 +578,19 @@ def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None): # Docstring in NDCubeABC. if isinstance(wcs, BaseHighLevelWCS): wcs = wcs.low_level_wcs - orig_wcs = wcs if isinstance(wcs, ExtraCoords): wcs = wcs.wcs - + if not wcs: + return () world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes) - - axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, world_indices, units=True) - + axes_coords = self._generate_world_coords(pixel_corners, orig_wcs, needed_axes=world_indices, units=True) world_axis_physical_types = wcs.world_axis_physical_types - # If user has supplied axes, extract only the # world coords that correspond to those axes. if axes: axes_coords = [axes_coords[i] for i in world_indices] world_axis_physical_types = tuple(np.array(world_axis_physical_types)[world_indices]) - # Return in array order. # First replace characters in physical types forbidden for namedtuple identifiers. identifiers = [] @@ -566,7 +600,8 @@ def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None): identifier = identifier.replace("-", "__") identifiers.append(identifier) CoordValues = namedtuple("CoordValues", identifiers) - return CoordValues(*axes_coords[::-1]) + flag = len(axes_coords) == 1 or isinstance(axes_coords, tuple | list) + return CoordValues(*axes_coords[::-1]) if flag else CoordValues(axes_coords) def crop(self, *points, wcs=None, keepdims=False): # The docstring is defined in NDCubeABC diff --git a/ndcube/tests/test_ndcube.py b/ndcube/tests/test_ndcube.py index e1fab7d05..0a2d05ec8 100644 --- a/ndcube/tests/test_ndcube.py +++ b/ndcube/tests/test_ndcube.py @@ -12,6 +12,7 @@ from astropy.coordinates import SkyCoord, SpectralCoord from astropy.io import fits from astropy.nddata import UnknownUncertainty +from astropy.tests.helper import assert_quantity_allclose from astropy.time import Time from astropy.units import UnitsError from astropy.wcs import WCS @@ -235,13 +236,31 @@ def test_axis_world_coords_single(axes, ndcube_3d_ln_lt_l): assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) +def test_axis_world_coords_combined_wcs(ndcube_3d_wave_lt_ln_ec_time): + # This replicates a specific NDCube object in visualization.rst + coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs) + assert len(coords) == 1 + assert isinstance(coords[0], Time) + assert np.all(coords[0] == Time(['2000-01-01T00:00:00.000', '2000-01-01T00:01:00.000', '2000-01-01T00:02:00.000'])) + + coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords_values('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs) + assert len(coords) == 1 + assert isinstance(coords.time, u.Quantity) + assert_quantity_allclose(coords.time, [0, 60, 120] * u.second) + + @pytest.mark.parametrize("axes", [[-1], [2], ["em"]]) def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l): + + # We go from 4 pixels to 6 pixels when we add pixel corners + coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=False) + assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) + coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=True) - assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m) + assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m) coords = ndcube_3d_ln_lt_l.axis_world_coords(*axes, pixel_corners=True) - assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m) + assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m) @pytest.mark.parametrize(("ndc", "item"), @@ -252,10 +271,10 @@ def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l): indirect=("ndc",)) def test_axis_world_coords_sliced_all_3d(ndc, item): coords = ndc[item].axis_world_coords_values() - assert u.allclose(coords, [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) + assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) coords = ndc[item].axis_world_coords() - assert u.allclose(coords, [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) + assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m) @pytest.mark.parametrize(("ndc", "item"),