Skip to content

Instantly share code, notes, and snippets.

@redwrasse
Last active January 8, 2021 21:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save redwrasse/1281b12a7012ad9e699842f2701eb8a9 to your computer and use it in GitHub Desktop.
Save redwrasse/1281b12a7012ad9e699842f2701eb8a9 to your computer and use it in GitHub Desktop.
heteroscedastic model: parameterized variance in discriminative gaussian
# -*- coding: utf-8 -*-
"""
standard discriminative gaussian
y ~ N(f(x), sigma^2)
as well as heteroscedastic model
y ~ N(f(x), sigma^2(x))
training on a dataset requiring the heteroscedastic model:
x in R, y in R^2
maps points about -1 -> small variance gaussian about (-1, -1)
+1 -> large variance gaussian about (1, 1)
standard model:
0 125.3414077758789
5000 123.4275894165039
10000 123.38433837890625
15000 123.31232452392578
20000 123.17870330810547
25000 122.92581176757812
...
heteroscedastic model:
0 25.26043701171875
5000 3.58022141456604
10000 3.5802078247070312
15000 3.5801913738250732
20000 3.5801899433135986
25000 3.580174207687378
...
"""
import torch
N, D_in, H, D_out = 100, 1, 5, 2
small_sigma = 0.2
large_sigma = 15.5
y1_center = [-1., -1.]
y2_center = [1., 1.]
x_sigma = 0.05
x1 = torch.randn((N,1)) * x_sigma - 1.
x2 = torch.randn((N,1)) * x_sigma + 1.
y1 = torch.randn(size=(N, 2)) * small_sigma + torch.Tensor(y1_center)
y2 = torch.randn(size=(N, 2)) * large_sigma + torch.Tensor(y2_center)
x = torch.cat([x1, x2], dim=0)
y = torch.cat([y1,y2], dim=0)
class MuOnly(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(MuOnly, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
h_relu = self.linear1(x).clamp(min=0)
return self.linear2(h_relu)
class MuAndSigma(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(MuAndSigma, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear_mu = torch.nn.Linear(H, D_out)
self.linear_beta = torch.nn.Linear(H, 1) # in this case assume scalar sigma, can also consider matrix
def forward(self, x):
h_relu = self.linear1(x).clamp(min=0)
mu = self.linear_mu(h_relu)
beta = self.linear_beta(h_relu).clamp(min=1e-3)
return mu, beta
def train_standard():
model = MuOnly(D_in, H, D_out)
criterion = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for t in range(10**5):
# Forward pass: Compute predicted y by passing x to the model
y_pred = model(x)
# Compute and print loss
loss = criterion(y_pred, y)
if t % 5000 == 0:
print(t, loss.item())
# Zero gradients, perform a backward pass, and update the weights.
optimizer.zero_grad()
loss.backward()
optimizer.step()
def heteroscedastic_loss(mu, beta, y):
return torch.mean(beta * torch.norm(y - mu, dim=1)**2 - torch.log(beta)) * 0.5
#return torch.mean((torch.norm((y - mu), dim=1)**2 - torch.log(beta)) / 2.)
def train_heteroscedastic():
model = MuAndSigma(D_in, H, D_out)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for i in range(10**5):
mu, beta = model(x)
loss = heteroscedastic_loss(mu, beta, y)
if i % 5000 == 0:
print(i, loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_heteroscedastic()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment