-
-
Save NizarDhahri/323f2df48fdea853a80f2c52e32e9fbb to your computer and use it in GitHub Desktop.
Janky pretraining script for small llama models using HF fineweb - modify according to your needs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import psutil | |
import datasets | |
import glob | |
from transformers import ( | |
AutoTokenizer, LlamaConfig, LlamaForCausalLM, Trainer, TrainingArguments, | |
DataCollatorForLanguageModeling | |
) | |
N_CTX = 512 # dont touch | |
LOG_DIR = f'/home/dylan/Documents/AI/train/logs' | |
OUTPUT_DIR = f'/home/dylan/Documents/AI/train/output' | |
DATA_DIR_350BT = f'/media/dylan/SanDisk/350BT' | |
TOKENIZED_DIR_350BT = f"{DATA_DIR_350BT}/tokenized" # Where to store processed data | |
DATA_DIR_10BT = f'/home/dylan/Documents/AI/datasets/fineweb/sample/10BT' | |
TOKENIZED_DIR_10BT = f"{DATA_DIR_10BT}/tokenized" # Where to store processed data | |
DATA_FILE_1BT = f'/home/dylan/Documents/AI/datasets/fineweb/sample/1BT/1BT.parquet' | |
TOKENIZED_FILE_1BT = f'{DATA_FILE_1BT}.tokenized' | |
DATA_DIR_WIKITEXT = f'/home/dylan/Documents/AI/datasets/wikitext/wikitext-103-raw-v1' | |
DATA_FILE_EVAL = f'{DATA_DIR_WIKITEXT}/train-00000-of-00002.parquet' | |
TOKENIZED_FILE_EVAL = f"{DATA_FILE_EVAL}.tokenized" | |
def print_used_ram(): | |
memory_info = psutil.virtual_memory() | |
used_ram_gib = memory_info.used / (1024 ** 3) | |
print(f"Used System RAM: {used_ram_gib:.2f} GiB") | |
print(f"Script start.") | |
print_used_ram() | |
print(f"Loading tokenizer ...") | |
tokenizer = AutoTokenizer.from_pretrained('./tokenizer/') | |
tokenizer.pad_token = tokenizer.eos_token | |
data_collator = DataCollatorForLanguageModeling( | |
tokenizer=tokenizer, | |
mlm=False # For causal LM | |
) | |
def tokenize_function(examples): | |
return tokenizer(examples['text'], truncation=True, max_length=N_CTX) | |
n_cpu = os.cpu_count() | |
def dataset_from_parquet(file_path: str) -> datasets.Dataset: | |
if not os.path.exists(file_path): | |
raise FileNotFoundError(f'file {file_path!r} does not exist') | |
if os.path.isdir(file_path): | |
raise IsADirectoryError(f'{file_path!r} is a directory, not a file') | |
print(f'Loading parquet file {file_path!r} ...') | |
ds = datasets.Dataset.from_parquet( | |
path_or_paths=file_path, | |
keep_in_memory=False, # XXX | |
num_proc=n_cpu | |
) | |
print(f'Finished loading parquet file.') | |
return ds | |
def dataset_from_parquet_dir(dir_path: str) -> datasets.Dataset: | |
if not os.path.exists(dir_path): | |
raise FileNotFoundError(f'directory {dir_path!r} does not exist') | |
if not os.path.isdir(dir_path): | |
raise FileNotFoundError(f'{dir_path!r} is a file, not a directory') | |
abs_dir_path = os.path.abspath(dir_path) | |
file_paths: list[str] = [] | |
print(f'Looking for Parquet files in {abs_dir_path!r}:') | |
for file_name in os.listdir(abs_dir_path): | |
if file_name.endswith('.parquet'): | |
print(f'-- Found {file_name!r}') | |
file_path = os.path.join(abs_dir_path, file_name) | |
file_paths.append(file_path) | |
n_file_paths = len(file_paths) | |
if n_file_paths == 0: | |
raise RuntimeError('No Parquet files were found.') | |
print(f'Loading {n_file_paths} Parquet files ...') | |
ds = datasets.Dataset.from_parquet( | |
path_or_paths=file_paths, | |
keep_in_memory=False, # XXX | |
num_proc=n_cpu | |
) | |
print(f'Finished loading {n_file_paths} Parquet files.') | |
return ds | |
def get_tokenized_dataset(data_dir: str, tokenized_dir: str) -> datasets.Dataset: | |
"""Load or create tokenized dataset with caching""" | |
if os.path.exists(tokenized_dir): | |
print(f"Loading pre-tokenized dataset from {tokenized_dir}") | |
return datasets.load_from_disk(tokenized_dir) | |
print(f"Tokenizing and caching dataset to {tokenized_dir}") | |
raw_dataset = dataset_from_parquet_dir(data_dir) | |
# Tokenize with parallel processing | |
tokenized_dataset = raw_dataset.map( | |
tokenize_function, | |
batched=True, | |
batch_size=1024, | |
num_proc=n_cpu, | |
remove_columns=["text"] | |
) | |
# Save for future runs | |
tokenized_dataset.save_to_disk(tokenized_dir) | |
return tokenized_dataset | |
def get_tokenized_dataset_file(data_file: str, tokenized_file: str) -> datasets.Dataset: | |
"""Load or create tokenized dataset with caching""" | |
if os.path.exists(tokenized_file): | |
print(f"Loading pre-tokenized dataset from {tokenized_file}") | |
return datasets.load_from_disk(tokenized_file) | |
print(f"Tokenizing and caching dataset to {tokenized_file}") | |
raw_dataset = dataset_from_parquet(data_file) | |
# Tokenize with parallel processing | |
tokenized_dataset = raw_dataset.map( | |
tokenize_function, | |
batched=True, | |
batch_size=1024, | |
num_proc=1, | |
remove_columns=["text"] | |
) | |
# Save for future runs | |
tokenized_dataset.save_to_disk(tokenized_file) | |
return tokenized_dataset | |
# extremely tiny model used for faster testing | |
nano_model_config = LlamaConfig( | |
attention_bias=False, | |
attention_dropout=0.0, | |
bos_token_id=128000, | |
eos_token_id=128001, | |
head_dim=1, | |
hidden_act="gelu", | |
hidden_size=256, | |
initializer_range=0.02, | |
intermediate_size=512, | |
max_position_embeddings=N_CTX, | |
mlp_bias=False, | |
num_attention_heads=1, | |
num_key_value_heads=1, | |
num_hidden_layers=1, | |
rms_norm_eps=1e-05, | |
pretraining_tp=1, | |
tie_word_embeddings=False, | |
rope_theta=10_000.0, | |
rope_scaling=None, | |
use_cache=True, | |
vocab_size=128256 | |
) | |
# small model used for testing training before training actual model. | |
micro_model_config = LlamaConfig( | |
attention_bias=False, | |
attention_dropout=0.0, | |
bos_token_id=128000, | |
eos_token_id=128001, | |
head_dim=64, | |
hidden_act="gelu", | |
hidden_size=1024, | |
initializer_range=0.02, | |
intermediate_size=1536, | |
max_position_embeddings=N_CTX, | |
mlp_bias=False, | |
num_attention_heads=4, | |
num_key_value_heads=2, | |
num_hidden_layers=6, | |
rms_norm_eps=1e-05, | |
pretraining_tp=1, | |
tie_word_embeddings=False, | |
rope_theta=10_000.0, | |
rope_scaling=None, | |
use_cache=True, | |
vocab_size=128256 | |
) | |
# actual target model to train - similar arch to llama 3.2 1B | |
model_config = LlamaConfig( | |
attention_bias=False, | |
attention_dropout=0.0, | |
bos_token_id=128000, | |
eos_token_id=128001, | |
head_dim=64, | |
hidden_act="gelu", | |
hidden_size=2048, | |
initializer_range=0.02, | |
intermediate_size=8192, | |
max_position_embeddings=N_CTX, | |
mlp_bias=False, | |
num_attention_heads=32, | |
num_key_value_heads=8, | |
num_hidden_layers=16, | |
rms_norm_eps=1e-05, | |
pretraining_tp=1, | |
tie_word_embeddings=False, | |
rope_theta=100_000.0, | |
rope_scaling=None, | |
use_cache=True, | |
vocab_size=128256 | |
) | |
def get_latest_checkpoint(output_dir): | |
"""Get the path of the latest checkpoint in the output directory.""" | |
checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*')) | |
if not checkpoints: | |
return None | |
latest_checkpoint = max(checkpoints, key=os.path.getctime) | |
return latest_checkpoint | |
def main() -> int: | |
print(f"Start main") | |
print(f"Init model ...") | |
model = LlamaForCausalLM(model_config) # actual model for real training (largest) | |
print(f"n_params: {model.num_parameters():,}") | |
if torch.cuda.is_available(): | |
print('Using CUDA device') | |
device = torch.device("cuda") | |
device_str = "CUDA device" | |
else: | |
print('Using CPU') | |
device = "cpu" | |
device_str = "CPU" | |
bf16_enabled = torch.cuda.is_available() and torch.cuda.is_bf16_supported() | |
print(f"bf16 enabled: {bf16_enabled}") | |
print(f"Moving model to {device_str} ...") | |
model.to(device, dtype=torch.bfloat16) | |
# For training data | |
print(f"Loading training data ...") | |
training_data = get_tokenized_dataset_file(DATA_FILE_1BT, TOKENIZED_FILE_1BT) | |
training_data.set_format("torch", columns=["input_ids", "attention_mask"]) | |
trainer_args = TrainingArguments( | |
output_dir=OUTPUT_DIR, | |
overwrite_output_dir=True, | |
num_train_epochs=1, | |
per_device_train_batch_size=2, | |
gradient_accumulation_steps=16, | |
save_steps=64, | |
save_total_limit=16, | |
logging_dir=LOG_DIR, | |
logging_steps=1, | |
eval_strategy="no", | |
learning_rate=2e-5, | |
bf16=bf16_enabled, | |
bf16_full_eval=bf16_enabled, | |
) | |
print(f"Create Trainer ...") | |
trainer = Trainer( | |
model=model, | |
args=trainer_args, | |
train_dataset=training_data, | |
data_collator=data_collator | |
) | |
# Check for the latest checkpoint | |
latest_checkpoint = get_latest_checkpoint(OUTPUT_DIR) | |
try: | |
if latest_checkpoint: | |
print(f"Resuming training from checkpoint: {latest_checkpoint}") | |
trainer.train(resume_from_checkpoint=latest_checkpoint) | |
else: | |
print(f"Starting training from scratch ...") | |
trainer.train() | |
print(f"Done training.") | |
except KeyboardInterrupt: | |
print(f"KeyboardInterrupt: Training interrupted by user.") | |
except Exception as e: | |
print(f"Caught exception: {e}") | |
finally: | |
print(f"Save model ...") | |
model.save_pretrained(f"{OUTPUT_DIR}/final", safe_serialization=True) | |
tokenizer.save_pretrained(f"{OUTPUT_DIR}/final") | |
return 0 | |
if __name__ == '__main__': | |
exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment