Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare a SchemaField's default value to work with RootModel schemas #51

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions django_pydantic_field/fields.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ from __future__ import annotations

import json
import typing as ty
import typing_extensions as te

import typing_extensions as te
from django.db.models.expressions import BaseExpression
from pydantic import BaseConfig, BaseModel, ConfigDict

try:
Expand Down Expand Up @@ -80,7 +80,7 @@ class _DeprecatedSchemaFieldKwargs(_SchemaFieldKwargs, total=False):
def SchemaField(
schema: ty.Type[ST] | None | ty.ForwardRef = ...,
config: ConfigType = ...,
default: OptSchemaT | ty.Callable[[], OptSchemaT] = ...,
default: OptSchemaT | ty.Callable[[], OptSchemaT] | BaseExpression = ...,
*args,
null: ty.Literal[True],
**kwargs: te.Unpack[_SchemaFieldKwargs],
Expand All @@ -89,7 +89,7 @@ def SchemaField(
def SchemaField(
schema: ty.Type[ST] | ty.ForwardRef = ...,
config: ConfigType = ...,
default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ...,
default: SchemaT | ty.Callable[[], SchemaT] | BaseExpression = ...,
*args,
null: ty.Literal[False] = ...,
**kwargs: te.Unpack[_SchemaFieldKwargs],
Expand All @@ -99,7 +99,7 @@ def SchemaField(
def SchemaField(
schema: ty.Type[ST] | None | ty.ForwardRef = ...,
config: ConfigType = ...,
default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ...,
default: SchemaT | ty.Callable[[], SchemaT] | BaseExpression = ...,
*args,
null: ty.Literal[True],
**kwargs: te.Unpack[_DeprecatedSchemaFieldKwargs],
Expand All @@ -109,7 +109,7 @@ def SchemaField(
def SchemaField(
schema: ty.Type[ST] | ty.ForwardRef = ...,
config: ConfigType = ...,
default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ...,
default: SchemaT | ty.Callable[[], SchemaT] | BaseExpression = ...,
*args,
null: ty.Literal[False] = ...,
**kwargs: te.Unpack[_DeprecatedSchemaFieldKwargs],
Expand Down
44 changes: 26 additions & 18 deletions django_pydantic_field/v2/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

if ty.TYPE_CHECKING:
import json

import typing_extensions as te
from django.db.models import Model

Expand Down Expand Up @@ -94,7 +95,7 @@ def deconstruct(self) -> ty.Any:

default = kwargs.get("default", NOT_PROVIDED)
if default is not NOT_PROVIDED and not callable(default):
kwargs["default"] = self.adapter.dump_python(default, include=None, exclude=None, round_trip=True)
kwargs["default"] = self._prepare_raw_value(default, include=None, exclude=None, round_trip=True)

prep_schema = GenericContainer.wrap(self.adapter.prepared_schema)
kwargs.update(schema=prep_schema, config=self.config, **self.export_kwargs)
Expand Down Expand Up @@ -161,14 +162,7 @@ def to_python(self, value: ty.Any):
raise exceptions.ValidationError(exc.json(), code="invalid", params=error_params) from exc

def get_prep_value(self, value: ty.Any):
if isinstance(value, Value) and isinstance(value.output_field, self.__class__):
# Prepare inner value for `Value`-wrapped expressions.
value = Value(self.get_prep_value(value.value), value.output_field)
elif not isinstance(value, BaseExpression):
# Prepare the value if it is not a query expression.
prep_value = self.adapter.validate_python(value)
value = self.adapter.dump_python(prep_value)

value = self._prepare_raw_value(value)
return super().get_prep_value(value)

def get_transform(self, lookup_name: str):
Expand All @@ -177,9 +171,11 @@ def get_transform(self, lookup_name: str):
transform = SchemaKeyTransformAdapter(transform)
return transform

def get_default(self) -> types.ST:
def get_default(self) -> ty.Any:
default_value = super().get_default()
return self.adapter.validate_python(default_value)
if self.has_default():
return self.adapter.validate_python(default_value)
return default_value

def formfield(self, **kwargs):
field_kwargs = dict(
Expand All @@ -195,7 +191,21 @@ def formfield(self, **kwargs):

def value_to_string(self, obj: Model):
value = super().value_from_object(obj)
return self.get_prep_value(value)
return self._prepare_raw_value(value)

def _prepare_raw_value(self, value: ty.Any, **dump_kwargs):
if isinstance(value, Value) and isinstance(value.output_field, self.__class__):
# Prepare inner value for `Value`-wrapped expressions.
value = Value(self._prepare_raw_value(value.value), value.output_field)
elif not isinstance(value, BaseExpression):
# Prepare the value if it is not a query expression.
try:
value = self.adapter.validate_python(value)
except pydantic.ValidationError:
"""This is a legitimate situation, the data could not be initially coerced."""
value = self.adapter.dump_python(value, **dump_kwargs)

return value


class SchemaKeyTransformAdapter:
Expand All @@ -217,24 +227,22 @@ def __call__(self, col: Col | None = None, *args, **kwargs) -> Transform | None:
def SchemaField(
schema: type[types.ST | None] | ty.ForwardRef = ...,
config: pydantic.ConfigDict = ...,
default: types.SchemaT | None | ty.Callable[[], types.SchemaT | None] = ...,
default: types.SchemaT | ty.Callable[[], types.SchemaT | None] | BaseExpression | None = ...,
*args,
null: ty.Literal[True],
**kwargs: te.Unpack[_SchemaFieldKwargs],
) -> types.ST | None:
...
) -> types.ST | None: ...


@ty.overload
def SchemaField(
schema: type[types.ST] | ty.ForwardRef = ...,
config: pydantic.ConfigDict = ...,
default: ty.Union[types.SchemaT, ty.Callable[[], types.SchemaT]] = ...,
default: types.SchemaT | ty.Callable[[], types.SchemaT] | BaseExpression = ...,
*args,
null: ty.Literal[False] = ...,
**kwargs: te.Unpack[_SchemaFieldKwargs],
) -> types.ST:
...
) -> types.ST: ...


def SchemaField(schema=None, config=None, default=NOT_PROVIDED, *args, **kwargs): # type: ignore
Expand Down
26 changes: 17 additions & 9 deletions django_pydantic_field/v2/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ def to_python(self, value: ty.Any) -> ty.Any:
return None

try:
if not isinstance(value, (str, bytes)):
# The form data may contain python objects for some cases (e.g. using django-constance).
value = self.adapter.validate_python(value)
elif not isinstance(value, JSONString):
# Otherwise, try to parse incoming JSON according to the schema.
value = self.adapter.validate_json(value)
value = self._try_coerce(value)
except pydantic.ValidationError as exc:
error_params = {"value": value, "title": exc.title, "detail": exc.json(), "errors": exc.errors()}
raise ValidationError(self.error_messages["schema_error"], code="invalid", params=error_params) from exc
Expand All @@ -76,10 +71,23 @@ def prepare_value(self, value):
if isinstance(value, InvalidJSONInput):
return value

value = self.adapter.validate_python(value)
value = self._try_coerce(value)
return self.adapter.dump_json(value).decode()

def has_changed(self, initial: ty.Any | None, data: ty.Any | None) -> bool:
if super(JSONField, self).has_changed(initial, data):
try:
initial = self._try_coerce(initial)
data = self._try_coerce(data)
return self.adapter.dump_python(initial) != self.adapter.dump_python(data)
except pydantic.ValidationError:
return True
return self.adapter.dump_json(initial) != self.adapter.dump_json(data)

def _try_coerce(self, value):
if not isinstance(value, (str, bytes)):
# The form data may contain python objects for some cases (e.g. using django-constance).
value = self.adapter.validate_python(value)
elif not isinstance(value, JSONString):
# Otherwise, try to parse incoming JSON according to the schema.
value = self.adapter.validate_json(value)

return value
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dependencies = [
]

[project.optional-dependencies]
openapi = ["uritemplate"]
openapi = ["uritemplate", "inflection"]
coreapi = ["coreapi"]
dev = [
"build",
Expand Down
23 changes: 23 additions & 0 deletions tests/settings/django_test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
SECRET_KEY = "1"
SITE_ID = 1
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
STATIC_URL = "/static/"

INSTALLED_APPS = [
"django.contrib.contenttypes",
Expand All @@ -12,10 +13,32 @@
"django.contrib.sessions",
"django.contrib.messages",
"django.contrib.staticfiles",
"django.contrib.admin",
"tests.sample_app",
"tests.test_app",
]

MIDDLEWARE = [
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
]
TEMPLATES = [
{
"BACKEND": "django.template.backends.django.DjangoTemplates",
"DIRS": [],
"APP_DIRS": True,
"OPTIONS": {
"context_processors": [
"django.template.context_processors.debug",
"django.template.context_processors.request",
"django.contrib.auth.context_processors.auth",
"django.contrib.messages.context_processors.messages",
],
},
},
]

DATABASES = {
"default": {
"ENGINE": "django.db.backends.sqlite3",
Expand Down
23 changes: 23 additions & 0 deletions tests/test_app/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from django.contrib import admin

from . import models


@admin.register(models.SampleModel)
class SampleModelAdmin(admin.ModelAdmin):
pass


@admin.register(models.SampleForwardRefModel)
class SampleForwardRefModelAdmin(admin.ModelAdmin):
pass


@admin.register(models.SampleModelWithRoot)
class SampleModelWithRootAdmin(admin.ModelAdmin):
pass


@admin.register(models.ExampleModel)
class ExampleModelAdmin(admin.ModelAdmin):
pass
31 changes: 31 additions & 0 deletions tests/test_app/migrations/0003_samplemodelwithroot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Generated by Django 5.0.1 on 2024-03-11 22:29

import django.core.serializers.json
import django_pydantic_field.fields
import tests.test_app.models
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("test_app", "0002_examplemodel"),
]

operations = [
migrations.CreateModel(
name="SampleModelWithRoot",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
(
"root_field",
django_pydantic_field.fields.PydanticSchemaField(
config=None,
default=list,
encoder=django.core.serializers.json.DjangoJSONEncoder,
schema=tests.test_app.models.RootSchema,
),
),
],
),
]
15 changes: 14 additions & 1 deletion tests/test_app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pydantic
from django.db import models
from django_pydantic_field import SchemaField
from django_pydantic_field.compat import PYDANTIC_V2

from ..conftest import InnerSchema

Expand Down Expand Up @@ -35,6 +36,18 @@ class SampleSchema(pydantic.BaseModel):
class ExampleSchema(pydantic.BaseModel):
count: int


class ExampleModel(models.Model):
example_field: ExampleSchema = SchemaField(default=ExampleSchema(count=1))


if PYDANTIC_V2:
class RootSchema(pydantic.RootModel):
root: t.List[int]

else:
class RootSchema(pydantic.BaseModel):
__root__: t.List[int]


class SampleModelWithRoot(models.Model):
root_field = SchemaField(schema=RootSchema, default=list)
31 changes: 31 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@
from .test_app.models import SampleForwardRefModel, SampleModel, SampleSchema


if PYDANTIC_V2:

class SampleRootModel(pydantic.RootModel):
root: ty.List[str]

else:

class SampleRootModel(pydantic.BaseModel):
__root__: ty.List[str]


@pytest.mark.parametrize(
"exported_primitive_name",
["SchemaField"],
Expand Down Expand Up @@ -98,6 +109,26 @@ class Meta:
default={"stub_str": "abc", "stub_list": [date(2022, 7, 1)]},
),
fields.PydanticSchemaField(schema=ty.Optional[InnerSchema], null=True, default=None),
fields.PydanticSchemaField(schema=SampleRootModel, default=[""]),
fields.PydanticSchemaField(schema=ty.Optional[SampleRootModel], default=[""]),
fields.PydanticSchemaField(schema=ty.Optional[SampleRootModel], null=True, default=None),
fields.PydanticSchemaField(schema=ty.Optional[SampleRootModel], null=True, blank=True),
pytest.param(
fields.PydanticSchemaField(schema=ty.Optional[SampleRootModel], default=SampleRootModel.parse_obj([])),
marks=pytest.mark.xfail(
PYDANTIC_V1,
reason="Prepared root-model based defaults are not supported with Pydantic v1",
raises=ValidationError,
),
),
pytest.param(
fields.PydanticSchemaField(schema=SampleRootModel, default=SampleRootModel.parse_obj([""])),
marks=pytest.mark.xfail(
PYDANTIC_V1,
reason="Prepared root-model based defaults are not supported with Pydantic v1",
raises=ValidationError,
),
),
pytest.param(
fields.PydanticSchemaField(
schema=InnerSchema,
Expand Down
Loading