Skip to content

Commit

Permalink
multi types schema format unmarshal fix
Browse files Browse the repository at this point in the history
  • Loading branch information
p1c2u committed Aug 17, 2023
1 parent 5a02484 commit 3bd6279
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 42 deletions.
69 changes: 27 additions & 42 deletions openapi_core/unmarshalling/schemas/unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,34 +140,15 @@ def _unmarshal_properties(

class MultiTypeUnmarshaller(PrimitiveUnmarshaller):
def __call__(self, value: Any) -> Any:
unmarshaller = self._get_best_unmarshaller(value)
primitive_type = self.schema_validator.get_primitive_type(value)
unmarshaller = self.schema_unmarshaller.get_type_unmarshaller(
primitive_type
)
return unmarshaller(value)

@property
def type(self) -> List[str]:
types = self.schema.getkey("type", ["any"])
assert isinstance(types, list)
return types

def _get_best_unmarshaller(self, value: Any) -> "PrimitiveUnmarshaller":
for schema_type in self.type:
result = self.schema_validator.type_validator(
value, type_override=schema_type
)
if not result:
continue
result = self.schema_validator.format_validator(value)
if not result:
continue
return self.schema_unmarshaller.get_type_unmarshaller(schema_type)

raise UnmarshallerError("Unmarshaller not found for type(s)")


class AnyUnmarshaller(MultiTypeUnmarshaller):
@property
def type(self) -> List[str]:
return self.schema_unmarshaller.types_unmarshaller.get_types()
pass


class TypesUnmarshaller:
Expand All @@ -187,7 +168,7 @@ def __init__(
def get_types(self) -> List[str]:
return list(self.unmarshallers.keys())

def get_unmarshaller(
def get_unmarshaller_cls(
self,
schema_type: Optional[Union[Iterable[str], str]],
) -> Type["PrimitiveUnmarshaller"]:
Expand Down Expand Up @@ -222,8 +203,8 @@ def unmarshal(self, schema_format: str, value: Any) -> Any:
return value
try:
return format_unmarshaller(value)
except (ValueError, TypeError) as exc:
raise FormatUnmarshalError(value, schema_format, exc)
except (AttributeError, ValueError, TypeError) as exc:
return value

def get_unmarshaller(
self, schema_format: str
Expand Down Expand Up @@ -270,22 +251,33 @@ def unmarshal(self, value: Any) -> Any:
schema_type = self.schema.getkey("type")
type_unmarshaller = self.get_type_unmarshaller(schema_type)
typed = type_unmarshaller(value)

schema_format = self.find_format(value)
if schema_format is None:
format_unmarshaller = self.get_format_unmarshaller(schema_format)
if format_unmarshaller is None:
return typed
try:
return format_unmarshaller(typed)
except (AttributeError, ValueError, TypeError):
return typed
return self.formats_unmarshaller.unmarshal(schema_format, typed)

def get_type_unmarshaller(
self,
schema_type: Optional[Union[Iterable[str], str]],
) -> PrimitiveUnmarshaller:
klass = self.types_unmarshaller.get_unmarshaller(schema_type)
klass = self.types_unmarshaller.get_unmarshaller_cls(schema_type)
return klass(
self.schema,
self.schema_validator,
self,
)

def get_format_unmarshaller(
self,
schema_format: str,
) -> Optional[FormatUnmarshaller]:
return self.formats_unmarshaller.get_unmarshaller(schema_format)

def evolve(self, schema: Spec) -> "SchemaUnmarshaller":
cls = self.__class__

Expand All @@ -297,17 +289,10 @@ def evolve(self, schema: Spec) -> "SchemaUnmarshaller":
)

def find_format(self, value: Any) -> Optional[str]:
for schema in self.iter_valid_schemas(value):
if "format" in schema:
primitive_type = self.schema_validator.get_primitive_type(value)
if primitive_type != "string":
return None
for schema in self.schema_validator.iter_valid_schemas(value):
if "format" in schema and schema.getkey("type") == primitive_type:
return str(schema.getkey("format"))
return None

def iter_valid_schemas(self, value: Any) -> Iterator[Spec]:
yield self.schema

one_of_schema = self.schema_validator.get_one_of_schema(value)
if one_of_schema is not None:
yield one_of_schema

yield from self.schema_validator.iter_any_of_schemas(value)
yield from self.schema_validator.iter_all_of_schemas(value)
26 changes: 26 additions & 0 deletions openapi_core/validation/schemas/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ def format_validator_callable(self) -> FormatValidator:

return lambda x: True

def get_primitive_type(self, value: Any) -> str:
schema_types = self.schema.getkey("type")
if isinstance(schema_types, str):
return schema_types
if schema_types is None:
schema_types = sorted(self.validator.TYPE_CHECKER._type_checkers)
assert isinstance(schema_types, list)
for schema_type in schema_types:
result = self.type_validator(value, type_override=schema_type)
if not result:
continue
result = self.format_validator(value)
if not result:
continue
return schema_type

def get_one_of_schema(
self,
value: Any,
Expand Down Expand Up @@ -133,3 +149,13 @@ def iter_all_of_schemas(
log.warning("invalid allOf schema found")
else:
yield subschema

def iter_valid_schemas(self, value: Any) -> Iterator[Spec]:
yield self.schema

one_of_schema = self.schema_validator.get_one_of_schema(value)
if one_of_schema is not None:
yield one_of_schema

yield from self.schema_validator.iter_any_of_schemas(value)
yield from self.schema_validator.iter_all_of_schemas(value)
45 changes: 45 additions & 0 deletions tests/integration/unmarshalling/test_unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,30 @@ def test_basic_type_formats(

assert result == unmarshalled

@pytest.mark.parametrize(
"type,format,value",
[
("string", "float", "test"),
("string", "double", "test"),
("number", "date", 3),
("number", "date-time", 3),
("number", "uuid", 3),
],
)
def test_basic_type_formats_ignored(
self, unmarshallers_factory, type, format, value
):
schema = {
"type": type,
"format": format,
}
spec = Spec.from_dict(schema, validator=None)
unmarshaller = unmarshallers_factory.create(spec)

result = unmarshaller.unmarshal(value)

assert result == value

@pytest.mark.parametrize(
"type,format,value",
[
Expand Down Expand Up @@ -2036,6 +2060,27 @@ def test_nultiple_types_invalid(self, unmarshallers_factory, types, value):
assert len(exc_info.value.schema_errors) == 1
assert "is not of type" in exc_info.value.schema_errors[0].message

@pytest.mark.parametrize(
"types,format,value,expected",
[
(["string", "null"], "date", None, None),
(["string", "null"], "date", "2018-12-13", date(2018, 12, 13)),
],
)
def test_multiple_types_format_valid_or_ignored(
self, unmarshallers_factory, types, format, value, expected
):
schema = {
"type": types,
"format": format,
}
spec = Spec.from_dict(schema, validator=None)
unmarshaller = unmarshallers_factory.create(spec)

result = unmarshaller.unmarshal(value)

assert result == expected

def test_any_null(self, unmarshallers_factory):
schema = {}
spec = Spec.from_dict(schema, validator=None)
Expand Down

0 comments on commit 3bd6279

Please sign in to comment.