Skip to content

Instantly share code, notes, and snippets.

@FoobarProtocol
Created October 17, 2023 23:25
Show Gist options
  • Save FoobarProtocol/a2cf81b1005752e5340493e665994561 to your computer and use it in GitHub Desktop.
Save FoobarProtocol/a2cf81b1005752e5340493e665994561 to your computer and use it in GitHub Desktop.
Full script; can't be called w/o other files where the classes are pulled from, but this gives you the whole pipeline (skeleton) of what this entails. Frankly this file is way too fucking big - code needs to be refactored so we're not reading 900+ lines of code in one file 💀
import aiohttp
import argparse
import asyncio
import backoff
import copy
import datetime
import faiss
import os
import json
import math
import numpy as np
import random
import re
import requests
import secrets
import sys
import yaml
from collections import defaultdict
from loguru import logger
from time import sleep
from tqdm import tqdm
from typing import List, Dict, Any
from uuid import uuid4
from airoboros.embeddings import calculate_embeddings
from airoboros.exceptions import (
RateLimitError,
TooManyRequestsError,
TokensExhaustedError,
ServerOverloadedError,
ServerError,
ContextLengthExceededError,
BadResponseError,
)
from fast_sentence_transformers import FastSentenceTransformer
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
# Defaults and constants.
MAX_DOCSTORE_SIZE = 15000
OPENAI_API_BASE_URL = "https://api.openai.com"
READABILITY_HINT = "The output should be written in such a way as to have a Flesch-Kincaid readability score of 30 or lower - best understood by those with college education. Only output the story - don't add any notes or information about Flesch-Kincaid scores."
class SelfInstructor:
"""Class and methods used to generate instructions, based on self-instruct paper/code."""
CLI_ARGS = {
# The updated code with several instructors has way too many options to support
# as CLI args, so we just accept the config file path now.
"--config-path": {
"type": str,
"default": "config.yaml",
"help": "path to the airobors configuration file",
},
"--debug": {
"action": "store_true",
"help": "enable debug logging",
},
}
def __init__(self, *, config_path: str = "config.yaml", debug: bool = False):
"""Constructor."""
if not debug:
logger.remove()
logger.add(sys.stdout, level="INFO")
self.used_tokens = 0
self.config_path = config_path
self.load_config()
self.instructor_counts = defaultdict(int)
def load_config(self):
"""Load an advanced configuration from a YAML file."""
raw_config = self.raw_config = yaml.safe_load(open(self.config_path).read())
self.model = raw_config.get("model") or "gpt-4"
self.openai_api_key = raw_config.get("openai_api_key") or os.environ.get(
"OPENAI_API_KEY"
)
if not self.openai_api_key:
raise ValueError(
"OPENAI_API_KEY environment variable or openai_api_key must be provided"
)
self.organization_id = raw_config.get("organization_id")
self.topics_path = raw_config.get("topics_path") or "topics.txt"
self.output_path = raw_config.get("output_path") or "instructions.jsonl"
self.overwrite = str(raw_config.get("overwrite")).lower() == "true"
self.append = str(raw_config.get("append")).lower() == "true"
self.topic_avoidance = raw_config.get("topic_avoidance", "")
self.response_filters = []
for val in raw_config.get("response_filters") or []:
self.response_filters.append(re.compile(val, re.I))
self.max_tokens = (
int(raw_config["max_tokens"]) if raw_config.get("max_tokens") else None
)
self.min_docsearch_score = float(raw_config.get("min_docsearch_score") or 0.35)
api_params = raw_config.get("api_params") or {}
self.api_params = {
"temperature": float(api_params.get("temperature") or 0.7),
"top_p": float(api_params.get("top_p") or 0.5),
"frequency_penalty": float(api_params.get("frequency_penalty") or 0.0),
"presence_penalty": float(api_params.get("presence_penalty") or 0.0),
}
self.topic_prompt = raw_config["topic_prompt"].format(
topic_avoidance=self.topic_avoidance
)
self.topic_request_count = int(raw_config.get("topic_request_count") or 20)
self.default_count = 100
if raw_config.get("default_count") is not None:
self.default_count = int(raw_config["default_count"])
self.default_batch_size = 5
if raw_config.get("default_batch_size") is not None:
self.default_batch_size = raw_config["default_batch_size"]
self.language = raw_config.get("language") or "English"
self.default_flesch = raw_config.get("default_flesch") or READABILITY_HINT
# Embedding model.
model_name = raw_config.get("embedding_model") or "thenlper/gte-small"
# Hacky, but we'll load this twice, the first time to get dimension, since
# it's not accessible in the Fast (cpu) version.
model = SentenceTransformer(model_name)
self.embedding_dimension = model.get_sentence_embedding_dimension()
model = None
if raw_config.get("embedding_device") == "cuda":
self.embedding_model = SentenceTransformer(model_name, device="cuda")
else:
self.embedding_model = FastSentenceTransformer(model_name, device="cpu")
self.embedding_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.index = faiss.IndexFlatL2(self.embedding_dimension)
# Validate the model for each generator.
self.instructors = raw_config.get("instructors")
self.validate_model(self.model)
valid_models = {self.model: True}
for key, config in self.instructors.items():
if config.get("model") and config["model"] not in valid_models:
self.validate_model(config["model"])
valid_models[config["model"]] = True
def initialize_index(self):
"""Initialize the in-memory faiss index to check prompt uniqueness."""
docs = []
if os.path.exists(self.output_path):
if self.overwrite:
result = input("Remove and overwrite {output_path} (Y/N)? ")
if result.strip().lower() == "y":
os.remove(self.output_path)
else:
raise RuntimeError("Overwrite aborted.")
elif self.append:
with open(self.output_path, "r") as infile:
for line in infile.readlines():
task = json.loads(line)
category = task.get("category", "general")
if category != "rp" or "rp" in task:
self.instructor_counts[category] += 1
if task["category"] != "rp":
docs.append(task["instruction"])
logger.info(
f"Found {len(docs)} existing machine-generated instruction(s)."
)
for category, count in self.instructor_counts.items():
logger.info(f" * category {category}: {count}")
else:
raise RuntimeError(
f"{self.output_path} already exists, but overwrite and append are false!"
)
logger.info("Initializing faiss index similarity comparison...")
if not docs:
docs = ["__initialize__"]
# This is a bit slow.
for doc in tqdm(docs):
self.index.add(
np.array(
[
calculate_embeddings(
doc, self.embedding_model, self.embedding_tokenizer
)
]
)
)
def validate_model(self, model):
"""Ensure the specified model is available."""
headers = {"Authorization": f"Bearer {self.openai_api_key}"}
if self.organization_id:
headers["OpenAI-Organization"] = self.organization_id
result = requests.get(f"{OPENAI_API_BASE_URL}/v1/models", headers=headers)
if result.status_code != 200:
raise ValueError(
f"Invalid openai API key [{result.status_code}: {result.text}]"
)
available = {item["id"] for item in result.json()["data"]}
if model not in available:
raise ValueError(f"Model is not available to your API key: {model}")
logger.success(f"Successfully validated model: {model}")
async def initialize_topics(self) -> List[str]:
"""Ensure topics are initialized, i.e. topics already exist and are read,
or a new list of topics is generated.
"""
if os.path.exists(self.topics_path):
self.topics = list(
{
line.strip()
for line in open(self.topics_path).readlines()
if line.strip()
}
)
logger.info(f"Using {len(self.topics)} topics from {self.topics_path}...")
return
logger.info("Generating random topics to use in prompts...")
seen = set([])
self.topics = []
with open(self.topics_path, "w") as outfile:
count = self.topic_request_count
while count > 0:
todo = 8 if count >= 8 else count
responses = await asyncio.gather(
*[
self.generate_response(self.topic_prompt, **self.api_params)
for _ in range(todo)
]
)
count -= todo
for response in responses:
if not response:
continue
for topic in re.findall(
r"(?:^|\n)\d+\. (.*?)(?:$|(?=\n\d+\. ))", response, re.DOTALL
):
if (
not topic
or topic.strip().endswith(":")
or topic.lower().strip() in seen
):
continue
seen.add(topic.lower().strip())
self.topics.append(topic)
outfile.write(topic.strip() + "\n")
logger.success(
f"Successfully generated {len(self.topics)} topics, stored in {self.topics_path}..."
)
def get_instructor_topics(self, instructor_config):
"""Get the topics for a specific instructor, defaulting to main topics.
:param instructor_config: Dict containing the target instructor's config.
:type instructor_config: dict
:return: List of topic strings.
:rtype: list[str]
"""
if not instructor_config.get("topics_path"):
return self.topics
with open(instructor_config["topics_path"]) as infile:
topics = list({line.strip() for line in infile.readlines() if line.strip()})
if not topics:
raise ValueError(
f"Found empty topics file: {instructor_config['topics_path']}"
)
return topics
@staticmethod
def load_template(path: str) -> str:
"""Load a prompt template."""
if not os.path.exists(path):
path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"instructors",
"prompts",
path,
)
with open(path) as infile:
return infile.read()
@backoff.on_exception(
backoff.fibo,
(
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
ServerError,
RateLimitError,
TooManyRequestsError,
ServerOverloadedError,
),
max_value=19,
)
async def _post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Perform a post request to OpenAI API.
:param path: URL path to send request to.
:type path: str
:param payload: Dict containing request body/payload.
:type payload: Dict[str, Any]
:return: Response object.
:rtype: Dict[str, Any]
"""
headers = {"Authorization": f"Bearer {self.openai_api_key}"}
if self.organization_id:
headers["OpenAI-Organization"] = self.organization_id
request_id = str(uuid4())
logger.debug(f"POST [{request_id}] with payload {json.dumps(payload)}")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{OPENAI_API_BASE_URL}{path}",
headers=headers,
json=payload,
timeout=600.0,
) as result:
if result.status != 200:
text = await result.text()
logger.error(f"OpenAI request error: {text}")
if "too many requests" in text.lower():
raise TooManyRequestsError(text)
if (
"rate limit reached" in text.lower()
or "rate_limit_exceeded" in text.lower()
):
sleep(30)
raise RateLimitError(text)
elif "context_length_exceeded" in text.lower():
raise ContextLengthExceededError(text)
elif "server_error" in text and "overloaded" in text.lower():
raise ServerOverloadedError(text)
elif (
"bad gateway" in text.lower() or "server_error" in text.lower()
):
raise ServerError(text)
else:
raise BadResponseError(text)
result = await result.json()
logger.debug(f"POST [{request_id}] response: {json.dumps(result)}")
self.used_tokens += result["usage"]["total_tokens"]
if self.max_tokens and self.used_tokens > self.max_tokens:
raise TokensExhaustedError(
f"Max token usage exceeded: {self.used_tokens}"
)
logger.debug(f"token usage: {self.used_tokens}")
return result
async def _post_no_exc(self, *a, **k):
"""Post, ignoring all exceptions."""
try:
return await self._post(*a, **k)
except Exception as ex:
logger.error(f"Error performing post: {ex}")
return None
async def generate_response(self, instruction: str, **kwargs) -> str:
"""Call OpenAI with the specified instruction and return the text response.
:param instruction: The instruction to respond to.
:type instruction: str
:return: Response text.
:rtype: str
"""
messages = copy.deepcopy(kwargs.pop("messages", None) or [])
filter_response = kwargs.pop("filter_response", True)
model = kwargs.get("model", self.model)
path = "/v1/chat/completions"
payload = {**kwargs}
if "model" not in payload:
payload["model"] = model
payload["messages"] = messages
if instruction:
payload["messages"].append({"role": "user", "content": instruction})
response = await self._post_no_exc(path, payload)
if (
not response
or not response.get("choices")
or response["choices"][0]["finish_reason"] == "length"
):
return None
text = response["choices"][0]["message"]["content"]
if filter_response:
for banned in self.response_filters:
if banned.search(text, re.I):
logger.warning(f"Banned response [{banned}]: {text}")
return None
if text.startswith(("I'm sorry,", "Apologies,", "I can't", "I won't")):
logger.warning(f"Banned response [apology]: {text}")
return None
return text
async def is_decent_response(self, item):
"""Filter the responses by having the LLM score based on a set of rules."""
config = self.raw_config.get("scoring", {})
template = self.load_template(config.get("prompt_path") or "filter.txt")
api_params = {**self.api_params, **config.get("api_params", {})}
instruction = item["instruction"]
if item.get("category") == "coding" and "PLAINFORMAT" in instruction:
instruction = instruction.replace("PLAINFORMAT", "").strip()
instruction += "\n" + "\n".join(
[
"Generate only the code, as a single, plain text output.",
"Do not include an intro sentence indicating what the code will do.",
"Do not include any instructions for usage, warnings about replacing certain values, etc.",
"Do not surround the code with backticks/markdown formatting.",
"Include help code comments.",
]
)
system_prompt = ""
if item.get("system"):
system_prompt = "\n".join(
[
"- did the response respect the system prompt that was used?",
"SYSTEM PROMPT:",
item["system"],
]
)
result = await self.generate_response(
template.format(
instruction=item["instruction"],
response=item["response"],
threshold=config.get("threshold") or "100",
system_prompt=system_prompt,
filter_response=False,
),
**api_params,
)
preview = item["instruction"].splitlines()[0][0:100]
if len(preview) == 100:
preview += "..."
if not result:
logger.error(
f"Error evaluating response, assuming decent [{item['category']}]: {preview}"
)
return True
if "GOOD" in result:
logger.info(f"Judge: good [{item['category']}]: {preview}")
return True
logger.info(f"Judge: bad [{item['category']}]: {preview}")
return False
async def judge(self, instructions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Filter only the "good" instructions, as determined by an LLM."""
batch_size = (
self.raw_config.get("judge", {}).get("batch_size")
or self.default_batch_size
)
batches = np.array_split(
instructions, math.ceil(len(instructions) / batch_size)
)
quality = []
for batch in batches:
results = await asyncio.gather(
*[self.is_decent_response(item["item"]) for item in batch]
)
for idx in range(len(batch)):
if results[idx]:
quality.append(batch[idx])
return quality
async def cull(self, input_paths: List[str], output_path: str) -> None:
"""Use the LLM to filter bad responses based on a set of rules.
:param input_paths: List of paths to the input JSONL file(s) to filter.
:type input_paths: List[str]
:param output_path: Path to save the "good" instructions to.
:type output_path: str
"""
# See if we have any state data.
state = {}
if os.path.exists(f"{output_path}.state"):
with open(f"{output_path}.state") as infile:
state = json.loads(infile.read())
logger.info(
f"Resuming from previous cull state - to restart, delete `{output_path}.state`"
)
def _save_state(c, line):
nonlocal state
state[c] = line
with open(f"{output_path}.state", "w") as outfile:
outfile.write(json.dumps(state, indent=2) + "\n")
categories = defaultdict(list)
found = set([])
for path in input_paths:
with open(path) as infile:
for line in infile.readlines():
item = json.loads(line)
category = item.get("category", "general")
if category == "reasoning_or_math":
category = "orca"
item["category"] = category
# Skip items already processed.
if category in state:
if line == state[category]:
found.add(category)
continue
elif category not in found:
continue
categories[category].append({"item": item, "line": line})
# Deduplicate and select best items.
output_file = open(output_path, "a+")
max_k = self.raw_config.get("cull_max_k")
if max_k is None:
max_k = 1000
for category in sorted(list(categories)):
items = categories[category]
# Skip categories that are too weird/cumbersome to score properly.
if category in [
"stylized_response",
"rp",
"detailed_writing",
"contextual",
"counterfactual_contextual",
"plan",
"song",
"wordgame",
]:
for idx in range(len(items)):
item = items[idx]["item"]
output_file.write(json.dumps(item) + "\n")
_save_state(category, items[idx]["line"])
output_file.flush()
continue
# Add all of the items in this category to a faiss index.
logger.info(
f"Initializing faiss index for {category} with {len(items)} documents..."
)
index = faiss.IndexFlatL2(self.embedding_dimension)
all_embeddings = []
for item in tqdm(items):
all_embeddings.append(
np.array(
[
calculate_embeddings(
"\n".join(
[
item["item"]["instruction"],
item["item"]["response"],
]
),
self.embedding_model,
self.embedding_tokenizer,
)
]
)
)
index.add(all_embeddings[-1])
# Here's where it's tricky...
#
# We need to iterate through the objects, finding all matches that are under are
# specified similarity score for this category.
#
# Once we've found all of the matches, we can select the "best" by first using
# the LLM to judge whether the response is high quality or not, but only if it's
# a category that we can score well.
#
# If multiple instructions remain that are high quality, we can use other metrics,
# such as length and complexity of speech to select the best.
#
# If none of the matching instructions are high quality, I guess we just remove
# all of them?
purged = set([])
saved = set([])
min_score = (
self.instructors.get(category, {}).get("min_docsearch_score")
or self.min_docsearch_score
)
for idx in range(len(items)):
if idx in purged or idx in saved:
continue
distances, indices = index.search(
all_embeddings[idx], k=min(len(items), max_k)
)
distances = distances[0].tolist()[1:]
indices = indices[0].tolist()[1:]
batch = [items[idx]]
batch_idx = [idx]
for check_idx in range(len(distances)):
# Don't check items before this one (since they would have already been checked).
if indices[check_idx] < idx:
continue
# Don't check items we've judged as duplicate or low-quality.
if indices[check_idx] in purged:
continue
# Ignore coding instructions that don't match on PLAINFORMAT tag.
if category == "coding":
source_has_plain = (
"PLAINFORMAT" in items[idx]["item"]["instruction"]
)
target_has_plain = (
"PLAINFORMAT"
in items[indices[check_idx]]["item"]["instruction"]
)
if (source_has_plain and not target_has_plain) or (
target_has_plain and not source_has_plain
):
continue
# Ignore and stop checking if we exceed the min score.
if distances[check_idx] > min_score:
break
batch.append(items[indices[check_idx]])
batch_idx.append(indices[check_idx])
# Score the batch.
quality = await self.judge(batch)
if not quality:
for purge_idx in range(len(batch)):
purged.add(batch_idx[purge_idx])
preview = items[batch_idx[purge_idx]]["item"][
"instruction"
].splitlines()[0][0:100]
logger.warning(f"Removing low-quality instruction: {preview}")
continue
# Only one high-quality result, keep it.
if len(quality) == 1:
preview = quality[0]["item"]["instruction"].splitlines()[0][0:100]
logger.success(f"Saving high-quality instruction: {preview}")
output_file.write(json.dumps(quality[0]["item"]) + "\n")
output_file.flush()
_save_state(category, quality[0]["line"])
found = False
for save_idx in range(len(batch)):
if batch[save_idx] == quality[0]:
if not found:
saved.add(batch_idx[save_idx])
found = True
else:
purged.add(batch_idx[save_idx])
continue
# This is kind of a hacky fallback, but it's fast and easy.
longest = sorted(
quality,
key=lambda x: len(x["item"]["instruction"] + x["item"]["response"]),
)[-1]
found = False
for purge_idx in range(len(batch)):
if batch[purge_idx] == longest and not found:
found = True
saved.add(batch_idx[purge_idx])
if batch[purge_idx] != longest or found:
purged.add(batch_idx[purge_idx])
found = True
preview = longest["item"]["instruction"].splitlines()[0][0:100]
logger.success(f"Saving high-quality, longest instruction: {preview}")
output_file.write(json.dumps(longest) + "\n")
output_file.flush()
_save_state(category, longest["line"])
output_file.close()
async def is_too_similar(
self, input_text: str, min_score: float = None, index=None
):
"""Check the similarity of an input string against an index.
:param input_text: The input string to calculate similarity of.
:type input_text: str
:param min_score: Minimum document similarity score to consider unique.
:type min_score: float
:param index: Optional faiss index to query against, defaults to main index.
:type index: failss index
:return: Boolean indicating if the text is too similar or not.
:rtype: bool
"""
index = index or self.index
input_embeddings = np.array(
[
calculate_embeddings(
input_text, self.embedding_model, self.embedding_tokenizer
)
]
)
min_score = min_score or self.min_docsearch_score
distance, _ = index.search(input_embeddings, k=1)
distance = distance[0].tolist()
if not distance:
return False
if distance[0] <= min_score:
logger.warning(f"Too similar [{distance[0]}]: {input_text}")
return True
return False
def persist(self, item):
"""Persist a single item to the output file and add it to the index."""
skip_counting = item.pop("skip_counting", False)
if "instruction" in item:
item["instruction"] = item["instruction"].strip()
if "response" in item:
item["response"] = item["response"].strip()
if "system" in item:
item["system"] = item["system"].strip()
self.outfile.write(json.dumps(item) + "\n")
self.outfile.flush()
if item["category"] != "rp":
self.index.add(
np.array(
[
calculate_embeddings(
item["instruction"],
self.embedding_model,
self.embedding_tokenizer,
)
]
)
)
if not skip_counting:
self.instructor_counts[item["category"]] += 1
async def run_instructor(self, category, method_map, **kwargs):
"""Run a single instructor, as an async task."""
if category not in method_map:
logger.warning(f"Unknown category: {category}, skipping...")
return
logger.info(f"Generating instructions for {category}...")
started_at = datetime.datetime.now()
running_total = self.instructor_counts.get(category, 0)
async for item in method_map[category](self, **kwargs):
self.persist(item)
preview = None
if category == "rp":
if "rp" in item:
running_total += 1
preview = item["rp"][0]["content"].splitlines()[0][:100]
else:
running_total += 1
preview = item["instruction"].splitlines()[0][0:100]
if preview:
logger.success(
f"Generated unique instruction [{category}, total={running_total}]: {preview}"
)
delta = (datetime.datetime.now() - started_at).total_seconds()
logger.success(
f"Finished generating {running_total} instructions [{category}] in {delta} seconds."
)
async def run(self):
"""Run prompt generation and answer to completion."""
from airoboros.instructors.agent import generate as agent_generator
from airoboros.instructors.awareness import generate as awareness_generator
from airoboros.instructors.card import generate as card_generator
from airoboros.instructors.coding import generate as coding_generator
from airoboros.instructors.contextual import generate as contextual_generator
from airoboros.instructors.cot import generate as cot_generator
from airoboros.instructors.counterfactual_contextual import (
generate as counterfactual_contextual_generator,
)
from airoboros.instructors.detailed_writing import (
generate as detailed_writing_generator,
)
from airoboros.instructors.editor import generate as editor_generator
from airoboros.instructors.experience import generate as experience_generator
from airoboros.instructors.general import generate as general_generator
from airoboros.instructors.gtkm import generate as gtkm_generator
from airoboros.instructors.joke import generate as joke_generator
from airoboros.instructors.misconception import (
generate as misconception_generator,
)
from airoboros.instructors.multiple_choice import (
generate as multiple_choice_generator,
)
from airoboros.instructors.orca import generate as orca_generator
from airoboros.instructors.plan import generate as plan_generator
from airoboros.instructors.riddle import generate as riddle_generator
from airoboros.instructors.roleplay import generate as roleplay_generator
from airoboros.instructors.rp import generate as rp_generator
from airoboros.instructors.rp import generate_cards
from airoboros.instructors.song import generate as song_generator
from airoboros.instructors.stylized_response import (
generate as stylized_response_generator,
)
from airoboros.instructors.trivia import generate as trivia_generator
from airoboros.instructors.wordgame import generate as wordgame_generator
from airoboros.instructors.writing import generate as writing_generator
method_map = {
"agent": agent_generator,
"awareness": awareness_generator,
"card": card_generator,
"coding": coding_generator,
"contextual": contextual_generator,
"cot": cot_generator,
"counterfactual_contextual": counterfactual_contextual_generator,
"detailed_writing": detailed_writing_generator,
"experience": experience_generator,
"general": general_generator,
"joke": joke_generator,
"misconception": misconception_generator,
"multiple_choice": multiple_choice_generator,
"plan": plan_generator,
"orca": orca_generator,
"riddle": riddle_generator,
"roleplay": roleplay_generator,
"rp": rp_generator,
"song": song_generator,
"trivia": trivia_generator,
"wordgame": wordgame_generator,
"writing": writing_generator,
}
await self.initialize_topics()
self.initialize_index()
# Generate instructions for each category.
self.outfile = open(self.output_path, "a+")
started_at = datetime.datetime.now()
try:
tasks = [
asyncio.create_task(self.run_instructor(category, method_map))
for category in self.instructors
]
for task in tasks:
await task
finally:
self.outfile.close()
# Editor needs the writing data to run.
if self.instructor_counts.get("writing"):
logger.info("Generating editor prompts using existing writing data...")
method_map["editor"] = editor_generator
self.outfile = open(self.output_path, "a+")
try:
await self.run_instructor("editor", method_map)
finally:
self.outfile.close()
# After we have a sampling of instructions, we can also generate a list of responses
# based on character cards generated.
if await generate_cards(self):
logger.info(
"Re-generating a sampling of responses using character cards..."
)
with open(self.output_path) as infile:
existing = [json.loads(line) for line in infile.readlines()]
method_map["stylized_response"] = stylized_response_generator
method_map["gtkm"] = gtkm_generator
self.outfile = open(self.output_path, "a+")
tasks = []
try:
tasks.append(
asyncio.create_task(
self.run_instructor(
"stylized_response",
method_map,
existing=existing,
)
)
)
tasks.append(
asyncio.create_task(self.run_instructor("gtkm", method_map))
)
for task in tasks:
await task
finally:
self.outfile.close()
delta = (datetime.datetime.now() - started_at).total_seconds()
logger.success(
f"Finished generating all instructions in {delta} seconds, enjoy!"
)
def generate_instructions(args):
random.seed(secrets.randbelow(1000000000))
parser = argparse.ArgumentParser()
for arg, kwargs in SelfInstructor.CLI_ARGS.items():
parser.add_argument(arg, **kwargs)
asyncio.run(SelfInstructor(**vars(parser.parse_args(args))).run())
def generate_topics(args):
random.seed(secrets.randbelow(1000000000))
parser = argparse.ArgumentParser()
for arg, kwargs in SelfInstructor.CLI_ARGS.items():
parser.add_argument(arg, **kwargs)
instructor = SelfInstructor(**vars(parser.parse_args(args)))
asyncio.run(instructor.initialize_topics())
def cull_instructions(args):
random.seed(secrets.randbelow(1000000000))
parser = argparse.ArgumentParser()
for arg, kwargs in SelfInstructor.CLI_ARGS.items():
parser.add_argument(arg, **kwargs)
parser.add_argument(
"--input",
**{
"type": str,
"help": "path to the file containing instructions to cull",
"nargs": "+",
},
)
parser.add_argument(
"--output",
**{
"type": str,
"default": "culled.jsonl",
"help": "path to save the culled instructions to",
},
)
all_args = vars(parser.parse_args(args))
instructor = SelfInstructor(config_path=all_args["config_path"])
asyncio.run(instructor.cull(all_args["input"], all_args["output"]))
if __name__ == "__main__":
generate_instructions(sys.argv[1:])
@FoobarProtocol
Copy link
Author

Placeholder (going to throw some documentation here that walks down what this code does top to bottom)

Don't thank me though - not doing this for you. Doing it for me because you gotta annotate a file like this (especially when devs decide to drop something of this size w/o so much as a README.md or documentation of any sort)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment