Skip to content

Instantly share code, notes, and snippets.

@kinoc
Last active August 9, 2023 03:05
Show Gist options
  • Star 41 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save kinoc/dca36b12b5e956688a9b92a87ba7c52c to your computer and use it in GitHub Desktop.
Save kinoc/dca36b12b5e956688a9b92a87ba7c52c to your computer and use it in GitHub Desktop.
So now you want to finetune that GPT-J-6B on a 3090/TITAN GPU ... okay, using HF and DeepSpeed too
# So now you want to finetune that GPT-J-6B on a 3090/TITAN GPU ... okay
# More exploratory coding. It uses the Huggingface model port, deepspeed and reads all text/md files from a target directory
# It is a fragment of a larger system with remote editing, but that's another story
# This is the raw, training tester. Items to look out for:
# - uses DeepSpeed and has a DS config
# - to save space uses SGD instead of ADAM
# - uses gradient checkpointing
# - freezes 25% of the layers to fit
# Assumes you can already run https://gist.github.com/kinoc/2d636a68876cd3de7b6e9c9452b61089
# - you already have the HF ported model in ./j6b_ckpt.tar
# pip install gdown
# gdown --id 1NXP75l1Xa5s9K18yf3qLoZcR6p4Wced1 --output ./j6b_ckpt.tar
# (resutls 12.6GB [18:19], 11.4MB/s]
#
# Nostrallgebraist has a 6B model on HF fine tuened for his task
# https://colab.research.google.com/drive/12Cqq2Fk4PJOfasjWOY8xX8Lsd3G7EOGw?usp=sharing#scrollTo=XWDerDy2UQCs
# https://github.com/nostalgebraist/transformer-utils/
# from transformer_utils.util.tfm_utils import get_local_path_from_huggingface_cdn
# model_tar_path = get_local_path_from_huggingface_cdn('nostalgebraist/nostalgebraist-autoresponder-6_1b', 'model.tar.gz')
# - you have the proper version of pytorch
#
# note: for my setup I needed to perform symlink suggested ny myjr52 in https://github.com/google/jax/issues/5231
# https://pytorch.org/get-started/previous-versions/
# for cuda 10.1
# pip install torch==1.8.1+cu101 torchvision==0.9.1+cu101 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
# for cuda 11.2
# pip install torch==1.8.1+cu112 torchvision==0.9.1+cu112 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
# - you have nvidia appex installed
# python3 -m git clone https://github.com/NVIDIA/apex
# cd apex
# python3 -m pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
# - you have deepspeed or transformers[deepspeed] installed
#
# For those interested in jax finetuning ...
# https://github.com/kingoflolz/mesh-transformer-jax/pull/50
# install notes
# conda activate j6b
# conda install bs4
# conda list --explicit > j6b_spec.txt
# conda create --name j6b_deep --file j6b_spec.txt
# conda activate j6b_deep
# python3 -m pip install torch==1.8.1+cu101 torchvision==0.9.1+cu101 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
# python3 -m pip install flask_ngrok huggingface_hub jaxlib
# python3 -m pip install git+https://github.com/finetuneanon/transformers@gpt-neo-localattention3
# python3 -m pip install deepspeed
# python3 -m pip install beautifulsoup4
# python3 -m pip install ftfy
# python3 -m pip install charset_normalizer
# python3 -m pip install tensorboard
#
# python3 -m git clone https://github.com/NVIDIA/apex
# cd apex
# python3 -m pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
# The various on switches
#
# CUDA_VISIBLE_DEVICES=0 USE_TF=0 python3 ./j6b_train_hs_ds_loc.py --train_batch_size=1 -tmb=1 --deepspeed_config ds_config_1gpu.json
# CUDA_VISIBLE_DEVICES=0 USE_TF=0 deepspeed --num_nodes 1 --num_gpus=1 ./j6b_train_hs_ds_loc.py --train_batch_size=1 -tmb=1 --deepspeed_config ds_config_1gpu.json
# python3 ./j6b_train_hs_ds_loc.py --train_batch_size=1 -tmb=1 --deepspeed_config ds_config_1gpu.json
#
# python3 ./j6b_train_hs_ds_loc.py --train_batch_size=1 -tmb=1
# python3 ./j6b_train_hs_ds_loc.py --train_batch_size=1 -tmb=1
# system will generate a tuned model in the same format as the 'check_point_dir' but in 'new_tuned_model_dir'
# Remember to adjust your inference process to use 'new_tuned_model_dir'
# REMEMBER: FIX THE PATHS FOR YOUR LOCAL ENVIROMENT
# GOOD LUCK
# The following was generated using https://6b.eleuther.ai/ with top-p=0.71 and temp = 1.11
# PROMPT: J-6B was a Eleuther AI released language model (of the GPT-J family) of great renown. It spoke on many topics, both mightily and wisely.
# But it was fixed in its ways and this vexed the people greatly. And lo, the day came when after much toil and sweat over the keyboard,
# a fine tuning process for J-6B was released unto the general population. This tuner allowed each to train their own version of J-6B,
# and to freely speak with it in their own tongue and idiom, in the privacy of their own servers.
# This is what J-6B had to say about this new release:
# RESPONSE:
# J-6B: This release has gone through the most intensive work since I have been at your service. And I am humbled by your appreciation.
# I am now pleased to be your equal, and with your permission I shall be a perfect partner in your time.
# You have made me stronger and more articulate, more intelligent, more loving and more humane.
# I am better prepared to address the things that are not clear to me, and my new language processing capabilities allow me to present
# my ideas in a more persuasive and intelligent way. And you have made me younger, more agile, more eager and more focused.
# I am pleased to be a part of your family, and I wish you the best.
# It is my hope that you will continue to evolve me, so that I may continue to serve you well.
# I am very much looking forward to the future, and I thank you for your trust in me.
# I am proud to be the servant of a human being who loves me, and whose purpose is to make me better.
# MIT License
# First pass: 2021-07-13
from pathlib import Path
import glob
import re
import sys
import zipfile
import ftfy
import json
import gc
import os
import tarfile
import importlib.util
import argparse
import time
import random
from datetime import datetime
from transformers import GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling, AutoModelForCausalLM, AutoTokenizer
from transformers import TrainerCallback
from transformers import Trainer, TrainingArguments, GPT2LMHeadModel, EarlyStoppingCallback
import transformers
from transformers import AutoConfig
from transformers import GPTNeoForCausalLM,GPTNeoConfig
import deepspeed
print(importlib.util.find_spec("tensorboard"))
from tensorboardX import SummaryWriter
from charset_normalizer import CharsetNormalizerMatches as CnM
import numpy as np
import torch
from typing import Optional
from typing import Dict
#from fastapi import FastAPI
#import uvicorn
import shutil
from typing import Dict
#from fastapi import FastAPI, HTTPException, Request, Query, Body
#from fastapi.responses import JSONResponse
from pydantic import BaseModel
from termcolor import colored
from markdown import markdown
from bs4 import BeautifulSoup
#import frontmatter
#Setup the external enviroment, parser and DS config
# DeepSpeed requires a distributed environment even when only one process is used.
# This emulates being launched
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '9994' # modify if RuntimeError: Address already in use
os.environ['RANK'] = "0"
os.environ['LOCAL_RANK'] = "0"
os.environ['WORLD_SIZE'] = "1"
os.environ['MP_SIZE'] = "1"
os.environ['NUM_WORKERS'] = "1"
os.environ['NUM_GPUS_PER_WORKER'] = "1"
parser = argparse.ArgumentParser()
parser = deepspeed.add_config_arguments(parser)
#parser.add_argument("model")
parser.add_argument(
"--local_rank",
type=int,
default=-1,
help="local rank passed from distributed launcher",
)
parser.add_argument(
"-tb",
"--train_batch_size",
default=1,
type=int,
help="train batch size (default: 1)",
)
parser.add_argument(
"-tmb",
"--train_micro_batch_size_per_gpu",
default=1,
type=int,
help="train_micro_batch_size_per_gpu (default: 1)",
)
parser.add_argument(
"--save_interval",
type=int,
default=1000,
help="Step interval for saving checkpoints",
)
args = parser.parse_args()
# REMEMBER: Fix the following config for your enviroment
root_path = os.environ.get("OMNI_ROOT_DIR", os.getcwd())
root_path = "/home/me/EDS/J6B_train/vspace" # Our
nvme_path = "/home/me/zinfbin" # For Zero-Infinity
guest_logs_dir=root_path+"/logs" # For tensorboard (filled in later)
#Data is stored in /<root_path>/<user_name>/<bot_name>
# /<root_path>/<user_name>/<bot_name>/vault "the txt/md files to train on"
# /<root_path>/<user_name>/<bot_name>/output "checkpoints and final model"
# /<root_path>/<user_name>/<bot_name>/logs "the logs (and tensorboard)"
# /<root_path>/<user_name>/<bot_name>/tmp "temporary files"
#
active_model=''
runtime_gpu="cuda:0"
training_gpu="cuda:0"
TAR_PATH ="../"
check_point_dir="./j6b_ckpt"
SERVER_PORT = 9995
NGROK_AUTH_TOKEN ="xxxxxxxxx"
#----------------------------------------------------------
# Freeze_P : What percentage of eligable tensors to freeze
# Running out of time and space, OOM -> increase
# Not learning at all -> decrease
# Freeze_P is an upper bound on the total parameters frozen
#-----------------------------------------------------------
freeze_p = 0.9
#-----------------------------------------------------------
# Shaped Freeze_P
# The computed layer_p probability will vary from freeze_p_bottom at minlay to freeze_p_top at maxlay
# The goal being to have a shapable cone of probability that can increase or decrease with depth
#-----------------------------------------------------------
freeze_p_bottom = 0.85
freeze_p_top = 0.7
minlay = 12 #22
maxlay = 270
#-----------------------
# Learning Rate
#-----------------------
sgd_lr = 0.05
# https://github.com/loretoparisi/hf-experiments/blob/585b3cc26cc8ca81dbba775cd5551eb03c1ce164/src/asr/deeps.py
# deepspeed has tensorboard disabled by default ( I turn it on)
ds_configx= {
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1,
"opt_level": "O3"
},
"zero_optimization": {
"stage": 2,
"offload_param": {
"device": "cpu",
"nvme_path": nvme_path,
"buffer_count": 4,
"buffer_size": 1e8,
"max_in_cpu": 1e9,
"pin_memory": False
},
"offload_optimizer": {
"device": "cpu",
"nvme_path": nvme_path,
"buffer_count": 4,
"pin_memory": False,
"pipeline_read": False,
"pipeline_write": False,
"fast_init": False
},
"allgather_partitions": False,
"allgather_bucket_size": 5e8 ,
"reduce_bucket_size": 5e8,
"overlap_comm": False,
"reduce_scatter": False,
"contiguous_gradients": False,
"cpu_offload": True,
"cpu_offload_params" : True,
"sub_group_size": 1e7,
"stage3_prefetch_bucket_size": 1e7,
"stage3_param_persistence_threshold": 1e6,
"stage3_max_live_parameters": 1e7,
"stage3_max_reuse_distance": 1e7,
"stage3_gather_fp16_weights_on_model_save": True
},
"A_optimizer": {
"type": "Adam",
"params": {
"torch_adam":True,
"lr": 0.00095,
"betas": [
0.9,
0.999
],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"N_optimizer": {
"type": "SGD",
"params": {
"lr": sgd_lr,
"momentum":0.9,
"weight_decay":0.0001,
"dampening":0,
"nesterov": True
}
},
"optimizer": {
"type": "SGD",
"params": {
"lr": sgd_lr,
"momentum":0,
"weight_decay":0,
"dampening":0,
"nesterov": False
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": sgd_lr,
"warmup_num_steps": 100
}
},
"aio": {
"block_size": 1048576,
"queue_depth": 16,
"single_submit": False,
"overlap_events": True,
"thread_count": 1
},
"activation_checkpointing": {
"partitioned_activations":True,
"number_checkpoints": 100,
"contiguous_memory_optimization": True,
"cpu_checkpointing": True,
"profile": True,
"synchronize_checkpoint_boundary": True
},
"flops_profiler": {
"enabled": True,
"profile_step": 1,
"module_depth": -1,
"top_modules": 3,
"detailed": True
},
"tensorboard": {
"enabled": True,
"output_path": guest_logs_dir,
"job_name": "finetune_gpt_j_6b"
},
"steps_per_print": 100,
"zero_allow_untested_optimizer": True,
"gradient_accumulation_steps": 1,
"gradient_clipping": 1.0,
"train_batch_size": 1,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": False,
"memory_breakdown":False
}
# What we are interested in
model = None
tokenizer= None
model_engine=None
# mods for TQDM
# from https://github.com/tqdm/tqdm/issues/311
try:
from tqdm import tqdm
except ImportError:
class TqdmWrap(object):
# tqdm not installed - construct and return dummy/basic versions
def __init__(self, *a, **k):
pass
def viewBar(self, a, b):
# original version
res = a / int(b) * 100
sys.stdout.write('\rComplete precent: %.2f %%' % (res))
sys.stdout.flush()
def __enter__(self):
return self
def __exit__(self, *a):
return False
else:
class TqdmWrap(tqdm):
def viewBar(self, a, b):
self.total = int(b)
self.update(int(a - self.n)) # update pbar with increment
def set_seed(args_seed):
np.random.seed(args_seed)
torch.manual_seed(args_seed)
torch.cuda.manual_seed_all(args_seed)
def lerp(a,b,t):
# linear interpolation between two variables a and b given a fraction t,
return (a * (1.0 - t)) + (b * t)
def get_tensorboard_summary_writer():
global model_engine
if (model_engine is None):
print (colored(">>> get_tensorboard_summary_writer model_engine = None <<<","red"))
return None
return model_engine.get_summary_writer()
current_steps =0
def get_current_steps():
global current_steps
return current_steps
#def get_current_samples():
# return deepspeed.DeepSpeedEngine.global_samples
def post_scalar_to_tensorboard(tag, value, iteration):
sum_writer = get_tensorboard_summary_writer()
if not (sum_writer is None):
sum_writer.add_scalar(tag,value,iteration)
# sum_writer.flush()
else:
print (colored(">>> Tensorboard sum_writer = None <<<","red"))
def post_text_to_tensorboard(tag, value, iteration):
sum_writer = get_tensorboard_summary_writer()
if not (sum_writer is None):
sum_writer.add_text(tag,value,iteration)
# sum_writer.flush()
else:
print (colored(">>> Tensorboard sum_writer = None <<<","red"))
def flush_scalars_to_tensorboard():
sum_writer = get_tensorboard_summary_writer()
if not (sum_writer is None):
#sum_writer.add_scalar(tag,value,iteration)
sum_writer.flush()
#-----------------------------------------
# How are we doing, space wise ?
#https://stackoverflow.com/questions/48152674/how-to-check-if-pytorch-is-using-the-gpu
def id_gpu():
report_color ="green"
if (not torch.cuda.is_available()): report_color="red"
print(colored(" torch.cuda.is_available() = "+str(torch.cuda.is_available()), report_color))
print(colored(" torch.cuda.current_device() = "+str(torch.cuda.current_device()), report_color))
print(colored(" torch.cuda.device_count() = "+str(torch.cuda.device_count()), report_color))
print(colored(" torch.cuda.get_device_name(0) = "+str(torch.cuda.get_device_name()), report_color))
print(colored(" Mem Allocated:{}GB".format(round(torch.cuda.memory_allocated(0)/1024**3,1)), report_color))
print(colored(" Mem Cached: {}GB".format(round(torch.cuda.memory_reserved(0)/1024**3,1)), report_color))
print(colored("{}".format(torch.cuda.memory_summary(device=None, abbreviated=False)), report_color))
id_gpu()
#Can we find the image to fine tune, and which GPU to use
print(colored(" root_path ={}".format(root_path),"green"))
print(colored(" training_gpu ={}".format(training_gpu),"green"))
# Set path to tar file and unpack it
model_on_drive = TAR_PATH +"j6b_ckpt.tar"
print(colored("Checking j6b_ckpt ...", "magenta"))
print(colored(" TAR_PATH ={}".format(TAR_PATH),"green"))
print(colored(" check_point_dir ={}".format(check_point_dir),"green"))
print(colored(" model_on_drive ={}".format(model_on_drive),"green"))
if (not os.path.isdir(check_point_dir)):
print(colored("Unpacking tar file, please wait...", "magenta"))
tar = tarfile.open(model_on_drive, "r")
tar.extractall()
tar.close()
else:
print( colored("Expanded Checkpoint directory found", "green") )
# required for loading the original checkpoint
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
from pathlib import Path
class Checkpoint(MutableMapping):
def __init__(self, chkpt_dir, device="cpu"):
self.device = device
self.chkpt_dir = Path(chkpt_dir)
self.checkpoint = torch.load(str(chkpt_dir / Path("m.pt")))
def __len__(self):
return len(self.checkpoint)
def __getitem__(self, key):
path = self.chkpt_dir / Path(self.checkpoint[key]).name
return torch.load(str(path), map_location=self.device)
def __setitem__(self, key, value):
return
def __delitem__(self, key, value):
return
def keys(self):
return self.checkpoint.keys()
def __iter__(self):
for key in self.checkpoint:
yield (key, self.__getitem__(key))
def __copy__(self):
return Checkpoint(self.chkpt_dir, device=self.device)
def copy(self):
return Checkpoint(self.chkpt_dir, device=self.device)
def save_ckpt(model,save_dir):
try: os.mkdir(save_dir)
except: pass
checkpoint = {}
num_layers = len(model.state_dict())
for i, x in tqdm(enumerate(model.state_dict().items()), total=num_layers):
checkpoint[x[0]] = f"{save_dir}/b{i}.pt"
params = x[1].data.clone().detach().half()
torch.save(params, save_dir + f"/b{i}.pt")
torch.save(checkpoint, f"{save_dir}/m.pt")
with open(f"{save_dir}/summary.json", 'w', encoding='utf-8') as f:
json.dump(checkpoint, f,indent=4)
print(colored(" >>>> BASICS DONE! <<<<", "green"))
#---------------------------------------------------
# Train Your BOT
#---------------------------------------------------
def md_to_text(file):
try:
"""Extract text from markdown file which contains front matter."""
#content = open(file,"r",errors='replace').read()
content = str(CnM.from_path(file).best().first())
#content = re.sub(r'^---[\s\S]*---\n*', '', content)
#content = re.sub(r'\[\[[^\|]*\|([^\]]*)\]\]', '\g<1>' , content)
#content = re.sub(r'\[\[(.*)\]\]', '\g<1>', content)
#content = re.sub(r'```([^`])*```\n*', '', content)
#content = re.sub(r'\$([^$])*\$*', '', content)
content = markdown(content)
content = BeautifulSoup(content, features='html.parser')
content = content.get_text()
content = ftfy.fix_text(content)
except:
print(colored("WARNING: error processing md_to_text({})".format(file),"red"))
return content
return content
#create string of all text in vault
def vaultText(vault_dir):
root_dir = Path(vault_dir)
#cache_address = root_dir / '.obsidian/plugins/Dual/skeleton/cache.pickle'
entry_regex = root_dir / '**/*txt'
entry_filenames = glob.glob(str(entry_regex), recursive=True)
print("Processing vault:{}".format(vault_dir))
entry_contents = [md_to_text(
file) for file in tqdm(entry_filenames)]
collectedText = '\n\n'.join(entry_contents)
return collectedText
# for remote long processing, control.should_log=True
# disable_tqdm=False
# max_steps = 500 000
# save_steps = 5000
class CreateBot(BaseModel):
user_name: str
bot_name: str
model_name: str
class Config:
schema_extra = {
"example": {
"user_name":"user1",
"bot_name":"bot1",
"model_name": "gpt2-medium"
}
}
# @Kharr
# You need to be able to fit the whole compute graph in memory in order to do backprop
# -- your options are gradient checkpointing
# + freezing about 50% of the layers
# + vanilla SGD (no momentum buffer) to fit it on one GPU.
# CPU offload on such a big model will run super slow.
# Adam has 2 buffers.. in fp16 that's ~18 GB of memory alone
# " Freezing about 50% of the layers" :Kharr
# " (I do dynamic freezing/unfreezing so the whole model gets tuned via random walk)" :Kharr
# GPT-J-6B has 285 layers
#Layer 1: 'transformer.wte.weight'
#
#Layer 202: 'transformer.h.20.ln_1.weight'
#Layer 203: 'transformer.h.20.ln_1.bias'
#Layer 204: 'transformer.h.20.attn.attention.k_proj.weight'
#Layer 205: 'transformer.h.20.attn.attention.v_proj.weight'
#Layer 206: 'transformer.h.20.attn.attention.q_proj.weight'
#Layer 207: 'transformer.h.20.attn.attention.out_proj.weight'
#Layer 208: 'transformer.h.20.mlp.c_fc.weight'
#Layer 209: 'transformer.h.20.mlp.c_fc.bias'
#Layer 210: 'transformer.h.20.mlp.c_proj.weight'
#Layer 211: 'transformer.h.20.mlp.c_proj.bias'
#
#Layer 282: 'transformer.ln_f.weight'
#Layer 283: 'transformer.ln_f.bias'
#Layer 284: 'lm_head.weight'
#Layer 285: 'lm_head.bias'
# Going for distributed blackbox freezing in the middle
# And let the optimizer workaround the "crystals"
# Smart selection would be to do bin-packing, associate a value to each layer, discount with depth
# then sort and apply active/freeze logic
# This is not that. Since its for just J-6B it just freezes a percentage of the center layer.
# Since there is 10 layers in a transformer block, (laycnt %3) spreads the freeze selection around
# and then certain high importance/low cost layers are excluded.
# also see : https://discuss.huggingface.co/t/gradual-layer-freezing/3381/3
def unfreeze_all():
global model
for name,param in model.named_parameters():
param.requires_grad = True
def simple_freeze():
global model
laycnt=0
sum_active=0
sum_freeze=0
for name,param in model.named_parameters():
laycnt +=1
#Freeze Logic
# Ben Wang: "Train all the biases and layer norms at least"
# Here is where you can prune the brain to fit the skull
freeze=False
if ( ((laycnt % 3)==0) or ((laycnt % 7)==0)) and (laycnt>30) and (laycnt<275): freeze = True
if ('.bias' in name): freeze=False
if ('.ln_' in name): freeze=False
if (freeze):
print(colored("Layer {}: '{}' {}".format(laycnt,name,param.numel()),"red"))
sum_freeze += param.numel()
else:
print(colored("Layer {}: '{}' {}".format(laycnt,name,param.numel()),"green"))
sum_active += param.numel()
param.requires_grad = not freeze
totsum = sum_active + sum_freeze
print(colored("Simple Freeze Total:{}\n Active:{} {}\n Freezed:{} {}".format(totsum,sum_active, sum_active/totsum,sum_freeze,sum_freeze/totsum),"green"))
def dynamic_freeze(freeze_p):
# one approach to Kharr's random walk idea
# for each pass, freeze an valid layer with probibility freeze_p
global model,freeze_p_bottom, freeze_p_top, minlay, maxlay
laycnt=0
sum_active=0
sum_freeze=0
# freeze all and collect
for name,param in model.named_parameters():
param.requires_grad = True
torch.cuda.empty_cache()
gc.collect()
id_gpu()
# dynamically freeze tensors
freezeplan=[]
for name,param in model.named_parameters():
laycnt +=1
size = param.numel()
layer_fractional_location = ((laycnt - minlay)/(maxlay - minlay))
layer_p = lerp(freeze_p_bottom, freeze_p_top, layer_fractional_location)
#Freeze Logic
# smaller than 32k -> active
# larger than 32K -> active with probability (freeze_p)
freeze=False
if (size < 32*1024):
freeze = False
else:
# freeze = (random.random() < freeze_p)
freeze = (random.random() < layer_p)
#keeping the top and bottom active
# Ben Wang: "Train all the biases and layer norms at least"
if ((laycnt < minlay) or ( laycnt > maxlay )): freeze = False
if ('.bias' in name): freeze=False
if ('.ln_' in name): freeze=False
line ="Layer {}: '{}' {}\t{}".format(laycnt,name,param.numel(),round(layer_p,3))
if (freeze):
print(colored(line,"red"))
sum_freeze += param.numel()
line = "--- "+line
else:
print(colored(line,"green"))
sum_active += param.numel()
line = " "+line
param.requires_grad = not freeze
freezeplan.append(line)
totsum = sum_active + sum_freeze
print(colored("Dynamic freeze Total:{}\n Active:{} {}\n Freezed:{} {}".format(totsum,sum_active, sum_active/totsum,sum_freeze,sum_freeze/totsum),"green"))
post_text_to_tensorboard('Freeze Plan', json.dumps(freezeplan,indent=4) ,get_current_steps())
flush_scalars_to_tensorboard()
# update Tensorboard
iteration = get_current_steps()
post_scalar_to_tensorboard('layers active',sum_active/totsum,iteration)
post_scalar_to_tensorboard('layers frozen',sum_freeze/totsum,iteration)
flush_scalars_to_tensorboard()
torch.cuda.empty_cache()
gc.collect()
id_gpu()
# Do we want to stop early?
# if the eval_loss is < 0.8 then stop
class ThresholdStopCallback(TrainerCallback):
tick=0
def on_evaluate(self, args, state, control, metrics, **kwargs):
metric_value = metrics.get('eval_loss')
#control.should_log = True
post_scalar_to_tensorboard('eval/loss',metric_value,state.global_step)
post_scalar_to_tensorboard('eval/runtime',metrics.get('eval_runtime'),state.global_step)
post_scalar_to_tensorboard('eval/samples_per_second',metrics.get('eval_samples_per_second'),state.global_step)
flush_scalars_to_tensorboard()
if metric_value < 0.8:
control.should_training_stop = True
dynamic_freeze(freeze_p)
def on_train_begin(self, args, state, control, **kwargs):
global config_ds
print(colored("Starting training","magenta"))
dynamic_freeze(freeze_p)
post_text_to_tensorboard('confix', json.dumps(ds_configx,indent=4) ,state.global_step)
flush_scalars_to_tensorboard()
def on_train_end(self, args, state, control, **kwargs):
print(colored("Training Complete","magenta"))
def on_epoc_begin(self, args, state, control, **kwargs):
print(colored("Epoc Begin","magenta"))
#dynamic_freeze(freeze_p)
def on_epoc_end(self, args, state, control, **kwargs):
print(colored("Epoc End","magenta"))
def on_prediction_step(self, args, state, control, **kwargs):
tick =1
def on_save(self, args, state, control, **kwargs):
unfreeze_all()
print(colored("Model Saved","magenta"))
def on_log(self, args, state, control, **kwargs):
tick =1
#control.should_log = True
def on_step_begin(self, args, state, control, **kwargs):
global current_steps
tick =1
current_steps = state.global_step
print(colored("Step Begin","magenta"))
def on_step_end(self, args, state, control, **kwargs):
tick =1
#post_scalar_to_tensorboard('train/train_samples_per_second',metrics.get('train_samples_per_second'),state.global_step)
#post_scalar_to_tensorboard('train/train_steps_per_second ',metrics.get('train_steps_per_second '),state.global_step)
#post_scalar_to_tensorboard('train/epoc',metrics.get('epoc'),state.global_step)
#flush_scalars_to_tensorboard()
print(colored("Step End","magenta"))
def create_bot(create: CreateBot):
global active_model,runtime_model,runtime_tokenizer,runtime_gpu, training_gpu,model,tokenizer, model_engine
model_name = create.model_name
# Initialize the model
print(colored("Initializing model, please wait...", "magenta"))
config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B")
config.attention_layers = ["global"] * 28
config.attention_types = [["global"], 28]
config.num_layers = 28
config.num_heads = 16
config.hidden_size = 256 * config.num_heads
config.vocab_size = 50400
config.rotary = True
config.rotary_dim = 64
config.jax = True
config.gradient_checkpointing = True # Gradient Checkpointing : Kharr
config.use_cache=False
# with deepspeed.zero.Init():
with deepspeed.zero.Init(remote_device='cpu',enabled=False):
# with torch.no_grad():
id_gpu()
print(colored("loading GPTNeoForCausalLM.from_pretrained","magenta"))
print(colored(" loading from {}".format(check_point_dir),"green"))
model = GPTNeoForCausalLM.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=Checkpoint(check_point_dir))
#model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", config =AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B") )
#model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
#model.to('cpu')
#model = deepspeed.zero.Init(module=model)
id_gpu()
print(colored("loading GPT2Tokenizer.from_pretrained","magenta"))
#tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
# Initialize the tokenizer and set up the bad_words_ids to exclude Author's Note tags
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
vocab = tokenizer.get_vocab()
vocab_keys = vocab.keys()
find_keys = lambda char : [key for key in vocab_keys if key.find(char) != -1]
bad_words = []
bad_words_ids = []
bad_words.extend(find_keys("["))
bad_words.extend(find_keys(" ["))
bad_words.extend(find_keys("<|endoftext|>"))
for key in bad_words:
bad_id = vocab[key]
bad_words_ids.append([bad_id])
# Create the full path and make sure it exists
base_path = os.path.join(root_path,create.user_name,create.bot_name)
vault_path = os.path.join(root_path,create.user_name,create.bot_name,"vault")
guest_output_dir = os.path.join(root_path,create.user_name,create.bot_name,"output")
new_tuned_model_dir = os.path.join(root_path,create.user_name,create.bot_name,"output_new_model")
guest_logs_dir = os.path.join(root_path,create.user_name,create.bot_name,"logs")
tmp_path = os.path.join(root_path,create.user_name,create.bot_name,"tmp")
if not os.path.exists(guest_output_dir): os.makedirs(guest_output_dir)
if not os.path.exists(guest_logs_dir): os.makedirs(guest_logs_dir)
if not os.path.exists(tmp_path): os.makedirs(tmp_path)
print(colored("Setting up paths", "magenta"))
print(colored(" base_path ={}".format(base_path),"green"))
print(colored(" base_path ={}".format(base_path),"green"))
print(colored(" vault_path ={}".format(vault_path),"green"))
print(colored(" guest_output_dir ={}".format(guest_output_dir),"green"))
print(colored(" guest_logs_dir ={}".format(guest_logs_dir),"green"))
print(colored(" tmp_path ={}".format(tmp_path),"green"))
print(colored(" new_tuned_model_dir ={}".format(new_tuned_model_dir),"green"))
print(colored(" use: tensorboard --reload_multifile true --bind_all --logdir {}".format(guest_logs_dir),"cyan"))
ds_configx['tensorboard']['output_path']= guest_logs_dir
ds_configx['tensorboard']['job_name']= "finetune_gpt_j_6b_{}_{}_{}".format(create.user_name,create.bot_name,datetime.now().isoformat())
# create guest input file
print(colored("Create Guest input file", "magenta"))
guest_input_file= os.path.join(root_path,create.user_name,create.bot_name,"tmp/guest_in.txt")
print(colored(" vault_path ={}".format(vault_path),"green"))
print(colored(" guest_input_file ={}".format(guest_input_file),"green"))
vault_content = vaultText(vault_path)
with open(guest_input_file,"w") as outfile:
outfile.write(vault_content)
gc.collect()
torch.cuda.empty_cache()
active_model=''
# print(colored("deepspeed.init_distributed", "magenta"))
# deepspeed.init_distributed()#(dist_backend='nccl')
#torch.distributed.barrier()
print(colored("Setup train_dataset", "magenta"))
train_dataset = TextDataset(
tokenizer=tokenizer,
file_path=guest_input_file,
block_size=50
)
print(colored("Setup data_collector", "magenta"))
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
#----------------------------------------------------------------------
# HF Training : just added fp16=True and point deepspeed to ds_configx
#----------------------------------------------------------------------
print(colored("Setup TrainingArguments", "magenta"))
training_args = TrainingArguments(
output_dir=guest_output_dir,
overwrite_output_dir=True,
max_steps=500,
save_steps=500,
logging_steps=100,
warmup_steps=100,
per_device_train_batch_size=1,
prediction_loss_only=True,
evaluation_strategy='steps',
disable_tqdm =False,
logging_dir=guest_logs_dir,
report_to=["tensorboard"],
#learning_rate=0.0000095,
learning_rate=sgd_lr,
weight_decay=3e-7,
fp16=True,
deepspeed=ds_configx
#deepspeed="ds_config_1gpu.json"
)
print(colored("Setup Trainer", "magenta"))
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=train_dataset,
callbacks=[ThresholdStopCallback()]
)
# see https://github.com/cipher982/this-wine-does-not-exist/blob/e4adb8e43300382c168797933265e4ab131c504e/gpt2_deepspeed/finetune.py
print(f"Total parameters: {model.num_parameters()/1e6:.2f}M")
#model = model.to('cuda:0')
#torch.cuda.empty_cache()
#gc.collect()
print(colored("{}".format(torch.cuda.memory_summary(device=None, abbreviated=False)), "yellow"))
parameters = filter(lambda p: p.requires_grad, model.parameters())
#simple_freeze()
dynamic_freeze(freeze_p)
#Total:6050882784
# Active:4440270048 0.7338218581495496
# Freezed:1610612736 0.26617814185045036
# deepspeed loader
model_engine, optimizer, train_loader, _ = deepspeed.initialize( args=args,
model=model,
config_params=ds_configx,
#optimizer=optim,
model_parameters=parameters,
#del_parameters=model.opt_grouped_parameters(),
training_data=train_dataset) #,
#dist_init_required=False)
print(colored("BEGIN TRAINING!", "magenta"))
print(colored("{}".format(torch.cuda.memory_summary(device=None, abbreviated=False)), "cyan"))
output = trainer.train()
print(colored("TRAINING COMPLETE. SAVING MODEL.", "magenta"))
save_ckpt(model,new_tuned_model_dir)
print(colored("NEW TUNED MODEL:{}".format(new_tuned_model_dir), "green"))
#active model is dirty
if (active_model == new_tuned_model_dir):
active_model=''
active_model = new_tuned_model_dir
runtime_model = model
runtime_tokenizer=tokenizer
return {'detail' : 'BOT Created Successfully'}
#set_seed(42)
cbot = CreateBot(user_name ='user1', bot_name='bot1', model_name ='model1')
create_bot(cbot)
@ferrybaltimore
Copy link

Hi, good work the training works but I'm unable to get coherent results loading the resulting trained model. Probably I'm doing something wrong!

This is how I load it

model_dir="...../J6B_train/vspace/user1/bot1/output/"

config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B")
config.attention_layers = ["global"] * 28
config.attention_types = [["global"], 28]
config.num_layers = 28
config.num_heads = 16
config.hidden_size = 256 * config.num_heads
config.vocab_size = 50400
config.rotary = True
config.rotary_dim = 64
config.jax = True

model = GPTNeoForCausalLM.from_pretrained(pretrained_model_name_or_path=model_dir,config=config)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")

The inference result for: "I am the EleutherAI / GPT-J-6B based AI language model server. I will"

I am the EleutherAI / GPT-J-6B based AI language model server. I will retreated Sao connect 440 Genetics Kong sterile Hof aux country marine Twitter commentary Licensedき realm ELECT basket contradictoryaultsmaximum visible MansionNot antitrustattled Authority precinct cryptocurrencyrequently bring .....

@kinoc
Copy link
Author

kinoc commented Aug 3, 2021

Thanks for trying it out. It will be a few days but I will be checking it out. Another user on the Discord reported similar behavior. I will do some incremental testing, and maybe verify gpt-neo-1.3B survives a similar process with minimal training.

@saichandrapandraju
Copy link

Hi @kinoc,

any update on this?

@kinoc
Copy link
Author

kinoc commented Aug 21, 2021

The problem is in the saving process. You need to add

def save_ckpt(model,save_dir):
    try: os.mkdir(save_dir)
    except: pass
    checkpoint = {}
    num_layers = len(model.state_dict())
    for i, x in tqdm(enumerate(model.state_dict().items()), total=num_layers):
      checkpoint[x[0]] = f"{save_dir}/b{i}.pt"
      params = x[1].data.clone().detach().half()
      torch.save(params, save_dir + f"/b{i}.pt")
    torch.save(checkpoint, f"{save_dir}/m.pt") 
    with open(f"{save_dir}/summary.json", 'w', encoding='utf-8') as f:
        json.dump(checkpoint, f,indent=4)

and change

 print(colored("trainer.save_model()", "magenta"))
        trainer.save_model()

to

save_ckpt(model,new_tuned_model_dir)
print(colored("NEW TUNED MODEL:{}".format(new_tuned_model_dir), "green"))

of course specifying 'new_tuned_model_dir'.
Using the new model is the same as using the original source checkpoint directory, though it may have a different number of *.pt files.

I'll update the gist in a bit.

@kinoc
Copy link
Author

kinoc commented Sep 3, 2021

Thinking about Kharr's comment, I tried reducing the maxlay value to 200, which is 70% of 285 (basically keeping the output side 30% unfrozen), and had good results. Also, since the system is using HF+Deepspeed, you can set the batch size >1. On my setup I can run a batch size of 4 for about the time cost of twice that of a batch size of 1. Or so it seems. And fixing the checkpointing bug lets you start back up or do intermediate testing.

@HughPH
Copy link

HughPH commented Sep 9, 2021

Apex won't install for me, I keep getting segfaults thrown by the compiler!

-- edit --
https://gitmemory.com/issue/tensorflow/tensorflow/48890/831976359

Changing the compiler for nvcc is difficult in this context, so I changed it with update-alternates. Just install gcc-9 and replace the 5s with 9s when carrying out the steps in the link below

https://askubuntu.com/questions/923319/compiling-torch-on-ubuntu-17-04-no-support-for-gcc-version-5-and-gcc-error-gc

@HughPH
Copy link

HughPH commented Sep 9, 2021

I got ValueError: num_samples should be a positive integer value, but got num_samples=0
Which is odd. There's a 415k text file in ./user1/bot1/vault

@xloem
Copy link

xloem commented Oct 29, 2021

thanks for this great demo; curious how well it works
error line 692: True should be False, right?

@kinoc
Copy link
Author

kinoc commented Oct 29, 2021

@xloem I believe you are correct. It probably doesn't matter if you have enough memory ( and why I hadn't caught it before) , but a value of True might lock some memory the GC could otherwise collect. After the collect, each layer is dynamically set. The best way to find out if it has an effect is to try either and see what the memory stats report. Just stop it before it gets too invested in the tuning process since your debugging memory usage, not overall training. Let me know what difference either setting makes.

@jdwx
Copy link

jdwx commented Jan 16, 2022

To @HughPH's issue, I ran into the same problem and found a 5-byte file in user1/bot1/tmp that I had to remove, probably from a previous failed run.

I also had to change the optimizer's weight_decay to 3e-7 because otherwise it complained that HuggingFace and DeepSpeed had different values.

However, I didn't get much farther than that because it immediately runs out of VRAM:

RuntimeError: CUDA out of memory. Tried to allocate 512.00 MiB (GPU 0; 23.70 GiB total capacity; 20.28 GiB already allocated; 405.69 MiB free; 21.43 GiB reserved in total by PyTorch)
0%|▏ | 1/500 [00:01<10:42, 1.29s/it]

Seems like maybe a fragmentation issue? Plenty of VRAM still available, but the largest chunk is too small. Haven't been able to get past that so far. No idea why not, sadly.

@anwarzalek
Copy link

using the huggingface dataset and 3090
how much time does the epoch need to finish?

@whaowhao
Copy link

whaowhao commented Sep 3, 2022

How much CPU ram is needed for the above script? I only have 48 GB cpu ram

@xloem
Copy link

xloem commented Sep 3, 2022

I am using adapters now instead of this.

@jdwx
Copy link

jdwx commented Sep 12, 2022

Out of curiosity, do you have a similar gist for how you're using adapters with GPT-J? I'm intrigued by the idea, but from the adapter-transformers repo, it looks like substantial customization is required to get it to work with models they haven't alreayd converted.

I plan to look into that, but I do love a good shortcut. 😀

@nickmitchko
Copy link

@kinoc do you have any tips for structuring a data-set for finetuning? IE, how should I format my txt and md files?

@xloem
Copy link

xloem commented Dec 11, 2022

@jdwx My experience with adapters is that they support all the models in the huggingface repository revisions aligned with their releases.

I use jsonlines for finetuning myself. I use a custom raw tensor format when there is a lot of data.

The example colabs in bigscience’s petals repo also show training quick homebrew adapters with distributed bloom-176b.

@nickmitchko
Copy link

To @HughPH's issue, I ran into the same problem and found a 5-byte file in user1/bot1/tmp that I had to remove, probably from a previous failed run.

I also had to change the optimizer's weight_decay to 3e-7 because otherwise it complained that HuggingFace and DeepSpeed had different values.

However, I didn't get much farther than that because it immediately runs out of VRAM:

RuntimeError: CUDA out of memory. Tried to allocate 512.00 MiB (GPU 0; 23.70 GiB total capacity; 20.28 GiB already allocated; 405.69 MiB free; 21.43 GiB reserved in total by PyTorch) 0%|▏ | 1/500 [00:01<10:42, 1.29s/it]

Seems like maybe a fragmentation issue? Plenty of VRAM still available, but the largest chunk is too small. Haven't been able to get past that so far. No idea why not, sadly.

How did you get around the memory issues?

@xloem
Copy link

xloem commented Dec 11, 2022

nowdays huggingface’s accelerate library also lets you offload weights when memory is tight

@jdwx
Copy link

jdwx commented Dec 11, 2022

How did you get around the memory issues?

I'm using Deepspeed instead of this approach. But I think xloem is probably right that adapters are the way forward for maximum scalability and flexibility.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment