Created
March 13, 2014 21:43
-
-
Save tklein23/9537592 to your computer and use it in GitHub Desktop.
SHOGUN: MultilabelLabels::save() and MultilabelLabels::load()
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// The following includes are only needed for save() and load() | |
#include <string> | |
#include <iostream> | |
#include <fstream> | |
#include <limits> | |
void | |
CMultilabelLabels::save(const char * fname) | |
{ | |
FILE * fh = fopen(fname, "wb"); | |
for (int32_t label_j = 0; label_j < get_num_labels(); label_j++) | |
{ | |
SGVector <int16_t> yb = get_label(label_j); | |
for (int32_t i_pos = 0; i_pos < yb.vlen; i_pos++) | |
{ | |
fprintf(fh, "%d ", yb[i_pos]); | |
} | |
fprintf(fh, "\n"); | |
} | |
fclose(fh); | |
} | |
void | |
CMultilabelLabels::load_info(const char * fname, int32_t &num_labels, int16_t &num_classes) | |
{ | |
std::ifstream labelfile(fname); | |
std::string line; | |
int32_t lineno = 0; | |
int16_t max_class_index = 0; | |
while (std::getline(labelfile, line)) | |
{ | |
REQUIRE(lineno < std::numeric_limits <int32_t>::max(), | |
"lineno will overflow"); | |
std::istringstream iss(line); | |
std::string token; | |
while (iss >> token) | |
{ | |
int16_t class_index; | |
if ((std::istringstream(token) >> class_index).fail()) | |
{ | |
SG_SERROR("INPUT ERROR (line %d): cannot cast token %s to integer\n", | |
lineno + 1, token.c_str()); | |
break; | |
} | |
REQUIRE(class_index >= 0, "class_index too small"); | |
REQUIRE(class_index < std::numeric_limits <int16_t>::max(), | |
"class_index will overflow "); | |
if (class_index > max_class_index) | |
{ | |
max_class_index = class_index; | |
} | |
} | |
lineno++; | |
} | |
labelfile.close(); | |
max_class_index++; | |
num_labels = lineno; | |
num_classes = CMath::max(max_class_index, num_classes); | |
return; | |
} | |
CMultilabelLabels * | |
CMultilabelLabels::load(const char * fname) | |
{ | |
int32_t num_labels = 0; | |
int16_t num_classes = 0; | |
CMultilabelLabels::load_info(fname, num_labels, num_classes); | |
SG_SINFO("CMultilabelLabels::load(%s): found %d multilabels with %d classes\n", | |
fname, num_labels, num_classes); | |
int16_t temp[num_classes]; | |
SGVector <int16_t> * output_rows = | |
SG_CALLOC(SGVector <int16_t>, num_labels); | |
std::ifstream labelfile(fname); | |
std::string line; | |
int32_t label_j = 0; | |
while (std::getline(labelfile, line)) | |
{ | |
// std::cout << "input(line " << label_j << "): " << line << std::endl; | |
REQUIRE(label_j < num_labels, "label_j exceeds num_labels"); | |
int32_t num_label_classes = 0; | |
std::istringstream iss(line); | |
std::string token; | |
while (iss >> token) | |
{ | |
int16_t class_i; | |
if ((std::istringstream(token) >> class_i).fail()) | |
{ | |
SG_SERROR("INPUT ERROR (line %d): cannot cast token %s to integer\n", | |
label_j + 1, token.c_str()); | |
break; | |
} | |
REQUIRE(class_i >= 0, "class_i too small"); | |
REQUIRE(class_i < num_classes, "class_i exceeds num_classes"); | |
REQUIRE(num_label_classes < num_classes, | |
"line contains more classes than num_classes"); | |
temp[num_label_classes] = class_i; | |
num_label_classes++; | |
} | |
output_rows[label_j] = SGVector <int16_t> ( | |
SGVector <int16_t>::clone_vector(temp, num_label_classes), | |
num_label_classes); | |
label_j++; | |
} | |
REQUIRE(label_j == num_labels, | |
"label count differs from what we read before"); | |
CMultilabelLabels * outputs; | |
outputs = new CMultilabelLabels(num_labels, num_classes); | |
outputs->set_labels(output_rows); | |
SG_FREE(output_rows); | |
return outputs; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment