-
-
Save workflow/294c3cc2c202e196a2687700136e3dc2 to your computer and use it in GitHub Desktop.
Hacked data loader to support multi-class probability loading from CSV
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def parse_csv_multi_class_probabilities(path_to_csv): | |
"""Parse filenames and probabilities for classes from a CSV file. | |
This method expects that the csv file at path :fn: has one column for filenames, | |
while all the other columns represent classes. | |
Expects a header with class names | |
Arguments: | |
path_to_csv: Path to a CSV file. | |
Returns: | |
a three-tuple of: | |
a list of filenames | |
a list of probabilities in the same order | |
a dictionary of classes by classIndex | |
""" | |
with open(path_to_csv) as fileobj: | |
reader = csv.reader(fileobj) | |
header = next(reader) | |
csv_lines = [l for l in reader] | |
fnames = [fname for fname, *_ in csv_lines] | |
classes = header[1:] | |
probabilities = [probs for _, *probs in csv_lines] | |
# tuples = [list(enumerate(probs)) for probs in probabilities] | |
# pred_tuples_by_filename = {fname: t for fname, t in zip(fnames, tuples)} | |
idx2class = {i: c for i, c in enumerate(classes)} | |
return fnames, probabilities, idx2class | |
def csv_source_multi_class(folder, csv_file, suffix=''): | |
fnames, probabilities, idx2class = parse_csv_multi_class_probabilities(csv_file) | |
full_names = [os.path.join(folder, fn + suffix) for fn in fnames] | |
prob_arr = np.array(probabilities).astype(np.float32) | |
return full_names, prob_arr, idx2class | |
@classmethod | |
def from_multiclass_csv(cls, path, folder, csv_fname, bs=64, tfms=(None, None), | |
val_idxs=None, suffix='', test_name=None, num_workers=8): | |
""" Read in images and their labels given as a CSV file. | |
-- | |
This method should be used when training image labels are given in an CSV file as opposed to | |
sub-directories with label names. | |
Arguments: | |
path: a root path of the data (used for storing trained models, precomputed values, etc) | |
folder: a name of the folder in which training images are contained. | |
csv_fname: a name of the CSV file which contains target labels. | |
bs: batch size | |
tfms: transformations (for data augmentations). e.g. output of `tfms_from_model` | |
val_idxs: index of images to be used for validation. e.g. output of `get_cv_idxs`. | |
If None, default arguments to get_cv_idxs are used. | |
suffix: suffix to add to image names in CSV file (sometimes CSV only contains the file name without file | |
extension e.g. '.jpg' - in which case, you can set suffix as '.jpg') | |
test_name: a name of the folder which contains test images. | |
skip_header: skip the first row of the CSV file. | |
num_workers: number of workers | |
Returns: | |
ImageClassifierData | |
""" | |
fnames, y, idx2class = csv_source_multi_class(folder, csv_fname, suffix) | |
val_idxs = get_cv_idxs(len(fnames)) if val_idxs is None else val_idxs | |
((val_fnames, trn_fnames), (val_y, trn_y)) = split_by_idx(val_idxs, np.array(fnames), y) | |
test_fnames = read_dir(path, test_name) if test_name else None | |
f = FilesNhotArrayDataset | |
datasets = cls.get_ds(f, (trn_fnames, trn_y), (val_fnames, val_y), tfms, | |
path=path, test=test_fnames) | |
return cls(path, datasets, bs, num_workers, classes=list(idx2class.values())) | |
ImageClassifierData.from_multiclass_csv = from_multiclass_csv |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment