Created
May 16, 2017 22:16
-
-
Save Flandan/9f9e4ce0c46b67d17f5d74423eef264c to your computer and use it in GitHub Desktop.
A modification to alexnet to enable training on analogue input and output data (for learning steering behaviour as an example).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# alexnet.py | |
""" AlexNet. | |
References: | |
- Alex Krizhevsky, Ilya Sutskever & Geoffrey E. Hinton. ImageNet | |
Classification with Deep Convolutional Neural Networks. NIPS, 2012. | |
Links: | |
- [AlexNet Paper](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf) | |
""" | |
import tflearn | |
from tflearn.layers.conv import conv_2d, max_pool_2d | |
from tflearn.layers.core import input_data, dropout, fully_connected | |
from tflearn.layers.estimator import regression | |
from tflearn.layers.normalization import local_response_normalization | |
import tensorflow as tf | |
# Loss Function | |
# pretty sure this is equivalent to regression(loss='mean_square') | |
def dist_sq(prediction, target): | |
return tf.reduce_mean(tf.square(prediction - target), name='dist_sq') | |
def alexnet(width, height, lr): | |
network = input_data(shape=[None, height, width, 1], name='input') | |
network = conv_2d(network, 96, 11, strides=4, activation='relu') | |
network = max_pool_2d(network, 3, strides=2) | |
network = local_response_normalization(network) | |
network = conv_2d(network, 256, 5, activation='relu') | |
network = max_pool_2d(network, 3, strides=2) | |
network = local_response_normalization(network) | |
network = conv_2d(network, 384, 3, activation='relu') | |
network = conv_2d(network, 384, 3, activation='relu') | |
network = conv_2d(network, 256, 3, activation='relu') | |
network = max_pool_2d(network, 3, strides=2) | |
network = local_response_normalization(network) | |
network = fully_connected(network, 4096, activation='tanh') | |
network = dropout(network, 0.5) | |
network = fully_connected(network, 4096, activation='tanh') | |
network = dropout(network, 0.5) | |
network = fully_connected(network, 2, activation='linear') | |
network = regression(network, optimizer='adam', | |
loss=dist_sq, | |
learning_rate=lr, name='targets') | |
model = tflearn.DNN(network, checkpoint_path='model_alexnet', | |
max_checkpoints=1, tensorboard_verbose=0, tensorboard_dir='log') | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment