Skip to content

Instantly share code, notes, and snippets.

@gmanny
Created August 5, 2022 10:07
Show Gist options
  • Save gmanny/c809f06c49a2ce96017be181a0c29c9e to your computer and use it in GitHub Desktop.
Save gmanny/c809f06c49a2ce96017be181a0c29c9e to your computer and use it in GitHub Desktop.
Python generic type argument resolution
def resolve_type_argument(query_type: type, target_type: type | GenericAlias, argument: TypeVar) -> type | TypeVar:
"Resolves a given TypeVar for a generic type `query_type` when supplied by `target_type`"
type_params = query_type.__parameters__
type_arguments = resolve_type_arguments(query_type, target_type)
params_to_args = {key: value for (key, value) in zip(type_params, type_arguments)}
return params_to_args[argument]
def resolve_type_arguments(query_type: type, target_type: type | GenericAlias) -> tuple[type | TypeVar, ...]:
"""
This code was taken from https://stackoverflow.com/a/69862817/579817
Resolves the type arguments of the query type as supplied by the target type of any of its bases.
Operates in a tail-recursive fashion, and drills through the hierarchy of generic base types breadth-first in left-to-right order to correctly identify the type arguments that need to be supplied to the next recursive call.
raises a TypeError if they target type was not an instance of the query type.
:param query_type: Must be supplied without args (e.g. Mapping not Mapping[KT,VT]
:param target_type: Must be supplied with args (e.g. Mapping[KT, T] or Mapping[str, int] not Mapping)
:return: A tuple of the arguments given via target_type for the type parameters of for the query_type, if it has any parameters, otherwise an empty tuple. These arguments may themselves be TypeVars.
"""
target_origin = get_origin(target_type)
if target_origin is None:
if target_type is query_type:
return target_type.__parameters__
else:
target_origin = target_type
supplied_args = None
else:
supplied_args = get_args(target_type)
if target_origin is query_type:
return supplied_args
param_set = set()
param_list = []
for each_base in target_origin.__orig_bases__:
each_origin = get_origin(each_base)
if each_origin is not None:
# each base is of the form class[T], which is a private type _GenericAlias, but it is formally documented to have __parameters__
for each_param in each_base.__parameters__:
if each_param not in param_set:
param_set.add(each_param)
param_list.append(each_param)
if issubclass(each_origin, query_type):
if supplied_args is not None and len(supplied_args) > 0:
params_to_args = {key: value for (key, value) in zip(param_list, supplied_args)}
resolved_args = tuple(params_to_args[each] for each in each_base.__parameters__)
return resolve_type_arguments(
query_type, each_base[resolved_args]
) # each_base[args] fowards the args to each_base, it is not quite equivalent to GenericAlias(each_origin, resolved_args)
else:
return resolve_type_arguments(query_type, each_base)
elif issubclass(each_base, query_type):
return resolve_type_arguments(query_type, each_base)
if not issubclass(target_origin, query_type):
raise ValueError(f"{target_type} is not a subclass of {query_type}")
else:
return ()
from types import GenericAlias
from typing import Generic, TypeVar
from resolve_type_arguments import resolve_type_argument, resolve_type_arguments
T = TypeVar("T")
U = TypeVar("U")
Q = TypeVar("Q")
R = TypeVar("R")
W = TypeVar("W")
X = TypeVar("X")
Y = TypeVar("Y")
Z = TypeVar("Z")
class A(Generic[T, U, Q, R]):
...
class NestedA(Generic[T, U, Q]):
...
class NestedB(Generic[T]):
...
class NoParams:
...
class B(NoParams, NestedA[U, Q, U], A[int, NestedA[Q, Q, Q], Q, U], NestedB[R]):
...
class C(B[T, str, int]):
...
class D(C[int]):
...
class E(D):
...
class F(E):
...
class G(Generic[T]):
...
class H(Generic[T]):
...
class I(G[int]):
...
class J(I, H[str]):
...
def test_resolve_type_arguments():
"""
Various test cases for resolve_type_arguments
Taken from examples in https://stackoverflow.com/a/69862817/579817
"""
def verify_type_arguments(query_type: type, target_type: type | GenericAlias, *verify_strs: str) -> None:
arg_tuple = resolve_type_arguments(query_type, target_type)
if len(verify_strs) == 0:
verify_str = "()"
elif len(verify_strs) == 1:
verify_str = f"({verify_strs[0]},)"
else:
verify_str = f"({', '.join(verify_strs)})"
verify_str = verify_str.replace("__main__", __name__)
assert str(arg_tuple) == verify_str
verify_type_arguments(A, A, "~T", "~U", "~Q", "~R")
verify_type_arguments(A, A[W, X, Y, Z], "~W", "~X", "~Y", "~Z")
verify_type_arguments(A, B, "<class 'int'>", "__main__.NestedA[~Q, ~Q, ~Q]", "~Q", "~U")
verify_type_arguments(A, B[W, X, Y], "<class 'int'>", "__main__.NestedA[~X, ~X, ~X]", "~X", "~W")
verify_type_arguments(B, B, "~U", "~Q", "~R")
verify_type_arguments(B, B[W, X, Y], "~W", "~X", "~Y")
verify_type_arguments(A, C, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "~T")
verify_type_arguments(A, C[W], "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "~W")
verify_type_arguments(B, C, "~T", "<class 'str'>", "<class 'int'>")
verify_type_arguments(B, C[W], "~W", "<class 'str'>", "<class 'int'>")
verify_type_arguments(C, C, "~T")
verify_type_arguments(C, C[W], "~W")
verify_type_arguments(A, D, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "<class 'int'>")
verify_type_arguments(B, D, "<class 'int'>", "<class 'str'>", "<class 'int'>")
verify_type_arguments(C, D, "<class 'int'>")
verify_type_arguments(D, D)
verify_type_arguments(A, E, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "<class 'int'>")
verify_type_arguments(B, E, "<class 'int'>", "<class 'str'>", "<class 'int'>")
verify_type_arguments(C, E, "<class 'int'>")
verify_type_arguments(D, E)
verify_type_arguments(E, E)
verify_type_arguments(A, F, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "<class 'int'>")
verify_type_arguments(B, F, "<class 'int'>", "<class 'str'>", "<class 'int'>")
verify_type_arguments(C, F, "<class 'int'>")
verify_type_arguments(D, F)
verify_type_arguments(E, F)
verify_type_arguments(F, F)
verify_type_arguments(G, J, "<class 'int'>")
def test_resolve_type_argument():
"Test resolving a single type argument"
def verify_type_arguments(query_type: type, target_type: type | GenericAlias, *verify_strs: str) -> None:
parameters = query_type.__parameters__
assert len(parameters) == len(verify_strs)
for parameter, verify_str in zip(parameters, verify_strs):
argument = resolve_type_argument(query_type, target_type, parameter)
verify_str = verify_str.replace("__main__", __name__)
assert str(argument) == verify_str
verify_type_arguments(A, A, "~T", "~U", "~Q", "~R")
verify_type_arguments(A, A[W, X, Y, Z], "~W", "~X", "~Y", "~Z")
verify_type_arguments(A, B, "<class 'int'>", "__main__.NestedA[~Q, ~Q, ~Q]", "~Q", "~U")
verify_type_arguments(A, B[W, X, Y], "<class 'int'>", "__main__.NestedA[~X, ~X, ~X]", "~X", "~W")
verify_type_arguments(B, B, "~U", "~Q", "~R")
verify_type_arguments(B, B[W, X, Y], "~W", "~X", "~Y")
verify_type_arguments(A, C, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "~T")
verify_type_arguments(A, C[W], "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "~W")
verify_type_arguments(B, C, "~T", "<class 'str'>", "<class 'int'>")
verify_type_arguments(B, C[W], "~W", "<class 'str'>", "<class 'int'>")
verify_type_arguments(C, C, "~T")
verify_type_arguments(C, C[W], "~W")
verify_type_arguments(A, D, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "<class 'int'>")
verify_type_arguments(B, D, "<class 'int'>", "<class 'str'>", "<class 'int'>")
verify_type_arguments(C, D, "<class 'int'>")
verify_type_arguments(D, D)
verify_type_arguments(A, E, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "<class 'int'>")
verify_type_arguments(B, E, "<class 'int'>", "<class 'str'>", "<class 'int'>")
verify_type_arguments(C, E, "<class 'int'>")
verify_type_arguments(D, E)
verify_type_arguments(E, E)
verify_type_arguments(A, F, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "<class 'int'>")
verify_type_arguments(B, F, "<class 'int'>", "<class 'str'>", "<class 'int'>")
verify_type_arguments(C, F, "<class 'int'>")
verify_type_arguments(D, F)
verify_type_arguments(E, F)
verify_type_arguments(F, F)
verify_type_arguments(G, J, "<class 'int'>")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment