Last active
August 19, 2024 03:25
-
-
Save nabenabe0928/a20e5d0273ab332822a6d17fe2fc6f84 to your computer and use it in GitHub Desktop.
Demo of Optuna Artifact Store for an Experiment with LLM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pickle | |
import tarfile | |
import tempfile | |
from datasets import load_dataset | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import peft | |
import torch | |
import transformers | |
import optuna | |
DATASET_PICKLE_PATH = "qa_dataset.pkl" | |
question_dict = { | |
"context": [ | |
"Cheese is the best food.", | |
"Cheese is the best food.", | |
"The Moon orbits Earth at an average distance of 384,400 km (238,900 mi), or about 30 times Earth's diameter. Its gravitational influence is the main driver of Earth's tides and very slowly lengthens Earth's day. The Moon's orbit around Earth has a sidereal period of 27.3 days. During each synodic period of 29.5 days, the amount of visible surface illuminated by the Sun varies from none up to 100%, resulting in lunar phases that form the basis for the months of a lunar calendar. The Moon is tidally locked to Earth, which means that the length of a full rotation of the Moon on its own axis causes its same side (the near side) to always face Earth, and the somewhat longer lunar day is the same as the synodic period. However, 59% of the total lunar surface can be seen from Earth through cyclical shifts in perspective known as libration.", | |
], | |
"question": [ | |
"What is the best food?", | |
"How far away is the Moon from the Earth?", | |
"At what distance does the Moon orbit the Earth?", | |
], | |
} | |
base_path = "artifacts" | |
os.makedirs(base_path, exist_ok=True) | |
artifact_store = optuna.artifacts.FileSystemArtifactStore(base_path) | |
def create_qa_dataset(): | |
if os.path.exists(DATASET_PICKLE_PATH): | |
return pickle.load(open(DATASET_PICKLE_PATH, mode="rb")) | |
def _create_prompt(context, question, answer): | |
if len(answer["text"]) < 1: | |
answer = "Cannot Find Answer" | |
else: | |
answer = answer["text"][0] | |
prompt_template = ( | |
f"### CONTEXT\n{context}\n\n### QUESTION\n{question}\n\n### ANSWER\n{answer}</s>" | |
) | |
return prompt_template | |
tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/tokenizer") | |
qa_dataset = load_dataset("squad_v2").map( | |
lambda samples: tokenizer( | |
_create_prompt(samples["context"], samples["question"], samples["answers"]) | |
) | |
) | |
with open(DATASET_PICKLE_PATH, mode="wb") as f: | |
pickle.dump(qa_dataset, f) | |
return qa_dataset | |
def build_model(): | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
"bigscience/bloom-560m", torch_dtype=torch.float16, device_map="auto" | |
) | |
config = peft.LoraConfig( | |
r=4, | |
lora_alpha=8, | |
target_modules=["query_key_value"], | |
lora_dropout=0.05, | |
bias="none", | |
task_type="CAUSAL_LM", | |
) | |
return peft.get_peft_model(model, config) | |
def suggest_train_params(trial, trial_dir_path): | |
# max_steps = 1000 | |
max_steps = 10 | |
learning_rate = trial.suggest_float("learning_rate", 1e-5, 1.0, log=True) | |
per_device_batch_size = 2 | |
return transformers.TrainingArguments( | |
per_device_eval_batch_size=per_device_batch_size, | |
per_device_train_batch_size=per_device_batch_size, | |
gradient_accumulation_steps=16 // per_device_batch_size, | |
learning_rate=learning_rate, | |
output_dir=trial_dir_path, | |
eval_steps=max_steps, | |
max_steps=max_steps, | |
warmup_steps=100, | |
fp16=True, | |
# logging_steps=100, | |
logging_steps=1, | |
eval_strategy="steps", | |
) | |
def plot_learning_curve(df, trial_dir_path, trial): | |
loss_vals = np.array(df["loss"]) | |
steps = np.array(df["step"]) | |
_, ax = plt.subplots() | |
ax.plot(steps, loss_vals, label="Train Loss") | |
ax.legend() | |
ax.grid() | |
ax.set_xlabel("Number of Steps") | |
ax.set_ylabel("Train Loss") | |
df.to_csv() | |
fig_path = os.path.join(trial_dir_path, "learning-curve.png") | |
plt.savefig(fig_path) | |
plt.close() | |
return fig_path | |
def _objective(trial, qa_dataset, trial_dir_path): | |
model = build_model() | |
tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/tokenizer") | |
data_collator = transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False) | |
train_params = suggest_train_params(trial, trial_dir_path) | |
trainer = transformers.Trainer( | |
model=model, | |
train_dataset=qa_dataset["train"], | |
eval_dataset=qa_dataset["validation"], | |
args=train_params, | |
data_collator=data_collator, | |
) | |
model.config.use_cache = False | |
trainer.train() | |
model.config.use_cache = True | |
df = pd.DataFrame(trainer.state.log_history) | |
log_path = os.path.join(trial_dir_path, "log_history.csv") | |
model_path = os.path.join(trial_dir_path, "trained_model") | |
trainer.save_model(model_path) | |
print("model save finished") | |
inference_path = inference(model_path, trial_dir_path, trial) | |
print("inference finished") | |
fig_path = plot_learning_curve(df, trial_dir_path, trial) | |
print("plot finished") | |
df.to_csv(log_path) | |
print("log finished") | |
tar_path = os.path.join(trial_dir_path, "model.tar.gz") | |
with tarfile.open(tar_path, "w:gz") as tar: | |
tar.add(model_path) | |
print("tar finished") | |
for file_path in [log_path, inference_path, fig_path, tar_path]: | |
optuna.artifacts.upload_artifact( | |
study_or_trial=trial, file_path=file_path, artifact_store=artifact_store | |
) | |
valid_loss = float(df["eval_loss"].max(skipna=True)) | |
return np.inf if np.isnan(valid_loss) else valid_loss | |
def objective(trial, qa_dataset): | |
with tempfile.TemporaryDirectory() as trial_dir_path: | |
value = _objective(trial, qa_dataset, trial_dir_path) | |
return value | |
def inference(model_path, trial_dir_path, trial): | |
config = peft.PeftConfig.from_pretrained(model_path) | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
config.base_model_name_or_path, | |
return_dict=True, | |
load_in_8bit=False, | |
device_map="auto", | |
) | |
tokenizer = transformers.AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
qa_model = peft.PeftModel.from_pretrained(model, model_path) | |
def _inference(context, question): | |
template = f"### CONTEXT\n{context}\n\n### QUESTION\n{question}\n\n### ANSWER\n" | |
batch = tokenizer(template, return_tensors="pt").to(model.device) | |
with torch.cuda.amp.autocast(): | |
output_tokens = qa_model.generate(**batch, max_new_tokens=200) | |
return tokenizer.decode(output_tokens[0], skip_special_tokens=True) | |
results = [] | |
questions = question_dict["question"] | |
contexts = question_dict["context"] | |
for question, context in zip(questions, contexts): | |
answer = _inference(context, question).split("ANSWER")[-1][1:] | |
results.append(dict(context=context, question=question, answer=answer)) | |
inference_path = os.path.join(trial_dir_path, "inference_results.jsonl") | |
pd.DataFrame(results).to_json(inference_path, orient="records", lines=True) | |
return inference_path | |
if __name__ == "__main__": | |
qa_dataset = create_qa_dataset() | |
study = optuna.create_study(storage="sqlite:///demo.db", study_name="demo") | |
study.optimize(lambda trial: objective(trial, qa_dataset), n_trials=10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tf_keras==2.17.0 | |
bitsandbytes==0.43.3 | |
datasets==2.20.0 | |
accelerate==0.33.0 | |
loralib==0.1.2 | |
peft==0.12.0 | |
transformers==4.43.3 | |
pandas==2.2.2 | |
optuna==4.0.0b0 | |
matplotlib==3.9.1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note
You need to have a GPU with 16+GB RAM to run this experiment.
The expected runtime of this experiment is around 1h.