Skip to content

Instantly share code, notes, and snippets.

@gajomi
Created November 18, 2020 18:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gajomi/1cefa4559eef7157ab3d5e96165eaa80 to your computer and use it in GitHub Desktop.
Save gajomi/1cefa4559eef7157ab3d5e96165eaa80 to your computer and use it in GitHub Desktop.
Baseline loss functions for first order markov chains
"""
Quick tutorial on comparing loss function for a fitted markov chain model to `baselines`.
Illustrates the effect of randomness in data.
Shows principle behind following gradients of loss to optimize the parameters
"""
import torch
from torch import tensor
from torch import nn
from torch.nn import Softmax
#some utility functions
softmax = Softmax(dim = 0)
def _rand_choice(probs):
return torch.multinomial(probs, 1, replacement=True)[0]
def _one_hot(k,n):
x = torch.zeros(n)
x[k] = 1.
return x
def _one_hot3s(k):
return _one_hot(k,3)
#neural network representation of markov chain step
class LogReg(nn.Module):
def __init__(self, input_size):
super(LogReg, self).__init__()
self.linear = nn.Linear(input_size, input_size)
def forward(self, x):
return self.linear(x)
# translate neural netwrok representation of models to transition probabilities
def transition_probs_from_linear(model, x=None):
n = model.linear.bias.shape[0]
if x is None:
return torch.stack([transition_probs_from_linear(model, x=i) for i in range(n)]).t()
else:
x = _one_hot(x, n) if isinstance(x, int) else x
return softmax(model(x))
#classical markov chain representation sampling
def markov_chain_sample(Q, T, x0):
"""Sample T steps from markov chain with transition kernel Q, starting from x0"""
trajectory = torch.empty(T, dtype=torch.long)
trajectory[0] = x0
for t in range(1,T):
probs = Q[:,trajectory[t-1]]
trajectory[t] = _rand_choice(probs)
return trajectory
#neural network representations of model of interest
def stupidest_model(nstates = 3):
"""Guesses that all transitions are equally likely for all states"""
model = LogReg(nstates)
with torch.no_grad():
model.linear.bias.fill_(0)
model.linear.weight.fill_(0)
return model
def stupid_model(trajectory, nstates = 3):
"""Memoryless model that capture the frequency of states"""
model = LogReg(nstates)
probs = torch.bincount(trajectory).float()/len(trajectory)
log_probs = probs.log()
bias = log_probs-log_probs.mean()#subtract the mean for numerical stability
with torch.no_grad():
model.linear.weight.fill_(0.)
model.linear.bias = nn.Parameter(bias)
return model
def exact_model_from_transitions(transitions):
nstates = transitions.shape[0]
model = LogReg(nstates)
# one way to parameterize is to just ignore the bias term and set non-zero weights
log_probs = transitions.log()
weight = log_probs - log_probs.mean(dim=0)
with torch.no_grad():
model.linear.weight = nn.Parameter(weight)
model.linear.bias.fill_(0.)
return model
def eval_loss(model, trajectory):
"""Conenvience function to for inputs/puts and evaluate loss. Statistics on replicas"""
n = model.linear.bias.shape[0]
features = torch.stack([_one_hot(k,n) for k in trajectory[:-1]])
targets = trajectory[1:]
loss = nn.CrossEntropyLoss()
return loss(model(features), targets)
def main():
# timesteps and initial state
T = 2**10
x0 = tensor(2)
## initialize markov chain
Q = tensor([
[0.005, 0.845, 0.150],
[0.845, 0.005, 0.150],
[0.700, 0.100, 0.100]
]
).t()
exact_model = exact_model_from_transitions(Q)
print('This is the exact transition matrix')
print(Q)
print('\n')
#sample trajectories
replicas = 2**6
trajectories = [markov_chain_sample(Q, T, x0) for r in range(replicas)]
trajectory = trajectories[0]
## create the stupidest possible model
stupidest = stupidest_model()
Q_stupidest = transition_probs_from_linear(stupidest)
print('This is the simplest possible transition matrix')
print(Q_stupidest)
print('\n')
## create merely stupid model, notice that the transition probabilities are different, but identical across columns
stupid = stupid_model(trajectory)
Q_stupid = transition_probs_from_linear(stupid)
print('This transition matrix takes into account he baseline frequencies')
print(Q_stupid)
print('\n')
## loss function
loss = nn.CrossEntropyLoss()
## evaluate loss function for different models on a full single trajectory
print('Evaluate loss on single trajectory')
print(trajectory)
models = {'model 0':stupidest, 'model 1':stupid, 'exact model':exact_model}
for name,model in models.items():
g = eval_loss(model, trajectory)
print(f'{name}: {g}')
print('\n')
## examine the effect of randomness in data via replicas
print(f'Evaluate loss on {replicas} trajectories')
models = {'model 0':stupidest, 'model 1':stupid, 'exact model':exact_model}
for name,model in models.items():
losses = torch.stack([eval_loss(model, traj) for traj in trajectories])
print(f'{name}: mean = {losses.mean()}, std = {losses.std()}')
## show that by starting at stupid model, gradient takes you in the right direction
print('Inspect the gradients away from optimum')
g = eval_loss(stupidest, trajectory)
g.backward()
bgrad, wgrad = stupidest.linear.bias.grad, stupidest.linear.weight.grad
print(f'Bias gradient:\n {bgrad}')
print(f'Weight gradient:\n {wgrad}')
print('\n')
print('Follow the direction of this gradient, computing the loss without updating the gradient')
rate = 0.25
for step in range(2**5):
with torch.no_grad():
stupidest.linear.bias -= rate*bgrad
stupidest.linear.weight -= rate*wgrad
g = eval_loss(stupidest, trajectory)
print(g)
##TODO: how would stochastic gradients impact the above? What is impact of batch size?
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment