Skip to content

Instantly share code, notes, and snippets.

@nudles
Last active February 1, 2018 05:33
Show Gist options
  • Save nudles/dc2c97f3b3f007109bffbd3e721f2318 to your computer and use it in GitHub Desktop.
Save nudles/dc2c97f3b3f007109bffbd3e721f2318 to your computer and use it in GitHub Desktop.
xception-foodlg
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# =============================================================================
from builtins import str
from builtins import object
from multiprocessing import Process, Queue
from flask import Flask,request, send_from_directory, jsonify
from flask_cors import CORS, cross_origin
import os, traceback, sys
import time
from werkzeug.utils import secure_filename
from werkzeug.datastructures import CombinedMultiDict, MultiDict
import pickle
import uuid
class MsgType(object):
def __init__(self, name):
self.name = name
def __str__(self):
return self.name
def __repr__(self):
return "<Msg: %s>" % self
def equal(self,target):
return str(self) == str(target)
def is_info(self):
return self.name.startswith('kInfo')
def is_command(self):
return self.name.startswith('kCommand')
def is_status(self):
return self.name.startswith('kStatus')
def is_request(self):
return self.name.startswith('kRequest')
def is_response(self):
return self.name.startswith('kResponse')
@staticmethod
def parse(name):
return getattr(MsgType,str(name))
@staticmethod
def get_command(name):
if name=='stop':
return MsgType.kCommandStop
if name=='pause':
return MsgType.kCommandPause
if name=='resume':
return MsgType.kCommandResume
return MsgType.kCommand
types = ['kInfo','kInfoMetric',
'kCommand','kCommandStop','kCommandPause','kCommandResume',
'kStatus','kStatusRunning','kStatusPaused','kStatusError',
'kRequest','kResponse']
for t in types:
setattr(MsgType,t,MsgType(t))
##### NOTE the server currently only can handle request sequentially
app = Flask(__name__)
top_k_=5
class Agent(object):
def __init__(self,port):
info_queue = Queue()
command_queue = Queue()
self.p = Process(target=start, args=(port, info_queue,command_queue))
self.p.start()
self.info_queue=info_queue
self.command_queue=command_queue
return
def pull(self):
if not self.command_queue.empty():
msg,data=self.command_queue.get()
if msg.is_request():
data = pickle.loads(data)
return msg,data
return None,None
def push(self,msg,value):
self.info_queue.put((msg,value))
return
def stop(self):
#sleep a while, wait for http response finished
time.sleep(1)
self.p.terminate()
def start(port,info_queue,command_queue):
global info_queue_, command_queue_, data_
info_queue_=info_queue
command_queue_=command_queue
data_ = []
app.run(host='0.0.0.0', port=port)
return
def getDataFromInfoQueue(need_return=False):
global info_queue_, data_
if not need_return:
while not info_queue_.empty():
msg,d = info_queue_.get()
data_.append(d)
else:
while True: # loop until get answer
while not info_queue_.empty():
msg,d = info_queue_.get()
if msg.is_info():
data_.append(d)
else:
return msg,d
time.sleep(0.01)
@app.route("/")
@cross_origin()
def index():
try:
req=send_from_directory(os.getcwd(),"index.html", mimetype='text/html')
except:
traceback.print_exc()
return "error"
return req
# support two operations for user to monitor the training status
@app.route('/getAllData')
@cross_origin()
def getAllData():
global data_
try:
getDataFromInfoQueue()
except:
traceback.print_exc()
return failure("Internal Error")
return success(data_)
@app.route('/getTopKData')
@cross_origin()
def getTopKData():
global data_
try:
k = int(request.args.get("k", top_k_))
except:
traceback.print_exc()
return failure("k should be integer")
try:
getDataFromInfoQueue()
except:
traceback.print_exc()
return failure("Internal Error")
return success(data_[-k:])
@app.route("/api", methods=['POST'])
@cross_origin()
def api():
global info_queue_,command_queue_
try:
files=transformFile(request.files)
values = CombinedMultiDict([request.args,request.form,files])
req_str = pickle.dumps(values)
command_queue_.put((MsgType.kRequest,req_str))
msg,response=getDataFromInfoQueue(True)
deleteFiles(files)
return response
except:
traceback.print_exc()
return failure("Internal Error")
@app.route("/command/<name>", methods=['GET','POST'])
@cross_origin()
def command(name):
global info_queue_,command_queue_
try:
command=MsgType.get_command(name)
command_queue_.put((command,""))
msg,response=getDataFromInfoQueue(True)
return response
except:
traceback.print_exc()
return failure("Internal Error")
def success(data=""):
'''return success status in json format'''
res = dict(result="success", data=data)
return jsonify(res)
def failure(message):
'''return failure status in json format'''
res = dict(result="message", message=message)
return jsonify(res)
def transformFile(files):
result= MultiDict([])
for f in files:
file = files[f]
unique_filename = str(uuid.uuid4())+secure_filename(file.filename)
filepath=os.path.join(os.getcwd(), unique_filename)
file.save(filepath)
result.add(f, filepath)
return result
def deleteFiles(files):
for f in files:
filepath = files[f]
os.remove(filepath)
return
## Xception model with download automatically
from keras.applications.xception import Xception
from keras.layers import GlobalAveragePooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.optimizers import SGD
from keras.layers import Dense
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.models import Model
import sys
import argparse
import traceback
from agent import MsgType, Agent
def main(args, agent):
img_rows, img_cols = 299, 299 # Resolution of inputs
channel = 3
num_classes = 100
batch_size = 8
nb_epoch = 100
#model = inception_v3_model(img_rows, img_cols, channel, num_classes)
base_model = Xception(weights='imagenet', include_top=False)
# add a global spatial average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)
# let's add a fully-connected layer
x = Dense(1024, activation='relu')(x)
# and a logistic layer -- let's say we have 200 classes
predictions = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
model.compile(optimizer=SGD(lr=1e-3, decay=1e-6, nesterov=True, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])
# (train_images, train_labels) = load_train_data((img_rows, img_cols))
filepath = "./models/xception-{epoch:02d}-{loss:0.3f}-{acc:0.3f}-{val_loss:0.3f}-{val_acc:0.3f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor = "loss", verbose = 1, save_best_only = True, mode = 'min')
early_stopping = EarlyStopping(monitor = "val_loss", patience = 3)
callbacks_list = [checkpoint, early_stopping]
datagen = ImageDataGenerator(
horizontal_flip=True,
fill_mode="nearest",
zoom_range=0.3,
width_shift_range=0.3,
height_shift_range=0.3,
rotation_range=30)
train_gen = train_datagen.flow_from_directory(args.train_data, batch_size=batch_size)
val_gen = train_datagen.flow_from_directory(args.val_data, batch_size=batch_size)
for epoch in range(nb_epoch):
hist = model.fit_generator(gen,
steps_per_epoch=len(gen) / batch_size,
epochs=1, verbose=1, validation_data=val_gen,
validation_steps=len(val_gen) / batch_size,
callbacks=callbacks_list
)
agent.push(MsgType.kInfoMetric, hist[-1])
if __name__ == '__main__':
try:
parser = argparse.ArgumentParser()
action = parser.add_mutually_exclusive_group(required=True)
action.add_argument('--train', help='Train a model', action='store_true')
action.add_argument('--test', help='Predict using a saved model', metavar='MODEL')
action.add_argument('--extract', help='Extract features using a saved model', metavar='MODEL')
action.add_argument('--port', type=int, default=8333)
args = parser.parse_args()
port = args.port
agent = Agent(port)
main(args, agent)
agent.stop()
except:
traceback.print_exc()
sys.stderr.write(" for help use --help \n\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment