-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(restapi): Moved validate service and helpers to shared folder
- Loading branch information
1 parent
1622ded
commit 5128579
Showing
6 changed files
with
207 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
117 changes: 117 additions & 0 deletions
117
src/dioptra/restapi/v1/shared/build_task_engine_dict.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from pathlib import Path | ||
from typing import Any, cast | ||
|
||
import structlog | ||
import yaml | ||
from structlog.stdlib import BoundLogger | ||
|
||
from dioptra.restapi.db import models | ||
from dioptra.task_engine.type_registry import BUILTIN_TYPES | ||
|
||
# from .type_coercions import ( | ||
# BOOLEAN_PARAM_TYPE, | ||
# FLOAT_PARAM_TYPE, | ||
# INTEGER_PARAM_TYPE, | ||
# STRING_PARAM_TYPE, | ||
# coerce_to_type, | ||
# ) | ||
|
||
LOGGER: BoundLogger = structlog.stdlib.get_logger() | ||
|
||
# EXPLICIT_GLOBAL_TYPES: Final[set[str]] = { | ||
# STRING_PARAM_TYPE, | ||
# BOOLEAN_PARAM_TYPE, | ||
# INTEGER_PARAM_TYPE, | ||
# FLOAT_PARAM_TYPE, | ||
# } | ||
# YAML_FILE_ENCODING: Final[str] = "utf-8" | ||
# YAML_EXPORT_SETTINGS: Final[dict[str, Any]] = { | ||
# "indent": 2, | ||
# "sort_keys": False, | ||
# "encoding": YAML_FILE_ENCODING, | ||
# } | ||
|
||
|
||
def build_task_engine_dict( | ||
plugins: list[Any], | ||
parameters: dict[str, Any], | ||
task_graph: str, | ||
) -> dict[str, Any]: | ||
"""Build a dictionary representation of a task engine YAML file. | ||
Args: | ||
plugins: The entrypoint's plugin files. | ||
parameters: The entrypoint parameteres. | ||
task_graph: The task graph of the entrypoint. | ||
Returns: | ||
The task engine dictionary. | ||
""" | ||
tasks: dict[str, Any] = {} | ||
parameter_types: dict[str, Any] = {} | ||
for plugin in plugins: | ||
for plugin_file in plugin['plugin_files']: | ||
for task in plugin_file.tasks: | ||
input_parameters = task.input_parameters | ||
output_parameters = task.output_parameters | ||
tasks[task.plugin_task_name] = { | ||
"plugin": _build_plugin_field(plugin['plugin'], plugin_file, task), | ||
} | ||
if input_parameters: | ||
tasks[task.plugin_task_name]["inputs"] = _build_task_inputs( | ||
input_parameters | ||
) | ||
if output_parameters: | ||
tasks[task.plugin_task_name]["outputs"] = _build_task_outputs( | ||
output_parameters | ||
) | ||
for param in input_parameters + output_parameters: | ||
name = param.parameter_type.name | ||
if name not in BUILTIN_TYPES: | ||
parameter_types[name] = param.parameter_type.structure | ||
|
||
task_engine_dict = { | ||
"types": parameter_types, | ||
"parameters": parameters, | ||
"tasks": tasks, | ||
"graph": cast(dict[str, Any], yaml.safe_load(task_graph)), | ||
} | ||
return task_engine_dict | ||
|
||
|
||
def _build_plugin_field( | ||
plugin: models.Plugin, plugin_file: models.PluginFile, task: models.PluginTask | ||
) -> str: | ||
if plugin_file.filename == "__init__.py": | ||
# Omit filename from plugin import path if it is an __init__.py file. | ||
module_parts = [Path(x).stem for x in Path(plugin_file.filename).parts[:-1]] | ||
|
||
else: | ||
module_parts = [Path(x).stem for x in Path(plugin_file.filename).parts] | ||
|
||
return ".".join([plugin.name, *module_parts, task.plugin_task_name]) | ||
|
||
|
||
def _build_task_inputs( | ||
input_parameters: list[models.PluginTaskInputParameter], | ||
) -> list[dict[str, Any]]: | ||
return [ | ||
{ | ||
"name": input_param.name, | ||
"type": input_param.parameter_type.name, | ||
"required": input_param.required, | ||
} | ||
for input_param in input_parameters | ||
] | ||
|
||
|
||
def _build_task_outputs( | ||
output_parameters: list[models.PluginTaskOutputParameter], | ||
) -> list[dict[str, Any]] | dict[str, Any]: | ||
if len(output_parameters) == 1: | ||
return {output_parameters[0].name: output_parameters[0].parameter_type.name} | ||
|
||
return [ | ||
{output_param.name: output_param.parameter_type.name} | ||
for output_param in output_parameters | ||
] |
85 changes: 85 additions & 0 deletions
85
src/dioptra/restapi/v1/shared/entrypoint_validate_service.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# This Software (Dioptra) is being made available as a public service by the | ||
# National Institute of Standards and Technology (NIST), an Agency of the United | ||
# States Department of Commerce. This software was developed in part by employees of | ||
# NIST and in part by NIST contractors. Copyright in portions of this software that | ||
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant | ||
# to Title 17 United States Code Section 105, works of NIST employees are not | ||
# subject to copyright protection in the United States. However, NIST may hold | ||
# international copyright in software created by its employees and domestic | ||
# copyright (or licensing rights) in portions of software that were assigned or | ||
# licensed to NIST. To the extent that NIST holds copyright in this software, it is | ||
# being made available under the Creative Commons Attribution 4.0 International | ||
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts | ||
# of the software developed or licensed by NIST. | ||
# | ||
# ACCESS THE FULL CC BY 4.0 LICENSE HERE: | ||
# https://creativecommons.org/licenses/by/4.0/legalcode | ||
"""The shared server-side functions that perform entrypoint validatation operations.""" | ||
from typing import Any, Final | ||
|
||
import structlog | ||
from injector import inject | ||
from structlog.stdlib import BoundLogger | ||
|
||
from dioptra.restapi.v1.plugins.service import PluginIdsService | ||
from dioptra.restapi.v1.shared.build_task_engine_dict import build_task_engine_dict | ||
from dioptra.task_engine.validation import validate | ||
from dioptra.restapi.errors import EntrpointWorkflowYamlValidationError | ||
|
||
LOGGER: BoundLogger = structlog.stdlib.get_logger() | ||
|
||
|
||
class EntrypointValidateService(object): | ||
"""The service for handling requests with entrypoint workflow yamls.""" | ||
|
||
@inject | ||
def __init__( | ||
self, | ||
plugin_id_service: PluginIdsService, | ||
) -> None: | ||
"""Initialize the entrypoint service. | ||
All arguments are provided via dependency injection. | ||
Args: | ||
plugin_ids_service: A PluginIdsService object. | ||
""" | ||
self._plugin_id_service = plugin_id_service | ||
|
||
def validate( | ||
self, | ||
task_graph: str, | ||
plugin_ids: list[int], | ||
entrypoint_parameters: list[dict[str, Any]], | ||
**kwargs, | ||
) -> dict[str, Any]: | ||
"""Validate a entrypoint workflow before the entrypoint is created. | ||
Args: | ||
task_graph: The proposed task graph of a new entrypoint resource. | ||
plugin_ids: A list of plugin files for the new entrypoint. | ||
parameters: A list of parameters for the new entrypoint. | ||
Returns: | ||
A success response and a indicator that states the entrypoint worklflow yaml is valid. | ||
Raises: | ||
EntrpointWorkflowYamlValidationError: If the entrypoint worklflow yaml is not valid. | ||
""" | ||
log: BoundLogger = kwargs.get("log", LOGGER.new()) | ||
log.debug("Validate a entrypoint workflow", task_graph=task_graph, plugin_ids=plugin_ids, entrypoint_parameters=entrypoint_parameters) | ||
|
||
parameters = {param['name']: param['default_value'] for param in entrypoint_parameters} | ||
plugins = self._plugin_id_service.get(plugin_ids) | ||
task_engine_dict = build_task_engine_dict( | ||
plugins=plugins, | ||
parameters=parameters, | ||
task_graph=task_graph | ||
) | ||
|
||
issues = validate(task_engine_dict) | ||
|
||
if not issues: | ||
return {"status": "Success", "valid": True} | ||
else: | ||
raise EntrpointWorkflowYamlValidationError(issues) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters