Created
November 20, 2023 17:59
-
-
Save ChrisMzz/7e8ec5f19e73f4192dfe51b0e971042f to your computer and use it in GitHub Desktop.
Sexual selection model for Computational Biology course (IBM)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
try: # imports tqdm for aesthetics (I'm sorry), but just uses regular range if you don't have it | |
import tqdm | |
cstmrange = tqdm.trange | |
except: | |
print('tqdm package not found') | |
cstmrange = range | |
class Peacock: | |
def __init__(self, mother=None, father=None): | |
if mother == None or father == None: self.p, self.z = p0, z0 # set initial conditions | |
else: | |
parent1, parent2 = np.random.choice([mother, father], 2, replace=False) | |
self.p = parent1.p | |
self.z = np.random.choice([parent1.z, parent2.z], p=[1-r,r]) # pick z from parent from which p was selected with probability 1-r | |
if np.random.rand() < mu: self.p += np.random.normal(0, stdp) # mutation on p | |
if np.random.rand() < mu: self.z += np.random.normal(0, stdz) # mutation on z | |
self.p = 1 if self.p > 1 else -1 if self.p < -1 else self.p # p [-1,1] cutoff | |
self.fitness = self.compute_fitness() # compute fitness for selection | |
class Female(Peacock): | |
def __init__(self, mother=None, father=None): | |
super().__init__(mother,father) | |
def compute_fitness(self): return np.exp(-sp*self.p**2) | |
class Male(Peacock): | |
def __init__(self, mother=None, father=None): | |
super().__init__(mother,father) | |
def compute_fitness(self): return np.exp(-sz*self.z**2) | |
# FOR GRAPHS --------------------------------------------------- | |
def draw_population(population, axes): | |
"""Draw dot plot representation of all individuals in a population, z in red, p in blue. Two scales are used in the same axis. | |
Please provide a list of individuals for ``population``, and a list of an axis ``ax`` and its twin for the `x` axis ``ax.twinx()``. | |
""" | |
listp, listz = [], [] | |
for peacock in population: | |
listp.append(peacock.p) | |
listz.append(peacock.z) | |
t = [i for i in range(len(population))] | |
axes[0].plot(t, listz, marker='o', linestyle='none', color='Red') | |
axes[0].tick_params(axis='y', colors='Red') | |
axes[1].plot(t, listp, marker='o', linestyle='none', color='Blue') | |
axes[1].tick_params(axis='y', colors='Blue') | |
return axes | |
def draw_heatmaps(poplist, axes): | |
"""Draw both heatmaps of z value and p value frequencies over time. First plt is z, second is p. | |
""" | |
sns.heatmap(np.array([np.histogram([peacock.z for peacock in poplist[i]], bins=40, range=[-zmax, zmax])[0] for i in range(len(poplist))]), ax=axes[0], cmap='RdBu_r') | |
sns.heatmap(np.array([np.histogram([peacock.p for peacock in poplist[i]], bins=40, range=[-1,1])[0] for i in range(len(poplist))]), ax=axes[1], cmap='RdBu_r') | |
axes[0].set_xticklabels(np.round(np.linspace(-zmax, zmax, len(axes[0].get_xticks())),2)) | |
axes[1].set_xticklabels(np.round(np.linspace(-1, 1, len(axes[1].get_xticks())),2)) | |
axes[0].set_xlabel("z range"), axes[1].set_xlabel("p range") | |
axes[0].set_ylabel("generation number"), axes[1].set_ylabel("generation number") | |
return axes | |
def draw_over_time(poplist, axes): | |
"""Draw a plot of mean value of z and p per generation. Both are overlapped, red is z, blue is p. | |
""" | |
t = [dt for dt in range(len(poplist))] | |
axes[0].plot(t, [np.mean([peacock.z for peacock in poplist[i]]) for i in range(len(poplist))], color="Red", label="z") | |
axes[0].tick_params(axis='y', colors='Red') | |
axes[1].plot(t, [np.mean([peacock.p for peacock in poplist[i]]) for i in range(len(poplist))], color="Blue", label="p") | |
axes[1].tick_params(axis='y', colors='Blue') | |
axes[0].set_xlabel("generation number") | |
axes[0].set_ylabel("z mean"), axes[1].set_ylabel("p mean") | |
axes[0].legend(loc='upper left'), axes[1].legend(loc='upper right') | |
return axes | |
def draw_covariance(poplist, ax): | |
"""Draw a plot of covariance of mean value of z and p per generation. | |
""" | |
t = [dt for dt in range(len(poplist))] | |
c = [np.cov([peacock.z for peacock in poplist[i]], [peacock.p for peacock in poplist[i]])[0,1] | |
for i in range(len(poplist))] | |
ax.plot(t, c, color='Black', label="cov(z,p)") | |
ax.legend() | |
ax.set_xlabel("generation number") | |
return ax | |
# -------------------------------------------------------------- | |
# I mostly used the same variable notations and format as the teacher | |
Tmax = 1000 | |
zmax = 5 | |
z0, p0 = 0,0 | |
N = 500 | |
n = 20 # this is the teacher's leksize, the PDF calls it n so I kept it this way | |
r = 0.01 | |
mu = 0.01 | |
sz, sp = 0.1, 0.05 | |
stdz, stdp = 0.1, 0.05 | |
population = [Female() for _ in range(N)] + [Male() for _ in range(N)] # make N females and N males | |
poplist = [population] # make list of generations | |
for _ in cstmrange(Tmax): | |
offspring = [] | |
parents = {"Females":[], "Males":[]} | |
for peacock in population: | |
if np.random.rand() > 1-peacock.fitness: # selection (fitness function is gender-dependent) | |
parents[("Females","Males")[int(type(peacock)==Male)]].append(peacock) | |
females = np.random.choice(parents["Females"], 2*N, replace=True) | |
for i,f in enumerate(females): | |
if np.random.rand() < 1-abs(f.p): m = np.random.choice(parents["Males"]) # select random male with probability 1-|p| | |
else: | |
lek = np.random.choice(parents["Males"], n) | |
criteria = [lm.z*f.p for lm in lek] # multiplying by f.p conserves sign | |
m = lek[criteria.index(np.max(criteria))] # this guarantees that np.max will select the largest value wrt the sign of p | |
if i < N: offspring.append(Female(f,m)) | |
else: offspring.append(Male(f,m)) | |
population = offspring # update next generation | |
poplist.append(population) # update list of generations | |
fig, ((ax1,ax2),(axt, axc)) = plt.subplots(2,2) | |
fig.tight_layout() | |
draw_heatmaps(poplist, [ax1, ax2]) | |
draw_over_time(poplist, [axt, axt.twinx()]) | |
draw_covariance(poplist, axc) | |
axc.set_visible(False) # axc normally has the covariance plot but I think I'm missing the point of what it's supposed to do | |
plt.show() | |
# Note : | |
# We need to make a conscious decision on whether to plot after selection or after reproduction. | |
# In my case, i do it after reproduction. | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment