Skip to content

Instantly share code, notes, and snippets.

@da-steve101
Created January 16, 2020 01:20
Show Gist options
  • Save da-steve101/d7da4a3cf8a014e3e78d4f65a5e13d3f to your computer and use it in GitHub Desktop.
Save da-steve101/d7da4a3cf8a014e3e78d4f65a5e13d3f to your computer and use it in GitHub Desktop.
# this function is used for quantizing activations
def quantize( zr, k ): # zr => number to quantize, k => number of bits to use
scaling = tf.cast( tf.pow( 2.0, k ) - 1, tf.float32 )
return tf.round( scaling * zr )/scaling # round the number to the nearest quantize value
# this function applies quantization to activations
def shaped_relu( x, k = 1.0 ): # x => number to be quantized, k => number of bits to use
act = tf.clip_by_value( x, 0, 1 ) # clip the activation between 0 and 1 to stop overflow issues
quant = quantize( act, k ) # quantize the value
return act + tf.stop_gradient( quant - act ) # use the stop gradient trick
# tf.stop_gradient(quant - act) = quant - act; on forward path
# = 0; on the backward path
# so returns 'quant' on forward and 'act' on back
# use the TWN method
def trinarize( x, nu = 1.0 ): # x => the weights to trinarize, nu => the sparsity factor
x_shape = x.get_shape()
thres = nu * tf.reduce_mean(tf.abs(x)) # calculate the threshold
g_e = tf.cast( tf.greater_equal( x, thres ), tf.float32 ) # if x >= thres
l_e = tf.cast( tf.less_equal( x, -thres ), tf.float32 ) # if x <= thres
unmasked = tf.multiply( g_e + l_e, x ) # if x >= thres or x <= thres multiply with 1, otherwise multiply with 0
# unmasked now has all 0's set correctly
eta = tf.reduce_mean( tf.abs( unmasked ) ) # determine the average magnitude of the remaining weights
t_x = tf.multiply( l_e, -eta ) # every weight that had x <= thres is now set to -eta
t_x = t_x + tf.multiply( g_e, eta ) # add in every weight that had x >= thres and set them to eta
return x + tf.stop_gradient( t_x - x ) # use the stop gradient trick to quantize on the forward path and back propogate to the real weights
# create a convolutional layer
def get_conv_layer( x, training, no_filt = 128, nu = None, act_prec = None ):
'''
x => the input activations
training => a boolean flag to indicate training
no_filt => the number of filters for the convolution
nu => the sparsity factor for TWN, if set to None then the weights are not quantized
act_prec = the number of bits to quantize the activations, if set to None there is no quantization
'''
if nu is None: # call another function not included here if no TWN
return get_conv_layer_full_prec( x, training, no_filt )
filter_shape = [ 3, x.get_shape()[-1], no_filt ] # determine the correct size of the convolution
conv_filter = tf.get_variable( "conv_filter", filter_shape ) # make a variable to store the real valued weights
tf.summary.histogram( "conv_filter_fp", conv_filter ) # add for debugging
conv_filter = q.trinarize( conv_filter, nu = nu ) # call TWN quantization
cnn = tf.nn.conv1d( x, conv_filter, 1, padding = "SAME" ) # use the quantized weights to compute the convolution
cnn = tf.layers.max_pooling1d( cnn, 2, 2 ) # add max pooling
cnn = tf.layers.batch_normalization( cnn, training = training ) # add batch normalization
tf.summary.histogram( "conv_dist", cnn ) # for debugging
tf.summary.histogram( "conv_filter_tri", conv_filter ) # for debugging
if act_prec is not None: # if quantize the activations
cnn = q.shaped_relu( cnn, act_prec ) # apply quantization
else:
cnn = tf.nn.relu( cnn )
return cnn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment