Skip to content

Instantly share code, notes, and snippets.

@jzstark
Last active February 18, 2019 18:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jzstark/aa4fa6e82667d4fc89214e47febfafb1 to your computer and use it in GitHub Desktop.
Save jzstark/aa4fa6e82667d4fc89214e47febfafb1 to your computer and use it in GitHub Desktop.
Owl-Tensorflow Converter Example: Oscillator

Owl-Tensorflow Converter Example: Oscillator

This example is provided by @tachukao. It is a simple example of learning a periodic oscillator and the initial condition.

  • Step 1 : running OCaml script oscillator.ml, which generates a file oscillator.pbtxt in current directory. Depending on n_steps, this step might take a bit long
  • Step 2 : make sure oscillator.pbtxt and oscillator.py in the same graph; make sure Tensorflow/numpy etc. is installed.
  • Step 3 : execute python oscillator.py, and the expected output in screen is a float number.

Here we only assume the python script writer knows where to find the output node (in collection "result") and the placeholder names (x0 and a).

There could be many posssible source of error at this stage, one of which could be incompatible tensorflow version; in that case, probably find this line in test_cgraph.pbtxt : tensorflow_version: "1.12.0" and then change the version number. Also, Tensorflow may yield some warning messages about version/dataset etc. Current scripts may also ignore some factors like file/directory location.

(* simple example of learning a periodic oscillator
* and the initial condition *)
#!/usr/bin/env owl
#require "owl-tensorflow"
open Owl
open Owl_tensorflow
open Owl_converter
module N = Dense.Ndarray.S
module G = Owl_computation_cpu_engine.Make (N)
module T = Owl_converter.Make (G)
include Owl_algodiff_generic.Make (G)
(* set-up *)
let n_steps = 100
let dt = 1E-2
(* pack the parameters *)
let f a x0 =
let rec run step c x =
if step < n_steps then
let x = if step = 0 then x else Maths.(a *@ x) in
let t = pack_flt ((float n_steps) *. dt) in
let y = Maths.(of_arrays [| [|cos t|]; [|sin t|] |]) in
run (step + 1) Maths.(c + (l2norm_sqr' (x - y))) x
else c in
run 0 (pack_flt 0.) x0
let make_cgraph () =
let a = G.var_arr ~shape:[|2; 2|] "a" in
let x0 = G.var_arr ~shape:[|2; 1|] "x0" in
let c = f (pack_arr a) (pack_arr x0) in
let input = [|
a |> G.arr_to_node;
x0 |> G.arr_to_node
|] in
let output = [| c |> unpack_elt |> G.elt_to_node |] in
let g = G.make_graph ~input ~output "oscillator" in
g
let visualise_cgraph () =
let g = make_cgraph () in
let s0 = G.graph_to_dot g in
Owl_io.write_file "oscillator.dot" s0;
Sys.command "dot -Tpdf oscillator.dot -o oscillator.pdf" |> ignore
let eval () =
let g = make_cgraph () in
let inp = G.get_inputs g in
G.assign_arr (G.node_to_arr inp.(0)) (N.uniform [|2; 2|]);
G.assign_arr (G.node_to_arr inp.(1)) (N.uniform [|2; 1|]);
G.eval_graph g;
let outputs = G.get_outputs g in
Owl_log.info "result: %.2f" (outputs.(0) |> G.node_to_elt |> G.unpack_elt)
(* Depending on n_steps, this step might take a bit long *)
let convert () =
let cgraph_forward = make_cgraph () in
let pbtxt = T.(convert cgraph_forward |> to_pbtxt) in
Owl_io.write_file "oscillator.pbtxt" pbtxt
let _ = convert ()
#!/usr/bin/env python
from __future__ import print_function
import numpy as np
import os
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.framework import graph_io
def eval(meta_file):
with tf.Graph().as_default():
sess = tf.Session()
saver = tf.train.import_meta_graph(meta_file)
graph = tf.get_default_graph()
a = graph.get_tensor_by_name('a:0')
x0 = graph.get_tensor_by_name('x0:0')
result0 = tf.get_collection("result")[0]
init = tf.global_variables_initializer()
sess.run(init)
a_data = np.random.uniform(low=0, high=1, size=(2,2))
x0_data = np.random.uniform(low=0, high=1, size=(2,1))
y = sess.run(result0, feed_dict={a:a_data, x0:x0_data})
print(y)
filename = 'oscillator'
with open(filename + '.pbtxt', 'r') as f:
metagraph_def = tf.MetaGraphDef()
file_content = f.read()
text_format.Merge(file_content,metagraph_def)
graph_io.write_graph(metagraph_def,
os.path.dirname(filename),
os.path.basename(filename) + '.pb',
as_text=False)
y = eval(filename+'.pb')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment