-
-
Save rayansostenes/baabf3d783641be3fccb2b92fbf466bb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import TYPE_CHECKING, Iterable, TypeVar, Union, cast | |
if TYPE_CHECKING: | |
from typing import reveal_type # type: ignore | |
else: | |
reveal_type = lambda x: x # noqa | |
_T = TypeVar("_T") | |
UnfoldIterable = Iterable[Union[_T, "UnfoldIterable[_T]"]] | |
UnfoldList = list[Union[_T, list[_T]]] | |
# If I don't type here pyright infers test_list a list[Unkown] | |
test_list: UnfoldList = [1, 2, [3], [4, [5, 6]], [[7]], 8] | |
result_list = [1, 2, 3, 4, 5, 6, 7, 8] | |
test_list2: UnfoldList = ["one", "two", ["three"], ["four", ["five", "six"]], [["seven"]], "eight"] | |
result_list2 = ["one", "two", "three", "four", "five", "six", "seven", "eight"] | |
def unfold(iterable: UnfoldIterable[_T]) -> Iterable[_T]: | |
for element in iterable: | |
if isinstance(element, str): | |
yield cast(_T, element) | |
elif isinstance(element, Iterable): | |
yield from unfold(element) | |
else: | |
yield element | |
def test_unfold(): | |
assert result_list == list(unfold(test_list)) | |
assert result_list2 == list(unfold(test_list2)) | |
reveal_type(unfold(test_list)) # Type of "unfold(test_list)" is "Iterable[int]" | |
reveal_type(unfold(test_list2)) # Type of "unfold(test_list2)" is "Iterable[str]" | |
if __name__ == "__main__": | |
test_unfold() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment