Skip to content

Instantly share code, notes, and snippets.

@nudles
Forked from aaronwwf/.gitignore
Last active January 4, 2017 09:12
Show Gist options
  • Save nudles/51b10a13fbc3427934591f0e135f40a3 to your computer and use it in GitHub Desktop.
Save nudles/51b10a13fbc3427934591f0e135f40a3 to your computer and use it in GitHub Desktop.
rafiki-cifar10-vgg
.project
.pydevproject
data_
parameter_
*.pyc
name SINGA version SINGA commit parameter_url parameter_sha1 data_url data_md5 license gist_id
VGG model on Cifar-10
1.0.0
7956019cf326c5f84401551f31f1e597fba77d46
52fe5e237e9efa42f9d4ed081a5a2ee9bcd22c8e
c58f30108f718f92721af3b95e74349a
non-commercial
51b10a13fbc3427934591f0e135f40a3

Train VGG over CIFAR-10

This example provides the training and serving scripts for VGG over CIFAR-10 data. The best validation accuracy (without data augmentation) we achieved was about 92%.

Folder layout

The folder structure for an example is as follows where README.md is required and other files are optional.

  • README.md. Every example should have a README.md file for the model description, SINGA version and running instructions.
  • train.py. The training script. Users should be able to run it directly by python train.py. It is optional if the model is shared only for prediction or serving tasks.
  • serve.py. The serving script. It is typically used in the cloud mode, where users can submit the query via the web front end provided by Rafiki. If the local mode is enabled, it should accepts command line input. It is optional if the model is shared only for training tasks.
  • model.py. It has the functions for creating the neural net. It could be merged into train.py and serve.py, hence are optional.
  • index.html. This file is used for the serving task, which provides a web page for users to submit queries and get responses for the results. If is required for running the serving task in the cloud mode. If the model is shared only for training or running in the local mode, it is optional.
  • requirements.txt. For specifying the python libraries used by users' code. It is optional if no third-party libs are used.

Some models may have other files and scripts. Typically, it is not recommended to put large files (e.g. >10MB) into this folder as it would be slow to clone the gist repo.

Instructions

Local mode

To run the scripts on your local computer, you need to install SINGA. Please refer to the installation page for detailed instructions.

Training

The training data should be downloaded and decompressed into the current folder

    wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
    tar xvf cifar-10-python.tar.gz

The training program could be started by

    python train.py

By default, the training is conducted on a GPU card, to use CPU for training (very slow), run

    python train.py --use_cpu

The model parameters would be dumped periodically, into model-<epoch ID> and the last one is model.

Serving

The pre-trained model parameters should be downloaded and decompressed into the current folder.

This example does not have the serving script for local mode. To simulate the local mode, you can start the prediction script and use curl to pass the query image.

    python serve.py &
    curl -i -F image=@image1.jpg http://localhost:9999/api
    curl -i -F image=@image2.jpg http://localhost:9999/api
    curl -i -F image=@image3.jpg http://localhost:9999/api

The above commands start the serving program using the model trained for VGG as a daemon, and then submit three queries (image1.jpg, image2.jpg, image3.jpg) to the port (the default port is 9999). To use other port, please add -p PORT_NUMBER to the running command. If you run the serving task after finishing the training task, then the model parameters from model would be used. Otherwise, it would use the one downloaded using data.py.

Cloud mode

To run the scripts on the Rafiki platform, you don't need to install SINGA. But you need to add the dependent libs introduced by your code into the requirement.txt file.

Adding model

The Rafiki front-end provides a web page for users to import gist repos directly. Users just specify the HTTPS (NOT the git web URL) clone link and click load to import a repo.

Training

The Rafiki font-end has a Job view for adding a new training job. Users need to configure the job type as 'training', select the model (i.e. the repo added in the above step) and the training dataset. If the training dataset is not there, please upload it to Rafiki. With these fields configured, the job could be started by clicking the start button. Afterwards, the users would be redirected to the monitoring view. Note that it may take sometime to download the data for the first time. The Rafiki backend would run python train.py in the backend.

Serving

The serving job is similar to the training job except the job type is 'serving' and the pre-trained parameters are used. The Rafiki backend would run python serve.py. Users can jump to the serving view rendered using the index.html from the gist repo.

<!DOCTYPE html>
<!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]-->
<!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]-->
<!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]-->
<!--[if gt IE 8]><!-->
<html class="no-js">
<!--<![endif]-->
<head>
<meta charset="utf-8">
<title>CIFAR</title>
<meta name="description" content="">
<script type="text/javascript" src="http://code.jquery.com/jquery-1.12.4.min.js"></script>
<style>
body{
text-align:center;
}
h2{
text-align:center;
}
#result{
font-size: 20px;
font-weight: bold;
color: #333;
}
</style>
</head>
<body>
<h2>CIFAR 10, Image Classification Live Service</h2>
<!--
<button id="stop">STOP</button>
<button id="go">Predict</button>
-->
<form>
<p>Please upload a jpg file!</p>
<input id="file-input" type="file" accept="image/*;capture=camera" ></input>
</form>
<br/>
<img id="image" src=""/>
<div id="result"></div>
</body>
<script language="javascript">
$("#file-input").change(function(){
var f = $("#file-input")[0].files[0];
ReadFile(f,function(result){
var file = DataURItoBlob(result);
$("#image").attr("src",result);
predict_dish(file);
});
});
$("#stop").click(function(){
$.ajax({
url:"/command/stop",
type:"POST",
processData: false, // Don't process the files
contentType: false,
success:function(response){
alert("Stop Success!");
},
error:function(e){
console.log(e);
alert("Stop Failed!");
}
});
});
function predict_dish(file){
var formData = new FormData();
formData.append('image', file, "image.jpg");
$.ajax({
url:"/api",
data:formData,
type:"POST",
processData: false, // Don't process the files
contentType: false,
success:function(response){
$("#result").html(response);
},
error:function(e){
console.log(e);
$("#result").html("Error Occurs!");
}
});
}
var ReadFile = function(file,callback) {
var reader = new FileReader();
reader.onloadend = function () {
ProcessFile(reader.result, file.type,callback);
}
reader.onerror = function () {
alert('There was an error reading the file!');
}
reader.readAsDataURL(file);
}
var ProcessFile = function(dataURL, fileType,callback) {
var maxWidth = 400;
var maxHeight = 400;
var image = new Image();
image.src = dataURL;
image.onload = function () {
var width = image.width;
var height = image.height;
var shouldResize = (width > maxWidth) || (height > maxHeight);
if (!shouldResize) {
callback(dataURL);
return;
}
var newWidth;
var newHeight;
if (width > height) {
newHeight = height * (maxWidth / width);
newWidth = maxWidth;
} else {
newWidth = width * (maxHeight / height);
newHeight = maxHeight;
}
var canvas = document.createElement('canvas');
canvas.width = newWidth;
canvas.height = newHeight;
var context = canvas.getContext('2d');
context.drawImage(this, 0, 0, newWidth, newHeight);
dataURL = canvas.toDataURL(fileType);
callback(dataURL);
};
image.onerror = function () {
alert('There was an error processing your file!');
};
}
var DataURItoBlob = function(dataURI) {
// convert base64 to raw binary data held in a string
// doesn't handle URLEncoded DataURIs - see SO answer #6850276 for code that does this
var byteString = atob(dataURI.split(',')[1]);
// separate out the mime component
var mimeString = dataURI.split(',')[0].split(':')[1].split(';')[0];
// write the bytes of the string to an ArrayBuffer
var ab = new ArrayBuffer(byteString.length);
var ia = new Uint8Array(ab);
for (var i = 0; i < byteString.length; i++) {
ia[i] = byteString.charCodeAt(i);
}
try {
return new Blob([ab], {type: mimeString});
} catch (e) {
// The BlobBuilder API has been deprecated in favour of Blob, but older
// browsers don't know about the Blob constructor
// IE10 also supports BlobBuilder, but since the `Blob` constructor
// also works, there's no need to add `MSBlobBuilder`.
var BlobBuilder = window.WebKitBlobBuilder || window.MozBlobBuilder;
var bb = new BlobBuilder();
bb.append(ab);
return bb.getBlob(mimeString);
}
}
</script>
</html>
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
''' The VGG model is adapted from http://torch.ch/blog/2015/07/30/cifar.html'''
from singa import layer
from singa import metric
from singa import loss
from singa import net as ffnet
def ConvBnReLU(net, name, nb_filers, sample_shape=None):
net.add(layer.Conv2D(name + '_1', nb_filers, 3, 1, pad=1,
input_sample_shape=sample_shape))
net.add(layer.BatchNormalization(name + '_2'))
net.add(layer.Activation(name + '_3'))
def create_net(use_cpu=False):
if use_cpu:
layer.engine = 'singacpp'
net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy())
ConvBnReLU(net, 'conv1_1', 64, (3, 32, 32))
net.add(layer.Dropout('drop1', 0.3))
ConvBnReLU(net, 'conv1_2', 64)
net.add(layer.MaxPooling2D('pool1', 2, 2, border_mode='valid'))
ConvBnReLU(net, 'conv2_1', 128)
net.add(layer.Dropout('drop2_1', 0.4))
ConvBnReLU(net, 'conv2_2', 128)
net.add(layer.MaxPooling2D('pool2', 2, 2, border_mode='valid'))
ConvBnReLU(net, 'conv3_1', 256)
net.add(layer.Dropout('drop3_1', 0.4))
ConvBnReLU(net, 'conv3_2', 256)
net.add(layer.Dropout('drop3_2', 0.4))
ConvBnReLU(net, 'conv3_3', 256)
net.add(layer.MaxPooling2D('pool3', 2, 2, border_mode='valid'))
ConvBnReLU(net, 'conv4_1', 512)
net.add(layer.Dropout('drop4_1', 0.4))
ConvBnReLU(net, 'conv4_2', 512)
net.add(layer.Dropout('drop4_2', 0.4))
ConvBnReLU(net, 'conv4_3', 512)
net.add(layer.MaxPooling2D('pool4', 2, 2, border_mode='valid'))
ConvBnReLU(net, 'conv5_1', 512)
net.add(layer.Dropout('drop5_1', 0.4))
ConvBnReLU(net, 'conv5_2', 512)
net.add(layer.Dropout('drop5_2', 0.4))
ConvBnReLU(net, 'conv5_3', 512)
net.add(layer.MaxPooling2D('pool5', 2, 2, border_mode='valid'))
net.add(layer.Flatten('flat'))
net.add(layer.Dropout('drop_flat', 0.5))
net.add(layer.Dense('ip1', 512))
net.add(layer.BatchNormalization('batchnorm_ip1'))
net.add(layer.Activation('relu_ip1'))
net.add(layer.Dropout('drop_ip2', 0.5))
net.add(layer.Dense('ip2', 10))
return net
flask>=0.10.1
flask_cors>=3.0.2
pillow>=2.3.0
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
import sys
import time
import traceback
import numpy as np
from argparse import ArgumentParser
from singa import tensor, device
from singa import image_tool
from rafiki.agent import Agent, MsgType
import model
top_k = 5
tool = image_tool.ImageTool()
num_augmentation = 10
mean = 127
std = 64
def image_transform(image):
'''Input an image path and return a set of augmented images (type Image)'''
global tool
return tool.load(image).resize_by_list([40]).crop5((32, 32), 5).get()
label_map = {
0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck'
}
def get_name(label):
return label_map[label]
def serve(agent, use_cpu, parameter_file):
net = model.create_net(use_cpu)
if use_cpu:
print 'running with cpu'
dev = device.get_default_device()
else:
print "runing with gpu"
dev = device.create_cuda_gpu()
agent = agent
print 'Start intialization............'
net.load(parameter_file)
net.to_device(dev)
print 'End intialization............'
while True:
key, val = agent.pull()
if key is None:
time.sleep(0.1)
continue
msg_type = MsgType.parse(key)
if msg_type.is_request():
try:
response = ""
images = []
for im in image_transform(val['image']):
ary = np.array(im.convert('RGB'), dtype=np.float32)
images.append(ary.transpose(2, 0, 1) - mean)
images = np.array(images)
images -= mean
images /= std
x = tensor.from_numpy(images.astype(np.float32))
x.to_device(dev)
y = net.predict(x)
y.to_host()
y = tensor.to_numpy(y)
prob = np.average(y, 0)
# sort and reverse
labels = np.argsort(-prob)
for i in range(top_k):
response += "%s:%s<br/>" % (get_name(labels[i]),
prob[labels[i]])
except:
traceback.print_exc()
response = "Sorry, system error during prediction."
agent.push(MsgType.kResponse, response)
elif MsgType.kCommandStop.equal(msg_type):
print 'get stop command'
agent.push(MsgType.kStatus, "success")
break
else:
print 'get unsupported message %s' % str(msg_type)
agent.push(MsgType.kStatus, "Unknown command")
break
# while loop
print "server stop"
def main():
try:
# Setup argument parser
parser = ArgumentParser(description="SINGA CIFAR VGG SERVING MODEL")
parser.add_argument("-p", "--port", default=9999, help="listen port")
parser.add_argument("-C", "--use_cpu", action="store_true")
parser.add_argument("--parameter_file", default="./model",
help="relative path")
# Process arguments
args = parser.parse_args()
port = args.port
# start to train
agent = Agent(port)
serve(agent, args.use_cpu, args.parameter_file)
agent.stop()
except SystemExit:
return
except:
traceback.print_exc()
sys.stderr.write(" for help use --help \n\n")
return 2
if __name__ == '__main__':
main()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
import sys
import os
import cPickle
import traceback
import time
import numpy as np
from argparse import ArgumentParser
from singa import tensor, device, optimizer
from singa import utils
from singa import initializer
from singa.proto import core_pb2
from rafiki.agent import Agent, MsgType
import model
data_path = 'cifar-10-batches-py'
parameter_folder = './'
def main():
'''Command line options'''
try:
# Setup argument parser
parser = ArgumentParser(description="Train VGG over CIFAR10")
parser.add_argument('-p', '--port', default=9999, help='listening port')
parser.add_argument('-C', '--use_cpu', action="store_true")
parser.add_argument('--max_epoch', type=int, default=250)
# Process arguments
args = parser.parse_args()
port = args.port
use_cpu = args.use_cpu
if use_cpu:
print "runing with cpu"
dev = device.get_default_device()
else:
print "runing with gpu"
dev = device.create_cuda_gpu()
# start to train
net = model.create_net(use_cpu)
agent = Agent(port)
train(net, dev, agent, args.max_epoch)
# wait the agent finish handling http request
agent.stop()
except SystemExit:
return
except:
traceback.print_exc()
sys.stderr.write(" for help use --help \n\n")
def initialize(net, dev, opt):
'''initialize all parameters in the model'''
print 'Start intialization............'
for (p, name) in zip(net.param_values(), net.param_names()):
print name, p.shape
if 'mean' in name or 'beta' in name:
p.set_value(0.0)
elif 'var' in name:
p.set_value(1.0)
elif 'gamma' in name:
initializer.uniform(p, 0, 1)
elif len(p.shape) > 1:
if 'conv' in name:
initializer.gaussian(p, 0, 3 * 3 * p.shape[0])
else:
p.gaussian(0, 0.02)
else:
p.set_value(0)
print name, p.l1()
net.to_device(dev)
print 'End intialization............'
def load_dataset(filepath):
'''load data from binary file'''
print 'Loading data file %s' % filepath
with open(filepath, 'rb') as fd:
cifar10 = cPickle.load(fd)
image = cifar10['data'].astype(dtype=np.uint8)
image = image.reshape((-1, 3, 32, 32))
label = np.asarray(cifar10['labels'], dtype=np.uint8)
label = label.reshape(label.size, 1)
return image, label
def load_train_data(num_batches=5):
labels = []
batchsize = 10000
images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8)
for did in range(1, num_batches + 1):
fname_train_data = os.path.join(data_path, "data_batch_{}".format(did))
image, label = load_dataset(fname_train_data)
images[(did - 1) * batchsize:did * batchsize] = image
labels.extend(label)
images = np.array(images, dtype=np.float32)
labels = np.array(labels, dtype=np.int32)
return images, labels
def load_test_data():
images, labels = load_dataset(
os.path.join(data_path, "test_batch"))
return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32)
def get_data():
'''load data'''
train_x, train_y = load_train_data()
test_x, test_y = load_test_data()
mean = train_x.mean()
std = train_x.std()
print mean
print std
train_x -= mean
test_x -= mean
train_x /= std
test_x /= std
return train_x, train_y, test_x, test_y
def handle_cmd(agent):
pause = False
stop = False
while not stop:
key, val = agent.pull()
if key is not None:
msg_type = MsgType.parse(key)
if msg_type.is_command():
if MsgType.kCommandPause.equal(msg_type):
agent.push(MsgType.kStatus, "success")
pause = True
elif MsgType.kCommandResume.equal(msg_type):
agent.push(MsgType.kStatus, "success")
pause = False
elif MsgType.kCommandStop.equal(msg_type):
agent.push(MsgType.kStatus, "success")
stop = True
else:
agent.push(MsgType.kStatus, "Warning, unkown message type")
print "Unsupported command %s" % str(key)
if pause and not stop:
time.sleep(0.1)
else:
break
return stop
def get_lr(epoch):
'''change learning rate as epoch goes up'''
return 0.1 / float(1 << ((epoch / 25)))
def train(net, dev, agent, max_epoch, batch_size=100):
agent.push(MsgType.kStatus, 'Downlaoding data...')
train_x, train_y, test_x, test_y = get_data()
print 'training shape', train_x.shape, train_y.shape
print 'validation shape', test_x.shape, test_y.shape
agent.push(MsgType.kStatus, 'Finish downloading data')
opt = optimizer.SGD(momentum=0.9, weight_decay=0.0005)
initialize(net, dev, opt)
tx = tensor.Tensor((batch_size, 3, 32, 32), dev)
ty = tensor.Tensor((batch_size, ), dev, core_pb2.kInt)
num_train_batch = train_x.shape[0] / batch_size
num_test_batch = test_x.shape[0] / (batch_size)
idx = np.arange(train_x.shape[0], dtype=np.int32)
for epoch in range(max_epoch):
if handle_cmd(agent):
break
np.random.shuffle(idx)
print 'Epoch %d' % epoch
loss, acc = 0.0, 0.0
for b in range(num_test_batch):
x = test_x[b * batch_size:(b + 1) * batch_size]
y = test_y[b * batch_size:(b + 1) * batch_size]
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
l, a = net.evaluate(tx, ty)
loss += l
acc += a
print 'testing loss = %f, accuracy = %f' % (loss / num_test_batch,
acc / num_test_batch)
# put test status info into a shared queue
info = dict(
phase='test',
step = epoch,
accuracy = acc / num_test_batch,
loss = loss / num_test_batch,
timestamp = time.time())
agent.push(MsgType.kInfoMetric, info)
loss, acc = 0.0, 0.0
for b in range(num_train_batch):
x = train_x[idx[b * batch_size:(b + 1) * batch_size]]
y = train_y[idx[b * batch_size:(b + 1) * batch_size]]
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
grads, (l, a) = net.train(tx, ty)
loss += l
acc += a
for (s, p, g) in zip(net.param_specs(),
net.param_values(), grads):
opt.apply_with_lr(epoch, get_lr(epoch), g, p,
str(s.name))
info = 'training loss = %f, training accuracy = %f' % (l, a)
utils.update_progress(b * 1.0 / num_train_batch, info)
# put training status info into a shared queue
info = dict(
phase='train',
step= epoch,
accuracy = acc / num_train_batch,
loss = loss / num_train_batch,
timestamp = time.time())
agent.push(MsgType.kInfoMetric, info)
info = 'training loss = %f, training accuracy = %f' \
% (loss / num_train_batch, acc / num_train_batch)
print info
if not os.path.exists(parameter_folder):
os.makedirs(parameter_folder)
net.save(os.path.join(parameter_folder, 'model'))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment