Skip to content

Commit

Permalink
First attempt at reverse projections.
Browse files Browse the repository at this point in the history
  • Loading branch information
erykoff committed Oct 19, 2024
1 parent bea2c8b commit d28adcf
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 61 deletions.
31 changes: 19 additions & 12 deletions skyproj/_skyproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize, LogNorm

from .skycrs import get_crs, GnomonicCRS, proj, proj_inverse
from .skycrs import get_crs, GnomonicCRS, PlateCarreeCRS, proj, proj_inverse
from .hpx_utils import healpix_pixels_range, hspmap_to_xy, hpxmap_to_xy, healpix_to_xy, healpix_bin
from .utils import wrap_values, _get_boundary_poly_xy, get_autoscale_vmin_vmax

Expand Down Expand Up @@ -101,11 +101,14 @@ def __init__(
lon_0 = 179.9999

kwargs['lon_0'] = lon_0
crs = get_crs(projection_name, **kwargs)
crs = get_crs(projection_name, celestial=celestial, **kwargs)
# crs = get_crs(projection_name, **kwargs)
self._ax = fig.add_subplot(subspec, projection=crs)
self._crs_orig = crs
self._reprojected = False

self._plate_carree = PlateCarreeCRS(celestial=celestial)

self._celestial = celestial
self._gridlines = gridlines
self._autorescale = autorescale
Expand Down Expand Up @@ -167,7 +170,7 @@ def proj(self, lon, lat):
y : `np.ndarray`
Array of y values.
"""
return proj(lon, lat, projection=self.crs, pole_clip=self._pole_clip)
return proj(lon, lat, projection=self.crs, plate_carree=self._plate_carree, pole_clip=self._pole_clip)

def proj_inverse(self, x, y):
"""Apply inverse projection to a set of points.
Expand All @@ -188,7 +191,7 @@ def proj_inverse(self, x, y):
lat : `np.ndarray`
Array of latitudes (degrees).
"""
return proj_inverse(x, y, self.crs)
return proj_inverse(x, y, projection=self.crs, plate_carree=self._plate_carree)

def _initialize_axes(self, extent, extent_xy=None):
"""Initialize the axes with a given extent.
Expand All @@ -205,15 +208,16 @@ def _initialize_axes(self, extent, extent_xy=None):
"""
self._set_axes_limits(extent, extent_xy=extent_xy, invert=False)
self._create_axes(extent)
self._set_axes_limits(extent, extent_xy=extent_xy, invert=self._celestial)
# self._set_axes_limits(extent, extent_xy=extent_xy, invert=self._celestial)
# Necessary?
# self._set_axes_limits(extent, extent_xy=extent_xy, invert=False)

self._ax.set_frame_on(False)
if self._gridlines:
self._ax.grid(visible=True, linestyle=':', color='k', lw=0.5,
n_grid_lon=self._n_grid_lon, n_grid_lat=self._n_grid_lat,
longitude_ticks=self._longitude_ticks,
equatorial_labels=self._equatorial_labels,
celestial=self._celestial,
full_circle=self._full_circle,
wrap=self._wrap,
min_lon_ticklabel_delta=self._min_lon_ticklabel_delta,
Expand All @@ -226,15 +230,12 @@ def _initialize_axes(self, extent, extent_xy=None):
def set_extent(self, extent):
"""Set the extent.
Axes will be properly inverted if Skyproj was initialized with
``celestial=True``.
Parameters
----------
extent : array-like
Extent as [lon_min, lon_max, lat_min, lat_max].
"""
self._set_axes_limits(extent, invert=self._celestial)
self._set_axes_limits(extent)
self._extent_xy = self._ax.get_extent(lonlat=False)

self._draw_bounds()
Expand All @@ -249,7 +250,13 @@ def _draw_bounds(self):

extent_xy = self._ax.get_extent(lonlat=False)
bounds_xy = self._compute_proj_boundary_xy()
bounds_xy_clipped = _get_boundary_poly_xy(bounds_xy, extent_xy, self.proj, self.proj_inverse)
bounds_xy_clipped = _get_boundary_poly_xy(
bounds_xy,
extent_xy,
self.proj,
self.proj_inverse,
self._celestial,
)

self._boundary_lines = self._ax.plot(bounds_xy_clipped[:, 0],
bounds_xy_clipped[:, 1],
Expand Down Expand Up @@ -278,7 +285,7 @@ def set_autorescale(self, autorescale):
"""
self._autorescale = autorescale

def _set_axes_limits(self, extent, extent_xy=None, invert=True):
def _set_axes_limits(self, extent, extent_xy=None, invert=False):
"""Set axis limits from an extent.
Parameters
Expand Down
5 changes: 2 additions & 3 deletions skyproj/skyaxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class SkyAxes(matplotlib.axes.Axes):
def __init__(self, *args, **kwargs):
self.projection = kwargs.pop("sky_crs")

self.plate_carree = PlateCarreeCRS()
self.plate_carree = PlateCarreeCRS(celestial=self.projection.celestial)

# Would like to fix this up.
self.gridlines = SkyGridlines([])
Expand Down Expand Up @@ -110,7 +110,7 @@ def clear(self):

def grid(self, visible=False, which="major", axis="both",
n_grid_lon=None, n_grid_lat=None,
longitude_ticks="positive", equatorial_labels=False, celestial=True,
longitude_ticks="positive", equatorial_labels=False,
full_circle=False, wrap=0.0, min_lon_ticklabel_delta=0.1,
draw_inner_lon_labels=False,
**kwargs):
Expand All @@ -127,7 +127,6 @@ def grid(self, visible=False, which="major", axis="both",
n_grid_lon_default=n_grid_lon,
n_grid_lat_default=n_grid_lat,
longitude_ticks=longitude_ticks,
celestial=celestial,
equatorial_labels=equatorial_labels,
full_circle=full_circle,
min_lon_ticklabel_delta=min_lon_ticklabel_delta,
Expand Down
74 changes: 46 additions & 28 deletions skyproj/skycrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


RADIUS = 1.0
CELESTIAL = True


class SkyCRS(CRS):
Expand All @@ -30,11 +31,14 @@ class SkyCRS(CRS):
Name of projection CRS type.
radius : `float`, optional
Radius of projected sphere.
celestial : `bool`, optional
Is this a celestial (inverted) projection?
**kwargs : `dict`, optional
Additional kwargs for PROJ4 parameters.
"""
def __init__(self, name=None, radius=RADIUS, **kwargs):
def __init__(self, name=None, radius=RADIUS, celestial=CELESTIAL, **kwargs):
self._name = name
self._celestial = celestial
self.proj4_params = {'ellps': 'sphere',
'R': radius}
self.proj4_params.update(**kwargs)
Expand All @@ -60,7 +64,11 @@ def with_new_center(self, lon_0, lat_0=None):
if lat_0 is not None:
proj4_params['lat_0'] = lat_0

return self.__class__(**proj4_params)
return self.__class__(**proj4_params, celestial=self._celestial)

@property
def celestial(self):
return self._celestial

def transform_points(self, src_crs, x, y):
"""Transform points from a source coordinate reference system (CRS)
Expand All @@ -86,6 +94,8 @@ def transform_points(self, src_crs, x, y):
x = x.ravel()
y = y.ravel()

sign = -1 if self._celestial else 1

npts = x.shape[0]

result = np.zeros([npts, 2], dtype=np.float64)
Expand All @@ -96,10 +106,10 @@ def transform_points(self, src_crs, x, y):
try:
transformer = Transformer.from_crs(src_crs, self, always_xy=True)
if len(x) == 1:
_x = x[0]
_x = sign * x[0]
_y = y[0]
else:
_x = x
_x = sign * x
_y = y
result[:, 0], result[:, 1] = transformer.transform(_x, _y, None, errcheck=False)
except ProjError as err:
Expand Down Expand Up @@ -185,14 +195,14 @@ class PlateCarreeCRS(SkyCRS):
**kwargs : `dict`, optional
Additional kwargs for PROJ4 parameters.
"""
def __init__(self, name='cyl', lon_0=0.0, radius=RADIUS, **kwargs):
def __init__(self, name='cyl', lon_0=0.0, radius=RADIUS, celestial=CELESTIAL, **kwargs):
proj4_params = {'proj': 'eqc',
'lon_0': lon_0,
'to_meter': math.radians(1)*radius,
'vto_meter': 1}
proj4_params = {**proj4_params, **kwargs}

super().__init__(name=name, radius=radius, **proj4_params)
super().__init__(name=name, radius=radius, celestial=celestial, **proj4_params)


class McBrydeThomasFlatPolarQuarticCRS(SkyCRS):
Expand All @@ -209,12 +219,12 @@ class McBrydeThomasFlatPolarQuarticCRS(SkyCRS):
**kwargs : `dict`, optional
Additional kwargs for PROJ4 parameters.
"""
def __init__(self, name='mbtfpq', lon_0=0.0, radius=RADIUS, **kwargs):
def __init__(self, name='mbtfpq', lon_0=0.0, radius=RADIUS, celestial=CELESTIAL, **kwargs):
proj4_params = {'proj': 'mbtfpq',
'lon_0': lon_0}
proj4_params = {**proj4_params, **kwargs}

super().__init__(name=name, radius=radius, **proj4_params)
super().__init__(name=name, radius=radius, celestial=celestial, **proj4_params)


class MollweideCRS(SkyCRS):
Expand All @@ -231,12 +241,12 @@ class MollweideCRS(SkyCRS):
**kwargs : `dict`, optional
Additional kwargs for PROJ4 parameters.
"""
def __init__(self, name='moll', lon_0=0.0, radius=RADIUS, **kwargs):
def __init__(self, name='moll', lon_0=0.0, radius=RADIUS, celestial=CELESTIAL, **kwargs):
proj4_params = {'proj': 'moll',
'lon_0': lon_0}
proj4_params = {**proj4_params, **kwargs}

super().__init__(name=name, radius=radius, **proj4_params)
super().__init__(name=name, radius=radius, celestial=celestial, **proj4_params)


class ObliqueMollweideCRS(SkyCRS):
Expand All @@ -257,15 +267,15 @@ class ObliqueMollweideCRS(SkyCRS):
**kwargs : `dict`, optional
Additional kwargs for PROJ4 parameters.
"""
def __init__(self, name='obmoll', lon_0=0.0, lat_p=90.0, lon_p=0.0, radius=RADIUS, **kwargs):
def __init__(self, name='obmoll', lon_0=0.0, lat_p=90.0, lon_p=0.0, radius=RADIUS, celestial=CELESTIAL, **kwargs):
proj4_params = {'proj': 'ob_tran',
'o_proj': 'moll',
'o_lat_p': lat_p,
'o_lon_p': lon_p,
'lon_0': lon_0}
proj4_params = {**proj4_params, **kwargs}

super().__init__(name=name, radius=radius, **proj4_params)
super().__init__(name=name, radius=radius, celestial=celestial, **proj4_params)

@property
def lon_0(self):
Expand All @@ -286,12 +296,12 @@ class HammerCRS(SkyCRS):
**kwargs : `dict`, optional
Additional kwargs for PROJ4 parameters.
"""
def __init__(self, name='hammer', lon_0=0.0, radius=RADIUS, **kwargs):
def __init__(self, name='hammer', lon_0=0.0, radius=RADIUS, celestial=CELESTIAL, **kwargs):
proj4_params = {'proj': 'hammer',
'lon_0': lon_0}
proj4_params = {**proj4_params, **kwargs}

super().__init__(name=name, radius=radius, **proj4_params)
super().__init__(name=name, radius=radius, celestial=celestial, **proj4_params)


class EqualEarthCRS(SkyCRS):
Expand All @@ -308,12 +318,12 @@ class EqualEarthCRS(SkyCRS):
**kwargs : `dict`, optional
Additional kwargs for PROJ4 parameters.
"""
def __init__(self, name='eqearth', lon_0=0.0, radius=RADIUS, **kwargs):
def __init__(self, name='eqearth', lon_0=0.0, radius=RADIUS, celestial=CELESTIAL, **kwargs):
proj4_params = {'proj': 'eqearth',
'lon_0': lon_0}
proj4_params = {**proj4_params, **kwargs}

super().__init__(name=name, radius=radius, **proj4_params)
super().__init__(name=name, radius=radius, celestial=celestial, **proj4_params)


class LambertAzimuthalEqualAreaCRS(SkyCRS):
Expand All @@ -332,13 +342,13 @@ class LambertAzimuthalEqualAreaCRS(SkyCRS):
**kwargs : `dict`, optional
Additional kwargs for PROJ4 parameters.
"""
def __init__(self, name='laea', lon_0=0.0, lat_0=0.0, radius=RADIUS, **kwargs):
def __init__(self, name='laea', lon_0=0.0, lat_0=0.0, radius=RADIUS, celestial=CELESTIAL, **kwargs):
proj4_params = {'proj': 'laea',
'lon_0': lon_0,
'lat_0': lat_0}
proj4_params = {**proj4_params, **kwargs}

super().__init__(name=name, radius=radius, **proj4_params)
super().__init__(name=name, radius=radius, celestial=celestial, **proj4_params)


class GnomonicCRS(SkyCRS):
Expand All @@ -357,13 +367,13 @@ class GnomonicCRS(SkyCRS):
**kwargs : `dict`, optional
Additional kwargs for PROJ4 parameters.
"""
def __init__(self, name='gnom', lon_0=0.0, lat_0=0.0, radius=RADIUS, **kwargs):
def __init__(self, name='gnom', lon_0=0.0, lat_0=0.0, radius=RADIUS, celestial=CELESTIAL, **kwargs):
proj4_params = {'proj': 'gnom',
'lon_0': lon_0,
'lat_0': lat_0}
proj4_params = {**proj4_params, **kwargs}

super().__init__(name=name, radius=radius, **proj4_params)
super().__init__(name=name, radius=radius, celestial=celestial, **proj4_params)


class AlbersEqualAreaCRS(SkyCRS):
Expand All @@ -384,14 +394,14 @@ class AlbersEqualAreaCRS(SkyCRS):
**kwargs : `dict`, optional
Additional kwargs for PROJ4 parameters.
"""
def __init__(self, name='aea', lon_0=0.0, lat_1=15.0, lat_2=45.0, radius=RADIUS, **kwargs):
def __init__(self, name='aea', lon_0=0.0, lat_1=15.0, lat_2=45.0, radius=RADIUS, celestial=CELESTIAL, **kwargs):
proj4_params = {'proj': 'aea',
'lon_0': lon_0,
'lat_1': lat_1,
'lat_2': lat_2}
proj4_params = {**proj4_params, **kwargs}

super().__init__(name=name, radius=radius, **proj4_params)
super().__init__(name=name, radius=radius, celestial=celestial, **proj4_params)


_crss = {
Expand All @@ -407,7 +417,7 @@ def __init__(self, name='aea', lon_0=0.0, lat_1=15.0, lat_2=45.0, radius=RADIUS,
}


def get_crs(name, **kwargs):
def get_crs(name, celestial=CELESTIAL, **kwargs):
"""Return a skyproj CRS.
For list of projections available, use skyproj.get_available_crs().
Expand All @@ -416,6 +426,8 @@ def get_crs(name, **kwargs):
----------
name : `str`
Skyproj name of projection CRS.
celestial : `bool`
Is this a celestial (inverted) projection?
**kwargs :
Additional kwargs appropriate for given projection CRS.
Expand All @@ -429,7 +441,7 @@ def get_crs(name, **kwargs):

descr, crsclass = _crss[name]

return crsclass(name=name, **kwargs)
return crsclass(name=name, celestial=celestial, **kwargs)


def get_available_crs():
Expand All @@ -447,29 +459,35 @@ def get_available_crs():
return available_crs


def proj(lon, lat, projection=None, pole_clip=None, wrap=None):
def proj(lon, lat, projection=None, plate_carree=None, pole_clip=None, wrap=None):
if projection is None:
raise RuntimeError("Must specify a projection.")

if plate_carree is None:
plate_carree = PlateCarreeCRS(celestial=projection.celestial)

lon = np.atleast_1d(lon)
lat = np.atleast_1d(lat)
if pole_clip is not None:
out = ((lat < (-90.0 + pole_clip))
| (lat > (90.0 - pole_clip)))
if wrap is not None:
lon[np.isclose(lon, wrap)] = wrap - 1e-10
proj_xy = projection.transform_points(PlateCarreeCRS(), lon, lat)
proj_xy = projection.transform_points(plate_carree, lon, lat)
if pole_clip is not None:
proj_xy[..., 1][out] = np.nan

return proj_xy[..., 0], proj_xy[..., 1]


def proj_inverse(x, y, projection=None):
def proj_inverse(x, y, projection=None, plate_carree=None):
if projection is None:
raise RuntimeError("Must specify a projection.")

if plate_carree is None:
plate_carree = PlateCarreeCRS(celestial=projection.celestial)

x = np.atleast_1d(x)
y = np.atleast_1d(y)
proj_lonlat = PlateCarreeCRS().transform_points(projection, x, y)
proj_lonlat = plate_carree.transform_points(projection, x, y)
return proj_lonlat[..., 0], proj_lonlat[..., 1]
Loading

0 comments on commit d28adcf

Please sign in to comment.