Last active
December 24, 2015 10:58
-
-
Save codyrioux/6787339 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* Author: Cody Rioux <cody.rioux@uleth.ca> | |
* Description: A simple toy program to implement fizzbuzz using reinfocement learning. | |
* | |
* Notes: This learner is very tightly coupled to the FizzBuzz specific implementation | |
* functions. This is obviously not a favorable situation. Consider using libCello to | |
* make a general SARSA learner that takes function parameters for the general case. | |
*/ | |
#include <stdio.h> | |
#include <assert.h> | |
#include <stdlib.h> //rand, malloc | |
#include <time.h> //time | |
#include <float.h> //DBL_MAX | |
// | |
// FizzBuzz specific implementation | |
// | |
typedef enum { NONE, FIZZ, BUZZ, FIZZBUZZ } Action; | |
typedef enum { NORMAL, DIV3, DIV5, DIV53 } Features; | |
typedef int State; | |
int features(State state) { | |
if (state % 5 == 0 && state % 3 == 0) return DIV53; | |
if (state % 3 == 0) return DIV3; | |
if (state % 5 == 0) return DIV5; | |
return NORMAL; | |
} | |
int terminal(State state) { | |
return state == 100; | |
} | |
double reward(State s, Action a) { | |
if (features(s) == DIV53 && a == FIZZBUZZ) return 1.0; | |
if (features(s) == DIV3 && a == FIZZ) return 1.0; | |
if (features(s) == DIV5 && a == BUZZ) return 1.0; | |
if (features(s) == NORMAL && a == NONE) return 1.0; | |
return -1.0; | |
} | |
State m(State s, Action a) { | |
return ++s; | |
} | |
void take_action(State s, Action a) { | |
if (a == FIZZ) { | |
printf("Fizz\n"); | |
} else if (a == BUZZ) { | |
printf("Buzz\n"); | |
} else if (a == FIZZBUZZ) { | |
printf("FizzBuzz\n"); | |
} else { | |
printf("%d\n", s); | |
} | |
} | |
// | |
// Learner | |
// | |
Action greedy(int action_count, State s, double q[4][4]) { | |
Action a = 0; | |
double scoremax = -DBL_MAX; | |
for (int i = 0; i < action_count; ++i) { | |
if (q[features(s)][i] > scoremax) { | |
a = i; | |
scoremax = q[features(s)][i]; | |
} | |
} | |
return a; | |
} | |
Action e_greedy(double e, int action_count, State s, double q[4][4]) { | |
double rnum = rand() / (double)RAND_MAX; | |
if (rnum <= e) { | |
return rand() % 4; | |
} else { | |
return greedy(action_count, s, q); | |
} | |
} | |
double (*sarsa(int feature_count, | |
int action_count, | |
int episodes, | |
double lambda, | |
double gamma))[4] { | |
double (*q)[action_count] = malloc(feature_count * action_count * sizeof(double)); // State-Action value approximations | |
double e[feature_count][action_count]; // Eligibility trace | |
double alpha = 0.001; // Learning Rate | |
double epsillon = 0.01; // Epsillon for e-greedy action selection | |
// Initialize q and e to 0 for all s, a | |
for(int i = 0; i < feature_count; ++i) { | |
for(int j = 0; j < action_count; ++j) { | |
q[i][j] = 0.0; | |
e[i][j] = 0.0; | |
} | |
} | |
for(int i = 0; i < episodes; ++i) { | |
State s = 1; | |
Action a = NONE; | |
while (!terminal(s)) { | |
State sprime = m(s, a); | |
double r = reward(s, a); | |
Action aprime = e_greedy(0.01, action_count, sprime, q); | |
e[features(s)][a] = e[features(s)][a] + 1; | |
double delta = r + gamma * e[features(sprime)][aprime] - e[features(s)][a]; | |
// update q and e | |
for(int i = 0; i < feature_count; ++i) { | |
for(int j = 0; j < action_count; ++j) { | |
q[i][j] = q[i][j] + alpha * delta * e[i][j]; | |
e[i][j] = gamma * lambda * e[i][j]; | |
} | |
} | |
// update s and a | |
s = sprime; | |
a = aprime; | |
} | |
} | |
return q; | |
} | |
int main(int argc, char **argv) { | |
srand(time(NULL)); | |
double (*q)[4] = sarsa(4, 4, 300, 0.0, 0.0); | |
State s = 1; | |
while(!terminal(s)) { | |
Action a = greedy(4, s, q); | |
take_action(s, a); | |
s = m(s, a); | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment