- This is an example of SINGA CNN model.
- Training cifar dataset with VGG model.
-
-
Save XiangruiCAI/a38ae090eecefae4ac975da43ab1365b to your computer and use it in GitHub Desktop.
rafiki-cifar-example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# ============================================================================= | |
import os, sys, shutil | |
import urllib | |
import cPickle | |
import numpy as np | |
data_folder = "data_" | |
tar_data_url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' | |
tar_data_name = 'cifar-10-python.tar.gz' | |
data_path = 'cifar-10-batches-py' | |
parameter_folder = "parameter_" | |
parameter_name = "parameter" | |
tar_parameter_url = "http://comp.nus.edu.sg/~dbsystem/singa/assets/file/parameter.tar.gz" | |
tar_parameter_name = 'parameter.tar.gz' | |
mean_url = "http://comp.nus.edu.sg/~dbsystem/singa/assets/file/train.mean.npy" | |
mean_name = "train.mean.npy" | |
def load_dataset(filepath): | |
'''load data from binary file''' | |
print 'Loading data file %s' % filepath | |
with open(filepath, 'rb') as fd: | |
cifar10 = cPickle.load(fd) | |
image = cifar10['data'].astype(dtype=np.uint8) | |
image = image.reshape((-1, 3, 32, 32)) | |
label = np.asarray(cifar10['labels'], dtype=np.uint8) | |
label = label.reshape(label.size, 1) | |
return image, label | |
def load_train_data(num_batches=5): | |
labels = [] | |
batchsize = 10000 | |
images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8) | |
for did in range(1, num_batches + 1): | |
fname_train_data = os.path.join(data_folder, data_path, | |
"data_batch_{}".format(did)) | |
image, label = load_dataset(fname_train_data) | |
images[(did - 1) * batchsize:did * batchsize] = image | |
labels.extend(label) | |
images = np.array(images, dtype=np.float32) | |
labels = np.array(labels, dtype=np.int32) | |
return images, labels | |
def load_test_data(): | |
images, labels = load_dataset( | |
os.path.join(data_folder, data_path, "test_batch")) | |
return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32) | |
def load_mean_data(): | |
mean_path = os.path.join(data_folder, mean_name) | |
if os.path.exists(mean_path): | |
return np.load(mean_path) | |
return None | |
def save_mean_data(mean): | |
mean_path = os.path.join(data_folder, mean_name) | |
np.save(mean_path, mean) | |
return | |
def train_file_prepare(): | |
'''download train file''' | |
if os.path.exists(os.path.join(data_folder, data_path)): | |
return | |
print "download file" | |
#clean data | |
download_file(tar_data_url, data_folder) | |
untar_data(os.path.join(data_folder, tar_data_name), data_folder) | |
def serve_file_prepare(): | |
'''download parameter file and mean file''' | |
if not os.path.exists(os.path.join(parameter_folder, parameter_name)): | |
print "download parameter file" | |
download_file(tar_parameter_url, parameter_folder) | |
untar_data( | |
os.path.join(parameter_folder, tar_parameter_name), | |
parameter_folder) | |
if not os.path.exists(os.path.join(data_folder, mean_name)): | |
print "download mean file" | |
download_file(mean_url, data_folder) | |
#clean data | |
def download_file(url, dest): | |
''' | |
download one file to dest | |
''' | |
if not os.path.exists(dest): | |
os.makedirs(dest) | |
if (url.startswith('http')): | |
file_name = url.split('/')[-1] | |
target = os.path.join(dest, file_name) | |
urllib.urlretrieve(url, target) | |
return | |
def get_parameter(file_name=None, auto_find=False): | |
''' | |
get a parameter file or return none | |
''' | |
if not os.path.exists(parameter_folder): | |
os.makedirs(parameter_folder) | |
return None | |
if file_name: | |
return os.path.join(parameter_folder, file_name) | |
#find the last parameter file if outo_find is True | |
if auto_find: | |
parameter_list = [] | |
for f in os.listdir(parameter_folder): | |
if f.endswith(".model"): | |
parameter_list.append(os.path.join(parameter_folder, f[0:-6])) | |
if len(parameter_list) == 0: | |
return None | |
parameter_list.sort() | |
return parameter_list[-1] | |
else: | |
return None | |
def untar_data(file_path, dest): | |
print 'untar data ..................' | |
tar_file = file_path | |
import tarfile | |
tar = tarfile.open(tar_file) | |
print dest | |
print file_path | |
tar.extractall(dest) | |
tar.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# ============================================================================= | |
from multiprocessing import Process | |
from flask import Flask, request, jsonify, send_from_directory | |
from flask_cors import CORS, cross_origin | |
app = Flask(__name__) | |
top_k_ = 5 | |
def success(data=""): | |
'''return success status in json format''' | |
res = dict(result="success", data=data) | |
return jsonify(res) | |
def failure(message): | |
'''return failure status in json format''' | |
res = dict(result="message", message=message) | |
return jsonify(res) | |
def start_monitor(port, queue): | |
''' | |
port run a server at 0.0.0.0:port | |
queue: transfer status information | |
data: buffer the information in queue | |
''' | |
global queue_, data_, type_ | |
queue_ = queue | |
data_ = [] | |
type_ = "monitor" | |
app.run(host='0.0.0.0', port=port) | |
return | |
def start_serve(port, service): | |
'''start a server for serve, e.g., identify a uploaded image''' | |
global type_, service_ | |
service_ = service | |
type_ = "serve" | |
app.run(host='0.0.0.0', port=port) | |
return | |
def getDataFromQueue(): | |
global queue_, data | |
while not queue_.empty(): | |
d = queue_.get() | |
data_.append(d) | |
@app.route("/") | |
@cross_origin() | |
def index(): | |
global type_ | |
print type_ | |
if type_ == "monitor": | |
return "Hello,This is SINGA monitor http server" | |
else: | |
return send_from_directory(".", "index.html", mimetype='text/html') | |
# predict a uploaded image | |
@app.route("/predict", methods=['POST']) | |
@cross_origin() | |
def predict(): | |
global type_, service_ | |
if type_ == "monitor": | |
return failure("not available in monitor mode") | |
if request.method == 'POST': | |
try: | |
print "test" | |
response = service_.serve(request) | |
except Exception as e: | |
print str(e) | |
return e | |
return response | |
# support two operations for user to monitor the training status | |
@app.route('/getAllData') | |
@cross_origin() | |
def getAllData(): | |
global data_, type_ | |
if type_ == "serve": | |
return failure("not available in serve mode") | |
getDataFromQueue() | |
return success(data_) | |
@app.route('/getTopKData') | |
@cross_origin() | |
def getTopKData(): | |
global data_, type_ | |
if type_ == "serve": | |
return failure("not available in serve mode") | |
k = request.args.get("k", top_k_) | |
try: | |
k = int(k) | |
except: | |
return failure("k should be integer") | |
getDataFromQueue() | |
return success(data_[-k:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]--> | |
<!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]--> | |
<!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]--> | |
<!--[if gt IE 8]><!--> | |
<html class="no-js"> | |
<!--<![endif]--> | |
<head> | |
<meta charset="utf-8"> | |
<title>Foodlg - Insight into your titbits</title> | |
<meta name="description" content=""> | |
<script type="text/javascript" src="http://code.jquery.com/jquery-1.12.4.min.js"></script> | |
</head> | |
<body> | |
<h2>CIFAR PREDICTION</h2> | |
<form> | |
<p>Please upload a jpg file!</p> | |
<input id="file-input" type="file" accept="image/*;capture=camera" ></input> | |
</form> | |
<button id="go">GO</button> | |
<br/> | |
<img id="image" src=""/> | |
<div id="result"></div> | |
</body> | |
<script language="javascript"> | |
$("#go").click(function(){ | |
var f = $("#file-input")[0].files[0]; | |
ReadFile(f,function(result){ | |
var file = DataURItoBlob(result); | |
$("#image").attr("src",result); | |
predict_dish(file); | |
}); | |
}); | |
function predict_dish(file){ | |
var formData = new FormData(); | |
formData.append('image', file, "image.jpg"); | |
$.ajax({ | |
url:"/predict", | |
data:formData, | |
type:"POST", | |
processData: false, // Don't process the files | |
contentType: false, | |
success:function(response){ | |
$("#result").html(response); | |
}, | |
error:function(e){ | |
console.log(e); | |
$("#result").html("Error Occurs!"); | |
} | |
}); | |
} | |
var ReadFile = function(file,callback) { | |
var reader = new FileReader(); | |
reader.onloadend = function () { | |
ProcessFile(reader.result, file.type,callback); | |
} | |
reader.onerror = function () { | |
alert('There was an error reading the file!'); | |
} | |
reader.readAsDataURL(file); | |
} | |
var ProcessFile = function(dataURL, fileType,callback) { | |
var maxWidth = 400; | |
var maxHeight = 400; | |
var image = new Image(); | |
image.src = dataURL; | |
image.onload = function () { | |
var width = image.width; | |
var height = image.height; | |
var shouldResize = (width > maxWidth) || (height > maxHeight); | |
if (!shouldResize) { | |
callback(dataURL); | |
return; | |
} | |
var newWidth; | |
var newHeight; | |
if (width > height) { | |
newHeight = height * (maxWidth / width); | |
newWidth = maxWidth; | |
} else { | |
newWidth = width * (maxHeight / height); | |
newHeight = maxHeight; | |
} | |
var canvas = document.createElement('canvas'); | |
canvas.width = newWidth; | |
canvas.height = newHeight; | |
var context = canvas.getContext('2d'); | |
context.drawImage(this, 0, 0, newWidth, newHeight); | |
dataURL = canvas.toDataURL(fileType); | |
callback(dataURL); | |
}; | |
image.onerror = function () { | |
alert('There was an error processing your file!'); | |
}; | |
} | |
var DataURItoBlob = function(dataURI) { | |
// convert base64 to raw binary data held in a string | |
// doesn't handle URLEncoded DataURIs - see SO answer #6850276 for code that does this | |
var byteString = atob(dataURI.split(',')[1]); | |
// separate out the mime component | |
var mimeString = dataURI.split(',')[0].split(':')[1].split(';')[0]; | |
// write the bytes of the string to an ArrayBuffer | |
var ab = new ArrayBuffer(byteString.length); | |
var ia = new Uint8Array(ab); | |
for (var i = 0; i < byteString.length; i++) { | |
ia[i] = byteString.charCodeAt(i); | |
} | |
try { | |
return new Blob([ab], {type: mimeString}); | |
} catch (e) { | |
// The BlobBuilder API has been deprecated in favour of Blob, but older | |
// browsers don't know about the Blob constructor | |
// IE10 also supports BlobBuilder, but since the `Blob` constructor | |
// also works, there's no need to add `MSBlobBuilder`. | |
var BlobBuilder = window.WebKitBlobBuilder || window.MozBlobBuilder; | |
var bb = new BlobBuilder(); | |
bb.append(ab); | |
return bb.getBlob(mimeString); | |
} | |
} | |
</script> | |
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# ============================================================================= | |
from singa import layer | |
from singa import metric | |
from singa import loss | |
from singa import net as ffnet | |
def add_layer_group(net, name, nb_filers, sample_shape=None): | |
'''add a group of layers which will be used in vgg model recurrently''' | |
net.add( | |
layer.Conv2D( | |
name + '_1', | |
nb_filers, | |
3, | |
1, | |
pad=1, | |
input_sample_shape=sample_shape)) | |
net.add(layer.Activation(name + 'act_1')) | |
net.add(layer.Conv2D(name + '_2', nb_filers, 3, 1, pad=1)) | |
net.add(layer.Activation(name + 'act_2')) | |
net.add(layer.MaxPooling2D(name, 2, 2, pad=0)) | |
def create(use_cpu = False): | |
'''create network of vgg model''' | |
if use_cpu: | |
layer.engine = 'singacpp' | |
net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy()) | |
add_layer_group(net, 'conv1', 64, (3, 32, 32)) | |
add_layer_group(net, 'conv2', 128) | |
add_layer_group(net, 'conv3', 256) | |
add_layer_group(net, 'conv4', 512) | |
add_layer_group(net, 'conv5', 512) | |
net.add(layer.Flatten('flat')) | |
net.add(layer.Dense('ip1', 512)) | |
net.add(layer.Activation('relu_ip1')) | |
net.add(layer.Dropout('drop1')) | |
net.add(layer.Dense('ip2', 10)) | |
return net |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
#/************************************************************ | |
#* | |
#* Licensed to the Apache Software Foundation (ASF) under one | |
#* or more contributor license agreements. See the NOTICE file | |
#* distributed with this work for additional information | |
#* regarding copyright ownership. The ASF licenses this file | |
#* to you under the Apache License, Version 2.0 (the | |
#* "License"); you may not use this file except in compliance | |
#* with the License. You may obtain a copy of the License at | |
#* | |
#* http://www.apache.org/licenses/LICENSE-2.0 | |
#* | |
#* Unless required by applicable law or agreed to in writing, | |
#* software distributed under the License is distributed on an | |
#* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
#* KIND, either express or implied. See the License for the | |
#* specific language governing permissions and limitations | |
#* under the License. | |
#* | |
#*************************************************************/ | |
#************** | |
#*sudo apt-get install libjpeg-dev | |
#*sudo pip install | |
from PIL import Image | |
import sys, glob, os, random, shutil, time | |
import numpy as np | |
def do_resize(img, small_size): | |
'''resize an image into a small size''' | |
size = img.size | |
if size[0] < size[1]: | |
new_size = (small_size, int(small_size * size[1] / size[0])) | |
else: | |
new_size = (int(small_size * size[0] / size[1]), small_size) | |
new_img = img.resize(new_size) | |
#print "resize to %d,%d" % new_size | |
return new_img | |
def do_crop(img, crop, position): | |
'''crop an image into a square''' | |
if img.size[0] < crop[0]: | |
raise Exception('img size[0] %d is smaller than crop[0]: %d' % | |
(img[0], crop[0])) | |
if img.size[1] < crop[1]: | |
raise Exception('img size[1] %d is smaller than crop[1]: %d' % | |
(img[1], crop[1])) | |
if position == 'left_top': | |
left = 0 | |
upper = 0 | |
if position == 'left_bottom': | |
left = 0 | |
upper = img.size[1] - crop[1] | |
if position == 'right_top': | |
left = img.size[0] - crop[0] | |
upper = 0 | |
if position == 'right_bottom': | |
left = img.size[0] - crop[0] | |
upper = img.size[1] - crop[1] | |
if position == 'center': | |
left = (img.size[0] - crop[0]) / 2 | |
upper = (img.size[1] - crop[1]) / 2 | |
box = (left, upper, left + crop[0], upper + crop[1]) | |
new_img = img.crop(box) | |
#print "crop to box %d,%d,%d,%d" % box | |
return new_img | |
def do_flip(img): | |
'''flip an image''' | |
new_img = img.transpose(Image.FLIP_LEFT_RIGHT) | |
return new_img | |
def load_img(path, grayscale=False): | |
'''load an image''' | |
from PIL import Image | |
img = Image.open(path) | |
if grayscale: | |
img = img.convert('L') | |
else: # Ensure 3 channel even when loaded image is grayscale | |
img = img.convert('RGB') | |
return img | |
def process_img(img, small_size, size, is_aug): | |
'''process an image, using do_resize, do_crop and do_flip''' | |
im = load_img(img) | |
im = do_resize(im, small_size) | |
dataArray = [] | |
if is_aug: | |
positions = ["left_top", "left_bottom", "right_top", "right_bottom", | |
"center"] | |
else: | |
positions = ["center"] | |
for position in positions: | |
newIm = do_crop(im, size, position) | |
assert newIm.size == size | |
pix = np.array(newIm.convert("RGB")) | |
dataArray.append(pix.transpose(2, 0, 1)) | |
if is_aug: | |
newIm = do_flip(newIm) | |
pix = np.array(newIm.convert("RGB")) | |
dataArray.append(pix.transpose(2, 0, 1)) | |
return dataArray | |
def unpickle(file): | |
import cPickle | |
fo = open(file, 'rb') | |
dict = cPickle.load(fo) | |
fo.close() | |
return dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
flask>=0.10.1 | |
flask_cors>=3.0.2 | |
pillow>=2.3.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# ============================================================================= | |
from multiprocessing import Process, Queue | |
from argparse import ArgumentParser | |
from argparse import RawDescriptionHelpFormatter | |
import sys, os, traceback | |
import flaskserver | |
import model | |
from service import Service | |
sys.path.append(os.getcwd()) | |
def main(argv=None): | |
'''Command line options''' | |
if argv is None: | |
argv = sys.argv | |
else: | |
sys.argv.extend(argv) | |
try: | |
# Setup argument parser | |
parser = ArgumentParser( | |
description="SINGA CIFAR SVG TRANING MODEL", | |
formatter_class=RawDescriptionHelpFormatter) | |
parser.add_argument( | |
"-p", | |
"--port", | |
dest="port", | |
default=9999, | |
help="the port to listen to, default is 9999") | |
parser.add_argument( | |
"-param", | |
"--parameter", | |
dest="parameter", | |
help="the parameter file path to be loaded") | |
parser.add_argument( | |
"-C", | |
"--cpu", | |
dest="use_cpu", | |
action="store_true", | |
default=False, | |
help="Using cpu or not, default is using gpu") | |
# Process arguments | |
args = parser.parse_args() | |
port = args.port | |
parameter_file = args.parameter | |
use_cpu = args.use_cpu | |
# start to train | |
m = model.create(use_cpu) | |
service = Service(m, use_cpu) | |
print parameter_file | |
service.initialize(parameter_file) | |
flaskserver.start_serve(port, service) | |
except: | |
traceback.print_exc() | |
sys.stderr.write(" for help use --help \n\n") | |
return 2 | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# ============================================================================= | |
import sys, glob, os, random, shutil, time | |
import numpy as np | |
import urllib, traceback | |
from singa import tensor, device, optimizer | |
from singa import utils, initializer, metric | |
from singa.proto import core_pb2 | |
import data | |
import process | |
top_k = 5 | |
class Service(): | |
def __init__(self, model, use_cpu): | |
self.model = model | |
if use_cpu: | |
print "running with cpu" | |
self.device = device.get_default_device() | |
#print "cpu mode is not supported at present!" | |
else: | |
print "runing with gpu" | |
self.device = device.create_cuda_gpu() | |
self.opt = optimizer.SGD(momentum=0.9, weight_decay=0.0005) | |
def initialize(self, parameter_file=None): | |
'''get parameters of the model to run the model in predict manner''' | |
data.serve_file_prepare() | |
print 'Start intialization............' | |
parameter = data.get_parameter(parameter_file, True) | |
print 'initialize with %s' % parameter | |
self.model.load(parameter) | |
self.model.to_device(self.device) | |
print 'End intialization............' | |
self.mean = data.load_mean_data() | |
def serve(self, request): | |
'''predict the label for the uploaded images''' | |
image = request.files['image'] | |
if not image: | |
return "error, no image file found!" | |
if not allowed_file(image.filename): | |
return "error, only jpg image is allowed." | |
try: | |
#process images | |
images = process.process_img(image, 36, (32, 32), True) | |
images = np.array(images[0:10]).astype(np.float32) | |
#normalize | |
images -= self.mean | |
x = tensor.from_numpy(images.astype(np.float32)) | |
x.to_device(self.device) | |
y = self.model.predict(x) | |
y.to_host() | |
y = tensor.to_numpy(y) | |
prob = np.average(y, 0) | |
#sort and reverse | |
labels = np.flipud(np.argsort(prob)) | |
response = "" | |
for i in range(top_k): | |
response += "%s:%s<br/>" % (get_name(labels[i]), | |
prob[labels[i]]) | |
return response | |
except Exception as e: | |
traceback.print_exc() | |
print e | |
return "sorry, system error." | |
def get_lr(epoch): | |
if epoch < 360: | |
return 0.0008 | |
elif epoch < 540: | |
return 0.0001 | |
else: | |
return 0.00001 | |
def allowed_file(filename): | |
return '.' in filename and \ | |
filename.rsplit('.', 1)[1] in ["jpg","JPG","JPEG","jpeg"] | |
label_map = { | |
0: 'airplane', | |
1: 'automobile', | |
2: 'bird', | |
3: 'cat', | |
4: 'deer', | |
5: 'dog', | |
6: 'frog', | |
7: 'horse', | |
8: 'ship', | |
9: 'truck' | |
} | |
def get_name(label): | |
return label_map[label] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# ============================================================================= | |
from multiprocessing import Process, Queue | |
from argparse import ArgumentParser | |
from argparse import RawDescriptionHelpFormatter | |
import sys, os, traceback | |
import flaskserver | |
import model | |
from trainer import Trainer | |
sys.path.append(os.getcwd()) | |
def main(argv=None): | |
'''Command line options''' | |
if argv is None: | |
argv = sys.argv | |
else: | |
sys.argv.extend(argv) | |
try: | |
# Setup argument parser | |
parser = ArgumentParser( | |
description="SINGA CIFAR SVG TRANING MODEL", | |
formatter_class=RawDescriptionHelpFormatter) | |
parser.add_argument( | |
"-p", | |
"--port", | |
dest="port", | |
default=9999, | |
help="the port to listen to, default is 9999") | |
parser.add_argument( | |
"-param", | |
"--parameter", | |
dest="parameter", | |
help="the parameter file path to be loaded") | |
parser.add_argument( | |
"-C", | |
"--cpu", | |
dest="use_cpu", | |
action="store_true", | |
default=False, | |
help="Using cpu or not, default is using gpu") | |
# Process arguments | |
args = parser.parse_args() | |
port = args.port | |
parameter_file = args.parameter | |
use_cpu = args.use_cpu | |
# start monitor server | |
# use multiprocessing to transfer training status information | |
queue = Queue() | |
p = Process(target=flaskserver.start_monitor, args=(port, queue)) | |
p.start() | |
# start to train | |
m = model.create(use_cpu) | |
trainer = Trainer(m, use_cpu, queue) | |
trainer.initialize(parameter_file) | |
trainer.train() | |
p.terminate() | |
except: | |
p.terminate() | |
traceback.print_exc() | |
sys.stderr.write(" for help use --help \n\n") | |
return 2 | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# ============================================================================= | |
import sys, glob, os, random, shutil, time | |
import numpy as np | |
import urllib | |
from singa import tensor, device, optimizer | |
from singa import utils, initializer, metric | |
from singa.proto import core_pb2 | |
import data | |
class Trainer(): | |
'''train a singa model''' | |
def __init__(self, model, use_cpu, queue): | |
self.model = model | |
if use_cpu: | |
print "runing with cpu" | |
self.device = device.get_default_device() | |
#raise CLIError("Currently cpu is not support!") | |
else: | |
print "runing with gpu" | |
self.device = device.create_cuda_gpu() | |
self.opt = optimizer.SGD(momentum=0.9, weight_decay=0.0005) | |
self.queue = queue | |
def initialize(self, parameter_file): | |
'''initialize all parameters in the model''' | |
print 'Start intialization............' | |
if parameter_file: | |
parameter = data.get_parameter(parameter_file) | |
print 'initialize with %s' % parameter | |
self.model.load(parameter) | |
else: | |
for (p, specs) in zip(self.model.param_values(), | |
self.model.param_specs()): | |
filler = specs.filler | |
if filler.type == 'gaussian': | |
initializer.gaussian(p, filler.mean, filler.std) | |
elif filler.type == 'xavier': | |
initializer.xavier(p) | |
p *= 0.5 # 0.5 if use glorot, which would have val acc to 83 | |
else: | |
p.set_value(0) | |
self.opt.register(p, specs) | |
print specs.name, filler.type, p.l1() | |
self.model.to_device(self.device) | |
print 'End intialization............' | |
def data_prepare(self): | |
'''load data''' | |
data.train_file_prepare() | |
self.train_x, self.train_y = data.load_train_data() | |
self.test_x, self.test_y = data.load_test_data() | |
self.mean = data.load_mean_data() | |
if self.mean is None: | |
self.mean = np.average(self.train_x, axis=0) | |
data.save_mean_data(self.mean) | |
self.train_x -= self.mean | |
self.test_x -= self.mean | |
def train(self, num_epoch=140, batch_size=50): | |
'''train and test model''' | |
self.data_prepare() | |
print 'training shape', self.train_x.shape, self.train_y.shape | |
print 'validation shape', self.test_x.shape, self.test_y.shape | |
tx = tensor.Tensor((batch_size, 3, 32, 32), self.device) | |
ty = tensor.Tensor((batch_size, ), self.device, core_pb2.kInt) | |
num_train_batch = self.train_x.shape[0] / batch_size | |
num_test_batch = self.test_x.shape[0] / (batch_size) | |
accuracy = metric.Accuracy() | |
idx = np.arange(self.train_x.shape[0], dtype=np.int32) | |
# frequency of gathering training status info | |
skip = 20 | |
for epoch in range(num_epoch): | |
np.random.shuffle(idx) | |
loss, acc = 0.0, 0.0 | |
print 'Epoch %d' % epoch | |
loss, acc = 0.0, 0.0 | |
for b in range(num_test_batch): | |
x = self.test_x[b * batch_size:(b + 1) * batch_size] | |
y = self.test_y[b * batch_size:(b + 1) * batch_size] | |
tx.copy_from_numpy(x) | |
ty.copy_from_numpy(y) | |
l, a = self.model.evaluate(tx, ty) | |
loss += l | |
acc += a | |
print 'testing loss = %f, accuracy = %f' % (loss / num_test_batch, | |
acc / num_test_batch) | |
# put test status info into a shared queue | |
dic = dict( | |
phase='test', | |
#step = (epoch + 1) * num_train_batch / skip - 1, | |
step=epoch * num_train_batch / skip, | |
accuracy=acc / num_test_batch, | |
loss=loss / num_test_batch, | |
timestamp=time.time()) | |
self.queue.put(dic) | |
for b in range(num_train_batch): | |
x = self.train_x[idx[b * batch_size:(b + 1) * batch_size]] | |
y = self.train_y[idx[b * batch_size:(b + 1) * batch_size]] | |
tx.copy_from_numpy(x) | |
ty.copy_from_numpy(y) | |
grads, (l, a) = self.model.train(tx, ty) | |
loss += l | |
acc += a | |
for (s, p, g) in zip(self.model.param_specs(), | |
self.model.param_values(), grads): | |
self.opt.apply_with_lr(epoch, get_lr(epoch), g, p, | |
str(s.name)) | |
info = 'training loss = %f, training accuracy = %f' % (l, a) | |
# put training status info into a shared queue | |
if b % skip == 0: | |
dic = dict( | |
phase='train', | |
step=(epoch * num_train_batch + b) / skip, | |
accuracy=a, | |
loss=l, | |
timestamp=time.time()) | |
self.queue.put(dic) | |
# update progress bar | |
utils.update_progress(b * 1.0 / num_train_batch, info) | |
print "" | |
info = 'training loss = %f, training accuracy = %f' \ | |
% (loss / num_train_batch, acc / num_train_batch) | |
print info | |
if epoch > 0 and epoch % 10 == 0: | |
self.model.save( | |
os.path.join(data.parameter_folder, 'parameter_%d' % | |
epoch)) | |
self.model.save(os.path.join(data.parameter_folder, 'parameter')) | |
return | |
def get_lr(epoch): | |
'''change learning rate as epoch goes up''' | |
if epoch < 360: | |
return 0.0008 | |
elif epoch < 540: | |
return 0.0001 | |
else: | |
return 0.00001 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment