Skip to content

Instantly share code, notes, and snippets.

@icoxfog417
Created June 15, 2017 09:37
Show Gist options
  • Save icoxfog417/3ab2c435969a1207d7f57436625c45da to your computer and use it in GitHub Desktop.
Save icoxfog417/3ab2c435969a1207d7f57436625c45da to your computer and use it in GitHub Desktop.
broadcast test
import numpy as np
from keras import backend as K
import tensorflow as tf
def try_broadcast():
x = np.random.randint(10, size=1000)
x = np.reshape(x, (10, 5, 20))
print(x.shape)
y = np.random.randint(10, size=200)
y = np.reshape(y, (10, 20))
print(y.shape)
broadcast_np = np.tensordot(x, np.transpose(y), 1)
print(broadcast_np.shape) # should be 10 x 5 x 10
x_k = K.variable(value=x)
y_k = K.variable(value=y)
broadcast = tf.tensordot(x_k, K.transpose(y_k), 1)
print(K.get_value(broadcast).shape)
if __name__ == "__main__":
try_broadcast()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment