Skip to content

Instantly share code, notes, and snippets.

@cleoold
Last active December 31, 2020 07:34
Show Gist options
  • Save cleoold/48f244eee94a01f7082613347b916201 to your computer and use it in GitHub Desktop.
Save cleoold/48f244eee94a01f7082613347b916201 to your computer and use it in GitHub Desktop.
example overloading callable type based on its argument type (tested on py 3.7+)
from __future__ import annotations
from inspect import getattr_static, isclass
from typing import Callable, Generic, Iterable, List, TypeVar, overload
TSource_co = TypeVar('TSource_co', covariant=True)
TResult = TypeVar('TResult')
class Enumerable(Generic[TSource_co]):
def __init__(self, iter_: Iterable[TSource_co]) -> None:
super().__init__()
self._iter = iter_
@overload
def select(self, selector: Callable[[TSource_co], TResult]) -> Enumerable[TResult]: ...
@overload
def select(self, selector: Callable[[TSource_co, int], TResult]) -> Enumerable[TResult]: ...
def select(self, selector: Callable):
# function
code = getattr(selector, '__code__', None)
if code:
is_static = not hasattr(selector, '__self__')
argcount = code.co_argcount
if (is_static and argcount == 1) or (not is_static and argcount == 2):
has_index = False
else:
has_index = True
# object/class
else:
if isclass(selector):
is_static = False
argcount = selector.__init__.__code__.co_argcount
else:
is_static = isinstance(getattr_static(selector, '__call__'), staticmethod)
argcount = selector.__call__.__code__.co_argcount
if (is_static and argcount == 1) or (not is_static and argcount == 2):
has_index = False
else:
has_index = True
if has_index:
wrapped = (selector(elem, i) for i, elem in enumerate(self._iter))
else:
wrapped = (selector(elem) for elem in self._iter)
return Enumerable(wrapped)
def to_list(self) -> List[TSource_co]:
return [e for e in self._iter]
def Test_Select():
class instance1:
def __call__(self, e: str) -> str:
return f'{e}_'
def method(self, e: str) -> str:
return f'{e}_'
class instance2:
def __call__(self, e: str, i: int) -> str:
return f'{e}_{i}'
def method(self, e: str, i: int) -> str:
return f'{e}_{i}'
class class1:
@classmethod
def __call__(cls, e: str) -> str:
return f'{e}_'
@classmethod
def method(cls, e: str) -> str:
return f'{e}_'
class class2:
@classmethod
def __call__(cls, e: str, i: int) -> str:
return f'{e}_{i}'
@classmethod
def method(cls, e: str, i: int) -> str:
return f'{e}_{i}'
class static1:
@staticmethod
def __call__(e: str) -> str:
return f'{e}_'
@staticmethod
def method(e: str) -> str:
return f'{e}_'
class static2:
@staticmethod
def __call__(e: str, i: int) -> str:
return f'{e}_{i}'
@staticmethod
def method(e: str, i: int) -> str:
return f'{e}_{i}'
class ctor:
def __init__(self, e: str) -> None:
self.e = e
def __repr__(self) -> str:
return f'<ctor {self.e}>'
lst = ['a', 'b', 'c', 'd']
en = Enumerable(lst)
assert en.select(lambda e: f'{e}_').to_list() == ['a_', 'b_', 'c_', 'd_']
assert en.select(lambda e, i: f'{e}_{i}').to_list() == ['a_0', 'b_1', 'c_2', 'd_3']
assert en.select(instance1()).to_list() == ['a_', 'b_', 'c_', 'd_']
assert en.select(instance1().method).to_list() == ['a_', 'b_', 'c_', 'd_']
assert en.select(instance2()).to_list() == ['a_0', 'b_1', 'c_2', 'd_3']
assert en.select(instance2().method).to_list() == ['a_0', 'b_1', 'c_2', 'd_3']
assert en.select(class1()).to_list() == ['a_', 'b_', 'c_', 'd_']
assert en.select(class1().method).to_list() == ['a_', 'b_', 'c_', 'd_']
assert en.select(class1.method).to_list() == ['a_', 'b_', 'c_', 'd_']
assert en.select(class2()).to_list() == ['a_0', 'b_1', 'c_2', 'd_3']
assert en.select(class2().method).to_list() == ['a_0', 'b_1', 'c_2', 'd_3']
assert en.select(class2.method).to_list() == ['a_0', 'b_1', 'c_2', 'd_3']
assert en.select(static1()).to_list() == ['a_', 'b_', 'c_', 'd_']
assert en.select(static1().method).to_list() == ['a_', 'b_', 'c_', 'd_']
assert en.select(static1.method).to_list() == ['a_', 'b_', 'c_', 'd_']
assert en.select(static2()).to_list() == ['a_0', 'b_1', 'c_2', 'd_3']
assert en.select(static2().method).to_list() == ['a_0', 'b_1', 'c_2', 'd_3']
assert en.select(static2.method).to_list() == ['a_0', 'b_1', 'c_2', 'd_3']
assert en.select(ctor).select(ctor.__repr__).to_list() == ['<ctor a>', '<ctor b>', '<ctor c>', '<ctor d>']
# equivalent
assert en.select(lambda e: ctor(e)).select(lambda e: repr(e)).to_list() \
== ['<ctor a>', '<ctor b>', '<ctor c>', '<ctor d>']
Test_Select()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment