Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yang-zhang/56fdd7f055d0ad090ded3cd05df31474 to your computer and use it in GitHub Desktop.
Save yang-zhang/56fdd7f055d0ad090ded3cd05df31474 to your computer and use it in GitHub Desktop.
fastai-ImageRegressionDataset-y_range-dev
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"scrolled": true,
"trusted": true
},
"cell_type": "code",
"source": "import fastai; fastai.__version__",
"execution_count": 1,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 1,
"data": {
"text/plain": "'1.0.19'"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from fastai import *\nfrom fastai.vision import *",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class ImageRegressionDataset(ImageClassificationBase):\n def __init__(self, fns:FilePathList, y:Collection[Number]):\n super().__init__(fns, classes=[])\n self.y = np.array(y, dtype=np.float32)[:, None]\n self.c = 1\n self.loss_func = F.mse_loss",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "pdata = Path('/data/cifar10/train/airplane/')\n\nlist(pdata.glob('*.png'))[:3]",
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 4,
"data": {
"text/plain": "[PosixPath('/data/cifar10/train/airplane/17015_airplane.png'),\n PosixPath('/data/cifar10/train/airplane/44932_airplane.png'),\n PosixPath('/data/cifar10/train/airplane/43160_airplane.png')]"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "n = 1000\nn_val = 20\nfns = list(pdata.glob('*.png'))[:n]\ny = np.random.randn(n)\n\nds_trn = ImageRegressionDataset(fns[:-n_val], y[:-n_val])\nds_val = ImageRegressionDataset(fns[-n_val:], y[-n_val:])\n\nimg, y = ds_val[0]\nimg.show(title=y[0])",
"execution_count": 15,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 216x216 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAMUAAADSCAYAAAD66wTTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAFDhJREFUeJztnVusXOdVx/9r9syc+/H9VrsXMGkaJ5CQEpU2qRLFEaWAeEjTAI2UiAeK6AMtPEAjRBGIVFQ8gHhFQNQqlEslkNxLkBCkpC2QtEmd1Jc0sePYx/bx9djnfmb2no+HGYvJWv/t7mO741j9/6SR7DX78u09s2af/7cun6WUIIT4f2rXewBCvNWQUwjhkFMI4ZBTCOGQUwjhkFMI4ZBTCOGQU6wSM/sdM5s2s4tm9rdmNlRhnz8ys2RmD/TZnjSzlpnN972y3ntNM/uSmR3p7XcfOeadZvZfvf1Omdkn+967w8ye7Y1xysw+4/bdbWYHzWzRzP7TzN7Z994+N6bczPb0vX+/mb1gZrNmdtjMPr7qm/hWJ6WkV8UXgA8BOAXgVgDrADwD4M9+wD47AbwM4ASAB/rsTwL405J9mgA+BeAeACcB3Ofe3wjgNIBHAAwBmABwS9/7+wE8ASDrnf8kgF/u2/cigI8CGAbw5wD+p2QcBuAwgEd7/2/09v3N3nt3AZgHcPv1/myu6ed8vQdwI70A/D2Az/b9fzeA6R+wz9cA/AKAI1Wdwu0/RZziswC+cJl9FgHs6vv/PwN4vPfvjwP4Vt97YwCWALyHHOfe3pd+rPf/LQASgNG+bZ4H8GvX+7O5li/9+bQ6bgWwt+//ewFsMbMNbGMz+yiAVkrpqyXH+4SZnTez75jZR1Yxjp8FcN7MvmVmp81sj5m9o+/9vwTwqJk1zOxmAO8H8O/sGlJKCwAO9eyexwB8qbcNUkqnAHwRwK+bWWZm7wfwTgDfWMXY3/LIKVbHOLp/Plzi0r8n/IZmNo7uL/qnSo71VwBuArAZwB8CeNLM7q44jh3ofmE/CeAdAF5H98t6iS8DeAjdJ8BBAH+TUnq+5BouXcebrsHMRnvHeNJt+0UAnwGwAuBZAH+QUjpWcdw3BHKKy2Bmj/QJzq+h+6fEZN8ml/49R3b/Y3T/xHmdHTul9EJK6VxKKe89SZ4C8GDFoS0B+JeU0vMppeXeuT5gZmvMbD2ApwH8Cbqa4e0APmRmn+jt66/h0nX4a3gQwHkAX79kMLP3APhHAI+iq3tuBfB7ZvaLFcd9QyCnuAwppadSSuO914cB7ANwe98mtwM4lVI6R3bfDeC3ezNV0+h+Of/JzH6/7HToitcqvNTbvn9f9Pb/cQBFSunzPYebAvAP6Ooa+GswszF0xfg+d47HAHw+9YRDj9sAvJJS+reUUiel9AqArwD4cMVx3xhcb1FzI70A/DyAaQC70J19+g+UzD4B2ABga9/rGLozPuO99x9C90+ZGoCfQ/eX+r6+/YfQ/aWf6r0/DMB6790PYAbAHejOCP0FgGd7700CuADgY71jbwXw3wCe6L2/Cd0/lz7SO+bn4Gaf0P3zLAew09l3ovukuR9dB9wJ4DUAv3G9P5tr+jlf7wHcaC8Av4vutOwsgL8DMNT33j4Aj5TsdwRvnn16tvflnEVX+P4q2T6517v63v8tAMd7zrEHwNv73rsf3Vmhiz0n/mu8ecboAXS1xhK608rvcud+/JKTket4GMD3ek481XOq2vX+XK7l69IvjxCihzSFEA45hRAOOYUQDjmFEA45hRCO+iBP9tRTXwhTXVmWhe1qWfTVmhEb2bd0/1rFYxKbWYyp1WrRxrYri8fxba98Oz6LGG1su7IZyE6nc8XnZvvS7VJB9mU2PpaiiNsyG9v/kY89Rm+unhRCOOQUQjjkFEI4Bqop2N/H7E9maiPuy2zl5yG6oOJ4iBypfI6qmmC128Z9oy2la6tbuseslgFRWTMRzbWa+8j1XvzAVpO5oSeFEA45hRAOOYUQDjmFEI63gNCuKnaZUOY+zUV11aBcte2uVmhfjajmsONFccnOWyZCq4+R3ce4FQugXa3QriqqV3O/9aQQwiGnEMIhpxDCIacQwjFYoV0jgo5GtK9ckANAjQm/qxDGVcU3Y3BCu9o5VpMlW1XEphQFdIcckn0G6LD7EM9bmrHLxl4xY7gMPSmEcMgphHDIKYRwyCmEcAxUaDOq6s2rjRb/MFK9PUyYljEIoQ1jIjSaspKxMFFdEAVNhTYtCSW2iudlZcsAL0lmX+qqpbWAnhRCBOQUQjjkFEI45BRCOOQUQjjekvUUVWdxWC5+2TGvZjzXuu7icvYrhaZfkLkdOlNUMjPDEiPYLNDo6EiwjQwT20i0DRFbsxmXJs/qfPbJSJ3Maq6RoSeFEA45hRAOOYUQDjmFEI6BCu08z4Otan4/ze0vcekaqdugXceJjXZBryr8r7Lg/mqEf563g63RiIJ1dHQ02MYnxukYJybXBBsTwfU6+xrFz6DVagVbux3HzbZbWlqiY1xpxe9Um+zPvnvbt+6gx9STQgiHnEIIh5xCCIecQgjHQIX2rT/103EARKTVM2JrRFuDbFe2f61G8u7r1cR39Tb+FdcVAG+uQAP07DxkjDkRrEPDw8HG7jcToQCw0loh28bIcJ4zYbwctyvieTrEdvrMuWA7fuI4HWOtxq4n3gsW0b7rvT/Dj0mtQvwII6cQwiGnEMIhpxDCMVChvW79ujgAuo52tahyRsQzUNKJjpARQc4i2lz/XnnXwNVsy9Kg6fHI/SmIuMyXowBuF/wcbbLG9bGT08H23AsHgm3H1o3BNj0d9z164myw3bFrZ7At5vyr+uKL8dyNoWawnT1/IdgefvAhekw9KYRwyCmEcMgphHDIKYRwDFRonz8bRRUVtiRKyWp0y4R2RiLVtDNe1iC2eO7KNdqraNnP1vBjkj7RSmmWbh+3Yh356DqBrG8+uMjf8/Q3g+3bB2K0efvGyWBbWY6R75nZ2WD7yVveHWyH3niDjrFBarzPX5gJtrmlGJ0vQ08KIRxyCiEccgohHHIKIRxWtt7ZD4NnvvHNcDImtOtEAPMoNxexBUlHZoKX1Suz1Goj52Gimk4alDVDy6qJ8iKPYpnVILOw+9BQTB1n4plFvgGgU8QU7OXFKFhrjXgtrcUYOWfp8okV2nfi9S2SSDzAlwYoyAQDu0G7772Xfjh6UgjhkFMI4ZBTCOGQUwjhGGhE+9SpM8GWEWGbNaJgbTZiOjCtcwbQJuut1YmwXViMDbYK0qm7bqzmu1qNdXnXcfJ7RNaoGyEdvWcvXgy2NRMxgnyS3O9GI05ioM2EKY+Iz80vBlurRWyk7Jutg9cuoqhutePOWSqbVIn7T46PBdvSChfqDD0phHDIKYRwyCmEcMgphHAMVGi3LYrYCxdj7ewy6TC9ZeOmYBsmacMAMLcwF7cldbszczFtme2bZXFflmJekAgwi3wDQINNMJDJgG1btwXbzPk47mY9Rq9PTZ8OtmEi3EmTdgBAh0SL2eeVd6IwZg3SWMO2FdKILSdCuw5eJsCaqS2vROG/vBLPXYaeFEI45BRCOOQUQjjkFEI45BRCOAY6+1QnjQaOHjsWbBPr1gbbImmv/uorvJh9dn4+2FgNwvo1cU23jevWB1tzLN6mEbKe3MT6DcE2TNrhA8BwM85osVSUgjQpWE86LbLGDrfdsivYcpK6UZDZHgBoLcfaibVrYjpJXsQxWiKpGiQ1Zom0+z83cz7YpqZ4K/4L5+K2oxMTwcZmqcrQk0IIh5xCCIecQgiHnEIIx0CF9qc/90SwbduyJdjed+d7g61GBOeWTVHYAsCum28KthPHTwTb0898Pdj2HYnivTkcRfUwSdMYH4+NEJgwBYDJ8SgGNxEBfdtNNwdbIs0MclKrsEwmF1hzhbVkLACwdm2c8BglEwcTo3HSYIllVZAakiaxvW1TbOM/wro+AniVLCNw+kKsN2m3SbOHEvSkEMIhpxDCIacQwiGnEMIxUKE9QQRdg9QbPHD3B4Pt6FSMfG/fEgUZANx5x53B1iHriL/8/cPxPGdivUCRiGAlzRVy0vRgltRnAECHdOobGomC9fXpOEHwto2bg61NuvyNjsXaiYW5OJ7nnt9Px/jt114NNrN43dvXR0HeyePEyMYNsSZmqBm/gtuI0H73zh+jY1xLJifmSJOCM2dP0f0ZelII4ZBTCOGQUwjhkFMI4Rio0F4mHfnSWOzm1lqIheeHXj8SbEMkOgsAL+7dF2ysG+DBQ68FW0bEMmtVzzrW5WRNt8l1UQgCwARZBmCoHrv3FSsxtXr3B+8Jtn/d8+Vge/iXfiXYNm+IWQD7D36fjvF/P/14sB07GYX/4aloY2ni9drBYBsimQGb10Th/t0D8TMFeLMH9nktLsbvVBl6UgjhkFMI4ZBTCOGQUwjhGKjQXlqIYnd+NNpePBAjrMdPxohkVtJ978y5GJW+QCK5rLNd1mTpzTEyPL8YRTVbTy7vxLEAwAoR0Bnp1LeuGc89c+ZssE2fPBlsU8ej7TTZ96UDUQADwDJZZ46lx7MF60eH4qTB2DCpSye15axl/5ETU3SMLdJhcJlMeLRaasUvxBUjpxDCIacQwiGnEMIxUKF9390fCLbZuYVge+nggWA7TcTz9Mw5ep71kzFFfZE09homwm/Iou3EmXieFhGDY2RpgBW2EDyAheUY8Z1fjPdidilGYvfuixMRp2Zmgu3lg3E7VqO9d3+83wAwOhLrsYfHooDOV6LYXSCTKoskU4EJ5UYzjnGEpOoDwHIrFoNPjse6eMv4sg0MPSmEcMgphHDIKYRwyCmEcAxUaG/bEmuLd2yLQzh9Nq7VRrKBkZMIMsAXEmfNwupk4fecdMteQ5ucxY7lm0ha9omT03yMJMK6SGyHjh0NtkTGPTMfI/bf3f+9YKuTVO3Db8RzAMB2Uiu9aUvsyp6Rum0moL+z96VgmzsTJwjYEnwdcjwAAOl4vn1zHLdl6jouxBUjpxDCIacQwiGnEMIx2BptssB3rRYF8No1sa553booYlnqNwDML8TIcEGEWidFkZbIzwST86PDMULaJounr50cp2NsLkVxakRAkyHiIokMN4diZ/TXp2K69WmyHFajHqP4APATO3YEW53UprNf1q1kqbN73ndXsBWkGzirVbeSMoF6FseDTsxeOHCUp8cz9KQQwiGnEMIhpxDCIacQwiGnEMIx0NmnNlkg3hBnD1iHN9a6nqVkAHx9M5bm0SAF91ktznzU2KLopM4hJwuYp4It/gaMkBb0LL1hiTQ4KPKYDjJB2vi3ctKljzQKaI7EmSsAqJMxsvGwLn1LZ2KjiRGyXt6GyZgus2Y0ztjVa7yeokPu2uxiXPNudi7Wd5ShJ4UQDjmFEA45hRAOOYUQjoEK7bxNUi2IAGbpFwWzkYXFy+zMltfiMbOCpFoQMVcjwq/RiEJydDQKdwAYI00BErnGZSJsx0kbf5LsQJsHbI3ZF5gi7fUBYN/8fLBl5Lozdn9IygpLY8lIW8RmI96zjF4hn0DJO9FW5NV///WkEMIhpxDCIacQwiGnEMIxUKF96NXngi3L4hCYSKuRSGyqcaENEv02Er1O9Wgjm9GaBiNRXDpuEg0HgNn5eFAWnWVx7vNkvfqC7UsC/mzSYMs2/tuYLB7ASD0FqzhhGQhsjGwR+hUmlEm2AAAU5DwdlhHBbkYJelII4ZBTCOGQUwjhkFMI4Rio0F5ZiJ3/MragOolS1oxEPsd4e/W6xcsqOiRFnWQjk8xxFPS3gxXwk2g46WAHAB1SiN+xKBBrTNiyQ5LjkbkJKkyXeXY7yC3nrRpZswcmdomA7pDPpUNENbMBQJuNh54n2srQk0IIh5xCCIecQgiHnEIIx0CFNls8nWRvo0G63bHaYrZ+GwAUebUIqyWyP6trJlFpdmYWvaYBYABM9llGOuORlHcjXfHYsgT0mol6ZhME3W3ZTAQRxqyHIjmkkbTzGpk1SAXLFuAp+PUU72QiN91Y68cS9KQQwiGnEMIhpxDCIacQwjFQoc2EqDGBSBYMz4l4tjpPHW+QJl6s8RlL9WbKmK2NV29Ua2ZWL1kUvSACsSDR70YzCkxaw04iw+0OiZqTseQsKgygRZrN8bRsAhkju99ZjZQOkMXlWf06AOR1UqNNlkToEFsZelII4ZBTCOGQUwjhkFMI4RhsRHuYSDIioDpMFJOIdqtkzbsVIgbrJLWara3GQrEFqwVnkVRytKLNRSzTjWTOAYmt1cfGSNLt22SMYJMLJYywNfhILL9DouQF+WzYuoNFih3i2yTen0o+6w6J+Cfy+edFPE8ZelII4ZBTCOGQUwjhkFMI4Rio0J5fiB206yT9u0kKpY0IrUaJaGRmlqLeJktvUVlM06DJ7wkRnFbSDI2lN7PzJHJMliZORT6rVSYKnzWA6x6TCWPS5KxDzk5seSfeb9YNvmp99+Xs4dwly6wx9KQQwiGnEMIhpxDCIacQwiGnEMIx0NmnhVZcFJ3NPi1mJCRPMxZK1kFjsyGk7X6Ntecnx6MzTaxJAcvTKMmqYMdkNQNstojVpdRJMwO2/ACdXSvpYmh0ncH4QbRYrQI5JEv9YAPqkO3K6inYV4CNscO+EyXoSSGEQ04hhENOIYRDTiGEY6BC+0IehbaRxdxpBgURknXSca67P7GTIn6m0pgurpF9a6R7HkuXYN0Fu8dkNQPVRHWNpWqQteNYowCWFlGUNC5IZGkAJqDrJBWFra3HZks6JP+mRmxMPAMlApp+1GrFL8QVI6cQwiGnEMIhpxDCMVChnRORx7sGkpbtRCdxMcfrAGgrdqKMjUptIjgzcg6yb14SLU7s3Ez4M1HNBDQNxZOaBraeYMZ/G2skQl/1V5QtJF+QGhJ2f9gYabYAgJw0hmATFqtBTwohHHIKIRxyCiEccgohHFaWkivEjyp6UgjhkFMI4ZBTCOGQUwjhkFMI4ZBTCOGQUwjhkFMI4ZBTCOGQUwjhkFMI4ZBTCOGQUwjhkFMI4ZBTCOGQUwjhkFMI4ZBTCOGQUwjhkFMI4ZBTCOGQUwjhkFMI4fg/mSx8IJiYO+cAAAAASUVORK5CYII=\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from fastai.vision.learner import cnn_config\n\ndef create_cnn(data:DataBunch, arch:Callable, cut:Union[int,Callable]=None, pretrained:bool=True,\n lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5,\n custom_head:Optional[nn.Module]=None, split_on:Optional[SplitFuncOrIdxList]=None,\n classification:bool=True, y_range:OptRange=None, **kwargs:Any)->ClassificationLearner:\n \"Build convnet style learners.\"\n assert classification, 'Regression CNN not implemented yet, bug us on the forums if you want this!'\n meta = cnn_config(arch)\n body = create_body(arch(pretrained), ifnone(cut,meta['cut']))\n nf = num_features_model(body) * 2\n head = custom_head or create_head(nf, data.c, lin_ftrs, ps, y_range)\n model = nn.Sequential(body, head)\n learn = ClassificationLearner(data, model, **kwargs)\n learn.split(ifnone(split_on,meta['split']))\n if pretrained: learn.freeze()\n apply_init(model[1], nn.init.kaiming_normal_)\n return learn\n\nclass ToRange(nn.Module):\n def __init__(self, x_range):\n \"create a layer that transforms the value of `x` within `x_range` using sigmoid function\"\n super().__init__()\n self.x_range=x_range\n \n def forward(self, x): \n return (self.x_range[1]-self.x_range[0]) * F.sigmoid(x) + self.x_range[0] \n\ndef create_head(nf:int, nc:int, lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5, y_range:OptRange=None):\n \"\"\"Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes.\n :param ps: dropout, can be a single float or a list for each layer.\"\"\"\n lin_ftrs = [nf, 512, nc] if lin_ftrs is None else [nf] + lin_ftrs + [nc]\n ps = listify(ps)\n if len(ps)==1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps\n actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]\n layers = [AdaptiveConcatPool2d(), Flatten()]\n for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns):\n layers += bn_drop_lin(ni,no,True,p,actn)\n if y_range is not None:\n layers.append(ToRange(y_range))\n return nn.Sequential(*layers)\n",
"execution_count": 16,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "y_range = [0, 1]\ndata = ImageDataBunch.create(ds_trn, ds_val)\n\nlearn = create_cnn(data, models.resnet18, y_range=y_range)\n\nlearn.fit(1)",
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"text": "Total time: 00:00\nepoch train_loss valid_loss\n1 1.165546 1.695799 (00:00)\n\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "preds, _ = learn.get_preds()\npreds = preds.numpy()",
"execution_count": 18,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "assert np.all(preds>y_range[0]) and np.all(preds<y_range[1])",
"execution_count": 19,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "conda-env-fastaiv1-py",
"display_name": "Python [conda env:fastaiv1]",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.7.0",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"base_numbering": 1,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"gist": {
"id": "",
"data": {
"description": "fastai-ImageRegressionDataset-y_range-dev",
"public": false
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment