Skip to content

Instantly share code, notes, and snippets.

Last active July 16, 2021 14:39
Show Gist options
  • Save danyashorokh/10e4ecad1f2e0a84b0050740a207a4f7 to your computer and use it in GitHub Desktop.
Save danyashorokh/10e4ecad1f2e0a84b0050740a207a4f7 to your computer and use it in GitHub Desktop.
[KERAS] Feature extractors
from keras.preprocessing import image
from keras.applications.vgg16 import VGG16, preprocess_input as preprocess_input_vgg
from keras.applications.inception_v3 import InceptionV3, preprocess_input as preprocess_input_inception
from keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input as preprocess_input_mobilenet
from keras.applications.resnet import ResNet50, preprocess_input as preprocess_input_resnet
from keras.models import Model, load_model
from keras.layers import Input, GlobalAveragePooling2D, GlobalMaxPooling2D
import numpy as np
import cv2
class FeatureExtractor():
def __init__(self, model_name='mobilenet_v2', weights=None, include_top=False, add_max=False,
add_avg=False, get_layer=None, convert_to_rgb=False):
"""Either load pretrained from imagenet, or load our saved
weights from our own training."""
self.model_name = model_name
self.weights = weights
self.include_top = include_top
self.get_layer = get_layer
self.convert_to_rgb = convert_to_rgb
self.add_max = add_max
self.add_avg = add_avg
# dict with load function, input size and preprocess input function
self.models = {
'mobilenet_v2': [MobileNetV2, (224, 224), preprocess_input_mobilenet],
'vgg19': [VGG16, (224, 224), preprocess_input_vgg],
'resnet50': [ResNet50, (224, 224), preprocess_input_resnet],
'inception_v3': [InceptionV3, (299, 299), preprocess_input_inception],
# check model name
if self.model_name not in self.models.keys():
raise ValueError(f'Unknown model {self.model_name}. Use one of {self.models.keys()}')
# check add layers flags:
if self.add_max and self.add_max:
raise ValueError(f'Both flags (max, avg) are True')
self.get_model = self.models[self.model_name][0]
self.target_size = self.models[self.model_name][1]
self.preprocess_input = self.models[self.model_name][2]
if weights is None:
base_model = self.get_model(weights='imagenet', include_top=self.include_top)
base_model = load_model(weights)
# Get custom layer layer.
if self.get_layer is not None:
base_model = Model(
# get output from base model
x = base_model.output
if self.add_avg:
# add a global spatial average pooling layer
x = GlobalAveragePooling2D()(x)
if self.add_max:
# add a global spatial max pooling layer
x = GlobalMaxPooling2D()(x)
# define output
outputs = x
# this is the model we will train
self.model = Model(inputs=base_model.input, outputs=outputs)
def extract(self, image_path):
if type(image_path) == str:
img = image.load_img(image_path, target_size=self.target_size)
x = image.img_to_array(img)
x = cv2.resize(image_path, self.target_size)
if self.convert_to_rgb:
x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
# x = x[:, :, ::-1]
x = np.expand_dims(x, axis=0)
x = self.preprocess_input(x)
features = self.model.predict(x)[0]
return features
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment