Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure import resolution for compatibility layer #48

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion django_pydantic_field/compat/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@ def compat_getattr(module_name: str):

def compat_dir(module_name: str):
compat_module = _import_compat_module(module_name)
return dir(compat_module)
module_ns = vars(compat_module)

if "__dir__" in module_ns:
return module_ns["__dir__"]

if "__all__" in module_ns:
return functools.partial(list, module_ns["__all__"])

return functools.partial(dir, compat_module)


def _import_compat_module(module_name: str) -> types.ModuleType:
Expand Down
16 changes: 10 additions & 6 deletions django_pydantic_field/rest_framework.pyi
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import typing as ty
import typing_extensions as te

import typing_extensions as te
from django.utils.functional import _StrOrPromise
from rest_framework import parsers, renderers
from rest_framework.fields import _DefaultInitial, Field
from rest_framework.schemas.openapi import AutoSchema as _OpenAPIAutoSchema
from rest_framework.validators import Validator

from django.utils.functional import _StrOrPromise

from .fields import ST, ConfigType, _ExportKwargs
from .fields import _ExportKwargs, ConfigType, ST

__all__ = ("SchemaField", "SchemaParser", "SchemaRenderer")
__all__ = ("SchemaField", "SchemaParser", "SchemaRenderer", "AutoSchema")

class _FieldKwargs(te.TypedDict, ty.Generic[ST], total=False):
read_only: bool
Expand Down Expand Up @@ -45,7 +45,9 @@ class SchemaField(Field, ty.Generic[ST]):
**kwargs: te.Unpack[_SchemaFieldKwargs[ST]],
) -> None: ...
@ty.overload
@te.deprecated("Passing `json.dump` kwargs to `SchemaField` is not supported by Pydantic 2 and will be removed in the future versions.")
@te.deprecated(
"Passing `json.dump` kwargs to `SchemaField` is not supported by Pydantic 2 and will be removed in the future versions."
)
def __init__(
self,
schema: ty.Type[ST] | ty.ForwardRef | str,
Expand All @@ -61,3 +63,5 @@ class SchemaParser(parsers.JSONParser, ty.Generic[ST]):
class SchemaRenderer(renderers.JSONRenderer, ty.Generic[ST]):
schema_context_key: ty.ClassVar[str]
config_context_key: ty.ClassVar[str]

class AutoSchema(_OpenAPIAutoSchema): ...
2 changes: 2 additions & 0 deletions django_pydantic_field/v2/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from django_pydantic_field.compat import deprecation
from . import types

__all__ = ("SchemaField",)


class SchemaField(JSONField, ty.Generic[types.ST]):
adapter: types.SchemaAdapter
Expand Down
16 changes: 15 additions & 1 deletion django_pydantic_field/v2/rest_framework/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from django_pydantic_field.compat import PYDANTIC_V2

from . import coreapi as coreapi
from . import openapi as openapi
from .fields import SchemaField as SchemaField
from .parsers import SchemaParser as SchemaParser
from .renderers import SchemaRenderer as SchemaRenderer
Expand All @@ -8,8 +12,18 @@
"or `django_pydantic_field.rest_framework.coreapi.AutoSchema` instead."
)

__all__ = (
"coreapi",
"openapi",
"SchemaField",
"SchemaParser",
"SchemaRenderer",
"AutoSchema", # type: ignore
)


def __getattr__(key):
if key == "AutoSchema":
if key == "AutoSchema" and PYDANTIC_V2:
import warnings

from .openapi import AutoSchema
Expand Down
10 changes: 5 additions & 5 deletions django_pydantic_field/v2/rest_framework/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rest_framework.test import APIRequestFactory

from . import fields, parsers, renderers
from ..utils import get_origin_type
from django_pydantic_field.v2 import utils

if ty.TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -60,7 +60,7 @@ def get_request_body(self, path, method):
schema_content = {}

for parser, ct in zip(self.view.parser_classes, self.request_media_types):
if issubclass(get_origin_type(parser), parsers.SchemaParser):
if issubclass(utils.get_origin_type(parser), parsers.SchemaParser):
parser_schema = self.collected_adapter_schema_refs[repr(parser)]
else:
parser_schema = request_schema
Expand All @@ -86,7 +86,7 @@ def get_responses(self, path, method):

schema_content = {}
for renderer, ct in zip(self.view.renderer_classes, self.response_media_types):
if issubclass(get_origin_type(renderer), renderers.SchemaRenderer):
if issubclass(utils.get_origin_type(renderer), renderers.SchemaRenderer):
renderer_schema = {"schema": self.collected_adapter_schema_refs[repr(renderer)]}
if is_list_view:
renderer_schema = self._get_paginated_schema(renderer_schema)
Expand All @@ -108,7 +108,7 @@ def map_parsers(self, path: str, method: str) -> list[str]:

for parser in self.view.parser_classes:
media_types.append(parser.media_type)
if issubclass(get_origin_type(parser), parsers.SchemaParser):
if issubclass(utils.get_origin_type(parser), parsers.SchemaParser):
schema_parsers.append(parser)

if schema_parsers:
Expand All @@ -125,7 +125,7 @@ def map_renderers(self, path: str, method: str) -> list[str]:

for renderer in self.view.renderer_classes:
media_types.append(renderer.media_type)
if issubclass(get_origin_type(renderer), renderers.SchemaRenderer):
if issubclass(utils.get_origin_type(renderer), renderers.SchemaRenderer):
schema_renderers.append(renderer)

if schema_renderers:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
from .test_app.models import SampleForwardRefModel, SampleModel, SampleSchema


@pytest.mark.parametrize(
"exported_primitive_name",
["SchemaField"],
)
def test_module_imports(exported_primitive_name):
assert exported_primitive_name in dir(fields)
assert getattr(fields, exported_primitive_name, None) is not None


def test_sample_field():
sample_field = fields.PydanticSchemaField(schema=InnerSchema)
existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])
Expand Down
53 changes: 53 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import warnings

import pytest

import django_pydantic_field
from django_pydantic_field import fields, forms, rest_framework
from django_pydantic_field.compat import PYDANTIC_V1, PYDANTIC_V2


@pytest.mark.parametrize(
"module, exported_primitive_name",
[
(django_pydantic_field, "SchemaField"),
(fields, "SchemaField"),
(forms, "SchemaField"),
(rest_framework, "SchemaParser"),
(rest_framework, "SchemaRenderer"),
(rest_framework, "SchemaField"),
(rest_framework, "AutoSchema"),
pytest.param(
rest_framework,
"openapi",
marks=pytest.mark.skipif(
not PYDANTIC_V2,
reason="`.rest_framework.openapi` module is only appearing in v2 layer",
),
),
pytest.param(
rest_framework,
"coreapi",
marks=pytest.mark.skipif(
not PYDANTIC_V2,
reason="`.rest_framework.coreapi` module is only appearing in v2 layer",
),
),
],
)
def test_module_imports(module, exported_primitive_name):
assert exported_primitive_name in dir(module)
assert getattr(module, exported_primitive_name, None) is not None


@pytest.mark.skipif(not PYDANTIC_V2, reason="AutoSchema import warning is only appearing in v2 layer")
def test_rest_framework_autoschema_warning_v2():
with pytest.deprecated_call(match="`django_pydantic_field.rest_framework.AutoSchema` is deprecated.*"):
rest_framework.AutoSchema


@pytest.mark.skipif(not PYDANTIC_V1, reason="Deprecation warning should not be raised in v1 layer")
def test_rest_framework_autoschema_no_warning_v1():
with warnings.catch_warnings():
warnings.simplefilter("error")
rest_framework.AutoSchema