Last active
July 11, 2020 19:16
-
-
Save thomashikaru/5cc972db0e9ce92a591cc6300db17cac to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import plotly.graph_objects as go | |
import numpy as np | |
# Batch Size, Input Neurons, Hidden Neurons, Output Neurons | |
N, D_in, H, D_out = 16, 1, 1024, 1 | |
# Create random Tensors to hold inputs and outputs | |
x = torch.randn(N, D_in) | |
y = torch.randn(N, D_out) | |
# Use the nn package to define our model | |
# Linear (Input -> Hidden), ReLU (Non-linearity), Linear (Hidden-> Output) | |
model = torch.nn.Sequential( | |
torch.nn.Linear(D_in, H), | |
torch.nn.ReLU(), | |
torch.nn.Linear(H, D_out), | |
) | |
# Define the loss function: Mean Squared Error | |
# The sum of the squares of the differences between prediction and ground truth | |
loss_fn = torch.nn.MSELoss(reduction='sum') | |
# The optimizer does a lot of the work of actually calculating gradients and | |
# applying backpropagation through the network to update weights | |
learning_rate = 1e-4 | |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
# Perform 30000 training steps | |
for t in range(30000): | |
# Forward pass: compute predicted y by passing x to the model. | |
y_pred = model(x) | |
# Compute loss and print it periodically | |
loss = loss_fn(y_pred, y) | |
if t % 100 == 0: | |
print(t, loss.item()) | |
# Update the network weights using gradient of the loss | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
# Draw the original random points as a scatter plot | |
fig = go.Figure() | |
fig.add_trace(go.Scatter(x=x.flatten().numpy(), y=y.flatten().numpy(), mode="markers")) | |
# Generate predictions for evenly spaced x-values between minx and maxx | |
minx = min(list(x.numpy())) | |
maxx = max(list(x.numpy())) | |
c = torch.from_numpy(np.linspace(minx, maxx, num=640)).reshape(-1, 1).float() | |
d = model(c) | |
# Draw the predicted functions as a line graph | |
fig.add_trace(go.Scatter(x=c.flatten().numpy(), y=d.flatten().detach().numpy(), mode="lines")) | |
fig.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment