Skip to content

Instantly share code, notes, and snippets.

@dfm
Created October 19, 2012 20:15
Show Gist options
  • Save dfm/3920431 to your computer and use it in GitHub Desktop.
Save dfm/3920431 to your computer and use it in GitHub Desktop.
Maximum likelihood parameter estimation for the skew normal based on samples

Usage

import sn

samples = np.array([...])  # Some numpy array of samples.

pars = sn.fit_skew_normal(samples)

print pars  # prints the location, scale and shape parameter from:
            #     http://en.wikipedia.org/wiki/Skew_normal_distribution
__all__ = ["fit_skew_normal"]
import numpy as np
import scipy.optimize as op
import scipy.special as sp
def fit_skew_normal(samples, p0=None):
if p0 is None:
mu = np.mean(samples)
std = np.std(samples)
# Total HACK for guessing initial alpha... seems to work though...
skewness = np.mean(((samples - mu) / std) ** 3)
p0 = [mu, std, 10.0 * skewness]
else:
p0[1] = np.sqrt(p0[1])
# The total negative log-likelihood.
nll = lambda p: -np.mean(loglike(samples, p[0], p[1] * p[1], p[2]))
p1 = op.fmin_bfgs(nll, p0)
p1[1] = p1[1] ** 2
return p1
_factor = -0.5 * np.log(2 * np.pi)
def loglike(x, mu, w, alpha):
v = (mu - x) / w
arg = sp.erfc(alpha * v / np.sqrt(2))
if np.any(arg <= 0):
return -1e10 * np.ones_like(x)
ll = _factor - 0.5 * v * v + np.log(arg) - np.log(w)
return ll
def sample_sn(mu, w, alpha, N=1):
u0 = np.random.randn(N)
v = np.random.randn(N)
delta = alpha / np.sqrt(1 + alpha * alpha)
u1 = delta * u0 + np.sqrt(1 - delta * delta) * v
u1[u0 < 0] = -u1[u0 < 0]
return mu + w * u1
if __name__ == "__main__":
import matplotlib.pyplot as pl
params = [2.0, 3.0, -8.0]
# Draw some samples.
samples = sample_sn(*params, N=10000)
# Fitting function.
x = np.linspace(samples.min(), samples.max(), 5000)
y_true = np.exp(loglike(x, *params))
# Fit for the maximum likelihood parameters.
p_fit = fit_skew_normal(samples)
y_fit = np.exp(loglike(x, *p_fit))
print("True parameters: {0}".format(params))
print("Fit parameters: {0}".format(p_fit))
# Plot the truth.
pl.plot(x, y_true, "-", color="#888888", lw=3, zorder=-100)
# Plot a histogram of the samples.
pl.hist(samples, 100, histtype="step", color="k", normed=True)
# Plot the fit.
pl.plot(x, y_fit, "--r", lw=1.5)
pl.savefig("samps.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment