Skip to content

Instantly share code, notes, and snippets.

@abhaikollara
Created April 13, 2017 10:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save abhaikollara/021cd5a82b7d8020b9eedde7866cbfda to your computer and use it in GitHub Desktop.
Save abhaikollara/021cd5a82b7d8020b9eedde7866cbfda to your computer and use it in GitHub Desktop.
Recurrent shop issues
import numpy as np
def get_data(n_samples, max_len):
data = np.zeros([n_samples, max_len])
labels = np.zeros([n_samples, 1])
for row, label in zip(data, labels):
length = np.random.randint(1, max_len+1)
seq = np.random.normal(size=length)
row[:length] += seq
if length >= max_len/2:
label += 1
data = np.expand_dims(data, axis=-1)
return data, labels
from RWACell import RWA
import data
from keras.layers import Input
from keras.models import Model
input_dim = 1
output_dim = 1
timesteps = 100
rwa = RWA(input_dim, output_dim)
inp = Input((timesteps, input_dim))
out = rwa(inp)
RWAModel = Model(inp, out)
RWAModel.compile(loss='binary_crossentropy', optimizer='adam')
train_data, train_labels = data.get_data(n_samples=1000, max_len=100)
RWAModel.fit(train_data, train_labels)
'''
Machine Learning on Sequential Data Using a Recurrent
Weighted Average
- Jared Ostmeyer
- Lindsay Cowell
https://arxiv.org/pdf/1703.01253.pdf
'''
import tensorflow as tf
from recurrentshop import RecurrentModel
from keras.layers import Dense, Activation, Lambda, Input, concatenate
from keras.layers import add, multiply
from keras.backend import exp
def RWA(input_dim, output_dim):
x = Input((input_dim, ))
h_tm1 = Input((output_dim, ))
n_tm1 = Input((output_dim, ))
d_tm1 = Input((output_dim, ))
x_h = concatenate([x, h_tm1]) # Concatenated vector
u = Dense(output_dim)(x)
g = Dense(output_dim, activation='tanh')(x_h)
a = Dense(output_dim, use_bias=False)(x_h)
e_a = Lambda(lambda x: exp(x))(a)
z = multiply([u, g])
nt = add([n_tm1, multiply([z, e_a])])
dt = add([d_tm1, e_a])
ht = Lambda(lambda x: tf.divide(1.0, x))(dt)
ht = Activation('tanh')(ht)
return RecurrentModel(input=x, output=ht,
initial_states=[h_tm1, n_tm1, d_tm1],
final_states=[ht, nt, dt],)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment