Skip to content

Instantly share code, notes, and snippets.

@mrm1001
Created August 11, 2018 16:40
Show Gist options
  • Save mrm1001/11fc60200fc2bab35c763c000cc80cab to your computer and use it in GitHub Desktop.
Save mrm1001/11fc60200fc2bab35c763c000cc80cab to your computer and use it in GitHub Desktop.
FastText
void FastText::trainThread(int32_t threadId) {
std::ifstream ifs(args_->input);
utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
Model model(input_, output_, args_, threadId);
if (args_->model == model_name::sup) {
model.setTargetCounts(dict_->getCounts(entry_type::label));
} else {
model.setTargetCounts(dict_->getCounts(entry_type::word));
}
const int64_t ntokens = dict_->ntokens();
int64_t localTokenCount = 0;
std::vector<int32_t> line, labels;
while (tokenCount_ < args_->epoch * ntokens) {
real progress = real(tokenCount_) / (args_->epoch * ntokens);
real lr = args_->lr * (1.0 - progress);
if (args_->model == model_name::sup) {
localTokenCount += dict_->getLine(ifs, line, labels);
supervised(model, lr, line, labels);
} else if (args_->model == model_name::cbow) {
localTokenCount += dict_->getLine(ifs, line, model.rng);
cbow(model, lr, line);
} else if (args_->model == model_name::sg) {
localTokenCount += dict_->getLine(ifs, line, model.rng);
skipgram(model, lr, line);
}
if (localTokenCount > args_->lrUpdateRate) {
tokenCount_ += localTokenCount;
localTokenCount = 0;
if (threadId == 0 && args_->verbose > 1)
loss_ = model.getLoss();
}
}
if (threadId == 0)
loss_ = model.getLoss();
ifs.close();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment