Skip to content

Instantly share code, notes, and snippets.

@ota42y
Last active February 11, 2019 03:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ota42y/3a7622f84ce86823f62457df4ff6639d to your computer and use it in GitHub Desktop.
Save ota42y/3a7622f84ce86823f62457df4ff6639d to your computer and use it in GitHub Desktop.
tenserflow serving test
# 2019/02/10現在、Ruby 2.6ではgoogle-protobufが動かない(google-protobufの3.7.0で治る)
FROM ruby:2.5.3
RUN gem install grpc grpc-tools
import tensorflow as tf
from tensorflow import keras
export_path = './fmnist_model/1'
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_images = train_images / 255.0
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28), name='inputs'),
keras.layers.Dense(128, activation=tf.nn.relu),
keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=5)
with tf.keras.backend.get_session() as sess:
tf.saved_model.simple_save(
sess,
export_path,
inputs={'inputs': model.input},
outputs={'outputs': model.output})
lib_dir = File.join('/work/proto_ruby')
$LOAD_PATH.unshift(lib_dir) unless $LOAD_PATH.include?(lib_dir)
Dir.glob('/work/proto_ruby/**/*.rb'){ |path| require_relative(path) unless File.directory? path }
require 'json'
data = JSON.load(open('/work/ruby_data.json'))
request = Tensorflow::Serving::PredictRequest.new
request.model_spec = Tensorflow::Serving::ModelSpec.new(name: "fmnist_model", signature_name: data["signature_name"])
images_proto = Tensorflow::TensorProto.new
images_proto.dtype = :DT_FLOAT
shape = Tensorflow::TensorShapeProto.new
shape.dim << Tensorflow::TensorShapeProto::Dim.new(size: 1)
shape.dim << Tensorflow::TensorShapeProto::Dim.new(size: 28)
shape.dim << Tensorflow::TensorShapeProto::Dim.new(size: 28)
shape.dim << Tensorflow::TensorShapeProto::Dim.new(size: 1)
images_proto.tensor_shape = shape
data["inputs"][0].each { |line| line.each { |dot| images_proto.float_val << dot } }
request.inputs['inputs'] = images_proto
require 'grpc'
stub = Tensorflow::Serving::PredictionService::Stub.new('host.docker.internal:8500', :this_channel_is_insecure)
ret = stub.predict(request)
vals = ret.outputs['outputs'].float_val
vals.index(vals.max)
# => 2
cd /work
mkdir proto_ruby
export PROTO_DIR=/work/proto_ruby/
find ./serving/tensorflow_serving -name *.proto | xargs grpc_tools_ruby_protoc -I=serving -I=serving/tensorflow --ruby_out=$PROTO_DIR --grpc_out=$PROTO_DIR --plugin=protoc-gen-grpc=`which grpc_tools_ruby_protoc_plugin`
grpc_tools_ruby_protoc -I serving/tensorflow --ruby_out=$PROTO_DIR --grpc_out=$PROTO_DIR --plugin=protoc-gen-grpc=`which grpc_tools_ruby_protoc_plugin` serving/tensorflow/tensorflow/core/{framework,example,protobuf}/*.proto
grpc_tools_ruby_protoc -I serving/tensorflow --ruby_out=$PROTO_DIR --grpc_out=$PROTO_DIR --plugin=protoc-gen-grpc=`which grpc_tools_ruby_protoc_plugin` serving/tensorflow/tensorflow/core/lib/core/error_codes.proto
import json
d = {
"signature_name": 'serving_default',
"inputs": [test_images[0].tolist()]
}
with open('./test_data.json', mode='w') as f:
f.write(json.dumps(d))
d = {
"signature_name": 'serving_default',
"inputs": [test_images[1].tolist()]
}
with open('./ruby_data.json', mode='w') as f:
f.write(json.dumps(d))
print("{}, {}".format(test_labels[0], test_labels[1]))
# => 9, 2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment