diff --git a/src/zarr/storage/__init__.py b/src/zarr/storage/__init__.py index 6703aa272..17b11f54a 100644 --- a/src/zarr/storage/__init__.py +++ b/src/zarr/storage/__init__.py @@ -3,6 +3,7 @@ from zarr.storage.logging import LoggingStore from zarr.storage.memory import MemoryStore from zarr.storage.remote import RemoteStore +from zarr.storage.wrapper import WrapperStore from zarr.storage.zip import ZipStore __all__ = [ @@ -12,6 +13,7 @@ "RemoteStore", "StoreLike", "StorePath", + "WrapperStore", "ZipStore", "make_store_path", ] diff --git a/src/zarr/storage/logging.py b/src/zarr/storage/logging.py index bc90b4f30..b26c4a18e 100644 --- a/src/zarr/storage/logging.py +++ b/src/zarr/storage/logging.py @@ -7,15 +7,19 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any -from zarr.abc.store import ByteRangeRequest, Store +from zarr.abc.store import Store +from zarr.storage.wrapper import WrapperStore if TYPE_CHECKING: - from collections.abc import AsyncIterator, Generator, Iterable + from collections.abc import AsyncGenerator, Generator, Iterable + from zarr.abc.store import ByteRangeRequest from zarr.core.buffer import Buffer, BufferPrototype + counter: defaultdict[str, int] + -class LoggingStore(Store): +class LoggingStore(WrapperStore[Store]): """ Store wrapper that logs all calls to the wrapped store. @@ -34,7 +38,6 @@ class LoggingStore(Store): Counter of number of times each method has been called """ - _store: Store counter: defaultdict[str, int] def __init__( @@ -43,11 +46,10 @@ def __init__( log_level: str = "DEBUG", log_handler: logging.Handler | None = None, ) -> None: - self._store = store + super().__init__(store) self.counter = defaultdict(int) self.log_level = log_level self.log_handler = log_handler - self._configure_logger(log_level, log_handler) def _configure_logger( @@ -203,19 +205,19 @@ async def set_partial_values( with self.log(keys): return await self._store.set_partial_values(key_start_values=key_start_values) - async def list(self) -> AsyncIterator[str]: + async def list(self) -> AsyncGenerator[str, None]: # docstring inherited with self.log(): async for key in self._store.list(): yield key - async def list_prefix(self, prefix: str) -> AsyncIterator[str]: + async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: # docstring inherited with self.log(prefix): async for key in self._store.list_prefix(prefix=prefix): yield key - async def list_dir(self, prefix: str) -> AsyncIterator[str]: + async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: # docstring inherited with self.log(prefix): async for key in self._store.list_dir(prefix=prefix): diff --git a/src/zarr/storage/wrapper.py b/src/zarr/storage/wrapper.py new file mode 100644 index 000000000..1869daf14 --- /dev/null +++ b/src/zarr/storage/wrapper.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, AsyncIterator, Iterable + from types import TracebackType + from typing import Any, Self + + from zarr.abc.store import ByteRangeRequest + from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.common import BytesLike + +from zarr.abc.store import Store + +T_Store = TypeVar("T_Store", bound=Store) + + +class WrapperStore(Store, Generic[T_Store]): + """ + A store class that wraps an existing ``Store`` instance. + By default all of the store methods are delegated to the wrapped store instance, which is + accessible via the ``._wrapped`` attribute of this class. + + Use this class to modify or extend the behavior of the other store classes. + """ + + _store: T_Store + + def __init__(self, store: T_Store) -> None: + self._store = store + + @classmethod + async def open(cls: type[Self], store_cls: type[T_Store], *args: Any, **kwargs: Any) -> Self: + store = store_cls(*args, **kwargs) + await store._open() + return cls(store=store) + + def __enter__(self) -> Self: + return type(self)(self._store.__enter__()) + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self._store.__exit__(exc_type, exc_value, traceback) + + async def _open(self) -> None: + await self._store._open() + + async def _ensure_open(self) -> None: + await self._store._ensure_open() + + async def is_empty(self, prefix: str) -> bool: + return await self._store.is_empty(prefix) + + async def clear(self) -> None: + return await self._store.clear() + + @property + def read_only(self) -> bool: + return self._store.read_only + + def _check_writable(self) -> None: + return self._store._check_writable() + + def __eq__(self, value: object) -> bool: + return type(self) is type(value) and self._store.__eq__(value) + + async def get( + self, key: str, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None + ) -> Buffer | None: + return await self._store.get(key, prototype, byte_range) + + async def get_partial_values( + self, + prototype: BufferPrototype, + key_ranges: Iterable[tuple[str, ByteRangeRequest]], + ) -> list[Buffer | None]: + return await self._store.get_partial_values(prototype, key_ranges) + + async def exists(self, key: str) -> bool: + return await self._store.exists(key) + + async def set(self, key: str, value: Buffer) -> None: + await self._store.set(key, value) + + async def set_if_not_exists(self, key: str, value: Buffer) -> None: + return await self._store.set_if_not_exists(key, value) + + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + await self._store._set_many(values) + + @property + def supports_writes(self) -> bool: + return self._store.supports_writes + + @property + def supports_deletes(self) -> bool: + return self._store.supports_deletes + + async def delete(self, key: str) -> None: + await self._store.delete(key) + + @property + def supports_partial_writes(self) -> bool: + return self._store.supports_partial_writes + + async def set_partial_values( + self, key_start_values: Iterable[tuple[str, int, BytesLike]] + ) -> None: + return await self._store.set_partial_values(key_start_values) + + @property + def supports_listing(self) -> bool: + return self._store.supports_listing + + def list(self) -> AsyncIterator[str]: + return self._store.list() + + def list_prefix(self, prefix: str) -> AsyncIterator[str]: + return self._store.list_prefix(prefix) + + def list_dir(self, prefix: str) -> AsyncIterator[str]: + return self._store.list_dir(prefix) + + async def delete_dir(self, prefix: str) -> None: + return await self._store.delete_dir(prefix) + + def close(self) -> None: + self._store.close() + + async def _get_many( + self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest | None]] + ) -> AsyncGenerator[tuple[str, Buffer | None], None]: + async for req in self._store._get_many(requests): + yield req diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index d26d83e56..34e5277d7 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -1,9 +1,20 @@ +from __future__ import annotations + +import asyncio import pickle -from typing import Any, Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar + +from zarr.storage.wrapper import WrapperStore + +if TYPE_CHECKING: + from typing import Any + + from zarr.abc.store import ByteRangeRequest + from zarr.core.buffer.core import BufferPrototype import pytest -from zarr.abc.store import Store +from zarr.abc.store import ByteRangeRequest, Store from zarr.core.buffer import Buffer, default_buffer_prototype from zarr.core.sync import _collect_aiterator from zarr.storage._utils import _normalize_interval_index @@ -319,25 +330,43 @@ async def test_set_if_not_exists(self, store: S) -> None: result = await store.get("k2", default_buffer_prototype()) assert result == new - async def test_getsize(self, store: S) -> None: - key = "k" - data = self.buffer_cls.from_bytes(b"0" * 10) - await self.set(store, key, data) - - result = await store.getsize(key) - assert isinstance(result, int) - assert result > 0 - - async def test_getsize_raises(self, store: S) -> None: - with pytest.raises(FileNotFoundError): - await store.getsize("not-a-real-key") - - async def test_getsize_prefix(self, store: S) -> None: - prefix = "array/c/" - for i in range(10): - data = self.buffer_cls.from_bytes(b"0" * 10) - await self.set(store, f"{prefix}/{i}", data) - - result = await store.getsize_prefix(prefix) - assert isinstance(result, int) - assert result > 0 + +class LatencyStore(WrapperStore[Store]): + """ + A wrapper class that takes any store class in its constructor and + adds latency to the `set` and `get` methods. This can be used for + performance testing. + """ + + get_latency: float + set_latency: float + + def __init__(self, cls: Store, *, get_latency: float = 0, set_latency: float = 0) -> None: + self.get_latency = float(get_latency) + self.set_latency = float(set_latency) + self._store = cls + + async def set(self, key: str, value: Buffer) -> None: + await asyncio.sleep(self.set_latency) + await self._store.set(key, value) + + async def get( + self, key: str, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None + ) -> Buffer | None: + """ + Add latency to the get method. + + Adds a sleep of `self.get_latency` seconds before calling the wrapped method. + + Parameters + ---------- + key : str + prototype : BufferPrototype + byte_range : ByteRangeRequest, optional + + Returns + ------- + buffer : Buffer or None + """ + await asyncio.sleep(self.get_latency) + return await self._store.get(key, prototype=prototype, byte_range=byte_range) diff --git a/tests/test_store/test_wrapper.py b/tests/test_store/test_wrapper.py new file mode 100644 index 000000000..1caf9c9ae --- /dev/null +++ b/tests/test_store/test_wrapper.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from zarr.core.buffer.cpu import Buffer, buffer_prototype +from zarr.storage.wrapper import WrapperStore + +if TYPE_CHECKING: + from zarr.abc.store import Store + from zarr.core.buffer.core import BufferPrototype + + +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=True) +async def test_wrapped_set(store: Store, capsys: pytest.CaptureFixture[str]) -> None: + # define a class that prints when it sets + class NoisySetter(WrapperStore): + async def set(self, key: str, value: Buffer) -> None: + print(f"setting {key}") + await super().set(key, value) + + key = "foo" + value = Buffer.from_bytes(b"bar") + store_wrapped = NoisySetter(store) + await store_wrapped.set(key, value) + captured = capsys.readouterr() + assert f"setting {key}" in captured.out + assert await store_wrapped.get(key, buffer_prototype) == value + + +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=True) +async def test_wrapped_get(store: Store, capsys: pytest.CaptureFixture[str]) -> None: + # define a class that prints when it sets + class NoisyGetter(WrapperStore): + def get(self, key: str, prototype: BufferPrototype) -> None: + print(f"getting {key}") + return super().get(key, prototype=prototype) + + key = "foo" + value = Buffer.from_bytes(b"bar") + store_wrapped = NoisyGetter(store) + await store_wrapped.set(key, value) + assert await store_wrapped.get(key, buffer_prototype) == value + captured = capsys.readouterr() + assert f"getting {key}" in captured.out