Skip to content

Instantly share code, notes, and snippets.

@NizarDhahri
Forked from ddh0/train.py
Created February 6, 2025 20:15
Show Gist options
  • Save NizarDhahri/323f2df48fdea853a80f2c52e32e9fbb to your computer and use it in GitHub Desktop.
Save NizarDhahri/323f2df48fdea853a80f2c52e32e9fbb to your computer and use it in GitHub Desktop.
Janky pretraining script for small llama models using HF fineweb - modify according to your needs
import os
import torch
import psutil
import datasets
import glob
from transformers import (
AutoTokenizer, LlamaConfig, LlamaForCausalLM, Trainer, TrainingArguments,
DataCollatorForLanguageModeling
)
N_CTX = 512 # dont touch
LOG_DIR = f'/home/dylan/Documents/AI/train/logs'
OUTPUT_DIR = f'/home/dylan/Documents/AI/train/output'
DATA_DIR_350BT = f'/media/dylan/SanDisk/350BT'
TOKENIZED_DIR_350BT = f"{DATA_DIR_350BT}/tokenized" # Where to store processed data
DATA_DIR_10BT = f'/home/dylan/Documents/AI/datasets/fineweb/sample/10BT'
TOKENIZED_DIR_10BT = f"{DATA_DIR_10BT}/tokenized" # Where to store processed data
DATA_FILE_1BT = f'/home/dylan/Documents/AI/datasets/fineweb/sample/1BT/1BT.parquet'
TOKENIZED_FILE_1BT = f'{DATA_FILE_1BT}.tokenized'
DATA_DIR_WIKITEXT = f'/home/dylan/Documents/AI/datasets/wikitext/wikitext-103-raw-v1'
DATA_FILE_EVAL = f'{DATA_DIR_WIKITEXT}/train-00000-of-00002.parquet'
TOKENIZED_FILE_EVAL = f"{DATA_FILE_EVAL}.tokenized"
def print_used_ram():
memory_info = psutil.virtual_memory()
used_ram_gib = memory_info.used / (1024 ** 3)
print(f"Used System RAM: {used_ram_gib:.2f} GiB")
print(f"Script start.")
print_used_ram()
print(f"Loading tokenizer ...")
tokenizer = AutoTokenizer.from_pretrained('./tokenizer/')
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False # For causal LM
)
def tokenize_function(examples):
return tokenizer(examples['text'], truncation=True, max_length=N_CTX)
n_cpu = os.cpu_count()
def dataset_from_parquet(file_path: str) -> datasets.Dataset:
if not os.path.exists(file_path):
raise FileNotFoundError(f'file {file_path!r} does not exist')
if os.path.isdir(file_path):
raise IsADirectoryError(f'{file_path!r} is a directory, not a file')
print(f'Loading parquet file {file_path!r} ...')
ds = datasets.Dataset.from_parquet(
path_or_paths=file_path,
keep_in_memory=False, # XXX
num_proc=n_cpu
)
print(f'Finished loading parquet file.')
return ds
def dataset_from_parquet_dir(dir_path: str) -> datasets.Dataset:
if not os.path.exists(dir_path):
raise FileNotFoundError(f'directory {dir_path!r} does not exist')
if not os.path.isdir(dir_path):
raise FileNotFoundError(f'{dir_path!r} is a file, not a directory')
abs_dir_path = os.path.abspath(dir_path)
file_paths: list[str] = []
print(f'Looking for Parquet files in {abs_dir_path!r}:')
for file_name in os.listdir(abs_dir_path):
if file_name.endswith('.parquet'):
print(f'-- Found {file_name!r}')
file_path = os.path.join(abs_dir_path, file_name)
file_paths.append(file_path)
n_file_paths = len(file_paths)
if n_file_paths == 0:
raise RuntimeError('No Parquet files were found.')
print(f'Loading {n_file_paths} Parquet files ...')
ds = datasets.Dataset.from_parquet(
path_or_paths=file_paths,
keep_in_memory=False, # XXX
num_proc=n_cpu
)
print(f'Finished loading {n_file_paths} Parquet files.')
return ds
def get_tokenized_dataset(data_dir: str, tokenized_dir: str) -> datasets.Dataset:
"""Load or create tokenized dataset with caching"""
if os.path.exists(tokenized_dir):
print(f"Loading pre-tokenized dataset from {tokenized_dir}")
return datasets.load_from_disk(tokenized_dir)
print(f"Tokenizing and caching dataset to {tokenized_dir}")
raw_dataset = dataset_from_parquet_dir(data_dir)
# Tokenize with parallel processing
tokenized_dataset = raw_dataset.map(
tokenize_function,
batched=True,
batch_size=1024,
num_proc=n_cpu,
remove_columns=["text"]
)
# Save for future runs
tokenized_dataset.save_to_disk(tokenized_dir)
return tokenized_dataset
def get_tokenized_dataset_file(data_file: str, tokenized_file: str) -> datasets.Dataset:
"""Load or create tokenized dataset with caching"""
if os.path.exists(tokenized_file):
print(f"Loading pre-tokenized dataset from {tokenized_file}")
return datasets.load_from_disk(tokenized_file)
print(f"Tokenizing and caching dataset to {tokenized_file}")
raw_dataset = dataset_from_parquet(data_file)
# Tokenize with parallel processing
tokenized_dataset = raw_dataset.map(
tokenize_function,
batched=True,
batch_size=1024,
num_proc=1,
remove_columns=["text"]
)
# Save for future runs
tokenized_dataset.save_to_disk(tokenized_file)
return tokenized_dataset
# extremely tiny model used for faster testing
nano_model_config = LlamaConfig(
attention_bias=False,
attention_dropout=0.0,
bos_token_id=128000,
eos_token_id=128001,
head_dim=1,
hidden_act="gelu",
hidden_size=256,
initializer_range=0.02,
intermediate_size=512,
max_position_embeddings=N_CTX,
mlp_bias=False,
num_attention_heads=1,
num_key_value_heads=1,
num_hidden_layers=1,
rms_norm_eps=1e-05,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10_000.0,
rope_scaling=None,
use_cache=True,
vocab_size=128256
)
# small model used for testing training before training actual model.
micro_model_config = LlamaConfig(
attention_bias=False,
attention_dropout=0.0,
bos_token_id=128000,
eos_token_id=128001,
head_dim=64,
hidden_act="gelu",
hidden_size=1024,
initializer_range=0.02,
intermediate_size=1536,
max_position_embeddings=N_CTX,
mlp_bias=False,
num_attention_heads=4,
num_key_value_heads=2,
num_hidden_layers=6,
rms_norm_eps=1e-05,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10_000.0,
rope_scaling=None,
use_cache=True,
vocab_size=128256
)
# actual target model to train - similar arch to llama 3.2 1B
model_config = LlamaConfig(
attention_bias=False,
attention_dropout=0.0,
bos_token_id=128000,
eos_token_id=128001,
head_dim=64,
hidden_act="gelu",
hidden_size=2048,
initializer_range=0.02,
intermediate_size=8192,
max_position_embeddings=N_CTX,
mlp_bias=False,
num_attention_heads=32,
num_key_value_heads=8,
num_hidden_layers=16,
rms_norm_eps=1e-05,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=100_000.0,
rope_scaling=None,
use_cache=True,
vocab_size=128256
)
def get_latest_checkpoint(output_dir):
"""Get the path of the latest checkpoint in the output directory."""
checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
if not checkpoints:
return None
latest_checkpoint = max(checkpoints, key=os.path.getctime)
return latest_checkpoint
def main() -> int:
print(f"Start main")
print(f"Init model ...")
model = LlamaForCausalLM(model_config) # actual model for real training (largest)
print(f"n_params: {model.num_parameters():,}")
if torch.cuda.is_available():
print('Using CUDA device')
device = torch.device("cuda")
device_str = "CUDA device"
else:
print('Using CPU')
device = "cpu"
device_str = "CPU"
bf16_enabled = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
print(f"bf16 enabled: {bf16_enabled}")
print(f"Moving model to {device_str} ...")
model.to(device, dtype=torch.bfloat16)
# For training data
print(f"Loading training data ...")
training_data = get_tokenized_dataset_file(DATA_FILE_1BT, TOKENIZED_FILE_1BT)
training_data.set_format("torch", columns=["input_ids", "attention_mask"])
trainer_args = TrainingArguments(
output_dir=OUTPUT_DIR,
overwrite_output_dir=True,
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=16,
save_steps=64,
save_total_limit=16,
logging_dir=LOG_DIR,
logging_steps=1,
eval_strategy="no",
learning_rate=2e-5,
bf16=bf16_enabled,
bf16_full_eval=bf16_enabled,
)
print(f"Create Trainer ...")
trainer = Trainer(
model=model,
args=trainer_args,
train_dataset=training_data,
data_collator=data_collator
)
# Check for the latest checkpoint
latest_checkpoint = get_latest_checkpoint(OUTPUT_DIR)
try:
if latest_checkpoint:
print(f"Resuming training from checkpoint: {latest_checkpoint}")
trainer.train(resume_from_checkpoint=latest_checkpoint)
else:
print(f"Starting training from scratch ...")
trainer.train()
print(f"Done training.")
except KeyboardInterrupt:
print(f"KeyboardInterrupt: Training interrupted by user.")
except Exception as e:
print(f"Caught exception: {e}")
finally:
print(f"Save model ...")
model.save_pretrained(f"{OUTPUT_DIR}/final", safe_serialization=True)
tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
return 0
if __name__ == '__main__':
exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment