Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kanekomasahiro/9ab23b85829e385c077445fb8dfaf6db to your computer and use it in GitHub Desktop.
Save kanekomasahiro/9ab23b85829e385c077445fb8dfaf6db to your computer and use it in GitHub Desktop.
自然言語処理でよくあるテンソル(バッチ×文長×隠れ層の次元サイズ)からバッチごとに異なるindexの単語の隠れ層を抽出する方法.
import torch
def extract_hidden_states_by_word_index(input, index):
'''
Extract hidden states from a tensor (batch * sentence length * hidden size) by word index.
Parameters
----------
input : torch.Tensor
Tensor of batch * sentence length * hidden size size
index : list
List of batch size containing word indexes to extract, eg. if input size = 3 * 10 * 300, index = [9, 0, 6]
'''
return input[range(input.size(0)), index, :]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment