Skip to content

Commit

Permalink
Convert ErtConfig to dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Oct 5, 2024
1 parent 611105a commit 329b324
Show file tree
Hide file tree
Showing 21 changed files with 187 additions and 85 deletions.
15 changes: 10 additions & 5 deletions src/ert/config/ensemble_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from ert.field_utils import get_shape

from .field import Field
from .ext_param_config import ExtParamConfig
from .field import Field as FieldConfig
from .gen_data_config import GenDataConfig
from .gen_kw_config import GenKwConfig
from .parameter_config import ParameterConfig
Expand Down Expand Up @@ -49,8 +50,12 @@ def _get_abs_path(file: Optional[str]) -> Optional[str]:
@dataclass
class EnsembleConfig:
grid_file: Optional[str] = None
response_configs: Dict[str, ResponseConfig] = field(default_factory=dict)
parameter_configs: Dict[str, ParameterConfig] = field(default_factory=dict)
response_configs: Dict[str, Union[SummaryConfig, GenDataConfig]] = field(
default_factory=dict
)
parameter_configs: Dict[
str, GenKwConfig | FieldConfig | SurfaceConfig | ExtParamConfig
] = field(default_factory=dict)
refcase: Optional[Refcase] = None
eclbase: Optional[str] = None

Expand Down Expand Up @@ -93,7 +98,7 @@ def from_dict(cls, config_dict: ConfigDict) -> EnsembleConfig:
grid_file_path,
) from err

def make_field(field_list: List[str]) -> Field:
def make_field(field_list: List[str]) -> FieldConfig:
if grid_file_path is None:
raise ConfigValidationError.with_context(
"In order to use the FIELD keyword, a GRID must be supplied.",
Expand All @@ -104,7 +109,7 @@ def make_field(field_list: List[str]) -> Field:
f"Grid file {grid_file_path} did not contain dimensions",
grid_file_path,
)
return Field.from_config_list(grid_file_path, dims, field_list)
return FieldConfig.from_config_list(grid_file_path, dims, field_list)

parameter_configs = (
[GenKwConfig.from_config_list(g) for g in gen_kw_list]
Expand Down
55 changes: 41 additions & 14 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import logging
import os
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import field
from datetime import datetime
from os import path
from pathlib import Path
from typing import (
Any,
ClassVar,
DefaultDict,
Dict,
List,
Optional,
Expand All @@ -24,6 +25,8 @@

import xarray as xr
from pydantic import ValidationError as PydanticValidationError
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self

from ert.plugins import ErtPluginManager
Expand All @@ -49,6 +52,7 @@
ConfigWarning,
ErrorInfo,
ForwardModelStepKeys,
HistorySource,
HookRuntime,
init_forward_model_schema,
init_site_config_schema,
Expand Down Expand Up @@ -98,7 +102,9 @@ class ErtConfig:
queue_config: QueueConfig = field(default_factory=QueueConfig)
workflow_jobs: Dict[str, WorkflowJob] = field(default_factory=dict)
workflows: Dict[str, Workflow] = field(default_factory=dict)
hooked_workflows: Dict[HookRuntime, List[Workflow]] = field(default_factory=dict)
hooked_workflows: DefaultDict[HookRuntime, List[Workflow]] = field(
default_factory=lambda: defaultdict(list)
)
runpath_file: Path = Path(DEFAULT_RUNPATH_FILE)
ert_templates: List[Tuple[str, str]] = field(default_factory=list)
installed_forward_model_steps: Dict[str, ForwardModelStep] = field(
Expand All @@ -111,15 +117,21 @@ class ErtConfig:
observation_config: List[
Tuple[str, Union[HistoryValues, SummaryValues, GenObsValues]]
] = field(default_factory=list)
enkf_obs: EnkfObs = field(default_factory=EnkfObs)

@field_validator("substitution_list", mode="before")
@classmethod
def convert_to_substitution_list(cls, v: Dict[str, str]) -> SubstitutionList:
if isinstance(v, SubstitutionList):
return v
return SubstitutionList(v)

def __post_init__(self) -> None:
self.config_path = (
path.dirname(path.abspath(self.user_config_file))
if self.user_config_file
else os.getcwd()
)
self.enkf_obs: EnkfObs = self._create_observations(self.observation_config)

self.observations: Dict[str, xr.Dataset] = self.enkf_obs.datasets

@staticmethod
Expand Down Expand Up @@ -276,7 +288,7 @@ def from_dict(cls, config_dict) -> Self:
errors.append(err)

obs_config_file = config_dict.get(ConfigKeys.OBS_CONFIG)
obs_config_content = None
obs_config_content = []
try:
if obs_config_file:
if path.isfile(obs_config_file) and path.getsize(obs_config_file) == 0:
Expand Down Expand Up @@ -307,6 +319,19 @@ def from_dict(cls, config_dict) -> Self:
[key] for key in summary_obs if key not in summary_keys
]
ensemble_config = EnsembleConfig.from_dict(config_dict=config_dict)
if model_config:
observations = cls._create_observations(
obs_config_content,
ensemble_config,
model_config.time_map,
model_config.history_source,
)
else:
errors.append(
ConfigValidationError(
"Not possible to validate observations without valid model config"
)
)
except ConfigValidationError as err:
errors.append(err)

Expand Down Expand Up @@ -339,6 +364,7 @@ def from_dict(cls, config_dict) -> Self:
model_config=model_config,
user_config_file=config_file_path,
observation_config=obs_config_content,
enkf_obs=observations,
)

@classmethod
Expand Down Expand Up @@ -947,24 +973,25 @@ def _installed_forward_model_steps_from_dict(
def preferred_num_cpu(self) -> int:
return int(self.substitution_list.get(f"<{ConfigKeys.NUM_CPU}>", 1))

@staticmethod
def _create_observations(
self,
obs_config_content: Optional[
Dict[str, Union[HistoryValues, SummaryValues, GenObsValues]]
],
ensemble_config: EnsembleConfig,
time_map: Optional[List[datetime]],
history: HistorySource,
) -> EnkfObs:
if not obs_config_content:
return EnkfObs({}, [])
obs_vectors: Dict[str, ObsVector] = {}
obs_time_list: Sequence[datetime] = []
if self.ensemble_config.refcase is not None:
obs_time_list = self.ensemble_config.refcase.all_dates
elif self.model_config.time_map is not None:
obs_time_list = self.model_config.time_map
if ensemble_config.refcase is not None:
obs_time_list = ensemble_config.refcase.all_dates
elif time_map is not None:
obs_time_list = time_map

history = self.model_config.history_source
time_len = len(obs_time_list)
ensemble_config = self.ensemble_config
config_errors: List[ErrorInfo] = []
for obs_name, values in obs_config_content:
try:
Expand Down Expand Up @@ -1036,7 +1063,7 @@ def _get_files_in_directory(job_path, errors):


def _substitution_list_from_dict(config_dict) -> SubstitutionList:
subst_list = SubstitutionList()
subst_list = {}

for key, val in config_dict.get("DEFINE", []):
subst_list[key] = val
Expand All @@ -1054,7 +1081,7 @@ def _substitution_list_from_dict(config_dict) -> SubstitutionList:
for key, val in config_dict.get("DATA_KW", []):
subst_list[key] = val

return subst_list
return SubstitutionList(subst_list)


@no_type_check
Expand Down
2 changes: 1 addition & 1 deletion src/ert/config/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import logging
import os
import time
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Optional, Union, overload

import numpy as np
import xarray as xr
from pydantic.dataclasses import dataclass
from typing_extensions import Self

from ert.field_utils import FieldFileFormat, Shape, read_field, read_mask, save_field
Expand Down
9 changes: 9 additions & 0 deletions src/ert/config/forward_model_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from dataclasses import dataclass, field
from typing import (
ClassVar,
Dict,
Literal,
Optional,
TypedDict,
Union,
)

from pydantic import field_validator
from typing_extensions import NotRequired, Unpack

from ert.config.parsing.config_errors import ConfigWarning
Expand Down Expand Up @@ -174,6 +176,13 @@ class ForwardModelStep:
"_ERT_RUNPATH": "<RUNPATH>",
}

@field_validator("private_args", mode="before")
@classmethod
def convert_to_substitution_list(cls, v: Dict[str, str]) -> SubstitutionList:
if isinstance(v, SubstitutionList):
return v
return SubstitutionList(v)

def validate_pre_experiment(self, fm_step_json: ForwardModelStepJSON) -> None:
"""
Raise errors pertaining to the environment not being
Expand Down
4 changes: 1 addition & 3 deletions src/ert/config/gen_kw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ class TransformFunctionDefinition:
class GenKwConfig(ParameterConfig):
template_file: Optional[str]
output_file: Optional[str]
transform_function_definitions: (
List[TransformFunctionDefinition] | List[Dict[Any, Any]]
)
transform_function_definitions: List[TransformFunctionDefinition]
forward_init_file: Optional[str] = None

def __post_init__(self) -> None:
Expand Down
10 changes: 5 additions & 5 deletions src/ert/config/general_observation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List

import numpy as np
import numpy.typing as npt


@dataclass(eq=False)
class GenObservation:
values: npt.NDArray[np.double]
stds: npt.NDArray[np.double]
indices: npt.NDArray[np.int32]
std_scaling: npt.NDArray[np.double]
values: List[float]
stds: List[float]
indices: List[int]
std_scaling: List[float]

def __post_init__(self) -> None:
for val in self.stds:
Expand Down
10 changes: 6 additions & 4 deletions src/ert/config/observations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Union

Expand Down Expand Up @@ -38,8 +38,8 @@ def history_key(key: str) -> str:

@dataclass
class EnkfObs:
obs_vectors: Dict[str, ObsVector]
obs_time: List[datetime]
obs_vectors: Dict[str, ObsVector] = field(default_factory=dict)
obs_time: List[datetime] = field(default_factory=list)

def __post_init__(self) -> None:
self.datasets: Dict[str, xr.Dataset] = {
Expand Down Expand Up @@ -370,7 +370,9 @@ def _create_gen_obs(
f"index list ({indices}) must be of equal length",
obs_file if obs_file is not None else "",
)
return GenObservation(values, stds, indices, std_scaling)
return GenObservation(
values.tolist(), stds.tolist(), indices.tolist(), std_scaling.tolist()
)

@classmethod
def _handle_general_observation(
Expand Down
16 changes: 8 additions & 8 deletions src/ert/config/queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import re
import shutil
from abc import abstractmethod
from dataclasses import asdict, dataclass, field, fields
from dataclasses import asdict, field, fields
from typing import Any, Dict, List, Literal, Mapping, Optional, Union, no_type_check

import pydantic
from pydantic import Field
from pydantic.dataclasses import dataclass
from typing_extensions import Annotated

from .parsing import (
Expand Down Expand Up @@ -81,7 +81,7 @@ def driver_options(self) -> Dict[str, Any]:

@pydantic.dataclasses.dataclass
class LocalQueueOptions(QueueOptions):
name: Literal[QueueSystem.LOCAL.lower()] = QueueSystem.LOCAL.lower()
name: Literal[QueueSystem.LOCAL] = QueueSystem.LOCAL

@property
def driver_options(self) -> Dict[str, Any]:
Expand All @@ -90,7 +90,7 @@ def driver_options(self) -> Dict[str, Any]:

@pydantic.dataclasses.dataclass
class LsfQueueOptions(QueueOptions):
name: Literal[QueueSystem.LSF.lower()] = QueueSystem.LSF.lower()
name: Literal[QueueSystem.LSF] = QueueSystem.LSF
bhist_cmd: Optional[NonEmptyString] = None
bjobs_cmd: Optional[NonEmptyString] = None
bkill_cmd: Optional[NonEmptyString] = None
Expand All @@ -113,7 +113,7 @@ def driver_options(self) -> Dict[str, Any]:

@pydantic.dataclasses.dataclass
class TorqueQueueOptions(QueueOptions):
name: Literal[QueueSystem.TORQUE.lower()] = QueueSystem.TORQUE.lower()
name: Literal[QueueSystem.TORQUE] = QueueSystem.TORQUE
qsub_cmd: Optional[NonEmptyString] = None
qstat_cmd: Optional[NonEmptyString] = None
qdel_cmd: Optional[NonEmptyString] = None
Expand Down Expand Up @@ -149,7 +149,7 @@ def check_memory_per_job(cls, value: Optional[str]) -> Optional[str]:

@pydantic.dataclasses.dataclass
class SlurmQueueOptions(QueueOptions):
name: Literal[QueueSystem.SLURM.lower()]
name: Literal[QueueSystem.SLURM] = QueueSystem.SLURM
sbatch: NonEmptyString = "sbatch"
scancel: NonEmptyString = "scancel"
scontrol: NonEmptyString = "scontrol"
Expand Down Expand Up @@ -269,7 +269,7 @@ class QueueConfig:
queue_system: QueueSystem = QueueSystem.LOCAL
queue_options: Union[
LsfQueueOptions, TorqueQueueOptions, SlurmQueueOptions, LocalQueueOptions
] = Field(default_factory=LocalQueueOptions, discriminator="name")
] = pydantic.Field(default_factory=LocalQueueOptions, discriminator="name")
queue_options_test_run: LocalQueueOptions = field(default_factory=LocalQueueOptions)
stop_long_running: bool = False

Expand Down Expand Up @@ -363,7 +363,7 @@ def from_dict(cls, config_dict: ConfigDict) -> QueueConfig:
selected_queue_system,
queue_options,
queue_options_test_run,
stop_long_running=stop_long_running,
stop_long_running=bool(stop_long_running),
)

def create_local_copy(self) -> QueueConfig:
Expand Down
Loading

0 comments on commit 329b324

Please sign in to comment.