Last active
January 8, 2021 21:23
-
-
Save redwrasse/1281b12a7012ad9e699842f2701eb8a9 to your computer and use it in GitHub Desktop.
heteroscedastic model: parameterized variance in discriminative gaussian
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
standard discriminative gaussian | |
y ~ N(f(x), sigma^2) | |
as well as heteroscedastic model | |
y ~ N(f(x), sigma^2(x)) | |
training on a dataset requiring the heteroscedastic model: | |
x in R, y in R^2 | |
maps points about -1 -> small variance gaussian about (-1, -1) | |
+1 -> large variance gaussian about (1, 1) | |
standard model: | |
0 125.3414077758789 | |
5000 123.4275894165039 | |
10000 123.38433837890625 | |
15000 123.31232452392578 | |
20000 123.17870330810547 | |
25000 122.92581176757812 | |
... | |
heteroscedastic model: | |
0 25.26043701171875 | |
5000 3.58022141456604 | |
10000 3.5802078247070312 | |
15000 3.5801913738250732 | |
20000 3.5801899433135986 | |
25000 3.580174207687378 | |
... | |
""" | |
import torch | |
N, D_in, H, D_out = 100, 1, 5, 2 | |
small_sigma = 0.2 | |
large_sigma = 15.5 | |
y1_center = [-1., -1.] | |
y2_center = [1., 1.] | |
x_sigma = 0.05 | |
x1 = torch.randn((N,1)) * x_sigma - 1. | |
x2 = torch.randn((N,1)) * x_sigma + 1. | |
y1 = torch.randn(size=(N, 2)) * small_sigma + torch.Tensor(y1_center) | |
y2 = torch.randn(size=(N, 2)) * large_sigma + torch.Tensor(y2_center) | |
x = torch.cat([x1, x2], dim=0) | |
y = torch.cat([y1,y2], dim=0) | |
class MuOnly(torch.nn.Module): | |
def __init__(self, D_in, H, D_out): | |
super(MuOnly, self).__init__() | |
self.linear1 = torch.nn.Linear(D_in, H) | |
self.linear2 = torch.nn.Linear(H, D_out) | |
def forward(self, x): | |
h_relu = self.linear1(x).clamp(min=0) | |
return self.linear2(h_relu) | |
class MuAndSigma(torch.nn.Module): | |
def __init__(self, D_in, H, D_out): | |
super(MuAndSigma, self).__init__() | |
self.linear1 = torch.nn.Linear(D_in, H) | |
self.linear_mu = torch.nn.Linear(H, D_out) | |
self.linear_beta = torch.nn.Linear(H, 1) # in this case assume scalar sigma, can also consider matrix | |
def forward(self, x): | |
h_relu = self.linear1(x).clamp(min=0) | |
mu = self.linear_mu(h_relu) | |
beta = self.linear_beta(h_relu).clamp(min=1e-3) | |
return mu, beta | |
def train_standard(): | |
model = MuOnly(D_in, H, D_out) | |
criterion = torch.nn.MSELoss(reduction='mean') | |
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |
for t in range(10**5): | |
# Forward pass: Compute predicted y by passing x to the model | |
y_pred = model(x) | |
# Compute and print loss | |
loss = criterion(y_pred, y) | |
if t % 5000 == 0: | |
print(t, loss.item()) | |
# Zero gradients, perform a backward pass, and update the weights. | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
def heteroscedastic_loss(mu, beta, y): | |
return torch.mean(beta * torch.norm(y - mu, dim=1)**2 - torch.log(beta)) * 0.5 | |
#return torch.mean((torch.norm((y - mu), dim=1)**2 - torch.log(beta)) / 2.) | |
def train_heteroscedastic(): | |
model = MuAndSigma(D_in, H, D_out) | |
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |
for i in range(10**5): | |
mu, beta = model(x) | |
loss = heteroscedastic_loss(mu, beta, y) | |
if i % 5000 == 0: | |
print(i, loss.item()) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
train_heteroscedastic() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment