Skip to content

Instantly share code, notes, and snippets.

@andiwand
Created February 20, 2020 10:24
Show Gist options
  • Save andiwand/5e2645107791370235613e9b5039d654 to your computer and use it in GitHub Desktop.
Save andiwand/5e2645107791370235613e9b5039d654 to your computer and use it in GitHub Desktop.
Flatten nested Python list with numbers and numpy arrays.
import numbers
import numpy as np
def types(x):
if type(x) is np.ndarray: return x.dtype
if isinstance(x, numbers.Number): return type(x)
result = []
for i in x: result.append(types(i))
return result
def shapes(x):
if type(x) is np.ndarray: return x.shape
if isinstance(x, numbers.Number): return 1
result = []
for i in x: result.append(shapes(i))
return result
def flatten(x):
if type(x) is np.ndarray: return x.flatten()
if isinstance(x, numbers.Number): return np.array([x])
result = np.array([])
for i in x: result = np.append(result, flatten(i))
return result
def restore(x, shapes, types):
def recursive(x, shapes, types):
if isinstance(shapes, tuple):
size = np.prod(shapes)
return x[:size].reshape(shapes).astype(types), size
if shapes == 1:
return types(x[0]), 1
result = []
offset = 0
for s, t in zip(shapes, types):
tmp = recursive(x[offset:], s, t)
result.append(tmp[0])
offset += tmp[1]
return result, offset
return recursive(x, shapes, types)[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment