Skip to content

Instantly share code, notes, and snippets.

@nirum
Last active February 3, 2018 23:58
Show Gist options
  • Save nirum/79d8e14da106c77c02c1 to your computer and use it in GitHub Desktop.
Save nirum/79d8e14da106c77c02c1 to your computer and use it in GitHub Desktop.
Unfold a muli-dimensional numpy array along an arbitrary axis
import numpy as np
def unfold(arr, ax):
"""
Unfolds a given array along the given axis
"""
return np.rollaxis(arr, ax, 0).reshape(arr.shape[ax], -1)
def test_unfold():
"""
Test the unfold function
"""
tensor = np.arange(24).reshape(2,3,4)
compare_unfolding = lambda tensor, ax, matrix: np.allclose(unfold(tensor, ax), matrix)
assert compare_unfolding(tensor, 0, np.vstack((np.arange(12), np.arange(12,24))))
assert compare_unfolding(tensor, 1, np.hstack((np.arange(12).reshape(3,4), np.arange(12,24).reshape(3,4))))
assert compare_unfolding(tensor, 2, np.arange(24).reshape(6,4).T)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment