diff --git a/src/aiidalab_qe/app/configuration/__init__.py b/src/aiidalab_qe/app/configuration/__init__.py
index 5551da9b..b9775ea0 100644
--- a/src/aiidalab_qe/app/configuration/__init__.py
+++ b/src/aiidalab_qe/app/configuration/__init__.py
@@ -56,7 +56,7 @@ def __init__(self, model: ConfigurationStepModel, **kwargs):
lambda structure: ""
if structure
else """
-
+
Please set the input structure first.
""",
diff --git a/src/aiidalab_qe/app/submission/__init__.py b/src/aiidalab_qe/app/submission/__init__.py
index d9176b9a..ba457fe7 100644
--- a/src/aiidalab_qe/app/submission/__init__.py
+++ b/src/aiidalab_qe/app/submission/__init__.py
@@ -39,10 +39,6 @@ def __init__(self, model: SubmissionStepModel, qe_auto_setup=True, **kwargs):
self._on_submission,
"confirmed",
)
- self._model.observe(
- self._on_input_structure_change,
- "input_structure",
- )
self._model.observe(
self._on_input_parameters_change,
"input_parameters",
@@ -77,22 +73,28 @@ def __init__(self, model: SubmissionStepModel, qe_auto_setup=True, **kwargs):
self.rendered = False
- global_code_model = GlobalResourceSettingsModel()
- self.global_code_settings = GlobalResourceSettingsPanel(model=global_code_model)
- self._model.add_model("global", global_code_model)
- global_code_model.observe(
+ global_resources_model = GlobalResourceSettingsModel()
+ self.global_resources = GlobalResourceSettingsPanel(
+ model=global_resources_model
+ )
+ self._model.add_model("global", global_resources_model)
+ ipw.dlink(
+ (self._model, "plugin_overrides"),
+ (global_resources_model, "plugin_overrides"),
+ )
+ global_resources_model.observe(
self._on_plugin_submission_blockers_change,
["submission_blockers"],
)
- global_code_model.observe(
+ global_resources_model.observe(
self._on_plugin_submission_warning_messages_change,
["submission_warning_messages"],
)
self.settings = {
- "global": self.global_code_settings,
+ "global": self.global_resources,
}
- self._fetch_plugin_settings()
+ self._fetch_plugin_resource_settings()
self._install_sssp(qe_auto_setup)
self._set_up_qe(qe_auto_setup)
@@ -211,14 +213,15 @@ def _on_tab_change(self, change):
tab: ResourceSettingsPanel = self.tabs.children[tab_index] # type: ignore
tab.render()
- def _on_input_structure_change(self, _):
- """"""
-
def _on_input_parameters_change(self, _):
- self._model.update_active_models()
- self._update_tabs()
self._model.update_process_label()
+ self._model.update_plugin_inclusion()
+ self._model.update_plugin_overrides()
self._model.update_submission_blockers()
+ self._update_tabs()
+
+ def _on_plugin_overrides_change(self, _):
+ self._model.update_plugin_overrides()
def _on_plugin_submission_blockers_change(self, _):
self._model.update_submission_blockers()
@@ -237,16 +240,13 @@ def _on_submission_blockers_change(self, _):
self._model.update_submission_blocker_message()
self._update_state()
- def _on_submission_warning_change(self, _):
- self._model.update_submission_warning_message()
-
def _on_installation_change(self, _):
self._model.update_submission_blockers()
def _on_qe_installed(self, _):
self._toggle_qe_installation_widget()
if self._model.qe_installed:
- self._model.refresh_codes()
+ self._model.update()
def _on_sssp_installed(self, _):
self._toggle_sssp_installation_widget()
@@ -325,14 +325,19 @@ def _update_state(self, _=None):
else:
self.state = self.state.CONFIGURED
- def _fetch_plugin_settings(self):
- eps = get_entry_items("aiidalab_qe.properties", "code")
- for identifier, data in eps.items():
+ def _fetch_plugin_resource_settings(self):
+ entries = get_entry_items("aiidalab_qe.properties", "resources")
+ for identifier, resources in entries.items():
for key in ("panel", "model"):
- if key not in data:
+ if key not in resources:
raise ValueError(f"Entry {identifier} is missing the '{key}' key")
- panel = data["panel"]
- model: ResourceSettingsModel = data["model"]()
+
+ panel = resources["panel"]
+ model: ResourceSettingsModel = resources["model"]()
+ model.observe(
+ self._on_plugin_overrides_change,
+ "override",
+ )
model.observe(
self._on_plugin_submission_blockers_change,
["submission_blockers"],
@@ -343,15 +348,6 @@ def _fetch_plugin_settings(self):
)
self._model.add_model(identifier, model)
- def toggle_plugin(_, model=model):
- model.update()
- self._update_tabs()
-
- model.observe(
- toggle_plugin,
- "include",
- )
-
self.settings[identifier] = panel(
identifier=identifier,
model=model,
diff --git a/src/aiidalab_qe/app/submission/global_settings/model.py b/src/aiidalab_qe/app/submission/global_settings/model.py
index 0345cc6a..88c07218 100644
--- a/src/aiidalab_qe/app/submission/global_settings/model.py
+++ b/src/aiidalab_qe/app/submission/global_settings/model.py
@@ -5,14 +5,11 @@
import traitlets as tl
from aiida import orm
-from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
from aiidalab_qe.common.code import CodeModel, PwCodeModel
from aiidalab_qe.common.mixins import HasInputStructure
from aiidalab_qe.common.panel import ResourceSettingsModel
from aiidalab_qe.common.widgets import QEAppComputationalResourcesWidget
-DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore
-
class GlobalResourceSettingsModel(
ResourceSettingsModel,
@@ -20,6 +17,8 @@ class GlobalResourceSettingsModel(
):
"""Model for the global code setting."""
+ identifier = "global"
+
dependencies = [
"input_parameters",
"input_structure",
@@ -27,33 +26,14 @@ class GlobalResourceSettingsModel(
input_parameters = tl.Dict()
- codes = tl.Dict(
- key_trait=tl.Unicode(), # code name
- value_trait=tl.Instance(CodeModel), # code metadata
- )
- # this is a copy of the codes trait, which is used to trigger the update of the plugin
- global_codes = tl.Dict(
- key_trait=tl.Unicode(), # code name
- value_trait=tl.Dict(), # code metadata
- )
-
- plugin_mapping = tl.Dict(
- key_trait=tl.Unicode(), # plugin identifier
- value_trait=tl.List(tl.Unicode()), # list of code names
- )
-
- submission_blockers = tl.List(tl.Unicode())
- submission_warning_messages = tl.Unicode("")
+ plugin_overrides = tl.List(tl.Unicode())
+ plugin_overrides_notification = tl.Unicode("")
include = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- # Used by the code-setup thread to fetch code options
- # This is necessary to avoid passing the User object
- # between session in separate threads.
- self._default_user_email = orm.User.collection.get_default().email
self._RUN_ON_LOCALHOST_NUM_SITES_WARN_THRESHOLD = 10
self._RUN_ON_LOCALHOST_VOLUME_WARN_THRESHOLD = 1000 # \AA^3
@@ -66,120 +46,85 @@ def __init__(self, *args, **kwargs):
"""
- def refresh_codes(self):
- for _, code_model in self.codes.items():
- code_model.update(self._default_user_email) # type: ignore
+ self.plugin_mapping: dict[str, list[str]] = {}
+
+ self.override = True
+
+ def update(self):
+ for _, code_model in self.get_models():
+ code_model.update(self.DEFAULT_USER_EMAIL)
+
+ def update_global_codes(self):
+ self.global_codes = self.get_model_state()["codes"]
def update_active_codes(self):
- for name, code_model in self.codes.items():
- if name != "quantumespresso.pw":
+ for identifier, code_model in self.get_models():
+ if identifier != "quantumespresso.pw":
code_model.deactivate()
properties = self._get_properties()
for identifier, code_names in self.plugin_mapping.items():
if identifier in properties:
for code_name in code_names:
- self.codes[code_name].activate()
+ self.get_model(code_name).activate()
- def get_model_state(self):
- codes = {name: model.get_model_state() for name, model in self.codes.items()}
-
- return {"codes": codes}
-
- def set_model_state(self, code_data: dict):
- for name, code_model in self.codes.items():
- if name in code_data and code_model.is_active:
- code_model.set_model_state(code_data[name])
+ def update_plugin_overrides_notification(self):
+ if self.plugin_overrides:
+ formatted = "\n".join(
+ f"
+
The submission is blocked due to the following reason(s):
- {fmt_list}
+ {formatted}
"""
@@ -178,8 +183,8 @@ def set_model_state(self, parameters):
def get_selected_codes(self) -> dict[str, dict]:
return {
- name: code_model.get_model_state()
- for name, code_model in self.get_model("global").codes.items()
+ identifier: code_model.get_model_state()
+ for identifier, code_model in self.get_model("global").get_models()
if code_model.is_ready
}
@@ -252,11 +257,9 @@ def _create_builder(self, parameters) -> ProcessBuilderNamespace:
return builder
def _check_submission_blockers(self):
- # Do not submit while any of the background setup processes are running.
if self.installing_qe or self.installing_sssp:
yield "Background setup processes must finish."
- # SSSP library not installed
if not self.sssp_installed:
yield "The SSSP library is not installed."
diff --git a/src/aiidalab_qe/common/code/model.py b/src/aiidalab_qe/common/code/model.py
index 4e40cc87..90b76157 100644
--- a/src/aiidalab_qe/common/code/model.py
+++ b/src/aiidalab_qe/common/code/model.py
@@ -24,6 +24,7 @@ class CodeModel(Model):
max_wallclock_seconds = tl.Int(3600 * 12)
allow_hidden_codes = tl.Bool(False)
allow_disabled_computers = tl.Bool(False)
+ override = tl.Bool(False)
def __init__(
self,
@@ -48,19 +49,24 @@ def __init__(
def is_ready(self):
return self.is_active and bool(self.selected)
+ @property
+ def first_option(self):
+ return self.options[0][1] if self.options else None # type: ignore
+
def activate(self):
self.is_active = True
def deactivate(self):
self.is_active = False
- def update(self, user_email: str):
- if not self.options:
+ def update(self, user_email="", refresh=False):
+ if not self.options or refresh:
self.options = self._get_codes(user_email)
- self.selected = self.options[0][1] if self.options else None
+ self.selected = self.first_option
def get_model_state(self) -> dict:
return {
+ "options": self.options,
"code": self.selected,
"nodes": self.num_nodes,
"cpus": self.num_cpus,
@@ -69,8 +75,12 @@ def get_model_state(self) -> dict:
"max_wallclock_seconds": self.max_wallclock_seconds,
}
- def set_model_state(self, parameters):
- self.selected = self._get_uuid(parameters["code"])
+ def set_model_state(self, parameters: dict):
+ self.selected = (
+ self._get_uuid(identifier)
+ if (identifier := parameters.get("code"))
+ else self.first_option
+ )
self.num_nodes = parameters.get("nodes", 1)
self.num_cpus = parameters.get("cpus", 1)
self.ntasks_per_node = parameters.get("ntasks_per_node", 1)
@@ -78,19 +88,15 @@ def set_model_state(self, parameters):
self.max_wallclock_seconds = parameters.get("max_wallclock_seconds", 3600 * 12)
def _get_uuid(self, identifier):
- if not self.selected:
- try:
- uuid = orm.load_code(identifier).uuid
- except NotExistent:
- uuid = None
- # If the code was imported from another user, it is not usable
- # in the app and thus will not be considered as an option!
- self.selected = uuid if uuid in [opt[1] for opt in self.options] else None
- return self.selected
-
- def _get_codes(self, user_email: str):
- # set default user_email if not provided
- user_email = user_email or orm.User.collection.get_default().email
+ try:
+ uuid = orm.load_code(identifier).uuid
+ except NotExistent:
+ uuid = None
+ # If the code was imported from another user, it is not usable
+ # in the app and thus will not be considered as an option!
+ return uuid if uuid in [opt[1] for opt in self.options] else None
+
+ def _get_codes(self, user_email: str = ""):
user = orm.User.collection.get(email=user_email)
filters = (
@@ -122,7 +128,7 @@ def _full_code_label(code):
class PwCodeModel(CodeModel):
- override = tl.Bool(False)
+ parallelization_override = tl.Bool(False)
npool = tl.Int(1)
def __init__(
@@ -142,14 +148,22 @@ def __init__(
def get_model_state(self) -> dict:
parameters = super().get_model_state()
- parameters["parallelization"] = {"npool": self.npool} if self.override else {}
+ parameters["parallelization"] = (
+ {
+ "npool": self.npool,
+ }
+ if self.parallelization_override
+ else {}
+ )
return parameters
def set_model_state(self, parameters):
super().set_model_state(parameters)
if "parallelization" in parameters and "npool" in parameters["parallelization"]:
- self.override = True
+ self.parallelization_override = True
self.npool = parameters["parallelization"].get("npool", 1)
+ else:
+ self.parallelization_override = False
CodesDict = dict[str, CodeModel]
diff --git a/src/aiidalab_qe/common/mixins.py b/src/aiidalab_qe/common/mixins.py
index 21421c23..bd710a8d 100644
--- a/src/aiidalab_qe/common/mixins.py
+++ b/src/aiidalab_qe/common/mixins.py
@@ -31,12 +31,19 @@ class HasModels(t.Generic[T]):
def __init__(self):
self._models: dict[str, T] = {}
+ def has_model(self, identifier):
+ return identifier in self._models
+
def add_model(self, identifier, model):
self._models[identifier] = model
self._link_model(model)
+ def add_models(self, models: dict[str, T]):
+ for identifier, model in models.items():
+ self.add_model(identifier, model)
+
def get_model(self, identifier) -> T:
- if identifier in self._models:
+ if self.has_model(identifier):
return self._models[identifier]
raise ValueError(f"Model with identifier '{identifier}' not found.")
diff --git a/src/aiidalab_qe/common/panel.py b/src/aiidalab_qe/common/panel.py
index 3cca22ab..0a5a913a 100644
--- a/src/aiidalab_qe/common/panel.py
+++ b/src/aiidalab_qe/common/panel.py
@@ -15,8 +15,9 @@
from aiida import orm
from aiida.common.extendeddicts import AttributeDict
+from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
from aiidalab_qe.common.code.model import CodeModel
-from aiidalab_qe.common.mixins import Confirmable, HasProcess
+from aiidalab_qe.common.mixins import Confirmable, HasModels, HasProcess
from aiidalab_qe.common.mvc import Model
from aiidalab_qe.common.widgets import (
LoadingWidget,
@@ -24,7 +25,7 @@
QEAppComputationalResourcesWidget,
)
-DEFAULT_PARAMETERS = {}
+DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore
class Panel(ipw.VBox):
@@ -90,18 +91,12 @@ class SettingsModel(Model):
_defaults = {}
- def update(self, specific=""):
- """Updates the model.
-
- Parameters
- ----------
- `specific` : `str`, optional
- If provided, specifies the level of update.
- """
+ def update(self):
+ """Updates the model."""
pass
def get_model_state(self) -> dict:
- """Retrieves the model current state as a dictionary."""
+ """Retrieves the current state of the model as a dictionary."""
raise NotImplementedError()
def set_model_state(self, parameters: dict):
@@ -118,7 +113,7 @@ def reset(self):
class SettingsPanel(Panel, t.Generic[SM]):
title = "Settings"
- description = ""
+ identifier = ""
def __init__(self, model: SM, **kwargs):
from aiidalab_qe.common.widgets import LoadingWidget
@@ -209,11 +204,12 @@ def _reset(self):
self._model.reset()
-class ResourceSettingsModel(SettingsModel):
+class ResourceSettingsModel(SettingsModel, HasModels[CodeModel]):
"""Base model for plugin code setting models."""
- dependencies = ["global.global_codes"]
- codes = {} # To be defined by subclasses
+ dependencies = [
+ "global.global_codes",
+ ]
global_codes = tl.Dict(
key_trait=tl.Unicode(),
@@ -228,38 +224,66 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Used by the code-setup thread to fetch code options
- self._default_user_email = orm.User.collection.get_default().email
+ self.DEFAULT_USER_EMAIL = orm.User.collection.get_default().email
- def refresh_codes(self):
- for _, code_model in self.codes.items():
- code_model.update(self._default_user_email)
+ def update(self):
+ """Updates the code models from the global resources.
- def update_code_from_global(self):
- # Skip the sync if the user has overridden the settings
+ Skips synchronization with global resources if the user has chosen to override
+ the resources for the plugin codes.
+ """
if self.override:
return
- for _, code_model in self.codes.items():
+ for _, code_model in self.get_models():
+ code_model.update(self.DEFAULT_USER_EMAIL)
default_calc_job_plugin = code_model.default_calc_job_plugin
if default_calc_job_plugin in self.global_codes:
- code_data = self.global_codes[default_calc_job_plugin]
- code_model.set_model_state(code_data)
+ code_resources: dict = self.global_codes[default_calc_job_plugin] # type: ignore
+ options = code_resources.get("options", [])
+ if options != code_model.options:
+ code_model.update(self.DEFAULT_USER_EMAIL, refresh=True)
+ code_model.set_model_state(code_resources)
+
+ def update_submission_blockers(self):
+ self.submission_blockers = list(self._check_submission_blockers())
def get_model_state(self):
- codes = {name: model.get_model_state() for name, model in self.codes.items()}
return {
- "codes": codes,
- "override": self.override,
+ "codes": {
+ identifier: code_model.get_model_state()
+ for identifier, code_model in self.get_models()
+ },
}
- def set_model_state(self, code_data: dict):
- for name, code_model in self.codes.items():
- if name in code_data:
- code_model.set_model_state(code_data[name])
+ def set_model_state(self, parameters: dict):
+ for name, code_model in self.get_models():
+ if name in parameters and code_model.is_active:
+ code_model.set_model_state(parameters[name])
+
+ def get_selected_codes(self) -> dict[str, dict]:
+ return {
+ identifier: code_model.get_model_state()
+ for identifier, code_model in self.get_models()
+ if code_model.is_ready
+ }
+
+ def set_selected_codes(self, code_data=DEFAULT["codes"]):
+ for identifier, code_model in self.get_models():
+ if identifier in code_data and code_model.is_active:
+ code_model.set_model_state(code_data[identifier])
def reset(self):
- """Reset the model to its default state."""
- for code_model in self.codes.values():
- code_model.reset()
+ """If not overridden, updates the model w.r.t the global resources."""
+ self.update()
+
+ def _check_submission_blockers(self):
+ return []
+
+ def _link_model(self, model: CodeModel):
+ tl.link(
+ (self, "override"),
+ (model, "override"),
+ )
RSM = t.TypeVar("RSM", bound=ResourceSettingsModel)
@@ -270,8 +294,7 @@ class ResourceSettingsPanel(SettingsPanel[RSM], t.Generic[RSM]):
def __init__(self, model, **kwargs):
super().__init__(model, **kwargs)
- self.code_widgets = {}
- self.rendered = False
+
self._model.observe(
self._on_global_codes_change,
"global_codes",
@@ -281,9 +304,12 @@ def __init__(self, model, **kwargs):
"override",
)
+ self.code_widgets = {}
+
def render(self):
if self.rendered:
return
+
self.override_help = ipw.HTML(
"Click to override the resource settings for this plugin."
)
@@ -297,35 +323,33 @@ def render(self):
(self.override, "value"),
)
self.code_widgets_container = ipw.VBox()
- self.code_widgets = {}
+
self.children = [
- ipw.HBox([self.override, self.override_help]),
+ ipw.HBox(
+ children=[
+ self.override,
+ self.override_help,
+ ]
+ ),
self.code_widgets_container,
]
self.rendered = True
- for code_model in self._model.codes.values():
+ # Render any active codes
+ for _, code_model in self._model.get_models():
self._toggle_code(code_model)
+
return self.code_widgets_container
def _on_global_codes_change(self, _):
- self._model.update_code_from_global()
+ self._model.update()
def _on_code_resource_change(self, _):
- """Update the submission blockers and warning messages."""
-
- def _on_override_change(self, change):
- if change["new"]:
- for code_widget in self.code_widgets.values():
- code_widget.num_nodes.disabled = False
- code_widget.num_cpus.disabled = False
- code_widget.code_selection.code_select_dropdown.disabled = False
- else:
- for code_widget in self.code_widgets.values():
- code_widget.num_nodes.disabled = True
- code_widget.num_cpus.disabled = True
- code_widget.code_selection.code_select_dropdown.disabled = True
+ pass
+
+ def _on_override_change(self, _):
+ self._model.reset()
def _toggle_code(self, code_model: CodeModel):
if not self.rendered:
@@ -349,7 +373,6 @@ def _render_code_widget(
code_model: CodeModel,
code_widget: QEAppComputationalResourcesWidget,
):
- code_model.update(None)
ipw.dlink(
(code_model, "options"),
(code_widget.code_selection.code_select_dropdown, "options"),
@@ -359,18 +382,28 @@ def _render_code_widget(
(code_widget.code_selection.code_select_dropdown, "value"),
)
ipw.dlink(
- (code_model, "selected"),
+ (code_model, "override"),
(code_widget.code_selection.code_select_dropdown, "disabled"),
- lambda selected: not selected,
+ lambda override: not override,
)
ipw.link(
(code_model, "num_cpus"),
(code_widget.num_cpus, "value"),
)
+ ipw.dlink(
+ (code_model, "override"),
+ (code_widget.num_cpus, "disabled"),
+ lambda override: not override,
+ )
ipw.link(
(code_model, "num_nodes"),
(code_widget.num_nodes, "value"),
)
+ ipw.dlink(
+ (code_model, "override"),
+ (code_widget.num_nodes, "disabled"),
+ lambda override: not override,
+ )
ipw.link(
(code_model, "ntasks_per_node"),
(code_widget.resource_detail.ntasks_per_node, "value"),
@@ -383,18 +416,47 @@ def _render_code_widget(
(code_model, "max_wallclock_seconds"),
(code_widget.resource_detail.max_wallclock_seconds, "value"),
)
+ ipw.dlink(
+ (code_model, "override"),
+ (code_widget.code_selection.btn_setup_new_code, "disabled"),
+ lambda override: not override,
+ )
+ ipw.dlink(
+ (code_model, "override"),
+ (code_widget.btn_setup_resource_detail, "disabled"),
+ lambda override: not override,
+ )
if isinstance(code_widget, PwCodeResourceSetupWidget):
ipw.link(
- (code_model, "override"),
+ (code_model, "parallelization_override"),
(code_widget.parallelization.override, "value"),
)
+ ipw.dlink(
+ (code_model, "override"),
+ (code_widget.parallelization.override, "disabled"),
+ lambda override: not override,
+ )
ipw.link(
(code_model, "npool"),
(code_widget.parallelization.npool, "value"),
)
+ ipw.dlink(
+ (code_model, "override"),
+ (code_widget.parallelization.npool, "disabled"),
+ lambda override: not override,
+ )
+ code_model.observe(
+ self._on_code_resource_change,
+ [
+ "parallelization_override",
+ "npool",
+ ],
+ )
code_model.observe(
self._on_code_resource_change,
[
+ "options",
+ "selected",
"num_cpus",
"num_nodes",
"ntasks_per_node",
@@ -402,15 +464,7 @@ def _render_code_widget(
"max_wallclock_seconds",
],
)
- # disable the code widget if the override is not set
- code_widget.num_nodes.disabled = not self.override.value
- code_widget.num_cpus.disabled = not self.override.value
- code_widget.code_selection.code_select_dropdown.disabled = (
- not self.override.value
- )
-
code_widgets = self.code_widgets_container.children[:-1] # type: ignore
-
self.code_widgets_container.children = [*code_widgets, code_widget]
code_model.is_rendered = True
diff --git a/src/aiidalab_qe/plugins/bands/__init__.py b/src/aiidalab_qe/plugins/bands/__init__.py
index 0c8b86ad..7b607d1a 100644
--- a/src/aiidalab_qe/plugins/bands/__init__.py
+++ b/src/aiidalab_qe/plugins/bands/__init__.py
@@ -18,7 +18,7 @@ class BandsPluginOutline(PluginOutline):
"panel": BandsConfigurationSettingsPanel,
"model": BandsConfigurationSettingsModel,
},
- "code": {
+ "resources": {
"panel": BandsResourceSettingsPanel,
"model": BandsResourceSettingsModel,
},
diff --git a/src/aiidalab_qe/plugins/bands/code.py b/src/aiidalab_qe/plugins/bands/code.py
index 79571a00..fa11f431 100644
--- a/src/aiidalab_qe/plugins/bands/code.py
+++ b/src/aiidalab_qe/plugins/bands/code.py
@@ -7,18 +7,22 @@
class BandsResourceSettingsModel(ResourceSettingsModel):
"""Model for the band structure plugin."""
- codes = {
- "pw": PwCodeModel(
- name="pw.x",
- description="pw.x",
- default_calc_job_plugin="quantumespresso.pw",
- ),
- "projwfc_bands": CodeModel(
- name="projwfc.x",
- description="projwfc.x",
- default_calc_job_plugin="quantumespresso.projwfc",
- ),
- }
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.add_models(
+ {
+ "pw": PwCodeModel(
+ name="pw.x",
+ description="pw.x",
+ default_calc_job_plugin="quantumespresso.pw",
+ ),
+ "projwfc_bands": CodeModel(
+ name="projwfc.x",
+ description="projwfc.x",
+ default_calc_job_plugin="quantumespresso.projwfc",
+ ),
+ }
+ )
class BandsResourceSettingsPanel(ResourceSettingsPanel[BandsResourceSettingsModel]):
diff --git a/src/aiidalab_qe/plugins/pdos/__init__.py b/src/aiidalab_qe/plugins/pdos/__init__.py
index 280a90c3..9d9d461f 100644
--- a/src/aiidalab_qe/plugins/pdos/__init__.py
+++ b/src/aiidalab_qe/plugins/pdos/__init__.py
@@ -17,7 +17,7 @@ class PdosPluginOutline(PluginOutline):
"panel": PdosConfigurationSettingPanel,
"model": PdosConfigurationSettingsModel,
},
- "code": {
+ "resources": {
"panel": PdosResourceSettingsPanel,
"model": PdosResourceSettingsModel,
},
diff --git a/src/aiidalab_qe/plugins/pdos/code.py b/src/aiidalab_qe/plugins/pdos/code.py
index 1e2095a2..dbe4a68b 100644
--- a/src/aiidalab_qe/plugins/pdos/code.py
+++ b/src/aiidalab_qe/plugins/pdos/code.py
@@ -7,23 +7,27 @@
class PdosResourceSettingsModel(ResourceSettingsModel):
"""Model for the pdos code setting plugin."""
- codes = {
- "pw": PwCodeModel(
- name="pw.x",
- description="pw.x",
- default_calc_job_plugin="quantumespresso.pw",
- ),
- "dos": CodeModel(
- name="dos.x",
- description="dos.x",
- default_calc_job_plugin="quantumespresso.dos",
- ),
- "projwfc": CodeModel(
- name="projwfc.x",
- description="projwfc.x",
- default_calc_job_plugin="quantumespresso.projwfc",
- ),
- }
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.add_models(
+ {
+ "pw": PwCodeModel(
+ name="pw.x",
+ description="pw.x",
+ default_calc_job_plugin="quantumespresso.pw",
+ ),
+ "dos": CodeModel(
+ name="dos.x",
+ description="dos.x",
+ default_calc_job_plugin="quantumespresso.dos",
+ ),
+ "projwfc": CodeModel(
+ name="projwfc.x",
+ description="projwfc.x",
+ default_calc_job_plugin="quantumespresso.projwfc",
+ ),
+ }
+ )
class PdosResourceSettingsPanel(ResourceSettingsPanel[PdosResourceSettingsModel]):
diff --git a/src/aiidalab_qe/plugins/xas/__init__.py b/src/aiidalab_qe/plugins/xas/__init__.py
index 76a4af00..0237636a 100644
--- a/src/aiidalab_qe/plugins/xas/__init__.py
+++ b/src/aiidalab_qe/plugins/xas/__init__.py
@@ -24,7 +24,7 @@ class XasPluginOutline(PluginOutline):
"panel": XasConfigurationSettingsPanel,
"model": XasConfigurationSettingsModel,
},
- "code": {
+ "resources": {
"panel": XasResourceSettingsPanel,
"model": XasResourceSettingsModel,
},
diff --git a/src/aiidalab_qe/plugins/xas/code.py b/src/aiidalab_qe/plugins/xas/code.py
index ff07fb8f..e980b915 100644
--- a/src/aiidalab_qe/plugins/xas/code.py
+++ b/src/aiidalab_qe/plugins/xas/code.py
@@ -7,18 +7,22 @@
class XasResourceSettingsModel(ResourceSettingsModel):
"""Model for the XAS plugin."""
- codes = {
- "pw": PwCodeModel(
- name="pw.x",
- description="pw.x",
- default_calc_job_plugin="quantumespresso.pw",
- ),
- "xspectra": CodeModel(
- name="xspectra.x",
- description="xspectra.x",
- default_calc_job_plugin="quantumespresso.xspectra",
- ),
- }
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.add_models(
+ {
+ "pw": PwCodeModel(
+ name="pw.x",
+ description="pw.x",
+ default_calc_job_plugin="quantumespresso.pw",
+ ),
+ "xspectra": CodeModel(
+ name="xspectra.x",
+ description="xspectra.x",
+ default_calc_job_plugin="quantumespresso.xspectra",
+ ),
+ }
+ )
class XasResourceSettingsPanel(ResourceSettingsPanel[XasResourceSettingsModel]):
diff --git a/tests/conftest.py b/tests/conftest.py
index bacfeb26..0494ff94 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -429,19 +429,12 @@ def app(pw_code, dos_code, projwfc_code, projwfc_bands_code):
app.submit_model.qe_installed = True
# set up codes
- pw_code_model = app.submit_model.get_model("global").get_code("quantumespresso.pw")
- dos_code_model = app.submit_model.get_model("global").get_code(
- "quantumespresso.dos"
- )
- projwfc_code_model = app.submit_model.get_model("global").get_code(
- "quantumespresso.projwfc"
- )
-
- pw_code_model.activate()
- dos_code_model.activate()
- projwfc_code_model.activate()
+ global_model = app.submit_model.get_model("global")
+ global_model.get_model("quantumespresso.pw").activate()
+ global_model.get_model("quantumespresso.dos").activate()
+ global_model.get_model("quantumespresso.projwfc").activate()
- app.submit_model.get_model("global").set_selected_codes(
+ global_model.set_selected_codes(
{
"pw": {"code": pw_code.label},
"dos": {"code": dos_code.label},
@@ -509,7 +502,9 @@ def _submit_app_generator(
app.configure_model.confirm()
app.submit_model.input_structure = generate_structure_data()
- app.submit_model.get_model("global").get_code("quantumespresso.pw").num_cpus = 2
+ app.submit_model.get_model("global").get_model(
+ "quantumespresso.pw"
+ ).num_cpus = 2
return app
@@ -818,7 +813,9 @@ def _generate_qeapp_workchain(
app.configure_model.confirm()
# step 3 setup code and resources
- app.submit_model.get_model("global").get_code("quantumespresso.pw").num_cpus = 4
+ app.submit_model.get_model("global").get_model(
+ "quantumespresso.pw"
+ ).num_cpus = 4
parameters = app.submit_model.get_model_state()
builder = app.submit_model._create_builder(parameters)
diff --git a/tests/test_codes.py b/tests/test_codes.py
index 98421560..8df05c6e 100644
--- a/tests/test_codes.py
+++ b/tests/test_codes.py
@@ -7,7 +7,7 @@ def test_code_not_selected(submit_app_generator):
"""Test if there is an error when the code is not selected."""
app: App = submit_app_generator(properties=["dos"])
model = app.submit_model
- model.get_model("global").get_code("quantumespresso.dos").selected = None
+ model.get_model("global").get_model("quantumespresso.dos").selected = None
# Check builder construction passes without an error
parameters = model.get_model_state()
model._create_builder(parameters)
@@ -19,8 +19,8 @@ def test_set_selected_codes(submit_app_generator):
parameters = app.submit_model.get_model_state()
model = SubmissionStepModel()
_ = SubmitQeAppWorkChainStep(model=model, qe_auto_setup=False)
- for name, code_model in app.submit_model.get_model("global").codes.items():
- model.get_model("global").get_code(name).is_active = code_model.is_active
+ for identifier, code_model in app.submit_model.get_model("global").get_models():
+ model.get_model("global").get_model(identifier).is_active = code_model.is_active
model.qe_installed = True
model.get_model("global").set_selected_codes(parameters["codes"]["global"]["codes"])
assert model.get_selected_codes() == app.submit_model.get_selected_codes()
@@ -32,23 +32,14 @@ def test_update_codes_display(app: App):
"""
app.submit_step.render()
model = app.submit_model
- model.get_model("global").update_active_codes()
- assert (
- app.submit_step.global_code_settings.code_widgets["dos"].layout.display
- == "none"
- )
+ global_model = model.get_model("global")
+ global_model.update_active_codes()
+ global_resources = app.submit_step.global_resources
+ assert global_resources.code_widgets["dos"].layout.display == "none"
model.input_parameters = {"workchain": {"properties": ["pdos"]}}
- model.get_model("global").update_active_codes()
- assert (
- app.submit_step._model.get_model("global")
- .codes["quantumespresso.dos"]
- .is_active
- is True
- )
- assert (
- app.submit_step.global_code_settings.code_widgets["dos"].layout.display
- == "block"
- )
+ global_model.update_active_codes()
+ assert global_model.get_model("quantumespresso.dos").is_active is True
+ assert global_resources.code_widgets["dos"].layout.display == "block"
def test_check_submission_blockers(app: App):
@@ -63,7 +54,7 @@ def test_check_submission_blockers(app: App):
assert len(model.internal_submission_blockers) == 0
# set dos code to None, will introduce another blocker
- dos_code = model.get_model("global").get_code("quantumespresso.dos")
+ dos_code = model.get_model("global").get_model("quantumespresso.dos")
dos_value = dos_code.selected
dos_code.selected = None
model.update_submission_blockers()
@@ -78,16 +69,16 @@ def test_check_submission_blockers(app: App):
def test_qeapp_computational_resources_widget(app: App):
"""Test QEAppComputationalResourcesWidget."""
app.submit_step.render()
- pw_code_model = app.submit_model.get_model("global").get_code("quantumespresso.pw")
- pw_code_widget = app.submit_step.global_code_settings.code_widgets["pw"]
+ global_model = app.submit_model.get_model("global")
+ global_resources = app.submit_step.global_resources
+ pw_code_model = global_model.get_model("quantumespresso.pw")
+ pw_code_widget = global_resources.code_widgets["pw"]
assert pw_code_widget.parallelization.npool.layout.display == "none"
- pw_code_model.override = True
+ pw_code_model.parallelization_override = True
pw_code_model.npool = 2
assert pw_code_widget.parallelization.npool.layout.display == "block"
assert pw_code_widget.parameters == {
- "code": app.submit_step.global_code_settings.code_widgets[
- "pw"
- ].value, # TODO why None?
+ "code": global_resources.code_widgets["pw"].value,
"cpus": 1,
"cpus_per_task": 1,
"max_wallclock_seconds": 43200,
diff --git a/tests/test_submit_qe_workchain.py b/tests/test_submit_qe_workchain.py
index e0c367a6..90d06d58 100644
--- a/tests/test_submit_qe_workchain.py
+++ b/tests/test_submit_qe_workchain.py
@@ -16,6 +16,7 @@ def test_create_builder_default(
app.submit_model._create_builder(parameters)
# since uuid is specific to each run, we remove it from the output
ui_parameters = remove_uuid_fields(parameters)
+ remove_code_options(ui_parameters)
# regression test for the parameters generated by the app
# this parameters are passed to the workchain
data_regression.check(ui_parameters)
@@ -144,16 +145,17 @@ def test_warning_messages(
app: App = submit_app_generator(properties=["bands", "pdos"])
submit_model = app.submit_model
+ global_model = submit_model.get_model("global")
- pw_code = submit_model.get_model("global").get_code("quantumespresso.pw")
+ pw_code = global_model.get_model("quantumespresso.pw")
pw_code.num_cpus = 1
- submit_model.get_model("global").check_resources()
+ global_model.check_resources()
# no warning:
assert submit_model.submission_warning_messages == ""
# now we increase the resources, so we should have the Warning-3
pw_code.num_cpus = len(os.sched_getaffinity(0))
- submit_model.get_model("global").check_resources()
+ global_model.check_resources()
for suggestion in ["avoid_overloading", "go_remote"]:
assert suggestions[suggestion] in submit_model.submission_warning_messages
@@ -161,12 +163,10 @@ def test_warning_messages(
structure = generate_structure_data("H2O-larger")
submit_model.input_structure = structure
pw_code.num_cpus = 1
- submit_model.get_model("global").check_resources()
+ global_model.check_resources()
num_sites = len(structure.sites)
volume = structure.get_cell_volume()
- estimated_CPUs = submit_model.get_model("global")._estimate_min_cpus(
- num_sites, volume
- )
+ estimated_CPUs = global_model._estimate_min_cpus(num_sites, volume)
assert estimated_CPUs == 2
for suggestion in ["more_resources", "change_configuration"]:
assert suggestions[suggestion] in submit_model.submission_warning_messages
@@ -232,3 +232,10 @@ def remove_uuid_fields(data):
else:
# Return the value unchanged if it's not a dictionary or list
return data
+
+
+def remove_code_options(parameters):
+ """Remove the code options from the parameters."""
+ for panel in parameters["codes"].values(): # type: ignore
+ for code in panel["codes"].values():
+ del code["options"]
diff --git a/tests/test_submit_qe_workchain/test_create_builder_default.yml b/tests/test_submit_qe_workchain/test_create_builder_default.yml
index 2d6f756e..dd3f9dac 100644
--- a/tests/test_submit_qe_workchain/test_create_builder_default.yml
+++ b/tests/test_submit_qe_workchain/test_create_builder_default.yml
@@ -37,7 +37,6 @@ codes:
nodes: 1
ntasks_per_node: 1
parallelization: {}
- override: false
global:
codes:
quantumespresso.dos:
@@ -87,7 +86,6 @@ codes:
nodes: 1
ntasks_per_node: 1
parallelization: {}
- override: false
pdos:
nscf_kpoints_distance: 0.1
pdos_degauss: 0.005