Skip to content

Instantly share code, notes, and snippets.

@yang-zhang
Created November 7, 2018 20:26
Show Gist options
  • Save yang-zhang/ded77ce20c9228332631e84eda778c57 to your computer and use it in GitHub Desktop.
Save yang-zhang/ded77ce20c9228332631e84eda778c57 to your computer and use it in GitHub Desktop.
ImageRegressionDataset for fastai
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true,
"scrolled": true
},
"cell_type": "code",
"source": "import fastai; fastai.__version__",
"execution_count": 1,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 1,
"data": {
"text/plain": "'1.0.19'"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from fastai import *\nfrom fastai.vision import *",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class ImageRegressionDataset(ImageClassificationBase):\n def __init__(self, fns:FilePathList, y:Collection[Number]):\n super().__init__(fns, classes=[])\n self.y = np.array(y, dtype=np.float32)[:, None]\n self.c = 1\n self.loss_func = F.mse_loss",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "pdata = Path('/data/cifar10/train/airplane/')\n\nlist(pdata.glob('*.png'))[:3]",
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 4,
"data": {
"text/plain": "[PosixPath('/data/cifar10/train/airplane/17015_airplane.png'),\n PosixPath('/data/cifar10/train/airplane/44932_airplane.png'),\n PosixPath('/data/cifar10/train/airplane/43160_airplane.png')]"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "n = 100\nn_val = 20\nfns = list(pdata.glob('*.png'))[:n]\ny = np.random.randn(n)\n\nds_trn = ImageRegressionDataset(fns[:-n_val], y[:-n_val])\nds_val = ImageRegressionDataset(fns[-n_val:], y[-n_val:])\n\nimg, y = ds_val[0]\nimg.show(title=y[0])",
"execution_count": 5,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 216x216 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAMUAAADSCAYAAAD66wTTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAEP9JREFUeJztnUusVtUZht/FRVG8gB4UUC6KXERPVEBNjJdaB0iTjtqYNA504KBpdOCgTTowMW0TJx00HZhOmtiYDkyrCV5iIRVNDRADKFcNd7kqV0URVC67A/+mJ9/3LN0Hyj5q3yc5A9+997/Wvzef+3/X9621StM0Msb8l2FD3QFjvm04KIwJOCiMCTgojAk4KIwJOCiMCTgojAk4KM6QUsqjpZSVpZQvSinPfM15D5dSTpVSjg74+8GA4++XUo4POLZ4ENf+tpSyrpRyspTyZGh3QinlxVLK3lJKU0qZGo4/UEpZVko5Vkp542v6/1Dv+kda35zvOCOGugPfYfZK+p2k+ZIu+IZzlzdNc+fXHP9x0zT/PINrt0j6laSfw7HTkv4h6SlJy+D4YUl/kDRL0g/pw0spYyX9WtKGete/fzgozpCmaV6QpFLKPElXD1Ef/tLrw4NwbJ+kp0sp+Iz/E4Tf8AZ4StIfJT1w9r397uCfT91wSynlYCllUynlCfiH+tdSyoFSyuJSyk2DvPacUEq5TdI8SX/qor1vEw6Kc8+/JN0o6QpJP5H0M0m/HHD8QUlTJU2R9LqkRaWUMS2vPSeUUoZLelrSY03TnD7X7X3bcFCcY5qm2dY0zfamaU43TbNO0m8k/XTA8aVN0xxvmuZY0zRPSfpY0l1trj2H/ELS2qZplnfQ1rcOe4ruaSSVMzz+Tdf+r7hP0j2llB/1/vsyffUz7uamaR7toP0hxUFxhvR+24+QNFzS8FLKKEknm6Y5Gc5bIOntpmn2lVJmSXpC0t96xyZLmiRphb56az8mqU/S0m+6tnd8ZK/9YZJG9PpwommaU73jo3rHJen8Usqopmk+7x0bLmlk7zsM6517qmmaE5IeljRqwNd4QdLfJf357O7ad4Smafx3Bn+SntRX/+ce+PekpMmSjkqa3Dvv95L2SfpM0jZ99RNoZO/YDZLW9o4dkvSapHkD2qhe2zv+DPTh4QHH47FmwLGH4fgzle/6hqRHhvqed/VXel/aGNPDRtuYgIPCmICDwpiAg8KYgIPCmECneYpt27aloa6JEyem806dOpW0jz76KGnbt2/Hdujc0aNHJ+3SSy9N2ogR+ZaMHTs2aTRqN2xY+//HnD6dqydKyXk56s/x48db9eeiiy5K2smTJ5M2fPjwpNX6eP755yeNvje1c+LEiaTRs963b1/S6PlJ0nnnnZc06veXX36ZtP7+fkyE+k1hTMBBYUzAQWFMoFNPcejQoaTRb72+vr6kjRo1Kmn0m1mSPvnkk6S1/b3/xRdfJO3TTz9NGv2Gpz7WoN/cI0eOTBrdH2qbfpt//vnnSSPfUoP8A90f+r1Pv+upj+QzDh48mLTas6bnOhhvh595Vlcb8z3EQWFMwEFhTMBBYUygU6NNiZbdu3cnjYwSmTlKqklsTtuafIKMLZliMpe1xBh9R2qHjC1pdC2ZajLPZPprffzss89a9YeeNWlkqqk/lNCTpMsvvzxp9GwGM0XCbwpjAg4KYwIOCmMCDgpjAp0abTJphw8fThpVgV577bVJq2U5r7zyyqS1rbwls0wm/eKLL27Vn1p2lXRqm0wjZYHpnpFhbWuKJR6IuOCCvGwuGXrKplPGn+4t3Rv6zpJ04YUXJo2+T23Ag/CbwpiAg8KYgIPCmICDwphAp0b7yJEjSSNzumnTpqSRSSPzLbH5mjp1atLINFLbZGKPHj2atMsuuyxptWmU9L3bTkcl2mbDyczXsr3UNpXlU2UBmd09e/YkjYz27NmzsT8EfR/63oMpJ/ebwpiAg8KYgIPCmICDwpiAg8KYQKejTxMmTEgalSLs378/abt27Ura+PHjsR0a2aE0/+TJk5NGpShbt25NWtvShloJBY2G0FwHGgGie9a2bKTtaI3E94y+T9tymTFjxiRtzpw5Sfv444+TRmU1tT5SSQ/1p4bfFMYEHBTGBBwUxgQcFMYEOjXaVBpBpogmo69ZsyZpGzduxHZoNXGal0BmmVZBp9IGKlkhrbY4ApWitO03GW0qWSHjT32sQdfT86I+UkkH9ZEGRag05tixY9hHuhc0OOH5FMacBQ4KYwIOCmMCDgpjAp0abTLLNE+CNDJ4W7ZswXauuOKKpNWy3xHKDNPn0UR6yuwOJlvcdvU+MqJk6Ck7T3ND6N5KnL2m/lC2mbLXlKletGhR0ubPn9+qDYlXGKTsNc11qeE3hTEBB4UxAQeFMQEHhTGBTo02Gbpt27YlrW2WkzLkkvTOO+8kjUqUacI9ZWfJ5NGiCTt27Ega7Zcn8YT9tt9x586dSfvggw+Stnnz5qTdddddSaMVFaX2pewEleUvXLgwaTQYQPujk3GXuAJh7969Sfvwww+TdsMNN+Bn+k1hTMBBYUzAQWFMwEFhTKBTo02maOnSpUm79957k0bGlIygxPujvf/++0mjLCdl06nEnDLfZMhrS8jTvG+6F/RdKNNM5e1Unk7f+ZprrsE+Em33wVu3bl3SXn/99aQtWLAgaVQmTtlwiQdGaDVImvdfw28KYwIOCmMCDgpjAg4KYwKdGu2+vr6kkYEiA00Gb9KkSdgOLZxGe+uReSeoPJlKsMnY3nrrrfiZ69evTxqVwtM9I2NM92LatGlJo3nStSw1lYnTHO8lS5YkbdmyZa2upecyY8aMpFG/JS7XJ+M/c+ZMvJ7wm8KYgIPCmICDwpiAg8KYQKdGm+Yrk8nbsGFD0miONWWFJZ6jS+XRtNjX888/n7RZs2Yl7aqrrkoalUvX5hb39/e3+kxaGK7tBu9klOk70z5/EpvY1atXJ+3tt99OWttBBxrEaLsSucSZ/LarjtPzkvymMCbhoDAm4KAwJuCgMCbQqdF+7733kkYbmx84cCBplPmsLTRGi5c9++yzSXvooYeSRtlUMn6UaSYzR2ZX4vLvq6++OmlUGk3ZdPo8MtW0UvuLL76IfaR5zVTWTcb/xhtvTBo9ayqXr82dJshoU0ab2q7hN4UxAQeFMQEHhTEBB4UxAQeFMYFOR59o1ITKIGiFQBpRqG3cfttttyWN5mhQecPjjz/e6lpauY9Gn6jcQeLRELqeRnvajirR4gF03rvvvot9pNGwuXPnJo22AaBnQ4swTJ8+PWm0qiI9f4nvD5WneM87Y84CB4UxAQeFMQEHhTGBTo02rdxGpoqWkKeJ+StXrsR2KPV/xx13JI0MKxlgWgiB+k0lGaRJbN5pc3kyrLSK4bhx45JG80BorkFtP0Cag0KGleZy0HehPlL5Dhn32pYGZKrJfFPJSg2/KYwJOCiMCTgojAk4KIwJdGq0yaTRZH3a84xWCKxNPKf5FNQOGUSay0HnUXZ++fLlSaO98SQ2tzQvgbLAa9asSRoZaMqmU38oky6135yelvengQQadKDnSvNk6PvVPpOqDci81/CbwpiAg8KYgIPCmICDwphAp0abMtBk8shUUUaztp8crWJHRmv27NlJI1NNZo42c1+1alXSaK+9Wjv33Xdf0mjZfTK2bSfmU6k+VQBIfM/oebXdE5CgaQKU5a6Vfl9yySVJo1UHa9MMCL8pjAk4KIwJOCiMCTgojAl0arTbLgNPRouW3a+tvkcmmDLDlPmmJd/JSK5YsSJpa9euTRqVeUvSPffc0+rcMWPGJI1WLKR7Qf2mwYlaWTaZZcoWjx07NmltV+mjNui8Wkab+k79qVU/EH5TGBNwUBgTcFAYE3BQGBPo1GjT/nZkLsmkUeaTspmSdP311ydtwoQJSXvzzTeTRiaPMsiUNaesaS3rTvOIycSSgaYMO2WaKQtM97u2pQENjNBnkgkmk0/3h86jZ13b8476Tll7+rdXw28KYwIOCmMCDgpjAg4KYwJDvuo4GT8ySmRMqURYYnNL7dx0001JI9P43HPPJY0yqbQRfK2PlMmn0nEq66aFxsiQkwml82iRMqme6Y7Qs6H7SAMEVJ5OAy21eeRkytsuKlfDbwpjAg4KYwIOCmMCDgpjAp0abcoWt12xmhbXIuMusVGjMvG2peyUaSbjduTIkaSREZQ4Q0uLnNEWW/SZtJBa27nuNEAgcdn6zp07k7Z79+6kkYGmbDhdS/eB/k3UdKpAaDtoIPlNYUzCQWFMwEFhTMBBYUzAQWFMoNPRJ1oaniau06gQnVcbNaHSiLfeeitpNDpDbdN5VHZAfayVJ9BnHjp0KGk0YkMjV1TaQiMzg+kjlcbMmDEjaTRXYcOGDUmjUUFqg/o9mKX4287lqOE3hTEBB4UxAQeFMQEHhTGBTo326NGjk9bWVJN5os3qJd5IfP369Ukjg0iDAW373db0SWwmqcTktddeS9q0adOSRmaZSkSoXIZKRCQ2tzQfg54DDRCQ+ab7M2vWrKTRJvISL15B5UQ22sacBQ4KYwIOCmMCDgpjAkNutGmuAplBMn21fd7oejLVdD3N0SCTdjabp9eup3MXL16ctLvvvjtpZKppfkdtVUWCzDJpZJZpE3uaB7Jx48ak7d+/P2m0vL4kHThwIGn0DGmFyBp+UxgTcFAYE3BQGBNwUBgT6NRot92EvO2y8rSSoMRl1DSZve0KemTyybBSlpsMtcSmvO0Awd69e5NGWXxqg/baq91Hup7OpYEIumd9fX1JowoCynzXSsfpWdMASm1zesJvCmMCDgpjAg4KYwIOCmMCnRptMkVk0shcklGq7dVGmXMyanQeZUMpm0omdDAlyzSYQPeHMtVUWk1Z4B07diRt0aJFSaNSdIlN8JQpU5I2ceLEpLUtoydt7ty5SaP56xLP+6bBksHgN4UxAQeFMQEHhTEBB4UxgSHf844WLiOjTUu718qgyXS2XdCMlp8ns0tLu5Oprg0GkMGcM2dO0shUk7F9+eWXk3bdddcljbLhq1evxj5SBpqMLT3D8ePHt/o8WgyNoK0UJDb+q1atOuN2JL8pjEk4KIwJOCiMCTgojAkM+RxtMm5tS6hr+8mR0aaSZzJf06dPTxpldsmk03zhmtGmOcy0ovctt9yStIULFybt8OHDSWu70Bj1W+LF2ejebt68OWm333570mgggp4BDYrQvxOJ515T27t27cLrCb8pjAk4KIwJOCiMCTgojAl0arTJLNMcbSoTp+x1zXxRO7QVGGV8yZDThuo0J5rKm2ubos+cObNV26+88krS1q1blzTKxFMp+vbt25NWW72d7iOZctpu7KWXXkoaZdMpi08VBLU52jTgQc+6Nlee8JvCmICDwpiAg8KYgIPCmICDwphAp6NPx44dS1rbeQ60+l5tU3QaaSCN5kTQ6MzBgweTRsvcD2ZlukmTJqEe2bJlS9JoXkLbOSg0elTb845Gw6gMhhYpoCX2V6xYkbQ9e/Yk7c4772zVbq1t0mhEq4bfFMYEHBTGBBwUxgQcFMYEOjXaZE5pMQOC5iXUTCzNfyAT3HZj+7ab3VNZBbUhseGlJeipvIVMNd1bKr8g6FqJ9yOkOR/Un5tvvjlp9F22bt2atFdffbXVeZI0b968pNFiBrV5LYTfFMYEHBTGBBwUxgQcFMYEOjXaNHGdjDZlYilLWZurQJPhaXl2WtmOzqP98ui70HyK2pyPJUuWJK2/vz9p48aNSxrN72h7Hyl7XdvzjioGaPEJGtig6gWa50CDJbQ4Qm0VQzr3/vvvTxrNX6nhN4UxAQeFMQEHhTEBB4UxgUKZXmP+n/GbwpiAg8KYgIPCmICDwpiAg8KYgIPCmICDwpiAg8KYgIPCmICDwpiAg8KYgIPCmICDwpiAg8KYgIPCmICDwpiAg8KYgIPCmICDwpiAg8KYgIPCmICDwpiAg8KYwL8BSlF8+BZHxLwAAAAASUVORK5CYII=\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "data = ImageDataBunch.create(ds_trn, ds_val)\n\nlearn = create_cnn(data, models.resnet18)\n\nlearn.fit(1)",
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": "Total time: 00:00\nepoch train_loss valid_loss\n1 1.250102 1.324627 (00:00)\n\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "data = ImageDataBunch.create(ds_trn, ds_val)\n\nlearn = create_cnn(data, models.resnet18, loss_func=F.l1_loss)\n\nlearn.fit(1)",
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": "Total time: 00:00\nepoch train_loss valid_loss\n1 1.430220 1.078725 (00:00)\n\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "conda-env-fastaiv1-py",
"display_name": "Python [conda env:fastaiv1]",
"language": "python"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"base_numbering": 1,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"language_info": {
"name": "python",
"version": "3.7.0",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"gist": {
"id": "",
"data": {
"description": "ImageRegressionDataset for fastai ",
"public": false
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment