Skip to content

Commit

Permalink
Fix JSON encoding of complex fill values
Browse files Browse the repository at this point in the history
We were not replacing NaNs and Infs with the string versions.
  • Loading branch information
dcherian committed Oct 22, 2024
1 parent a9d6d74 commit 7a59d84
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def default(self, o: object) -> Any:
if isinstance(out, complex):
# python complex types are not JSON serializable, so we use the
# serialization defined in the zarr v3 spec
return [out.real, out.imag]
return _replace_special_floats([out.real, out.imag])
elif np.isnan(out):
return "NaN"
elif np.isinf(out):
Expand Down
23 changes: 23 additions & 0 deletions tests/test_array.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import math
import pickle
from itertools import accumulate
from typing import Any, Literal
Expand All @@ -9,6 +11,7 @@
from zarr import Array, AsyncArray, Group
from zarr.codecs import BytesCodec, VLenBytesCodec
from zarr.core.array import chunks_initialized
from zarr.core.buffer import default_buffer_prototype
from zarr.core.buffer.cpu import NDBuffer
from zarr.core.common import JSON, MemoryOrder, ZarrFormat
from zarr.core.group import AsyncGroup
Expand Down Expand Up @@ -436,3 +439,23 @@ def test_array_create_order(
assert vals.flags.f_contiguous
else:
raise AssertionError


@pytest.mark.parametrize(
("fill_value", "expected"),
[
(np.nan * 1j, ["NaN", "NaN"]),
(np.nan, ["NaN", 0.0]),
(np.inf, ["Infinity", 0.0]),
(np.inf * 1j, ["NaN", "Infinity"]),
(-np.inf, ["-Infinity", 0.0]),
(math.inf, ["Infinity", 0.0]),
],
)
async def test_special_complex_fill_values_roundtrip(fill_value: Any, expected: list[Any]) -> None:
store = MemoryStore({}, mode="w")
Array.create(store=store, shape=(1,), dtype=np.complex64, fill_value=fill_value)
content = await store.get("zarr.json", prototype=default_buffer_prototype())
assert content is not None
actual = json.loads(content.to_bytes())
assert actual["fill_value"] == expected

0 comments on commit 7a59d84

Please sign in to comment.