Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster shortcut for working out coordinates values for non-correlated WCS #780

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

nabobalis
Copy link
Contributor

@nabobalis nabobalis commented Nov 8, 2024

Fixes #585

So in my case, I use a DKIST compound model which has two coupled pixel axes and a time axes.
If I want just the time axis, the current code will work out the entire grid leading to a very dense and memory intense results.

So I added a really specific hack which works for my code and seems to not break the rest of the test suite.

TODO:

Unit tests:

  • EDGE CASES - THERE SHOULD BE SOME?

@nabobalis nabobalis added this to the 2.2.4 milestone Nov 8, 2024
@nabobalis nabobalis added the backport 2.2 on-merge: backport to 2.2 label Nov 8, 2024
@nabobalis nabobalis modified the milestones: 2.2.4, 2.2.5 Nov 8, 2024
# If user wants pixel_corners, set pixel values to pixel corners.
# This will generate only the coordinates that are needed if there is no correlation within the WCS
# This bypasses the entire rest of the function below which works out the full set of coordinates
# This only works for WCS that have the same number of world and pixel dimensions
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably shouldn't be a limitation but right now I can't work out a good idea.

Copy link
Member

@DanRyanIrish DanRyanIrish Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on what the difficulty is here? This is probably related to the comment below

Copy link
Contributor Author

@nabobalis nabobalis Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too stupid to work out the line of code to overcome this limitation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Put another way, why does enforcing this limitation make the problem simpler? If you just remove this condition, so long as np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1, does it still work as intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the one case in the test suite, no this is not enough.

(0, self.data.shape[::-1][needed_axes[0]]) errors with *** IndexError: tuple index out of range as the "self.data.shape" is length 3 but the needed_axis is "3".

(Pdb++) needed_axes
array([3])
(Pdb++) self.data.shape
(3, 4, 5)

There is no direct translation between the needed world coordinate and the corresponding pixel coordinate in this block.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this util function solve that problem?

@nabobalis nabobalis changed the title My finest work Faster Shortcut for working out coordinates values for non-correlated WCS Nov 11, 2024
@nabobalis nabobalis changed the title Faster Shortcut for working out coordinates values for non-correlated WCS Faster shortcut for working out coordinates values for non-correlated WCS Nov 11, 2024
ndcube/ndcube.py Outdated
# This will generate only the coordinates that are needed if there is no correlation within the WCS
# This bypasses the entire rest of the function below which works out the full set of coordinates
# This only works for WCS that have the same number of world and pixel dimensions
if not pixel_corners and needed_axes is not None and not isinstance(wcs, ExtraCoords) and np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1 and len(self.data.shape) == wcs.world_n_dim:
Copy link
Member

@DanRyanIrish DanRyanIrish Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future developers (and reviewer), could you add comments stating why each of these conditions are required? This may clarify some questions I have, such as:

  • why does this only work when not pixel_corners?
  • Does this only work when there is only one needed_axis?
    • If not, what is the significance of the magic 0 index below, e.g. wcs.axis_correlation_matrix[needed_axes][0]]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future developers (and reviewer), could you add comments stating why each of these conditions are required? This may clarify some questions I have, such as:

Fair

  • why does this only work when not pixel_corners?

Too stupid to work out how to fix it.

  • Does this only work when there is only one needed_axis?

Too stupid to work out how to fix it.

  • If not, what is the significance of the magic 0 index below, e.g. wcs.axis_correlation_matrix[needed_axes][0]]

I think wcs.axis_correlation_matrix[needed_axes] returns a list that has to be escaped hence the [0]. But I should check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • why does this only work when not pixel_corners?

Too stupid to work out how to fix it.

See suggested change below for initial guess at how to solve this. Can you write a test to check out if this in fact does fix it?

  • Does this only work when there is only one needed_axis?

Too stupid to work out how to fix it.

Would this work if you made the unwanted dimensions for indices have shape (1) * n where n is the number of dimensions. And the indices in the wanted dimensions can have the np.arange output. Since the dimensions are independent, this should result in broadcastable inputs to wcs.pixel_to_world_values, and so you should be able to get coords out for multiple independent axes, without blowing up the memory. At least, I think this is true.

  • If not, what is the significance of the magic 0 index below, e.g. wcs.axis_correlation_matrix[needed_axes][0]]

I think wcs.axis_correlation_matrix[needed_axes] returns a list that has to be escaped hence the [0]. But I should check.

Is this list length 1 because you enforced this condition above: np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • why does this only work when not pixel_corners?

Too stupid to work out how to fix it.

See suggested change below for initial guess at how to solve this. Can you write a test to check out if this in fact does fix it?

I added your change and we have a few unit tests that already cover this.
They fail but they are close.

  • Does this only work when there is only one needed_axis?

Too stupid to work out how to fix it.

Would this work if you made the unwanted dimensions for indices have shape (1) * n where n is the number of dimensions. And the indices in the wanted dimensions can have the np.arange output. Since the dimensions are independent, this should result in broadcastable inputs to wcs.pixel_to_world_values, and so you should be able to get coords out for multiple independent axes, without blowing up the memory. At least, I think this is true.

Yeah that would work I just don't know how to code that.

  • If not, what is the significance of the magic 0 index below, e.g. wcs.axis_correlation_matrix[needed_axes][0]]

I think wcs.axis_correlation_matrix[needed_axes] returns a list that has to be escaped hence the [0]. But I should check.

Is this list length 1 because you enforced this condition above: np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1?

I was wrong, the reason for the [0] is that:

(Pdb++) wcs.axis_correlation_matrix[needed_axes]
array([[False, False,  True]])

This returns a nested array, so I need to escape that.

ndcube/ndcube.py Outdated
# This bypasses the entire rest of the function below which works out the full set of coordinates
# This only works for WCS that have the same number of world and pixel dimensions
if not pixel_corners and needed_axes is not None and not isinstance(wcs, ExtraCoords) and np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1 and len(self.data.shape) == wcs.world_n_dim:
indices = [np.arange(self.data.shape[::-1][needed_axes[0]]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[needed_axes][0]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think something like below is what you need to make this valid of pixel_corners as well.

Suggested change
indices = [np.arange(self.data.shape[::-1][needed_axes[0]]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[needed_axes][0]]
lims = (-0.5, self.data.shape[::-1][needed_axes[0]] + 1) if pixel_corners else (0, self.data.shape[::-1][needed_axes[0]])
indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[needed_axes][0]]

Copy link
Contributor Author

@nabobalis nabobalis Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this, I think it works perfectly thank you. I just have to account for a test change:

-> assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m)
(Pdb++) coords
CoordValues(em_wl=<Quantity [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] m>)
(Pdb++) coords[0]
<Quantity [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] m>
(Pdb++) 

I need to escape another dim which I didn't need to before?

There is also one more element now. I do not know why.

# If user wants pixel_corners, set pixel values to pixel corners.
# This will generate only the coordinates that are needed if there is no correlation within the WCS
# This bypasses the entire rest of the function below which works out the full set of coordinates
# This only works for WCS that have the same number of world and pixel dimensions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Put another way, why does enforcing this limitation make the problem simpler? If you just remove this condition, so long as np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1, does it still work as intended?

ndcube/ndcube.py Outdated
# This will generate only the coordinates that are needed if there is no correlation within the WCS
# This bypasses the entire rest of the function below which works out the full set of coordinates
# This only works for WCS that have the same number of world and pixel dimensions
if not pixel_corners and needed_axes is not None and not isinstance(wcs, ExtraCoords) and np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1 and len(self.data.shape) == wcs.world_n_dim:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • why does this only work when not pixel_corners?

Too stupid to work out how to fix it.

See suggested change below for initial guess at how to solve this. Can you write a test to check out if this in fact does fix it?

  • Does this only work when there is only one needed_axis?

Too stupid to work out how to fix it.

Would this work if you made the unwanted dimensions for indices have shape (1) * n where n is the number of dimensions. And the indices in the wanted dimensions can have the np.arange output. Since the dimensions are independent, this should result in broadcastable inputs to wcs.pixel_to_world_values, and so you should be able to get coords out for multiple independent axes, without blowing up the memory. At least, I think this is true.

  • If not, what is the significance of the magic 0 index below, e.g. wcs.axis_correlation_matrix[needed_axes][0]]

I think wcs.axis_correlation_matrix[needed_axes] returns a list that has to be escaped hence the [0]. But I should check.

Is this list length 1 because you enforced this condition above: np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1?

@@ -237,11 +237,16 @@ def test_axis_world_coords_single(axes, ndcube_3d_ln_lt_l):

@pytest.mark.parametrize("axes", [[-1], [2], ["em"]])
def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l):

# We go from 4 pixels to 6 pixels when we add pixel corners
Copy link
Contributor Author

@nabobalis nabobalis Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this correct at all. Going to lean towards no.

coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=True)
assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m)
assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m)
Copy link
Contributor Author

@nabobalis nabobalis Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This grew another dim but I do not know why.

assert isinstance(coords[0], Time)
assert np.all(coords[0] == Time(['2000-01-01T00:00:00.000', '2000-01-01T00:01:00.000', '2000-01-01T00:02:00.000']))

# This fails and returns the wrong coords
Copy link
Contributor Author

@nabobalis nabobalis Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operation returns very very wrong values. I failed to work out why.

@@ -0,0 +1 @@
Added an internal method to shortcut non-correlated axes avoiding the creation of a full coordinate grid, reducing memory use in specific circumstances.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Better changelog.

Comment on lines +564 to +565
if not wcs:
return ()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was missing and is included in the other version of this method, so I added it.

@@ -235,13 +235,32 @@ def test_axis_world_coords_single(axes, ndcube_3d_ln_lt_l):
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)


def test_axis_world_coords_crazy(ndcube_3d_wave_lt_ln_ec_time):
# This replicates a specific NDCube test in the visualization.rst
Copy link
Contributor Author

@nabobalis nabobalis Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my understanding of this and I don't know if this is correct.

This NDCube has a 3D WCS, which is Wavelength, Space, Space and has time as an extra coord.

We pass to axis_world_coords_values the combined WCS which has bolted on time at the end and then asking the code to do a WCS look up on that.

<ndcube.wcs.wrappers.compound_wcs.CompoundLowLevelWCS object at 0x7a606abf92b0>
CompoundLowLevelWCS Transformation

This transformation has 3 pixel and 4 world dimensions

Array shape (Numpy order): None

Pixel Dim  Axis Name  Data size  Bounds
        0  None            None  None
        1  None            None  None
        2  None            None  None

World Dim  Axis Name  Physical Type                   Units
        0  None       em.wl                           m
        1  None       custom:pos.helioprojective.lat  deg
        2  None       custom:pos.helioprojective.lon  deg
        3  time       time                            s

Correlation between pixel and world axes:

             Pixel Dim
World Dim    0    1    2
        0  yes   no   no
        1   no  yes  yes
        2   no  yes  yes
        3   no   no  yes

Time here while in the WCS is not a pixel dimension. So I think that

(Pdb++) wcs.pixel_to_world_values(0,0,2,0)
(array(1.02e-09), array(1.26915033e-05), array(1.39997827), np.float64(119.99999999999957))

gives me the time axis but its number 3 in the list.

So I need a way to get the correct axis and work out how many coords there are so I can get it from pixel to world. But I don't see that information in the WCS.

This feels like it shouldn't hit this code path at all.

Copy link
Member

@DanRyanIrish DanRyanIrish left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More review to follow.

# If user wants pixel_corners, set pixel values to pixel corners.
# This will generate only the coordinates that are needed if there is no correlation within the WCS
# This bypasses the entire rest of the function below which works out the full set of coordinates
# This only works for WCS that have the same number of world and pixel dimensions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this util function solve that problem?

# This bypasses the entire rest of the function below which works out the full set of coordinates
# This only works for WCS that have the same number of world and pixel dimensions
if needed_axes is not None and not isinstance(wcs, ExtraCoords) and np.sum(wcs.axis_correlation_matrix[needed_axes]) == 1:
# Account for non-pixel axes affecting the value of needed_axes
Copy link
Member

@DanRyanIrish DanRyanIrish Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you better explain this comment. Are you trying to convert world axes to pixel axes here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the passed in needed_axes is given in world indices but the loop below works on the data shape, so I made a random guess at modifying the needed axes to give a variable which does not error.

It isn't correct

@DanRyanIrish
Copy link
Member

DanRyanIrish commented Nov 28, 2024

Hi @nabobalis. My reading of this PR is that it should be generalised to make it a bit simpler to implement and follow. If I understand correctly, you want to special-case the scenario where all wanted world axes (whether 1 or more) correspond to their own unique pixel axis. In this scenario, you don't have to work out all the coordinates for the entire (m, n) pixel grid, but instead just then 1-D dimensions corresponding to the desired world axes. This saves RAM and time. Am I right so far?

If this is the case, then I think it would help readability, to move your code to a new private method, e.g. _generate_independent_world_coords(). The pre-existing code could be moved to another private method, e.g. _generate_dependent_world_coords(), and then _generate_world_coords() could simply check whether the needed axes are independent, and then call the appropriate method above. Something like this:

def _generate_world_coords(self, pixel_corners, wcs, needed_axes=None, *, units):
    if needed_axes is None:
        needed_axes = list(range(wcs.n_world_dims))
    axes_are_independent = []
    pixel_axes = {}
    for world_axis in needed_axes:
        pix_ax = ndcube.utils.wcs.world_axis_to_pixel_axes(world_axis, wcs.axis_correlation_matrix)
        axes_are_independent.append(len(pix_ax) == 1)
        pixel_axes = pixel_axes.union(set(pix_ax))
    if all(axes_are_independent) and len(pixel_axes) == len(needed_axes):
        return self._generate_independent_world_coords(pixel_corners, wcs, pixel_axes, units)
    else:
        return self._generate_dependent_world_coords(pixel_corners, wcs, pixel_axes, units)

This brings us to your new code, i.e. _generate_independent_world_coords(). Since the check whether the needed axes are all independent will be done in the new _generate_world_coords(), we don't have to repeat that. We need to check whether we want pixel centers or corners, then define the pixel grids, where unwanted axes are set to np.array([0], dtype=int).reshape([1] * ndim) where ndim is the number of dimensions in the data array. Wanted dimensions are then set to np.arange(n).astype(int).reshape((1,...,n,...1) where n is the length of the relevant axis. So something like this:

def _generate_independent_world_coords(pixel_corners, wcs, pixel_axes, units):
    naxes = len(self.shape)
    pixel_indices = [np.array([0], dtype=int).reshape([1] * naxes)] * naxes
    for pixel_axis in pixel_axes:
        # Define limits of desired pixel range based on whether corners or centers are desired 
        len_axis = self.data.shape[::-1][pixel_axis]
        lims = (-0.5, len_axis + 1) if pixel_corners else (0, len_axis)
        pix_ind = np.arange(lims[0], lims[1])
        shape = [1] * naxes
        shape[pixel_axis] = len(pix_ind)
        pix_ind = pix_ind.reshape(shape)
        pixel_indices[pixel_axis] = pix_ind
    world_coords = wcs.pixel_to_world_values(*pixel_indices)
    if units:
        # Attach units to each coord here.
    return world_coords

I presume there's still some work to actually make this code work, but hopefully it's a good start. Let me know if you have questions.

(And these private methods should probably have their own docstrings so it's easier for developers to understand them in the future.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backport 2.2 on-merge: backport to 2.2
Projects
None yet
Development

Successfully merging this pull request may close these issues.

_generate_world_coords is slow and uses a lot of memory
2 participants