Skip to content

Instantly share code, notes, and snippets.

@doug65536
Created March 19, 2023 23:13
Show Gist options
  • Save doug65536/7f697fa87233df17cb1f404f83224774 to your computer and use it in GitHub Desktop.
Save doug65536/7f697fa87233df17cb1f404f83224774 to your computer and use it in GitHub Desktop.
int self_test(tokenizer_map const& tok)
{
// Expect tok.size() to be 102495
std::cerr << "Testing with " << tok.size() << " tokens\n";
size_t input_size = tok.size() * context_size;
size_t output_size = tok.size();
genann_real learning_rate = 0.3;
std::vector<genann_real> inputs(input_size);
std::vector<genann_real> outputs(output_size);
uint64_t mersenne_seed =
std::chrono::high_resolution_clock::now().time_since_epoch().count();
std::mt19937_64 mersenne_rand_engine(mersenne_seed);
std::uniform_int_distribution<size_t> random_token_index(0, tok.size() - 1);
std::vector<size_t> context(context_size + 1);
genann *ann = genann_init(inputs.size(), 4, 1024, outputs.size());
// ann->activation_hidden = genann_act_linear;
// ann->activation_output = genann_act_linear;
std::string bar;
float bar_whole = 60;
for (size_t iter = 0; iter < 9001; ++iter) {
for (size_t i = 0; i < context.size(); ++i)
context[i] = random_token_index(mersenne_rand_engine);
for (size_t i = 0; i < context_size; ++i)
inputs[tok.size() * i + context[i]] = genann_real(1);
outputs.at(context.back()) = genann_real(1);
for (size_t con = 0; con < 16; ++con) {
genann_train(ann, inputs.data(), outputs.data(), learning_rate);
genann_real const *got = genann_run(ann, inputs.data());
for (size_t i = 0; i < outputs.size(); ++i) {
if (!outputs[i] && !got[i])
continue;
size_t bar_len = (size_t)(std::min((genann_real)1.0,
std::max((genann_real)0.0, got[i])) * bar_whole);
bar.assign(bar_len, '=');
bar.append(bar_whole - bar_len, ' ');
std::cerr << bar << ' ' << outputs[i] <<
' ' << got[i] << ' ' << i << '\n';
}
}
for (size_t i = 0; i < context_size; ++i)
inputs[tok.size() * i + context[i]] = genann_real(0);
outputs[context.back()] = genann_real(0);
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment