Skip to content

Commit

Permalink
feat(restapi): Moved validate service and helpers to shared folder
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewhand committed Nov 22, 2024
1 parent 1622ded commit 5128579
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 74 deletions.
71 changes: 1 addition & 70 deletions src/dioptra/restapi/v1/entrypoints/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
EntityExistsError,
QueryParameterNotUniqueError,
SortParameterValidationError,
EntrpointWorkflowYamlValidationError,
)
from dioptra.restapi.utils import find_non_unique
from dioptra.restapi.v1 import utils
Expand All @@ -42,8 +41,7 @@
from dioptra.restapi.v1.queues.service import RESOURCE_TYPE as QUEUE_RESOURCE_TYPE
from dioptra.restapi.v1.queues.service import QueueIdsService
from dioptra.restapi.v1.shared.search_parser import construct_sql_query_filters
from dioptra.restapi.v1.workflows.lib.export_task_engine_yaml import build_task_engine_dict_for_validation
from dioptra.task_engine.validation import validate
from dioptra.restapi.v1.shared.entrypoint_validate_service import EntrypointValidateService

LOGGER: BoundLogger = structlog.stdlib.get_logger()
PLUGIN_RESOURCE_TYPE: Final[str] = "entry_point_plugin"
Expand Down Expand Up @@ -90,7 +88,6 @@ def __init__(
self._queue_ids_service = queue_ids_service
self._group_id_service = group_id_service

# add here
def create(
self,
name: str,
Expand Down Expand Up @@ -406,7 +403,6 @@ def get(
entry_point=entrypoint, queues=queues, has_draft=has_draft
)

# add here
def modify(
self,
entrypoint_id: int,
Expand Down Expand Up @@ -509,13 +505,6 @@ def modify(
queue_resources = [queue.resource for queue in queues]
new_entrypoint.children = plugin_resources + queue_resources

# plugin_ids = [plugin.plugin.resource_id for plugin in new_entrypoint.entry_point_plugin_files]
# self._entrypoint_validate_service.validate(
# task_graph=task_graph,
# plugin_ids=plugin_ids,
# entrypoint_parameters=parameters,
# )

db.session.add(new_entrypoint)

if commit:
Expand Down Expand Up @@ -614,7 +603,6 @@ def get(

return _get_entrypoint_plugin_snapshots(entrypoint["entry_point"])

# add here
def append(
self,
entrypoint_id: int,
Expand Down Expand Up @@ -798,7 +786,6 @@ def get(

return plugin

# add here
def delete(
self,
entrypoint_id: int,
Expand Down Expand Up @@ -1261,62 +1248,6 @@ def get(
return entrypoint


class EntrypointValidateService(object):
"""The service for handling requests with entrypoint workflow yamls."""

@inject
def __init__(
self,
plugin_id_service: PluginIdsService,
) -> None:
"""Initialize the entrypoint service.
All arguments are provided via dependency injection.
Args:
plugin_ids_service: A PluginIdsService object.
"""
self._plugin_id_service = plugin_id_service

def validate(
self,
task_graph: str,
plugin_ids: list[int],
entrypoint_parameters: list[dict[str, Any]],
**kwargs,
) -> dict[str, Any]:
"""Validate a entrypoint workflow before the entrypoint is created.
Args:
task_graph: The proposed task graph of a new entrypoint resource.
plugin_ids: A list of plugin files for the new entrypoint.
parameters: A list of parameters for the new entrypoint.
Returns:
A success response and a indicator that states the entrypoint worklflow yaml is valid.
Raises:
EntrpointWorkflowYamlValidationError: If the entrypoint worklflow yaml is not valid.
"""
log: BoundLogger = kwargs.get("log", LOGGER.new())
log.debug("Validate a entrypoint workflow", task_graph=task_graph, plugin_ids=plugin_ids, entrypoint_parameters=entrypoint_parameters)

parameters = {param['name']: param['default_value'] for param in entrypoint_parameters}
plugins = self._plugin_id_service.get(plugin_ids)
task_engine_dict = build_task_engine_dict_for_validation(
plugins=plugins,
parameters=parameters,
task_graph=task_graph
)

issues = validate(task_engine_dict)

if not issues:
return {"status": "Success", "valid": True}
else:
raise EntrpointWorkflowYamlValidationError(issues)


def _get_entrypoint_plugin_snapshots(
entrypoint: models.EntryPoint,
) -> list[utils.PluginWithFilesDict]:
Expand Down
117 changes: 117 additions & 0 deletions src/dioptra/restapi/v1/shared/build_task_engine_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from pathlib import Path
from typing import Any, cast

import structlog
import yaml
from structlog.stdlib import BoundLogger

from dioptra.restapi.db import models
from dioptra.task_engine.type_registry import BUILTIN_TYPES

# from .type_coercions import (
# BOOLEAN_PARAM_TYPE,
# FLOAT_PARAM_TYPE,
# INTEGER_PARAM_TYPE,
# STRING_PARAM_TYPE,
# coerce_to_type,
# )

LOGGER: BoundLogger = structlog.stdlib.get_logger()

# EXPLICIT_GLOBAL_TYPES: Final[set[str]] = {
# STRING_PARAM_TYPE,
# BOOLEAN_PARAM_TYPE,
# INTEGER_PARAM_TYPE,
# FLOAT_PARAM_TYPE,
# }
# YAML_FILE_ENCODING: Final[str] = "utf-8"
# YAML_EXPORT_SETTINGS: Final[dict[str, Any]] = {
# "indent": 2,
# "sort_keys": False,
# "encoding": YAML_FILE_ENCODING,
# }


def build_task_engine_dict(
plugins: list[Any],
parameters: dict[str, Any],
task_graph: str,
) -> dict[str, Any]:
"""Build a dictionary representation of a task engine YAML file.
Args:
plugins: The entrypoint's plugin files.
parameters: The entrypoint parameteres.
task_graph: The task graph of the entrypoint.
Returns:
The task engine dictionary.
"""
tasks: dict[str, Any] = {}
parameter_types: dict[str, Any] = {}
for plugin in plugins:
for plugin_file in plugin['plugin_files']:
for task in plugin_file.tasks:
input_parameters = task.input_parameters
output_parameters = task.output_parameters
tasks[task.plugin_task_name] = {
"plugin": _build_plugin_field(plugin['plugin'], plugin_file, task),
}
if input_parameters:
tasks[task.plugin_task_name]["inputs"] = _build_task_inputs(
input_parameters
)
if output_parameters:
tasks[task.plugin_task_name]["outputs"] = _build_task_outputs(
output_parameters
)
for param in input_parameters + output_parameters:
name = param.parameter_type.name
if name not in BUILTIN_TYPES:
parameter_types[name] = param.parameter_type.structure

task_engine_dict = {
"types": parameter_types,
"parameters": parameters,
"tasks": tasks,
"graph": cast(dict[str, Any], yaml.safe_load(task_graph)),
}
return task_engine_dict


def _build_plugin_field(
plugin: models.Plugin, plugin_file: models.PluginFile, task: models.PluginTask
) -> str:
if plugin_file.filename == "__init__.py":
# Omit filename from plugin import path if it is an __init__.py file.
module_parts = [Path(x).stem for x in Path(plugin_file.filename).parts[:-1]]

else:
module_parts = [Path(x).stem for x in Path(plugin_file.filename).parts]

return ".".join([plugin.name, *module_parts, task.plugin_task_name])


def _build_task_inputs(
input_parameters: list[models.PluginTaskInputParameter],
) -> list[dict[str, Any]]:
return [
{
"name": input_param.name,
"type": input_param.parameter_type.name,
"required": input_param.required,
}
for input_param in input_parameters
]


def _build_task_outputs(
output_parameters: list[models.PluginTaskOutputParameter],
) -> list[dict[str, Any]] | dict[str, Any]:
if len(output_parameters) == 1:
return {output_parameters[0].name: output_parameters[0].parameter_type.name}

return [
{output_param.name: output_param.parameter_type.name}
for output_param in output_parameters
]
85 changes: 85 additions & 0 deletions src/dioptra/restapi/v1/shared/entrypoint_validate_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# This Software (Dioptra) is being made available as a public service by the
# National Institute of Standards and Technology (NIST), an Agency of the United
# States Department of Commerce. This software was developed in part by employees of
# NIST and in part by NIST contractors. Copyright in portions of this software that
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant
# to Title 17 United States Code Section 105, works of NIST employees are not
# subject to copyright protection in the United States. However, NIST may hold
# international copyright in software created by its employees and domestic
# copyright (or licensing rights) in portions of software that were assigned or
# licensed to NIST. To the extent that NIST holds copyright in this software, it is
# being made available under the Creative Commons Attribution 4.0 International
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts
# of the software developed or licensed by NIST.
#
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
# https://creativecommons.org/licenses/by/4.0/legalcode
"""The shared server-side functions that perform entrypoint validatation operations."""
from typing import Any, Final

import structlog
from injector import inject
from structlog.stdlib import BoundLogger

from dioptra.restapi.v1.plugins.service import PluginIdsService
from dioptra.restapi.v1.shared.build_task_engine_dict import build_task_engine_dict
from dioptra.task_engine.validation import validate
from dioptra.restapi.errors import EntrpointWorkflowYamlValidationError

LOGGER: BoundLogger = structlog.stdlib.get_logger()


class EntrypointValidateService(object):
"""The service for handling requests with entrypoint workflow yamls."""

@inject
def __init__(
self,
plugin_id_service: PluginIdsService,
) -> None:
"""Initialize the entrypoint service.
All arguments are provided via dependency injection.
Args:
plugin_ids_service: A PluginIdsService object.
"""
self._plugin_id_service = plugin_id_service

def validate(
self,
task_graph: str,
plugin_ids: list[int],
entrypoint_parameters: list[dict[str, Any]],
**kwargs,
) -> dict[str, Any]:
"""Validate a entrypoint workflow before the entrypoint is created.
Args:
task_graph: The proposed task graph of a new entrypoint resource.
plugin_ids: A list of plugin files for the new entrypoint.
parameters: A list of parameters for the new entrypoint.
Returns:
A success response and a indicator that states the entrypoint worklflow yaml is valid.
Raises:
EntrpointWorkflowYamlValidationError: If the entrypoint worklflow yaml is not valid.
"""
log: BoundLogger = kwargs.get("log", LOGGER.new())
log.debug("Validate a entrypoint workflow", task_graph=task_graph, plugin_ids=plugin_ids, entrypoint_parameters=entrypoint_parameters)

parameters = {param['name']: param['default_value'] for param in entrypoint_parameters}
plugins = self._plugin_id_service.get(plugin_ids)
task_engine_dict = build_task_engine_dict(
plugins=plugins,
parameters=parameters,
task_graph=task_graph
)

issues = validate(task_engine_dict)

if not issues:
return {"status": "Success", "valid": True}
else:
raise EntrpointWorkflowYamlValidationError(issues)
5 changes: 3 additions & 2 deletions src/dioptra/restapi/v1/workflows/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

from .schema import FileTypes, JobFilesDownloadQueryParametersSchema, EntrypointWorkflowSchema
from .service import JobFilesDownloadService
from dioptra.restapi.v1.entrypoints.service import EntrypointValidateService

from dioptra.restapi.v1.shared.entrypoint_validate_service import EntrypointValidateService

LOGGER: BoundLogger = structlog.stdlib.get_logger()

Expand Down Expand Up @@ -83,7 +84,7 @@ def get(self):

@api.route("/entrypointValidate")
class EntrypointValidateEndpoint(Resource):
"""Wrapper endpoint to expose entrypoint validation service."""
"""Wrapper endpoint to expose shared entrypoint validation service."""

@inject
def __init__(
Expand Down
1 change: 0 additions & 1 deletion src/dioptra/restapi/v1/workflows/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import structlog
from structlog.stdlib import BoundLogger
from injector import inject
import yaml

from .lib import views
from .lib.package_job_files import package_job_files
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/restapi/v1/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from dioptra.restapi.routes import V1_WORKFLOWS_ROUTE, V1_ROOT

from ..lib import actions, asserts, helpers
from ..lib import actions


# -- Actions ---------------------------------------------------------------------------
Expand Down

0 comments on commit 5128579

Please sign in to comment.