Skip to content

Instantly share code, notes, and snippets.

@sjchoi86
Created March 21, 2020 04:50
Show Gist options
  • Save sjchoi86/01d8957cfdeb39d55e7dd42e8836b2ab to your computer and use it in GitHub Desktop.
Save sjchoi86/01d8957cfdeb39d55e7dd42e8836b2ab to your computer and use it in GitHub Desktop.
Desktop/Framework/Libraries/vibroptml/scripts/demo_tf_gradient.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "markdown",
"source": "### Custrom Gradient in TF"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import math, random\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nprint (tf.__version__)",
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": "1.12.0\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Define custom operation with gradient override"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "$f(x,y,z) = \\sqrt{x^4 + y^4 + z^4}$\n\n$\\frac{\\partial}{\\partial x} f(\\cdot) = \\frac{2x^3}{\\sqrt{x^4 + y^4 + z^4}}$, \n$\\frac{\\partial}{\\partial y} f(\\cdot) = \\frac{2y^3}{\\sqrt{x^4 + y^4 + z^4}}$, \n$\\frac{\\partial}{\\partial z} f(\\cdot) = \\frac{2z^3}{\\sqrt{x^4 + y^4 + z^4}}$"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "The goal is to find z from x, y, and f(x,y,z)."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def custom_func(x,y,z): # x,y,z: numpy ndarray\n n = x.shape[0]\n out = np.zeros(shape=(n,1))\n for i_idx in range(n):\n x_i,y_i,z_i = x[i_idx,0],y[i_idx,0],z[i_idx,0]\n out[i_idx,:] = math.sqrt( (x_i**4) + (y_i**4) + (z_i**4) ) \n out = out.astype(np.float32) # cast into float32\n return out\n\ndef custom_func_derivatives(x,y,z,grads): # x,y,z,grad: numpy ndarray\n out = custom_func(x,y,z)\n dx = grads*2*np.power(x,3)/out\n dy = grads*2*np.power(y,3)/out\n dz = grads*2*np.power(z,3)/out\n return dx,dy,dz\n\ndef py_func_wrapper(func, inp, Tout, stateful=True, name=None, grad=None):\n # py_func wrapper function \n rnd_name = 'custom_gradient_name' # gradient name (make sure to be unique)\n tf.RegisterGradient(rnd_name)(grad)\n g = tf.get_default_graph()\n with g.gradient_override_map({\"PyFunc\": rnd_name, \"PyFuncStateless\": rnd_name}):\n return tf.py_func(func, inp, Tout, stateful=stateful, name=name)\n \ndef grad_wrapper(op, grads):\n \"\"\"\n :param op: Operation - \n operation.inputs = [x,y,z], \n operation.outputs = [equation]\n :param grads: Gradients for equation prime\n \"\"\"\n # Following are tf tensors so we need to call tf.py_func() to use numpy-based gradients\n x,y,z,out = op.inputs[0],op.inputs[1],op.inputs[2],op.outputs[0]\n dx,dy,dz = tf.py_func(custom_func_derivatives,[x,y,z,grads],\n [tf.float32,tf.float32,tf.float32])\n return dx,dy,dz\n \ndef model(x, y, z, name=None):\n z_tile = tf.tile(tf.expand_dims(z,1),multiples=[tf.shape(x)[0],1]) # match the shape of z\n with ops.name_scope(name, \"EuDist\", [x, y, z_tile]) as name:\n out = py_func_wrapper(custom_func, # function\n [x, y, z_tile], # input\n [tf.float32], # output type of 'custom_func'\n name=name,\n grad=grad_wrapper)\n return out\nprint (\"Custom functions ready.\")",
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": "Custom functions ready.\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Training Data"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "CONSTANT = 9.87654321 # we should predict this value\n\nn_data = 100\ntrain_data = -3.0+6.0*np.random.rand(n_data,2)\ntrain_label = np.sqrt(np.power(train_data[:,0],4) + \n np.power(train_data[:,1],4) + \n np.power(CONSTANT,4) ).reshape((-1,+1))\nprint (\"train_data: %s train_label:%s\"%(train_data.shape,train_label.shape))",
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": "train_data: (100, 2) train_label:(100, 1)\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Model"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Input Placeholder for training data\nph_in = tf.placeholder(dtype=tf.float32, shape=[None, 2]) # [N x 2]\nph_out = tf.placeholder(dtype=tf.float32, shape=[None, 1]) # [N x 1]\nvar_z = tf.Variable(tf.random_uniform(shape=[1]), dtype='float32', trainable=True) # [1]\nph_x, ph_y = tf.split(ph_in, 2, axis=1)\nmodel_out = model(ph_x, ph_y, var_z)\ncost = tf.reduce_mean(tf.reduce_mean(tf.square(ph_out - model_out)))\noptimizer = tf.train.AdagradOptimizer(learning_rate=1.0).minimize(cost)\nprint (\"Model ready.\")",
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": "Model ready.\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Run"
},
{
"metadata": {
"trusted": true,
"scrolled": true
},
"cell_type": "code",
"source": "init = tf.global_variables_initializer()\nwith tf.Session() as sess:\n # Run Initializer\n sess.run(init)\n # Epoch Training\n max_epoch = 10000\n for e in range(max_epoch):\n c, _, ed_layer_val, var_z_val = \\\n sess.run([cost, optimizer, model_out, var_z], \n feed_dict={\n ph_in: train_data,\n ph_out: train_label})\n if (e==0) or (((e+1)%500)==0):\n print (\"Epoch:[%d/%d] cost:[%.3e] estimate:[%f] answer:[%f] err:[%.3e]\"%\n (e+1,max_epoch,c,var_z_val[0],CONSTANT,np.abs(var_z_val[0]-CONSTANT)))\nprint (\"Done.\")",
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": "Epoch:[1/10000] cost:[8.612e+03] estimate:[1.503146] answer:[9.876543] err:[8.373e+00]\nEpoch:[500/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[1000/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[1500/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[2000/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[2500/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[3000/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[3500/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[4000/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[4500/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[5000/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[5500/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[6000/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[6500/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[7000/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[7500/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[8000/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[8500/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[9000/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[9500/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nEpoch:[10000/10000] cost:[1.718e-09] estimate:[9.876541] answer:[9.876543] err:[2.072e-06]\nDone.\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.6.7",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"gist": {
"id": "",
"data": {
"description": "Desktop/Framework/Libraries/vibroptml/scripts/demo_tf_gradient.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment