Skip to content

Instantly share code, notes, and snippets.

@iranroman
iranroman / mask_broadcast_multiply_output_along_batch_keras
Created April 21, 2022 01:43
DataGenerator to broadcast-multiply the model output with a mask designed for the entire mini batch in Keras
import tensorflow.keras as tfk
import numpy as np
class DataGenerator(tfk.utils.Sequence):
'Generates data for Keras'
def __init__(self, batches_per_epoch, X, masks, Y):
'Initialization'
self.batches_per_epoch = batches_per_epoch
self.X = X
self.masks = masks