Created
November 3, 2014 05:22
-
-
Save tma15/1d7bd594d5be774ca6e9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import re | |
import json | |
import numpy as np | |
import maf | |
import maflib.util | |
def configure(conf): | |
pass | |
@maflib.util.rule | |
def jsonize(task): | |
""" Calculate accuracy from a format as below: | |
Recall[-1]: 0.932965 (21934/23510) | |
Prec[-1]: 0.849562 (21934/25818) | |
-- | |
Recall[+1]: 0.478378 (3562/7446) | |
Prec[+1]: 0.693266 (3562/5138) | |
""" | |
out = task.parameter | |
with open(task.inputs[0].abspath(), 'r') as f: | |
num = 0 | |
num_trues = 0 | |
for line in f: | |
if line.startswith("Prec"): | |
sp = line.split() | |
nums = re.search("(\d+)/(\d+)", sp[2]).groups() | |
num_trues += int(nums[0]) | |
num += int(nums[1]) | |
out["accuracy"] = num_trues / float(num) | |
with open(task.outputs[0].abspath(), 'w') as f: | |
json.dump(out, f) | |
@maflib.util.rule | |
def aggregate_by_alg(task): | |
out = [] | |
for i in task.inputs: | |
with open(i.abspath(), 'r') as f: | |
out.append(json.load(f)) | |
with open(task.outputs[0].abspath(), 'w') as f: | |
json.dump(out, f) | |
def aggregate_by_param(): | |
@maflib.util.aggregator | |
def body(values, outpath, parameter): | |
out = [] | |
for value in values: | |
out.append(value) | |
return json.dumps(out) | |
return maflib.core.Rule(fun=body) | |
def build(exp): | |
traindata='a1a' | |
train = '~/go/src/github.com/tma15/gonline/gonline/gonline train' | |
test = '~/go/src/github.com/tma15/gonline/gonline/gonline test' | |
NUM_FOLD = 3 | |
exp(source=traindata, | |
target="train dev", | |
parameters=[{"fold": i} for i in xrange(NUM_FOLD)], | |
rule=maflib.rules.segment_by_line(NUM_FOLD, 'fold')) | |
exp(source="train", | |
target="model", | |
parameters=maflib.util.product({ | |
"a": ["perceptron", "pa2", "adagrad"], | |
"c": np.power(10., np.arange(-12, -5, dtype=np.float64)), | |
}), | |
rule="%s -a ${a} -m ${TGT[0].abspath()} ${SRC[0].abspath()}" % train) | |
exp(source="model dev", | |
target="dev_result", | |
rule="%s -m ${SRC[0].abspath()} ${SRC[1].abspath()} > ${TGT[0].abspath()}" % test) | |
exp(source="dev_result", | |
target="accuracy", | |
rule=jsonize) ### パラメータごとのaccuracyをjson形式で出力 | |
exp(source="accuracy", | |
target="accuracies_by_param", | |
for_each=["a", "c"], | |
rule=aggregate_by_param()) ### パラメータ毎にaccuracyを集約する | |
exp(source="accuracies_by_param", | |
target="avg_acc", | |
aggregate_by=["fold"], | |
rule=maflib.rules.average) ### パラメータ毎の平均を計算 | |
exp(source="avg_acc", | |
target="for_each_alg", | |
for_each=["a"], | |
rule=aggregate_by_alg) ## アルゴリズム毎に集約 | |
exp(source="for_each_alg", | |
target="max_acc", | |
aggregate_by = ["fold"], | |
rule=maflib.rules.max("accuracy")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment