Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save rvinas/9e81ae0f17e61cc2c54b63f45fb07a28 to your computer and use it in GitHub Desktop.
Save rvinas/9e81ae0f17e61cc2c54b63f45fb07a28 to your computer and use it in GitHub Desktop.
Solution for "Implementing a batch dependent loss in Keras" (StackOverflow)
# StackOverflow question: https://stackoverflow.com/questions/53105294/implementing-a-batch-dependent-loss-in-keras
from keras.utils import Sequence
from keras.models import Model
from keras.layers import Input, Dense
import keras.backend as K
import numpy as np
# Constants
input_dim = 64 # digits.data.shape[1]
dataset = np.random.rand(1000, input_dim) # TODO: replace with digits.data
prec = np.random.rand(1000)
dims = [40, 20, 2]
def MAEpw_wrapper(y_prec):
def MAEpw(y_true, y_pred):
return K.mean(K.square(y_prec * (y_pred - y_true)))
return MAEpw
class DataGenerator(Sequence):
def __init__(self, batch_size, y, prec, shuffle=True):
self.batch_size = batch_size
self.y = y
self.shuffle = shuffle
self.prec = prec
self.on_epoch_end()
def on_epoch_end(self):
self.indexes = np.arange(len(self.y))
if self.shuffle:
np.random.shuffle(self.indexes)
def __len__(self):
return int(np.floor(len(self.y) / self.batch_size))
def __getitem__(self, index):
indexes = self.indexes[index * self.batch_size: (index + 1) * self.batch_size]
y = self.y[indexes]
y_prec = self.prec[indexes]
return [y, y_prec], y
# Define model
y_input = Input(shape=(input_dim,))
y_prec_input = Input(shape=(1,))
h_enc = Dense(dims[0], activation='relu')(y_input)
h_enc = Dense(dims[1], activation='relu')(h_enc)
h_enc = Dense(dims[2], activation='relu', name='bottleneck')(h_enc)
h_dec = Dense(dims[1], activation='relu')(h_enc)
h_dec = Dense(input_dim, activation='relu')(h_dec)
model2 = Model(inputs=[y_input, y_prec_input], outputs=h_dec)
model2.compile(optimizer='adam', loss=MAEpw_wrapper(y_prec_input))
# Train model
model2.fit_generator(DataGenerator(32, dataset, prec), epochs=100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment