diff --git a/odmantic/model.py b/odmantic/model.py index 08308f29..a927d4f4 100644 --- a/odmantic/model.py +++ b/odmantic/model.py @@ -69,6 +69,7 @@ dataclass_transform, get_args, get_first_type_argument_subclassing, + get_generic_origin, get_origin, is_classvar, is_type_argument_subclass, @@ -303,7 +304,7 @@ def __validate_cls_namespace__( # noqa C901 "Declaring a generic type of embedded models containing " f"references is not allowed: {field_name} in {name}" ) - generic_origin = get_origin(field_type) + generic_origin = get_generic_origin(field_type) assert generic_origin is not None odm_fields[field_name] = ODMEmbeddedGeneric( model=model, diff --git a/odmantic/typing.py b/odmantic/typing.py index 4beb82cb..ab3f33f7 100644 --- a/odmantic/typing.py +++ b/odmantic/typing.py @@ -1,5 +1,5 @@ import sys -from typing import TYPE_CHECKING, AbstractSet, Any # noqa: F401 +from typing import TYPE_CHECKING, AbstractSet, Any, Optional # noqa: F401 from typing import Callable as TypingCallable from typing import Dict, Iterable, Mapping, Tuple, Type, TypeVar, Union # noqa: F401 @@ -36,9 +36,27 @@ IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None" +def is_optional(type_: Type) -> bool: + type_origin: Optional[Type] = getattr(type_, "__origin__", None) + if type_origin is Union: + type_args: Tuple[Type, ...] = getattr(type_, "__args__", ()) + if type_args: + return type_origin is Union and type_args[1] is type(None) + return False + + +def resolve_optional_to_some(type_: Type) -> Type: + if is_optional(type_): + type_args: Tuple[Type, ...] = getattr(type_, "__args__", ()) + assert type_args + type_ = type_args[0] + return type_ + + def is_type_argument_subclass( type_: Type, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...]] ) -> bool: + type_ = resolve_optional_to_some(type_) args = get_args(type_) return any(lenient_issubclass(arg, class_or_tuple) for arg in args) @@ -49,8 +67,14 @@ def is_type_argument_subclass( def get_first_type_argument_subclassing( type_: Type, cls: Type[T] ) -> Union[Type[T], None]: + type_ = resolve_optional_to_some(type_) args: Tuple[Type, ...] = get_args(type_) for arg in args: if lenient_issubclass(arg, cls): return arg return None + + +def get_generic_origin(type_: Type) -> Optional[Any]: + type_ = resolve_optional_to_some(type_) + return get_origin(type_) diff --git a/tests/integration/test_embedded_model.py b/tests/integration/test_embedded_model.py index 9c648890..e7f6391b 100644 --- a/tests/integration/test_embedded_model.py +++ b/tests/integration/test_embedded_model.py @@ -1,6 +1,7 @@ -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import pytest +from pymongo.database import Database from odmantic.engine import AIOEngine, SyncEngine from odmantic.field import Field @@ -227,7 +228,9 @@ class User(Model): assert sync_engine.find_one(User, User.id == Id(user=1, chat=1001)) is not None -async def test_embedded_model_custom_key_name_save_and_fetch(aio_engine: AIOEngine): +async def test_embedded_model_custom_key_name_save_and_fetch( + pymongo_database: Database, aio_engine: AIOEngine +): class In(EmbeddedModel): a: int = Field(key_name="in-a") @@ -238,9 +241,19 @@ class Out(Model): await aio_engine.save(instance) fetched = await aio_engine.find_one(Out) assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id # type: ignore + } + ) + assert document + assert document["in"] + assert document["in"]["in-a"] == 3 -def test_sync_embedded_model_custom_key_name_save_and_fetch(sync_engine: SyncEngine): +def test_sync_embedded_model_custom_key_name_save_and_fetch( + pymongo_database: Database, sync_engine: SyncEngine +): class In(EmbeddedModel): a: int = Field(key_name="in-a") @@ -251,9 +264,18 @@ class Out(Model): sync_engine.save(instance) fetched = sync_engine.find_one(Out) assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id # type: ignore + } + ) + assert document + assert document["in"] + assert document["in"]["in-a"] == 3 async def test_embedded_model_list_custom_key_name_save_and_fetch( + pymongo_database: Database, aio_engine: AIOEngine, ): class In(EmbeddedModel): @@ -266,9 +288,18 @@ class Out(Model): await aio_engine.save(instance) fetched = await aio_engine.find_one(Out) assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id # type: ignore + } + ) + assert document + assert document["in"] + assert document["in"][0]["in-a"] == 3 def test_sync_embedded_model_list_custom_key_name_save_and_fetch( + pymongo_database: Database, sync_engine: SyncEngine, ): class In(EmbeddedModel): @@ -281,9 +312,64 @@ class Out(Model): sync_engine.save(instance) fetched = sync_engine.find_one(Out) assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id # type: ignore + } + ) + assert document + assert document["in"] + assert document["in"][0]["in-a"] == 3 + + +async def test_embedded_model_optional_list_custom_key_name_save_and_fetch( + pymongo_database: Database, aio_engine: AIOEngine +): + class In(EmbeddedModel): + a: int = Field(key_name="in-a") + + class Out(Model): + inner: Optional[List[In]] = Field(key_name="in", default_factory=list) + + instance = Out(inner=[In(a=3)]) + await aio_engine.save(instance) + fetched = await aio_engine.find_one(Out) + assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id, # type: ignore + } + ) + assert document is not None + assert document["in"] + assert document["in"][0]["in-a"] == 3 + + +def test_sync_embedded_model_optional_list_custom_key_name_save_and_fetch( + pymongo_database: Database, sync_engine: SyncEngine +): + class In(EmbeddedModel): + a: int = Field(key_name="in-a") + + class Out(Model): + inner: Optional[List[In]] = Field(key_name="in", default_factory=list) + + instance = Out(inner=[In(a=3)]) + sync_engine.save(instance) + fetched = sync_engine.find_one(Out) + assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id, # type: ignore + } + ) + assert document is not None + assert document["in"] + assert document["in"][0]["in-a"] == 3 async def test_embedded_model_dict_custom_key_name_save_and_fetch( + pymongo_database: Database, aio_engine: AIOEngine, ): class In(EmbeddedModel): @@ -296,9 +382,18 @@ class Out(Model): await aio_engine.save(instance) fetched = await aio_engine.find_one(Out) assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id, # type: ignore + } + ) + assert document is not None + assert document["in"] + assert document["in"]["key"]["in-a"] == 3 def test_sync_embedded_model_dict_custom_key_name_save_and_fetch( + pymongo_database: Database, sync_engine: SyncEngine, ): class In(EmbeddedModel): @@ -311,3 +406,59 @@ class Out(Model): sync_engine.save(instance) fetched = sync_engine.find_one(Out) assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id, # type: ignore + } + ) + assert document is not None + assert document["in"] + assert document["in"]["key"]["in-a"] == 3 + + +async def test_embedded_model_optional_dict_custom_key_name_save_and_fetch( + pymongo_database: Database, + aio_engine: AIOEngine, +): + class In(EmbeddedModel): + a: int = Field(key_name="in-a") + + class Out(Model): + inner: Optional[Dict[str, In]] = Field(key_name="in", default_factory=dict) + + instance = Out(inner={"key": In(a=3)}) + await aio_engine.save(instance) + fetched = await aio_engine.find_one(Out) + assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id, # type: ignore + } + ) + assert document is not None + assert document["in"] + assert document["in"]["key"]["in-a"] == 3 + + +def test_sync_embedded_model_optional_dict_custom_key_name_save_and_fetch( + pymongo_database: Database, + sync_engine: SyncEngine, +): + class In(EmbeddedModel): + a: int = Field(key_name="in-a") + + class Out(Model): + inner: Optional[Dict[str, In]] = Field(key_name="in", default_factory=dict) + + instance = Out(inner={"key": In(a=3)}) + sync_engine.save(instance) + fetched = sync_engine.find_one(Out) + assert instance == fetched + document = pymongo_database[Out.__collection__].find_one( + { + +Out.id: instance.id, # type: ignore + } + ) + assert document is not None + assert document["in"] + assert document["in"]["key"]["in-a"] == 3 diff --git a/tests/integration/test_index.py b/tests/integration/test_index.py index 1c8e631f..6f90752b 100644 --- a/tests/integration/test_index.py +++ b/tests/integration/test_index.py @@ -367,7 +367,7 @@ class Post(Model): assert await aio_engine.find_one(Post, {"$text": {"$search": "python"}}) is not None -async def test_sync_custom_text_index(sync_engine: SyncEngine): +def test_sync_custom_text_index(sync_engine: SyncEngine): class Post(Model): title: str content: str