Skip to content

Instantly share code, notes, and snippets.

@joisino joisino/README.md
Created Oct 3, 2017

Embed
What would you like to do?
"""
VGG for Art Class (IOI2013)
"""
import numpy as np
import os
from PIL import Image
from chainer import dataset
from chainer import links as L
from chainer import Variable
from chainer import Chain
class ArtclassNetwork(Chain):
def __init__(self):
super(ArtclassNetwork, self).__init__(
vgg = L.VGG16Layers(),
fc = L.Linear(4096, 4))
def __call__(self, x):
h = Variable(self.xp.asarray(x, dtype=np.float32))
h = self.vgg(x, layers=['fc7'])['fc7']
h = self.fc(h)
return h
class Dataset(dataset.DatasetMixin):
""" dataset for Art Class (IOI 2013)
Args:
dir (str): the directory containing data.
It cannot contain other data (e.g. label data.)
The first character of filenames must correspont with the label ('1', '2', '3', '4').
"""
def __init__(self, directory, augment=False):
self.directory = directory
self.augment = augment
self.files = os.listdir(directory)
def __len__(self):
return len(self.files)
def random_box(self, size, d):
l = np.random.randint(0, d)
t = np.random.randint(0, d)
r = size[0] - np.random.randint(0, d)
b = size[1] - np.random.randint(0, d)
return (l, t, r, b)
def get_example(self, i):
img = Image.open(self.directory + self.files[i])
if self.augment:
img = img.crop(self.random_box(img.size, min(img.size)//2-32))
img = L.model.vision.vgg.prepare(img)
img = np.array(img, dtype=np.float32)
label = int(self.files[i][0])-1
label = np.array(label, dtype=np.int32)
return img, label
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
predict images with VGG model for Art Class (IOI2013)
"""
import numpy as np
import argparse
from artclass import ArtclassNetwork, Dataset
from chainer import iterators, optimizers, serializers, training
from chainer import cuda
from chainer import links as L
from chainer.training import extensions
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', '-g', type=int, default=-1)
parser.add_argument('--model', type=str, default="./results/model")
args = parser.parse_args()
model = ArtclassNetwork()
serializers.load_npz(args.model, model)
if args.gpu >= 0:
cuda.get_device_from_id(args.gpu).use()
model.to_gpu()
test_data = Dataset("./test_data/")
ac = 0
cnt = 0
for (x, t) in test_data:
x = np.expand_dims(x, axis=0)
if args.gpu >= 0:
x = cuda.to_gpu(x)
y = model(x)
res = np.argmax(y.data[0])
res = int(res)
t = int(t)
print( str(res) + " " + str(t) )
if res == t:
ac += 1
cnt += 1
print( str(ac) + " / " + str(cnt) )
if __name__ == '__main__':
main()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
train VGG model for Art Class (IOI2013)
"""
import numpy as np
import argparse
from artclass import ArtclassNetwork, Dataset
from chainer import iterators, optimizers, serializers, training
from chainer import cuda
from chainer import links as L
from chainer.training import extensions
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', '-g', type=int, default=-1)
parser.add_argument('--epoch', '-n', type=int, default=100)
parser.add_argument('--decay', type=float, default=0.99)
args = parser.parse_args()
model = ArtclassNetwork()
clf = L.Classifier(model)
if args.gpu >= 0:
cuda.get_device_from_id(args.gpu).use()
clf.to_gpu()
train_data = Dataset("./train_data/", True)
train_iter = iterators.SerialIterator(train_data, batch_size=8, shuffle=True)
opt = optimizers.Adam()
opt.setup(clf)
updater = training.StandardUpdater(train_iter, opt, device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out='results')
trainer.extend(extensions.ExponentialShift("alpha",args.decay), trigger=(1, 'epoch'))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'main/loss']))
trainer.extend(extensions.ProgressBar())
trainer.run()
modelname = "./results/model"
print("saving model to " + modelname)
serializers.save_npz(modelname, model)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.