Skip to content

Instantly share code, notes, and snippets.

@NiharG15
Last active January 17, 2018 18:42
Show Gist options
  • Save NiharG15/28275e0a89355543e2a758c735b3c23c to your computer and use it in GitHub Desktop.
Save NiharG15/28275e0a89355543e2a758c735b3c23c to your computer and use it in GitHub Desktop.
A tensorpack.dataflow based image data generator for pandas DataFrames

A data generator for pandas DataFrames

This is a tensorpack.dataflow based data generator that takes a pandas DataFrame as input. The DataFrame has one or more columns that act as path components and a label column.

  • path_func:
    This function takes a pandas DataFrame row as an input and generates corresponding image path.

Example Usage:

    df = DFBaseDataFlow(df=data_frame,
                        path_func=lambda row: (base_path + row.filename),
                        label_column='class',
                        resize=(224, 224))
    
    generator = df.get_data()

Usage combined with other tensorpack.dataflow classes:

    df = DFBaseDataFlow(df=data_frame,
                        path_func=lambda row: (base_path + row.filename),
                        label_column='class',
                        resize=(224, 224))
    
    df1 = BatchData(df, 32)
    df2 = PrefetchData(df1, nr_prefetch=4, nr_proc=1)

    generator = df2.get_data()
''' An image data generator based on tensorpack dataflow to generate training data using pandas DataFrames '''
import copy
import pandas as pd
from tensorpack.dataflow import DataFlow, ImageFromFile
class DFBaseDataFlow(DataFlow):
"""
A generic class to generate image data from pandas DataFrame.
The data frame consists of one or more columns that make up the file path and a label column.
"""
def __init__(self, df, path_func, label_column, resize=None):
"""
Parameters:
df: a pandas DataFrame,
path_func: function that generates path from a dataframe row,
label_column: Column in df that contains labels,
resize: tuple (h, w), resize all images to this size.
"""
self.data_frame = copy.deepcopy(df)
self.path_func = path_func
self.label_column = label_column
# Generate base dataflow using ImageFromFile
self.file_paths = [self.path_func(row) for index, row in self.data_frame.iterrows()]
self.imf = ImageFromFile(files=self.file_paths, resize=resize)
self.labels = self.data_frame[label_column].values
self.gen = self.imf.get_data()
self.index = -1
def reset_state(self):
"""
Reset the base data source (ImageFromFile)
"""
self.imf.reset_state()
def size(self):
"""
Returns the size of the data
"""
return len(self.file_paths)
def get_data(self):
"""
Returns the python generator object associated with the dataframe.
Each next() call will return a list [image, label].
"""
while True:
self.gen = self.imf.get_data()
self.index = -1
for img in self.gen:
self.index = self.index + 1
yield [img[0], self.labels[self.index]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment