diff --git a/CHANGELOG.md b/CHANGELOG.md index bd560a3d4..fdefc832b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,21 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## 0.18.0 + +### Added + +- Displays : Tree is now a display +- Files : Method from_file now sets Binary and StringFile filename attribute +- Schemas : Order entry based on signature order + +### Changed + +- MultiObject : Now compute object names for sample names +- Workflow : Remove workflow display from WorkflowRun +- Displays : CAD/ volmdlr_primitives backward compatibility has been + + ## 0.17.0 ### Added diff --git a/code_pylint.py b/code_pylint.py index 20a456c9e..fd2c2c755 100644 --- a/code_pylint.py +++ b/code_pylint.py @@ -18,7 +18,7 @@ from pylint import __version__ from pylint.lint import Run -MIN_NOTE = 9.5 +MIN_NOTE = 9.7 EFFECTIVE_DATE = date(2023, 1, 18) WEEKLY_DECREASE = 0.03 @@ -31,7 +31,6 @@ "too-many-locals": 5, # Reduce by dropping vectored objects "too-many-branches": 8, # Huge refactor needed. Will be reduced by schema refactor "unused-argument": 3, # Some abstract functions have unused arguments (plot_data). Hence cannot decrease - "cyclic-import": 2, # Still work to do on Specific based DessiaObject "too-many-arguments": 18, # Huge refactor needed "too-few-public-methods": 4, # Abstract classes (Errors, Checks,...) "too-many-return-statements": 7, # Huge refactor needed. Will be reduced by schema refactor diff --git a/dessia_common/breakdown.py b/dessia_common/breakdown.py index ba9e0d40f..3ad466797 100644 --- a/dessia_common/breakdown.py +++ b/dessia_common/breakdown.py @@ -8,9 +8,7 @@ import collections.abc import numpy as npy -from dessia_common import REF_MARKER, OLD_REF_MARKER import dessia_common.serialization as dcs -import dessia_common.utils.types as dct def attrmethod_getter(object_, attr_methods): @@ -33,96 +31,6 @@ def attrmethod_getter(object_, attr_methods): return object_ -class ExtractionError(Exception): - """ Custom Exception for deep attributes Extraction process. """ - - -def extract_segment_from_object(object_, segment: str): - """ Try all ways to get an attribute (segment) from an object that can of numerous types. """ - if dct.is_sequence(object_): - try: - return object_[int(segment)] - except ValueError as err: - message_error = (f"Cannot extract segment {segment} from object {{str(object_)[:500]}}:" - f" segment is not a sequence index") - raise ExtractionError(message_error) from err - - if isinstance(object_, dict): - if segment in object_: - return object_[segment] - - if segment.isdigit(): - intifyed_segment = int(segment) - if intifyed_segment in object_: - return object_[intifyed_segment] - if segment in object_: - return object_[segment] - raise ExtractionError(f'Cannot extract segment {segment} from object {str(object_)[:200]}') - - # should be a tuple - if segment.startswith('(') and segment.endswith(')') and ',' in segment: - key = [] - for subsegment in segment.strip('()').replace(' ', '').split(','): - if subsegment.isdigit(): - subkey = int(subsegment) - else: - subkey = subsegment - key.append(subkey) - return object_[tuple(key)] - raise ExtractionError(f"Cannot extract segment {segment} from object {str(object_)[:500]}") - - # Finally, it is a regular object - return getattr(object_, segment) - - -def get_in_object_from_path(object_, path, evaluate_pointers=True): - """ Get deep attributes from an object. Argument 'path' represents path to deep attribute. """ - segments = path.lstrip('#/').split('/') - element = object_ - for segment in segments: - if isinstance(element, dict): - # Going down in the object and it is a reference : evaluating sub-reference - if evaluate_pointers: - if REF_MARKER in element: - try: - element = get_in_object_from_path(object_, element[REF_MARKER]) - except RecursionError as err: - err_msg = f'Cannot get segment {segment} from path {path} in element {str(element)[:500]}' - raise RecursionError(err_msg) from err - elif OLD_REF_MARKER in element: # Retro-compatibility to be remove sometime - try: - element = get_in_object_from_path(object_, element[OLD_REF_MARKER]) - except RecursionError as err: - err_msg = f'Cannot get segment {segment} from path {path} in element {str(element)[:500]}' - raise RecursionError(err_msg) from err - - try: - element = extract_segment_from_object(element, segment) - except ExtractionError as err: - - err_msg = f'Cannot get segment {segment} from path {path} in element {str(element)[:500]}' - raise ExtractionError(err_msg) from err - - return element - - -def set_in_object_from_path(object_, path, value, evaluate_pointers=True): - """ Set deep attribute from an object to the given value. Argument 'path' represents path to deep attribute. """ - reduced_path = '/'.join(path.lstrip('#/').split('/')[:-1]) - last_segment = path.split('/')[-1] - if reduced_path: - last_object = get_in_object_from_path(object_, reduced_path, evaluate_pointers=evaluate_pointers) - else: - last_object = object_ - - if dct.is_sequence(last_object): - last_object[int(last_segment)] = value - elif isinstance(last_object, dict): - last_object[last_segment] = value - else: - setattr(last_object, last_segment, value) - - def merge_breakdown_dicts(dict1, dict2): """ Merge strategy of breakdown dictionaries. """ dict3 = dict1.copy() diff --git a/dessia_common/core.py b/dessia_common/core.py index 97bd6ab6d..9622cc751 100644 --- a/dessia_common/core.py +++ b/dessia_common/core.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Module to handle serialization for engineering objects. """ - +import base64 +import io import time import warnings import operator @@ -18,9 +19,12 @@ import traceback as tb from importlib import import_module +import numpy as npy -from dessia_common.utils.diff import data_eq, diff, choose_hash -from dessia_common.utils.types import is_bson_valid +from dessia_common import FLOAT_TOLERANCE +from dessia_common.utils.diff import diff, choose_hash +from dessia_common.utils.helpers import full_classname, is_sequence +from dessia_common.utils.types import is_bson_valid, isinstance_base_types from dessia_common.utils.copy import deepcopy_value import dessia_common.schemas.core as dcs from dessia_common.serialization import SerializableObject, deserialize_argument, serialize @@ -29,7 +33,7 @@ from dessia_common import templates import dessia_common.checks as dcc from dessia_common.displays import DisplayObject, DisplaySetting -from dessia_common.breakdown import attrmethod_getter, get_in_object_from_path +from dessia_common.breakdown import attrmethod_getter import dessia_common.utils.helpers as dch import dessia_common.files as dcf from dessia_common.document_generator import DocxWriter @@ -150,7 +154,7 @@ def _data_diff(self, other_object): def _get_from_path(self, path: str): """ Get object's deep attribute from given path. """ - return get_in_object_from_path(self, path) + return dch.get_in_object_from_path(self, path) @classmethod def raw_schema(cls): @@ -396,34 +400,43 @@ def plot(self, reference_path: str = "#", **kwargs): for data in self.plot_data(reference_path, **kwargs): plot_data.plot_canvas(plot_data_object=data, canvas_id='canvas', - width=1400, height=900, - debug_mode=False) + width=1400, height=900) else: msg = f"Class '{self.__class__.__name__}' does not implement a plot_data method to define what to plot" raise NotImplementedError(msg) - def mpl_plot(self, **kwargs): + def mpl_plot(self, selector: str): """ Plot with matplotlib using plot_data function. """ - axs = [] - if hasattr(self, 'plot_data'): - try: - plot_datas = self.plot_data(**kwargs) - except TypeError as error: - raise TypeError(f'{self.__class__.__name__}.{error}') from error - for data in plot_datas: - if hasattr(data, 'mpl_plot'): - ax = data.mpl_plot() - axs.append(ax) - else: - msg = f"Class '{self.__class__.__name__}' does not implement a plot_data method to define what to plot" - raise NotImplementedError(msg) - return axs + display_setting = self._display_settings_from_selector(selector) + if display_setting.type != "plot_data": + raise NotImplementedError(f"Selector '{selector}' depicts a display of type '{display_setting.type}'" + f" which cannot be used to plot with matplotlib." + f"\nPlease select a 'plot_data' display setting.") + display = attrmethod_getter(self, display_setting.method)(**display_setting.arguments) + if hasattr(display, 'mpl_plot'): + return display.mpl_plot() + raise NotImplementedError(f"plot_data display of type '{display.__class__.__name__}'" + f" does not implement a mpl_plot converter." + f"\nSelector used : '{selector}'.") + + def picture(self, stream, selector: str): + """ Take a stream to generate picture. """ + ax = self.mpl_plot(selector) + ax.set_axis_off() + ax.figure.savefig(stream, format="png") + stream.seek(0) @classmethod def display_settings(cls, **kwargs) -> List[DisplaySetting]: """ Return a list of objects describing how to call object displays. """ - settings = [DisplaySetting(selector="markdown", type_="markdown", method="to_markdown", load_by_default=True)] - settings.extend(cls._display_settings_from_decorators()) + decorators_settings = cls._display_settings_from_decorators() + has_markdown = any([s.type == "markdown" for s in decorators_settings]) + settings = decorators_settings + if not has_markdown: + default_md = DisplaySetting(selector="Markdown", type_="markdown", method="to_markdown", + load_by_default=True) + settings.insert(0, default_md) + settings.append(DisplaySetting(selector="Structure Tree", type_="tree", method="")) return settings @classmethod @@ -452,13 +465,12 @@ def _display_from_selector(self, selector: str) -> DisplayObject: display_setting = self._display_settings_from_selector(selector) track = "" try: - data = attrmethod_getter(self, display_setting.method)(**display_setting.arguments) + display = attrmethod_getter(self, display_setting.method)(**display_setting.arguments) except Exception: - data = None + display = None track = tb.format_exc() - if display_setting.serialize_data: - data = serialize(data) + data = serialize(display) if display_setting.serialize_data else display reference_path = display_setting.reference_path # Trying this return DisplayObject(type_=display_setting.type, data=data, reference_path=reference_path, traceback=track) @@ -652,7 +664,7 @@ def to_vector(self): """ Compute vector from object. """ vectored_objects = [] for feature in self.vector_features(): - vectored_objects.append(get_in_object_from_path(self, feature.lower())) + vectored_objects.append(dch.get_in_object_from_path(self, feature.lower())) return vectored_objects @classmethod @@ -666,14 +678,6 @@ def vector_features(cls): class PhysicalObject(DessiaObject): """ Represent an object with CAD capabilities. """ - @classmethod - def display_settings(cls, **kwargs): - """ Returns a list of DisplaySettings objects describing how to call sub-displays. """ - display_settings = super().display_settings() - display_settings.append(DisplaySetting(selector='cad', type_='babylon_data', - method='volmdlr_volume_model().babylon_data', serialize_data=True)) - return display_settings - def volmdlr_primitives(self, **kwargs): """ Return a list of volmdlr primitives to build up volume model. """ warnings.warn("This method is deprecated and will be removed in a future version. " @@ -907,7 +911,7 @@ def _comparison_operator(self): return self._REAL_OPERATORS[self.comparison_operator] def _to_lambda(self): - return lambda x: (self._comparison_operator()(get_in_object_from_path(value, f'#/{self.attribute}'), + return lambda x: (self._comparison_operator()(dch.get_in_object_from_path(value, f'#/{self.attribute}'), self.bound) for value in x) def get_booleans_index(self, values: List[DessiaObject]): @@ -1199,3 +1203,71 @@ def get_attribute_names(object_class): for item in [float, int, bool, complex])] attributes += [a for a in subclass_numeric_attributes if a not in dcs.RESERVED_ARGNAMES] return attributes + + +def data_eq(value1, value2): + """ Returns if two values are equal on data equality. """ + if is_sequence(value1) and is_sequence(value2): + return sequence_data_eq(value1, value2) + + if isinstance(value1, npy.int64) or isinstance(value2, npy.int64): + return value1 == value2 + + if isinstance(value1, npy.float64) or isinstance(value2, npy.float64): + return math.isclose(value1, value2, abs_tol=FLOAT_TOLERANCE) + + if not isinstance(value2, type(value1)) and not isinstance(value1, type(value2)): + return False + + if isinstance_base_types(value1): + if isinstance(value1, float): + return math.isclose(value1, value2, abs_tol=FLOAT_TOLERANCE) + return value1 == value2 + + if isinstance(value1, dict): + return dict_data_eq(value1, value2) + + if isinstance(value1, (dcf.BinaryFile, dcf.StringFile)): + return value1 == value2 + + if isinstance(value1, type): + return full_classname(value1) == full_classname(value2) + + # Else: its an object + if full_classname(value1) != full_classname(value2): + return False + + # Test if _data_eq is customized + if hasattr(value1, '_data_eq'): + custom_method = value1._data_eq.__code__ is not DessiaObject._data_eq.__code__ + if custom_method: + return value1._data_eq(value2) + + # Not custom, use generic implementation + eq_dict = value1._data_eq_dict() + if 'name' in eq_dict: + del eq_dict['name'] + + other_eq_dict = value2._data_eq_dict() + return dict_data_eq(eq_dict, other_eq_dict) + + +def dict_data_eq(dict1, dict2): + """ Returns True if two dictionaries are equal on data equality, False otherwise. """ + for key, value in dict1.items(): + if key not in dict2: + return False + if not data_eq(value, dict2[key]): + return False + return True + + +def sequence_data_eq(seq1, seq2): + """ Returns if two sequences are equal on data equality. """ + if len(seq1) != len(seq2): + return False + + for v1, v2 in zip(seq1, seq2): + if not data_eq(v1, v2): + return False + return True diff --git a/dessia_common/decorators.py b/dessia_common/decorators.py index c4a4901f0..28dfc5d2d 100644 --- a/dessia_common/decorators.py +++ b/dessia_common/decorators.py @@ -40,6 +40,23 @@ def get_decorated_methods(class_: Type, decorator_name: str): return [getattr(class_, n) for n in method_names] +def picture_view(selector: str = None, load_by_default: bool = False): + """ + Decorator to plot data pictures. + + :param str selector: A custom and unique name that identifies the display. + It is what is displayed on platform to select your view. + + :param bool load_by_default: Whether the view should be displayed on platform by default or not. + """ + def decorator(function): + """ Decorator to plot data.""" + set_decorated_function_metadata(function=function, type_="picture", selector=selector, + serialize_data=True, load_by_default=load_by_default) + return function + return decorator + + def plot_data_view(selector: str = None, load_by_default: bool = False): """ Decorator to plot data. diff --git a/dessia_common/errors.py b/dessia_common/errors.py index a913e4c1c..29d0d41aa 100644 --- a/dessia_common/errors.py +++ b/dessia_common/errors.py @@ -41,3 +41,7 @@ class CopyError(Exception): class UntypedArgumentError(Exception): """ Error of code annotation. """ + + +class ExtractionError(Exception): + """ Custom Exception for deep attributes Extraction process. """ diff --git a/dessia_common/files.py b/dessia_common/files.py index 69297c279..f20d390e2 100644 --- a/dessia_common/files.py +++ b/dessia_common/files.py @@ -31,7 +31,7 @@ def stream_template(cls): def from_file(cls, filepath: str): """ Get a file from a binary file. """ with open(filepath, 'rb') as file: - stream = cls() + stream = cls(filepath) stream.write(file.read()) stream.seek(0) return stream @@ -91,7 +91,7 @@ def from_stream(cls, stream_file): def from_file(cls, filepath: str): """ Get a file from a file. """ with open(filepath, 'r', encoding='utf-8') as file: - stream = cls() + stream = cls(filepath) stream.write(file.read()) stream.seek(0) return stream diff --git a/dessia_common/forms.py b/dessia_common/forms.py index a1a8c798a..1e359450b 100644 --- a/dessia_common/forms.py +++ b/dessia_common/forms.py @@ -26,6 +26,7 @@ from typing import Dict, List, Tuple, Union, Any, Literal, get_args import time from numpy import linspace +from random import randrange try: import volmdlr as vm @@ -43,7 +44,7 @@ from dessia_common.files import BinaryFile, StringFile -from dessia_common.decorators import plot_data_view, markdown_view, cad_view +from dessia_common.decorators import plot_data_view, markdown_view, cad_view, picture_view class EmbeddedBuiltinsSubobject(PhysicalObject): @@ -804,6 +805,14 @@ def plot2d(self, reference_path: str = "#"): fill_style = plot_data.SurfaceStyle(color_fill=plot_data.colors.WHITE) return contour.plot_data(edge_style=edge_style, surface_style=fill_style) + @cad_view("CAD View") + def plot3d(self, reference_path: str = "#"): + """ A dummy 3D method to test form interactions. """ + contour = self.contour(reference_path) + frame = vm.Frame3D(origin=vm.Point3D(0, 0, 0), u=vm.X3D, v=vm.Y3D, w=vm.Z3D) + primitive = p3d.ExtrudedProfile(frame, outer_contour2d=contour, inner_contours2d=[], extrusion_length=1) + return vm.core.VolumeModel([primitive]).babylon_data() + class VerticalBeam(Beam): """ A dummy class to test 2D/3D form interactions. """ @@ -820,13 +829,21 @@ def contour(self, origin: float, reference_path: str = "#"): return p2d.ClosedRoundedLineSegments2D(points=points, radius={}, reference_path=reference_path) @plot_data_view("2D View") - def plot2d(self, origin: float, reference_path: str = "#"): + def plot2d(self, origin: float = 0, reference_path: str = "#"): """ A dummy 2D method to test form interactions. """ contour = self.contour(origin=origin, reference_path=reference_path) edge_style = plot_data.EdgeStyle(color_stroke=plot_data.colors.BLUE) fill_style = plot_data.SurfaceStyle(color_fill=plot_data.colors.WHITE) return contour.plot_data(edge_style=edge_style, surface_style=fill_style) + @cad_view("CAD View") + def plot3d(self, origin: float = 0, reference_path: str = "#"): + """ A dummy 3D method to test form interactions. """ + contour = self.contour(origin=origin, reference_path=reference_path) + frame = vm.Frame3D(origin=vm.Point3D(0, 0, 0), u=vm.X3D, v=vm.Y3D, w=vm.Z3D) + primitive = p3d.ExtrudedProfile(frame, outer_contour2d=contour, inner_contours2d=[], extrusion_length=1) + return vm.core.VolumeModel([primitive]).babylon_data() + class BeamStructure(DessiaObject): """ A dummy class to test 2D/3D form interactions. """ @@ -836,10 +853,12 @@ class BeamStructure(DessiaObject): def __init__(self, horizontal_beam: HorizontalBeam, vertical_beams: List[VerticalBeam], name: str = ""): self.horizontal_beam = horizontal_beam self.vertical_beams = vertical_beams + self.n_beams = len(vertical_beams) super().__init__(name=name) @plot_data_view("2D View") + @picture_view("2D View") def plot2d(self, reference_path: str = "#"): """ A dummy 2D method to test form interactions. """ horizontal_contour = self.horizontal_beam.plot2d(reference_path=f"{reference_path}/horizontal_beam") @@ -847,8 +866,47 @@ def plot2d(self, reference_path: str = "#"): reference_path=f"{reference_path}/vertical_beams/{i}") for i, b in enumerate(self.vertical_beams)] labels = [plot_data.Label(c.reference_path, shape=c) for c in [horizontal_contour] + vertical_contours] - primtives = [horizontal_contour] + vertical_contours + labels - return plot_data.PrimitiveGroup(primitives=primtives, name="Contour") + primitives = [horizontal_contour] + vertical_contours + labels + return plot_data.PrimitiveGroup(primitives=primitives, name="Contour") + + @cad_view("CAD View") + def plot3d(self, reference_path: str = "#"): + """ A dummy 3D method to test form interactions. """ + horizontal_primitive = self.horizontal_beam.plot3d(reference_path=f"{reference_path}/horizontal_beam") + vertical_primitives = [b.plot3d(origin=self.horizontal_beam.length * i / len(self.vertical_beams), + reference_path=f"{reference_path}/vertical_beams/{i}") + for i, b in enumerate(self.vertical_beams)] + primitives = [horizontal_primitive] + vertical_primitives + return vm.core.VolumeModel(primitives).babylon_data() + + +class BeamStructureGenerator(DessiaObject): + """ A dummy class to generate a lot of BeamStructures. """ + + _standalone_in_db = True + + def __init__(self, n_solutions: int = 5, max_beams: int = 5, + max_length: int = 100, min_length: int = 10, name: str = "Beams"): + self.n_solutions = n_solutions + self.max_beams = max_beams + self.max_length = max_length + self.min_length = min_length + + super().__init__(name) + + def generate(self) -> List[BeamStructure]: + """ A dummy method to generate a certain number of BeamStructures. """ + beam_structures = [] + for i in range(self.n_solutions): + n_beams = randrange(1, self.max_beams + 1) + horizontal_beam = HorizontalBeam(length=(n_beams - 1) * 10, name="H") + vertical_beams = [] + for j in range(n_beams): + length = randrange(self.min_length * 10, self.max_length * 10) / 10 + vertical_beams.append(VerticalBeam(length=length, name=f"V{j + 1}")) + beam_structures.append(BeamStructure(horizontal_beam=horizontal_beam, vertical_beams=vertical_beams, + name=f"{self.name} {i + 1}")) + return beam_structures # Definition 1 diff --git a/dessia_common/schemas/core.py b/dessia_common/schemas/core.py index 6672c1c8c..eb22cacae 100644 --- a/dessia_common/schemas/core.py +++ b/dessia_common/schemas/core.py @@ -188,9 +188,11 @@ def standalone_properties(self) -> List[str]: def to_dict(self, **kwargs) -> Dict[str, Any]: """ Base Schema. kwargs are added to result as well. """ schema = deepcopy(SCHEMA_HEADER) + order = [a.name for a in self.attributes] properties = {a.name: self.chunk(a) for a in self.attributes} required = [a.name for a in self.required] - schema.update({"required": required, "properties": properties, "description": self.documentation}) + schema.update({"required": required, "order": order, "properties": properties, + "description": self.documentation}) schema.update(**kwargs) return schema @@ -389,8 +391,10 @@ def return_serialized(self): def to_dict(self, **kwargs): """ Write the whole schema. """ schema = deepcopy(SCHEMA_HEADER) + order = [str(i) for i in range(len(self.attributes))] properties = {str(i): self.chunk(a) for i, a in enumerate(self.attributes)} - schema.update({"required": self.required_indices, "properties": properties, "description": self.documentation}) + schema.update({"required": self.required_indices, "order": order, "properties": properties, + "description": self.documentation}) return schema def definition_json(self): diff --git a/dessia_common/serialization.py b/dessia_common/serialization.py index 6217987d9..399278721 100644 --- a/dessia_common/serialization.py +++ b/dessia_common/serialization.py @@ -15,12 +15,12 @@ import dessia_common.errors as dc_err from dessia_common.files import StringFile, BinaryFile import dessia_common.utils.types as dcty -from dessia_common.utils.helpers import full_classname, get_python_class_from_class_name +from dessia_common.utils.helpers import (full_classname, get_python_class_from_class_name, is_sequence, + get_in_object_from_path, set_in_object_from_path) from dessia_common.abstract import CoreDessiaObject from dessia_common.typings import InstanceOf, JsonSerializable from dessia_common.measures import Measure from dessia_common.graph import explore_tree_from_leaves -from dessia_common.breakdown import get_in_object_from_path, set_in_object_from_path from dessia_common.schemas.core import TYPING_EQUIVALENCES, is_typing, serialize_annotation fullargsspec_cache = {} @@ -103,7 +103,7 @@ def serialize(value): serialized_value = value.to_dict() elif isinstance(value, dict): serialized_value = serialize_dict(value) - elif dcty.is_sequence(value): + elif is_sequence(value): serialized_value = serialize_sequence(value) elif isinstance(value, (BinaryFile, StringFile)): serialized_value = value @@ -171,7 +171,7 @@ def serialize_with_pointers(value, memo=None, path='#', id_method=True, id_memo= elif isinstance(value, dict): serialized, memo = serialize_dict_with_pointers(value, memo=memo, path=path, id_method=id_method, id_memo=id_memo) - elif dcty.is_sequence(value): + elif is_sequence(value): serialized, memo = serialize_sequence_with_pointers(value, memo=memo, path=path, id_method=id_method, id_memo=id_memo) @@ -209,7 +209,7 @@ def serialize_dict_with_pointers(dict_, memo, path, id_method, id_memo): for key, value in dict_.items(): if isinstance(value, dict): dict_attrs_keys.append(key) - elif dcty.is_sequence(value): + elif is_sequence(value): seq_attrs_keys.append(key) else: other_keys.append(key) @@ -265,7 +265,7 @@ def deserialize(serialized_element, sequence_annotation: str = 'List', if isinstance(serialized_element, dict): return dict_to_object(serialized_element, global_dict=global_dict, pointers_memo=pointers_memo, path=path) - if dcty.is_sequence(serialized_element): + if is_sequence(serialized_element): return deserialize_sequence(sequence=serialized_element, annotation=sequence_annotation, global_dict=global_dict, pointers_memo=pointers_memo, path=path) if isinstance(serialized_element, str): @@ -481,7 +481,7 @@ def find_references(value, path='#'): return find_references_dict(value, path) if dcty.isinstance_base_types(value): return [] - if dcty.is_sequence(value): + if is_sequence(value): return find_references_sequence(value, path) if isinstance(value, (BinaryFile, StringFile)): return [] @@ -673,7 +673,7 @@ def pointer_graph_elements(value, path='#'): return pointer_graph_elements_dict(value, path) if dcty.isinstance_base_types(value): return [], [] - if dcty.is_sequence(value): + if is_sequence(value): return pointer_graph_elements_sequence(value, path) raise ValueError(value) @@ -790,7 +790,7 @@ def is_serializable(obj): if not is_serializable(key) or not is_serializable(value): return False return True - if dcty.is_sequence(obj): + if is_sequence(obj): for element in obj: if not is_serializable(element): return False diff --git a/dessia_common/utils/diff.py b/dessia_common/utils/diff.py index 237154d57..043c9a02b 100644 --- a/dessia_common/utils/diff.py +++ b/dessia_common/utils/diff.py @@ -5,12 +5,9 @@ import math from typing import List -import numpy as npy from dessia_common import FLOAT_TOLERANCE -import dessia_common.core as dc -from dessia_common.utils.types import isinstance_base_types, is_sequence -from dessia_common.utils.helpers import full_classname -from dessia_common.files import BinaryFile, StringFile +from dessia_common.utils.types import isinstance_base_types +from dessia_common.utils.helpers import is_sequence class DifferentValues: @@ -175,78 +172,6 @@ def sequence_diff(seq1, seq2, path='#'): return seq_diff -def data_eq(value1, value2): - """ Returns if two values are equal on data equality. """ - if is_sequence(value1) and is_sequence(value2): - return sequence_data_eq(value1, value2) - - if isinstance(value1, npy.int64) or isinstance(value2, npy.int64): - return value1 == value2 - - if isinstance(value1, npy.float64) or isinstance(value2, npy.float64): - return math.isclose(value1, value2, abs_tol=FLOAT_TOLERANCE) - - if not isinstance(value2, type(value1))\ - and not isinstance(value1, type(value2)): - return False - - if isinstance_base_types(value1): - if isinstance(value1, float): - return math.isclose(value1, value2, abs_tol=FLOAT_TOLERANCE) - - return value1 == value2 - - if isinstance(value1, dict): - return dict_data_eq(value1, value2) - - if isinstance(value1, (BinaryFile, StringFile)): - return value1 == value2 - - if isinstance(value1, type): - return full_classname(value1) == full_classname(value2) - - # Else: its an object - if full_classname(value1) != full_classname(value2): - return False - - # Test if _data_eq is customized - if hasattr(value1, '_data_eq'): - custom_method = value1._data_eq.__code__ is not dc.DessiaObject._data_eq.__code__ - if custom_method: - return value1._data_eq(value2) - - # Not custom, use generic implementation - eq_dict = value1._data_eq_dict() - if 'name' in eq_dict: - del eq_dict['name'] - - other_eq_dict = value2._data_eq_dict() - - return dict_data_eq(eq_dict, other_eq_dict) - - -def dict_data_eq(dict1, dict2): - """ Returns True if two dictionaries are equal on data equality, False otherwise. """ - for key, value in dict1.items(): - if key not in dict2: - return False - if not data_eq(value, dict2[key]): - return False - return True - - -def sequence_data_eq(seq1, seq2): - """ Returns if two sequences are equal on data equality. """ - if len(seq1) != len(seq2): - return False - - for v1, v2 in zip(seq1, seq2): - if not data_eq(v1, v2): - return False - - return True - - def choose_hash(object_): """ Base function to return hash. """ if is_sequence(object_): diff --git a/dessia_common/utils/helpers.py b/dessia_common/utils/helpers.py index ead97c34a..513c6d378 100644 --- a/dessia_common/utils/helpers.py +++ b/dessia_common/utils/helpers.py @@ -12,6 +12,9 @@ from importlib import import_module from ast import literal_eval from typing import Type +from collections.abc import Sequence +from dessia_common import REF_MARKER, OLD_REF_MARKER +from dessia_common.errors import ExtractionError _PYTHON_CLASS_CACHE = {} @@ -75,3 +78,116 @@ def get_python_class_from_class_name(full_class_name: str) -> Type: # Storing in cache _PYTHON_CLASS_CACHE[full_class_name] = class_ return class_ + + +def is_sequence(obj) -> bool: + """ + Return True if object is sequence (but not string), else False. + + :param obj: Object to check + :return: bool. True if object is a sequence but not a string. False otherwise + """ + if not hasattr(obj, "__len__") or not hasattr(obj, "__getitem__"): + # Performance improvements for trivial checks + return False + + if is_list(obj) or is_tuple(obj): + # Performance improvements for trivial checks + return True + return isinstance(obj, Sequence) and not isinstance(obj, str) + + +def is_list(obj) -> bool: + """ Check if given obj is exactly of type list (not instance of). Used mainly for performance. """ + return obj.__class__ == list + + +def is_tuple(obj) -> bool: + """ Check if given obj is exactly of type tuple (not instance of). Used mainly for performance. """ + return obj.__class__ == tuple + + +def extract_segment_from_object(object_, segment: str): + """ Try all ways to get an attribute (segment) from an object that can of numerous types. """ + if is_sequence(object_): + try: + return object_[int(segment)] + except ValueError as err: + message_error = (f"Cannot extract segment {segment} from object {{str(object_)[:500]}}:" + f" segment is not a sequence index") + raise ExtractionError(message_error) from err + + if isinstance(object_, dict): + if segment in object_: + return object_[segment] + + if segment.isdigit(): + intifyed_segment = int(segment) + if intifyed_segment in object_: + return object_[intifyed_segment] + if segment in object_: + return object_[segment] + raise ExtractionError(f'Cannot extract segment {segment} from object {str(object_)[:200]}') + + # should be a tuple + if segment.startswith('(') and segment.endswith(')') and ',' in segment: + key = [] + for subsegment in segment.strip('()').replace(' ', '').split(','): + if subsegment.isdigit(): + subkey = int(subsegment) + else: + subkey = subsegment + key.append(subkey) + return object_[tuple(key)] + raise ExtractionError(f"Cannot extract segment {segment} from object {str(object_)[:500]}") + + # Finally, it is a regular object + return getattr(object_, segment) + + +def get_in_object_from_path(object_, path, evaluate_pointers=True): + """ Get deep attributes from an object. Argument 'path' represents path to deep attribute. """ + segments = path.lstrip('#/').split('/') + element = object_ + for segment in segments: + if isinstance(element, dict): + # Going down in the object and it is a reference : evaluating sub-reference + if evaluate_pointers: + if REF_MARKER in element: + try: + element = get_in_object_from_path(object_, element[REF_MARKER]) + except RecursionError as err: + err_msg = f'Cannot get segment {segment} from path {path} in element {str(element)[:500]}' + raise RecursionError(err_msg) from err + elif OLD_REF_MARKER in element: # Retro-compatibility to be remove sometime + try: + element = get_in_object_from_path(object_, element[OLD_REF_MARKER]) + except RecursionError as err: + err_msg = f'Cannot get segment {segment} from path {path} in element {str(element)[:500]}' + raise RecursionError(err_msg) from err + + try: + element = extract_segment_from_object(element, segment) + except ExtractionError as err: + + err_msg = f'Cannot get segment {segment} from path {path} in element {str(element)[:500]}' + raise ExtractionError(err_msg) from err + + return element + + +def set_in_object_from_path(object_, path, value, evaluate_pointers=True): + """ Set deep attribute from an object to the given value. Argument 'path' represents path to deep attribute. """ + reduced_path = '/'.join(path.lstrip('#/').split('/')[:-1]) + last_segment = path.split('/')[-1] + if reduced_path: + last_object = get_in_object_from_path(object_, reduced_path, evaluate_pointers=evaluate_pointers) + else: + last_object = object_ + + if is_sequence(last_object): + last_object[int(last_segment)] = value + elif isinstance(last_object, dict): + last_object[last_segment] = value + else: + setattr(last_object, last_segment, value) diff --git a/dessia_common/utils/types.py b/dessia_common/utils/types.py index 028b535c2..d1c9f31e4 100644 --- a/dessia_common/utils/types.py +++ b/dessia_common/utils/types.py @@ -10,7 +10,7 @@ from dessia_common.typings import InstanceOf, MethodType, ClassMethodType from dessia_common.files import BinaryFile, StringFile from dessia_common.schemas.core import TYPING_EQUIVALENCES, union_is_default_value, is_typing, serialize_annotation -from dessia_common.utils.helpers import get_python_class_from_class_name +from dessia_common.utils.helpers import get_python_class_from_class_name, is_sequence SIMPLE_TYPES = [int, str] @@ -51,6 +51,9 @@ def is_sequence(obj) -> bool: :param obj: Object to check :return: bool. True if object is a sequence but not a string. False otherwise """ + if isinstance(obj, (str, bytes)): + return False + if not hasattr(obj, "__len__") or not hasattr(obj, "__getitem__"): # Performance improvements for trivial checks return False diff --git a/dessia_common/workflow/blocks.py b/dessia_common/workflow/blocks.py index fadaf3bf4..26ae98339 100644 --- a/dessia_common/workflow/blocks.py +++ b/dessia_common/workflow/blocks.py @@ -12,13 +12,15 @@ from dessia_common.displays import DisplaySetting, DisplayObject from dessia_common.errors import UntypedArgumentError from dessia_common.typings import (JsonSerializable, MethodType, ClassMethodType, AttributeType, ViewType, CadViewType, - PlotDataType, MarkdownType) + PlotDataType, MarkdownType, InstanceOf) from dessia_common.files import StringFile, BinaryFile, generate_archive -from dessia_common.utils.helpers import concatenate, full_classname, get_python_class_from_class_name -from dessia_common.breakdown import attrmethod_getter, get_in_object_from_path +from dessia_common.utils.helpers import (concatenate, full_classname, get_python_class_from_class_name, + get_in_object_from_path) +from dessia_common.breakdown import attrmethod_getter from dessia_common.exports import ExportFormat from dessia_common.workflow.core import Block, Variable, Workflow from dessia_common.workflow.utils import ToScriptElement +import plot_data as pd T = TypeVar("T") @@ -801,25 +803,25 @@ def equivalent_hash(self): def evaluate(self, values, **kwargs): """ Create MultiPlot from block configuration. Handle reference path. """ reference_path = kwargs.get("reference_path", "#") - import plot_data + # import plot_data objects = values[self.inputs[self._displayable_input]] - samples = [plot_data.Sample(values={a: get_in_object_from_path(o, a) for a in self.attributes}, - reference_path=f"{reference_path}/{i}", name=f"Sample {i}") + samples = [pd.Sample(values={a: get_in_object_from_path(o, a) for a in self.attributes}, + reference_path=f"{reference_path}/{i}", name=f"Sample {i}") for i, o in enumerate(objects)] - samples2d = [plot_data.Sample(values={a: get_in_object_from_path(o, a) for a in self.attributes[:2]}, - reference_path=f"{reference_path}/{i}", name=f"Sample {i}") + samples2d = [pd.Sample(values={a: get_in_object_from_path(o, a) for a in self.attributes[:2]}, + reference_path=f"{reference_path}/{i}", name=f"Sample {i}") for i, o in enumerate(objects)] - tooltip = plot_data.Tooltip(name="Tooltip", attributes=self.attributes) + tooltip = pd.Tooltip(name="Tooltip", attributes=self.attributes) - scatterplot = plot_data.Scatter(tooltip=tooltip, x_variable=self.attributes[0], y_variable=self.attributes[1], - elements=samples2d, name="Scatter Plot") + scatterplot = pd.Scatter(tooltip=tooltip, x_variable=self.attributes[0], y_variable=self.attributes[1], + elements=samples2d, name="Scatter Plot") - parallelplot = plot_data.ParallelPlot(disposition="horizontal", axes=self.attributes, - rgbs=[(192, 11, 11), (14, 192, 11), (11, 11, 192)], elements=samples) + parallelplot = pd.ParallelPlot(disposition="horizontal", axes=self.attributes, + rgbs=[(192, 11, 11), (14, 192, 11), (11, 11, 192)], elements=samples) plots = [scatterplot, parallelplot] - sizes = [plot_data.Window(width=560, height=300), plot_data.Window(width=560, height=300)] - multiplot = plot_data.MultiplePlots(elements=samples, plots=plots, sizes=sizes, - coords=[(0, 0), (0, 300)], name="Results plot") + sizes = [pd.Window(width=560, height=300), pd.Window(width=560, height=300)] + multiplot = pd.MultiplePlots(elements=samples, plots=plots, sizes=sizes, + coords=[(0, 0), (0, 300)], name="Results plot") return [multiplot.to_dict()] def _to_script(self, _) -> ToScriptElement: @@ -866,25 +868,25 @@ def equivalent_hash(self): def evaluate(self, values, **kwargs): """ Create MultiPlot from block configuration. Handle reference path. """ reference_path = kwargs.get("reference_path", "#") - import plot_data + # import plot_data objects = values[self.inputs[self._displayable_input]] - samples = [plot_data.Sample(values={a: get_in_object_from_path(o, a) for a in self.attributes}, - reference_path=f"{reference_path}/{i}", name=f"Sample {i}") + samples = [pd.Sample(values={a: get_in_object_from_path(o, a) for a in self.attributes}, + reference_path=f"{reference_path}/{i}", name=f"Sample {i}") for i, o in enumerate(objects)] - samples2d = [plot_data.Sample(values={a: get_in_object_from_path(o, a) for a in self.attributes[:2]}, - reference_path=f"{reference_path}/{i}", name=f"Sample {i}") + samples2d = [pd.Sample(values={a: get_in_object_from_path(o, a) for a in self.attributes[:2]}, + reference_path=f"{reference_path}/{i}", name=f"Sample {i}") for i, o in enumerate(objects)] - tooltip = plot_data.Tooltip(name="Tooltip", attributes=self.attributes) + tooltip = pd.Tooltip(name="Tooltip", attributes=self.attributes) - scatterplot = plot_data.Scatter(tooltip=tooltip, x_variable=self.attributes[0], y_variable=self.attributes[1], - elements=samples2d, name="Scatter Plot") + scatterplot = pd.Scatter(tooltip=tooltip, x_variable=self.attributes[0], y_variable=self.attributes[1], + elements=samples2d, name="Scatter Plot") - parallelplot = plot_data.ParallelPlot(disposition="horizontal", axes=self.attributes, - rgbs=[(192, 11, 11), (14, 192, 11), (11, 11, 192)], elements=samples) + parallelplot = pd.ParallelPlot(disposition="horizontal", axes=self.attributes, + rgbs=[(192, 11, 11), (14, 192, 11), (11, 11, 192)], elements=samples) plots = [scatterplot, parallelplot] - sizes = [plot_data.Window(width=560, height=300), plot_data.Window(width=560, height=300)] - multiplot = plot_data.MultiplePlots(elements=samples, plots=plots, sizes=sizes, - coords=[(0, 0), (0, 300)], name="Results plot") + sizes = [pd.Window(width=560, height=300), pd.Window(width=560, height=300)] + multiplot = pd.MultiplePlots(elements=samples, plots=plots, sizes=sizes, + coords=[(0, 0), (0, 300)], name="Results plot") return [multiplot.to_dict()] def _to_script(self, _) -> ToScriptElement: @@ -922,6 +924,83 @@ def dict_to_object(cls, dict_: JsonSerializable, **kwargs) -> 'MultiPlot': return block +class MultiObject(Display): + """ + Generate a MultiObject view which axes will be the given attributes. + + It Will show a Scatter and a Parallel Plot. + + :param selector_name: Name of the selector to be displayed in object page. + Must be unique throughout workflow. + :param attributes: A List of all attributes that will be shown on axes in the ParallelPlot window. + Can be deep attributes with the '/' separator. + :param name: Name of the block. + :param position: Position of the block in canvas. + """ + + _type = "plot_data" + serialize = True + + def __init__(self, selector_name: str, configurations: List[InstanceOf['PlotDataView']], + load_by_default: bool = True, name: str = "Multi Object View", position: Position = (0, 0)): + self.configurations = configurations + Display.__init__(self, inputs=[Variable(type_=List[DessiaObject])], load_by_default=load_by_default, + name=name, selector=PlotDataType(class_=DessiaObject, name=selector_name), position=position) + self.inputs[0].name = "Sequence" + + def __deepcopy__(self, memo=None): + return MultiObject(selector_name=self.selector.name, configurations=self.configurations, + load_by_default=self.load_by_default, name=self.name, position=self.position) + + def equivalent(self, other: 'MultiObject'): + """ Return whether if the block is equivalent to the other given. """ + same_attributes = self.configurations == other.configurations + return super().equivalent(other) and same_attributes + + def equivalent_hash(self): + """ Custom hash function. Related to 'equivalent' method. """ + return sum(hash(a) for a in self.configurations) + + def evaluate(self, values, **kwargs): + """ Create MultiPlot from block configuration. Handle reference path. """ + reference_path = kwargs.get("reference_path", "#") + objects = values[self.inputs[self._displayable_input]] + plots = [c.plot_data_object(objects=objects, reference_path=reference_path) for c in self.configurations] + + # TODO Mutualizing samples from multiplot and subplots should probably be done by plot_data + attributes = list(set([a for c in self.configurations for a in c.attributes])) + samples = [pd.Sample(values={a: get_in_object_from_path(o, a) for a in attributes}, + reference_path=f"{reference_path}/{i}", name=o.name if o.name else f"Sample {i}") + for i, o in enumerate(objects)] + multiplot = pd.MultiplePlots(elements=samples, plots=plots, name="Results plot") + return [multiplot.to_dict()] + + # def _to_script(self, _) -> ToScriptElement: + # """ Write block config into a chunk of script. """ + # script = (f"MultiObject(" + # f"selector_name='{self.selector.name}'," + # f" attributes={self.attributes}," + # f" {self.base_script()})") + # return ToScriptElement(declaration=script, imports=[self.full_classname]) + + def to_dict(self, use_pointers: bool = True, memo=None, path: str = '#', + id_method=True, id_memo=None, **kwargs) -> JsonSerializable: + """ Overwrite to_dict method in order to handle difference of behaviors about selector. """ + dict_ = super().to_dict(use_pointers=use_pointers, memo=memo, path=path, id_method=id_method, id_memo=id_memo) + dict_.update({"selector_name": self.selector.name, "configurations": [c.to_dict() for c in self.configurations], + "name": self.name, "load_by_default": self.load_by_default, "position": self.position}) + return dict_ + + @classmethod + def dict_to_object(cls, dict_: JsonSerializable, **kwargs) -> 'MultiObject': + """ Backward compatibility for old versions of Display blocks. """ + configurations = [PlotDataView.dict_to_object(c) for c in dict_["configurations"]] + block = MultiObject(selector_name=dict_["selector_name"], configurations=configurations, name=dict_["name"], + load_by_default=dict_["load_by_default"], position=dict_["position"]) + block.deserialize_variables(dict_) + return block + + class DeprecatedCadView(Display): """ Deprecated version of CadView block. @@ -1451,3 +1530,40 @@ def _to_script(self, _) -> ToScriptElement: """ Write block config into a chunk of script. """ script = f"Archive(number_exports={self.number_exports}, filename='{self.filename}', {self.base_script()})" return ToScriptElement(declaration=script, imports=[self.full_classname]) + + +class PlotDataView(DessiaObject): + """ Plot Data View framework base class. """ + + def __init__(self, attributes: List[str], name: str = ""): + self.attributes = attributes + + super().__init__(name) + + def samples(self, objects, reference_path: str = "#"): + return [pd.Sample(values={a: get_in_object_from_path(o, a) for a in self.attributes}, + reference_path=f"{reference_path}/{i}", name=f"Sample {i}") for i, o in enumerate(objects)] + + +class ScatterView(PlotDataView): + """ Scatter View Framework. """ + + def __init__(self, attributes: Tuple[str, str], name: str = ""): + super().__init__(attributes=list(attributes), name=name) + + def plot_data_object(self, objects, reference_path: str = "#") -> pd.Scatter: + tooltip = pd.Tooltip(name="Tooltip", attributes=list(self.attributes)) + samples = self.samples(objects=objects, reference_path=reference_path) + return pd.Scatter(tooltip=tooltip, x_variable=self.attributes[0], y_variable=self.attributes[1], + elements=samples, name=self.name) + + +class ParallelView(PlotDataView): + """ Scatter View Framework. """ + + def __init__(self, attributes: List[str], name: str = ""): + super().__init__(attributes=attributes, name=name) + + def plot_data_object(self, objects, reference_path: str = "#") -> pd.ParallelPlot: + samples = self.samples(objects=objects, reference_path=reference_path) + return pd.ParallelPlot(elements=samples, disposition="horizontal", axes=self.attributes) diff --git a/dessia_common/workflow/core.py b/dessia_common/workflow/core.py index 9442e7887..137523e3c 100644 --- a/dessia_common/workflow/core.py +++ b/dessia_common/workflow/core.py @@ -19,15 +19,14 @@ from dessia_common.schemas.core import (get_schema, FAILED_ATTRIBUTE_PARSING, EMPTY_PARSED_ATTRIBUTE, serialize_annotation, pretty_annotation, UNDEFINED, Schema, SchemaAttribute) -from dessia_common.utils.types import recursive_type, typematch, is_sequence, is_file_or_file_sequence +from dessia_common.utils.types import recursive_type, typematch, is_file_or_file_sequence from dessia_common.utils.copy import deepcopy_value from dessia_common.utils.diff import choose_hash -from dessia_common.utils.helpers import prettyname +from dessia_common.utils.helpers import prettyname, is_sequence from dessia_common.typings import JsonSerializable, ViewType from dessia_common.files import StringFile, BinaryFile from dessia_common.displays import DisplaySetting -from dessia_common.breakdown import ExtractionError -from dessia_common.errors import SerializationError +from dessia_common.errors import SerializationError, ExtractionError from dessia_common.warnings import SerializationWarning from dessia_common.exports import ExportFormat, MarkdownWriter import dessia_common.templates @@ -1899,26 +1898,13 @@ def display_settings(self, **kwargs) -> List[DisplaySetting]: Concatenate WorkflowState display_settings and inserting Workflow ones. """ - workflow_settings = [display_setting for display_setting in self.workflow.display_settings() - if display_setting.selector != "Documentation"] block_settings = self.workflow.blocks_display_settings displays_by_default = [s.load_by_default for s in block_settings] documentation = DisplaySetting(selector="Documentation", type_="markdown", method="to_markdown", load_by_default=True) documentation.load_by_default = not any(displays_by_default) - - workflow_settings_to_keep = [documentation] - for settings in workflow_settings: - # Update workflow settings - settings.compose("workflow") - - if settings.selector == "Workflow": - settings.load_by_default = False - - if settings.selector != "Tasks": - workflow_settings_to_keep.append(settings) - return workflow_settings_to_keep + block_settings + return [documentation] + block_settings def method_dict(self, method_name: str = None, method_jsonschema=None): """ Get run again default dict. """ diff --git a/dessia_common/workflow/utils.py b/dessia_common/workflow/utils.py index 8da844d9c..1706501f2 100644 --- a/dessia_common/workflow/utils.py +++ b/dessia_common/workflow/utils.py @@ -6,7 +6,8 @@ from dessia_common.schemas.core import get_schema, SchemaAttribute from dessia_common.serialization import SerializableObject -from dessia_common.utils.types import is_file_or_file_sequence, is_sequence +from dessia_common.utils.types import is_file_or_file_sequence +from dessia_common.utils.helpers import is_sequence class ToScriptElement: diff --git a/tests/test_displays/test_moving_object.py b/tests/test_displays/test_moving_object.py index 8ff318ccf..0fb68f415 100644 --- a/tests/test_displays/test_moving_object.py +++ b/tests/test_displays/test_moving_object.py @@ -12,11 +12,12 @@ def test_viability(self): self.mso._check_platform() def test_length(self): - self.assertEqual(len(self.displays), 2) + self.assertEqual(len(self.displays), 3) @parameterized.expand([ (0, "markdown"), - (1, "babylon_data"), + (1, "tree"), + (2, "babylon_data"), ]) def test_decorators(self, index, expected_type): self.assertEqual(self.displays[index]["type_"], expected_type) diff --git a/scripts/downstream.py b/tests/test_downstream.py similarity index 93% rename from scripts/downstream.py rename to tests/test_downstream.py index bb1c3eae8..e084b0eaf 100644 --- a/scripts/downstream.py +++ b/tests/test_downstream.py @@ -1,6 +1,6 @@ import unittest - + class BackendBreakingChangeTest(unittest.TestCase): def test_import_is_working(self): """Basic unittest to make sure backend import of DC is working""" @@ -10,7 +10,8 @@ def test_import_is_working(self): from dessia_common.errors import DeepAttributeError from dessia_common.files import BinaryFile, StringFile from dessia_common.serialization import serialize, serialize_with_pointers - from dessia_common.utils.types import is_bson_valid, is_jsonable, is_sequence + from dessia_common.utils.helpers import is_sequence + from dessia_common.utils.types import is_bson_valid, is_jsonable from dessia_common.workflow.core import WorkflowRun, WorkflowState from dessia_common.workflow.utils import ToScriptElement from dessia_common.schemas.core import ( diff --git a/tests/test_framework/test_types.py b/tests/test_framework/test_types.py index a71dc22e3..91ce99bfc 100644 --- a/tests/test_framework/test_types.py +++ b/tests/test_framework/test_types.py @@ -1,4 +1,5 @@ -from dessia_common.utils.types import is_sequence, is_list, is_tuple, isinstance_base_types, is_simple +from dessia_common.utils.helpers import is_sequence, is_list, is_tuple +from dessia_common.utils.types import isinstance_base_types, is_simple import unittest from parameterized import parameterized diff --git a/tests/test_schemas/test_computation_proxies.py b/tests/test_schemas/test_computation_proxies.py index c9e78872c..3789ed5db 100644 --- a/tests/test_schemas/test_computation_proxies.py +++ b/tests/test_schemas/test_computation_proxies.py @@ -10,7 +10,7 @@ class TestFaulty(unittest.TestCase): @parameterized.expand([ - (OptionalProperty(annotation=Optional[List[int]], attribute=ATTRIBUTE)) + (OptionalProperty(annotation=Optional[List[int]], attribute=ATTRIBUTE),) ]) def test_schema_check(self, schema): self.assertEqual(schema.args, (int,)) diff --git a/tests/test_schemas/test_workflows.py b/tests/test_schemas/test_workflows.py index 2bbde0f51..f45f39b32 100644 --- a/tests/test_schemas/test_workflows.py +++ b/tests/test_schemas/test_workflows.py @@ -10,7 +10,8 @@ def setUp(self) -> None: @parameterized.expand([ ("required", ["0", "3", "5"]), ("method", True), - ("type", "object") + ("type", "object"), + ("order", ['0', '1', '2', '3', '4', '5']) ]) def test_items(self, key, value): self.assertEqual(self.schema[key], value)