Last active
December 1, 2018 15:26
-
-
Save DerekChia/c34685a94f2be3c5d361de8d71b97bb6 to your computer and use it in GitHub Desktop.
w2v_training_error_backpropagation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class word2vec(): | |
##Removed## | |
for i in range(self.epochs): | |
self.loss = 0 | |
for w_t, w_c in training_data: | |
##Removed## | |
# Calculate error | |
# 1. For a target word, calculate difference between y_pred and each of the context words | |
# 2. Sum up the differences using np.sum to give us the error for this particular target word | |
EI = np.sum([np.subtract(y_pred, word) for word in w_c], axis=0) | |
# Backpropagation | |
# We use SGD to backpropagate errors - calculate loss on the output layer | |
self.backprop(EI, h, w_t) | |
# Calculate loss | |
# There are 2 parts to the loss function | |
# Part 1: -ve sum of all the output + | |
# Part 2: length of context words * log of sum for all elements (exponential-ed) in the output layer before softmax (u) | |
# Note: word.index(1) returns the index in the context word vector with value 1 | |
# Note: u[word.index(1)] returns the value of the output layer before softmax | |
self.loss += -np.sum([u[word.index(1)] for word in w_c]) + len(w_c) * np.log(np.sum(np.exp(u))) | |
print('Epoch:', i, "Loss:", self.loss) | |
def backprop(self, e, h, x): | |
# https://docs.scipy.org/doc/numpy-1.15.1/reference/generated/numpy.outer.html | |
# Column vector EI represents row-wise sum of prediction errors across each context word for the current center word | |
# Going backwards, we need to take derivative of E with respect of w2 | |
# h - shape 10x1, e - shape 9x1, dl_dw2 - shape 10x9 | |
dl_dw2 = np.outer(h, e) | |
# x - shape 1x8, w2 - 5x8, e.T - 8x1 | |
# x - 1x8, np.dot() - 5x1, dl_dw1 - 8x5 | |
dl_dw1 = np.outer(x, np.dot(self.w2, e.T)) | |
# Update weights | |
self.w1 = self.w1 - (self.lr * dl_dw1) | |
self.w2 = self.w2 - (self.lr * dl_dw2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment