Created
July 25, 2019 18:11
-
-
Save zlapp/e82902577aade2d296e52e0c26bd5ac0 to your computer and use it in GitHub Desktop.
Fastai Active Learning.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Fastai Active Learning.ipynb", | |
"version": "0.3.2", | |
"provenance": [], | |
"collapsed_sections": [], | |
"toc_visible": true, | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/zlapp/e82902577aade2d296e52e0c26bd5ac0/fastai-active-learning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "6UfvLedCQz4O", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Fastai Active Learning" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "jyNrrwBL3AQp", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"!curl -s https://course.fast.ai/setup/colab | bash" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "hOCfmhtk3Lze", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import fastai\n", | |
"fastai.__version__" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "bvM9NMt33rJp", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from fastai.vision import *\n", | |
"from fastai.basic_train import BasicLearner\n", | |
"from fastai.basic_data import DataBunch\n", | |
"from fastprogress.fastprogress import format_time, IN_NOTEBOOK\n", | |
"\n", | |
"\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"from tqdm import tqdm" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "TdsPZd37-j_t", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Data Prep" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "KPl5rgOETVst", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"np.random.uniform()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5lTbHa2zUDjn", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from fastai.script import *\n", | |
"from fastai.vision import *\n", | |
"from fastai.distributed import *" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "28B-sLgsU7jT", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"path = untar_data(URLs.MNIST_TINY)\n", | |
"data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), []), num_workers=2)\n", | |
"data.normalize()\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "zx8FIoMXroq_", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"learn = cnn_learner(data, models.resnet18, metrics=[accuracy])\n", | |
"learn.fit_one_cycle(3)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "YyrhPAreVGgN", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"unlabeled_path = untar_data(URLs.MNIST_SAMPLE)\n", | |
"unlabeled_data = ImageDataBunch.from_folder(unlabeled_path, ds_tfms=(rand_pad(2, 28), []), num_workers=2)\n", | |
"unlabeled_data.normalize()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "_xHPx-SmsehI", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"unlabeled_data.test_dl = unlabeled_data.train_dl\n", | |
"unlabeled_data" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "r6hvAddo-fki", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Activer Learner and Measurement" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "p0_FX0JG4Cls", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"class UncertaintyMeasurement:\n", | |
" def __init__(self,learn:BasicLearner):\n", | |
" self.learn = learn\n", | |
" self.reverse = False\n", | |
" def __call__(self,x):\n", | |
" return 0\n", | |
" \n", | |
"class RandomUncertaintyMeasurement(UncertaintyMeasurement):\n", | |
" def __call__(self,x):\n", | |
" return np.random.uniform()\n", | |
" \n", | |
" \n", | |
"class ModelSoftmaxMeasurement(UncertaintyMeasurement):\n", | |
" def __call__(self,x):\n", | |
" _,idx,probs=self.learn.predict(x)\n", | |
" prob = probs.data[idx.item()].item()\n", | |
" # print(idx,probs,prob)\n", | |
" return prob\n", | |
" \n", | |
"class MCDStandardDevMeasurement(UncertaintyMeasurement):\n", | |
" def __init__(self,learn:BasicLearner):\n", | |
" self.learn = learn\n", | |
" # high std means very uncertain\n", | |
" self.reverse = True \n", | |
" def __call__(self,x):\n", | |
" mcd_preds = learn.predict_with_mc_dropout(x)\n", | |
" # print(mcd_preds)\n", | |
" mcd_probs = [probs for _,_,probs in mcd_preds]\n", | |
" mcd_probs = torch.stack(mcd_probs)\n", | |
" # print(mcd_probs)\n", | |
" mcd_means = mcd_probs.mean(dim=0)\n", | |
" # print(mcd_means)\n", | |
" mcd_std = mcd_probs.std(dim=0)\n", | |
" # print(mcd_std)\n", | |
" _,idx = mcd_means.max(0)\n", | |
" std = mcd_std.data[idx.item()].item()\n", | |
" # print(std)\n", | |
" return std" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "r1fY6lvQ3XY8", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"class ActiveLearner:\n", | |
" def __init__(self,measurement:UncertaintyMeasurement, export_path:str=\"./uncertain.csv\", per_batch:bool=False):\n", | |
" self.measurement = measurement\n", | |
" self.export_path = export_path\n", | |
" self.per_batch = per_batch\n", | |
" self.reverse = self.measurement.reverse\n", | |
" \n", | |
" \n", | |
" def get_top_uncertain(self,data:DataBunch, count:int, export:bool=True)->list:\n", | |
" # list of uncertainties\n", | |
" # will be list of lists\n", | |
" uncertainty_vals = []\n", | |
" \n", | |
" assert count<=len(unlabeled_data.test_dl.dataset)\n", | |
" \n", | |
" bs = len(unlabeled_data.test_dl)\n", | |
" \n", | |
" # print(\"batch size\",bs)\n", | |
" \n", | |
" batch_idx = -1\n", | |
" x_idx = 0\n", | |
" for x,y in tqdm(unlabeled_data.test_dl.dataset):\n", | |
" # print(x_idx,batch_idx)\n", | |
" if x_idx%bs==0:\n", | |
" uncertainty_vals.append([])\n", | |
" batch_idx+=1\n", | |
"\n", | |
" uncertainty = self.measurement(x)\n", | |
" # print(\"len(uncertainty_vals)\",len(uncertainty_vals))\n", | |
" filename = str(unlabeled_data.test_dl.dataset.items[x_idx])\n", | |
" uncertainty_vals[batch_idx].append((uncertainty,filename))\n", | |
" # print(\"len(uncertainty_vals[batch_idx])\",len(uncertainty_vals[batch_idx]))\n", | |
" x_idx+=1\n", | |
" \n", | |
" \n", | |
" top_uncertain = []\n", | |
" \n", | |
" if self.per_batch:\n", | |
" for sublist in uncertainty_vals:\n", | |
" sublist.sort(key=lambda tup: tup[0], reverse=self.reverse)\n", | |
" \n", | |
" while len(top_uncertain)<count:\n", | |
" # remove empty sublists\n", | |
" uncertainty_vals = [sublist for sublist in uncertainty_vals if sublist != []]\n", | |
" \n", | |
" uncertainty_vals.sort(key=lambda sublist: sublist[0], reverse=self.reverse)\n", | |
" for sublist in uncertainty_vals:\n", | |
" if count<=len(top_uncertain):\n", | |
" break\n", | |
" tup = sublist.pop(0)\n", | |
" top_uncertain.append(tup)\n", | |
" \n", | |
" else:\n", | |
" uncertainty_vals = [tup for sublist in uncertainty_vals for tup in sublist]\n", | |
" uncertainty_vals.sort(key=lambda tup: tup[0], reverse=self.reverse)\n", | |
" top_uncertain = uncertainty_vals[:count]\n", | |
" \n", | |
" top_uncertain.sort(key=lambda tup: tup[0], reverse=self.reverse)\n", | |
" \n", | |
" if export:\n", | |
" self.export_top_uncertain(top_uncertain)\n", | |
" \n", | |
" return top_uncertain\n", | |
" \n", | |
" def export_top_uncertain(self,top_uncertain:list)->None:\n", | |
" print(\"Exporting to\",self.export_path)\n", | |
" uncertainties, filenames = [],[]\n", | |
" for uncertainty,filename in top_uncertain:\n", | |
" uncertainties.append(uncertainty)\n", | |
" filenames.append(filename)\n", | |
" data ={\n", | |
" \"name\":filenames,\n", | |
" \"uncertainty\":uncertainties\n", | |
" } \n", | |
" df = pd.DataFrame.from_dict(data)\n", | |
" df.to_csv(self.export_path,index=False)\n", | |
" " | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "eK5PeyPnWy2Q", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"rum = RandomUncertaintyMeasurement(learn)\n", | |
"msum = ModelSoftmaxMeasurement(learn)\n", | |
"mcdum = MCDStandardDevMeasurement(learn)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ClT2SXI56D87", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"a_learner = ActiveLearner(mcdum)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "lVv2xQWZpWYD", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"a_learner.get_top_uncertain(unlabeled_data,50)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "E9OSXu2W8IGk", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# for x,y in unlabeled_data.test_dl.dataset:\n", | |
"# mcd_preds = learn.predict_with_mc_dropout(x)\n", | |
"# print(mcd_preds)\n", | |
"# mcd_probs = [probs for _,_,probs in mcd_preds]\n", | |
"# mcd_probs = torch.stack(mcd_probs)\n", | |
"# print(mcd_probs)\n", | |
"# mcd_means = mcd_probs.mean(dim=0)\n", | |
"# print(mcd_means)\n", | |
"# mcd_std = mcd_probs.std(dim=0)\n", | |
"# print(mcd_std)\n", | |
"# _,idx = mcd_means.max(0)\n", | |
"# std = mcd_std.data[idx.item()].item()\n", | |
"# print(std)\n", | |
"# break" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "T__OVSpTGuj9", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment