Skip to content

Instantly share code, notes, and snippets.

@changkun
Created December 30, 2017 17:00
Show Gist options
  • Save changkun/56a11eab9dee138de1180b733c9a4c10 to your computer and use it in GitHub Desktop.
Save changkun/56a11eab9dee138de1180b733c9a4c10 to your computer and use it in GitHub Desktop.
DARC1 Loss on MNIST, Keras Implementation
from keras.datasets import mnist
from keras.utils import to_categorical
from keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Activation, BatchNormalization
from keras.models import Model
from keras import backend as K
# 1. load data
def load_data():
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
return (x_train, y_train), (x_test, y_test)
# 2. define model
def base_model():
inputs = Input(shape=(28, 28, 1))
x = Conv2D(8, (5, 5))(inputs)
x = Activation(activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = BatchNormalization()(x)
x = Conv2D(16, (5, 5))(x)
x = Activation(activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = BatchNormalization()(x)
x = Flatten()(x)
x = Dense(10)(x)
intermediate = x
y = Activation('softmax')(x)
model = Model(inputs=inputs, outputs=y)
return model, intermediate
# 3. define loss
def darc1_loss(intermediate, lamb=0.01):
def _loss(y_true, y_pred):
original_loss = K.categorical_crossentropy(y_true, y_pred)
custom_loss = lamb*K.max(K.sum(K.abs(intermediate), axis=0))
return original_loss + custom_loss
return _loss
# 4. traning and evaluation
(x_train, y_train), (x_test, y_test) = load_data()
model, intermediate = base_model()
model.summary()
model.compile(
optimizer='adam',
loss=darc1_loss(intermediate),
# loss='categorical_crossentropy',
metrics=['accuracy']
)
model.fit(
x_train, y_train,
batch_size=128, epochs=3, validation_split=0.2
)
results = model.evaluate(x_test, y_test)
print('test acc: ', results[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment