Skip to content

Instantly share code, notes, and snippets.

@hmmhmmhm
Created July 18, 2023 08:05
Show Gist options
  • Save hmmhmmhm/fe42a1cb0c99b842bcdd83109e8c425a to your computer and use it in GitHub Desktop.
Save hmmhmmhm/fe42a1cb0c99b842bcdd83109e8c425a to your computer and use it in GitHub Desktop.
Fine Tuning Open AI
import fs from "fs";
import * as tokenizer from "gpt-3-encoder";
import { Configuration, OpenAIApi } from "openai";
import { logger } from "./utils/logger.js";
import jsonl from "jsonl";
import inquirer from "inquirer";
import chalk from "chalk";
export const loadProjectDataFromJSON = (
assetPath: string
): Record<string, string> => {
const json = fs.readFileSync(assetPath, "utf8");
const data = JSON.parse(json);
return data;
};
export const tokenSplitter = (code: string, limitToken: number) => {
const token = tokenizer.encode(code);
const collected: string[] = [];
let current: number[] = [];
for (const t of token) {
if (current.length + 1 > limitToken) {
collected.push(tokenizer.decode(current));
current = [];
}
current.push(t);
}
if (current.length > 0) {
collected.push(tokenizer.decode(current));
}
return collected;
};
export const convertToTrainingData = (
data: Record<string, string>,
tokenSplitCount: number
) => {
const prompts: {
prompt: string;
completion: string;
}[] = [];
const timestamp = new Date().toISOString();
for (const [key, value] of Object.entries(data)) {
const completionToken = tokenizer.encode(value);
if (completionToken.length <= tokenSplitCount) {
const oneLinePrompt = {
prompt: `{
title: "friday-gpt 프로젝트 파일",
filePath: ${key},
lastModified: ${timestamp}
}`,
completion: `${value}`,
};
prompts.push(oneLinePrompt);
} else {
const splitted = tokenSplitter(value, tokenSplitCount);
let i = 1;
for (const s of splitted) {
prompts.push({
prompt: `{
title: "friday-gpt 프로젝트 파일 (분할됨: ${i}/${splitted.length})",
filePath: ${key},
lastModified: ${timestamp}
}`,
completion: s,
});
i += 1;
}
}
}
return prompts;
};
export const uploadTrainingData = async (assetPath: string) => {
const configuration = new Configuration({
apiKey: process.env.OPENAI_API_KEY,
});
const openai = new OpenAIApi(configuration);
// * https://github.com/openai/openai-node/issues/25#issuecomment-1291536117
const file = fs.createReadStream(assetPath) as any;
const response = await openai.createFile(file, "fine-tune");
return response.data.id;
};
export const deleteTrainingData = async (fileId: string) => {
const configuration = new Configuration({
apiKey: process.env.OPENAI_API_KEY,
});
const openai = new OpenAIApi(configuration);
const response = await openai.deleteFile(fileId);
return response.data.deleted;
};
export const createFineTuneModel = async ({
traningDataId,
epochs,
}: {
traningDataId: string;
epochs: number;
}) => {
const configuration = new Configuration({
apiKey: process.env.OPENAI_API_KEY,
});
const openai = new OpenAIApi(configuration);
const response = await openai.createFineTune({
training_file: traningDataId,
model: "davinci",
suffix: "friday-gpt",
n_epochs: epochs,
});
const events = response?.data?.events;
return {
events,
fineTuneId: response.data.id,
};
};
export const getStatusOfFineTuning = async (fineTuneId: string) => {
const configuration = new Configuration({
apiKey: process.env.OPENAI_API_KEY,
});
const openai = new OpenAIApi(configuration);
const response = await openai.listFineTuneEvents(fineTuneId);
return response.data.data;
};
export const now = () => {
return new Date().toISOString().split("T")[1].split(".")[0];
};
export const main = async () => {
logger(`[${now()}] 파인튜닝할 프로젝트 데이터 로딩중...`);
const data = loadProjectDataFromJSON("./training/collected.json");
// * https://platform.openai.com/docs/models/gpt-3
const prompts = convertToTrainingData(data, 1500);
// * 비용예측
// * https://openai.com/pricing
fs.writeFileSync("./training/training.json", JSON.stringify(prompts));
await new Promise<void>((resolve) => {
fs.createReadStream("./training/training.json")
.pipe(jsonl())
.pipe(fs.createWriteStream("./training/training.jsonl"))
.on("finish", () => {
resolve();
});
});
// * 몇번 반복해서 학습할지 묻기 (기본값:1 이나 4번 학습이 권장됨)
// const repeatCount = await prompt(
// "몇번 반복해서 학습할까요? (epochs 기본값: 1): "
// );
// * readline 이용해서 묻기
const { repeatCount } = await inquirer.prompt({
name: "repeatCount",
message: chalk.magentaBright(
`몇번 반복해서 학습할까요? (epochs 기본값: 1): `
),
});
const epochs = Number(repeatCount) || 1;
const token = tokenizer.encode(JSON.stringify(prompts));
const trainigPrice = (token.length / 1000) * 0.03 * epochs;
logger(
`[${now()}] 프로젝트 전체 트레이닝 비용: $${trainigPrice} (${
token.length
} 토큰)`
);
// * 진행할지 물어보기
const { answer } = await inquirer.prompt({
name: "answer",
message: chalk.magentaBright(`진행하시겠습니까? (y/n): `),
});
if (answer.toLowerCase() !== "y") {
logger(`[${now()}] 종료합니다.`);
return;
}
// * 트레이닝 데이터 업로드
const traningDataId = await uploadTrainingData("./training/training.jsonl");
logger(`[${now()}] 트레이닝 데이터 업로드 완료: ${traningDataId}`);
// * 파인튜닝 모델 생성
const { fineTuneId } = await createFineTuneModel({ traningDataId, epochs });
logger(`[${now()}] 파인튜닝 모델 학습이 시작되었습니다: ${fineTuneId}`);
const startFineTuningTime = new Date().getTime();
// * 파인튜닝 모델 완성 대기
let status = await getStatusOfFineTuning(fineTuneId);
const checkedMessageCreatedAt = new Set<number>();
let isFinished = false;
while (!isFinished) {
for (const event of status) {
if (!checkedMessageCreatedAt.has(event.created_at)) {
logger(`[${now()}] [Open A.I]: ${event.message}`);
checkedMessageCreatedAt.add(event.created_at);
}
if (
event.message.includes("cancelled") ||
event.message.includes("failed")
) {
logger(`[${now()}] 파인튜닝 모델 학습이 완료되었습니다: ${fineTuneId}`);
isFinished = true;
break;
}
if (event.message.startsWith("Uploaded model: ")) {
const modelId = event.message.split("Uploaded model: ")[1];
fs.writeFileSync("./training/modelId.txt", modelId);
logger(`[${now()}] 파인튜닝 모델 학습이 완료되었습니다: ${modelId}`);
isFinished = true;
break;
}
}
status = await getStatusOfFineTuning(fineTuneId);
await new Promise((resolve) => setTimeout(resolve, 2000));
}
const endFineTuningTime = new Date().getTime();
const fineTuningTime = (endFineTuningTime - startFineTuningTime) / 1000;
logger(`[${now()}] ${fineTuningTime}초 만에 모델 학습이 완료되었습니다.`);
// * 트레이닝 데이터 삭제
await deleteTrainingData(traningDataId);
logger(`[${now()}] 트레이닝 데이터 삭제 완료: ${traningDataId}`);
};
main();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment