Skip to content

Instantly share code, notes, and snippets.

@memo
Last active February 28, 2018 06:37
Show Gist options
  • Save memo/c21356c3a2cf39948985f05352ea5212 to your computer and use it in GitHub Desktop.
Save memo/c21356c3a2cf39948985f05352ea5212 to your computer and use it in GitHub Desktop.
very quick & simple dictionary / json based graph builder for tensorflow
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Tue Feb 27 02:46:51 2018
@author: memo
very quick & simple dictionary / json based graph builder for tensorflow
( inspired by https://github.com/dribnet/discgen/blob/master/discgen/vae.py#L43-L163 )
"""
from __future__ import print_function
from __future__ import division
import tensorflow as tf
import numpy as np
from pprint import pprint
import msa.tf.ops
def example():
tf.reset_default_graph()
# dict of dicts { { <opname> : kwargs }, ... }
default_op_args = {
'conv2d' : { 'padding':'same', 'kernel_size':(3,3), 'strides':(1,1) },
'conv2d_transpose' : { 'kernel_size':(2,2), 'strides':(2,2) },
}
# list of dicts [ {'op':<opname>, kwargs }, ... ]
encoder_ops_info = [
{ 'op':'conv2d', 'filters':64 }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'conv2d', 'kernel_size':(2,2), 'strides':(2,2), 'filters':64 }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'conv2d', 'filters':128 }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'conv2d', 'kernel_size':(2,2), 'strides':(2,2), 'filters':128 }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'conv2d', 'filters':256 }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'conv2d', 'kernel_size':(2,2), 'strides':(2,2), 'filters':256 }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'identity', 'name':'pre_z_conv' },
{ 'op':'flatten' },
{ 'op':'dense', 'units':1024 }, { 'op':'batch_norm' }, { 'op':'relu' },
]
decoder_ops_info = [
{ 'op':'dense', 'units':128, 'name':'z' }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'dense', 'units':0, 'name':'post_z_flat' }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'tf.reshape', 'name':'post_z_conv' }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'conv2d', 'filters':256 }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'conv2d_transpose','filters':256 }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'conv2d', 'filters':128 }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'conv2d_transpose', 'filters':128 }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'conv2d', 'filters':64 }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'conv2d_transpose', 'filters':64 }, { 'op':'batch_norm' }, { 'op':'relu' },
{ 'op':'conv2d', 'kernel_size':(1,1), 'filters':3 }, { 'op':'tanh', 'name':'output'}
]
x = tf.placeholder(tf.float32, [None, 64, 64, 3])
# build encoder
with tf.variable_scope('encoder'):
encoder_ops, errors = build_graph(x, encoder_ops_info, default_op_args)
# TODO: THIS BIT IS UGLY. is there a better way of automating all of this?
# need to get the conv shape before flattening. search encoder tensors by name
pre_z_conv = get_tensors_by_name(encoder_ops, 'pre_z_conv')[0]
# write to decoder_ops_info
# flattened shape is multiplication of all dims except for batch size
get_ops_by_name(decoder_ops_info, 'post_z_flat')[0]['units'] = np.prod(pre_z_conv.shape[1:])
# first conv op after flat layer needs write shape.
get_ops_by_name(decoder_ops_info, 'post_z_conv')[0]['shape'] = tf.shape(pre_z_conv)
# now build decoder
with tf.variable_scope('decoder'):
decoder_ops, errors = build_graph(encoder_ops[-1], decoder_ops_info, default_op_args)
return encoder_ops, decoder_ops
'''
Output:
--------------------------------------------------------------------------------
> msa.tf.ops.conv2d {'filters': 64} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("encoder/conv2d/BiasAdd:0", shape=(?, 256, 256, 64), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization/FusedBatchNorm:0", shape=(?, 256, 256, 64), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("encoder/Relu:0", shape=(?, 256, 256, 64), dtype=float32)
> msa.tf.ops.conv2d {'strides': (2, 2), 'kernel_size': (2, 2), 'filters': 64} + defaults {'padding': 'same'} --> Tensor("encoder/conv2d_2/BiasAdd:0", shape=(?, 128, 128, 64), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_2/FusedBatchNorm:0", shape=(?, 128, 128, 64), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("encoder/Relu_1:0", shape=(?, 128, 128, 64), dtype=float32)
> msa.tf.ops.conv2d {'filters': 128} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("encoder/conv2d_3/BiasAdd:0", shape=(?, 128, 128, 128), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_3/FusedBatchNorm:0", shape=(?, 128, 128, 128), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("encoder/Relu_2:0", shape=(?, 128, 128, 128), dtype=float32)
> msa.tf.ops.conv2d {'strides': (2, 2), 'kernel_size': (2, 2), 'filters': 128} + defaults {'padding': 'same'} --> Tensor("encoder/conv2d_4/BiasAdd:0", shape=(?, 64, 64, 128), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_4/FusedBatchNorm:0", shape=(?, 64, 64, 128), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("encoder/Relu_3:0", shape=(?, 64, 64, 128), dtype=float32)
> msa.tf.ops.conv2d {'filters': 256} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("encoder/conv2d_5/BiasAdd:0", shape=(?, 64, 64, 256), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_5/FusedBatchNorm:0", shape=(?, 64, 64, 256), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("encoder/Relu_4:0", shape=(?, 64, 64, 256), dtype=float32)
> msa.tf.ops.conv2d {'strides': (2, 2), 'kernel_size': (2, 2), 'filters': 256} + defaults {'padding': 'same'} --> Tensor("encoder/conv2d_6/BiasAdd:0", shape=(?, 32, 32, 256), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_6/FusedBatchNorm:0", shape=(?, 32, 32, 256), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("encoder/Relu_5:0", shape=(?, 32, 32, 256), dtype=float32)
> msa.tf.ops.identity {'name': 'pre_z_conv'} --> Tensor("encoder/pre_z_conv:0", shape=(?, 32, 32, 256), dtype=float32)
> msa.tf.ops.flatten {} --> Tensor("encoder/Flatten/flatten/Reshape:0", shape=(?, 262144), dtype=float32)
> msa.tf.ops.dense {'units': 1024} --> Tensor("encoder/dense/BiasAdd:0", shape=(?, 1024), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_7/batchnorm/add_1:0", shape=(?, 1024), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("encoder/Relu_6:0", shape=(?, 1024), dtype=float32)
--------------------------------------------------------------------------------
23 ops added
--------------------------------------------------------------------------------
> msa.tf.ops.dense {'units': 128, 'name': 'z'} --> Tensor("decoder/z/BiasAdd:0", shape=(?, 128), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization/batchnorm/add_1:0", shape=(?, 128), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("decoder/Relu:0", shape=(?, 128), dtype=float32)
> msa.tf.ops.dense {'units': 1024} --> Tensor("decoder/dense/BiasAdd:0", shape=(?, 1024), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_2/batchnorm/add_1:0", shape=(?, 1024), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_1:0", shape=(?, 1024), dtype=float32)
> tf.reshape {'shape': <tf.Tensor 'Shape:0' shape=(4,) dtype=int32>, 'name': 'post_z_conv'} --> Tensor("decoder/post_z_conv:0", shape=(?, 32, 32, 256), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_3/FusedBatchNorm:0", shape=(?, 32, 32, 256), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_2:0", shape=(?, 32, 32, 256), dtype=float32)
> msa.tf.ops.conv2d {'filters': 256} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("decoder/conv2d/BiasAdd:0", shape=(?, 32, 32, 256), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_4/FusedBatchNorm:0", shape=(?, 32, 32, 256), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_3:0", shape=(?, 32, 32, 256), dtype=float32)
> msa.tf.ops.conv2d_transpose {'filters': 256} + defaults {'strides': (2, 2), 'kernel_size': (2, 2)} --> Tensor("decoder/conv2d_transpose/BiasAdd:0", shape=(?, 64, 64, 256), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_5/FusedBatchNorm:0", shape=(?, 64, 64, 256), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_4:0", shape=(?, 64, 64, 256), dtype=float32)
> msa.tf.ops.conv2d {'filters': 128} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("decoder/conv2d_2/BiasAdd:0", shape=(?, 64, 64, 128), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_6/FusedBatchNorm:0", shape=(?, 64, 64, 128), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_5:0", shape=(?, 64, 64, 128), dtype=float32)
> msa.tf.ops.conv2d_transpose {'filters': 128} + defaults {'strides': (2, 2), 'kernel_size': (2, 2)} --> Tensor("decoder/conv2d_transpose_2/BiasAdd:0", shape=(?, 128, 128, 128), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_7/FusedBatchNorm:0", shape=(?, 128, 128, 128), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_6:0", shape=(?, 128, 128, 128), dtype=float32)
> msa.tf.ops.conv2d {'filters': 64} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("decoder/conv2d_3/BiasAdd:0", shape=(?, 128, 128, 64), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_8/FusedBatchNorm:0", shape=(?, 128, 128, 64), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_7:0", shape=(?, 128, 128, 64), dtype=float32)
> msa.tf.ops.conv2d_transpose {'filters': 64} + defaults {'strides': (2, 2), 'kernel_size': (2, 2)} --> Tensor("decoder/conv2d_transpose_3/BiasAdd:0", shape=(?, 256, 256, 64), dtype=float32)
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_9/FusedBatchNorm:0", shape=(?, 256, 256, 64), dtype=float32)
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_8:0", shape=(?, 256, 256, 64), dtype=float32)
> msa.tf.ops.conv2d {'kernel_size': (1, 1), 'filters': 3} + defaults {'padding': 'same', 'strides': (1, 1)} --> Tensor("decoder/conv2d_4/BiasAdd:0", shape=(?, 256, 256, 3), dtype=float32)
> msa.tf.ops.tanh {'name': 'output'} --> Tensor("decoder/output:0", shape=(?, 256, 256, 3), dtype=float32)
--------------------------------------------------------------------------------
29 ops added
'''
#%%
namespaces=[
'',
'msa.tf.ops',
'tf',
'tf.layers',
'tf.nn',
'tf.contrib.layers'
]
def get_tensors_by_name(tensors, name):
'''given a list of tensors, return any tensor which has matching name'''
return filter(lambda x: name in x.name, tensors)
def get_ops_by_name(ops_info, name):
'''given a list of op info dicts, return any op dict which has matching name'''
return filter(lambda x: 'name' in x and name in x['name'], ops_info)
def build_graph(input_T, ops_info, default_op_args=None, verbose=True):
print('-'*80)
errors = []
def handle_error(msg, op_dict):
print('\n** ERROR', msg, op_dict,'\n')
errors.append( {msg : op_dict} )
t = input_T
ops = []
for op_dict in ops_info:
if type(op_dict) == dict:
if 'op' in op_dict:
op_str = op_dict['op'] # get dict for this layer
op_fn = None
fn_path = None
for namespace in namespaces:
try:
fn_path = '.'.join([namespace, op_str]) if namespace else op_str
op_fn = eval(fn_path)
break
except: pass
if op_fn:
# get op args excluding op name
args = { k:v for k,v in op_dict.items() if k!='op' }
if verbose: print('>', fn_path, args, end=' ')
extra_args = None
if default_op_args and op_str in default_op_args: # check for defaults
op_defaults = default_op_args[op_str] # defaults dict for this op type
extra_args = { k:v for k,v in op_defaults.items() if k not in args }
if extra_args:
if verbose: print('+ defaults', extra_args, end=' ')
args.update(extra_args)
try:
t = op_fn(t, **args)
print('-->', t)
ops.append(t)
except Exception as e:
handle_error(fn_path + ' : ' + str(e), op_dict)
else: # if op_fn:
handle_error('function not found', op_dict)
else: # if 'op' in op_dict:
handle_error('missing op key', op_dict)
else: # type(op_dict) == dict:
handle_error('unknown entry type', op_dict)
print('-'*80)
print('{} ops added'.format(len(ops)))
if len(errors) > 0:
print('{} errors found:'.format(len(errors)))
pprint(errors)
return ops, errors
#%%
if __name__ == "__main__":
encoder_ops, decoder_ops = example()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment