Skip to content

Instantly share code, notes, and snippets.

@shrubb
Created July 24, 2015 03:43
Show Gist options
  • Save shrubb/e79c86ccb5ab1e4d36f2 to your computer and use it in GitHub Desktop.
Save shrubb/e79c86ccb5ab1e4d36f2 to your computer and use it in GitHub Desktop.
DQN helper
#include <cmath>
#include <iostream>
#include <ale_interface.hpp>
#include <glog/logging.h>
#include <gflags/gflags.h>
#include "prettyprint.hpp"
#include "dqn.hpp"
#include <fstream>
#include <string>
DEFINE_bool(gpu, false, "Use GPU to brew Caffe");
DEFINE_bool(gui, false, "Open a GUI window");
DEFINE_string(rom, "/home/shrubb/Projects/deephack/games/tutankham.bin", "Atari 2600 ROM to play");
DEFINE_string(game_name, "gopher", "Game name");
DEFINE_string(solver, "/home/shrubb/Projects/deephack/Net/Snapshot-tutankham/dqn_tutankham_solver.prototxt", "Solver parameter file (*.prototxt)");
DEFINE_int32(memory, 500000, "Capacity of replay memory");
DEFINE_int32(explore, 1000000, "Number of iterations needed for epsilon to reach 0.1");
DEFINE_double(gamma, 0.95, "Discount factor of future rewards (0,1]");
DEFINE_int32(memory_threshold, 100, "Enough amount of transitions to start learning");
DEFINE_int32(skip_frame, 3, "Number of frames skipped");
DEFINE_bool(show_frame, false, "Show the current frame in CUI");
DEFINE_string(model, "/home/shrubb/Projects/deephack/Net/Snapshot-tutankham/tutankham_iter_140000.caffemodel", "Model file to load");
DEFINE_bool(evaluate, true, "Evaluation mode: only playing a game, no updates");
DEFINE_double(evaluate_with_epsilon, 0.05, "Epsilon value to be used in evaluation mode");
DEFINE_double(repeat_games, 30, "Number of games played in evaluation mode");
double CalculateEpsilon(const int iter) {
if (iter < FLAGS_explore) {
return 1.0 - 0.9 * (static_cast<double>(iter) / FLAGS_explore);
} else {
return 0.1;
}
}
unsigned char hex(char x) {
if (x >= '0' and x <= '9') {
return x - '0';
} else {
return (unsigned char)10 + (unsigned char)(x - 'A');
}
}
int zero_count = 0;
int total_score = 0;
bool read_screen(std::vector<std::vector<unsigned char>> &raw_screen, ALEInterface & interface, int iter){
bool term = false;
char a, b;
char terminate;
char reward[10];
const ALEScreen * screen_const = & interface.getScreen();
ALEScreen * screen = const_cast<ALEScreen*>(screen_const);
for (int i = 0; i < 210; ++i) {
for (int j = 0; j < 160; ++j) {
int scan_res = fscanf(stdin, "%c%c", &a, &b);
if (b == 'I') { // DIE
term = true;
break;
}
raw_screen[i][j] = hex(a) * (unsigned char)16 + hex(b);
*screen->pixel(i, j) = raw_screen[i][j];
}
}
char temp;
int scan_res = fscanf(stdin, "%c", &temp); // :
scan_res = fscanf(stdin, "%c", &terminate);
scan_res = fscanf(stdin, "%c", &temp); // :
std::cin.get(reward, 9, ':');
if (reward[0] != '0') {
std::cerr << zero_count << " zeros\n" << reward << std::endl;
zero_count = 0;
total_score += atoi(reward);
} else {
zero_count++;
}
if (terminate == '1') {
term = true;
}
//char temp[70000];
//fgets(temp, 69999, stdin);
std::cin.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
if (iter % 25 == 0) {
interface.saveScreenPNG(std::string("/home/shrubb/screen") + std::to_string(iter) + std::string(".png"));
}
return term;
}
void make_action(ALEInterface ale, Action action){
fprintf(stdout, "%d,18\n", action);
fflush(stdout);
if(FLAGS_gui)
ale.act(action);
return;
}
int main(int argc, char** argv) {
/*int test;
std::ifstream in("/home/shrubb/test.txt");
for (int i = 0; i < 10; ++i) {
test = in.get();
std::cout << test << std::endl;
}
return 0;*/
//google::
//minloglevel=google::ERROR;
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
google::LogToStderr();
if (FLAGS_gpu) {
caffe::Caffe::set_mode(caffe::Caffe::GPU);
} else {
caffe::Caffe::set_mode(caffe::Caffe::CPU);
}
//freopen("simple_in", "r", stdin);
fprintf(stdout, "team_4,CZELol,%s\n", FLAGS_game_name.c_str());
fflush(stdout);
//std::ofstream out("test.txt");
char temp[70000];
fgets(temp, 69999, stdin);
//out << temp << std::endl;
//out.close();
fprintf(stdout, "0,0,0,1\n");
fflush(stdout);
ALEInterface ale(FLAGS_gui);
ALEInterface ale2(FLAGS_gui);
// Load the ROM file
ale.loadROM(FLAGS_rom);
ale2.loadROM(FLAGS_rom);
// Get the vector of legal actions
const auto legal_actions = ale.getMinimalActionSet();
dqn::DQN dqn(legal_actions, FLAGS_solver, FLAGS_memory, FLAGS_gamma);
dqn.Initialize();
std::cerr << "Loading " << FLAGS_model << std::endl;
dqn.LoadTrainedModel(FLAGS_model);
// char test[70000];
// fscanf(stdin, "%s", test);
// std::cout << strlen(test);
// return 0;
std::deque<dqn::FrameDataSp> past_frames;
dqn::FrameDataSp current_frame;
std::vector<std::vector<unsigned char>> raw_screen(210, std::vector<unsigned char>(160));
bool term;
// If there are not past frames enough for DQN input, just select NOOP
int frame = 0;
for (; frame < 4; ++frame){
term = read_screen(raw_screen, ale2, frame);
if (term)
goto konec;
//std::cout << "Term " << term << std::endl;
current_frame = dqn::PreprocessArrayScreen(raw_screen);
past_frames.push_back(current_frame);
make_action(ale, PLAYER_A_NOOP);
}
for (; ; ++frame) {
term = read_screen(raw_screen, ale2, 4 * frame);
if (term)
goto konec;
std::cerr << "frame " << frame << std::endl;
//std::cout << "Term " << term << std::endl;
if (frame % 500 == 0) std::cerr << "frame: " << frame << std::endl;
current_frame = dqn::PreprocessArrayScreen(raw_screen);
past_frames.push_back(current_frame);
past_frames.pop_front();
dqn::InputFrames input_frames;
std::copy(past_frames.begin(), past_frames.end(), input_frames.begin());
float max_qvalue;
const auto action = dqn.SelectAction(input_frames, FLAGS_evaluate_with_epsilon, max_qvalue);
auto immediate_score = 0.0;
//for (auto i = 0; i < FLAGS_skip_frame + 1 && !ale.game_over(); ++i) {
for (auto i = 0; i < FLAGS_skip_frame + 1; ++i) {
// Last action is repeated on skipped frames
make_action(ale, action);
term = read_screen(raw_screen, ale2, 4 * frame + i + 1);
//std::cout << "Term " << term << std::endl;
if (term)
goto konec;
}
make_action(ale, action);
}
//out.close();
konec:
std::cerr << "total score: " << total_score << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment