Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save UFO-101/7b5e27291424029d092d8798ee1a1161 to your computer and use it in GitHub Desktop.
Save UFO-101/7b5e27291424029d092d8798ee1a1161 to your computer and use it in GitHub Desktop.
#%%
import torch as t
from embedding_lens.custom_tqdm import tqdm
from datetime import datetime
from auto_circuit.data import PromptDataLoader, PromptDataset
from auto_circuit.experiment_utils import load_tl_model
from auto_circuit.types import AblationType
from auto_circuit.utils.ablation_activations import src_ablations
from auto_circuit.utils.graph_utils import patch_mode, patchable_model
from datasets import load_dataset
import plotly.express as px
# Disable all gradients globally
t.set_grad_enabled(False)
device = t.device("cuda:3" if t.cuda.is_available() else "cpu")
model_name = "gpt2"
model = load_tl_model(model_name, device)
ds = load_dataset("NeelNanda/pile-10k")
#%%
n_prompts= 100
model.tokenizer.padding_side = "left" # type: ignore
pile10k_tokens: t.Tensor = model.tokenizer(
ds["train"]["text"][:n_prompts], # type: ignore
return_tensors="pt",
truncation=True,
padding=True,
max_length=100
)["input_ids"].to(device) # type: ignore
pile10k_tokens_shuffled: t.Tensor = pile10k_tokens[:, t.randperm(pile10k_tokens.size(1))]
answers = [t.tensor([0]) for _ in pile10k_tokens]
#%%
dataset = PromptDataset(pile10k_tokens, pile10k_tokens_shuffled, answers, answers)
dataloader = PromptDataLoader(dataset, None, 0, batch_size=50)
model = patchable_model(
model,
factorized=True,
slice_output="last_seq",
separate_qkv=False,
device=device,
)
#%%
max_layer = max(edge.dest.layer for edge in model.edges)
print("max_layer", max_layer)
ablations = src_ablations(model, dataloader, AblationType.TOKENWISE_MEAN_CLEAN)
STEP_SIZE = 2
baseline_losses = []
for batch in tqdm(dataloader):
batch = batch.clean
output_preds = model(batch)[:, :-1].flatten(0, 1)
correct_preds = batch[:, 1:].flatten()
loss = t.nn.functional.cross_entropy(output_preds, correct_preds)
baseline_losses.append(loss.item())
baseline_loss = sum(baseline_losses) / len(baseline_losses)
layer_horizon_losses = {}
for layer_horizon in tqdm(range(0, max_layer + 1, STEP_SIZE)):
print()
edges_to_patch = []
for edge in model.edges:
layer_diff = edge.dest.layer - edge.src.layer
if layer_diff > layer_horizon:
edges_to_patch.append(edge)
print("layer_horizon", layer_horizon, "edges_to_patch", len(edges_to_patch))
with patch_mode(model, ablations, edges_to_patch):
print("patch mode on")
# draw_seq_graph(model)
losses = []
for idx, batch in tqdm(enumerate(dataloader)):
print("batch idx", idx)
batch = batch.clean
output_preds = model(batch)[:, :-1].flatten(0, 1)
correct_preds = batch[:, 1:].flatten()
loss = t.nn.functional.cross_entropy(output_preds, correct_preds)
losses.append(loss.item())
layer_horizon_losses[layer_horizon] = sum(losses) / len(losses)
#%%
layer_horizons = [layer_horizon / 2 for layer_horizon in layer_horizon_losses.keys()]
fig = px.line(x=layer_horizons, y=list(layer_horizon_losses.values()))
fig.update_layout(title=f"Layer Horizon vs Loss of {model_name.upper()} ({model.cfg.n_layers} layers)")
fig.update_xaxes(title="Layer Horizon")
fig.update_yaxes(title="Loss")
fig.add_hline(y=baseline_loss, line_dash="dot", line_color="black", annotation_text=f"Baseline Loss: {baseline_loss:.4f}", annotation_position="bottom right")
fig.show()
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
path = f"figures/layer_horizon_vs_loss_{timestamp}.png"
fig.write_image(path, scale=4)
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment