Skip to content

Instantly share code, notes, and snippets.

@abhayraw1
Last active April 17, 2019 06:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save abhayraw1/af7ab2bb524b0392c6495c4d8d90c3f4 to your computer and use it in GitHub Desktop.
Save abhayraw1/af7ab2bb524b0392c6495c4d8d90c3f4 to your computer and use it in GitHub Desktop.
Simple Feed Forward Neural Net Using tf.keras written for reproducing issue #27316 mentioned in tensorflow/issues
import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow.keras.activations as Z
from tensorflow.keras import Model
from pprint import pprint
def feed_forward_nn(input_shape):
with tf.variable_scope("keras_"):
_input_ = L.Input(shape=input_shape)
layer_1 = L.Dense(5, Z.relu)(_input_)
output = L.Dense(1, Z.tanh)(layer_1)
return Model(inputs=_input_, outputs=output)
def main():
model1 = feed_forward_nn((3, ))
model2 = feed_forward_nn((3, ))
print("Model weights before:\n")
pprint(model1.get_weights())
pprint(model2.get_weights())
print()
model1.set_weights(model2.get_weights())
are_wts_equal = all([(i == j).all() for i, j in zip(model1.get_weights(), model2.get_weights())])
print("Model weights after:\n")
pprint(model1.get_weights())
pprint(model2.get_weights())
print()
print("\nAre Model Weights Equal (using <tf.keras.Model>.get_weights()): {}".format(are_wts_equal))
print("\nIf YES then these should be equal!!!")
some_input = np.random.random((1, 3))
print(model1.predict(some_input))
print(model2.predict(some_input))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
print(model1.predict(some_input))
print(model2.predict(some_input))
print(sess.run(model1.trainable_weights))
print(sess.run(model2.trainable_weights))
return model1, model2
if __name__ == "__main__":
model1, model2 = main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment