Skip to content

Instantly share code, notes, and snippets.

@davegreenwood
Last active April 2, 2020 11:55
Show Gist options
  • Save davegreenwood/7e0b17526796f7866bfbe23d4a069e2c to your computer and use it in GitHub Desktop.
Save davegreenwood/7e0b17526796f7866bfbe23d4a069e2c to your computer and use it in GitHub Desktop.
PyTorch demo to learn line parameters.
"""Fit a 2d line to data using auto differentiation to learn w, b."""
# %%
import torch
import matplotlib.pyplot as plt
# size of data - more data wil give a more accurate result
N = 500
# we can add noise to data to make it more realistic
noise = torch.randn(N)
def get_data(w, b, n=100):
"""generate x, y data """
x = torch.linspace(0, 10.0, n)
y = w * x + b
return x, y
# %% Start
# line parameters y = wx + b
w_true, b_true = 2.1, 1.8
w_start, b_start = 1.0, 0.0
x_true, y_true = get_data(w_true, b_true, N)
x_start, y_start = get_data(w_start, b_start, N)
# This is our noisy data!!
x, y = x_true + noise, y_true - noise
fig, ax = plt.subplots(1, figsize=[9, 9])
ax.plot(x_start, y_start, "g")
ax.plot(x_true, y_true, "r")
ax.plot(x, y, "+k")
ax.legend(["start", "true", "data"])
plt.savefig("start.png")
# %% Learning
# learn these parameters by setting requires_grad = True!!
W = torch.tensor([w_start], requires_grad=True)
B = torch.tensor([b_start], requires_grad=True)
n_evals = 20000
learning_rate = 0.0001
optimizer = torch.optim.SGD([W, B], lr=learning_rate)
loss_func = torch.nn.MSELoss()
# repeatedly calculate the loss and back prop
for i in range(n_evals):
optimizer.zero_grad()
x_p, y_p = get_data(W, B, N)
loss = loss_func(x, x_p) + loss_func(y, y_p)
loss.backward()
optimizer.step()
if i % 1000 == 0:
# report every 1000 evals
log = f"eval: {i}, w: {W[0]:0.3f}, b: {B[0]:0.3f}"
print(log)
log = f"eval: {i}, w: {W[0]:0.3f}, b: {B[0]:0.3f}"
print(log)
# %% Results
fig, ax = plt.subplots(1, figsize=[9, 9])
ax.plot(x_p.clone().detach(), y_p.clone().detach(), "b")
ax.plot(x_true, y_true, "r")
ax.plot(x, y, "+k")
ax.legend(["predicted", "true", "data"])
plt.savefig("result.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment