Skip to content

Instantly share code, notes, and snippets.

@currymj
Created April 5, 2017 01:19
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save currymj/e903644c4e54e35fdb858c94f1631fe4 to your computer and use it in GitHub Desktop.
Save currymj/e903644c4e54e35fdb858c94f1631fe4 to your computer and use it in GitHub Desktop.
simple forward-backward implementation in tensorflow
import tensorflow as tf
import numpy as np
def backward_tf(transition, emission, initial_state, seq):
beta = tf.ones_like(transition[:,-1])
jrange = tf.reverse(tf.range(1,seq.get_shape()[0]),[0])
def scan_op(curr_beta, j):
return tf.reduce_sum( emission[:,seq[j]] * transition * curr_beta,1)
betas = tf.scan(scan_op, jrange, initializer=beta)
final_beta = initial_state * betas[-1,:] * emission[:,seq[0]]
return tf.concat([betas, tf.reshape(final_beta,[1,-1])], 0)
def forward_tf(transition, emission, initial_state, seq):
alpha = emission[:, seq[0]] * initial_state
jrange = tf.range(1, seq.get_shape()[0])
def scan_op(curr_alpha, j):
return tf.reduce_sum( emission[:,seq[j]] * transition * curr_alpha,1)
alphas = tf.scan(scan_op, jrange, initializer=alpha)
return tf.concat([tf.reshape(alpha, [1,-1]), alphas], 0)
if __name__ == '__main__':
transition = tf.constant([[.6,.4],[.3,.7]])
em = tf.constant([[.3,.4,.3],[.4,.3,.3]])
initial_state = tf.constant([.8,.2])
with tf.Session() as sess:
print(sess.run(backward_tf(transition, em, initial_state, tf.constant([0,1,2,2]))))
with tf.Session() as sess:
print(sess.run(forward_tf(transition, em, initial_state, tf.constant([0,1,2,2]))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment