Created
March 9, 2017 14:56
-
-
Save ajsyp/3113650c9debf960c64722d3b6e516f6 to your computer and use it in GitHub Desktop.
Recreates the bug with multi-GPU TensorFlow in Kur
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import keras.models as M | |
import keras.layers as L | |
from kur.utils.parallelism import make_parallel | |
# Pretend to have some 32 x 32 images. | |
input = x = L.Input(shape=(32, 32)) | |
# Shape: (samples, 32, 32) | |
x = L.TimeDistributed( | |
L.Dense(100) | |
)(x) | |
# Shape: (samples, 32, 100) | |
x = L.LSTM(50, return_sequences=True)(x) | |
# Shape: (samples, 32, 50) | |
x = L.TimeDistributed( | |
L.Dense(20) | |
)(x) | |
# Shape: (samples, 32, 20) | |
# Compile the model. | |
model = M.Model(input=input, output=x) | |
model.summary() | |
# Try to parallelize it. | |
# It will fail for any value of `gpu_count`. | |
model = make_parallel(model, 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment