Skip to content

Instantly share code, notes, and snippets.

@Guitaricet
Last active September 1, 2022 18:18
Show Gist options
  • Save Guitaricet/478bb6bc5c7e963a6acfe1100c6c2c18 to your computer and use it in GitHub Desktop.
Save Guitaricet/478bb6bc5c7e963a6acfe1100c6c2c18 to your computer and use it in GitHub Desktop.
Short example of distributed generation. Pure PyTorch Distributed.
# torchrun --nproc_per_node 2 generate_multigpu.py
# tested on torch==1.12.1 and transformers==4.21
import os
import json
import torch
import torch.distributed as dist
import transformers
import datasets
from tqdm import tqdm
MODEL_NAME = "facebook/opt-350m"
PER_DEVICE_BATCH_SIZE = 8
MAX_BATCHES = 2
if __name__ == '__main__':
dist.init_process_group(backend="nccl")
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = dist.get_rank() # in case you use multiple nodes
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device) # required for dist.gather_object
model.to(device)
dataset = datasets.load_dataset("wikitext", "wikitext-2-v1")["train"]
ddp_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=PER_DEVICE_BATCH_SIZE, sampler=ddp_sampler)
all_predictions_str = []
print(f"RANK {global_rank}: Starting generation")
with torch.no_grad():
for i, batch in enumerate(tqdm(dataloader, disable=local_rank != 0)):
if i >= MAX_BATCHES:
break
encoded = tokenizer(batch["text"], return_tensors="pt", padding=True, truncation=True, max_length=128)
encoded = encoded.to(device)
predictions = model.generate(**encoded, do_sample=True, max_length=256)
predictions = predictions.cpu().numpy()
predictions_str = tokenizer.batch_decode(predictions, skip_special_tokens=True)
all_predictions_str.extend(predictions_str)
print(f"RANK {global_rank}: Finished generation")
# collect all of the predicted strings from all GPUs and send them to GPU 0
all_predictions_str_gathered = None
if global_rank == 0:
all_predictions_str_gathered = dist.get_world_size() * [None]
dist.gather_object(all_predictions_str, all_predictions_str_gathered, dst=0)
if global_rank == 0:
# flatten list
all_predictions_str_gathered = [item for sublist in all_predictions_str_gathered for item in sublist]
with open("predictions.txt", "w") as f:
f.write("\n".join(all_predictions_str_gathered))
print(f"RANK {global_rank}: Predictions saved to predictions.txt")
print(f"RANK {global_rank}: Done")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment