Created
May 26, 2015 05:56
-
-
Save shawntan/a384b18ad8ec4daf55de to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Symbol = (function() { | |
function createArray(length) { | |
var arr = new Array(length || 0),i = length; | |
if (arguments.length > 1) { | |
var args = Array.prototype.slice.call(arguments, 1); | |
while(i--) arr[length-1 - i] = createArray.apply(this, args); | |
} else { | |
for (var i=0;i < arr.length; i++) arr[i] = 0.0; | |
} | |
return arr; | |
} | |
function transformApplier(fun) { | |
return function(sym) { | |
data = sym.data; | |
data_res = createArray(sym.shape[0]); | |
for (var i=0;i<data.length;i++) { | |
data_res[i] = fun(data[i]); | |
if (isNaN(data_res[i])) throw "NaN alert! " + data; | |
} | |
return new Symbol(data_res); | |
} | |
} | |
function binaryApplier(fun) { | |
return function(sym1,sym2) { | |
data1 = sym1.data; | |
data2 = sym2.data; | |
data_res = createArray(sym1.shape[0]); | |
for (var i=0; i < sym1.shape[0]; i++) { | |
data_res[i] = fun(data1[i],data2[i]); | |
if (isNaN(data_res[i])) { | |
console.log(sym1); | |
console.log(sym2); | |
throw "NaN alert!" + data1[i] + ", " + data2[i]; | |
} | |
} | |
return new Symbol(data_res); | |
} | |
} | |
function getDimensions(data) { | |
if (data instanceof Array) { | |
return [data.length].concat(getDimensions(data[0])); | |
} else { | |
return []; | |
} | |
} | |
var Symbol = function(data) { | |
this.data = data; | |
this.shape = getDimensions(data); | |
} | |
var thresh = 10; | |
Symbol.sigmoid = transformApplier(function(x) { | |
if ( x > thresh ) { | |
return 1; | |
} else if ( x < -thresh ) { | |
return 0; | |
} else { | |
return 1 / (1 + Math.exp(-x)) | |
} | |
}); | |
Symbol.tanh = transformApplier(function(x) { | |
if (x > thresh) { | |
return 1; | |
} else if ( x < -thresh) { | |
return -1; | |
} else { | |
var x_ = Math.exp(2*x); | |
return (x_ - 1)/(x_ + 1); | |
} | |
}); | |
Symbol.neg = transformApplier(function(x) { return -x; }); | |
Symbol.plus = binaryApplier(function(x,y) { return x + y; }); | |
Symbol.mult = binaryApplier(function(x,y) { return x * y; }); | |
Symbol.sub = binaryApplier(function(x,y) { return x - y; }); | |
Symbol.div = binaryApplier(function(x,y) { return x / y; }); | |
Symbol.argmax = function(vec) { | |
var max_id = null; | |
var max = -Infinity; | |
for (var i = 0;i < vec.shape[0]; i++) { | |
if (vec.data[i] > max) { | |
max = vec.data[i]; | |
max_id = i; | |
} | |
} | |
return max_id; | |
}; | |
Symbol.prototype = { | |
dot: function(sym2) { | |
var sym1 = this; | |
if (sym1.shape[0] != sym2.shape[0]) throw "Dimensions wrong!"; | |
data1 = sym1.data; | |
data2 = sym2.data; | |
data_res = createArray(sym2.shape[1]); | |
for (var j=0; j < data_res.length; j++) { | |
var sum = 0; | |
for (var i=0; i < sym1.shape[0]; i++) { | |
sum += data1[i] * data2[i][j]; | |
if (isNaN(sum)) { | |
throw "NaN alert! Sum so far: " + sum + | |
" Value 1: " + data1[i] + | |
" Value 2: " + data2[i][j]; | |
} | |
} | |
data_res[j] = sum; | |
} | |
return new Symbol(data_res); | |
}, | |
plus: function(sym2) { return Symbol.plus(this,sym2); }, | |
mult: function(sym2) { return Symbol.mult(this,sym2); }, | |
sub: function(sym2) { return Symbol.sub(this,sym2); }, | |
div: function(sym2) { return Symbol.div(this,sym2); }, | |
idx: function(i) { return new Symbol(this.data[i]); }, | |
slice: function(start,end) { | |
return new Symbol(this.data.slice(start,end)); | |
} | |
} | |
return Symbol; | |
})(); | |
var model = null; | |
var loadModel = function(params) { | |
var P = {}; | |
for (var k in params) { P[k] = new Symbol(params[k]); } | |
function lstm_builder(W_cell,W_hidden,W_input,b) { | |
var size = 100; | |
var b_i = b.idx(0); | |
var b_f = b.idx(1); | |
var b_c = b.idx(2); | |
var b_o = b.idx(3); | |
return function(input,prev_hidden,prev_cell) { | |
var x = input.dot(W_input); | |
var h = prev_hidden.dot(W_hidden); | |
var c = prev_cell.dot(W_cell); | |
var x_i = x.slice(0 * size, 1 * size); | |
var x_f = x.slice(1 * size, 2 * size); | |
var x_c = x.slice(2 * size, 3 * size); | |
var x_o = x.slice(3 * size, 4 * size); | |
var h_i = h.slice(0 * size, 1 * size); | |
var h_f = h.slice(1 * size, 2 * size); | |
var h_c = h.slice(2 * size, 3 * size); | |
var h_o = h.slice(3 * size, 4 * size); | |
var c_i = c.slice(0 * size, 1 * size); | |
var c_f = c.slice(1 * size, 2 * size); | |
var in_lin = x_i.plus(h_i).plus(b_i).plus(c_i); | |
var forget_lin = x_f.plus(h_f).plus(b_f).plus(c_f); | |
var cell_lin = x_c.plus(h_c).plus(b_c); | |
var in_gate = Symbol.sigmoid(in_lin); | |
var forget_gate = Symbol.sigmoid(forget_lin); | |
var cell_updates = Symbol.tanh(cell_lin); | |
var cell = (forget_gate.mult(prev_cell)).plus(in_gate.mult(cell_updates)); | |
var c_o = (cell.dot(W_cell)).slice(2 * size, 3 * size); | |
var out_lin = x_o.plus(h_o).plus(b_o).plus(c_o); | |
var out_gate = Symbol.sigmoid(out_lin); | |
var hidden = out_gate.mult(Symbol.tanh(cell)) | |
return { "hidden":hidden, "cell":cell }; | |
} | |
} | |
lstm_1 = lstm_builder( | |
P.W_recurrent_1_cell, | |
P.W_recurrent_1_hidden, | |
P.W_recurrent_1_input, | |
P.b_recurrent_1 | |
); | |
lstm_2 = lstm_builder( | |
P.W_recurrent_2_cell, | |
P.W_recurrent_2_hidden, | |
P.W_recurrent_2_input, | |
P.b_recurrent_2 | |
); | |
var next_word = function(c,h1,c1,h2,c2) { | |
word_vec = P.V.idx(c); | |
layer_1 = lstm_1(word_vec,h1,c1); | |
layer_2 = lstm_2(layer_1.hidden,h2,c2); | |
output = (layer_2.hidden.dot(P.W_output)).plus(P.b_output); | |
return { | |
"output":output, | |
"h1": layer_1.hidden, | |
"c1": layer_1.cell, | |
"h2": layer_2.hidden, | |
"c2": layer_2.cell, | |
}; | |
} | |
window.params = P; | |
model = window.model = { | |
"next_word":next_word, | |
"init_h1":Symbol.tanh(P.init_recurrent_1_hidden), | |
"init_h2":Symbol.tanh(P.init_recurrent_2_hidden), | |
"init_c1":P.init_recurrent_1_cell, | |
"init_c2":P.init_recurrent_2_cell, | |
"vocab": [ | |
'\n', ' ', '!', '"', '#', '$', '%', '&', "'", | |
'(', ')', '*', '+', ',', '-', '.', '/', '0', | |
'1', '2', '3', '4', '5', '6', '7', '8', '9', | |
':', ';', '<', '=', '>', '?', '@', 'A', 'B', | |
'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', | |
'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', | |
'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', | |
'^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', | |
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', | |
'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', | |
'y', 'z', '{', '|', '}', '~' | |
] | |
}; | |
window.max_sample = function() { | |
var result = ""; | |
var start_id = model.vocab.length; | |
var state = model.next_word(start_id,model.init_h1,model.init_c1,model.init_h2,model.init_c2); | |
var id = Symbol.argmax(state.output); | |
for ( var i=0;i < 100;i++) { | |
result += model.vocab[id]; | |
state = model.next_word(id,state.h1,state.c1,state.h2,state.c2); | |
id = Symbol.argmax(state.output); | |
} | |
console.log(result) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment