Skip to content

Instantly share code, notes, and snippets.

@timvandam
Created August 23, 2022 11:51
Show Gist options
  • Save timvandam/bb6992156686c7eeafeab7a59a1db6de to your computer and use it in GitHub Desktop.
Save timvandam/bb6992156686c7eeafeab7a59a1db6de to your computer and use it in GitHub Desktop.
DDP training
import math
import operator
from multiprocessing import Pool
from typing import List
import torch
import random
import os
import numpy as np
from fuzzywuzzy import fuzz
from torch.utils.data import DataLoader, Dataset, RandomSampler, DistributedSampler, SequentialSampler
from tqdm import tqdm
from transformers import RobertaConfig, RobertaTokenizer, RobertaModel, AdamW, get_linear_schedule_with_warmup
from Seq2Seq import Seq2Seq
import json
import sys
from functools import lru_cache
import argparse
import torch.multiprocessing
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import datetime
from time import time
torch.multiprocessing.set_sharing_strategy('file_system')
def parse_args():
parser = argparse.ArgumentParser(
description="Trains a model, optionally using multiple GPUs using DistributedDataParallel. "
"Requires ./datasets/train.txt and ./datasets/validation.jsonl to be present.")
parser.add_argument(
"--slurm",
action='store_true',
help='Enables automatic DistributedDataParallel based on Slurm options. '
'Enabling this will automatically fill in all DDP options based on environment variables'
)
multi_gpu_manual_options = parser.add_argument_group(
title="Multi GPU Options",
description="Options for manually configuring DistributedDataParallel"
)
multi_gpu_manual_options.add_argument('--world_size', type=int, default=-1,
help='Number of processes participating in the job')
multi_gpu_manual_options.add_argument('--gpus_per_node', type=int, default=-1)
multi_gpu_manual_options.add_argument('--rank', type=int, default=-1, help='The global rank')
multi_gpu_manual_options.add_argument('--local_rank', type=int, default=-1,
help='The local rank (determines which GPU to use)')
multi_gpu_manual_options.add_argument('--dist_backend', type=str, default=dist.Backend.NCCL)
multi_gpu_manual_options.add_argument('--dist_url', default='env://', type=str)
parser.add_argument("dataset_folder", type=str, help="The folder containing the dataset", action='store')
parser.add_argument("--model_name", default="microsoft/unixcoder-base", type=str,
help="The name or the path of the model to be trained")
parser.add_argument("--learning_rate", default=2e-4, type=float, help="The learning rate")
parser.add_argument("--max_input_length", default=936, type=int,
help="The maximum length of the input (validation input is left-truncated if over this length)")
parser.add_argument("--max_output_length", default=64, type=int, help="The maximum length of the output")
parser.add_argument("--chunk_overlap", default=100, type=int,
help="The window offset used for windowing train inputs larger than the max allowed length")
parser.add_argument("--seed", default=42, type=int, help="The seed used for randomized things")
parser.add_argument("--beam_size", default=3, type=int, help="The beam size for beam search")
parser.add_argument("--batch_size", default=8, type=int, help="The batch size")
parser.add_argument("--num_epochs", default=10, type=int, help="The number of epochs")
parser.add_argument("--gradient_accumulation_steps", default=1, type=int,
help="The number of steps to accumulate gradients")
parser.add_argument("--weight_decay", default=0.0, type=float, help="The weight decay")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="The epsilon for Adam optimizer")
parser.add_argument("--cpu_count", default=os.cpu_count(), type=int,
help="The number of CPUs to use for loading data")
return parser.parse_args()
def set_slurm_args(args):
if not args.slurm:
raise Exception("Slurm args can only be set if slurm is enabled")
args.rank = int(os.environ['SLURM_PROCID'])
args.local_rank = int(os.environ['SLURM_LOCALID'])
args.cpu_count = int(os.environ['SLURM_CPUS_PER_TASK'])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpus_per_node = int(os.environ['SLURM_GPUS_ON_NODE'])
if args.gpus_per_node > torch.cuda.device_count():
raise Exception("The number of GPUs per node is greater than the number of GPUs available")
def get_current_date_time_string():
return datetime.datetime.now().strftime("%d.%b %Y %H:%M:%S")
def log(message):
print(f'[{get_current_date_time_string()}] {message}', flush=True)
def progress_enumerator(enumerable, create_message, step=100, total=None, report_eta=False):
start = time()
for i, x in enumerate(enumerable):
if step and i % step == 0:
suffix = ""
if report_eta and i > 0:
if total is None:
total = len(enumerable)
time_elapsed = time() - start
avg_time_per_element = time_elapsed / i
remaining_elements = total - i
remaining_time_estimate = avg_time_per_element * remaining_elements
suffix = f" [{datetime.timedelta(seconds=round(remaining_time_estimate))} remaining]"
log(create_message(i, x) + suffix)
yield x
def set_seed(seed: int):
random.seed(seed)
os.environ['PYHTONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def get_train_file_path(dataset_folder: str):
return os.path.join(dataset_folder, "datasets", "train.txt")
def get_validation_file_path(dataset_folder: str):
return os.path.join(dataset_folder, "datasets", "validation.jsonl")
def get_model_folder_path(dataset_folder: str):
return os.path.join(dataset_folder, "models")
def get_model_file_path(dataset_folder: str, model_name: str):
return os.path.join(get_model_folder_path(dataset_folder), model_name + ".bin")
def get_training_details_file_path(dataset_folder: str):
return os.path.join(get_model_folder_path(dataset_folder), "training_details.json")
def get_training_details(dataset_folder):
training_details_json_path = get_training_details_file_path(dataset_folder)
if not os.path.exists(training_details_json_path):
return {
"models": [], # contains { modelName: string, exactMatch: number, editSim: number }
}
with open(training_details_json_path, "r", encoding='utf-8') as f:
return json.loads(f.read())
def save_training_details(dataset_folder, training_details):
training_details_json_path = get_training_details_file_path(dataset_folder)
with open(training_details_json_path, "w", encoding='utf-8') as f:
f.write(json.dumps(training_details, indent=2))
def verify_files_exist(dataset_folder: str):
if not os.path.exists(dataset_folder):
raise Exception("Dataset folder does not exist")
if not os.path.exists(get_train_file_path(dataset_folder)):
raise Exception("Train file does not exist")
if not os.path.exists(get_validation_file_path(dataset_folder)):
raise Exception("Validation file does not exist")
def prepare_model(model_name: str, max_output_length: int, beam_size: int):
config = RobertaConfig.from_pretrained(model_name)
config.is_decoder = True
tokenizer = RobertaTokenizer.from_pretrained(model_name)
# confirm assumptions that are made in createTrainSet.ts
if tokenizer.bos_token != '<s>':
raise Exception("Tokenizer bos_token is not <s>")
if tokenizer.cls_token != '<s>':
raise Exception("Tokenizer cls_token is not <s>")
if tokenizer.sep_token != '</s>':
raise Exception("Tokenizer sep_token is not </s>")
if tokenizer.eos_token != '</s>':
raise Exception("Tokenizer eos_token is not </s>")
encoder = RobertaModel.from_pretrained(model_name, config=config)
decoder = encoder
model = Seq2Seq(
encoder=encoder,
decoder=decoder,
config=config,
beam_size=beam_size,
max_length=max_output_length,
sos_id=tokenizer.cls_token_id,
eos_id=[tokenizer.sep_token_id],
)
return model, tokenizer
class IndexedFileDataset(Dataset):
"""
Dataset that caches an input file but lazily tokenizes them (with a cache!)
"""
# input_index -> (line_index, input_index)
# handy when a line contains multiple inputs/chunks
index = []
index_loaded = False
# all lines in the input file are cached
lines = []
lines_loaded = False
def __init__(self, file_path: str, cpu_count: int):
self.file_path = file_path
self.cpu_count = cpu_count
def _read_file(self):
if self.lines_loaded:
return
with open(self.file_path, 'r', encoding='utf-8') as f:
self.lines = f.readlines()
self.lines_loaded = True
def _create_index(self):
if not self.lines_loaded:
self._read_file()
if self.index_loaded:
return
with Pool(self.cpu_count) as pool:
input_counts = progress_enumerator(
pool.imap(self._get_line_input_count, range(len(self.lines)), chunksize=1000),
lambda i, _: f"Creating index [{i} / {len(self.lines) - 1}]",
step=1000,
total=len(self.lines),
report_eta=True,
)
self.index = [
(line_index, i)
for line_index, input_count in enumerate(input_counts)
for i in range(input_count)
]
self.index_loaded = True
def __len__(self):
if not self.index_loaded:
self._create_index()
return len(self.index)
# TODO: Force it to get all items to test how much ram it takes
# @lru_cache(maxsize=65536)
def __getitem__(self, idx):
if not self.index_loaded:
raise Exception("Index not loaded")
if idx < 0 or idx >= len(self.index):
raise IndexError("Index out of range")
line_index, input_index = self.index[idx]
line_input = self._get_line_input(line_index, input_index)
return line_input
def _get_line(self, line_index: int):
if not self.lines_loaded:
self._read_file()
if line_index < 0 or line_index >= len(self.lines):
raise IndexError("Line index out of range")
return self.lines[line_index]
def _get_line_input_count(self, line_index: int):
"""
Should return the amount of inputs that are in some line
"""
raise NotImplementedError()
def _get_line_input(self, line_index: int, input_index: int):
"""
Should return some input from some line
"""
raise NotImplementedError()
def str_to_tokens(string: str, tokenizer: RobertaTokenizer):
return [token for token in tokenizer.tokenize(string) if token != '\u0120']
def tokens_to_token_ids(tokenizer: RobertaTokenizer, max_length: int, tokens: List[str]):
if len(tokens) > max_length:
raise Exception("Input is too long")
token_ids = tokenizer.convert_tokens_to_ids(tokens)
padding_length = max_length - len(token_ids)
token_ids += [tokenizer.pad_token_id] * padding_length
return token_ids
def rindex(items, item, start=0, end=None):
if 0 <= start <= len(items):
if end is None:
end = len(items)
end = min(end, len(items))
for i in range(end - 1, start - 1, -1):
if items[i] == item:
return i
raise ValueError("Item not found")
class TrainDataset(IndexedFileDataset):
def __init__(
self,
train_file_path: str,
cpu_count: int,
max_length: int,
tokenizer: RobertaTokenizer,
chunk_overlap: int,
):
super().__init__(train_file_path, cpu_count)
if chunk_overlap < 0:
raise Exception("Chunk overlap must be >= 0")
if chunk_overlap >= max_length - 3:
raise Exception("Chunk overlap must be < max_length - 3")
self.max_length = max_length
self.tokenizer = tokenizer
self.chunk_overlap = chunk_overlap
@lru_cache(maxsize=10)
def _get_line_tokens(self, line_index: int):
line = self._get_line(line_index)
# remove leading <s>
line = " ".join(line.strip().split()[1:])
line_tokens = str_to_tokens(line, self.tokenizer)
return line_tokens
@lru_cache(maxsize=10)
def _get_line_chunk_ranges(self, line_index: int):
line_tokens = self._get_line_tokens(line_index)
chunk_length = self.max_length - 3
chunk_start = 0
chunk_end = chunk_start + chunk_length
chunk_ranges = []
while True:
try:
# ensure that we always end with an eos
last_eos_idx = rindex(line_tokens, "</s>", chunk_start, chunk_end)
chunk_end = last_eos_idx + 1
except ValueError:
# ValueError means we don't have an eos
# this only happens when lines don't fit in the model
# we can only really skip those lines
# including them partially would lead to partial predictions (which is not the point of line completion)
pass
chunk_ranges.append((chunk_start, chunk_end))
if chunk_end >= len(line_tokens):
break
current_chunk_length = chunk_end - chunk_start
if current_chunk_length > self.chunk_overlap:
# the current chunk is larger than the chunk overlap. good!
chunk_start = chunk_end - self.chunk_overlap
else:
# chunk is smaller than the chunk overlap :o
# very small chunk means that the next lines are very long
# in order to still make it work we will just not use the overlap
# this only happens when the next line is extremely long, so no big issue
chunk_start = chunk_end
chunk_end = chunk_start + chunk_length
if line_index < 10:
print(f"Line {line_index} has {len(chunk_ranges)} chunks")
for i in range(len(chunk_ranges)):
chunk_start, chunk_end = chunk_ranges[i]
print(f"*** {i} ***")
print(" ".join(line_tokens[chunk_start:chunk_end]))
print("\n\n\n")
return chunk_ranges
@lru_cache(maxsize=10)
def _get_line_input_count(self, line_index: int):
return len(self._get_line_chunk_ranges(line_index))
@lru_cache(maxsize=10)
def _get_line_input(self, line_index: int, input_index: int):
# chunked approach with overlapping chunks
line_tokens = self._get_line_tokens(line_index)
line_chunk_ranges = self._get_line_chunk_ranges(line_index)
if input_index < 0 or input_index >= len(line_chunk_ranges):
raise IndexError("Input index out of range. "
f"Got {input_index}, expected value in range [0, {len(line_chunk_ranges) - 1})")
chunk_start, chunk_end = line_chunk_ranges[input_index]
chunk = ['<s>', '<decoder-only>', '</s>'] + line_tokens[chunk_start:chunk_end]
return tokens_to_token_ids(self.tokenizer, self.max_length, chunk)
class ValidationDataset(IndexedFileDataset):
def __init__(self, validation_file_path: str, cpu_count: int, max_length: int, tokenizer: RobertaTokenizer):
super().__init__(validation_file_path, cpu_count)
self.max_length = max_length
self.tokenizer = tokenizer
def _get_line_input_count(self, line_index: int):
# the validation set is just input->output on each line
# input is left-truncated if need be
return 1
def _get_line_input(self, line_index: int, input_index: int):
if input_index != 0:
raise IndexError("Index out of range")
line = self._get_line(line_index)
obj = json.loads(line)
# replace \n with </s>, normalize spacing
left_context = obj["leftContext"]
left_context = left_context.replace("\n", " </s> ")
left_context = left_context.split()
left_context = " ".join(left_context)
tokens = str_to_tokens(left_context, self.tokenizer)
# truncate from the left side and add prefix
tokens = ["<s>", "<decoder-only>", "</s>"] + tokens[-(self.max_length - 3):]
input_tokens = tokens_to_token_ids(self.tokenizer, self.max_length, tokens)
return input_tokens, obj["groundTruth"]
def main(args):
set_seed(args.seed)
verify_files_exist(args.dataset_folder)
os.makedirs(get_model_folder_path(args.dataset_folder), exist_ok=True)
previous_training_details = get_training_details(args.dataset_folder)
previous_epochs = len(previous_training_details["models"])
remaining_epochs = args.num_epochs - previous_epochs
if remaining_epochs <= 0:
log(f"This model has already trained for {len(previous_training_details['models'])} "
f"out of {args.num_epochs} epochs, exiting")
exit(0)
model, tokenizer = prepare_model(args.model_name, args.max_output_length, args.beam_size)
if previous_epochs > 0:
log(f"Found a model that was already trained for {previous_epochs} epoch(s), loading it")
model.load_state_dict(torch.load(get_model_file_path(
args.dataset_folder,
previous_training_details["models"][-1]["modelName"]))
)
if torch.cuda.is_available():
log("CUDA available, using GPU")
if args.local_rank != -1:
log(f"Using distributed GPU [{args.local_rank}]")
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda")
model = model.to(device)
else:
log("CUDA not available, using CPU")
device = torch.device("cpu")
model = model.to(device)
if args.local_rank != -1:
dist.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
rank=args.rank,
world_size=args.world_size,
)
model = DDP(model, device_ids=[args.local_rank], find_unused_parameters=True)
log("Loading train dataset")
train_dataset = TrainDataset(
train_file_path=get_train_file_path(args.dataset_folder),
cpu_count=args.cpu_count,
max_length=args.max_input_length + args.max_output_length,
tokenizer=tokenizer,
chunk_overlap=args.chunk_overlap
)
if args.local_rank == -1:
train_sampler = RandomSampler(train_dataset)
else:
train_sampler = DistributedSampler(train_dataset, num_replicas=args.world_size, rank=args.rank)
train_dataloader = DataLoader(
train_dataset,
sampler=train_sampler,
batch_size=args.batch_size // args.gradient_accumulation_steps,
num_workers=args.cpu_count,
pin_memory=True,
)
log(f"Window size avg: {len(train_dataset.index)/len(train_dataset.lines)}")
log(f"Loaded train dataset: {len(train_dataloader)} batches")
if args.rank <= 0:
log("Loading validation dataset")
validation_dataset = ValidationDataset(
validation_file_path=get_validation_file_path(args.dataset_folder),
cpu_count=args.cpu_count,
max_length=args.max_input_length,
tokenizer=tokenizer
)
validation_sampler = SequentialSampler(validation_dataset)
validation_dataloader = DataLoader(
validation_dataset,
sampler=validation_sampler,
batch_size=args.batch_size,
num_workers=args.cpu_count,
pin_memory=True,
)
log(f"Loaded validation dataset: {len(validation_dataloader)} batches")
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
# TODO: Use the non deprecated optimizer
# TODO: Research how this works to set good params and to make sure it is implemented correctly
optimizer = AdamW(
optimizer_grouped_parameters,
lr=args.learning_rate,
eps=args.adam_epsilon
)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=int(len(train_dataloader) * args.num_epochs * 0.1),
num_training_steps=len(train_dataloader) * args.num_epochs
)
if args.local_rank != -1:
dist.barrier()
if args.rank <= 0:
log(
"***** Running Training *****\n" +
"\tNum examples = %d\n" % len(train_dataset) +
"\tBatch size = %d\n" % args.batch_size +
"\tNum epochs = %d\n" % args.num_epochs +
"\tSteps per epoch = %d" % len(train_dataloader)
)
model.train()
nb_tr_examples, nb_tr_steps, tr_loss, global_step, best_accuracy, best_loss = 0, 0, 0, 0, 0, 1e6
losses = []
# TODO: Make tqdm bars more friendly to output file
for epoch in range(previous_epochs, args.num_epochs):
if args.rank <= 0:
log(f"Starting epoch {epoch}")
for idx, batch in enumerate(train_dataloader):
source_ids = torch.transpose(torch.stack(batch), 0, 1).to(device).contiguous()
loss, _, _ = model(source_ids, True)
losses.append(loss.item())
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
tr_loss += loss.item()
if (idx + 1) % 100 == 0:
# TODO: Add loss plot
log("epoch %d step %d loss %f" % (epoch, idx + 1, round(np.mean(losses[-100:]), 4)))
nb_tr_examples += source_ids.size(0)
nb_tr_steps += 1
loss.backward()
if (nb_tr_steps + 1) % args.gradient_accumulation_steps == 0:
# Update parameters
optimizer.step()
optimizer.zero_grad()
scheduler.step()
global_step += 1
# Eval model with validation dataset
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
if args.local_rank != -1:
dist.barrier()
if args.rank <= 0:
log("***** Running Validation *****")
model.eval()
# See if its possible to make this parallel too
EM = 0.0
edit_sim = 0.0
for i, (batch, ground_truths) in enumerate(progress_enumerator(
validation_dataloader,
lambda i, _: f"Validating [{i} / {len(validation_dataloader) - 1}]",
total=len(validation_dataloader),
report_eta=True,
)):
source_ids = torch.transpose(torch.stack(batch), 0, 1).to(device).contiguous()
with torch.no_grad():
predict = model.module if hasattr(model, 'module') else model
preds = predict(source_ids=source_ids)
for j, (gt, pred) in enumerate(zip(ground_truths, preds)):
t = pred[0].cpu().numpy()
t = list(t)
if 0 in t:
t = t[:t.index(0)]
pred = tokenizer.decode(t, clean_up_tokenization_spaces=False)
if "</s>" in pred:
pred = pred[:pred.index("</s>")]
pred = " ".join(pred.strip().split())
gt = " ".join(gt.strip().split())
if i == 0 and j < 5:
log(f"Validation example {j}:\n*** Prediction ***\n{pred}\n\n*** Ground Truth ***\n{gt}\n***")
if pred == gt:
EM += 1
edit_sim += fuzz.ratio(pred, gt)
EM /= len(preds)
edit_sim /= len(preds)
model.train()
validation_accuracy = round(EM * 100, 2)
log("\t%s = %s " % ("Acc", str(validation_accuracy)))
log("\t%s = %s " % ("Edit sim", str(round(edit_sim, 2))))
log(" " + "*" * 20)
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
model_name = f"epoch-{epoch}"
torch.save(model_to_save.state_dict(), get_model_file_path(args.dataset_folder, model_name))
training_details = get_training_details(args.dataset_folder)
training_details["models"].append({
"modelName": model_name,
"exactMatch": validation_accuracy,
"editSim": edit_sim,
})
save_training_details(args.dataset_folder, training_details)
# wait for all ranks to reach this barrier
# this ensures that all processes wait while rank 0 is validating, saving the model, and saving training details
if args.local_rank != -1:
dist.barrier()
if __name__ == '__main__':
args = parse_args()
if args.slurm:
set_slurm_args(args)
config_text = ""
config_text += "*** Config ***\n"
config_text += "\n".join(map(lambda kv: f"{kv[0]}: {kv[1]}", args.__dict__.items())) + "\n"
config_text += "**************"
log(config_text)
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment