Skip to content

Instantly share code, notes, and snippets.

@vene
Created January 20, 2020 12:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vene/79ddddd3d6bb3e1f088c94bce55a88f4 to your computer and use it in GitHub Desktop.
Save vene/79ddddd3d6bb3e1f088c94bce55a88f4 to your computer and use it in GitHub Desktop.
"""
Check sampling from a sequential CRF model.
- Code is for n_states=2 but the strategy is general.
- TODO; cythonize or numbaize
- TODO; write general impl for clarity
"""
# author: vlad niculae <vlad@vene.ro>
# license: MIT
from collections import Counter
import numpy as np
from scipy.special import logsumexp, softmax
import itertools
import matplotlib.pyplot as plt
def iter_configs(n):
return itertools.product(*((0, 1) for _ in range(n)))
def score(config, eta_u, eta_v):
config = np.array(config)
unary_score = np.dot(eta_u, config)
pairwise_marg = config[:-1] * config[1:]
pairwise_score = np.dot(pairwise_marg, eta_v)
return unary_score + pairwise_score
def exhaustive_softmax(eta_u, eta_v):
"""Compute logZ and softmax-p exhaustively by trying all configs"""
n = eta_u.shape[0]
all_cfgs = list(iter_configs(n))
all_scores = [score(cfg, eta_u, eta_v) for cfg in all_cfgs]
all_scores = np.array(all_scores)
logZ = logsumexp(all_scores)
p = softmax(all_scores)
return logZ, p
def exhaustive_sample(eta_u, eta_v, size=10, seed=None):
"""Sample sequences explicitly from a full categorical distribution"""
rng = np.random.RandomState(seed)
n = eta_u.shape[0]
all_cfgs = list(iter_configs(n))
_, p = exhaustive_softmax(eta_u, eta_v)
ix = rng.choice(len(all_cfgs), p=p, size=size)
return ix
def forward(eta_u, eta_v):
"""Famous forward algo in sequence CRF. Computes logZ"""
n = len(eta_u)
n_states = 2
alpha = np.zeros((len(eta_u), n_states))
alpha[:, 1] = eta_u
for j in range(1, n):
alpha[j, 0] += logsumexp(alpha[j - 1])
alpha[j, 1] += logsumexp(alpha[j - 1] + np.array([0, eta_v[j - 1]]))
return logsumexp(alpha[-1])
def forward_sample(eta_u, eta_v, size=10, seed=None):
"""Sample from linear-chain CRF.
The algorithm is analogous to a softmax-viterbi:
in the "forward" direction, we compute logsumexp instead of max
and the backptr is a softargmax instead of an argmax.
"""
rng = np.random.RandomState(seed)
n = len(eta_u)
n_states = 2
alpha = np.zeros((n, n_states))
backp = np.zeros((n, n_states, n_states))
alpha[:, 1] = eta_u
for j in range(1, n):
alpha[j, 0] += logsumexp(alpha[j - 1])
backp[j, 0] = softmax(alpha[j - 1])
alpha[j, 1] += logsumexp(alpha[j - 1] + np.array([0, eta_v[j - 1]]))
backp[j, 1] = softmax(alpha[j - 1] + np.array([0, eta_v[j - 1]]))
final = softmax(alpha[-1])
seqs = []
for _ in range(size):
seq = []
z = rng.choice(2, p=final)
seq.append(z)
for j in range(n - 1, 0, -1):
z = rng.choice(2, p=backp[j, z])
seq.append(z)
seq.reverse()
seqs.append(seq)
return seqs
def bar_hist(ax, N, samples):
"""helper plot function"""
bins = Counter(samples)
x = np.arange(N)
height = [bins[i] for i in x]
height = np.array(height, dtype=np.double)
height /= height.sum()
print("sampled:", height.round(3))
ax.bar(x, height)
def main():
n = 3
eta_u = np.array([-3, -2, -1])
# eta_u = np.array([0, -100, 0])
eta_v = np.array([2, -1])
logZ, p = exhaustive_softmax(eta_u, eta_v)
print("check log partition function")
print(logZ)
print(forward(eta_u, eta_v))
print("check sampling")
samples = forward_sample(eta_u, eta_v, size=50000)
seq_to_ix = {seq: k for k, seq in enumerate(iter_configs(n))}
sample_ix = [seq_to_ix[tuple(sample)] for sample in samples]
fig, (ax1, ax2) = plt.subplots(1, 2)
N = len(p)
ax1.bar(x=np.arange(N), height=p)
print(" true:", p.round(3))
bar_hist(ax2, N, sample_ix)
ax1.set_title("True probabilities over sequences")
ax2.set_title("Sample mean")
lbls = ["".join(str(x) for x in cfg) for cfg in iter_configs(n)]
for ax in (ax1, ax2):
ax.set_xticks(np.arange(N))
ax.set_xticklabels(lbls, rotation=45)
fig.savefig("sample_seq_crf.png")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment