- This is an example of SINGA CNN model.
- Training cifar dataset with VGG model.
-
-
Save aaronwwf/3231cbf85cd93558cd47907ff5561385 to your computer and use it in GitHub Desktop.
rafiki-cifar-example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
.project | |
.pydevproject | |
data_ | |
parameter_ | |
*.pyc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# ============================================================================= | |
import os, sys, shutil | |
import urllib | |
import cPickle | |
import numpy as np | |
data_folder = "data_" | |
tar_data_url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' | |
tar_data_name = 'cifar-10-python.tar.gz' | |
data_path = 'cifar-10-batches-py' | |
parameter_folder = "parameter_" | |
parameter_name = "parameter" | |
tar_parameter_url = "http://comp.nus.edu.sg/~dbsystem/singa/assets/file/cifar/parameter.tar.gz" | |
tar_parameter_name = 'parameter.tar.gz' | |
mean_url = "http://comp.nus.edu.sg/~dbsystem/singa/assets/file/train.mean.npy" | |
mean_name = "train.mean.npy" | |
def load_dataset(filepath): | |
'''load data from binary file''' | |
print 'Loading data file %s' % filepath | |
with open(filepath, 'rb') as fd: | |
cifar10 = cPickle.load(fd) | |
image = cifar10['data'].astype(dtype=np.uint8) | |
image = image.reshape((-1, 3, 32, 32)) | |
label = np.asarray(cifar10['labels'], dtype=np.uint8) | |
label = label.reshape(label.size, 1) | |
return image, label | |
def load_train_data(num_batches=5): | |
labels = [] | |
batchsize = 10000 | |
images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8) | |
for did in range(1, num_batches + 1): | |
fname_train_data = os.path.join(data_folder, data_path, | |
"data_batch_{}".format(did)) | |
image, label = load_dataset(fname_train_data) | |
images[(did - 1) * batchsize:did * batchsize] = image | |
labels.extend(label) | |
images = np.array(images, dtype=np.float32) | |
labels = np.array(labels, dtype=np.int32) | |
return images, labels | |
def load_test_data(): | |
images, labels = load_dataset( | |
os.path.join(data_folder, data_path, "test_batch")) | |
return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32) | |
def load_mean_data(): | |
mean_path = os.path.join(data_folder, mean_name) | |
if os.path.exists(mean_path): | |
return np.load(mean_path) | |
return None | |
def save_mean_data(mean): | |
mean_path = os.path.join(data_folder, mean_name) | |
np.save(mean_path, mean) | |
return | |
def train_file_prepare(): | |
'''download train file''' | |
if os.path.exists(os.path.join(data_folder, data_path)): | |
return | |
print "download file" | |
#clean data | |
download_file(tar_data_url, data_folder) | |
untar_data(os.path.join(data_folder, tar_data_name), data_folder) | |
if not os.path.exists(parameter_folder): | |
os.makedirs(parameter_folder) | |
def serve_file_prepare(): | |
'''download parameter file and mean file''' | |
if not os.path.exists(os.path.join(parameter_folder, tar_parameter_name)): | |
print "download parameter file" | |
download_file(tar_parameter_url, parameter_folder) | |
untar_data( | |
os.path.join(parameter_folder, tar_parameter_name), | |
parameter_folder) | |
if not os.path.exists(os.path.join(data_folder, mean_name)): | |
print "download mean file" | |
download_file(mean_url, data_folder) | |
#clean data | |
def download_file(url, dest): | |
''' | |
download one file to dest | |
''' | |
if not os.path.exists(dest): | |
os.makedirs(dest) | |
if (url.startswith('http')): | |
file_name = url.split('/')[-1] | |
target = os.path.join(dest, file_name) | |
urllib.urlretrieve(url, target) | |
return | |
def get_parameter(file_name=None, auto_find=False): | |
''' | |
get a parameter file or return none | |
''' | |
if not os.path.exists(parameter_folder): | |
os.makedirs(parameter_folder) | |
return None | |
if file_name: | |
return os.path.join(parameter_folder, file_name) | |
#find the last parameter file if outo_find is True | |
if auto_find: | |
parameter_list = [] | |
for f in os.listdir(parameter_folder): | |
if f.endswith(".model"): | |
parameter_list.append(os.path.join(parameter_folder, f[0:-6])) | |
if len(parameter_list) == 0: | |
return None | |
parameter_list.sort() | |
return parameter_list[-1] | |
else: | |
return None | |
def untar_data(file_path, dest): | |
print 'untar data ..................' | |
tar_file = file_path | |
import tarfile | |
tar = tarfile.open(tar_file) | |
print dest | |
print file_path | |
tar.extractall(dest) | |
tar.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]--> | |
<!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]--> | |
<!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]--> | |
<!--[if gt IE 8]><!--> | |
<html class="no-js"> | |
<!--<![endif]--> | |
<head> | |
<meta charset="utf-8"> | |
<title>CIFAR</title> | |
<meta name="description" content=""> | |
<script type="text/javascript" src="http://code.jquery.com/jquery-1.12.4.min.js"></script> | |
<style> | |
body{ | |
text-align:center; | |
} | |
h2{ | |
text-align:center; | |
} | |
#result{ | |
font-size: 20px; | |
font-weight: bold; | |
color: #333; | |
} | |
</style> | |
</head> | |
<body> | |
<h2>CIFAR 10, Image Classification Live Service</h2> | |
<!-- | |
<button id="stop">STOP</button> | |
<button id="go">Predict</button> | |
--> | |
<form> | |
<p>Please upload a jpg file!</p> | |
<input id="file-input" type="file" accept="image/*;capture=camera" ></input> | |
</form> | |
<br/> | |
<img id="image" src=""/> | |
<div id="result"></div> | |
</body> | |
<script language="javascript"> | |
$("#file-input").change(function(){ | |
var f = $("#file-input")[0].files[0]; | |
ReadFile(f,function(result){ | |
var file = DataURItoBlob(result); | |
$("#image").attr("src",result); | |
predict_dish(file); | |
}); | |
}); | |
$("#stop").click(function(){ | |
$.ajax({ | |
url:"/command/stop", | |
type:"POST", | |
processData: false, // Don't process the files | |
contentType: false, | |
success:function(response){ | |
alert("Stop Success!"); | |
}, | |
error:function(e){ | |
console.log(e); | |
alert("Stop Failed!"); | |
} | |
}); | |
}); | |
function predict_dish(file){ | |
var formData = new FormData(); | |
formData.append('image', file, "image.jpg"); | |
$.ajax({ | |
url:"/api", | |
data:formData, | |
type:"POST", | |
processData: false, // Don't process the files | |
contentType: false, | |
success:function(response){ | |
$("#result").html(response); | |
}, | |
error:function(e){ | |
console.log(e); | |
$("#result").html("Error Occurs!"); | |
} | |
}); | |
} | |
var ReadFile = function(file,callback) { | |
var reader = new FileReader(); | |
reader.onloadend = function () { | |
ProcessFile(reader.result, file.type,callback); | |
} | |
reader.onerror = function () { | |
alert('There was an error reading the file!'); | |
} | |
reader.readAsDataURL(file); | |
} | |
var ProcessFile = function(dataURL, fileType,callback) { | |
var maxWidth = 400; | |
var maxHeight = 400; | |
var image = new Image(); | |
image.src = dataURL; | |
image.onload = function () { | |
var width = image.width; | |
var height = image.height; | |
var shouldResize = (width > maxWidth) || (height > maxHeight); | |
if (!shouldResize) { | |
callback(dataURL); | |
return; | |
} | |
var newWidth; | |
var newHeight; | |
if (width > height) { | |
newHeight = height * (maxWidth / width); | |
newWidth = maxWidth; | |
} else { | |
newWidth = width * (maxHeight / height); | |
newHeight = maxHeight; | |
} | |
var canvas = document.createElement('canvas'); | |
canvas.width = newWidth; | |
canvas.height = newHeight; | |
var context = canvas.getContext('2d'); | |
context.drawImage(this, 0, 0, newWidth, newHeight); | |
dataURL = canvas.toDataURL(fileType); | |
callback(dataURL); | |
}; | |
image.onerror = function () { | |
alert('There was an error processing your file!'); | |
}; | |
} | |
var DataURItoBlob = function(dataURI) { | |
// convert base64 to raw binary data held in a string | |
// doesn't handle URLEncoded DataURIs - see SO answer #6850276 for code that does this | |
var byteString = atob(dataURI.split(',')[1]); | |
// separate out the mime component | |
var mimeString = dataURI.split(',')[0].split(':')[1].split(';')[0]; | |
// write the bytes of the string to an ArrayBuffer | |
var ab = new ArrayBuffer(byteString.length); | |
var ia = new Uint8Array(ab); | |
for (var i = 0; i < byteString.length; i++) { | |
ia[i] = byteString.charCodeAt(i); | |
} | |
try { | |
return new Blob([ab], {type: mimeString}); | |
} catch (e) { | |
// The BlobBuilder API has been deprecated in favour of Blob, but older | |
// browsers don't know about the Blob constructor | |
// IE10 also supports BlobBuilder, but since the `Blob` constructor | |
// also works, there's no need to add `MSBlobBuilder`. | |
var BlobBuilder = window.WebKitBlobBuilder || window.MozBlobBuilder; | |
var bb = new BlobBuilder(); | |
bb.append(ab); | |
return bb.getBlob(mimeString); | |
} | |
} | |
</script> | |
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# ============================================================================= | |
from singa import layer | |
from singa import metric | |
from singa import loss | |
from singa import net as ffnet | |
def add_layer_group(net, name, nb_filers, sample_shape=None): | |
'''add a group of layers which will be used in vgg model recurrently''' | |
net.add(layer.Conv2D(name + '_1',nb_filers,3,1,pad=1, | |
input_sample_shape=sample_shape)) | |
net.add(layer.Activation(name + 'act_1')) | |
net.add(layer.Conv2D(name + '_2', nb_filers, 3, 1, pad=1)) | |
net.add(layer.Activation(name + 'act_2')) | |
net.add(layer.MaxPooling2D(name, 2, 2, pad=0)) | |
def create(use_cpu = False): | |
'''create network of vgg model''' | |
if use_cpu: | |
layer.engine = 'singacpp' | |
net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy()) | |
add_layer_group(net, 'conv1', 64, (3, 32, 32)) | |
add_layer_group(net, 'conv2', 128) | |
add_layer_group(net, 'conv3', 256) | |
add_layer_group(net, 'conv4', 512) | |
add_layer_group(net, 'conv5', 512) | |
net.add(layer.Flatten('flat')) | |
net.add(layer.Dense('ip1', 512)) | |
net.add(layer.Activation('relu_ip1')) | |
net.add(layer.Dropout('drop1')) | |
net.add(layer.Dense('ip2', 10)) | |
return net |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
flask>=0.10.1 | |
flask_cors>=3.0.2 | |
pillow>=2.3.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# ============================================================================= | |
import sys, glob, os, random, shutil, time | |
import urllib, traceback | |
import numpy as np | |
from multiprocessing import Process, Queue | |
from argparse import ArgumentParser | |
from argparse import RawDescriptionHelpFormatter | |
from singa import tensor, device, optimizer | |
from singa import utils, initializer, metric, image_tool | |
from singa.proto import core_pb2 | |
from rafiki.agent import Agent, MsgType | |
import data | |
import model | |
top_k = 5 | |
tool = image_tool.ImageTool() | |
small_size = 35 | |
big_size = 45 | |
crop_size = 32 | |
num_augmentation = 10 | |
def image_transform(image): | |
'''Input an image path and return a set of augmented images (type Image)''' | |
global tool | |
return tool.load(image).resize_by_list([(small_size + big_size)/2]).crop5( | |
(crop_size, crop_size), 5).flip(2).get() | |
def main(argv=None): | |
'''Command line options''' | |
if argv is None: | |
argv = sys.argv | |
else: | |
sys.argv.extend(argv) | |
try: | |
# Setup argument parser | |
parser = ArgumentParser( | |
description="SINGA CIFAR SVG TRANING MODEL", | |
formatter_class=RawDescriptionHelpFormatter) | |
parser.add_argument( | |
"-p", | |
"--port", | |
dest="port", | |
default=9999, | |
help="the port to listen to, default is 9999") | |
parser.add_argument( | |
"-param", | |
"--parameter", | |
dest="parameter", | |
default="parameter", | |
help="the parameter file path to be loaded") | |
parser.add_argument( | |
"-C", | |
"--cpu", | |
dest="use_cpu", | |
action="store_true", | |
default=False, | |
help="Using cpu or not, default is using gpu") | |
# Process arguments | |
args = parser.parse_args() | |
port = args.port | |
parameter_file = args.parameter | |
use_cpu = args.use_cpu | |
# start to train | |
agent = Agent(port) | |
service = Service(agent,use_cpu) | |
service.initialize(parameter_file) | |
service.serve() | |
#wait the agent finish handling http request | |
time.sleep(1) | |
agent.stop() | |
except SystemExit: | |
return | |
except: | |
traceback.print_exc() | |
sys.stderr.write(" for help use --help \n\n") | |
return 2 | |
class Service(): | |
def __init__(self, agent, use_cpu): | |
self.model =model.create(use_cpu) | |
if use_cpu: | |
print "running with cpu" | |
self.device = device.get_default_device() | |
#print "cpu mode is not supported at present!" | |
else: | |
print "runing with gpu" | |
self.device = device.create_cuda_gpu() | |
self.opt = optimizer.SGD(momentum=0.9, weight_decay=0.0005) | |
self.agent = agent | |
def initialize(self, parameter_file=None): | |
'''get parameters of the model to run the model in predict manner''' | |
data.serve_file_prepare() | |
print 'Start intialization............' | |
parameter = data.get_parameter(parameter_file, True) | |
print 'initialize with %s' % parameter | |
self.model.load(parameter) | |
self.model.to_device(self.device) | |
print 'End intialization............' | |
self.mean = data.load_mean_data() | |
def serve(self): | |
'''predict the label for the uploaded images''' | |
while True: | |
msg,data= self.agent.pull() | |
if msg == None: | |
continue | |
msg=MsgType.parse(msg) | |
if msg.is_request(): | |
try: | |
response = "" | |
images = [] | |
for im in image_transform(data): | |
ary = np.array(im.convert('RGB'), dtype=np.float32) | |
images.append(ary.transpose(2, 0, 1)) | |
images = np.array(images) | |
#normalize | |
images -= self.mean | |
x = tensor.from_numpy(images.astype(np.float32)) | |
x.to_device(self.device) | |
y = self.model.predict(x) | |
y.to_host() | |
y = tensor.to_numpy(y) | |
prob = np.average(y, 0) | |
#sort and reverse | |
labels = np.flipud(np.argsort(prob)) | |
for i in range(top_k): | |
response += "%s:%s<br/>" % (get_name(labels[i]), | |
prob[labels[i]]) | |
except: | |
traceback.print_exc() | |
response="sorry, system error." | |
self.agent.push(MsgType.kResponse,response) | |
elif msg.is_command(): | |
if MsgType.kCommandStop.equal(msg): | |
print 'get stop command' | |
self.agent.push(MsgType.kStatus,"success") | |
break | |
else: | |
print 'get unsupported command %s' % str(msg) | |
self.agent.push(MsgType.kStatus,"failure") | |
else: | |
print 'get unsupported message %s' % str(msg) | |
self.agent.push(MsgType.kStatus,"failure") | |
break | |
time.sleep(0.01) | |
# while loop | |
print "server stop" | |
return | |
label_map = { | |
0: 'airplane', | |
1: 'automobile', | |
2: 'bird', | |
3: 'cat', | |
4: 'deer', | |
5: 'dog', | |
6: 'frog', | |
7: 'horse', | |
8: 'ship', | |
9: 'truck' | |
} | |
def get_name(label): | |
return label_map[label] | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# ============================================================================= | |
import sys, os, traceback | |
import glob, random, shutil, time | |
import numpy as np | |
from argparse import ArgumentParser | |
from argparse import RawDescriptionHelpFormatter | |
from singa import tensor, device, optimizer | |
from singa import utils, initializer, metric | |
from singa.proto import core_pb2 | |
from rafiki.agent import Agent, MsgType | |
import model | |
import data | |
def main(): | |
'''Command line options''' | |
argv = sys.argv | |
try: | |
# Setup argument parser | |
parser = ArgumentParser( | |
description="SINGA CIFAR SVG TRANING MODEL", | |
formatter_class=RawDescriptionHelpFormatter) | |
parser.add_argument( | |
"-p", | |
"--port", | |
dest="port", | |
default=9999, | |
help="the port to listen to, default is 9999") | |
parser.add_argument( | |
"-param", | |
"--parameter", | |
dest="parameter", | |
help="the parameter file path to be loaded") | |
parser.add_argument( | |
"-C", | |
"--cpu", | |
dest="use_cpu", | |
action="store_true", | |
default=False, | |
help="Using cpu or not, default is using gpu") | |
# Process arguments | |
args = parser.parse_args() | |
port = args.port | |
parameter_file = args.parameter | |
use_cpu = args.use_cpu | |
# start to train | |
m = model.create(use_cpu) | |
agent = Agent(port) | |
trainer = Trainer(m,agent,use_cpu) | |
trainer.initialize(parameter_file) | |
trainer.train() | |
#wait the agent finish handling http request | |
time.sleep(1) | |
agent.stop() | |
except SystemExit: | |
return | |
except: | |
#p.terminate() | |
traceback.print_exc() | |
sys.stderr.write(" for help use --help \n\n") | |
return 2 | |
class Trainer(): | |
'''train a singa model''' | |
def __init__(self, model, agent, use_cpu): | |
self.model = model | |
if use_cpu: | |
print "runing with cpu" | |
self.device = device.get_default_device() | |
#raise CLIError("Currently cpu is not support!") | |
else: | |
print "runing with gpu" | |
self.device = device.create_cuda_gpu() | |
self.opt = optimizer.SGD(momentum=0.9, weight_decay=0.0005) | |
self.agent = agent | |
def initialize(self, parameter_file): | |
'''initialize all parameters in the model''' | |
print 'Start intialization............' | |
if parameter_file: | |
parameter = data.get_parameter(parameter_file) | |
print 'initialize with %s' % parameter | |
self.model.load(parameter) | |
else: | |
for (p, specs) in zip(self.model.param_values(), | |
self.model.param_specs()): | |
filler = specs.filler | |
if filler.type == 'gaussian': | |
initializer.gaussian(p, filler.mean, filler.std) | |
elif filler.type == 'xavier': | |
initializer.xavier(p) | |
p *= 0.5 # 0.5 if use glorot, which would have val acc to 83 | |
else: | |
p.set_value(0) | |
self.opt.register(p, specs) | |
print specs.name, filler.type, p.l1() | |
self.model.to_device(self.device) | |
print 'End intialization............' | |
def data_prepare(self): | |
'''load data''' | |
data.train_file_prepare() | |
self.train_x, self.train_y = data.load_train_data() | |
self.test_x, self.test_y = data.load_test_data() | |
self.mean = data.load_mean_data() | |
if self.mean is None: | |
self.mean = np.average(self.train_x, axis=0) | |
data.save_mean_data(self.mean) | |
self.train_x -= self.mean | |
self.test_x -= self.mean | |
def pause(self): | |
while True: | |
msg,data = self.agent.pull() | |
if msg == None: | |
continue | |
msg=MsgType.parse(msg) | |
if MsgType.kCommandResume.equal(msg): | |
self.agent.push(MsgType.kStatus,"success") | |
break | |
elif MsgType.kCommandStop.equal(msg): | |
self.agent.push(MsgType.kStatus,"success") | |
return False | |
else: | |
self.agent.push(MsgType.kStatus,"warning, nothing happened") | |
print "Receive an unsupported command: %s " % str(msg) | |
pass | |
time.sleep(0.1) | |
return True | |
def listen(self): | |
msg,data = self.agent.pull() | |
if not msg == None: | |
msg=MsgType.parse(msg) | |
if msg.is_command(): | |
if MsgType.kCommandPause.equal(msg): | |
self.agent.push(MsgType.kStatus,"success") | |
if not self.pause(): | |
return False | |
elif MsgType.kCommandStop.equal(msg): | |
self.agent.push(MsgType.kStatus,"success") | |
return False | |
else: | |
self.agent.push(MsgType.kStatus,"warning, nothing happened") | |
print "Unsupported command %s" % str(msg) | |
pass | |
else: | |
pass | |
else: | |
pass | |
return True | |
def train(self, num_epoch=140, batch_size=50): | |
'''train and test model''' | |
self.data_prepare() | |
print 'training shape', self.train_x.shape, self.train_y.shape | |
print 'validation shape', self.test_x.shape, self.test_y.shape | |
tx = tensor.Tensor((batch_size, 3, 32, 32), self.device) | |
ty = tensor.Tensor((batch_size, ), self.device, core_pb2.kInt) | |
num_train_batch = self.train_x.shape[0] / batch_size | |
num_test_batch = self.test_x.shape[0] / (batch_size) | |
accuracy = metric.Accuracy() | |
idx = np.arange(self.train_x.shape[0], dtype=np.int32) | |
# frequency of gathering training status info | |
skip = 20 | |
stop=False | |
for epoch in range(num_epoch): | |
if not self.listen(): | |
stop=True | |
if stop: | |
break | |
np.random.shuffle(idx) | |
loss, acc = 0.0, 0.0 | |
print 'Epoch %d' % epoch | |
loss, acc = 0.0, 0.0 | |
for b in range(num_test_batch): | |
x = self.test_x[b * batch_size:(b + 1) * batch_size] | |
y = self.test_y[b * batch_size:(b + 1) * batch_size] | |
tx.copy_from_numpy(x) | |
ty.copy_from_numpy(y) | |
l, a = self.model.evaluate(tx, ty) | |
loss += l | |
acc += a | |
print 'testing loss = %f, accuracy = %f' % (loss / num_test_batch, | |
acc / num_test_batch) | |
# put test status info into a shared queue | |
dic = dict( | |
phase='test', | |
#step = (epoch + 1) * num_train_batch / skip - 1, | |
step=epoch * num_train_batch / skip, | |
accuracy=acc / num_test_batch, | |
loss=loss / num_test_batch, | |
timestamp=time.time()) | |
self.agent.push(MsgType.kInfoMetric,dic) | |
for b in range(num_train_batch): | |
if not self.listen(): | |
stop=True | |
break | |
x = self.train_x[idx[b * batch_size:(b + 1) * batch_size]] | |
y = self.train_y[idx[b * batch_size:(b + 1) * batch_size]] | |
tx.copy_from_numpy(x) | |
ty.copy_from_numpy(y) | |
grads, (l, a) = self.model.train(tx, ty) | |
loss += l | |
acc += a | |
for (s, p, g) in zip(self.model.param_specs(), | |
self.model.param_values(), grads): | |
self.opt.apply_with_lr(epoch, get_lr(epoch), g, p, | |
str(s.name)) | |
info = 'training loss = %f, training accuracy = %f' % (l, a) | |
# put training status info into a shared queue | |
if b % skip == 0: | |
dic = dict( | |
phase='train', | |
step=(epoch * num_train_batch + b) / skip, | |
accuracy=a, | |
loss=l, | |
timestamp=time.time()) | |
self.agent.push(MsgType.kInfoMetric,dic) | |
# update progress bar | |
utils.update_progress(b * 1.0 / num_train_batch, info) | |
info = 'training loss = %f, training accuracy = %f' \ | |
% (loss / num_train_batch, acc / num_train_batch) | |
print info | |
if epoch > 0 and epoch % 10 == 0: | |
self.model.save( | |
os.path.join(data.parameter_folder, 'parameter_%d' % | |
epoch)) | |
self.model.save(os.path.join(data.parameter_folder, 'parameter_last')) | |
return | |
def get_lr(epoch): | |
'''change learning rate as epoch goes up''' | |
if epoch < 360: | |
return 0.0008 | |
elif epoch < 540: | |
return 0.0001 | |
else: | |
return 0.00001 | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment