Skip to content

Instantly share code, notes, and snippets.

Created November 4, 2021 16:05
Show Gist options
  • Save dmasad/e080016635dc8ca914ec6439d97287ea to your computer and use it in GitHub Desktop.
Save dmasad/e080016635dc8ca914ec6439d97287ea to your computer and use it in GitHub Desktop.
Mesa spatial SIR model with visualization server
from mesa import Agent, Model
from mesa.time import RandomActivation
from import MultiGrid
from mesa.datacollection import DataCollector
from mesa.visualization.ModularVisualization import ModularServer
from mesa.visualization.modules import CanvasGrid, ChartModule
from mesa.visualization.UserParam import UserSettableParameter
# Model
# ===============================================================
class SpatialSIRModel(Model):
def __init__(self, n_agents, initial_infected=1, infection_radius=1):
self.schedule = RandomActivation(self)
self.grid = MultiGrid(20, 20, True)
self.datacollector = DataCollector(model_reporters={
"Susceptible": 'susceptible',
"Infected": 'infected',
"Recovered": 'recovered'},
agent_reporters={"coordinates": "pos",
"status": "status"})
# Create agents
for i in range(n_agents):
a = SpatialSIRAgent(i, self)
# Place the agent somewhere at random on the grid
x = self.random.randrange(self.grid.width)
y = self.random.randrange(self.grid.height)
self.grid.place_agent(a, (x, y))
for i in range(initial_infected):
self.agents[i].status = "Infected"
self.infection_radius = infection_radius
self.running = True
def susceptible(self):
return len([agent for agent in self.agents if agent.status=='Susceptible'])
def infected(self):
return len([agent for agent in self.agents if agent.status=='Infected'])
def recovered(self):
return len([agent for agent in self.agents if agent.status=='Recovered'])
def agents(self):
return self.schedule.agents
def step(self):
# Check end condition
if len([agent for agent in self.agents if agent.status=="Infected"]) == 0:
self.running = False
class SpatialSIRAgent(Agent):
def __init__(self, unique_id, model):
super().__init__(unique_id, model)
self.status = "Susceptible"
def move(self):
# We can use self.model.grid.torus_adj, or:
possible_steps = self.model.grid.get_neighborhood(
new_pos = self.random.choice(possible_steps)
self.model.grid.move_agent(self, new_pos)
def infect(self):
exposed = self.model.grid.get_neighbors(self.pos, moore=True,
for agent in exposed:
if agent.status == "Susceptible":
agent.status = "Infected"
def step(self):
if self.status == "Infected":
if self.random.random() < 0.25:
if self.random.random() < 0.25:
self.status = "Recovered"
# Server
# ===============================================================
COLORS = {"Susceptible": "blue", "Infected": "red", "Recovered": "green"}
def sir_model_portrayal(cell):
if cell is None:
portrayal = {"Shape": "circle",
"r": 0.9,
"Filled": "true",
"Layer": 0}
(x, y) = cell.pos
portrayal["x"] = x
portrayal["y"] = y
portrayal["Color"] = COLORS[cell.status]
return portrayal
if __name__ == "__main__":
grid_element = CanvasGrid(sir_model_portrayal, 20, 20, 500, 500)
chart_element = ChartModule([{"Label": label, "Color": color }
for label, color in COLORS.items()])
model_params = {
"n_agents": 100,
"initial_infected": 10,
"infection_radius": UserSettableParameter("slider", "Infection Radius",
1, 1, 10, 1)
server = ModularServer(SpatialSIRModel, [grid_element, chart_element],
"Spatial SIR Model", model_params)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment