Skip to content

Instantly share code, notes, and snippets.

@shawntan
Created May 26, 2015 05:56
Show Gist options
  • Save shawntan/a384b18ad8ec4daf55de to your computer and use it in GitHub Desktop.
Save shawntan/a384b18ad8ec4daf55de to your computer and use it in GitHub Desktop.
Symbol = (function() {
function createArray(length) {
var arr = new Array(length || 0),i = length;
if (arguments.length > 1) {
var args = Array.prototype.slice.call(arguments, 1);
while(i--) arr[length-1 - i] = createArray.apply(this, args);
} else {
for (var i=0;i < arr.length; i++) arr[i] = 0.0;
}
return arr;
}
function transformApplier(fun) {
return function(sym) {
data = sym.data;
data_res = createArray(sym.shape[0]);
for (var i=0;i<data.length;i++) {
data_res[i] = fun(data[i]);
if (isNaN(data_res[i])) throw "NaN alert! " + data;
}
return new Symbol(data_res);
}
}
function binaryApplier(fun) {
return function(sym1,sym2) {
data1 = sym1.data;
data2 = sym2.data;
data_res = createArray(sym1.shape[0]);
for (var i=0; i < sym1.shape[0]; i++) {
data_res[i] = fun(data1[i],data2[i]);
if (isNaN(data_res[i])) {
console.log(sym1);
console.log(sym2);
throw "NaN alert!" + data1[i] + ", " + data2[i];
}
}
return new Symbol(data_res);
}
}
function getDimensions(data) {
if (data instanceof Array) {
return [data.length].concat(getDimensions(data[0]));
} else {
return [];
}
}
var Symbol = function(data) {
this.data = data;
this.shape = getDimensions(data);
}
var thresh = 10;
Symbol.sigmoid = transformApplier(function(x) {
if ( x > thresh ) {
return 1;
} else if ( x < -thresh ) {
return 0;
} else {
return 1 / (1 + Math.exp(-x))
}
});
Symbol.tanh = transformApplier(function(x) {
if (x > thresh) {
return 1;
} else if ( x < -thresh) {
return -1;
} else {
var x_ = Math.exp(2*x);
return (x_ - 1)/(x_ + 1);
}
});
Symbol.neg = transformApplier(function(x) { return -x; });
Symbol.plus = binaryApplier(function(x,y) { return x + y; });
Symbol.mult = binaryApplier(function(x,y) { return x * y; });
Symbol.sub = binaryApplier(function(x,y) { return x - y; });
Symbol.div = binaryApplier(function(x,y) { return x / y; });
Symbol.argmax = function(vec) {
var max_id = null;
var max = -Infinity;
for (var i = 0;i < vec.shape[0]; i++) {
if (vec.data[i] > max) {
max = vec.data[i];
max_id = i;
}
}
return max_id;
};
Symbol.prototype = {
dot: function(sym2) {
var sym1 = this;
if (sym1.shape[0] != sym2.shape[0]) throw "Dimensions wrong!";
data1 = sym1.data;
data2 = sym2.data;
data_res = createArray(sym2.shape[1]);
for (var j=0; j < data_res.length; j++) {
var sum = 0;
for (var i=0; i < sym1.shape[0]; i++) {
sum += data1[i] * data2[i][j];
if (isNaN(sum)) {
throw "NaN alert! Sum so far: " + sum +
" Value 1: " + data1[i] +
" Value 2: " + data2[i][j];
}
}
data_res[j] = sum;
}
return new Symbol(data_res);
},
plus: function(sym2) { return Symbol.plus(this,sym2); },
mult: function(sym2) { return Symbol.mult(this,sym2); },
sub: function(sym2) { return Symbol.sub(this,sym2); },
div: function(sym2) { return Symbol.div(this,sym2); },
idx: function(i) { return new Symbol(this.data[i]); },
slice: function(start,end) {
return new Symbol(this.data.slice(start,end));
}
}
return Symbol;
})();
var model = null;
var loadModel = function(params) {
var P = {};
for (var k in params) { P[k] = new Symbol(params[k]); }
function lstm_builder(W_cell,W_hidden,W_input,b) {
var size = 100;
var b_i = b.idx(0);
var b_f = b.idx(1);
var b_c = b.idx(2);
var b_o = b.idx(3);
return function(input,prev_hidden,prev_cell) {
var x = input.dot(W_input);
var h = prev_hidden.dot(W_hidden);
var c = prev_cell.dot(W_cell);
var x_i = x.slice(0 * size, 1 * size);
var x_f = x.slice(1 * size, 2 * size);
var x_c = x.slice(2 * size, 3 * size);
var x_o = x.slice(3 * size, 4 * size);
var h_i = h.slice(0 * size, 1 * size);
var h_f = h.slice(1 * size, 2 * size);
var h_c = h.slice(2 * size, 3 * size);
var h_o = h.slice(3 * size, 4 * size);
var c_i = c.slice(0 * size, 1 * size);
var c_f = c.slice(1 * size, 2 * size);
var in_lin = x_i.plus(h_i).plus(b_i).plus(c_i);
var forget_lin = x_f.plus(h_f).plus(b_f).plus(c_f);
var cell_lin = x_c.plus(h_c).plus(b_c);
var in_gate = Symbol.sigmoid(in_lin);
var forget_gate = Symbol.sigmoid(forget_lin);
var cell_updates = Symbol.tanh(cell_lin);
var cell = (forget_gate.mult(prev_cell)).plus(in_gate.mult(cell_updates));
var c_o = (cell.dot(W_cell)).slice(2 * size, 3 * size);
var out_lin = x_o.plus(h_o).plus(b_o).plus(c_o);
var out_gate = Symbol.sigmoid(out_lin);
var hidden = out_gate.mult(Symbol.tanh(cell))
return { "hidden":hidden, "cell":cell };
}
}
lstm_1 = lstm_builder(
P.W_recurrent_1_cell,
P.W_recurrent_1_hidden,
P.W_recurrent_1_input,
P.b_recurrent_1
);
lstm_2 = lstm_builder(
P.W_recurrent_2_cell,
P.W_recurrent_2_hidden,
P.W_recurrent_2_input,
P.b_recurrent_2
);
var next_word = function(c,h1,c1,h2,c2) {
word_vec = P.V.idx(c);
layer_1 = lstm_1(word_vec,h1,c1);
layer_2 = lstm_2(layer_1.hidden,h2,c2);
output = (layer_2.hidden.dot(P.W_output)).plus(P.b_output);
return {
"output":output,
"h1": layer_1.hidden,
"c1": layer_1.cell,
"h2": layer_2.hidden,
"c2": layer_2.cell,
};
}
window.params = P;
model = window.model = {
"next_word":next_word,
"init_h1":Symbol.tanh(P.init_recurrent_1_hidden),
"init_h2":Symbol.tanh(P.init_recurrent_2_hidden),
"init_c1":P.init_recurrent_1_cell,
"init_c2":P.init_recurrent_2_cell,
"vocab": [
'\n', ' ', '!', '"', '#', '$', '%', '&', "'",
'(', ')', '*', '+', ',', '-', '.', '/', '0',
'1', '2', '3', '4', '5', '6', '7', '8', '9',
':', ';', '<', '=', '>', '?', '@', 'A', 'B',
'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K',
'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T',
'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']',
'^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',
'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
'y', 'z', '{', '|', '}', '~'
]
};
window.max_sample = function() {
var result = "";
var start_id = model.vocab.length;
var state = model.next_word(start_id,model.init_h1,model.init_c1,model.init_h2,model.init_c2);
var id = Symbol.argmax(state.output);
for ( var i=0;i < 100;i++) {
result += model.vocab[id];
state = model.next_word(id,state.h1,state.c1,state.h2,state.c2);
id = Symbol.argmax(state.output);
}
console.log(result)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment