Skip to content

Instantly share code, notes, and snippets.

@thepabloaguilar
Created July 24, 2021 04:03
Show Gist options
  • Save thepabloaguilar/5fe6a7279d959e73a351df04407f6197 to your computer and use it in GitHub Desktop.
Save thepabloaguilar/5fe6a7279d959e73a351df04407f6197 to your computer and use it in GitHub Desktop.
Playing with union of protocols
from abc import ABCMeta
from dataclasses import dataclass
from typing import (
Any,
ClassVar,
Iterable,
Protocol,
Union,
get_args,
get_origin,
runtime_checkable,
)
# PROTOCOLS
@runtime_checkable
class OtherIterable(Protocol):
def __len__(self) -> int:
"""Returns the length of an object."""
def __getitem__(self, item: Any):
"""Returns an item based on the received item."""
class UnionProtocol(ABCMeta):
protocols: ClassVar[tuple[Any, ...]]
def __instancecheck__(cls, instance):
return any(
isinstance(instance, protocol) for protocol in cls.protocols
)
class ExtendedItarable(metaclass=UnionProtocol):
protocols = (Iterable, OtherIterable)
# FUNCTIONS
def isinstanceprotocol(
instance: Any,
ref: Union[
type,
tuple[Union[type, tuple[Any, ...]], ...],
Union[type, ...],
],
) -> bool:
if get_origin(ref) == Union:
return any(
isinstance(instance, protocol) for protocol in get_args(ref)
)
return isinstance(instance, ref)
# CLASSES & TYPES
@dataclass
class Card:
rank: str
suit: str
class FrenchDeck:
ranks: list[str] = [str(n) for n in range(2, 11)] + list('JQKA')
suits: list[str] = 'spades diamonds clubs hearts'.split()
_cards: list[Card]
def __init__(self) -> None:
self._cards = [
Card(rank, suit)
for suit in self.suits for rank in self.ranks
]
def __len__(self) -> int:
return len(self._cards)
def __getitem__(self, position: int) -> Card:
return self._cards[position]
ExtendedItarableUnion = Union[Iterable, OtherIterable]
if __name__ == '__main__':
my_list = [1, 2, 3, 4]
deck = FrenchDeck()
# Using `UnionProtocol`
if isinstance(my_list, ExtendedItarable):
print('"my_list" is iterable')
if isinstance(deck, ExtendedItarable):
print('"deck" is iterable')
# Using rewrited function
if isinstanceprotocol(my_list, ExtendedItarableUnion):
print('"my_list" is iterable')
if isinstanceprotocol(deck, ExtendedItarableUnion):
print('"deck" is iterable')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment