Skip to content

Instantly share code, notes, and snippets.

@thomashikaru
Last active July 11, 2020 19:20
Show Gist options
  • Save thomashikaru/7b3b08eac4fe638692adfb36fa365652 to your computer and use it in GitHub Desktop.
Save thomashikaru/7b3b08eac4fe638692adfb36fa365652 to your computer and use it in GitHub Desktop.
import torch
import plotly.express as px
import pandas as pd
# Batch Size, Input Neurons, Hidden Neurons, Output Neurons
N, D_in, H, D_out = 128, 2, 1024, 1
# Create random Tensors to hold inputs and outputs
x = torch.rand(N, D_in)
y = torch.randint(0, 2, (N, D_out))
# Plot randomly generated points and color by label
df = pd.DataFrame({"x": x[:, 0].flatten(), "y": x[:, 1].flatten(), "class": y.flatten()})
fig = px.scatter(df, x="x", y="y", color="class", color_continuous_scale="tealrose")
fig.show()
# define model: Linear (Input->Hidden), ReLU, Linear (Hidden->Output), Sigmoid
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out),
torch.nn.Sigmoid()
)
# define loss function: Binary Cross Entropy Loss (good for binary classification tasks)
loss_fn = torch.nn.BCELoss()
learning_rate = 0.002
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Store losses over time
ts, losses = ([], [])
# run training steps
for t in range(60000):
y_pred = model(x)
loss = loss_fn(y_pred.float(), y.float())
if t % 100 == 0:
ts.append(t)
losses.append(loss.data.numpy())
optimizer.zero_grad()
loss.backward()
optimizer.step()
# generate a bunch of random points to cover the sample space, then call model
c = torch.rand(32000, D_in)
d = model(c)
# store random data and predicted classifications in a DataFrame and plot with Plotly Express
df2 = pd.DataFrame({"x": c[:, 0].flatten(),
"y": c[:, 1].flatten(),
"class": d.flatten().detach().numpy()})
fig2 = px.scatter(df2, x="x", y="y", color="class", color_continuous_scale="tealrose")
fig2.show()
# plot the loss as a function of training step
fig3 = px.scatter(x=ts, y=losses)
fig3.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment