Skip to content

Instantly share code, notes, and snippets.

@macleginn
Created June 21, 2021 07:17
Show Gist options
  • Save macleginn/7a60126cf805b6dcb58a364973777deb to your computer and use it in GitHub Desktop.
Save macleginn/7a60126cf805b6dcb58a364973777deb to your computer and use it in GitHub Desktop.
Code for extracting word embeddings from RoBERTa
def rm_whitespace(s):
if s.startswith('Ġ'):
return s[1:]
else:
return s
def get_tokens_with_ranges(input_string, tokenizer):
'''
RoBERTa prepends 'Ġ' to the beginning of what it
thinks to be a word in the input, except the first one
(when it is not prefixed with a whitespace). E.g.:
```
In [30]: tokenizer.tokenize('I say geography')
Out[30]: ['I', 'Ġsay', 'Ġgeography']
In [29]: tokenizer.tokenize(' i say chronomoscopy')
Out[29]: ['Ġi', 'Ġsay', 'Ġchron', 'om', 'osc', 'opy']
```
This function returns an array containing ranges for
tokens of each word together with PyTorch tensors for
the tokens.
'''
assert input_string
tokens = tokenizer.tokenize(input_string)
ranges = [
[0, 1] # Start-of-sentence token
]
tmp = []
for i, token in enumerate(tokens):
idx = i + 1 # 0 is <s>
if not tmp:
tmp.append(idx)
else:
if token.startswith('Ġ'):
ranges.append([tmp[0], tmp[-1]+1])
tmp = [idx]
else:
tmp.append(idx)
ranges.append([tmp[0], tmp[-1] + 1])
ranges.append([len(tokens) + 1, len(tokens) + 2]) # End-of-sentence token
return ranges, tokenizer(input_string, return_tensors='pt')
def get_word_embeddings(input_string, tokenizer, model, level):
input_tokens = tokenizer.tokenize(input_string)
ranges, inputs = get_tokens_with_ranges(input_string, tokenizer)
with torch.no_grad():
if level == 'final':
# No batches for now
outputs = model(**inputs).last_hidden_state[0]
else:
outputs = model(
**inputs, output_hidden_states=True).hidden_states[level-1][0]
return {
rm_whitespace(''.join(input_tokens[start-1:end-1])):
outputs[start-1:end-1, :].mean(0).numpy()
for start, end in ranges[1:-1]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment