Skip to content

Instantly share code, notes, and snippets.

@volkanunsal
Created February 23, 2016 19:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save volkanunsal/32a602c062b0968a92c5 to your computer and use it in GitHub Desktop.
Save volkanunsal/32a602c062b0968a92c5 to your computer and use it in GitHub Desktop.
class TransitionParser(object):
"""
An arc-eager transition parser
"""
def __init__(self, transition, feature_extractor):
self._dictionary = {}
self._transition = {}
self._match_transition = {}
self._model = None
self._user_feature_extractor = feature_extractor
self.transitions = transition
def _get_dep_relation(self, idx_parent, idx_child, depgraph):
p_node = depgraph.nodes[idx_parent]
c_node = depgraph.nodes[idx_child]
if c_node['word'] is None:
return None # Root word
if c_node['head'] == p_node['address']:
return c_node['rel']
else:
return None
def _convert_to_binary_features(self, features):
"""
This function converts a feature into libsvm format, and adds it to the
feature dictionary
:param features: list of feature string which is needed to convert to
binary features
:type features: list(str)
:return : string of binary features in libsvm format which is
'featureID:value' pairs
"""
unsorted_result = []
for feature in features:
self._dictionary.setdefault(feature, len(self._dictionary))
unsorted_result.append(self._dictionary[feature])
# Default value of each feature is 1.0
return ' '.join(str(featureID) + ':1.0' for featureID in sorted(unsorted_result))
@staticmethod
def _is_projective(depgraph):
"""
Checks if a dependency graph is projective
"""
arc_list = set()
for key in depgraph.nodes:
node = depgraph.nodes[key]
if 'head' in node:
childIdx = node['address']
parentIdx = node['head']
arc_list.add((parentIdx, childIdx))
for (parentIdx, childIdx) in arc_list:
# Ensure that childIdx < parentIdx
if childIdx > parentIdx:
temp = childIdx
childIdx = parentIdx
parentIdx = temp
for k in range(childIdx + 1, parentIdx):
for m in range(len(depgraph.nodes)):
if (m < childIdx) or (m > parentIdx):
if (k, m) in arc_list:
return False
if (m, k) in arc_list:
return False
return True
def _write_to_file(self, key, binary_features, input_file):
"""
write the binary features to input file and update the transition dictionary
"""
self._transition.setdefault(key, len(self._transition) + 1)
self._match_transition[self._transition[key]] = key
input_str = str(self._transition[key]) + ' ' + binary_features + '\n'
input_file.write(input_str.encode('utf-8'))
def _create_training_examples_arc_eager(self, depgraphs, input_file):
"""
Create the training example in the libsvm format and write it to the input_file.
Reference : 'A Dynamic Oracle for Arc-Eager Dependency Parsing' by Joav Goldberg and Joakim Nivre
"""
training_seq = []
projective_dependency_graphs = [dg for dg in depgraphs if TransitionParser._is_projective(dg)]
countProj = len(projective_dependency_graphs)
for depgraph in projective_dependency_graphs:
conf = Configuration(depgraph, self._user_feature_extractor.extract_features)
while conf.buffer:
b0 = conf.buffer[0]
features = conf.extract_features()
print(features)
binary_features = self._convert_to_binary_features(features)
if conf.stack:
s0 = conf.stack[-1]
# Left-arc operation
rel = self._get_dep_relation(b0, s0, depgraph)
if rel is not None:
key = self.transitions.LEFT_ARC + ':' + rel
self._write_to_file(key, binary_features, input_file)
self.transitions.left_arc(conf, rel)
training_seq.append(key)
continue
# Right-arc operation
rel = self._get_dep_relation(s0, b0, depgraph)
if rel is not None:
key = self.transitions.RIGHT_ARC + ':' + rel
self._write_to_file(key, binary_features, input_file)
self.transitions.right_arc(conf, rel)
training_seq.append(key)
continue
# reduce operation
flag = False
for k in range(s0):
if self._get_dep_relation(k, b0, depgraph) is not None:
flag = True
if self._get_dep_relation(b0, k, depgraph) is not None:
flag = True
if flag:
key = self.transitions.REDUCE
self._write_to_file(key, binary_features, input_file)
self.transitions.reduce(conf)
training_seq.append(key)
continue
# Shift operation as the default
key = self.transitions.SHIFT
self._write_to_file(key, binary_features, input_file)
self.transitions.shift(conf)
training_seq.append(key)
print(" Number of training examples : {}".format(len(depgraphs)))
print(" Number of valid (projective) examples : {}".format(countProj))
return training_seq
def train(self, depgraphs):
"""
:param depgraphs : list of DependencyGraph as the training data
:type depgraphs : DependencyGraph
"""
try:
input_file = tempfile.NamedTemporaryFile(
prefix='transition_parse.train',
dir=tempfile.gettempdir(),
delete=False)
self._create_training_examples_arc_eager(depgraphs, input_file)
input_file.close()
# Using the temporary file to train the libsvm classifier
x_train, y_train = load_svmlight_file(input_file.name)
# The parameter is set according to the paper:
# Algorithms for Deterministic Incremental Dependency Parsing by Joakim Nivre
# this is very slow.
self._model = svm.SVC(
kernel='poly',
degree=2,
coef0=0,
gamma=0.2,
C=0.5,
verbose=False,
probability=True)
print('Training support vector machine...')
self._model.fit(x_train, y_train)
print('done!')
finally:
os.remove(input_file.name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment