Skip to content

Commit

Permalink
(squash) Address review
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Nov 25, 2024
1 parent 106b7bc commit 687513a
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 32 deletions.
5 changes: 4 additions & 1 deletion src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,10 @@ def validate(self) -> None:

@tracer.start_as_current_span(f"{__name__}.run_workflows")
def run_workflows(
self, runtime: HookRuntime, storage: Storage, ensemble: Ensemble
self,
runtime: HookRuntime,
storage: Storage | None = None,
ensemble: Ensemble | None = None,
) -> None:
for workflow in self.ert_config.hooked_workflows[runtime]:
WorkflowRunner(workflow, storage, ensemble).run_blocking()
Expand Down
7 changes: 3 additions & 4 deletions src/ert/run_models/ensemble_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np

from ert.config import HookRuntime
from ert.enkf_main import sample_prior
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Ensemble, Experiment, Storage
Expand All @@ -15,7 +16,7 @@
from .base_run_model import BaseRunModel, StatusEvents

if TYPE_CHECKING:
from ert.config import ErtConfig, HookRuntime, QueueConfig
from ert.config import ErtConfig, QueueConfig


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,6 +65,7 @@ def run_experiment(
) -> None:
self.log_at_startup()
if not restart:
self.run_workflows(HookRuntime.PRE_EXPERIMENT)
self.experiment = self._storage.create_experiment(
name=self.experiment_name,
parameters=self.ert_config.ensemble_config.parameter_configuration,
Expand All @@ -83,16 +85,13 @@ def run_experiment(

self.set_env_key("_ERT_EXPERIMENT_ID", str(self.experiment.id))
self.set_env_key("_ERT_ENSEMBLE_ID", str(self.ensemble.id))
self.set_env_key("_ERT_ITERATION", "0")
self.set_env_key("_IS_FINAL_ITERATION", "False")

run_args = create_run_arguments(
self.run_paths,
np.array(self.active_realizations, dtype=bool),
ensemble=self.ensemble,
)

self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, self.ensemble)
sample_prior(
self.ensemble,
np.where(self.active_realizations)[0],
Expand Down
4 changes: 1 addition & 3 deletions src/ert/run_models/ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def run_experiment(
self, evaluator_server_config: EvaluatorServerConfig, restart: bool = False
) -> None:
self.log_at_startup()
self.run_workflows(HookRuntime.PRE_EXPERIMENT)
ensemble_format = self.target_ensemble_format
experiment = self._storage.create_experiment(
parameters=self.ert_config.ensemble_config.parameter_configuration,
Expand All @@ -82,10 +83,7 @@ def run_experiment(
np.array(self.active_realizations, dtype=bool),
ensemble=prior,
)
self.set_env_key("_ERT_ITERATION", "0")
self.set_env_key("_IS_FINAL_ITERATION", "True")

self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, self.ensemble)
sample_prior(
prior,
np.where(self.active_realizations)[0],
Expand Down
4 changes: 2 additions & 2 deletions src/ert/run_models/iterated_ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def run_experiment(
self, evaluator_server_config: EvaluatorServerConfig, restart: bool = False
) -> None:
self.log_at_startup()

self.run_workflows(HookRuntime.PRE_EXPERIMENT)
target_ensemble_format = self.target_ensemble_format
experiment = self._storage.create_experiment(
parameters=self.ert_config.ensemble_config.parameter_configuration,
Expand Down Expand Up @@ -151,7 +151,7 @@ def run_experiment(
"_IS_FINAL_ITERATION",
"False",
)
self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, prior)

sample_prior(
prior,
np.where(self.active_realizations)[0],
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/multiple_data_assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def run_experiment(
f"Prior ensemble with ID: {id} does not exists"
) from err
else:
self.run_workflows(HookRuntime.PRE_EXPERIMENT)
sim_args = {"weights": self._relative_weights}
experiment = self._storage.create_experiment(
parameters=self.ert_config.ensemble_config.parameter_configuration,
Expand All @@ -135,7 +136,6 @@ def run_experiment(
ensemble=prior,
)

self.run_workflows(HookRuntime.PRE_EXPERIMENT, self._storage, self.ensemble)
sample_prior(
prior,
np.where(self.active_realizations)[0],
Expand Down
41 changes: 23 additions & 18 deletions tests/ert/ui_tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,49 +545,49 @@ def test_that_stop_on_fail_workflow_jobs_stop_ert(


@pytest.mark.usefixtures("copy_poly_case")
def test_that_post_experiment_hook_works(
def test_that_pre_post_experiment_hook_works(
monkeypatch,
):
monkeypatch.setattr(_ert.threading, "_can_raise", False)

# The executable
with open("dump_final_ensemble_id.sh", "w", encoding="utf-8") as f:
with open("hello_post_exp.sh", "w", encoding="utf-8") as f:
f.write(
dedent("""#!/bin/bash
echo $_IS_FINAL_ITERATION> final_ensemble_info.txt
echo "just sending regards" > from_post_experiment.txt
""")
)
os.chmod("dump_final_ensemble_id.sh", 0o755)
os.chmod("hello_post_exp.sh", 0o755)

# The workflow job
with open("DUMP_FINAL_ENSEMBLE_ID", "w", encoding="utf-8") as s:
with open("SAY_HELLO_POST_EXP", "w", encoding="utf-8") as s:
s.write("""
INTERNAL False
EXECUTABLE dump_final_ensemble_info.sh
EXECUTABLE hello_post_exp.sh
""")

# The workflow
with open("POST_EXPERIMENT_DUMP.WF", "w", encoding="utf-8") as s:
with open("SAY_HELLO_POST_EXP.wf", "w", encoding="utf-8") as s:
s.write("""dump_final_ensemble_id""")

# The executable
with open("dump_first_ensemble_id.sh", "w", encoding="utf-8") as f:
with open("hello_pre_exp.sh", "w", encoding="utf-8") as f:
f.write(
dedent("""#!/bin/bash
echo $_ERT_ITERATION > first_ensemble_id.txt
echo "first" > from_pre_experiment.txt
""")
)
os.chmod("dump_first_ensemble_id.sh", 0o755)
os.chmod("hello_pre_exp.sh", 0o755)

# The workflow job
with open("DUMP_FIRST_ENSEMBLE_ID", "w", encoding="utf-8") as s:
with open("SAY_HELLO_PRE_EXP", "w", encoding="utf-8") as s:
s.write("""
INTERNAL False
EXECUTABLE dump_first_ensemble_id.sh
EXECUTABLE hello_pre_exp.sh
""")

# The workflow
with open("PRE_EXPERIMENT_DUMP.WF", "w", encoding="utf-8") as s:
with open("SAY_HELLO_PRE_EXP.wf", "w", encoding="utf-8") as s:
s.write("""dump_first_ensemble_id""")

with open("poly.ert", mode="a", encoding="utf-8") as fh:
Expand All @@ -596,20 +596,25 @@ def test_that_post_experiment_hook_works(
"""
NUM_REALIZATIONS 2
LOAD_WORKFLOW_JOB DUMP_FINAL_ENSEMBLE_ID dump_final_ensemble_id
LOAD_WORKFLOW POST_EXPERIMENT_DUMP.WF POST_EXPERIMENT_DUMP
LOAD_WORKFLOW_JOB SAY_HELLO_POST_EXP dump_final_ensemble_id
LOAD_WORKFLOW SAY_HELLO_POST_EXP.wf POST_EXPERIMENT_DUMP
HOOK_WORKFLOW POST_EXPERIMENT_DUMP POST_EXPERIMENT
LOAD_WORKFLOW_JOB DUMP_FIRST_ENSEMBLE_ID dump_first_ensemble_id
LOAD_WORKFLOW PRE_EXPERIMENT_DUMP.WF PRE_EXPERIMENT_DUMP
LOAD_WORKFLOW_JOB SAY_HELLO_PRE_EXP dump_first_ensemble_id
LOAD_WORKFLOW SAY_HELLO_PRE_EXP.wf PRE_EXPERIMENT_DUMP
HOOK_WORKFLOW PRE_EXPERIMENT_DUMP PRE_EXPERIMENT
"""
)
)

run_cli(ITERATIVE_ENSEMBLE_SMOOTHER_MODE, "--disable-monitor", "poly.ert")

# ...2do assert correct contents in files
assert (Path(os.getcwd()) / "from_pre_experiment.txt").read_text(
"utf-8"
) == "first\n"
assert (Path(os.getcwd()) / "from_post_experiment.txt").read_text(
"utf-8"
) == "just sending regards\n"


@pytest.fixture(name="mock_cli_run")
Expand Down
11 changes: 8 additions & 3 deletions tests/ert/unit_tests/cli/test_model_hook_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
)

EXPECTED_CALL_ORDER = [
HookRuntime.PRE_EXPERIMENT,
HookRuntime.PRE_SIMULATION,
HookRuntime.POST_SIMULATION,
HookRuntime.PRE_FIRST_UPDATE,
HookRuntime.PRE_UPDATE,
HookRuntime.POST_UPDATE,
HookRuntime.PRE_SIMULATION,
HookRuntime.POST_SIMULATION,
HookRuntime.POST_EXPERIMENT,
]


Expand Down Expand Up @@ -57,7 +59,8 @@ def test_hook_call_order_ensemble_smoother(monkeypatch):
test_class.run_experiment(MagicMock())

expected_calls = [
call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER
call(HookRuntime.PRE_EXPERIMENT),
*[call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER[1:]],
]
assert run_wfs_mock.mock_calls == expected_calls

Expand Down Expand Up @@ -93,7 +96,8 @@ def test_hook_call_order_es_mda(monkeypatch):
test_class.run_experiment(MagicMock())

expected_calls = [
call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER
call(HookRuntime.PRE_EXPERIMENT),
*[call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER[1:]],
]
assert run_wfs_mock.mock_calls == expected_calls

Expand Down Expand Up @@ -128,6 +132,7 @@ def test_hook_call_order_iterative_ensemble_smoother(monkeypatch):
test_class.run_experiment(MagicMock())

expected_calls = [
call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER
call(HookRuntime.PRE_EXPERIMENT),
*[call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER[1:]],
]
assert run_wfs_mock.mock_calls == expected_calls

0 comments on commit 687513a

Please sign in to comment.