Skip to content

Instantly share code, notes, and snippets.

@bkj
Last active January 19, 2024 23:37
Show Gist options
  • Save bkj/cfa8d52652e03d56d6d962e6b8ccf951 to your computer and use it in GitHub Desktop.
Save bkj/cfa8d52652e03d56d6d962e6b8ccf951 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
"""
ablr.py
"""
import numpy as np
import torch
from torch import nn
torch.set_default_tensor_type('torch.DoubleTensor')
torch.set_num_threads(1)
# --
# Helpers
class Encoder(nn.Module):
""" NN for learning projection """
def __init__(self, input_dim=1, output_dim=1, hidden_dim=50):
super().__init__()
self._encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
)
def forward(self, x):
return self._encoder(x)
class BLR:
""" Bayesian linear regression """
def __init__(self, alpha, beta):
self.alpha = alpha
self.beta = beta
def fit(self, phi, y):
S_inv_prior = self.alpha * torch.eye(phi.shape[1])
S_inv = S_inv_prior + self.beta * phi.t() @ phi
S = torch.inverse(S_inv)
m = self.beta * S @ phi.t() @ y
self.S = S
self.m = m
return self
def predict_with_nll(self, phi, y):
mu = phi @ self.m
sig = 1 / self.beta + ((phi @ self.S) * phi).sum(dim=-1)
nll = ((y - mu).pow(2).sum() / sig).mean() + sig.log().mean()
return mu, sig, nll
# --
# Make datasets
def make_problems(num_problems):
"""
Generate synthetic problems
sin functions w/ different amplitude, phase and frequency + noise
"""
problems = []
for _ in range(num_problems):
x = np.random.uniform(-5, 5, (10, 1))
noise_std = 0.1
amp = np.random.uniform(0.1, 5.0)
phase = np.random.uniform(0, 3.14)
freq = np.random.uniform(0.999, 1.0)
y = amp * np.sin(freq * x + phase)
y += np.random.normal(0, noise_std, y.shape)
problems.append([
torch.Tensor(x),
torch.Tensor(y),
])
return problems
num_problems = 30
train_problems = make_problems(num_problems=num_problems)
# --
# Setup model
encoder = Encoder()
alphas = nn.Parameter(torch.zeros(num_problems)) # One alpha per problem
betas = nn.Parameter(torch.zeros(num_problems)) # One beta per problem
params = list(encoder.parameters()) + [alphas, betas]
opt = torch.optim.LBFGS(params, lr=0.1, max_iter=30)
# --
# Train
def _optimization_target():
opt.zero_grad()
total_nll, total_mse = 0, 0
for idx, (X, y) in enumerate(train_problems):
alpha, beta = 10 ** alphas[idx], 1 / 10 ** betas[idx]
phi = encoder(X)
blr = BLR(alpha=alpha, beta=beta)
blr = blr.fit(phi, y)
mu, sig, nll = blr.predict_with_nll(phi, y)
total_nll += nll
total_mse += ((mu - y) ** 2).mean()
total_nll /= len(train_problems)
total_mse /= len(train_problems)
total_nll.backward()
print(float(total_mse), float(total_nll))
return float(total_nll)
opt.step(_optimization_target)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment