Skip to content

Instantly share code, notes, and snippets.

@idleuncle
Last active October 27, 2019 14:12
Show Gist options
  • Save idleuncle/edbea1d8c178194170f6f8f664658eb6 to your computer and use it in GitHub Desktop.
Save idleuncle/edbea1d8c178194170f6f8f664658eb6 to your computer and use it in GitHub Desktop.
[Activation Functions]
# https://github.com/ShahariarRabby/Mnist_cnn_Swish
from keras import backend as K
from keras.layers import Activationfrom
keras.utils.generic_utils import get_custom_objects
def swish(x):
return (K.sigmoid(x) * x)
get_custom_objects().update({'swish': swish})
#Now just add Swish as an activation
model.add(Conv2D(filters = 32, kernel_size = (5,5),padding = ‘Same’,
activation =’swish’, input_shape = (28,28,1)))
#And last layer as sigmoid
model.add(Dense(10, activation = "sigmoid"))
# https://github.com/ChingChuan-Chen/keras_swish_beta/blob/master/swishBeta.py
import keras
from keras import backend as K
from keras.datasets import mnist
from keras.layers import Dense, Dropout, Activation
from keras.layers import Conv2D, MaxPooling2D, GlobalAveragePooling1D
from keras.layers import BatchNormalization
from keras.layers import initializers, InputSpec
from keras.models import Sequential
from keras.utils import multi_gpu_model
from keras.engine.topology import Layer
class SwishBeta(Layer):
def __init__(self, trainable_beta = False, beta_initializer = 'ones', **kwargs):
super(SwishBeta, self).__init__(**kwargs)
self.supports_masking = True
self.trainable = trainable_beta
self.beta_initializer = initializers.get(beta_initializer)
def build(self, input_shape):
self.beta = self.add_weight(shape=[1], name='beta',
initializer=self.beta_initializer)
self.input_spec = InputSpec(ndim=len(input_shape))
self.built = True
def call(self, inputs):
return inputs * K.sigmoid(self.beta * inputs)
def get_config(self):
config = {'trainable_beta': self.trainable_beta,
'beta_initializer': initializers.serialize(self.beta_initializer)}
base_config = super(SwishBeta, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment