From 9883552646196bd913dbbb33e39320825d9bec43 Mon Sep 17 00:00:00 2001 From: Anentropic Date: Thu, 30 May 2024 17:36:31 +0100 Subject: [PATCH 1/3] failing test reproducing issue #484 --- .pre-commit-config.yaml | 2 +- tests/integration/test_embedded_model.py | 91 +++++++++++++++++++++++- tests/integration/test_index.py | 2 +- tests/integration/test_types.py | 12 +++- 4 files changed, 101 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7837431a..06fe7623 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks default_language_version: - python: python3.8 + python: python3.11 node: 15.4.0 repos: - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/tests/integration/test_embedded_model.py b/tests/integration/test_embedded_model.py index 9c648890..25555bd2 100644 --- a/tests/integration/test_embedded_model.py +++ b/tests/integration/test_embedded_model.py @@ -1,6 +1,7 @@ -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import pytest +from pymongo.database import Database from odmantic.engine import AIOEngine, SyncEngine from odmantic.field import Field @@ -227,7 +228,9 @@ class User(Model): assert sync_engine.find_one(User, User.id == Id(user=1, chat=1001)) is not None -async def test_embedded_model_custom_key_name_save_and_fetch(aio_engine: AIOEngine): +async def test_embedded_model_custom_key_name_save_and_fetch( + pymongo_database: Database, aio_engine: AIOEngine +): class In(EmbeddedModel): a: int = Field(key_name="in-a") @@ -238,9 +241,19 @@ class Out(Model): await aio_engine.save(instance) fetched = await aio_engine.find_one(Out) assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id # type: ignore + } + ) + assert document + assert document["in"] + assert document["in"]["in-a"] == 3 -def test_sync_embedded_model_custom_key_name_save_and_fetch(sync_engine: SyncEngine): +def test_sync_embedded_model_custom_key_name_save_and_fetch( + pymongo_database: Database, sync_engine: SyncEngine +): class In(EmbeddedModel): a: int = Field(key_name="in-a") @@ -251,9 +264,18 @@ class Out(Model): sync_engine.save(instance) fetched = sync_engine.find_one(Out) assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id # type: ignore + } + ) + assert document + assert document["in"] + assert document["in"]["in-a"] == 3 async def test_embedded_model_list_custom_key_name_save_and_fetch( + pymongo_database: Database, aio_engine: AIOEngine, ): class In(EmbeddedModel): @@ -266,9 +288,18 @@ class Out(Model): await aio_engine.save(instance) fetched = await aio_engine.find_one(Out) assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id # type: ignore + } + ) + assert document + assert document["in"] + assert document["in"][0]["in-a"] == 3 def test_sync_embedded_model_list_custom_key_name_save_and_fetch( + pymongo_database: Database, sync_engine: SyncEngine, ): class In(EmbeddedModel): @@ -281,6 +312,60 @@ class Out(Model): sync_engine.save(instance) fetched = sync_engine.find_one(Out) assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id # type: ignore + } + ) + assert document + assert document["in"] + assert document["in"][0]["in-a"] == 3 + + +async def test_embedded_model_optional_list_custom_key_name_save_and_fetch( + pymongo_database: Database, aio_engine: AIOEngine +): + class In(EmbeddedModel): + a: int = Field(key_name="in-a") + + class Out(Model): + inner: Optional[List[In]] = Field(key_name="in", default_factory=list) + + instance = Out(inner=[In(a=3)]) + await aio_engine.save(instance) + fetched = await aio_engine.find_one(Out) + assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id, # type: ignore + } + ) + assert document is not None + assert document["in"] + assert document["in"][0]["in-a"] == 3 + + +def test_sync_embedded_model_optional_list_custom_key_name_save_and_fetch( + pymongo_database: Database, sync_engine: SyncEngine +): + class In(EmbeddedModel): + a: int = Field(key_name="in-a") + + class Out(Model): + inner: Optional[List[In]] = Field(key_name="in", default_factory=list) + + instance = Out(inner=[In(a=3)]) + sync_engine.save(instance) + fetched = sync_engine.find_one(Out) + assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id, # type: ignore + } + ) + assert document is not None + assert document["in"] + assert document["in"][0]["in-a"] == 3 async def test_embedded_model_dict_custom_key_name_save_and_fetch( diff --git a/tests/integration/test_index.py b/tests/integration/test_index.py index 1c8e631f..6f90752b 100644 --- a/tests/integration/test_index.py +++ b/tests/integration/test_index.py @@ -367,7 +367,7 @@ class Post(Model): assert await aio_engine.find_one(Post, {"$text": {"$search": "python"}}) is not None -async def test_sync_custom_text_index(sync_engine: SyncEngine): +def test_sync_custom_text_index(sync_engine: SyncEngine): class Post(Model): title: str content: str diff --git a/tests/integration/test_types.py b/tests/integration/test_types.py index 361cd884..d3c79186 100644 --- a/tests/integration/test_types.py +++ b/tests/integration/test_types.py @@ -2,7 +2,17 @@ import re from datetime import datetime from decimal import Decimal -from typing import Any, Dict, Generic, List, Pattern, Tuple, Type, TypeVar, Union +from typing import ( + Any, + Dict, + Generic, + List, + Pattern, + Tuple, + Type, + TypeVar, + Union, +) import pytest from bson import Binary, Decimal128, Int64, ObjectId, Regex From 469925495b2551d42ae5b5a83214869728ec3157 Mon Sep 17 00:00:00 2001 From: Anentropic Date: Thu, 30 May 2024 17:36:32 +0100 Subject: [PATCH 2/3] handle optional embededd generic fields, fixes #484 --- .pre-commit-config.yaml | 2 +- odmantic/model.py | 3 ++- odmantic/typing.py | 26 +++++++++++++++++++++++++- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 06fe7623..7837431a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks default_language_version: - python: python3.11 + python: python3.8 node: 15.4.0 repos: - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/odmantic/model.py b/odmantic/model.py index 08308f29..a927d4f4 100644 --- a/odmantic/model.py +++ b/odmantic/model.py @@ -69,6 +69,7 @@ dataclass_transform, get_args, get_first_type_argument_subclassing, + get_generic_origin, get_origin, is_classvar, is_type_argument_subclass, @@ -303,7 +304,7 @@ def __validate_cls_namespace__( # noqa C901 "Declaring a generic type of embedded models containing " f"references is not allowed: {field_name} in {name}" ) - generic_origin = get_origin(field_type) + generic_origin = get_generic_origin(field_type) assert generic_origin is not None odm_fields[field_name] = ODMEmbeddedGeneric( model=model, diff --git a/odmantic/typing.py b/odmantic/typing.py index 4beb82cb..ab3f33f7 100644 --- a/odmantic/typing.py +++ b/odmantic/typing.py @@ -1,5 +1,5 @@ import sys -from typing import TYPE_CHECKING, AbstractSet, Any # noqa: F401 +from typing import TYPE_CHECKING, AbstractSet, Any, Optional # noqa: F401 from typing import Callable as TypingCallable from typing import Dict, Iterable, Mapping, Tuple, Type, TypeVar, Union # noqa: F401 @@ -36,9 +36,27 @@ IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None" +def is_optional(type_: Type) -> bool: + type_origin: Optional[Type] = getattr(type_, "__origin__", None) + if type_origin is Union: + type_args: Tuple[Type, ...] = getattr(type_, "__args__", ()) + if type_args: + return type_origin is Union and type_args[1] is type(None) + return False + + +def resolve_optional_to_some(type_: Type) -> Type: + if is_optional(type_): + type_args: Tuple[Type, ...] = getattr(type_, "__args__", ()) + assert type_args + type_ = type_args[0] + return type_ + + def is_type_argument_subclass( type_: Type, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...]] ) -> bool: + type_ = resolve_optional_to_some(type_) args = get_args(type_) return any(lenient_issubclass(arg, class_or_tuple) for arg in args) @@ -49,8 +67,14 @@ def is_type_argument_subclass( def get_first_type_argument_subclassing( type_: Type, cls: Type[T] ) -> Union[Type[T], None]: + type_ = resolve_optional_to_some(type_) args: Tuple[Type, ...] = get_args(type_) for arg in args: if lenient_issubclass(arg, cls): return arg return None + + +def get_generic_origin(type_: Type) -> Optional[Any]: + type_ = resolve_optional_to_some(type_) + return get_origin(type_) From 927a93f27bdefa5c3715b91f5d7665038992bab6 Mon Sep 17 00:00:00 2001 From: Anentropic Date: Thu, 30 May 2024 17:36:32 +0100 Subject: [PATCH 3/3] tests for optional dict too --- tests/integration/test_embedded_model.py | 66 ++++++++++++++++++++++++ tests/integration/test_types.py | 12 +---- 2 files changed, 67 insertions(+), 11 deletions(-) diff --git a/tests/integration/test_embedded_model.py b/tests/integration/test_embedded_model.py index 25555bd2..e7f6391b 100644 --- a/tests/integration/test_embedded_model.py +++ b/tests/integration/test_embedded_model.py @@ -369,6 +369,7 @@ class Out(Model): async def test_embedded_model_dict_custom_key_name_save_and_fetch( + pymongo_database: Database, aio_engine: AIOEngine, ): class In(EmbeddedModel): @@ -381,9 +382,18 @@ class Out(Model): await aio_engine.save(instance) fetched = await aio_engine.find_one(Out) assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id, # type: ignore + } + ) + assert document is not None + assert document["in"] + assert document["in"]["key"]["in-a"] == 3 def test_sync_embedded_model_dict_custom_key_name_save_and_fetch( + pymongo_database: Database, sync_engine: SyncEngine, ): class In(EmbeddedModel): @@ -396,3 +406,59 @@ class Out(Model): sync_engine.save(instance) fetched = sync_engine.find_one(Out) assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id, # type: ignore + } + ) + assert document is not None + assert document["in"] + assert document["in"]["key"]["in-a"] == 3 + + +async def test_embedded_model_optional_dict_custom_key_name_save_and_fetch( + pymongo_database: Database, + aio_engine: AIOEngine, +): + class In(EmbeddedModel): + a: int = Field(key_name="in-a") + + class Out(Model): + inner: Optional[Dict[str, In]] = Field(key_name="in", default_factory=dict) + + instance = Out(inner={"key": In(a=3)}) + await aio_engine.save(instance) + fetched = await aio_engine.find_one(Out) + assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id, # type: ignore + } + ) + assert document is not None + assert document["in"] + assert document["in"]["key"]["in-a"] == 3 + + +def test_sync_embedded_model_optional_dict_custom_key_name_save_and_fetch( + pymongo_database: Database, + sync_engine: SyncEngine, +): + class In(EmbeddedModel): + a: int = Field(key_name="in-a") + + class Out(Model): + inner: Optional[Dict[str, In]] = Field(key_name="in", default_factory=dict) + + instance = Out(inner={"key": In(a=3)}) + sync_engine.save(instance) + fetched = sync_engine.find_one(Out) + assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id, # type: ignore + } + ) + assert document is not None + assert document["in"] + assert document["in"]["key"]["in-a"] == 3 diff --git a/tests/integration/test_types.py b/tests/integration/test_types.py index d3c79186..361cd884 100644 --- a/tests/integration/test_types.py +++ b/tests/integration/test_types.py @@ -2,17 +2,7 @@ import re from datetime import datetime from decimal import Decimal -from typing import ( - Any, - Dict, - Generic, - List, - Pattern, - Tuple, - Type, - TypeVar, - Union, -) +from typing import Any, Dict, Generic, List, Pattern, Tuple, Type, TypeVar, Union import pytest from bson import Binary, Decimal128, Int64, ObjectId, Regex