Created
March 24, 2019 14:49
-
-
Save jgontrum/46b14b27ea20dc238358d104d793262f to your computer and use it in GitHub Desktop.
PyTorch Embedding Updates
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
import torch.nn as nn | |
import itertools | |
corpus = [ | |
("this is a very good sentence number one".split(), [1.0, 0.0]), | |
("a rather bad example".split(), [0.0, 1.0]), | |
("an even better sentence to use".split(), [1.0, 0.0]), | |
("negative sentence very short".split(), [0.0, 1.0]) | |
] | |
X, y = zip(*corpus) | |
vocabulary = { | |
"*pad*": 0, | |
"*oov*": 1 | |
} | |
for token in itertools.chain(*X): | |
id_ = vocabulary.get(token) | |
if id_ is None: | |
vocabulary[token] = len(vocabulary) | |
X_tensors = [ | |
torch.tensor([[vocabulary.get(token) for token in sentence]]) | |
for sentence in X | |
] | |
y_tensors = [ | |
torch.tensor(y_) for y_ in y | |
] | |
class ExampleModule(nn.Module): | |
def __init__(self, vocabulary): | |
super().__init__() | |
self.vocabulary = vocabulary | |
np.random.seed(1234567) | |
torch.manual_seed(1234567) | |
self.word_embeddings = nn.Embedding( | |
num_embeddings=len(self.vocabulary), | |
embedding_dim=5, | |
padding_idx=self.vocabulary["*pad*"], | |
) | |
self.lstm = nn.LSTM( | |
input_size=5, | |
hidden_size=15, | |
bidirectional=True, | |
batch_first=True | |
) | |
self.mlp = nn.Sequential( | |
nn.Linear(30, 100), | |
nn.Tanh(), | |
nn.Linear(100, 2) | |
) | |
def forward(self, sentence): | |
embeddings = self.word_embeddings(sentence) | |
output, _ = self.lstm(embeddings) | |
last_word = output[0][-1] | |
flattened = last_word.view((-1,)) | |
return self.mlp(flattened) | |
model = ExampleModule(vocabulary) | |
criterion = nn.MSELoss() | |
optimizer = torch.optim.Adam(model.parameters()) | |
token_id = vocabulary["better"] | |
for epoch in range(10): | |
print("----------------------------") | |
print(f"Starting epoch #{epoch + 1}") | |
for sentence, X_, y_ in zip(X, X_tensors, y_tensors): | |
# Setting the gradients to zero in the model | |
model.zero_grad() | |
optimizer.zero_grad() | |
prediction = model(X_) | |
loss = criterion(y_, prediction) | |
loss.backward() | |
# Save the embedding and the gradient as a numpy array so it doesn't | |
# accidentally change. | |
emb_before = model.word_embeddings(torch.tensor(token_id)).detach().numpy() | |
gradient = model.word_embeddings._parameters["weight"]._grad[token_id].numpy() | |
optimizer.step() | |
# Check the embedding again and check for equality | |
emb_after = model.word_embeddings(torch.tensor(token_id)).detach().numpy() | |
emb_equal = np.array_equal(emb_before, emb_after) | |
sentence_contains_token = 'better' in sentence | |
print(f"Current sentence: '{' '.join(sentence)}' ") | |
if not emb_equal and sentence_contains_token: | |
print(f"✅ The sentence contained the word 'better' and its embedding was updated. Hurray.") | |
if emb_equal and not sentence_contains_token: | |
print(f"✅ The sentence did not contain the word 'better' and its embedding was not updated. Hurray.") | |
if not emb_equal and not sentence_contains_token: | |
diff = sum(emb_before - emb_after) | |
print(f"❌ The sentence did not contain the word 'better', but the embedding was updated anyhow. Boo.") | |
print(f"The gradient for the embedding: {gradient}") | |
print(f"The embedding before the update: {emb_before}") | |
print(f"The embedding after the update: {emb_after}") | |
print(f"Summed difference between both embeddings: {diff}") | |
if emb_equal and sentence_contains_token: | |
# This never happens, putting it here for the sake of completeness | |
print(f"❌❌❌ The sentence did not contain the word 'better', but the embedding was not updated. Boo.") | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment