Skip to content

Commit

Permalink
Merge branch 'master' of github.com:wesselb/plum
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Jul 6, 2024
2 parents 159a253 + d14991d commit 2fa552e
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 29 deletions.
2 changes: 2 additions & 0 deletions plum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from .util import * # noqa: F401, F403

# Deprecated
# isort: split
from .parametric import Val # noqa: F401, F403
from .util import multihash # noqa: F401, F403

# Ensure that type checking is always entirely correct! The default O(1) strategy
Expand Down
26 changes: 14 additions & 12 deletions plum/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,22 @@
parsing how unions print.
"""

import typing
from functools import wraps
from typing import List, TypeVar, Union, _type_repr

from .typing import get_args

__all__ = ["activate_union_aliases", "deactivate_union_aliases", "set_union_alias"]

_union_type = type(typing.Union[int, float])
UnionT = TypeVar("UnionT")

_union_type = type(Union[int, float])
_original_repr = _union_type.__repr__
_original_str = _union_type.__str__


@wraps(_original_repr)
def _new_repr(self):
def _new_repr(self: object) -> str:
"""Print a `typing.Union`, replacing all aliased unions by their aliased names.
Returns:
Expand All @@ -52,7 +54,7 @@ def _new_repr(self):
found_unions = []
found_positions = []
found_aliases = []
for union, alias in reversed(_aliased_unions):
for union, alias in reversed(_ALIASED_UNIONS):
union_set = set(union)
if union_set <= args_set:
found = False
Expand Down Expand Up @@ -103,7 +105,7 @@ def _new_repr(self):
args = new_args

# Generate a string representation.
args_repr = [a if isinstance(a, str) else typing._type_repr(a) for a in args]
args_repr = [a if isinstance(a, str) else _type_repr(a) for a in args]
# Like `typing` does, print `Optional` whenever possible.
if len(args) == 2:
if args[0] is type(None): # noqa: E721
Expand All @@ -116,7 +118,7 @@ def _new_repr(self):


@wraps(_original_str)
def _new_str(self):
def _new_str(self: object) -> str:
"""Does the same as :func:`_new_repr`.
Returns:
Expand All @@ -125,24 +127,24 @@ def _new_str(self):
return _new_repr(self)


def activate_union_aliases():
def activate_union_aliases() -> None:
"""When printing `typing.Union`s, replace all aliased unions by the aliased names.
This monkey patches `__repr__` and `__str__` for `typing.Union`."""
_union_type.__repr__ = _new_repr
_union_type.__str__ = _new_str


def deactivate_union_aliases():
def deactivate_union_aliases() -> None:
"""Undo what :func:`.alias.activate` did. This restores the original `__repr__`
and `__str__` for `typing.Union`."""
_union_type.__repr__ = _original_repr
_union_type.__str__ = _original_str


_aliased_unions = []
_ALIASED_UNIONS: List = []


def set_union_alias(union, alias):
def set_union_alias(union: UnionT, alias: str) -> UnionT:
"""Change how a `typing.Union` is printed. This does not modify `union`.
Args:
Expand All @@ -153,12 +155,12 @@ def set_union_alias(union, alias):
type or type hint: `union`.
"""
args = get_args(union) if isinstance(union, _union_type) else (union,)
for existing_union, existing_alias in _aliased_unions:
for existing_union, existing_alias in _ALIASED_UNIONS:
if set(existing_union) == set(args) and alias != existing_alias:
if isinstance(union, _union_type):
union_str = _original_str(union)
else:
union_str = repr(union)
raise RuntimeError(f"`{union_str}` already has alias `{existing_alias}`.")
_aliased_unions.append((args, alias))
_ALIASED_UNIONS.append((args, alias))
return union
11 changes: 3 additions & 8 deletions plum/parametric.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import contextlib
import warnings
from typing import Type, TypeVar, Union

from typing_extensions import deprecated

import beartype.door
from beartype.roar import BeartypeDoorNonpepException

Expand All @@ -19,7 +20,6 @@
"type_unparametrized",
"kind",
"Kind",
"Val",
]

T = TypeVar("T")
Expand Down Expand Up @@ -632,6 +632,7 @@ def get(self):
Kind = kind() #: A default kind provided for convenience.


@deprecated("Use `typing.Literal[val]` instead.")
@parametric
class Val:
"""A parametric type used to move information from the value domain to the type
Expand Down Expand Up @@ -661,12 +662,6 @@ def __init__(self, val=None):
Args:
val (object): The value to be moved to the type domain.
"""
warnings.warn(
"`plum.Val` is deprecated and will be removed in a future version. "
"Please use `typing.Literal` instead.",
category=DeprecationWarning,
stacklevel=2,
)
if type(self).concrete:
if val is not None and type_parameter(self) != val:
raise ValueError("The value must be equal to the type parameter.")
Expand Down
10 changes: 3 additions & 7 deletions plum/util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import abc
import sys
import warnings
from typing import Hashable, List, Sequence

from typing_extensions import deprecated

if sys.version_info.minor <= 8: # pragma: specific no cover 3.9 3.10 3.11
from typing import Callable
else: # pragma: specific no cover 3.8
Expand Down Expand Up @@ -47,6 +48,7 @@ def __init__(self):
raise TypeError("`Missing` cannot be instantiated.")


@deprecated("Use `hash(tuple_of_args)` instead.")
def multihash(*args: Hashable) -> int:
"""Multi-argument order-sensitive hash.
Expand All @@ -56,12 +58,6 @@ def multihash(*args: Hashable) -> int:
Returns:
int: Hash.
"""
warnings.warn(
"The function `multihash` is deprecated and will be removed in a future "
"version. Please use `hash(tuple(*args))` instead.",
DeprecationWarning,
stacklevel=2,
)
return hash(args)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from plum import activate_union_aliases, deactivate_union_aliases, set_union_alias
from plum.alias import _aliased_unions
from plum.alias import _ALIASED_UNIONS


@pytest.fixture()
Expand All @@ -13,7 +13,7 @@ def union_aliases():
activate_union_aliases()
yield
deactivate_union_aliases()
_aliased_unions.clear()
_ALIASED_UNIONS.clear()


@pytest.mark.parametrize("display", [str, repr])
Expand Down

0 comments on commit 2fa552e

Please sign in to comment.