Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it possible to serialize and deserialize ErtConfig #8820

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/ert/config/design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ class DesignMatrix:
xls_filename: Path
design_sheet: str
default_sheet: str
num_realizations: Optional[int] = None
active_realizations: Optional[List[bool]] = None
design_matrix_df: Optional[pd.DataFrame] = None
parameter_configuration: Optional[Dict[str, ParameterConfig]] = None

def __post_init__(self) -> None:
self.num_realizations: Optional[int] = None
self.active_realizations: Optional[List[bool]] = None
self.design_matrix_df: Optional[pd.DataFrame] = None
self.parameter_configuration: Optional[Dict[str, ParameterConfig]] = None

@classmethod
def from_config_list(cls, config_list: List[str]) -> "DesignMatrix":
Expand Down
30 changes: 10 additions & 20 deletions src/ert/config/ensemble_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from ert.field_utils import get_shape

from .field import Field
from .ext_param_config import ExtParamConfig
from .field import Field as FieldConfig
from .gen_data_config import GenDataConfig
from .gen_kw_config import GenKwConfig
from .parameter_config import ParameterConfig
Expand Down Expand Up @@ -49,8 +50,12 @@ def _get_abs_path(file: Optional[str]) -> Optional[str]:
@dataclass
class EnsembleConfig:
grid_file: Optional[str] = None
response_configs: Dict[str, ResponseConfig] = field(default_factory=dict)
parameter_configs: Dict[str, ParameterConfig] = field(default_factory=dict)
response_configs: Dict[str, Union[SummaryConfig, GenDataConfig]] = field(
default_factory=dict
)
parameter_configs: Dict[
str, GenKwConfig | FieldConfig | SurfaceConfig | ExtParamConfig
] = field(default_factory=dict)
refcase: Optional[Refcase] = None

def __post_init__(self) -> None:
Expand Down Expand Up @@ -92,7 +97,7 @@ def from_dict(cls, config_dict: ConfigDict) -> EnsembleConfig:
grid_file_path,
) from err

def make_field(field_list: List[str]) -> Field:
def make_field(field_list: List[str]) -> FieldConfig:
if grid_file_path is None:
raise ConfigValidationError.with_context(
"In order to use the FIELD keyword, a GRID must be supplied.",
Expand All @@ -103,7 +108,7 @@ def make_field(field_list: List[str]) -> Field:
f"Grid file {grid_file_path} did not contain dimensions",
grid_file_path,
)
return Field.from_config_list(grid_file_path, dims, field_list)
return FieldConfig.from_config_list(grid_file_path, dims, field_list)

parameter_configs = (
[GenKwConfig.from_config_list(g) for g in gen_kw_list]
Expand Down Expand Up @@ -152,21 +157,6 @@ def hasNodeGenData(self, key: str) -> bool:
config = self.response_configs["gen_data"]
return key in config.keys

def addNode(self, config_node: Union[ParameterConfig, ResponseConfig]) -> None:
assert config_node is not None
if config_node.name in self:
raise ConfigValidationError(
f"Config node with key {config_node.name!r} already present in ensemble config"
)

if isinstance(config_node, ParameterConfig):
logger.info(
f"Adding {type(config_node).__name__} config (of size {len(config_node)}) to parameter_configs"
)
self.parameter_configs[config_node.name] = config_node
else:
self.response_configs[config_node.name] = config_node

def get_keylist_gen_kw(self) -> List[str]:
return [
val.name
Expand Down
55 changes: 41 additions & 14 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import logging
import os
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import field
from datetime import datetime
from os import path
from pathlib import Path
from typing import (
Any,
ClassVar,
DefaultDict,
Dict,
List,
Optional,
Expand All @@ -24,6 +25,8 @@

import polars
from pydantic import ValidationError as PydanticValidationError
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self

from ert.plugins import ErtPluginManager
Expand All @@ -49,6 +52,7 @@
ConfigWarning,
ErrorInfo,
ForwardModelStepKeys,
HistorySource,
HookRuntime,
init_forward_model_schema,
init_site_config_schema,
Expand Down Expand Up @@ -256,7 +260,9 @@ class ErtConfig:
queue_config: QueueConfig = field(default_factory=QueueConfig)
workflow_jobs: Dict[str, WorkflowJob] = field(default_factory=dict)
workflows: Dict[str, Workflow] = field(default_factory=dict)
hooked_workflows: Dict[HookRuntime, List[Workflow]] = field(default_factory=dict)
hooked_workflows: DefaultDict[HookRuntime, List[Workflow]] = field(
default_factory=lambda: defaultdict(list)
)
runpath_file: Path = Path(DEFAULT_RUNPATH_FILE)
ert_templates: List[Tuple[str, str]] = field(default_factory=list)
installed_forward_model_steps: Dict[str, ForwardModelStep] = field(
Expand All @@ -269,6 +275,14 @@ class ErtConfig:
observation_config: List[
Tuple[str, Union[HistoryValues, SummaryValues, GenObsValues]]
] = field(default_factory=list)
enkf_obs: EnkfObs = field(default_factory=EnkfObs)

@field_validator("substitutions", mode="before")
@classmethod
def convert_to_substitutions(cls, v: Dict[str, str]) -> Substitutions:
if isinstance(v, Substitutions):
return v
return Substitutions(v)

def __eq__(self, other: object) -> bool:
if not isinstance(other, ErtConfig):
Expand Down Expand Up @@ -298,8 +312,6 @@ def __post_init__(self) -> None:
if self.user_config_file
else os.getcwd()
)
self.enkf_obs: EnkfObs = self._create_observations(self.observation_config)

self.observations: Dict[str, polars.DataFrame] = self.enkf_obs.datasets

@staticmethod
Expand Down Expand Up @@ -456,7 +468,7 @@ def from_dict(cls, config_dict) -> Self:
errors.append(err)

obs_config_file = config_dict.get(ConfigKeys.OBS_CONFIG)
obs_config_content = None
obs_config_content = []
try:
if obs_config_file:
if path.isfile(obs_config_file) and path.getsize(obs_config_file) == 0:
Expand Down Expand Up @@ -487,6 +499,19 @@ def from_dict(cls, config_dict) -> Self:
[key] for key in summary_obs if key not in summary_keys
]
ensemble_config = EnsembleConfig.from_dict(config_dict=config_dict)
if model_config:
observations = cls._create_observations(
obs_config_content,
ensemble_config,
model_config.time_map,
model_config.history_source,
)
else:
errors.append(
ConfigValidationError(
"Not possible to validate observations without valid model config"
)
)
except ConfigValidationError as err:
errors.append(err)

Expand Down Expand Up @@ -519,6 +544,7 @@ def from_dict(cls, config_dict) -> Self:
model_config=model_config,
user_config_file=config_file_path,
observation_config=obs_config_content,
enkf_obs=observations,
)

@classmethod
Expand Down Expand Up @@ -970,24 +996,25 @@ def _installed_forward_model_steps_from_dict(
def preferred_num_cpu(self) -> int:
return int(self.substitutions.get(f"<{ConfigKeys.NUM_CPU}>", 1))

@staticmethod
def _create_observations(
self,
obs_config_content: Optional[
Dict[str, Union[HistoryValues, SummaryValues, GenObsValues]]
],
ensemble_config: EnsembleConfig,
time_map: Optional[List[datetime]],
history: HistorySource,
) -> EnkfObs:
if not obs_config_content:
return EnkfObs({}, [])
obs_vectors: Dict[str, ObsVector] = {}
obs_time_list: Sequence[datetime] = []
if self.ensemble_config.refcase is not None:
obs_time_list = self.ensemble_config.refcase.all_dates
elif self.model_config.time_map is not None:
obs_time_list = self.model_config.time_map
if ensemble_config.refcase is not None:
obs_time_list = ensemble_config.refcase.all_dates
elif time_map is not None:
obs_time_list = time_map

history = self.model_config.history_source
time_len = len(obs_time_list)
ensemble_config = self.ensemble_config
config_errors: List[ErrorInfo] = []
for obs_name, values in obs_config_content:
try:
Expand Down Expand Up @@ -1059,7 +1086,7 @@ def _get_files_in_directory(job_path, errors):


def _substitutions_from_dict(config_dict) -> Substitutions:
subst_list = Substitutions()
subst_list = {}

for key, val in config_dict.get("DEFINE", []):
subst_list[key] = val
Expand All @@ -1077,7 +1104,7 @@ def _substitutions_from_dict(config_dict) -> Substitutions:
for key, val in config_dict.get("DATA_KW", []):
subst_list[key] = val

return subst_list
return Substitutions(subst_list)


@no_type_check
Expand Down
2 changes: 1 addition & 1 deletion src/ert/config/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import logging
import os
import time
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Optional, Union, overload

import numpy as np
import xarray as xr
from pydantic.dataclasses import dataclass
from typing_extensions import Self

from ert.field_utils import FieldFileFormat, Shape, read_field, read_mask, save_field
Expand Down
9 changes: 9 additions & 0 deletions src/ert/config/forward_model_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from dataclasses import dataclass, field
from typing import (
ClassVar,
Dict,
Literal,
Optional,
TypedDict,
Union,
)

from pydantic import field_validator
from typing_extensions import NotRequired, Unpack

from ert.config.parsing.config_errors import ConfigWarning
Expand Down Expand Up @@ -172,6 +174,13 @@ class ForwardModelStep:
"_ERT_RUNPATH": "<RUNPATH>",
}

@field_validator("private_args", mode="before")
@classmethod
def convert_to_substitutions(cls, v: Dict[str, str]) -> Substitutions:
if isinstance(v, Substitutions):
return v
return Substitutions(v)

def validate_pre_experiment(self, fm_step_json: ForwardModelStepJSON) -> None:
"""
Raise errors pertaining to the environment not being
Expand Down
10 changes: 5 additions & 5 deletions src/ert/config/general_observation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List

import numpy as np
import numpy.typing as npt


@dataclass(eq=False)
class GenObservation:
values: npt.NDArray[np.double]
stds: npt.NDArray[np.double]
indices: npt.NDArray[np.int32]
std_scaling: npt.NDArray[np.double]
values: List[float]
stds: List[float]
indices: List[int]
std_scaling: List[float]

def __post_init__(self) -> None:
for val in self.stds:
Expand Down
10 changes: 6 additions & 4 deletions src/ert/config/observations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union
Expand Down Expand Up @@ -39,8 +39,8 @@ def history_key(key: str) -> str:

@dataclass
class EnkfObs:
obs_vectors: Dict[str, ObsVector]
obs_time: List[datetime]
obs_vectors: Dict[str, ObsVector] = field(default_factory=dict)
obs_time: List[datetime] = field(default_factory=list)

def __post_init__(self) -> None:
grouped: Dict[str, List[polars.DataFrame]] = {}
Expand Down Expand Up @@ -394,7 +394,9 @@ def _create_gen_obs(
f"index list ({indices}) must be of equal length",
obs_file if obs_file is not None else "",
)
return GenObservation(values, stds, indices, std_scaling)
return GenObservation(
values.tolist(), stds.tolist(), indices.tolist(), std_scaling.tolist()
)

@classmethod
def _handle_general_observation(
Expand Down
3 changes: 1 addition & 2 deletions src/ert/config/queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Any, Dict, List, Literal, Mapping, Optional, Union, no_type_check

import pydantic
from pydantic import Field
from pydantic.dataclasses import dataclass
from typing_extensions import Annotated

Expand Down Expand Up @@ -270,7 +269,7 @@ class QueueConfig:
queue_system: QueueSystem = QueueSystem.LOCAL
queue_options: Union[
LsfQueueOptions, TorqueQueueOptions, SlurmQueueOptions, LocalQueueOptions
] = Field(default_factory=LocalQueueOptions, discriminator="name")
] = pydantic.Field(default_factory=LocalQueueOptions, discriminator="name")
queue_options_test_run: LocalQueueOptions = field(default_factory=LocalQueueOptions)
stop_long_running: bool = False

Expand Down
8 changes: 4 additions & 4 deletions src/ert/config/refcase.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from dataclasses import dataclass
from datetime import datetime
from typing import (
Any,
List,
Optional,
Sequence,
)

import numpy as np
import numpy.typing as npt

from ._read_summary import read_summary
from .parsing.config_dict import ConfigDict
Expand All @@ -21,7 +19,7 @@ class Refcase:
start_date: datetime
keys: List[str]
dates: Sequence[datetime]
values: npt.NDArray[Any]
values: List[List[float]]

def __eq__(self, other: object) -> bool:
if not isinstance(other, Refcase):
Expand Down Expand Up @@ -50,5 +48,7 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Optional["Refcase"]:
raise ConfigValidationError(f"Could not read refcase: {err}") from err

return (
cls(start_date, refcase_keys, time_map, data) if data is not None else None
cls(start_date, refcase_keys, time_map, data.tolist())
if data is not None
else None
)
Loading
Loading