Skip to content

Instantly share code, notes, and snippets.

@jzstark
Last active August 21, 2019 12:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jzstark/d2bf87e02a0ebc0c39ed237af82e0a2c to your computer and use it in GitHub Desktop.
Save jzstark/d2bf87e02a0ebc0c39ed237af82e0a2c to your computer and use it in GitHub Desktop.

Owl-Tensorflow Converter Example: MNIST CNN Training

This example is a example of MNIST-based CNN.

  • Step 1 : running OCaml script tfgraph_train.ml, which generates a file tf_convert_mnist.pbtxt in current directory.
  • Step 2 : make sure tf_convert_mnist.pbtxt and tfgraph_train.py in the same graph; make sure Tensorflow/numpy etc. is installed.
  • Step 3 : execute python tf_converter_mnist.py, and the expected output on screen is the training progress. After each 100 steps, loss value and model accuracy will be shown.

Here we only assume the python script writer knows where to find the output node (in collection "result") and the placeholder names (x:0).

There could be many posssible source of error at this stage, one of which could be incompatible tensorflow version; in that case, probably find this line in test_cgraph.pbtxt : tensorflow_version: "1.12.0" and then change the version number. Also, Tensorflow may yield some warning messages about version/dataset etc. Current scripts may also ignore some factors like file/directory location.

#!/usr/bin/env owl
#require "owl-tensorflow"
open Owl
open Owl_tensorflow
open Owl_converter
module G = Owl_computation_cpu_engine.Make (Dense.Ndarray.S)
module T = Owl_converter.Make (G)
module CGCompiler = Owl_neural_compiler.Make (G)
open CGCompiler.Neural
open CGCompiler.Neural.Graph
open CGCompiler.Neural.Algodiff
let make_network input_shape =
input input_shape
|> lambda (fun x -> Maths.(x / pack_flt 256.))
|> conv2d [|5;5;1;32|] [|1;1|] ~act_typ:Activation.Relu
|> max_pool2d [|2;2|] [|2;2|]
|> fully_connected 1024 ~act_typ:Activation.Relu
|> linear 10 ~act_typ:Activation.(Softmax 1)
|> get_network ~name:"mnist"
let network = make_network [|28;28;1|]
let _ = Graph.init network
let x = G.var_arr "x" ~shape:[|100;28;28;1|] |> pack_arr
let y = Graph.forward network x |> fst
let output = [| unpack_arr y |> G.arr_to_node |]
let input = [| unpack_arr x |> G.arr_to_node |]
let cgraph = G.make_graph ~input ~output "graph_diff"
let s = G.graph_to_dot cgraph
let pbtxt = T.(convert cgraph |> to_pbtxt)
let _ = Owl_io.write_file "tf_convert_mnist.pbtxt" pbtxt
#!/usr/bin/env python
import numpy as np
import os
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.framework import graph_io
os.environ['KMP_DUPLICATE_LIB_OK']='True'
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("../data/", one_hot=True)
batch_size = 100
filename = 'tf_convert_mnist'
with open(filename + '.pbtxt', 'r') as f:
metagraph_def = tf.compat.v1.MetaGraphDef()
file_content = f.read()
text_format.Merge(file_content,metagraph_def)
graph_io.write_graph(metagraph_def,
os.path.dirname(filename),
os.path.basename(filename) + '.pb',
as_text=False)
input_data = np.random.rand(100, 28, 28, 1)
with tf.Graph().as_default():
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)
saver = tf.compat.v1.train.import_meta_graph(filename+'.pb')
graph = tf.compat.v1.get_default_graph()
x = graph.get_tensor_by_name('x:0')
y = tf.compat.v1.placeholder("float", [None, 10])
result = tf.compat.v1.get_collection("result")[0]
correct_pred = tf.equal(tf.argmax(result, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Extra tensors for training
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits_v2(
labels=y, logits=result
))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
init = tf.global_variables_initializer()
sess.run(init)
# Begin training
for i in range(2000):
batch_x, batch_y = mnist.train.next_batch(batch_size)
batch_x = np.reshape(batch_x, (-1,28,28,1))
sess.run(train_step, feed_dict={x: batch_x, y: batch_y})
if (i % 100 == 0):
minibatch_loss, acc = sess.run([cross_entropy, accuracy], feed_dict={x: batch_x, y: batch_y})
print("Loss:%s" % str(minibatch_loss))
print("Accuracy:%s\n" % str(acc))
# saver.save(sess, "owl_model.ckpt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment