Skip to content

Instantly share code, notes, and snippets.

@mzr1996
Last active April 29, 2022 10:25
Show Gist options
  • Save mzr1996/4e81a4c0e59fb140ffe0d17ff4f352ed to your computer and use it in GitHub Desktop.
Save mzr1996/4e81a4c0e59fb140ffe0d17ff4f352ed to your computer and use it in GitHub Desktop.
Create an demo multi-task dataset from CIFAR10
from pathlib import Path
import mmcv
from mmcls.datasets import build_dataset
data_root = Path("./cifar10")
extract_folder = data_root / 'images'
extract_folder.mkdir(parents=True, exist_ok=True)
dataset_cfg = dict(type='CIFAR10', data_prefix='cifar10', pipeline=())
dataset = build_dataset(dataset_cfg)
classes = dataset.CLASSES
metainfo = dict(tasks=[
dict(
name='task1', categories=classes[:5] + ['other'], type='single-label'),
dict(
name='task2', categories=classes[5:] + ['other'], type='single-label'),
])
data_list = []
for i, data_info in enumerate(dataset.data_infos):
img = data_info['img']
gt_label = data_info['gt_label']
img_path = extract_folder / f"{i}.png"
mmcv.imwrite(img, str(img_path))
gt_label = int(gt_label)
if gt_label < 5:
task1_label = gt_label
task2_label = 5
else:
task1_label = 5
task2_label = gt_label - 5
data_info = dict(
img_path=str(img_path.relative_to(data_root)),
task1_img_label=task1_label,
task2_img_label=task2_label)
data_list.append(data_info)
demo_train = {'metainfo': metainfo, 'data_list': data_list[:45000]}
demo_test = {'metainfo': metainfo, 'data_list': data_list[45000:]}
mmcv.dump(demo_train, str(data_root / 'multi-task-train.json'), indent=2, sort_keys=False)
mmcv.dump(demo_test, str(data_root / 'multi-task-test.json'), indent=2, sort_keys=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment