-
-
Save djsaunde/691bd0e2f89ba0ccbc5e78f813820d02 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Usage with torchrun: torchrun --nproc_per_node=2 fp8_ddp.py | |
| """ | |
| import os | |
| import torch | |
| import torch.distributed as dist | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torchao.float8 import convert_to_float8_training | |
| from transformers.models.auto.tokenization_auto import AutoTokenizer | |
| from transformers.models.auto.modeling_auto import AutoModelForCausalLM | |
| from transformers.models.auto.configuration_auto import AutoConfig | |
| from datasets import Dataset | |
| def setup_ddp(): | |
| """Initialize DDP process group using torchrun environment variables""" | |
| rank = int(os.environ["LOCAL_RANK"]) | |
| world_size = int(os.environ["WORLD_SIZE"]) | |
| dist.init_process_group("nccl") | |
| torch.cuda.set_device(rank) | |
| return rank, world_size | |
| def create_dummy_dataset(tokenizer, num_samples=100, seq_len=512): | |
| """Create a dummy dataset for testing""" | |
| def tokenize_function(examples): | |
| return tokenizer( | |
| examples["text"], | |
| truncation=True, | |
| padding="max_length", | |
| max_length=seq_len, | |
| return_tensors="pt", | |
| ) | |
| # Generate dummy text data | |
| dummy_texts = [ | |
| f"This is sample text number {i} for training." * 20 for i in range(num_samples) | |
| ] | |
| dataset = Dataset.from_dict({"text": dummy_texts}) | |
| # Tokenize | |
| tokenized_dataset = dataset.map( | |
| tokenize_function, | |
| batched=True, | |
| remove_columns=dataset.column_names, | |
| ) | |
| return tokenized_dataset | |
| def train_step(model, batch, optimizer): | |
| """Single training step""" | |
| model.train() | |
| input_ids = batch["input_ids"].cuda() | |
| attention_mask = batch["attention_mask"].cuda() | |
| # Forward pass | |
| outputs = model( | |
| input_ids=input_ids, attention_mask=attention_mask, labels=input_ids | |
| ) | |
| loss = outputs.loss | |
| # Backward pass | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| return loss.item() | |
| def run_training(): | |
| """Main training function""" | |
| # Setup DDP | |
| rank, world_size = setup_ddp() | |
| print(f"Running on rank {rank}/{world_size}") | |
| # Model setup | |
| model_name = "meta-llama/Llama-3.2-1B" | |
| print("Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| config = AutoConfig.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, config=config, torch_dtype=torch.bfloat16, device_map={"": rank} | |
| ) | |
| # Apply FP8 quantization | |
| model = convert_to_float8_training(model) | |
| print("Applied FP8 quantization using convert_to_float8_training") | |
| # Wrap with DDP | |
| model = DDP(model, device_ids=[rank]) | |
| print("Model wrapped with DDP") | |
| # Apply torch.compile | |
| compiled_model = torch.compile(model, mode="default") | |
| model = compiled_model | |
| print("Applied torch.compile") | |
| # Create dataset and dataloader | |
| print("Creating dataset...") | |
| dataset = create_dummy_dataset(tokenizer, num_samples=32, seq_len=256) | |
| # Simple data loading (not distributed for simplicity) | |
| def collate_fn(batch): | |
| return { | |
| "input_ids": torch.stack( | |
| [torch.tensor(item["input_ids"]) for item in batch] | |
| ), | |
| "attention_mask": torch.stack( | |
| [torch.tensor(item["attention_mask"]) for item in batch] | |
| ), | |
| } | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, batch_size=2, shuffle=True, collate_fn=collate_fn | |
| ) | |
| # Optimizer | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) | |
| print("Starting training...") | |
| for step, batch in enumerate(dataloader): | |
| if step >= 5: # Only run a few steps | |
| break | |
| loss = train_step(model, batch, optimizer) | |
| if rank == 0: | |
| print(f"Step {step}: Loss = {loss:.4f}") | |
| def main(): | |
| """Main entry point for torchrun""" | |
| run_training() | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment