From 7a59d84772dc845e341cfaf4589308ab33465199 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 22 Oct 2024 11:10:26 -0600 Subject: [PATCH] Fix JSON encoding of complex fill values We were not replacing NaNs and Infs with the string versions. --- src/zarr/core/metadata/v3.py | 2 +- tests/test_array.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 6b6f28dd9..931905bee 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -149,7 +149,7 @@ def default(self, o: object) -> Any: if isinstance(out, complex): # python complex types are not JSON serializable, so we use the # serialization defined in the zarr v3 spec - return [out.real, out.imag] + return _replace_special_floats([out.real, out.imag]) elif np.isnan(out): return "NaN" elif np.isinf(out): diff --git a/tests/test_array.py b/tests/test_array.py index f182cb1a1..2443482a9 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1,3 +1,5 @@ +import json +import math import pickle from itertools import accumulate from typing import Any, Literal @@ -9,6 +11,7 @@ from zarr import Array, AsyncArray, Group from zarr.codecs import BytesCodec, VLenBytesCodec from zarr.core.array import chunks_initialized +from zarr.core.buffer import default_buffer_prototype from zarr.core.buffer.cpu import NDBuffer from zarr.core.common import JSON, MemoryOrder, ZarrFormat from zarr.core.group import AsyncGroup @@ -436,3 +439,23 @@ def test_array_create_order( assert vals.flags.f_contiguous else: raise AssertionError + + +@pytest.mark.parametrize( + ("fill_value", "expected"), + [ + (np.nan * 1j, ["NaN", "NaN"]), + (np.nan, ["NaN", 0.0]), + (np.inf, ["Infinity", 0.0]), + (np.inf * 1j, ["NaN", "Infinity"]), + (-np.inf, ["-Infinity", 0.0]), + (math.inf, ["Infinity", 0.0]), + ], +) +async def test_special_complex_fill_values_roundtrip(fill_value: Any, expected: list[Any]) -> None: + store = MemoryStore({}, mode="w") + Array.create(store=store, shape=(1,), dtype=np.complex64, fill_value=fill_value) + content = await store.get("zarr.json", prototype=default_buffer_prototype()) + assert content is not None + actual = json.loads(content.to_bytes()) + assert actual["fill_value"] == expected