Skip to content

Instantly share code, notes, and snippets.

@Mihonarium
Last active June 16, 2023 15:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Mihonarium/7b4b9a4a17c8f1b1c67dc143b9225d53 to your computer and use it in GitHub Desktop.
Save Mihonarium/7b4b9a4a17c8f1b1c67dc143b9225d53 to your computer and use it in GitHub Desktop.
from tqdm import tqdm
def top_1_acc(OV_circuit):
return ((OV_circuit.argmax(dim=0) == t.arange(0, OV_circuit.size(-1)).to(device)).sum() / OV_circuit.size(-1)).item()
def top_5_acc(OV_circuit):
return ((OV_circuit.topk(5, dim=0)[1] == t.arange(0, OV_circuit.size(-1)).to(device)).sum() / OV_circuit.size(-1)).item()
W_U = model.unembed.W_U.to(device)
W_E = model.embed.W_E.to(device)
W_o = t.randn(768, 64, requires_grad=True, device=device)
W_v = t.randn(64, 768, requires_grad=True, device=device)
optimizer = t.optim.AdamW([W_o, W_v], lr=0.001)
steps = 120000
batch_size = 1024
pbar = tqdm(range(steps))
i = t.eye(batch_size).to(device)
i_diag = i.diag() # actually it's just t.ones but whatever
top_1 = 0
top_5 = 0
for step in pbar:
optimizer.zero_grad() # Clear the gradients of W_o and W_v
indices = t.randperm(W_U.size(0))[:batch_size].to(device)
W_U_subset = W_U[indices]
W_E_subset = W_E[:, indices]
combined_WoWv = W_U_subset @ W_o @ W_v @ W_E_subset
softmax = t.softmax(combined_WoWv, dim=0).diag()
softmax[softmax > 0.9] = 1
loss = t.nn.functional.mse_loss(softmax,
i_diag) + t.nn.functional.mse_loss(
combined_WoWv, i) / 160
# the second mse_loss actually just penalizes large activations
# this keeps the activations from exploding, makes them similar in magnitute to the head's
loss.backward()
optimizer.step()
if step % 500 == 0:
pbar.set_postfix({'info': "Loss: {:.4f}, Top-1: {:.4f}, Top-5: {:.4f}".format(loss.item(),
top_1, top_5)})
if step % 5000 == 0:
# print()
try:
t.cuda.empty_cache()
del OV_circuit_full
except:
pass
OV_circuit_full = model.unembed.W_U @ W_o @ W_v @ model.embed.W_E
top_1 = top_1_acc(OV_circuit_full)
top_5 = top_5_acc(OV_circuit_full)
del OV_circuit_full
try:
t.cuda.empty_cache()
del OV_circuit_full
except:
pass
OV_circuit_full = model.unembed.W_U @ W_o @ W_v @ model.embed.W_E
print("Top 1 accuracy for the trained OV Circuit:", top_1_acc(OV_circuit_full))
print("Top 5 accuracy for the trained OV Circuit:", top_5_acc(OV_circuit_full))
try:
del OV_circuit_full
except:
pass
print("The trained matrix rank (expected to be 64):", t.linalg.matrix_rank(W_o @ W_v).item())
all_WoWvs = model.blocks[1].attn.W_O @ model.blocks[1].attn.W_V
print("The combained L1H4 + L1H10 W_O@W_V rank (expected to be 128):", t.linalg.matrix_rank(all_WoWvs[4] + all_WoWvs[10]).item())
px.imshow(
to_numpy(W_U_subset @ (all_WoWvs[4] + all_WoWvs[10]) @ W_E_subset),
labels={"x": "Output tokens", "y": "Input tokens"},
width=800, height=800,
title="two copying OV circuts from L1",
).show()
px.imshow(
to_numpy(combined_WoWv),
labels={"x": "Output tokens", "y": "Input tokens"},
width=800, height=800,
title="trained token copying circut",
).show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment