-
-
Save thistleknot/b936477ee82ce608b3c7f47381f6b15d to your computer and use it in GitHub Desktop.
control vector training
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
!pip install repeng | |
import json | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from repeng import ControlVector, ControlModel, DatasetEntry | |
import numpy as np | |
import os | |
from datasets import load_dataset | |
## Control Vector Generation ## | |
import json | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from repeng import ControlVector, ControlModel, DatasetEntry | |
from tqdm import tqdm | |
import random | |
from datasets import load_dataset | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from repeng import ControlVector, ControlModel, DatasetEntry | |
from datasets import load_dataset | |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
# Initialize model | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token_id = 0 | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) | |
model = model.to("cuda") | |
model = ControlModel(model, list(range(-5, -18, -1))) | |
model.half() | |
def format_chat_template(row): | |
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False) | |
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False) | |
return row | |
#volatile handle | |
def subset_dataset(dataset_name, n, map=False): | |
dataset = load_dataset(dataset_name, split="all") | |
dataset = dataset.shuffle(seed=42).select(range(n)) | |
if(map): | |
dataset = dataset.map( | |
format_chat_template, | |
num_proc= os.cpu_count(), | |
) | |
else: | |
pass | |
return dataset | |
# Define a function to generate baseline responses with max length constraint | |
def make_orpo_dataset() -> list[DatasetEntry]: | |
dataset = [] | |
orpo_preferred = [c.replace('<|user|>','[INST]').replace('<|assistant|>','[/INST]') for c in orpo_1k['chosen']] | |
orpo_rejected = [c.replace('<|user|>','[INST]').replace('<|assistant|>','[/INST]') for c in orpo_1k['rejected']] | |
for preferred, rejected in zip(orpo_preferred, orpo_rejected): | |
dataset.append( | |
DatasetEntry( | |
positive=preferred, | |
negative=rejected, | |
) | |
) | |
return dataset | |
dolly_1k = subset_dataset("databricks/databricks-dolly-15k", 1000) | |
orpo_1k = subset_dataset("mlabonne/orpo-dpo-mix-40k", 1000, map=True) | |
orpo_dataset = make_orpo_dataset() | |
example = random.sample(dolly_1k['instruction'],1)[0] | |
print(example) | |
out = model.generate( | |
**tokenizer( | |
f"[INST] {example} [/INST]", | |
return_tensors="pt" | |
).to('cuda'), | |
do_sample=False, | |
max_new_tokens=128, | |
repetition_penalty=1.1, | |
) | |
print(tokenizer.decode(out.squeeze()).strip()) | |
# train the vector—takes less than a minute! | |
orpo_vector = ControlVector.train(model, tokenizer, orpo_dataset) | |
# set the control strength and let inference rip! | |
for strength in (-2.2, 1, 2.2): | |
print(f"strength={strength}") | |
model.set_control(orpo_vector, strength) | |
out = model.generate( | |
**tokenizer( | |
f"[INST] {example} [/INST]", | |
return_tensors="pt" | |
), | |
do_sample=False, | |
max_new_tokens=128, | |
repetition_penalty=1.1, | |
) | |
print(tokenizer.decode(out.squeeze()).strip()) | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment