Skip to content

Instantly share code, notes, and snippets.

@zomux
Created May 8, 2015 00:04
Show Gist options
  • Save zomux/cee5c4878c9256ecf6c0 to your computer and use it in GitHub Desktop.
Save zomux/cee5c4878c9256ecf6c0 to your computer and use it in GitHub Desktop.
learn: function(r1) {
// perform an update on Q function
if(!(this.r0 == null) && this.alpha > 0) {
// learn from this tuple to get a sense of how "surprising" it is to the agent
var tderror = this.learnFromTuple(this.s0, this.a0, this.r0, this.s1, this.a1);
this.tderror = tderror; // a measure of surprise
// decide if we should keep this experience in the replay
if(this.t % this.experience_add_every === 0) {
this.exp[this.expi] = [this.s0, this.a0, this.r0, this.s1, this.a1];
this.expi += 1;
if(this.expi > this.experience_size) { this.expi = 0; } // roll over when we run out
}
this.t += 1;
// sample some additional experience from replay memory and learn from it
for(var k=0;k<this.learning_steps_per_iteration;k++) {
var ri = randi(0, this.exp.length); // todo: priority sweeps?
var e = this.exp[ri];
this.learnFromTuple(e[0], e[1], e[2], e[3], e[4])
}
}
this.r0 = r1; // store for next update
},
learnFromTuple: function(s0, a0, r0, s1, a1) {
// want: Q(s,a) = r + gamma * max_a' Q(s',a')
// compute the target Q value
var tmat = this.forwardQ(this.net, s1, false);
var qmax = r0 + this.gamma * tmat.w[R.maxi(tmat.w)];
// now predict
var pred = this.forwardQ(this.net, s0, true);
var tderror = pred.w[a0] - qmax;
var clamp = this.tderror_clamp;
if(Math.abs(tderror) > clamp) { // huber loss to robustify
if(tderror > clamp) tderror = clamp;
if(tderror < -clamp) tderror = -clamp;
}
pred.dw[a0] = tderror;
this.lastG.backward(); // compute gradients on net params
// update net
R.updateNet(this.net, this.alpha);
return tderror;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment