Skip to content

Instantly share code, notes, and snippets.

@khirotaka
Last active September 13, 2019 12:52
Show Gist options
  • Save khirotaka/d430b4c6d49b200d90491ba71c8c2a2a to your computer and use it in GitHub Desktop.
Save khirotaka/d430b4c6d49b200d90491ba71c8c2a2a to your computer and use it in GitHub Desktop.
Comet.ml with PyTorch
import os
import time
import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class NeuralNetworkClassifier:
def __init__(self, model, criterion, optimizer, optimizer_config: dict, experiment) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = model.to(self.device)
self.optimizer = optimizer(self.model.parameters(), **optimizer_config)
self.criterion = criterion
self.experiment = experiment
self.hyper_params = optimizer_config
self._start_epoch = 0
self._is_parallel = False
if torch.cuda.device_count() > 1:
self.model = nn.DataParallel(self.model)
self._is_parallel = True
notice = "Running on {} GPUs.".format(torch.cuda.device_count())
print("\033[33m" + notice + "\033[0m")
def fit(self, loader: dict, epochs: int, checkpoint_path=None) -> None:
len_of_train_dataset = len(loader["train"].dataset)
len_of_val_dataset = len(loader["val"].dataset)
epochs = epochs + self._start_epoch
self.hyper_params["epochs"] = epochs
self.hyper_params["batch_size"] = loader["train"].batch_size
self.hyper_params["train_ds_size"] = len_of_train_dataset
self.hyper_params["val_ds_size"] = len_of_val_dataset
self.experiment.log_parameters(self.hyper_params)
for epoch in range(self._start_epoch, epochs):
if checkpoint_path is not None and epoch % 100 == 0:
self.save_to_file(checkpoint_path)
with self.experiment.train():
correct = 0.0
total = 0.0
self.model.train()
pbar = tqdm.tqdm(total=len_of_train_dataset)
for x, y in loader["train"]:
b_size = x.shape[0]
total += y.shape[0]
x = x.to(self.device)
y = y.to(self.device)
pbar.set_description(
"\033[36m" + "Training" + "\033[0m" + " - Epochs: {:03d}/{:03d}".format(epoch+1, epochs)
)
pbar.update(b_size)
self.optimizer.zero_grad()
outputs = self.model(x)
loss = self.criterion(outputs, y)
loss.backward()
self.optimizer.step()
_, predicted = torch.max(outputs, 1)
correct += (predicted == y).sum().float().cpu().item()
self.experiment.log_metric("loss", loss.cpu().item(), step=epoch)
self.experiment.log_metric("accuracy", float(correct / total), step=epoch)
with self.experiment.validate():
with torch.no_grad():
val_correct = 0.0
val_total = 0.0
self.model.eval()
for x_val, y_val in loader["val"]:
val_total += y_val.shape[0]
x_val = x_val.to(self.device)
y_val = y_val.to(self.device)
val_output = self.model(x_val)
val_loss = self.criterion(val_output, y_val)
_, val_pred = torch.max(val_output, 1)
val_correct += (val_pred == y_val).sum().float().cpu().item()
self.experiment.log_metric("loss", val_loss.cpu().item(), step=epoch)
self.experiment.log_metric("accuracy", float(val_correct / val_total), step=epoch)
pbar.close()
def evaluate(self, loader: DataLoader) -> None:
running_loss = 0.0
running_corrects = 0.0
pbar = tqdm.tqdm(total=len(loader.dataset))
self.model.eval()
self.experiment.log_parameter("test_ds_size", len(loader.dataset))
with self.experiment.test():
with torch.no_grad():
correct = 0.0
total = 0.0
for step, (x, y) in enumerate(loader):
b_size = x.shape[0]
total += y.shape[0]
x = x.to(self.device)
y = y.to(self.device)
pbar.set_description("\033[32m"+"Evaluating"+"\033[0m")
pbar.update(b_size)
outputs = self.model(x)
loss = self.criterion(outputs, y)
_, predicted = torch.max(outputs, 1)
correct += (predicted == y).sum().float().cpu().item()
running_loss += loss.cpu().item()
running_corrects += torch.sum(predicted == y).float().cpu().item()
self.experiment.log_metric("loss", running_loss)
self.experiment.log_metric("accuracy", float(running_corrects / total))
pbar.close()
print("\033[33m" + "Evaluation finished. Check your workspace" + "\033[0m" + " https://www.comet.ml/")
def save_checkpoint(self) -> dict:
checkpoints = {
"epoch": self.hyper_params["epochs"],
"optimizer_state_dict": self.optimizer.state_dict()
}
if self._is_parallel:
checkpoints["model_state_dict"] = self.model.module.state_dict()
else:
checkpoints["model_state_dict"] = self.model.state_dict()
return checkpoints
def save_to_file(self, path: str) -> str:
if not os.path.isdir(path):
os.mkdir(path)
file_name = "model_params-epochs_{}-{}.pth".format(
self.hyper_params["epochs"], time.ctime().replace(" ", "_")
)
path = path + file_name
checkpoints = self.save_checkpoint()
torch.save(checkpoints, path)
self.experiment.log_asset(path, file_name=file_name)
return path
def restore_checkpoint(self, checkpoints: dict) -> None:
self._start_epoch = checkpoints["epoch"]
assert isinstance(self._start_epoch, int)
if self._is_parallel:
self.model.module.load_state_dict(checkpoints["model_state_dict"])
else:
self.model.load_state_dict(checkpoints["model_state_dict"])
self.optimizer.load_state_dict(checkpoints["optimizer_state_dict"])
def restore_from_file(self, path: str, map_location="cpu") -> None:
checkpoints = torch.load(path, map_location=map_location)
self.restore_checkpoint(checkpoints)
@property
def experiment_tag(self) -> list:
return self.experiment.get_tags()
@experiment_tag.setter
def experiment_tag(self, tag: str) -> None:
"""
clf = NeuralNetworkClassifier(...)
clf.experiment_tag = "tag"
:param tag: str
:return: None
"""
assert isinstance(tag, str)
self.experiment.add_tag(tag)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment