Skip to content

Commit

Permalink
Add obsm, obsp, varm, varp to ExperimentAxisQuery (#179)
Browse files Browse the repository at this point in the history
* Add obsm, etc to ExperimentAxisQuery

Co-authored-by: Paul Fisher <[email protected]>
  • Loading branch information
ebezzi and thetorpedodog authored Dec 4, 2023
1 parent c6f6fd8 commit 10dc344
Showing 1 changed file with 153 additions and 16 deletions.
169 changes: 153 additions & 16 deletions python-spec/src/somacore/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,34 @@ def obsp(self, layer: str) -> data.SparseRead:
return self._axisp_inner(_Axis.OBS, layer)

def varp(self, layer: str) -> data.SparseRead:
"""Returns an ``varp`` layer as a sparse read.
"""Returns a ``varp`` layer as a sparse read.
Lifecycle: maturing
"""
return self._axisp_inner(_Axis.VAR, layer)

def obsm(self, layer: str) -> data.SparseRead:
"""Returns an ``obsm`` layer as a sparse read.
Lifecycle: experimental
"""
return self._axism_inner(_Axis.OBS, layer)

def varm(self, layer: str) -> data.SparseRead:
"""Returns a ``varm`` layer as a sparse read.
Lifecycle: experimental
"""
return self._axism_inner(_Axis.VAR, layer)

def to_anndata(
self,
X_name: str,
*,
column_names: Optional[AxisColumnNames] = None,
X_layers: Sequence[str] = (),
obsm_layers: Sequence[str] = (),
obsp_layers: Sequence[str] = (),
varm_layers: Sequence[str] = (),
varp_layers: Sequence[str] = (),
) -> anndata.AnnData:
"""
Executes the query and return result as an ``AnnData`` in-memory object.
Expand All @@ -258,13 +274,25 @@ def to_anndata(
to read.
X_layers: Additional X layers to read and return
in the ``layers`` slot.
obsm_layers:
Additional obsm layers to read and return in the obsm slot.
obsp_layers:
Additional obsp layers to read and return in the obsp slot.
varm_layers:
Additional varm layers to read and return in the varm slot.
varp_layers:
Additional varp layers to read and return in the varp slot.
Lifecycle: maturing
"""
return self._read(
X_name,
column_names=column_names or AxisColumnNames(obs=None, var=None),
X_layers=X_layers,
obsm_layers=obsm_layers,
obsp_layers=obsp_layers,
varm_layers=varm_layers,
varp_layers=varp_layers,
).to_anndata()

# Context management
Expand Down Expand Up @@ -306,19 +334,32 @@ def _read(
*,
column_names: AxisColumnNames,
X_layers: Sequence[str],
obsm_layers: Sequence[str] = (),
obsp_layers: Sequence[str] = (),
varm_layers: Sequence[str] = (),
varp_layers: Sequence[str] = (),
) -> "_AxisQueryResult":
"""Reads the entire query result into in-memory Arrow tables.
"""Reads the entire query result in memory.
This is a low-level routine intended to be used by loaders for other
in-core formats, such as AnnData, which can be created from the
resulting Tables.
resulting objects.
Args:
X_name: The X layer to read and return in the ``X`` slot.
column_names: The columns in the ``var`` and ``obs`` dataframes
to read.
X_layers: Additional X layers to read and return
in the ``layers`` slot.
obsm_layers:
Additional obsm layers to read and return in the obsm slot.
obsp_layers:
Additional obsp layers to read and return in the obsp slot.
varm_layers:
Additional varm layers to read and return in the varm slot.
varp_layers:
Additional varp layers to read and return in the varp slot.
"""
x_collection = self._ms.X
all_x_names = [X_name] + list(X_layers)
Expand All @@ -333,6 +374,22 @@ def _read(
raise NotImplementedError("Dense array unsupported")
all_x_arrays[_xname] = x_array

def _read_axis_mappings(fn, axis, keys: Sequence[str]) -> Dict[str, np.ndarray]:
return {key: fn(axis, key) for key in keys}

obsm_ft = self._threadpool.submit(
_read_axis_mappings, self._axism_inner_ndarray, _Axis.OBS, obsm_layers
)
obsp_ft = self._threadpool.submit(
_read_axis_mappings, self._axisp_inner_ndarray, _Axis.OBS, obsp_layers
)
varm_ft = self._threadpool.submit(
_read_axis_mappings, self._axism_inner_ndarray, _Axis.VAR, varm_layers
)
varp_ft = self._threadpool.submit(
_read_axis_mappings, self._axisp_inner_ndarray, _Axis.VAR, varp_layers
)

obs_table, var_table = self._read_both_axes(column_names)

x_matrices = {
Expand All @@ -343,7 +400,23 @@ def _read(
}

x = x_matrices.pop(X_name)
return _AxisQueryResult(obs=obs_table, var=var_table, X=x, X_layers=x_matrices)

obs = obs_table.to_pandas()
obs.index = obs.index.astype(str)

var = var_table.to_pandas()
var.index = var.index.astype(str)

return _AxisQueryResult(
obs=obs,
var=var,
X=x,
obsm=obsm_ft.result(),
obsp=obsp_ft.result(),
varm=varm_ft.result(),
varp=varp_ft.result(),
X_layers=x_matrices,
)

def _read_both_axes(
self,
Expand Down Expand Up @@ -433,9 +506,64 @@ def _axisp_inner(
f" stored in {p_name} layer {layer!r}"
)

joinids = getattr(self._joinids, axis.value)
joinids = axis.getattr_from(self._joinids)
return ap_layer.read((joinids, joinids))

def _axism_inner(
self,
axis: "_Axis",
layer: str,
) -> data.SparseRead:
m_name = f"{axis.value}m"

try:
axism = axis.getitem_from(self._ms, suf="m")
except KeyError:
raise ValueError(f"Measurement does not contain {m_name} data") from None

try:
axism_layer = axism[layer]
except KeyError as ke:
raise ValueError(f"layer {layer!r} is not available in {m_name}") from ke

if not isinstance(axism_layer, data.SparseNDArray):
raise TypeError(f"Unexpected SOMA type stored in '{m_name}' layer")

joinids = axis.getattr_from(self._joinids)
return axism_layer.read((joinids, slice(None)))

def _convert_to_ndarray(
self, axis: "_Axis", table: pa.Table, n_row: int, n_col: int
) -> np.ndarray:
indexer: pd.Index = axis.getattr_from(self.indexer, pre="by_")
idx = indexer(table["soma_dim_0"])
z = np.zeros(n_row * n_col, dtype=np.float32)
np.put(z, idx * n_col + table["soma_dim_1"], table["soma_data"])
return z.reshape(n_row, n_col)

def _axisp_inner_ndarray(
self,
axis: "_Axis",
layer: str,
) -> np.ndarray:
n_row = n_col = len(axis.getattr_from(self._joinids))

table = self._axisp_inner(axis, layer).tables().concat()
return self._convert_to_ndarray(axis, table, n_row, n_col)

def _axism_inner_ndarray(
self,
axis: "_Axis",
layer: str,
) -> np.ndarray:
axism = axis.getitem_from(self._ms, suf="m")

_, n_col = axism[layer].shape
n_row = len(axis.getattr_from(self._joinids))

table = self._axism_inner(axis, layer).tables().concat()
return self._convert_to_ndarray(axis, table, n_row, n_col)

@property
def _obs_df(self) -> data.DataFrame:
return self.experiment.obs
Expand Down Expand Up @@ -466,24 +594,33 @@ def _threadpool(self) -> futures.ThreadPoolExecutor:
class _AxisQueryResult:
"""The result of running :meth:`ExperimentAxisQuery.read`. Private."""

obs: pa.Table
"""Experiment.obs query slice, as an Arrow Table"""
var: pa.Table
"""Experiment.ms[...].var query slice, as an Arrow Table"""
obs: pd.DataFrame
"""Experiment.obs query slice, as a pandas DataFrame"""
var: pd.DataFrame
"""Experiment.ms[...].var query slice, as a pandas DataFrame"""
X: sparse.csr_matrix
"""Experiment.ms[...].X[...] query slice, as an SciPy sparse.csr_matrix """
X_layers: Dict[str, sparse.csr_matrix] = attrs.field(factory=dict)
"""Any additional X layers requested, as SciPy sparse.csr_matrix(s)"""
obsm: Dict[str, np.ndarray] = attrs.field(factory=dict)
"""Experiment.obsm query slice, as a numpy ndarray"""
obsp: Dict[str, np.ndarray] = attrs.field(factory=dict)
"""Experiment.obsp query slice, as a numpy ndarray"""
varm: Dict[str, np.ndarray] = attrs.field(factory=dict)
"""Experiment.varm query slice, as a numpy ndarray"""
varp: Dict[str, np.ndarray] = attrs.field(factory=dict)
"""Experiment.varp query slice, as a numpy ndarray"""

def to_anndata(self) -> anndata.AnnData:
obs = self.obs.to_pandas()
obs.index = obs.index.astype(str)

var = self.var.to_pandas()
var.index = var.index.astype(str)

return anndata.AnnData(
X=self.X, obs=obs, var=var, layers=(self.X_layers or None)
X=self.X,
obs=self.obs,
var=self.var,
obsm=(self.obsm or None),
obsp=(self.obsp or None),
varm=(self.varm or None),
varp=(self.varp or None),
layers=(self.X_layers or None),
)


Expand Down

0 comments on commit 10dc344

Please sign in to comment.