Skip to content

Instantly share code, notes, and snippets.

@windweller
Last active November 15, 2019 05:39
Show Gist options
  • Save windweller/7e3a55bb8d8a26fdf0064802a78f4183 to your computer and use it in GitHub Desktop.
Save windweller/7e3a55bb8d8a26fdf0064802a78f4183 to your computer and use it in GitHub Desktop.
Amortized RSA, without fully realizing S1 Pragmatic Speaker (no need)
from collections import defaultdict
utterances = ["blue", "green", "square", "circle"]
objects = ['blue square', 'blue circle', 'green square']
def meaning(utt, obj):
return int(utt in obj)
def normalize(space):
denom = sum(space.values())
for obj in space.keys():
space[obj] /= float(denom)
return space
def cond_normalize(conditional_model):
for utt, space in conditional_model.items():
normalize(space)
return conditional_model
def prob_model():
# {utt -> {obj1: prob, obj2: prob, ...}}
conditional_model = defaultdict(dict)
for utt in utterances:
for obj in objects:
conditional_model[utt][obj] = meaning(utt, obj)
return cond_normalize(conditional_model)
def get_optimal(cond_model, utt):
# utt -> {obj: prob, obj2: prob}
opt_model = {}
# Complexity: |S| x |U|
# Can cache denom which is Z_S
# |S|
for obj in objects:
nom = cond_model[utt][obj]
denom = 0
# |U|
for obj_set in cond_model.values():
if obj in obj_set:
denom += obj_set[obj]
if denom > 0:
opt_model[obj] = nom / float(denom)
# This is S1 (without the final normlize)
return normalize(opt_model)
def get_optimal_model(cond_model):
# we recompute the model
# just like prob_model()
opt_conditional_model = {}
for utt in utterances:
opt_conditional_model[utt] = get_optimal(cond_model, utt)
return opt_conditional_model
print("L0:")
print(prob_model()) # L0
print("L1:")
# Partial Evaluation L1 model
print(get_optimal(prob_model(), 'blue'))
print(get_optimal(prob_model(), 'square'))
print(get_optimal(prob_model(), 'green'))
print(get_optimal(prob_model(), 'circle'))
print("Full L1 model:")
# Full Evaluation L1 Model
print(get_optimal_model(prob_model()))
# Full Evaluation L2 Model
print("Full L2 model:")
print(get_optimal_model(get_optimal_model(prob_model())))
# Full Evaluation L3 Model
print("Full L3 model:")
print(get_optimal_model(get_optimal_model(get_optimal_model(prob_model()))))
# It's probably equivalent to adjusting the rationality parameter
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment