-
-
Save jessstringham/c1a9f90ef62672597b07713ce68fd439 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.stats.distributions import binom | |
def binomial(n, k, p): | |
return binom.pmf(k, n, p) | |
def get_uninformative_prior(K_min, K_max, num_K): | |
P_k = np.ones(num_K)/num_K | |
N = np.linspace(K_min, K_max, num_K) | |
return P_k, N | |
class FieldStudy(object): | |
def __init__(self, true_N, K_min, K_max, num_K): | |
self.true_N = true_N | |
self.marked = np.repeat(False, self.true_N) | |
# parameters related to the population domain | |
self.K_min = K_min | |
self.K_max = K_max | |
self.num_K = num_K | |
self.P_k, self.N = get_uninformative_prior(self.K_min, self.K_max, self.num_K) | |
self.C = [] # Sampled | |
self.R = [] # Recaptured | |
self.M = [] # Marked at start of trial | |
def update_dist(self, marked, captured, recaptured): | |
# Compute the unnormalized distribution | |
unnormalized_P_k = self.P_k * binomial(self.C[-1], self.R[-1], self.M[-1] / self.N) | |
# Update the prior to trim off values that are impossible now | |
unnormalized_P_k[self.N < self.M[-1]] = 0 | |
self.P_k = unnormalized_P_k/np.sum(unnormalized_P_k) | |
def sample(self, difficulty_of_catch=0.05): | |
self.M.append(np.sum(self.marked)) | |
captured = np.random.rand(true_N) < difficulty_of_catch | |
recaptured = self.marked & captured | |
self.C.append(np.sum(captured)) # how many were caught | |
self.R.append(np.sum(recaptured)) # how many caught were already marked | |
self.marked |= captured # update which were marked | |
self.update_dist(self.M[-1], self.C[-1], self.R[-1]) | |
return captured | |
# Example usage | |
true_N = 8000 | |
fs = FieldStudy(true_N, 100, 100000, 1000) | |
for _ in range(21): | |
fs.sample(difficulty_of_catch=0.005) | |
print('M', fs.M) | |
print('C', fs.C) | |
print('R', fs.R) | |
print("Mode", fs.N[np.argmax(fs.P_k)]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment