Skip to content

Commit

Permalink
fix pipeline and transactions
Browse files Browse the repository at this point in the history
  • Loading branch information
mdumandag committed Jul 22, 2024
1 parent 89d4866 commit fb74606
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 28 deletions.
27 changes: 27 additions & 0 deletions tests/test_read_your_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,33 @@ async def test_should_update_sync_token_on_pipeline_async(async_redis: AsyncRedi
assert initial_token != updated_token


@pytest.mark.parametrize("redis", [{"read_your_writes": True}], indirect=True)
def test_should_update_sync_token_on_multiexec(redis: Redis):
initial_token = redis._sync_token

multi = redis.multi()
multi.set("key", "value")
multi.set("key2", "value2")
multi.exec()

updated_token = redis._sync_token
assert initial_token != updated_token


@pytest.mark.parametrize("async_redis", [{"read_your_writes": True}], indirect=True)
@pytest.mark.asyncio
async def test_should_update_sync_token_on_multiexec_async(async_redis: AsyncRedis):
initial_token = async_redis._sync_token

multi = async_redis.multi()
multi.set("key", "value")
multi.set("key2", "value2")
await multi.exec()

updated_token = async_redis._sync_token
assert initial_token != updated_token


@pytest.mark.parametrize("redis", [{"read_your_writes": True}], indirect=True)
def test_updates_after_successful_lua_script_call(redis: Redis):
initial_token = redis._sync_token
Expand Down
39 changes: 25 additions & 14 deletions upstash_redis/asyncio/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from os import environ
from typing import Any, List, Literal, Optional, Type, Dict
from typing import Any, List, Literal, Optional, Type, Dict, Callable

from aiohttp import ClientSession

Expand Down Expand Up @@ -60,13 +60,6 @@ def __init__(
self._read_your_writes = read_your_writes
self._sync_token = ""

def nop_sync_token_cb(_: str):
pass

self._sync_token_cb = (
self._update_sync_token if read_your_writes else nop_sync_token_cb
)

self._headers = make_headers(token, rest_encoding, allow_telemetry)
self._context_manager: Optional[_SessionContextManager] = None

Expand Down Expand Up @@ -123,7 +116,12 @@ async def close(self) -> None:
self._context_manager = None

def _update_sync_token(self, new_token: str):
self._sync_token = new_token
if self._read_your_writes:
self._sync_token = new_token

def _maybe_set_sync_token_header(self, headers: Dict[str, str]):
if self._read_your_writes:
headers["Upstash-Sync-Token"] = self._sync_token

async def execute(self, command: List) -> RESTResultT:
"""
Expand All @@ -135,9 +133,7 @@ async def execute(self, command: List) -> RESTResultT:
ClientSession(), close_session=True
)

if self._read_your_writes:
self._headers["Upstash-Sync-Token"] = self._sync_token

self._maybe_set_sync_token_header(self._headers)
async with context_manager:
res = await async_execute(
session=context_manager.session,
Expand All @@ -147,7 +143,7 @@ async def execute(self, command: List) -> RESTResultT:
retries=self._rest_retries,
retry_interval=self._rest_retry_interval,
command=command,
sync_token_cb=self._sync_token_cb,
sync_token_cb=self._update_sync_token,
)

return cast_response(command, res)
Expand All @@ -166,6 +162,8 @@ def pipeline(self) -> "AsyncPipeline":
headers=self._headers,
context_manager=self._context_manager,
multi_exec="pipeline",
set_sync_token_header_fn=self._maybe_set_sync_token_header,
sync_token_cb=self._update_sync_token,
)

def multi(self) -> "AsyncPipeline":
Expand All @@ -182,6 +180,8 @@ def multi(self) -> "AsyncPipeline":
headers=self._headers,
context_manager=self._context_manager,
multi_exec="multi-exec",
set_sync_token_header_fn=self._maybe_set_sync_token_header,
sync_token_cb=self._update_sync_token,
)


Expand All @@ -197,6 +197,8 @@ def __init__(
context_manager: Optional["_SessionContextManager"] = None,
headers: Optional[Dict[str, str]] = None,
multi_exec: Literal["multi-exec", "pipeline"] = "pipeline",
set_sync_token_header_fn: Optional[Callable[[Dict[str, str]], None]] = None,
sync_token_cb: Optional[Callable[[str], None]] = None,
):
"""
Creates a new blocking Redis client.
Expand All @@ -209,7 +211,9 @@ def __init__(
:param allow_telemetry: whether anonymous telemetry can be collected
:param context_manager: context manager
:param headers: request headers
:param miltiexec: Whether multi execution (transaction) or pipelining is to be used
:param multiexec: Whether multi execution (transaction) or pipelining is to be used
:param set_sync_token_header_fn: Function to set the Upstash-Sync-Token header
:param sync_token_cb: Function to call when a new Upstash-Sync-Token response is received
"""

self._url = url
Expand All @@ -229,6 +233,9 @@ def __init__(
self._command_stack: List[List[str]] = []
self._multi_exec = multi_exec

self._set_sync_token_header_fn = set_sync_token_header_fn
self._sync_token_cb = sync_token_cb

def execute(self, command: List) -> "AsyncPipeline": # type: ignore[override]
"""
Adds commnd to the command stack which will be sent as a batch
Expand All @@ -245,6 +252,9 @@ async def exec(self) -> List[RESTResultT]:
"""
url = f"{self._url}/{self._multi_exec}"

if self._set_sync_token_header_fn:
self._set_sync_token_header_fn(self._headers)

context_manager = self._context_manager
async with context_manager:
res: List[RESTResultT] = await async_execute( # type: ignore[assignment]
Expand All @@ -256,6 +266,7 @@ async def exec(self) -> List[RESTResultT]:
retry_interval=self._rest_retry_interval,
command=self._command_stack,
from_pipeline=True,
sync_token_cb=self._sync_token_cb,
)

response = [
Expand Down
40 changes: 26 additions & 14 deletions upstash_redis/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from os import environ
from typing import Any, List, Literal, Optional, Type, Dict
from typing import Any, List, Literal, Optional, Type, Dict, Callable

from requests import Session

Expand Down Expand Up @@ -62,13 +62,6 @@ def __init__(
self._read_your_writes = read_your_writes
self._sync_token = ""

def nop_sync_token_cb(_: str):
pass

self._sync_token_cb = (
self._update_sync_token if read_your_writes else nop_sync_token_cb
)

self._headers = make_headers(token, rest_encoding, allow_telemetry)
self._session = Session()

Expand Down Expand Up @@ -120,15 +113,18 @@ def close(self):
self._session.close()

def _update_sync_token(self, new_token: str):
self._sync_token = new_token
if self._read_your_writes:
self._sync_token = new_token

def _maybe_set_sync_token_header(self, headers: Dict[str, str]):
if self._read_your_writes:
headers["Upstash-Sync-Token"] = self._sync_token

def execute(self, command: List) -> RESTResultT:
"""
Executes the given command.
"""
if self._read_your_writes:
self._headers["Upstash-Sync-Token"] = self._sync_token

self._maybe_set_sync_token_header(self._headers)
res = sync_execute(
session=self._session,
url=self._url,
Expand All @@ -137,7 +133,7 @@ def execute(self, command: List) -> RESTResultT:
retries=self._rest_retries,
retry_interval=self._rest_retry_interval,
command=command,
sync_token_cb=self._sync_token_cb,
sync_token_cb=self._update_sync_token,
)

return cast_response(command, res)
Expand All @@ -156,6 +152,8 @@ def pipeline(self) -> "Pipeline":
headers=self._headers,
session=self._session,
multi_exec="pipeline",
set_sync_token_header_fn=self._maybe_set_sync_token_header,
sync_token_cb=self._update_sync_token,
)

def multi(self) -> "Pipeline":
Expand All @@ -172,6 +170,8 @@ def multi(self) -> "Pipeline":
headers=self._headers,
session=self._session,
multi_exec="multi-exec",
set_sync_token_header_fn=self._maybe_set_sync_token_header,
sync_token_cb=self._update_sync_token,
)


Expand All @@ -187,6 +187,8 @@ def __init__(
headers: Optional[Dict[str, str]] = None,
session: Optional[Session] = None,
multi_exec: Literal["multi-exec", "pipeline"] = "pipeline",
set_sync_token_header_fn: Optional[Callable[[Dict[str, str]], None]] = None,
sync_token_cb: Optional[Callable[[str], None]] = None,
):
"""
Creates a new blocking Redis client.
Expand All @@ -199,7 +201,9 @@ def __init__(
:param allow_telemetry: whether anonymous telemetry can be collected
:param headers: request headers
:param session: A Requests session
:param miltiexec: Whether multi execution (transaction) or pipelining is to be used
:param multiexec: Whether multi execution (transaction) or pipelining is to be used
:param set_sync_token_header_fn: Function to set the Upstash-Sync-Token header
:param sync_token_cb: Function to call when a new Upstash-Sync-Token response is received
"""

self._url = url
Expand All @@ -217,6 +221,9 @@ def __init__(
self._command_stack: List[List[str]] = []
self._multi_exec = multi_exec

self._set_sync_token_header_fn = set_sync_token_header_fn
self._sync_token_cb = sync_token_cb

def execute(self, command: List) -> "Pipeline":
"""
Adds commnd to the command stack which will be sent as a batch
Expand All @@ -232,6 +239,10 @@ def exec(self) -> List[RESTResultT]:
Executes the commands in the pipeline by sending them as a batch
"""
url = f"{self._url}/{self._multi_exec}"

if self._set_sync_token_header_fn:
self._set_sync_token_header_fn(self._headers)

res: List[RESTResultT] = sync_execute( # type: ignore[assignment]
session=self._session,
url=url,
Expand All @@ -241,6 +252,7 @@ def exec(self) -> List[RESTResultT]:
retry_interval=self._rest_retry_interval,
command=self._command_stack,
from_pipeline=True,
sync_token_cb=self._sync_token_cb,
)
response = [
cast_response(command, response)
Expand Down

0 comments on commit fb74606

Please sign in to comment.