-
-
Save marcociccone/b82bcd3e4d9aafae8946108306d0e183 to your computer and use it in GitHub Desktop.
Reward Augmented Maximum Likelihood (RAML; https://arxiv.org/pdf/1609.00150.pdf) -- Python code snippet to compute marginal distribution of different #edits for a given sequence length, temperature, and vocab size.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scipy.misc as misc | |
import numpy as np | |
len_target = 20 | |
v = 60 # Vocabulary size | |
T = .9 # Temperature | |
max_edits = len_target | |
x = np.zeros(max_edits) | |
for n_edits in range(max_edits): | |
total_n_edits = 0 # total edits with n_edits edits without v^n_edits term | |
for n_substitutes in range(min(len_target, n_edits)+1): | |
print n_substitutes | |
n_insert = n_edits - n_substitutes | |
current_edits = misc.comb(len_target, n_substitutes, exact=False) * \ | |
misc.comb(len_target+n_insert-n_substitutes, n_insert, exact=False) | |
total_n_edits += current_edits | |
x[n_edits] = np.log(total_n_edits) + n_edits * np.log(v) | |
# log(tot_edits * v^n_edits) | |
x[n_edits] = x[n_edits] -n_edits / T * np.log(v) -n_edits / T | |
# log(tot_edits * v^n_edits * exp(-n_edits / T) * v^(-n_edits / T)) | |
p = np.exp(x) | |
p /= np.sum(p) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment