Skip to content

Instantly share code, notes, and snippets.

@breuderink
Created September 22, 2017 10:53
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save breuderink/744036376f4a145982d1cb3553a7847c to your computer and use it in GitHub Desktop.
Save breuderink/744036376f4a145982d1cb3553a7847c to your computer and use it in GitHub Desktop.
Linear TD with gradient correction
// [1] Sutton, Richard S., et al. "Fast gradient-descent methods for
// temporal-difference learning with linear function approximation."
// Proceedings of the 26th Annual International Conference on Machine
// Learning. ACM, 2009.
#define TDC_NFEAT 4
typedef struct {
float gamma;
float theta[TDC_NFEAT], w[TDC_NFEAT];
} TDC;
typedef struct {
float reward;
float phi[TDC_NFEAT];
float phi_next[TDC_NFEAT];
} observation;
float tdc_tderror(TDC *model, observation *o) {
// From (2):
float delta = o->reward;
for (int i = 0; i < TDC_NFEAT; i++) {
delta += model->theta[i] * (
model->gamma * o->phi_next[i] - o->phi[i]);
}
return delta;
}
void tdc_update(TDC *model, observation *o, float alpha, float beta) {
// Compute temporal-difference error \delta
float delta = tdc_tderror(model, o);
// Compute phi_k^T w_k:
float phi_dot_w = 0;
for (int i = 0; i < TDC_NFEAT; i++) {
phi_dot_w += o->phi[i] * model->w[i];
}
// Update theta according to (10):
for (int i = 0; i < TDC_NFEAT; i++) {
model->theta[i] += alpha * (
delta * o->phi[i]
- model->gamma*o->phi_next[i]*phi_dot_w
);
}
// Update w according to (9):
for (int i = 0; i < TDC_NFEAT; i++) {
model->w[i] += beta* (delta-phi_dot_w) * o->phi[i];
}
}
#include <stdio.h>
int main() {
TDC model = {
.gamma = 0.9,
};
observation o = {
.reward = 0.1,
.phi = {0, 1, 0, 0},
.phi_next = {0, 1, 2, 0},
};
for (int t = 0; t < 10; t++) {
for (int i = 0; i < TDC_NFEAT; i++) {
printf("w[%d] = %.4g, theta[%d] = %.4g\n",
i, model.w[i], i, model.theta[i]);
}
printf("TD error: %.6g\n", tdc_tderror(&model, &o));
tdc_update(&model, &o, 0.01, 0.02);
}
for (int i = 0; i < TDC_NFEAT; i++) {
printf("w[%d] = %.4g, theta[%d] = %.4g\n",
i, model.w[i], i, model.theta[i]);
}
printf("TD error: %.6g\n", tdc_tderror(&model, &o));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment