Created
May 21, 2020 22:38
-
-
Save timothymillar/c2fc61e5255f0682ccb56c2b624cd200 to your computer and use it in GitHub Desktop.
Generic flatten function for Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def flatten(obj, *args, depth=1): | |
"""Generic function for flattening nested sequences. | |
Parameters | |
---------- | |
obj : iterable | |
iterable sequence | |
args : list of type | |
one or more iterable types to flatten | |
depth : int | |
limit for recursion depth, use -1 for no limit | |
Returns | |
------- | |
generator | |
flattened sequence of values | |
Notes | |
----- | |
Specifying depth -1 when flattening strings will | |
result in a RecursionError. | |
Examples | |
-------- | |
>>> seq = (1, 2, 3, [4, 5, (6,)], 7, [8, 9]) | |
>>> tuple(flatten(seq, tuple, list, depth=-1)) | |
(1, 2, 3, 4, 5, 6, 7, 8, 9) | |
""" | |
if depth == 0: | |
yield obj | |
elif any(isinstance(obj, t) for t in args): | |
for i in obj: | |
yield from flatten(i, *args, depth=depth-1) | |
else: | |
yield obj |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment