Skip to content

Instantly share code, notes, and snippets.

@thomashikaru
Last active July 11, 2020 19:16
Show Gist options
  • Save thomashikaru/5cc972db0e9ce92a591cc6300db17cac to your computer and use it in GitHub Desktop.
Save thomashikaru/5cc972db0e9ce92a591cc6300db17cac to your computer and use it in GitHub Desktop.
import torch
import plotly.graph_objects as go
import numpy as np
# Batch Size, Input Neurons, Hidden Neurons, Output Neurons
N, D_in, H, D_out = 16, 1, 1024, 1
# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
# Use the nn package to define our model
# Linear (Input -> Hidden), ReLU (Non-linearity), Linear (Hidden-> Output)
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out),
)
# Define the loss function: Mean Squared Error
# The sum of the squares of the differences between prediction and ground truth
loss_fn = torch.nn.MSELoss(reduction='sum')
# The optimizer does a lot of the work of actually calculating gradients and
# applying backpropagation through the network to update weights
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Perform 30000 training steps
for t in range(30000):
# Forward pass: compute predicted y by passing x to the model.
y_pred = model(x)
# Compute loss and print it periodically
loss = loss_fn(y_pred, y)
if t % 100 == 0:
print(t, loss.item())
# Update the network weights using gradient of the loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Draw the original random points as a scatter plot
fig = go.Figure()
fig.add_trace(go.Scatter(x=x.flatten().numpy(), y=y.flatten().numpy(), mode="markers"))
# Generate predictions for evenly spaced x-values between minx and maxx
minx = min(list(x.numpy()))
maxx = max(list(x.numpy()))
c = torch.from_numpy(np.linspace(minx, maxx, num=640)).reshape(-1, 1).float()
d = model(c)
# Draw the predicted functions as a line graph
fig.add_trace(go.Scatter(x=c.flatten().numpy(), y=d.flatten().detach().numpy(), mode="lines"))
fig.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment