Skip to content

Instantly share code, notes, and snippets.

@socket-var
Last active August 1, 2019 16:31
Show Gist options
  • Save socket-var/b995f175ba6b25fea78b2049e7eca88a to your computer and use it in GitHub Desktop.
Save socket-var/b995f175ba6b25fea78b2049e7eca88a to your computer and use it in GitHub Desktop.
Generates random mini-batches
def batch(X,Y,batch_size,seed = 0):
np.random.seed(seed)
m = X.shape[0] #total no of images
batches = []
# Shuffle data
perm = list(np.random.permutation(m))
shuffled_X = X[perm,:]
shuffled_Y = Y[perm,:]
# Partition (shuffled_X, shuffled_Y)
num_minibatches = math.floor(m/batch_size) # number of mini batches of required size in our partitioning
for k in range(0, num_minibatches):
mini_batch_X = shuffled_X[(batch_size*k):(batch_size*(k+1)),:]
mini_batch_Y = shuffled_Y[(batch_size*k):(batch_size*(k+1)),:]
mini_batch = (mini_batch_X, mini_batch_Y)
batches.append(mini_batch)
# Handling the end case (if size (last mini-batch) < batch_size)
if m % batch_size != 0:
mini_batch_X = shuffled_X[(batch_size*(k+1)):m, :]
mini_batch_Y = shuffled_Y[(batch_size*(k+1)):m, :]
mini_batch = (mini_batch_X, mini_batch_Y)
batches.append(mini_batch)
return batches
for i in range(iter):
minibatch_loss = 0
num_batches = int(m/Batch_size)
seed+=1 #to ensure the shuffling doesn't happen using the same permutation for all iterations
minibatches = batch(X_train,y_train,Batch_size,seed) #getting (m/Batch_size) minibatches of size Batch_size
# iterating over all minibatches
for minibatch in minibatches:
(X_mb,y_mb) = minibatch
# evaluate loss for current iteration
_,temp_loss = sess.run([optimizer,cost],
feed_dict = {X:X_mb,y:y_mb,keep_prob:drop_prob})
minibatch_loss += temp_loss/num_batches
# print accuracy and loss every 10 iterations
losses.append(temp_loss)
train_accuracy = accuracy.eval(
feed_dict = {X:X_mb,y:y_mb,keep_prob:1.0})
print('Epoch %d, Loss %g, Training Accuracy %g'
%(i,loss,train_accuracy))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment