Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save UFO-101/41b7ff0b250babe69bf16071e76658a6 to your computer and use it in GitHub Desktop.
Save UFO-101/41b7ff0b250babe69bf16071e76658a6 to your computer and use it in GitHub Desktop.
Requires my fork of TransformerLens (https://github.com/UFO-101/TransformerLens) - until this change is released in the main package: https://github.com/TransformerLensOrg/TransformerLens/pull/699
# %%
from datetime import datetime
from math import ceil, floor
import plotly.express as px
import torch as t
from auto_circuit.data import PromptDataLoader, PromptDataset
from auto_circuit.experiment_utils import load_tl_model
from auto_circuit.types import AblationType
from auto_circuit.utils.ablation_activations import src_ablations
from auto_circuit.utils.graph_utils import patch_mode, patchable_model, set_all_masks
from datasets import load_dataset
from embedding_lens.custom_tqdm import tqdm
# Disable all gradients globally
t.set_grad_enabled(False)
device = t.device("cuda:3" if t.cuda.is_available() else "cpu")
model_name = "gpt2-xl"
# model_name = "gpt2"
model = load_tl_model(model_name, device)
ds = load_dataset("NeelNanda/pile-10k")
# %%
n_prompts = 100
model.tokenizer.padding_side = "left" # type: ignore
pile10k_tokens: t.Tensor = model.tokenizer(
[model.tokenizer.bos_token + seq for seq in ds["train"]["text"][:n_prompts]], # type: ignore
return_tensors="pt",
truncation=True,
padding=True,
max_length=100,
)["input_ids"].to(device) # type: ignore
pile10k_tokens_shuffled: t.Tensor = pile10k_tokens[
:, t.randperm(pile10k_tokens.size(1))
]
answers = [t.tensor([0]) for _ in pile10k_tokens]
# %%
dataset = PromptDataset(pile10k_tokens, pile10k_tokens_shuffled, answers, answers)
dataloader = PromptDataLoader(dataset, None, 0, batch_size=10, shuffle=False)
model = patchable_model(
model,
factorized=True,
slice_output="last_seq",
separate_qkv=False,
device=device,
)
# %%
mlp_1_to_4 = model.edge_name_dict[None]["MLP 1->MLP 4"].patch_idx
attn_2_0_to_mlp_4 = model.edge_name_dict[None]["A2.0->MLP 4"].patch_idx
mlp_2_to_4 = model.edge_name_dict[None]["MLP 2->MLP 4"].patch_idx
mlp_diff = attn_2_0_to_mlp_4[0] - mlp_1_to_4[0]
attn_diff = mlp_2_to_4[0] - attn_2_0_to_mlp_4[0]
assert mlp_diff == 1
assert attn_diff == model.cfg.n_heads
# %%
max_layer = max(edge.dest.layer for edge in model.edges)
print("max_layer", max_layer)
ablations = src_ablations(model, dataloader, AblationType.TOKENWISE_MEAN_CLEAN)
STEP_SIZE = 2
total_token_losses = t.tensor(0.0).to(device)
n_token_loss = 0
for batch in tqdm(dataloader):
batch = batch.clean
token_losses = model(batch, return_type="loss", loss_per_token=True)
total_token_losses += token_losses.sum()
n_token_loss += t.count_nonzero(token_losses).item()
baseline_loss = (total_token_losses / n_token_loss).item()
layer_horizon_losses = {}
for layer_horizon in tqdm(range(0, max_layer + 1, STEP_SIZE)):
set_all_masks(model, 0.0)
for wrapper in model.dest_wrappers:
if "mlp" in wrapper.module_name:
horizon_mask_idx_diff = model.cfg.n_heads * ceil(layer_horizon / 2) + floor(
layer_horizon / 2
)
elif "attn" in wrapper.module_name or "resid_post" in wrapper.module_name:
horizon_mask_idx_diff = model.cfg.n_heads * floor(layer_horizon / 2) + ceil(
layer_horizon / 2
)
else:
raise ValueError(f"Unknown module name: {wrapper.module_name}")
wrapper.patch_mask.data[..., :-horizon_mask_idx_diff] = 1.0
# print("wrapper.patch_mask", wrapper.patch_mask.shape)
# print("layer_horizon_patch_mask_idx_diff", layer_horizon_patch_mask_idx_diff)
# print("wrapper module", wrapper.module_name)
# print("wrapper patch mask", wrapper.patch_mask)
if layer_horizon == 0:
set_all_masks(model, 1.0)
with patch_mode(model, ablations):
# draw_seq_graph(model)
total_token_losses = t.tensor(0.0).to(device)
n_token_loss = 0
for idx, batch in tqdm(enumerate(dataloader)):
batch = batch.clean
token_losses = model(batch, return_type="loss", loss_per_token=True)
total_token_losses += token_losses.sum()
n_token_loss += t.count_nonzero(token_losses).item()
layer_horizon_losses[layer_horizon] = (total_token_losses / n_token_loss).item()
# %%
layer_horizons = [layer_horizon / 2 for layer_horizon in layer_horizon_losses.keys()]
fig = px.line(x=layer_horizons, y=list(layer_horizon_losses.values()))
fig.update_layout(
title=f"Layer Horizon vs Loss of {model_name.upper()} ({model.cfg.n_layers} layers)"
)
fig.update_xaxes(title="Layer Horizon")
fig.update_yaxes(title="Loss")
fig.add_hline(
y=baseline_loss,
line_dash="dot",
line_color="black",
annotation_text=f"Baseline Loss: {baseline_loss:.4f}",
annotation_position="bottom right",
)
fig.update_layout(width=1000, height=600)
fig.show()
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
path = f"figures/layer_horizon_vs_loss_{timestamp}.png"
fig.write_image(path, scale=4)
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment