Skip to content

Instantly share code, notes, and snippets.

@foowaa
Last active December 21, 2018 03:03
Show Gist options
  • Save foowaa/89292ed7761aa2ad2e612a3750219e36 to your computer and use it in GitHub Desktop.
Save foowaa/89292ed7761aa2ad2e612a3750219e36 to your computer and use it in GitHub Desktop.
import os
import torch.utils.data as data
from tqdm import tqdm
def train():
# config saving
model_path = './ckpt'
if not os.path.exists(model_path):
os.mkdir(model_path)
save_step = 2
save_epoch_step = 50
# data
batch_size = 16
my_data = MyDataset(others)
train_loader = data.DataLoader(dataset=my_data,
batch_size=batch_size,
shuffle=True)
# model
model = MyModel(sizes)
# loss
criterion = torch.nn.BCELoss()
# optimizer
optimizer = torch.optim.Adam(model.parameters())
# Train the model
num_epochs = 100
for epoch in range(num_epochs):
states = torch.zeros(params)
pbar = tqdm(train_loader)
i = 0
for (x, y) in pbar:
i += 1
if torch.cuda.is_available():
y = y.cuda()
x = x.cuda()
# RNN need detach
states = states.detach()
# Forward pass
outputs, states = model(x.float(), states)
# loss
loss = criterion(outputs, y.float())
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
def dev():
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
# dev_loader
losses = 0
states = torch.zeros(params)
for i, (x, y) in enumerate(dev_loader):
if torch.cuda.is_available():
x = x.cuda()
y = y.cuda()
states = states.detach()
outputs, states = model(x.float(), states)
losses += criterion(outputs, y.float())
loss = losses/(i+1)
return loss
dev_loss = dev()
pbar.set_description(
"Epoch: {:d}/{:d}, train_loss:{:.4f}, dev_loss:{:.4f}.".format(epoch, num_epochs, loss, dev_loss))
# Save the model checkpoints
if i % save_step == 0:
torch.save(model.state_dict(), os.path.join(model_path, 'model-{}-{}.ckpt'.format(epoch+1, i)))
if (epoch+1) % save_epoch_step == 0:
torch.save(model, os.path.join(
model_path, 'model-{}.pt'.format(epoch+1)))
def test():
# Test the model
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
correct = 0
total = 0
for x, y in test_loader:
if torch.cuda.is_available():
x = x.cuda()
y = y.cuda()
outputs = model(x)
_, predicted = torch.max(outputs.data, 1)
total += y.size(0)
correct += (predicted == y).sum().item()
print('Test metrics: {:4f} %'.format(100 * correct / total))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment