VGG for Art Class (IOI 2013) with chainer
http://joisino.hatenablog.com/entry/2017/10/05/200000
Copyright (c) 2017 joisino
Released under the MIT license
http://opensource.org/licenses/mit-license.php
VGG for Art Class (IOI 2013) with chainer
http://joisino.hatenablog.com/entry/2017/10/05/200000
Copyright (c) 2017 joisino
Released under the MIT license
http://opensource.org/licenses/mit-license.php
""" | |
VGG for Art Class (IOI2013) | |
""" | |
import numpy as np | |
import os | |
from PIL import Image | |
from chainer import dataset | |
from chainer import links as L | |
from chainer import Variable | |
from chainer import Chain | |
class ArtclassNetwork(Chain): | |
def __init__(self): | |
super(ArtclassNetwork, self).__init__( | |
vgg = L.VGG16Layers(), | |
fc = L.Linear(4096, 4)) | |
def __call__(self, x): | |
h = Variable(self.xp.asarray(x, dtype=np.float32)) | |
h = self.vgg(x, layers=['fc7'])['fc7'] | |
h = self.fc(h) | |
return h | |
class Dataset(dataset.DatasetMixin): | |
""" dataset for Art Class (IOI 2013) | |
Args: | |
dir (str): the directory containing data. | |
It cannot contain other data (e.g. label data.) | |
The first character of filenames must correspont with the label ('1', '2', '3', '4'). | |
""" | |
def __init__(self, directory, augment=False): | |
self.directory = directory | |
self.augment = augment | |
self.files = os.listdir(directory) | |
def __len__(self): | |
return len(self.files) | |
def random_box(self, size, d): | |
l = np.random.randint(0, d) | |
t = np.random.randint(0, d) | |
r = size[0] - np.random.randint(0, d) | |
b = size[1] - np.random.randint(0, d) | |
return (l, t, r, b) | |
def get_example(self, i): | |
img = Image.open(self.directory + self.files[i]) | |
if self.augment: | |
img = img.crop(self.random_box(img.size, min(img.size)//2-32)) | |
img = L.model.vision.vgg.prepare(img) | |
img = np.array(img, dtype=np.float32) | |
label = int(self.files[i][0])-1 | |
label = np.array(label, dtype=np.int32) | |
return img, label | |
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
predict images with VGG model for Art Class (IOI2013) | |
""" | |
import numpy as np | |
import argparse | |
from artclass import ArtclassNetwork, Dataset | |
from chainer import iterators, optimizers, serializers, training | |
from chainer import cuda | |
from chainer import links as L | |
from chainer.training import extensions | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--gpu', '-g', type=int, default=-1) | |
parser.add_argument('--model', type=str, default="./results/model") | |
args = parser.parse_args() | |
model = ArtclassNetwork() | |
serializers.load_npz(args.model, model) | |
if args.gpu >= 0: | |
cuda.get_device_from_id(args.gpu).use() | |
model.to_gpu() | |
test_data = Dataset("./test_data/") | |
ac = 0 | |
cnt = 0 | |
for (x, t) in test_data: | |
x = np.expand_dims(x, axis=0) | |
if args.gpu >= 0: | |
x = cuda.to_gpu(x) | |
y = model(x) | |
res = np.argmax(y.data[0]) | |
res = int(res) | |
t = int(t) | |
print( str(res) + " " + str(t) ) | |
if res == t: | |
ac += 1 | |
cnt += 1 | |
print( str(ac) + " / " + str(cnt) ) | |
if __name__ == '__main__': | |
main() |
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
train VGG model for Art Class (IOI2013) | |
""" | |
import numpy as np | |
import argparse | |
from artclass import ArtclassNetwork, Dataset | |
from chainer import iterators, optimizers, serializers, training | |
from chainer import cuda | |
from chainer import links as L | |
from chainer.training import extensions | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--gpu', '-g', type=int, default=-1) | |
parser.add_argument('--epoch', '-n', type=int, default=100) | |
parser.add_argument('--decay', type=float, default=0.99) | |
args = parser.parse_args() | |
model = ArtclassNetwork() | |
clf = L.Classifier(model) | |
if args.gpu >= 0: | |
cuda.get_device_from_id(args.gpu).use() | |
clf.to_gpu() | |
train_data = Dataset("./train_data/", True) | |
train_iter = iterators.SerialIterator(train_data, batch_size=8, shuffle=True) | |
opt = optimizers.Adam() | |
opt.setup(clf) | |
updater = training.StandardUpdater(train_iter, opt, device=args.gpu) | |
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out='results') | |
trainer.extend(extensions.ExponentialShift("alpha",args.decay), trigger=(1, 'epoch')) | |
trainer.extend(extensions.LogReport()) | |
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'main/loss'])) | |
trainer.extend(extensions.ProgressBar()) | |
trainer.run() | |
modelname = "./results/model" | |
print("saving model to " + modelname) | |
serializers.save_npz(modelname, model) | |
if __name__ == '__main__': | |
main() |