From 3bd6279288360f861dca2a2be419ab116997b171 Mon Sep 17 00:00:00 2001 From: p1c2u Date: Mon, 24 Apr 2023 12:14:43 +0100 Subject: [PATCH] multi types schema format unmarshal fix --- .../unmarshalling/schemas/unmarshallers.py | 69 ++++++++----------- openapi_core/validation/schemas/validators.py | 26 +++++++ .../unmarshalling/test_unmarshallers.py | 45 ++++++++++++ 3 files changed, 98 insertions(+), 42 deletions(-) diff --git a/openapi_core/unmarshalling/schemas/unmarshallers.py b/openapi_core/unmarshalling/schemas/unmarshallers.py index 39617f51..50417fbd 100644 --- a/openapi_core/unmarshalling/schemas/unmarshallers.py +++ b/openapi_core/unmarshalling/schemas/unmarshallers.py @@ -140,34 +140,15 @@ def _unmarshal_properties( class MultiTypeUnmarshaller(PrimitiveUnmarshaller): def __call__(self, value: Any) -> Any: - unmarshaller = self._get_best_unmarshaller(value) + primitive_type = self.schema_validator.get_primitive_type(value) + unmarshaller = self.schema_unmarshaller.get_type_unmarshaller( + primitive_type + ) return unmarshaller(value) - @property - def type(self) -> List[str]: - types = self.schema.getkey("type", ["any"]) - assert isinstance(types, list) - return types - - def _get_best_unmarshaller(self, value: Any) -> "PrimitiveUnmarshaller": - for schema_type in self.type: - result = self.schema_validator.type_validator( - value, type_override=schema_type - ) - if not result: - continue - result = self.schema_validator.format_validator(value) - if not result: - continue - return self.schema_unmarshaller.get_type_unmarshaller(schema_type) - - raise UnmarshallerError("Unmarshaller not found for type(s)") - class AnyUnmarshaller(MultiTypeUnmarshaller): - @property - def type(self) -> List[str]: - return self.schema_unmarshaller.types_unmarshaller.get_types() + pass class TypesUnmarshaller: @@ -187,7 +168,7 @@ def __init__( def get_types(self) -> List[str]: return list(self.unmarshallers.keys()) - def get_unmarshaller( + def get_unmarshaller_cls( self, schema_type: Optional[Union[Iterable[str], str]], ) -> Type["PrimitiveUnmarshaller"]: @@ -222,8 +203,8 @@ def unmarshal(self, schema_format: str, value: Any) -> Any: return value try: return format_unmarshaller(value) - except (ValueError, TypeError) as exc: - raise FormatUnmarshalError(value, schema_format, exc) + except (AttributeError, ValueError, TypeError) as exc: + return value def get_unmarshaller( self, schema_format: str @@ -270,22 +251,33 @@ def unmarshal(self, value: Any) -> Any: schema_type = self.schema.getkey("type") type_unmarshaller = self.get_type_unmarshaller(schema_type) typed = type_unmarshaller(value) + schema_format = self.find_format(value) - if schema_format is None: + format_unmarshaller = self.get_format_unmarshaller(schema_format) + if format_unmarshaller is None: + return typed + try: + return format_unmarshaller(typed) + except (AttributeError, ValueError, TypeError): return typed - return self.formats_unmarshaller.unmarshal(schema_format, typed) def get_type_unmarshaller( self, schema_type: Optional[Union[Iterable[str], str]], ) -> PrimitiveUnmarshaller: - klass = self.types_unmarshaller.get_unmarshaller(schema_type) + klass = self.types_unmarshaller.get_unmarshaller_cls(schema_type) return klass( self.schema, self.schema_validator, self, ) + def get_format_unmarshaller( + self, + schema_format: str, + ) -> Optional[FormatUnmarshaller]: + return self.formats_unmarshaller.get_unmarshaller(schema_format) + def evolve(self, schema: Spec) -> "SchemaUnmarshaller": cls = self.__class__ @@ -297,17 +289,10 @@ def evolve(self, schema: Spec) -> "SchemaUnmarshaller": ) def find_format(self, value: Any) -> Optional[str]: - for schema in self.iter_valid_schemas(value): - if "format" in schema: + primitive_type = self.schema_validator.get_primitive_type(value) + if primitive_type != "string": + return None + for schema in self.schema_validator.iter_valid_schemas(value): + if "format" in schema and schema.getkey("type") == primitive_type: return str(schema.getkey("format")) return None - - def iter_valid_schemas(self, value: Any) -> Iterator[Spec]: - yield self.schema - - one_of_schema = self.schema_validator.get_one_of_schema(value) - if one_of_schema is not None: - yield one_of_schema - - yield from self.schema_validator.iter_any_of_schemas(value) - yield from self.schema_validator.iter_all_of_schemas(value) diff --git a/openapi_core/validation/schemas/validators.py b/openapi_core/validation/schemas/validators.py index b9f73940..a6a96072 100644 --- a/openapi_core/validation/schemas/validators.py +++ b/openapi_core/validation/schemas/validators.py @@ -78,6 +78,22 @@ def format_validator_callable(self) -> FormatValidator: return lambda x: True + def get_primitive_type(self, value: Any) -> str: + schema_types = self.schema.getkey("type") + if isinstance(schema_types, str): + return schema_types + if schema_types is None: + schema_types = sorted(self.validator.TYPE_CHECKER._type_checkers) + assert isinstance(schema_types, list) + for schema_type in schema_types: + result = self.type_validator(value, type_override=schema_type) + if not result: + continue + result = self.format_validator(value) + if not result: + continue + return schema_type + def get_one_of_schema( self, value: Any, @@ -133,3 +149,13 @@ def iter_all_of_schemas( log.warning("invalid allOf schema found") else: yield subschema + + def iter_valid_schemas(self, value: Any) -> Iterator[Spec]: + yield self.schema + + one_of_schema = self.schema_validator.get_one_of_schema(value) + if one_of_schema is not None: + yield one_of_schema + + yield from self.schema_validator.iter_any_of_schemas(value) + yield from self.schema_validator.iter_all_of_schemas(value) diff --git a/tests/integration/unmarshalling/test_unmarshallers.py b/tests/integration/unmarshalling/test_unmarshallers.py index 6fa0708d..f09dce7f 100644 --- a/tests/integration/unmarshalling/test_unmarshallers.py +++ b/tests/integration/unmarshalling/test_unmarshallers.py @@ -240,6 +240,30 @@ def test_basic_type_formats( assert result == unmarshalled + @pytest.mark.parametrize( + "type,format,value", + [ + ("string", "float", "test"), + ("string", "double", "test"), + ("number", "date", 3), + ("number", "date-time", 3), + ("number", "uuid", 3), + ], + ) + def test_basic_type_formats_ignored( + self, unmarshallers_factory, type, format, value + ): + schema = { + "type": type, + "format": format, + } + spec = Spec.from_dict(schema, validator=None) + unmarshaller = unmarshallers_factory.create(spec) + + result = unmarshaller.unmarshal(value) + + assert result == value + @pytest.mark.parametrize( "type,format,value", [ @@ -2036,6 +2060,27 @@ def test_nultiple_types_invalid(self, unmarshallers_factory, types, value): assert len(exc_info.value.schema_errors) == 1 assert "is not of type" in exc_info.value.schema_errors[0].message + @pytest.mark.parametrize( + "types,format,value,expected", + [ + (["string", "null"], "date", None, None), + (["string", "null"], "date", "2018-12-13", date(2018, 12, 13)), + ], + ) + def test_multiple_types_format_valid_or_ignored( + self, unmarshallers_factory, types, format, value, expected + ): + schema = { + "type": types, + "format": format, + } + spec = Spec.from_dict(schema, validator=None) + unmarshaller = unmarshallers_factory.create(spec) + + result = unmarshaller.unmarshal(value) + + assert result == expected + def test_any_null(self, unmarshallers_factory): schema = {} spec = Spec.from_dict(schema, validator=None)