Skip to content

Instantly share code, notes, and snippets.

@thistleknot
Last active April 29, 2024 03:06
Show Gist options
  • Save thistleknot/b936477ee82ce608b3c7f47381f6b15d to your computer and use it in GitHub Desktop.
Save thistleknot/b936477ee82ce608b3c7f47381f6b15d to your computer and use it in GitHub Desktop.
control vector training
!pip install repeng
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from repeng import ControlVector, ControlModel, DatasetEntry
import numpy as np
import os
from datasets import load_dataset
## Control Vector Generation ##
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from repeng import ControlVector, ControlModel, DatasetEntry
from tqdm import tqdm
import random
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from repeng import ControlVector, ControlModel, DatasetEntry
from datasets import load_dataset
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Initialize model
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = 0
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model = model.to("cuda")
model = ControlModel(model, list(range(-5, -18, -1)))
model.half()
def format_chat_template(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row
#volatile handle
def subset_dataset(dataset_name, n, map=False):
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42).select(range(n))
if(map):
dataset = dataset.map(
format_chat_template,
num_proc= os.cpu_count(),
)
else:
pass
return dataset
# Define a function to generate baseline responses with max length constraint
def make_orpo_dataset() -> list[DatasetEntry]:
dataset = []
orpo_preferred = [c.replace('<|user|>','[INST]').replace('<|assistant|>','[/INST]') for c in orpo_1k['chosen']]
orpo_rejected = [c.replace('<|user|>','[INST]').replace('<|assistant|>','[/INST]') for c in orpo_1k['rejected']]
for preferred, rejected in zip(orpo_preferred, orpo_rejected):
dataset.append(
DatasetEntry(
positive=preferred,
negative=rejected,
)
)
return dataset
dolly_1k = subset_dataset("databricks/databricks-dolly-15k", 1000)
orpo_1k = subset_dataset("mlabonne/orpo-dpo-mix-40k", 1000, map=True)
orpo_dataset = make_orpo_dataset()
example = random.sample(dolly_1k['instruction'],1)[0]
print(example)
out = model.generate(
**tokenizer(
f"[INST] {example} [/INST]",
return_tensors="pt"
).to('cuda'),
do_sample=False,
max_new_tokens=128,
repetition_penalty=1.1,
)
print(tokenizer.decode(out.squeeze()).strip())
# train the vector—takes less than a minute!
orpo_vector = ControlVector.train(model, tokenizer, orpo_dataset)
# set the control strength and let inference rip!
for strength in (-2.2, 1, 2.2):
print(f"strength={strength}")
model.set_control(orpo_vector, strength)
out = model.generate(
**tokenizer(
f"[INST] {example} [/INST]",
return_tensors="pt"
),
do_sample=False,
max_new_tokens=128,
repetition_penalty=1.1,
)
print(tokenizer.decode(out.squeeze()).strip())
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment