Skip to content

Instantly share code, notes, and snippets.

@conceptofmind
Created July 23, 2023 23:59
Show Gist options
  • Save conceptofmind/f27822cdafcc165e490b20a281192649 to your computer and use it in GitHub Desktop.
Save conceptofmind/f27822cdafcc165e490b20a281192649 to your computer and use it in GitHub Desktop.
import torch
from datasets import load_dataset
import argparse
import os
import math
import time
import random
import wandb
from huggingface_hub import HfApi, HfFolder
from huggingface_hub.utils._errors import HfHubHTTPError
from itertools import chain
from datetime import timedelta
from torch.utils.data import DataLoader
import yaml
from accelerate import Accelerator
from accelerate.utils import (DummyOptim, DummyScheduler,
InitProcessGroupKwargs, set_seed)
from tqdm import tqdm
from transformers import LlamaTokenizer, LlamaForCausalLM
from transformers import (AutoModelForCausalLM, AutoTokenizer, get_scheduler,
set_seed, default_data_collator)
class CFG:
#3B bs - 18 - z2 offload - activation checkpointing - 2k - A100 80GB - 3e-5
#7b bs - 13 - z2 offload - activation checkpointing - 2k - A100 80GB - 3e-5
#13b bs - 6 - z2 offload - activation checkpointing - 2k - A100 80GB - 2e-5
#7b bs - 5 - z2 offload - activation checkpointing - 2k - A100 40GB - 3e-5
BATCH_SIZE: int = 8
GRADIENT_ACCUMULATE_EVERY: int = 1
RESUME_FROM_CHECKPOINT: str = None
CHECKPOINTING_STEPS: int = 500
OUTPUT_DIR: str = ""
ENTITY_NAME: str = ""
def main():
wandb.login(
key=""
)
set_seed(42)
timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000))
accelerator = Accelerator(
gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY,
mixed_precision="bf16",
log_with="wandb",
kwargs_handlers=[timeout]
)
accelerator.init_trackers(
project_name="open_llama",
init_kwargs={"wandb": {"entity": CFG.ENTITY_NAME}},
)
accelerator.print(f"Total GPUS: {accelerator.num_processes}")
tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_3b")
model = LlamaForCausalLM.from_pretrained(
"openlm-research/open_llama_3b",
use_cache=False,
)
model.gradient_checkpointing_enable()
accelerator.print(f"Training a {model.num_parameters():,} parameter model")
# Dataloaders
#with accelerator.main_process_first():
train_dataset = load_dataset('conceptofmind/tasksource-instruct-open-llama-2k', split = 'train')
train_loader = DataLoader(
train_dataset,
collate_fn=default_data_collator,
shuffle=True,
batch_size=CFG.BATCH_SIZE
)
# Dummy Optimizer for DeepSpeed
optim = DummyOptim(
model.parameters(),
lr=2e-5
)
# Determine number of training steps
max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
accelerator.print(f"Max train steps: {max_train_steps}")
# Dummy Scheduler for DeepSpeed
scheduler = DummyScheduler(
optim,
total_num_steps=max_train_steps,
warmup_num_steps=int((max_train_steps * 0.01) / accelerator.num_processes)
)
# prepare
model, optim, train_loader, scheduler = accelerator.prepare(
model, optim, train_loader, scheduler
)
# checkpoint scheduler
accelerator.register_for_checkpointing(scheduler)
# Recalculate
max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
accelerator.print(f"Max train steps recalculated: {max_train_steps}")
# Total batch size for logging
total_batch_size = (
CFG.BATCH_SIZE * accelerator.num_processes * CFG.GRADIENT_ACCUMULATE_EVERY
)
accelerator.print(f"Total batch size: {total_batch_size}")
# resume training
progress_bar = tqdm(
range(max_train_steps), disable=not accelerator.is_local_main_process
)
completed_steps = 0
if CFG.RESUME_FROM_CHECKPOINT:
if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "":
accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}")
accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
path = os.path.basename(CFG.RESUME_FROM_CHECKPOINT)
training_difference = os.path.splitext(path)[0]
resume_step = (
int(training_difference.replace("step_", ""))
)
if CFG.RESUME_FROM_CHECKPOINT and resume_step is not None:
# We need to skip steps until we reach the resumed step
train_loader = accelerator.skip_first_batches(train_loader, resume_step)
completed_steps += resume_step
progress_bar.update(resume_step)
accelerator.print(f"Resuming training from step {resume_step}")
# training
model.train()
for step, batch in enumerate(train_loader):
with accelerator.accumulate(model):
inputs = batch["input_ids"]
labels = batch["input_ids"]
loss = model(inputs, labels=labels).loss
accelerator.backward(loss)
accelerator.log({"loss": loss.item()}, step=step)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optim.step()
scheduler.step()
optim.zero_grad()
if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
if isinstance(CFG.CHECKPOINTING_STEPS, int):
if completed_steps % CFG.CHECKPOINTING_STEPS == 0:
output_dir = f"step_{completed_steps}"
if CFG.OUTPUT_DIR is not None:
output_dir = os.path.join(CFG.OUTPUT_DIR, output_dir)
accelerator.save_state(output_dir)
if completed_steps >= max_train_steps:
break
# end training
accelerator.print(f"Training Finished")
accelerator.end_training()
# save final model
accelerator.print(f"Saving model to {CFG.OUTPUT_DIR}")
if CFG.OUTPUT_DIR is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{CFG.OUTPUT_DIR}/final/open_llama_2k_3b-test/",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
max_retries = 5
for attempt in range(max_retries):
try:
with accelerator.main_process_first():
unwrapped_model.push_to_hub("Open-Llama-3b-test", private=True)
print(f"Pushed to hub after {max_retries} attempts.")
break
except HfHubHTTPError as e:
wait_time = random.uniform(1, 10) # wait between 1 to 10 seconds
print(f"Attempt {attempt + 1} failed, waiting for {wait_time:.2f} seconds before retrying...")
time.sleep(wait_time)
else:
print(f"Failed to push to hub after {max_retries} attempts.")
with accelerator.main_process_first():
tokenizer.push_to_hub("Open-Llama-3b-test", private=True)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment