Created
October 18, 2017 15:53
predict.py -error
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# this code does not work | |
# error | |
# chainer.utils.type_check.InvalidType: | |
# Invalid operation is performed in: Convolution2DFunction (Forward) | |
# | |
# Expect: in_types[0].ndim == 4 | |
# Actual: 2 != 4 | |
from flask import Flask, render_template, request, redirect, url_for | |
import numpy as np | |
from PIL import Image | |
import chainer | |
from chainer import cuda, Function, gradient_check, report, training, utils, Variable | |
from chainer import datasets, iterators, optimizers, serializers | |
from chainer import Link, Chain, ChainList | |
import chainer.functions as F | |
import chainer.links as L | |
from chainer.datasets import tuple_dataset | |
from chainer import training | |
from chainer.training import extensions | |
from datetime import datetime | |
""" | |
# Network definition | |
class MLP(chainer.Chain): | |
def __init__(self, n_units, n_out): | |
super(MLP, self).__init__() | |
with self.init_scope(): | |
# the size of the inputs to each layer will be inferred | |
self.l1 = L.Linear(None, n_units) # n_in -> n_units | |
self.l2 = L.Linear(None, n_units) # n_units -> n_units | |
self.l3 = L.Linear(None, n_out) # n_units -> n_out | |
def __call__(self, x): | |
h1 = F.relu(self.l1(x)) | |
h2 = F.relu(self.l2(h1)) | |
return self.l3(h2) | |
""" | |
#2 モデルの記述 | |
class MyModel(Chain): | |
# パラメータを含む関数の宣言 | |
def __init__(self): | |
super(MyModel, self).__init__( | |
cn1=L.Convolution2D(1,20,5), | |
cn2=L.Convolution2D(20,50,5), | |
fc1=L.Linear(800,500), | |
fc2=L.Linear(500,10), | |
) | |
# 損失関数 | |
def __call__(self, x,t): | |
return F.softmax_cross_entropy(self.fwd(x),t) | |
def fwd(self, x): | |
h1 = F.max_pooling_2d(F.relu(self.cn1(x)),2) | |
h2 = F.max_pooling_2d(F.relu(self.cn2(h1)),2) | |
h3 = F.dropout(F.relu(self.fc1(h2))) | |
return self.fc2(h3) | |
app = Flask(__name__) | |
@app.route('/', methods = ['GET', 'POST']) | |
def upload_file(): | |
if request.method == 'GET': | |
return render_template('index.html') | |
if request.method == 'POST': | |
# アプロードされたファイルを保存する | |
f = request.files['file'] | |
filepath = "./static/" + datetime.now().strftime("%Y%m%d%H%M%S") + ".png" | |
f.save(filepath) | |
# モデルを使って判定する | |
# model = L.Classifier(MyModel(1000, 10)) | |
model = MyModel() | |
optimizer = optimizers.Adam() | |
optimizer.setup(model) | |
# serializers.load_npz('lcd2.model', model) | |
serializers.load_hdf5('mnist-cnn.model', model) | |
image = Image.open(filepath).convert('L') | |
image = np.asarray(image).astype(np.float32) / 255 | |
# image = image.reshape((1, -1)) | |
x = Variable(image) | |
# y = model.predictor(x) # | |
out = model.fwd(x) | |
# predict = np.argmax(y.data) | |
predict = np.argmax(out.data) | |
return render_template('index.html', filepath = filepath , predict = predict ) | |
if __name__ == '__main__': | |
app.run(host="0.0.0.0", port=int("5000"),debug=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment