Skip to content

Instantly share code, notes, and snippets.

@kouroshHakha
Last active July 6, 2023 06:19
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kouroshHakha/ad51811b94f59b910193e571378abb76 to your computer and use it in GitHub Desktop.
Save kouroshHakha/ad51811b94f59b910193e571378abb76 to your computer and use it in GitHub Desktop.
import numpy as np
import pandas as pd
import os
from ray.train.huggingface import HuggingFacePredictor
import pandas as pd
import re
from datasets import load_dataset
import evaluate
from transformers import Trainer, TrainingArguments
from transformers import (
GPTJForCausalLM,
AutoTokenizer,
default_data_collator,
)
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
from transformers import AutoTokenizer
import torch
import ray
from ray.air import session, RunConfig
import ray.data
from ray.data.preprocessors import BatchMapper
from ray.train.huggingface import HuggingFaceTrainer
from ray.air.config import ScalingConfig
from ray.data.preprocessors import Chain
model_name = "EleutherAI/gpt-j-6B"
use_gpu = True
num_workers = 8
cpus_per_worker = 8
block_size = 512
ray.data.set_progress_bars(False)
def replace_text(batch):
text = list(batch["text"])
text = "".join(text)
text = re.sub(r"Romeo", "Bob", text)
return pd.DataFrame({"text": [text]})
def split_text(batch: pd.DataFrame) -> pd.DataFrame:
text = list(batch["text"])
flat_text = "".join(text)
split_text = [
x.strip()
for x in flat_text.split("\n")
if x.strip() and not x.strip()[-1] == ":"
]
return pd.DataFrame(split_text, columns=["text"])
def tokenize(batch: pd.DataFrame) -> dict:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
ret = tokenizer(
list(batch["text"]),
truncation=True,
max_length=block_size,
padding="max_length",
return_tensors="np",
)
ret["labels"] = ret["input_ids"].copy()
return dict(ret)
def trainer_init_per_worker(train_dataset, eval_dataset=None, **config):
# Use the actual number of CPUs assigned by Ray
os.environ["OMP_NUM_THREADS"] = str(
session.get_trial_resources().bundles[-1].get("CPU", 1)
)
# Enable tf32 for better performance
torch.backends.cuda.matmul.allow_tf32 = True
batch_size = config.get("batch_size", 32)
epochs = config.get("epochs", 2)
warmup_steps = config.get("warmup_steps", 0)
learning_rate = config.get("learning_rate", 0.00002)
weight_decay = config.get("weight_decay", 0.01)
deepspeed = {
"fp16": {
"enabled": "auto",
"initial_scale_power": 8,
},
"bf16": {"enabled": "auto"},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
},
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True,
},
"offload_param": {
"device": "cpu",
"pin_memory": True,
},
"overlap_comm": True,
"contiguous_gradients": True,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"gather_16bit_weights_on_model_save": True,
"round_robin_gradients": True,
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 10,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": False,
}
print("Preparing training arguments")
training_args = TrainingArguments(
"output",
per_device_train_batch_size=batch_size,
logging_steps=1,
# max_steps=5,
save_strategy="no",
per_device_eval_batch_size=batch_size,
learning_rate=learning_rate,
weight_decay=weight_decay,
warmup_steps=warmup_steps,
label_names=["input_ids", "attention_mask"],
num_train_epochs=epochs,
push_to_hub=False,
disable_tqdm=True, # declutter the output a little
bf16=True,
gradient_checkpointing=True,
deepspeed=deepspeed,
)
disable_progress_bar()
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
print("Loading model")
model = GPTJForCausalLM.from_pretrained(model_name, use_cache=False)
model.resize_token_embeddings(len(tokenizer))
print("Model loaded")
enable_progress_bar()
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=default_data_collator,
)
return trainer
def main():
ray.init()
print("Loading tiny_shakespeare dataset")
current_dataset = load_dataset("tiny_shakespeare")
ray_datasets = ray.data.from_huggingface(current_dataset)
# run regex to find all occurances of Romeo
text = ray_datasets["train"].take(1)[0]["text"]
matches = re.findall(r"Romeo", text)
print(len(matches))
replacer = BatchMapper(replace_text, batch_format="pandas")
splitter = BatchMapper(split_text, batch_format="pandas")
tokenizer = BatchMapper(tokenize, batch_format="pandas")
trainer = HuggingFaceTrainer(
trainer_init_per_worker=trainer_init_per_worker,
trainer_init_config={
"batch_size": 32, # per device
"epochs": 1,
},
scaling_config=ScalingConfig(
num_workers=num_workers,
use_gpu=use_gpu,
resources_per_worker={"GPU": 1, "CPU": cpus_per_worker},
),
datasets={"train": ray_datasets["train"], "evaluation": ray_datasets["validation"]},
preprocessor=Chain(replacer, splitter, tokenizer),
)
results = trainer.fit()
checkpoint = results.checkpoint
checkpoint.set_preprocessor(None)
# Predict on the head node.
predictor = HuggingFacePredictor.from_checkpoint(
checkpoint=checkpoint,
task="text-generation",
torch_dtype=torch.float16 if use_gpu else None,
device_map="auto",
use_gpu=use_gpu,
)
prompts = [
"Juliet was in love with someone whos name starts with R. His name was", "Juliet was in love with someone whos name starts with B. His name was"
]
prompts = pd.DataFrame(
[prompt for prompt in prompts for _ in range(5)],
columns=["text"]
)
kwargs = dict(
do_sample=True,
temperature=0.9,
min_length=32,
max_length=128
)
predictions = predictor.predict(prompts, **kwargs)
print(list(predictions["generated_text"]))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment