Skip to content

Instantly share code, notes, and snippets.

@ckrapu
Created April 7, 2021 02:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ckrapu/e2fb8692972ec2b499a1494760ff626e to your computer and use it in GitHub Desktop.
Save ckrapu/e2fb8692972ec2b499a1494760ff626e to your computer and use it in GitHub Desktop.
import enum
import numpy as np
from mesa import Agent, Model
from mesa.time import RandomActivation
from mesa.space import MultiGrid
from mesa.datacollection import DataCollector
from tqdm import tqdm
class InfectionModel(Model):
"""A model for infection spread."""
def __init__(self, N=10, width=10, height=10, ptrans=0.1,
death_rate=0.02, recovery_days=21,
recovery_sd=7,p_infected_initial=0.001):
self.num_agents = N
self.recovery_days = recovery_days
self.recovery_sd = recovery_sd
self.ptrans = ptrans
self.death_rate = death_rate
self.grid = MultiGrid(width, height, True)
self.schedule = RandomActivation(self)
self.running = True
self.dead_agents = []
# Create agents
for i in range(self.num_agents):
a = MyAgent(i, self)
self.schedule.add(a)
# Add the agent to a random grid cell
x = self.random.randrange(self.grid.width)
y = self.random.randrange(self.grid.height)
self.grid.place_agent(a, (x, y))
#make some agents infected at start
infected = np.random.choice([0,1], p=[1-p_infected_initial,p_infected_initial])
if infected == 1:
a.state = State.INFECTED
a.recovery_time = self.get_recovery_time()
self.datacollector = DataCollector(
agent_reporters={"State": "state"})
def get_recovery_time(self):
return int(self.random.normalvariate(self.recovery_days,self.recovery_sd))
def step(self):
self.datacollector.collect(self)
self.schedule.step()
class State(enum.IntEnum):
SUSCEPTIBLE = 0
INFECTED = 1
REMOVED = 2
class MyAgent(Agent):
""" An agent in an epidemic model."""
def __init__(self, unique_id, model):
super().__init__(unique_id, model)
self.age = self.random.normalvariate(20,40)
self.state = State.SUSCEPTIBLE
self.infection_time = 0
def move(self):
"""Move the agent"""
possible_steps = self.model.grid.get_neighborhood(
self.pos,
moore=True,
include_center=False)
new_position = self.random.choice(possible_steps)
self.model.grid.move_agent(self, new_position)
def status(self):
"""Check infection status"""
if self.state == State.INFECTED:
drate = self.model.death_rate
alive = np.random.choice([0,1], p=[drate,1-drate])
if not alive:
self.state = State.REMOVED
self.model.schedule.remove(self)
t = self.model.schedule.time-self.infection_time
if t >= self.recovery_time:
self.state = State.REMOVED
def contact(self):
"""Find close contacts and infect"""
cellmates = self.model.grid.get_cell_list_contents([self.pos])
if len(cellmates) > 1:
for other in cellmates:
if self.random.random() > self.model.ptrans:
continue
if self.state is State.INFECTED and other.state is State.SUSCEPTIBLE:
other.state = State.INFECTED
other.infection_time = self.model.schedule.time
other.recovery_time = self.model.get_recovery_time()
def step(self):
self.status()
self.move()
self.contact()
state_dict = {
0:'Susceptible',
1:'Infected',
2:'Removed',
}
def SIR(steps, model_kwargs={}):
model = InfectionModel(**model_kwargs)
grid_state = np.zeros((steps, model.grid.width, model.grid.height, len(State)))
for i in tqdm(range(steps)):
model.step()
for cell in model.grid.coord_iter():
agents, x, y = cell
for a in agents:
grid_state[i, x,y,a.state] +=1
return model, grid_state
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment