Skip to content

Instantly share code, notes, and snippets.

@codyrioux
Last active December 24, 2015 10:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save codyrioux/6787339 to your computer and use it in GitHub Desktop.
Save codyrioux/6787339 to your computer and use it in GitHub Desktop.
/* Author: Cody Rioux <cody.rioux@uleth.ca>
* Description: A simple toy program to implement fizzbuzz using reinfocement learning.
*
* Notes: This learner is very tightly coupled to the FizzBuzz specific implementation
* functions. This is obviously not a favorable situation. Consider using libCello to
* make a general SARSA learner that takes function parameters for the general case.
*/
#include <stdio.h>
#include <assert.h>
#include <stdlib.h> //rand, malloc
#include <time.h> //time
#include <float.h> //DBL_MAX
//
// FizzBuzz specific implementation
//
typedef enum { NONE, FIZZ, BUZZ, FIZZBUZZ } Action;
typedef enum { NORMAL, DIV3, DIV5, DIV53 } Features;
typedef int State;
int features(State state) {
if (state % 5 == 0 && state % 3 == 0) return DIV53;
if (state % 3 == 0) return DIV3;
if (state % 5 == 0) return DIV5;
return NORMAL;
}
int terminal(State state) {
return state == 100;
}
double reward(State s, Action a) {
if (features(s) == DIV53 && a == FIZZBUZZ) return 1.0;
if (features(s) == DIV3 && a == FIZZ) return 1.0;
if (features(s) == DIV5 && a == BUZZ) return 1.0;
if (features(s) == NORMAL && a == NONE) return 1.0;
return -1.0;
}
State m(State s, Action a) {
return ++s;
}
void take_action(State s, Action a) {
if (a == FIZZ) {
printf("Fizz\n");
} else if (a == BUZZ) {
printf("Buzz\n");
} else if (a == FIZZBUZZ) {
printf("FizzBuzz\n");
} else {
printf("%d\n", s);
}
}
//
// Learner
//
Action greedy(int action_count, State s, double q[4][4]) {
Action a = 0;
double scoremax = -DBL_MAX;
for (int i = 0; i < action_count; ++i) {
if (q[features(s)][i] > scoremax) {
a = i;
scoremax = q[features(s)][i];
}
}
return a;
}
Action e_greedy(double e, int action_count, State s, double q[4][4]) {
double rnum = rand() / (double)RAND_MAX;
if (rnum <= e) {
return rand() % 4;
} else {
return greedy(action_count, s, q);
}
}
double (*sarsa(int feature_count,
int action_count,
int episodes,
double lambda,
double gamma))[4] {
double (*q)[action_count] = malloc(feature_count * action_count * sizeof(double)); // State-Action value approximations
double e[feature_count][action_count]; // Eligibility trace
double alpha = 0.001; // Learning Rate
double epsillon = 0.01; // Epsillon for e-greedy action selection
// Initialize q and e to 0 for all s, a
for(int i = 0; i < feature_count; ++i) {
for(int j = 0; j < action_count; ++j) {
q[i][j] = 0.0;
e[i][j] = 0.0;
}
}
for(int i = 0; i < episodes; ++i) {
State s = 1;
Action a = NONE;
while (!terminal(s)) {
State sprime = m(s, a);
double r = reward(s, a);
Action aprime = e_greedy(0.01, action_count, sprime, q);
e[features(s)][a] = e[features(s)][a] + 1;
double delta = r + gamma * e[features(sprime)][aprime] - e[features(s)][a];
// update q and e
for(int i = 0; i < feature_count; ++i) {
for(int j = 0; j < action_count; ++j) {
q[i][j] = q[i][j] + alpha * delta * e[i][j];
e[i][j] = gamma * lambda * e[i][j];
}
}
// update s and a
s = sprime;
a = aprime;
}
}
return q;
}
int main(int argc, char **argv) {
srand(time(NULL));
double (*q)[4] = sarsa(4, 4, 300, 0.0, 0.0);
State s = 1;
while(!terminal(s)) {
Action a = greedy(4, s, q);
take_action(s, a);
s = m(s, a);
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment