Created
January 20, 2020 12:05
-
-
Save vene/79ddddd3d6bb3e1f088c94bce55a88f4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Check sampling from a sequential CRF model. | |
- Code is for n_states=2 but the strategy is general. | |
- TODO; cythonize or numbaize | |
- TODO; write general impl for clarity | |
""" | |
# author: vlad niculae <vlad@vene.ro> | |
# license: MIT | |
from collections import Counter | |
import numpy as np | |
from scipy.special import logsumexp, softmax | |
import itertools | |
import matplotlib.pyplot as plt | |
def iter_configs(n): | |
return itertools.product(*((0, 1) for _ in range(n))) | |
def score(config, eta_u, eta_v): | |
config = np.array(config) | |
unary_score = np.dot(eta_u, config) | |
pairwise_marg = config[:-1] * config[1:] | |
pairwise_score = np.dot(pairwise_marg, eta_v) | |
return unary_score + pairwise_score | |
def exhaustive_softmax(eta_u, eta_v): | |
"""Compute logZ and softmax-p exhaustively by trying all configs""" | |
n = eta_u.shape[0] | |
all_cfgs = list(iter_configs(n)) | |
all_scores = [score(cfg, eta_u, eta_v) for cfg in all_cfgs] | |
all_scores = np.array(all_scores) | |
logZ = logsumexp(all_scores) | |
p = softmax(all_scores) | |
return logZ, p | |
def exhaustive_sample(eta_u, eta_v, size=10, seed=None): | |
"""Sample sequences explicitly from a full categorical distribution""" | |
rng = np.random.RandomState(seed) | |
n = eta_u.shape[0] | |
all_cfgs = list(iter_configs(n)) | |
_, p = exhaustive_softmax(eta_u, eta_v) | |
ix = rng.choice(len(all_cfgs), p=p, size=size) | |
return ix | |
def forward(eta_u, eta_v): | |
"""Famous forward algo in sequence CRF. Computes logZ""" | |
n = len(eta_u) | |
n_states = 2 | |
alpha = np.zeros((len(eta_u), n_states)) | |
alpha[:, 1] = eta_u | |
for j in range(1, n): | |
alpha[j, 0] += logsumexp(alpha[j - 1]) | |
alpha[j, 1] += logsumexp(alpha[j - 1] + np.array([0, eta_v[j - 1]])) | |
return logsumexp(alpha[-1]) | |
def forward_sample(eta_u, eta_v, size=10, seed=None): | |
"""Sample from linear-chain CRF. | |
The algorithm is analogous to a softmax-viterbi: | |
in the "forward" direction, we compute logsumexp instead of max | |
and the backptr is a softargmax instead of an argmax. | |
""" | |
rng = np.random.RandomState(seed) | |
n = len(eta_u) | |
n_states = 2 | |
alpha = np.zeros((n, n_states)) | |
backp = np.zeros((n, n_states, n_states)) | |
alpha[:, 1] = eta_u | |
for j in range(1, n): | |
alpha[j, 0] += logsumexp(alpha[j - 1]) | |
backp[j, 0] = softmax(alpha[j - 1]) | |
alpha[j, 1] += logsumexp(alpha[j - 1] + np.array([0, eta_v[j - 1]])) | |
backp[j, 1] = softmax(alpha[j - 1] + np.array([0, eta_v[j - 1]])) | |
final = softmax(alpha[-1]) | |
seqs = [] | |
for _ in range(size): | |
seq = [] | |
z = rng.choice(2, p=final) | |
seq.append(z) | |
for j in range(n - 1, 0, -1): | |
z = rng.choice(2, p=backp[j, z]) | |
seq.append(z) | |
seq.reverse() | |
seqs.append(seq) | |
return seqs | |
def bar_hist(ax, N, samples): | |
"""helper plot function""" | |
bins = Counter(samples) | |
x = np.arange(N) | |
height = [bins[i] for i in x] | |
height = np.array(height, dtype=np.double) | |
height /= height.sum() | |
print("sampled:", height.round(3)) | |
ax.bar(x, height) | |
def main(): | |
n = 3 | |
eta_u = np.array([-3, -2, -1]) | |
# eta_u = np.array([0, -100, 0]) | |
eta_v = np.array([2, -1]) | |
logZ, p = exhaustive_softmax(eta_u, eta_v) | |
print("check log partition function") | |
print(logZ) | |
print(forward(eta_u, eta_v)) | |
print("check sampling") | |
samples = forward_sample(eta_u, eta_v, size=50000) | |
seq_to_ix = {seq: k for k, seq in enumerate(iter_configs(n))} | |
sample_ix = [seq_to_ix[tuple(sample)] for sample in samples] | |
fig, (ax1, ax2) = plt.subplots(1, 2) | |
N = len(p) | |
ax1.bar(x=np.arange(N), height=p) | |
print(" true:", p.round(3)) | |
bar_hist(ax2, N, sample_ix) | |
ax1.set_title("True probabilities over sequences") | |
ax2.set_title("Sample mean") | |
lbls = ["".join(str(x) for x in cfg) for cfg in iter_configs(n)] | |
for ax in (ax1, ax2): | |
ax.set_xticks(np.arange(N)) | |
ax.set_xticklabels(lbls, rotation=45) | |
fig.savefig("sample_seq_crf.png") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment