Skip to content

Instantly share code, notes, and snippets.

@Phylliida
Created February 20, 2024 23:50
Show Gist options
  • Save Phylliida/ec38ffba6addf8c65bc0fd2479b1e063 to your computer and use it in GitHub Desktop.
Save Phylliida/ec38ffba6addf8c65bc0fd2479b1e063 to your computer and use it in GitHub Desktop.
Noising mamba
from einops import rearrange
import torch
from functools import partial
from jaxtyping import Float
from transformer_lens.hook_points import HookPoint
import tqdm
import pandas as pd
import plotly.express as px
from mamba_lens import HookedMamba
from test_data import greater_than_data_generator, IOI_generator, ABC_TEMPLATES, BAC_TEMPLATES, BABA_TEMPLATES, BABA_LONG_TEMPLATES, BABA_LATE_IOS, BABA_EARLY_IOS
model = HookedMamba.from_pretrained("state-spaces/mamba-370m")
torch.set_grad_enabled(False)
seed = 27
num_examples = 120
data = IOI_generator(templates=[BABA_TEMPLATES[0]], tokenizer=model.tokenizer, num_examples=num_examples, seed=seed)
batched_data = []
batched_correct = []
batched_incorrect = []
for i, (prompt, corrects, incorrects) in enumerate(data):
if i < 3:
print(prompt, corrects, incorrects)
batched_data.append(torch.tensor(model.tokenizer.encode(prompt), device=model.cfg.device))
batched_correct.append(model.tokenizer.encode(corrects[0])[0])
batched_incorrect.append(model.tokenizer.encode(incorrects[0])[0])
batched_data = torch.stack(batched_data)
batched_correct = torch.tensor(batched_correct)
batched_incorrect = torch.tensor(batched_incorrect)
points = list(torch.linspace(0, 0.1, 200))
output_accuracies = torch.zeros([len(points)], device=model.cfg.device)
output_prs = torch.zeros([len(points)], device=model.cfg.device)
output_prs_incorrect = torch.zeros([len(points)], device=model.cfg.device)
def resid_pre_hook(
resid_pre: Float[torch.Tensor, "B L D"],
hook: HookPoint,
noise_std: Float,
) -> Float[torch.Tensor, "B L D"]:
noise = torch.empty(resid_pre.size(), device=model.cfg.device).normal_(mean=0,std=noise_std)
return resid_pre + noise
for i, noise_std in tqdm.tqdm(enumerate(points)):
hook = partial(resid_pre_hook, noise_std=noise_std)
hooks = []
for layer in range(model.cfg.n_layers):
hooks.append((f"blocks.{layer}.hook_resid_pre", hook))
logits = model.run_with_hooks(input=batched_data, fwd_hooks=hooks, fast_ssm=True, fast_conv=True)[:,-1]
prs = torch.nn.functional.softmax(logits, dim=1)
correct_prs = prs[torch.arange(num_examples),batched_correct]
incorrect_prs = prs[torch.arange(num_examples),batched_incorrect]
correct_logits = logits[torch.arange(num_examples),batched_correct]
incorrect_logits = logits[torch.arange(num_examples),batched_incorrect]
num_correct = torch.sum(correct_logits > incorrect_logits)
output_accuracies[i] = num_correct/float(num_examples)
output_prs[i] = torch.mean(correct_prs)
output_prs_incorrect[i] = torch.mean(incorrect_prs)
def bar_chart(data, x_labels, y_label, title, font_size=None):
# it requires a pandas dict with the columns and rows named, annoying
# by default rows and columns are named with ints so we relabel them accordingly
renames = dict([(i, x_labels[i]) for i in range(len(x_labels))])
ps = pd.DataFrame(data.cpu().numpy()).rename(renames, axis='rows').rename({0: y_label}, axis='columns')
fig = px.bar(ps, y=y_label, x=x_labels, title=title)
if not font_size is None:
fig.update_layout(
xaxis = dict(
tickmode='array',
tickvals = x_labels,
ticktext = x_labels,
),
font=dict(size=font_size, color="black"))
#fig.update_xaxes(title_font=dict(size=font_size))
fig.show()
bar_chart(data=output_accuracies, x_labels=points, y_label='accuracy', title='applying mean zero noise to input of every layer')
bar_chart(data=output_prs, x_labels=points, y_label='pr correct answer', title='applying mean zero noise to input of every layer')
bar_chart(data=output_prs_incorrect, x_labels=points, y_label='pr incorrect answer', title='applying mean zero noise to input of every layer')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment