Last active
January 12, 2016 14:49
-
-
Save lambdalisue/6060315 to your computer and use it in GitHub Desktop.
An example code of histogram plotting (Python)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# vim: set fileencoding=utf-8 : | |
# | |
# Author: Alisue (lambdalisue@hashnote.net) | |
# URL: http://hashnote.net/ | |
# Date: 2013-07-23 | |
# | |
# (C) 2013 hashnote.net, Alisue | |
# | |
import math | |
import itertools | |
import numpy as np | |
import matplotlib.pyplot as pl | |
from sklearn.mixture import GMM | |
def fit_gmm(X, n_components=5, **kwargs): | |
""" | |
混合ガウスモデルによるフィッティング | |
Arguments: | |
X -- サンプルデータ | |
n_components -- 使用するコンポーネント数の最大数(デフォルト 5) | |
*kwargs -- `sklearn.mixture.GMM`に渡される名前付き引数 | |
Returns: | |
best_model: AIC最小により判断された最適なモデル | |
""" | |
# covarianceのタイプリスト | |
COVARIANCE_TYPES = ['spherical', 'tied', 'diag', 'full'] | |
# covariance_typeとn_componentsの組み合わせ配列作成 | |
args = list(itertools.product(COVARIANCE_TYPES, range(1, n_components+1))) | |
# modelを格納する配列をゼロで初期化(GMMインスタンスを入れるのでdtype=object) | |
models = np.zeros(len(args), dtype=object) | |
# 各パラメータでモデルを作成 | |
for i, (ctype, n) in enumerate(args): | |
models[i] = GMM(n, covariance_type=ctype, **kwargs) | |
models[i].fit(X) | |
# 最適モデルをAICにより算出(AIC最小を選択) | |
AIC = np.array([m.aic(X) for m in models]) # 各モデルのAIC計算 | |
return models[np.argmin(AIC)] # AIC最小を選択 | |
def plot_histogram(X, **kwargs): | |
""" | |
ヒストグラムを描画 | |
Arguments: | |
X -- サンプルデータ | |
*kwargs -- `matplotlib.pyplot.bar` に渡される名前付き引数 | |
Returns: | |
hist, bins: `numpy.histogram`の戻り値 | |
""" | |
# 最適ビン数を計算(rice rule -- Wikipedia参照) | |
N = len(X) | |
k = math.pow(float(2*N), float(1)/3) | |
k = math.ceil(k) | |
# numpyを使用してヒストグラム作成 | |
# | |
# Note: | |
# matplotlibで直接ヒストグラムを描くほうが楽だがそういうサンプル | |
# はたくさんあるのであえてnumpyから行く方法を示す | |
# | |
hist, bins = np.histogram(X, bins=k) | |
# ヒストグラム描画 | |
width = (bins[1] - bins[0]) * 0.7 # ビン幅を実際のビン幅の0.7倍に | |
center = (bins[1:] + bins[:-1]) / 2 # ヒストグラム中心を計算 | |
# ヒストグラム中心で | |
pl.bar(center, hist, width=width, align='center', **kwargs) | |
return hist, bins | |
def plot_gmm(X, model, **kwargs): | |
""" | |
混合ガウスモデルを描画 | |
Arguments: | |
X -- サンプルデータ | |
model -- 描画する混合ガウスモデル | |
*kwargs -- `matplotlib.pyplot.bar` に渡される名前付き引数 | |
Returns: | |
x, pdf, pdf_individual: 描画に使用したパラメータ郡 | |
""" | |
# プロット用X座標配列作成(1000個固定) | |
x = np.linspace(np.min(X, axis=0), np.max(X, axis=0), 1000) | |
# プロット用x座標に対する確率・パラメータ比率を算出 | |
logprob, responsibilities = model.eval(x) | |
pdf = np.exp(logprob) | |
pdf_individual = responsibilities * pdf[:,np.newaxis] | |
# 混合ガウスモデルの描画 | |
pl.plot(x, pdf, 'r-', **kwargs) | |
# 各正規分布の描画 | |
args = zip(model.weights_, model.means_, model._get_covars()) | |
for i, (weight, mean, covar) in enumerate(args): | |
# varianceを計算(一次元なので別に計算する必要も無いんだけど…w) | |
var = np.diag(covar)[0] | |
# 式を作成: N(u, o^2) | |
formula_label = "N(%1.2f, %1.2f)" % (mean, var) | |
formula_label = "%d%% - %s" % (round(weight * 100), formula_label) | |
print formula_label | |
pl.plot(x, pdf_individual[:,i], 'k--', label=formula_label, alpha=0.5, **kwargs) | |
return x, pdf, pdf_individual | |
if __name__ == '__main__': | |
# | |
# サンプルデータの準備 | |
#=========================================================================== | |
# 3種類の標準分布を組み合わせたサンプルデータを作成 | |
mu = [100, 30, 50] # mean | |
sigma = [15, 5, 30] # variance | |
n = [300, 200, 500] # the number of samples | |
X = np.hstack([np.random.normal(*args) for args in zip(mu, sigma, n)]) | |
# | |
# 混合ガウスモデルによるフィッティング | |
#=========================================================================== | |
model = fit_gmm(X, n_components=5) | |
# | |
# ヒストグラムの描画 | |
#=========================================================================== | |
plot_histogram(X, color='k', alpha=0.7) | |
pl.xlabel('X') | |
pl.ylabel('Frequency') | |
# | |
# 混合ガウスモデルの描画 | |
#=========================================================================== | |
pl.twinx() | |
plot_gmm(X, model) | |
pl.ylabel('Probabilit') | |
pl.xticks() | |
pl.legend(loc='upper right') | |
pl.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment