Skip to content

Instantly share code, notes, and snippets.

@Mrpatekful
Last active November 12, 2019 22:33
Show Gist options
  • Save Mrpatekful/88b06a21c0d0aebb9531519e2eae6a02 to your computer and use it in GitHub Desktop.
Save Mrpatekful/88b06a21c0d0aebb9531519e2eae6a02 to your computer and use it in GitHub Desktop.
Named entity recognition with BERT and CRF model on custom dataset.
"""
@author: Patrik Purgai
@copyright: Copyright 2019, named-entity-recognition
@license: MIT
@email: purgai.patrik@gmail.com
@date: 2019.07.12.
"""
# pylint: disable=import-error
# pylint: disable=no-name-in-module
# pylint: disable=no-member
# pylint: disable=not-callable
# pylint: disable=used-before-assignment
# pylint: disable=unused-variable
import sys
import torch
import random
import json
import copy
import argparse
import itertools
import collections
import statistics
import tabulate
import functools
import os
import time
import re
import numpy as np
try:
from apex import amp
APEX_INSTALLED = True
except ImportError:
APEX_INSTALLED = False
from ignite.contrib.handlers import ProgressBar
from ignite.handlers import (
ModelCheckpoint, EarlyStopping)
from ignite.engine import Engine as Engine_, Events
from ignite.metrics import RunningAverage, Metric
from ignite._utils import _to_hours_mins_secs
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Sampler, DataLoader
from torch.nn.modules import Module, Linear
from os.path import (
exists, join, abspath, dirname)
from transformers.modeling_bert import \
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers import (
BertModel, BertTokenizer, AdamW)
from torchcrf import CRF
IGNORE_IDX = -1
def setup_train_args():
"""
Sets up the training arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--working_dir',
type=str,
required=True,
help='Path of the working directory.')
parser.add_argument(
'--max_epochs',
type=int,
default=25,
help='Maximum number of epochs for training.')
parser.add_argument(
'--no_cuda',
action='store_true',
help='Device for training.')
parser.add_argument(
'--fp16',
action='store_true',
help='Use mixed precision training.')
parser.add_argument(
'--model_type',
type=str,
default='bert-large-cased-whole-word-masking',
choices=list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP),
help='Name of the pretrained model.')
parser.add_argument(
'--lr',
type=float,
default=1e-3,
help='Learning rate for the model.')
parser.add_argument(
'--batch_size',
type=int,
default=128,
help='Batch size during training.')
parser.add_argument(
'--warmup_prop',
type=float,
default=0.1,
help='Percentage of total steps for warmup.')
parser.add_argument(
'--warmup_steps',
type=int,
default=16000,
help='Number of warmup steps.')
parser.add_argument(
'--total_steps',
type=int,
default=1000000,
help='Number of optimization steps.')
parser.add_argument(
'--patience',
type=int,
default=5,
help='Number of patience epochs before termination.')
parser.add_argument(
'--grad_accum_steps',
type=int,
default=4,
help='Number of steps for grad accum.')
parser.add_argument(
'--clip_grad',
type=float,
default=None,
help='Gradient clipping norm value.')
parser.add_argument(
'--seed',
type=int,
default=None,
help='Random seed for training.')
return parser.parse_args()
def set_random_seed(args):
"""
Sets the random seed for training.
"""
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
if args.cuda:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def flatten(iterable):
"""
Flattens an iterable of dicts.
"""
result = {}
for key, value in iterable.items():
if isinstance(value, dict):
result.update({
f'{key}_{k}': v for k, v in
flatten(value).items()
})
else:
result[key] = value
return result
def read_file(data_path):
"""
Reads a conll formated file.
"""
with open(data_path, 'r') as fh:
for line in fh:
yield line.strip().split()
def generate_examples(data_path):
"""
Generates words and the entity word tags
from a conll formatted file.
"""
sentence, labels = [], []
for line in read_file(data_path):
if len(line) == 0:
yield sentence, labels
sentence, labels = [], []
else:
# it is assumed that the ner tag
# is the final tag in a row
word, *_, label = line
sentence.append(word)
labels.append(label)
def collate_fn(examples, pad_idx):
"""
Creates a batch from a list of examples.
"""
encoded_words, encoded_labels, _ = zip(*examples)
padded_words, padded_labels = [], []
max_len = max(len(s) for s in encoded_words)
for word_ids, label_ids in zip(
encoded_words, encoded_labels):
length_diff = max_len - len(word_ids)
padded_words.append(
word_ids + [pad_idx] * length_diff)
padded_labels.append(
label_ids + [IGNORE_IDX] * length_diff)
inputs = np.array(padded_words, dtype=np.int64)
targets = np.array(padded_labels, dtype=np.int64)
mask = inputs != pad_idx
return np.stack([inputs, mask], axis=0), targets
def encode_example(
words, labels, labels_to_ids, tokenizer):
"""
Creates a encoded version of words and labels.
"""
encoded_words, encoded_labels, lengths = [], [], []
for word, label in zip(words, labels):
encoded_word = tokenizer.encode(word)
padding = [IGNORE_IDX] * (len(encoded_word) - 1)
encoded_label = int(labels_to_ids[label])
# adding padding tokens to labels so only
# the first subword token will be labeled
encoded_words.extend(encoded_word)
encoded_labels.extend([encoded_label] + padding)
lengths.append(len(encoded_word))
assert len(encoded_words) == len(encoded_labels)
return encoded_words, encoded_labels, lengths
def generate_encoded_examples(
data_path, labels_to_ids, tokenizer):
"""
Generates encoded examples from the raw data.
"""
for words, labels in tqdm.tqdm(
list(generate_examples(data_path)),
desc='encoding dataset', leave=False):
yield encode_example(
words, labels, labels_to_ids, tokenizer)
def create_dataset(args):
"""
Creates the dataset from the working directory.
"""
tokenizer = BertTokenizer.from_pretrained(
args.model_type)
train_path = join(args.working_dir, 'train.txt')
valid_path = join(args.working_dir, 'valid.txt')
assert exists(train_path) and exists(valid_path)
labels = create_labels(train_path, args.working_dir)
enc_train_path = join(args.working_dir, 'train.pt')
enc_valid_path = join(args.working_dir, 'valid.pt')
if not exists(enc_train_path):
train_data = list(generate_encoded_examples(
train_path, labels['labels_to_ids'],
tokenizer))
torch.save(train_data, enc_train_path)
else:
train_data = torch.load(enc_train_path)
if not exists(enc_valid_path):
valid_data = list(generate_encoded_examples(
valid_path, labels['labels_to_ids'],
tokenizer))
torch.save(valid_data, enc_valid_path)
else:
valid_data = torch.load(enc_valid_path)
pad_idx = tokenizer.pad_token_id
train_loader = DataLoader(
train_data,
batch_size=args.batch_size,
sampler=BucketSampler(train_data),
num_workers=1,
pin_memory=True,
collate_fn=lambda b: collate_fn(b, pad_idx))
valid_loader = DataLoader(
valid_data,
batch_size=args.batch_size,
sampler=BucketSampler(
valid_data, shuffle=False),
num_workers=1,
pin_memory=True,
collate_fn=lambda b: collate_fn(b, pad_idx))
return train_loader, valid_loader, tokenizer, labels
def get_latest_file(working_dir, pattern):
"""
Returns the latest file from the directory.
"""
files = sorted([
join(working_dir, f) for f
in os.listdir(working_dir)
if re.match(pattern, f) is not None
], key=lambda x: int(
re.search(r'\d+', x).group(0)))
return files[-1] if len(files) > 0 else None
def create_labels(data_path, working_dir):
"""
Creates the labels for the provided conll data.
"""
labels_path = join(working_dir, 'labels.json')
if not exists(labels_path):
labels = [
label for _, label in
generate_examples(data_path)
]
label_freqs = dict(collections.Counter(
itertools.chain(*labels)))
labels_to_ids = {
label: str(idx) for idx, label
in enumerate(label_freqs)
}
ids_to_labels = {
idx: label for label, idx
in labels_to_ids.items()
}
labels = {
'label_freqs': label_freqs,
'labels_to_ids': labels_to_ids,
'ids_to_labels': ids_to_labels
}
with open(labels_path, 'w') as fh:
json.dump(labels, fh)
else:
with open(labels_path, 'r') as fh:
labels = json.load(fh)
return labels
class Engine(Engine_):
"""
Simple extension to the Ignite `Engine` class
to check for exception inside the loop.
"""
def _run_once_on_dataset(self):
start_time = time.time()
try:
for batch in self.state.dataloader:
self.state.batch = batch
self.state.iteration += 1
self._fire_event(Events.ITERATION_STARTED)
try:
self.state.output = \
self._process_function(
self, self.state.batch)
except BaseException as e:
self._logger.error(
'Current step is terminating '
'due to exception: %s.',
str(e))
self._handle_exception(e)
self._fire_event(Events.ITERATION_COMPLETED)
if self.should_terminate or \
self.should_terminate_single_epoch:
self.should_terminate_single_epoch = False
break
except BaseException as e:
self._logger.error(
'Current run is terminating '
'due to exception: %s.',
str(e))
self._handle_exception(e)
time_taken = time.time() - start_time
hours, mins, secs = _to_hours_mins_secs(time_taken)
return hours, mins, secs
class BucketSampler(Sampler):
"""
Bucketized sampler that yields exclusive groups
of indices based on the sequence length.
"""
def __init__(self, data_source, bucket_size=200,
shuffle=True):
self.bucket_size = bucket_size
self.shuffle = shuffle
self.indices = sorted(
list(range(len(data_source))),
key=lambda i: len(data_source[i][0]))
def __iter__(self):
# divides the data into bucket size segments
# and only these segment are shuffled
def generate_indices(group):
for idx in group:
if idx is not None:
yield idx
def group_elements(iterable, group_size):
groups = [iter(iterable)] * group_size
return itertools.zip_longest(*groups)
groups = group_elements(
iterable=self.indices,
group_size=self.bucket_size)
groups = list(groups)
random.shuffle(groups)
for group in groups:
indices = list(generate_indices(group))
if self.shuffle:
indices = copy.deepcopy(indices)
random.shuffle(indices)
yield from indices
def __len__(self):
return len(self.indices)
class MultiMetric(Metric):
"""
Computes min, max, med and mean for a single
metric and the end of and epoch.
"""
def reset(self):
self._results = []
def update(self, output):
self._results.append(output)
def compute(self):
results = []
# dividing the data to form a single
# list of results for each label class
for result in zip(*self._results):
# filtering the nan values
values = [v for v in result if v == v]
if len(values) > 0:
results.append(statistics.mean(values))
return {
'min': min(results),
'max': max(results),
'mean': statistics.mean(results),
'med': statistics.median(results)
}
class EntityRecognizer(Module):
"""
Wrapper module for the BERT entity recognizer
with a CRF output layer.
"""
def __init__(self, model_type, n_labels):
super().__init__()
self.bert = BertModel.from_pretrained(
model_type)
self.proj = Linear(
self.bert.config.hidden_size, n_labels)
self.crf = CRF(n_labels, batch_first=True)
def forward(self, inputs, targets):
mask = inputs[1]
logits = self.bert(
input_ids=inputs[0],
attention_mask=inputs[1])
emissions = self.proj(logits[0])
loss = self.crf(
emissions, targets, mask.bool())
loss /= mask.long().sum()
return -loss
@torch.no_grad()
def decode(self, inputs):
mask = inputs[1]
logits = self.bert(
input_ids=inputs[0],
attention_mask=inputs[1])
emissions = self.proj(logits[0])
preds = self.crf.decode(
emissions, mask.bool())
return preds
def compute_f1_score(
preds, targets, num_classes):
"""
Computes the multi-class f1 score of the
predictions.
"""
targets = targets.view(-1)
valid_indices = targets != IGNORE_IDX
targets = targets[valid_indices]
preds = preds[valid_indices]
labels = torch.arange(
0, num_classes, dtype=torch.int64)
labels = labels.unsqueeze(1)
labels = labels.to(preds.device)
preds = preds.unsqueeze(0).expand(
num_classes, -1) == labels
targets = targets.unsqueeze(0).expand(
num_classes, -1) == labels
recall = compute_recall(
preds=preds, targets=targets)
precision = compute_precision(
preds=preds, targets=targets)
f1_score = 2 * (precision * recall) / \
(precision + recall)
return f1_score
def compute_precision(preds, targets):
"""
Computes the precision of the predictions.
"""
true_positives = (preds & targets).sum(dim=-1)
selected = preds.sum(dim=-1)
precision = true_positives.float() / \
selected.float()
return precision
def compute_recall(preds, targets):
"""
Computes the recall of the predictions.
"""
true_positives = (preds & targets).sum(dim=-1)
relevants = targets.sum(dim=-1)
recall = true_positives.float() / \
relevants.float()
return recall
def compute_accuracy(preds, targets):
"""
Computes the accuracy of the predictions.
"""
targets = targets.view(-1)
# computing accuracy without including
# the values at the ignore indices
not_ignore = targets.ne(IGNORE_IDX)
num_targets = not_ignore.long().sum()
num_targets = num_targets.item()
correct = (targets == preds) & not_ignore
correct = correct.float().sum()
accuracy = correct / num_targets
return accuracy
def compute_metrics(preds, targets, n_labels):
"""
Computes the loss and accuracy from the
outputs of the model.
"""
accuracy = compute_accuracy(preds, targets)
f1_score = compute_f1_score(
preds, targets, n_labels)
return {
'acc': accuracy.item(),
'f1': f1_score.tolist()
}
def attach_metrics(
engine, metrics, metric_type=RunningAverage):
"""
Attaches the metrics to the engine for
logging.
"""
for metric in metrics:
metric_type(
output_transform= \
lambda x: x[metric]).attach(
engine, metric)
# implementation is from DialoGPT repo
def noam_decay(step, warmup_steps, d_model):
"""
Learning rate schedule described in
https://arxiv.org/pdf/1706.03762.pdf.
"""
return (
d_model ** (-0.5) * min(step ** (-0.5),
step * warmup_steps**(-1.5)))
# implementation is from DialoGPT repo
def set_lr(step, optimizer, lr, warmup_steps, d_model):
"""
Learning rate scheduler that applies either
noam or noamwd rule.
"""
lr_this_step = lr * 1e4 * \
noam_decay(step + 1, warmup_steps, d_model)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
def main():
"""
Performs training, validation and testing.
"""
args = setup_train_args()
args.cuda = torch.cuda.is_available() and \
not args.no_cuda
args.fp16 = args.fp16 and APEX_INSTALLED \
and args.cuda
if args.seed is not None:
set_random_seed(args)
device = torch.device(
'cuda' if args.cuda else 'cpu')
# creating dataset and storing dataset splits
# as individual variables for convenience
train_loader, valid_loader, tokenizer, labels = \
create_dataset(args)
n_labels = len(labels['label_freqs'])
model = EntityRecognizer(
args.model_type, n_labels)
model = model.to(device)
model_file = get_latest_file(
working_dir=args.working_dir,
pattern=r'model_\d+.pth')
# if model file exists resume training
if model_file is not None and args.warm_start:
print('Loading model from {}'.format(
model_file))
model_state = torch.load(
model_file, map_location=device)
model.load_state_dict(model_state)
optimizer = AdamW(
params=model.parameters(),
lr=args.lr,
weight_decay=0.01)
optim_file = get_latest_file(
working_dir=args.working_dir,
pattern=r'optim_\d+.pth')
if optim_file is not None and args.warm_start:
print('Loading optimizer from {}'.format(
optim_file))
optim_state = torch.load(
optim_file, map_location=device)
optimizer.load_state_dict(optim_state)
if args.fp16:
model, optimizer = amp.initialize(
model, optimizer, opt_level='O2')
set_lr_fn = functools.partial(
set_lr,
optimizer=optimizer,
lr=args.lr,
warmup_steps=args.warmup_steps,
d_model=model.bert.config.hidden_size)
def forward_step(batch):
"""
Applies forward pass with the given batch.
"""
inputs, targets = batch
inputs = torch.as_tensor(inputs)
inputs = inputs.to(device)
targets = torch.as_tensor(targets)
targets = targets.long().to(device)
loss = model(inputs, targets)
outputs = model.decode(inputs)
outputs = [
preds + [-1] * (
inputs.size(-1) - len(preds))
for preds in outputs
]
preds = torch.as_tensor(outputs)
preds = preds.to(device)
preds = preds.view(-1)
metrics = compute_metrics(
preds, targets, n_labels)
metrics['loss'] = loss
return metrics
def train_step(engine, batch):
"""
Propagates the inputs forward and updates
the parameters.
"""
model.train()
results = forward_step(batch)
loss = results['loss']
loss /= args.grad_accum_steps
backward(loss)
if args.clip_grad is not None:
clip_grad_norm(args.clip_grad)
if engine.state.iteration % \
args.grad_accum_steps == 0:
set_lr_fn(engine.state.iteration)
optimizer.step()
optimizer.zero_grad()
results['loss'] = results['loss'].item()
return results
def eval_step(engine, batch):
"""
Propagates the inputs forward without
storing any gradients.
"""
model.eval()
with torch.no_grad():
results = forward_step(batch)
results['loss'] = results['loss'].item()
return results
def backward(loss):
"""
Backpropagates the loss in either mixed or
normal precision mode.
"""
if args.fp16:
with amp.scale_loss(
loss, optimizer) as scaled:
scaled.backward()
else:
loss.backward()
def clip_grad_norm(max_norm):
"""
Applies gradient clipping.
"""
if args.fp16:
params = amp.master_params(optimizer)
clip_grad_norm_(params, max_norm)
else:
clip_grad_norm_(
model.parameters(), max_norm)
# after an epoch the results are metrics
# are evaluated on the whole train dataset
trainer = Engine(train_step)
train_eval = Engine(eval_step)
valid_eval = Engine(eval_step)
# these metrics are derived from MultiMetric
valid_metrics = [
'loss', 'acc', 'f1_mean',
'f1_min', 'f1_max', 'f1_med'
]
attach_metrics(trainer, ['loss'])
attach_metrics(train_eval, ['loss', 'acc'])
attach_metrics(valid_eval, ['loss', 'acc'])
attach_metrics(valid_eval, ['f1'], MultiMetric)
pbar = ProgressBar()
pbar.attach(trainer, metric_names=['loss'])
# adding model checkpoint handler
checkpoint = ModelCheckpoint(
args.working_dir, '',
n_saved=3,
require_empty=False,
save_as_state_dict=True,
score_function=lambda e: -e.state.metrics['loss'])
checkpoint_dict = {
'model': model,
'optim': optimizer
}
valid_eval.add_event_handler(
Events.COMPLETED, checkpoint,
checkpoint_dict)
early_stopping = EarlyStopping(
patience=args.patience,
score_function=lambda e: -e.state.metrics['loss'],
trainer=trainer)
valid_eval.add_event_handler(
Events.COMPLETED, early_stopping)
# loading history for training logs
history_path = join(args.working_dir, 'history.json')
history = collections.defaultdict(list)
# NOTE the hardcoded values to keep track of
# in the history
valid_headers = [
f'valid_{metric}' for metric in valid_metrics]
train_headers = [
f'train_{metric}' for metric in ['loss', 'acc']
]
headers = ['epoch'] + train_headers + valid_headers
if exists(history_path):
with open(history_path, 'r') as fh:
history = json.load(fh)
def record_history(results):
"""
Records the results to the history.
"""
# saving history and handling unexpected
# keyboard interrupt
for header in headers:
history[header].append(results[header])
while True:
try:
with open(history_path, 'w') as fh:
json.dump(history, fh)
break
except KeyboardInterrupt:
pass
def run_eval(evaluator, loader, prefix, metrics):
"""
Runs evaluation with the given evaluator.
"""
evaluator.run(loader)
results = flatten(evaluator.state.metrics)
return {
f'{prefix}_{metric}': results[metric]
for metric in metrics
}
@trainer.on(Events.EPOCH_COMPLETED)
def log_eval_results(trainer):
"""
Logs the training results.
"""
results = {'epoch': trainer.state.epoch}
results.update(run_eval(
train_eval, train_loader,
'train', ['loss', 'acc']))
results.update(run_eval(
valid_eval, valid_loader,
'valid', valid_metrics))
record_history(results)
data = list(zip(*[history[h] for h in headers]))
table = tabulate.tabulate(
data, headers, floatfmt='.3f')
print(table.split('\n')[-1])
@trainer.on(Events.EXCEPTION_RAISED)
def handle_exception(engine, e):
if isinstance(e, KeyboardInterrupt) and \
engine.state.iteration > 1:
engine.terminate()
elif isinstance(e, RuntimeError):
pass
print(vars(args))
data = list(zip(*[history[h] for h in headers]))
# printing the initial table headers and
# previous results of the training if resuming
print(tabulate.tabulate(
data, headers, floatfmt='.3f'))
trainer.run(train_loader, args.max_epochs)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment