Skip to content

Instantly share code, notes, and snippets.

@kazoo04
Created February 1, 2015 13:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kazoo04/40f1acac80330dfa0312 to your computer and use it in GitHub Desktop.
Save kazoo04/40f1acac80330dfa0312 to your computer and use it in GitHub Desktop.
import arow;
import std.stdio;
import std.random;
import std.string;
import std.conv;
import std.file;
import std.stream;
struct example {
int label;
double[int] features;
@safe nothrow {
this(int label, double[int] features) {
this.label = label;
this.features = features;
}
}
}
double[int] parseLine(string line) {
immutable string delim_value = ":";
immutable string delim_cols = " ";
string[] columns = line.split(delim_cols);
double[int] f;
for(int i = 1; i < columns.length; i++) {
string[] arr = columns[i].split(delim_value);
if(arr.length != 2)
continue;
assert(arr != null);
assert(arr.length == 2);
auto key = to!int(arr[0]);
auto val = to!double(arr[1]);
f[key] = val;
}
return f;
}
/*
example[] readDataIris(string filename){
Stream file = new BufferedFile(filename);
size_t num_lines = 0;
example[] data;
foreach (char[] _line; file) {
string line = cast(string)_line;
if (line.length == 0) continue;
if (line[0] == '#') continue;
assert(line[0] == '1' || line[0] == '2' || line[0] == '3');
string[] columns = line.split(",");
double[int] vec = parseLine(line);
vec[0] = to!double(columns[1]);
vec[1] = to!double(columns[2]);
vec[2] = to!double(columns[3]);
example ex = example(to!int(columns[0]) - 1, vec);
data ~= ex;
}
file.close();
return data;
}
*/
example[] readDataMulti(string filename){
Stream file = new BufferedFile(filename);
size_t num_lines = 0;
example[] data;
foreach (char[] _line; file) {
string line = cast(string)_line;
if (line.length == 0) continue;
if (line[0] == '#') continue;
int label = to!int(line.split(" ")[0]) - 1;
double[int] vec = parseLine(line);
if(vec != null) {
example ex = example(label, vec);
assert(vec != null);
assert(ex.features != null);
data ~= ex;
}
}
file.close();
return data;
}
example[] readData(string filename){
Stream file = new BufferedFile(filename);
size_t num_lines = 0;
example[] data;
foreach (char[] _line; file) {
string line = cast(string)_line;
if (line.length == 0) continue;
if (line[0] == '#') continue;
assert(line[0] == '-' || line[0] == '+');
int label = line[0] == '+' ? +1 : -1;
double[int] vec = parseLine(line);
if(vec != null) {
example ex = example(label, vec);
assert(vec != null);
assert(ex.features != null);
assert(ex.label == -1 || ex.label == +1);
data ~= ex;
}
}
file.close();
return data;
}
void main(string[] args) {
immutable uint dimension = 62062;
int num_class = 20;
Arow arow[];
arow.length = num_class;
for(auto i = 0; i < num_class; i++) arow[i] = new Arow(dimension, 0.1);
example[] data = readDataMulti("news20.scale");
data.randomShuffle();
int i;
int num_train = 10000;
for(i = 0; i < num_train; i++) {
int label = data[i].label;
for(auto j = 0; j < num_class; j++) {
/*
if(data[i].label == j)
arow[j].update(data[i].features, 1);
else if(uniform(0,2) == 0)
arow[j].update(data[i].features, -1);
*/
arow[j].update(data[i].features, data[i].label == j ? 1 : -1);
}
}
int correct = 0;
for(; i < data.length; i++) {
int result;
double max_distance = -double.infinity;
for(auto j = 0; j < num_class; j++) {
auto d = arow[j].getMargin(data[i].features);
if(d > max_distance) {
result = j;
max_distance = d;
}
}
if(result == data[i].label) correct++;
}
writeln(correct, "/", (cast(double)data.length - num_train));
writeln(correct / (cast(double)data.length - num_train));
}
/+
void main(string[] args) {
immutable uint dimension = 5;
Arow arow[3];
arow[0] = new Arow(dimension, 0.1);
arow[1] = new Arow(dimension, 0.1);
arow[2] = new Arow(dimension, 0.1);
example[] data = readDataMulti("iris.scale");
data.randomShuffle();
int i;
int num_train = 100;
for(i = 0; i < num_train; i++) {
int label = data[i].label;
for(auto j = 0; j < 3; j++) {
if(data[i].label == j)
arow[j].update(data[i].features, 1);
else if(uniform(0,2) == 0)
arow[j].update(data[i].features, -1);
//arow[j].update(data[i].features, data[i].label == j ? 1 : -1);
}
}
int correct = 0;
for(; i < data.length; i++) {
write(data[i].label, ": ");
int result;
double max_distance = -double.infinity;
for(auto j = 0; j < 3; j++) {
auto d = arow[j].getMargin(data[i].features);
if(d > max_distance) {
result = j;
max_distance = d;
}
write(d, ", ");
}
writeln("->", result);
if(result == data[i].label) correct++;
}
writeln(correct, "/", (cast(double)data.length - num_train));
writeln(correct / (cast(double)data.length - num_train));
}
+/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment