Last active
August 29, 2015 14:11
-
-
Save shoyer/3e36af0a8196c82d4b42 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numpy.lib.stride_tricks import as_strided | |
def broadcast_to(array, shape): | |
"""Expand a numpy.ndarray to a new shape according to broadcasting rules | |
""" | |
array = np.asarray(array) | |
# will raise ValueError if shapes incompatible | |
np.nditer((array,), itershape=shape) | |
strides = ([0] * (len(shape) - array.ndim) | |
+ [0 if size == 1 else stride | |
for size, stride | |
in zip(array.shape, array.strides)]) | |
return as_strided(array, shape=shape, strides=strides) | |
def broadcast_shape(*args): | |
return np.nditer(args, flags=['multi_index']).shape | |
def broadcast_arrays(*args): | |
shape = broadcast_shape(*args) | |
return [broadcast_to(array, shape) for array in args] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I think broadcast_to and broadcast_arrays could also be a one-liners: