Skip to content

Instantly share code, notes, and snippets.

@ben0it8
Last active July 18, 2019 14:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ben0it8/10247e93f1049c9ade08bace43f2ba1d to your computer and use it in GitHub Desktop.
Save ben0it8/10247e93f1049c9ade08bace43f2ba1d to your computer and use it in GitHub Desktop.
prepare training and eval loops
from ignite.engine import Engine, Events
from ignite.metrics import RunningAverage, Accuracy
from ignite.handlers import ModelCheckpoint
from ignite.contrib.handlers import CosineAnnealingScheduler, PiecewiseLinear, create_lr_scheduler_with_warmup, ProgressBar
import torch.nn.functional as F
from pytorch_transformers.optimization import AdamW
# Bert optimizer
optimizer = AdamW(model.parameters(), lr=finetuning_config.lr, correct_bias=False)
def update(engine, batch):
"update function for training"
model.train()
inputs, labels = (t.to(finetuning_config.device) for t in batch)
inputs = inputs.transpose(0, 1).contiguous() # [S, B]
_, loss = model(inputs,
clf_tokens_mask = (inputs == tokenizer.vocab[processor.CLS]),
clf_labels=labels)
loss = loss / finetuning_config.gradient_acc_steps
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), finetuning_config.max_norm)
if engine.state.iteration % finetuning_config.gradient_acc_steps == 0:
optimizer.step()
optimizer.zero_grad()
return loss.item()
def inference(engine, batch):
"update function for evaluation"
model.eval()
with torch.no_grad():
batch, labels = (t.to(finetuning_config.device) for t in batch)
inputs = batch.transpose(0, 1).contiguous()
logits = model(inputs,
clf_tokens_mask = (inputs == tokenizer.vocab[processor.CLS]),
padding_mask = (batch == tokenizer.vocab[processor.PAD]))
return logits, labels
trainer = Engine(update)
evaluator = Engine(inference)
# add metric to evaluator
Accuracy().attach(evaluator, "accuracy")
# add evaluator to trainer: eval on valid set after each epoch
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
evaluator.run(valid_dl)
print(f"validation epoch: {engine.state.epoch} acc: {100*evaluator.state.metrics['accuracy']}")
# lr schedule: linearly warm-up to lr and then to zero
scheduler = PiecewiseLinear(optimizer, 'lr', [(0, 0.0), (finetuning_config.n_warmup, finetuning_config.lr),
(len(train_dl)*finetuning_config.n_epochs, 0.0)])
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
# add progressbar with loss
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
ProgressBar(persist=True).attach(trainer, metric_names=['loss'])
# save checkpoints and finetuning config
checkpoint_handler = ModelCheckpoint(finetuning_config.log_dir, 'finetuning_checkpoint',
save_interval=1, require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'imdb_model': model})
# save config to logdir
torch.save(finetuning_config, os.path.join(finetuning_config.log_dir, 'fine_tuning_args.bin'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment