Skip to content

Instantly share code, notes, and snippets.

@adash333
Created October 18, 2017 15:53
Show Gist options
  • Save adash333/c895b792eca527701ffca54f9dc59331 to your computer and use it in GitHub Desktop.
Save adash333/c895b792eca527701ffca54f9dc59331 to your computer and use it in GitHub Desktop.
predict.py -error
# this code does not work
# error
# chainer.utils.type_check.InvalidType:
# Invalid operation is performed in: Convolution2DFunction (Forward)
#
# Expect: in_types[0].ndim == 4
# Actual: 2 != 4
from flask import Flask, render_template, request, redirect, url_for
import numpy as np
from PIL import Image
import chainer
from chainer import cuda, Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.datasets import tuple_dataset
from chainer import training
from chainer.training import extensions
from datetime import datetime
"""
# Network definition
class MLP(chainer.Chain):
def __init__(self, n_units, n_out):
super(MLP, self).__init__()
with self.init_scope():
# the size of the inputs to each layer will be inferred
self.l1 = L.Linear(None, n_units) # n_in -> n_units
self.l2 = L.Linear(None, n_units) # n_units -> n_units
self.l3 = L.Linear(None, n_out) # n_units -> n_out
def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
return self.l3(h2)
"""
#2 モデルの記述
class MyModel(Chain):
# パラメータを含む関数の宣言
def __init__(self):
super(MyModel, self).__init__(
cn1=L.Convolution2D(1,20,5),
cn2=L.Convolution2D(20,50,5),
fc1=L.Linear(800,500),
fc2=L.Linear(500,10),
)
# 損失関数
def __call__(self, x,t):
return F.softmax_cross_entropy(self.fwd(x),t)
def fwd(self, x):
h1 = F.max_pooling_2d(F.relu(self.cn1(x)),2)
h2 = F.max_pooling_2d(F.relu(self.cn2(h1)),2)
h3 = F.dropout(F.relu(self.fc1(h2)))
return self.fc2(h3)
app = Flask(__name__)
@app.route('/', methods = ['GET', 'POST'])
def upload_file():
if request.method == 'GET':
return render_template('index.html')
if request.method == 'POST':
# アプロードされたファイルを保存する
f = request.files['file']
filepath = "./static/" + datetime.now().strftime("%Y%m%d%H%M%S") + ".png"
f.save(filepath)
# モデルを使って判定する
# model = L.Classifier(MyModel(1000, 10))
model = MyModel()
optimizer = optimizers.Adam()
optimizer.setup(model)
# serializers.load_npz('lcd2.model', model)
serializers.load_hdf5('mnist-cnn.model', model)
image = Image.open(filepath).convert('L')
image = np.asarray(image).astype(np.float32) / 255
# image = image.reshape((1, -1))
x = Variable(image)
# y = model.predictor(x) #
out = model.fwd(x)
# predict = np.argmax(y.data)
predict = np.argmax(out.data)
return render_template('index.html', filepath = filepath , predict = predict )
if __name__ == '__main__':
app.run(host="0.0.0.0", port=int("5000"),debug=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment