Skip to content

Instantly share code, notes, and snippets.

@tklein23
Created March 13, 2014 21:43
Show Gist options
  • Save tklein23/9537592 to your computer and use it in GitHub Desktop.
Save tklein23/9537592 to your computer and use it in GitHub Desktop.
SHOGUN: MultilabelLabels::save() and MultilabelLabels::load()
// The following includes are only needed for save() and load()
#include <string>
#include <iostream>
#include <fstream>
#include <limits>
void
CMultilabelLabels::save(const char * fname)
{
FILE * fh = fopen(fname, "wb");
for (int32_t label_j = 0; label_j < get_num_labels(); label_j++)
{
SGVector <int16_t> yb = get_label(label_j);
for (int32_t i_pos = 0; i_pos < yb.vlen; i_pos++)
{
fprintf(fh, "%d ", yb[i_pos]);
}
fprintf(fh, "\n");
}
fclose(fh);
}
void
CMultilabelLabels::load_info(const char * fname, int32_t &num_labels, int16_t &num_classes)
{
std::ifstream labelfile(fname);
std::string line;
int32_t lineno = 0;
int16_t max_class_index = 0;
while (std::getline(labelfile, line))
{
REQUIRE(lineno < std::numeric_limits <int32_t>::max(),
"lineno will overflow");
std::istringstream iss(line);
std::string token;
while (iss >> token)
{
int16_t class_index;
if ((std::istringstream(token) >> class_index).fail())
{
SG_SERROR("INPUT ERROR (line %d): cannot cast token %s to integer\n",
lineno + 1, token.c_str());
break;
}
REQUIRE(class_index >= 0, "class_index too small");
REQUIRE(class_index < std::numeric_limits <int16_t>::max(),
"class_index will overflow ");
if (class_index > max_class_index)
{
max_class_index = class_index;
}
}
lineno++;
}
labelfile.close();
max_class_index++;
num_labels = lineno;
num_classes = CMath::max(max_class_index, num_classes);
return;
}
CMultilabelLabels *
CMultilabelLabels::load(const char * fname)
{
int32_t num_labels = 0;
int16_t num_classes = 0;
CMultilabelLabels::load_info(fname, num_labels, num_classes);
SG_SINFO("CMultilabelLabels::load(%s): found %d multilabels with %d classes\n",
fname, num_labels, num_classes);
int16_t temp[num_classes];
SGVector <int16_t> * output_rows =
SG_CALLOC(SGVector <int16_t>, num_labels);
std::ifstream labelfile(fname);
std::string line;
int32_t label_j = 0;
while (std::getline(labelfile, line))
{
// std::cout << "input(line " << label_j << "): " << line << std::endl;
REQUIRE(label_j < num_labels, "label_j exceeds num_labels");
int32_t num_label_classes = 0;
std::istringstream iss(line);
std::string token;
while (iss >> token)
{
int16_t class_i;
if ((std::istringstream(token) >> class_i).fail())
{
SG_SERROR("INPUT ERROR (line %d): cannot cast token %s to integer\n",
label_j + 1, token.c_str());
break;
}
REQUIRE(class_i >= 0, "class_i too small");
REQUIRE(class_i < num_classes, "class_i exceeds num_classes");
REQUIRE(num_label_classes < num_classes,
"line contains more classes than num_classes");
temp[num_label_classes] = class_i;
num_label_classes++;
}
output_rows[label_j] = SGVector <int16_t> (
SGVector <int16_t>::clone_vector(temp, num_label_classes),
num_label_classes);
label_j++;
}
REQUIRE(label_j == num_labels,
"label count differs from what we read before");
CMultilabelLabels * outputs;
outputs = new CMultilabelLabels(num_labels, num_classes);
outputs->set_labels(output_rows);
SG_FREE(output_rows);
return outputs;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment