Skip to content

Instantly share code, notes, and snippets.

@olbat
Last active November 24, 2017 09:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save olbat/41bc27ad1da84742c812b42ebf0c5868 to your computer and use it in GitHub Desktop.
Save olbat/41bc27ad1da84742c812b42ebf0c5868 to your computer and use it in GitHub Desktop.
Python3 matplotlib script that plots histograms for features of an ML corpus
#!/usr/bin/env python3
"""
usage: {} < corpus.json > plot.pdf
The corpus file must contain one JSON document per line,
features must be stored in a field names "{}",
classes in a field names "{}".
"""
import sys
import json
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import EngFormatter
FEATURES_FIELD = "features"
CLASS_FIELD = "category"
RATIO = (2, 3)
# display help if no data on stdin
if sys.stdin.isatty():
print(__doc__.format(sys.argv[0], FEATURES_FIELD, CLASS_FIELD))
sys.exit(1)
# load corpus from the corpus file
features = []
class_features = defaultdict(list)
for line in sys.stdin:
data = json.loads(line)
features.append(data[FEATURES_FIELD])
class_features[data[CLASS_FIELD]].append(data[FEATURES_FIELD])
features = np.transpose(np.array(features))
class_features = {
cl: np.transpose(np.array(fts))
for cl, fts in class_features.items()}
# setup global figure parameters
plotsize = (len(features), len(class_features) + 1)
fig = plt.figure(figsize=tuple(x * plotsize[i] for i, x in enumerate(RATIO)))
plt.rcParams["axes.grid"] = True
plt.rcParams["grid.linestyle"] = "dotted"
# display information about the corpus as suptitle
desc = "samples:{} (".format(len(features[0]))
desc += ", ".join([
"{}:{}".format(cl, len(fts[0]))
for cl, fts in class_features.items()])
fig.suptitle(desc + ")")
# display features subplots
ylims = []
for i, vals in enumerate(features):
axis = plt.subplot2grid(plotsize, (i, 0))
axis.yaxis.set_major_formatter(EngFormatter())
if i == 0:
plt.title("all")
axis.set_ylabel("feature #{}".format(i))
axis.get_yaxis().set_label_coords(-0.15, 0.5)
plt.hist(vals)
ylims.append(axis.get_ylim())
# display per-class features subplots
for i, c in enumerate(class_features):
for j, vals in enumerate(class_features[c]):
axis = plt.subplot2grid(plotsize, (j, i+1))
axis.yaxis.set_major_formatter(EngFormatter())
if j == 0:
plt.title("{}".format(c))
plt.hist(vals)
axis.set_ylim(ylims[j])
# fix layout
plt.tight_layout()
fig.subplots_adjust(top=0.90)
# output the resulting plot
plt.savefig(sys.stdout.buffer, format="pdf")
plt.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment