Skip to content

Instantly share code, notes, and snippets.

@joisino
Created October 3, 2017 07:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save joisino/7c9712ec70cb026f2140c5ac1cebc2bb to your computer and use it in GitHub Desktop.
Save joisino/7c9712ec70cb026f2140c5ac1cebc2bb to your computer and use it in GitHub Desktop.
"""
VGG for Art Class (IOI2013)
"""
import numpy as np
import os
from PIL import Image
from chainer import dataset
from chainer import links as L
from chainer import Variable
from chainer import Chain
class ArtclassNetwork(Chain):
def __init__(self):
super(ArtclassNetwork, self).__init__(
vgg = L.VGG16Layers(),
fc = L.Linear(4096, 4))
def __call__(self, x):
h = Variable(self.xp.asarray(x, dtype=np.float32))
h = self.vgg(x, layers=['fc7'])['fc7']
h = self.fc(h)
return h
class Dataset(dataset.DatasetMixin):
""" dataset for Art Class (IOI 2013)
Args:
dir (str): the directory containing data.
It cannot contain other data (e.g. label data.)
The first character of filenames must correspont with the label ('1', '2', '3', '4').
"""
def __init__(self, directory, augment=False):
self.directory = directory
self.augment = augment
self.files = os.listdir(directory)
def __len__(self):
return len(self.files)
def random_box(self, size, d):
l = np.random.randint(0, d)
t = np.random.randint(0, d)
r = size[0] - np.random.randint(0, d)
b = size[1] - np.random.randint(0, d)
return (l, t, r, b)
def get_example(self, i):
img = Image.open(self.directory + self.files[i])
if self.augment:
img = img.crop(self.random_box(img.size, min(img.size)//2-32))
img = L.model.vision.vgg.prepare(img)
img = np.array(img, dtype=np.float32)
label = int(self.files[i][0])-1
label = np.array(label, dtype=np.int32)
return img, label
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
predict images with VGG model for Art Class (IOI2013)
"""
import numpy as np
import argparse
from artclass import ArtclassNetwork, Dataset
from chainer import iterators, optimizers, serializers, training
from chainer import cuda
from chainer import links as L
from chainer.training import extensions
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', '-g', type=int, default=-1)
parser.add_argument('--model', type=str, default="./results/model")
args = parser.parse_args()
model = ArtclassNetwork()
serializers.load_npz(args.model, model)
if args.gpu >= 0:
cuda.get_device_from_id(args.gpu).use()
model.to_gpu()
test_data = Dataset("./test_data/")
ac = 0
cnt = 0
for (x, t) in test_data:
x = np.expand_dims(x, axis=0)
if args.gpu >= 0:
x = cuda.to_gpu(x)
y = model(x)
res = np.argmax(y.data[0])
res = int(res)
t = int(t)
print( str(res) + " " + str(t) )
if res == t:
ac += 1
cnt += 1
print( str(ac) + " / " + str(cnt) )
if __name__ == '__main__':
main()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
train VGG model for Art Class (IOI2013)
"""
import numpy as np
import argparse
from artclass import ArtclassNetwork, Dataset
from chainer import iterators, optimizers, serializers, training
from chainer import cuda
from chainer import links as L
from chainer.training import extensions
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', '-g', type=int, default=-1)
parser.add_argument('--epoch', '-n', type=int, default=100)
parser.add_argument('--decay', type=float, default=0.99)
args = parser.parse_args()
model = ArtclassNetwork()
clf = L.Classifier(model)
if args.gpu >= 0:
cuda.get_device_from_id(args.gpu).use()
clf.to_gpu()
train_data = Dataset("./train_data/", True)
train_iter = iterators.SerialIterator(train_data, batch_size=8, shuffle=True)
opt = optimizers.Adam()
opt.setup(clf)
updater = training.StandardUpdater(train_iter, opt, device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out='results')
trainer.extend(extensions.ExponentialShift("alpha",args.decay), trigger=(1, 'epoch'))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'main/loss']))
trainer.extend(extensions.ProgressBar())
trainer.run()
modelname = "./results/model"
print("saving model to " + modelname)
serializers.save_npz(modelname, model)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment