Skip to content

Commit

Permalink
Fix _storage_main terminate_on_parent_death not working on mac
Browse files Browse the repository at this point in the history
Prior to this commit, the `terminate_on_parent_death` function was only usable on linux, due to it using the prctl command.
This commit creates a new thread which polls the parent process, and signals terminate when it can no longer find the parent.
  • Loading branch information
jonathan-eq committed Nov 4, 2024
1 parent 90d11ec commit bb80e7a
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 23 deletions.
48 changes: 25 additions & 23 deletions src/ert/services/_storage_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import signal
import socket
import string
import sys
import threading
import time
import warnings
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -123,30 +124,23 @@ def run_server(args: Optional[argparse.Namespace] = None, debug: bool = False) -
server.run(sockets=[sock])


def terminate_on_parent_death() -> None:
"""Quit the server when the parent does a SIGABRT or is otherwise destroyed.
This functionality has existed on Linux for a good while, but it isn't
exposed in the Python standard library. Use ctypes to hook into the
functionality.
def terminate_on_parent_death(
stopped: threading.Event, poll_interval: float = 1.0
) -> None:
"""
Quit the server when the parent process is no longer running.
"""
if sys.platform != "linux" or "ERT_COMM_FD" not in os.environ:
return

from ctypes import CDLL, c_int, c_ulong # noqa: PLC0415

lib = CDLL(None)

# from <sys/prctl.h>
# int prctl(int option, ...)
prctl = lib.prctl
prctl.restype = c_int
prctl.argtypes = (c_int, c_ulong)
def check_parent_alive() -> bool:
return os.getppid() != 1

# from <linux/prctl.h>
PR_SET_PDEATHSIG = 1
while check_parent_alive():
if stopped.is_set():
return
time.sleep(poll_interval)

# connect parent death signal to our SIGTERM
prctl(PR_SET_PDEATHSIG, signal.SIGTERM)
# Parent is no longer alive, terminate this process.
os.kill(os.getpid(), signal.SIGTERM)


if __name__ == "__main__":
Expand All @@ -156,6 +150,14 @@ def terminate_on_parent_death() -> None:
warnings.filterwarnings("ignore", category=DeprecationWarning)
uvicorn.config.LOGGING_CONFIG.clear()
uvicorn.config.LOGGING_CONFIG.update(logging_conf)
terminate_on_parent_death()
_stopped = threading.Event()
terminate_on_parent_death_thread = threading.Thread(
target=terminate_on_parent_death, args=[_stopped, 1.0]
)
with ErtPluginContext(logger=logging.getLogger()) as context:
run_server(debug=False)
terminate_on_parent_death_thread.start()
try:
run_server(debug=False)
finally:
_stopped.set()
terminate_on_parent_death_thread.join()
40 changes: 40 additions & 0 deletions tests/ert/ui_tests/cli/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import fileinput
import json
import logging
Expand All @@ -13,6 +14,7 @@
import pytest
import websockets.exceptions
import xtgeo
from psutil import NoSuchProcess, Process, ZombieProcess
from resdata.summary import Summary

import _ert.threading
Expand Down Expand Up @@ -955,3 +957,41 @@ def raise_connection_error(*args, **kwargs):
ENSEMBLE_EXPERIMENT_MODE,
"poly.ert",
)


@pytest.mark.usefixtures("copy_poly_case")
async def test_that_killed_ert_does_not_leave_storage_server_process():
ert_subprocess = await asyncio.create_subprocess_exec("ert", "gui", "poly.ert")
assert ert_subprocess.returncode is None
ert_process = Process(ert_subprocess.pid)

async def _find_storage_process_pid() -> int:
while True:
for ert_child_process in ert_process.children():
try:
if "_storage_main" in ert_child_process.cmdline()[1]:
return ert_child_process.pid
except (ZombieProcess, NoSuchProcess, IndexError):
pass
await asyncio.sleep(0.10)

storage_process_pid = await asyncio.wait_for(
_find_storage_process_pid(), timeout=30
)

assert ert_subprocess.returncode is None
ert_subprocess.kill()
await ert_subprocess.wait()
assert ert_subprocess.returncode is not None

storage_process = Process(storage_process_pid)

async def _wait_for_storage_process_to_shut_down():
storage_server_has_shutdown = asyncio.Event()
while not storage_server_has_shutdown.is_set():
if not storage_process.is_running():
storage_server_has_shutdown.set()
await asyncio.sleep(0.1)

await asyncio.wait_for(_wait_for_storage_process_to_shut_down(), timeout=15)
assert not storage_process.is_running()

0 comments on commit bb80e7a

Please sign in to comment.