Skip to content

Instantly share code, notes, and snippets.

@strubell
Last active December 11, 2017 18:31
Show Gist options
  • Save strubell/ba042ae0aa5ca77a56183ec9fc9d9ac2 to your computer and use it in GitHub Desktop.
Save strubell/ba042ae0aa5ca77a56183ec9fc9d9ac2 to your computer and use it in GitHub Desktop.
from __future__ import division
from __future__ import print_function
import numpy as np
import matplotlib.pyplot as plt
import os
import string
import operator
def sorted_mean(l):
return sorted(l, key=lambda key: np.mean(np.array(key[1])[:,1]))
coarse_data_file = "naacl_data.tsv"
fine_data_file = "vary_len_scores.tsv"
which_data = 'fine'
data_file = coarse_data_file if which_data == 'coarse' else fine_data_file
data_lines = map(lambda x: x.strip().split('\t'), open(data_file, 'r').readlines())
print(data_lines)
dataset_int_to_str = {}
dataset_str_to_int = {}
data_map = {}
for line in data_lines:
# so that we can look up the name later
dataset_name = line[0]
dist = int(line[1])
if dataset_name not in dataset_str_to_int.keys():
# data_map[line[0]] = []
curr_idx = len(dataset_int_to_str)
dataset_int_to_str[curr_idx] = dataset_name
dataset_str_to_int[dataset_name] = curr_idx
if dist not in data_map.keys():
data_map[dist] = []
# data_map[line[0]].append([int(line[1]), float(line[2])])
data_map[dist].append([dataset_str_to_int[dataset_name], float(line[2])])
print(dataset_str_to_int)
print(dataset_int_to_str)
print(data_map)
num_datum = len(data_map)
print(num_datum)
bar_width = 1/(num_datum+1)
fig, ax = plt.subplots()
start_loc = 0
rects = []
sorted_data = sorted_mean(data_map.iteritems())
for d, data in sorted_data:
print(d, data)
data = np.array(data)
labels = data[:,0]
values = data[:,1]
rect = ax.bar(np.arange(len(values))+start_loc, height=values, width=bar_width)
rects.append(rect)
start_loc += bar_width
lgd = ax.legend(map(lambda r: r[0], rects), zip(*sorted_data)[0], bbox_to_anchor=(1.0, 1.0))
# ax.set_xlabel('Distance')
# ax.set_ylabel('Score')
ax.set_xlabel('Dataset')
ax.set_ylabel('F1 Score')
# example_datum = np.array(data_map[dataset_intmap[0]])
# ax.set_xticks(np.arange(len(example_datum)) + bar_width*1.5)
# ax.set_xticklabels(example_datum[:,0])
example_datum = np.array(data_map[data_map.keys()[0]])
labels = np.array(sorted(dataset_str_to_int.iteritems(), key=lambda kv: kv[1]))[:, 0]
ax.set_xticks(np.arange(len(labels)) + bar_width) # *1.5
if which_data == 'fine':
ax.set_xticklabels(labels, rotation=90, fontsize=8)
else:
ax.set_xticklabels(labels)
# ax.set_title("Totals")
fig.tight_layout()
fig_name = '%s.pdf' % which_data
plt.savefig(fig_name, bbox_extra_artists=(lgd,), bbox_inches='tight')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment