Last active
April 9, 2018 16:10
-
-
Save radekosmulski/e78802dd000c3427a2cdb7f9f199ebfe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ConcatDataset(Dataset): | |
"""Concatenates a dataset and an iterable of appropriate size.""" | |
def __init__(self, ds, y2): | |
assert(len(ds)==len(y2)) | |
self.ds,self.y2 = ds,y2 | |
def __len__(self): return len(self.ds) | |
def __getitem__(self, i): | |
x,y = self.ds[i] | |
return (x, (self.y2[i],y)) | |
def denorm(self, im): return self.ds.denorm(im) | |
def concat_datasets_for_detection(fnames, ys, transform, path): | |
""" | |
Arguments: | |
fnames: image file names | |
y[0]: an array of lables for each example | |
y[1]: bounding box coordinates | |
fnames, y[0] and y[1] need to be in corresponding order. | |
Returns: | |
ConcatDataset | |
""" | |
return ConcatDataset(FilesIndexArrayRegressionDataset(fnames, ys[1], transform, path), ys[0]) | |
class ObjectDetectionData(ImageData): | |
@classmethod | |
def from_csv(cls, path, folder, csv_fname, bs=64, tfms=(None,None), | |
val_idxs=None, suffix='', test_name=None, skip_header=True, num_workers=8): | |
""" Read in images and associated bounding boxes with labels given as a CSV file. | |
The csv file needs to contain three columns - first one containing file names, second one classes and the third one | |
bounding box coordinates. | |
Example: | |
file_name,category,bbox_coords | |
000012.jpg,car,96 155 269 350 | |
000017.jpg,car person,61 184 198 278 77 89 335 402 | |
Arguments: | |
path: a root path of the data (used for storing trained models, precomputed values, etc) | |
folder: a name of the folder in which training images are contained. | |
csv_fname: a name of the CSV file which contains target labels. | |
suffix: suffix to add to image names in CSV file (sometimes CSV only contains the file name without file | |
extension e.g. '.jpg' - in which case, you can set suffix as '.jpg') | |
bs: batch size | |
tfms: transformations (for data augmentations). e.g. output of `tfms_from_model` | |
val_idxs: index of images to be used for validation. e.g. output of `get_cv_idxs`. | |
If None, default arguments to get_cv_idxs are used. | |
test_name: a name of the folder which contains test images. | |
skip_header: skip the first row of the CSV file. | |
num_workers: number of workers | |
Returns: | |
ObjectDetectionData | |
""" | |
df = pd.read_csv(csv_fname, index_col=0, header=0 if skip_header else None, dtype=str) | |
for i in range(df.shape[1]): df.iloc[:,i] = df.iloc[:,i].str.split(' ') | |
labels = [] | |
for row in df.iloc[:, 0]: labels += row | |
classes = sorted(list(set(labels))) | |
class2id = {l: i for i, l in enumerate(classes)} | |
df.iloc[:, 0] = df.iloc[:, 0].apply(lambda row: np.array(list(map(lambda i: class2id[i], row)))) | |
for col in range(1, df.shape[1]): | |
df.iloc[:, col] = df.iloc[:, col].apply(lambda row: np.array(list(map(lambda i: int(i), row)))) | |
fnames,y = df.index.values,[df.values[:, i] for i in range(df.shape[1])] | |
full_names = [os.path.join(folder,str(fn)+suffix) for fn in fnames] | |
return cls.from_names_and_arrays(path, full_names, y, classes, val_idxs, test_name, | |
num_workers=num_workers, tfms=tfms, bs=bs) | |
@classmethod | |
def from_names_and_arrays(cls, path, fnames, y, classes, val_idxs=None, test_name=None, | |
num_workers=8, tfms=(None,None), bs=64): | |
val_idxs = get_cv_idxs(len(fnames)) if val_idxs is None else val_idxs | |
(val_fnames, trn_fnames), *ys = split_by_idx(val_idxs, np.array(fnames), *y) | |
test_fnames = read_dir(path, test_name) if test_name else None | |
datasets = cls.get_ds(concat_datasets_for_detection, (trn_fnames,[y[1] for y in ys]), (val_fnames,[y[0] for y in ys]), tfms, | |
path=path, test=test_fnames) | |
return cls(path, datasets, bs, num_workers, classes=classes) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment