Skip to content

Instantly share code, notes, and snippets.

@georgeadam
Created September 8, 2019 15:14
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save georgeadam/b48c948ce71a16bbd915b09384e431ee to your computer and use it in GitHub Desktop.
Save georgeadam/b48c948ce71a16bbd915b09384e431ee to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MNIST Distances"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"from torchvision.datasets import MNIST\n",
"import torchvision.transforms as transforms"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f6c0195ef70>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mnist_dir = \".\"\n",
"torch.manual_seed(1)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_set = MNIST(root=mnist_dir, train=True)\n",
"transform = transforms.ToTensor()\n",
"train_set.transform = transform\n",
"train_loader = torch.utils.data.DataLoader(train_set, batch_size=60000, shuffle=True, num_workers=2)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"test_set = MNIST(root=mnist_dir, train=False, transform=transform)\n",
"test_loader = torch.utils.data.DataLoader(test_set, batch_size=10000, shuffle=True, num_workers=2)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"train_imgs, train_labels = next(iter(train_loader))\n",
"test_imgs, test_labels = next(iter(test_loader))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"partitions = [\"train\", \"test\"]\n",
"norms = [float(\"inf\"), 1, 2]\n",
"stats = [\"min\", \"max\", \"mean\", \"std\"]\n",
"labels = list(range(10))\n",
"data = {\"train\": {\"imgs\": train_imgs, \"labels\": train_labels}, \"test\": {\"imgs\": test_imgs, \"labels\": test_labels}}\n",
"\n",
"within_class_distances = {partition: {label: {norm: [] for norm in norms} for label in labels} for partition in partitions}\n",
"all_distances = {partition: {norm: {} for norm in norms} for partition in partitions}\n",
"imgs_per_class = {partition: {label: torch.tensor([]) for label in range(10)} for partition in partitions}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Extract 100 images per class"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"for partition in partitions:\n",
" for label in labels:\n",
" temp = data[partition][\"labels\"] == label\n",
" idx = torch.arange(len(data[partition][\"labels\"])).long()[temp]\n",
" idx = idx[:100]\n",
" imgs_per_class[partition][label] = data[partition][\"imgs\"][idx]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute distances"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"for partition in partitions:\n",
" for norm in norms:\n",
" for label1 in range(len(labels)):\n",
" for label2 in range(label1, len(labels)):\n",
" pair = \"{}_{}\".format(label1, label2)\n",
" all_distances[partition][norm][pair] = []\n",
" \n",
" for j in range(len(imgs_per_class[partition][label1])):\n",
" img1 = imgs_per_class[partition][label1][j]\n",
" \n",
" if label1 != label2:\n",
" start = 0\n",
" else:\n",
" start = j + 1\n",
" \n",
" for k in range(start, len(imgs_per_class[partition][label2])):\n",
" img2 = imgs_per_class[partition][label2][k]\n",
" dist = torch.norm(img1 - img2, p=norm)\n",
" all_distances[partition][norm][pair].append(dist.item())\n",
" \n",
" all_distances[partition][norm][\"{}_{}\".format(label2, label1)] = all_distances[partition][norm][pair]"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"heatmaps = {partition: {norm: np.zeros([len(labels), len(labels)]) for norm in norms} for partition in partitions}\n",
"\n",
"for partition in partitions:\n",
" for norm in norms:\n",
" for j in labels:\n",
" for k in labels:\n",
" pair = \"{}_{}\".format(j, k)\n",
" mean_dist = torch.mean(torch.tensor(all_distances[partition][norm][pair]))\n",
" heatmaps[partition][norm][j, k] = mean_dist"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x576 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.style.use(\"ggplot\")\n",
"fig = plt.figure(figsize=(8, 8))\n",
"ax = fig.add_subplot(111)\n",
"sns.heatmap(heatmaps[\"train\"][2], ax=ax, cmap=\"viridis_r\", annot=True, fmt=\".1f\")\n",
"plt.tight_layout()\n",
"plt.savefig(\"temp.png\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dl",
"language": "python",
"name": "dl"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment