Skip to content

Commit

Permalink
more types
Browse files Browse the repository at this point in the history
  • Loading branch information
martinfleis committed Jan 11, 2024
1 parent c521ac8 commit d61e175
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 77 deletions.
117 changes: 65 additions & 52 deletions xvec/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@

import warnings
from collections.abc import Hashable, Mapping, Sequence
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable, cast

import numpy as np
import pandas as pd # type: ignore
import shapely # type: ignore
import pandas as pd
import shapely
import xarray as xr
from pyproj import CRS, Transformer

from .index import GeometryIndex
from .zonal import _zonal_stats_iterative, _zonal_stats_rasterize

if TYPE_CHECKING:
from geopandas import GeoDataFrame

Check warning on line 17 in xvec/accessor.py

View check run for this annotation

Codecov / codecov/patch

xvec/accessor.py#L17

Added line #L17 was not covered by tests


@xr.register_dataarray_accessor("xvec")
@xr.register_dataset_accessor("xvec")
Expand All @@ -22,7 +25,7 @@ class XvecAccessor:
Currently works on coordinates with :class:`xvec.GeometryIndex`.
"""

def __init__(self, xarray_obj: xr.Dataset | xr.DataArray):
def __init__(self, xarray_obj: xr.Dataset | xr.DataArray) -> None:
"""xvec init, nothing to be done here."""
self._obj = xarray_obj
self._geom_coords_all = [
Expand All @@ -36,7 +39,9 @@ def __init__(self, xarray_obj: xr.Dataset | xr.DataArray):
if self.is_geom_variable(name, has_index=True)
]

def is_geom_variable(self, name: Hashable, has_index: bool = True):
def is_geom_variable(
self, name: Hashable, has_index: bool = True
) -> bool | np.bool_:
"""Check if coordinate variable is composed of :class:`shapely.Geometry`.
Can return all such variables or only those using :class:`~xvec.GeometryIndex`.
Expand Down Expand Up @@ -208,7 +213,7 @@ def to_crs(
self,
variable_crs: Mapping[Any, Any] | None = None,
**variable_crs_kwargs: Any,
):
) -> xr.DataArray | xr.Dataset:
"""
Transform :class:`shapely.Geometry` objects of a variable to a new coordinate
reference system.
Expand Down Expand Up @@ -313,20 +318,15 @@ def to_crs(
currently wraps :meth:`Dataset.assign_coords <xarray.Dataset.assign_coords>`
or :meth:`DataArray.assign_coords <xarray.DataArray.assign_coords>`.
"""
if variable_crs and variable_crs_kwargs:
raise ValueError(
"Cannot specify both keyword and positional arguments to "
"'.xvec.to_crs'."
)
variable_crs_solved = _resolve_input(
variable_crs, variable_crs_kwargs, "to_crs"
)

_obj = self._obj.copy(deep=False)

if variable_crs_kwargs:
variable_crs = variable_crs_kwargs

transformed = {}

for key, crs in variable_crs.items():
for key, crs in variable_crs_solved.items():
if not isinstance(self._obj.xindexes[key], GeometryIndex):
raise ValueError(
f"The index '{key}' is not an xvec.GeometryIndex. "
Expand All @@ -335,7 +335,7 @@ def to_crs(
)

data = _obj[key]
data_crs = self._obj.xindexes[key].crs
data_crs = self._obj.xindexes[key].crs # type: ignore

# transformation code taken from geopandas (BSD 3-clause license)
if data_crs is None:
Expand Down Expand Up @@ -374,21 +374,21 @@ def to_crs(
for key, (result, _crs) in transformed.items():
_obj = _obj.assign_coords({key: result})

_obj = _obj.drop_indexes(variable_crs.keys())
_obj = _obj.drop_indexes(variable_crs_solved.keys())

for key, crs in variable_crs.items():
for key, crs in variable_crs_solved.items():
if crs:
_obj[key].attrs["crs"] = CRS.from_user_input(crs)
_obj = _obj.set_xindex(key, GeometryIndex, crs=crs)
_obj = _obj.set_xindex([key], GeometryIndex, crs=crs)

return _obj

def set_crs(
self,
variable_crs: Mapping[Any, Any] | None = None,
allow_override=False,
allow_override: bool = False,
**variable_crs_kwargs: Any,
):
) -> xr.DataArray | xr.Dataset:
"""Set the Coordinate Reference System (CRS) of coordinates backed by
:class:`~xvec.GeometryIndex`.
Expand Down Expand Up @@ -480,27 +480,21 @@ def set_crs(
transform the geometries to a new CRS, use the :meth:`to_crs`
method.
"""

if variable_crs and variable_crs_kwargs:
raise ValueError(
"Cannot specify both keyword and positional arguments to "
".xvec.set_crs."
)
variable_crs_solved = _resolve_input(
variable_crs, variable_crs_kwargs, "set_crs"
)

_obj = self._obj.copy(deep=False)

if variable_crs_kwargs:
variable_crs = variable_crs_kwargs

for key, crs in variable_crs.items():
for key, crs in variable_crs_solved.items():
if not isinstance(self._obj.xindexes[key], GeometryIndex):
raise ValueError(
f"The index '{key}' is not an xvec.GeometryIndex. "
"Set the xvec.GeometryIndex using '.xvec.set_geom_indexes' before "
"handling projection information."
)

data_crs = self._obj.xindexes[key].crs
data_crs = self._obj.xindexes[key].crs # type: ignore

if not allow_override and data_crs is not None and not data_crs == crs:
raise ValueError(
Expand All @@ -510,12 +504,12 @@ def set_crs(
"want to transform the geometries, use '.xvec.to_crs' instead."
)

_obj = _obj.drop_indexes(variable_crs.keys())
_obj = _obj.drop_indexes(variable_crs_solved.keys())

for key, crs in variable_crs.items():
for key, crs in variable_crs_solved.items():
if crs:
_obj[key].attrs["crs"] = CRS.from_user_input(crs)
_obj = _obj.set_xindex(key, GeometryIndex, crs=crs)
_obj = _obj.set_xindex([key], GeometryIndex, crs=crs)

return _obj

Expand All @@ -526,7 +520,7 @@ def query(
predicate: str | None = None,
distance: float | Sequence[float] | None = None,
unique: bool = False,
):
) -> xr.DataArray | xr.Dataset:
"""Return a subset of a DataArray/Dataset filtered using a spatial query on
:class:`~xvec.GeometryIndex`.
Expand Down Expand Up @@ -619,12 +613,12 @@ def query(
"""
if isinstance(geometry, shapely.Geometry):
ilocs = self._obj.xindexes[coord_name].sindex.query(
ilocs = self._obj.xindexes[coord_name].sindex.query( # type: ignore
geometry, predicate=predicate, distance=distance
)

else:
_, ilocs = self._obj.xindexes[coord_name].sindex.query(
_, ilocs = self._obj.xindexes[coord_name].sindex.query( # type: ignore
geometry, predicate=predicate, distance=distance
)
if unique:
Expand All @@ -637,8 +631,8 @@ def set_geom_indexes(
coord_names: str | Sequence[str],
crs: Any = None,
allow_override: bool = False,
**kwargs,
):
**kwargs: dict[str, Any],
) -> xr.DataArray | xr.Dataset:
"""Set a new :class:`~xvec.GeometryIndex` for one or more existing
coordinate(s). One :class:`~xvec.GeometryIndex` is set per coordinate. Only
1-dimensional coordinates are supported.
Expand Down Expand Up @@ -691,7 +685,7 @@ def set_geom_indexes(

for coord in coord_names:
if isinstance(self._obj.xindexes[coord], GeometryIndex):
data_crs = self._obj.xindexes[coord].crs
data_crs = self._obj.xindexes[coord].crs # type: ignore

if not allow_override and data_crs is not None and not data_crs == crs:
raise ValueError(
Expand All @@ -710,7 +704,7 @@ def set_geom_indexes(

return _obj

def to_geopandas(self):
def to_geopandas(self) -> GeoDataFrame | pd.DataFrame:
"""Convert this array into a GeoPandas :class:`~geopandas.GeoDataFrame`
Returns a :class:`~geopandas.GeoDataFrame` with coordinates based on a
Expand All @@ -736,7 +730,7 @@ def to_geopandas(self):
to_geodataframe
"""
try:
import geopandas as gpd # type: ignore
import geopandas as gpd
except ImportError as err:
raise ImportError(
"The geopandas package is required for `xvec.to_geodataframe()`. "
Expand Down Expand Up @@ -766,7 +760,7 @@ def to_geopandas(self):
gdf = self._obj.to_pandas()
if gdf.columns.name == self._geom_indexes[0]:
gdf = gdf.T
return gdf.reset_index().set_geometry(
return gdf.reset_index().set_geometry( # type: ignore
self._geom_indexes[0],
crs=self._obj.xindexes[self._geom_indexes[0]].crs,
)
Expand All @@ -775,7 +769,7 @@ def to_geopandas(self):
UserWarning,
stacklevel=2,
)
return self._obj.to_pandas()
return cast(pd.DataFrame, self._obj.to_pandas())

# Dataset
gdf = self._obj.to_pandas()
Expand All @@ -801,7 +795,7 @@ def to_geopandas(self):
stacklevel=2,
)

return gdf
return cast(pd.DataFrame, gdf)

def to_geodataframe(
self,
Expand All @@ -810,7 +804,7 @@ def to_geodataframe(
dim_order: Sequence[Hashable] | None = None,
geometry: Hashable | None = None,
long: bool = True,
):
) -> GeoDataFrame | pd.DataFrame:
"""Convert this array and its coordinates into a tidy geopandas.GeoDataFrame.
The GeoDataFrame is indexed by the Cartesian product of index coordinates
Expand Down Expand Up @@ -884,7 +878,7 @@ def to_geodataframe(
level
for level in df.index.names
if level not in self._geom_coords_all
]
] # type: ignore
)

if isinstance(df.index, pd.MultiIndex):
Expand All @@ -907,7 +901,7 @@ def to_geodataframe(
if geometry is not None:
return df.set_geometry(
geometry, crs=self._obj[geometry].attrs.get("crs", None)
)
) # type: ignore

warnings.warn(
"No active geometry column to be set. The resulting object "
Expand All @@ -930,8 +924,8 @@ def zonal_stats(
method: str = "rasterize",
all_touched: bool = False,
n_jobs: int = -1,
**kwargs,
):
**kwargs: dict[str, Any],
) -> xr.DataArray | xr.Dataset:
"""Extract the values from a dataset indexed by a set of geometries
Given an object indexed by x and y coordinates (or latitude and longitude), such
Expand Down Expand Up @@ -1123,7 +1117,7 @@ def extract_points(
name: str = "geometry",
crs: Any | None = None,
index: bool | None = None,
):
) -> xr.DataArray | xr.Dataset:
"""Extract points from a DataArray or a Dataset indexed by spatial coordinates
Given an object indexed by x and y coordinates (or latitude and longitude), such
Expand Down Expand Up @@ -1263,3 +1257,22 @@ def extract_points(
}
)
return result


def _resolve_input(
positional: Mapping[Any, Any] | None,
keyword: Mapping[str, Any],
func_name: str,
) -> Mapping[Hashable, Any]:
"""Resolve combination of positional and keyword arguments.
Based on xarray's ``either_dict_or_kwargs``.
"""
if positional and keyword:
raise ValueError(
"Cannot specify both keyword and positional arguments to "
f"'.xvec.{func_name}'."
)
if positional is None or positional == {}:
return cast(Mapping[Hashable, Any], keyword)
return positional
Loading

0 comments on commit d61e175

Please sign in to comment.