Created
November 18, 2020 18:19
-
-
Save gajomi/1cefa4559eef7157ab3d5e96165eaa80 to your computer and use it in GitHub Desktop.
Baseline loss functions for first order markov chains
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Quick tutorial on comparing loss function for a fitted markov chain model to `baselines`. | |
Illustrates the effect of randomness in data. | |
Shows principle behind following gradients of loss to optimize the parameters | |
""" | |
import torch | |
from torch import tensor | |
from torch import nn | |
from torch.nn import Softmax | |
#some utility functions | |
softmax = Softmax(dim = 0) | |
def _rand_choice(probs): | |
return torch.multinomial(probs, 1, replacement=True)[0] | |
def _one_hot(k,n): | |
x = torch.zeros(n) | |
x[k] = 1. | |
return x | |
def _one_hot3s(k): | |
return _one_hot(k,3) | |
#neural network representation of markov chain step | |
class LogReg(nn.Module): | |
def __init__(self, input_size): | |
super(LogReg, self).__init__() | |
self.linear = nn.Linear(input_size, input_size) | |
def forward(self, x): | |
return self.linear(x) | |
# translate neural netwrok representation of models to transition probabilities | |
def transition_probs_from_linear(model, x=None): | |
n = model.linear.bias.shape[0] | |
if x is None: | |
return torch.stack([transition_probs_from_linear(model, x=i) for i in range(n)]).t() | |
else: | |
x = _one_hot(x, n) if isinstance(x, int) else x | |
return softmax(model(x)) | |
#classical markov chain representation sampling | |
def markov_chain_sample(Q, T, x0): | |
"""Sample T steps from markov chain with transition kernel Q, starting from x0""" | |
trajectory = torch.empty(T, dtype=torch.long) | |
trajectory[0] = x0 | |
for t in range(1,T): | |
probs = Q[:,trajectory[t-1]] | |
trajectory[t] = _rand_choice(probs) | |
return trajectory | |
#neural network representations of model of interest | |
def stupidest_model(nstates = 3): | |
"""Guesses that all transitions are equally likely for all states""" | |
model = LogReg(nstates) | |
with torch.no_grad(): | |
model.linear.bias.fill_(0) | |
model.linear.weight.fill_(0) | |
return model | |
def stupid_model(trajectory, nstates = 3): | |
"""Memoryless model that capture the frequency of states""" | |
model = LogReg(nstates) | |
probs = torch.bincount(trajectory).float()/len(trajectory) | |
log_probs = probs.log() | |
bias = log_probs-log_probs.mean()#subtract the mean for numerical stability | |
with torch.no_grad(): | |
model.linear.weight.fill_(0.) | |
model.linear.bias = nn.Parameter(bias) | |
return model | |
def exact_model_from_transitions(transitions): | |
nstates = transitions.shape[0] | |
model = LogReg(nstates) | |
# one way to parameterize is to just ignore the bias term and set non-zero weights | |
log_probs = transitions.log() | |
weight = log_probs - log_probs.mean(dim=0) | |
with torch.no_grad(): | |
model.linear.weight = nn.Parameter(weight) | |
model.linear.bias.fill_(0.) | |
return model | |
def eval_loss(model, trajectory): | |
"""Conenvience function to for inputs/puts and evaluate loss. Statistics on replicas""" | |
n = model.linear.bias.shape[0] | |
features = torch.stack([_one_hot(k,n) for k in trajectory[:-1]]) | |
targets = trajectory[1:] | |
loss = nn.CrossEntropyLoss() | |
return loss(model(features), targets) | |
def main(): | |
# timesteps and initial state | |
T = 2**10 | |
x0 = tensor(2) | |
## initialize markov chain | |
Q = tensor([ | |
[0.005, 0.845, 0.150], | |
[0.845, 0.005, 0.150], | |
[0.700, 0.100, 0.100] | |
] | |
).t() | |
exact_model = exact_model_from_transitions(Q) | |
print('This is the exact transition matrix') | |
print(Q) | |
print('\n') | |
#sample trajectories | |
replicas = 2**6 | |
trajectories = [markov_chain_sample(Q, T, x0) for r in range(replicas)] | |
trajectory = trajectories[0] | |
## create the stupidest possible model | |
stupidest = stupidest_model() | |
Q_stupidest = transition_probs_from_linear(stupidest) | |
print('This is the simplest possible transition matrix') | |
print(Q_stupidest) | |
print('\n') | |
## create merely stupid model, notice that the transition probabilities are different, but identical across columns | |
stupid = stupid_model(trajectory) | |
Q_stupid = transition_probs_from_linear(stupid) | |
print('This transition matrix takes into account he baseline frequencies') | |
print(Q_stupid) | |
print('\n') | |
## loss function | |
loss = nn.CrossEntropyLoss() | |
## evaluate loss function for different models on a full single trajectory | |
print('Evaluate loss on single trajectory') | |
print(trajectory) | |
models = {'model 0':stupidest, 'model 1':stupid, 'exact model':exact_model} | |
for name,model in models.items(): | |
g = eval_loss(model, trajectory) | |
print(f'{name}: {g}') | |
print('\n') | |
## examine the effect of randomness in data via replicas | |
print(f'Evaluate loss on {replicas} trajectories') | |
models = {'model 0':stupidest, 'model 1':stupid, 'exact model':exact_model} | |
for name,model in models.items(): | |
losses = torch.stack([eval_loss(model, traj) for traj in trajectories]) | |
print(f'{name}: mean = {losses.mean()}, std = {losses.std()}') | |
## show that by starting at stupid model, gradient takes you in the right direction | |
print('Inspect the gradients away from optimum') | |
g = eval_loss(stupidest, trajectory) | |
g.backward() | |
bgrad, wgrad = stupidest.linear.bias.grad, stupidest.linear.weight.grad | |
print(f'Bias gradient:\n {bgrad}') | |
print(f'Weight gradient:\n {wgrad}') | |
print('\n') | |
print('Follow the direction of this gradient, computing the loss without updating the gradient') | |
rate = 0.25 | |
for step in range(2**5): | |
with torch.no_grad(): | |
stupidest.linear.bias -= rate*bgrad | |
stupidest.linear.weight -= rate*wgrad | |
g = eval_loss(stupidest, trajectory) | |
print(g) | |
##TODO: how would stochastic gradients impact the above? What is impact of batch size? | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment