Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@shoyer
Last active August 29, 2015 14:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shoyer/3e36af0a8196c82d4b42 to your computer and use it in GitHub Desktop.
Save shoyer/3e36af0a8196c82d4b42 to your computer and use it in GitHub Desktop.
import numpy as np
from numpy.lib.stride_tricks import as_strided
def broadcast_to(array, shape):
"""Expand a numpy.ndarray to a new shape according to broadcasting rules
"""
array = np.asarray(array)
# will raise ValueError if shapes incompatible
np.nditer((array,), itershape=shape)
strides = ([0] * (len(shape) - array.ndim)
+ [0 if size == 1 else stride
for size, stride
in zip(array.shape, array.strides)])
return as_strided(array, shape=shape, strides=strides)
def broadcast_shape(*args):
return np.nditer(args, flags=['multi_index']).shape
def broadcast_arrays(*args):
shape = broadcast_shape(*args)
return [broadcast_to(array, shape) for array in args]
@mwiebe
Copy link

mwiebe commented Dec 12, 2014

I think broadcast_to and broadcast_arrays could also be a one-liners:

def broadcast_to(array, shape)
    return np.nditer((array,), flags=['multi_index'], itershape=shape).itviews[0]

def broadcast_arrays(*args):
    return np.nditer(args, flags=['multi_index']).itviews

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment