Skip to content

Instantly share code, notes, and snippets.

@riccardomurri
Created October 2, 2018 19:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save riccardomurri/c293abc5a51a39a50512779dca72c727 to your computer and use it in GitHub Desktop.
Save riccardomurri/c293abc5a51a39a50512779dca72c727 to your computer and use it in GitHub Desktop.
Compute point in Mandelbrot graph with TF
#! /usr/bin/env python3
"""
Load TF graph from `./mandelbrot.save` and run it.
"""
import numpy as np
import tensorflow as tf
from tensorflow import (
constant as C_,
Variable as V_,
placeholder as X_,
)
#from tensorflow.saved_model import loader, tag_constants
from tensorflow.train import import_meta_graph
def mandelbrot(x, y):
"""
Compute number of iterations of the Mandelbrot function at (x,y).
"""
with tf.Session().as_default() as session:
saver = import_meta_graph('mandelbrot.meta')
# loader.load(
# session,
# [tag_constants.SERVING], # default in `simple_save`
# './mandelbrot.model.d'
# )
g = session.graph
input_ = g.get_tensor_by_name('IN:0')
output_ = g.get_operation_by_name('OUT')
n_out = g.get_tensor_by_name('n_out:0')
x_out = g.get_tensor_by_name('x_out:0')
y_out = g.get_tensor_by_name('y_out:0')
feed = { input_:[x, y] }
session.run(tf.global_variables_initializer(), feed)
#tf.train.Saver(inout_vars).restore('./mandelbrot.save')
# run the graph at the chosen point
session.run(output_, feed)
print("({0},{1}): {2}".format(x, y, [
n_out.eval(feed),
x_out.eval(feed),
y_out.eval(feed),
]))
#
# main
#
if __name__ == '__main__':
mandelbrot(0.25, -0.15)
#! /usr/bin/env python3
"""
Generate TF graph computing nr. of iterations for a point in the Mandelbrot plane,
and save it to `./mandelbrot.save.*`.
"""
import numpy as np
import tensorflow as tf
from tensorflow import (
constant as C_,
Variable as V_,
placeholder as X_,
)
#from tensorflow.saved_model import simple_save as save
from tensorflow.train import export_meta_graph
def mandelbrot(x, y):
"""
Compute number of iterations of the Mandelbrot function at (x,y).
"""
g, input_, output_ = mandelbrot_()
# print("DEBUG: input_.={!r}".format(input_))
# print("DEBUG: output_.={!r}".format(output_))
n_out = g.get_tensor_by_name('n_out:0')
x_out = g.get_tensor_by_name('x_out:0')
y_out = g.get_tensor_by_name('y_out:0')
inout_vars = [n_out, x_out, y_out]
with tf.Session(graph=g).as_default() as session:
writer = tf.summary.FileWriter('./events', g)
feed = { input_:[x, y] }
# initialize vars with default values
with g.as_default():
session.run(tf.global_variables_initializer(), feed)
# run the graph at the chosen point
session.run(output_, feed)
print("({0},{1}): {2}"
.format(x, y, [
n_out.eval(feed),
x_out.eval(feed),
y_out.eval(feed),
]))
# save
for v in inout_vars:
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, v)
export_meta_graph(
filename='mandelbrot.meta',
graph=g,
#as_text=True, # DEBUG
)
# save(
# session, './mandelbrot.model.d',
# inputs={},
# outputs={
# "n_out": n_out,
# "x_out": x_out,
# "y_out": y_out,
# },
# )
def mandelbrot_(maxiter=255):
"""
Return graph computing the Mandelbrot set at (x,y).
"""
graph = tf.Graph()
with graph.as_default():
# input variables
input_ = X_(tf.float32, shape=[2], name='IN')
x = input_[0]
y = input_[1]
# output variables
n_ = V_([0], tf.int32, name='n_out')
x_ = V_(0.0, tf.float32, name='x_out')
y_ = V_(0.0, tf.float32, name='y_out')
# main loop
i_ = tf.constant(0)
def cond(i_, z_re_, z_im_):
return tf.logical_and(
tf.less(i_, maxiter),
(z_re_*z_re_ + z_im_*z_im_) < 4)
def body(i_, z_re_, z_im_):
return [
i_+1, # iteration count
z_re_*z_re_ - z_im_*z_im_ + x, # real part of z
2*z_re_*z_im_ + y, # imag part of z
]
l_ = tf.while_loop(cond, body, [i_, x, y],
name='mandelbrot',
parallel_iterations=1,
back_prop=False,
return_same_structure=True)
with tf.control_dependencies(l_):
output_ = tf.group(
n_.assign([l_[0]]),
x_.assign(l_[1]),
y_.assign(l_[2]),
name='OUT',
)
return (
graph, # graph
input_, # inputs
output_, # outputs
)
#
# main
#
if __name__ == '__main__':
mandelbrot(0.25, -0.15)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment