Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DX-1019: Read Your Writes #50

Merged
merged 6 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions tests/test_read_your_writes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest

from upstash_redis import Redis

def test_should_update_sync_token_on_basic_request(redis: Redis):
redis = Redis.from_env(read_your_writes=True)

initial_token = redis._upstash_sync_token

redis.set("key", "value")

updated_token = redis._upstash_sync_token

assert initial_token != updated_token

def test_should_update_sync_token_on_pipeline(redis: Redis):
redis = Redis.from_env()

initial_token = redis._upstash_sync_token

pipeline = redis.pipeline()

pipeline.set("key", "value")
pipeline.set("key2", "value2")

pipeline.exec()

updated_token = redis._upstash_sync_token

assert initial_token != updated_token

def test_updates_after_successful_lua_script_call(redis):
s = """
redis.call('SET', 'mykey', 'myvalue')
return 1
"""
initial_sync = redis._upstash_sync_token
redis.eval(
s, keys=[], args=[]
)

updated_sync = redis._upstash_sync_token
assert updated_sync != initial_sync

def test_should_not_update_sync_state_with_opt_out_ryw():
redis = Redis.from_env(read_your_writes=False)
initial_sync = redis._upstash_sync_token
redis.set("key", "value")
updated_sync = redis._upstash_sync_token
assert updated_sync == initial_sync

def test_should_update_sync_state_with_default_behavior():
redis = Redis.from_env()
initial_sync = redis._upstash_sync_token
redis.set("key", "value")
updated_sync = redis._upstash_sync_token
assert updated_sync != initial_sync

13 changes: 13 additions & 0 deletions upstash_redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
rest_retries: int = 1,
rest_retry_interval: float = 3, # Seconds.
allow_telemetry: bool = True,
read_your_writes: bool = True
):
"""
Creates a new blocking Redis client.
Expand All @@ -44,6 +45,7 @@ def __init__(
:param rest_retries: how many times an HTTP request will be retried if it fails
:param rest_retry_interval: how many seconds will be waited between each retry
:param allow_telemetry: whether anonymous telemetry can be collected
:param read_your_writes: whether the client should wait for the response of a write operation before sending the next one
mdumandag marked this conversation as resolved.
Show resolved Hide resolved
"""

self._url = url
Expand All @@ -55,6 +57,9 @@ def __init__(
self._rest_retries = rest_retries
self._rest_retry_interval = rest_retry_interval

self._read_your_writes = read_your_writes
self._upstash_sync_token = ""

self._headers = make_headers(token, rest_encoding, allow_telemetry)
self._session = Session()

Expand All @@ -65,6 +70,7 @@ def from_env(
rest_retries: int = 1,
rest_retry_interval: float = 3,
allow_telemetry: bool = True,
read_your_writes: bool = True
):
"""
Load the credentials from environment.
Expand All @@ -82,6 +88,7 @@ def from_env(
rest_retries,
rest_retry_interval,
allow_telemetry,
read_your_writes
)

def __enter__(self) -> "Redis":
Expand All @@ -101,10 +108,14 @@ def close(self):
"""
self._session.close()

def _update_sync_token(self, new_token: str):
self._upstash_sync_token = new_token
def execute(self, command: List) -> RESTResultT:
"""
Executes the given command.
"""
if self._read_your_writes:
self._headers["upstash-sync-token"] = self._upstash_sync_token

res = sync_execute(
session=self._session,
Expand All @@ -114,8 +125,10 @@ def execute(self, command: List) -> RESTResultT:
retries=self._rest_retries,
retry_interval=self._rest_retry_interval,
command=command,
upstash_sync_token_callback=self._update_sync_token
)

print(res)
return cast_response(command, res)

def pipeline(self) -> "Pipeline":
Expand Down
25 changes: 20 additions & 5 deletions upstash_redis/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from base64 import b64decode
from json import dumps
from platform import python_version
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union, Callable

from aiohttp import ClientSession
from requests import Session
Expand Down Expand Up @@ -48,7 +48,8 @@ async def async_execute(
retries: int,
retry_interval: float,
command: List,
from_pipeline: bool = False
from_pipeline: bool = False,
upstash_sync_token_callback: Optional[Callable[[str], None]] = None
) -> Union[RESTResultT, List[RESTResultT]]:
"""
Execute the given command over the REST API.
Expand All @@ -57,6 +58,7 @@ async def async_execute(
:param retries: how many times an HTTP request will be retried if it fails
:param retry_interval: how many seconds will be waited between each retry
:param allow_telemetry: whether anonymous telemetry can be collected
:param upstash_sync_token_callback: This callback is called with the new Upstash Sync Token after each request to update the client's token
"""

# Serialize the command; more specifically, write string-incompatible types as JSON strings.
Expand All @@ -68,6 +70,12 @@ async def async_execute(
for attempts_left in range(max(0, retries), -1, -1):
try:
async with session.post(url, headers=headers, json=command) as r:
headers = await r.headers
new_upstash_sync_token = headers.get("Upstash-Sync-Token")

if upstash_sync_token_callback and new_upstash_sync_token:
upstash_sync_token_callback(new_upstash_sync_token)

response = await r.json()
break # Break the loop as soon as we receive a proper response
except Exception as e:
Expand Down Expand Up @@ -99,7 +107,8 @@ def sync_execute(
retries: int,
retry_interval: float,
command: List[Any],
from_pipeline: bool = False
from_pipeline: bool = False,
upstash_sync_token_callback: Optional[Callable[[str], None]] = None
) -> Union[RESTResultT, List[RESTResultT]]:
command = _format_command(command, from_pipeline=from_pipeline)

Expand All @@ -108,7 +117,14 @@ def sync_execute(

for attempts_left in range(max(0, retries), -1, -1):
try:
response = session.post(url, headers=headers, json=command).json()
response = session.post(url, headers=headers, json=command)

new_upstash_sync_token = response.headers.get("Upstash-Sync-Token")
if upstash_sync_token_callback and new_upstash_sync_token:
upstash_sync_token_callback(new_upstash_sync_token)

response = response.json()

break # Break the loop as soon as we receive a proper response
except Exception as e:
last_error = e
Expand All @@ -121,7 +137,6 @@ def sync_execute(

# Exhausted all retries, but no response is received
raise last_error

if not from_pipeline:
return format_response(response, encoding)
else:
Expand Down
Loading