Skip to content

Instantly share code, notes, and snippets.

@marty1885
Created November 28, 2018 06:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marty1885/b94225e60a6c067b6831be57ad1c6fd6 to your computer and use it in GitHub Desktop.
Save marty1885/b94225e60a6c067b6831be57ad1c6fd6 to your computer and use it in GitHub Desktop.
xt::xarray<float> compute(xt::xarray<float> input)
{
assert(input.size() == 3);
//save data for traning
if(last_input_.size() != 0) {
for(auto v : last_input_)
input_.push_back(v);
for(auto v : input)
output_.push_back(v);
}
last_input_ = vec_t(input.begin(), input.end());
//Train once all needed data collected
if(input_.size() == RNN_DATA_PER_EPOCH) {
assert(input_.size() == output_.size());
//Set the netwotk into a "traning more"
nn.at<recurrent_layer>(0).seq_len(RNN_DATA_PER_EPOCH);
nn.set_netphase(net_phase::train);
nn.fit<cross_entropy_multiclass>(optimizer_, std::vector<vec_t>({input_}),std::vector<vec_t>({output_}), 1, 1, [](){},[](){});
//Leave the "leaning" mode. Keep on predicting
nn.set_netphase(net_phase::test);
nn.at<recurrent_layer>(0).seq_len(1);
input_.clear();
output_.clear();
}
//Predict the opponent's next mvoe
vec_t out = nn_.predict(vec_t(input.begin(), input.end()));
assert(out.size() == 3);
//Convert the prediction to xarray
xt::xarray<float> r = xt::zeros<float>({3});
for(size_t i=0;i<out.size();i++)
r[i] = out[i];
return r;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment