Skip to content

Instantly share code, notes, and snippets.

@arne-cl
Last active August 29, 2015 14:01
Show Gist options
  • Save arne-cl/bc13bdfe7d460d4a228e to your computer and use it in GitHub Desktop.
Save arne-cl/bc13bdfe7d460d4a228e to your computer and use it in GitHub Desktop.
prints a rhetorical structure tree (RS3) as table
# -*- coding: utf-8 -*-
# #!/usr/bin/python
# Titel: rst.py
# Discription: prints a rst tree as table
# Lizenz: GPLv3
# Author: Andre Herzog
# vers.: 0.1c
# Date: 26.03.2014
import sys
import os
import csv
from xml.etree import cElementTree as et
from collections import deque
from collections import OrderedDict
from prettytable import PrettyTable
class Node(object):
'''This class represens a single rst node.'''
def __init__(self, parent=u"", rel=u"span", child=[],
status=u"nuc", start=-1, end=-1, node=True):
'''The init of the node object requires an id only.
The init default is a nucleus node (not a leaf).'''
self.start = start + 1 # first token of the segment (position)
self.end = end # last token ot the segment (position)
self.status = status # nucleus or satellite
self.node = node # node or leaf (True/False)
self.parent = parent # id of the parent node
self.child = child # id list of the children
self.relation = rel # relation name (between node and parent)
def pprint(self):
'''Pretty prints a node as a small table.'''
x = PrettyTable(["start", "end", "status", "node", "parent", "child",
"relation"])
x.align["RELATION"] = "l" # Left align
x.add_row([self.start, self.end, self.status, self.node, self.parent,
self.child, self.relation])
print x.get_string(sortby="start")
class RstTree(OrderedDict):
'''This class represens a rst tree.'''
def __init__(self, path="", nuc=[]):
'''Init with list of root elements.'''
super(RstTree, self).__init__()
self.token = [] # list of token
self.root = [] # list of root node ids
if nuc:
self._nuc = nuc # list with nucleus relnames (multinuc)
else:
self._nuc = [u'span', u'conjunction', u'contrast', u'disjunction',
u'joint', u'list', u'restatement-rm', u'sequence']
if path:
self.loadTree(path)
def __getAttrib(self, e, name):
'''Gets attribute by name, if this elem has not such attribute, it
returns an empty string.'''
if name in e.attrib:
return e.attrib[name]
return ""
def __addAdditionalInfos(self):
'''Make implicit information explicit. (Ugly like the rst format):
* adds the children to the nodes
* calculates the spans of the nodes'''
# Add children to the nodes
for ID in self:
if not self[ID].node:
child = ID
for i in self.traverseToRoot(self[child].parent):
if not child in self[i].child:
self[i].child.append(child)
child = i
# calculate the spans of the nodes
for ID in self:
child = ID
for i in self.traverseToRoot(self[ID].parent):
if self[child].relation in self._nuc:
start_tmp = [self[child].start, self[i].start]
end_tmp = [self[child].end, self[i].end]
for c in self[child].child:
start_tmp.append(self[c].start)
end_tmp.append(self[c].end)
self[i].start = min(start_tmp)
self[i].end = max(end_tmp)
child = i
def loadTree(self, path):
'''Loads a rst tree from a given rs3.xml file into the class.'''
xml = et.ElementTree(file=path)
count = 0
for e in xml.iter():
if e.tag == "segment":
parent = self.__getAttrib(e, "parent")
relname = self.__getAttrib(e, "relname")
if relname in self._nuc:
status = "nuc"
else:
status = "sat"
if not parent:
self.root.append(e.attrib['id'])
# get text and split it by whitespace
txt = e.text.strip().split()
self.token.extend(txt)
txt_len = len(txt)
start = count
count += txt_len
self.update({e.attrib['id']: Node(parent,
relname, [], status,
start, count, False)})
elif e.tag == "group":
parent = self.__getAttrib(e, "parent")
relname = self.__getAttrib(e, "relname")
if relname in self._nuc:
status = "nuc"
else:
status = "sat"
if not parent:
self.root.append(e.attrib['id'])
self.update({e.attrib['id']: Node(parent,
relname, [], status,
count + 1, -1, True)})
self.__addAdditionalInfos()
def getRelation(self, node):
'''Returns the relation between multi-nuc nodes.'''
rel = []
if not self[node].relation == "span":
return [self[node].relation]
for child in self[node].child:
if not self[child].relation == "span":
rel.append(self[child].relation)
return rel
def traverseToRoot(self, ID):
'''Traverses form the given node ID to the hightest node of the rst
tree by following the nucelei. (bottom-up)'''
parent = deque()
last = current = 0
if ID:
parent.append((current, ID))
while parent and self[ID].parent:
current, ID = parent.popleft()
while last and last >= current:
last -= 1
if ID:
yield ID
parent.extendleft([(current + 1, self[ID].parent)])
last = current
def traverse(self, ID):
'''Traverses form the given node ID to the lowest child nodes of the
rst tree. (top-down)'''
children = deque()
path = list()
last = current = 0
children.append((current, ID))
while children:
print ID, self[ID].child
current, ID = children.popleft()
while last and last >= current:
path.pop()
last -= 1
path.append(ID)
yield current, path
children.extendleft([(current + 1,
node) for node in self[ID].child])
last = current
def getAllDescendents(self, ID):
'''Gets all descendents of a node.'''
for node in self.traverse(self[ID], lambda x: None):
yield node
def printSatNucTable(self):
'''Prints a satelite-nucleus-relation table'''
x = PrettyTable(["S-ID", "S-START", "S-END", "N-ID", "N-START",
"N-END", "RELATION"])
x.align["RELATION"] = "l" # Left align
for r in rst:
if rst[r].status == "nuc":
if rst[r].child:
for s in rst[r].child:
if rst[s].relation != "span":
n = rst[s].parent
# normal nucs
if not rst[s].relation in rst._nuc:
x.add_row([s, rst[s].start, rst[s].end,
n, rst[n].start, rst[n].end,
rst[s].relation])
# multi nucs with children
elif rst[r].relation != "span":
x.add_row(["-", "-", "-",
r, rst[r].start, rst[r].end,
rst[r].relation + " ("
+ rst[r].parent + ")"])
else:
# multi nucs without children
x.add_row(["-", "-", "-", r, rst[r].start, rst[r].end,
rst[r].relation + " (" + rst[r].parent + ")"])
print x.get_string(sortby="N-END")
def writeToCsv(self, path):
'''Writes the satelite-nucleus-relation table to a tab separated
format.'''
with open(path, 'wb') as csvfile:
w = csv.writer(csvfile, delimiter='\t',
quotechar='|', quoting=csv.QUOTE_MINIMAL)
w.writerow(["S-ID", "S-START", "S-END", "N-ID", "N-START",
"N-END", "RELATION"])
for r in rst:
if rst[r].status == "nuc":
if rst[r].child:
for s in rst[r].child:
if rst[s].relation != "span":
n = rst[s].parent
# normal nucs
if not rst[s].relation in rst._nuc:
w.writerow([s, rst[s].start, rst[s].end,
n, rst[n].start, rst[n].end,
rst[s].relation])
# multi nucs with children
elif rst[r].relation != "span":
w.writerow(["-", "-", "-",
r, rst[r].start, rst[r].end,
rst[r].relation + " ("
+ rst[r].parent + ")"])
else:
# multi nucs without children
w.writerow(["-", "-", "-", r, rst[r].start,
rst[r].end, rst[r].relation +
" (" + rst[r].parent + ")"])
# Main #######################################################################
if __name__ == "__main__":
# TODO A better run part.
if len(sys.argv) <= 1:
print "Usage: rst.py <path-to-rst-file> [<path-to-csv-output-file>]"
sys.exit(0)
elif len(sys.argv) > 2:
rst_path = sys.argv[1]
csv_path = sys.argv[2]
else:
rst_path = sys.argv[1]
csv_path = ""
if os.path.isfile(rst_path):
rst = RstTree(rst_path)
if csv_path:
rst.writeToCsv(csv_path)
else:
rst.printSatNucTable()
else:
print "Rst file not exists."
sys.exit(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment