Skip to content

Instantly share code, notes, and snippets.

@cccntu
Last active December 7, 2020 02:32
Show Gist options
  • Save cccntu/967d9624d37024875e6cd094d2bf13ae to your computer and use it in GitHub Desktop.
Save cccntu/967d9624d37024875e6cd094d2bf13ae to your computer and use it in GitHub Desktop.
pytorch-lightning ddp BatchEncoding
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------
# --------------------------------------------
# --------------------------------------------
# USE THIS MODEL TO REPRODUCE A BUG YOU REPORT
# --------------------------------------------
# --------------------------------------------
# --------------------------------------------
import os
from argparse import ArgumentParser
import pytorch_lightning as pl
import torch
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset
from transformers import (AutoModel, AutoModelForSequenceClassification,
AutoTokenizer, DataCollatorWithPadding)
class BoringModel(LightningModule):
def __init__(self):
"""
Testing PL Module
Use as follows:
- subclass
- modify the behavior for what you want
class TestModel(BaseTestModel):
def training_step(...):
# do your own thing
or:
model = BaseTestModel()
model.training_epoch_end = None
"""
super().__init__()
self.layer = torch.nn.Linear(32, 2)
self.model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-uncased"
)
def forward(self, x):
return self.model(**x)
def training_step(self, batch, batch_idx):
print("type(batch)", type(batch))
print(batch["attention_mask"].device)
labels = batch.pop("labels")
output = self.model(**batch, labels=labels)
loss = output[0]
return {"loss": loss}
def training_step_end(self, training_step_outputs):
return training_step_outputs
def training_epoch_end(self, outputs) -> None:
torch.stack([x["loss"] for x in outputs]).mean()
def validation_step(self, batch, batch_idx):
labels = batch.pop("labels")
output = self.model(**batch, labels=labels)
loss = output[0]
return {"x": loss}
def validation_epoch_end(self, outputs) -> None:
torch.stack([x["x"] for x in outputs]).mean()
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
# NOTE: If you are using a cmd line to run your script,
# provide the cmd line as below.
# opt = "--max_epochs 1 --limit_train_batches 1".split(" ")
# parser = ArgumentParser()
# args = parser.parse_args(opt)
model_name = "bert-base-uncased"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
class RandomDataset(Dataset):
def __init__(self):
# lines = ['I love dogs.']
# self.data = [tokenizer(line)] * 10
# print(x)
self.data = [
{
"input_ids": [101, 1045, 2293, 6077, 1012, 102],
"token_type_ids": [0, 0, 0, 0, 0, 0],
"attention_mask": [1, 1, 1, 1, 1, 1],
"labels": 0,
}
] * 10
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
def run_test():
class TestModel(BoringModel):
def on_train_epoch_start(self) -> None:
print("override any method to prove your bug")
# fake data
dset = RandomDataset()
train_dataloader = DataLoader(dset, batch_size=2, collate_fn=data_collator)
val_dataloader = DataLoader(dset, batch_size=2, collate_fn=data_collator)
# model
model = TestModel()
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
args.default_root_dir = os.getcwd()
args.limit_train_batches = 1
args.limit_val_batches = 1
args.max_epochs = 1
args.weights_summary = None
trainer = Trainer.from_argparse_args(args)
trainer.fit(model, train_dataloader, val_dataloader)
if __name__ == "__main__":
run_test()
# python bug_report_model.py
# python bug_report_model.py --gpus=1
# python bug_report_model.py --gpus=2 --acce=ddp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment