Skip to content

Instantly share code, notes, and snippets.

@julienr
Created October 26, 2015 16:04
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save julienr/6b9b9a03bd8224db7b4f to your computer and use it in GitHub Desktop.
Save julienr/6b9b9a03bd8224db7b4f to your computer and use it in GitHub Desktop.
Parse and convert scikit-learn classification_report to latex
"""
Code to parse sklearn classification_report
"""
##
import sys
import collections
##
def parse_classification_report(clfreport):
"""
Parse a sklearn classification report into a dict keyed by class name
and containing a tuple (precision, recall, fscore, support) for each class
"""
lines = clfreport.split('\n')
# Remove empty lines
lines = filter(lambda l: not len(l.strip()) == 0, lines)
# Starts with a header, then score for each class and finally an average
header = lines[0]
cls_lines = lines[1:-1]
avg_line = lines[-1]
assert header.split() == ['precision', 'recall', 'f1-score', 'support']
assert avg_line.split()[0] == 'avg'
# We cannot simply use split because class names can have spaces. So instead
# figure the width of the class field by looking at the indentation of the
# precision header
cls_field_width = len(header) - len(header.lstrip())
# Now, collect all the class names and score in a dict
def parse_line(l):
"""Parse a line of classification_report"""
cls_name = l[:cls_field_width].strip()
precision, recall, fscore, support = l[cls_field_width:].split()
precision = float(precision)
recall = float(recall)
fscore = float(fscore)
support = int(support)
return (cls_name, precision, recall, fscore, support)
data = collections.OrderedDict()
for l in cls_lines:
ret = parse_line(l)
cls_name = ret[0]
scores = ret[1:]
data[cls_name] = scores
# average
data['avg'] = parse_line(avg_line)[1:]
return data
#parse_classification_report(clfreport)
##
def report_to_latex_table(data):
out = ""
out += "\\begin{tabular}{c | c c c c}\n"
out += "Class & Precision & Recall & F-score & Support\\\\\n"
out += "\hline\n"
out += "\hline\\\\\n"
for cls, scores in data.iteritems():
out += cls + " & " + " & ".join([str(s) for s in scores])
out += "\\\\\n"
out += "\\end{tabular}"
return out
#print report_to_latex_table(data)
##
if __name__ == '__main__':
with open(sys.argv[1]) as f:
data = parse_classification_report(f.read())
print report_to_latex_table(data)
##
@Mikelew88
Copy link

There is a bug with cls_field_width = len(header) - len(header.lstrip()) if class field length varies widely. I have class labels as long as 38 characters, but the top row only leave 13 blank characters in front of "precision." To work around I've added try and except blocks to this:

    def parse_line(l):
        """Parse a line of classification_report"""

        try:
            cls_name = l[:cls_field_width].strip()
            precision, recall, fscore, support = l[cls_field_width:].split()
        except ValueError:
            try:
                cls_name = l[:15].strip()
                precision, recall, fscore, support = l[15:].split()
            except ValueError:
                try:
                    cls_name = l[:16].strip()
                    precision, recall, fscore, support = l[16:].split()
                except ValueError:
                    try:
                        cls_name = l[:24].strip()
                        precision, recall, fscore, support = l[24:].split()
                    except ValueError: 
                        cls_name = l[:38].strip()
                        precision, recall, fscore, support = l[38:].split()

I'm sure there is a way to make this more dynamic. Just wanted to let you know about this bug.

Cheers,
Mike

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment