Skip to content

Instantly share code, notes, and snippets.

@XiangruiCAI
Last active October 11, 2016 08:07
Show Gist options
  • Save XiangruiCAI/a38ae090eecefae4ac975da43ab1365b to your computer and use it in GitHub Desktop.
Save XiangruiCAI/a38ae090eecefae4ac975da43ab1365b to your computer and use it in GitHub Desktop.
rafiki-cifar-example
  • This is an example of SINGA CNN model.
  • Training cifar dataset with VGG model.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
import os, sys, shutil
import urllib
import cPickle
import numpy as np
data_folder = "data_"
tar_data_url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
tar_data_name = 'cifar-10-python.tar.gz'
data_path = 'cifar-10-batches-py'
parameter_folder = "parameter_"
parameter_name = "parameter"
tar_parameter_url = "http://comp.nus.edu.sg/~dbsystem/singa/assets/file/parameter.tar.gz"
tar_parameter_name = 'parameter.tar.gz'
mean_url = "http://comp.nus.edu.sg/~dbsystem/singa/assets/file/train.mean.npy"
mean_name = "train.mean.npy"
def load_dataset(filepath):
'''load data from binary file'''
print 'Loading data file %s' % filepath
with open(filepath, 'rb') as fd:
cifar10 = cPickle.load(fd)
image = cifar10['data'].astype(dtype=np.uint8)
image = image.reshape((-1, 3, 32, 32))
label = np.asarray(cifar10['labels'], dtype=np.uint8)
label = label.reshape(label.size, 1)
return image, label
def load_train_data(num_batches=5):
labels = []
batchsize = 10000
images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8)
for did in range(1, num_batches + 1):
fname_train_data = os.path.join(data_folder, data_path,
"data_batch_{}".format(did))
image, label = load_dataset(fname_train_data)
images[(did - 1) * batchsize:did * batchsize] = image
labels.extend(label)
images = np.array(images, dtype=np.float32)
labels = np.array(labels, dtype=np.int32)
return images, labels
def load_test_data():
images, labels = load_dataset(
os.path.join(data_folder, data_path, "test_batch"))
return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32)
def load_mean_data():
mean_path = os.path.join(data_folder, mean_name)
if os.path.exists(mean_path):
return np.load(mean_path)
return None
def save_mean_data(mean):
mean_path = os.path.join(data_folder, mean_name)
np.save(mean_path, mean)
return
def train_file_prepare():
'''download train file'''
if os.path.exists(os.path.join(data_folder, data_path)):
return
print "download file"
#clean data
download_file(tar_data_url, data_folder)
untar_data(os.path.join(data_folder, tar_data_name), data_folder)
def serve_file_prepare():
'''download parameter file and mean file'''
if not os.path.exists(os.path.join(parameter_folder, parameter_name)):
print "download parameter file"
download_file(tar_parameter_url, parameter_folder)
untar_data(
os.path.join(parameter_folder, tar_parameter_name),
parameter_folder)
if not os.path.exists(os.path.join(data_folder, mean_name)):
print "download mean file"
download_file(mean_url, data_folder)
#clean data
def download_file(url, dest):
'''
download one file to dest
'''
if not os.path.exists(dest):
os.makedirs(dest)
if (url.startswith('http')):
file_name = url.split('/')[-1]
target = os.path.join(dest, file_name)
urllib.urlretrieve(url, target)
return
def get_parameter(file_name=None, auto_find=False):
'''
get a parameter file or return none
'''
if not os.path.exists(parameter_folder):
os.makedirs(parameter_folder)
return None
if file_name:
return os.path.join(parameter_folder, file_name)
#find the last parameter file if outo_find is True
if auto_find:
parameter_list = []
for f in os.listdir(parameter_folder):
if f.endswith(".model"):
parameter_list.append(os.path.join(parameter_folder, f[0:-6]))
if len(parameter_list) == 0:
return None
parameter_list.sort()
return parameter_list[-1]
else:
return None
def untar_data(file_path, dest):
print 'untar data ..................'
tar_file = file_path
import tarfile
tar = tarfile.open(tar_file)
print dest
print file_path
tar.extractall(dest)
tar.close()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
from multiprocessing import Process
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS, cross_origin
app = Flask(__name__)
top_k_ = 5
def success(data=""):
'''return success status in json format'''
res = dict(result="success", data=data)
return jsonify(res)
def failure(message):
'''return failure status in json format'''
res = dict(result="message", message=message)
return jsonify(res)
def start_monitor(port, queue):
'''
port run a server at 0.0.0.0:port
queue: transfer status information
data: buffer the information in queue
'''
global queue_, data_, type_
queue_ = queue
data_ = []
type_ = "monitor"
app.run(host='0.0.0.0', port=port)
return
def start_serve(port, service):
'''start a server for serve, e.g., identify a uploaded image'''
global type_, service_
service_ = service
type_ = "serve"
app.run(host='0.0.0.0', port=port)
return
def getDataFromQueue():
global queue_, data
while not queue_.empty():
d = queue_.get()
data_.append(d)
@app.route("/")
@cross_origin()
def index():
global type_
print type_
if type_ == "monitor":
return "Hello,This is SINGA monitor http server"
else:
return send_from_directory(".", "index.html", mimetype='text/html')
# predict a uploaded image
@app.route("/predict", methods=['POST'])
@cross_origin()
def predict():
global type_, service_
if type_ == "monitor":
return failure("not available in monitor mode")
if request.method == 'POST':
try:
print "test"
response = service_.serve(request)
except Exception as e:
print str(e)
return e
return response
# support two operations for user to monitor the training status
@app.route('/getAllData')
@cross_origin()
def getAllData():
global data_, type_
if type_ == "serve":
return failure("not available in serve mode")
getDataFromQueue()
return success(data_)
@app.route('/getTopKData')
@cross_origin()
def getTopKData():
global data_, type_
if type_ == "serve":
return failure("not available in serve mode")
k = request.args.get("k", top_k_)
try:
k = int(k)
except:
return failure("k should be integer")
getDataFromQueue()
return success(data_[-k:])
<!DOCTYPE html>
<!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]-->
<!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]-->
<!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]-->
<!--[if gt IE 8]><!-->
<html class="no-js">
<!--<![endif]-->
<head>
<meta charset="utf-8">
<title>Foodlg - Insight into your titbits</title>
<meta name="description" content="">
<script type="text/javascript" src="http://code.jquery.com/jquery-1.12.4.min.js"></script>
</head>
<body>
<h2>CIFAR PREDICTION</h2>
<form>
<p>Please upload a jpg file!</p>
<input id="file-input" type="file" accept="image/*;capture=camera" ></input>
</form>
<button id="go">GO</button>
<br/>
<img id="image" src=""/>
<div id="result"></div>
</body>
<script language="javascript">
$("#go").click(function(){
var f = $("#file-input")[0].files[0];
ReadFile(f,function(result){
var file = DataURItoBlob(result);
$("#image").attr("src",result);
predict_dish(file);
});
});
function predict_dish(file){
var formData = new FormData();
formData.append('image', file, "image.jpg");
$.ajax({
url:"/predict",
data:formData,
type:"POST",
processData: false, // Don't process the files
contentType: false,
success:function(response){
$("#result").html(response);
},
error:function(e){
console.log(e);
$("#result").html("Error Occurs!");
}
});
}
var ReadFile = function(file,callback) {
var reader = new FileReader();
reader.onloadend = function () {
ProcessFile(reader.result, file.type,callback);
}
reader.onerror = function () {
alert('There was an error reading the file!');
}
reader.readAsDataURL(file);
}
var ProcessFile = function(dataURL, fileType,callback) {
var maxWidth = 400;
var maxHeight = 400;
var image = new Image();
image.src = dataURL;
image.onload = function () {
var width = image.width;
var height = image.height;
var shouldResize = (width > maxWidth) || (height > maxHeight);
if (!shouldResize) {
callback(dataURL);
return;
}
var newWidth;
var newHeight;
if (width > height) {
newHeight = height * (maxWidth / width);
newWidth = maxWidth;
} else {
newWidth = width * (maxHeight / height);
newHeight = maxHeight;
}
var canvas = document.createElement('canvas');
canvas.width = newWidth;
canvas.height = newHeight;
var context = canvas.getContext('2d');
context.drawImage(this, 0, 0, newWidth, newHeight);
dataURL = canvas.toDataURL(fileType);
callback(dataURL);
};
image.onerror = function () {
alert('There was an error processing your file!');
};
}
var DataURItoBlob = function(dataURI) {
// convert base64 to raw binary data held in a string
// doesn't handle URLEncoded DataURIs - see SO answer #6850276 for code that does this
var byteString = atob(dataURI.split(',')[1]);
// separate out the mime component
var mimeString = dataURI.split(',')[0].split(':')[1].split(';')[0];
// write the bytes of the string to an ArrayBuffer
var ab = new ArrayBuffer(byteString.length);
var ia = new Uint8Array(ab);
for (var i = 0; i < byteString.length; i++) {
ia[i] = byteString.charCodeAt(i);
}
try {
return new Blob([ab], {type: mimeString});
} catch (e) {
// The BlobBuilder API has been deprecated in favour of Blob, but older
// browsers don't know about the Blob constructor
// IE10 also supports BlobBuilder, but since the `Blob` constructor
// also works, there's no need to add `MSBlobBuilder`.
var BlobBuilder = window.WebKitBlobBuilder || window.MozBlobBuilder;
var bb = new BlobBuilder();
bb.append(ab);
return bb.getBlob(mimeString);
}
}
</script>
</html>
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
from singa import layer
from singa import metric
from singa import loss
from singa import net as ffnet
def add_layer_group(net, name, nb_filers, sample_shape=None):
'''add a group of layers which will be used in vgg model recurrently'''
net.add(
layer.Conv2D(
name + '_1',
nb_filers,
3,
1,
pad=1,
input_sample_shape=sample_shape))
net.add(layer.Activation(name + 'act_1'))
net.add(layer.Conv2D(name + '_2', nb_filers, 3, 1, pad=1))
net.add(layer.Activation(name + 'act_2'))
net.add(layer.MaxPooling2D(name, 2, 2, pad=0))
def create(use_cpu = False):
'''create network of vgg model'''
if use_cpu:
layer.engine = 'singacpp'
net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy())
add_layer_group(net, 'conv1', 64, (3, 32, 32))
add_layer_group(net, 'conv2', 128)
add_layer_group(net, 'conv3', 256)
add_layer_group(net, 'conv4', 512)
add_layer_group(net, 'conv5', 512)
net.add(layer.Flatten('flat'))
net.add(layer.Dense('ip1', 512))
net.add(layer.Activation('relu_ip1'))
net.add(layer.Dropout('drop1'))
net.add(layer.Dense('ip2', 10))
return net
#!/usr/bin/env python
#/************************************************************
#*
#* Licensed to the Apache Software Foundation (ASF) under one
#* or more contributor license agreements. See the NOTICE file
#* distributed with this work for additional information
#* regarding copyright ownership. The ASF licenses this file
#* to you under the Apache License, Version 2.0 (the
#* "License"); you may not use this file except in compliance
#* with the License. You may obtain a copy of the License at
#*
#* http://www.apache.org/licenses/LICENSE-2.0
#*
#* Unless required by applicable law or agreed to in writing,
#* software distributed under the License is distributed on an
#* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#* KIND, either express or implied. See the License for the
#* specific language governing permissions and limitations
#* under the License.
#*
#*************************************************************/
#**************
#*sudo apt-get install libjpeg-dev
#*sudo pip install
from PIL import Image
import sys, glob, os, random, shutil, time
import numpy as np
def do_resize(img, small_size):
'''resize an image into a small size'''
size = img.size
if size[0] < size[1]:
new_size = (small_size, int(small_size * size[1] / size[0]))
else:
new_size = (int(small_size * size[0] / size[1]), small_size)
new_img = img.resize(new_size)
#print "resize to %d,%d" % new_size
return new_img
def do_crop(img, crop, position):
'''crop an image into a square'''
if img.size[0] < crop[0]:
raise Exception('img size[0] %d is smaller than crop[0]: %d' %
(img[0], crop[0]))
if img.size[1] < crop[1]:
raise Exception('img size[1] %d is smaller than crop[1]: %d' %
(img[1], crop[1]))
if position == 'left_top':
left = 0
upper = 0
if position == 'left_bottom':
left = 0
upper = img.size[1] - crop[1]
if position == 'right_top':
left = img.size[0] - crop[0]
upper = 0
if position == 'right_bottom':
left = img.size[0] - crop[0]
upper = img.size[1] - crop[1]
if position == 'center':
left = (img.size[0] - crop[0]) / 2
upper = (img.size[1] - crop[1]) / 2
box = (left, upper, left + crop[0], upper + crop[1])
new_img = img.crop(box)
#print "crop to box %d,%d,%d,%d" % box
return new_img
def do_flip(img):
'''flip an image'''
new_img = img.transpose(Image.FLIP_LEFT_RIGHT)
return new_img
def load_img(path, grayscale=False):
'''load an image'''
from PIL import Image
img = Image.open(path)
if grayscale:
img = img.convert('L')
else: # Ensure 3 channel even when loaded image is grayscale
img = img.convert('RGB')
return img
def process_img(img, small_size, size, is_aug):
'''process an image, using do_resize, do_crop and do_flip'''
im = load_img(img)
im = do_resize(im, small_size)
dataArray = []
if is_aug:
positions = ["left_top", "left_bottom", "right_top", "right_bottom",
"center"]
else:
positions = ["center"]
for position in positions:
newIm = do_crop(im, size, position)
assert newIm.size == size
pix = np.array(newIm.convert("RGB"))
dataArray.append(pix.transpose(2, 0, 1))
if is_aug:
newIm = do_flip(newIm)
pix = np.array(newIm.convert("RGB"))
dataArray.append(pix.transpose(2, 0, 1))
return dataArray
def unpickle(file):
import cPickle
fo = open(file, 'rb')
dict = cPickle.load(fo)
fo.close()
return dict
flask>=0.10.1
flask_cors>=3.0.2
pillow>=2.3.0
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
from multiprocessing import Process, Queue
from argparse import ArgumentParser
from argparse import RawDescriptionHelpFormatter
import sys, os, traceback
import flaskserver
import model
from service import Service
sys.path.append(os.getcwd())
def main(argv=None):
'''Command line options'''
if argv is None:
argv = sys.argv
else:
sys.argv.extend(argv)
try:
# Setup argument parser
parser = ArgumentParser(
description="SINGA CIFAR SVG TRANING MODEL",
formatter_class=RawDescriptionHelpFormatter)
parser.add_argument(
"-p",
"--port",
dest="port",
default=9999,
help="the port to listen to, default is 9999")
parser.add_argument(
"-param",
"--parameter",
dest="parameter",
help="the parameter file path to be loaded")
parser.add_argument(
"-C",
"--cpu",
dest="use_cpu",
action="store_true",
default=False,
help="Using cpu or not, default is using gpu")
# Process arguments
args = parser.parse_args()
port = args.port
parameter_file = args.parameter
use_cpu = args.use_cpu
# start to train
m = model.create(use_cpu)
service = Service(m, use_cpu)
print parameter_file
service.initialize(parameter_file)
flaskserver.start_serve(port, service)
except:
traceback.print_exc()
sys.stderr.write(" for help use --help \n\n")
return 2
if __name__ == '__main__':
main()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
import sys, glob, os, random, shutil, time
import numpy as np
import urllib, traceback
from singa import tensor, device, optimizer
from singa import utils, initializer, metric
from singa.proto import core_pb2
import data
import process
top_k = 5
class Service():
def __init__(self, model, use_cpu):
self.model = model
if use_cpu:
print "running with cpu"
self.device = device.get_default_device()
#print "cpu mode is not supported at present!"
else:
print "runing with gpu"
self.device = device.create_cuda_gpu()
self.opt = optimizer.SGD(momentum=0.9, weight_decay=0.0005)
def initialize(self, parameter_file=None):
'''get parameters of the model to run the model in predict manner'''
data.serve_file_prepare()
print 'Start intialization............'
parameter = data.get_parameter(parameter_file, True)
print 'initialize with %s' % parameter
self.model.load(parameter)
self.model.to_device(self.device)
print 'End intialization............'
self.mean = data.load_mean_data()
def serve(self, request):
'''predict the label for the uploaded images'''
image = request.files['image']
if not image:
return "error, no image file found!"
if not allowed_file(image.filename):
return "error, only jpg image is allowed."
try:
#process images
images = process.process_img(image, 36, (32, 32), True)
images = np.array(images[0:10]).astype(np.float32)
#normalize
images -= self.mean
x = tensor.from_numpy(images.astype(np.float32))
x.to_device(self.device)
y = self.model.predict(x)
y.to_host()
y = tensor.to_numpy(y)
prob = np.average(y, 0)
#sort and reverse
labels = np.flipud(np.argsort(prob))
response = ""
for i in range(top_k):
response += "%s:%s<br/>" % (get_name(labels[i]),
prob[labels[i]])
return response
except Exception as e:
traceback.print_exc()
print e
return "sorry, system error."
def get_lr(epoch):
if epoch < 360:
return 0.0008
elif epoch < 540:
return 0.0001
else:
return 0.00001
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1] in ["jpg","JPG","JPEG","jpeg"]
label_map = {
0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck'
}
def get_name(label):
return label_map[label]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
from multiprocessing import Process, Queue
from argparse import ArgumentParser
from argparse import RawDescriptionHelpFormatter
import sys, os, traceback
import flaskserver
import model
from trainer import Trainer
sys.path.append(os.getcwd())
def main(argv=None):
'''Command line options'''
if argv is None:
argv = sys.argv
else:
sys.argv.extend(argv)
try:
# Setup argument parser
parser = ArgumentParser(
description="SINGA CIFAR SVG TRANING MODEL",
formatter_class=RawDescriptionHelpFormatter)
parser.add_argument(
"-p",
"--port",
dest="port",
default=9999,
help="the port to listen to, default is 9999")
parser.add_argument(
"-param",
"--parameter",
dest="parameter",
help="the parameter file path to be loaded")
parser.add_argument(
"-C",
"--cpu",
dest="use_cpu",
action="store_true",
default=False,
help="Using cpu or not, default is using gpu")
# Process arguments
args = parser.parse_args()
port = args.port
parameter_file = args.parameter
use_cpu = args.use_cpu
# start monitor server
# use multiprocessing to transfer training status information
queue = Queue()
p = Process(target=flaskserver.start_monitor, args=(port, queue))
p.start()
# start to train
m = model.create(use_cpu)
trainer = Trainer(m, use_cpu, queue)
trainer.initialize(parameter_file)
trainer.train()
p.terminate()
except:
p.terminate()
traceback.print_exc()
sys.stderr.write(" for help use --help \n\n")
return 2
if __name__ == '__main__':
main()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
import sys, glob, os, random, shutil, time
import numpy as np
import urllib
from singa import tensor, device, optimizer
from singa import utils, initializer, metric
from singa.proto import core_pb2
import data
class Trainer():
'''train a singa model'''
def __init__(self, model, use_cpu, queue):
self.model = model
if use_cpu:
print "runing with cpu"
self.device = device.get_default_device()
#raise CLIError("Currently cpu is not support!")
else:
print "runing with gpu"
self.device = device.create_cuda_gpu()
self.opt = optimizer.SGD(momentum=0.9, weight_decay=0.0005)
self.queue = queue
def initialize(self, parameter_file):
'''initialize all parameters in the model'''
print 'Start intialization............'
if parameter_file:
parameter = data.get_parameter(parameter_file)
print 'initialize with %s' % parameter
self.model.load(parameter)
else:
for (p, specs) in zip(self.model.param_values(),
self.model.param_specs()):
filler = specs.filler
if filler.type == 'gaussian':
initializer.gaussian(p, filler.mean, filler.std)
elif filler.type == 'xavier':
initializer.xavier(p)
p *= 0.5 # 0.5 if use glorot, which would have val acc to 83
else:
p.set_value(0)
self.opt.register(p, specs)
print specs.name, filler.type, p.l1()
self.model.to_device(self.device)
print 'End intialization............'
def data_prepare(self):
'''load data'''
data.train_file_prepare()
self.train_x, self.train_y = data.load_train_data()
self.test_x, self.test_y = data.load_test_data()
self.mean = data.load_mean_data()
if self.mean is None:
self.mean = np.average(self.train_x, axis=0)
data.save_mean_data(self.mean)
self.train_x -= self.mean
self.test_x -= self.mean
def train(self, num_epoch=140, batch_size=50):
'''train and test model'''
self.data_prepare()
print 'training shape', self.train_x.shape, self.train_y.shape
print 'validation shape', self.test_x.shape, self.test_y.shape
tx = tensor.Tensor((batch_size, 3, 32, 32), self.device)
ty = tensor.Tensor((batch_size, ), self.device, core_pb2.kInt)
num_train_batch = self.train_x.shape[0] / batch_size
num_test_batch = self.test_x.shape[0] / (batch_size)
accuracy = metric.Accuracy()
idx = np.arange(self.train_x.shape[0], dtype=np.int32)
# frequency of gathering training status info
skip = 20
for epoch in range(num_epoch):
np.random.shuffle(idx)
loss, acc = 0.0, 0.0
print 'Epoch %d' % epoch
loss, acc = 0.0, 0.0
for b in range(num_test_batch):
x = self.test_x[b * batch_size:(b + 1) * batch_size]
y = self.test_y[b * batch_size:(b + 1) * batch_size]
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
l, a = self.model.evaluate(tx, ty)
loss += l
acc += a
print 'testing loss = %f, accuracy = %f' % (loss / num_test_batch,
acc / num_test_batch)
# put test status info into a shared queue
dic = dict(
phase='test',
#step = (epoch + 1) * num_train_batch / skip - 1,
step=epoch * num_train_batch / skip,
accuracy=acc / num_test_batch,
loss=loss / num_test_batch,
timestamp=time.time())
self.queue.put(dic)
for b in range(num_train_batch):
x = self.train_x[idx[b * batch_size:(b + 1) * batch_size]]
y = self.train_y[idx[b * batch_size:(b + 1) * batch_size]]
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
grads, (l, a) = self.model.train(tx, ty)
loss += l
acc += a
for (s, p, g) in zip(self.model.param_specs(),
self.model.param_values(), grads):
self.opt.apply_with_lr(epoch, get_lr(epoch), g, p,
str(s.name))
info = 'training loss = %f, training accuracy = %f' % (l, a)
# put training status info into a shared queue
if b % skip == 0:
dic = dict(
phase='train',
step=(epoch * num_train_batch + b) / skip,
accuracy=a,
loss=l,
timestamp=time.time())
self.queue.put(dic)
# update progress bar
utils.update_progress(b * 1.0 / num_train_batch, info)
print ""
info = 'training loss = %f, training accuracy = %f' \
% (loss / num_train_batch, acc / num_train_batch)
print info
if epoch > 0 and epoch % 10 == 0:
self.model.save(
os.path.join(data.parameter_folder, 'parameter_%d' %
epoch))
self.model.save(os.path.join(data.parameter_folder, 'parameter'))
return
def get_lr(epoch):
'''change learning rate as epoch goes up'''
if epoch < 360:
return 0.0008
elif epoch < 540:
return 0.0001
else:
return 0.00001
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment