Created
February 23, 2016 19:02
-
-
Save volkanunsal/32a602c062b0968a92c5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TransitionParser(object): | |
""" | |
An arc-eager transition parser | |
""" | |
def __init__(self, transition, feature_extractor): | |
self._dictionary = {} | |
self._transition = {} | |
self._match_transition = {} | |
self._model = None | |
self._user_feature_extractor = feature_extractor | |
self.transitions = transition | |
def _get_dep_relation(self, idx_parent, idx_child, depgraph): | |
p_node = depgraph.nodes[idx_parent] | |
c_node = depgraph.nodes[idx_child] | |
if c_node['word'] is None: | |
return None # Root word | |
if c_node['head'] == p_node['address']: | |
return c_node['rel'] | |
else: | |
return None | |
def _convert_to_binary_features(self, features): | |
""" | |
This function converts a feature into libsvm format, and adds it to the | |
feature dictionary | |
:param features: list of feature string which is needed to convert to | |
binary features | |
:type features: list(str) | |
:return : string of binary features in libsvm format which is | |
'featureID:value' pairs | |
""" | |
unsorted_result = [] | |
for feature in features: | |
self._dictionary.setdefault(feature, len(self._dictionary)) | |
unsorted_result.append(self._dictionary[feature]) | |
# Default value of each feature is 1.0 | |
return ' '.join(str(featureID) + ':1.0' for featureID in sorted(unsorted_result)) | |
@staticmethod | |
def _is_projective(depgraph): | |
""" | |
Checks if a dependency graph is projective | |
""" | |
arc_list = set() | |
for key in depgraph.nodes: | |
node = depgraph.nodes[key] | |
if 'head' in node: | |
childIdx = node['address'] | |
parentIdx = node['head'] | |
arc_list.add((parentIdx, childIdx)) | |
for (parentIdx, childIdx) in arc_list: | |
# Ensure that childIdx < parentIdx | |
if childIdx > parentIdx: | |
temp = childIdx | |
childIdx = parentIdx | |
parentIdx = temp | |
for k in range(childIdx + 1, parentIdx): | |
for m in range(len(depgraph.nodes)): | |
if (m < childIdx) or (m > parentIdx): | |
if (k, m) in arc_list: | |
return False | |
if (m, k) in arc_list: | |
return False | |
return True | |
def _write_to_file(self, key, binary_features, input_file): | |
""" | |
write the binary features to input file and update the transition dictionary | |
""" | |
self._transition.setdefault(key, len(self._transition) + 1) | |
self._match_transition[self._transition[key]] = key | |
input_str = str(self._transition[key]) + ' ' + binary_features + '\n' | |
input_file.write(input_str.encode('utf-8')) | |
def _create_training_examples_arc_eager(self, depgraphs, input_file): | |
""" | |
Create the training example in the libsvm format and write it to the input_file. | |
Reference : 'A Dynamic Oracle for Arc-Eager Dependency Parsing' by Joav Goldberg and Joakim Nivre | |
""" | |
training_seq = [] | |
projective_dependency_graphs = [dg for dg in depgraphs if TransitionParser._is_projective(dg)] | |
countProj = len(projective_dependency_graphs) | |
for depgraph in projective_dependency_graphs: | |
conf = Configuration(depgraph, self._user_feature_extractor.extract_features) | |
while conf.buffer: | |
b0 = conf.buffer[0] | |
features = conf.extract_features() | |
print(features) | |
binary_features = self._convert_to_binary_features(features) | |
if conf.stack: | |
s0 = conf.stack[-1] | |
# Left-arc operation | |
rel = self._get_dep_relation(b0, s0, depgraph) | |
if rel is not None: | |
key = self.transitions.LEFT_ARC + ':' + rel | |
self._write_to_file(key, binary_features, input_file) | |
self.transitions.left_arc(conf, rel) | |
training_seq.append(key) | |
continue | |
# Right-arc operation | |
rel = self._get_dep_relation(s0, b0, depgraph) | |
if rel is not None: | |
key = self.transitions.RIGHT_ARC + ':' + rel | |
self._write_to_file(key, binary_features, input_file) | |
self.transitions.right_arc(conf, rel) | |
training_seq.append(key) | |
continue | |
# reduce operation | |
flag = False | |
for k in range(s0): | |
if self._get_dep_relation(k, b0, depgraph) is not None: | |
flag = True | |
if self._get_dep_relation(b0, k, depgraph) is not None: | |
flag = True | |
if flag: | |
key = self.transitions.REDUCE | |
self._write_to_file(key, binary_features, input_file) | |
self.transitions.reduce(conf) | |
training_seq.append(key) | |
continue | |
# Shift operation as the default | |
key = self.transitions.SHIFT | |
self._write_to_file(key, binary_features, input_file) | |
self.transitions.shift(conf) | |
training_seq.append(key) | |
print(" Number of training examples : {}".format(len(depgraphs))) | |
print(" Number of valid (projective) examples : {}".format(countProj)) | |
return training_seq | |
def train(self, depgraphs): | |
""" | |
:param depgraphs : list of DependencyGraph as the training data | |
:type depgraphs : DependencyGraph | |
""" | |
try: | |
input_file = tempfile.NamedTemporaryFile( | |
prefix='transition_parse.train', | |
dir=tempfile.gettempdir(), | |
delete=False) | |
self._create_training_examples_arc_eager(depgraphs, input_file) | |
input_file.close() | |
# Using the temporary file to train the libsvm classifier | |
x_train, y_train = load_svmlight_file(input_file.name) | |
# The parameter is set according to the paper: | |
# Algorithms for Deterministic Incremental Dependency Parsing by Joakim Nivre | |
# this is very slow. | |
self._model = svm.SVC( | |
kernel='poly', | |
degree=2, | |
coef0=0, | |
gamma=0.2, | |
C=0.5, | |
verbose=False, | |
probability=True) | |
print('Training support vector machine...') | |
self._model.fit(x_train, y_train) | |
print('done!') | |
finally: | |
os.remove(input_file.name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment