Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support spatialpandas DaskGeoDataFrame #4792

Merged
merged 18 commits into from
Jan 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions holoviews/core/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@
from .multipath import MultiInterface # noqa (API import)
from .image import ImageInterface # noqa (API import)
from .pandas import PandasInterface # noqa (API import)
from .spatialpandas import SpatialPandasInterface # noqa (API import)
from .spatialpandas import SpatialPandasInterface # noqa (API import)
from .spatialpandas_dask import DaskSpatialPandasInterface # noqa (API import)
from .xarray import XArrayInterface # noqa (API import)

default_datatype = 'dataframe'

datatypes = ['dataframe', 'dictionary', 'grid', 'xarray', 'dask',
'cuDF', 'spatialpandas', 'array', 'multitabular', 'ibis']
datatypes = ['dataframe', 'dictionary', 'grid', 'xarray', 'multitabular',
'spatialpandas', 'dask_spatialpandas', 'dask', 'cuDF', 'array',
'ibis']


def concat(datasets, datatype=None):
Expand Down
102 changes: 43 additions & 59 deletions holoviews/core/data/spatialpandas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import absolute_import, division

import sys
import warnings

from collections import defaultdict

Expand All @@ -16,12 +15,14 @@

class SpatialPandasInterface(MultiInterface):

types = ()
base_interface = PandasInterface

datatype = 'spatialpandas'

multi = True

types = ()

@classmethod
def loaded(cls):
return 'spatialpandas' in sys.modules
Expand All @@ -30,37 +31,50 @@ def loaded(cls):
def applies(cls, obj):
if not cls.loaded():
return False
from spatialpandas import GeoDataFrame, GeoSeries
is_sdf = isinstance(obj, (GeoDataFrame, GeoSeries))
is_sdf = isinstance(obj, cls.data_types())
if 'geopandas' in sys.modules and not 'geoviews' in sys.modules:
import geopandas as gpd
is_sdf |= isinstance(obj, (gpd.GeoDataFrame, gpd.GeoSeries))
return is_sdf

@classmethod
def geo_column(cls, data):
def data_types(cls):
from spatialpandas import GeoDataFrame, GeoSeries
return (GeoDataFrame, GeoSeries)

@classmethod
def series_type(cls):
from spatialpandas import GeoSeries
return GeoSeries

@classmethod
def frame_type(cls):
from spatialpandas import GeoDataFrame
return GeoDataFrame

@classmethod
def geo_column(cls, data):
col = 'geometry'
if col in data and isinstance(data[col], GeoSeries):
stypes = cls.series_type()
if col in data and isinstance(data[col], stypes):
return col
cols = [c for c in data.columns if isinstance(data[c], GeoSeries)]
cols = [c for c in data.columns if isinstance(data[c], stypes)]
if not cols:
raise ValueError('No geometry column found in spatialpandas.GeoDataFrame, '
'use the PandasInterface instead.')
return cols[0]

@classmethod
def init(cls, eltype, data, kdims, vdims):
import pandas as pd
from spatialpandas import GeoDataFrame, GeoSeries
from spatialpandas import GeoDataFrame

if kdims is None:
kdims = eltype.kdims

if vdims is None:
vdims = eltype.vdims

if isinstance(data, GeoSeries):
if isinstance(data, cls.series_type()):
data = data.to_frame()

if 'geopandas' in sys.modules:
Expand All @@ -74,8 +88,8 @@ def init(cls, eltype, data, kdims, vdims):
data = from_shapely(data)
if isinstance(data, list):
data = from_multi(eltype, data, kdims, vdims)
elif not isinstance(data, GeoDataFrame):
raise ValueError("SpatialPandasInterface only support spatialpandas DataFrames.")
elif not isinstance(data, cls.frame_type()):
raise ValueError("%s only support spatialpandas DataFrames." % cls.__name__)
elif 'geometry' not in data:
cls.geo_column(data)

Expand Down Expand Up @@ -116,7 +130,7 @@ def dtype(cls, dataset, dimension):
dim = dataset.get_dimension(dimension, strict=True)
if dim in cls.geom_dims(dataset):
col = cls.geo_column(dataset.data)
return dataset.data[col].values.numpy_dtype
return dataset.data[col].dtype.subtype
return dataset.data[dim.name].dtype

@classmethod
Expand Down Expand Up @@ -157,43 +171,14 @@ def select(cls, dataset, selection_mask=None, **selection):
elif selection_mask is None:
selection_mask = cls.select_mask(dataset, selection)
indexed = cls.indexed(dataset, selection)
df = df.iloc[selection_mask]
df = df[selection_mask]
if indexed and len(df) == 1 and len(dataset.vdims) == 1:
return df[dataset.vdims[0].name].iloc[0]
return df

@classmethod
def select_mask(cls, dataset, selection):
mask = np.ones(len(dataset.data), dtype=np.bool)
for dim, k in selection.items():
if isinstance(k, tuple):
k = slice(*k)
arr = dataset.data[dim].values
if isinstance(k, slice):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', r'invalid value encountered')
if k.start is not None:
mask &= k.start <= arr
if k.stop is not None:
mask &= arr < k.stop
elif isinstance(k, (set, list)):
iter_slcs = []
for ik in k:
with warnings.catch_warnings():
warnings.filterwarnings('ignore', r'invalid value encountered')
iter_slcs.append(arr == ik)
mask &= np.logical_or.reduce(iter_slcs)
elif callable(k):
mask &= k(arr)
else:
index_mask = arr == k
if dataset.ndims == 1 and np.sum(index_mask) == 0:
data_index = np.argmin(np.abs(arr - k))
mask = np.zeros(len(dataset), dtype=np.bool)
mask[data_index] = True
else:
mask &= index_mask
return mask
return cls.base_interface.select_mask(dataset, selection)

@classmethod
def geom_dims(cls, dataset):
Expand All @@ -203,13 +188,7 @@ def geom_dims(cls, dataset):
@classmethod
def dimension_type(cls, dataset, dim):
dim = dataset.get_dimension(dim)
col = cls.geo_column(dataset.data)
if dim in cls.geom_dims(dataset) and len(dataset.data):
arr = geom_to_array(dataset.data[col].iloc[0])
ds = dataset.clone(arr, datatype=cls.subtypes, vdims=[])
return ds.interface.dimension_type(ds, dim)
else:
return cls.dtype(dataset, dim).type
return cls.dtype(dataset, dim).type

@classmethod
def isscalar(cls, dataset, dim, per_geom=False):
Expand Down Expand Up @@ -238,15 +217,15 @@ def range(cls, dataset, dim):
else:
return (bounds[1], bounds[3])
else:
return Interface.range(dataset, dim)
return cls.base_interface.range(dataset, dim)

@classmethod
def groupby(cls, dataset, dimensions, container_type, group_type, **kwargs):
geo_dims = cls.geom_dims(dataset)
if any(d in geo_dims for d in dimensions):
raise DataError("SpatialPandasInterface does not allow grouping "
"by geometry dimension.", cls)
return PandasInterface.groupby(dataset, dimensions, container_type, group_type, **kwargs)
return cls.base_interface.groupby(dataset, dimensions, container_type, group_type, **kwargs)

@classmethod
def aggregate(cls, columns, dimensions, function, **kwargs):
Expand All @@ -270,7 +249,7 @@ def sort(cls, dataset, by=[], reverse=False):
if any(d in geo_dims for d in by):
raise DataError("SpatialPandasInterface does not allow sorting "
"by geometry dimension.", cls)
return PandasInterface.sort(dataset, by, reverse)
return cls.base_interface.sort(dataset, by, reverse)

@classmethod
def length(cls, dataset):
Expand All @@ -279,7 +258,7 @@ def length(cls, dataset):
column = dataset.data[col_name]
geom_type = cls.geom_type(dataset)
if not isinstance(column.dtype, MultiPointDtype) and geom_type != 'Point':
return PandasInterface.length(dataset)
return cls.base_interface.length(dataset)
length = 0
for i, geom in enumerate(column):
if isinstance(geom, Point):
Expand All @@ -290,11 +269,11 @@ def length(cls, dataset):

@classmethod
def nonzero(cls, dataset):
return bool(cls.length(dataset))
return bool(len(dataset.data.head(1)))

@classmethod
def redim(cls, dataset, dimensions):
return PandasInterface.redim(dataset, dimensions)
return cls.base_interface.redim(dataset, dimensions)

@classmethod
def add_dimension(cls, dataset, dimension, dim_pos, values, vdim):
Expand Down Expand Up @@ -386,13 +365,18 @@ def values(cls, dataset, dimension, expanded=True, flat=True, compute=True, keep
if isgeom and keep_index:
return data[geom_col]
elif not isgeom:
if is_points:
return data[dimension.name].values
return get_value_array(data, dimension, expanded, keep_index, geom_col, is_points)
elif not len(data):
return np.array([])

geom_type = cls.geom_type(dataset)
index = geom_dims.index(dimension)
return geom_array_to_array(data[geom_col].values, index, expanded, geom_type)
geom_series = data[geom_col]
if compute and hasattr(geom_series, 'compute'):
geom_series = geom_series.compute()
return geom_array_to_array(geom_series.values, index, expanded, geom_type)

@classmethod
def split(cls, dataset, start, end, datatype, **kwargs):
Expand Down Expand Up @@ -604,7 +588,7 @@ def get_value_array(data, dimension, expanded, keep_index, geom_col,
all_scalar = True
arrays, scalars = [], []
for i, geom in enumerate(data[geom_col]):
length = geom_length(geom)
length = 1 if is_points else geom_length(geom)
val = column.iloc[i]
scalar = isscalar(val)
if scalar:
Expand Down
84 changes: 84 additions & 0 deletions holoviews/core/data/spatialpandas_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import absolute_import

import sys

import numpy as np

from .dask import DaskInterface
from .interface import Interface
from .spatialpandas import SpatialPandasInterface


class DaskSpatialPandasInterface(SpatialPandasInterface):

base_interface = DaskInterface

datatype = 'dask_spatialpandas'

@classmethod
def loaded(cls):
return 'spatialpandas.dask' in sys.modules

@classmethod
def data_types(cls):
from spatialpandas.dask import DaskGeoDataFrame, DaskGeoSeries
return (DaskGeoDataFrame, DaskGeoSeries)

@classmethod
def series_type(cls):
from spatialpandas.dask import DaskGeoSeries
return DaskGeoSeries

@classmethod
def frame_type(cls):
from spatialpandas.dask import DaskGeoDataFrame
return DaskGeoDataFrame

@classmethod
def init(cls, eltype, data, kdims, vdims):
import dask.dataframe as dd
data, dims, params = super(DaskSpatialPandasInterface, cls).init(
eltype, data, kdims, vdims
)
if not isinstance(data, cls.frame_type()):
data = dd.from_pandas(data, npartitions=1)
return data, dims, params

@classmethod
def partition_values(cls, df, dataset, dimension, expanded, flat):
ds = dataset.clone(df, datatype=['spatialpandas'])
return ds.interface.values(ds, dimension, expanded, flat)

@classmethod
def values(cls, dataset, dimension, expanded=True, flat=True, compute=True, keep_index=False):
if compute and not keep_index:
meta = np.array([], dtype=cls.dtype(dataset, dimension))
return dataset.data.map_partitions(
cls.partition_values, meta=meta, dataset=dataset,
dimension=dimension, expanded=expanded, flat=flat
).compute()
values = super(DaskSpatialPandasInterface, cls).values(
dataset, dimension, expanded, flat, compute, keep_index
)
if compute and not keep_index and hasattr(values, 'compute'):
return values.compute()
return values

@classmethod
def split(cls, dataset, start, end, datatype, **kwargs):
ds = dataset.clone(dataset.data.compute(), datatype=['spatialpandas'])
return ds.interface.split(ds, start, end, datatype, **kwargs)

@classmethod
def iloc(cls, dataset, index):
rows, cols = index
if rows is not None:
raise NotImplementedError
return super(DaskSpatialPandasInterface, cls).iloc(dataset, index)

@classmethod
def add_dimension(cls, dataset, dimension, dim_pos, values, vdim):
return cls.base_interface.add_dimension(dataset, dimension, dim_pos, values, vdim)


Interface.register(DaskSpatialPandasInterface)
6 changes: 3 additions & 3 deletions holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,7 @@ def _process(self, element, key=None):
if element._plot_id in self._precomputed:
data, col = self._precomputed[element._plot_id]
else:
if element.interface.datatype != 'spatialpandas':
if 'spatialpandas' not in element.interface.datatype:
element = element.clone(datatype=['spatialpandas'])
data = element.data
col = element.interface.geo_column(data)
Expand Down Expand Up @@ -1429,8 +1429,8 @@ class rasterize(AggregationOperation):

_transforms = [(Image, regrid),
(Polygons, geometry_rasterize),
(lambda x: (isinstance(x, Path) and
x.interface.datatype == 'spatialpandas'),
(lambda x: (isinstance(x, (Path, Points)) and
'spatialpandas' in x.interface.datatype),
geometry_rasterize),
(TriMesh, trimesh_rasterize),
(QuadMesh, quadmesh_rasterize),
Expand Down
Loading