Last active
February 7, 2020 20:03
-
-
Save chaserileyroberts/e44533fa67ee2ebe4466ad73bb980029 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import tensornetwork as tn | |
class TNLayer(tf.keras.layers.Layer): | |
def __init__(self): | |
super(TNLayer, self).__init__() | |
# Create the variables for the layer. | |
self.a_var = tf.Variable(tf.random.normal( | |
shape=(32, 32, 2), stddev=1.0/32.0), | |
name="a", trainable=True) | |
self.b_var = tf.Variable(tf.random.normal(shape=(32, 32, 2), stddev=1.0/32.0), | |
name="b", trainable=True) | |
self.bias = tf.Variable(tf.zeros(shape=(32, 32)), name="bias", trainable=True) | |
def call(self, inputs): | |
# Define the contraction. | |
# We break it out so we can parallelize a batch using | |
# tf.vectorized_map (see below). | |
def f(input_vec, a_var, b_var, bias_var): | |
# Reshape to a matrix instead of a vector. | |
input_vec = tf.reshape(input_vec, (32,32)) | |
# Now we create the network. | |
a = tn.Node(a_var, backend="tensorflow") | |
b = tn.Node(b_var, backend="tensorflow") | |
x_node = tn.Node(input_vec, backend="tensorflow") | |
a[1] ^ x_node[0] | |
b[1] ^ x_node[1] | |
a[2] ^ b[2] | |
# The TN should now look like this | |
# | | | |
# a --- b | |
# \ / | |
# x | |
# Now we begin the contraction. | |
c = a @ x_node | |
result = (c @ b).tensor | |
# To make the code shorter, we also could've used Ncon. | |
# The above few lines of code is the same as this: | |
# result = tn.ncon([x, a_var, b_var], [[1, 2], [-1, 1, 3], [-2, 2, 3]]) | |
# Finally, add bias. | |
return result + bias_var | |
# To deal with a batch of items, we can use the tf.vectorized_map | |
# function. | |
# https://www.tensorflow.org/api_docs/python/tf/vectorized_map | |
result = tf.vectorized_map( | |
lambda vec: f(vec, self.a_var, self.b_var, self.bias), inputs) | |
return tf.nn.relu(tf.reshape(result, (-1, 1024))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment