Created
October 24, 2018 15:46
-
-
Save muhammadgaffar/2fd5abc22da4c597c20da7dfaa262d7c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def random_mini_batches(X, Y, mini_batch_size = 64, seed = 0): | |
m = X.shape[1] # number of training examples | |
mini_batches = [] | |
# Step 1: Shuffle (X, Y) | |
permutation = list(np.random.permutation(m)) | |
shuffled_X = X[:, permutation] | |
shuffled_Y = Y[:, permutation].reshape((1,m)) | |
# Step 2: Partition (shuffled_X, shuffled_Y). Minus the end case. | |
num_complete_minibatches = math.floor(m/mini_batch_size) # number of mini batches of size mini_batch_size in your partitioning | |
for k in range(0, num_complete_minibatches): | |
mini_batch_X = shuffled_X[:,k*mini_batch_size : (k+1)*mini_batch_size] | |
mini_batch_Y = shuffled_Y[:,k*mini_batch_size : (k+1)*mini_batch_size] | |
mini_batch = (mini_batch_X, mini_batch_Y) | |
mini_batches.append(mini_batch) | |
# Handling the end case (last mini-batch < mini_batch_size) | |
if m % mini_batch_size != 0: | |
mini_batch_X = shuffled_X[:,m-mini_batch_size*num_complete_minibatches:m] | |
mini_batch_Y = shuffled_Y[:,m-mini_batch_size*num_complete_minibatches:m] | |
mini_batch = (mini_batch_X, mini_batch_Y) | |
mini_batches.append(mini_batch) | |
return mini_batches |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment