Skip to content

Commit

Permalink
Use spawn context for local driver
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Nov 25, 2024
1 parent 968a97c commit 5f0f2d9
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 19 deletions.
5 changes: 5 additions & 0 deletions src/_ert/forward_model_runner/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def _read_jobs_file(retry=True):


def main(args):
print("Started job_runner")
parser = argparse.ArgumentParser(
description=(
"Run all the jobs specified in jobs.json, "
Expand Down Expand Up @@ -127,6 +128,7 @@ def main(args):
dispatch_url = jobs_data.get("dispatch_url")

is_interactive_run = len(parsed_args.job) > 0
print("Setting up reporter")
reporters = _setup_reporters(
is_interactive_run,
ens_id,
Expand All @@ -136,11 +138,13 @@ def main(args):
experiment_id,
)

print("creating runner")
job_runner = ForwardModelRunner(jobs_data)

for job_status in job_runner.run(parsed_args.job):
logger.info(f"Job status: {job_status}")
for reporter in reporters:
print(reporter)
try:
reporter.report(job_status)
except OSError as oserror:
Expand All @@ -153,3 +157,4 @@ def main(args):
if isinstance(job_status, Finish) and not job_status.success():
pgid = os.getpgid(os.getpid())
os.killpg(pgid, signal.SIGKILL)
print("Job runner finished")
17 changes: 17 additions & 0 deletions src/_ert/forward_model_runner/reporting/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,21 @@ def __init__(self, evaluator_url, token=None, cert_path=None):

def _event_publisher(self):
logger.debug("Publishing event.")
print(
self._evaluator_url,
self._token,
self._cert,
)
with Client(
url=self._evaluator_url,
token=self._token,
cert=self._cert,
) as client:
event = None
while True:
print("In client loop")
with self._timestamp_lock:
print("Got timestamp lock")
if (
self._timeout_timestamp is not None
and datetime.now() > self._timeout_timestamp
Expand All @@ -102,11 +109,15 @@ def _event_publisher(self):
if event is None:
# if we successfully sent the event we can proceed
# to next one
print("Getting event_queue")
event = self._event_queue.get()
print(f"Got {event}")
if event is self._sentinel:
break
try:
print("Sending")
client.send(event_to_json(event))
print("Sent")
event = None
except ClientConnectionError as exception:
# Possible intermittent failure, we retry sending the event
Expand All @@ -118,7 +129,9 @@ def _event_publisher(self):
break

def report(self, msg):
print(f"Reporting {msg}")
self._statemachine.transition(msg)
print(f"Reported {msg}")

def _dump_event(self, event: events.Event):
logger.debug(f'Schedule "{type(event)}" for delivery')
Expand All @@ -127,7 +140,9 @@ def _dump_event(self, event: events.Event):
def _init_handler(self, msg: Init):
self._ens_id = str(msg.ens_id)
self._real_id = str(msg.real_id)
print("Starting event_publisher thread")
self._event_publisher_thread.start()
print(self._event_publisher_thread)

def _job_handler(self, msg: Union[Start, Running, Exited]):
assert msg.job
Expand Down Expand Up @@ -184,7 +199,9 @@ def _finished_handler(self, _):
seconds=self._reporter_timeout
)
if self._event_publisher_thread.is_alive():
print("Joining publisher thread")
self._event_publisher_thread.join()
print("publisher thread joined")

def _checksum_handler(self, msg: Checksum):
fm_checksum = ForwardModelStepChecksum(
Expand Down
1 change: 1 addition & 0 deletions src/_ert/forward_model_runner/reporting/statemachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def transition(self, message: Message):
if self._state not in self._transitions or not isinstance(
message, self._transitions[self._state]
):
print("Illegal transition")
logger.error(
f"{message} illegal state transition: {self._state} -> {new_state}"
)
Expand Down
3 changes: 3 additions & 0 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ async def handle_client(self, websocket: ServerConnection) -> None:
await websocket.send(event_to_json(event))

async for raw_msg in websocket:
print(f"{raw_msg=}")
event = event_from_json(raw_msg)
logger.debug(f"got message from client: {event}")
if type(event) is EEUserCancel:
Expand All @@ -238,9 +239,11 @@ async def count_dispatcher(self) -> AsyncIterator[None]:
self._dispatchers_connected.task_done()

async def handle_dispatch(self, websocket: ServerConnection) -> None:
print("handle_dispatch")
async with self.count_dispatcher():
try:
async for raw_msg in websocket:
print(f"{raw_msg=}")
try:
event = dispatch_event_from_json(raw_msg)
if event.ensemble != self.ensemble.id_:
Expand Down
42 changes: 24 additions & 18 deletions src/ert/scheduler/local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import asyncio
import contextlib
import logging
import os
import multiprocessing
import signal
from asyncio.subprocess import Process
from contextlib import suppress
from pathlib import Path
from typing import MutableMapping, Optional, Set
Expand Down Expand Up @@ -36,6 +35,7 @@ async def submit(
realization_memory: Optional[int] = 0,
) -> None:
self._tasks[iens] = asyncio.create_task(self._run(iens, executable, *args))
self._spawn_context = multiprocessing.get_context("spawn")
with suppress(KeyError):
self._sent_finished_events.remove(iens)

Expand Down Expand Up @@ -73,6 +73,9 @@ async def _run(self, iens: int, executable: str, /, *args: str) -> None:
executable,
*args,
)
print("start")
proc.start()
print("starte")
except FileNotFoundError as err:
# /bin/sh uses returncode 127 for FileNotFound, so copy that
# behaviour.
Expand All @@ -86,8 +89,11 @@ async def _run(self, iens: int, executable: str, /, *args: str) -> None:

returncode = 1
try:
returncode = await self._wait(proc)
logger.info(f"Realization {iens} finished with {returncode=}")
while proc.is_alive():
proc.join(timeout=1)
await asyncio.sleep(1)
logger.info(f"Realization {iens} finished with exitcode={proc.exitcode}")
returncode = proc.exitcode if proc.exitcode is not None else 0
except asyncio.CancelledError:
returncode = await self._kill(proc)
finally:
Expand All @@ -99,35 +105,35 @@ async def _dispatch_finished_event(self, iens: int, returncode: int) -> None:
await self.event_queue.put(FinishedEvent(iens=iens, returncode=returncode))
self._sent_finished_events.add(iens)

@staticmethod
async def _init(iens: int, executable: str, /, *args: str) -> Process:
async def _init(
self, iens: int, executable: str, /, *args: str
) -> multiprocessing.Process:
"""This method exists to allow for mocking it in tests"""
return await asyncio.create_subprocess_exec(
executable,
*args,
preexec_fn=os.setpgrp,
)
from _ert.forward_model_runner.cli import main # noqa

@staticmethod
async def _wait(proc: Process) -> int:
return await proc.wait()
return self._spawn_context.Process(
target=main, args=[["job_dispatch.py", *args]]
) # type: ignore

@staticmethod
async def _kill(proc: Process) -> int:
async def _kill(proc: multiprocessing.Process) -> int:
try:
proc.terminate()
await asyncio.wait_for(proc.wait(), _TERMINATE_TIMEOUT)
proc.join()
except asyncio.TimeoutError:
proc.kill()
except ProcessLookupError:
# This will happen if the subprocess has not yet started
return signal.SIGTERM + SIGNAL_OFFSET
ret_val = await proc.wait()
proc.join()
# the returncode of a subprocess will be the negative signal value
# if it terminated due to a signal.
# https://docs.python.org/3/library/subprocess.html#subprocess.CompletedProcess.returncode
# we return SIGNAL_OFFSET + signal value to be in line with lfs/pbs drivers.
return -ret_val + SIGNAL_OFFSET
if proc.exitcode is not None:
return proc.exitcode + SIGNAL_OFFSET
else:
return -1 + SIGNAL_OFFSET

async def poll(self) -> None:
"""LocalDriver does not poll"""
1 change: 0 additions & 1 deletion tests/ert/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,6 @@ def _run_heat_equation(source_root):
parser,
[
ES_MDA_MODE,
"--disable-monitor",
"config.ert",
],
)
Expand Down

0 comments on commit 5f0f2d9

Please sign in to comment.