Skip to content

Commit

Permalink
Clarify computation of accumulated precipitation in docstring and tests.
Browse files Browse the repository at this point in the history
BUGFIX in compute_derived_variables, whereby we remove (possibly) dropped dims from source_chunks.

PiperOrigin-RevId: 679676003
  • Loading branch information
langmore authored and Weatherbench2 authors committed Sep 27, 2024
1 parent 47d7257 commit de3f56e
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 6 deletions.
14 changes: 10 additions & 4 deletions scripts/compute_derived_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ def main(argv: list[str]) -> None:

source_dataset, source_chunks = xbeam.open_zarr(INPUT_PATH.value)

for var_name in PREEXISTING_VARIABLES_TO_REMOVE.value:
if var_name in source_dataset:
del source_dataset[var_name]
source_chunks = {
# Removing variables may remove some dims.
k: v
for k, v in source_chunks.items()
if k in source_dataset.dims
}

# Validate and clean-up the source datset.
if RENAME_RAW_TP_NAME.value:
source_dataset = source_dataset.rename(
Expand All @@ -168,10 +178,6 @@ def main(argv: list[str]) -> None:
rename_variables.get(k, k): v for k, v in source_chunks.items()
}

for var_name in PREEXISTING_VARIABLES_TO_REMOVE.value:
if var_name in source_dataset:
del source_dataset[var_name]

for var_name, dv in derived_variables.items():
if var_name in source_dataset:
raise ValueError(
Expand Down
7 changes: 5 additions & 2 deletions weatherbench2/derived_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,11 @@ def compute(self, dataset: xr.Dataset) -> xr.DataArray:
class PrecipitationAccumulation(DerivedVariable):
"""Compute precipitation accumulation from hourly accumulations.
Accumulation is computed for the time period leading up to the lead_time.
E.g. 24h accumulation at lead_time=24h indicates 0-24h accumulation.
Accumulation is computed for the time period leading up to and including the
lead_time. E.g. 24h accumulation at lead_time=24h indicates accumulation
from lead_time=0 to lead_time=24. This is equal to the values of
`total_precipitation_name` at 24, minus the value at 0.
Caution: Small negative values sometimes appear in model output.
Here, we set them to zero.
Expand Down
31 changes: 31 additions & 0 deletions weatherbench2/derived_variables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,22 @@ def testPrecipitationAccumulation6hr(self):
accumulation_hours=6,
)
result = derived_variable.compute(dataset)

# Test a few specific times for example's sake.
# We want to verify that
# PrecipAccum6hr[t] = ReLu(TotalPrecip[t] - TotalPrecip[t - 6])
sel = lambda ds, hr: ds.sel(prediction_timedelta=f'{hr}hr')
relu = lambda ds: np.maximum(0, ds)
np.testing.assert_array_equal(
relu(sel(dataset, 24) - sel(dataset, 24 - 6)).total_precipitation.data,
sel(result, 24),
)
np.testing.assert_array_equal(
relu(sel(dataset, 18) - sel(dataset, 18 - 6)).total_precipitation.data,
sel(result, 18),
)

# Test every timedelta.
expected = xr.DataArray(
[np.nan, 5, 10, 0, 6, 10, 0],
dims=['prediction_timedelta'],
Expand All @@ -154,6 +170,21 @@ def testPrecipitationAccumulation24hr(self):
accumulation_hours=24,
)
result = derived_variable.compute(dataset)

# Test a few specific times for example's sake.
# We want to verify that
# PrecipAccum24hr[t] = ReLu(TotalPrecip[t] - TotalPrecip[t - 24])
sel = lambda ds, hr: ds.sel(prediction_timedelta=f'{hr}hr')
relu = lambda ds: np.maximum(0, ds)
np.testing.assert_array_equal(
relu(sel(dataset, 36) - sel(dataset, 36 - 24)).total_precipitation.data,
sel(result, 36),
)
np.testing.assert_array_equal(
relu(sel(dataset, 30) - sel(dataset, 30 - 24)).total_precipitation.data,
sel(result, 30),
)

expected = xr.DataArray(
[np.nan, np.nan, np.nan, np.nan, 20, 25, 15],
dims=['prediction_timedelta'],
Expand Down

0 comments on commit de3f56e

Please sign in to comment.