Created
February 1, 2015 13:29
-
-
Save kazoo04/40f1acac80330dfa0312 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import arow; | |
import std.stdio; | |
import std.random; | |
import std.string; | |
import std.conv; | |
import std.file; | |
import std.stream; | |
struct example { | |
int label; | |
double[int] features; | |
@safe nothrow { | |
this(int label, double[int] features) { | |
this.label = label; | |
this.features = features; | |
} | |
} | |
} | |
double[int] parseLine(string line) { | |
immutable string delim_value = ":"; | |
immutable string delim_cols = " "; | |
string[] columns = line.split(delim_cols); | |
double[int] f; | |
for(int i = 1; i < columns.length; i++) { | |
string[] arr = columns[i].split(delim_value); | |
if(arr.length != 2) | |
continue; | |
assert(arr != null); | |
assert(arr.length == 2); | |
auto key = to!int(arr[0]); | |
auto val = to!double(arr[1]); | |
f[key] = val; | |
} | |
return f; | |
} | |
/* | |
example[] readDataIris(string filename){ | |
Stream file = new BufferedFile(filename); | |
size_t num_lines = 0; | |
example[] data; | |
foreach (char[] _line; file) { | |
string line = cast(string)_line; | |
if (line.length == 0) continue; | |
if (line[0] == '#') continue; | |
assert(line[0] == '1' || line[0] == '2' || line[0] == '3'); | |
string[] columns = line.split(","); | |
double[int] vec = parseLine(line); | |
vec[0] = to!double(columns[1]); | |
vec[1] = to!double(columns[2]); | |
vec[2] = to!double(columns[3]); | |
example ex = example(to!int(columns[0]) - 1, vec); | |
data ~= ex; | |
} | |
file.close(); | |
return data; | |
} | |
*/ | |
example[] readDataMulti(string filename){ | |
Stream file = new BufferedFile(filename); | |
size_t num_lines = 0; | |
example[] data; | |
foreach (char[] _line; file) { | |
string line = cast(string)_line; | |
if (line.length == 0) continue; | |
if (line[0] == '#') continue; | |
int label = to!int(line.split(" ")[0]) - 1; | |
double[int] vec = parseLine(line); | |
if(vec != null) { | |
example ex = example(label, vec); | |
assert(vec != null); | |
assert(ex.features != null); | |
data ~= ex; | |
} | |
} | |
file.close(); | |
return data; | |
} | |
example[] readData(string filename){ | |
Stream file = new BufferedFile(filename); | |
size_t num_lines = 0; | |
example[] data; | |
foreach (char[] _line; file) { | |
string line = cast(string)_line; | |
if (line.length == 0) continue; | |
if (line[0] == '#') continue; | |
assert(line[0] == '-' || line[0] == '+'); | |
int label = line[0] == '+' ? +1 : -1; | |
double[int] vec = parseLine(line); | |
if(vec != null) { | |
example ex = example(label, vec); | |
assert(vec != null); | |
assert(ex.features != null); | |
assert(ex.label == -1 || ex.label == +1); | |
data ~= ex; | |
} | |
} | |
file.close(); | |
return data; | |
} | |
void main(string[] args) { | |
immutable uint dimension = 62062; | |
int num_class = 20; | |
Arow arow[]; | |
arow.length = num_class; | |
for(auto i = 0; i < num_class; i++) arow[i] = new Arow(dimension, 0.1); | |
example[] data = readDataMulti("news20.scale"); | |
data.randomShuffle(); | |
int i; | |
int num_train = 10000; | |
for(i = 0; i < num_train; i++) { | |
int label = data[i].label; | |
for(auto j = 0; j < num_class; j++) { | |
/* | |
if(data[i].label == j) | |
arow[j].update(data[i].features, 1); | |
else if(uniform(0,2) == 0) | |
arow[j].update(data[i].features, -1); | |
*/ | |
arow[j].update(data[i].features, data[i].label == j ? 1 : -1); | |
} | |
} | |
int correct = 0; | |
for(; i < data.length; i++) { | |
int result; | |
double max_distance = -double.infinity; | |
for(auto j = 0; j < num_class; j++) { | |
auto d = arow[j].getMargin(data[i].features); | |
if(d > max_distance) { | |
result = j; | |
max_distance = d; | |
} | |
} | |
if(result == data[i].label) correct++; | |
} | |
writeln(correct, "/", (cast(double)data.length - num_train)); | |
writeln(correct / (cast(double)data.length - num_train)); | |
} | |
/+ | |
void main(string[] args) { | |
immutable uint dimension = 5; | |
Arow arow[3]; | |
arow[0] = new Arow(dimension, 0.1); | |
arow[1] = new Arow(dimension, 0.1); | |
arow[2] = new Arow(dimension, 0.1); | |
example[] data = readDataMulti("iris.scale"); | |
data.randomShuffle(); | |
int i; | |
int num_train = 100; | |
for(i = 0; i < num_train; i++) { | |
int label = data[i].label; | |
for(auto j = 0; j < 3; j++) { | |
if(data[i].label == j) | |
arow[j].update(data[i].features, 1); | |
else if(uniform(0,2) == 0) | |
arow[j].update(data[i].features, -1); | |
//arow[j].update(data[i].features, data[i].label == j ? 1 : -1); | |
} | |
} | |
int correct = 0; | |
for(; i < data.length; i++) { | |
write(data[i].label, ": "); | |
int result; | |
double max_distance = -double.infinity; | |
for(auto j = 0; j < 3; j++) { | |
auto d = arow[j].getMargin(data[i].features); | |
if(d > max_distance) { | |
result = j; | |
max_distance = d; | |
} | |
write(d, ", "); | |
} | |
writeln("->", result); | |
if(result == data[i].label) correct++; | |
} | |
writeln(correct, "/", (cast(double)data.length - num_train)); | |
writeln(correct / (cast(double)data.length - num_train)); | |
} | |
+/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment