Skip to content

Instantly share code, notes, and snippets.

@Chanlaw
Created December 1, 2022 20:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Chanlaw/e8c286629e0626f723a20cef027665d1 to your computer and use it in GitHub Desktop.
Save Chanlaw/e8c286629e0626f723a20cef027665d1 to your computer and use it in GitHub Desktop.
# %%
from typing import Tuple
import einops
import numpy as np
import pandas as pd # type: ignore
import seaborn as sns # type: ignore
import torch
n_steps = 10000
def train(
steps=n_steps, weight_decay=0.0, scale=10.0, lr=0.01
) -> Tuple[np.array, np.array, np.array, np.array, np.array]:
a = torch.rand((100,)) * scale - scale / 2
b = torch.rand((100,)) * scale - scale / 2
C = torch.einsum("i,j->ij", a, b)
x = torch.autograd.Variable(
torch.rand((100,)) * scale - scale / 2, requires_grad=True
)
y = torch.autograd.Variable(
torch.rand((100,)) * scale - scale / 2, requires_grad=True
)
learning_rate = lr
# loss is mse between C and Z
# do gradient descent on x and y using SGDOptimizer while logging the values of x, y, and loss
x_vals = []
y_vals = []
losses = []
optim = torch.optim.SGD([x, y], lr=learning_rate, weight_decay=weight_decay)
for i in range(steps):
Z = torch.einsum("i,j->ij", x, y)
loss = torch.mean((C - Z) ** 2)
loss.backward()
optim.step()
optim.zero_grad()
x_vals.append(x.detach().numpy().copy())
y_vals.append(y.detach().numpy().copy())
losses.append(loss.detach().numpy())
return (
np.array(x_vals),
np.array(y_vals),
np.array(losses),
a.detach().numpy(),
b.detach().numpy(),
)
# %%
# save the values of x, y, and loss for 10 random seeds into a single array
x_vals = []
y_vals = []
losses = []
a_list = []
b_list = []
for i in range(10):
x, y, loss, a, b = train()
x_vals.append(x)
y_vals.append(y)
losses.append(loss)
a_list.append(a)
b_list.append(b)
# %%
# put them into a dataframe with seeds as columns
x_vals = np.array(x_vals)
y_vals = np.array(y_vals)
losses = np.array(losses)
a_np = np.array(a_list)
b_np = np.array(b_list)
df = pd.DataFrame(
{
"x_norms": np.linalg.norm(x_vals, axis=2).flatten(),
"y_norms": np.linalg.norm(y_vals, axis=2).flatten(),
"x_a_inner_product": np.einsum("ijk,ik->ij", x_vals, a_np).flatten(),
"y_b_inner_product": np.einsum("ijk,ik->ij", y_vals, b_np).flatten(),
"loss": losses.flatten(),
"seed": np.repeat(np.arange(10), n_steps),
"steps": np.tile(np.arange(n_steps), 10),
}
)
# %%
# plot the loss over time for each seed using seaborn
# plot only every 50th point to make it easier to see
df = df[df["steps"] % 50 == 0]
fig = sns.lineplot(
data=df,
x="steps",
y="loss",
errorbar=None,
color="xkcd:dark blue",
label="Average Loss",
)
fig.set_title("Loss over time, weight decay = 0.0")
fig.set_xlabel("Steps")
fig.set_ylabel("Loss")
# add loss for each seed to `fig` with low alpha
for i in range(10):
sns.lineplot(
data=df[df.seed == i],
x="steps",
y="loss",
alpha=0.3,
ax=fig,
label=f"Seed {i}",
color="blue",
dashes=True,
legend=False,
)
fig.figure.savefig("loss_over_time.png")
# %%
# repeat the above experiment with weight decay
weight_decay = 0.1
x_vals_wd = []
y_vals_wd = []
losses_wd = []
a_list_wd = []
b_list_wd = []
for i in range(10):
x, y, loss, a, b = train(weight_decay=weight_decay)
x_vals_wd.append(x)
y_vals_wd.append(y)
losses_wd.append(loss)
a_list_wd.append(a)
b_list_wd.append(b)
x_vas_wd = np.array(x_vals_wd)
y_vals_wd = np.array(y_vals_wd)
losses_wd = np.array(losses_wd)
a_np_wd = np.array(a_list_wd)
b_np_wd = np.array(b_list_wd)
df_wd = pd.DataFrame(
{
"x_norms": np.linalg.norm(x_vals_wd, axis=2).flatten(),
"y_norms": np.linalg.norm(y_vals_wd, axis=2).flatten(),
"x_a_inner_product": np.einsum("ijk,ik->ij", x_vals_wd, a_np_wd).flatten(),
"y_b_inner_product": np.einsum("ijk,ik->ij", y_vals_wd, b_np_wd).flatten(),
"loss": losses_wd.flatten(),
"seed": np.repeat(np.arange(10), n_steps),
"steps": np.tile(np.arange(n_steps), 10),
}
)
# %%
df_wd = df_wd[df_wd["steps"] % 50 == 0]
fig = sns.lineplot(
data=df_wd,
x="steps",
y="loss",
errorbar=None,
color="xkcd:dark red",
label="Average Loss",
)
for i in range(10):
sns.lineplot(
data=df_wd[df_wd.seed == i],
x="steps",
y="loss",
alpha=0.3,
ax=fig,
label=f"Seed {i}",
color="red",
legend=False,
)
fig.set_title("Loss over time, weight decay = 0.1")
fig.set_xlabel("Steps")
fig.set_ylabel("Loss")
fig.figure.savefig("loss_over_time_wd.png")
# %%
# plot weight norms of x and y over time on the same figure
fig = sns.lineplot(
data=df,
x="steps",
y="x_norms",
color="xkcd:dark blue",
label="Average Norm of x",
)
fig = sns.lineplot(
data=df,
x="steps",
y="y_norms",
color="xkcd:blue",
label="Average Norm of y",
ax=fig,
)
fig.set_title("Norms of x and y over time, weight decay = 0.0")
fig.set_xlabel("Steps")
fig.set_ylabel("Norm")
fig.figure.savefig("norms_over_time.png")
# %%
# plot weight norms of x and y over time on the same figure with weight decay
fig = sns.lineplot(
data=df_wd,
x="steps",
y="x_norms",
label="Average Norm of x",
color="xkcd:dark red",
)
fig = sns.lineplot(
data=df_wd,
x="steps",
y="y_norms",
label="Average Norm of y",
color="xkcd:red",
ax=fig,
)
fig.set_title("Norms of x and y over time, weight decay = 0.1")
fig.set_xlabel("Steps")
fig.set_ylabel("Norm")
fig.figure.savefig("norms_over_time_wd.png")
# %%
# finally, plot the absolute value of the dot products of x and a and y and b over time
df["abs_x_a_inner_product"] = np.abs(df["x_a_inner_product"])
df["abs_y_b_inner_product"] = np.abs(df["y_b_inner_product"])
fig = sns.lineplot(
data=df,
x="steps",
y="abs_x_a_inner_product",
color="xkcd:dark blue",
label="Average Inner Product of x and a",
)
fig = sns.lineplot(
data=df,
x="steps",
y="abs_y_b_inner_product",
color="xkcd:blue",
label="Average Inner Product of y and b",
ax=fig,
)
fig.set_title(
"Absolute value of inner products with truth over time, weight decay = 0.0"
)
fig.set_xlabel("Steps")
fig.set_ylabel("Inner Product")
fig.figure.savefig("inner_products_over_time.png")
# %%
# finally, plot the absolute value of dot products of x and a and y and b over time with weight decay
df_wd["abs_x_a_inner_product"] = np.abs(df_wd["x_a_inner_product"])
df_wd["abs_y_b_inner_product"] = np.abs(df_wd["y_b_inner_product"])
fig = sns.lineplot(
data=df_wd,
x="steps",
y="abs_x_a_inner_product",
color="xkcd:dark red",
label="Average Inner Product of x and a",
)
fig = sns.lineplot(
data=df_wd,
x="steps",
y="abs_y_b_inner_product",
color="xkcd:light red",
label="Average Inner Product of y and b",
ax=fig,
)
fig.set_title(
"Absolute value of inner products with truth over time, weight decay = 0.1"
)
fig.set_xlabel("Steps")
fig.set_ylabel("Inner Product")
fig.figure.savefig("inner_products_over_time_wd.png")
@Chanlaw
Copy link
Author

Chanlaw commented Dec 1, 2022

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment