Skip to content

Commit

Permalink
Initialize ert storage outside Simulator.start
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk authored Oct 16, 2024
1 parent 353edda commit 2d4421b
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 171 deletions.
5 changes: 3 additions & 2 deletions src/ert/simulator/batch_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class BatchSimulator:
def __init__(
self,
ert_config: ErtConfig,
experiment: Experiment,
controls: Iterable[str],
results: Iterable[str],
callback: Optional[Callable[[BatchContext], None]] = None,
Expand Down Expand Up @@ -98,6 +99,7 @@ def callback(*args, **kwargs):
raise ValueError("The first argument must be valid ErtConfig instance")

self.ert_config = ert_config
self.experiment = experiment
self.control_keys = set(controls)
self.result_keys = set(results)
self.callback = callback
Expand Down Expand Up @@ -162,7 +164,6 @@ def start(
self,
case_name: str,
case_data: List[Tuple[int, Dict[str, Dict[str, Any]]]],
experiment: Experiment,
) -> BatchContext:
"""Start batch simulation, return a simulation context
Expand Down Expand Up @@ -221,7 +222,7 @@ def start(
time, so when you have called the 'start' method you need to let that
batch complete before you start a new batch.
"""
ensemble = experiment.create_ensemble(
ensemble = self.experiment.create_ensemble(
name=case_name,
ensemble_size=self.ert_config.model_config.num_realizations,
)
Expand Down
4 changes: 2 additions & 2 deletions src/everest/detached/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def start_server(config: EverestConfig, ert_config: ErtConfig, storage):
responses=[],
)

_server = BatchSimulator(ert_config, {}, [])
_context = _server.start("dispatch_server", [(0, {})], experiment)
_server = BatchSimulator(ert_config, experiment, {}, [])
_context = _server.start("dispatch_server", [(0, {})])

return _context

Expand Down
62 changes: 30 additions & 32 deletions src/everest/simulator/simulator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import time
from collections import defaultdict
from datetime import datetime
from itertools import count
from typing import Any, DefaultDict, Dict, List, Mapping, Optional, Tuple

Expand All @@ -10,20 +10,30 @@
from ropt.evaluator import EvaluatorContext, EvaluatorResult

from ert import BatchSimulator, WorkflowRunner
from ert.config import HookRuntime
from ert.storage import open_storage
from ert.config import ErtConfig, HookRuntime
from ert.storage import Storage
from everest.config import EverestConfig
from everest.simulator.everest_to_ert import everest_to_ert_config


class Simulator(BatchSimulator):
"""Everest simulator: BatchSimulator"""

def __init__(self, ever_config: EverestConfig, callback=None) -> None:
ert_config = everest_to_ert_config(ever_config)
def __init__(
self,
ever_config: EverestConfig,
ert_config: ErtConfig,
storage: Storage,
callback=None,
) -> None:
experiment = storage.create_experiment(
name=f"EnOpt@{datetime.datetime.now().strftime('%Y-%m-%d@%H:%M:%S')}",
parameters=ert_config.ensemble_config.parameter_configuration,
responses=ert_config.ensemble_config.response_configuration,
)

super(Simulator, self).__init__(
ert_config,
experiment,
self._get_controls(ever_config),
self._get_results(ever_config),
callback=callback,
Expand All @@ -36,6 +46,8 @@ def __init__(self, ever_config: EverestConfig, callback=None) -> None:
if ever_config.simulator is not None and ever_config.simulator.enable_cache:
self._cache = _SimulatorCache()

self.storage = storage

def _get_controls(self, ever_config: EverestConfig) -> List[str]:
controls = ever_config.controls or []
return [control.name for control in controls]
Expand Down Expand Up @@ -103,33 +115,19 @@ def __call__(
self._add_control(controls, control_name, control_value)
case_data.append((real_id, controls))

with open_storage(self.ert_config.ens_path, "w") as storage:
if self._experiment_id is None:
experiment = storage.create_experiment(
name=f"EnOpt@{datetime.now().strftime('%Y-%m-%d@%H:%M:%S')}",
parameters=self.ert_config.ensemble_config.parameter_configuration,
responses=self.ert_config.ensemble_config.response_configuration,
)
sim_context = self.start(f"batch_{self._batch}", case_data)

self._experiment_id = experiment.id
else:
experiment = storage.get_experiment(self._experiment_id)

sim_context = self.start(f"batch_{self._batch}", case_data, experiment)

while sim_context.running():
time.sleep(0.2)
results = sim_context.results()

# Pre-simulation workflows are run by sim_context, but
# post-stimulation workflows are not, do it here:
ensemble = sim_context.get_ensemble()
for workflow in self.ert_config.hooked_workflows[
HookRuntime.POST_SIMULATION
]:
WorkflowRunner(
workflow, storage, ensemble, ert_config=self.ert_config
).run_blocking()
while sim_context.running():
time.sleep(0.2)
results = sim_context.results()

# Pre-simulation workflows are run by sim_context, but
# post-stimulation workflows are not, do it here:
ensemble = sim_context.get_ensemble()
for workflow in self.ert_config.hooked_workflows[HookRuntime.POST_SIMULATION]:
WorkflowRunner(
workflow, self.storage, ensemble, ert_config=self.ert_config
).run_blocking()

for fnc_name, alias in self._function_aliases.items():
for result in results:
Expand Down
74 changes: 41 additions & 33 deletions src/everest/suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
from seba_sqlite import SqliteStorage

from ert.resources import all_shell_script_fm_steps
from ert.storage import open_storage
from everest.config import EverestConfig
from everest.optimizer.everest2ropt import everest2ropt
from everest.plugins.site_config_env import PluginSiteConfigEnv
from everest.simulator import Simulator
from everest.simulator.everest_to_ert import everest_to_ert_config
from everest.strings import EVEREST, SIMULATOR_END, SIMULATOR_START, SIMULATOR_UPDATE
from everest.util import makedirs_if_needed

Expand Down Expand Up @@ -388,39 +390,45 @@ def start_optimization(self):
of this method will probably lead to a crash
"""
assert self._monitor_thread is None

# Initialize the Everest simulator:
simulator = Simulator(self.config, callback=self._simulation_callback)

# Initialize the ropt optimizer:
optimizer = self._configure_optimizer(simulator)

# Before each batch evaluation we check if we should abort:
optimizer.add_observer(
EventType.START_EVALUATION,
partial(self._ropt_callback, optimizer=optimizer, simulator=simulator),
)

# The SqliteStorage object is used to store optimization results from
# Seba in an sqlite database. It reacts directly to events emitted by
# Seba and is not called by Everest directly. The stored results are
# accessed by Everest via separate SebaSnapshot objects.
# This mechanism is outdated and not supported by the ropt package. It
# is retained for now via the seba_sqlite package.
seba_storage = SqliteStorage(optimizer, self.config.optimization_output_dir)

# Run the optimization:
exit_code = optimizer.run().exit_code

# Extract the best result from the storage.
self._result = seba_storage.get_optimal_result()

if self._monitor_thread is not None:
self._monitor_thread.stop()
self._monitor_thread.join()
self._monitor_thread = None

return "max_batch_num_reached" if self._max_batch_num_reached else exit_code
ert_config = everest_to_ert_config(self._config)

with open_storage(ert_config.ens_path, mode="w") as storage:
simulator = Simulator(
self.config,
ert_config,
storage,
callback=self._simulation_callback,
)

# Initialize the ropt optimizer:
optimizer = self._configure_optimizer(simulator)

# Before each batch evaluation we check if we should abort:
optimizer.add_observer(
EventType.START_EVALUATION,
partial(self._ropt_callback, optimizer=optimizer, simulator=simulator),
)

# The SqliteStorage object is used to store optimization results from
# Seba in an sqlite database. It reacts directly to events emitted by
# Seba and is not called by Everest directly. The stored results are
# accessed by Everest via separate SebaSnapshot objects.
# This mechanism is outdated and not supported by the ropt package. It
# is retained for now via the seba_sqlite package.
seba_storage = SqliteStorage(optimizer, self.config.optimization_output_dir)

# Run the optimization:
exit_code = optimizer.run().exit_code

# Extract the best result from the storage.
self._result = seba_storage.get_optimal_result()

if self._monitor_thread is not None:
self._monitor_thread.stop()
self._monitor_thread.join()
self._monitor_thread = None

return "max_batch_num_reached" if self._max_batch_num_reached else exit_code

@property
def result(self):
Expand Down
Loading

0 comments on commit 2d4421b

Please sign in to comment.