Skip to content

Instantly share code, notes, and snippets.

@evanthebouncy
Created July 3, 2019 21:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save evanthebouncy/acf1bacd277ac98fdd588ee9d8ee9a57 to your computer and use it in GitHub Desktop.
Save evanthebouncy/acf1bacd277ac98fdd588ee9d8ee9a57 to your computer and use it in GitHub Desktop.
trying to auto-encode a real number
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import random
from tqdm import tqdm
if torch.cuda.is_available():
def to_torch(x, dtype, req = False):
tor_type = torch.cuda.LongTensor if dtype == "int" else torch.cuda.FloatTensor
x = Variable(torch.from_numpy(x).type(tor_type), requires_grad = req)
return x
else:
def to_torch(x, dtype, req = False):
tor_type = torch.LongTensor if dtype == "int" else torch.FloatTensor
x = Variable(torch.from_numpy(x).type(tor_type), requires_grad = req)
return x
class Compl(nn.Module):
def __init__(self):
super(Compl, self).__init__()
n_hidden = 100
self.fc1 = nn.Linear(1, n_hidden)
self.fc2 = nn.Linear(n_hidden, 1)
self.opt = torch.optim.SGD(self.parameters(), lr=1e-5)
def forward(self, yy):
yy = yy.unsqueeze(-1)
h = nn.LeakyReLU()(self.fc1(yy))
return self.fc2(h)
def loss_function(self, y, y_pred):
return torch.sum((y - y_pred) ** 2)
def learn_once(self, yy):
yy = to_torch(yy, "float")
self.opt.zero_grad()
yy_pred = self(yy)
loss = self.loss_function(yy, yy_pred)
loss.backward()
self.opt.step()
return loss
def save(self, loc):
torch.save(self.state_dict(), loc)
def load(self, loc):
self.load_state_dict(torch.load(loc))
if __name__ == '__main__':
compl = Compl().cuda()
for i in tqdm(range(1000000)):
yy = np.random.random((100,))
loss = compl.learn_once(yy)
if i % 1000 == 0:
print ("------------------------------")
yy_pred = compl(to_torch(yy, "float"))
print ("loss ", loss)
print ("yy_pred ", yy_pred[0])
print ("yy ", yy[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment