Skip to content

Instantly share code, notes, and snippets.

@jgontrum
Created March 24, 2019 14:49
Show Gist options
  • Save jgontrum/46b14b27ea20dc238358d104d793262f to your computer and use it in GitHub Desktop.
Save jgontrum/46b14b27ea20dc238358d104d793262f to your computer and use it in GitHub Desktop.
PyTorch Embedding Updates
import torch
import numpy as np
import torch.nn as nn
import itertools
corpus = [
("this is a very good sentence number one".split(), [1.0, 0.0]),
("a rather bad example".split(), [0.0, 1.0]),
("an even better sentence to use".split(), [1.0, 0.0]),
("negative sentence very short".split(), [0.0, 1.0])
]
X, y = zip(*corpus)
vocabulary = {
"*pad*": 0,
"*oov*": 1
}
for token in itertools.chain(*X):
id_ = vocabulary.get(token)
if id_ is None:
vocabulary[token] = len(vocabulary)
X_tensors = [
torch.tensor([[vocabulary.get(token) for token in sentence]])
for sentence in X
]
y_tensors = [
torch.tensor(y_) for y_ in y
]
class ExampleModule(nn.Module):
def __init__(self, vocabulary):
super().__init__()
self.vocabulary = vocabulary
np.random.seed(1234567)
torch.manual_seed(1234567)
self.word_embeddings = nn.Embedding(
num_embeddings=len(self.vocabulary),
embedding_dim=5,
padding_idx=self.vocabulary["*pad*"],
)
self.lstm = nn.LSTM(
input_size=5,
hidden_size=15,
bidirectional=True,
batch_first=True
)
self.mlp = nn.Sequential(
nn.Linear(30, 100),
nn.Tanh(),
nn.Linear(100, 2)
)
def forward(self, sentence):
embeddings = self.word_embeddings(sentence)
output, _ = self.lstm(embeddings)
last_word = output[0][-1]
flattened = last_word.view((-1,))
return self.mlp(flattened)
model = ExampleModule(vocabulary)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
token_id = vocabulary["better"]
for epoch in range(10):
print("----------------------------")
print(f"Starting epoch #{epoch + 1}")
for sentence, X_, y_ in zip(X, X_tensors, y_tensors):
# Setting the gradients to zero in the model
model.zero_grad()
optimizer.zero_grad()
prediction = model(X_)
loss = criterion(y_, prediction)
loss.backward()
# Save the embedding and the gradient as a numpy array so it doesn't
# accidentally change.
emb_before = model.word_embeddings(torch.tensor(token_id)).detach().numpy()
gradient = model.word_embeddings._parameters["weight"]._grad[token_id].numpy()
optimizer.step()
# Check the embedding again and check for equality
emb_after = model.word_embeddings(torch.tensor(token_id)).detach().numpy()
emb_equal = np.array_equal(emb_before, emb_after)
sentence_contains_token = 'better' in sentence
print(f"Current sentence: '{' '.join(sentence)}' ")
if not emb_equal and sentence_contains_token:
print(f"✅ The sentence contained the word 'better' and its embedding was updated. Hurray.")
if emb_equal and not sentence_contains_token:
print(f"✅ The sentence did not contain the word 'better' and its embedding was not updated. Hurray.")
if not emb_equal and not sentence_contains_token:
diff = sum(emb_before - emb_after)
print(f"❌ The sentence did not contain the word 'better', but the embedding was updated anyhow. Boo.")
print(f"The gradient for the embedding: {gradient}")
print(f"The embedding before the update: {emb_before}")
print(f"The embedding after the update: {emb_after}")
print(f"Summed difference between both embeddings: {diff}")
if emb_equal and sentence_contains_token:
# This never happens, putting it here for the sake of completeness
print(f"❌❌❌ The sentence did not contain the word 'better', but the embedding was not updated. Boo.")
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment