Skip to content

Instantly share code, notes, and snippets.

@shtern
Created July 30, 2019 19:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shtern/25903f0fd503ce6879b23daec67981ec to your computer and use it in GitHub Desktop.
Save shtern/25903f0fd503ce6879b23daec67981ec to your computer and use it in GitHub Desktop.
import tensorflow as tf
import numpy as np
class AlexNet(object):
def __init__(self, x, keep_prob, num_classes, skip_layer,
weights_path='DEFAULT'):
# Parse input arguments into class variables
self.X = x
self.NUM_CLASSES = num_classes
self.KEEP_PROB = keep_prob
self.SKIP_LAYER = skip_layer
if weights_path == 'DEFAULT':
self.WEIGHTS_PATH = 'bvlc_alexnet.npy'
else:
self.WEIGHTS_PATH = weights_path
# Call the create function to build the computational graph of AlexNet
self.create()
# =====================================================================================================================
# create
# =====================================================================================================================
def create(self):
# 1st Layer: Conv -> ReLu -> Pool -> Lrn
conv1 = conv(self.X, [11, 11], [4, 4], 96, 'conv1', 'VALID')
pool1 = max_pool(conv1, [3, 3], [2, 2], 'pool1')
norm1 = lrn(pool1, 2, 2e-05, 0.75, 'norm1')
# 2nd Layer: Conv (use groups =2) -> ReLu -> Pool -> Lrn
conv2 = conv(norm1, [5, 5], [1, 1], 256, 'conv2', 'SAME', groups=2)
pool2 = max_pool(conv2, [3, 3], [2, 2], 'pool2')
norm2 = lrn(pool2, 2, 2e-05, 0.75, name='norm2')
# 3rd Layer: Conv -> ReLu
conv3 = conv(norm2, [3, 3], [1, 1], 384, 'conv3', 'SAME')
# 4th Layer: Conv (use groups =2) -> ReLu
conv4 = conv(conv3, [3, 3], [1, 1], 384, 'conv4', 'SAME', groups=2)
# 5th Layer: Conv (use groups =2) -> ReLu
conv5 = conv(conv4, [3, 3], [1, 1], 256, 'conv5', 'SAME', groups=2)
pool5 = max_pool(conv5, [3, 3], [2, 2], 'pool5')
# 6th Layer: Flatten -> FC -> ReLu -> Dropout
flattened = tf.reshape(pool5, [-1, 6 * 6 * 256])
fc6 = fc(flattened, 6*6*256, 4096, 'fc6')
dropout6 = tf.nn.dropout(fc6, self.KEEP_PROB)
dropout6r = tf.nn.relu(dropout6)
# 7th Layer: FC -> ReLu -> Dropout
fc7 = fc(dropout6r, 4096, 4096, 'fc7')
fc7r = tf.nn.relu(fc7)
dropout7 = tf.nn.dropout(fc7r, self.KEEP_PROB)
# 8th Layer: FC and return unscaled activations
self.fc8 = fc(dropout7, 4096, self.NUM_CLASSES, layer_name='fc8')
# =====================================================================================================================
# load_initial_weights
# =====================================================================================================================
def load_initial_weights(self, session):
# Load the weights into memory
weights_dict = np.load(self.WEIGHTS_PATH, encoding='bytes').item()
# Loop over all layer names stored in the weights dict
for op_name in weights_dict:
# Check if the layer is one of the layers that should be reinitialized
if op_name not in self.SKIP_LAYER:
with tf.variable_scope(op_name, reuse=True):
# Loop over list of weights/biases and assign them to their corresponding tf variable
for data in weights_dict[op_name]:
# Biases
if len(data.shape) == 1:
var = tf.get_variable('biases', trainable=False)
session.run(var.assign(data))
# Weights
else:
var = tf.get_variable('weights', trainable=False)
session.run(var.assign(data))
# =====================================================================================================================
# conv
# =====================================================================================================================
def conv(x, kernel_size, strides, num_filters, layer_name,
padding, groups=1):
"""
"""
# Get number of input channels
input_channels = int(x.get_shape()[-1])
# Create lambda function for the convolution
convolve = lambda i, k: tf.nn.conv2d(i, k,
strides=[1, strides[0], strides[1], 1],
padding=padding)
with tf.variable_scope(layer_name) as scope:
weights = tf.get_variable('weights', shape=[kernel_size[0], kernel_size[1], input_channels / groups, num_filters])
biases = tf.get_variable('biases', shape=[num_filters])
if groups == 1:
conv = convolve(x, weights)
else:
input_groups = tf.split(axis=3, num_or_size_splits=groups, value=x)
weight_groups = tf.split(axis=3, num_or_size_splits=groups, value=weights)
output_groups = [convolve(i, k) for i, k in zip(input_groups, weight_groups)]
# Concat the convolved output together again
conv = tf.concat(axis=3, values=output_groups)
# Add biases
bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape().as_list())
# Apply relu function
relu = tf.nn.relu(bias, name=scope.name)
return relu
# =====================================================================================================================
# FC
# =====================================================================================================================
def fc(x, num_in, num_out, layer_name):
with tf.variable_scope(layer_name) as scope:
weights = tf.get_variable('weights', shape=[num_in, num_out], trainable=True)
biases = tf.get_variable('biases', [num_out], trainable=True)
act = tf.nn.xw_plus_b(x, weights, biases, name=scope.name)
return act
# =====================================================================================================================
# max_pool
# =====================================================================================================================
def max_pool(x, kernel_size, strides, layer_name):
return tf.nn.max_pool(x, ksize=[1, kernel_size[0], kernel_size[1], 1],
strides=[1, strides[0], strides[1], 1],
padding='VALID', name=layer_name)
# =====================================================================================================================
# lrn
# =====================================================================================================================
def lrn(x, radius, alpha, beta, name, bias=1.0):
return tf.nn.local_response_normalization(x, depth_radius=radius, alpha=alpha,
beta=beta, bias=bias, name=name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment