Last active
April 19, 2024 19:44
-
-
Save bigsnarfdude/2bab6d4fd7c6c3c419cfec264fb0f082 to your computer and use it in GitHub Desktop.
convert_alpaca.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pretraining -> supervised instruction-finetuning -> RLHF | |
import json | |
import tiktoken | |
# open file | |
def extract_text_from_jsonl(file_path): | |
prompts = [] | |
completions = [] | |
with open(file_path, 'r') as file: | |
for line in file: | |
data = json.loads(line) | |
if 'prompt' in data: | |
prompts.append(data['prompt']) | |
if 'completion' in data: | |
completions.append(data['completion']) | |
return prompts, completions | |
# extract texts | |
prompts, completions = extract_text_from_jsonl(target_file="alpaca_data.json") | |
# using generate_prompt - generate train, generate labels | |
def generate_prompt(data_point): | |
# sorry about the formatting disaster gotta move fast | |
if data_point["instruction"]: | |
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. | |
### Instruction: | |
{data_point["instruction"]} | |
### Input: | |
{data_point["input"]} | |
### Response: | |
{data_point["output"]}""" | |
else: | |
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. | |
### Instruction: | |
{data_point["instruction"]} | |
### Response: | |
{data_point["output"]}""" | |
# generate train, generate labels | |
# output to train.bin | |
data = data.map(lambda data_point: {"prompt": tokenizer(generate_prompt(data_point))}) | |
# encode all data into raw numbers for train.bin | |
enc = tiktoken.get_encoding("gpt2") | |
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) | |
decode = lambda l: enc.decode(l) | |
start_ids = encode(input_test_text) | |
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) | |
#labels = x.clone() | |
#shift_labels = labels[..., 1:].contiguous() | |
# expected_prediction = y[..., -3: ] | |
#sliding predictions | |
#shift_logits = y[..., :-1, :].contiguous() | |
#shift_labels = labels[..., 1:].contiguous() | |
with torch.no_grad(): | |
with ctx: | |
for k in range(num_samples): | |
first_ten_tokens_inputs = x[...,:10].contiguous() | |
next_three_tokens_after_inputs = x[...,10:].contiguous() | |
y = model.generate(first_ten_tokens_inputs, max_new_tokens, temperature=temperature, top_k=top_k) | |
expected_prediction = y[..., -3:] | |
last_three_tokens_predictions = y[...,-3:].contiguous() | |
#labels = x.clone() | |
#shift_logits = y[..., :-1, :].contiguous() | |
#shift_labels = labels[..., 1:].contiguous() | |
#print(shift_logits) | |
print('------- prediction --------') | |
print(decode(expected_prediction[0].tolist())) | |
#print(decode(y[0].tolist()[-3:])) | |
print('------- label --------') | |
print(decode(next_three_tokens_after_inputs[0].tolist()[:3])) | |
# save to train.bin | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment