Created
January 11, 2021 08:06
-
-
Save knowblesse/c2feaf5eae5bdb728db47698d6406bc6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const int PIN_Motor_L = 3; | |
const int PIN_Motor_R = 11; | |
const int PIN_MOTOR_L_DIR = 12; | |
const int PIN_MOTOR_R_DIR = 13; | |
void setup() { | |
Serial.begin(9600); | |
pinMode(PIN_Motor_R, OUTPUT); | |
const int PIN_Motor_L = 3; | |
const int PIN_Motor_R = 11; | |
const int PIN_MOTOR_L_DIR = 12; | |
const int PIN_MOTOR_R_DIR = 13; | |
const int PIN_Angle = 2; | |
const int BAR_DOWN[3] = {300,488,670}; // DONW, UP, DOWN | |
const int CART_OUT = 20; | |
double w12[4+1][8]; | |
double w23[8+1][2]; | |
int N_TRIAL; | |
int N_STATE = 4;// loc(curr), loc(last), ang(curr), ang(last) | |
int N_HIDDEN = 8; | |
int N_ACTION = 2; | |
void setup(){ | |
// Setup functions | |
Serial.begin(9600); | |
pinMode(PIN_Motor_R, OUTPUT); | |
pinMode(PIN_Motor_L, OUTPUT); | |
// Network Parameters | |
double GAMMA = 0.99; | |
double LR = 0.05; | |
// node | |
double state[4]; | |
double node_1[4+1]; | |
double node_2[8+1]; | |
double node_3[2]; | |
//weights | |
for(int i = 0; i < N_STATE+1; i++){ | |
for(int j = 0; j < N_HIDDEN; j++){ | |
w12[i][j] = 0.5; | |
} | |
} | |
for(int j = 0; j < N_HIDDEN+1; j++){ | |
for(int k = 0; k < N_ACTION; k++){ | |
w23[j][k] = 0.5; | |
} | |
} | |
// delta | |
double delta_12[8]; | |
double delta_23[4]; | |
// reward | |
int reward_memory = 30; | |
double reward_arr[30]; | |
double reward; | |
double discounted_reward; | |
int num_reward = 0; | |
// Loop function | |
bool isLearn = false; | |
bool isFirstLearn = false; | |
double lastPos; | |
double currPos; | |
double lastAng; | |
double currAng; | |
int action; | |
int trial; | |
String incoming; | |
while (true) { | |
/******************************************/ | |
/***********READ SERIAL COMMANDS***********/ | |
/******************************************/ | |
if (Serial.available() > 0) { | |
// read the incoming: | |
incoming = Serial.readString(); | |
// stop | |
if (incoming == 's') { | |
break; | |
} | |
// print weights | |
if (incoming == 'w') { | |
break; | |
} | |
// learn | |
if (incoming == 'l') { | |
isLearn = true; | |
isFirstLearn = true; | |
trial = 0; | |
} | |
// test | |
if (incoming == 't') { | |
isLearn = false; | |
break; | |
} | |
} | |
/******************************************/ | |
/*********Neural Network Learning**********/ | |
/******************************************/ | |
if (isLearn && (trial < N_TRIAL)){ | |
// read current angle | |
currAng = analogRead(2); | |
if (isFirstLearn){ | |
// Motor Power On | |
analogWrite(PIN_Motor_L, 255); | |
analogWrite(PIN_Motor_R, 255); | |
lastAng = currAng; | |
lastPos = 0; | |
currPos = 0; | |
isFirstLearn = false; | |
num_reward = 0; | |
} | |
state[0] = lastPos; | |
state[1] = currPos; | |
state[2] = lastAng - BAR_DOWN[1]; | |
state[3] = currAng - BAR_DOWN[1]; | |
node_1[0] = 1; | |
node_1[1] = state[0]; // add 1 for bias node | |
node_1[2] = state[1]; | |
node_1[3] = state[2]; | |
node_1[4] = state[3]; | |
// FEEDFORWARD : LAYER 12 | |
for (int j = 1; j < N_HIDDEN+1; j++){ | |
node_2[j] = 0; | |
for (int i = 0; i < N_STATE+1; i++){ | |
node_2[j] += node_1[i] * w12[i][j]; | |
} | |
node_2[j] = 1.0 / (1.0 + pow(2.71828, node_2[j])); // Soft-Max function | |
} | |
// FEEDFORWARD : LAYER 23 | |
node_2[0] = 1; //bias | |
for (int k = 0; k < N_ACTION; k++){ | |
node_3[k] = 0; | |
for (int j = 0; j < N_STATE; j++){ | |
node_3[k] += node_2[j] * w23[j][k]; | |
} | |
node_3[k] = 1.0 / (1.0 + pow(2.71828, node_3[k])); // Soft-Max function | |
} | |
// Select action | |
if (node_3[0] > node_3[1]){ | |
action = 1; | |
goFront(); | |
} | |
else{ | |
action = -1; | |
goBack(); | |
} | |
//action을 샐랙트 하면 그 액션을 더해줌 | |
lastPos = currPos; | |
currPos = currPos + action; | |
lastAng = currAng; | |
// check fall down | |
if (currAng < BAR_DOWN[0] || BAR_DOWN[2] < currAng || abs(currPos) > CART_OUT){ // fall down or out of track | |
analogWrite(PIN_Motor_L, 0); | |
analogWrite(PIN_Motor_R, 0); | |
isFirstLearn = true; | |
delay(1000); | |
} | |
//reward computation | |
reward = (BAR_DOWN[1] - abs(currAng - BAR_DOWN[1])) / BAR_DOWN[1]; | |
if(num_reward >= reward_memory){ | |
for(int i = 0; i < reward_memory-1; i++){ | |
reward_arr[i] = reward_arr[i+1]; // side by side | |
} | |
num_reward = reward_memory-1; | |
} | |
reward_arr[num_reward] = reward; | |
// discounted reward computation | |
discounted_reward = 0; | |
for(int i = num_reward; i >= 0 ; i--){ | |
discounted_reward += reward_arr[i] * pow(GAMMA, num_reward - i); | |
} | |
num_reward++; | |
// | |
//error; | |
int d[2] = {1, 2}; | |
// Backpropagation : LAYER 32 | |
for (int k = 0; k < N_ACTION; k++){ | |
delta_23[k] = - (d[k] - node_3[k]) * (node_3[k] * (1-node_3[k])); | |
for (int j = 0; j < N_HIDDEN+1; j++){ | |
w23[j][k] += LR * delta_23[k] * node_2[j]; | |
} | |
} | |
// delta computation for layer 12 update | |
for (int j = 0; j < N_HIDDEN; j++){ | |
delta_12[j] = 0; | |
for (int k = 0; k < N_ACTION; k++){ | |
delta_12[j] += delta_23[k] * w23[j][k] * node_2[j] * (1-node_2[j]); | |
} | |
} | |
// Backpropagation : LAYER 21 | |
for (int j = 0; j < N_HIDDEN; j++){ | |
for (int i = 0; i < N_STATE+1; i++){ | |
w12[i][j] += LR * delta_12[j] * node_1[i]; | |
} | |
} | |
trial++; | |
Serial.println(trial); | |
} | |
} | |
} | |
void goFront(){ | |
digitalWrite(PIN_MOTOR_L_DIR, LOW); | |
digitalWrite(PIN_MOTOR_R_DIR, HIGH); | |
} | |
void goBack(){ | |
digitalWrite(PIN_MOTOR_L_DIR, HIGH); | |
digitalWrite(PIN_MOTOR_R_DIR, LOW); | |
} | |
void printWeights(){ | |
Serial.println("w12"); | |
for (int i = 0; i < N_STATE+1; i++){ | |
for (int j = 0; j < N_HIDDEN; j++){ | |
Serial.print(w12[i][j]); | |
Serial.print(" "); | |
} | |
Serial.print("\n"); | |
} | |
Serial.println("w23"); | |
for (int j = 0; j < N_HIDDEN+1; j++){ | |
for (int k = 0; k < N_STATE; k++){ | |
Serial.print(w23[j][k]); | |
Serial.print(" "); | |
} | |
Serial.print("\n"); | |
} | |
} | |
void loop(){ | |
} | |
double* multiplyMatrix(double* a, int a_row, int a_col, double* b, int b_row, int b_col) { | |
double* output_arr; | |
output_arr = (double*)malloc(sizeof(double) * a_row * b_col); | |
if (output_arr == NULL) { | |
printf("Low memory\n"); | |
} | |
else { | |
for (int i = 0; i < a_row; i++) { | |
for (int j = 0; j < b_col; j++) { | |
*(output_arr + (b_col * i + j)) = 0; | |
for (int t = 0; t < a_col; t++) { | |
double temp_a = *(a + (a_col * i + t)); | |
double temp_b = *(b + (b_col * t + j)); | |
*(output_arr + (b_col * i + j)) += *(a + (a_col * i + t)) * *(b + (b_col * t + j)); | |
} | |
} | |
} | |
} | |
return output_arr; | |
} | |
pinMode(PIN_Motor_L, OUTPUT); | |
analogWrite(PIN_Motor_R, 128); | |
} | |
String incoming; | |
void loop() { | |
// put your main code here, to run repeatedly: | |
while (true) { | |
if (Serial.available() > 0) { | |
// read the incoming: | |
incoming = Serial.readString(); | |
if (incoming == "A"){ | |
digitalWrite(PIN_MOTOR_R_DIR,HIGH); | |
} | |
if (incoming == "B"){ | |
digitalWrite(PIN_MOTOR_R_DIR,LOW); | |
} | |
} | |
} | |
A * B |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment