Skip to content

Instantly share code, notes, and snippets.

@sbugrov
Created April 21, 2018 22:13
Show Gist options
  • Save sbugrov/8b729a966098ee03681cc580b2a9736c to your computer and use it in GitHub Desktop.
Save sbugrov/8b729a966098ee03681cc580b2a9736c to your computer and use it in GitHub Desktop.
// # Feed forward
vector<float> a1 = relu(dot( b_X, W1, BATCH_SIZE, 784, 128 ));
vector<float> a2 = relu(dot( a1, W2, BATCH_SIZE, 128, 64 ));
vector<float> yhat = softmax(dot( a2, W3, BATCH_SIZE, 64, 10 ), 10);
// # Back propagation
vector<float> dyhat = yhat - b_y;
vector<float> dW3 = dot(transpose( &a2[0], BATCH_SIZE, 64 ), dyhat, 64, BATCH_SIZE, 10);
vector<float> dz2 = dot(dyhat, transpose( &W3[0], 64, 10 ), BATCH_SIZE, 10, 64) * reluPrime(a2);
vector<float> dW2 = dot(transpose( &a1[0], BATCH_SIZE, 128 ), dz2, 128, BATCH_SIZE, 64);
vector<float> dz1 = dot(dz2, transpose( &W2[0], 128, 64 ), BATCH_SIZE, 64, 128) * reluPrime(a1);
vector<float> dW1 = dot(transpose( &b_X[0], BATCH_SIZE, 784 ), dz1, 784, BATCH_SIZE, 128);
// # Updating the parameters
W3 = W3 - lr * dW3;
W2 = W2 - lr * dW2;
W1 = W1 - lr * dW1;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment