Skip to content

Commit

Permalink
Fix the case of unprepared default value in conjunction with RootModel
Browse files Browse the repository at this point in the history
  • Loading branch information
surenkov committed Mar 8, 2024
1 parent 23701fb commit 3288eb1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
9 changes: 5 additions & 4 deletions django_pydantic_field/fields.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ from __future__ import annotations

import json
import typing as ty
from django.db.models.expressions import BaseExpression
import typing_extensions as te

import typing_extensions as te
Expand Down Expand Up @@ -80,7 +81,7 @@ class _DeprecatedSchemaFieldKwargs(_SchemaFieldKwargs, total=False):
def SchemaField(
schema: ty.Type[ST] | None | ty.ForwardRef = ...,
config: ConfigType = ...,
default: OptSchemaT | ty.Callable[[], OptSchemaT] = ...,
default: OptSchemaT | ty.Callable[[], OptSchemaT] | BaseExpression = ...,
*args,
null: ty.Literal[True],
**kwargs: te.Unpack[_SchemaFieldKwargs],
Expand All @@ -89,7 +90,7 @@ def SchemaField(
def SchemaField(
schema: ty.Type[ST] | ty.ForwardRef = ...,
config: ConfigType = ...,
default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ...,
default: SchemaT | ty.Callable[[], SchemaT] | BaseExpression = ...,
*args,
null: ty.Literal[False] = ...,
**kwargs: te.Unpack[_SchemaFieldKwargs],
Expand All @@ -99,7 +100,7 @@ def SchemaField(
def SchemaField(
schema: ty.Type[ST] | None | ty.ForwardRef = ...,
config: ConfigType = ...,
default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ...,
default: SchemaT | ty.Callable[[], SchemaT] | BaseExpression = ...,
*args,
null: ty.Literal[True],
**kwargs: te.Unpack[_DeprecatedSchemaFieldKwargs],
Expand All @@ -109,7 +110,7 @@ def SchemaField(
def SchemaField(
schema: ty.Type[ST] | ty.ForwardRef = ...,
config: ConfigType = ...,
default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ...,
default: SchemaT | ty.Callable[[], SchemaT] | BaseExpression = ...,
*args,
null: ty.Literal[False] = ...,
**kwargs: te.Unpack[_DeprecatedSchemaFieldKwargs],
Expand Down
37 changes: 21 additions & 16 deletions django_pydantic_field/v2/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def deconstruct(self) -> ty.Any:

default = kwargs.get("default", NOT_PROVIDED)
if default is not NOT_PROVIDED and not callable(default):
kwargs["default"] = self.adapter.dump_python(default, include=None, exclude=None, round_trip=True)
kwargs["default"] = self._prepare_raw_value(default, include=None, exclude=None, round_trip=True)

prep_schema = GenericContainer.wrap(self.adapter.prepared_schema)
kwargs.update(schema=prep_schema, config=self.config, **self.export_kwargs)
Expand Down Expand Up @@ -161,14 +161,7 @@ def to_python(self, value: ty.Any):
raise exceptions.ValidationError(exc.json(), code="invalid", params=error_params) from exc

def get_prep_value(self, value: ty.Any):
if isinstance(value, Value) and isinstance(value.output_field, self.__class__):
# Prepare inner value for `Value`-wrapped expressions.
value = Value(self.get_prep_value(value.value), value.output_field)
elif not isinstance(value, BaseExpression):
# Prepare the value if it is not a query expression.
prep_value = self.adapter.validate_python(value)
value = self.adapter.dump_python(prep_value)

value = self._prepare_raw_value(value)
return super().get_prep_value(value)

def get_transform(self, lookup_name: str):
Expand All @@ -195,7 +188,21 @@ def formfield(self, **kwargs):

def value_to_string(self, obj: Model):
value = super().value_from_object(obj)
return self.get_prep_value(value)
return self._prepare_raw_value(value)

def _prepare_raw_value(self, value: ty.Any, **dump_kwargs):
if isinstance(value, Value) and isinstance(value.output_field, self.__class__):
# Prepare inner value for `Value`-wrapped expressions.
value = Value(self._prepare_raw_value(value.value), value.output_field)
elif not isinstance(value, BaseExpression):
# Prepare the value if it is not a query expression.
try:
value = self.adapter.validate_python(value)
except pydantic.ValidationError:
"""This is a legitimate situation(?), the data could not be initially coerced."""
value = self.adapter.dump_python(value, **dump_kwargs)

return value


class SchemaKeyTransformAdapter:
Expand All @@ -217,24 +224,22 @@ def __call__(self, col: Col | None = None, *args, **kwargs) -> Transform | None:
def SchemaField(
schema: type[types.ST | None] | ty.ForwardRef = ...,
config: pydantic.ConfigDict = ...,
default: types.SchemaT | None | ty.Callable[[], types.SchemaT | None] = ...,
default: types.SchemaT | ty.Callable[[], types.SchemaT | None] | BaseExpression | None = ...,
*args,
null: ty.Literal[True],
**kwargs: te.Unpack[_SchemaFieldKwargs],
) -> types.ST | None:
...
) -> types.ST | None: ...


@ty.overload
def SchemaField(
schema: type[types.ST] | ty.ForwardRef = ...,
config: pydantic.ConfigDict = ...,
default: ty.Union[types.SchemaT, ty.Callable[[], types.SchemaT]] = ...,
default: types.SchemaT | ty.Callable[[], types.SchemaT] | BaseExpression = ...,
*args,
null: ty.Literal[False] = ...,
**kwargs: te.Unpack[_SchemaFieldKwargs],
) -> types.ST:
...
) -> types.ST: ...


def SchemaField(schema=None, config=None, default=NOT_PROVIDED, *args, **kwargs): # type: ignore
Expand Down

0 comments on commit 3288eb1

Please sign in to comment.