Skip to content

Instantly share code, notes, and snippets.

View rockt's full-sized avatar

Tim Rocktäschel rockt

View GitHub Profile
@rockt
rockt / zweisum.py
Created February 12, 2021 09:22
PyTorch einsum with named tensors
import torch
import re
def einsumfy_exp(exp):
names = set(re.split("[, \(\)]|->", exp))
names.remove("")
invalid_names = set(filter(lambda x: len(x) > 1, names))
if "..." in invalid_names:
invalid_names.remove("...")
@rockt
rockt / gather_by_lengths.py
Created August 25, 2016 19:45
TensorFlow utility method to obtain the correct last outputs from a tf.nn.dynamic_rnn as determined by sequence length
def gather_by_lengths(outputs, seq_lengths):
"""
:param outputs: [batch_size x max_seq_length x hidden_size] tensor of dynamic_rnn outputs
:param seq_lengths: [batch_size] tensor of sequence lengths
:return: [batch_size x hidden_size] tensor of last outputs
"""
batch_size, max_seq_length, hidden_size = tf.unpack(tf.shape(outputs))
index = tf.range(0, batch_size) * max_seq_length + (seq_lengths - 1)
return tf.gather(tf.reshape(outputs, [-1, hidden_size]), index)