Skip to content

Instantly share code, notes, and snippets.

@aaronwwf
Forked from XiangruiCAI/README.md
Last active November 9, 2016 08:38
Show Gist options
  • Save aaronwwf/3231cbf85cd93558cd47907ff5561385 to your computer and use it in GitHub Desktop.
Save aaronwwf/3231cbf85cd93558cd47907ff5561385 to your computer and use it in GitHub Desktop.
rafiki-cifar-example
.project
.pydevproject
data_
parameter_
*.pyc
  • This is an example of SINGA CNN model.
  • Training cifar dataset with VGG model.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
import os, sys, shutil
import urllib
import cPickle
import numpy as np
data_folder = "data_"
tar_data_url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
tar_data_name = 'cifar-10-python.tar.gz'
data_path = 'cifar-10-batches-py'
parameter_folder = "parameter_"
parameter_name = "parameter"
tar_parameter_url = "http://comp.nus.edu.sg/~dbsystem/singa/assets/file/cifar/parameter.tar.gz"
tar_parameter_name = 'parameter.tar.gz'
mean_url = "http://comp.nus.edu.sg/~dbsystem/singa/assets/file/train.mean.npy"
mean_name = "train.mean.npy"
def load_dataset(filepath):
'''load data from binary file'''
print 'Loading data file %s' % filepath
with open(filepath, 'rb') as fd:
cifar10 = cPickle.load(fd)
image = cifar10['data'].astype(dtype=np.uint8)
image = image.reshape((-1, 3, 32, 32))
label = np.asarray(cifar10['labels'], dtype=np.uint8)
label = label.reshape(label.size, 1)
return image, label
def load_train_data(num_batches=5):
labels = []
batchsize = 10000
images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8)
for did in range(1, num_batches + 1):
fname_train_data = os.path.join(data_folder, data_path,
"data_batch_{}".format(did))
image, label = load_dataset(fname_train_data)
images[(did - 1) * batchsize:did * batchsize] = image
labels.extend(label)
images = np.array(images, dtype=np.float32)
labels = np.array(labels, dtype=np.int32)
return images, labels
def load_test_data():
images, labels = load_dataset(
os.path.join(data_folder, data_path, "test_batch"))
return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32)
def load_mean_data():
mean_path = os.path.join(data_folder, mean_name)
if os.path.exists(mean_path):
return np.load(mean_path)
return None
def save_mean_data(mean):
mean_path = os.path.join(data_folder, mean_name)
np.save(mean_path, mean)
return
def train_file_prepare():
'''download train file'''
if os.path.exists(os.path.join(data_folder, data_path)):
return
print "download file"
#clean data
download_file(tar_data_url, data_folder)
untar_data(os.path.join(data_folder, tar_data_name), data_folder)
if not os.path.exists(parameter_folder):
os.makedirs(parameter_folder)
def serve_file_prepare():
'''download parameter file and mean file'''
if not os.path.exists(os.path.join(parameter_folder, tar_parameter_name)):
print "download parameter file"
download_file(tar_parameter_url, parameter_folder)
untar_data(
os.path.join(parameter_folder, tar_parameter_name),
parameter_folder)
if not os.path.exists(os.path.join(data_folder, mean_name)):
print "download mean file"
download_file(mean_url, data_folder)
#clean data
def download_file(url, dest):
'''
download one file to dest
'''
if not os.path.exists(dest):
os.makedirs(dest)
if (url.startswith('http')):
file_name = url.split('/')[-1]
target = os.path.join(dest, file_name)
urllib.urlretrieve(url, target)
return
def get_parameter(file_name=None, auto_find=False):
'''
get a parameter file or return none
'''
if not os.path.exists(parameter_folder):
os.makedirs(parameter_folder)
return None
if file_name:
return os.path.join(parameter_folder, file_name)
#find the last parameter file if outo_find is True
if auto_find:
parameter_list = []
for f in os.listdir(parameter_folder):
if f.endswith(".model"):
parameter_list.append(os.path.join(parameter_folder, f[0:-6]))
if len(parameter_list) == 0:
return None
parameter_list.sort()
return parameter_list[-1]
else:
return None
def untar_data(file_path, dest):
print 'untar data ..................'
tar_file = file_path
import tarfile
tar = tarfile.open(tar_file)
print dest
print file_path
tar.extractall(dest)
tar.close()
<!DOCTYPE html>
<!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]-->
<!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]-->
<!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]-->
<!--[if gt IE 8]><!-->
<html class="no-js">
<!--<![endif]-->
<head>
<meta charset="utf-8">
<title>CIFAR</title>
<meta name="description" content="">
<script type="text/javascript" src="http://code.jquery.com/jquery-1.12.4.min.js"></script>
<style>
body{
text-align:center;
}
h2{
text-align:center;
}
#result{
font-size: 20px;
font-weight: bold;
color: #333;
}
</style>
</head>
<body>
<h2>CIFAR 10, Image Classification Live Service</h2>
<!--
<button id="stop">STOP</button>
<button id="go">Predict</button>
-->
<form>
<p>Please upload a jpg file!</p>
<input id="file-input" type="file" accept="image/*;capture=camera" ></input>
</form>
<br/>
<img id="image" src=""/>
<div id="result"></div>
</body>
<script language="javascript">
$("#file-input").change(function(){
var f = $("#file-input")[0].files[0];
ReadFile(f,function(result){
var file = DataURItoBlob(result);
$("#image").attr("src",result);
predict_dish(file);
});
});
$("#stop").click(function(){
$.ajax({
url:"/command/stop",
type:"POST",
processData: false, // Don't process the files
contentType: false,
success:function(response){
alert("Stop Success!");
},
error:function(e){
console.log(e);
alert("Stop Failed!");
}
});
});
function predict_dish(file){
var formData = new FormData();
formData.append('image', file, "image.jpg");
$.ajax({
url:"/api",
data:formData,
type:"POST",
processData: false, // Don't process the files
contentType: false,
success:function(response){
$("#result").html(response);
},
error:function(e){
console.log(e);
$("#result").html("Error Occurs!");
}
});
}
var ReadFile = function(file,callback) {
var reader = new FileReader();
reader.onloadend = function () {
ProcessFile(reader.result, file.type,callback);
}
reader.onerror = function () {
alert('There was an error reading the file!');
}
reader.readAsDataURL(file);
}
var ProcessFile = function(dataURL, fileType,callback) {
var maxWidth = 400;
var maxHeight = 400;
var image = new Image();
image.src = dataURL;
image.onload = function () {
var width = image.width;
var height = image.height;
var shouldResize = (width > maxWidth) || (height > maxHeight);
if (!shouldResize) {
callback(dataURL);
return;
}
var newWidth;
var newHeight;
if (width > height) {
newHeight = height * (maxWidth / width);
newWidth = maxWidth;
} else {
newWidth = width * (maxHeight / height);
newHeight = maxHeight;
}
var canvas = document.createElement('canvas');
canvas.width = newWidth;
canvas.height = newHeight;
var context = canvas.getContext('2d');
context.drawImage(this, 0, 0, newWidth, newHeight);
dataURL = canvas.toDataURL(fileType);
callback(dataURL);
};
image.onerror = function () {
alert('There was an error processing your file!');
};
}
var DataURItoBlob = function(dataURI) {
// convert base64 to raw binary data held in a string
// doesn't handle URLEncoded DataURIs - see SO answer #6850276 for code that does this
var byteString = atob(dataURI.split(',')[1]);
// separate out the mime component
var mimeString = dataURI.split(',')[0].split(':')[1].split(';')[0];
// write the bytes of the string to an ArrayBuffer
var ab = new ArrayBuffer(byteString.length);
var ia = new Uint8Array(ab);
for (var i = 0; i < byteString.length; i++) {
ia[i] = byteString.charCodeAt(i);
}
try {
return new Blob([ab], {type: mimeString});
} catch (e) {
// The BlobBuilder API has been deprecated in favour of Blob, but older
// browsers don't know about the Blob constructor
// IE10 also supports BlobBuilder, but since the `Blob` constructor
// also works, there's no need to add `MSBlobBuilder`.
var BlobBuilder = window.WebKitBlobBuilder || window.MozBlobBuilder;
var bb = new BlobBuilder();
bb.append(ab);
return bb.getBlob(mimeString);
}
}
</script>
</html>
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
from singa import layer
from singa import metric
from singa import loss
from singa import net as ffnet
def add_layer_group(net, name, nb_filers, sample_shape=None):
'''add a group of layers which will be used in vgg model recurrently'''
net.add(layer.Conv2D(name + '_1',nb_filers,3,1,pad=1,
input_sample_shape=sample_shape))
net.add(layer.Activation(name + 'act_1'))
net.add(layer.Conv2D(name + '_2', nb_filers, 3, 1, pad=1))
net.add(layer.Activation(name + 'act_2'))
net.add(layer.MaxPooling2D(name, 2, 2, pad=0))
def create(use_cpu = False):
'''create network of vgg model'''
if use_cpu:
layer.engine = 'singacpp'
net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy())
add_layer_group(net, 'conv1', 64, (3, 32, 32))
add_layer_group(net, 'conv2', 128)
add_layer_group(net, 'conv3', 256)
add_layer_group(net, 'conv4', 512)
add_layer_group(net, 'conv5', 512)
net.add(layer.Flatten('flat'))
net.add(layer.Dense('ip1', 512))
net.add(layer.Activation('relu_ip1'))
net.add(layer.Dropout('drop1'))
net.add(layer.Dense('ip2', 10))
return net
flask>=0.10.1
flask_cors>=3.0.2
pillow>=2.3.0
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
import sys, glob, os, random, shutil, time
import urllib, traceback
import numpy as np
from multiprocessing import Process, Queue
from argparse import ArgumentParser
from argparse import RawDescriptionHelpFormatter
from singa import tensor, device, optimizer
from singa import utils, initializer, metric, image_tool
from singa.proto import core_pb2
from rafiki.agent import Agent, MsgType
import data
import model
top_k = 5
tool = image_tool.ImageTool()
small_size = 35
big_size = 45
crop_size = 32
num_augmentation = 10
def image_transform(image):
'''Input an image path and return a set of augmented images (type Image)'''
global tool
return tool.load(image).resize_by_list([(small_size + big_size)/2]).crop5(
(crop_size, crop_size), 5).flip(2).get()
def main(argv=None):
'''Command line options'''
if argv is None:
argv = sys.argv
else:
sys.argv.extend(argv)
try:
# Setup argument parser
parser = ArgumentParser(
description="SINGA CIFAR SVG TRANING MODEL",
formatter_class=RawDescriptionHelpFormatter)
parser.add_argument(
"-p",
"--port",
dest="port",
default=9999,
help="the port to listen to, default is 9999")
parser.add_argument(
"-param",
"--parameter",
dest="parameter",
default="parameter",
help="the parameter file path to be loaded")
parser.add_argument(
"-C",
"--cpu",
dest="use_cpu",
action="store_true",
default=False,
help="Using cpu or not, default is using gpu")
# Process arguments
args = parser.parse_args()
port = args.port
parameter_file = args.parameter
use_cpu = args.use_cpu
# start to train
agent = Agent(port)
service = Service(agent,use_cpu)
service.initialize(parameter_file)
service.serve()
#wait the agent finish handling http request
time.sleep(1)
agent.stop()
except SystemExit:
return
except:
traceback.print_exc()
sys.stderr.write(" for help use --help \n\n")
return 2
class Service():
def __init__(self, agent, use_cpu):
self.model =model.create(use_cpu)
if use_cpu:
print "running with cpu"
self.device = device.get_default_device()
#print "cpu mode is not supported at present!"
else:
print "runing with gpu"
self.device = device.create_cuda_gpu()
self.opt = optimizer.SGD(momentum=0.9, weight_decay=0.0005)
self.agent = agent
def initialize(self, parameter_file=None):
'''get parameters of the model to run the model in predict manner'''
data.serve_file_prepare()
print 'Start intialization............'
parameter = data.get_parameter(parameter_file, True)
print 'initialize with %s' % parameter
self.model.load(parameter)
self.model.to_device(self.device)
print 'End intialization............'
self.mean = data.load_mean_data()
def serve(self):
'''predict the label for the uploaded images'''
while True:
msg,data= self.agent.pull()
if msg == None:
continue
msg=MsgType.parse(msg)
if msg.is_request():
try:
response = ""
images = []
for im in image_transform(data):
ary = np.array(im.convert('RGB'), dtype=np.float32)
images.append(ary.transpose(2, 0, 1))
images = np.array(images)
#normalize
images -= self.mean
x = tensor.from_numpy(images.astype(np.float32))
x.to_device(self.device)
y = self.model.predict(x)
y.to_host()
y = tensor.to_numpy(y)
prob = np.average(y, 0)
#sort and reverse
labels = np.flipud(np.argsort(prob))
for i in range(top_k):
response += "%s:%s<br/>" % (get_name(labels[i]),
prob[labels[i]])
except:
traceback.print_exc()
response="sorry, system error."
self.agent.push(MsgType.kResponse,response)
elif msg.is_command():
if MsgType.kCommandStop.equal(msg):
print 'get stop command'
self.agent.push(MsgType.kStatus,"success")
break
else:
print 'get unsupported command %s' % str(msg)
self.agent.push(MsgType.kStatus,"failure")
else:
print 'get unsupported message %s' % str(msg)
self.agent.push(MsgType.kStatus,"failure")
break
time.sleep(0.01)
# while loop
print "server stop"
return
label_map = {
0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck'
}
def get_name(label):
return label_map[label]
if __name__ == '__main__':
main()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
import sys, os, traceback
import glob, random, shutil, time
import numpy as np
from argparse import ArgumentParser
from argparse import RawDescriptionHelpFormatter
from singa import tensor, device, optimizer
from singa import utils, initializer, metric
from singa.proto import core_pb2
from rafiki.agent import Agent, MsgType
import model
import data
def main():
'''Command line options'''
argv = sys.argv
try:
# Setup argument parser
parser = ArgumentParser(
description="SINGA CIFAR SVG TRANING MODEL",
formatter_class=RawDescriptionHelpFormatter)
parser.add_argument(
"-p",
"--port",
dest="port",
default=9999,
help="the port to listen to, default is 9999")
parser.add_argument(
"-param",
"--parameter",
dest="parameter",
help="the parameter file path to be loaded")
parser.add_argument(
"-C",
"--cpu",
dest="use_cpu",
action="store_true",
default=False,
help="Using cpu or not, default is using gpu")
# Process arguments
args = parser.parse_args()
port = args.port
parameter_file = args.parameter
use_cpu = args.use_cpu
# start to train
m = model.create(use_cpu)
agent = Agent(port)
trainer = Trainer(m,agent,use_cpu)
trainer.initialize(parameter_file)
trainer.train()
#wait the agent finish handling http request
time.sleep(1)
agent.stop()
except SystemExit:
return
except:
#p.terminate()
traceback.print_exc()
sys.stderr.write(" for help use --help \n\n")
return 2
class Trainer():
'''train a singa model'''
def __init__(self, model, agent, use_cpu):
self.model = model
if use_cpu:
print "runing with cpu"
self.device = device.get_default_device()
#raise CLIError("Currently cpu is not support!")
else:
print "runing with gpu"
self.device = device.create_cuda_gpu()
self.opt = optimizer.SGD(momentum=0.9, weight_decay=0.0005)
self.agent = agent
def initialize(self, parameter_file):
'''initialize all parameters in the model'''
print 'Start intialization............'
if parameter_file:
parameter = data.get_parameter(parameter_file)
print 'initialize with %s' % parameter
self.model.load(parameter)
else:
for (p, specs) in zip(self.model.param_values(),
self.model.param_specs()):
filler = specs.filler
if filler.type == 'gaussian':
initializer.gaussian(p, filler.mean, filler.std)
elif filler.type == 'xavier':
initializer.xavier(p)
p *= 0.5 # 0.5 if use glorot, which would have val acc to 83
else:
p.set_value(0)
self.opt.register(p, specs)
print specs.name, filler.type, p.l1()
self.model.to_device(self.device)
print 'End intialization............'
def data_prepare(self):
'''load data'''
data.train_file_prepare()
self.train_x, self.train_y = data.load_train_data()
self.test_x, self.test_y = data.load_test_data()
self.mean = data.load_mean_data()
if self.mean is None:
self.mean = np.average(self.train_x, axis=0)
data.save_mean_data(self.mean)
self.train_x -= self.mean
self.test_x -= self.mean
def pause(self):
while True:
msg,data = self.agent.pull()
if msg == None:
continue
msg=MsgType.parse(msg)
if MsgType.kCommandResume.equal(msg):
self.agent.push(MsgType.kStatus,"success")
break
elif MsgType.kCommandStop.equal(msg):
self.agent.push(MsgType.kStatus,"success")
return False
else:
self.agent.push(MsgType.kStatus,"warning, nothing happened")
print "Receive an unsupported command: %s " % str(msg)
pass
time.sleep(0.1)
return True
def listen(self):
msg,data = self.agent.pull()
if not msg == None:
msg=MsgType.parse(msg)
if msg.is_command():
if MsgType.kCommandPause.equal(msg):
self.agent.push(MsgType.kStatus,"success")
if not self.pause():
return False
elif MsgType.kCommandStop.equal(msg):
self.agent.push(MsgType.kStatus,"success")
return False
else:
self.agent.push(MsgType.kStatus,"warning, nothing happened")
print "Unsupported command %s" % str(msg)
pass
else:
pass
else:
pass
return True
def train(self, num_epoch=140, batch_size=50):
'''train and test model'''
self.data_prepare()
print 'training shape', self.train_x.shape, self.train_y.shape
print 'validation shape', self.test_x.shape, self.test_y.shape
tx = tensor.Tensor((batch_size, 3, 32, 32), self.device)
ty = tensor.Tensor((batch_size, ), self.device, core_pb2.kInt)
num_train_batch = self.train_x.shape[0] / batch_size
num_test_batch = self.test_x.shape[0] / (batch_size)
accuracy = metric.Accuracy()
idx = np.arange(self.train_x.shape[0], dtype=np.int32)
# frequency of gathering training status info
skip = 20
stop=False
for epoch in range(num_epoch):
if not self.listen():
stop=True
if stop:
break
np.random.shuffle(idx)
loss, acc = 0.0, 0.0
print 'Epoch %d' % epoch
loss, acc = 0.0, 0.0
for b in range(num_test_batch):
x = self.test_x[b * batch_size:(b + 1) * batch_size]
y = self.test_y[b * batch_size:(b + 1) * batch_size]
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
l, a = self.model.evaluate(tx, ty)
loss += l
acc += a
print 'testing loss = %f, accuracy = %f' % (loss / num_test_batch,
acc / num_test_batch)
# put test status info into a shared queue
dic = dict(
phase='test',
#step = (epoch + 1) * num_train_batch / skip - 1,
step=epoch * num_train_batch / skip,
accuracy=acc / num_test_batch,
loss=loss / num_test_batch,
timestamp=time.time())
self.agent.push(MsgType.kInfoMetric,dic)
for b in range(num_train_batch):
if not self.listen():
stop=True
break
x = self.train_x[idx[b * batch_size:(b + 1) * batch_size]]
y = self.train_y[idx[b * batch_size:(b + 1) * batch_size]]
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
grads, (l, a) = self.model.train(tx, ty)
loss += l
acc += a
for (s, p, g) in zip(self.model.param_specs(),
self.model.param_values(), grads):
self.opt.apply_with_lr(epoch, get_lr(epoch), g, p,
str(s.name))
info = 'training loss = %f, training accuracy = %f' % (l, a)
# put training status info into a shared queue
if b % skip == 0:
dic = dict(
phase='train',
step=(epoch * num_train_batch + b) / skip,
accuracy=a,
loss=l,
timestamp=time.time())
self.agent.push(MsgType.kInfoMetric,dic)
# update progress bar
utils.update_progress(b * 1.0 / num_train_batch, info)
info = 'training loss = %f, training accuracy = %f' \
% (loss / num_train_batch, acc / num_train_batch)
print info
if epoch > 0 and epoch % 10 == 0:
self.model.save(
os.path.join(data.parameter_folder, 'parameter_%d' %
epoch))
self.model.save(os.path.join(data.parameter_folder, 'parameter_last'))
return
def get_lr(epoch):
'''change learning rate as epoch goes up'''
if epoch < 360:
return 0.0008
elif epoch < 540:
return 0.0001
else:
return 0.00001
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment