Last active
April 5, 2017 18:57
-
-
Save ColaColin/1a8c7ef9a76a96a3868756bbaafbd4a0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import numpy as np | |
from keras.models import Model | |
from keras.layers import Input | |
from keras.layers.core import Dense | |
import keras.backend as K | |
from keras.optimizers import RMSprop | |
learn_rate = 0.00005 | |
batch_size = 512 | |
lamb = 10 | |
def get_gradient_norm(model, y_pred): | |
weights = model.trainable_weights | |
gradients = model.optimizer.get_gradients(K.mean(y_pred), weights) | |
acc = K.sum(K.square(gradients[0])) | |
for g in gradients: | |
s = K.sum(K.square(g)) | |
if acc == None: | |
acc = s | |
else: | |
acc = s + acc | |
return K.sqrt(acc) | |
# this loss function, mainly defined through the function above, | |
# appears to be the culprit. I don't see what is wrong with it however | |
def make_w_reg_loss(model): | |
lvar = K.variable(lamb, name="Lambda") | |
def foo(y_true, y_pred): | |
gnorm = get_gradient_norm(model, y_pred) | |
return lvar * K.square(gnorm - 1) | |
return foo | |
def make_critic(): | |
inp = Input(shape=(2,), name='critic input') | |
x = Dense(128, activation="tanh")(inp) | |
x = Dense(128, activation="tanh")(x) | |
x = Dense(1)(x) | |
model = Model(inp, x, name="Toy Critic") | |
return model | |
toy_data = np.zeros((7 * 100, 2)) | |
critic = make_critic() | |
critic.compile(loss=make_w_reg_loss(critic), optimizer=RMSprop(learn_rate)) | |
# not supposed to learn anything, just don't explode please? | |
X_real_batch = toy_data[np.random.choice(toy_data.shape[0], size=batch_size)] | |
reg_loss = critic.train_on_batch(X_real_batch, -np.ones(X_real_batch.shape[0])) | |
print reg_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Traceback (most recent call last): | |
File "<ipython-input-31-ba6d4d1a6ec9>", line 1, in <module> | |
runfile('/UltraKeks/Dev/nn/wgan/minimal_disconnect.py', wdir='/UltraKeks/Dev/nn/wgan') | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/spyder/utils/site/sitecustomize.py", line 866, in runfile | |
execfile(filename, namespace) | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/spyder/utils/site/sitecustomize.py", line 94, in execfile | |
builtins.execfile(filename, *where) | |
File "/UltraKeks/Dev/nn/wgan/minimal_disconnect.py", line 54, in <module> | |
reg_loss = critic.train_on_batch(X_real_batch, -np.ones(X_real_batch.shape[0])) | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/keras/engine/training.py", line 1619, in train_on_batch | |
self._make_train_function() | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/keras/engine/training.py", line 1001, in _make_train_function | |
self.total_loss) | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/keras/optimizers.py", line 197, in get_updates | |
grads = self.get_gradients(loss, params) | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/keras/optimizers.py", line 47, in get_gradients | |
grads = K.gradients(loss, params) | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/keras/backend/theano_backend.py", line 1108, in gradients | |
return T.grad(loss, variables) | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/theano/gradient.py", line 539, in grad | |
handle_disconnected(elem) | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/theano/gradient.py", line 526, in handle_disconnected | |
raise DisconnectedInputError(message) | |
DisconnectedInputError: | |
Backtrace when that variable is created: | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/spyder/utils/site/sitecustomize.py", line 866, in runfile | |
execfile(filename, namespace) | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/spyder/utils/site/sitecustomize.py", line 94, in execfile | |
builtins.execfile(filename, *where) | |
File "/UltraKeks/Dev/nn/wgan/minimal_disconnect.py", line 48, in <module> | |
critic = make_critic() | |
File "/UltraKeks/Dev/nn/wgan/minimal_disconnect.py", line 41, in make_critic | |
x = Dense(1)(x) | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/keras/engine/topology.py", line 528, in __call__ | |
self.build(input_shapes[0]) | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/keras/layers/core.py", line 833, in build | |
constraint=self.bias_constraint) | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/keras/engine/topology.py", line 364, in add_weight | |
weight = K.variable(initializer(shape), dtype=K.floatx(), name=name) | |
File "/home/cclausen/anaconda2/lib/python2.7/site-packages/keras/backend/theano_backend.py", line 146, in variable | |
strict=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment