Last active
November 12, 2019 22:33
-
-
Save Mrpatekful/88b06a21c0d0aebb9531519e2eae6a02 to your computer and use it in GitHub Desktop.
Named entity recognition with BERT and CRF model on custom dataset.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
@author: Patrik Purgai | |
@copyright: Copyright 2019, named-entity-recognition | |
@license: MIT | |
@email: purgai.patrik@gmail.com | |
@date: 2019.07.12. | |
""" | |
# pylint: disable=import-error | |
# pylint: disable=no-name-in-module | |
# pylint: disable=no-member | |
# pylint: disable=not-callable | |
# pylint: disable=used-before-assignment | |
# pylint: disable=unused-variable | |
import sys | |
import torch | |
import random | |
import json | |
import copy | |
import argparse | |
import itertools | |
import collections | |
import statistics | |
import tabulate | |
import functools | |
import os | |
import time | |
import re | |
import numpy as np | |
try: | |
from apex import amp | |
APEX_INSTALLED = True | |
except ImportError: | |
APEX_INSTALLED = False | |
from ignite.contrib.handlers import ProgressBar | |
from ignite.handlers import ( | |
ModelCheckpoint, EarlyStopping) | |
from ignite.engine import Engine as Engine_, Events | |
from ignite.metrics import RunningAverage, Metric | |
from ignite._utils import _to_hours_mins_secs | |
from torch.nn.utils import clip_grad_norm_ | |
from torch.utils.data import Sampler, DataLoader | |
from torch.nn.modules import Module, Linear | |
from os.path import ( | |
exists, join, abspath, dirname) | |
from transformers.modeling_bert import \ | |
BERT_PRETRAINED_MODEL_ARCHIVE_MAP | |
from transformers import ( | |
BertModel, BertTokenizer, AdamW) | |
from torchcrf import CRF | |
IGNORE_IDX = -1 | |
def setup_train_args(): | |
""" | |
Sets up the training arguments. | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--working_dir', | |
type=str, | |
required=True, | |
help='Path of the working directory.') | |
parser.add_argument( | |
'--max_epochs', | |
type=int, | |
default=25, | |
help='Maximum number of epochs for training.') | |
parser.add_argument( | |
'--no_cuda', | |
action='store_true', | |
help='Device for training.') | |
parser.add_argument( | |
'--fp16', | |
action='store_true', | |
help='Use mixed precision training.') | |
parser.add_argument( | |
'--model_type', | |
type=str, | |
default='bert-large-cased-whole-word-masking', | |
choices=list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP), | |
help='Name of the pretrained model.') | |
parser.add_argument( | |
'--lr', | |
type=float, | |
default=1e-3, | |
help='Learning rate for the model.') | |
parser.add_argument( | |
'--batch_size', | |
type=int, | |
default=128, | |
help='Batch size during training.') | |
parser.add_argument( | |
'--warmup_prop', | |
type=float, | |
default=0.1, | |
help='Percentage of total steps for warmup.') | |
parser.add_argument( | |
'--warmup_steps', | |
type=int, | |
default=16000, | |
help='Number of warmup steps.') | |
parser.add_argument( | |
'--total_steps', | |
type=int, | |
default=1000000, | |
help='Number of optimization steps.') | |
parser.add_argument( | |
'--patience', | |
type=int, | |
default=5, | |
help='Number of patience epochs before termination.') | |
parser.add_argument( | |
'--grad_accum_steps', | |
type=int, | |
default=4, | |
help='Number of steps for grad accum.') | |
parser.add_argument( | |
'--clip_grad', | |
type=float, | |
default=None, | |
help='Gradient clipping norm value.') | |
parser.add_argument( | |
'--seed', | |
type=int, | |
default=None, | |
help='Random seed for training.') | |
return parser.parse_args() | |
def set_random_seed(args): | |
""" | |
Sets the random seed for training. | |
""" | |
torch.manual_seed(args.seed) | |
random.seed(args.seed) | |
np.random.seed(args.seed) | |
if args.cuda: | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def flatten(iterable): | |
""" | |
Flattens an iterable of dicts. | |
""" | |
result = {} | |
for key, value in iterable.items(): | |
if isinstance(value, dict): | |
result.update({ | |
f'{key}_{k}': v for k, v in | |
flatten(value).items() | |
}) | |
else: | |
result[key] = value | |
return result | |
def read_file(data_path): | |
""" | |
Reads a conll formated file. | |
""" | |
with open(data_path, 'r') as fh: | |
for line in fh: | |
yield line.strip().split() | |
def generate_examples(data_path): | |
""" | |
Generates words and the entity word tags | |
from a conll formatted file. | |
""" | |
sentence, labels = [], [] | |
for line in read_file(data_path): | |
if len(line) == 0: | |
yield sentence, labels | |
sentence, labels = [], [] | |
else: | |
# it is assumed that the ner tag | |
# is the final tag in a row | |
word, *_, label = line | |
sentence.append(word) | |
labels.append(label) | |
def collate_fn(examples, pad_idx): | |
""" | |
Creates a batch from a list of examples. | |
""" | |
encoded_words, encoded_labels, _ = zip(*examples) | |
padded_words, padded_labels = [], [] | |
max_len = max(len(s) for s in encoded_words) | |
for word_ids, label_ids in zip( | |
encoded_words, encoded_labels): | |
length_diff = max_len - len(word_ids) | |
padded_words.append( | |
word_ids + [pad_idx] * length_diff) | |
padded_labels.append( | |
label_ids + [IGNORE_IDX] * length_diff) | |
inputs = np.array(padded_words, dtype=np.int64) | |
targets = np.array(padded_labels, dtype=np.int64) | |
mask = inputs != pad_idx | |
return np.stack([inputs, mask], axis=0), targets | |
def encode_example( | |
words, labels, labels_to_ids, tokenizer): | |
""" | |
Creates a encoded version of words and labels. | |
""" | |
encoded_words, encoded_labels, lengths = [], [], [] | |
for word, label in zip(words, labels): | |
encoded_word = tokenizer.encode(word) | |
padding = [IGNORE_IDX] * (len(encoded_word) - 1) | |
encoded_label = int(labels_to_ids[label]) | |
# adding padding tokens to labels so only | |
# the first subword token will be labeled | |
encoded_words.extend(encoded_word) | |
encoded_labels.extend([encoded_label] + padding) | |
lengths.append(len(encoded_word)) | |
assert len(encoded_words) == len(encoded_labels) | |
return encoded_words, encoded_labels, lengths | |
def generate_encoded_examples( | |
data_path, labels_to_ids, tokenizer): | |
""" | |
Generates encoded examples from the raw data. | |
""" | |
for words, labels in tqdm.tqdm( | |
list(generate_examples(data_path)), | |
desc='encoding dataset', leave=False): | |
yield encode_example( | |
words, labels, labels_to_ids, tokenizer) | |
def create_dataset(args): | |
""" | |
Creates the dataset from the working directory. | |
""" | |
tokenizer = BertTokenizer.from_pretrained( | |
args.model_type) | |
train_path = join(args.working_dir, 'train.txt') | |
valid_path = join(args.working_dir, 'valid.txt') | |
assert exists(train_path) and exists(valid_path) | |
labels = create_labels(train_path, args.working_dir) | |
enc_train_path = join(args.working_dir, 'train.pt') | |
enc_valid_path = join(args.working_dir, 'valid.pt') | |
if not exists(enc_train_path): | |
train_data = list(generate_encoded_examples( | |
train_path, labels['labels_to_ids'], | |
tokenizer)) | |
torch.save(train_data, enc_train_path) | |
else: | |
train_data = torch.load(enc_train_path) | |
if not exists(enc_valid_path): | |
valid_data = list(generate_encoded_examples( | |
valid_path, labels['labels_to_ids'], | |
tokenizer)) | |
torch.save(valid_data, enc_valid_path) | |
else: | |
valid_data = torch.load(enc_valid_path) | |
pad_idx = tokenizer.pad_token_id | |
train_loader = DataLoader( | |
train_data, | |
batch_size=args.batch_size, | |
sampler=BucketSampler(train_data), | |
num_workers=1, | |
pin_memory=True, | |
collate_fn=lambda b: collate_fn(b, pad_idx)) | |
valid_loader = DataLoader( | |
valid_data, | |
batch_size=args.batch_size, | |
sampler=BucketSampler( | |
valid_data, shuffle=False), | |
num_workers=1, | |
pin_memory=True, | |
collate_fn=lambda b: collate_fn(b, pad_idx)) | |
return train_loader, valid_loader, tokenizer, labels | |
def get_latest_file(working_dir, pattern): | |
""" | |
Returns the latest file from the directory. | |
""" | |
files = sorted([ | |
join(working_dir, f) for f | |
in os.listdir(working_dir) | |
if re.match(pattern, f) is not None | |
], key=lambda x: int( | |
re.search(r'\d+', x).group(0))) | |
return files[-1] if len(files) > 0 else None | |
def create_labels(data_path, working_dir): | |
""" | |
Creates the labels for the provided conll data. | |
""" | |
labels_path = join(working_dir, 'labels.json') | |
if not exists(labels_path): | |
labels = [ | |
label for _, label in | |
generate_examples(data_path) | |
] | |
label_freqs = dict(collections.Counter( | |
itertools.chain(*labels))) | |
labels_to_ids = { | |
label: str(idx) for idx, label | |
in enumerate(label_freqs) | |
} | |
ids_to_labels = { | |
idx: label for label, idx | |
in labels_to_ids.items() | |
} | |
labels = { | |
'label_freqs': label_freqs, | |
'labels_to_ids': labels_to_ids, | |
'ids_to_labels': ids_to_labels | |
} | |
with open(labels_path, 'w') as fh: | |
json.dump(labels, fh) | |
else: | |
with open(labels_path, 'r') as fh: | |
labels = json.load(fh) | |
return labels | |
class Engine(Engine_): | |
""" | |
Simple extension to the Ignite `Engine` class | |
to check for exception inside the loop. | |
""" | |
def _run_once_on_dataset(self): | |
start_time = time.time() | |
try: | |
for batch in self.state.dataloader: | |
self.state.batch = batch | |
self.state.iteration += 1 | |
self._fire_event(Events.ITERATION_STARTED) | |
try: | |
self.state.output = \ | |
self._process_function( | |
self, self.state.batch) | |
except BaseException as e: | |
self._logger.error( | |
'Current step is terminating ' | |
'due to exception: %s.', | |
str(e)) | |
self._handle_exception(e) | |
self._fire_event(Events.ITERATION_COMPLETED) | |
if self.should_terminate or \ | |
self.should_terminate_single_epoch: | |
self.should_terminate_single_epoch = False | |
break | |
except BaseException as e: | |
self._logger.error( | |
'Current run is terminating ' | |
'due to exception: %s.', | |
str(e)) | |
self._handle_exception(e) | |
time_taken = time.time() - start_time | |
hours, mins, secs = _to_hours_mins_secs(time_taken) | |
return hours, mins, secs | |
class BucketSampler(Sampler): | |
""" | |
Bucketized sampler that yields exclusive groups | |
of indices based on the sequence length. | |
""" | |
def __init__(self, data_source, bucket_size=200, | |
shuffle=True): | |
self.bucket_size = bucket_size | |
self.shuffle = shuffle | |
self.indices = sorted( | |
list(range(len(data_source))), | |
key=lambda i: len(data_source[i][0])) | |
def __iter__(self): | |
# divides the data into bucket size segments | |
# and only these segment are shuffled | |
def generate_indices(group): | |
for idx in group: | |
if idx is not None: | |
yield idx | |
def group_elements(iterable, group_size): | |
groups = [iter(iterable)] * group_size | |
return itertools.zip_longest(*groups) | |
groups = group_elements( | |
iterable=self.indices, | |
group_size=self.bucket_size) | |
groups = list(groups) | |
random.shuffle(groups) | |
for group in groups: | |
indices = list(generate_indices(group)) | |
if self.shuffle: | |
indices = copy.deepcopy(indices) | |
random.shuffle(indices) | |
yield from indices | |
def __len__(self): | |
return len(self.indices) | |
class MultiMetric(Metric): | |
""" | |
Computes min, max, med and mean for a single | |
metric and the end of and epoch. | |
""" | |
def reset(self): | |
self._results = [] | |
def update(self, output): | |
self._results.append(output) | |
def compute(self): | |
results = [] | |
# dividing the data to form a single | |
# list of results for each label class | |
for result in zip(*self._results): | |
# filtering the nan values | |
values = [v for v in result if v == v] | |
if len(values) > 0: | |
results.append(statistics.mean(values)) | |
return { | |
'min': min(results), | |
'max': max(results), | |
'mean': statistics.mean(results), | |
'med': statistics.median(results) | |
} | |
class EntityRecognizer(Module): | |
""" | |
Wrapper module for the BERT entity recognizer | |
with a CRF output layer. | |
""" | |
def __init__(self, model_type, n_labels): | |
super().__init__() | |
self.bert = BertModel.from_pretrained( | |
model_type) | |
self.proj = Linear( | |
self.bert.config.hidden_size, n_labels) | |
self.crf = CRF(n_labels, batch_first=True) | |
def forward(self, inputs, targets): | |
mask = inputs[1] | |
logits = self.bert( | |
input_ids=inputs[0], | |
attention_mask=inputs[1]) | |
emissions = self.proj(logits[0]) | |
loss = self.crf( | |
emissions, targets, mask.bool()) | |
loss /= mask.long().sum() | |
return -loss | |
@torch.no_grad() | |
def decode(self, inputs): | |
mask = inputs[1] | |
logits = self.bert( | |
input_ids=inputs[0], | |
attention_mask=inputs[1]) | |
emissions = self.proj(logits[0]) | |
preds = self.crf.decode( | |
emissions, mask.bool()) | |
return preds | |
def compute_f1_score( | |
preds, targets, num_classes): | |
""" | |
Computes the multi-class f1 score of the | |
predictions. | |
""" | |
targets = targets.view(-1) | |
valid_indices = targets != IGNORE_IDX | |
targets = targets[valid_indices] | |
preds = preds[valid_indices] | |
labels = torch.arange( | |
0, num_classes, dtype=torch.int64) | |
labels = labels.unsqueeze(1) | |
labels = labels.to(preds.device) | |
preds = preds.unsqueeze(0).expand( | |
num_classes, -1) == labels | |
targets = targets.unsqueeze(0).expand( | |
num_classes, -1) == labels | |
recall = compute_recall( | |
preds=preds, targets=targets) | |
precision = compute_precision( | |
preds=preds, targets=targets) | |
f1_score = 2 * (precision * recall) / \ | |
(precision + recall) | |
return f1_score | |
def compute_precision(preds, targets): | |
""" | |
Computes the precision of the predictions. | |
""" | |
true_positives = (preds & targets).sum(dim=-1) | |
selected = preds.sum(dim=-1) | |
precision = true_positives.float() / \ | |
selected.float() | |
return precision | |
def compute_recall(preds, targets): | |
""" | |
Computes the recall of the predictions. | |
""" | |
true_positives = (preds & targets).sum(dim=-1) | |
relevants = targets.sum(dim=-1) | |
recall = true_positives.float() / \ | |
relevants.float() | |
return recall | |
def compute_accuracy(preds, targets): | |
""" | |
Computes the accuracy of the predictions. | |
""" | |
targets = targets.view(-1) | |
# computing accuracy without including | |
# the values at the ignore indices | |
not_ignore = targets.ne(IGNORE_IDX) | |
num_targets = not_ignore.long().sum() | |
num_targets = num_targets.item() | |
correct = (targets == preds) & not_ignore | |
correct = correct.float().sum() | |
accuracy = correct / num_targets | |
return accuracy | |
def compute_metrics(preds, targets, n_labels): | |
""" | |
Computes the loss and accuracy from the | |
outputs of the model. | |
""" | |
accuracy = compute_accuracy(preds, targets) | |
f1_score = compute_f1_score( | |
preds, targets, n_labels) | |
return { | |
'acc': accuracy.item(), | |
'f1': f1_score.tolist() | |
} | |
def attach_metrics( | |
engine, metrics, metric_type=RunningAverage): | |
""" | |
Attaches the metrics to the engine for | |
logging. | |
""" | |
for metric in metrics: | |
metric_type( | |
output_transform= \ | |
lambda x: x[metric]).attach( | |
engine, metric) | |
# implementation is from DialoGPT repo | |
def noam_decay(step, warmup_steps, d_model): | |
""" | |
Learning rate schedule described in | |
https://arxiv.org/pdf/1706.03762.pdf. | |
""" | |
return ( | |
d_model ** (-0.5) * min(step ** (-0.5), | |
step * warmup_steps**(-1.5))) | |
# implementation is from DialoGPT repo | |
def set_lr(step, optimizer, lr, warmup_steps, d_model): | |
""" | |
Learning rate scheduler that applies either | |
noam or noamwd rule. | |
""" | |
lr_this_step = lr * 1e4 * \ | |
noam_decay(step + 1, warmup_steps, d_model) | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr_this_step | |
def main(): | |
""" | |
Performs training, validation and testing. | |
""" | |
args = setup_train_args() | |
args.cuda = torch.cuda.is_available() and \ | |
not args.no_cuda | |
args.fp16 = args.fp16 and APEX_INSTALLED \ | |
and args.cuda | |
if args.seed is not None: | |
set_random_seed(args) | |
device = torch.device( | |
'cuda' if args.cuda else 'cpu') | |
# creating dataset and storing dataset splits | |
# as individual variables for convenience | |
train_loader, valid_loader, tokenizer, labels = \ | |
create_dataset(args) | |
n_labels = len(labels['label_freqs']) | |
model = EntityRecognizer( | |
args.model_type, n_labels) | |
model = model.to(device) | |
model_file = get_latest_file( | |
working_dir=args.working_dir, | |
pattern=r'model_\d+.pth') | |
# if model file exists resume training | |
if model_file is not None and args.warm_start: | |
print('Loading model from {}'.format( | |
model_file)) | |
model_state = torch.load( | |
model_file, map_location=device) | |
model.load_state_dict(model_state) | |
optimizer = AdamW( | |
params=model.parameters(), | |
lr=args.lr, | |
weight_decay=0.01) | |
optim_file = get_latest_file( | |
working_dir=args.working_dir, | |
pattern=r'optim_\d+.pth') | |
if optim_file is not None and args.warm_start: | |
print('Loading optimizer from {}'.format( | |
optim_file)) | |
optim_state = torch.load( | |
optim_file, map_location=device) | |
optimizer.load_state_dict(optim_state) | |
if args.fp16: | |
model, optimizer = amp.initialize( | |
model, optimizer, opt_level='O2') | |
set_lr_fn = functools.partial( | |
set_lr, | |
optimizer=optimizer, | |
lr=args.lr, | |
warmup_steps=args.warmup_steps, | |
d_model=model.bert.config.hidden_size) | |
def forward_step(batch): | |
""" | |
Applies forward pass with the given batch. | |
""" | |
inputs, targets = batch | |
inputs = torch.as_tensor(inputs) | |
inputs = inputs.to(device) | |
targets = torch.as_tensor(targets) | |
targets = targets.long().to(device) | |
loss = model(inputs, targets) | |
outputs = model.decode(inputs) | |
outputs = [ | |
preds + [-1] * ( | |
inputs.size(-1) - len(preds)) | |
for preds in outputs | |
] | |
preds = torch.as_tensor(outputs) | |
preds = preds.to(device) | |
preds = preds.view(-1) | |
metrics = compute_metrics( | |
preds, targets, n_labels) | |
metrics['loss'] = loss | |
return metrics | |
def train_step(engine, batch): | |
""" | |
Propagates the inputs forward and updates | |
the parameters. | |
""" | |
model.train() | |
results = forward_step(batch) | |
loss = results['loss'] | |
loss /= args.grad_accum_steps | |
backward(loss) | |
if args.clip_grad is not None: | |
clip_grad_norm(args.clip_grad) | |
if engine.state.iteration % \ | |
args.grad_accum_steps == 0: | |
set_lr_fn(engine.state.iteration) | |
optimizer.step() | |
optimizer.zero_grad() | |
results['loss'] = results['loss'].item() | |
return results | |
def eval_step(engine, batch): | |
""" | |
Propagates the inputs forward without | |
storing any gradients. | |
""" | |
model.eval() | |
with torch.no_grad(): | |
results = forward_step(batch) | |
results['loss'] = results['loss'].item() | |
return results | |
def backward(loss): | |
""" | |
Backpropagates the loss in either mixed or | |
normal precision mode. | |
""" | |
if args.fp16: | |
with amp.scale_loss( | |
loss, optimizer) as scaled: | |
scaled.backward() | |
else: | |
loss.backward() | |
def clip_grad_norm(max_norm): | |
""" | |
Applies gradient clipping. | |
""" | |
if args.fp16: | |
params = amp.master_params(optimizer) | |
clip_grad_norm_(params, max_norm) | |
else: | |
clip_grad_norm_( | |
model.parameters(), max_norm) | |
# after an epoch the results are metrics | |
# are evaluated on the whole train dataset | |
trainer = Engine(train_step) | |
train_eval = Engine(eval_step) | |
valid_eval = Engine(eval_step) | |
# these metrics are derived from MultiMetric | |
valid_metrics = [ | |
'loss', 'acc', 'f1_mean', | |
'f1_min', 'f1_max', 'f1_med' | |
] | |
attach_metrics(trainer, ['loss']) | |
attach_metrics(train_eval, ['loss', 'acc']) | |
attach_metrics(valid_eval, ['loss', 'acc']) | |
attach_metrics(valid_eval, ['f1'], MultiMetric) | |
pbar = ProgressBar() | |
pbar.attach(trainer, metric_names=['loss']) | |
# adding model checkpoint handler | |
checkpoint = ModelCheckpoint( | |
args.working_dir, '', | |
n_saved=3, | |
require_empty=False, | |
save_as_state_dict=True, | |
score_function=lambda e: -e.state.metrics['loss']) | |
checkpoint_dict = { | |
'model': model, | |
'optim': optimizer | |
} | |
valid_eval.add_event_handler( | |
Events.COMPLETED, checkpoint, | |
checkpoint_dict) | |
early_stopping = EarlyStopping( | |
patience=args.patience, | |
score_function=lambda e: -e.state.metrics['loss'], | |
trainer=trainer) | |
valid_eval.add_event_handler( | |
Events.COMPLETED, early_stopping) | |
# loading history for training logs | |
history_path = join(args.working_dir, 'history.json') | |
history = collections.defaultdict(list) | |
# NOTE the hardcoded values to keep track of | |
# in the history | |
valid_headers = [ | |
f'valid_{metric}' for metric in valid_metrics] | |
train_headers = [ | |
f'train_{metric}' for metric in ['loss', 'acc'] | |
] | |
headers = ['epoch'] + train_headers + valid_headers | |
if exists(history_path): | |
with open(history_path, 'r') as fh: | |
history = json.load(fh) | |
def record_history(results): | |
""" | |
Records the results to the history. | |
""" | |
# saving history and handling unexpected | |
# keyboard interrupt | |
for header in headers: | |
history[header].append(results[header]) | |
while True: | |
try: | |
with open(history_path, 'w') as fh: | |
json.dump(history, fh) | |
break | |
except KeyboardInterrupt: | |
pass | |
def run_eval(evaluator, loader, prefix, metrics): | |
""" | |
Runs evaluation with the given evaluator. | |
""" | |
evaluator.run(loader) | |
results = flatten(evaluator.state.metrics) | |
return { | |
f'{prefix}_{metric}': results[metric] | |
for metric in metrics | |
} | |
@trainer.on(Events.EPOCH_COMPLETED) | |
def log_eval_results(trainer): | |
""" | |
Logs the training results. | |
""" | |
results = {'epoch': trainer.state.epoch} | |
results.update(run_eval( | |
train_eval, train_loader, | |
'train', ['loss', 'acc'])) | |
results.update(run_eval( | |
valid_eval, valid_loader, | |
'valid', valid_metrics)) | |
record_history(results) | |
data = list(zip(*[history[h] for h in headers])) | |
table = tabulate.tabulate( | |
data, headers, floatfmt='.3f') | |
print(table.split('\n')[-1]) | |
@trainer.on(Events.EXCEPTION_RAISED) | |
def handle_exception(engine, e): | |
if isinstance(e, KeyboardInterrupt) and \ | |
engine.state.iteration > 1: | |
engine.terminate() | |
elif isinstance(e, RuntimeError): | |
pass | |
print(vars(args)) | |
data = list(zip(*[history[h] for h in headers])) | |
# printing the initial table headers and | |
# previous results of the training if resuming | |
print(tabulate.tabulate( | |
data, headers, floatfmt='.3f')) | |
trainer.run(train_loader, args.max_epochs) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment