Created
August 20, 2017 05:52
-
-
Save cacaocake/4cc4d1dc2cb1f729fd6d381170dfc188 to your computer and use it in GitHub Desktop.
20% 모델
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LogisticModel(BaseModel): | |
"""Logistic model with L2 regularization.""" | |
def create_model(self, model_input, num_classes=10, l2_penalty=1e-8, **unused_params): | |
with slim.arg_scope([slim.conv2d, slim.fully_connected], | |
activation_fn=tf.nn.relu, | |
weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), | |
weights_regularizer=slim.l2_regularizer(0.0005)): | |
net = slim.repeat(model_input, 1, slim.conv2d, 32, [5, 5], scope='conv1') | |
net = slim.max_pool2d(net, [2, 2], scope='pool1') | |
net = slim.repeat(net, 2, slim.conv2d, 64, [3, 3], scope='conv2') | |
net = slim.max_pool2d(net, [2, 2], scope='pool2') | |
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv3') | |
net = slim.max_pool2d(net, [2, 2], scope='pool3') | |
net = slim.flatten(net) | |
net = slim.fully_connected(net, 1024, scope='fc8') | |
net = slim.dropout(net, 0.5, scope='dropout8') | |
output = slim.fully_connected( | |
net, num_classes, activation_fn=None, | |
weights_regularizer=slim.l2_regularizer(l2_penalty)) | |
#output = slim.dropout(net, 0.5, scope='dropoutlast') | |
return {"predictions": output} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment