Skip to content

Commit

Permalink
BUG: Ignore categorical columns not in measurements
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Jan 23, 2024
1 parent bf265e4 commit 13aa686
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 97 deletions.
4 changes: 2 additions & 2 deletions geocube/api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def make_geocube(
geobox_maker=geobox_maker,
fill=fill,
categorical_enums=categorical_enums,
).make_geocube(
measurements=measurements,
datetime_measurements=datetime_measurements,
group_by=group_by,
datetime_measurements=datetime_measurements,
).make_geocube(
interpolate_na_method=interpolate_na_method,
rasterize_function=rasterize_function,
)
39 changes: 27 additions & 12 deletions geocube/geo_utils/geobox.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections.abc import Iterable
from typing import Any, Optional, Union

import fiona.errors
import geopandas
import rioxarray # noqa: F401 pylint: disable=unused-import
import shapely.geometry.base
Expand Down Expand Up @@ -46,14 +47,19 @@ def geobox_from_rio(xds: Union[xarray.Dataset, xarray.DataArray]) -> GeoBox:


def load_vector_data(
vector_data: Union[str, os.PathLike, geopandas.GeoDataFrame]
vector_data: Union[str, os.PathLike, geopandas.GeoDataFrame],
measurements: Optional[list[str]] = None,
) -> geopandas.GeoDataFrame:
"""
Parameters
----------
vector_data: str, path-like object or :obj:`geopandas.GeoDataFrame`
A file path to an OGR supported source or GeoDataFrame containing
the vector data.
measurements: list[str], optional
Attributes name or list of names to be included. If a list is specified,
the measurements will be returned in the order requested.
By default all available measurements are included.
Returns
-------
Expand All @@ -63,17 +69,31 @@ def load_vector_data(
logger = get_logger()

if isinstance(vector_data, (str, os.PathLike)):
vector_data = geopandas.read_file(vector_data)
try:
vector_data = geopandas.read_file(vector_data, include_fields=measurements)
except fiona.errors.DriverError as error:
if "ignore_fields" not in str(error):
raise
vector_data = geopandas.read_file(vector_data)

elif not isinstance(vector_data, geopandas.GeoDataFrame):
vector_data = geopandas.GeoDataFrame(vector_data)
else:
vector_data = vector_data.copy()

if vector_data.empty:
raise VectorDataError("Empty GeoDataFrame.")
if "geometry" not in vector_data.columns:

if measurements is not None:
vector_data = vector_data[measurements + [vector_data.geometry.name]]

try:
vector_data.geometry
except AttributeError as error:
raise VectorDataError(
"'geometry' column missing. Columns in file: "
f"{vector_data.columns.values.tolist()}"
)
) from error

# make sure projection is set
if not vector_data.crs:
Expand Down Expand Up @@ -132,25 +152,20 @@ def __init__(
self.geom = geom
self.like = like

def from_vector(
self, vector_data: Union[str, os.PathLike, geopandas.GeoDataFrame]
) -> GeoBox:
def from_vector(self, vector_data: geopandas.GeoDataFrame) -> GeoBox:
"""Get the geobox to use for the grid.
Parameters
----------
vector_data: str, path-like object or :obj:`geopandas.GeoDataFrame`
A file path to an OGR supported source or GeoDataFrame
containing the vector data.
vector_data: geopandas.GeoDataFrame
A GeoDataFrame containing the vector data.
Returns
-------
:obj:`odc.geo.geobox.GeoBox`
The geobox for the grid to be generated from the vector data.
"""
vector_data = load_vector_data(vector_data)

if self.like is not None:
assert (
self.output_crs is None
Expand Down
155 changes: 78 additions & 77 deletions geocube/vector_to_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _format_series_data(data_series: geopandas.GeoSeries) -> geopandas.GeoSeries


class VectorToCube:
# pylint: disable=too-many-instance-attributes
"""
Tool that facilitates converting vector data to raster data into
an :obj:`xarray.DataFrame`.
Expand All @@ -55,6 +56,9 @@ def __init__(
geobox_maker: GeoBoxMaker,
fill: float,
categorical_enums: Optional[dict[str, list]],
measurements: Optional[list[str]] = None,
datetime_measurements: Optional[list[str]] = None,
group_by: Optional[str] = None,
):
"""
Initialize the GeoCube class.
Expand All @@ -73,33 +77,88 @@ def __init__(
categorical data. The categories will be made unique and sorted
if they are not already.
E.g. {'column_name': ['a', 'b'], 'other_column': ['c', 'd']}
measurements: list[str], optional
Attributes name or list of names to be included. If a list is specified,
the measurements will be returned in the order requested.
By default all available measurements are included.
datetime_measurements: list[str], optional
Attributes that are temporal in nature and should be converted to the
datetime format. These are only included if listed in 'measurements'.
group_by: str, optional
When specified, perform basic combining/reducing of the data on this column.
"""
self._vector_data = load_vector_data(vector_data)
self._fill = fill if fill is not None else numpy.nan
self._group_by = group_by
self._measurements = measurements
self._rasterize_function: Callable[..., Optional[NDArray]] = rasterize_image

load_measurements = measurements
if (
measurements is not None
and self._group_by is not None
and self._group_by not in measurements
):
load_measurements.append(self._group_by)
self._vector_data = load_vector_data(
vector_data, measurements=load_measurements
)
self._geobox = geobox_maker.from_vector(self._vector_data)
self._grid_coords = affine_to_coords(
self._geobox.affine, self._geobox.width, self._geobox.height
)
self._fill = fill if fill is not None else numpy.nan
if self._geobox.crs is not None:
self._vector_data = self._vector_data.to_crs(self._geobox.crs)

if self._measurements is None:
self._measurements = self._vector_data.columns.tolist()
self._measurements.remove("geometry")

if categorical_enums is not None:
for column_name, categories in categorical_enums.items():
if column_name not in self._measurements:
continue
category_type = pandas.api.types.CategoricalDtype(
categories=sorted(set(categories)) + ["nodata"]
)
self._vector_data[column_name] = self._vector_data[column_name].astype(
category_type
)

# define defaults
self._rasterize_function: Callable[..., Optional[NDArray]] = rasterize_image
# get categorical enumerations if they exist
self._categorical_enums: dict[str, list] = {
categorical_column: self._vector_data[categorical_column].cat.categories
for categorical_column in self._vector_data.select_dtypes(
["category"]
).columns
if categorical_column in self._measurements
}

self._datetime_measurements: tuple[str, ...] = ()
self._categorical_enums: dict[str, list] = {}
if datetime_measurements is not None:
self._datetime_measurements = tuple(
set(datetime_measurements) & set(self._measurements)
)

# convert to datetime
for datetime_measurement in self._datetime_measurements: # type: ignore
date_data = pandas.to_datetime(self._vector_data[datetime_measurement])
try:
date_data = date_data.dt.tz_convert("UTC")
except TypeError:
pass
self._vector_data[datetime_measurement] = date_data.dt.tz_localize(
None
).astype("datetime64[ns]")

if self._group_by:
self._vector_data = self._vector_data.groupby(self._group_by)
try:
self._measurements.remove(self._group_by)
except ValueError:
pass

def make_geocube(
self,
measurements: Optional[list[str]] = None,
datetime_measurements: Optional[list[str]] = None,
group_by: Optional[str] = None,
interpolate_na_method: Optional[Literal["linear", "nearest", "cubic"]] = None,
rasterize_function: Optional[Callable[..., Optional[NDArray]]] = None,
) -> xarray.Dataset:
Expand All @@ -112,15 +171,6 @@ def make_geocube(
Parameters
----------
measurements: list[str], optional
Attributes name or list of names to be included. If a list is specified,
the measurements will be returned in the order requested.
By default all available measurements are included.
datetime_measurements: list[str], optional
Attributes that are temporal in nature and should be converted to the
datetime format. These are only included if listed in 'measurements'.
group_by: str, optional
When specified, perform basic combining/reducing of the data on this column.
interpolate_na_method: {'linear', 'nearest', 'cubic'}, optional
This is the method for interpolation to use to fill in the nodata with
:func:`scipy.interpolate.griddata`.
Expand All @@ -137,47 +187,8 @@ def make_geocube(
self._rasterize_function = (
rasterize_image if rasterize_function is None else rasterize_function # type: ignore
)
if measurements is None:
measurements = self._vector_data.columns.tolist()
measurements.remove("geometry")

self._datetime_measurements = ()
if datetime_measurements is not None:
self._datetime_measurements = tuple(
set(datetime_measurements) & set(measurements)
)
# reproject vector data to the projection of the output raster
if self._geobox.crs is not None:
vector_data = self._vector_data.to_crs(self._geobox.crs)

# convert to datetime
for datetime_measurement in self._datetime_measurements: # type: ignore
date_data = pandas.to_datetime(vector_data[datetime_measurement])
try:
date_data = date_data.dt.tz_convert("UTC")
except TypeError:
pass
vector_data[datetime_measurement] = date_data.dt.tz_localize(None).astype(
"datetime64[ns]"
)

# get categorical enumerations if they exist
self._categorical_enums = {}
for categorical_column in vector_data.select_dtypes(["category"]).columns:
self._categorical_enums[categorical_column] = vector_data[
categorical_column
].cat.categories

# map the shape data to the grid
if group_by:
vector_data = vector_data.groupby(group_by)
try:
measurements.remove(group_by)
except ValueError:
pass

return self._get_dataset(
vector_data, measurements, group_by, interpolate_na_method
interpolate_na_method=interpolate_na_method,
)

@staticmethod
Expand Down Expand Up @@ -225,22 +236,13 @@ def _update_time_attrs(self, attrs: dict[str, Any], image_data: NDArray) -> None

def _get_dataset(
self,
vector_data: geopandas.GeoDataFrame,
measurements: list[str],
group_by: Optional[str],
interpolate_na_method: Optional[str],
) -> xarray.Dataset:
"""
Parameters
----------
vector_data: :obj:`geopandas.GeoDataFrame`
A GeoDataFrame containing the vector data.
measurements: list[str]
Attributes name or list of names to be included. If a list is specified,
the measurements will be returned in the order requested.
By default all available measurements are included.
group_by: str, optional
When specified, perform basic combining/reducing of the data on this column.
interpolate_na_method: {'linear', 'nearest', 'cubic'}, optional
This is the method for interpolation to use to fill in the nodata with
:func:`scipy.interpolate.griddata`.
Expand All @@ -252,20 +254,22 @@ def _get_dataset(
"""
data_vars = {}
for measurement in measurements:
if group_by:
for measurement in self._measurements:
if self._group_by:
grid_array = self._get_grouped_grid(
vector_data[[measurement, "geometry"]], measurement, group_by
self._vector_data[[measurement, "geometry"]],
measurement_name=measurement,
)
else:
grid_array = self._get_grid(
vector_data[[measurement, "geometry"]], measurement
self._vector_data[[measurement, "geometry"]],
measurement_name=measurement,
)
if grid_array is not None:
data_vars[measurement] = grid_array

if group_by:
self._grid_coords[group_by] = list(vector_data.groups.keys()) # type: ignore
if self._group_by:
self._grid_coords[self._group_by] = list(self._vector_data.groups.keys()) # type: ignore

out_xds = xarray.Dataset(data_vars=data_vars, coords=self._grid_coords)

Expand All @@ -288,7 +292,6 @@ def _get_grouped_grid(
self,
grouped_dataframe: geopandas.GeoDataFrame,
measurement_name: str,
group_by: str,
) -> Optional[tuple]:
"""Retrieve the variable data to append to the ssurgo :obj:`xarray.Dataset`.
This method is designed specifically to work on a dataframe that has
Expand All @@ -302,8 +305,6 @@ def _get_grouped_grid(
Attributes name or list of names to be included. If a list is specified,
the measurements will be returned in the order requested.
By default all available measurements are included.
group_by: str
Perform basic combining/reducing of the data on this column.
Returns
-------
Expand Down Expand Up @@ -343,7 +344,7 @@ def _get_grouped_grid(
self._update_time_attrs(attrs, image_data)

return (
(group_by, "y", "x"),
(self._group_by, "y", "x"),
image_data,
attrs,
{"grid_mapping": DEFAULT_GRID_MAP},
Expand Down
Loading

0 comments on commit 13aa686

Please sign in to comment.