Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor datetime and timedelta encoding for increased robustness #9498

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 78 additions & 50 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,19 @@ def _unpack_time_units_and_ref_date(units: str) -> tuple[str, pd.Timestamp]:
return time_units, ref_date


def _unpack_time_units_and_ref_date_cftime(units: str, calendar: str):
# same as _unpack_netcdf_time_units but finalizes ref_date for
# processing in encode_cf_datetime
time_units, ref_date = _unpack_netcdf_time_units(units)
ref_date = cftime.num2date(
0,
units=f"microseconds since {ref_date}",
calendar=calendar,
only_use_cftime_datetimes=True,
)
return time_units, ref_date


def _decode_cf_datetime_dtype(
data, units: str, calendar: str | None, use_cftime: bool | None
) -> np.dtype:
Expand Down Expand Up @@ -387,6 +400,7 @@ def _unit_timedelta_numpy(units: str) -> np.timedelta64:
def _infer_time_units_from_diff(unique_timedeltas) -> str:
unit_timedelta: Callable[[str], timedelta] | Callable[[str], np.timedelta64]
zero_timedelta: timedelta | np.timedelta64
unique_timedeltas = asarray(unique_timedeltas)
if unique_timedeltas.dtype == np.dtype("O"):
time_units = _NETCDF_TIME_UNITS_CFTIME
unit_timedelta = _unit_timedelta_cftime
Expand All @@ -405,6 +419,10 @@ def _time_units_to_timedelta64(units: str) -> np.timedelta64:
return np.timedelta64(1, _netcdf_to_numpy_timeunit(units)).astype("timedelta64[ns]")


def _time_units_to_timedelta(units: str) -> timedelta:
return timedelta(microseconds=_US_PER_TIME_DELTA[units])


def infer_calendar_name(dates) -> CFCalendar:
"""Given an array of datetimes, infer the CF calendar name"""
if is_np_datetime_like(dates.dtype):
Expand Down Expand Up @@ -667,42 +685,6 @@ def _division(deltas, delta, floor):
return num


def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="overflow")
cast_num = np.asarray(num, dtype=dtype)

if np.issubdtype(dtype, np.integer):
if not (num == cast_num).all():
if np.issubdtype(num.dtype, np.floating):
raise ValueError(
f"Not possible to cast all encoded times from "
f"{num.dtype!r} to {dtype!r} without losing precision. "
f"Consider modifying the units such that integer values "
f"can be used, or removing the units and dtype encoding, "
f"at which point xarray will make an appropriate choice."
)
else:
raise OverflowError(
f"Not possible to cast encoded times from "
f"{num.dtype!r} to {dtype!r} without overflow. Consider "
f"removing the dtype encoding, at which point xarray will "
f"make an appropriate choice, or explicitly switching to "
"a larger integer dtype."
)
else:
if np.isinf(cast_num).any():
raise OverflowError(
f"Not possible to cast encoded times from {num.dtype!r} to "
f"{dtype!r} without overflow. Consider removing the dtype "
f"encoding, at which point xarray will make an appropriate "
f"choice, or explicitly switching to a larger floating point "
f"dtype."
)

return cast_num


def encode_cf_datetime(
dates: T_DuckArray, # type: ignore[misc]
units: str | None = None,
Expand All @@ -725,6 +707,26 @@ def encode_cf_datetime(
return _eagerly_encode_cf_datetime(dates, units, calendar, dtype)


def _infer_needed_units_numpy(ref_date, data_units):
needed_units, data_ref_date = _unpack_time_units_and_ref_date(data_units)
ref_delta = abs(data_ref_date - ref_date).to_timedelta64()
data_delta = _time_units_to_timedelta64(needed_units)
if (ref_delta % data_delta) > np.timedelta64(0, "ns"):
needed_units = _infer_time_units_from_diff(ref_delta)
return needed_units


def _infer_needed_units_cftime(ref_date, data_units, calendar):
needed_units, data_ref_date = _unpack_time_units_and_ref_date_cftime(
data_units, calendar
)
ref_delta = abs(data_ref_date - ref_date)
data_delta = _time_units_to_timedelta(needed_units)
if (ref_delta % data_delta) > timedelta(seconds=0):
needed_units = _infer_time_units_from_diff(ref_delta)
return needed_units


def _eagerly_encode_cf_datetime(
dates: T_DuckArray, # type: ignore[misc]
units: str | None = None,
Expand All @@ -744,6 +746,7 @@ def _eagerly_encode_cf_datetime(
if calendar is None:
calendar = infer_calendar_name(dates)

raise_incompatible_units_error = False
try:
if not _is_standard_calendar(calendar) or dates.dtype.kind == "O":
# parse with cftime instead
Expand All @@ -760,15 +763,7 @@ def _eagerly_encode_cf_datetime(
time_deltas = dates_as_index - ref_date

# retrieve needed units to faithfully encode to int64
needed_units, data_ref_date = _unpack_time_units_and_ref_date(data_units)
if data_units != units:
# this accounts for differences in the reference times
ref_delta = abs(data_ref_date - ref_date).to_timedelta64()
data_delta = _time_units_to_timedelta64(needed_units)
if (ref_delta % data_delta) > np.timedelta64(0, "ns"):
needed_units = _infer_time_units_from_diff(ref_delta)

# needed time delta to encode faithfully to int64
needed_units = _infer_needed_units_numpy(ref_date, data_units)
needed_time_delta = _time_units_to_timedelta64(needed_units)

floor_division = True
Expand All @@ -792,18 +787,47 @@ def _eagerly_encode_cf_datetime(
units = new_units
time_delta = needed_time_delta
floor_division = True
elif np.issubdtype(dtype, np.integer) and not allow_units_modification:
new_units = f"{needed_units} since {format_timestamp(ref_date)}"
raise_incompatible_units_error = True

num = _division(time_deltas, time_delta, floor_division)
num = reshape(num.values, dates.shape)

except (OutOfBoundsDatetime, OverflowError, ValueError):
time_units, ref_date = _unpack_time_units_and_ref_date_cftime(units, calendar)
time_delta_cftime = _time_units_to_timedelta(time_units)
needed_units = _infer_needed_units_cftime(ref_date, data_units, calendar)
needed_time_delta_cftime = _time_units_to_timedelta(needed_units)

if (
np.issubdtype(dtype, np.integer)
and time_delta_cftime > needed_time_delta_cftime
):
new_units = f"{needed_units} since {format_cftime_datetime(ref_date)}"
if allow_units_modification:
units = new_units
emit_user_level_warning(
f"Times can't be serialized faithfully to int64 with requested units {units!r}. "
f"Serializing with units {new_units!r} instead. "
f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. "
f"Set encoding['units'] to {new_units!r} to silence this warning ."
)
else:
raise_incompatible_units_error = True

num = _encode_datetime_with_cftime(dates, units, calendar)
# do it now only for cftime-based flow
# we already covered for this in pandas-based flow
num = cast_to_int_if_safe(num)

if dtype is not None:
num = _cast_to_dtype_if_safe(num, dtype)
if raise_incompatible_units_error:
raise ValueError(
f"Times can't be serialized faithfully to int64 with requested units {units!r}. "
f"Consider setting encoding['dtype'] to a floating point dtype to serialize with "
f"units {units!r}. Consider setting encoding['units'] to {new_units!r} to "
f"serialize with an integer dtype."
)

return num, units, calendar

Expand Down Expand Up @@ -912,13 +936,17 @@ def _eagerly_encode_cf_timedelta(
units = needed_units
time_delta = needed_time_delta
floor_division = True
elif np.issubdtype(dtype, np.integer) and not allow_units_modification:
raise ValueError(
f"Timedeltas can't be serialized faithfully to int64 with requested units {units!r}. "
f"Consider setting encoding['dtype'] to a floating point dtype to serialize with "
f"units {units!r}. Consider setting encoding['units'] to {needed_units!r} to "
f"serialize with an integer dtype."
)

num = _division(time_deltas, time_delta, floor_division)
num = reshape(num.values, timedeltas.shape)

if dtype is not None:
num = _cast_to_dtype_if_safe(num, dtype)

return num, units


Expand Down
76 changes: 37 additions & 39 deletions xarray/tests/test_coding_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,45 +1513,44 @@ def test_encode_cf_datetime_cftime_datetime_via_dask(units, dtype) -> None:
"use_cftime", [False, pytest.param(True, marks=requires_cftime)]
)
@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)])
def test_encode_cf_datetime_casting_value_error(use_cftime, use_dask) -> None:
def test_encode_cf_datetime_units_change(use_cftime, use_dask) -> None:
times = date_range(start="2000", freq="12h", periods=3, use_cftime=use_cftime)
encoding = dict(units="days since 2000-01-01", dtype=np.dtype("int64"))
variable = Variable(["time"], times, encoding=encoding)

if use_dask:
variable = variable.chunk({"time": 1})

if not use_cftime and not use_dask:
# In this particular case we automatically modify the encoding units to
# continue encoding with integer values. For all other cases we raise.
with pytest.raises(ValueError, match="Times can't be serialized"):
conventions.encode_cf_variable(variable).compute()
else:
with pytest.warns(UserWarning, match="Times can't be serialized"):
encoded = conventions.encode_cf_variable(variable)
assert encoded.attrs["units"] == "hours since 2000-01-01"
decoded = conventions.decode_cf_variable("name", encoded)
if use_cftime:
expected_units = "hours since 2000-01-01 00:00:00.000000"
else:
expected_units = "hours since 2000-01-01"
assert encoded.attrs["units"] == expected_units
decoded = conventions.decode_cf_variable("name", encoded, use_cftime=use_cftime)
assert_equal(variable, decoded)
else:
with pytest.raises(ValueError, match="Not possible"):
encoded = conventions.encode_cf_variable(variable)
encoded.compute()


@pytest.mark.parametrize(
"use_cftime", [False, pytest.param(True, marks=requires_cftime)]
)
@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)])
@pytest.mark.parametrize("dtype", [np.dtype("int16"), np.dtype("float16")])
def test_encode_cf_datetime_casting_overflow_error(use_cftime, use_dask, dtype) -> None:
# Regression test for GitHub issue #8542
times = date_range(start="2018", freq="5h", periods=3, use_cftime=use_cftime)
encoding = dict(units="microseconds since 2018-01-01", dtype=dtype)
def test_encode_cf_datetime_precision_loss_regression_test(use_dask):
# Regression test for
# https://github.com/pydata/xarray/issues/9134#issuecomment-2191446463
times = date_range("2000", periods=5, freq="ns")
encoding = dict(units="seconds since 1970-01-01", dtype=np.dtype("int64"))
variable = Variable(["time"], times, encoding=encoding)

if use_dask:
variable = variable.chunk({"time": 1})

with pytest.raises(OverflowError, match="Not possible"):
encoded = conventions.encode_cf_variable(variable)
encoded.compute()
with pytest.raises(ValueError, match="Times can't be serialized"):
conventions.encode_cf_variable(variable).compute()
else:
with pytest.warns(UserWarning, match="Times can't be serialized"):
encoded = conventions.encode_cf_variable(variable)
decoded = conventions.decode_cf_variable("name", encoded)
assert_equal(variable, decoded)


@requires_dask
Expand Down Expand Up @@ -1582,38 +1581,37 @@ def test_encode_cf_timedelta_via_dask(


@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)])
def test_encode_cf_timedelta_casting_value_error(use_dask) -> None:
def test_encode_cf_timedelta_units_change(use_dask) -> None:
timedeltas = pd.timedelta_range(start="0h", freq="12h", periods=3)
encoding = dict(units="days", dtype=np.dtype("int64"))
variable = Variable(["time"], timedeltas, encoding=encoding)

if use_dask:
variable = variable.chunk({"time": 1})

if not use_dask:
# In this particular case we automatically modify the encoding units to
# continue encoding with integer values.
with pytest.raises(ValueError, match="Timedeltas can't be serialized"):
conventions.encode_cf_variable(variable).compute()
else:
# In this case we automatically modify the encoding units to continue
# encoding with integer values.
with pytest.warns(UserWarning, match="Timedeltas can't be serialized"):
encoded = conventions.encode_cf_variable(variable)
assert encoded.attrs["units"] == "hours"
decoded = conventions.decode_cf_variable("name", encoded)
assert_equal(variable, decoded)
else:
with pytest.raises(ValueError, match="Not possible"):
encoded = conventions.encode_cf_variable(variable)
encoded.compute()


@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)])
@pytest.mark.parametrize("dtype", [np.dtype("int16"), np.dtype("float16")])
def test_encode_cf_timedelta_casting_overflow_error(use_dask, dtype) -> None:
timedeltas = pd.timedelta_range(start="0h", freq="5h", periods=3)
encoding = dict(units="microseconds", dtype=dtype)
def test_encode_cf_timedelta_small_dtype_missing_value(use_dask):
# Regression test for GitHub issue #9134
timedeltas = np.array([1, 2, "NaT", 4], dtype="timedelta64[D]").astype(
"timedelta64[ns]"
)
encoding = dict(units="days", dtype=np.dtype("int16"), _FillValue=np.int16(-1))
variable = Variable(["time"], timedeltas, encoding=encoding)

if use_dask:
variable = variable.chunk({"time": 1})

with pytest.raises(OverflowError, match="Not possible"):
encoded = conventions.encode_cf_variable(variable)
encoded.compute()
encoded = conventions.encode_cf_variable(variable)
decoded = conventions.decode_cf_variable("name", encoded)
assert_equal(variable, decoded)
Loading