Skip to content

Instantly share code, notes, and snippets.

@gajomi
Created February 14, 2020 00:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gajomi/d3b1c5f20e1e6a4bdc36df7e81ecdeaf to your computer and use it in GitHub Desktop.
Save gajomi/d3b1c5f20e1e6a4bdc36df7e81ecdeaf to your computer and use it in GitHub Desktop.
albert hidden state oscillations (are there any?)
import matplotlib.pyplot as plt
import torch
import numpy as np
from transformers import *
def get_albert_model(albert_model_name = 'albert-large-v2'):
"""get an albert model from name for thei experiment"""
model_class, tokenizer_class, pretrained_weights = (AlbertModel, AlbertTokenizer, albert_model_name)
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
model = model_class.from_pretrained(pretrained_weights, output_hidden_states=True, output_attentions=True)
return tokenizer, model
def get_gen_model():
"""model data to gnerate text"""
gen_model_class, gen_tokenizer_class, gen_pretrained_weights = (GPT2LMHeadModel, GPT2Tokenizer, 'gpt2')
gen_tokenizer = gen_tokenizer_class.from_pretrained(gen_pretrained_weights)
gen_model = gen_model_class.from_pretrained(gen_pretrained_weights, output_hidden_states=True, output_attentions=True)
return gen_tokenizer, gen_model
def random_sentence_generator():
"""Returns a function that generates random sentence of specified length"""
gen_tokenizer, gen_model = get_gen_model()
token_ids = range(gen_tokenizer.vocab_size)
all_tokens = gen_tokenizer.convert_ids_to_tokens(token_ids)
valid_start_tokens = [t[1:] for t in all_tokens if t.startswith('Ġ')]
def generator(n = 12):
start_word = np.random.choice(valid_start_tokens)
input_ten = torch.tensor([gen_tokenizer.encode(start_word)])
ouput_toks = gen_model.generate(input_ten, max_length=n)
sentence = gen_tokenizer.decode(ouput_toks[0])
return sentence
return generator
def model_hidden_states(model, tokenizer, input, **modelkwargs):
"""Get hidden states of model during evalutation"""
input_ids = torch.tensor([tokenizer.encode(input, **modelkwargs)])
with torch.no_grad():
last_hidden_state, pooler_output, hidden_states, attentions = model(input_ids)
return hidden_states
def project_hidden_states(hidden_states, k = 4):
"""Return least squared projection of hidden state trajectories onto specified subspace
along with norm of residual"""
cat_hs = torch.cat([hs.flatten().unsqueeze(1) for hs in hidden_states],dim = 1).T
cat_hs = (cat_hs-cat_hs.mean())/cat_hs.std()
U,S,V = cat_hs.svd()
hs_proj = U[:,:k]*S[:k]
hs_res_norm = (U[:,k:]*S[k:]).norm(dim = 1)
return hs_proj, hs_res_norm
def _plot_proj_hidden_states(prj_hidden_states,norm_hidden_states, ax = None):
if ax is None:
fig, ax = plt.subplots()
ax.plot(prj_hidden_states)
ax.plot(norm_hidden_states, '--k')
mode_norms = prj_hidden_states.norm(dim = 0)
res_norm = norm_hidden_states.norm()
k = prj_hidden_states.shape[1]
ax.legend([f"||z_{i}||={n:.2f}" for i,n in enumerate(mode_norms)]+[f'residual norm = {res_norm:.2f}'])
return ax
def make_plot_proj_hidden_states(sentence_groups, subplot_size = (8,4)):
G = len(sentence_groups)
N = len(sentence_groups[0])
figsize = (subplot_size[0]*G,subplot_size[1]*N)
fig, axs = plt.subplots(N, G, figsize=figsize)
for g in range(G):
for n in range(N):
sentence = sentence_groups[g][n]
hidden_states = model_hidden_states(model, tokenizer, sentence, add_special_tokens=True)
prj_hidden_states, norm_hidden_states = project_hidden_states(hidden_states)
ax = axs[n,g]
ax = _plot_proj_hidden_states(prj_hidden_states, norm_hidden_states, ax = ax)
ax.set_title(sentence)
return axs
tokenizer, model = get_albert_model()
#rpeating word inputs
repeating_input_words = ['hi','yes','embryogenesis','antidisestablishmentarianism']
n_repeats = 5
repeating_inputs = [" ".join([word]*n_repeats) for word in repeating_input_words]
# gpt genreated random inputs
generator = random_sentence_generator()
gpt_gen_inputs = [generator() for i in range(4)]
make_plot_proj_hidden_states([gpt_gen_inputs,repeating_inputs])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment