Last active
November 8, 2018 16:15
-
-
Save rvinas/9e81ae0f17e61cc2c54b63f45fb07a28 to your computer and use it in GitHub Desktop.
Solution for "Implementing a batch dependent loss in Keras" (StackOverflow)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# StackOverflow question: https://stackoverflow.com/questions/53105294/implementing-a-batch-dependent-loss-in-keras | |
from keras.utils import Sequence | |
from keras.models import Model | |
from keras.layers import Input, Dense | |
import keras.backend as K | |
import numpy as np | |
# Constants | |
input_dim = 64 # digits.data.shape[1] | |
dataset = np.random.rand(1000, input_dim) # TODO: replace with digits.data | |
prec = np.random.rand(1000) | |
dims = [40, 20, 2] | |
def MAEpw_wrapper(y_prec): | |
def MAEpw(y_true, y_pred): | |
return K.mean(K.square(y_prec * (y_pred - y_true))) | |
return MAEpw | |
class DataGenerator(Sequence): | |
def __init__(self, batch_size, y, prec, shuffle=True): | |
self.batch_size = batch_size | |
self.y = y | |
self.shuffle = shuffle | |
self.prec = prec | |
self.on_epoch_end() | |
def on_epoch_end(self): | |
self.indexes = np.arange(len(self.y)) | |
if self.shuffle: | |
np.random.shuffle(self.indexes) | |
def __len__(self): | |
return int(np.floor(len(self.y) / self.batch_size)) | |
def __getitem__(self, index): | |
indexes = self.indexes[index * self.batch_size: (index + 1) * self.batch_size] | |
y = self.y[indexes] | |
y_prec = self.prec[indexes] | |
return [y, y_prec], y | |
# Define model | |
y_input = Input(shape=(input_dim,)) | |
y_prec_input = Input(shape=(1,)) | |
h_enc = Dense(dims[0], activation='relu')(y_input) | |
h_enc = Dense(dims[1], activation='relu')(h_enc) | |
h_enc = Dense(dims[2], activation='relu', name='bottleneck')(h_enc) | |
h_dec = Dense(dims[1], activation='relu')(h_enc) | |
h_dec = Dense(input_dim, activation='relu')(h_dec) | |
model2 = Model(inputs=[y_input, y_prec_input], outputs=h_dec) | |
model2.compile(optimizer='adam', loss=MAEpw_wrapper(y_prec_input)) | |
# Train model | |
model2.fit_generator(DataGenerator(32, dataset, prec), epochs=100) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment