Skip to content

Instantly share code, notes, and snippets.

@XinDongol
Created November 17, 2018 04:48
Show Gist options
  • Save XinDongol/d17507b63646de419792b167cb65ed17 to your computer and use it in GitHub Desktop.
Save XinDongol/d17507b63646de419792b167cb65ed17 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import matplotlib as mpl
import matplotlib.pyplot as plt
from math import sqrt
%matplotlib inline
import seaborn as sns
#sns.set_palette(sns.color_palette("cubehelix"))
sns.set_palette(sns.color_palette("coolwarm",9))
config = tf.ConfigProto(
device_count = {'GPU': 0}
)
sess = tf.Session(config=config)
def scale_tanh(x, x_scale, y_scale):
# scale tanh alone x-axis and y-axis
return (y_scale*tf.tanh(x_scale*x))
def move_scaled_tanh(x, x_scale, y_scale, x_range, x_move, y_move):
# move the scaled tanh along x-axis and y-axis
return tf.clip_by_value(scale_tanh(x+x_move, x_scale, y_scale ),-0.5*x_range,0.5*x_range)+y_move
#* \
#tf.to_float((x+x_move)>=-0.5*x_range) *\
#tf.to_float((x+x_move)<0.5*x_range)
def tanh_appro(x, x_scale, y_scale, k, delta):
y=0
#for i in range(1,2**k):
y += move_scaled_tanh(x, x_scale, y_scale, delta, (-1+0.5)*delta, (0.5)*delta)
delta2 = 0.1 + delta
y_scale2 = (0.5*delta2)/tf.tanh(x_scale*0.5*delta2)
y += move_scaled_tanh(x, x_scale, y_scale2, delta2, -0.5*delta2-delta, (0.5)*delta2)
#i=1
#y = move_scaled_tanh(x, x_scale, y_scale, delta, (-i+0.5)*delta, (0.5)*delta)
return y
def quantize(x, k, x_scale):
delta = float(1./(2**k-1.))
y_scale = (0.5*delta)/tf.tanh(x_scale*0.5*delta)
#print(delta,minv,maxv)
@tf.custom_gradient
def _quantize(x):
return tanh_appro(x, x_scale, y_scale, k, delta), lambda dy: dy
return _quantize(x)
from mpl_toolkits.axes_grid.axislines import SubplotZero
fig = plt.figure(figsize=(6.9, 6.9*0.9) ) # for one column
ax = SubplotZero(fig, 111)
fig.add_subplot(ax)
for direction in ["xzero", "yzero"]:
ax.axis[direction].set_axisline_style("-|>")
ax.axis[direction].set_visible(True)
for direction in ["left", "right", "bottom", "top"]:
ax.axis[direction].set_visible(False)
plt.rcParams['image.cmap'] = 'Blues'
fa = quantize
bit = 2
#plt.plot(sess.run(x),sess.run(fa(x,bit,50000000)), alpha=1, linewidth=2)
plt.xlim(-0.1,0.85)
plt.ylim(-0.1,0.85)
plt.tick_params(labelsize=50)
plt.xticks([0,0.2,0.4,0.6,0.8], ['','0.2','0.4','0.6','0.8'])
plt.yticks([0,0.2,0.4,0.6,0.8], ['','0.2','0.4','0.6','0.8'])
x = tf.range(-0.1,1.5,0.01)
plt.plot(sess.run(x),sess.run(x), linestyle='--', label='y=x', alpha=1, linewidth=2, )
x = tf.range(0,1,0.01)
plt.plot(sess.run(x),sess.run(fa(x,bit,3)),label='b=3', alpha=0.99)
plt.plot(sess.run(x),sess.run(fa(x,bit,5)),label='b=5', alpha=0.99)
plt.plot(sess.run(x),sess.run(fa(x,bit,10)),label='b=10', alpha=0.99)
plt.plot(sess.run(x),sess.run(fa(x,bit,15)),label='b=15', alpha=0.99)
plt.plot(sess.run(x),sess.run(fa(x,bit,20)),label='b=20', alpha=0.99)
plt.plot(sess.run(x),sess.run(fa(x,bit,30)),label='b=30', alpha=0.99)
plt.plot(sess.run(x),sess.run(fa(x,bit,45)),label='b=45', alpha=0.99)
plt.plot(sess.run(x),sess.run(fa(x,bit,1000)),label='b=1000', alpha=1, linewidth=2)
#plt.plot(sess.run(x),sess.run(fa(x,4,500)),label='4-bit', alpha=0.5)
#plt.title('Forward')
#plt.tight_layout()
plt.tight_layout()
#plt.grid()
plt.legend(bbox_to_anchor=(1, 1),ncol=5)
plt.savefig('./warm_bin_grad.pdf')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment