From e6ca8d20d3598f42d7d6454f871c93d1fd2fdab4 Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Mon, 18 Nov 2024 16:58:30 +0100 Subject: [PATCH 1/3] Migrate to same linting and formatting as in aiida-core We migrate to ruff for the linter and formatter with the settings of aiida-core as well as the pre-commit hooks used for formatting. This makes the repository more compatible with aiida-core and also allows solves more linter problems automatically which should speed up development --- .pre-commit-config.yaml | 89 ++++++++++++++++++++++++----------------- pyproject.toml | 48 +++++++++++++--------- 2 files changed, 81 insertions(+), 56 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8da31be..8662c91 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,24 +1,24 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.6.0 hooks: - - id: end-of-file-fixer - - id: fix-encoding-pragma - - id: mixed-line-ending - - id: trailing-whitespace - - id: check-json - - id: check-yaml - - - repo: https://github.com/pycqa/isort - rev: '5.12.0' - hooks: - - id: isort - - - repo: https://github.com/psf/black - rev: '22.10.0' - hooks: - - id: black - + - id: check-merge-conflict + - id: check-yaml + - id: double-quote-string-fixer + - id: end-of-file-fixer + exclude: &exclude_pre_commit_hooks > + (?x)^( + tests/.*(? + - id: check-github-workflows + + - repo: https://github.com/ikamensh/flynt/ + rev: 1.0.1 + hooks: + - id: flynt + args: [--line-length=120, --fail-on-change] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.0 + hooks: + - id: ruff-format + exclude: &exclude_ruff > (?x)^( - docs/.*| + docs/source/topics/processes/include/snippets/functions/parse_docstring_expose_ipython.py| + docs/source/topics/processes/include/snippets/functions/signature_plain_python_call_illegal.py| + )$ + - id: ruff + exclude: *exclude_ruff + args: [--fix, --exit-non-zero-on-fix, --show-fixes] + + - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks + rev: v2.13.0 + hooks: + - id: pretty-format-toml + args: [--autofix] + - id: pretty-format-yaml + args: [--autofix] + exclude: >- + (?x)^( + tests/.*| + environment.yml| )$ diff --git a/pyproject.toml b/pyproject.toml index 9e18e15..6a37b83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,26 +82,36 @@ exclude = [ 'codecov.yml', ] -[tool.isort] -profile = 'black' - -[tool.pylint.master] -extension-pkg-whitelist = ['pydantic'] - -[tool.pylint.format] -max-line-length = 125 - -[tool.pylint.messages_control] -disable = [ - 'duplicate-code', - 'fixme', - 'invalid-name', - 'too-many-ancestors', - 'too-many-arguments', +[tool.ruff] +line-length = 120 + +[tool.ruff.format] +quote-style = 'single' + +[tool.ruff.lint] +ignore = [ + 'F403', # Star imports unable to detect undefined names + 'F405', # Import may be undefined or defined from star imports + 'PLR0911', # Too many return statements + 'PLR0912', # Too many branches + 'PLR0913', # Too many arguments in function definition + 'PLR0915', # Too many statements + 'PLR2004', # Magic value used in comparison + 'RUF005', # Consider iterable unpacking instead of concatenation + 'RUF012', # Mutable class attributes should be annotated with `typing.ClassVar` +] +select = [ + 'E', # pydocstyle + 'W', # pydocstyle + 'F', # pyflakes + 'I', # isort + 'N', # pep8-naming + 'PLC', # pylint-convention + 'PLE', # pylint-error + 'PLR', # pylint-refactor + 'PLW', # pylint-warning + 'RUF' # ruff ] - -[tool.pylint.similarities] -ignore-imports = 'yes' [tool.pytest.ini_options] python_files = 'test_*.py example_*.py' From bf53b22d448d6e0e2e774ea57bea48b902596c0a Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Mon, 18 Nov 2024 17:01:13 +0100 Subject: [PATCH 2/3] Remove following pep8 name convention In several places we do not follow the pep8 naming convention but rather GraphQL one. We therefore add the N80* linter ignores to disable the linter for these cases. --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6a37b83..667f209 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,7 +105,6 @@ select = [ 'W', # pydocstyle 'F', # pyflakes 'I', # isort - 'N', # pep8-naming 'PLC', # pylint-convention 'PLE', # pylint-error 'PLR', # pylint-refactor From 11d8aad9a07336da61eec7fe217c530d5c08384c Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Wed, 20 Nov 2024 10:03:43 +0100 Subject: [PATCH 3/3] Fix errors produced by formatter and linter migration This commit is mainly the result of running `pre-commit run --all-files` with the previous changes in the formatter and linter. There is some additional manual changes to fix linebreaks. --- .github/workflows/cd.yml | 66 +++--- .github/workflows/ci.yml | 114 ++++----- .github/workflows/validate_release_tag.py | 26 +- .pre-commit-config.yaml | 142 +++++------ aiida_restapi/__init__.py | 5 +- aiida_restapi/aiida_db_mappings.py | 150 ++++++------ aiida_restapi/config.py | 22 +- aiida_restapi/filter_syntax.py | 56 ++--- aiida_restapi/graphql/basic.py | 12 +- aiida_restapi/graphql/comments.py | 19 +- aiida_restapi/graphql/computers.py | 25 +- aiida_restapi/graphql/config.py | 1 - aiida_restapi/graphql/entry_points.py | 16 +- aiida_restapi/graphql/groups.py | 19 +- aiida_restapi/graphql/logs.py | 19 +- aiida_restapi/graphql/main.py | 2 +- aiida_restapi/graphql/nodes.py | 122 ++++------ aiida_restapi/graphql/orm_factories.py | 112 ++++----- aiida_restapi/graphql/plugins.py | 26 +- aiida_restapi/graphql/sphinx_ext.py | 7 +- aiida_restapi/graphql/users.py | 32 +-- aiida_restapi/graphql/utils.py | 8 +- aiida_restapi/main.py | 4 +- aiida_restapi/models.py | 215 ++++++++--------- aiida_restapi/routers/auth.py | 39 ++- aiida_restapi/routers/computers.py | 16 +- aiida_restapi/routers/daemon.py | 22 +- aiida_restapi/routers/groups.py | 16 +- aiida_restapi/routers/nodes.py | 36 ++- aiida_restapi/routers/process.py | 40 ++-- aiida_restapi/routers/users.py | 14 +- aiida_restapi/utils.py | 2 +- docs/source/conf.py | 111 ++++----- examples/daemon_management/script.py | 60 +++-- examples/process_management/script.py | 96 ++++---- examples/submit_quantumespresso_pw/script.py | 154 ++++++------ pyproject.toml | 166 ++++++------- tests/__init__.py | 4 +- tests/conftest.py | 160 ++++++------- tests/test_auth.py | 12 +- tests/test_computers.py | 43 ++-- tests/test_daemon.py | 34 +-- tests/test_filter_syntax.py | 44 ++-- tests/test_graphql/test_basic.py | 14 +- tests/test_graphql/test_comments.py | 22 +- tests/test_graphql/test_computers.py | 22 +- tests/test_graphql/test_entry_points.py | 23 +- tests/test_graphql/test_full.py | 6 +- tests/test_graphql/test_groups.py | 24 +- tests/test_graphql/test_logs.py | 10 +- tests/test_graphql/test_nodes.py | 70 +++--- tests/test_graphql/test_users.py | 10 +- tests/test_groups.py | 39 ++- tests/test_models.py | 28 +-- tests/test_nodes.py | 235 +++++++++---------- tests/test_processes.py | 138 +++++------ tests/test_users.py | 23 +- 57 files changed, 1350 insertions(+), 1603 deletions(-) diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index dcc3827..3c9efff 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -1,49 +1,49 @@ name: cd on: - push: - tags: - - 'v[0-9]+.[0-9]+.[0-9]+*' + push: + tags: + - v[0-9]+.[0-9]+.[0-9]+* jobs: - validate-release-tag: + validate-release-tag: - if: github.repository == 'aiidateam/aiida-restapi' - runs-on: ubuntu-latest + if: github.repository == 'aiidateam/aiida-restapi' + runs-on: ubuntu-latest - steps: - - name: Checkout source - uses: actions/checkout@v2 + steps: + - name: Checkout source + uses: actions/checkout@v2 - - name: Set up Python 3.9 - uses: actions/setup-python@v2 - with: - python-version: '3.9' + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: '3.9' - - name: Validate the tag version against the package version - run: python .github/workflows/validate_release_tag.py $GITHUB_REF + - name: Validate the tag version against the package version + run: python .github/workflows/validate_release_tag.py $GITHUB_REF - publish: + publish: - name: Publish to PyPI - needs: [validate-release-tag] - runs-on: ubuntu-latest + name: Publish to PyPI + needs: [validate-release-tag] + runs-on: ubuntu-latest - steps: - - name: Checkout source - uses: actions/checkout@v2 + steps: + - name: Checkout source + uses: actions/checkout@v2 - - name: Set up Python 3.9 - uses: actions/setup-python@v2 - with: - python-version: '3.9' + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: '3.9' - - name: Install flit - run: pip install flit~=3.4 + - name: Install flit + run: pip install flit~=3.4 - - name: Build and publish - run: flit publish - env: - FLIT_USERNAME: __token__ - FLIT_PASSWORD: ${{ secrets.pypi_token }} + - name: Build and publish + run: flit publish + env: + FLIT_USERNAME: __token__ + FLIT_PASSWORD: ${{ secrets.pypi_token }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 96a25f9..490a9c7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,73 +4,73 @@ on: [push, pull_request] jobs: - pre-commit: + pre-commit: - runs-on: ubuntu-latest + runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 + steps: + - uses: actions/checkout@v2 - - name: Set up Python 3.9 - uses: actions/setup-python@v2 - with: - python-version: '3.9' + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: '3.9' - - uses: pre-commit/action@v2.0.0 + - uses: pre-commit/action@v2.0.0 - tests: + tests: - runs-on: ubuntu-latest - timeout-minutes: 30 - strategy: - matrix: - python-version: ['3.9', '3.10', '3.11'] + runs-on: ubuntu-latest + timeout-minutes: 30 + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11'] - services: - postgres: - image: postgres:latest - env: - POSTGRES_DB: test_db - POSTGRES_PASSWORD: '' - POSTGRES_HOST_AUTH_METHOD: trust - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - ports: - - 5432:5432 - rabbitmq: - image: rabbitmq:latest - ports: - - 5672:5672 + services: + postgres: + image: postgres:latest + env: + POSTGRES_DB: test_db + POSTGRES_PASSWORD: '' + POSTGRES_HOST_AUTH_METHOD: trust + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + rabbitmq: + image: rabbitmq:latest + ports: + - 5672:5672 - steps: - - uses: actions/checkout@v2 + steps: + - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} - - name: Install Python dependencies - run: | - pip install --upgrade pip - pip install -e .[testing,auth] + - name: Install Python dependencies + run: | + pip install --upgrade pip + pip install -e .[testing,auth] - - name: Run test suite - env: + - name: Run test suite + env: # show timings of tests - PYTEST_ADDOPTS: "--durations=0" - AIIDA_WARN_v3: true - run: pytest --cov aiida_restapi --cov-report=xml + PYTEST_ADDOPTS: --durations=0 + AIIDA_WARN_v3: true + run: pytest --cov aiida_restapi --cov-report=xml - - name: Upload to Codecov - if: matrix.python-version == 3.8 - uses: codecov/codecov-action@v4 - with: - name: pytests - flags: pytests - file: ./coverage.xml - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: true + - name: Upload to Codecov + if: matrix.python-version == 3.8 + uses: codecov/codecov-action@v4 + with: + name: pytests + flags: pytests + file: ./coverage.xml + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: true diff --git a/.github/workflows/validate_release_tag.py b/.github/workflows/validate_release_tag.py index e333933..74f28f4 100644 --- a/.github/workflows/validate_release_tag.py +++ b/.github/workflows/validate_release_tag.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Validate that the version in the tag label matches the version of the package.""" + import argparse import ast from pathlib import Path @@ -13,7 +13,7 @@ def get_version_from_module(content: str) -> str: try: module = ast.parse(content) except SyntaxError as exception: - raise IOError("Unable to parse module.") from exception + raise IOError('Unable to parse module.') from exception try: return next( @@ -21,25 +21,19 @@ def get_version_from_module(content: str) -> str: for statement in module.body if isinstance(statement, ast.Assign) for target in statement.targets - if isinstance(target, ast.Name) and target.id == "__version__" + if isinstance(target, ast.Name) and target.id == '__version__' ) except StopIteration as exception: - raise IOError( - "Unable to find the `__version__` attribute in the module." - ) from exception + raise IOError('Unable to find the `__version__` attribute in the module.') from exception -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("GITHUB_REF", help="The GITHUB_REF environmental variable") + parser.add_argument('GITHUB_REF', help='The GITHUB_REF environmental variable') args = parser.parse_args() - tag_prefix = "refs/tags/v" - assert args.GITHUB_REF.startswith( - tag_prefix - ), f'GITHUB_REF should start with "{tag_prefix}": {args.GITHUB_REF}' + tag_prefix = 'refs/tags/v' + assert args.GITHUB_REF.startswith(tag_prefix), f'GITHUB_REF should start with "{tag_prefix}": {args.GITHUB_REF}' tag_version = args.GITHUB_REF.removeprefix(tag_prefix) - package_version = get_version_from_module( - Path("aiida_restapi/__init__.py").read_text(encoding="utf-8") - ) - error_message = f"The tag version `{tag_version}` is different from the package version `{package_version}`" + package_version = get_version_from_module(Path('aiida_restapi/__init__.py').read_text(encoding='utf-8')) + error_message = f'The tag version `{tag_version}` is different from the package version `{package_version}`' assert tag_version == package_version, error_message diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8662c91..97b86f1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,73 +1,73 @@ repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 - hooks: - - id: check-merge-conflict - - id: check-yaml - - id: double-quote-string-fixer - - id: end-of-file-fixer - exclude: &exclude_pre_commit_hooks > - (?x)^( - tests/.*(? - (?x)^( - docs/.*| - examples/.*| - tests/.*| - conftest.py - )$ - - - repo: https://github.com/python-jsonschema/check-jsonschema - rev: 0.28.6 - hooks: - - id: check-github-workflows - - - repo: https://github.com/ikamensh/flynt/ - rev: 1.0.1 - hooks: - - id: flynt - args: [--line-length=120, --fail-on-change] - - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.0 - hooks: - - id: ruff-format - exclude: &exclude_ruff > - (?x)^( - docs/source/topics/processes/include/snippets/functions/parse_docstring_expose_ipython.py| - docs/source/topics/processes/include/snippets/functions/signature_plain_python_call_illegal.py| - )$ - - id: ruff - exclude: *exclude_ruff - args: [--fix, --exit-non-zero-on-fix, --show-fixes] - - - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.13.0 - hooks: - - id: pretty-format-toml - args: [--autofix] - - id: pretty-format-yaml - args: [--autofix] - exclude: >- - (?x)^( +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-merge-conflict + - id: check-yaml + - id: double-quote-string-fixer + - id: end-of-file-fixer + exclude: &exclude_pre_commit_hooks > + (?x)^( + tests/.*(? + (?x)^( + docs/.*| + examples/.*| tests/.*| - environment.yml| - )$ + conftest.py + )$ + +- repo: https://github.com/python-jsonschema/check-jsonschema + rev: 0.28.6 + hooks: + - id: check-github-workflows + +- repo: https://github.com/ikamensh/flynt/ + rev: 1.0.1 + hooks: + - id: flynt + args: [--line-length=120, --fail-on-change] + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.0 + hooks: + - id: ruff-format + exclude: &exclude_ruff > + (?x)^( + docs/source/topics/processes/include/snippets/functions/parse_docstring_expose_ipython.py| + docs/source/topics/processes/include/snippets/functions/signature_plain_python_call_illegal.py| + )$ + - id: ruff + exclude: *exclude_ruff + args: [--fix, --exit-non-zero-on-fix, --show-fixes] + +- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks + rev: v2.13.0 + hooks: + - id: pretty-format-toml + args: [--autofix] + - id: pretty-format-yaml + args: [--autofix] + exclude: >- + (?x)^( + tests/.*| + environment.yml| + )$ diff --git a/aiida_restapi/__init__.py b/aiida_restapi/__init__.py index f170106..d3618b6 100644 --- a/aiida_restapi/__init__.py +++ b/aiida_restapi/__init__.py @@ -1,6 +1,5 @@ -# -*- coding: utf-8 -*- """AiiDA REST API for data queries and workflow managment.""" -__version__ = "0.1.0a1" +__version__ = '0.1.0a1' -from .main import app +from .main import app # noqa: F401 diff --git a/aiida_restapi/aiida_db_mappings.py b/aiida_restapi/aiida_db_mappings.py index 46eb6e7..b0457af 100644 --- a/aiida_restapi/aiida_db_mappings.py +++ b/aiida_restapi/aiida_db_mappings.py @@ -1,10 +1,10 @@ -# -*- coding: utf-8 -*- """The 'source of truth' for mapping AiiDA's database table models to pydantic models. Note in the future we may want to do this programmatically, however, there are two issues: - AiiDA uses both SQLAlchemy and Django backends, so one would need to be chosen - Neither model includes descriptions of fields """ + from datetime import datetime from typing import Dict, Optional, Type from uuid import UUID @@ -16,129 +16,119 @@ class AuthInfo(BaseModel): """AiiDA AuthInfo SQL table fields.""" - id: int = Field(description="Unique id (pk)") - aiidauser_id: int = Field(description="Relates to user") - dbcomputer_id: int = Field(description="Relates to computer") - metadata: Json = Field(description="Metadata of the authorisation") - auth_params: Json = Field(description="Parameters of the authorisation") - enabled: bool = Field(description="Whether the computer is enabled", default=True) + id: int = Field(description='Unique id (pk)') + aiidauser_id: int = Field(description='Relates to user') + dbcomputer_id: int = Field(description='Relates to computer') + metadata: Json = Field(description='Metadata of the authorisation') + auth_params: Json = Field(description='Parameters of the authorisation') + enabled: bool = Field(description='Whether the computer is enabled', default=True) class Comment(BaseModel): """AiiDA Comment SQL table fields.""" - id: int = Field(description="Unique id (pk)") - uuid: UUID = Field(description="Universally unique id") - ctime: datetime = Field(description="Creation time") - mtime: datetime = Field(description="Last modification time") - content: Optional[str] = Field(None, description="Content of the comment") - user_id: int = Field(description="Created by user id (pk)") - dbnode_id: int = Field(description="Associated node id (pk)") + id: int = Field(description='Unique id (pk)') + uuid: UUID = Field(description='Universally unique id') + ctime: datetime = Field(description='Creation time') + mtime: datetime = Field(description='Last modification time') + content: Optional[str] = Field(None, description='Content of the comment') + user_id: int = Field(description='Created by user id (pk)') + dbnode_id: int = Field(description='Associated node id (pk)') class Computer(BaseModel): """AiiDA Computer SQL table fields.""" - id: int = Field(description="Unique id (pk)") - uuid: UUID = Field(description="Universally unique id") - label: str = Field(description="Computer name") - hostname: str = Field(description="Identifier for the computer within the network") - description: Optional[str] = Field(None, description="Description of the computer") - scheduler_type: str = Field( - description="Scheduler plugin type, to manage compute jobs" - ) - transport_type: str = Field( - description="Transport plugin type, to manage file transfers" - ) - metadata: Json = Field(description="Metadata of the computer") + id: int = Field(description='Unique id (pk)') + uuid: UUID = Field(description='Universally unique id') + label: str = Field(description='Computer name') + hostname: str = Field(description='Identifier for the computer within the network') + description: Optional[str] = Field(None, description='Description of the computer') + scheduler_type: str = Field(description='Scheduler plugin type, to manage compute jobs') + transport_type: str = Field(description='Transport plugin type, to manage file transfers') + metadata: Json = Field(description='Metadata of the computer') class Group(BaseModel): """AiiDA Group SQL table fields.""" - id: int = Field(description="Unique id (pk)") - uuid: UUID = Field(description="Universally unique id") - label: str = Field(description="Label of group") - type_string: str = Field(description="type of the group") - time: datetime = Field(description="Created time") - description: Optional[str] = Field(None, description="Description of group") - extras: Json = Field(description="extra data about for the group") - user_id: int = Field(description="Created by user id (pk)") + id: int = Field(description='Unique id (pk)') + uuid: UUID = Field(description='Universally unique id') + label: str = Field(description='Label of group') + type_string: str = Field(description='type of the group') + time: datetime = Field(description='Created time') + description: Optional[str] = Field(None, description='Description of group') + extras: Json = Field(description='extra data about for the group') + user_id: int = Field(description='Created by user id (pk)') class Log(BaseModel): """AiiDA Log SQL table fields.""" - id: int = Field(description="Unique id (pk)") - uuid: UUID = Field(description="Universally unique id") - time: datetime = Field(description="Creation time") - loggername: str = Field(description="The loggers name") - levelname: str = Field(description="The log level") - message: Optional[str] = Field(None, description="The log message") - metadata: Json = Field(description="Metadata associated with the log") - dbnode_id: int = Field(description="Associated node id (pk)") + id: int = Field(description='Unique id (pk)') + uuid: UUID = Field(description='Universally unique id') + time: datetime = Field(description='Creation time') + loggername: str = Field(description='The loggers name') + levelname: str = Field(description='The log level') + message: Optional[str] = Field(None, description='The log message') + metadata: Json = Field(description='Metadata associated with the log') + dbnode_id: int = Field(description='Associated node id (pk)') class Node(BaseModel): """AiiDA Node SQL table fields.""" - id: int = Field(description="Unique id (pk)") - uuid: UUID = Field(description="Universally unique id") - node_type: str = Field(description="Node type") - process_type: str = Field(description="Process type") - label: str = Field(description="Label of node") - description: str = Field(description="Description of node") - ctime: datetime = Field(description="Creation time") - mtime: datetime = Field(description="Last modification time") - user_id: int = Field(description="Created by user id (pk)") - dbcomputer_id: Optional[int] = Field( - None, description="Associated computer id (pk)" - ) + id: int = Field(description='Unique id (pk)') + uuid: UUID = Field(description='Universally unique id') + node_type: str = Field(description='Node type') + process_type: str = Field(description='Process type') + label: str = Field(description='Label of node') + description: str = Field(description='Description of node') + ctime: datetime = Field(description='Creation time') + mtime: datetime = Field(description='Last modification time') + user_id: int = Field(description='Created by user id (pk)') + dbcomputer_id: Optional[int] = Field(None, description='Associated computer id (pk)') attributes: Json = Field( - description="Attributes of the node (immutable after storing the node)", + description='Attributes of the node (immutable after storing the node)', ) extras: Json = Field( - description="Extra attributes of the node (mutable)", + description='Extra attributes of the node (mutable)', ) class User(BaseModel): """AiiDA User SQL table fields.""" - id: int = Field(description="Unique id (pk)") - email: str = Field(description="Email address of the user") - first_name: Optional[str] = Field(None, description="First name of the user") - last_name: Optional[str] = Field(None, description="Last name of the user") - institution: Optional[str] = Field( - None, description="Host institution or workplace of the user" - ) + id: int = Field(description='Unique id (pk)') + email: str = Field(description='Email address of the user') + first_name: Optional[str] = Field(None, description='First name of the user') + last_name: Optional[str] = Field(None, description='Last name of the user') + institution: Optional[str] = Field(None, description='Host institution or workplace of the user') class Link(BaseModel): """AiiDA Link SQL table fields.""" - id: int = Field(description="Unique id (pk)") - input_id: int = Field(description="Unique id (pk) of the input node") - output_id: int = Field(description="Unique id (pk) of the output node") - label: Optional[str] = Field(None, description="The label of the link") - type: str = Field(description="The type of link") + id: int = Field(description='Unique id (pk)') + input_id: int = Field(description='Unique id (pk) of the input node') + output_id: int = Field(description='Unique id (pk) of the output node') + label: Optional[str] = Field(None, description='The label of the link') + type: str = Field(description='The type of link') ORM_MAPPING: Dict[str, Type[BaseModel]] = { - "AuthInfo": AuthInfo, - "Comment": Comment, - "Computer": Computer, - "Group": Group, - "Log": Log, - "Node": Node, - "User": User, - "Link": Link, + 'AuthInfo': AuthInfo, + 'Comment': Comment, + 'Computer': Computer, + 'Group': Group, + 'Log': Log, + 'Node': Node, + 'User': User, + 'Link': Link, } -def get_model_from_orm( - orm_cls: Type[orm.Entity], allow_subclasses: bool = True -) -> Type[BaseModel]: +def get_model_from_orm(orm_cls: Type[orm.Entity], allow_subclasses: bool = True) -> Type[BaseModel]: """Return the pydantic model related to the orm class. :param allow_subclasses: Return the base class mapping for subclasses @@ -149,4 +139,4 @@ def get_model_from_orm( return Node if allow_subclasses and issubclass(orm_cls, orm.Group): return Group - raise KeyError(f"{orm_cls}") + raise KeyError(f'{orm_cls}') diff --git a/aiida_restapi/config.py b/aiida_restapi/config.py index a1f16da..fead108 100644 --- a/aiida_restapi/config.py +++ b/aiida_restapi/config.py @@ -1,19 +1,19 @@ -# -*- coding: utf-8 -*- """Configuration of API""" + # to get a string like this run: # openssl rand -hex 32 -SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" -ALGORITHM = "HS256" +SECRET_KEY = '09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7' +ALGORITHM = 'HS256' ACCESS_TOKEN_EXPIRE_MINUTES = 30 fake_users_db = { - "johndoe@example.com": { - "pk": 23, - "first_name": "John", - "last_name": "Doe", - "institution": "EPFL", - "email": "johndoe@example.com", - "hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", - "disabled": False, + 'johndoe@example.com': { + 'pk': 23, + 'first_name': 'John', + 'last_name': 'Doe', + 'institution': 'EPFL', + 'email': 'johndoe@example.com', + 'hashed_password': '$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW', + 'disabled': False, } } diff --git a/aiida_restapi/filter_syntax.py b/aiida_restapi/filter_syntax.py index 6859ea0..ba7010f 100644 --- a/aiida_restapi/filter_syntax.py +++ b/aiida_restapi/filter_syntax.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """The AiiDA QueryBuilder filter grammar resolver. Converts the string into a dict that can be passed to @@ -7,6 +6,7 @@ This grammar was originally adapted from: https://github.com/Materials-Consortia/OPTIMADE/blob/master/optimade.rst#the-filter-language-ebnf-grammar """ + # pylint: disable=too-many-branches from importlib import resources from typing import Any, Callable, Dict, List, Optional, Union @@ -16,19 +16,19 @@ from . import static from .utils import parse_date -FILTER_GRAMMAR = resources.open_text(static, "filter_grammar.lark") +FILTER_GRAMMAR = resources.open_text(static, 'filter_grammar.lark') -FILTER_PARSER = Lark(FILTER_GRAMMAR, start="filter") +FILTER_PARSER = Lark(FILTER_GRAMMAR, start='filter') _converters: Dict[str, Callable[[str], Any]] = { - "FLOAT": float, - "STRING": lambda s: s[1:-1], - "PROPERTY": str, - "INTEGER": int, - "DIGITS": int, - "DATE": parse_date, - "TIME": parse_date, - "DATETIME": parse_date, + 'FLOAT': float, + 'STRING': lambda s: s[1:-1], + 'PROPERTY': str, + 'INTEGER': int, + 'DIGITS': int, + 'DATE': parse_date, + 'TIME': parse_date, + 'DATETIME': parse_date, } @@ -42,7 +42,7 @@ def _parse_valuelist(valuelist: Tree) -> List[Union[int, float, str]]: output = [] for child in valuelist.children: try: - if child.data != "value": + if child.data != 'value': continue except AttributeError: continue @@ -58,11 +58,11 @@ def parse_filter_str(string: Optional[str]) -> Dict[str, Any]: try: tree = FILTER_PARSER.parse(string) except Exception as err: - raise ValueError(f"Malformed filter string: {err}") from err + raise ValueError(f'Malformed filter string: {err}') from err for child in tree.children: try: - if child.data != "comparison": + if child.data != 'comparison': continue except AttributeError: continue @@ -72,32 +72,32 @@ def parse_filter_str(string: Optional[str]) -> Dict[str, Any]: rhs_compare = rhs_tree.children[0] # parse the comparator value: Any - if rhs_compare.data == "value_op_rhs": + if rhs_compare.data == 'value_op_rhs': operator = rhs_compare.children[0].value.strip() value = _parse_value(rhs_compare.children[1].children[0]) - elif rhs_compare.data == "fuzzy_string_op_rhs": + elif rhs_compare.data == 'fuzzy_string_op_rhs': operator = rhs_compare.children[0].type.lower() value = _parse_value(rhs_compare.children[-1]) - elif rhs_compare.data == "length_op_rhs": - operator = "of_length" + elif rhs_compare.data == 'length_op_rhs': + operator = 'of_length' value = _parse_value(rhs_compare.children[-1]) - elif rhs_compare.data == "contains_op_rhs": - operator = "contains" + elif rhs_compare.data == 'contains_op_rhs': + operator = 'contains' value = _parse_valuelist(rhs_compare.children[-1]) - elif rhs_compare.data == "is_in_op_rhs": - operator = "in" + elif rhs_compare.data == 'is_in_op_rhs': + operator = 'in' value = _parse_valuelist(rhs_compare.children[-1]) - elif rhs_compare.data == "has_op_rhs": - operator = "has_key" + elif rhs_compare.data == 'has_op_rhs': + operator = 'has_key' value = _parse_value(rhs_compare.children[-1]) else: - raise ValueError(f"Unknown comparison: {rhs_compare.data}") + raise ValueError(f'Unknown comparison: {rhs_compare.data}') if prop_token.value in filters: - if "and" not in filters[prop_token.value]: + if 'and' not in filters[prop_token.value]: current = filters.pop(prop_token.value) - filters[prop_token.value] = {"and": [current]} - filters[prop_token.value]["and"].append({operator: value}) + filters[prop_token.value] = {'and': [current]} + filters[prop_token.value]['and'].append({operator: value}) else: filters[prop_token.value] = {operator: value} return filters diff --git a/aiida_restapi/graphql/basic.py b/aiida_restapi/graphql/basic.py index cff6d53..7219021 100644 --- a/aiida_restapi/graphql/basic.py +++ b/aiida_restapi/graphql/basic.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Defines plugins for basic information about aiida etc.""" + # pylint: disable=too-few-public-methods,unused-argument from typing import Any @@ -22,15 +22,13 @@ def resolve_aiidaVersion(parent: Any, info: gr.ResolveInfo) -> str: rowLimitMaxPlugin = QueryPlugin( - "rowLimitMax", - gr.Int( - description="Maximum number of entity rows allowed to be returned from a query" - ), + 'rowLimitMax', + gr.Int(description='Maximum number of entity rows allowed to be returned from a query'), resolve_rowLimitMax, ) aiidaVersionPlugin = QueryPlugin( - "aiidaVersion", - gr.String(description="Version of aiida-core"), + 'aiidaVersion', + gr.String(description='Version of aiida-core'), resolve_aiidaVersion, ) diff --git a/aiida_restapi/graphql/comments.py b/aiida_restapi/graphql/comments.py index 7fad36f..49467ac 100644 --- a/aiida_restapi/graphql/comments.py +++ b/aiida_restapi/graphql/comments.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Defines plugins for AiiDA comments.""" # pylint: disable=too-few-public-methods,redefined-builtin,unused-argument @@ -23,7 +22,7 @@ class CommentQuery(single_cls_factory(Comment)): # type: ignore[misc] """Query an AiiDA Comment""" -class CommentsQuery(multirow_cls_factory(CommentQuery, Comment, "comments")): # type: ignore[misc] +class CommentsQuery(multirow_cls_factory(CommentQuery, Comment, 'comments')): # type: ignore[misc] """Query all AiiDA Comments.""" @@ -37,28 +36,24 @@ def resolve_Comment( return resolve_entity(Comment, info, id, uuid) -def resolve_Comments( - parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None -) -> dict: +def resolve_Comments(parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None) -> dict: """Resolution function.""" # pass filter to CommentsQuery - return {"filters": parse_filter_str(filters)} + return {'filters': parse_filter_str(filters)} CommentQueryPlugin = QueryPlugin( - "comment", + 'comment', gr.Field( CommentQuery, id=gr.Int(), uuid=gr.String(), - description="Query for a single Comment", + description='Query for a single Comment', ), resolve_Comment, ) CommentsQueryPlugin = QueryPlugin( - "comments", - gr.Field( - CommentsQuery, description="Query for multiple Comments", filters=FilterString() - ), + 'comments', + gr.Field(CommentsQuery, description='Query for multiple Comments', filters=FilterString()), resolve_Comments, ) diff --git a/aiida_restapi/graphql/computers.py b/aiida_restapi/graphql/computers.py index 023c930..bc87718 100644 --- a/aiida_restapi/graphql/computers.py +++ b/aiida_restapi/graphql/computers.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Defines plugins for AiiDA computers.""" # pylint: disable=too-few-public-methods,redefined-builtin,,unused-argument @@ -26,17 +25,15 @@ class ComputerQuery(single_cls_factory(Computer)): # type: ignore[misc] nodes = gr.Field(NodesQuery, filters=FilterString()) @staticmethod - def resolve_nodes( - parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None - ) -> dict: + def resolve_nodes(parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None) -> dict: """Resolution function.""" # pass filter specification to NodesQuery parsed_filters = parse_filter_str(filters) - parsed_filters["dbcomputer_id"] = parent["id"] - return {"filters": parsed_filters} + parsed_filters['dbcomputer_id'] = parent['id'] + return {'filters': parsed_filters} -class ComputersQuery(multirow_cls_factory(ComputerQuery, Computer, "computers")): # type: ignore[misc] +class ComputersQuery(multirow_cls_factory(ComputerQuery, Computer, 'computers')): # type: ignore[misc] """Query all AiiDA Computers""" @@ -50,29 +47,27 @@ def resolve_Computer( return resolve_entity(Computer, info, id, uuid) -def resolve_Computers( - parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None -) -> dict: +def resolve_Computers(parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None) -> dict: """Resolution function.""" # pass filter to ComputersQuery - return {"filters": parse_filter_str(filters)} + return {'filters': parse_filter_str(filters)} ComputerQueryPlugin = QueryPlugin( - "computer", + 'computer', gr.Field( ComputerQuery, id=gr.Int(), uuid=gr.String(), - description="Query for a single Computer", + description='Query for a single Computer', ), resolve_Computer, ) ComputersQueryPlugin = QueryPlugin( - "computers", + 'computers', gr.Field( ComputersQuery, - description="Query for multiple Computers", + description='Query for multiple Computers', filters=FilterString(), ), resolve_Computers, diff --git a/aiida_restapi/graphql/config.py b/aiida_restapi/graphql/config.py index b6f63c3..5fab8a3 100644 --- a/aiida_restapi/graphql/config.py +++ b/aiida_restapi/graphql/config.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Configuration for Graphql.""" ENTITY_LIMIT = 100 diff --git a/aiida_restapi/graphql/entry_points.py b/aiida_restapi/graphql/entry_points.py index b4e9ecb..d1150d5 100644 --- a/aiida_restapi/graphql/entry_points.py +++ b/aiida_restapi/graphql/entry_points.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Defines plugins for retrieving entry-point group and name lists.""" + # pylint: disable=too-few-public-methods,unused-argument from typing import Any, Dict, List @@ -24,23 +24,21 @@ def resolve_aiidaEntryPointGroups(parent: Any, info: gr.ResolveInfo) -> List[str return list(ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys()) -def resolve_aiidaEntryPoints( - parent: Any, info: gr.ResolveInfo, group: str -) -> Dict[str, Any]: +def resolve_aiidaEntryPoints(parent: Any, info: gr.ResolveInfo, group: str) -> Dict[str, Any]: """Resolution function.""" - return {"group": group, "names": get_entry_point_names(group)} + return {'group': group, 'names': get_entry_point_names(group)} aiidaEntryPointGroupsPlugin = QueryPlugin( - "aiidaEntryPointGroups", - gr.List(gr.String, description="List of the entrypoint group names"), + 'aiidaEntryPointGroups', + gr.List(gr.String, description='List of the entrypoint group names'), resolve_aiidaEntryPointGroups, ) aiidaEntryPointsPlugin = QueryPlugin( - "aiidaEntryPoints", + 'aiidaEntryPoints', gr.Field( EntryPoints, - description="List of the entrypoint names in a group", + description='List of the entrypoint names in a group', group=gr.String(required=True), ), resolve_aiidaEntryPoints, diff --git a/aiida_restapi/graphql/groups.py b/aiida_restapi/graphql/groups.py index e34a983..21f7b12 100644 --- a/aiida_restapi/graphql/groups.py +++ b/aiida_restapi/graphql/groups.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Defines plugins for AiiDA groups.""" # pylint: disable=too-few-public-methods,redefined-builtin,,unused-argument @@ -29,10 +28,10 @@ class GroupQuery(single_cls_factory(Group)): # type: ignore[misc] def resolve_nodes(parent: Any, info: gr.ResolveInfo) -> dict: """Resolution function.""" # pass group specification to NodesQuery - return {"group_id": parent["id"]} + return {'group_id': parent['id']} -class GroupsQuery(multirow_cls_factory(GroupQuery, Group, "groups")): # type: ignore[misc] +class GroupsQuery(multirow_cls_factory(GroupQuery, Group, 'groups')): # type: ignore[misc] """Query all AiiDA Groups""" @@ -47,30 +46,28 @@ def resolve_Group( return resolve_entity(Group, info, id, uuid, label) -def resolve_Groups( - parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None -) -> dict: +def resolve_Groups(parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None) -> dict: """Resolution function.""" # pass filter to GroupsQuery - return {"filters": parse_filter_str(filters)} + return {'filters': parse_filter_str(filters)} GroupQueryPlugin = QueryPlugin( - "group", + 'group', gr.Field( GroupQuery, id=gr.Int(), uuid=gr.String(), label=gr.String(), - description="Query for a single Group", + description='Query for a single Group', ), resolve_Group, ) GroupsQueryPlugin = QueryPlugin( - "groups", + 'groups', gr.Field( GroupsQuery, - description="Query for multiple Groups", + description='Query for multiple Groups', filters=FilterString(), ), resolve_Groups, diff --git a/aiida_restapi/graphql/logs.py b/aiida_restapi/graphql/logs.py index 00c15dc..eebd2e6 100644 --- a/aiida_restapi/graphql/logs.py +++ b/aiida_restapi/graphql/logs.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Defines plugins for AiiDA process node logs.""" # pylint: disable=too-few-public-methods,redefined-builtin,,unused-argument @@ -23,7 +22,7 @@ class LogQuery(single_cls_factory(Log)): # type: ignore[misc] """Query an AiiDA Log""" -class LogsQuery(multirow_cls_factory(LogQuery, Log, "logs")): # type: ignore[misc] +class LogsQuery(multirow_cls_factory(LogQuery, Log, 'logs')): # type: ignore[misc] """Query all AiiDA Logs.""" @@ -37,26 +36,22 @@ def resolve_Log( return resolve_entity(Log, info, id, uuid) -def resolve_Logs( - parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None -) -> dict: +def resolve_Logs(parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None) -> dict: """Resolution function.""" # pass filter to LogsQuery - return {"filters": parse_filter_str(filters)} + return {'filters': parse_filter_str(filters)} LogQueryPlugin = QueryPlugin( - "log", - gr.Field( - LogQuery, id=gr.Int(), uuid=gr.String(), description="Query for a single Log" - ), + 'log', + gr.Field(LogQuery, id=gr.Int(), uuid=gr.String(), description='Query for a single Log'), resolve_Log, ) LogsQueryPlugin = QueryPlugin( - "logs", + 'logs', gr.Field( LogsQuery, - description="Query for multiple Logs", + description='Query for multiple Logs', filters=FilterString(), ), resolve_Logs, diff --git a/aiida_restapi/graphql/main.py b/aiida_restapi/graphql/main.py index d135838..059b22d 100644 --- a/aiida_restapi/graphql/main.py +++ b/aiida_restapi/graphql/main.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Main module that generates the full Graphql App.""" + from starlette_graphene3 import GraphQLApp from .basic import aiidaVersionPlugin, rowLimitMaxPlugin diff --git a/aiida_restapi/graphql/nodes.py b/aiida_restapi/graphql/nodes.py index 818a35e..b0e0fae 100644 --- a/aiida_restapi/graphql/nodes.py +++ b/aiida_restapi/graphql/nodes.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Defines plugins for AiiDA nodes.""" + # pylint: disable=redefined-builtin,too-few-public-methods,unused-argument from typing import Any, Dict, List, Optional @@ -20,7 +20,7 @@ ) from .utils import JSON, FilterString -Link = type("LinkObjectType", (gr.ObjectType,), fields_from_name("Link")) +Link = type('LinkObjectType', (gr.ObjectType,), fields_from_name('Link')) class LinkQuery(gr.ObjectType): @@ -28,151 +28,135 @@ class LinkQuery(gr.ObjectType): link = gr.Field(Link) # note: we must refer to this query using a string, to prevent circular dependencies - node = gr.Field("aiida_restapi.graphql.nodes.NodeQuery") + node = gr.Field('aiida_restapi.graphql.nodes.NodeQuery') -class LinksQuery(multirow_cls_factory(LinkQuery, orm.nodes.Node, "nodes")): # type: ignore[misc] +class LinksQuery(multirow_cls_factory(LinkQuery, orm.nodes.Node, 'nodes')): # type: ignore[misc] """Query all AiiDA Links.""" class NodeQuery( - single_cls_factory(orm.nodes.Node, exclude_fields=("attributes", "extras")) # type: ignore[misc] + single_cls_factory(orm.nodes.Node, exclude_fields=('attributes', 'extras')) # type: ignore[misc] ): """Query an AiiDA Node""" attributes = JSON( - description="Variable attributes of the node", + description='Variable attributes of the node', filter=gr.List( gr.String, - description="return an exact set of attributes keys (non-existent will return null)", + description='return an exact set of attributes keys (non-existent will return null)', ), ) extras = JSON( - description="Variable extras (unsealed) of the node", + description='Variable extras (unsealed) of the node', filter=gr.List( gr.String, - description="return an exact set of extras keys (non-existent will return null)", + description='return an exact set of extras keys (non-existent will return null)', ), ) # TODO it would be ideal if the attributes/extras were filtered via the SQL query @staticmethod - def resolve_attributes( - parent: Any, info: gr.ResolveInfo, filter: Optional[List[str]] = None - ) -> Dict[str, Any]: + def resolve_attributes(parent: Any, info: gr.ResolveInfo, filter: Optional[List[str]] = None) -> Dict[str, Any]: """Resolution function.""" - attributes = parent.get("attributes") + attributes = parent.get('attributes') if filter is None or attributes is None: return attributes return {key: attributes.get(key) for key in filter} @staticmethod - def resolve_extras( - parent: Any, info: gr.ResolveInfo, filter: Optional[List[str]] = None - ) -> Dict[str, Any]: + def resolve_extras(parent: Any, info: gr.ResolveInfo, filter: Optional[List[str]] = None) -> Dict[str, Any]: """Resolution function.""" - extras = parent.get("extras") + extras = parent.get('extras') if filter is None or extras is None: return extras return {key: extras.get(key) for key in filter} - comments = gr.Field(CommentsQuery, description="Comments attached to a node") + comments = gr.Field(CommentsQuery, description='Comments attached to a node') @staticmethod def resolve_comments(parent: Any, info: gr.ResolveInfo) -> dict: """Resolution function.""" # pass filter specification to CommentsQuery filters = {} - filters["dbnode_id"] = parent["id"] - return {"filters": filters} + filters['dbnode_id'] = parent['id'] + return {'filters': filters} - logs = gr.Field(LogsQuery, description="Logs attached to a process node") + logs = gr.Field(LogsQuery, description='Logs attached to a process node') @staticmethod def resolve_logs(parent: Any, info: gr.ResolveInfo) -> dict: """Resolution function.""" # pass filter specification to CommentsQuery filters = {} - filters["dbnode_id"] = parent["id"] - return {"filters": filters} + filters['dbnode_id'] = parent['id'] + return {'filters': filters} - incoming = gr.Field( - LinksQuery, description="Query for incoming nodes", filters=FilterString() - ) + incoming = gr.Field(LinksQuery, description='Query for incoming nodes', filters=FilterString()) @staticmethod - def resolve_incoming( - parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None - ) -> dict: + def resolve_incoming(parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None) -> dict: """Resolution function.""" # pass edge specification to LinksQuery return { - "parent_id": parent["id"], + 'parent_id': parent['id'], # this node is outgoing relative to the incoming nodes - "edge_type": "outgoing", - "project_edge": True, - "filters": parse_filter_str(filters), + 'edge_type': 'outgoing', + 'project_edge': True, + 'filters': parse_filter_str(filters), } - outgoing = gr.Field( - LinksQuery, description="Query for outgoing nodes", filters=FilterString() - ) + outgoing = gr.Field(LinksQuery, description='Query for outgoing nodes', filters=FilterString()) @staticmethod - def resolve_outgoing( - parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None - ) -> dict: + def resolve_outgoing(parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None) -> dict: """Resolution function.""" # pass edge specification to LinksQuery return { - "parent_id": parent["id"], + 'parent_id': parent['id'], # this node is incoming relative to the outgoing nodes - "edge_type": "incoming", - "project_edge": True, - "filters": parse_filter_str(filters), + 'edge_type': 'incoming', + 'project_edge': True, + 'filters': parse_filter_str(filters), } ancestors = gr.Field( - "aiida_restapi.graphql.nodes.NodesQuery", - description="Query for ancestor nodes", + 'aiida_restapi.graphql.nodes.NodesQuery', + description='Query for ancestor nodes', filters=FilterString(), ) @staticmethod - def resolve_ancestors( - parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None - ) -> dict: + def resolve_ancestors(parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None) -> dict: """Resolution function.""" # pass edge specification to LinksQuery return { - "parent_id": parent["id"], + 'parent_id': parent['id'], # this node is a descendant relative to the ancestor nodes - "edge_type": "descendants", - "filters": parse_filter_str(filters), + 'edge_type': 'descendants', + 'filters': parse_filter_str(filters), } descendants = gr.Field( - "aiida_restapi.graphql.nodes.NodesQuery", - description="Query for descendant nodes", + 'aiida_restapi.graphql.nodes.NodesQuery', + description='Query for descendant nodes', filters=FilterString(), ) @staticmethod - def resolve_descendants( - parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None - ) -> dict: + def resolve_descendants(parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None) -> dict: """Resolution function.""" # pass edge specification to LinksQuery return { - "parent_id": parent["id"], + 'parent_id': parent['id'], # this node is an ancestor relative to the descendant nodes - "edge_type": "ancestors", - "filters": parse_filter_str(filters), + 'edge_type': 'ancestors', + 'filters': parse_filter_str(filters), } -class NodesQuery(multirow_cls_factory(NodeQuery, orm.nodes.Node, "nodes")): # type: ignore[misc] +class NodesQuery(multirow_cls_factory(NodeQuery, orm.nodes.Node, 'nodes')): # type: ignore[misc] """Query all AiiDA Nodes""" @@ -186,25 +170,19 @@ def resolve_Node( return resolve_entity(orm.nodes.Node, info, id, uuid) -def resolve_Nodes( - parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None -) -> dict: +def resolve_Nodes(parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None) -> dict: """Resolution function.""" # pass filter to NodesQuery - return {"filters": parse_filter_str(filters)} + return {'filters': parse_filter_str(filters)} NodeQueryPlugin = QueryPlugin( - "node", - gr.Field( - NodeQuery, id=gr.Int(), uuid=gr.String(), description="Query for a single Node" - ), + 'node', + gr.Field(NodeQuery, id=gr.Int(), uuid=gr.String(), description='Query for a single Node'), resolve_Node, ) NodesQueryPlugin = QueryPlugin( - "nodes", - gr.Field( - NodesQuery, description="Query for multiple Nodes", filters=FilterString() - ), + 'nodes', + gr.Field(NodesQuery, description='Query for multiple Nodes', filters=FilterString()), resolve_Nodes, ) diff --git a/aiida_restapi/graphql/orm_factories.py b/aiida_restapi/graphql/orm_factories.py index eac6d1e..3a7d2f9 100644 --- a/aiida_restapi/graphql/orm_factories.py +++ b/aiida_restapi/graphql/orm_factories.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Classes and functions to auto-generate base ObjectTypes for aiida orm entities.""" + # pylint: disable=unused-argument,redefined-builtin import typing from datetime import datetime @@ -40,9 +40,7 @@ def get_pydantic_type_name(annotation: Any) -> Any: return annotation -def fields_from_orm( - cls: Type[orm.Entity], exclude_fields: Sequence[str] = () -) -> Dict[str, gr.Scalar]: +def fields_from_orm(cls: Type[orm.Entity], exclude_fields: Sequence[str] = ()) -> Dict[str, gr.Scalar]: """Extract the fields from an AIIDA ORM class and convert them to graphene objects.""" output = {} for name, field in get_model_from_orm(cls).model_fields.items(): @@ -54,9 +52,7 @@ def fields_from_orm( return output -def fields_from_name( - cls: str, exclude_fields: Sequence[str] = () -) -> Dict[str, gr.Scalar]: +def fields_from_name(cls: str, exclude_fields: Sequence[str] = ()) -> Dict[str, gr.Scalar]: """Extract the fields from an AIIDA ORM class name and convert them to graphene objects.""" output = {} for name, field in ORM_MAPPING[cls].model_fields.items(): @@ -73,9 +69,7 @@ def field_names_from_orm(cls: Type[orm.Entity]) -> Set[str]: return set(get_model_from_orm(cls).model_fields.keys()) -def get_projection( - db_fields: Set[str], info: gr.ResolveInfo, is_link: bool = False -) -> Union[List[str], str]: +def get_projection(db_fields: Set[str], info: gr.ResolveInfo, is_link: bool = False) -> Union[List[str], str]: """Traverse the child AST to work out what fields we should project. Any fields found that are not database fields, are assumed to be joins. @@ -85,77 +79,70 @@ def get_projection( """ if is_link: # TODO here we need to look deeper under the "node" field - return "**" + return '**' try: selected = set(selected_field_names_naive(info.field_nodes[0].selection_set)) fields = db_fields.intersection(selected) joins = db_fields.difference(selected) if joins: - fields.add("id") + fields.add('id') return list(fields) except NotImplementedError: - return "**" + return '**' -def single_cls_factory( - orm_cls: Type[orm.Entity], exclude_fields: Sequence[str] = () -) -> Type[gr.ObjectType]: +def single_cls_factory(orm_cls: Type[orm.Entity], exclude_fields: Sequence[str] = ()) -> Type[gr.ObjectType]: """Create a graphene class with standard fields/resolvers for querying a single AiiDA ORM entity.""" - return type( - "AiidaOrmObjectType", (gr.ObjectType,), fields_from_orm(orm_cls, exclude_fields) - ) + return type('AiidaOrmObjectType', (gr.ObjectType,), fields_from_orm(orm_cls, exclude_fields)) EntitiesParentType = Optional[Dict[str, Any]] -def create_query_path( - query: orm.QueryBuilder, parent: Dict[str, Any] -) -> Dict[str, Any]: +def create_query_path(query: orm.QueryBuilder, parent: Dict[str, Any]) -> Dict[str, Any]: """Append parent entities to the ``QueryBuilder`` path. :param parent: data from the parent resolver :returns: key-word arguments for the "leaf" path """ leaf_kwargs: Dict[str, Any] = {} - if "group_id" in parent: - query.append(orm.Group, filters={"id": parent["group_id"]}, tag="group") - leaf_kwargs["with_group"] = "group" - if "edge_type" in parent: + if 'group_id' in parent: + query.append(orm.Group, filters={'id': parent['group_id']}, tag='group') + leaf_kwargs['with_group'] = 'group' + if 'edge_type' in parent: query.append( orm.nodes.Node, - filters={"id": parent["parent_id"]}, - tag=parent["edge_type"], + filters={'id': parent['parent_id']}, + tag=parent['edge_type'], ) - leaf_kwargs[f'with_{parent["edge_type"]}'] = parent["edge_type"] - if parent.get("project_edge"): - leaf_kwargs["edge_tag"] = f'{parent["edge_type"]}_edge' - leaf_kwargs["edge_project"] = "**" + leaf_kwargs[f'with_{parent["edge_type"]}'] = parent['edge_type'] + if parent.get('project_edge'): + leaf_kwargs['edge_tag'] = f'{parent["edge_type"]}_edge' + leaf_kwargs['edge_project'] = '**' return leaf_kwargs -def multirow_cls_factory( - entity_cls: Type[gr.ObjectType], orm_cls: Type[orm.Entity], name: str -) -> Type[gr.ObjectType]: - """Create a graphene class with standard fields/resolvers for querying multiple rows of the same AiiDA ORM entity.""" +def multirow_cls_factory(entity_cls: Type[gr.ObjectType], orm_cls: Type[orm.Entity], name: str) -> Type[gr.ObjectType]: + """Create a graphene class with standard fields/resolvers for querying multiple rows of the same AiiDA ORM + entity.""" db_fields = field_names_from_orm(orm_cls) class AiidaOrmRowsType(gr.ObjectType): """A class for querying multiple rows of the same AiiDA ORM entity.""" - count = gr.Int(description=f"Total number of rows of {name}") + count = gr.Int(description=f'Total number of rows of {name}') rows = gr.List( entity_cls, limit=gr.Int( default_value=ENTITY_LIMIT, - description=f"Maximum number of rows to return (no more than {ENTITY_LIMIT})", + description=f'Maximum number of rows to return (no more than {ENTITY_LIMIT})', ), - offset=gr.Int(default_value=0, description="Skip the first n rows"), - orderBy=gr.String(description="Field to order rows by", default_value="id"), + offset=gr.Int(default_value=0, description='Skip the first n rows'), + orderBy=gr.String(description='Field to order rows by', default_value='id'), orderAsc=gr.Boolean( default_value=True, - description="Sort field in ascending order, else descending.", + description='Sort field in ascending order, else descending.', ), ) @@ -166,7 +153,7 @@ def resolve_count(parent: EntitiesParentType, info: gr.ResolveInfo) -> int: parent = parent or {} query = orm.QueryBuilder() leaf_kwargs = create_query_path(query, parent) - leaf_kwargs["filters"] = parent.get("filters", None) + leaf_kwargs['filters'] = parent.get('filters', None) query.append(orm_cls, **leaf_kwargs) return query.count() @@ -189,19 +176,15 @@ def resolve_rows( # pylint: disable=too-many-arguments :param orderAsc: Sort field in ascending order, else descending """ if limit > ENTITY_LIMIT: - raise GraphQLError( - f"{name} 'limit' must be no more than {ENTITY_LIMIT}" - ) + raise GraphQLError(f"{name} 'limit' must be no more than {ENTITY_LIMIT}") parent = parent or {} # setup the query query = orm.QueryBuilder() leaf_kwargs = create_query_path(query, parent) - leaf_kwargs["filters"] = parent.get("filters", None) - leaf_kwargs["project"] = get_projection( - db_fields, info, is_link=(parent.get("project_edge") is True) - ) - leaf_kwargs["tag"] = "fields" + leaf_kwargs['filters'] = parent.get('filters', None) + leaf_kwargs['project'] = get_projection(db_fields, info, is_link=(parent.get('project_edge') is True)) + leaf_kwargs['tag'] = 'fields' query.append(orm_cls, **leaf_kwargs) # setup returned rows configuration of the query @@ -209,15 +192,12 @@ def resolve_rows( # pylint: disable=too-many-arguments query.offset(offset) query.limit(limit) if orderBy: - query.order_by({"fields": {orderBy: "asc" if orderAsc else "desc"}}) + query.order_by({'fields': {orderBy: 'asc' if orderAsc else 'desc'}}) # run query - if parent.get("project_edge") is True: - return [ - {"node": d["fields"], "link": d[f'{parent["edge_type"]}_edge']} - for d in query.dict() - ] - return [d["fields"] for d in query.dict()] + if parent.get('project_edge') is True: + return [{'node': d['fields'], 'link': d[f'{parent["edge_type"]}_edge']} for d in query.dict()] + return [d['fields'] for d in query.dict()] return AiidaOrmRowsType @@ -232,7 +212,7 @@ def resolve_entity( id: Optional[int] = None, uuid: Optional[str] = None, label: Optional[str] = None, - uuid_name: str = "uuid", + uuid_name: str = 'uuid', ) -> ENTITY_DICT_TYPE: """Query for a single entity, and project only the fields requested. @@ -240,22 +220,18 @@ def resolve_entity( """ filters: Dict[str, Union[str, int]] if id is not None: - assert uuid is None, f"Only one of id or {uuid_name} can be specified" - filters = {"id": id} + assert uuid is None, f'Only one of id or {uuid_name} can be specified' + filters = {'id': id} elif uuid is not None: filters = {uuid_name: uuid} elif label is not None: - filters = {"label": label} + filters = {'label': label} else: - raise AssertionError(f"One of id, {uuid_name}, or label must be specified") + raise AssertionError(f'One of id, {uuid_name}, or label must be specified') db_fields = field_names_from_orm(orm_cls) project = get_projection(db_fields, info) - entities = ( - orm.QueryBuilder() - .append(orm_cls, tag="result", filters=filters, project=project) - .dict() - ) + entities = orm.QueryBuilder().append(orm_cls, tag='result', filters=filters, project=project).dict() if not entities: return None - return entities[0]["result"] + return entities[0]['result'] diff --git a/aiida_restapi/graphql/plugins.py b/aiida_restapi/graphql/plugins.py index ea8254d..080c662 100644 --- a/aiida_restapi/graphql/plugins.py +++ b/aiida_restapi/graphql/plugins.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Module defining the graphql plugin mechanism.""" + from typing import Any, Callable, Dict, NamedTuple, Sequence, Type, Union import graphene as gr @@ -17,31 +17,27 @@ class QueryPlugin(NamedTuple): resolver: ResolverType -def create_query( - queries: Sequence[QueryPlugin], docstring: str = "The root query" -) -> Type[gr.ObjectType]: +def create_query(queries: Sequence[QueryPlugin], docstring: str = 'The root query') -> Type[gr.ObjectType]: """Generate a query from a sequence of query plugins.""" # check that there are no duplicate names name_map: Dict[str, QueryPlugin] = {} # construct the dict of attributes/methods on the class attr_map: Dict[str, Union[gr.ObjectType, ResolverType]] = {} for query in queries: - if query.name.startswith("resolve_"): - raise ValueError("Plugin name cannot") + if query.name.startswith('resolve_'): + raise ValueError('Plugin name cannot') if query.name in name_map: - raise ValueError( - f"Duplicate plugin name '{query.name}': {query} and {name_map[query.name]}" - ) + raise ValueError(f"Duplicate plugin name '{query.name}': {query} and {name_map[query.name]}") name_map[query.name] = query attr_map[query.name] = query.field - attr_map[f"resolve_{query.name}"] = query.resolver - attr_map["__doc__"] = docstring - return type("RootQuery", (gr.ObjectType,), attr_map) + attr_map[f'resolve_{query.name}'] = query.resolver + attr_map['__doc__'] = docstring + return type('RootQuery', (gr.ObjectType,), attr_map) def create_schema( queries: Sequence[QueryPlugin], - docstring: str = "The root query", + docstring: str = 'The root query', auto_camelcase: bool = False, **kwargs: Any, ) -> gr.Schema: @@ -49,6 +45,4 @@ def create_schema( Note we set auto_camelcase False, since this keeps database field names the same. """ - return gr.Schema( - query=create_query(queries, docstring), auto_camelcase=auto_camelcase, **kwargs - ) + return gr.Schema(query=create_query(queries, docstring), auto_camelcase=auto_camelcase, **kwargs) diff --git a/aiida_restapi/graphql/sphinx_ext.py b/aiida_restapi/graphql/sphinx_ext.py index dfa2a89..329b1ea 100644 --- a/aiida_restapi/graphql/sphinx_ext.py +++ b/aiida_restapi/graphql/sphinx_ext.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Sphinx extension for documenting the GraphQL schema.""" + # pylint: disable=import-outside-toplevel from typing import TYPE_CHECKING, List @@ -8,11 +8,10 @@ from .main import SCHEMA if TYPE_CHECKING: - from docutils.nodes import literal_block from sphinx.application import Sphinx -def setup(app: "Sphinx") -> None: +def setup(app: 'Sphinx') -> None: """Setup the sphinx extension.""" from docutils.nodes import Element, literal_block from sphinx.util.docutils import SphinxDirective @@ -29,4 +28,4 @@ def run(self) -> List[Element]: self.set_source_info(code_node) return [code_node] - app.add_directive("aiida-graphql-schema", SchemaDirective) + app.add_directive('aiida-graphql-schema', SchemaDirective) diff --git a/aiida_restapi/graphql/users.py b/aiida_restapi/graphql/users.py index b14a3bb..cc0a807 100644 --- a/aiida_restapi/graphql/users.py +++ b/aiida_restapi/graphql/users.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Defines plugins for AiiDA users.""" + # pylint: disable=too-few-public-methods,redefined-builtin,unused-argument from typing import Any, Optional @@ -25,14 +25,12 @@ class UserQuery(single_cls_factory(User)): # type: ignore[misc] nodes = gr.Field(NodesQuery, filters=FilterString()) @staticmethod - def resolve_nodes( - parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None - ) -> dict: + def resolve_nodes(parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None) -> dict: """Resolution function.""" # pass filter specification to NodesQuery parsed_filters = parse_filter_str(filters) - parsed_filters["user_id"] = parent["id"] - return {"filters": parsed_filters} + parsed_filters['user_id'] = parent['id'] + return {'filters': parsed_filters} def resolve_User( @@ -42,32 +40,26 @@ def resolve_User( email: Optional[str] = None, ) -> ENTITY_DICT_TYPE: """Resolution function.""" - return resolve_entity(User, info, id, email, uuid_name="email") + return resolve_entity(User, info, id, email, uuid_name='email') -class UsersQuery(multirow_cls_factory(UserQuery, User, "users")): # type: ignore[misc] +class UsersQuery(multirow_cls_factory(UserQuery, User, 'users')): # type: ignore[misc] """Query all AiiDA Users""" -def resolve_Users( - parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None -) -> dict: +def resolve_Users(parent: Any, info: gr.ResolveInfo, filters: Optional[str] = None) -> dict: """Resolution function.""" # pass filter to UsersQuery - return {"filters": parse_filter_str(filters)} + return {'filters': parse_filter_str(filters)} UserQueryPlugin = QueryPlugin( - "user", - gr.Field( - UserQuery, id=gr.Int(), email=gr.String(), description="Query for a single User" - ), + 'user', + gr.Field(UserQuery, id=gr.Int(), email=gr.String(), description='Query for a single User'), resolve_User, ) UsersQueryPlugin = QueryPlugin( - "users", - gr.Field( - UsersQuery, description="Query for multiple Users", filters=FilterString() - ), + 'users', + gr.Field(UsersQuery, description='Query for multiple Users', filters=FilterString()), resolve_Users, ) diff --git a/aiida_restapi/graphql/utils.py b/aiida_restapi/graphql/utils.py index 97ad042..6559926 100644 --- a/aiida_restapi/graphql/utils.py +++ b/aiida_restapi/graphql/utils.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- """Utility functions for graphql.""" # pylint: disable=unused-argument,too-many-arguments - -from typing import Any, Iterator +from typing import Iterator import graphene as gr from graphene.types.scalars import Scalar @@ -35,8 +33,6 @@ def selected_field_names_naive(selection_set: ast.SelectionSetNode) -> Iterator[ yield node.name.value # Fragment spread (`... fragmentName`) elif isinstance(node, (ast.FragmentSpreadNode, ast.InlineFragmentNode)): - raise NotImplementedError( - "Fragments are not supported by this simplistic function" - ) + raise NotImplementedError('Fragments are not supported by this simplistic function') else: raise NotImplementedError(str(type(node))) diff --git a/aiida_restapi/main.py b/aiida_restapi/main.py index bbbc180..5e75b8c 100644 --- a/aiida_restapi/main.py +++ b/aiida_restapi/main.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Declaration of FastAPI application.""" + from fastapi import FastAPI from aiida_restapi.graphql import main @@ -13,4 +13,4 @@ app.include_router(groups.router) app.include_router(users.router) app.include_router(process.router) -app.add_route("/graphql", main.app, name="graphql") +app.add_route('/graphql', main.app, name='graphql') diff --git a/aiida_restapi/models.py b/aiida_restapi/models.py index a5c3222..9c00e62 100644 --- a/aiida_restapi/models.py +++ b/aiida_restapi/models.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Schemas for AiiDA REST API. Models in this module mirror those in @@ -17,7 +16,7 @@ from pydantic import BaseModel, ConfigDict, Field # Template type for subclasses of `AiidaModel` -ModelType = TypeVar("ModelType", bound="AiidaModel") +ModelType = TypeVar('ModelType', bound='AiidaModel') def as_form(cls: Type[BaseModel]) -> Type[BaseModel]: @@ -34,9 +33,7 @@ def as_form(cls: Type[BaseModel]) -> Type[BaseModel]: inspect.Parameter( name=field_name, kind=inspect.Parameter.POSITIONAL_ONLY, - default=Form(...) - if model_field.is_required() - else Form(model_field.default), + default=Form(...) if model_field.is_required() else Form(model_field.default), annotation=model_field.annotation, ) ) @@ -47,7 +44,7 @@ async def as_form_func(**data: Dict[str, Any]) -> Any: sig = inspect.signature(as_form_func) sig = sig.replace(parameters=new_parameters) as_form_func.__signature__ = sig # type: ignore - setattr(cls, "as_form", as_form_func) + setattr(cls, 'as_form', as_form_func) return cls @@ -55,12 +52,12 @@ class AiidaModel(BaseModel): """A mapping of an AiiDA entity to a pydantic model.""" _orm_entity: ClassVar[Type[orm.entities.Entity]] = orm.entities.Entity - model_config = ConfigDict(from_attributes=True, extra="forbid") + model_config = ConfigDict(from_attributes=True, extra='forbid') @classmethod def get_projectable_properties(cls) -> List[str]: """Return projectable properties.""" - return list(cls.schema()["properties"].keys()) + return list(cls.schema()['properties'].keys()) @classmethod def get_entities( @@ -82,19 +79,17 @@ def get_entities( else: assert set(cls.get_projectable_properties()).issuperset( project - ), f"projection not subset of projectable properties: {project!r}" - query = orm.QueryBuilder().append( - cls._orm_entity, tag="fields", project=project - ) + ), f'projection not subset of projectable properties: {project!r}' + query = orm.QueryBuilder().append(cls._orm_entity, tag='fields', project=project) if page_size is not None: query.offset(page_size * (page - 1)) query.limit(page_size) if order_by is not None: assert set(cls.get_projectable_properties()).issuperset( order_by - ), f"order_by not subset of projectable properties: {project!r}" - query.order_by({"fields": order_by}) - return [cls(**result["fields"]) for result in query.dict()] + ), f'order_by not subset of projectable properties: {project!r}' + query.order_by({'fields': order_by}) + return [cls(**result['fields']) for result in query.dict()] class Comment(AiidaModel): @@ -102,28 +97,26 @@ class Comment(AiidaModel): _orm_entity = orm.Comment - id: Optional[int] = Field(None, description="Unique comment id (pk)") - uuid: str = Field(description="Unique comment uuid") - ctime: Optional[datetime] = Field(None, description="Creation time") - mtime: Optional[datetime] = Field(None, description="Last modification time") - content: Optional[str] = Field(None, description="Comment content") - dbnode_id: Optional[int] = Field(None, description="Unique node id (pk)") - user_id: Optional[int] = Field(None, description="Unique user id (pk)") + id: Optional[int] = Field(None, description='Unique comment id (pk)') + uuid: str = Field(description='Unique comment uuid') + ctime: Optional[datetime] = Field(None, description='Creation time') + mtime: Optional[datetime] = Field(None, description='Last modification time') + content: Optional[str] = Field(None, description='Comment content') + dbnode_id: Optional[int] = Field(None, description='Unique node id (pk)') + user_id: Optional[int] = Field(None, description='Unique user id (pk)') class User(AiidaModel): """AiiDA User model.""" _orm_entity = orm.User - model_config = ConfigDict(extra="allow") - - id: Optional[int] = Field(None, description="Unique user id (pk)") - email: str = Field(description="Email address of the user") - first_name: Optional[str] = Field(None, description="First name of the user") - last_name: Optional[str] = Field(None, description="Last name of the user") - institution: Optional[str] = Field( - None, description="Host institution or workplace of the user" - ) + model_config = ConfigDict(extra='allow') + + id: Optional[int] = Field(None, description='Unique user id (pk)') + email: str = Field(description='Email address of the user') + first_name: Optional[str] = Field(None, description='First name of the user') + last_name: Optional[str] = Field(None, description='Last name of the user') + institution: Optional[str] = Field(None, description='Host institution or workplace of the user') class Computer(AiidaModel): @@ -131,27 +124,25 @@ class Computer(AiidaModel): _orm_entity = orm.Computer - id: Optional[int] = Field(None, description="Unique computer id (pk)") - uuid: Optional[str] = Field(None, description="Unique id for computer") - label: str = Field(description="Used to identify a computer. Must be unique") - hostname: Optional[str] = Field( - None, description="Label that identifies the computer within the network" - ) + id: Optional[int] = Field(None, description='Unique computer id (pk)') + uuid: Optional[str] = Field(None, description='Unique id for computer') + label: str = Field(description='Used to identify a computer. Must be unique') + hostname: Optional[str] = Field(None, description='Label that identifies the computer within the network') scheduler_type: Optional[str] = Field( None, - description="The scheduler (and plugin) that the computer uses to manage jobs", + description='The scheduler (and plugin) that the computer uses to manage jobs', ) transport_type: Optional[str] = Field( None, - description="The transport (and plugin) \ - required to copy files and communicate to and from the computer", + description='The transport (and plugin) \ + required to copy files and communicate to and from the computer', ) metadata: Optional[dict] = Field( None, - description="General settings for these communication and management protocols", + description='General settings for these communication and management protocols', ) - description: Optional[str] = Field(None, description="Description of node") + description: Optional[str] = Field(None, description='Description of node') class Node(AiidaModel): @@ -159,29 +150,27 @@ class Node(AiidaModel): _orm_entity = orm.Node - id: Optional[int] = Field(None, description="Unique id (pk)") - uuid: Optional[UUID] = Field(None, description="Unique uuid") - node_type: Optional[str] = Field(None, description="Node type") - process_type: Optional[str] = Field(None, description="Process type") - label: str = Field(description="Label of node") - description: Optional[str] = Field(None, description="Description of node") - ctime: Optional[datetime] = Field(None, description="Creation time") - mtime: Optional[datetime] = Field(None, description="Last modification time") - user_id: Optional[int] = Field(None, description="Created by user id (pk)") - dbcomputer_id: Optional[int] = Field( - None, description="Associated computer id (pk)" - ) + id: Optional[int] = Field(None, description='Unique id (pk)') + uuid: Optional[UUID] = Field(None, description='Unique uuid') + node_type: Optional[str] = Field(None, description='Node type') + process_type: Optional[str] = Field(None, description='Process type') + label: str = Field(description='Label of node') + description: Optional[str] = Field(None, description='Description of node') + ctime: Optional[datetime] = Field(None, description='Creation time') + mtime: Optional[datetime] = Field(None, description='Last modification time') + user_id: Optional[int] = Field(None, description='Created by user id (pk)') + dbcomputer_id: Optional[int] = Field(None, description='Associated computer id (pk)') attributes: Optional[Dict] = Field( None, - description="Variable attributes of the node", + description='Variable attributes of the node', ) extras: Optional[Dict] = Field( None, - description="Variable extras (unsealed) of the node", + description='Variable extras (unsealed) of the node', ) repository_metadata: Optional[Dict] = Field( None, - description="Metadata about file repository associated with this node", + description='Metadata about file repository associated with this node', ) @@ -189,21 +178,19 @@ class Node(AiidaModel): class Node_Post(AiidaModel): """AiiDA model for posting Nodes.""" - entry_point: str = Field(description="Entry_point") - process_type: Optional[str] = Field(None, description="Process type") - label: Optional[str] = Field(None, description="Label of node") - description: Optional[str] = Field(None, description="Description of node") - user_id: Optional[int] = Field(None, description="Created by user id (pk)") - dbcomputer_id: Optional[int] = Field( - None, description="Associated computer id (pk)" - ) + entry_point: str = Field(description='Entry_point') + process_type: Optional[str] = Field(None, description='Process type') + label: Optional[str] = Field(None, description='Label of node') + description: Optional[str] = Field(None, description='Description of node') + user_id: Optional[int] = Field(None, description='Created by user id (pk)') + dbcomputer_id: Optional[int] = Field(None, description='Associated computer id (pk)') attributes: Optional[Dict] = Field( None, - description="Variable attributes of the node", + description='Variable attributes of the node', ) extras: Optional[Dict] = Field( None, - description="Variable extras (unsealed) of the node", + description='Variable extras (unsealed) of the node', ) @classmethod @@ -213,13 +200,13 @@ def create_new_node( node_dict: dict, ) -> orm.Node: """Create and Store new Node""" - attributes = node_dict.pop("attributes", {}) - extras = node_dict.pop("extras", {}) - repository_metadata = node_dict.pop("repository_metadata", {}) + attributes = node_dict.pop('attributes', {}) + extras = node_dict.pop('extras', {}) + repository_metadata = node_dict.pop('repository_metadata', {}) if issubclass(orm_class, orm.BaseType): orm_object = orm_class( - attributes["value"], + attributes['value'], **node_dict, ) elif issubclass(orm_class, orm.Dict): @@ -229,17 +216,17 @@ def create_new_node( ) elif issubclass(orm_class, orm.InstalledCode): orm_object = orm_class( - computer=orm.load_computer(pk=node_dict.get("dbcomputer_id")), - filepath_executable=attributes["filepath_executable"], + computer=orm.load_computer(pk=node_dict.get('dbcomputer_id')), + filepath_executable=attributes['filepath_executable'], ) - orm_object.label = node_dict.get("label") + orm_object.label = node_dict.get('label') elif issubclass(orm_class, orm.PortableCode): orm_object = orm_class( - computer=orm.load_computer(pk=node_dict.get("dbcomputer_id")), - filepath_executable=attributes["filepath_executable"], - filepath_files=attributes["filepath_files"], + computer=orm.load_computer(pk=node_dict.get('dbcomputer_id')), + filepath_executable=attributes['filepath_executable'], + filepath_files=attributes['filepath_files'], ) - orm_object.label = node_dict.get("label") + orm_object.label = node_dict.get('label') else: orm_object = orm_class(**node_dict) orm_object.base.attributes.set_many(attributes) @@ -257,8 +244,8 @@ def create_new_node_with_file( file: Path, ) -> orm.Node: """Create and Store new Node with file""" - attributes = node_dict.pop("attributes", {}) - extras = node_dict.pop("extras", {}) + attributes = node_dict.pop('attributes', {}) + extras = node_dict.pop('extras', {}) orm_object = orm_class(file=file, **node_dict, **attributes) @@ -272,14 +259,14 @@ class Group(AiidaModel): _orm_entity = orm.Group - id: int = Field(description="Unique id (pk)") - uuid: UUID = Field(description="Universally unique id") - label: str = Field(description="Label of group") - type_string: str = Field(description="type of the group") - description: Optional[str] = Field(None, description="Description of group") - extras: Optional[Dict] = Field(None, description="extra data about for the group") - time: datetime = Field(description="Created time") - user_id: int = Field(description="Created by user id (pk)") + id: int = Field(description='Unique id (pk)') + uuid: UUID = Field(description='Universally unique id') + label: str = Field(description='Label of group') + type_string: str = Field(description='type of the group') + description: Optional[str] = Field(None, description='Description of group') + extras: Optional[Dict] = Field(None, description='extra data about for the group') + time: datetime = Field(description='Created time') + user_id: int = Field(description='Created by user id (pk)') @classmethod def from_orm(cls, orm_entity: orm.Group) -> orm.Group: @@ -295,14 +282,14 @@ def from_orm(cls, orm_entity: orm.Group) -> orm.Group: orm.QueryBuilder() .append( cls._orm_entity, - filters={"pk": orm_entity.id}, - tag="fields", - project=["user_id", "time"], + filters={'pk': orm_entity.id}, + tag='fields', + project=['user_id', 'time'], ) .limit(1) ) - orm_entity.user_id = query.dict()[0]["fields"]["user_id"] - orm_entity.time = query.dict()[0]["fields"]["time"] + orm_entity.user_id = query.dict()[0]['fields']['user_id'] + orm_entity.time = query.dict()[0]['fields']['time'] return super().from_orm(orm_entity) @@ -312,11 +299,9 @@ class Group_Post(AiidaModel): _orm_entity = orm.Group - label: str = Field(description="Used to access the group. Must be unique.") - type_string: Optional[str] = Field(None, description="Type of the group") - description: Optional[str] = Field( - None, description="Short description of the group." - ) + label: str = Field(description='Used to access the group. Must be unique.') + type_string: Optional[str] = Field(None, description='Type of the group') + description: Optional[str] = Field(None, description='Short description of the group.') class Process(AiidaModel): @@ -324,35 +309,33 @@ class Process(AiidaModel): _orm_entity = orm.ProcessNode - id: Optional[int] = Field(None, description="Unique id (pk)") - uuid: Optional[UUID] = Field(None, description="Universally unique identifier") - node_type: Optional[str] = Field(None, description="Node type") - process_type: Optional[str] = Field(None, description="Process type") - label: str = Field(description="Label of node") - description: Optional[str] = Field(None, description="Description of node") - ctime: Optional[datetime] = Field(None, description="Creation time") - mtime: Optional[datetime] = Field(None, description="Last modification time") - user_id: Optional[int] = Field(None, description="Created by user id (pk)") - dbcomputer_id: Optional[int] = Field( - None, description="Associated computer id (pk)" - ) + id: Optional[int] = Field(None, description='Unique id (pk)') + uuid: Optional[UUID] = Field(None, description='Universally unique identifier') + node_type: Optional[str] = Field(None, description='Node type') + process_type: Optional[str] = Field(None, description='Process type') + label: str = Field(description='Label of node') + description: Optional[str] = Field(None, description='Description of node') + ctime: Optional[datetime] = Field(None, description='Creation time') + mtime: Optional[datetime] = Field(None, description='Last modification time') + user_id: Optional[int] = Field(None, description='Created by user id (pk)') + dbcomputer_id: Optional[int] = Field(None, description='Associated computer id (pk)') attributes: Optional[Dict] = Field( None, - description="Variable attributes of the node", + description='Variable attributes of the node', ) extras: Optional[Dict] = Field( None, - description="Variable extras (unsealed) of the node", + description='Variable extras (unsealed) of the node', ) repository_metadata: Optional[Dict] = Field( None, - description="Metadata about file repository associated with this node", + description='Metadata about file repository associated with this node', ) class Process_Post(AiidaModel): """AiiDA Process Post Model""" - label: str = Field(description="Label of node") - inputs: dict = Field(description="Input parmeters") - process_entry_point: str = Field(description="Entry Point for process") + label: str = Field(description='Label of node') + inputs: dict = Field(description='Input parmeters') + process_entry_point: str = Field(description='Entry Point for process') diff --git a/aiida_restapi/routers/auth.py b/aiida_restapi/routers/auth.py index 01339ce..9daff7e 100644 --- a/aiida_restapi/routers/auth.py +++ b/aiida_restapi/routers/auth.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Handle API authentication and authorization.""" + # pylint: disable=missing-function-docstring,missing-class-docstring from datetime import datetime, timedelta from typing import Any, Dict, Optional @@ -28,20 +28,18 @@ class UserInDB(User): disabled: Optional[bool] = None -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto') -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl='token') router = APIRouter() def verify_password(plain_password: str, hashed_password: str) -> bool: - return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password: str) -> str: - return pwd_context.hash(password) @@ -53,15 +51,12 @@ def get_user(db: dict, email: str) -> Optional[UserInDB]: def authenticate_user(fake_db: dict, email: str, password: str) -> Optional[UserInDB]: - user = get_user(fake_db, email) if not user: - return None if not verify_password(password, user.hashed_password): - return None return user @@ -73,7 +68,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) - expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=15) - to_encode.update({"exp": expire}) + to_encode.update({'exp': expire}) encoded_jwt = jwt.encode(to_encode, config.SECRET_KEY, algorithm=config.ALGORITHM) return encoded_jwt @@ -81,12 +76,12 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) - async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, + detail='Could not validate credentials', + headers={'WWW-Authenticate': 'Bearer'}, ) try: payload = jwt.decode(token, config.SECRET_KEY, algorithms=[config.ALGORITHM]) - email: str = payload.get("sub") + email: str = payload.get('sub') if email is None: raise credentials_exception token_data = TokenData(email=email) @@ -102,31 +97,27 @@ async def get_current_active_user( current_user: UserInDB = Depends(get_current_user), ) -> UserInDB: if current_user.disabled: - raise HTTPException(status_code=400, detail="Inactive user") + raise HTTPException(status_code=400, detail='Inactive user') return current_user -@router.post("/token", response_model=Token) +@router.post('/token', response_model=Token) async def login_for_access_token( form_data: OAuth2PasswordRequestForm = Depends(), ) -> Dict[str, Any]: - user = authenticate_user( - config.fake_users_db, form_data.username, form_data.password - ) + user = authenticate_user(config.fake_users_db, form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect email or password", - headers={"WWW-Authenticate": "Bearer"}, + detail='Incorrect email or password', + headers={'WWW-Authenticate': 'Bearer'}, ) access_token_expires = timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = create_access_token( - data={"sub": user.email}, expires_delta=access_token_expires - ) - return {"access_token": access_token, "token_type": "bearer"} + access_token = create_access_token(data={'sub': user.email}, expires_delta=access_token_expires) + return {'access_token': access_token, 'token_type': 'bearer'} -@router.get("/auth/me/", response_model=User) +@router.get('/auth/me/', response_model=User) async def read_users_me(current_user: User = Depends(get_current_active_user)) -> User: return current_user diff --git a/aiida_restapi/routers/computers.py b/aiida_restapi/routers/computers.py index ba3f147..b97d466 100644 --- a/aiida_restapi/routers/computers.py +++ b/aiida_restapi/routers/computers.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Declaration of FastAPI application.""" + from typing import List, Optional from aiida import orm @@ -14,7 +14,7 @@ router = APIRouter() -@router.get("/computers", response_model=List[Computer]) +@router.get('/computers', response_model=List[Computer]) @with_dbenv() async def read_computers() -> List[Computer]: """Get list of all computers""" @@ -22,26 +22,24 @@ async def read_computers() -> List[Computer]: return Computer.get_entities() -@router.get("/computers/projectable_properties", response_model=List[str]) +@router.get('/computers/projectable_properties', response_model=List[str]) async def get_computers_projectable_properties() -> List[str]: """Get projectable properties for computers endpoint""" return Computer.get_projectable_properties() -@router.get("/computers/{comp_id}", response_model=Computer) +@router.get('/computers/{comp_id}', response_model=Computer) @with_dbenv() async def read_computer(comp_id: int) -> Optional[Computer]: """Get computer by id.""" qbobj = QueryBuilder() - qbobj.append( - orm.Computer, filters={"id": comp_id}, project="**", tag="computer" - ).limit(1) + qbobj.append(orm.Computer, filters={'id': comp_id}, project='**', tag='computer').limit(1) - return qbobj.dict()[0]["computer"] + return qbobj.dict()[0]['computer'] -@router.post("/computers", response_model=Computer) +@router.post('/computers', response_model=Computer) @with_dbenv() async def create_computer( computer: Computer, diff --git a/aiida_restapi/routers/daemon.py b/aiida_restapi/routers/daemon.py index f10cba6..8f2f425 100644 --- a/aiida_restapi/routers/daemon.py +++ b/aiida_restapi/routers/daemon.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Declaration of FastAPI router for daemon endpoints.""" + from __future__ import annotations import typing as t @@ -18,13 +18,11 @@ class DaemonStatusModel(BaseModel): """Response model for daemon status.""" - running: bool = Field(description="Whether the daemon is running or not.") - num_workers: t.Optional[int] = Field( - description="The number of workers if the daemon is running." - ) + running: bool = Field(description='Whether the daemon is running or not.') + num_workers: t.Optional[int] = Field(description='The number of workers if the daemon is running.') -@router.get("/daemon/status", response_model=DaemonStatusModel) +@router.get('/daemon/status', response_model=DaemonStatusModel) @with_dbenv() async def get_daemon_status() -> DaemonStatusModel: """Return the daemon status.""" @@ -35,10 +33,10 @@ async def get_daemon_status() -> DaemonStatusModel: response = client.get_numprocesses() - return DaemonStatusModel(running=True, num_workers=response["numprocesses"]) + return DaemonStatusModel(running=True, num_workers=response['numprocesses']) -@router.post("/daemon/start", response_model=DaemonStatusModel) +@router.post('/daemon/start', response_model=DaemonStatusModel) @with_dbenv() async def get_daemon_start( current_user: User = Depends( # pylint: disable=unused-argument @@ -49,7 +47,7 @@ async def get_daemon_start( client = get_daemon_client() if client.is_daemon_running: - raise HTTPException(status_code=400, detail="The daemon is already running.") + raise HTTPException(status_code=400, detail='The daemon is already running.') try: client.start_daemon() @@ -58,10 +56,10 @@ async def get_daemon_start( response = client.get_numprocesses() - return DaemonStatusModel(running=True, num_workers=response["numprocesses"]) + return DaemonStatusModel(running=True, num_workers=response['numprocesses']) -@router.post("/daemon/stop", response_model=DaemonStatusModel) +@router.post('/daemon/stop', response_model=DaemonStatusModel) @with_dbenv() async def get_daemon_stop( current_user: User = Depends( # pylint: disable=unused-argument @@ -72,7 +70,7 @@ async def get_daemon_stop( client = get_daemon_client() if not client.is_daemon_running: - raise HTTPException(status_code=400, detail="The daemon is not running.") + raise HTTPException(status_code=400, detail='The daemon is not running.') try: client.stop_daemon() diff --git a/aiida_restapi/routers/groups.py b/aiida_restapi/routers/groups.py index 59a2f5b..e7e017e 100644 --- a/aiida_restapi/routers/groups.py +++ b/aiida_restapi/routers/groups.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Declaration of FastAPI application.""" + from typing import List, Optional from aiida import orm @@ -13,7 +13,7 @@ router = APIRouter() -@router.get("/groups", response_model=List[Group]) +@router.get('/groups', response_model=List[Group]) @with_dbenv() async def read_groups() -> List[Group]: """Get list of all groups""" @@ -21,26 +21,24 @@ async def read_groups() -> List[Group]: return Group.get_entities() -@router.get("/groups/projectable_properties", response_model=List[str]) +@router.get('/groups/projectable_properties', response_model=List[str]) async def get_groups_projectable_properties() -> List[str]: """Get projectable properties for groups endpoint""" return Group.get_projectable_properties() -@router.get("/groups/{group_id}", response_model=Group) +@router.get('/groups/{group_id}', response_model=Group) @with_dbenv() async def read_group(group_id: int) -> Optional[Group]: """Get group by id.""" qbobj = orm.QueryBuilder() - qbobj.append(orm.Group, filters={"id": group_id}, project="**", tag="group").limit( - 1 - ) - return qbobj.dict()[0]["group"] + qbobj.append(orm.Group, filters={'id': group_id}, project='**', tag='group').limit(1) + return qbobj.dict()[0]['group'] -@router.post("/groups", response_model=Group) +@router.post('/groups', response_model=Group) @with_dbenv() async def create_group( group: Group_Post, diff --git a/aiida_restapi/routers/nodes.py b/aiida_restapi/routers/nodes.py index 3d79a70..42e5d62 100644 --- a/aiida_restapi/routers/nodes.py +++ b/aiida_restapi/routers/nodes.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Declaration of FastAPI application.""" + import json import os import tempfile @@ -20,30 +20,30 @@ router = APIRouter() -@router.get("/nodes", response_model=List[models.Node]) +@router.get('/nodes', response_model=List[models.Node]) @with_dbenv() async def read_nodes() -> List[models.Node]: """Get list of all nodes""" return models.Node.get_entities() -@router.get("/nodes/projectable_properties", response_model=List[str]) +@router.get('/nodes/projectable_properties', response_model=List[str]) async def get_nodes_projectable_properties() -> List[str]: """Get projectable properties for nodes endpoint""" return models.Node.get_projectable_properties() -@router.get("/nodes/{nodes_id}", response_model=models.Node) +@router.get('/nodes/{nodes_id}', response_model=models.Node) @with_dbenv() async def read_node(nodes_id: int) -> Optional[models.Node]: """Get nodes by id.""" qbobj = orm.QueryBuilder() - qbobj.append(orm.Node, filters={"id": nodes_id}, project="**", tag="node").limit(1) - return qbobj.dict()[0]["node"] + qbobj.append(orm.Node, filters={'id': nodes_id}, project='**', tag='node').limit(1) + return qbobj.dict()[0]['node'] -@router.post("/nodes", response_model=models.Node) +@router.post('/nodes', response_model=models.Node) @with_dbenv() async def create_node( node: models.Node_Post, @@ -53,10 +53,10 @@ async def create_node( ) -> models.Node: """Create new AiiDA node.""" node_dict = node.dict(exclude_unset=True) - entry_point = node_dict.pop("entry_point", None) + entry_point = node_dict.pop('entry_point', None) try: - cls = load_entry_point(group="aiida.data", name=entry_point) + cls = load_entry_point(group='aiida.data', name=entry_point) except EntryPointError as exception: raise HTTPException(status_code=404, detail=str(exception)) from exception @@ -68,7 +68,7 @@ async def create_node( return models.Node.from_orm(orm_object) -@router.post("/nodes/singlefile", response_model=models.Node) +@router.post('/nodes/singlefile', response_model=models.Node) @with_dbenv() async def create_upload_file( params: str = Form(...), @@ -90,34 +90,32 @@ async def create_upload_file( except json.JSONDecodeError as exception: raise HTTPException( status_code=400, - detail=f"Invalid JSON format: {str(exception)}", + detail=f'Invalid JSON format: {exception!s}', ) from exception except ValidationError as exception: raise HTTPException( status_code=422, - detail=f"Validation failed: {exception}", + detail=f'Validation failed: {exception}', ) from exception node_dict = params_obj.dict(exclude_unset=True) - entry_point = node_dict.pop("entry_point", None) + entry_point = node_dict.pop('entry_point', None) try: - cls = load_entry_point(group="aiida.data", name=entry_point) + cls = load_entry_point(group='aiida.data', name=entry_point) except EntryPointError as exception: raise HTTPException( status_code=404, - detail=f"Could not load entry point: {exception}", + detail=f'Could not load entry point: {exception}', ) from exception - with tempfile.NamedTemporaryFile(mode="wb", delete=False) as temp_file: + with tempfile.NamedTemporaryFile(mode='wb', delete=False) as temp_file: # Todo: read in chunks content = await upload_file.read() temp_file.write(content) temp_path = temp_file.name - orm_object = models.Node_Post.create_new_node_with_file( - cls, node_dict, Path(temp_path) - ) + orm_object = models.Node_Post.create_new_node_with_file(cls, node_dict, Path(temp_path)) # Clean up the temporary file if os.path.exists(temp_path): diff --git a/aiida_restapi/routers/process.py b/aiida_restapi/routers/process.py index ab7571c..905a468 100644 --- a/aiida_restapi/routers/process.py +++ b/aiida_restapi/routers/process.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Declaration of FastAPI router for processes.""" from typing import List, Optional @@ -27,28 +26,27 @@ def process_inputs(inputs: dict) -> dict: :returns: The deserialized inputs dictionary. :raises HTTPException: If the inputs contain a UUID that does not correspond to an existing node. """ - uuid_suffix = ".uuid" + uuid_suffix = '.uuid' results = {} for key, value in inputs.items(): if isinstance(value, dict): results[key] = process_inputs(value) + elif key.endswith(uuid_suffix): + try: + results[key[: -len(uuid_suffix)]] = orm.load_node(uuid=value) + except NotExistent as exc: + raise HTTPException( + status_code=404, + detail=f'Node with UUID `{value}` does not exist.', + ) from exc else: - if key.endswith(uuid_suffix): - try: - results[key[: -len(uuid_suffix)]] = orm.load_node(uuid=value) - except NotExistent as exc: - raise HTTPException( - status_code=404, - detail=f"Node with UUID `{value}` does not exist.", - ) from exc - else: - results[key] = value + results[key] = value return results -@router.get("/processes", response_model=List[Process]) +@router.get('/processes', response_model=List[Process]) @with_dbenv() async def read_processes() -> List[Process]: """Get list of all processes""" @@ -56,26 +54,24 @@ async def read_processes() -> List[Process]: return Process.get_entities() -@router.get("/processes/projectable_properties", response_model=List[str]) +@router.get('/processes/projectable_properties', response_model=List[str]) async def get_processes_projectable_properties() -> List[str]: """Get projectable properties for processes endpoint""" return Process.get_projectable_properties() -@router.get("/processes/{proc_id}", response_model=Process) +@router.get('/processes/{proc_id}', response_model=Process) @with_dbenv() async def read_process(proc_id: int) -> Optional[Process]: """Get process by id.""" qbobj = QueryBuilder() - qbobj.append( - orm.ProcessNode, filters={"id": proc_id}, project="**", tag="process" - ).limit(1) + qbobj.append(orm.ProcessNode, filters={'id': proc_id}, project='**', tag='process').limit(1) - return qbobj.dict()[0]["process"] + return qbobj.dict()[0]['process'] -@router.post("/processes", response_model=Process) +@router.post('/processes', response_model=Process) @with_dbenv() async def post_process( process: Process_Post, @@ -85,8 +81,8 @@ async def post_process( ) -> Optional[Process]: """Create new process.""" process_dict = process.dict(exclude_unset=True, exclude_none=True) - inputs = process_inputs(process_dict["inputs"]) - entry_point = process_dict.get("process_entry_point") + inputs = process_inputs(process_dict['inputs']) + entry_point = process_dict.get('process_entry_point') try: entry_point_process = load_entry_point_from_string(entry_point) diff --git a/aiida_restapi/routers/users.py b/aiida_restapi/routers/users.py index 147a754..2277f51 100644 --- a/aiida_restapi/routers/users.py +++ b/aiida_restapi/routers/users.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Declaration of FastAPI application.""" + from typing import List, Optional from aiida import orm @@ -14,31 +14,31 @@ router = APIRouter() -@router.get("/users", response_model=List[User]) +@router.get('/users', response_model=List[User]) @with_dbenv() async def read_users() -> List[User]: """Get list of all users""" return User.get_entities() -@router.get("/users/projectable_properties", response_model=List[str]) +@router.get('/users/projectable_properties', response_model=List[str]) async def get_users_projectable_properties() -> List[str]: """Get projectable properties for users endpoint""" return User.get_projectable_properties() -@router.get("/users/{user_id}", response_model=User) +@router.get('/users/{user_id}', response_model=User) @with_dbenv() async def read_user(user_id: int) -> Optional[User]: """Get user by id.""" qbobj = QueryBuilder() - qbobj.append(orm.User, filters={"id": user_id}, project="**", tag="user").limit(1) + qbobj.append(orm.User, filters={'id': user_id}, project='**', tag='user').limit(1) - return qbobj.dict()[0]["user"] + return qbobj.dict()[0]['user'] -@router.post("/users", response_model=User) +@router.post('/users', response_model=User) @with_dbenv() async def create_user( user: User, diff --git a/aiida_restapi/utils.py b/aiida_restapi/utils.py index 653878c..ec25655 100644 --- a/aiida_restapi/utils.py +++ b/aiida_restapi/utils.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """General utility functions.""" + import datetime from dateutil.parser import parser as date_parser diff --git a/docs/source/conf.py b/docs/source/conf.py index 6345dda..38ab465 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Sphinx configuration for aiida-restapi # @@ -26,30 +25,26 @@ load_documentation_profile() extensions = [ - "myst_parser", - "sphinx_external_toc", - "sphinx_panels", - "sphinx.ext.autodoc", - "sphinx.ext.intersphinx", - "sphinx.ext.viewcode", - "aiida_restapi.graphql.sphinx_ext", + 'myst_parser', + 'sphinx_external_toc', + 'sphinx_panels', + 'sphinx.ext.autodoc', + 'sphinx.ext.intersphinx', + 'sphinx.ext.viewcode', + 'aiida_restapi.graphql.sphinx_ext', ] # General information about the project. -project = "aiida-restapi" -copyright_first_year = "2021" -copyright_owners = "The AiiDA Team" +project = 'aiida-restapi' +copyright_first_year = '2021' +copyright_owners = 'The AiiDA Team' show_authors = True current_year = str(time.localtime().tm_year) copyright_year_string = ( - current_year - if current_year == copyright_first_year - else "{}-{}".format(copyright_first_year, current_year) + current_year if current_year == copyright_first_year else f'{copyright_first_year}-{current_year}' ) # pylint: disable=redefined-builtin -copyright = "{}, {}. All rights reserved".format( - copyright_year_string, copyright_owners -) +copyright = f'{copyright_year_string}, {copyright_owners}. All rights reserved' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -58,49 +53,49 @@ # The full version, including alpha/beta/rc tags. release = aiida_restapi.__version__ # The short X.Y version. -version = ".".join(release.split(".")[:2]) +version = '.'.join(release.split('.')[:2]) -myst_enable_extensions = ["replacements", "deflist", "colon_fence", "linkify"] +myst_enable_extensions = ['replacements', 'deflist', 'colon_fence', 'linkify'] -pygments_style = "sphinx" +pygments_style = 'sphinx' -html_theme = "sphinx_book_theme" +html_theme = 'sphinx_book_theme' html_theme_options = { - "home_page_in_toc": True, - "show_navbar_depth": 2, - "path_to_docs": "docs/source", + 'home_page_in_toc': True, + 'show_navbar_depth': 2, + 'path_to_docs': 'docs/source', } -html_title = "REST API" -html_logo = "images/AiiDA_transparent_logo.png" -html_use_opensearch = "http://aiida-restapi.readthedocs.io" -html_search_language = "en" +html_title = 'REST API' +html_logo = 'images/AiiDA_transparent_logo.png' +html_use_opensearch = 'http://aiida-restapi.readthedocs.io' +html_search_language = 'en' intersphinx_mapping = { - "python": ("https://docs.python.org/3", None), - "aiida": ("https://aiida-core.readthedocs.io/en/latest", None), + 'python': ('https://docs.python.org/3', None), + 'aiida': ('https://aiida-core.readthedocs.io/en/latest', None), # note pydantic and fastapi are not on sphinx - "graphql": ("https://graphql-core-3.readthedocs.io/en/latest", None), - "graphene": ("https://docs.graphene-python.org/en/latest", None), + 'graphql': ('https://graphql-core-3.readthedocs.io/en/latest', None), + 'graphene': ('https://docs.graphene-python.org/en/latest', None), } -autodoc_typehints = "none" +autodoc_typehints = 'none' nitpick_ignore = [ - ("py:class", name) + ('py:class', name) for name in [ - "pydantic.main.BaseModel", - "pydantic.types.Json", - "graphene.types.generic.GenericScalar", - "graphene.types.objecttype.ObjectType", - "graphene.types.scalars.String", - "aiida_restapi.aiida_db_mappings.Config", - "aiida_restapi.models.Config", - "aiida_restapi.routers.auth.Config", - "aiida_restapi.graphql.orm_factories.AiidaOrmObjectType", - "aiida_restapi.graphql.nodes.LinkObjectType", - "aiida_restapi.graphql.orm_factories.multirow_cls_factory..AiidaOrmRowsType", + 'pydantic.main.BaseModel', + 'pydantic.types.Json', + 'graphene.types.generic.GenericScalar', + 'graphene.types.objecttype.ObjectType', + 'graphene.types.scalars.String', + 'aiida_restapi.aiida_db_mappings.Config', + 'aiida_restapi.models.Config', + 'aiida_restapi.routers.auth.Config', + 'aiida_restapi.graphql.orm_factories.AiidaOrmObjectType', + 'aiida_restapi.graphql.nodes.LinkObjectType', + 'aiida_restapi.graphql.orm_factories.multirow_cls_factory..AiidaOrmRowsType', ] ] -suppress_warnings = ["etoc.toctree"] +suppress_warnings = ['etoc.toctree'] def run_apidoc(_): @@ -112,8 +107,8 @@ def run_apidoc(_): See also https://github.com/rtfd/readthedocs.org/issues/1139 """ source_dir = os.path.abspath(os.path.dirname(__file__)) - apidoc_dir = os.path.join(source_dir, "apidoc") - package_dir = os.path.join(source_dir, os.pardir, os.pardir, "aiida_restapi") + apidoc_dir = os.path.join(source_dir, 'apidoc') + package_dir = os.path.join(source_dir, os.pardir, os.pardir, 'aiida_restapi') # In #1139, they suggest the route below, but this ended up # calling sphinx-build, not sphinx-apidoc @@ -122,27 +117,25 @@ def run_apidoc(_): import subprocess - cmd_path = "sphinx-apidoc" - if hasattr(sys, "real_prefix"): # Check to see if we are in a virtualenv + cmd_path = 'sphinx-apidoc' + if hasattr(sys, 'real_prefix'): # Check to see if we are in a virtualenv # If we are, assemble the path manually - cmd_path = os.path.abspath(os.path.join(sys.prefix, "bin", "sphinx-apidoc")) + cmd_path = os.path.abspath(os.path.join(sys.prefix, 'bin', 'sphinx-apidoc')) options = [ - "-o", + '-o', apidoc_dir, package_dir, - "--private", - "--force", - "--no-toc", + '--private', + '--force', + '--no-toc', ] # See https://stackoverflow.com/a/30144019 env = os.environ.copy() - env[ - "SPHINX_APIDOC_OPTIONS" - ] = "members,special-members,private-members,undoc-members,show-inheritance" + env['SPHINX_APIDOC_OPTIONS'] = 'members,special-members,private-members,undoc-members,show-inheritance' subprocess.check_call([cmd_path] + options, env=env) def setup(app): - app.connect("builder-inited", run_apidoc) + app.connect('builder-inited', run_apidoc) diff --git a/examples/daemon_management/script.py b/examples/daemon_management/script.py index ecb9058..e362996 100755 --- a/examples/daemon_management/script.py +++ b/examples/daemon_management/script.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """Example script to demonstrate daemon management over the web API.""" + from __future__ import annotations import os @@ -9,7 +9,7 @@ import click import requests -BASE_URL = "http://127.0.0.1:8000" +BASE_URL = 'http://127.0.0.1:8000' def echo_error(message: str) -> None: @@ -17,7 +17,7 @@ def echo_error(message: str) -> None: :param message: The error message to echo. """ - click.echo(click.style("Error: ", fg="red", bold=True), nl=False) + click.echo(click.style('Error: ', fg='red', bold=True), nl=False) click.echo(message) @@ -25,7 +25,7 @@ def request( url, json: dict[str, t.Any] | None = None, data: dict[str, t.Any] | None = None, - method="POST", + method='POST', ) -> dict[str, t.Any] | None: """Perform a request to the web API of ``aiida-restapi``. @@ -37,16 +37,16 @@ def request( :param method: The request method, POST by default. :returns: The response in JSON or ``None``. """ - access_token = os.getenv("ACCESS_TOKEN", None) + access_token = os.getenv('ACCESS_TOKEN', None) if access_token: - headers = {"Authorization": f"Bearer {access_token}"} + headers = {'Authorization': f'Bearer {access_token}'} else: headers = {} response = requests.request( # pylint: disable=missing-timeout method, - f"{BASE_URL}/{url}", + f'{BASE_URL}/{url}', json=json, data=data, headers=headers, @@ -57,21 +57,19 @@ def request( except requests.HTTPError: results = response.json() - echo_error(f"{response.status_code} {response.reason}") + echo_error(f'{response.status_code} {response.reason}') - if "detail" in results: - echo_error(results["detail"]) + if 'detail' in results: + echo_error(results['detail']) - for error in results.get("errors", []): - click.echo(error["message"]) + for error in results.get('errors', []): + click.echo(error['message']) return None return response.json() -def authenticate( - username: str = "johndoe@example.com", password: str = "secret" -) -> str | None: +def authenticate(username: str = 'johndoe@example.com', password: str = 'secret') -> str | None: """Authenticate with the web API to obtain an access token. Note that if authentication is successful, the access token is stored in the ``ACCESS_TOKEN`` environment variable. @@ -80,11 +78,11 @@ def authenticate( :param password: The password. :returns: The access token or ``None`` if authentication was unsuccessful. """ - results = request("token", data={"username": username, "password": password}) + results = request('token', data={'username': username, 'password': password}) if results: - access_token = results["access_token"] - os.environ["ACCESS_TOKEN"] = access_token + access_token = results['access_token'] + os.environ['ACCESS_TOKEN'] = access_token return access_token return None @@ -96,25 +94,25 @@ def main(): token = authenticate() if token is None: - echo_error("Could not authenticate with the API, aborting") + echo_error('Could not authenticate with the API, aborting') return - response = request("daemon/status", method="GET") + response = request('daemon/status', method='GET') - if response["running"]: - num_workers = response["num_workers"] - click.echo(f"The daemon is currently running with {num_workers} workers") + if response['running']: + num_workers = response['num_workers'] + click.echo(f'The daemon is currently running with {num_workers} workers') - click.echo("Stopping the daemon.") - response = request("daemon/stop", method="POST") + click.echo('Stopping the daemon.') + response = request('daemon/stop', method='POST') else: - click.echo("The daemon is currently not running.") - click.echo("Starting the daemon.") - response = request("daemon/start", method="POST") - num_workers = response["num_workers"] - click.echo(f"The daemon is currently running with {num_workers} workers") + click.echo('The daemon is currently not running.') + click.echo('Starting the daemon.') + response = request('daemon/start', method='POST') + num_workers = response['num_workers'] + click.echo(f'The daemon is currently running with {num_workers} workers') -if __name__ == "__main__": +if __name__ == '__main__': main() # pylint: disable=no-value-for-parameter diff --git a/examples/process_management/script.py b/examples/process_management/script.py index 1d04672..b44e926 100755 --- a/examples/process_management/script.py +++ b/examples/process_management/script.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """Example script to demonstrate process management over the web API.""" + from __future__ import annotations import os @@ -10,7 +10,7 @@ import click import requests -BASE_URL = "http://127.0.0.1:8000" +BASE_URL = 'http://127.0.0.1:8000' def echo_error(message: str) -> None: @@ -18,7 +18,7 @@ def echo_error(message: str) -> None: :param message: The error message to echo. """ - click.echo(click.style("Error: ", fg="red", bold=True), nl=False) + click.echo(click.style('Error: ', fg='red', bold=True), nl=False) click.echo(message) @@ -26,7 +26,7 @@ def request( url, json: dict[str, t.Any] | None = None, data: dict[str, t.Any] | None = None, - method="POST", + method='POST', ) -> dict[str, t.Any] | None: """Perform a request to the web API of ``aiida-restapi``. @@ -38,15 +38,15 @@ def request( :param method: The request method, POST by default. :returns: The response in JSON or ``None``. """ - access_token = os.getenv("ACCESS_TOKEN", None) + access_token = os.getenv('ACCESS_TOKEN', None) if access_token: - headers = {"Authorization": f"Bearer {access_token}"} + headers = {'Authorization': f'Bearer {access_token}'} else: headers = {} response = requests.request( # pylint: disable=missing-timeout - method, f"{BASE_URL}/{url}", json=json, data=data, headers=headers + method, f'{BASE_URL}/{url}', json=json, data=data, headers=headers ) try: @@ -54,21 +54,19 @@ def request( except requests.HTTPError: results = response.json() - echo_error(f"{response.status_code} {response.reason}") + echo_error(f'{response.status_code} {response.reason}') - if "detail" in results: - echo_error(results["detail"]) + if 'detail' in results: + echo_error(results['detail']) - for error in results.get("errors", []): - click.echo(error["message"]) + for error in results.get('errors', []): + click.echo(error['message']) return None return response.json() -def authenticate( - username: str = "johndoe@example.com", password: str = "secret" -) -> str | None: +def authenticate(username: str = 'johndoe@example.com', password: str = 'secret') -> str | None: """Authenticate with the web API to obtain an access token. Note that if authentication is successful, the access token is stored in the ``ACCESS_TOKEN`` environment variable. @@ -77,11 +75,11 @@ def authenticate( :param password: The password. :returns: The access token or ``None`` if authentication was unsuccessful. """ - results = request("token", data={"username": username, "password": password}) + results = request('token', data={'username': username, 'password': password}) if results: - access_token = results["access_token"] - os.environ["ACCESS_TOKEN"] = access_token + access_token = results['access_token'] + os.environ['ACCESS_TOKEN'] = access_token return access_token return None @@ -95,13 +93,13 @@ def create_node(entry_point: str, attributes: dict[str, t.Any]) -> str | None: :returns: The UUID of the created node or ``None`` if it failed. """ data = { - "entry_point": entry_point, - "attributes": attributes, + 'entry_point': entry_point, + 'attributes': attributes, } - result = request("nodes", data) + result = request('nodes', data) if result: - return result["uuid"] + return result['uuid'] return None @@ -114,7 +112,7 @@ def get_code(default_calc_job_plugin: str) -> dict[str, t.Any] | None: :param default_calc_job_plugin: The default calculation job plugin the code should have. :raises ValueError: If no code could be found. """ - variables = {"default_calc_job_plugin": default_calc_job_plugin} + variables = {'default_calc_job_plugin': default_calc_job_plugin} query = """ { nodes(filters: "node_type ILIKE 'data.core.code.installed.InstalledCode%'") { @@ -126,21 +124,19 @@ def get_code(default_calc_job_plugin: str) -> dict[str, t.Any] | None: } } """ - results = request("graphql", {"query": query, "variables": variables}) + results = request('graphql', {'query': query, 'variables': variables}) if results is None: return None node = None - for row in results["data"]["nodes"]["rows"]: - if row["attributes"]["input_plugin"] == default_calc_job_plugin: + for row in results['data']['nodes']['rows']: + if row['attributes']['input_plugin'] == default_calc_job_plugin: node = row if node is None: - raise ValueError( - f"No code with default calculation job plugin `{default_calc_job_plugin}` found." - ) + raise ValueError(f'No code with default calculation job plugin `{default_calc_job_plugin}` found.') return node @@ -169,16 +165,16 @@ def get_outputs(process_id: int) -> dict[str, t.Any]: } } """ - variables = {"process_id": process_id} - results = request("graphql", {"query": query, "variables": variables}) + variables = {'process_id': process_id} + results = request('graphql', {'query': query, 'variables': variables}) outputs = {} - for value in results["data"]["node"]["outgoing"]["rows"]: - link_label = value["link"]["label"] + for value in results['data']['node']['outgoing']['rows']: + link_label = value['link']['label'] outputs[link_label] = { - "uuid": value["node"]["uuid"], - "attributes": value["node"]["attributes"], + 'uuid': value['node']['uuid'], + 'attributes': value['node']['attributes'], } return outputs @@ -190,44 +186,44 @@ def main(): token = authenticate() if token is None: - echo_error("Could not authenticate with the API, aborting") + echo_error('Could not authenticate with the API, aborting') return # Inputs for a ``ArithmeticAddCalculation`` inputs = { - "label": "Launched over the web API", - "process_entry_point": "aiida.calculations:core.arithmetic.add", - "inputs": { - "code.uuid": get_code("core.arithmetic.add")["uuid"], - "x.uuid": create_node("core.int", {"value": 1}), - "y.uuid": create_node("core.int", {"value": 1}), + 'label': 'Launched over the web API', + 'process_entry_point': 'aiida.calculations:core.arithmetic.add', + 'inputs': { + 'code.uuid': get_code('core.arithmetic.add')['uuid'], + 'x.uuid': create_node('core.int', {'value': 1}), + 'y.uuid': create_node('core.int', {'value': 1}), }, } - results = request("processes", json=inputs) + results = request('processes', json=inputs) click.echo(f'Successfuly submitted process with pk<{results["id"]}>') - process_id = results["id"] + process_id = results['id'] while True: - results = request(f"processes/{process_id}", method="GET") - process_state = results["attributes"]["process_state"] + results = request(f'processes/{process_id}', method='GET') + process_state = results['attributes']['process_state'] - if process_state in ["finished", "excepted", "killed"]: + if process_state in ['finished', 'excepted', 'killed']: break time.sleep(2) - click.echo(f"Calculation terminated with state `{process_state}`") + click.echo(f'Calculation terminated with state `{process_state}`') results = get_outputs(process_id) - click.echo("Output nodes:") + click.echo('Output nodes:') for key, value in results.items(): click.echo(f"* {key}: UUID<{value['uuid']}>") click.echo(f"Computed sum: {results['sum']['attributes']['value']}") -if __name__ == "__main__": +if __name__ == '__main__': main() # pylint: disable=no-value-for-parameter diff --git a/examples/submit_quantumespresso_pw/script.py b/examples/submit_quantumespresso_pw/script.py index cd5c9e7..8464ccb 100755 --- a/examples/submit_quantumespresso_pw/script.py +++ b/examples/submit_quantumespresso_pw/script.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """Example script to submit a ``PwCalculation`` or ``PwBaseWorkChain`` over the web API.""" + from __future__ import annotations import os @@ -9,7 +9,7 @@ import click import requests -BASE_URL = "http://127.0.0.1:8000" +BASE_URL = 'http://127.0.0.1:8000' def echo_error(message: str) -> None: @@ -17,7 +17,7 @@ def echo_error(message: str) -> None: :param message: The error message to echo. """ - click.echo(click.style("Error: ", fg="red", bold=True), nl=False) + click.echo(click.style('Error: ', fg='red', bold=True), nl=False) click.echo(message) @@ -25,7 +25,7 @@ def request( url, json: dict[str, t.Any] | None = None, data: dict[str, t.Any] | None = None, - method="POST", + method='POST', ) -> dict[str, t.Any] | None: """Perform a request to the web API of ``aiida-restapi``. @@ -37,15 +37,15 @@ def request( :param method: The request method, POST by default. :returns: The response in JSON or ``None``. """ - access_token = os.getenv("ACCESS_TOKEN", None) + access_token = os.getenv('ACCESS_TOKEN', None) if access_token: - headers = {"Authorization": f"Bearer {access_token}"} + headers = {'Authorization': f'Bearer {access_token}'} else: headers = {} response = requests.request( # pylint: disable=missing-timeout - method, f"{BASE_URL}/{url}", json=json, data=data, headers=headers + method, f'{BASE_URL}/{url}', json=json, data=data, headers=headers ) try: @@ -53,21 +53,19 @@ def request( except requests.HTTPError: results = response.json() - echo_error(f"{response.status_code} {response.reason}") + echo_error(f'{response.status_code} {response.reason}') - if "detail" in results: - echo_error(results["detail"]) + if 'detail' in results: + echo_error(results['detail']) - for error in results.get("errors", []): - click.echo(error["message"]) + for error in results.get('errors', []): + click.echo(error['message']) return None return response.json() -def authenticate( - username: str = "johndoe@example.com", password: str = "secret" -) -> str | None: +def authenticate(username: str = 'johndoe@example.com', password: str = 'secret') -> str | None: """Authenticate with the web API to obtain an access token. Note that if authentication is successful, the access token is stored in the ``ACCESS_TOKEN`` environment variable. @@ -76,11 +74,11 @@ def authenticate( :param password: The password. :returns: The access token or ``None`` if authentication was unsuccessful. """ - results = request("token", data={"username": username, "password": password}) + results = request('token', data={'username': username, 'password': password}) if results: - access_token = results["access_token"] - os.environ["ACCESS_TOKEN"] = access_token + access_token = results['access_token'] + os.environ['ACCESS_TOKEN'] = access_token return access_token return None @@ -92,7 +90,7 @@ def get_pseudo_family_pseudos(pseudo_family_label: str) -> dict[str, t.Any] | No :param pseudo_family_label: The label of the pseudopotential family. :returns: The pseudopotential family and the pseudos it contains. """ - variables = {"label": pseudo_family_label} + variables = {'label': pseudo_family_label} query = """ query function($label: String) { group(label: $label) { @@ -107,17 +105,15 @@ def get_pseudo_family_pseudos(pseudo_family_label: str) -> dict[str, t.Any] | No } } """ - results = request("graphql", {"query": query, "variables": variables}) + results = request('graphql', {'query': query, 'variables': variables}) if results: - return results["data"]["group"] + return results['data']['group'] return None -def get_pseudo_for_element( - pseudo_family_label: str, element: str -) -> dict[str, t.Any] | None: +def get_pseudo_for_element(pseudo_family_label: str, element: str) -> dict[str, t.Any] | None: """Return the pseudo potential for a given pseudo potential family and element. :param pseudo_family_label: The label of the pseudopotential family. @@ -131,8 +127,8 @@ def get_pseudo_for_element( pseudo = None - for row in family["nodes"]["rows"]: - if row["attributes"]["element"] == element: + for row in family['nodes']['rows']: + if row['attributes']['element'] == element: pseudo = row break @@ -147,13 +143,13 @@ def create_node(entry_point: str, attributes: dict[str, t.Any]) -> str | None: :returns: The UUID of the created node or ``None`` if it failed. """ data = { - "entry_point": entry_point, - "attributes": attributes, + 'entry_point': entry_point, + 'attributes': attributes, } - result = request("nodes", data) + result = request('nodes', data) if result: - return result["uuid"] + return result['uuid'] return None @@ -166,7 +162,7 @@ def get_code(default_calc_job_plugin: str) -> dict[str, t.Any] | None: :param default_calc_job_plugin: The default calculation job plugin the code should have. :raises ValueError: If no code could be found. """ - variables = {"default_calc_job_plugin": default_calc_job_plugin} + variables = {'default_calc_job_plugin': default_calc_job_plugin} query = """ { nodes(filters: "node_type ILIKE 'data.core.code.installed.InstalledCode%'") { @@ -178,81 +174,79 @@ def get_code(default_calc_job_plugin: str) -> dict[str, t.Any] | None: } } """ - results = request("graphql", {"query": query, "variables": variables}) + results = request('graphql', {'query': query, 'variables': variables}) if results is None: return None node = None - for row in results["data"]["nodes"]["rows"]: - if row["attributes"]["input_plugin"] == default_calc_job_plugin: + for row in results['data']['nodes']['rows']: + if row['attributes']['input_plugin'] == default_calc_job_plugin: node = row if node is None: - raise ValueError( - f"No code with default calculation job plugin `{default_calc_job_plugin}` found." - ) + raise ValueError(f'No code with default calculation job plugin `{default_calc_job_plugin}` found.') return node @click.command() @click.option( - "--workchain", + '--workchain', is_flag=True, - help="Submit a ``PwBaseWorkChain`` instead of a ``PwCalculation``.", + help='Submit a ``PwBaseWorkChain`` instead of a ``PwCalculation``.', ) def main(workchain): """Authenticate with the web API and submit a ``PwCalculation`` or ``PwBaseWorkChain``.""" token = authenticate() if token is None: - echo_error("Could not authenticate with the API, aborting") + echo_error('Could not authenticate with the API, aborting') return - kpoints = {"mesh": [2, 2, 2], "offset": [0, 0, 0]} + kpoints = {'mesh': [2, 2, 2], 'offset': [0, 0, 0]} parameters = { - "CONTROL": { - "calculation": "scf", + 'CONTROL': { + 'calculation': 'scf', }, - "SYSTEM": { - "ecutwfc": 30, - "ecutrho": 240, + 'SYSTEM': { + 'ecutwfc': 30, + 'ecutrho': 240, }, } structure = { - "cell": [[0.0, 2.715, 2.715], [2.715, 0.0, 2.715], [2.715, 2.715, 0.0]], - "kinds": [{"mass": 28.085, "name": "Si", "symbols": ["Si"], "weights": [1.0]}], - "pbc1": True, - "pbc2": True, - "pbc3": True, - "sites": [ - {"kind_name": "Si", "position": [0.0, 0.0, 0.0]}, - {"kind_name": "Si", "position": [1.3575, 1.3575, 1.3575]}, + 'cell': [[0.0, 2.715, 2.715], [2.715, 0.0, 2.715], [2.715, 2.715, 0.0]], + 'kinds': [{'mass': 28.085, 'name': 'Si', 'symbols': ['Si'], 'weights': [1.0]}], + 'pbc1': True, + 'pbc2': True, + 'pbc3': True, + 'sites': [ + {'kind_name': 'Si', 'position': [0.0, 0.0, 0.0]}, + {'kind_name': 'Si', 'position': [1.3575, 1.3575, 1.3575]}, ], } - code_uuid = get_code("quantumespresso.pw")["uuid"] - structure_uuid = create_node("core.structure", structure) - parameters_uuid = create_node("core.dict", parameters) - kpoints_uuid = create_node("core.array.kpoints", kpoints) - pseudo_si_uuid = get_pseudo_for_element("SSSP/1.2/PBE/efficiency", "Si")["uuid"] + code_uuid = get_code('quantumespresso.pw')['uuid'] + structure_uuid = create_node('core.structure', structure) + parameters_uuid = create_node('core.dict', parameters) + kpoints_uuid = create_node('core.array.kpoints', kpoints) + pseudo_si_uuid = get_pseudo_for_element('SSSP/1.2/PBE/efficiency', 'Si')['uuid'] if workchain: # Inputs for a ``PwBaseWorkChain`` to compute SCF of Si crystal structure inputs = { - "label": "PwCalculation over REST API", - "process_entry_point": "aiida.workflows:quantumespresso.pw.base", - "inputs": { - "kpoints.uuid": kpoints_uuid, - "pw": { - "code.uuid": code_uuid, - "structure.uuid": structure_uuid, - "parameters.uuid": parameters_uuid, - "pseudos": { - "Si.uuid": pseudo_si_uuid, + 'label': 'PwCalculation over REST API', + 'process_entry_point': 'aiida.workflows:quantumespresso.pw.base', + 'inputs': { + 'kpoints.uuid': kpoints_uuid, + 'pw': { + 'code.uuid': code_uuid, + 'structure.uuid': structure_uuid, + 'parameters.uuid': parameters_uuid, + 'pseudos': { + 'Si.uuid': pseudo_si_uuid, }, }, }, @@ -260,22 +254,22 @@ def main(workchain): else: # Inputs for a ``PwCalculation`` to compute SCF of Si crystal structure inputs = { - "label": "PwCalculation over REST API", - "process_entry_point": "aiida.calculations:quantumespresso.pw", - "inputs": { - "code.uuid": code_uuid, - "structure.uuid": structure_uuid, - "parameters.uuid": parameters_uuid, - "kpoints.uuid": kpoints_uuid, - "pseudos": { - "Si.uuid": pseudo_si_uuid, + 'label': 'PwCalculation over REST API', + 'process_entry_point': 'aiida.calculations:quantumespresso.pw', + 'inputs': { + 'code.uuid': code_uuid, + 'structure.uuid': structure_uuid, + 'parameters.uuid': parameters_uuid, + 'kpoints.uuid': kpoints_uuid, + 'pseudos': { + 'Si.uuid': pseudo_si_uuid, }, }, } - results = request("processes", json=inputs) + results = request('processes', json=inputs) click.echo(f'Successfuly submitted process with pk<{results["id"]}>') -if __name__ == "__main__": +if __name__ == '__main__': main() # pylint: disable=no-value-for-parameter diff --git a/pyproject.toml b/pyproject.toml index 667f209..5d1ac28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,86 +1,112 @@ [build-system] -requires = ['flit_core>=3.4,<4'] build-backend = 'flit_core.buildapi' +requires = ['flit_core>=3.4,<4'] [project] -name = 'aiida-restapi' -dynamic = ['description', 'version'] authors = [ - {name = 'The AiiDA Team', email = 'developers@aiida.net'} + {name = 'The AiiDA Team', email = 'developers@aiida.net'} ] -readme = 'README.md' -license = {file = 'LICENSE.txt'} classifiers = [ - 'Development Status :: 3 - Alpha', - 'Framework :: AiiDA', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Topic :: Scientific/Engineering' + 'Development Status :: 3 - Alpha', + 'Framework :: AiiDA', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Topic :: Scientific/Engineering' ] -keywords = ['aiida', 'workflows'] -requires-python = '>=3.9' dependencies = [ - 'aiida-core~=2.0', - 'fastapi~=0.115.5', - 'uvicorn[standard]~=0.19.0', - 'pydantic~=2.0', - 'starlette-graphene3~=0.6.0', - 'graphene~=3.0', - 'python-dateutil~=2.0', - 'lark~=0.11.0', + 'aiida-core~=2.0', + 'fastapi~=0.115.5', + 'uvicorn[standard]~=0.19.0', + 'pydantic~=2.0', + 'starlette-graphene3~=0.6.0', + 'graphene~=3.0', + 'python-dateutil~=2.0', + 'lark~=0.11.0' ] - -[project.urls] -Source = 'https://github.com/aiidateam/aiida-restapi' +dynamic = ['description', 'version'] +keywords = ['aiida', 'workflows'] +license = {file = 'LICENSE.txt'} +name = 'aiida-restapi' +readme = 'README.md' +requires-python = '>=3.9' [project.optional-dependencies] auth = [ - 'python-jose', - 'python-multipart', - 'passlib', + 'python-jose', + 'python-multipart', + 'passlib' ] docs = [ - 'sphinx', - 'myst-parser[linkify]>=0.13.7', - 'sphinx-external-toc', - 'sphinx-book-theme', - 'sphinx-panels', - 'pygments-graphql', + 'sphinx', + 'myst-parser[linkify]>=0.13.7', + 'sphinx-external-toc', + 'sphinx-book-theme', + 'sphinx-panels', + 'pygments-graphql' ] pre-commit = [ - 'pre-commit~=2.12' + 'pre-commit~=2.12' ] testing = [ - 'aiida-restapi[auth]', - 'pgtest~=1.3.1', - 'wheel~=0.31', - 'coverage', - 'pytest~=6.2', - 'pytest-regressions', - 'pytest-cov', - 'requests', - 'httpx', + 'aiida-restapi[auth]', + 'pgtest~=1.3.1', + 'wheel~=0.31', + 'coverage', + 'pytest~=6.2', + 'pytest-regressions', + 'pytest-cov', + 'requests', + 'httpx' ] +[project.urls] +Source = 'https://github.com/aiidateam/aiida-restapi' + [tool.flit.module] name = 'aiida_restapi' [tool.flit.sdist] exclude = [ - '.github/', - 'docs/', - 'examples/', - 'tests/', - '.coveragerc', - '.gitignore', - '.pre-commit-config.yaml', - '.readthedocs.yml', - 'codecov.yml', + '.github/', + 'docs/', + 'examples/', + 'tests/', + '.coveragerc', + '.gitignore', + '.pre-commit-config.yaml', + '.readthedocs.yml', + 'codecov.yml' +] + +[tool.mypy] +disallow_incomplete_defs = true +disallow_untyped_defs = true +no_implicit_optional = true +plugins = ['pydantic.mypy'] +show_error_codes = true +strict_equality = true +warn_redundant_casts = true +warn_unused_ignores = true + +[[tool.pydantic.mypy.overrides]] +init_forbid_extra = true +init_typed = true +warn_required_dynamic_aliases = true +warn_untyped_fields = false + +[tool.pytest.ini_options] +filterwarnings = [ + 'ignore:Creating AiiDA configuration folder.*:UserWarning', + 'ignore::DeprecationWarning:aiida:', + 'ignore::DeprecationWarning:plumpy:', + 'ignore::DeprecationWarning:django:', + 'ignore::DeprecationWarning:yaml:' ] +python_files = 'test_*.py example_*.py' [tool.ruff] line-length = 120 @@ -98,7 +124,7 @@ ignore = [ 'PLR0915', # Too many statements 'PLR2004', # Magic value used in comparison 'RUF005', # Consider iterable unpacking instead of concatenation - 'RUF012', # Mutable class attributes should be annotated with `typing.ClassVar` + 'RUF012' # Mutable class attributes should be annotated with `typing.ClassVar` ] select = [ 'E', # pydocstyle @@ -112,32 +138,6 @@ select = [ 'RUF' # ruff ] -[tool.pytest.ini_options] -python_files = 'test_*.py example_*.py' -filterwarnings = [ - 'ignore:Creating AiiDA configuration folder.*:UserWarning', - 'ignore::DeprecationWarning:aiida:', - 'ignore::DeprecationWarning:plumpy:', - 'ignore::DeprecationWarning:django:', - 'ignore::DeprecationWarning:yaml:', -] - -[tool.mypy] -show_error_codes = true -disallow_untyped_defs = true -disallow_incomplete_defs = true -warn_unused_ignores = true -warn_redundant_casts = true -no_implicit_optional = true -strict_equality = true -plugins = ['pydantic.mypy'] - -[[tool.pydantic.mypy.overrides]] -init_forbid_extra = true -init_typed = true -warn_required_dynamic_aliases = true -warn_untyped_fields = false - [tool.tox] legacy_tox_ini = """ [tox] diff --git a/tests/__init__.py b/tests/__init__.py index 424be79..484b467 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,9 +1,9 @@ -# -*- coding: utf-8 -*- -""" Tests for the plugin. +"""Tests for the plugin. Includes both tests written in unittest style (test_cli.py) and tests written in pytest style (test_calculations.py). """ + import os TEST_DIR = os.path.dirname(os.path.realpath(__file__)) diff --git a/tests/conftest.py b/tests/conftest.py index 819b053..c39c403 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Test fixtures specific to this package.""" + # pylint: disable=too-many-arguments import tempfile from datetime import datetime @@ -17,15 +17,15 @@ from aiida_restapi import app, config from aiida_restapi.routers.auth import UserInDB, get_current_user -pytest_plugins = ["aiida.manage.tests.pytest_fixtures"] +pytest_plugins = ['aiida.manage.tests.pytest_fixtures'] -@pytest.fixture(scope="function", autouse=True) +@pytest.fixture(scope='function', autouse=True) def clear_database_auto(aiida_profile_clean): # pylint: disable=unused-argument """Automatically clear database in between tests.""" -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def client(): """Return fastapi test client.""" yield TestClient(app) @@ -38,57 +38,52 @@ def anyio_backend(): Returns: str: The name of the backend to use """ - return "asyncio" + return 'asyncio' -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') async def async_client(): """Return fastapi async test client.""" - async with AsyncClient(app=app, base_url="http://test") as async_test_client: + async with AsyncClient(app=app, base_url='http://test') as async_test_client: yield async_test_client -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def default_users(): """Populate database with some users.""" - user_1 = orm.User( - email="verdi@opera.net", first_name="Giuseppe", last_name="Verdi" - ).store() - user_2 = orm.User( - email="stravinsky@symphony.org", first_name="Igor", last_name="Stravinsky" - ).store() + user_1 = orm.User(email='verdi@opera.net', first_name='Giuseppe', last_name='Verdi').store() + user_2 = orm.User(email='stravinsky@symphony.org', first_name='Igor', last_name='Stravinsky').store() return [user_1.pk, user_2.pk] -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def default_computers(): """Populate database with some computer""" comp_1 = orm.Computer( - label="test_comp_1", - hostname="localhost_1", - transport_type="core.local", - scheduler_type="core.pbspro", + label='test_comp_1', + hostname='localhost_1', + transport_type='core.local', + scheduler_type='core.pbspro', ).store() comp_2 = orm.Computer( - label="test_comp_2", - hostname="localhost_2", - transport_type="core.local", - scheduler_type="core.pbspro", + label='test_comp_2', + hostname='localhost_2', + transport_type='core.local', + scheduler_type='core.pbspro', ).store() return [comp_1.pk, comp_2.pk] -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def example_processes(): """Populate database with some processes""" calcs = [] - process_label = "SomeDummyWorkFunctionNode" + process_label = 'SomeDummyWorkFunctionNode' # Create 6 WorkFunctionNodes and WorkChainNodes (one for each ProcessState) for state in ProcessState: - calc = WorkFunctionNode() calc.set_process_state(state) @@ -97,7 +92,7 @@ def example_processes(): calc.set_exit_status(0) # Give a `process_label` to the `WorkFunctionNodes` so the `--process-label` option can be tested - calc.base.attributes.set("process_label", process_label) + calc.base.attributes.set('process_label', process_label) calc.store() calcs.append(calc.pk) @@ -118,27 +113,27 @@ def example_processes(): return calcs -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def default_test_add_process(): """Populate database with some node to test adding process""" workdir = tempfile.mkdtemp() computer = orm.Computer( - label="localhost", - hostname="localhost", + label='localhost', + hostname='localhost', workdir=workdir, - transport_type="core.local", - scheduler_type="core.direct", + transport_type='core.local', + scheduler_type='core.direct', ) computer.store() computer.set_minimum_job_poll_interval(0.0) computer.configure() code = orm.InstalledCode( - default_calc_job_plugin="core.arithmetic.add", + default_calc_job_plugin='core.arithmetic.add', computer=computer, - filepath_executable="/bin/true", + filepath_executable='/bin/true', ).store() x = orm.Int(1).store() @@ -148,32 +143,28 @@ def default_test_add_process(): return [code.uuid, x.uuid, y.uuid] -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def default_groups(): """Populate database with some groups.""" - test_user_1 = orm.User( - email="verdi@opera.net", first_name="Giuseppe", last_name="Verdi" - ).store() - test_user_2 = orm.User( - email="stravinsky@symphony.org", first_name="Igor", last_name="Stravinsky" - ).store() - group_1 = orm.Group(label="test_label_1", user=test_user_1).store() - group_2 = orm.Group(label="test_label_2", user=test_user_2).store() + test_user_1 = orm.User(email='verdi@opera.net', first_name='Giuseppe', last_name='Verdi').store() + test_user_2 = orm.User(email='stravinsky@symphony.org', first_name='Igor', last_name='Stravinsky').store() + group_1 = orm.Group(label='test_label_1', user=test_user_1).store() + group_2 = orm.Group(label='test_label_2', user=test_user_2).store() return [group_1.pk, group_2.pk] -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def default_nodes(): """Populate database with some nodes.""" node_1 = orm.Int(1).store() node_2 = orm.Float(1.1).store() - node_3 = orm.Str("test_string").store() + node_3 = orm.Str('test_string').store() node_4 = orm.Bool(False).store() return [node_1.pk, node_2.pk, node_3.pk, node_4.pk] -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def authenticate(): """Authenticate user. @@ -182,7 +173,7 @@ def authenticate(): async def logged_in_user(token=None): # pylint: disable=unused-argument """Fake active user.""" - return UserInDB(**config.fake_users_db["johndoe@example.com"]) + return UserInDB(**config.fake_users_db['johndoe@example.com']) app.dependency_overrides[get_current_user] = logged_in_user yield @@ -201,9 +192,8 @@ def mutate_mapping( for key, item in mapping.items(): if isinstance(item, (MutableMapping, list)): mutate_mapping(item, mutations) - else: - if key in mutations: - mapping[key] = mutations[key](mapping[key]) + elif key in mutations: + mapping[key] = mutations[key](mapping[key]) @pytest.fixture @@ -213,15 +203,15 @@ def orm_regression(data_regression): def _func( data: dict, varfields=( - "id", - "uuid", - "time", - "ctime", - "mtime", - "dbnode_id", - "user_id", - "_aiida_hash", - "aiidaVersion", + 'id', + 'uuid', + 'time', + 'ctime', + 'mtime', + 'dbnode_id', + 'user_id', + '_aiida_hash', + 'aiidaVersion', ), ): mutate_mapping( @@ -238,10 +228,10 @@ def create_user(): """Create and store an AiiDA User.""" def _func( - email="a@b.com", - first_name: str = "", - last_name: str = "", - institution: str = "", + email='a@b.com', + first_name: str = '', + last_name: str = '', + institution: str = '', ) -> orm.User: return orm.User( email=email, @@ -258,16 +248,14 @@ def create_comment(): """Create and store an AiiDA Comment (+ associated user and node).""" def _func( - content: str = "content", - user_email="verdi@opera.net", + content: str = 'content', + user_email='verdi@opera.net', node: Optional[orm.nodes.Node] = None, ) -> orm.Comment: try: user = orm.User.collection.get(email=user_email) except NotExistent: - user = orm.User( - email=user_email, first_name="Giuseppe", last_name="Verdi" - ).store() + user = orm.User(email=user_email, first_name='Giuseppe', last_name='Verdi').store() if node is None: node = orm.Data() node.user = user @@ -282,15 +270,13 @@ def create_log(): """Create and store an AiiDA Log (and node).""" def _func( - loggername: str = "name", - level_name: str = "level 1", - message="", + loggername: str = 'name', + level_name: str = 'level 1', + message='', node: Optional[orm.nodes.Node] = None, ) -> orm.Comment: orm_node = node or orm.Data().store() - return orm.Log( - datetime.now(pytz.UTC), loggername, level_name, orm_node.pk, message=message - ).store() + return orm.Log(datetime.now(pytz.UTC), loggername, level_name, orm_node.pk, message=message).store() return _func @@ -300,11 +286,11 @@ def create_computer(): """Create and store an AiiDA Computer.""" def _func( - label: str = "localhost", - hostname: str = "localhost", - transport_type: str = "core.local", - scheduler_type: str = "core.direct", - description: str = "", + label: str = 'localhost', + hostname: str = 'localhost', + transport_type: str = 'core.local', + scheduler_type: str = 'core.direct', + description: str = '', workdir: Optional[str] = None, ) -> orm.Computer: return orm.Computer( @@ -325,19 +311,15 @@ def create_node(): def _func( *, - label: str = "", - description: str = "", + label: str = '', + description: str = '', attributes: Optional[dict] = None, extras: Optional[dict] = None, process_type: Optional[str] = None, computer: Optional[orm.Computer] = None, - store: bool = True + store: bool = True, ) -> orm.nodes.Node: - node = ( - orm.CalcJobNode(computer=computer) - if process_type - else orm.Data(computer=computer) - ) + node = orm.CalcJobNode(computer=computer) if process_type else orm.Data(computer=computer) node.label = label node.description = description node.base.attributes.reset(attributes or {}) @@ -356,8 +338,8 @@ def create_group(): """Create and store an AiiDA Group.""" def _func( - label: str = "group", - description: str = "", + label: str = 'group', + description: str = '', type_string: Optional[str] = None, ) -> orm.Group: return orm.Group( diff --git a/tests/test_auth.py b/tests/test_auth.py index 7e5bd15..c2d7f9c 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Test the /users endpoint""" + from fastapi.testclient import TestClient from aiida_restapi import app @@ -10,13 +10,11 @@ def test_authenticate_user(): """Test authenticating as a user.""" # authenticate with username and password - response = client.post( - "/token", data={"username": "johndoe@example.com", "password": "secret"} - ) + response = client.post('/token', data={'username': 'johndoe@example.com', 'password': 'secret'}) assert response.status_code == 200, response.content - token = response.json()["access_token"] + token = response.json()['access_token'] # use JSON web token to access protected endpoint - response = client.get("/auth/me", headers={"Authorization": "Bearer " + str(token)}) + response = client.get('/auth/me', headers={'Authorization': 'Bearer ' + str(token)}) assert response.status_code == 200, response.content - assert response.json()["last_name"] == "Doe" + assert response.json()['last_name'] == 'Doe' diff --git a/tests/test_computers.py b/tests/test_computers.py index 7d47d60..4d98cb9 100644 --- a/tests/test_computers.py +++ b/tests/test_computers.py @@ -1,10 +1,9 @@ -# -*- coding: utf-8 -*- """Test the /computers endpoint""" def test_get_computers(default_computers, client): # pylint: disable=unused-argument """Test listing existing computer.""" - response = client.get("/computers/") + response = client.get('/computers/') assert response.status_code == 200 assert len(response.json()) == 2 @@ -12,44 +11,42 @@ def test_get_computers(default_computers, client): # pylint: disable=unused-arg def test_get_computers_projectable(client): """Test get projectable properites for computer.""" - response = client.get("/computers/projectable_properties") + response = client.get('/computers/projectable_properties') assert response.status_code == 200 assert response.json() == [ - "id", - "uuid", - "label", - "hostname", - "scheduler_type", - "transport_type", - "metadata", - "description", + 'id', + 'uuid', + 'label', + 'hostname', + 'scheduler_type', + 'transport_type', + 'metadata', + 'description', ] -def test_get_single_computers( - default_computers, client -): # pylint: disable=unused-argument +def test_get_single_computers(default_computers, client): # pylint: disable=unused-argument """Test retrieving a single computer.""" for comp_id in default_computers: - response = client.get(f"/computers/{comp_id}") + response = client.get(f'/computers/{comp_id}') assert response.status_code == 200 def test_create_computer(client, authenticate): # pylint: disable=unused-argument """Test creating a new computer.""" response = client.post( - "/computers", + '/computers', json={ - "label": "test_comp", - "hostname": "fake_host", - "transport_type": "core.local", - "scheduler_type": "core.pbspro", + 'label': 'test_comp', + 'hostname': 'fake_host', + 'transport_type': 'core.local', + 'scheduler_type': 'core.pbspro', }, ) assert response.status_code == 200, response.content - response = client.get("/computers") - computers = [comp["label"] for comp in response.json()] - assert "test_comp" in computers + response = client.get('/computers') + computers = [comp['label'] for comp in response.json()] + assert 'test_comp' in computers diff --git a/tests/test_daemon.py b/tests/test_daemon.py index 0dba53b..e3124a5 100644 --- a/tests/test_daemon.py +++ b/tests/test_daemon.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Test the /daemon endpoint""" + import pytest from fastapi.testclient import TestClient @@ -8,43 +8,43 @@ client = TestClient(app) -@pytest.mark.usefixtures("stopped_daemon_client", "authenticate") +@pytest.mark.usefixtures('stopped_daemon_client', 'authenticate') def test_status_and_start(): """Test ``/daemon/status`` when the daemon is not running and ``/daemon/start``.""" - response = client.get("/daemon/status") + response = client.get('/daemon/status') assert response.status_code == 200, response.content results = response.json() - assert results["running"] is False - assert results["num_workers"] is None + assert results['running'] is False + assert results['num_workers'] is None - response = client.post("/daemon/start") + response = client.post('/daemon/start') assert response.status_code == 200, response.content results = response.json() - assert results["running"] is True - assert results["num_workers"] == 1 + assert results['running'] is True + assert results['num_workers'] == 1 - response = client.post("/daemon/start") + response = client.post('/daemon/start') assert response.status_code == 400, response.content -@pytest.mark.usefixtures("started_daemon_client", "authenticate") +@pytest.mark.usefixtures('started_daemon_client', 'authenticate') def test_status_and_stop(): """Test ``/daemon/status`` when the daemon is running and ``/daemon/stop``.""" - response = client.get("/daemon/status") + response = client.get('/daemon/status') assert response.status_code == 200, response.content results = response.json() - assert results["running"] is True - assert results["num_workers"] == 1 + assert results['running'] is True + assert results['num_workers'] == 1 - response = client.post("/daemon/stop") + response = client.post('/daemon/stop') assert response.status_code == 200, response.content results = response.json() - assert results["running"] is False - assert results["num_workers"] is None + assert results['running'] is False + assert results['num_workers'] is None - response = client.post("/daemon/stop") + response = client.post('/daemon/stop') assert response.status_code == 400, response.content diff --git a/tests/test_filter_syntax.py b/tests/test_filter_syntax.py index 0ae80f9..5d7ebb7 100644 --- a/tests/test_filter_syntax.py +++ b/tests/test_filter_syntax.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Tests for syntax filter.""" + import datetime import pytest @@ -8,28 +8,28 @@ @pytest.mark.parametrize( - "input_str,output", + 'input_str,output', [ - ("a==1", {"a": {"==": 1}}), - ("a_bc>='d'", {"a_bc": {">=": "d"}}), - ("a.b<=c", {"a.b": {"<=": "c"}}), - ("a !== 1.0", {"a": {"!==": 1.0}}), - ("a==2020-01-01", {"a": {"==": datetime.datetime(2020, 1, 1, 0, 0)}}), - ("a==2020-01-01 10:11", {"a": {"==": datetime.datetime(2020, 1, 1, 10, 11)}}), - ("a == 1 AND b == 2", {"a": {"==": 1}, "b": {"==": 2}}), - ('a LIKE "x%"', {"a": {"like": "x%"}}), - ('a ILIKE "x%"', {"a": {"ilike": "x%"}}), - ('a ILIKE "x%"', {"a": {"ilike": "x%"}}), - ("a LENGTH 33", {"a": {"of_length": 33}}), - ("a OF LENGTH 33", {"a": {"of_length": 33}}), - ("a IN 1", {"a": {"in": [1]}}), - ("a IS IN 1", {"a": {"in": [1]}}), - ("a IN 1,2,3", {"a": {"in": [1, 2, 3]}}), - ("a IN x,y,z", {"a": {"in": ["x", "y", "z"]}}), - ('a IN "x","y","z"', {"a": {"in": ["x", "y", "z"]}}), - ('a HAS "x"', {"a": {"has_key": "x"}}), - ('a HAS KEY "y"', {"a": {"has_key": "y"}}), - ("a < 2 & a >=1 & a == 3", {"a": {"and": [{"<": 2}, {">=": 1}, {"==": 3}]}}), + ('a==1', {'a': {'==': 1}}), + ("a_bc>='d'", {'a_bc': {'>=': 'd'}}), + ('a.b<=c', {'a.b': {'<=': 'c'}}), + ('a !== 1.0', {'a': {'!==': 1.0}}), + ('a==2020-01-01', {'a': {'==': datetime.datetime(2020, 1, 1, 0, 0)}}), + ('a==2020-01-01 10:11', {'a': {'==': datetime.datetime(2020, 1, 1, 10, 11)}}), + ('a == 1 AND b == 2', {'a': {'==': 1}, 'b': {'==': 2}}), + ('a LIKE "x%"', {'a': {'like': 'x%'}}), + ('a ILIKE "x%"', {'a': {'ilike': 'x%'}}), + ('a ILIKE "x%"', {'a': {'ilike': 'x%'}}), + ('a LENGTH 33', {'a': {'of_length': 33}}), + ('a OF LENGTH 33', {'a': {'of_length': 33}}), + ('a IN 1', {'a': {'in': [1]}}), + ('a IS IN 1', {'a': {'in': [1]}}), + ('a IN 1,2,3', {'a': {'in': [1, 2, 3]}}), + ('a IN x,y,z', {'a': {'in': ['x', 'y', 'z']}}), + ('a IN "x","y","z"', {'a': {'in': ['x', 'y', 'z']}}), + ('a HAS "x"', {'a': {'has_key': 'x'}}), + ('a HAS KEY "y"', {'a': {'has_key': 'y'}}), + ('a < 2 & a >=1 & a == 3', {'a': {'and': [{'<': 2}, {'>=': 1}, {'==': 3}]}}), ], ) def test_parser(input_str, output): diff --git a/tests/test_graphql/test_basic.py b/tests/test_graphql/test_basic.py index 9a67286..80c5b2b 100644 --- a/tests/test_graphql/test_basic.py +++ b/tests/test_graphql/test_basic.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Tests for basic plugins.""" + from aiida import __version__ from graphene.test import Client @@ -12,15 +12,15 @@ def test_aiidaVersion(): """Test aiidaVersion query.""" schema = create_schema([aiidaVersionPlugin]) client = Client(schema) - executed = client.execute("{ aiidaVersion }") - assert "aiidaVersion" in executed["data"] - assert executed["data"]["aiidaVersion"] == __version__ + executed = client.execute('{ aiidaVersion }') + assert 'aiidaVersion' in executed['data'] + assert executed['data']['aiidaVersion'] == __version__ def test_rowLimitMax(): """Test rowLimitMax query.""" schema = create_schema([rowLimitMaxPlugin]) client = Client(schema) - executed = client.execute("{ rowLimitMax }") - assert "rowLimitMax" in executed["data"] - assert executed["data"]["rowLimitMax"] == ENTITY_LIMIT + executed = client.execute('{ rowLimitMax }') + assert 'rowLimitMax' in executed['data'] + assert executed['data']['rowLimitMax'] == ENTITY_LIMIT diff --git a/tests/test_graphql/test_comments.py b/tests/test_graphql/test_comments.py index bcc343d..93711cd 100644 --- a/tests/test_graphql/test_comments.py +++ b/tests/test_graphql/test_comments.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Tests for comments plugins.""" + from graphene.test import Client from aiida_restapi.graphql.comments import CommentQueryPlugin, CommentsQueryPlugin @@ -13,31 +13,27 @@ def test_comment(create_comment, orm_regression): fields = field_names_from_orm(type(comment)) schema = create_schema([CommentQueryPlugin]) client = Client(schema) - executed = client.execute( - "{ comment(id: %r) { %s } }" % (comment.id, " ".join(fields)) - ) + executed = client.execute('{ comment(id: %r) { %s } }' % (comment.id, ' '.join(fields))) orm_regression(executed) def test_comments(create_comment, orm_regression): """Test Comments query, for all fields.""" - create_comment(content="comment 1") - comment = create_comment(content="comment 2") + create_comment(content='comment 1') + comment = create_comment(content='comment 2') fields = field_names_from_orm(type(comment)) schema = create_schema([CommentsQueryPlugin]) client = Client(schema) - executed = client.execute("{ comments { count rows { %s } } }" % " ".join(fields)) + executed = client.execute('{ comments { count rows { %s } } }' % ' '.join(fields)) orm_regression(executed) def test_comments_order_by(create_comment, orm_regression): """Test Comments query, when ordering by a field.""" - create_comment(content="b") - create_comment(content="a") - create_comment(content="c") + create_comment(content='b') + create_comment(content='a') + create_comment(content='c') schema = create_schema([CommentsQueryPlugin]) client = Client(schema) - executed = client.execute( - '{ comments { count rows(orderBy: "content", orderAsc: false) { content } } }' - ) + executed = client.execute('{ comments { count rows(orderBy: "content", orderAsc: false) { content } } }') orm_regression(executed) diff --git a/tests/test_graphql/test_computers.py b/tests/test_graphql/test_computers.py index 6e63986..036c1c5 100644 --- a/tests/test_graphql/test_computers.py +++ b/tests/test_graphql/test_computers.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Tests for computer plugins.""" + from graphene.test import Client from aiida_restapi.graphql.computers import ComputerQueryPlugin, ComputersQueryPlugin @@ -13,31 +13,27 @@ def test_computer(create_computer, orm_regression): fields = field_names_from_orm(type(computer)) schema = create_schema([ComputerQueryPlugin]) client = Client(schema) - executed = client.execute( - "{ computer(id: %r) { %s } }" % (computer.id, " ".join(fields)) - ) + executed = client.execute('{ computer(id: %r) { %s } }' % (computer.id, ' '.join(fields))) orm_regression(executed) def test_computer_nodes(create_computer, create_node, orm_regression): """Test querying Nodes inside Computer.""" - computer = create_computer(label="mycomputer") - create_node(label="node 1", computer=computer) - create_node(label="node 2", computer=computer) + computer = create_computer(label='mycomputer') + create_node(label='node 1', computer=computer) + create_node(label='node 2', computer=computer) schema = create_schema([ComputerQueryPlugin]) client = Client(schema) - executed = client.execute( - "{ computer(id: %r) { nodes { count rows{ label } } } }" % (computer.id) - ) + executed = client.execute('{ computer(id: %r) { nodes { count rows{ label } } } }' % (computer.id)) orm_regression(executed) def test_computers(create_computer, orm_regression): """Test Computers query, for all fields.""" - create_computer(label="computer 1") - computer = create_computer(label="computer 2") + create_computer(label='computer 1') + computer = create_computer(label='computer 2') fields = field_names_from_orm(type(computer)) schema = create_schema([ComputersQueryPlugin]) client = Client(schema) - executed = client.execute("{ computers { count rows { %s } } }" % " ".join(fields)) + executed = client.execute('{ computers { count rows { %s } } }' % ' '.join(fields)) orm_regression(executed) diff --git a/tests/test_graphql/test_entry_points.py b/tests/test_graphql/test_entry_points.py index 9222dda..7de3d3e 100644 --- a/tests/test_graphql/test_entry_points.py +++ b/tests/test_graphql/test_entry_points.py @@ -1,6 +1,5 @@ -# -*- coding: utf-8 -*- """Tests for basic plugins.""" -from aiida import __version__ + from graphene.test import Client from aiida_restapi.graphql.entry_points import ( @@ -14,20 +13,18 @@ def test_aiidaEntryPointGroups(): """Test test_aiidaEntryPointGroups query.""" schema = create_schema([aiidaEntryPointGroupsPlugin]) client = Client(schema) - executed = client.execute("{ aiidaEntryPointGroups }") - assert "data" in executed, executed - assert "aiidaEntryPointGroups" in executed["data"], executed["data"] - assert "aiida.data" in executed["data"]["aiidaEntryPointGroups"], executed["data"] + executed = client.execute('{ aiidaEntryPointGroups }') + assert 'data' in executed, executed + assert 'aiidaEntryPointGroups' in executed['data'], executed['data'] + assert 'aiida.data' in executed['data']['aiidaEntryPointGroups'], executed['data'] def test_aiidaEntryPoints(): """Test aiidaEntryPoints query.""" schema = create_schema([aiidaEntryPointsPlugin]) client = Client(schema) - executed = client.execute( - '{ aiidaEntryPoints(group: "aiida.schedulers") { group names } }' - ) - assert "data" in executed, executed - assert "aiidaEntryPoints" in executed["data"] - assert executed["data"]["aiidaEntryPoints"]["group"] == "aiida.schedulers" - assert "core.direct" in executed["data"]["aiidaEntryPoints"]["names"] + executed = client.execute('{ aiidaEntryPoints(group: "aiida.schedulers") { group names } }') + assert 'data' in executed, executed + assert 'aiidaEntryPoints' in executed['data'] + assert executed['data']['aiidaEntryPoints']['group'] == 'aiida.schedulers' + assert 'core.direct' in executed['data']['aiidaEntryPoints']['names'] diff --git a/tests/test_graphql/test_full.py b/tests/test_graphql/test_full.py index 3dad4e0..a1a811a 100644 --- a/tests/test_graphql/test_full.py +++ b/tests/test_graphql/test_full.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Simple test for the full schema.""" + from graphene.test import Client from aiida_restapi.graphql.main import SCHEMA @@ -7,7 +7,7 @@ def test_full(create_node, orm_regression): """Test loading the full schema.""" - node = create_node(label="node 1") + node = create_node(label='node 1') client = Client(SCHEMA) - executed = client.execute("{ aiidaVersion node(id: %r) { label } }" % (node.id)) + executed = client.execute('{ aiidaVersion node(id: %r) { label } }' % (node.id)) orm_regression(executed) diff --git a/tests/test_graphql/test_groups.py b/tests/test_graphql/test_groups.py index 16f714d..0e0b107 100644 --- a/tests/test_graphql/test_groups.py +++ b/tests/test_graphql/test_groups.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Tests for group plugins.""" + from graphene.test import Client from aiida_restapi.graphql.groups import GroupQueryPlugin, GroupsQueryPlugin @@ -13,41 +13,37 @@ def test_group(create_group, orm_regression): fields = field_names_from_orm(type(group)) schema = create_schema([GroupQueryPlugin]) client = Client(schema) - executed = client.execute("{ group(id: %r) { %s } }" % (group.id, " ".join(fields))) + executed = client.execute('{ group(id: %r) { %s } }' % (group.id, ' '.join(fields))) orm_regression(executed) def test_group_label(create_group, orm_regression): """Test Group query on the label.""" - group = create_group(label="custom-label") + group = create_group(label='custom-label') fields = field_names_from_orm(type(group)) schema = create_schema([GroupQueryPlugin]) client = Client(schema) - executed = client.execute( - '{ group(label: "custom-label") { %s } }' % (" ".join(fields)) - ) + executed = client.execute('{ group(label: "custom-label") { %s } }' % (' '.join(fields))) orm_regression(executed) def test_group_nodes(create_group, create_node, orm_regression): """Test querying Nodes inside Group.""" - create_node(label="not in group") + create_node(label='not in group') group = create_group() - group.add_nodes([create_node(label="node 1"), create_node(label="node 2")]) + group.add_nodes([create_node(label='node 1'), create_node(label='node 2')]) schema = create_schema([GroupQueryPlugin]) client = Client(schema) - executed = client.execute( - "{ group(id: %r) { nodes { count rows{ label } } } }" % (group.id) - ) + executed = client.execute('{ group(id: %r) { nodes { count rows{ label } } } }' % (group.id)) orm_regression(executed) def test_groups(create_group, orm_regression): """Test Groups query, for all fields.""" - create_group(label="group1") - group = create_group(label="group2") + create_group(label='group1') + group = create_group(label='group2') fields = field_names_from_orm(type(group)) schema = create_schema([GroupsQueryPlugin]) client = Client(schema) - executed = client.execute("{ groups { count rows { %s } } }" % " ".join(fields)) + executed = client.execute('{ groups { count rows { %s } } }' % ' '.join(fields)) orm_regression(executed) diff --git a/tests/test_graphql/test_logs.py b/tests/test_graphql/test_logs.py index be152f6..3f5a6f1 100644 --- a/tests/test_graphql/test_logs.py +++ b/tests/test_graphql/test_logs.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Tests for log plugins.""" + from graphene.test import Client from aiida_restapi.graphql.logs import LogQueryPlugin, LogsQueryPlugin @@ -13,16 +13,16 @@ def test_log(create_log, orm_regression): fields = field_names_from_orm(type(log)) schema = create_schema([LogQueryPlugin]) client = Client(schema) - executed = client.execute("{ log(id: %r) { %s } }" % (log.id, " ".join(fields))) + executed = client.execute('{ log(id: %r) { %s } }' % (log.id, ' '.join(fields))) orm_regression(executed) def test_logs(create_log, orm_regression): """Test Logs query, for all fields.""" - create_log(message="log 1") - log = create_log(message="log 2") + create_log(message='log 1') + log = create_log(message='log 2') fields = field_names_from_orm(type(log)) schema = create_schema([LogsQueryPlugin]) client = Client(schema) - executed = client.execute("{ logs { count rows { %s } } }" % " ".join(fields)) + executed = client.execute('{ logs { count rows { %s } } }' % ' '.join(fields)) orm_regression(executed) diff --git a/tests/test_graphql/test_nodes.py b/tests/test_graphql/test_nodes.py index ad3550f..18ee144 100644 --- a/tests/test_graphql/test_nodes.py +++ b/tests/test_graphql/test_nodes.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Tests for node plugins.""" + from aiida.common.links import LinkType from graphene.test import Client @@ -10,46 +10,42 @@ def test_node(create_node, orm_regression): """Test Node query.""" - node = create_node(process_type="my_process") + node = create_node(process_type='my_process') fields = field_names_from_orm(type(node)) schema = create_schema([NodeQueryPlugin]) client = Client(schema) - executed = client.execute("{ node(id: %r) { %s } }" % (node.id, " ".join(fields))) + executed = client.execute('{ node(id: %r) { %s } }' % (node.id, ' '.join(fields))) orm_regression(executed) def test_node_logs(create_node, create_log, orm_regression): """Test queryinglogs of a node.""" - node = create_node(label="mynode") - create_log(message="log 1", node=node) - create_log(message="log 2", node=node) + node = create_node(label='mynode') + create_log(message='log 1', node=node) + create_log(message='log 2', node=node) schema = create_schema([NodeQueryPlugin]) client = Client(schema) - executed = client.execute( - "{ node(id: %r) { logs { count rows{ message } } } }" % (node.id) - ) + executed = client.execute('{ node(id: %r) { logs { count rows{ message } } } }' % (node.id)) orm_regression(executed) def test_node_comments(create_node, create_comment, orm_regression): """Test querying comments of a node.""" - node = create_node(label="mynode") - create_comment(content="comment 1", node=node) - create_comment(content="comment 2", node=node) + node = create_node(label='mynode') + create_comment(content='comment 1', node=node) + create_comment(content='comment 2', node=node) schema = create_schema([NodeQueryPlugin]) client = Client(schema) - executed = client.execute( - "{ node(id: %r) { comments { count rows{ content } } } }" % (node.id) - ) + executed = client.execute('{ node(id: %r) { comments { count rows{ content } } } }' % (node.id)) orm_regression(executed) def test_node_incoming(create_node, orm_regression): """Test querying incoming links to a node.""" - node = create_node(label="mynode", process_type="process", store=False) + node = create_node(label='mynode', process_type='process', store=False) for label, link_type, link_label in [ - ("incoming1", LinkType.INPUT_CALC, "link1"), - ("incoming2", LinkType.INPUT_CALC, "link2"), + ('incoming1', LinkType.INPUT_CALC, 'link1'), + ('incoming2', LinkType.INPUT_CALC, 'link2'), ]: node.base.links.add_incoming(create_node(label=label), link_type, link_label) node.store() @@ -57,18 +53,17 @@ def test_node_incoming(create_node, orm_regression): schema = create_schema([NodeQueryPlugin]) client = Client(schema) executed = client.execute( - "{ node(id: %r) { incoming { count rows{ node { label } link { label type } } } } }" - % (node.id) + '{ node(id: %r) { incoming { count rows{ node { label } link { label type } } } } }' % (node.id) ) orm_regression(executed) def test_node_outgoing(create_node, orm_regression): """Test querying ancestor links to a node.""" - node = create_node(label="mynode", process_type="process") + node = create_node(label='mynode', process_type='process') for label, link_type, link_label in [ - ("outgoing1", LinkType.CREATE, "link1"), - ("outgoing2", LinkType.CREATE, "link2"), + ('outgoing1', LinkType.CREATE, 'link1'), + ('outgoing2', LinkType.CREATE, 'link2'), ]: outgoing = create_node(label=label) outgoing.base.links.add_incoming(node, link_type, link_label) @@ -77,36 +72,33 @@ def test_node_outgoing(create_node, orm_regression): schema = create_schema([NodeQueryPlugin]) client = Client(schema) executed = client.execute( - "{ node(id: %r) { outgoing { count rows{ node { label } link { label type } } } } }" - % (node.id) + '{ node(id: %r) { outgoing { count rows{ node { label } link { label type } } } } }' % (node.id) ) orm_regression(executed) def test_node_ancestors(create_node, orm_regression): """Test querying incoming links to a node.""" - node = create_node(label="mynode", process_type="process", store=False) + node = create_node(label='mynode', process_type='process', store=False) for label, link_type, link_label in [ - ("incoming1", LinkType.INPUT_CALC, "link1"), - ("incoming2", LinkType.INPUT_CALC, "link2"), + ('incoming1', LinkType.INPUT_CALC, 'link1'), + ('incoming2', LinkType.INPUT_CALC, 'link2'), ]: node.base.links.add_incoming(create_node(label=label), link_type, link_label) node.store() schema = create_schema([NodeQueryPlugin]) client = Client(schema) - executed = client.execute( - "{ node(id: %r) { ancestors { count rows{ label } } } }" % (node.id) - ) + executed = client.execute('{ node(id: %r) { ancestors { count rows{ label } } } }' % (node.id)) orm_regression(executed) def test_node_descendants(create_node, orm_regression): """Test querying descendant links to a node.""" - node = create_node(label="mynode", process_type="process") + node = create_node(label='mynode', process_type='process') for label, link_type, link_label in [ - ("outgoing1", LinkType.CREATE, "link1"), - ("outgoing2", LinkType.CREATE, "link2"), + ('outgoing1', LinkType.CREATE, 'link1'), + ('outgoing2', LinkType.CREATE, 'link2'), ]: outgoing = create_node(label=label) outgoing.base.links.add_incoming(node, link_type, link_label) @@ -114,18 +106,16 @@ def test_node_descendants(create_node, orm_regression): schema = create_schema([NodeQueryPlugin]) client = Client(schema) - executed = client.execute( - "{ node(id: %r) { descendants { count rows{ label } } } }" % (node.id) - ) + executed = client.execute('{ node(id: %r) { descendants { count rows{ label } } } }' % (node.id)) orm_regression(executed) def test_nodes(create_node, orm_regression): """Test Nodes query, for all fields.""" - create_node(label="node 1") - node = create_node(label="node 2") + create_node(label='node 1') + node = create_node(label='node 2') fields = field_names_from_orm(type(node)) schema = create_schema([NodesQueryPlugin]) client = Client(schema) - executed = client.execute("{ nodes { count rows { %s } } }" % " ".join(fields)) + executed = client.execute('{ nodes { count rows { %s } } }' % ' '.join(fields)) orm_regression(executed) diff --git a/tests/test_graphql/test_users.py b/tests/test_graphql/test_users.py index 83495eb..cd6b52d 100644 --- a/tests/test_graphql/test_users.py +++ b/tests/test_graphql/test_users.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Tests for user plugins.""" + from graphene.test import Client from aiida_restapi.graphql.orm_factories import field_names_from_orm @@ -13,16 +13,16 @@ def test_user(create_user, orm_regression): fields = field_names_from_orm(type(user)) schema = create_schema([UserQueryPlugin]) client = Client(schema) - executed = client.execute("{ user(id: %r) { %s } }" % (user.id, " ".join(fields))) + executed = client.execute('{ user(id: %r) { %s } }' % (user.id, ' '.join(fields))) orm_regression(executed) def test_users(create_user, orm_regression): """Test Users query, for all fields.""" - create_user(email="a@b.com") - user = create_user(email="c@d.com") + create_user(email='a@b.com') + user = create_user(email='c@d.com') fields = field_names_from_orm(type(user)) schema = create_schema([UsersQueryPlugin]) client = Client(schema) - executed = client.execute("{ users { count rows { %s } } }" % " ".join(fields)) + executed = client.execute('{ users { count rows { %s } } }' % ' '.join(fields)) orm_regression(executed) diff --git a/tests/test_groups.py b/tests/test_groups.py index 8268d20..09e245d 100644 --- a/tests/test_groups.py +++ b/tests/test_groups.py @@ -1,28 +1,27 @@ -# -*- coding: utf-8 -*- """Test the /groups endpoint""" def test_get_group(default_groups, client): # pylint: disable=unused-argument """Test listing existing groups.""" - response = client.get("/groups") + response = client.get('/groups') assert response.status_code == 200 assert len(response.json()) == 2 def test_get_group_projectable(client): """Test get projectable properites for group.""" - response = client.get("/groups/projectable_properties") + response = client.get('/groups/projectable_properties') assert response.status_code == 200 assert response.json() == [ - "id", - "uuid", - "label", - "type_string", - "description", - "extras", - "time", - "user_id", + 'id', + 'uuid', + 'label', + 'type_string', + 'description', + 'extras', + 'time', + 'user_id', ] @@ -30,26 +29,24 @@ def test_get_single_group(default_groups, client): # pylint: disable=unused-arg """Test retrieving a single group.""" for group_id in default_groups: - response = client.get(f"/groups/{group_id}") + response = client.get(f'/groups/{group_id}') assert response.status_code == 200 def test_create_group(client, authenticate): # pylint: disable=unused-argument """Test creating a new group.""" - response = client.post("/groups", json={"label": "test_label_create"}) + response = client.post('/groups', json={'label': 'test_label_create'}) assert response.status_code == 200, response.content - response = client.get("/groups") - first_names = [group["label"] for group in response.json()] + response = client.get('/groups') + first_names = [group['label'] for group in response.json()] - assert "test_label_create" in first_names + assert 'test_label_create' in first_names -def test_create_group_returns_user_id( - client, authenticate -): # pylint: disable=unused-argument +def test_create_group_returns_user_id(client, authenticate): # pylint: disable=unused-argument """Test creating a new group returns user_id.""" - response = client.post("/groups", json={"label": "test_label_create"}) + response = client.post('/groups', json={'label': 'test_label_create'}) assert response.status_code == 200, response.content - assert response.json()["user_id"] + assert response.json()['user_id'] diff --git a/tests/test_models.py b/tests/test_models.py index c7715a7..b28416b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Test that all aiida entity models can be loaded loaded into pydantic models.""" + from aiida import orm from aiida_restapi import models @@ -7,7 +7,7 @@ def replace_dynamic(data: dict) -> dict: """Replace dynamic fields with their type name.""" - for key in ["id", "uuid", "dbnode_id", "user_id", "mtime", "ctime", "time"]: + for key in ['id', 'uuid', 'dbnode_id', 'user_id', 'mtime', 'ctime', 'time']: if key in data: data[key] = type(data[key]).__name__ return data @@ -15,29 +15,27 @@ def replace_dynamic(data: dict) -> dict: def test_comment_get_entities(data_regression): """Test ``Comment.get_entities``""" - orm_user = orm.User( - email="verdi@opera.net", first_name="Giuseppe", last_name="Verdi" - ).store() + orm_user = orm.User(email='verdi@opera.net', first_name='Giuseppe', last_name='Verdi').store() orm_node = orm.Data().store() - orm.Comment(orm_node, orm_user, "content").store() - py_comments = models.Comment.get_entities(order_by=["id"]) + orm.Comment(orm_node, orm_user, 'content').store() + py_comments = models.Comment.get_entities(order_by=['id']) data_regression.check([replace_dynamic(c.dict()) for c in py_comments]) def test_user_get_entities(data_regression): """Test ``User.get_entities``""" - orm.User(email="verdi@opera.net", first_name="Giuseppe", last_name="Verdi").store() - py_users = models.User.get_entities(order_by=["id"]) + orm.User(email='verdi@opera.net', first_name='Giuseppe', last_name='Verdi').store() + py_users = models.User.get_entities(order_by=['id']) data_regression.check([replace_dynamic(c.dict()) for c in py_users]) def test_computer_get_entities(data_regression): """Test ``Computer.get_entities``""" orm.Computer( - label="test_comp_1", - hostname="localhost_1", - transport_type="core.local", - scheduler_type="core.pbspro", + label='test_comp_1', + hostname='localhost_1', + transport_type='core.local', + scheduler_type='core.pbspro', ).store() py_computer = models.Computer.get_entities() data_regression.check([replace_dynamic(c.dict()) for c in py_computer]) @@ -45,6 +43,6 @@ def test_computer_get_entities(data_regression): def test_group_get_entities(data_regression): """Test ``Group.get_entities``""" - orm.Group(label="regression_label_1", description="regrerssion_test").store() - py_group = models.Group.get_entities(order_by=["id"]) + orm.Group(label='regression_label_1', description='regrerssion_test').store() + py_group = models.Group.get_entities(order_by=['id']) data_regression.check([replace_dynamic(c.dict()) for c in py_group]) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index a273b99..fba4831 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1,7 +1,5 @@ -# -*- coding: utf-8 -*- """Test the /nodes endpoint""" - import io import json @@ -10,23 +8,23 @@ def test_get_nodes_projectable(client): """Test get projectable properites for nodes.""" - response = client.get("/nodes/projectable_properties") + response = client.get('/nodes/projectable_properties') assert response.status_code == 200 assert response.json() == [ - "id", - "uuid", - "node_type", - "process_type", - "label", - "description", - "ctime", - "mtime", - "user_id", - "dbcomputer_id", - "attributes", - "extras", - "repository_metadata", + 'id', + 'uuid', + 'node_type', + 'process_type', + 'label', + 'description', + 'ctime', + 'mtime', + 'user_id', + 'dbcomputer_id', + 'attributes', + 'extras', + 'repository_metadata', ] @@ -34,13 +32,13 @@ def test_get_single_nodes(default_nodes, client): # pylint: disable=unused-argu """Test retrieving a single nodes.""" for nodes_id in default_nodes: - response = client.get(f"/nodes/{nodes_id}") + response = client.get(f'/nodes/{nodes_id}') assert response.status_code == 200 def test_get_nodes(default_nodes, client): # pylint: disable=unused-argument """Test listing existing nodes.""" - response = client.get("/nodes") + response = client.get('/nodes') assert response.status_code == 200 assert len(response.json()) == 4 @@ -48,30 +46,28 @@ def test_get_nodes(default_nodes, client): # pylint: disable=unused-argument def test_create_dict(client, authenticate): # pylint: disable=unused-argument """Test creating a new dict.""" response = client.post( - "/nodes", + '/nodes', json={ - "entry_point": "core.dict", - "attributes": {"x": 1, "y": 2}, - "label": "test_dict", + 'entry_point': 'core.dict', + 'attributes': {'x': 1, 'y': 2}, + 'label': 'test_dict', }, ) assert response.status_code == 200, response.content @pytest.mark.anyio -async def test_create_code( - default_computers, async_client, authenticate -): # pylint: disable=unused-argument +async def test_create_code(default_computers, async_client, authenticate): # pylint: disable=unused-argument """Test creating a new Code.""" for comp_id in default_computers: response = await async_client.post( - "/nodes", + '/nodes', json={ - "entry_point": "core.code.installed", - "dbcomputer_id": comp_id, - "attributes": {"filepath_executable": "/bin/true"}, - "label": "test_code", + 'entry_point': 'core.code.installed', + 'dbcomputer_id': comp_id, + 'attributes': {'filepath_executable': '/bin/true'}, + 'label': 'test_code', }, ) assert response.status_code == 200, response.content @@ -80,10 +76,10 @@ async def test_create_code( def test_create_list(client, authenticate): # pylint: disable=unused-argument """Test creating a new list.""" response = client.post( - "/nodes", + '/nodes', json={ - "entry_point": "core.list", - "attributes": {"list": [2, 3]}, + 'entry_point': 'core.list', + 'attributes': {'list': [2, 3]}, }, ) @@ -93,10 +89,10 @@ def test_create_list(client, authenticate): # pylint: disable=unused-argument def test_create_int(client, authenticate): # pylint: disable=unused-argument """Test creating a new Int.""" response = client.post( - "/nodes", + '/nodes', json={ - "entry_point": "core.int", - "attributes": {"value": 6}, + 'entry_point': 'core.int', + 'attributes': {'value': 6}, }, ) assert response.status_code == 200, response.content @@ -105,10 +101,10 @@ def test_create_int(client, authenticate): # pylint: disable=unused-argument def test_create_float(client, authenticate): # pylint: disable=unused-argument """Test creating a new Float.""" response = client.post( - "/nodes", + '/nodes', json={ - "entry_point": "core.float", - "attributes": {"value": 6.6}, + 'entry_point': 'core.float', + 'attributes': {'value': 6.6}, }, ) assert response.status_code == 200, response.content @@ -117,10 +113,10 @@ def test_create_float(client, authenticate): # pylint: disable=unused-argument def test_create_string(client, authenticate): # pylint: disable=unused-argument """Test creating a new string.""" response = client.post( - "/nodes", + '/nodes', json={ - "entry_point": "core.str", - "attributes": {"value": "test_string"}, + 'entry_point': 'core.str', + 'attributes': {'value': 'test_string'}, }, ) assert response.status_code == 200, response.content @@ -129,10 +125,10 @@ def test_create_string(client, authenticate): # pylint: disable=unused-argument def test_create_bool(client, authenticate): # pylint: disable=unused-argument """Test creating a new Bool.""" response = client.post( - "/nodes", + '/nodes', json={ - "entry_point": "core.bool", - "attributes": {"value": "True"}, + 'entry_point': 'core.bool', + 'attributes': {'value': 'True'}, }, ) assert response.status_code == 200, response.content @@ -141,18 +137,18 @@ def test_create_bool(client, authenticate): # pylint: disable=unused-argument def test_create_structure_data(client, authenticate): # pylint: disable=unused-argument """Test creating a new StructureData.""" response = client.post( - "/nodes", + '/nodes', json={ - "entry_point": "core.structure", - "process_type": None, - "description": "", - "attributes": { - "cell": [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], - "pbc": [True, True, True], - "ase": None, - "pymatgen": None, - "pymatgen_structure": None, - "pymatgen_molecule": None, + 'entry_point': 'core.structure', + 'process_type': None, + 'description': '', + 'attributes': { + 'cell': [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + 'pbc': [True, True, True], + 'ase': None, + 'pymatgen': None, + 'pymatgen_structure': None, + 'pymatgen_molecule': None, }, }, ) @@ -163,27 +159,27 @@ def test_create_structure_data(client, authenticate): # pylint: disable=unused- def test_create_orbital_data(client, authenticate): # pylint: disable=unused-argument """Test creating a new OrbitalData.""" response = client.post( - "/nodes", + '/nodes', json={ - "entry_point": "core.orbital", - "process_type": None, - "description": "", - "attributes": { - "orbital_dicts": [ + 'entry_point': 'core.orbital', + 'process_type': None, + 'description': '', + 'attributes': { + 'orbital_dicts': [ { - "spin": 0, - "position": [ + 'spin': 0, + 'position': [ -1, 1, 1, ], - "kind_name": "As", - "diffusivity": None, - "radial_nodes": 0, - "_orbital_type": "realhydrogen", - "x_orientation": None, - "z_orientation": None, - "angular_momentum": -3, + 'kind_name': 'As', + 'diffusivity': None, + 'radial_nodes': 0, + '_orbital_type': 'realhydrogen', + 'x_orientation': None, + 'z_orientation': None, + 'angular_momentum': -3, } ] }, @@ -193,122 +189,107 @@ def test_create_orbital_data(client, authenticate): # pylint: disable=unused-ar assert response.status_code == 200, response.content -def test_create_single_file_upload( - client, authenticate -): # pylint: disable=unused-argument +def test_create_single_file_upload(client, authenticate): # pylint: disable=unused-argument """Testing file upload""" test_file = { - "upload_file": ( - "test_file.txt", - io.BytesIO(b"Some test strings"), - "multipart/form-data", + 'upload_file': ( + 'test_file.txt', + io.BytesIO(b'Some test strings'), + 'multipart/form-data', ) } data = { - "params": json.dumps( + 'params': json.dumps( { - "entry_point": "core.singlefile", - "process_type": None, - "description": "Testing single upload file", - "attributes": {}, + 'entry_point': 'core.singlefile', + 'process_type': None, + 'description': 'Testing single upload file', + 'attributes': {}, } ) } - response = client.post("/nodes/singlefile", files=test_file, data=data) + response = client.post('/nodes/singlefile', files=test_file, data=data) assert response.status_code == 200, response.json() -def test_create_node_wrong_value( - client, authenticate -): # pylint: disable=unused-argument +def test_create_node_wrong_value(client, authenticate): # pylint: disable=unused-argument """Test creating a new node with wrong value.""" response = client.post( - "/nodes", + '/nodes', json={ - "entry_point": "core.float", - "attributes": {"value": "tests"}, + 'entry_point': 'core.float', + 'attributes': {'value': 'tests'}, }, ) assert response.status_code == 400, response.content response = client.post( - "/nodes", + '/nodes', json={ - "entry_point": "core.int", - "attributes": {"value": "tests"}, + 'entry_point': 'core.int', + 'attributes': {'value': 'tests'}, }, ) assert response.status_code == 400, response.content -def test_create_node_wrong_attribute( - client, authenticate -): # pylint: disable=unused-argument +def test_create_node_wrong_attribute(client, authenticate): # pylint: disable=unused-argument """Test adding node with wrong attributes.""" response = client.post( - "/nodes", + '/nodes', json={ - "entry_point": "core.str", - "attributes": {"value1": 5}, + 'entry_point': 'core.str', + 'attributes': {'value1': 5}, }, ) assert response.status_code == 400, response.content -def test_create_unknown_entry_point( - default_computers, client, authenticate -): # pylint: disable=unused-argument +def test_create_unknown_entry_point(default_computers, client, authenticate): # pylint: disable=unused-argument """Test error message when specifying unknown ``entry_point``.""" response = client.post( - "/nodes", + '/nodes', json={ - "entry_point": "core.not.existing.entry.point", - "label": "test_code", + 'entry_point': 'core.not.existing.entry.point', + 'label': 'test_code', }, ) assert response.status_code == 404, response.content - assert ( - response.json()["detail"] - == "Entry point 'core.not.existing.entry.point' not found in group 'aiida.data'" - ) + assert response.json()['detail'] == "Entry point 'core.not.existing.entry.point' not found in group 'aiida.data'" -def test_create_additional_attribute( - default_computers, client, authenticate -): # pylint: disable=unused-argument +def test_create_additional_attribute(default_computers, client, authenticate): # pylint: disable=unused-argument """Test adding additional properties returns errors.""" for comp_id in default_computers: response = client.post( - "/nodes", + '/nodes', json={ - "uuid": "3", - "entry_point": "core.code.installed", - "dbcomputer_id": comp_id, - "attributes": {"filepath_executable": "/bin/true"}, - "label": "test_code", + 'uuid': '3', + 'entry_point': 'core.code.installed', + 'dbcomputer_id': comp_id, + 'attributes': {'filepath_executable': '/bin/true'}, + 'label': 'test_code', }, ) assert response.status_code == 422, response.content -def test_create_bool_with_extra( - client, authenticate -): # pylint: disable=unused-argument +def test_create_bool_with_extra(client, authenticate): # pylint: disable=unused-argument """Test creating a new Bool with extra.""" response = client.post( - "/nodes", + '/nodes', json={ - "entry_point": "core.bool", - "attributes": {"value": "True"}, - "extras": {"extra_one": "value_1", "extra_two": "value_2"}, + 'entry_point': 'core.bool', + 'attributes': {'value': 'True'}, + 'extras': {'extra_one': 'value_1', 'extra_two': 'value_2'}, }, ) check_response = client.get(f"/nodes/{response.json()['id']}") assert check_response.status_code == 200, response.content - assert check_response.json()["extras"]["extra_one"] == "value_1" - assert check_response.json()["extras"]["extra_two"] == "value_2" + assert check_response.json()['extras']['extra_one'] == 'value_1' + assert check_response.json()['extras']['extra_two'] == 'value_2' diff --git a/tests/test_processes.py b/tests/test_processes.py index 2d28806..f32e219 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Test the /processes endpoint""" + import io import pytest @@ -8,7 +8,7 @@ def test_get_processes(example_processes, client): # pylint: disable=unused-argument """Test listing existing processes.""" - response = client.get("/processes/") + response = client.get('/processes/') assert response.status_code == 200 assert len(response.json()) == 12 @@ -16,52 +16,48 @@ def test_get_processes(example_processes, client): # pylint: disable=unused-arg def test_get_processes_projectable(client): """Test get projectable properties for processes.""" - response = client.get("/processes/projectable_properties") + response = client.get('/processes/projectable_properties') assert response.status_code == 200 assert response.json() == [ - "id", - "uuid", - "node_type", - "process_type", - "label", - "description", - "ctime", - "mtime", - "user_id", - "dbcomputer_id", - "attributes", - "extras", - "repository_metadata", + 'id', + 'uuid', + 'node_type', + 'process_type', + 'label', + 'description', + 'ctime', + 'mtime', + 'user_id', + 'dbcomputer_id', + 'attributes', + 'extras', + 'repository_metadata', ] -def test_get_single_processes( - example_processes, client -): # pylint: disable=unused-argument +def test_get_single_processes(example_processes, client): # pylint: disable=unused-argument """Test retrieving a single processes.""" for proc_id in example_processes: - response = client.get(f"/processes/{proc_id}") + response = client.get(f'/processes/{proc_id}') assert response.status_code == 200 @pytest.mark.anyio -async def test_add_process( - default_test_add_process, async_client, authenticate -): # pylint: disable=unused-argument +async def test_add_process(default_test_add_process, async_client, authenticate): # pylint: disable=unused-argument """Test adding new process""" code_id, x_id, y_id = default_test_add_process response = await async_client.post( - "/processes", + '/processes', json={ - "label": "test_new_process", - "process_entry_point": "aiida.calculations:core.arithmetic.add", - "inputs": { - "code.uuid": code_id, - "x.uuid": x_id, - "y.uuid": y_id, - "metadata": { - "description": "Test job submission with the add plugin", + 'label': 'test_new_process', + 'process_entry_point': 'aiida.calculations:core.arithmetic.add', + 'inputs': { + 'code.uuid': code_id, + 'x.uuid': x_id, + 'y.uuid': y_id, + 'metadata': { + 'description': 'Test job submission with the add plugin', }, }, }, @@ -69,22 +65,20 @@ async def test_add_process( assert response.status_code == 200 -def test_add_process_invalid_entry_point( - default_test_add_process, client, authenticate -): # pylint: disable=unused-argument +def test_add_process_invalid_entry_point(default_test_add_process, client, authenticate): # pylint: disable=unused-argument """Test adding new process with invalid entry point""" code_id, x_id, y_id = default_test_add_process response = client.post( - "/processes", + '/processes', json={ - "label": "test_new_process", - "process_entry_point": "wrong_entry_point", - "inputs": { - "code.uuid": code_id, - "x.uuid": x_id, - "y.uuid": y_id, - "metadata": { - "description": "Test job submission with the add plugin", + 'label': 'test_new_process', + 'process_entry_point': 'wrong_entry_point', + 'inputs': { + 'code.uuid': code_id, + 'x.uuid': x_id, + 'y.uuid': y_id, + 'metadata': { + 'description': 'Test job submission with the add plugin', }, }, }, @@ -92,59 +86,51 @@ def test_add_process_invalid_entry_point( assert response.status_code == 404 -def test_add_process_invalid_node_id( - default_test_add_process, client, authenticate -): # pylint: disable=unused-argument +def test_add_process_invalid_node_id(default_test_add_process, client, authenticate): # pylint: disable=unused-argument """Test adding new process with invalid Node ID""" code_id, x_id, _ = default_test_add_process response = client.post( - "/processes", + '/processes', json={ - "label": "test_new_process", - "process_entry_point": "aiida.calculations:core.arithmetic.add", - "inputs": { - "code.uuid": code_id, - "x.uuid": x_id, - "y.uuid": "891a9efa-f90e-11eb-9a03-0242ac130003", - "metadata": { - "description": "Test job submission with the add plugin", + 'label': 'test_new_process', + 'process_entry_point': 'aiida.calculations:core.arithmetic.add', + 'inputs': { + 'code.uuid': code_id, + 'x.uuid': x_id, + 'y.uuid': '891a9efa-f90e-11eb-9a03-0242ac130003', + 'metadata': { + 'description': 'Test job submission with the add plugin', }, }, }, ) assert response.status_code == 404 - assert response.json() == { - "detail": "Node with UUID `891a9efa-f90e-11eb-9a03-0242ac130003` does not exist." - } + assert response.json() == {'detail': 'Node with UUID `891a9efa-f90e-11eb-9a03-0242ac130003` does not exist.'} @pytest.mark.anyio -async def test_add_process_nested_inputs( - default_test_add_process, async_client, authenticate -): # pylint: disable=unused-argument +async def test_add_process_nested_inputs(default_test_add_process, async_client, authenticate): # pylint: disable=unused-argument """Test adding new process that has nested inputs""" code_id, _, _ = default_test_add_process template = Dict( { - "files_to_copy": [("file", "file.txt")], + 'files_to_copy': [('file', 'file.txt')], } ).store() - single_file = SinglefileData(io.StringIO("content")).store() + single_file = SinglefileData(io.StringIO('content')).store() response = await async_client.post( - "/processes", + '/processes', json={ - "label": "test_new_process", - "process_entry_point": "aiida.calculations:core.templatereplacer", - "inputs": { - "code.uuid": code_id, - "template.uuid": template.uuid, - "files": {"file.uuid": single_file.uuid}, - "metadata": { - "description": "Test job submission with the add plugin", - "options": { - "resources": {"num_machines": 1, "num_mpiprocs_per_machine": 1} - }, + 'label': 'test_new_process', + 'process_entry_point': 'aiida.calculations:core.templatereplacer', + 'inputs': { + 'code.uuid': code_id, + 'template.uuid': template.uuid, + 'files': {'file.uuid': single_file.uuid}, + 'metadata': { + 'description': 'Test job submission with the add plugin', + 'options': {'resources': {'num_machines': 1, 'num_mpiprocs_per_machine': 1}}, }, }, }, diff --git a/tests/test_users.py b/tests/test_users.py index 67bc9b1..d7ccf25 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Test the /users endpoint""" import pytest @@ -7,7 +6,7 @@ def test_get_single_user(default_users, client): # pylint: disable=unused-argument """Test retrieving a single user.""" for user_id in default_users: - response = client.get(f"/users/{user_id}") + response = client.get(f'/users/{user_id}') assert response.status_code == 200 @@ -17,29 +16,25 @@ def test_get_users(default_users, client): # pylint: disable=unused-argument Note: Besides the default users set up by the pytest fixture the test profile includes a default user. """ - response = client.get("/users") + response = client.get('/users') assert response.status_code == 200 assert len(response.json()) == 2 + 1 @pytest.mark.anyio -async def test_create_user( - async_client, authenticate -): # pylint: disable=unused-argument +async def test_create_user(async_client, authenticate): # pylint: disable=unused-argument """Test creating a new user.""" - response = await async_client.post( - "/users", json={"first_name": "New", "email": "aiida@localhost"} - ) + response = await async_client.post('/users', json={'first_name': 'New', 'email': 'aiida@localhost'}) assert response.status_code == 200, response.content - response = await async_client.get("/users") - first_names = [user["first_name"] for user in response.json()] - assert "New" in first_names + response = await async_client.get('/users') + first_names = [user['first_name'] for user in response.json()] + assert 'New' in first_names def test_get_users_projectable(client): """Test get projectable properites for users.""" - response = client.get("/users/projectable_properties") + response = client.get('/users/projectable_properties') assert response.status_code == 200 - assert response.json() == ["id", "email", "first_name", "last_name", "institution"] + assert response.json() == ['id', 'email', 'first_name', 'last_name', 'institution']