Created
February 16, 2017 20:06
-
-
Save maximus009/6d8a2e4d6d56d02e5581eb11ab1b2ea2 to your computer and use it in GitHub Desktop.
Calculate the outer product/bilinear projection in Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Lambda | |
from keras import backend as K | |
from numpy import newaxis | |
from keras.models import Model, Input | |
def outer_product(inputs): | |
""" | |
inputs: list of two tensors (of equal dimensions, | |
for which you need to compute the outer product | |
""" | |
x, y = inputs | |
batchSize = K.shape(x)[0] | |
outerProduct = x[:,:, newaxis] * y[:,newaxis,:] | |
outerProduct = K.reshape(outerProduct, (batchSize, -1)) | |
# returns a flattened batch-wise set of tensors | |
return outerProduct | |
def main(input_dim=32): | |
inputX = Input(shape=(input_dim, )) | |
inputY = Input(shape=(input_dim, )) | |
bilinearProduct = Lambda(outer_product, output_shape=(input_dim**2, ))([inputX, inputY]) | |
#tensor 'bilinearProduct' contains the batchwise outer product | |
model = Model(input=[inputX, inputY], output=bilinearProduct) | |
#you can run model.predict to get the output of the outer product of each sample in your batch | |
return model |
input_a = np.random.rand(10,32)
input_b = np.random.rand(10,32)
model = main()
model.summary()
y = model.predict([input_a,input_b])
# y will be your output flattened tensor
print(y.shape)
Hope this helps!
what does your newaxis mean?
numpy.newaxis
is used to extend the dimension of the vector.
https://stackoverflow.com/questions/29241056/how-does-numpy-newaxis-work-and-when-to-use-it
Excellent.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
coud you show how to call this mode:
`x=np.random.rand(10,1,32)
input_a = np.reshape(x, (10,1, 32))
input_b = np.reshape(x, (10,1,32))
print input_a.shape
print input_b.shape
model=main()
y=model([input_a,input_b])
print(y.shape)`