Last active
August 4, 2019 19:17
-
-
Save jakic12/414ad450d9c1222e58d8e09c6b92cebb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<html> | |
<body> | |
<canvas id="myChart"></canvas> | |
<script src="https://cdn.jsdelivr.net/npm/chart.js@2.8.0"></script> | |
<script> | |
exp = [ // expected array | |
[1,0], | |
[0,1] | |
] | |
f = [ // filter | |
[1,1,1], | |
[1,1,1], | |
[1,1,1] | |
] | |
input = [ // input | |
[1,1,1], | |
[1,1,1], | |
[1,1,1] | |
] | |
s = 2 // stride | |
p = 1 // padding | |
lr = 0.0001 // learning rate | |
// first correlation,just to see what the raw output is without backpropagation | |
var out = corre(input, f, s, p); | |
console.log(`output before learning:`,out) | |
var data = [] // array to store errors for graphing | |
for(let ow = 0; ow < 10000; ow++){ // main training loop | |
// first correlate to get the input | |
out = corre(input, f, s, p); | |
// then get the error and the partial derivatives | |
// with respect to the output layer | |
let {dO, err} = getError(out, exp) | |
// store this error to the graphing array every 100 iterations | |
if(ow % 100 == 0) | |
data.push(err) | |
// get the filter deltas | |
dF = backpropFilter(f.length, dO, input, s, p) | |
// update the filter with the filter deltas | |
f = updateArray(f, dF) | |
} | |
console.log(`output after learning:`,out) | |
// chart js code | |
var ctx = document.getElementById('myChart').getContext('2d'); | |
var myLineChart = new Chart(ctx, { | |
type: 'line', | |
data: { | |
labels: data.map((v,i) => i), | |
datasets:[{ | |
data:data | |
}] | |
}, | |
options: {} | |
}); | |
// chart js code | |
/** | |
* update an array elementwise with multiplying darr elements with the learining rate | |
*/ | |
function updateArray(arr, darr){ | |
for(let i = 0; i < arr.length; i++){ | |
for(let j = 0; j < arr[i].length; j++){ | |
arr[i][j] -= darr[i][j] * lr | |
} | |
} | |
return arr | |
} | |
/** | |
* Calculate the error and the partial derivatives with respect to the output layer | |
* @param{*} actual the actual output | |
* @param{*} exp the expected output | |
*/ | |
function getError(actual, exp){ | |
let err = 0 | |
let out = new Array(actual.length).fill(0).map(() => new Array(actual[0].length)) | |
for(let i = 0; i < out.length; i++){ | |
for(let j = 0; j < out[i].length; j++){ | |
out[i][j] = actual[i][j] - exp[i][j] | |
err += Math.pow(exp[i][j] - actual[i][j],2) | |
} | |
} | |
return { dO:out, err:err/2 } | |
} | |
/** | |
* correlate an array `a` with a filter `f` | |
* @param{*} a the input array | |
* @param{*} f the filter | |
* @param{*} s stride | |
* @param{*} p padding | |
*/ | |
function corre(a, f, s, p){ | |
let outY = parseInt((a.length - f.length + 2 * p)/s + 1) | |
let outX = parseInt((a[0].length - f[0].length + 2 * p)/s + 1) | |
return new Array(outY).fill(0).map((_, y) => | |
new Array(outX).fill(0).map((__, x) => { | |
let sum = 0; | |
for(let j = 0; j < f.length; j++){ | |
for(let i = 0; i < f[j].length; i++){ | |
if(a[y + j * s - p] && a[y + j * s - p][x + i * s - p]) | |
sum += a[y + j * s - p][x + i * s - p] * f[j][i] | |
} | |
} | |
return sum | |
} | |
) | |
) | |
} | |
/** | |
* backpropagate the correlation with given derivatives of the next layer | |
* @param{*} FS filter size | |
* @param{*} dO derivative with respect to the output of the correlation | |
* @param{*} input the input of the correlation | |
* @param{*} s stride | |
* @param{*} p padding | |
*/ | |
function backpropFilter(FS, dO, input, s, p){ | |
return new Array(FS).fill(0).map((_, j) => | |
new Array(FS).fill(0).map((__, i) => { | |
let sum = 0 | |
for(let h = 0; h < dO.length; h++){ | |
for(let k = 0; k < dO[h].length; k++){ | |
if(j+h*s-p >= 0 && j+h*s-p < input.length && i+k*s-p >= 0 && i+k*s-p < input[0].length) | |
sum += dO[h][k] * input[j+h*s-p][i+k*s-p] | |
} | |
} | |
return sum | |
} | |
) | |
) | |
} | |
</script> | |
</body> | |
</html> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment