Skip to content

Instantly share code, notes, and snippets.

@knowblesse
Created January 11, 2021 08:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save knowblesse/c2feaf5eae5bdb728db47698d6406bc6 to your computer and use it in GitHub Desktop.
Save knowblesse/c2feaf5eae5bdb728db47698d6406bc6 to your computer and use it in GitHub Desktop.
const int PIN_Motor_L = 3;
const int PIN_Motor_R = 11;
const int PIN_MOTOR_L_DIR = 12;
const int PIN_MOTOR_R_DIR = 13;
void setup() {
Serial.begin(9600);
pinMode(PIN_Motor_R, OUTPUT);
const int PIN_Motor_L = 3;
const int PIN_Motor_R = 11;
const int PIN_MOTOR_L_DIR = 12;
const int PIN_MOTOR_R_DIR = 13;
const int PIN_Angle = 2;
const int BAR_DOWN[3] = {300,488,670}; // DONW, UP, DOWN
const int CART_OUT = 20;
double w12[4+1][8];
double w23[8+1][2];
int N_TRIAL;
int N_STATE = 4;// loc(curr), loc(last), ang(curr), ang(last)
int N_HIDDEN = 8;
int N_ACTION = 2;
void setup(){
// Setup functions
Serial.begin(9600);
pinMode(PIN_Motor_R, OUTPUT);
pinMode(PIN_Motor_L, OUTPUT);
// Network Parameters
double GAMMA = 0.99;
double LR = 0.05;
// node
double state[4];
double node_1[4+1];
double node_2[8+1];
double node_3[2];
//weights
for(int i = 0; i < N_STATE+1; i++){
for(int j = 0; j < N_HIDDEN; j++){
w12[i][j] = 0.5;
}
}
for(int j = 0; j < N_HIDDEN+1; j++){
for(int k = 0; k < N_ACTION; k++){
w23[j][k] = 0.5;
}
}
// delta
double delta_12[8];
double delta_23[4];
// reward
int reward_memory = 30;
double reward_arr[30];
double reward;
double discounted_reward;
int num_reward = 0;
// Loop function
bool isLearn = false;
bool isFirstLearn = false;
double lastPos;
double currPos;
double lastAng;
double currAng;
int action;
int trial;
String incoming;
while (true) {
/******************************************/
/***********READ SERIAL COMMANDS***********/
/******************************************/
if (Serial.available() > 0) {
// read the incoming:
incoming = Serial.readString();
// stop
if (incoming == 's') {
break;
}
// print weights
if (incoming == 'w') {
break;
}
// learn
if (incoming == 'l') {
isLearn = true;
isFirstLearn = true;
trial = 0;
}
// test
if (incoming == 't') {
isLearn = false;
break;
}
}
/******************************************/
/*********Neural Network Learning**********/
/******************************************/
if (isLearn && (trial < N_TRIAL)){
// read current angle
currAng = analogRead(2);
if (isFirstLearn){
// Motor Power On
analogWrite(PIN_Motor_L, 255);
analogWrite(PIN_Motor_R, 255);
lastAng = currAng;
lastPos = 0;
currPos = 0;
isFirstLearn = false;
num_reward = 0;
}
state[0] = lastPos;
state[1] = currPos;
state[2] = lastAng - BAR_DOWN[1];
state[3] = currAng - BAR_DOWN[1];
node_1[0] = 1;
node_1[1] = state[0]; // add 1 for bias node
node_1[2] = state[1];
node_1[3] = state[2];
node_1[4] = state[3];
// FEEDFORWARD : LAYER 12
for (int j = 1; j < N_HIDDEN+1; j++){
node_2[j] = 0;
for (int i = 0; i < N_STATE+1; i++){
node_2[j] += node_1[i] * w12[i][j];
}
node_2[j] = 1.0 / (1.0 + pow(2.71828, node_2[j])); // Soft-Max function
}
// FEEDFORWARD : LAYER 23
node_2[0] = 1; //bias
for (int k = 0; k < N_ACTION; k++){
node_3[k] = 0;
for (int j = 0; j < N_STATE; j++){
node_3[k] += node_2[j] * w23[j][k];
}
node_3[k] = 1.0 / (1.0 + pow(2.71828, node_3[k])); // Soft-Max function
}
// Select action
if (node_3[0] > node_3[1]){
action = 1;
goFront();
}
else{
action = -1;
goBack();
}
//action을 샐랙트 하면 그 액션을 더해줌
lastPos = currPos;
currPos = currPos + action;
lastAng = currAng;
// check fall down
if (currAng < BAR_DOWN[0] || BAR_DOWN[2] < currAng || abs(currPos) > CART_OUT){ // fall down or out of track
analogWrite(PIN_Motor_L, 0);
analogWrite(PIN_Motor_R, 0);
isFirstLearn = true;
delay(1000);
}
//reward computation
reward = (BAR_DOWN[1] - abs(currAng - BAR_DOWN[1])) / BAR_DOWN[1];
if(num_reward >= reward_memory){
for(int i = 0; i < reward_memory-1; i++){
reward_arr[i] = reward_arr[i+1]; // side by side
}
num_reward = reward_memory-1;
}
reward_arr[num_reward] = reward;
// discounted reward computation
discounted_reward = 0;
for(int i = num_reward; i >= 0 ; i--){
discounted_reward += reward_arr[i] * pow(GAMMA, num_reward - i);
}
num_reward++;
//
//error;
int d[2] = {1, 2};
// Backpropagation : LAYER 32
for (int k = 0; k < N_ACTION; k++){
delta_23[k] = - (d[k] - node_3[k]) * (node_3[k] * (1-node_3[k]));
for (int j = 0; j < N_HIDDEN+1; j++){
w23[j][k] += LR * delta_23[k] * node_2[j];
}
}
// delta computation for layer 12 update
for (int j = 0; j < N_HIDDEN; j++){
delta_12[j] = 0;
for (int k = 0; k < N_ACTION; k++){
delta_12[j] += delta_23[k] * w23[j][k] * node_2[j] * (1-node_2[j]);
}
}
// Backpropagation : LAYER 21
for (int j = 0; j < N_HIDDEN; j++){
for (int i = 0; i < N_STATE+1; i++){
w12[i][j] += LR * delta_12[j] * node_1[i];
}
}
trial++;
Serial.println(trial);
}
}
}
void goFront(){
digitalWrite(PIN_MOTOR_L_DIR, LOW);
digitalWrite(PIN_MOTOR_R_DIR, HIGH);
}
void goBack(){
digitalWrite(PIN_MOTOR_L_DIR, HIGH);
digitalWrite(PIN_MOTOR_R_DIR, LOW);
}
void printWeights(){
Serial.println("w12");
for (int i = 0; i < N_STATE+1; i++){
for (int j = 0; j < N_HIDDEN; j++){
Serial.print(w12[i][j]);
Serial.print(" ");
}
Serial.print("\n");
}
Serial.println("w23");
for (int j = 0; j < N_HIDDEN+1; j++){
for (int k = 0; k < N_STATE; k++){
Serial.print(w23[j][k]);
Serial.print(" ");
}
Serial.print("\n");
}
}
void loop(){
}
double* multiplyMatrix(double* a, int a_row, int a_col, double* b, int b_row, int b_col) {
double* output_arr;
output_arr = (double*)malloc(sizeof(double) * a_row * b_col);
if (output_arr == NULL) {
printf("Low memory\n");
}
else {
for (int i = 0; i < a_row; i++) {
for (int j = 0; j < b_col; j++) {
*(output_arr + (b_col * i + j)) = 0;
for (int t = 0; t < a_col; t++) {
double temp_a = *(a + (a_col * i + t));
double temp_b = *(b + (b_col * t + j));
*(output_arr + (b_col * i + j)) += *(a + (a_col * i + t)) * *(b + (b_col * t + j));
}
}
}
}
return output_arr;
}
pinMode(PIN_Motor_L, OUTPUT);
analogWrite(PIN_Motor_R, 128);
}
String incoming;
void loop() {
// put your main code here, to run repeatedly:
while (true) {
if (Serial.available() > 0) {
// read the incoming:
incoming = Serial.readString();
if (incoming == "A"){
digitalWrite(PIN_MOTOR_R_DIR,HIGH);
}
if (incoming == "B"){
digitalWrite(PIN_MOTOR_R_DIR,LOW);
}
}
}
A * B
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment