Skip to content

Instantly share code, notes, and snippets.

@djsaunde
Created July 22, 2025 22:35
Show Gist options
  • Select an option

  • Save djsaunde/691bd0e2f89ba0ccbc5e78f813820d02 to your computer and use it in GitHub Desktop.

Select an option

Save djsaunde/691bd0e2f89ba0ccbc5e78f813820d02 to your computer and use it in GitHub Desktop.
"""
Usage with torchrun: torchrun --nproc_per_node=2 fp8_ddp.py
"""
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torchao.float8 import convert_to_float8_training
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
from transformers.models.auto.configuration_auto import AutoConfig
from datasets import Dataset
def setup_ddp():
"""Initialize DDP process group using torchrun environment variables"""
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl")
torch.cuda.set_device(rank)
return rank, world_size
def create_dummy_dataset(tokenizer, num_samples=100, seq_len=512):
"""Create a dummy dataset for testing"""
def tokenize_function(examples):
return tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=seq_len,
return_tensors="pt",
)
# Generate dummy text data
dummy_texts = [
f"This is sample text number {i} for training." * 20 for i in range(num_samples)
]
dataset = Dataset.from_dict({"text": dummy_texts})
# Tokenize
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset.column_names,
)
return tokenized_dataset
def train_step(model, batch, optimizer):
"""Single training step"""
model.train()
input_ids = batch["input_ids"].cuda()
attention_mask = batch["attention_mask"].cuda()
# Forward pass
outputs = model(
input_ids=input_ids, attention_mask=attention_mask, labels=input_ids
)
loss = outputs.loss
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def run_training():
"""Main training function"""
# Setup DDP
rank, world_size = setup_ddp()
print(f"Running on rank {rank}/{world_size}")
# Model setup
model_name = "meta-llama/Llama-3.2-1B"
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, config=config, torch_dtype=torch.bfloat16, device_map={"": rank}
)
# Apply FP8 quantization
model = convert_to_float8_training(model)
print("Applied FP8 quantization using convert_to_float8_training")
# Wrap with DDP
model = DDP(model, device_ids=[rank])
print("Model wrapped with DDP")
# Apply torch.compile
compiled_model = torch.compile(model, mode="default")
model = compiled_model
print("Applied torch.compile")
# Create dataset and dataloader
print("Creating dataset...")
dataset = create_dummy_dataset(tokenizer, num_samples=32, seq_len=256)
# Simple data loading (not distributed for simplicity)
def collate_fn(batch):
return {
"input_ids": torch.stack(
[torch.tensor(item["input_ids"]) for item in batch]
),
"attention_mask": torch.stack(
[torch.tensor(item["attention_mask"]) for item in batch]
),
}
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=2, shuffle=True, collate_fn=collate_fn
)
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
print("Starting training...")
for step, batch in enumerate(dataloader):
if step >= 5: # Only run a few steps
break
loss = train_step(model, batch, optimizer)
if rank == 0:
print(f"Step {step}: Loss = {loss:.4f}")
def main():
"""Main entry point for torchrun"""
run_training()
dist.destroy_process_group()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment