Skip to content

Instantly share code, notes, and snippets.

@usmcamp0811
Last active March 6, 2017 22:02
Show Gist options
  • Save usmcamp0811/db7feacca9a8f2894eb92b0a6195206d to your computer and use it in GitHub Desktop.
Save usmcamp0811/db7feacca9a8f2894eb92b0a6195206d to your computer and use it in GitHub Desktop.
A class object that takes in a dataframe or array and return batches as called. This will also return a onehot encoded target variable if provided
import numpy as np
class batcher(object):
def __init__(self, data, batch_size, target=None):
self.data = data
self.batch_size = batch_size
self.batch_n = 0
self.n_batches =int(data.shape[0]/batch_size)
self.target = target
if target != None:
self.x_cols = list(self.data.columns)
self.x_cols.remove(self.target)
self.y = pd.get_dummies(self.data[target])
self.y_cols = list(self.y.columns)
self.data = pd.concat([self.data[self.x_cols], self.y], axis=1)
del self.y
def train_batch(self, data, batch_size, batch_n):
n_batches = int(data.shape[0]/batch_size)
data_batches = np.array_split(data, n_batches)
if self.batch_n > len(data_batches):
self.batch_n = 0
return data_batches[batch_n]
def batch(self):
if self.batch_n >= self.n_batches:
self.batch_n = 0
batch = self.train_batch(self.data, self.batch_size, self.batch_n)
self.batch_n += 1
if self.target != None:
y = batch[self.y_cols]
x = batch[self.x_cols]
return x, y
else:
return batch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment