Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/relax and bands wc #254

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
324 changes: 324 additions & 0 deletions aiida_common_workflows/workflows/relax_and_bands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
# -*- coding: utf-8 -*-
"""
Workflow that runs a relaxation and subsequently calculates bands.
It can use any code plugin implementing the common relax workflow and the
common bands workflow.
It also allows the automatic use of `seekpath` in order to get the high
symmetries path for bands.
"""
bosonie marked this conversation as resolved.
Show resolved Hide resolved
import inspect

from aiida import orm
from aiida.common import AttributeDict, exceptions
from aiida.engine import ToContext, WorkChain, calcfunction, if_
from aiida.orm.nodes.data.base import to_aiida_type
from aiida.plugins import WorkflowFactory

from aiida_common_workflows.workflows.bands.generator import CommonBandsInputGenerator
from aiida_common_workflows.workflows.bands.workchain import CommonBandsWorkChain
from aiida_common_workflows.workflows.relax.generator import CommonRelaxInputGenerator, RelaxType
from aiida_common_workflows.workflows.relax.workchain import CommonRelaxWorkChain


def deserialize(inputs):
bosonie marked this conversation as resolved.
Show resolved Hide resolved
"""
Function used to deserialize the inputs of get_builder.

The process of serialization consists in transforming simple python types
into aiida Data types.
In this function we perform the opposite process, we bring the aiida
Data to the correspondingo normal python types.

:param inputs: dictionary containing the elements to deserialize
:return: the dictionary containing the deserialized inputs.
"""

for key, val in inputs.items():
if isinstance(val, (orm.Float, orm.Str, orm.Int, orm.Bool)):
inputs[key] = val.value
if isinstance(val, orm.Dict):
inputs[key] = val.get_dict()
if isinstance(val, orm.List):
inputs[key] = val.get_list()
if isinstance(val, orm.Code):
inputs[key] = val.label
if isinstance(val, dict):
deserialize(val)

return inputs


@calcfunction
def seekpath_explicit_kp_path(structure, seekpath_params):
"""
Return the modified structure of SeekPath and the explicit list of kpoints.
:param structure: StructureData containing the structure information.
:param seekpath_params: Dict of seekpath parameters to be unwrapped as arguments of `get_explicit_kpoints_path`.
"""
from aiida.tools import get_explicit_kpoints_path

results = get_explicit_kpoints_path(structure, **seekpath_params)

return {'structure': results['primitive_structure'], 'kpoints': results['explicit_kpoints']}


def validate_inputs(value, _): #pylint: disable=too-many-branches,too-many-return-statements
"""Validate the entire input namespace."""

process_class = WorkflowFactory(value['relax_sub_process_class'].value)
generator = process_class.get_input_generator()

# Validate that the provided ``relax_inputs`` are valid for the associated input generator.
try:
generator.get_builder(**AttributeDict(value['relax_inputs']))
except Exception as exc: # pylint: disable=broad-except
return f'`{generator.__class__.__name__}.get_builder()` fails for the provided `relax_inputs`: {exc}'

#Validate that the plugin for bands and the relax are the same
bands_plugin = value['bands_sub_process_class'].value.replace('common_workflows.bands.', '')
relax_plugin = value['relax_sub_process_class'].value.replace('common_workflows.relax.', '')
if relax_plugin != bands_plugin:
return 'Different code between relax and bands. Not supported yet.'


def validate_sub_process_class_r(value, _):
"""Validate the sub process class."""
try:
process_class = WorkflowFactory(value.value)
except exceptions.EntryPointError:
return f'`{value.value}` is not a valid or registered workflow entry point.'

if not inspect.isclass(process_class) or not issubclass(process_class, CommonRelaxWorkChain):
return f'`{value.value}` is not a subclass of the `CommonRelaxWorkChain` common workflow.'


def validate_sub_process_class_b(value, _):
"""Validate the sub process class."""
try:
process_class = WorkflowFactory(value.value)
except exceptions.EntryPointError:
return f'`{value.value}` is not a valid or registered workflow entry point.'

if not inspect.isclass(process_class) or not issubclass(process_class, CommonBandsWorkChain):
return f'`{value.value}` is not a subclass of the `CommonBandsWorkChain` common workflow.'
bosonie marked this conversation as resolved.
Show resolved Hide resolved


class RelaxAndBandsWorkChain(WorkChain):
"""
Workflow to carry on a relaxation and subsequently calculate the bands.
"""

@classmethod
def define(cls, spec):
# yapf: disable
super().define(spec)
spec.input_namespace(
'seekpath_parameters',
help='Inputs for the seekpath to be passed to `get_explicit_kpoints_path`.',
)
spec.input(
'seekpath_parameters.reference_distance',
valid_type=orm.Float,
default=lambda: orm.Float(0.025),
serializer=to_aiida_type,
help='Reference target distance between neighboring k-points along the path in units 1/Å.',
)
spec.input(
'seekpath_parameters.symprec',
valid_type=orm.Float,
default=lambda: orm.Float(0.00001),
serializer=to_aiida_type,
help='The symmetry precision used internally by SPGLIB.',
)
spec.input(
'seekpath_parameters.angle_tolerance',
valid_type=orm.Float,
default=lambda: orm.Float(-1.0),
serializer=to_aiida_type,
help='The angle tollerance used internally by SPGLIB.',
)
spec.input(
'seekpath_parameters.threshold',
valid_type=orm.Float,
default=lambda: orm.Float(0.0000001),
serializer=to_aiida_type,
help='The treshold for determining edge cases. Meaning is different depending on bravais lattice.',
)
spec.input(
'seekpath_parameters.with_time_reversal',
valid_type=orm.Bool,
default=lambda: orm.Bool(True),
serializer=to_aiida_type,
help='If False, and the group has no inversion symmetry, additional lines are returned.',
)

namspac = spec.inputs.create_port_namespace('relax_inputs')
namspac.absorb(CommonRelaxInputGenerator.spec().inputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use expose_inputs. If the namespace is fully optional, you set populate_defaults = False:

        spec.expose_inputs(
            CommonRelaxInputGenerator,
            namespace='relax',
            exclude=('structure',),
            namespace_options={'populate_defaults': False, 'required': False}
        )
        spec.expose_inputs(
            CommonRelaxInputGenerator,
            namespace='scf',
            exclude=('structure',),
            namespace_options={'populate_defaults': False, 'required': False}
        )
        spec.expose_inputs(
            CommonBandsInputGenerator,
            namespace='bands'
        )

namspac['protocol'].non_db = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing this manually, I would write a recursive function and call it:

def set_port_non_db(port):
    """Set ``non_db =  True`` for the given port.

    .. note:: If ``port``, is a port namespace, the function will be called recursively on all sub ports.

    :param port: a port or port namespace.
    """
    if isinstance(port, PortNamespace):
        for subport in port.values():
            set_port_non_db(subport)
    elif isinstance(port, InputPort):
        port.non_db = True

Then in the define method, you simply call the method:

set_port_non_db(spec.inputs)

namspac['spin_type'].non_db = True
namspac['relax_type'].non_db = True
namspac['electronic_type'].non_db = True
namspac['magnetization_per_site'].non_db = True
namspac['threshold_forces'].non_db = True
namspac['threshold_stress'].non_db = True
namspac['engines']['relax']['options'].non_db = True

namspac2 = spec.inputs.create_port_namespace('bands_inputs')
namspac2.absorb(CommonBandsInputGenerator.spec().inputs, exclude=('parent_folder'))
namspac2['engines']['bands']['options'].non_db = True
namspac2['bands_kpoints'].required = False

namspac3 = spec.inputs.create_port_namespace('second_relax_inputs')
namspac3.absorb(CommonRelaxInputGenerator.spec().inputs, exclude=('structure'))
namspac3['protocol'].non_db = True
namspac3['spin_type'].non_db = True
namspac3['relax_type'].non_db = True
namspac3['electronic_type'].non_db = True
namspac3['magnetization_per_site'].non_db = True
namspac3['threshold_forces'].non_db = True
namspac3['threshold_stress'].non_db = True
namspac3['engines']['relax']['options'].non_db = True
for key in namspac3:
namspac3[key].required = False
namspac3[key].populate_defaults = False
namspac3[key].default = ()
namspac3['relax_type'].required = True
namspac3['relax_type'].default = RelaxType.NONE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't rely on this through a default because then a user can run anything else. We should always set this no? So just set it in the workchain step when preparing the inputs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this was against the idea to fully expose the inputs of the sub-processes, but ok, this makes sense. And also avoids the problem of my last comments above.


spec.input('relax_sub_process_class',
valid_type=orm.Str,
serializer=to_aiida_type,
validator=validate_sub_process_class_r
)
spec.input('bands_sub_process_class',
valid_type=orm.Str,
serializer=to_aiida_type,
validator=validate_sub_process_class_b
)

spec.inputs.validator = validate_inputs

spec.outline(
cls.initialize,
cls.run_relax,
cls.prepare_bands,
if_(cls.should_run_other_scf)(
cls.fix_inputs,
cls.run_relax
),
cls.run_bands,
cls.inspect_bands
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was wondering if we should have a slightly different outline:

Suggested change
spec.outline(
cls.initialize,
cls.run_relax,
cls.prepare_bands,
if_(cls.should_run_other_scf)(
cls.fix_inputs,
cls.run_relax
),
cls.run_bands,
cls.inspect_bands
)
spec.outline(
cls.setup,
if_(cls.should_run_relax)(
cls.run_relax,
cls.inspect_relax,
),
if_(cls.should_run_scf)(
cls.run_seekpath,
cls.run_scf,
cls.inspect_scf,
),
cls.run_bands,
cls.inspect_bands,
cls.results,
)

I think this is a bit more explicit and clear as to what will happen.

This would make the relax optional, which is still useful, because the user can then compute the bands for an already optimized structure but be sure that the bands are computed in exactly the same way as if it were to include relaxation.

I would also call the second relax run an SCF, because that is what it is, even though we are using the RelaxWorkflow but that is merely because we decided that this is the way an SCF is run.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I am sensing a conceptual difference here. In my mind the user would NOT have the possibility to decide whether to run an extra scf or not. To determine the presence of an extra step is only the fact that seekpath was used (and in the future maybe the decision to us a different code for bands compared to relax). The user could only select some inputs for this step.

Your idea only gives the additional flexibility to run an extra scf even when it is not necessary (when seekpath is not used) and I believe it is a valid use case.

However I have strong doubts about your other suggestion to make the relaxation optional

This would make the relax optional, which is still useful, because the user can then compute the bands for an already optimized structure but be sure that the bands are computed in exactly the same way as if it were to include relaxation.

This can be done already setting the relax_type to None in the relax input. So it is not needed. I do not see any other case when this is useful and the price to pay is very high. In fact I believe that making both the relaxstep and the scf step optional is difficult to support. How do we understand that the user wants only the scf and no the relaxation? If we make populate_defaults = False and required = False for both relaxation and scf, then we have to implement additional logic that checks that one of the two is set, and that the one selected sets at least the structure, protocol, engines. Basically it makes useless to expose the inputs since None of the properties of the inputs port can be used, they must be implemented all over again. It seems to me an additional complication that has no motivation to exist.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that making both optional adds more complexity. But I don't understand your current design. If I understand correctly, in prepare_bands you check whether explicit kpoints were specified in the bands option, in which case you don't rerun the relax workflow as scf. But the first relax workflow and seekpath are always run. So it is practically guaranteed that the structure changes, even if the user sets RelaxType.NONE. Even if the relax doesn't change the structure, it is possible that seekpath still normalizes the structure, which will cause the k-points to longer be commensurate.

I guess what I am saying is that we should make the relax+seekpath optional, or otherwise allowing to directly specify k-points doesn't make any sense. Or we just require relax+seekpath+scf always and just don't support the user specifying their own k-points. At least that would be consistent.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. The relaxation always run. The use of SeeKpath is optional. If you set a kpoints for bands in input, SeekPath is not used. This is a valid use case, for instance, I know the kpoints path along which I want to calculate the bands, and I also want a relaxation of just the atoms coordinates. Of course it does not make much sense to specify a kpoints path and also allow relaxation with variable cell, since you will not know the final cell shape. But the rest is perfectly fine.

When you DO NOT use Seekpath (i.e. you specify kpoints for bands in input), you do not need an extra scf. In fact, you pass to the bands calculation the remote_folder of the relaxation and this guarantees to copy the correct final structure and the density matrix / wave-functions. However, based to your suggestion, I understood that the possibility for the user to trigger an extra scf anyway, is a valid use case. For instance a user wants to do a final scf with "precise" protocol before calculating bands.

To summarize, three options are available

  1. relax+seekpath+scf+bands
  2. relax+kp_in_input+bands
  3. relax+kp_in_input+scf+bands

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course it does not make much sense to specify a kpoints path and also allow relaxation with variable cell, since you will not know the final cell shape. But the rest is perfectly fine.

This is what was confusing me. Should we maybe have validation in there to prevent this? If you specify kpoints, then the relax cannot change the cell.

To summarize, three options are available

relax+seekpath+scf+bands
relax+kp_in_input+bands
relax+kp_in_input+scf+bands

This is great, this makes things a lot clearer. I would propose to at least describe this in the docstring of the workchain. And then maybe adapt the outline to the following that makes it a lot clearer what is happening:

        spec.outline(
            cls.setup,
            cls.run_relax,
            cls.inspect_relax,
            if_(cls.should_run_seekpath)(
                cls.run_seekpath,
                cls.run_scf,
                cls.inspect_scf,
            ).elif_(cls.should_run_scf)(
                cls.run_scf,
                cls.inspect_scf,
            ),
            cls.run_bands,
            cls.inspect_bands,
            cls.results,
        )


spec.output('final_structure', valid_type=orm.StructureData, help='The final structure.')
bosonie marked this conversation as resolved.
Show resolved Hide resolved
spec.output('bands', valid_type=orm.BandsData,
help='The computed total energy of the relaxed structures at each scaling factor.')
spec.exit_code(400, 'ERROR_SUB_PROCESS_FAILED',
message='At least one of the sub processes did not finish successfully.')


def initialize(self):
"""
Initialize some variables that will be used and modified in the workchain
"""
self.ctx.inputs = AttributeDict(self.inputs.relax_inputs)
self.ctx.need_other_scf = False


def run_relax(self):
"""
Run the relaxation workchain.
"""
process_class = WorkflowFactory(self.inputs.relax_sub_process_class.value)

builder = process_class.get_input_generator().get_builder(
**self.ctx.inputs
)
#builder._update(**self.inputs.get('relax_sub_process', {})) # pylint: disable=protected-access

self.report(f'submitting `{builder.process_class.__name__}` for relaxation.')
running = self.submit(builder)

return ToContext(workchain_relax=running)


def prepare_bands(self):
"""
Check that the first workchain finished successfully or abort the workchain.
Analyze the `bands_inputs` namespace and decide whether to call SeeKpath or not.
When SeeKpath is called in order to create the bands high symmetries path,
the structure might change, therefore a new scf calculation should be
performed before calculating bands.
"""
if not self.ctx.workchain_relax.is_finished_ok:
self.report('Relaxation did not finish successful so aborting the workchain.')
return self.exit_codes.ERROR_SUB_PROCESS_FAILED.format(cls=self.inputs.relax_sub_process_class.value) # pylint: disable=no-member
if 'relaxed_structure' in self.ctx.workchain_relax.outputs:
structure = self.ctx.workchain_relax.outputs.relaxed_structure
else:
structure = self.ctx.inputs['structure']

if 'bands_kpoints' not in self.inputs.bands_inputs:
self.report('Using SekPath to create kpoints for bands. Structure might change.')
seekpath_dict_to_aiida = orm.Dict(dict=deserialize(AttributeDict(self.inputs.seekpath_parameters)))
res = seekpath_explicit_kp_path(structure, seekpath_dict_to_aiida)
self.ctx.inputs['structure'] = res['structure']
self.ctx.bandskpoints = res['kpoints']
self.ctx.need_other_scf = True
else:
self.report('Kpoints for bands in inputs detected.')
self.ctx.need_other_scf = False
self.ctx.bandskpoints = self.inputs.bands_inputs['bands_kpoints']

if self.ctx.need_other_scf:
self.report('A new scf cycle needed')

def should_run_other_scf(self):
"""
Return the bool variable that triggers a further scf calculation before the bands run.
"""
return self.ctx.need_other_scf

def fix_inputs(self):
"""
Add to the inputs of the second relaxation whatever optional overrides
specified by users in `second_relax_inputs` namespace.
"""
for key in self.ctx.inputs:
if key in self.inputs.second_relax_inputs:
self.ctx.inputs[key] = self.inputs.second_relax_inputs[key]


def run_bands(self):
"""
Run the sub process to obtain the bands.
"""
rel_wc = self.ctx.workchain_relax

process_class = WorkflowFactory(self.inputs.bands_sub_process_class.value)

builder = process_class.get_input_generator().get_builder(
bands_kpoints=self.ctx.bandskpoints,
parent_folder=rel_wc.outputs.remote_folder,
engines=AttributeDict(self.inputs.bands_inputs['engines']),
)

#builder._update(**self.inputs.get('bands_sub_process', {})) # pylint: disable=protected-access

self.report(f'submitting `{builder.process_class.__name__}` for bands.')
running = self.submit(builder)

return ToContext(workchain_bands=running)

def inspect_bands(self):
"""
Check the success of the bands calculation and return outputs.
"""
if not self.ctx.workchain_bands.is_finished_ok:
self.report('Bands calculation did not finish successful so aborting the workchain.')
return self.exit_codes.ERROR_SUB_PROCESS_FAILED.format(cls=self.inputs.bands_sub_process_class)

self.report('Bands calculation finished successfully, returning outputs')

self.out('final_structure', self.ctx.workchain_bands.inputs.structure)
self.out('bands', self.ctx.workchain_bands.outputs.bands)
1 change: 1 addition & 0 deletions setup.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"aiida.workflows": [
"common_workflows.dissociation_curve = aiida_common_workflows.workflows.dissociation:DissociationCurveWorkChain",
"common_workflows.eos = aiida_common_workflows.workflows.eos:EquationOfStateWorkChain",
"common_workflows.relax_and_bands = aiida_common_workflows.workflows.common_workflows.relax_and_bands:RelaxAndBandsWorkChain",
"common_workflows.relax.abinit = aiida_common_workflows.workflows.relax.abinit.workchain:AbinitCommonRelaxWorkChain",
"common_workflows.relax.bigdft = aiida_common_workflows.workflows.relax.bigdft.workchain:BigDftCommonRelaxWorkChain",
"common_workflows.relax.castep = aiida_common_workflows.workflows.relax.castep.workchain:CastepCommonRelaxWorkChain",
Expand Down