Skip to content

Instantly share code, notes, and snippets.

@foowaa
Last active December 13, 2018 18:27
Show Gist options
  • Save foowaa/917e3bdc1f963480c729ea9f528dce1f to your computer and use it in GitHub Desktop.
Save foowaa/917e3bdc1f963480c729ea9f528dce1f to your computer and use it in GitHub Desktop.
training procedure
'''
num_epochs: 运行epoch的轮数
train: 训练数据
dev: 验证集数据
evalp: 几轮进行一次验证
model: 模型
metric_best: 记录最佳的metric
metric_stop: 训练停止的metric
cnt_stop: 训练到dev几次不能超过metric_best就停止
'''
from tqdm import tqdm
pbar_epochs = tqdm(range(num_epochs))
for epoch in pbar_epochs:
pbar_epochs.set_description("Epoch:{:d}.".format(epoch))
pbar = tqdm(train)
for data in pbar:
model.train_batch(data)
pbar.set_description("Loss:{:.4f}.".format(model.loss))
pbar.refresh()
if (epoch+1) % evalp == 0:
metric = model.evaluate(dev)
if metric >= metric_best:
metric_best = metric
cnt = 0
else:
cnt += 1
if metric > metric_stop:
break
if cnt > cnt_stop:
break
pbar_epochs.refresh()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment