Skip to content

Commit

Permalink
Merge pull request #12 from bcdev/forman-11-persist_inmem_slices
Browse files Browse the repository at this point in the history
Persist in-memory slices before appending
  • Loading branch information
forman authored Jan 12, 2024
2 parents abd1017 + 55ca905 commit 0ee0ad7
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 19 deletions.
1 change: 1 addition & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def test_schema(self):
"fixed_dims",
"included_variables",
"logging",
"persist_mem_slices",
"slice_engine",
"slice_polling",
"slice_storage_options",
Expand Down
22 changes: 15 additions & 7 deletions tests/test_slicesource.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from zappend.context import Context
from zappend.fsutil.fileobj import FileObj
from zappend.slicesource.common import open_slice_source
from zappend.slicesource.identity import IdentitySliceSource
from zappend.slicesource.memory import MemorySliceSource
from zappend.slicesource.persistent import PersistentSliceSource
from zappend.slicesource.temporary import TemporarySliceSource
from .helpers import clear_memory_fs
from .helpers import make_test_dataset

Expand All @@ -20,16 +21,23 @@ class SliceSourceTest(unittest.TestCase):
def setUp(self):
clear_memory_fs()

def test_in_memory(self):
slice_dir = FileObj("memory://slice.zarr")
dataset = make_test_dataset(uri=slice_dir.uri)
def test_memory_slice_source(self):
dataset = make_test_dataset()
ctx = Context(dict(target_dir="memory://target.zarr"))
slice_zarr = open_slice_source(ctx, dataset)
self.assertIsInstance(slice_zarr, IdentitySliceSource)
self.assertIsInstance(slice_zarr, MemorySliceSource)
with slice_zarr as slice_ds:
self.assertIsInstance(slice_ds, xr.Dataset)

def test_temporary_slice_source(self):
dataset = make_test_dataset()
ctx = Context(dict(target_dir="memory://target.zarr", persist_mem_slices=True))
slice_zarr = open_slice_source(ctx, dataset)
self.assertIsInstance(slice_zarr, TemporarySliceSource)
with slice_zarr as slice_ds:
self.assertIsInstance(slice_ds, xr.Dataset)

def test_persistent_zarr(self):
def test_persistent_slice_source_for_zarr(self):
slice_dir = FileObj("memory://slice.zarr")
make_test_dataset(uri=slice_dir.uri)
ctx = Context(dict(target_dir="memory://target.zarr"))
Expand All @@ -38,7 +46,7 @@ def test_persistent_zarr(self):
with slice_zarr as slice_ds:
self.assertIsInstance(slice_ds, xr.Dataset)

# def test_persistent_nc(self):
# def test_persistent_slice_source_for_nc(self):
# slice_ds = make_test_dataset()
# slice_file = FileObj("memory:///slice.nc")
# with slice_file.fs.open(slice_file.path, "wb") as f:
Expand Down
10 changes: 10 additions & 0 deletions zappend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,16 @@
"type": "array",
"items": {"type": "string", "minLength": 1},
},
persist_mem_slices={
"description": (
"Persist in-memory slices and reopen from a temporary Zarr before"
" appending them to the target dataset."
" This can prevent expensive re-computation of dask chunks at the"
" cost of additional i/o."
),
"type": "boolean",
"default": False,
},
disable_rollback={
"description": (
"Disable rolling back dataset changes on failure."
Expand Down
6 changes: 5 additions & 1 deletion zappend/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from .config import DEFAULT_SLICE_POLLING_INTERVAL
from .config import DEFAULT_SLICE_POLLING_TIMEOUT
from .config import DEFAULT_ZARR_VERSION
from .metadata import DatasetMetadata
from .fsutil.fileobj import FileObj
from .metadata import DatasetMetadata


class Context:
Expand Down Expand Up @@ -92,6 +92,10 @@ def slice_polling(self) -> tuple[float, float] | tuple[None, None]:
def temp_dir(self) -> FileObj:
return self._temp_dir

@property
def persist_mem_slices(self) -> bool:
return self._config.get("persist_mem_slices", False)

@property
def disable_rollback(self) -> bool:
return self._config.get("disable_rollback", False)
Expand Down
4 changes: 4 additions & 0 deletions zappend/slicesource/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class SliceSource(ABC):
def __init__(self, ctx: Context):
self._ctx = ctx

@property
def ctx(self) -> Context:
return self._ctx

def __enter__(self) -> xr.Dataset:
return self.open()

Expand Down
12 changes: 8 additions & 4 deletions zappend/slicesource/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

import xarray as xr

from ..context import Context
from ..fsutil.fileobj import FileObj
from .abc import SliceSource
from .identity import IdentitySliceSource
from .memory import MemorySliceSource
from .persistent import PersistentSliceSource
from .temporary import TemporarySliceSource
from ..context import Context
from ..fsutil.fileobj import FileObj


def open_slice_source(
Expand All @@ -23,7 +24,10 @@ def open_slice_source(
:return: A new slice source instance
"""
if isinstance(slice_obj, xr.Dataset):
return IdentitySliceSource(ctx, slice_obj, slice_index)
if ctx.persist_mem_slices:
return TemporarySliceSource(ctx, slice_obj, slice_index)
else:
return MemorySliceSource(ctx, slice_obj, slice_index)
if isinstance(slice_obj, str):
slice_file = FileObj(slice_obj, storage_options=ctx.slice_storage_options)
return PersistentSliceSource(ctx, slice_file)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from .abc import SliceSource


class IdentitySliceSource(SliceSource):
"""A slice source that returns the dataset passed in when opened.
class MemorySliceSource(SliceSource):
"""A slice source that uses the in-memory dataset passed in.
:param ctx: Processing context
:param slice_ds: The slice dataset
Expand Down
9 changes: 4 additions & 5 deletions zappend/slicesource/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def close(self):

def _wait_for_slice_dataset(self) -> xr.Dataset:
slice_ds: xr.Dataset | None = None
interval, timeout = self._ctx.slice_polling
interval, timeout = self.ctx.slice_polling
if timeout is not None:
# t0 = time.monotonic()
# while (time.monotonic() - t0) < timeout:
Expand All @@ -52,8 +52,7 @@ def _wait_for_slice_dataset(self) -> xr.Dataset:
slice_ds = self._open_slice_dataset()
except OSError:
logger.debug(
f"Slice not ready or corrupt,"
f" retrying after {interval} seconds"
f"Slice not ready or corrupt, retrying after {interval} seconds"
)
time.sleep(interval)
else:
Expand All @@ -64,14 +63,14 @@ def _wait_for_slice_dataset(self) -> xr.Dataset:
return slice_ds

def _open_slice_dataset(self) -> xr.Dataset:
engine = self._ctx.slice_engine
engine = self.ctx.slice_engine
if engine is None and (
self._slice_file.path.endswith(".zarr")
or self._slice_file.path.endswith(".zarr.zip")
):
engine = "zarr"
if engine == "zarr":
storage_options = self._ctx.slice_storage_options
storage_options = self.ctx.slice_storage_options
return xr.open_zarr(self._slice_file.uri, storage_options=storage_options)

with self._slice_file.fs.open(self._slice_file.path, "rb") as f:
Expand Down
58 changes: 58 additions & 0 deletions zappend/slicesource/temporary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright © 2024 Norman Fomferra
# Permissions are hereby granted under the terms of the MIT License:
# https://opensource.org/licenses/MIT.

import xarray as xr

from .memory import MemorySliceSource
from ..context import Context
from ..fsutil.fileobj import FileObj
from ..log import logger


class TemporarySliceSource(MemorySliceSource):
"""A slice source that persists the in-memory dataset and returns
the re-opened dataset instance.
:param ctx: Processing context
:param slice_ds: The slice dataset
:param slice_index: An index for slice identification
"""

def __init__(self, ctx: Context, slice_ds: xr.Dataset, slice_index: int):
super().__init__(ctx, slice_ds, slice_index)
self._temp_slice_dir: FileObj | None = None
self._temp_slice_ds: xr.Dataset | None = None

def open(self) -> xr.Dataset:
slice_index = self._slice_index
temp_slice_dir = self.ctx.temp_dir / f"slice-{self._slice_index}.zarr"
self._temp_slice_dir = temp_slice_dir
temp_slice_store = temp_slice_dir.fs.get_mapper(
temp_slice_dir.path, create=True
)
logger.info(
f"Persisting in-memory slice dataset #{self._slice_index}"
f" to {temp_slice_dir.uri}"
)
self._slice_ds.to_zarr(temp_slice_store)
self._slice_ds = None
self._temp_slice_ds = xr.open_zarr(temp_slice_store)
return self._temp_slice_ds

def close(self):
if self._temp_slice_ds is not None:
self._temp_slice_ds.close()
self._temp_slice_ds = None

temp_slice_dir = self._temp_slice_dir
if temp_slice_dir is not None:
self._temp_slice_dir = None
if temp_slice_dir.exists():
logger.info(
f"Removing temporary dataset {temp_slice_dir.uri}"
f" for slice #{self._slice_index}"
)
temp_slice_dir.delete(recursive=True)

super().close()

0 comments on commit 0ee0ad7

Please sign in to comment.