Skip to content

Instantly share code, notes, and snippets.

@ricsonc
Last active June 9, 2022 05:19
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ricsonc/0ac6d0b9fd13c42c1c154eb4ccd7fab3 to your computer and use it in GitHub Desktop.
Save ricsonc/0ac6d0b9fd13c42c1c154eb4ccd7fab3 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python2
import random
import tensorflow as tf
import numpy as np
batch = 16
length = 40
in_dim = 1
hidden_dim = 8
out_dim = 1
inputs = tf.placeholder(tf.float32, shape = (batch, length, in_dim))
inputs_split = tf.unstack(inputs, axis = 1)
weight_output = tf.tile(tf.get_variable('output', shape = [1, out_dim, hidden_dim]), [batch, 1, 1])
def rnn(inputs_split):
bias = tf.get_variable('bias', shape = [hidden_dim, 1])
weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])
hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
for i, input in enumerate(inputs_split):
input = tf.reshape(input, (batch, in_dim, 1))
last_state = hidden_states[-1]
hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
hidden_states.append(hidden)
return hidden_states[-1]
def attention_rnn(inputs_split):
bias = tf.get_variable('bias', shape = [hidden_dim, 1])
weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, 2 * hidden_dim]), [batch, 1, 1])
weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])
hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
for i, input in enumerate(inputs_split):
input = tf.reshape(input, (batch, in_dim, 1))
last_state = hidden_states[-1] if (i > 0) else tf.zeros((batch, hidden_dim, 1), tf.float32)
if len(hidden_states) > 1:
logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
probs = tf.nn.softmax(logits)
probs = tf.reshape(probs, (batch, -1, 1, 1))
context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
else:
context = tf.zeros_like(last_state)
last_state = tf.concat([last_state, context], axis = 1)
hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
hidden_states.append(hidden)
return hidden_states[-1]
last_state = attention_rnn(inputs_split)
output = tf.matmul(weight_output, last_state)
target = tf.placeholder(tf.float32, shape = output.shape)
loss = tf.reduce_mean(tf.square(output - target))
opt = tf.train.AdamOptimizer().minimize(loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
input_data = np.random.rand(batch, length, 1) + 2 * np.random.rand(batch, 1, 1)
target_data = np.max(input_data, axis = 1, keepdims = True)
for i in range(500):
_, loss_ = sess.run([opt, loss], feed_dict = {inputs: input_data, target: target_data})
if not i % 10:
print loss_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment