Skip to content

Instantly share code, notes, and snippets.

@Arturus
Last active September 13, 2017 18:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Arturus/6388505ed7eafc3a67bcd7d077add5cf to your computer and use it in GitHub Desktop.
Save Arturus/6388505ed7eafc3a67bcd7d077add5cf to your computer and use it in GitHub Desktop.
def attn_readout_v3(readout, attn_window, attn_heads, page_features, seed):
# input: [n_days, batch, readout_depth]
# [n_days, batch, readout_depth] -> [batch(readout_depth), width=n_days, channels=batch]
readout = tf.transpose(readout, [2, 0, 1])
# [batch(readout_depth), width, channels] -> [batch, height=1, width, channels]
inp = readout[:, tf.newaxis, :, :]
# attn_window = train_window - predict_window + 1
# [batch, attn_window * n_heads]
filter_logits = tf.layers.dense(page_features, attn_window * attn_heads, name="attn_focus", kernel_initializer=default_init(seed)
#kernel_initializer=layers.variance_scaling_initializer(uniform=True)
#activation=selu,
#kernel_initializer=layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN')
)
# [batch, attn_window * n_heads] -> [batch, attn_window, n_heads]
filter_logits = tf.reshape(filter_logits, [-1, attn_window, attn_heads])
#filter_logits = tf.get_variable("attn_focus", dtype=tf.float32,
# initializer=tf.random_uniform([attn_window, attn_heads], -1, 1))
#attns_max = tf.nn.softmax(filter_logits, dim=1)
attns_max = filter_logits / tf.reduce_sum(filter_logits, axis=1, keep_dims=True)
# [batch, attn_window, n_heads] -> [width(attn_window), channels(batch), n_heads]
attns_max = tf.transpose(attns_max, [1, 0, 2])
# [width(attn_window), channels(batch), n_heads] -> [height(1), width(attn_window), channels(batch), multiplier(n_heads)]
attn_filter = attns_max[tf.newaxis, :, :, :]
# [batch(readout_depth), height=1, width=n_days, channels=batch] -> [batch(readout_depth), height=1, width=predict_window, channels=batch*n_heads]
averaged = tf.nn.depthwise_conv2d_native(inp, attn_filter, [1, 1, 1, 1], 'VALID')
# [batch, height=1, width=predict_window, channels=readout_depth*n_neads] -> [batch(depth), predict_window, batch*n_heads]
attn_features = tf.squeeze(averaged, 1)
# [batch(depth), predict_window, batch*n_heads] -> [batch*n_heads, predict_window, depth]
attn_features = tf.transpose(attn_features, [2, 1, 0])
# [batch * n_heads, predict_window, depth] -> n_heads * [batch, predict_window, depth]
heads = [attn_features[head_no::attn_heads] for head_no in range(attn_heads)]
# n_heads * [batch, predict_window, depth] -> [batch, predict_window, depth*n_heads]
result = tf.concat(heads, axis=-1)
#attn_diag = tf.unstack(attns_max, axis=-1)
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment