Skip to content

Instantly share code, notes, and snippets.

@badbye
Created April 27, 2016 07:55
Show Gist options
  • Save badbye/35a5e087aa75e4c19a38a546036891d0 to your computer and use it in GitHub Desktop.
Save badbye/35a5e087aa75e4c19a38a546036891d0 to your computer and use it in GitHub Desktop.
# encoding: utf-8
'''
Created on 2016.04.21
@author: yalei
'''
import mxnet as mx
def cnn_text_network(num_class = 3, input_shape=(20, 300), conv_kernels = [3, 4, 5], num_filter = 100, drop_prob = 0.5):
'''
Convolutional Neural Networks for Sentence Classification(http://arxiv.org/pdf/1408.5882v2.pdf)
CNN-static: use pre-trained word2vec model
'''
row_length, vec_length = input_shape
data = mx.symbol.Variable('data') # [(1000L, 1L, 20L, 300L)]
conv_layers = []
for i in conv_kernels:
conv = mx.symbol.Convolution(data = data, kernel = (i, vec_length),
num_filter=num_filter, name='%s_conv' %i) # [(1000L, 100L, 18L, 1L)]
tanh = mx.symbol.Activation(data=conv, act_type='tanh', name='%s_tanh' %i)
pool = mx.symbol.Pooling(tanh, kernel = (row_length - i + 1, 1), stride=(1,1),
pool_type = 'max', name = '%s_pool' %i) # [(1000L, 100L, 1L, 1L)]
conv_layers.append(pool)
max_pool_concat = mx.symbol.Concat(*conv_layers, name='concat_max_pool') # [(1000L, 300L, 1L, 1L)]
pool_reshape = mx.sym.Reshape(data=max_pool_concat, target_shape=(0, num_filter * len(conv_kernels))) # [(1000L, 100L)]
if drop_prob > 0:
drop_out = mx.sym.Dropout(data=pool_reshape, p=drop_prob, name = 'dropout')
else:
drop_out = pool_reshape
fc = mx.symbol.FullyConnected(data = drop_out, num_hidden = num_class, name='fc')
softmax = mx.symbol.SoftmaxOutput(fc, name='softmax')
return softmax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment