Skip to content

Instantly share code, notes, and snippets.

@mokemokechicken
Created January 24, 2016 05:04
Show Gist options
  • Save mokemokechicken/8a82e18fb384db5352b1 to your computer and use it in GitHub Desktop.
Save mokemokechicken/8a82e18fb384db5352b1 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# coding: utf-8
"""Compare data distribution"""
import json
from itertools import chain
import cPickle as pickle
import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
from rnnlib import JSD
flags = tf.flags
logging = tf.logging
flags.DEFINE_string("dataset", None, "path to dataset")
flags.DEFINE_string("sample", None, "path to sampling data")
flags.DEFINE_string("figure", 'fig.png', 'path to output figure image')
FLAGS = flags.FLAGS
def check_length(ax, true_data, sampling_data):
d1 = pd.DataFrame([len(x) for x in true_data])
d2 = pd.DataFrame([len(x) for x in sampling_data])
d1.columns = ['len']
d2.columns = ['len']
# cal JSD
z1 = pd.value_counts(d1.len)
z2 = pd.value_counts(d2.len)
len_min = min(z1.index.min(), z2.index.min())
len_max = max(z1.index.max(), z2.index.max())
z = pd.concat({"d1": z1, "d2": z2}, axis=1, join_axes=[range(len_min, len_max)])
z.fillna(0, inplace=True)
jsd = JSD(z.d1, z.d2)
d1.len.plot(ax=ax, kind="kde", label="true_data")
d2.len.plot(ax=ax, kind="kde", label="sampling_data")
ax.legend()
ax.set_xlim((1, 50))
ax.set_title("Length(JSD: %.5f)" % jsd)
def check_frequency(ax, true_data, sampling_data):
true_seq = list(chain.from_iterable(true_data))
sampling_seq = list(chain.from_iterable(sampling_data))
f1 = pd.value_counts(true_seq) / len(true_seq)
f2 = pd.value_counts(sampling_seq) / len(sampling_seq)
freq = pd.concat([f1, f2], axis=1)
freq.columns = ["true_data", "sampling_data"]
jsd = JSD(freq.true_data, freq.sampling_data)
freq.plot(ax=ax, kind='bar')
ax.set_title("Frequency(JSD: %.5f)" % jsd)
def check_pair(ax, true_data, sampling_data):
true_seq = list(chain.from_iterable(true_data))
sampling_seq = list(chain.from_iterable(sampling_data))
true_pairs = map(lambda x: "-".join([str(z) for z in x]), zip(true_seq[:-1], true_seq[1:]))
sampling_pairs = map(lambda x: "-".join([str(z) for z in x]), zip(sampling_seq[:-1], sampling_seq[1:]))
p1 = pd.value_counts(true_pairs) / len(true_pairs)
p2 = pd.value_counts(sampling_pairs) / len(sampling_pairs)
freq = pd.concat([p1, p2], axis=1)
freq.columns = ["true_data", "sampling_data"]
freq.fillna(0, inplace=True)
freq['sum'] = freq.true_data + freq.sampling_data
freq.sort_values(['sum'], ascending=[False], inplace=True)
del freq['sum']
jsd = JSD(freq.true_data, freq.sampling_data)
freq.plot(ax=ax, kind='bar')
ax.set_xlim((-1, 40))
ax.set_title("Pair(JSD: %.5f)" % jsd)
def compare_data(fig_path, true_data, sampling_data):
fig = plt.figure(figsize=(12, 9))
ax1 = fig.add_subplot(3, 1, 1)
ax2 = fig.add_subplot(3, 1, 2)
ax3 = fig.add_subplot(3, 1, 3)
check_length(ax1, true_data, sampling_data)
check_frequency(ax2, true_data, sampling_data)
check_pair(ax3, true_data, sampling_data)
fig.tight_layout()
# fig.show()
fig.savefig(fig_path)
def main(unused_args):
dataset_path = FLAGS.dataset
sampling_data_path = FLAGS.sample
fig_path = FLAGS.figure
if not dataset_path or not sampling_data_path:
raise ValueError("Must set --true_data and --sampling_data")
with open(dataset_path) as f:
if dataset_path.endswith('.pkl'):
_, _, true_data = pickle.load(f)
else:
_, _, true_data = json.load(f)
with open(sampling_data_path) as f:
sampling_data = json.load(f)
print len(true_data)
print len(sampling_data)
compare_data(fig_path, true_data, sampling_data)
if __name__ == '__main__':
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment