Skip to content

Instantly share code, notes, and snippets.

@KUNAL1612
Created November 1, 2018 09:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save KUNAL1612/b854297bfee902ac5292af76d45026bb to your computer and use it in GitHub Desktop.
Save KUNAL1612/b854297bfee902ac5292af76d45026bb to your computer and use it in GitHub Desktop.
Step by step guide to practical transfer learning
from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.models import Sequential,Model
from keras import backend as k
from keras.callbacks import ModelCheckpoint,LearningRateScheduler,TensorBoard,EarlyStopping
from keras.layers import GlobalAveragePooling2D,Flatten,Dense,Dropout
img_width,img_height=256,256
train_data_directory='data/train'
val_data_directory='data/val'
n_train=4125
n_validation_samples=466
batch_size=16
epochs=50
model=applications.VGG19(weights='imagenet',include_top=False,input_shape=(img_width,img_height,3))
#frreeze the layers we don't want
for layer in model.layers[:5]:
layer.trainable=False
#add custom layers
x=model.output
x=Flatten(x)
x=Dense(1024,activation='relu')(x)
x=Dropout(0.5)
x=Dense(1024,activation='relu')(x)
predictions=Dense(16,activation='softmax')(x)
model_final=Model(input=model.input,output=predictions)
model_final.compile(loss='categorical_crossentropy',optimizer=optimizers.SGD(lr=0.001,momentum=0.9),metrics=['accuracy'])
train_datagen=ImageDataGenerator(
rescale=1./255,
horizontal_flip=True,
zoom_range=0.3,
fill_mode='nearest'
)
train_generator=train_datagen.flow_from_directory(
train_data_directory,
target_size=(img_width,img_height),
batch_size=batch_size,
class_mode='categorical'
)
validation_generator=train_datagen(
val_data_directory,
target_size=(img_width,img_height),
class_mode='categorical'
)
checkpoint=ModelCheckpoint("vgg_1.h5",monitor='val_acc',verbose=1,save_best_only=True,save-save_weights_only=False,mode='auto',period=1)
early=EarlyStopping(monitor='val_acc',min_delta=0,patience=10,verbose=1,mode='auto')
model_final.fit_generator(
train_generator,
samples_per_epoch=n_train,
epochs=epochs,
validation_data=validation_generator,
nb_val_samples=n_validation_samples,
callbacks=[checkpoint,early]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment