Skip to content

Instantly share code, notes, and snippets.

View scotthuang1989's full-sized avatar

scott huang(黄) scotthuang1989

View GitHub Profile
@scotthuang1989
scotthuang1989 / bnlstm.py
Created March 11, 2018 02:55 — forked from spitis/bnlstm.py
Batch normalized LSTM Cell for Tensorflow
"""adapted from https://github.com/OlavHN/bnlstm to store separate population statistics per state"""
import tensorflow as tf, numpy as np
RNNCell = tf.nn.rnn_cell.RNNCell
class BNLSTMCell(RNNCell):
'''Batch normalized LSTM as described in arxiv.org/abs/1603.09025'''
def __init__(self, num_units, is_training_tensor, max_bn_steps, initial_scale=0.1, activation=tf.tanh, decay=0.95):
"""
* max bn steps is the maximum number of steps for which to store separate population stats
"""