Skip to content

Commit

Permalink
Merge pull request #5 from ecmwf-projects/cuon_fix
Browse files Browse the repository at this point in the history
Fix CUON OOM errors.
  • Loading branch information
aperezpredictia authored Jan 31, 2024
2 parents 60ae651 + d88ad22 commit cdd595c
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 114 deletions.
6 changes: 3 additions & 3 deletions cdsobs/cli/_object_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def check_if_missing_in_catalogue(
object_names, bucket, s3client, catalogue_repo
)
else:
objects = s3client.list_directory_objects(dataset)
object_names = [o.object_name for o in objects]
red_flag = objects_in_catalogue(object_names, dataset, s3client, catalogue_repo)
bucket = s3client.get_bucket_name(dataset)
objects = s3client.list_directory_objects(bucket)
red_flag = objects_in_catalogue(objects, bucket, s3client, catalogue_repo)
if not red_flag:
console.print("[bold green] Found all assets in catalogue [/bold green]")

Expand Down
238 changes: 127 additions & 111 deletions cdsobs/ingestion/readers/cuon.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import calendar
import importlib
import itertools
import os
import statistics
from dataclasses import dataclass
Expand All @@ -15,7 +14,7 @@

from cdsobs import constants
from cdsobs.cdm.denormalize import denormalize_tables
from cdsobs.cdm.tables import read_cdm_tables
from cdsobs.cdm.tables import CDMTable, read_cdm_tables
from cdsobs.config import CDSObsConfig
from cdsobs.ingestion.api import EmptyBatchException
from cdsobs.ingestion.core import TimeBatch, TimeSpaceBatch
Expand Down Expand Up @@ -201,43 +200,14 @@ def _maybe_swap_bytes(field_data):
return field_data


def read_all_nc_files(
files_and_slices: list[CUONFileandSlices], table_name: str, time_batch: TimeBatch
def read_table_data(
file_and_slices: CUONFileandSlices, table_name: str, time_batch: TimeBatch
) -> pandas.DataFrame:
"""Read nc table of all station files using h5py."""
results = []
if os.environ.get("CADSOBS_AVOID_MULTIPROCESS"):
# This is for the tests.
scheduler = "synchronous"
else:
# Do not use threads as HDF5 is not yet thread safe.
scheduler = "processes"
# Use dask to speed up the process
for file_and_slices in files_and_slices:
logger.info(f"Reading {table_name=} from {file_and_slices.path}")
results.append(
dask.delayed(_read_nc_file)(file_and_slices, table_name, time_batch)
)
results = dask.compute(
*results, scheduler=scheduler, num_workers=min(len(files_and_slices), 32)
)
result = _read_nc_file(file_and_slices, table_name, time_batch)

results = [r for r in results if r is not None]
if len(results) >= 1:
fields = sorted(
set(itertools.chain.from_iterable([list(r.keys()) for r in results]))
)
final_data = {}
for field in fields:
to_concat = []
for r in results:
if field in r:
to_concat.append(r[field])
else:
file_data_len = len(r[list(r)[0]])
to_concat.append(numpy.repeat(numpy.nan, file_data_len))
final_data[field] = numpy.concatenate(to_concat)
final_df_out = pandas.DataFrame(final_data)
if result is not None:
final_df_out = pandas.DataFrame(result)
else:
final_df_out = pandas.DataFrame()
# Reduce field size for memory efficiency
Expand Down Expand Up @@ -295,6 +265,38 @@ def read_cuon_netcdfs(
tables_to_use = config.get_dataset(dataset_name).available_cdm_tables
cdm_tables = read_cdm_tables(config.cdm_tables_location, tables_to_use)
files_and_slices = read_all_nc_slices(files, time_space_batch.time_batch)
denormalized_tables_futures = []
if os.environ.get("CADSOBS_AVOID_MULTIPROCESS"):
# This is for the tests.
scheduler = "synchronous"
else:
# Do not use threads as HDF5 is not yet thread safe.
scheduler = "processes"
# Use dask to speed up the process
for file_and_slices in files_and_slices:
denormalized_table_future = dask.delayed(_get_denormalized_table_file)(
cdm_tables, config, file_and_slices, tables_to_use, time_space_batch
)
if denormalized_table_future is not None:
denormalized_tables_futures.append(denormalized_table_future)
denormalized_tables = dask.compute(
*denormalized_tables_futures,
scheduler=scheduler,
num_workers=min(len(files_and_slices), 32),
)
return pandas.concat(denormalized_tables)


def _get_denormalized_table_file(*args):
try:
return get_denormalized_table_file(*args)
except NoDataInFileException:
return None


def get_denormalized_table_file(
cdm_tables, config, file_and_slices, tables_to_use, time_space_batch
):
dataset_cdm: dict[str, pandas.DataFrame] = {}
for table_name, table_definition in cdm_tables.items():
# Fix era5fb having different names in the CDM and in the files
Expand All @@ -303,77 +305,13 @@ def read_cuon_netcdfs(
else:
table_name_in_file = table_name
# Read table data
table_data = read_all_nc_files(
files_and_slices, table_name_in_file, time_space_batch.time_batch
table_data = read_table_data(
file_and_slices, table_name_in_file, time_space_batch.time_batch
)
# Make sure that latitude and longiture always carry on their table name.
for coord in ["latitude", "longitude", "source_id"]:
if coord in table_data:
table_data = table_data.rename(
{coord: coord + "|" + table_name}, axis=1
)
# Check that observation id is unique and fix if not
if table_name == "observations_table":
# If there is nothing here it is a waste of time to continue
if len(table_data) == 0:
raise EmptyBatchException
if not table_data.observation_id.is_unique:
logger.warning("observation_id is not unique, fixing")
table_data["observation_id"] = numpy.arange(
len(table_data), dtype="int"
).astype("bytes")
# Remove missing values to save memory
table_data = table_data.loc[~table_data.observation_value.isnull()]
# Try with sparse arrays to reduce memory usage.
for var in table_data:
if str(table_data[var].dtype) == "float32" and var != "observation_value":
table_data[var] = pandas.arrays.SparseArray(table_data[var])
# Check primary keys can be used to build a unique index
primary_keys = table_definition.primary_keys
if table_name in ["era5fb_table", "advanced_homogenisation"]:
table_data = table_data.reset_index()
table_data_len = len(table_data)
obs_table_len = len(dataset_cdm["observations_table"])
logger.warning(
"Filling era5fb table index with observation_id from observations_table"
)
obs_id_name = "obs_id" if table_name == "era5fb_table" else "observation_id"
if table_data_len < obs_table_len:
logger.warning(
"era5fb is shorter than observations_table "
"truncating observation_ids"
)
observation_ids = dataset_cdm[
"observations_table"
].observation_id.values[0:table_data_len]
table_data[obs_id_name] = observation_ids
elif table_data_len > obs_table_len:
logger.warning("era5fb is longer than observations_table " "truncating")
table_data = table_data.iloc[0:obs_table_len]
observation_ids = dataset_cdm[
"observations_table"
].observation_id.values
else:
observation_ids = dataset_cdm[
"observations_table"
].observation_id.values
table_data[obs_id_name] = observation_ids
table_data = table_data.set_index(obs_id_name).rename(
{"index": f"index|{table_name}"}, axis=1
)
if "level_0" in table_data:
table_data = table_data.drop("level_0", axis=1)
# Drop duplicates for header and stations
if table_name == ["station_configuration", "header_table"]:
table_data.drop_duplicates(inplace=True, ignore_index=False)
primary_keys_are_unique = (
table_data.reset_index().set_index(primary_keys).index.is_unique
table_data = _fix_table_data(
dataset_cdm, table_data, table_definition, table_name
)
if not primary_keys_are_unique:
logger.warning(
"Unable to build a unique index with primary_keys "
f"{table_definition.primary_keys} in table {table_name}"
)
dataset_cdm[table_name] = table_data
# Filter stations outside ofthe Batch
lats = dataset_cdm["header_table"]["latitude|header_table"]
Expand All @@ -389,24 +327,102 @@ def read_cuon_netcdfs(
)
dataset_cdm["header_table"] = dataset_cdm["header_table"].loc[spatial_mask]
# Denormalize tables
denormalized_table = denormalize_tables(
denormalized_table_file = denormalize_tables(
cdm_tables, dataset_cdm, tables_to_use, ignore_errors=True
)
del cdm_tables
# Decode time
if len(denormalized_table) > 0:
if len(denormalized_table_file) > 0:
for time_field in ["record_timestamp", "report_timestamp"]:
denormalized_table.loc[:, time_field] = cftime.num2date(
denormalized_table.loc[:, time_field],
denormalized_table_file.loc[:, time_field] = cftime.num2date(
denormalized_table_file.loc[:, time_field],
constants.TIME_UNITS,
only_use_cftime_datetimes=False,
)
else:
logger.warning(f"No data was found in file {file_and_slices.path}")
# Decode variable names
code_dict = get_var_code_dict(config.cdm_tables_location)
denormalized_table["observed_variable"] = denormalized_table[
denormalized_table_file["observed_variable"] = denormalized_table_file[
"observed_variable"
].map(code_dict)
return denormalized_table
return denormalized_table_file


class NoDataInFileException(RuntimeError):
pass


def _fix_table_data(
dataset_cdm: dict[str, pandas.DataFrame],
table_data: pandas.DataFrame,
table_definition: CDMTable,
table_name: str,
):
for coord in ["latitude", "longitude", "source_id"]:
if coord in table_data:
table_data = table_data.rename({coord: coord + "|" + table_name}, axis=1)
# Check that observation id is unique and fix if not
if table_name == "observations_table":
# If there is nothing here it is a waste of time to continue
if len(table_data) == 0:
logger.warning("No data found in file for this times.")
raise NoDataInFileException
if not table_data.observation_id.is_unique:
logger.warning("observation_id is not unique, fixing")
table_data["observation_id"] = numpy.arange(
len(table_data), dtype="int"
).astype("bytes")
# Remove missing values to save memory
table_data = table_data.loc[~table_data.observation_value.isnull()]
# Remove duplicate station records
if table_name == "station_configuration":
table_data = table_data.drop_duplicates(
subset=["primary_id", "record_number"], ignore_index=True
)
# Try with sparse arrays to reduce memory usage.
for var in table_data:
if str(table_data[var].dtype) == "float32" and var != "observation_value":
table_data[var] = pandas.arrays.SparseArray(table_data[var])
# Check primary keys can be used to build a unique index
primary_keys = table_definition.primary_keys
if table_name in ["era5fb_table", "advanced_homogenisation"]:
table_data = table_data.reset_index()
table_data_len = len(table_data)
obs_table_len = len(dataset_cdm["observations_table"])
logger.warning(
"Filling era5fb table index with observation_id from observations_table"
)
obs_id_name = "obs_id" if table_name == "era5fb_table" else "observation_id"
if table_data_len < obs_table_len:
logger.warning(
"era5fb is shorter than observations_table "
"truncating observation_ids"
)
observation_ids = dataset_cdm["observations_table"].observation_id.values[
0:table_data_len
]
table_data[obs_id_name] = observation_ids
elif table_data_len > obs_table_len:
logger.warning("era5fb is longer than observations_table " "truncating")
table_data = table_data.iloc[0:obs_table_len]
observation_ids = dataset_cdm["observations_table"].observation_id.values
else:
observation_ids = dataset_cdm["observations_table"].observation_id.values
table_data[obs_id_name] = observation_ids
table_data = table_data.set_index(obs_id_name).rename(
{"index": f"index|{table_name}"}, axis=1
)
if "level_0" in table_data:
table_data = table_data.drop("level_0", axis=1)
primary_keys_are_unique = (
table_data.reset_index().set_index(primary_keys).index.is_unique
)
if not primary_keys_are_unique:
logger.warning(
"Unable to build a unique index with primary_keys "
f"{table_definition.primary_keys} in table {table_name}"
)
return table_data


def read_nc_file_slices(
Expand Down

0 comments on commit cdd595c

Please sign in to comment.