Skip to content

Instantly share code, notes, and snippets.

@winger
Created August 14, 2020 21:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save winger/edc520e46ab870abb91d3a93d391eb30 to your computer and use it in GitHub Desktop.
Save winger/edc520e46ab870abb91d3a93d391eb30 to your computer and use it in GitHub Desktop.
import torch
import numpy as np
from matplotlib import pyplot as plt
torch.set_printoptions(precision=10)
class ResidualBlock(torch.nn.Module):
def __init__(self, dims, bottleneck):
super(ResidualBlock, self).__init__()
self.linear1 = torch.nn.Linear(dims, bottleneck)
self.linear2 = torch.nn.Linear(bottleneck, dims)
def forward(self, x):
y = self.linear1(x.relu())
y = self.linear2(y.relu())
return x + y
class ResNet(torch.nn.Module):
def __init__(self, inout_dims, dims, bottleneck, depth):
super(ResNet, self).__init__()
self.linear_in = torch.nn.Linear(inout_dims, dims)
self.residuals = torch.nn.Sequential(*[ResidualBlock(dims, bottleneck) for i in range(depth)])
self.linear_out = torch.nn.Linear(dims, inout_dims)
def forward(self, x):
x = self.linear_in(x)
x = self.residuals(x)
x = self.linear_out(x)
return x
dataset = torch.distributions.Normal(0, 1)
model = ResNet(64, 1024, 256, 1).cuda()
opt = torch.optim.Adam(model.parameters(), 3e-4)
for it in range(1000000):
x = dataset.sample([1024, 64]).cuda()
loss = (model(x) - x).square().sum(1).mean()
print(it, loss.item())
opt.zero_grad()
loss.backward()
opt.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment