Skip to content

Instantly share code, notes, and snippets.

@devforfu
Last active August 25, 2019 17:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save devforfu/9acebd780215efe43b8b5d69ba0f3f9c to your computer and use it in GitHub Desktop.
Save devforfu/9acebd780215efe43b8b5d69ba0f3f9c to your computer and use it in GitHub Desktop.
Catalyst example. (Doesn't work as expected).
import os
import re
from pdb import set_trace
from multiprocessing import cpu_count
from pprint import pprint as pp
from imageio import imread
import numpy as np
import pandas as pd
import PIL.Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as T
from catalyst.contrib.schedulers import OneCycleLR
from catalyst.dl.callbacks import AccuracyCallback, AUCCallback, F1ScoreCallback
from catalyst.dl.runner import SupervisedRunner
import pretrainedmodels
from jupytools import auto_set_trace
set_trace = auto_set_trace()
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
def list_files(folder):
dirname = os.path.expanduser(folder)
return [os.path.join(dirname, x) for x in os.listdir(dirname)]
def extract_labels(files):
regex = re.compile('.*_(\d+)\\.png$')
return [int(regex.match(os.path.basename(fn)).group(1)) for fn in files]
class ImageDataset(Dataset):
def __init__(self, files, train=True, tr=None):
regex = re.compile('.*_(\d+)\\.png$')
self.files = files
self.tr = tr
self.labels = extract_labels(files)
@property
def n_classes(self):
return len(np.unique(self.labels))
def __len__(self):
return len(self.files)
def __getitem__(self, index):
x = PIL.Image.open(self.files[index])
if self.tr is not None:
x = self.tr(x)
y = self.labels[index]
return x, y
def get_model(model_name, num_classes, pretrained='imagenet'):
model_fn = pretrainedmodels.__dict__[model_name]
model = model_fn(num_classes=1000, pretrained=pretrained)
dim_feats = model.last_linear.in_features
model.last_linear = nn.Linear(dim_feats, num_classes)
return model
model_name = 'resnet50'
params = pretrainedmodels.pretrained_settings[model_name]['imagenet']
pp(params)
data_tr = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(params['mean'], params['std'])
])
bs = 128
num_epochs = 1
trn_files = list_files('~/data/tmp/train')
trn_ds = ImageDataset(trn_files, tr=data_tr)
trn_dl = DataLoader(trn_ds, batch_size=bs, num_workers=cpu_count())
tst_files = list_files('~/data/tmp/test')
tst_ds = ImageDataset(tst_files, tr=data_tr)
tst_dl = DataLoader(tst_ds, batch_size=bs, num_workers=cpu_count())
from collections import OrderedDict
loaders = OrderedDict()
loaders['train'] = trn_dl
loaders['valid'] = tst_dl
resnet = get_model(model_name, trn_ds.n_classes)
for param in resnet.parameters():
param.requires_grad = False
resnet.last_linear.weight.requires_grad = True
for param in resnet.layer4.parameters():
param.requires_grad = True
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(resnet.parameters(), lr=0.0001)
logdir = '/tmp/logs/'
runner = SupervisedRunner()
runner.train(
model=resnet,
criterion=loss_fn,
optimizer=opt,
loaders=loaders,
logdir=logdir,
num_epochs=num_epochs,
callbacks=[
AccuracyCallback(),
AUCCallback(),
F1ScoreCallback(activation='Softmax')
],
verbose=True
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment