Last active
August 13, 2024 08:41
-
-
Save UFO-101/41b7ff0b250babe69bf16071e76658a6 to your computer and use it in GitHub Desktop.
Requires my fork of TransformerLens (https://github.com/UFO-101/TransformerLens) - until this change is released in the main package: https://github.com/TransformerLensOrg/TransformerLens/pull/699
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
from datetime import datetime | |
from math import ceil, floor | |
import plotly.express as px | |
import torch as t | |
from auto_circuit.data import PromptDataLoader, PromptDataset | |
from auto_circuit.experiment_utils import load_tl_model | |
from auto_circuit.types import AblationType | |
from auto_circuit.utils.ablation_activations import src_ablations | |
from auto_circuit.utils.graph_utils import patch_mode, patchable_model, set_all_masks | |
from datasets import load_dataset | |
from embedding_lens.custom_tqdm import tqdm | |
# Disable all gradients globally | |
t.set_grad_enabled(False) | |
device = t.device("cuda:3" if t.cuda.is_available() else "cpu") | |
model_name = "gpt2-xl" | |
# model_name = "gpt2" | |
model = load_tl_model(model_name, device) | |
ds = load_dataset("NeelNanda/pile-10k") | |
# %% | |
n_prompts = 100 | |
model.tokenizer.padding_side = "left" # type: ignore | |
pile10k_tokens: t.Tensor = model.tokenizer( | |
[model.tokenizer.bos_token + seq for seq in ds["train"]["text"][:n_prompts]], # type: ignore | |
return_tensors="pt", | |
truncation=True, | |
padding=True, | |
max_length=100, | |
)["input_ids"].to(device) # type: ignore | |
pile10k_tokens_shuffled: t.Tensor = pile10k_tokens[ | |
:, t.randperm(pile10k_tokens.size(1)) | |
] | |
answers = [t.tensor([0]) for _ in pile10k_tokens] | |
# %% | |
dataset = PromptDataset(pile10k_tokens, pile10k_tokens_shuffled, answers, answers) | |
dataloader = PromptDataLoader(dataset, None, 0, batch_size=10, shuffle=False) | |
model = patchable_model( | |
model, | |
factorized=True, | |
slice_output="last_seq", | |
separate_qkv=False, | |
device=device, | |
) | |
# %% | |
mlp_1_to_4 = model.edge_name_dict[None]["MLP 1->MLP 4"].patch_idx | |
attn_2_0_to_mlp_4 = model.edge_name_dict[None]["A2.0->MLP 4"].patch_idx | |
mlp_2_to_4 = model.edge_name_dict[None]["MLP 2->MLP 4"].patch_idx | |
mlp_diff = attn_2_0_to_mlp_4[0] - mlp_1_to_4[0] | |
attn_diff = mlp_2_to_4[0] - attn_2_0_to_mlp_4[0] | |
assert mlp_diff == 1 | |
assert attn_diff == model.cfg.n_heads | |
# %% | |
max_layer = max(edge.dest.layer for edge in model.edges) | |
print("max_layer", max_layer) | |
ablations = src_ablations(model, dataloader, AblationType.TOKENWISE_MEAN_CLEAN) | |
STEP_SIZE = 2 | |
total_token_losses = t.tensor(0.0).to(device) | |
n_token_loss = 0 | |
for batch in tqdm(dataloader): | |
batch = batch.clean | |
token_losses = model(batch, return_type="loss", loss_per_token=True) | |
total_token_losses += token_losses.sum() | |
n_token_loss += t.count_nonzero(token_losses).item() | |
baseline_loss = (total_token_losses / n_token_loss).item() | |
layer_horizon_losses = {} | |
for layer_horizon in tqdm(range(0, max_layer + 1, STEP_SIZE)): | |
set_all_masks(model, 0.0) | |
for wrapper in model.dest_wrappers: | |
if "mlp" in wrapper.module_name: | |
horizon_mask_idx_diff = model.cfg.n_heads * ceil(layer_horizon / 2) + floor( | |
layer_horizon / 2 | |
) | |
elif "attn" in wrapper.module_name or "resid_post" in wrapper.module_name: | |
horizon_mask_idx_diff = model.cfg.n_heads * floor(layer_horizon / 2) + ceil( | |
layer_horizon / 2 | |
) | |
else: | |
raise ValueError(f"Unknown module name: {wrapper.module_name}") | |
wrapper.patch_mask.data[..., :-horizon_mask_idx_diff] = 1.0 | |
# print("wrapper.patch_mask", wrapper.patch_mask.shape) | |
# print("layer_horizon_patch_mask_idx_diff", layer_horizon_patch_mask_idx_diff) | |
# print("wrapper module", wrapper.module_name) | |
# print("wrapper patch mask", wrapper.patch_mask) | |
if layer_horizon == 0: | |
set_all_masks(model, 1.0) | |
with patch_mode(model, ablations): | |
# draw_seq_graph(model) | |
total_token_losses = t.tensor(0.0).to(device) | |
n_token_loss = 0 | |
for idx, batch in tqdm(enumerate(dataloader)): | |
batch = batch.clean | |
token_losses = model(batch, return_type="loss", loss_per_token=True) | |
total_token_losses += token_losses.sum() | |
n_token_loss += t.count_nonzero(token_losses).item() | |
layer_horizon_losses[layer_horizon] = (total_token_losses / n_token_loss).item() | |
# %% | |
layer_horizons = [layer_horizon / 2 for layer_horizon in layer_horizon_losses.keys()] | |
fig = px.line(x=layer_horizons, y=list(layer_horizon_losses.values())) | |
fig.update_layout( | |
title=f"Layer Horizon vs Loss of {model_name.upper()} ({model.cfg.n_layers} layers)" | |
) | |
fig.update_xaxes(title="Layer Horizon") | |
fig.update_yaxes(title="Loss") | |
fig.add_hline( | |
y=baseline_loss, | |
line_dash="dot", | |
line_color="black", | |
annotation_text=f"Baseline Loss: {baseline_loss:.4f}", | |
annotation_position="bottom right", | |
) | |
fig.update_layout(width=1000, height=600) | |
fig.show() | |
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") | |
path = f"figures/layer_horizon_vs_loss_{timestamp}.png" | |
fig.write_image(path, scale=4) | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment