Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/latency store #2474

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/zarr/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from zarr.storage.logging import LoggingStore
from zarr.storage.memory import MemoryStore
from zarr.storage.remote import RemoteStore
from zarr.storage.wrapper import WrapperStore
from zarr.storage.zip import ZipStore

__all__ = [
Expand All @@ -12,6 +13,7 @@
"RemoteStore",
"StoreLike",
"StorePath",
"WrapperStore",
"ZipStore",
"make_store_path",
]
20 changes: 11 additions & 9 deletions src/zarr/storage/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any

from zarr.abc.store import ByteRangeRequest, Store
from zarr.abc.store import Store
from zarr.storage.wrapper import WrapperStore

if TYPE_CHECKING:
from collections.abc import AsyncIterator, Generator, Iterable
from collections.abc import AsyncGenerator, Generator, Iterable

from zarr.abc.store import ByteRangeRequest
from zarr.core.buffer import Buffer, BufferPrototype

counter: defaultdict[str, int]


class LoggingStore(Store):
class LoggingStore(WrapperStore[Store]):
"""
Store wrapper that logs all calls to the wrapped store.

Expand All @@ -34,7 +38,6 @@ class LoggingStore(Store):
Counter of number of times each method has been called
"""

_store: Store
counter: defaultdict[str, int]

def __init__(
Expand All @@ -43,11 +46,10 @@ def __init__(
log_level: str = "DEBUG",
log_handler: logging.Handler | None = None,
) -> None:
self._store = store
super().__init__(store)
self.counter = defaultdict(int)
self.log_level = log_level
self.log_handler = log_handler

self._configure_logger(log_level, log_handler)

def _configure_logger(
Expand Down Expand Up @@ -203,19 +205,19 @@ async def set_partial_values(
with self.log(keys):
return await self._store.set_partial_values(key_start_values=key_start_values)

async def list(self) -> AsyncIterator[str]:
async def list(self) -> AsyncGenerator[str, None]:
# docstring inherited
with self.log():
async for key in self._store.list():
yield key

async def list_prefix(self, prefix: str) -> AsyncIterator[str]:
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
# docstring inherited
with self.log(prefix):
async for key in self._store.list_prefix(prefix=prefix):
yield key

async def list_dir(self, prefix: str) -> AsyncIterator[str]:
async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
# docstring inherited
with self.log(prefix):
async for key in self._store.list_dir(prefix=prefix):
Expand Down
139 changes: 139 additions & 0 deletions src/zarr/storage/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Generic, TypeVar

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, AsyncIterator, Iterable
from types import TracebackType
from typing import Any, Self

from zarr.abc.store import ByteRangeRequest
from zarr.core.buffer import Buffer, BufferPrototype
from zarr.core.common import BytesLike

from zarr.abc.store import Store

T_Store = TypeVar("T_Store", bound=Store)


class WrapperStore(Store, Generic[T_Store]):
"""
A store class that wraps an existing ``Store`` instance.
By default all of the store methods are delegated to the wrapped store instance, which is
accessible via the ``._wrapped`` attribute of this class.

Use this class to modify or extend the behavior of the other store classes.
"""

_store: T_Store

def __init__(self, store: T_Store) -> None:
self._store = store

@classmethod
async def open(cls: type[Self], store_cls: type[T_Store], *args: Any, **kwargs: Any) -> Self:
store = store_cls(*args, **kwargs)
await store._open()
return cls(store=store)

def __enter__(self) -> Self:
return type(self)(self._store.__enter__())

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
return self._store.__exit__(exc_type, exc_value, traceback)

async def _open(self) -> None:
await self._store._open()

async def _ensure_open(self) -> None:
await self._store._ensure_open()

async def is_empty(self, prefix: str) -> bool:
return await self._store.is_empty(prefix)

async def clear(self) -> None:
return await self._store.clear()

@property
def read_only(self) -> bool:
return self._store.read_only

def _check_writable(self) -> None:
return self._store._check_writable()

def __eq__(self, value: object) -> bool:
return type(self) is type(value) and self._store.__eq__(value)

async def get(
self, key: str, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None
) -> Buffer | None:
return await self._store.get(key, prototype, byte_range)

async def get_partial_values(
self,
prototype: BufferPrototype,
key_ranges: Iterable[tuple[str, ByteRangeRequest]],
) -> list[Buffer | None]:
return await self._store.get_partial_values(prototype, key_ranges)

async def exists(self, key: str) -> bool:
return await self._store.exists(key)

async def set(self, key: str, value: Buffer) -> None:
await self._store.set(key, value)

async def set_if_not_exists(self, key: str, value: Buffer) -> None:
return await self._store.set_if_not_exists(key, value)

async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
await self._store._set_many(values)

@property
def supports_writes(self) -> bool:
return self._store.supports_writes

@property
def supports_deletes(self) -> bool:
return self._store.supports_deletes

async def delete(self, key: str) -> None:
await self._store.delete(key)

@property
def supports_partial_writes(self) -> bool:
return self._store.supports_partial_writes

async def set_partial_values(
self, key_start_values: Iterable[tuple[str, int, BytesLike]]
) -> None:
return await self._store.set_partial_values(key_start_values)

@property
def supports_listing(self) -> bool:
return self._store.supports_listing

def list(self) -> AsyncIterator[str]:
return self._store.list()

def list_prefix(self, prefix: str) -> AsyncIterator[str]:
return self._store.list_prefix(prefix)

def list_dir(self, prefix: str) -> AsyncIterator[str]:
return self._store.list_dir(prefix)

async def delete_dir(self, prefix: str) -> None:
return await self._store.delete_dir(prefix)

def close(self) -> None:
self._store.close()

async def _get_many(
self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest | None]]
) -> AsyncGenerator[tuple[str, Buffer | None], None]:
async for req in self._store._get_many(requests):
yield req
77 changes: 53 additions & 24 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from __future__ import annotations

import asyncio
import pickle
from typing import Any, Generic, TypeVar
from typing import TYPE_CHECKING, Generic, TypeVar

from zarr.storage.wrapper import WrapperStore

if TYPE_CHECKING:
from typing import Any

from zarr.abc.store import ByteRangeRequest
from zarr.core.buffer.core import BufferPrototype

import pytest

from zarr.abc.store import Store
from zarr.abc.store import ByteRangeRequest, Store
from zarr.core.buffer import Buffer, default_buffer_prototype
from zarr.core.sync import _collect_aiterator
from zarr.storage._utils import _normalize_interval_index
Expand Down Expand Up @@ -319,25 +330,43 @@ async def test_set_if_not_exists(self, store: S) -> None:
result = await store.get("k2", default_buffer_prototype())
assert result == new

async def test_getsize(self, store: S) -> None:
key = "k"
data = self.buffer_cls.from_bytes(b"0" * 10)
await self.set(store, key, data)

result = await store.getsize(key)
assert isinstance(result, int)
assert result > 0

async def test_getsize_raises(self, store: S) -> None:
with pytest.raises(FileNotFoundError):
await store.getsize("not-a-real-key")

async def test_getsize_prefix(self, store: S) -> None:
prefix = "array/c/"
for i in range(10):
data = self.buffer_cls.from_bytes(b"0" * 10)
await self.set(store, f"{prefix}/{i}", data)

result = await store.getsize_prefix(prefix)
assert isinstance(result, int)
assert result > 0

class LatencyStore(WrapperStore[Store]):
"""
A wrapper class that takes any store class in its constructor and
adds latency to the `set` and `get` methods. This can be used for
performance testing.
"""

get_latency: float
set_latency: float

def __init__(self, cls: Store, *, get_latency: float = 0, set_latency: float = 0) -> None:
self.get_latency = float(get_latency)
self.set_latency = float(set_latency)
self._store = cls

async def set(self, key: str, value: Buffer) -> None:
await asyncio.sleep(self.set_latency)
await self._store.set(key, value)

async def get(
self, key: str, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None
) -> Buffer | None:
"""
Add latency to the get method.

Adds a sleep of `self.get_latency` seconds before calling the wrapped method.

Parameters
----------
key : str
prototype : BufferPrototype
byte_range : ByteRangeRequest, optional

Returns
-------
buffer : Buffer or None
"""
await asyncio.sleep(self.get_latency)
return await self._store.get(key, prototype=prototype, byte_range=byte_range)
46 changes: 46 additions & 0 deletions tests/test_store/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

from zarr.core.buffer.cpu import Buffer, buffer_prototype
from zarr.storage.wrapper import WrapperStore

if TYPE_CHECKING:
from zarr.abc.store import Store
from zarr.core.buffer.core import BufferPrototype


@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=True)
async def test_wrapped_set(store: Store, capsys: pytest.CaptureFixture[str]) -> None:
# define a class that prints when it sets
class NoisySetter(WrapperStore):
async def set(self, key: str, value: Buffer) -> None:
print(f"setting {key}")
await super().set(key, value)

key = "foo"
value = Buffer.from_bytes(b"bar")
store_wrapped = NoisySetter(store)
await store_wrapped.set(key, value)
captured = capsys.readouterr()
assert f"setting {key}" in captured.out
assert await store_wrapped.get(key, buffer_prototype) == value


@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=True)
async def test_wrapped_get(store: Store, capsys: pytest.CaptureFixture[str]) -> None:
# define a class that prints when it sets
class NoisyGetter(WrapperStore):
def get(self, key: str, prototype: BufferPrototype) -> None:
print(f"getting {key}")
return super().get(key, prototype=prototype)

key = "foo"
value = Buffer.from_bytes(b"bar")
store_wrapped = NoisyGetter(store)
await store_wrapped.set(key, value)
assert await store_wrapped.get(key, buffer_prototype) == value
captured = capsys.readouterr()
assert f"getting {key}" in captured.out