Skip to content

Commit

Permalink
CLI: Filter cutoffs from JSON file for family cutoffs set (#132)
Browse files Browse the repository at this point in the history
Currently the `family cutoffs set` command requires that the JSON file
that contains the cutoffs _only_ has the cutoff keys specified for each
element. This is because the cutoffs are validated by the
`RecommendedCutoffMixin.validate_cutoffs()` method, which doesn't allow
for extraneous keys.

However, the  `.json`  file could have been adapted by the user from e.g. the
SSSP `metadata.json` or some other dictionary that also contains other keys.
As long as the cutoff keys are defined, we shouldn't raise an error just because
there are other keys present.

Here we first filter the data loaded from the JSON file for the cutoffs,
before passing the dictionary to the `set_cutoffs` method.
  • Loading branch information
mbercx authored Oct 10, 2022
1 parent aa819c9 commit 5028c9a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
11 changes: 10 additions & 1 deletion src/aiida_pseudo/cli/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,17 @@ def cmd_family_cutoffs_set(family, cutoffs, stringency, unit): # noqa: D301
except ValueError as exception:
raise click.BadParameter(f'`{cutoffs.name}` contains invalid JSON: {exception}', param_hint='CUTOFFS')

cutoffs_dict = {}
for element, values in data.items():
try:
cutoffs_dict[element] = {'cutoff_wfc': values['cutoff_wfc'], 'cutoff_rho': values['cutoff_rho']}
except KeyError as exception:
raise click.BadParameter(
f'`{cutoffs.name}` is missing cutoffs for element `{element}`: {exception}', param_hint='CUTOFFS'
) from exception

try:
family.set_cutoffs(data, stringency, unit=unit)
family.set_cutoffs(cutoffs_dict, stringency, unit=unit)
except ValueError as exception:
raise click.BadParameter(f'`{cutoffs.name}` contains invalid cutoffs: {exception}', param_hint='CUTOFFS')

Expand Down
21 changes: 17 additions & 4 deletions tests/cli/test_family.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# pylint: disable=unused-argument,redefined-outer-name
"""Tests for the command `aiida-pseudo family`."""
from copy import deepcopy
import json

from aiida.orm import Group
Expand Down Expand Up @@ -31,13 +32,15 @@ def test_family_cutoffs_set(run_cli_command, get_pseudo_family, generate_cutoffs
assert "Error: Missing option '-s' / '--stringency'" in result.output
assert sorted(family.get_cutoff_stringencies()) == sorted(['low', 'normal'])

# Invalid cutoffs structure
filepath.write_text(json.dumps({'Ar': {'cutoff_rho': 300}}))
# Missing cutoffs
high_cutoffs = deepcopy(cutoffs_dict['high'])
high_cutoffs['Ar'].pop('cutoff_wfc')
filepath.write_text(json.dumps(high_cutoffs))
result = run_cli_command(cmd_family_cutoffs_set, [family.label, str(filepath), '-s', 'high'], raises=True)
assert 'Error: Invalid value for CUTOFFS:' in result.output
assert 'Error: Invalid value for CUTOFFS: ' in result.output
assert sorted(family.get_cutoff_stringencies()) == sorted(['low', 'normal'])

# Set correct stringency
# Set the high stringency
stringency = 'high'
filepath.write_text(json.dumps(cutoffs_dict['high']))
result = run_cli_command(cmd_family_cutoffs_set, [family.label, str(filepath), '-s', stringency])
Expand All @@ -46,6 +49,16 @@ def test_family_cutoffs_set(run_cli_command, get_pseudo_family, generate_cutoffs
assert sorted(family.get_cutoff_stringencies()) == sorted(['low', 'normal', 'high'])
assert family.get_cutoffs(stringency) == cutoffs_dict[stringency]

# Additional keys in the cutoffs should be accepted and simply ignored
stringency = 'invalid'
high_cutoffs = deepcopy(cutoffs_dict['high'])
high_cutoffs['Ar']['GME'] = 'moon'
filepath.write_text(json.dumps(high_cutoffs))
result = run_cli_command(cmd_family_cutoffs_set, [family.label, str(filepath), '-s', stringency])
assert 'Success: set cutoffs for' in result.output
assert sorted(family.get_cutoff_stringencies()) == sorted(['low', 'normal', 'high', stringency])
assert family.get_cutoffs(stringency) == cutoffs_dict['high']


@pytest.mark.usefixtures('clear_db')
def test_family_cutoffs_set_unit(run_cli_command, get_pseudo_family, generate_cutoffs, tmp_path):
Expand Down

0 comments on commit 5028c9a

Please sign in to comment.