Skip to content

Instantly share code, notes, and snippets.

@ajsyp
Created March 9, 2017 14:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ajsyp/3113650c9debf960c64722d3b6e516f6 to your computer and use it in GitHub Desktop.
Save ajsyp/3113650c9debf960c64722d3b6e516f6 to your computer and use it in GitHub Desktop.
Recreates the bug with multi-GPU TensorFlow in Kur
import keras.models as M
import keras.layers as L
from kur.utils.parallelism import make_parallel
# Pretend to have some 32 x 32 images.
input = x = L.Input(shape=(32, 32))
# Shape: (samples, 32, 32)
x = L.TimeDistributed(
L.Dense(100)
)(x)
# Shape: (samples, 32, 100)
x = L.LSTM(50, return_sequences=True)(x)
# Shape: (samples, 32, 50)
x = L.TimeDistributed(
L.Dense(20)
)(x)
# Shape: (samples, 32, 20)
# Compile the model.
model = M.Model(input=input, output=x)
model.summary()
# Try to parallelize it.
# It will fail for any value of `gpu_count`.
model = make_parallel(model, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment