Skip to content

Instantly share code, notes, and snippets.

@workflow
Last active September 27, 2018 20:58
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save workflow/294c3cc2c202e196a2687700136e3dc2 to your computer and use it in GitHub Desktop.
Save workflow/294c3cc2c202e196a2687700136e3dc2 to your computer and use it in GitHub Desktop.
Hacked data loader to support multi-class probability loading from CSV
def parse_csv_multi_class_probabilities(path_to_csv):
"""Parse filenames and probabilities for classes from a CSV file.
This method expects that the csv file at path :fn: has one column for filenames,
while all the other columns represent classes.
Expects a header with class names
Arguments:
path_to_csv: Path to a CSV file.
Returns:
a three-tuple of:
a list of filenames
a list of probabilities in the same order
a dictionary of classes by classIndex
"""
with open(path_to_csv) as fileobj:
reader = csv.reader(fileobj)
header = next(reader)
csv_lines = [l for l in reader]
fnames = [fname for fname, *_ in csv_lines]
classes = header[1:]
probabilities = [probs for _, *probs in csv_lines]
# tuples = [list(enumerate(probs)) for probs in probabilities]
# pred_tuples_by_filename = {fname: t for fname, t in zip(fnames, tuples)}
idx2class = {i: c for i, c in enumerate(classes)}
return fnames, probabilities, idx2class
def csv_source_multi_class(folder, csv_file, suffix=''):
fnames, probabilities, idx2class = parse_csv_multi_class_probabilities(csv_file)
full_names = [os.path.join(folder, fn + suffix) for fn in fnames]
prob_arr = np.array(probabilities).astype(np.float32)
return full_names, prob_arr, idx2class
@classmethod
def from_multiclass_csv(cls, path, folder, csv_fname, bs=64, tfms=(None, None),
val_idxs=None, suffix='', test_name=None, num_workers=8):
""" Read in images and their labels given as a CSV file.
--
This method should be used when training image labels are given in an CSV file as opposed to
sub-directories with label names.
Arguments:
path: a root path of the data (used for storing trained models, precomputed values, etc)
folder: a name of the folder in which training images are contained.
csv_fname: a name of the CSV file which contains target labels.
bs: batch size
tfms: transformations (for data augmentations). e.g. output of `tfms_from_model`
val_idxs: index of images to be used for validation. e.g. output of `get_cv_idxs`.
If None, default arguments to get_cv_idxs are used.
suffix: suffix to add to image names in CSV file (sometimes CSV only contains the file name without file
extension e.g. '.jpg' - in which case, you can set suffix as '.jpg')
test_name: a name of the folder which contains test images.
skip_header: skip the first row of the CSV file.
num_workers: number of workers
Returns:
ImageClassifierData
"""
fnames, y, idx2class = csv_source_multi_class(folder, csv_fname, suffix)
val_idxs = get_cv_idxs(len(fnames)) if val_idxs is None else val_idxs
((val_fnames, trn_fnames), (val_y, trn_y)) = split_by_idx(val_idxs, np.array(fnames), y)
test_fnames = read_dir(path, test_name) if test_name else None
f = FilesNhotArrayDataset
datasets = cls.get_ds(f, (trn_fnames, trn_y), (val_fnames, val_y), tfms,
path=path, test=test_fnames)
return cls(path, datasets, bs, num_workers, classes=list(idx2class.values()))
ImageClassifierData.from_multiclass_csv = from_multiclass_csv
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment