Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Created June 27, 2020 17:54
Show Gist options
  • Save williamFalcon/645019619bdd897d135d232556bcf27d to your computer and use it in GitHub Desktop.
Save williamFalcon/645019619bdd897d135d232556bcf27d to your computer and use it in GitHub Desktop.
import torch.utils.data as tud
import torch
from typing import List
import random
import nlp
def prepare_dataset(tokenizer, split="train", max_length=120, num_datapoints=100_000):
"""Prepares WikiText-103 dataset"""
wikitext = nlp.load_dataset("wikitext", "wikitext-103-v1")
data = [x["text"] for x in wikitext[split]][:num_datapoints]
data = "".join(data)
token_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(data))
chunked_token_ids = chunks(token_ids, max_length, tokenizer)
data = Data(chunked_token_ids, tokenizer)
return data
def chunks(lst, n, tokenizer):
"""Yield successive n-sized chunks from lst."""
_chunks = []
for i in range(0, len(lst), n):
ids = [tokenizer.cls_token_id] + lst[i : i + n] + [tokenizer.sep_token_id]
_chunks.append(torch.tensor(ids))
return _chunks
def noise_text_input(text: str, noise_prob=0.2):
"""Takes a string, returns noised version of it"""
splitted = text.split(" ")
bool_mask = torch.empty(len(splitted)).uniform_() > 1 - noise_prob
noised = []
for word, boolean in zip(splitted, bool_mask):
if boolean:
if len(word) > 1:
idx = random.randint(1, len(word) - 1)
noised.append(word[:idx])
noised.append(word[idx:])
else:
noised.append(word)
return " ".join(noised)
def make_transformer_inputs(
input_ids, max_length, padding_value, prefix="", make_labels=False, **kwargs
):
lengths = [s.size(0) for s in input_ids]
max_len = max(lengths)
if max_len > max_length:
max_len = max_length
out_dims = (len(input_ids), max_len)
padded_input_ids = input_ids[0].data.new(*out_dims).fill_(padding_value)
attention_mask = padded_input_ids.clone()
token_type_ids = padded_input_ids.clone()
for i, tensor in enumerate(input_ids):
length = tensor.size(0)
if length > max_length:
length = max_length
tensor = tensor[:length]
padded_input_ids[i, :length] = tensor
attention_mask[i, :length] = torch.ones_like(tensor)
batch = {
f"{prefix}input_ids": padded_input_ids,
f"{prefix}attention_mask": attention_mask,
f"{prefix}token_type_ids": token_type_ids,
}
if make_labels:
lm_labels = padded_input_ids.clone()
lm_labels[lm_labels == padding_value] = -100
batch["lm_labels"] = lm_labels
batch.update(kwargs)
return batch
class Data(tud.Dataset):
def __init__(self, token_ids: List[torch.Tensor], tokenizer, noise_prob=0.2):
self.token_ids = token_ids
self.tokenizer = tokenizer
self.len = len(token_ids)
self.noise_prob = noise_prob
def __len__(self):
return self.len
def __getitem__(self, idx):
tgt_ids = self.token_ids[idx]
decoded = self.tokenizer.decode(tgt_ids, skip_special_tokens=True)
noised = noise_text_input(decoded, self.noise_prob)
src = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(noised))
src = [self.tokenizer.cls_token_id] + src + [self.tokenizer.sep_token_id]
src_ids = torch.tensor(src)
return dict(src_input_ids=src_ids, tgt_input_ids=tgt_ids)
class Collater:
def __init__(self, tokenizer, max_length=128):
self.tokenizer = tokenizer
self.max_length = max_length
def __call__(self, batch: List):
src = [x["src_input_ids"] for x in batch]
tgt = [x["tgt_input_ids"] for x in batch]
src_batch = self.collate(src)
tgt_batch = self.collate(tgt, "decoder_", make_labels=True)
src_batch.update(tgt_batch)
return src_batch
def collate(self, input_ids, prefix="", make_labels=False):
return make_transformer_inputs(
input_ids, self.max_length, self.tokenizer.pad_token_id, prefix, make_labels
)
import pytorch_lightning as pl
from transformers import EncoderDecoderModel, BertTokenizer
import torch
import torch_optimizer
import torch.utils.data as tud
class NoamScheduler(torch.optim.lr_scheduler.LambdaLR):
def __init__(self, optimizer, num_warmup_steps=1000, last_epoch=-1):
assert num_warmup_steps > 0
normalize = 1 / (num_warmup_steps * num_warmup_steps ** -1.5)
super().__init__(
optimizer,
lambda step: normalize
* min((step + 1) ** -0.5, (step + 1) * num_warmup_steps ** -1.5),
last_epoch,
)
class Model(pl.LightningModule):
def __init__(
self, hparams, train_dataset=None, val_dataset=None, test_dataset=None
):
super().__init__()
self.hparams = hparams
self.model = EncoderDecoderModel.from_encoder_decoder_pretrained(
"bert-base-cased", "bert-base-cased"
) # initialize Bert2Bert
self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
self.collater = Collater(self.tokenizer, self.hparams.max_length)
def setup(self, step) -> None:
self.train_dataset = prepare_dataset(
self.tokenizer,
"validation", # to save time
self.hparams.max_length,
self.hparams.num_datapoints,
)
self.val_dataset = prepare_dataset(
self.tokenizer,
"validation",
self.hparams.max_length,
self.hparams.num_datapoints,
)
self.test_dataset = prepare_dataset(
self.tokenizer, "test", self.hparams.max_length, self.hparams.num_datapoints
)
def train_dataloader(self):
return tud.DataLoader(
self.test_dataset,
batch_size=self.hparams.train_bs,
shuffle=True,
num_workers=self.hparams.num_workers or 4,
collate_fn=self.collater,
)
def val_dataloader(self):
return tud.DataLoader(
self.val_dataset,
batch_size=self.hparams.val_bs,
shuffle=False,
num_workers=4,
collate_fn=self.collater,
)
def test_dataloader(self):
return tud.DataLoader(
self.test_dataset,
self.hparams.val_bs,
False,
num_workers=self.hparams.num_workers or 4,
collate_fn=self.collater,
)
def forward(self, batch):
return self.model(**batch)
def training_step(self, batch, batch_idx):
loss, logits, *_ = self(batch)
self.logger.log_metrics({"loss": loss.cpu()})
output = {"loss": loss}
return output
def validation_step(self, batch, batch_idx):
return self._shared_val_step(batch, batch_idx, "val")
def validation_epoch_end(self, output):
return self._shared_val_end(output, "val")
def test_step(self, batch, batch_idx):
return self._shared_val_step(batch, batch_idx, "test")
def test_epoch_end(self, output):
return self._shared_val_end(output, "test")
def _shared_val_step(self, batch, batch_idx, prefix):
loss, logits, *_ = self(batch)
preds = logits.argmax(-1) # bs x seqlen
lm_labels = batch["lm_labels"] # bs x seqlen
acc_mask = lm_labels[:, 1:].ne(-100)
correct = preds[:, :-1].eq(lm_labels[:, 1:]) # bs x (seqlen - 1)
frac_tokens_correct = correct.masked_select(acc_mask).float().mean()
correct[~acc_mask] = True
frac_seqs_correct = correct.all(1).float().mean()
logs = {
f"{prefix}_loss": loss,
"frac_tokens_correct": frac_tokens_correct,
"frac_seqs_correct": frac_seqs_correct,
}
return logs
def _shared_val_end(self, output, prefix):
output = self.collate(output)
logs = {"log": output, f"{prefix}_loss": output[f"{prefix}_loss"]}
# self.logger.log_metrics(output)
return logs
def configure_optimizers(self):
opt_class = getattr(torch_optimizer, self.hparams.optimizer)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in self.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": self.hparams.optimizer_kwargs.weight_decay or 1e-7,
},
{
"params": [
p
for n, p in self.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
self.optimizer = opt_class(
optimizer_grouped_parameters, **self.hparams.optimizer_kwargs
)
scheduler = NoamScheduler(
self.optimizer, self.hparams.schedulers_kwargs.num_warmup_steps
)
self.scheduler = {"scheduler": scheduler, "interval": "step"}
return [self.optimizer], [self.scheduler]
def collate(self, output):
keys = output[0].keys()
return_dict = {}
for key in keys:
tensor = output[0][key]
if tensor.dim() == 0:
return_dict[key] = torch.stack([x[key] for x in output]).mean()
elif tensor.dim() == 1:
return_dict[key] = torch.cat([x[key] for x in output]).mean()
return return_dict
hparams = {
"name": "MY-WANDB-NAME",
"project": "MY-WANDB-PROJECT",
"train_bs": 4,
"val_bs": 4,
"num_workers": 4,
"max_length": 160,
"num_datapoints": 100_000,
"optimizer": "Ranger",
"optimizer_kwargs": {
"lr": 3e-4,
"alpha": 0.5,
"betas": [0.95, 0.999],
"eps": 1e-5,
"weight_decay": 1e-3,
# "use_gc": True,
},
"schedulers_kwargs": {"num_warmup_steps": 1000},
"trainer_kwargs": {
"gpus": 2,
"gradient_clip_val": 0.5,
"accumulate_grad_batches": 4,
"min_epochs": 5,
"max_epochs": 100,
"precision": 32,
"distributed_backend": 'ddp', ### Change this to "ddp" when on multi-gpu to see the bug
},
}
import wandb
wandb.login()
from omegaconf import OmegaConf
from pytorch_lightning.loggers import WandbLogger
def train(hparams):
hparams = OmegaConf.create(hparams)
print(hparams.pretty())
log = WandbLogger(name=hparams.name, project=hparams.project)
checkpoint = pl.callbacks.ModelCheckpoint(
filepath="checkpoints/", verbose=True, monitor="val_loss", mode="min"
)
trainer = pl.Trainer(
logger=log, checkpoint_callback=checkpoint, **hparams.trainer_kwargs
)
model = Model(hparams)
trainer.fit(model)
if __name__ == '__main__':
train(hparams)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment