Skip to content

Instantly share code, notes, and snippets.

@DerekChia
Last active December 1, 2018 15:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DerekChia/c34685a94f2be3c5d361de8d71b97bb6 to your computer and use it in GitHub Desktop.
Save DerekChia/c34685a94f2be3c5d361de8d71b97bb6 to your computer and use it in GitHub Desktop.
w2v_training_error_backpropagation
class word2vec():
##Removed##
for i in range(self.epochs):
self.loss = 0
for w_t, w_c in training_data:
##Removed##
# Calculate error
# 1. For a target word, calculate difference between y_pred and each of the context words
# 2. Sum up the differences using np.sum to give us the error for this particular target word
EI = np.sum([np.subtract(y_pred, word) for word in w_c], axis=0)
# Backpropagation
# We use SGD to backpropagate errors - calculate loss on the output layer
self.backprop(EI, h, w_t)
# Calculate loss
# There are 2 parts to the loss function
# Part 1: -ve sum of all the output +
# Part 2: length of context words * log of sum for all elements (exponential-ed) in the output layer before softmax (u)
# Note: word.index(1) returns the index in the context word vector with value 1
# Note: u[word.index(1)] returns the value of the output layer before softmax
self.loss += -np.sum([u[word.index(1)] for word in w_c]) + len(w_c) * np.log(np.sum(np.exp(u)))
print('Epoch:', i, "Loss:", self.loss)
def backprop(self, e, h, x):
# https://docs.scipy.org/doc/numpy-1.15.1/reference/generated/numpy.outer.html
# Column vector EI represents row-wise sum of prediction errors across each context word for the current center word
# Going backwards, we need to take derivative of E with respect of w2
# h - shape 10x1, e - shape 9x1, dl_dw2 - shape 10x9
dl_dw2 = np.outer(h, e)
# x - shape 1x8, w2 - 5x8, e.T - 8x1
# x - 1x8, np.dot() - 5x1, dl_dw1 - 8x5
dl_dw1 = np.outer(x, np.dot(self.w2, e.T))
# Update weights
self.w1 = self.w1 - (self.lr * dl_dw1)
self.w2 = self.w2 - (self.lr * dl_dw2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment