Source code for plum._alias

"""This module monkey patches `__repr__` and `__str__` of `typing.Union` to control how
`typing.Unions` are displayed.

Example::

    >> plum.activate_union_aliases()

    >> IntOrFloat = typing.Union[int, float]

    >> IntOrFloat
    Union[int, float]

    >> plum.set_union_alias(IntOrFloat, "IntOrFloat")

    >> IntOrFloat
    typing.Union[IntOrFloat]

    >> typing.Union[int, float]
    typing.Union[IntOrFloat]

    >> typing.Union[int, float, str]
    typing.Union[IntOrFloat, str]

Note that `IntOrFloat` prints to `typing.Union[IntOrFloat]` rather than just
`IntOrFloat`. This is deliberate, with the goal of not breaking code that relies on
parsing how unions print.
"""

# pyright: reportUnreachable=false

__all__ = (
    "activate_union_aliases",
    "deactivate_union_aliases",
    "set_union_alias",
)

import sys
from functools import wraps
from typing import Any, TypeVar, Union, _type_repr, get_args
from typing_extensions import TypeAliasType

UnionT = TypeVar("UnionT")

_union_type = type(Union[int, float])  # noqa: UP007

_ALIASES_ARE_ACTIVE: bool = True

if sys.version_info < (3, 14):  # pragma: specific no cover 3.14
    _original_repr = _union_type.__repr__
    _original_str = _union_type.__str__

    _ALIASED_UNIONS: dict[tuple[Any, ...], str] = {}

    @wraps(_original_repr)
    def _new_repr(self: object) -> str:
        """Print a `typing.Union`, replacing all aliased unions by their aliased names.

        Returns:
            str: Representation of a `typing.Union` taking into account union aliases.
        """
        args = get_args(self)
        args_set = set(args)

        # Find all aliased unions contained in this union.
        found_unions = []
        found_positions = []
        found_aliases = []
        for union, alias in reversed(_ALIASED_UNIONS.items()):
            union_set = set(union)
            if union_set <= args_set:
                found = False
                for i, arg in enumerate(args):
                    if arg in union_set:
                        found_unions.append(union_set)
                        found_positions.append(i)
                        found_aliases.append(alias)
                        found = True
                        break
                if not found:  # pragma: no cover
                    # This branch should never be reached.
                    raise AssertionError(
                        "Could not identify union. This should never happen."
                    )

        # Delete any unions that are contained in strictly bigger unions. We
        # check for strictly inequality because any union includes itself.
        for i in range(len(found_unions) - 1, -1, -1):
            for union_ in found_unions:
                if found_unions[i] < set(union_):
                    del found_unions[i]
                    del found_positions[i]
                    del found_aliases[i]
                    break

        # Create a set with all arguments of all found unions.
        found_args = set().union(*found_unions) if found_unions else set()

        # Build a mapping from original position to aliases to insert before it.
        inserts: dict[int, list[str]] = {}
        for pos, alias in zip(found_positions, found_aliases, strict=False):
            inserts.setdefault(pos, []).append(alias)
        # Interleave aliases at the appropriate positions.
        args = tuple(
            v for i, arg in enumerate(args) for v in (*inserts.pop(i, []), arg)
        )

        # Filter all elements of unions that are aliased.
        args = tuple(arg for arg in args if arg not in found_args)

        # Generate a string representation.
        args_repr = [a if isinstance(a, str) else _type_repr(a) for a in args]
        # Like `typing` does, print `Optional` whenever possible.
        if len(args) == 2:
            if args[0] is type(None):  # noqa: E721
                return f"typing.Optional[{args_repr[1]}]"
            elif args[1] is type(None):  # noqa: E721
                return f"typing.Optional[{args_repr[0]}]"
        # We would like to just print `args_repr[0]` whenever `len(args) == 1`, but
        # this might break code that parses how unions print.
        return "typing.Union[" + ", ".join(args_repr) + "]"

    @wraps(_original_str)
    def _new_str(self: object) -> str:
        """Does the same as :func:`_new_repr`.

        Returns:
            str: Representation of the `typing.Union` taking into account union aliases.
        """
        return _new_repr(self)

    def activate_union_aliases() -> None:
        """When printing `typing.Union`s, replace aliased unions by the aliased names.
        This monkey patches `__repr__` and `__str__` for `typing.Union`."""
        global _ALIASES_ARE_ACTIVE
        _union_type.__repr__ = _new_repr  # type: ignore[method-assign]
        _union_type.__str__ = _new_str  # type: ignore[method-assign]
        _ALIASES_ARE_ACTIVE = True

    def deactivate_union_aliases() -> None:
        """Undo what :func:`.alias.activate` did. This restores the original  `__repr__`
        and `__str__` for `typing.Union`."""
        global _ALIASES_ARE_ACTIVE
        _union_type.__repr__ = _original_repr  # type: ignore[method-assign]
        _union_type.__str__ = _original_str  # type: ignore[method-assign]
        _ALIASES_ARE_ACTIVE = False

    def set_union_alias(union: UnionT, alias: str) -> UnionT:
        """Change how a `typing.Union` is printed. This does not modify `union`.

        Args:
            union (type or type hint): A union.
            alias (str): How to print `union`.

        Returns:
            type or type hint: `union`.
        """
        args = get_args(union) if isinstance(union, _union_type) else (union,)
        for existing_union, existing_alias in _ALIASED_UNIONS.items():
            if set(existing_union) == set(args) and alias != existing_alias:
                if isinstance(union, _union_type):
                    union_str = _original_str(union)
                else:
                    union_str = repr(union)
                raise RuntimeError(
                    f"`{union_str}` already has alias `{existing_alias}`."
                )
        _ALIASED_UNIONS[args] = alias
        return union

else:  # pragma: specific no cover 3.13 3.12 3.11 3.10
    _ALIASED_UNIONS: dict[tuple[Any, ...], TypeAliasType] = {}

[docs] def activate_union_aliases() -> None: """When printing `typing.Union`, replace aliased unions by the aliased names.""" global _ALIASES_ARE_ACTIVE _ALIASES_ARE_ACTIVE = True
[docs] def deactivate_union_aliases() -> None: """When printing `typing.Union`s, print as normal.""" global _ALIASES_ARE_ACTIVE _ALIASES_ARE_ACTIVE = False
[docs] def set_union_alias(union: UnionT, /, alias: str) -> UnionT: """Register a union alias for use in plum's printing of dispatch signatures. This does not modify the given `union` in any way. It only controls how the union is printed when it is registered as a union alias. Args: union (type or type hint): A union type or a single type. alias (str): Alias name for the union. Returns: type or type hint: The given union. """ # Handle both union types and single types, matching pre-3.14 behaviour. args = get_args(union) if isinstance(union, _union_type) else (union,) # Check for conflicting aliases. for existing_union, existing_alias in _ALIASED_UNIONS.items(): if set(existing_union) == set(args) and alias != existing_alias.__name__: union_str = repr(union) raise RuntimeError( f"`{union_str}` already has alias `{existing_alias.__name__}`." ) new_alias = TypeAliasType(alias, union, type_params=()) # type: ignore[misc] _ALIASED_UNIONS[args] = new_alias return union
def _transform_union_alias(x: object, /) -> object: """Transform a Union type hint to a TypeAliasType if it's registered in the alias registry. This is used by plum's dispatch machinery to use aliased names for unions. Args: x (type or type hint): Type hint, potentially a Union. Returns: type or type hint: If `x` is a Union registered in `_ALIASED_UNIONS`, returns the TypeAliasType. Otherwise returns `x` unchanged. """ # Fast path: if aliases are not active, return `x` immediately. if not _ALIASES_ARE_ACTIVE: return x # `TypeAliasType` instances are already transformed, so return as-is. if isinstance(x, TypeAliasType): return x # Get the union args to check if it is registered. args = get_args(x) if isinstance(x, _union_type) else None if args: args_set = set(args) # Look for a matching alias in the registry. for union_args, type_alias in _ALIASED_UNIONS.items(): if set(union_args) == args_set: return type_alias # Not a union or not aliased, so return as-is. return x