Skip to content

Commit

Permalink
Move everserver config to ServerConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Nov 13, 2024
1 parent 22b7f71 commit 93081a1
Show file tree
Hide file tree
Showing 12 changed files with 135 additions and 126 deletions.
4 changes: 2 additions & 2 deletions src/everest/bin/everest_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ert.config import ErtConfig
from ert.storage import open_storage
from everest.config import EverestConfig
from everest.config import EverestConfig, ServerConfig
from everest.detached import (
ServerStatus,
everserver_status,
Expand Down Expand Up @@ -84,7 +84,7 @@ def run_everest(options):
logger = logging.getLogger("everest_main")
server_state = everserver_status(options.config)

if server_is_running(*options.config.server_context):
if server_is_running(*ServerConfig.get_server_context(options.config.output_dir)):
config_file = options.config.config_file
print(
"An optimization is currently running.\n"
Expand Down
6 changes: 4 additions & 2 deletions src/everest/bin/kill_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import traceback
from functools import partial

from everest.config import EverestConfig
from everest.config import EverestConfig, ServerConfig
from everest.detached import server_is_running, stop_server, wait_for_server_to_stop
from everest.util import version_info

Expand Down Expand Up @@ -70,7 +70,9 @@ def _handle_keyboard_interrupt(signal, frame, after=False):


def kill_everest(options):
if not server_is_running(*options.config.server_context):
if not server_is_running(
*ServerConfig.get_server_context(options.config.output_dir)
):
print("Server is not running.")
return

Expand Down
4 changes: 2 additions & 2 deletions src/everest/bin/monitor_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import threading
from functools import partial

from everest.config import EverestConfig
from everest.config import EverestConfig, ServerConfig
from everest.detached import ServerStatus, everserver_status, server_is_running

from .utils import (
Expand Down Expand Up @@ -63,7 +63,7 @@ def monitor_everest(options):
config: EverestConfig = options.config
server_state = everserver_status(options.config)

if server_is_running(*config.server_context):
if server_is_running(*ServerConfig.get_server_context(config.output_dir)):
run_detached_monitor(config, show_all_jobs=options.show_all_jobs)
server_state = everserver_status(config)
if server_state["status"] == ServerStatus.failed:
Expand Down
75 changes: 1 addition & 74 deletions src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
import os
import shutil
Expand All @@ -11,7 +10,6 @@
Literal,
Optional,
Protocol,
Tuple,
no_type_check,
)

Expand Down Expand Up @@ -45,14 +43,9 @@

from ..config_file_loader import yaml_file_to_substituted_config_dict
from ..strings import (
CERTIFICATE_DIR,
DEFAULT_OUTPUT_DIR,
DETACHED_NODE_DIR,
HOSTFILE_NAME,
OPTIMIZATION_LOG_DIR,
OPTIMIZATION_OUTPUT_DIR,
SERVER_STATUS,
SESSION_DIR,
STORAGE_DIR,
)
from .control_config import ControlConfig
Expand Down Expand Up @@ -605,7 +598,7 @@ def config_file(self) -> Optional[str]:
return None

@property
def output_dir(self) -> Optional[str]:
def output_dir(self) -> str:
assert self.environment is not None
path = self.environment.output_folder

Expand Down Expand Up @@ -655,67 +648,6 @@ def storage_dir(self):
def log_dir(self):
return self._get_output_subdirectory(OPTIMIZATION_LOG_DIR)

@property
def detached_node_dir(self):
return self._get_output_subdirectory(DETACHED_NODE_DIR)

@property
def session_dir(self):
"""Return path to the session directory containing information about the
certificates and host information"""
return os.path.join(self.detached_node_dir, SESSION_DIR)

@property
def certificate_dir(self):
"""Return the path to certificate folder"""
return os.path.join(self.session_dir, CERTIFICATE_DIR)

def get_server_url(self, server_info=None):
"""Return the url of the server.
If server_info are given, the url is generated using that info. Otherwise
server information are retrieved from the hostfile
"""
if server_info is None:
server_info = self.server_info

url = f"https://{server_info['host']}:{server_info['port']}"
return url

@property
def hostfile_path(self):
return os.path.join(self.session_dir, HOSTFILE_NAME)

@property
def server_info(self):
"""Load server information from the hostfile"""
host_file_path = self.hostfile_path
try:
with open(host_file_path, "r", encoding="utf-8") as f:
json_string = f.read()

data = json.loads(json_string)
if set(data.keys()) != {"host", "port", "cert", "auth"}:
raise RuntimeError("Malformed hostfile")
return data
except FileNotFoundError:
# No host file
return {"host": None, "port": None, "cert": None, "auth": None}

@property
def server_context(self) -> Tuple[str, str, Tuple[str, str]]:
"""Returns a tuple with
- url of the server
- path to the .cert file
- password for the certificate file
"""

return (
self.get_server_url(self.server_info),
self.server_info[CERTIFICATE_DIR],
("username", self.server_info["auth"]),
)

@property
def export_path(self):
"""Returns the export file path. If not file name is provide the default
Expand All @@ -738,11 +670,6 @@ def export_path(self):
default_export_file = f"{os.path.splitext(self.config_file)[0]}.csv"
return os.path.join(full_file_path, default_export_file)

@property
def everserver_status_path(self):
"""Returns path to the everest server status file"""
return os.path.join(self.session_dir, SERVER_STATUS)

def to_dict(self) -> dict:
the_dict = self.model_dump(exclude_none=True)

Expand Down
75 changes: 74 additions & 1 deletion src/everest/config/server_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from typing import Literal, Optional
import json
import os
from typing import Literal, Optional, Tuple

from pydantic import BaseModel, ConfigDict, Field

from ..strings import (
CERTIFICATE_DIR,
DETACHED_NODE_DIR,
HOSTFILE_NAME,
SERVER_STATUS,
SESSION_DIR,
)
from .has_ert_queue_options import HasErtQueueOptions


Expand Down Expand Up @@ -41,3 +50,67 @@ class ServerConfig(BaseModel, HasErtQueueOptions): # type: ignore
model_config = ConfigDict(
extra="forbid",
)

@staticmethod
def get_server_url(output_dir: str) -> str:
"""Return the url of the server.
If server_info are given, the url is generated using that info. Otherwise
server information are retrieved from the hostfile
"""
server_info = ServerConfig.get_server_info(output_dir)
return f"https://{server_info['host']}:{server_info['port']}"

@staticmethod
def get_server_context(output_dir: str) -> Tuple[str, bool, Tuple[str, str]]:
"""Returns a tuple with
- url of the server
- path to the .cert file
- password for the certificate file
"""
server_info = ServerConfig.get_server_info(output_dir)
return (
ServerConfig.get_server_url(output_dir),
server_info[CERTIFICATE_DIR],
("username", server_info["auth"]),
)

@staticmethod
def get_server_info(output_dir: str) -> dict:
"""Load server information from the hostfile"""
host_file_path = ServerConfig.get_hostfile_path(output_dir)
try:
with open(host_file_path, "r", encoding="utf-8") as f:
json_string = f.read()

data = json.loads(json_string)
if set(data.keys()) != {"host", "port", "cert", "auth"}:
raise RuntimeError("Malformed hostfile")
return data
except FileNotFoundError:
# No host file
return {"host": None, "port": None, "cert": None, "auth": None}

@staticmethod
def get_detached_node_dir(output_dir: str) -> str:
return os.path.join(os.path.abspath(output_dir), DETACHED_NODE_DIR)

@staticmethod
def get_hostfile_path(output_dir: str) -> str:
return os.path.join(ServerConfig.get_session_dir(output_dir), HOSTFILE_NAME)

@staticmethod
def get_session_dir(output_dir: str) -> str:
"""Return path to the session directory containing information about the
certificates and host information"""
return os.path.join(ServerConfig.get_detached_node_dir(output_dir), SESSION_DIR)

@staticmethod
def get_everserver_status_path(output_dir: str) -> str:
"""Returns path to the everest server status file"""
return os.path.join(ServerConfig.get_session_dir(output_dir), SERVER_STATUS)

@staticmethod
def get_certificate_dir(output_dir: str) -> str:
"""Return the path to certificate folder"""
return os.path.join(ServerConfig.get_session_dir(output_dir), CERTIFICATE_DIR)
30 changes: 17 additions & 13 deletions src/everest/detached/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from ert import BatchContext, BatchSimulator, JobState
from ert.config import ErtConfig, QueueSystem
from everest.config import EverestConfig
from everest.config import EverestConfig, ServerConfig
from everest.config_keys import ConfigKeys as CK
from everest.strings import (
EVEREST,
Expand Down Expand Up @@ -59,7 +59,9 @@ def start_server(config: EverestConfig, ert_config: ErtConfig, storage):
"""
Start an Everest server running the optimization defined in the config
"""
if server_is_running(*config.server_context): # better safe than sorry
if server_is_running(
*ServerConfig.get_server_context(config.output_dir)
): # better safe than sorry
return

log_dir = config.log_dir
Expand Down Expand Up @@ -143,7 +145,7 @@ def stop_server(config: EverestConfig, retries: int = 5):
"""
for retry in range(retries):
try:
url, cert, auth = config.server_context
url, cert, auth = ServerConfig.get_server_context(config.output_dir)
stop_endpoint = "/".join([url, STOP_ENDPOINT])
response = requests.post(
stop_endpoint,
Expand Down Expand Up @@ -174,7 +176,7 @@ def wait_for_server(
Raise an exception when the timeout is reached.
"""
if not server_is_running(*config.server_context):
if not server_is_running(*ServerConfig.get_server_context(config.output_dir)):
sleep_time_increment = float(timeout) / (2**_HTTP_REQUEST_RETRY - 1)
for retry_count in range(_HTTP_REQUEST_RETRY):
# Failure may occur before contact with the server is established:
Expand Down Expand Up @@ -218,11 +220,11 @@ def wait_for_server(

sleep_time = sleep_time_increment * (2**retry_count)
time.sleep(sleep_time)
if server_is_running(*config.server_context):
if server_is_running(*ServerConfig.get_server_context(config.output_dir)):
return

# If number of retries reached and server is not running - throw exception
if not server_is_running(*config.server_context):
if not server_is_running(*ServerConfig.get_server_context(config.output_dir)):
raise RuntimeError("Failed to start server within configured timeout.")


Expand Down Expand Up @@ -264,16 +266,18 @@ def wait_for_server_to_stop(config: EverestConfig, timeout):
Raise an exception when the timeout is reached.
"""
if server_is_running(*config.server_context):
if server_is_running(*ServerConfig.get_server_context(config.output_dir)):
sleep_time_increment = float(timeout) / (2**_HTTP_REQUEST_RETRY - 1)
for retry_count in range(_HTTP_REQUEST_RETRY):
sleep_time = sleep_time_increment * (2**retry_count)
time.sleep(sleep_time)
if not server_is_running(*config.server_context):
if not server_is_running(
*ServerConfig.get_server_context(config.output_dir)
):
return

# If number of retries reached and server still running - throw exception
if server_is_running(*config.server_context):
if server_is_running(*ServerConfig.get_server_context(config.output_dir)):
raise Exception("Failed to stop server within configured timeout.")


Expand Down Expand Up @@ -310,7 +314,7 @@ def start_monitor(config: EverestConfig, callback, polling_interval=5):
Monitoring stops when the server stops answering. It can also be
interrupted by returning True from the callback
"""
url, cert, auth = config.server_context
url, cert, auth = ServerConfig.get_server_context(config.output_dir)
sim_endpoint = "/".join([url, SIM_PROGRESS_ENDPOINT])
opt_endpoint = "/".join([url, OPT_PROGRESS_ENDPOINT])

Expand Down Expand Up @@ -448,7 +452,7 @@ def generate_everserver_ert_config(config: EverestConfig, debug_mode: bool = Fal

site_config = ErtConfig.read_site_config()
abs_everest_config = os.path.join(config.config_directory, config.config_file)
detached_node_dir = config.detached_node_dir
detached_node_dir = ServerConfig.get_detached_node_dir(config.output_dir)
simulation_path = os.path.join(detached_node_dir, SIMULATION_DIR)
queue_system = _find_res_queue_system(config)
arg_list = ["--config-file", abs_everest_config]
Expand Down Expand Up @@ -532,7 +536,7 @@ def update_everserver_status(
):
"""Update the everest server status with new status information"""
new_status = {"status": status, "message": message}
path = config.everserver_status_path
path = ServerConfig.get_everserver_status_path(config.output_dir)
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
with open(path, "w", encoding="utf-8") as outfile:
Expand Down Expand Up @@ -560,7 +564,7 @@ def everserver_status(config: EverestConfig):
'message': None
}
"""
path = config.everserver_status_path
path = ServerConfig.get_everserver_status_path(config.output_dir)
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as f:
return json.load(f, object_hook=ServerStatusEncoder.decode)
Expand Down
Loading

0 comments on commit 93081a1

Please sign in to comment.