Skip to content

Instantly share code, notes, and snippets.

@KennethEnevoldsen
Created August 10, 2021 11:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save KennethEnevoldsen/ecb881e81006aeeb1a269181c41c08d1 to your computer and use it in GitHub Desktop.
Save KennethEnevoldsen/ecb881e81006aeeb1a269181c41c08d1 to your computer and use it in GitHub Desktop.
ktom_fit_pymc3.py
"""
This script fits the k-ToM model to data using pymc3
"""
import pymc3 as pm
import tomsup as ts
# generating some sample data
group = ts.create_agents(["1-ToM", "2-ToM"])
penny = ts.PayoffMatrix("penny_competitive")
results = group.compete(
p_matrix=penny, n_rounds=30, env="round_robin", save_history=True
)
# define our input
trials = results.shape[0]
op_choices = results.choice_agent1
agent_choices = results.choice_agent0
p_matrix = penny
agent = 0
levels = [0, 4]
def initialize_agent(level: int, volatility: float, b_temp: float, bias: float, dilution: float):
return ts.agent.TOM(level=k, volatility=ts.log(sigma), b_temp=ts.log(b_temp), bias=b, dilution=ts.inv_logit(delta))
def agent_compete(prev_choice: int, choice_op: int, agent: ts.Agent):
agent.choice = prev_choice # force previous choice to update k-ToM correctly
_ = agent.compete(op_choice=choice_op, p_matrix=p_matrix, agent=agent)
# extract probability of choosing 1
internal = agent.get_internal_states()
return internal["p_self"]
with pm.Model() as ktom_fitting:
# some very simple priors
k = pm.distributions.discrete.DiscreteUniform("level", lower=levels[0], upper=levels[1])
sigma = pm.Uniform("sigma", 0, 10)
beta = pm.Uniform("beta", 0, 10)
b = pm.Normal("bias", 0, 1)
delta = pm.Uniform("delta", 0, 1)
# initialize agent
agent = initialize_agent(k, sigma, beta, b, delta)
prev_choice = None # set first choice to be None
for c_op in range(op_choices):
p = agent_compete(prev_choice, choice_op=c_op, agent=agent)
# sample choice as either 0 or 1
prev_choice = pm.Binomial("choice", 1, p, observed=agent_choices)
trace_h = pm.sample(1000, return_inferencedata=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment