Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster shortcut for working out coordinates values for non-correlated WCS #780

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/780.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added an internal method to shortcut non-correlated axes avoiding the creation of a full coordinate grid, reducing memory use in specific circumstances.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Better changelog.

42 changes: 42 additions & 0 deletions ndcube/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,30 @@ def wcs_3d_lt_ln_l():
return WCS(header=header)


@pytest.fixture
def wcs_3d_wave_lt_ln():
header = {
'CTYPE1': 'WAVE ',
'CUNIT1': 'Angstrom',
'CDELT1': 0.2,
'CRPIX1': 0,
'CRVAL1': 10,

'CTYPE2': 'HPLT-TAN',
'CUNIT2': 'deg',
'CDELT2': 0.5,
'CRPIX2': 2,
'CRVAL2': 0.5,

'CTYPE3': 'HPLN-TAN ',
'CUNIT3': 'deg',
'CDELT3': 0.4,
'CRPIX3': 2,
'CRVAL3': 1,
}
return WCS(header=header)


@pytest.fixture
def wcs_2d_lt_ln():
spatial = {
Expand Down Expand Up @@ -445,6 +469,24 @@ def ndcube_3d_ln_lt_l_ec_time(wcs_3d_l_lt_ln, time_and_simple_extra_coords_2d):
return cube


@pytest.fixture
def ndcube_3d_wave_lt_ln_ec_time(wcs_3d_wave_lt_ln):
shape = (3, 4, 5)
wcs_3d_wave_lt_ln.array_shape = shape
data = data_nd(shape)
mask = data > 0
cube = NDCube(
data,
wcs_3d_wave_lt_ln,
mask=mask,
uncertainty=data,
)
base_time = Time('2000-01-01', format='fits', scale='utc')
timestamps = Time([base_time + TimeDelta(60 * i, format='sec') for i in range(data.shape[0])])
cube.extra_coords.add('time', 0, timestamps)
return cube


@pytest.fixture
def ndcube_3d_rotated(wcs_3d_ln_lt_t_rotated, simple_extra_coords_3d):
data_rotated = np.array([[[1, 2, 3, 4, 6], [2, 4, 5, 3, 1], [0, -1, 2, 4, 2], [3, 5, 1, 2, 0]],
Expand Down
23 changes: 21 additions & 2 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,25 @@
return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED)

def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units):
# Create meshgrid of all pixel coordinates.
# If user wants pixel_corners, set pixel values to pixel corners.
# This will generate only the coordinates that are needed if there is no correlation within the WCS
# This bypasses the entire rest of the function below which works out the full set of coordinates
# This only works for WCS that have the same number of world and pixel dimensions
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably shouldn't be a limitation but right now I can't work out a good idea.

Copy link
Member

@DanRyanIrish DanRyanIrish Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on what the difficulty is here? This is probably related to the comment below

Copy link
Contributor Author

@nabobalis nabobalis Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too stupid to work out the line of code to overcome this limitation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Put another way, why does enforcing this limitation make the problem simpler? If you just remove this condition, so long as np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1, does it still work as intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the one case in the test suite, no this is not enough.

(0, self.data.shape[::-1][needed_axes[0]]) errors with *** IndexError: tuple index out of range as the "self.data.shape" is length 3 but the needed_axis is "3".

(Pdb++) needed_axes
array([3])
(Pdb++) self.data.shape
(3, 4, 5)

There is no direct translation between the needed world coordinate and the corresponding pixel coordinate in this block.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this util function solve that problem?

if needed_axes is not None and not isinstance(wcs, ExtraCoords) and np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1:
# Account for non-pixel axes affecting the value of needed_axes
Copy link
Member

@DanRyanIrish DanRyanIrish Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you better explain this comment. Are you trying to convert world axes to pixel axes here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the passed in needed_axes is given in world indices but the loop below works on the data shape, so I made a random guess at modifying the needed axes to give a variable which does not error.

It isn't correct

# Only works for one axis
if np.max(wcs.axis_correlation_matrix[needed_axes][0].shape) == needed_axes[0]:
needed_axis = needed_axes[0] - 1
else:
needed_axis = needed_axes[0]
lims = (-0.5, self.data.shape[::-1][needed_axis] + 1) if pixel_corners else (0, self.data.shape[::-1][needed_axis])
indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[needed_axis]]
world_coords = wcs.pixel_to_world_values(*indices)
if units:
world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axis])
return world_coords

# Create a meshgrid of all pixel coordinates.
# If the user wants pixel corners, set pixel values to pixel corners.
# Else make pixel centers.
pixel_shape = self.data.shape[::-1]
if pixel_corners:
Expand Down Expand Up @@ -544,6 +561,8 @@
orig_wcs = wcs
if isinstance(wcs, ExtraCoords):
wcs = wcs.wcs
if not wcs:
return ()

Check warning on line 565 in ndcube/ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/ndcube.py#L565

Added line #L565 was not covered by tests
Comment on lines +564 to +565
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was missing and is included in the other version of this method, so I added it.


world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)

Expand Down
23 changes: 21 additions & 2 deletions ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,32 @@
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)


def test_axis_world_coords_combined_wcs(ndcube_3d_wave_lt_ln_ec_time):
# This replicates a specific NDCube test in the visualization.rst
Copy link
Contributor Author

@nabobalis nabobalis Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my understanding of this and I don't know if this is correct.

This NDCube has a 3D WCS, which is Wavelength, Space, Space and has time as an extra coord.

We pass to axis_world_coords_values the combined WCS which has bolted on time at the end and then asking the code to do a WCS look up on that.

<ndcube.wcs.wrappers.compound_wcs.CompoundLowLevelWCS object at 0x7a606abf92b0>
CompoundLowLevelWCS Transformation

This transformation has 3 pixel and 4 world dimensions

Array shape (Numpy order): None

Pixel Dim  Axis Name  Data size  Bounds
        0  None            None  None
        1  None            None  None
        2  None            None  None

World Dim  Axis Name  Physical Type                   Units
        0  None       em.wl                           m
        1  None       custom:pos.helioprojective.lat  deg
        2  None       custom:pos.helioprojective.lon  deg
        3  time       time                            s

Correlation between pixel and world axes:

             Pixel Dim
World Dim    0    1    2
        0  yes   no   no
        1   no  yes  yes
        2   no  yes  yes
        3   no   no  yes

Time here while in the WCS is not a pixel dimension. So I think that

(Pdb++) wcs.pixel_to_world_values(0,0,2,0)
(array(1.02e-09), array(1.26915033e-05), array(1.39997827), np.float64(119.99999999999957))

gives me the time axis but its number 3 in the list.

So I need a way to get the correct axis and work out how many coords there are so I can get it from pixel to world. But I don't see that information in the WCS.

This feels like it shouldn't hit this code path at all.

coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs)
assert len(coords) == 1
assert isinstance(coords[0], Time)
assert np.all(coords[0] == Time(['2000-01-01T00:00:00.000', '2000-01-01T00:01:00.000', '2000-01-01T00:02:00.000']))

# This fails and returns the wrong coords
Copy link
Contributor Author

@nabobalis nabobalis Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operation returns very very wrong values. I failed to work out why.

coords = ndcube_3d_wave_lt_ln_ec_time.axis_world_coords_values('time', wcs=ndcube_3d_wave_lt_ln_ec_time.combined_wcs)
assert len(coords) == 1
assert isinstance(coords.time, Time)
assert np.all(coords.time == Time(['2000-01-01T00:00:00.000', '2000-01-01T00:01:00.000', '2000-01-01T00:02:00.000']))

Check warning on line 249 in ndcube/tests/test_ndcube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/tests/test_ndcube.py#L249

Added line #L249 was not covered by tests


@pytest.mark.parametrize("axes", [[-1], [2], ["em"]])
def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l):

# We go from 4 pixels to 6 pixels when we add pixel corners
Copy link
Contributor Author

@nabobalis nabobalis Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this correct at all. Going to lean towards no.

coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=False)
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)

coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=True)
assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m)
assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m)
Copy link
Contributor Author

@nabobalis nabobalis Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This grew another dim but I do not know why.


coords = ndcube_3d_ln_lt_l.axis_world_coords(*axes, pixel_corners=True)
assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m)
assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m)


@pytest.mark.parametrize(("ndc", "item"),
Expand Down
Loading