Skip to content

Instantly share code, notes, and snippets.

@bdhammel
Last active January 30, 2020 14:29
Show Gist options
  • Save bdhammel/8c95b028442e507f961e557c57d65be8 to your computer and use it in GitHub Desktop.
Save bdhammel/8c95b028442e507f961e557c57d65be8 to your computer and use it in GitHub Desktop.
Example of MLE and MAP
import numpy as np
from scipy.stats import norm, invgamma
# The barrel of apples
# The average apples is between 70-100 g
BARREL = np.random.normal(loc=85, scale=20, size=100)
# Grid
WEIGHT_GUESSES = np.linspace(1, 200, 100)
ERROR_GUESSES = np.linspace(.1, 50, 100)
# NOTE: Try changing the scale error
# in practice, you would not know this number
SCALE_ERR = 5
# NOTE: Try changing the number of measurements taken
N_MEASURMENTS = 10
# NOTE: Try changing the prior values and distributions
PRIOR_WEIGHT = norm(50, 1).logpdf(WEIGHT_GUESSES)
PRIOR_ERR = invgamma(4).logpdf(ERROR_GUESSES)
LOG_PRIOR_GRID = np.add.outer(PRIOR_ERR, PRIOR_WEIGHT)
def read_scale(apple):
return apple + np.random.normal(loc=0, scale=SCALE_ERR)
def get_log_likelihood_grid(measurments):
log_liklelihood = [
[
norm(weight_guess, error_guess).logpdf(measurments).sum()
for weight_guess in WEIGHT_GUESSES
]
for error_guess in ERROR_GUESSES
]
return np.asarray(log_liklelihood)
def get_mle(measurments):
log_likelihood = get_log_likelihood_grid(measurments)
idx = np.argwhere(log_likelihood == log_likelihood.max())[0][1]
return WEIGHT_GUESSES[idx]
def get_map(measurments):
log_likelihood = get_log_likelihood_grid(measurments)
log_posterior = log_likelihood + LOG_PRIOR_GRID
idx = np.argwhere(log_posterior == log_posterior.max())[0][1]
return WEIGHT_GUESSES[idx]
# Pick and apple at random
apple = np.random.choice(BARREL)
# weight the apple
measurments = np.asarray([read_scale(apple) for _ in range(N_MEASURMENTS)])
print(f"Average measurement: {measurments.mean():.3f} g")
print(f"Maximum Likelihood estimate: {get_mle(measurments):.3f} g")
print(f"Maximum A Posterior estimate: {get_map(measurments):.3f} g")
print(f"The true weight of the apple was: {apple:.3f} g")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment