Last active
January 30, 2020 14:29
-
-
Save bdhammel/8c95b028442e507f961e557c57d65be8 to your computer and use it in GitHub Desktop.
Example of MLE and MAP
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.stats import norm, invgamma | |
# The barrel of apples | |
# The average apples is between 70-100 g | |
BARREL = np.random.normal(loc=85, scale=20, size=100) | |
# Grid | |
WEIGHT_GUESSES = np.linspace(1, 200, 100) | |
ERROR_GUESSES = np.linspace(.1, 50, 100) | |
# NOTE: Try changing the scale error | |
# in practice, you would not know this number | |
SCALE_ERR = 5 | |
# NOTE: Try changing the number of measurements taken | |
N_MEASURMENTS = 10 | |
# NOTE: Try changing the prior values and distributions | |
PRIOR_WEIGHT = norm(50, 1).logpdf(WEIGHT_GUESSES) | |
PRIOR_ERR = invgamma(4).logpdf(ERROR_GUESSES) | |
LOG_PRIOR_GRID = np.add.outer(PRIOR_ERR, PRIOR_WEIGHT) | |
def read_scale(apple): | |
return apple + np.random.normal(loc=0, scale=SCALE_ERR) | |
def get_log_likelihood_grid(measurments): | |
log_liklelihood = [ | |
[ | |
norm(weight_guess, error_guess).logpdf(measurments).sum() | |
for weight_guess in WEIGHT_GUESSES | |
] | |
for error_guess in ERROR_GUESSES | |
] | |
return np.asarray(log_liklelihood) | |
def get_mle(measurments): | |
log_likelihood = get_log_likelihood_grid(measurments) | |
idx = np.argwhere(log_likelihood == log_likelihood.max())[0][1] | |
return WEIGHT_GUESSES[idx] | |
def get_map(measurments): | |
log_likelihood = get_log_likelihood_grid(measurments) | |
log_posterior = log_likelihood + LOG_PRIOR_GRID | |
idx = np.argwhere(log_posterior == log_posterior.max())[0][1] | |
return WEIGHT_GUESSES[idx] | |
# Pick and apple at random | |
apple = np.random.choice(BARREL) | |
# weight the apple | |
measurments = np.asarray([read_scale(apple) for _ in range(N_MEASURMENTS)]) | |
print(f"Average measurement: {measurments.mean():.3f} g") | |
print(f"Maximum Likelihood estimate: {get_mle(measurments):.3f} g") | |
print(f"Maximum A Posterior estimate: {get_map(measurments):.3f} g") | |
print(f"The true weight of the apple was: {apple:.3f} g") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment