Last active
February 13, 2016 00:42
-
-
Save arendu-zz/9aaa83ba3f6c10d4f130 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__author__ = 'arenduchintala' | |
from edu.jhu.pacaya.gm.model import FactorGraph, FgModel, Var, VarSet, ExpFamFactor, ExplicitExpFamFactor, ExplicitFactor, VarConfig, ClampFactor | |
from edu.jhu.pacaya.gm.train import CrfTrainer | |
from edu.jhu.pacaya.gm.train.CrfTrainer import CrfTrainerPrm | |
from edu.jhu.pacaya.gm.data import FgExampleList, FgExampleDiskStore, FgExampleMemoryStore, LabeledFgExample | |
from edu.jhu.pacaya.gm.feat import FeatureVector | |
from edu.jhu.nlp.joint import OptimizerFactory | |
from edu.jhu.pacaya.gm.inf.BeliefPropagation import BeliefPropagationPrm | |
from edu.jhu.pacaya.gm.inf.BeliefPropagation import BpScheduleType | |
from edu.jhu.pacaya.gm.inf.BeliefPropagation import BpUpdateOrder | |
from edu.jhu.pacaya.util.semiring import LogSemiring | |
from edu.jhu.hlt.optimize import AdaGradComidL2, AdaDelta | |
from edu.jhu.hlt.optimize.AdaGradComidL2 import AdaGradComidL2Prm | |
from edu.jhu.hlt.optimize.AdaDelta import AdaDeltaPrm | |
from edu.jhu.pacaya.util.semiring import LogSemiring | |
from edu.jhu.pacaya.gm.decode import MbrDecoder | |
from edu.jhu.pacaya.gm.decode.MbrDecoder import MbrDecoderPrm | |
import re | |
import logging | |
import random | |
import pdb | |
import sys | |
from optparse import OptionParser | |
#features fired key: configuration of var set | |
# value list of tuples, (feature_label, feature_value) | |
factor_cell_to_features = {} | |
feature_idxs = {} | |
class ObservedFactor(ExpFamFactor): | |
def __init__(self, varset, var_list, factor_type, observed_state): | |
ExpFamFactor.__init__(self,varset) | |
self.var_list = var_list | |
self.factor_type = factor_type | |
self.observed_state = observed_state | |
def getFeatures(self, configuration_id): | |
vs = self.getVars() | |
configuration = vs.getVarConfig(configuration_id) | |
state1 = configuration.getStateName(self.var_list[0]) | |
#print 'vs:' , vs.calcNumConfigs() | |
#print 'config_id:', configuration_id, 'config:' , state1, self.factor_type | |
#print 'vars:' , self.var_list[0].name, self.observed_state | |
feats_fired = factor_cell_to_features[(self.factor_type, state1, self.observed_state)] | |
feat_idxs = [(feature_idxs[f_label], f_val) for f_label, f_val in feats_fired] | |
feats = zip(* feat_idxs) | |
return FeatureVector(list(feats[0]), list(feats[1])) | |
class CRFFactor(ExpFamFactor): | |
def __init__(self, varset, var_list, factor_type): | |
ExplicitExpFamFactor.__init__(self,varset) | |
self.var_list = var_list | |
self.factor_type = factor_type | |
def getFeatures(self, configuration_id): | |
vs = self.getVars() | |
configuration = vs.getVarConfig(configuration_id) | |
state1 = configuration.getStateName(self.var_list[0]) | |
state2 = configuration.getStateName(self.var_list[1]) | |
#print 'vs:' , vs.calcNumConfigs() | |
#print 'config_id:', configuration_id, 'config:' , state1, state2, self.factor_type | |
#print 'vars:' , self.var_list[0].name, self.var_list[1].name | |
feats_fired = factor_cell_to_features[(self.factor_type, state1, state2)] | |
feat_idxs = [(feature_idxs[f_label], f_val) for f_label, f_val in feats_fired] | |
feats = zip(* feat_idxs) | |
return FeatureVector(list(feats[0]), list(feats[1])) | |
class Clamper(ClampFactor): | |
def __init__(self, var, var_state): | |
ClampFactor.__init__(self, var, var_state) | |
# This method tweaks a few defaults on the CrfTrainer, but isn't | |
# strictly necessary. | |
def get_trainer_prm(): | |
tr_prm = CrfTrainerPrm() | |
ad_prm = AdaGradComidL2Prm() | |
ad_prm.numPasses = 2 # Number of passes through the data | |
ad_prm.l2Lambda = 1. / 2000. # L2 regularizer weight | |
tr_prm.batchOptimizer = AdaGradComidL2(ad_prm) | |
# We can use brute force inference because the factor graph | |
# consists of only a single variable and factor. | |
#tr_prm.infFactory = BruteForceInferencerPrm(LogSemiring.getInstance()) | |
tr_prm.infFactory = BeliefPropagationPrm() | |
tr_prm.infFactory.s = LogSemiring.getInstance() | |
tr_prm.infFactory.schedule = BpScheduleType.TREE_LIKE | |
tr_prm.infFactory.updateOrder = BpUpdateOrder.SEQUENTIAL | |
tr_prm.infFactory.normalizeMessages = True | |
tr_prm.infFactory.maxIterations = 1; | |
tr_prm.infFactory.convergenceThreshold = 1e-3; | |
tr_prm.infFactory.keepTape = True | |
return tr_prm | |
def make_instances(txt_file, tag_list, obs_list): | |
instances = FgExampleMemoryStore() | |
text_train = [t.strip() for t in open(txt_file).read().split('###/###') if t.strip() != ''] | |
for x in range(len(text_train)): | |
factor_graph = FactorGraph() | |
vc = VarConfig() | |
prev_hc = None | |
for i, line in enumerate(text_train[x].split('\n')): | |
hidden_state = line.split('/')[1].strip() | |
observed_state = line.split('/')[0].strip() | |
#print 'h', hidden_state, 'o', observed_state | |
#make variables with their configurations | |
hc = Var(Var.VarType.PREDICTED , len(tag_list), "TAG_"+ str(i), tag_list) | |
vc.put(hc, hidden_state) | |
#o = Var(Var.VarType.PREDICTED , len(obs_list), "OBS_" + str(i), obs_list) | |
#vc.put(o, observed_state) | |
#make transition factor | |
if prev_hc: | |
t_varset = VarSet(hc) | |
t_varset.add(prev_hc) | |
t_factor = CRFFactor(t_varset, [prev_hc, hc], 'TAG-TAG') | |
factor_graph.addFactor(t_factor) | |
else: | |
pass | |
prev_hc = hc | |
#make emission factor | |
e_varset = VarSet(hc) | |
#e_varset.add(o) | |
#e_factor = CRFFactor(e_varset, [hc, o], 'TAG-OBS') | |
e_factor = ObservedFactor(e_varset, [hc], 'TAG-OBS', observed_state) | |
factor_graph.addFactor(e_factor) | |
#make clamp factor | |
#c_factor = Clamper(o, obs_list.index(observed_state)) | |
#factor_graph.addFactor(c_factor) | |
instances.add(LabeledFgExample(factor_graph, vc)) | |
return instances | |
def load_factor_features(txt_file): | |
to, factors = open(txt_file).read().split('FACTORS:') | |
to = to.strip() | |
tag_list = list(set(to.split('\n')[0].split())) | |
obs_list = list(set(to.split('\n')[1].split())) | |
fac_cell_2feat = {} | |
feat2id = {} | |
for features_fired_in_factor in factors.strip().split('FACTOR:'): | |
feature_lines = features_fired_in_factor.strip().split('\n') | |
fac_type = feature_lines[0].strip() | |
for fl in feature_lines[1:]: | |
items = fl.split() | |
label1,label2 = items[0], items[1] | |
fac_cell_2feat[fac_type, label1, label2] = [(f_name, 1.0) for f_name in items[2:]] | |
for f_fired in items[2:]: | |
feat2id[f_fired] = feat2id.get(f_fired, len(feat2id)) | |
return tag_list, obs_list, fac_cell_2feat, feat2id | |
if __name__ == '__main__': | |
opt = OptionParser() | |
#insert options here | |
opt.add_option('--test', dest='test_file', default='') | |
opt.add_option('--train', dest='train_file', default='') | |
opt.add_option('--feats', dest='feats_file', default='') | |
(options, _) = opt.parse_args() | |
tag_list, obs_list, factor_cell_to_features, feature_idxs = load_factor_features(options.feats_file) | |
training_instances = make_instances(options.train_file, tag_list, obs_list) | |
trainer = CrfTrainer(get_trainer_prm()) | |
factor_graph_model = FgModel(len(feature_idxs)) | |
trainer.train(factor_graph_model, training_instances) | |
testing_instances = make_instances(options.test_file, tag_list, obs_list) | |
correct_label_count = 0 | |
total_lable_count = 0 | |
for test_idx in range(testing_instances.size()): | |
test_instance = testing_instances.get(test_idx) | |
gold_config = test_instance.getGoldConfig() | |
decoder = MbrDecoder(MbrDecoderPrm()) | |
decoder.decode(factor_graph_model, test_instance) | |
predicted_config = decoder.getMbrVarConfig() | |
predictions = dict((v.getName(), predicted_config.getStateName(v)) for v in predicted_config.getVars() if v.getName().startswith('TAG')) | |
gold = dict((v.getName(), gold_config.getStateName(v)) for v in gold_config.getVars() if v.getName().startswith('TAG')) | |
assert len(gold) == len(predictions) | |
for k in sorted(gold.keys()): | |
print k, gold[k], predictions[k] | |
if gold[k] == predictions[k]: | |
correct_label_count +=1.0 | |
total_lable_count +=1.0 | |
print 'acc', float(correct_label_count / total_lable_count) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment