Skip to content

Instantly share code, notes, and snippets.

@norouzi
Last active June 18, 2019 14:28
Show Gist options
  • Save norouzi/6e65b253f7b143d339832d3b292cc3a7 to your computer and use it in GitHub Desktop.
Save norouzi/6e65b253f7b143d339832d3b292cc3a7 to your computer and use it in GitHub Desktop.
def ComputeHammingCDF(len_target, temprature, vocab):
max_edits = len_target + 1 # we allow between 0 and len_target subs
a = np.zeros(max_edits)
for n_subs in range(max_edits):
count_n_subs = []
tot_edits = misc.comb(len_target, n_subs)
a[n_subs] = np.log(tot_edits) + n_subs * np.log(len(vocab) - 1) # number of sequences: tot_edits * (N-1) ^ n_subs
a[n_subs] += - n_subs / float(temprature) * np.log(len(vocab) - 1) - n_subs / float(temprature) # tot_edits * (N-1) ^ n_subs * ((N-1)e) ^ (-n_subs / T)
p_subs = a - np.max(a)
p_subs = np.exp(p_subs)
p_subs /= np.sum(p_subs)
p_hamming_cdf = np.cumsum(p_subs)
return p_hamming_cdf
# one can precompute hamming CDFs for different sequence lengths
subs_cdf = []
for len_target in range(200): # assuming maximum length is 200
p_subs_cdf = ComputeHammingCDF(len_target, temprature = augment_target_prob, vocab = vocab)
subs_cdf.append(p_subs_cdf)
def SubstitutionSampling(s, temprature, hamming_cdf, vocab):
'''
Sample one sequence from the vicinity of a given target sequence s.
A string t is sampled proportionally to exp{-hamming_distance(t, s) / temprature}
Args:
s: numpy array of a sequence which is output of a seq2seq/crf model e.g. POS tag sequence
temprature: temprature of sampling
hamming_cdf: precomputed edit CDF
vocab: the vocabulary elements that are allowed for substitution
Returns:
numpy array of a sampled sequence t
'''
assert(min(vocab) >= 0)
len_target = len(s) - 1
p_hamming_cdf = hamming_cdf[len_target]
# sample
rand_n_subs = np.sum(np.random.rand() >= p_hamming_cdf)
# apply changes
t = copy.copy(s)
perm = np.random.permutation(len_target)
subs = perm[:rand_n_subs]
for i in subs:
while True:
rand_char = vocab[np.random.randint(len(vocab))]
if not t[i] == rand_char:
break
t[i] = rand_char
return t
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment