Skip to content

Instantly share code, notes, and snippets.

@BenedictWilkins
Last active September 28, 2020 09:12
Show Gist options
  • Save BenedictWilkins/d58bcecc48eaf0553320484ee7eda040 to your computer and use it in GitHub Desktop.
Save BenedictWilkins/d58bcecc48eaf0553320484ee7eda040 to your computer and use it in GitHub Desktop.
Train a siamese network on MNIST
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# REQUIRES\n",
"\n",
"`[h5py, numpy, matplotlib, pytorch]`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib notebook\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class CNet(nn.Module):\n",
"\n",
" def __init__(self, input_shape):\n",
" super(CNet, self).__init__() \n",
" self.input_shape = input_shape\n",
" \n",
" self.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=4, stride=2)\n",
" self.conv2 = nn.Conv2d(64, 32, kernel_size=4, stride=1)\n",
" self.conv3 = nn.Conv2d(32, 16, kernel_size=4, stride=1)\n",
" \n",
" s1 = conv_output_shape(input_shape[1:], kernel_size=4, stride=2)\n",
" s2 = conv_output_shape(s1, kernel_size=4, stride=1)\n",
" s3 = conv_output_shape(s2, kernel_size=4, stride=1)\n",
" \n",
" self.output_shape = np.prod(s3) * 16\n",
" \n",
" def to(self, device):\n",
" self.device = device\n",
" return super(CNet, self).to(device)\n",
"\n",
" def forward(self, x_):\n",
" x_ = x_.to(self.device)\n",
" y_ = F.leaky_relu(self.conv1(x_))\n",
" y_ = F.leaky_relu(self.conv2(y_))\n",
" y_ = F.leaky_relu(self.conv3(y_)).view(x_.shape[0], -1)\n",
" return y_\n",
" \n",
"class CNet2(CNet):\n",
" \n",
" def __init__(self, input_shape, output_shape, activation=lambda x: x):\n",
" super(CNet2, self).__init__(input_shape)\n",
" self.out_layer = nn.Linear(self.output_shape, output_shape)\n",
" self.output_shape = output_shape\n",
" self.activation = activation\n",
" \n",
" def forward(self, x_):\n",
" x_ = super(CNet2, self).forward(x_)\n",
" y_ = self.activation(self.out_layer(x_))\n",
" return y_\n",
"\n",
"def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):\n",
" from math import floor\n",
" if type(kernel_size) is not tuple:\n",
" kernel_size = (kernel_size, kernel_size)\n",
" h = floor( ((h_w[0] + (2 * pad) - ( dilation * (kernel_size[0] - 1) ) - 1 )/ stride) + 1)\n",
" w = floor( ((h_w[1] + (2 * pad) - ( dilation * (kernel_size[1] - 1) ) - 1 )/ stride) + 1)\n",
" return h, w"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import h5py #pip install h5py -- https://www.h5py.org/\n",
"\n",
"def mnist():\n",
" #load train\n",
" f = h5py.File(\"./train.hdf5\", 'r')\n",
" train_x, train_y = f['image'][...], f['label'][...]\n",
" f.close()\n",
"\n",
" #load test\n",
" f = h5py.File(\"./test.hdf5\", 'r')\n",
" test_x, test_y = f['image'][...], f['label'][...]\n",
" f.close()\n",
"\n",
" print(\"train_x\", train_x.shape, train_x.dtype)\n",
" print(\"train_y\", train_y.shape, train_y.dtype)\n",
" \n",
" return train_x[:,np.newaxis], train_y, test_x[:,np.newaxis], test_y\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def distance_matrix(x1, x2=None): #L22 distance by default\n",
" if x2 is None:\n",
" x2 = x1\n",
" n_dif = x1.unsqueeze(1) - x2.unsqueeze(0)\n",
" return torch.sum(n_dif * n_dif, -1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def loss(model, x, y, margin=0.2):\n",
" x_ = model(x)\n",
" unique = np.unique(y)\n",
" device = list(model.parameters())[0].device\n",
" loss = torch.FloatTensor(np.array([0.])).to(device)\n",
"\n",
" for u in unique:\n",
" pi = np.nonzero(y == u)[0]\n",
" ni = np.nonzero(y != u)[0]\n",
" \n",
" #slightly more efficient below\n",
" xp_ = x_[pi] # get all positive images\n",
" xn_ = x_[ni] # get all negative images\n",
" xp = distance_matrix(xp_, xp_) #P-P distance\n",
" xn = distance_matrix(xp_, xn_) #P-N distance\n",
"\n",
" #3D tensor, (a - p) - (a - n) \n",
" xf = xp.unsqueeze(2) - xn\n",
"\n",
" xf = F.relu(xf + margin) #triplet loss\n",
" loss += xf.sum()\n",
"\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.rcParams['figure.figsize'] = [8,4.5]\n",
"\n",
"def plot(fig, model, x, y):\n",
" plt.clf()\n",
" with torch.no_grad():\n",
" z = model(x).cpu().numpy()\n",
" for i in range(0,10):\n",
" plt.scatter(*z[y==i].T, marker=\".\", alpha=0.5, edgecolors='none')\n",
" plt.legend([str(i) for i in range(0,10)], loc=\"upper right\")\n",
" fig.canvas.draw()\n",
" \n",
"def figtoimage(fig):\n",
" # Get the RGBA buffer from the figure\n",
" w,h = fig.canvas.get_width_height()\n",
" buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n",
" return np.flip(buf.reshape((h,w,3)), 2) #bgr format for opencv!\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"input_dim = (1, 28, 28)\n",
"batch_size = 100\n",
"margin = 0.2\n",
"latent_dim = 2\n",
"lr = 0.0005\n",
"epochs = 3\n",
"\n",
"if torch.cuda.is_available(): \n",
" device = 'cuda'\n",
"else:\n",
" device = 'cpu'\n",
"print(\"USING DEVICE:\", device)\n",
"\n",
"x_train, y_train, x_test, y_test = mnist()\n",
"x_train = torch.FloatTensor(x_train).to(device)\n",
"x_test = torch.FloatTensor(x_test).to(device)\n",
"model = CNet2(input_dim, latent_dim).to(device)\n",
"\n",
"optim = torch.optim.Adam(model.parameters(), lr=lr)\n",
"\n",
"fig = plt.figure()\n",
"fig.tight_layout()\n",
"plot(fig, model, x_test, y_test)\n",
"img = figtoimage(fig)\n",
"\n",
"plt.imsave(\"./initial.png\", img)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"fig = plt.figure()\n",
"fig.tight_layout()\n",
"video = []\n",
"\n",
"x_train = x_train.reshape(x_train.shape[0] // batch_size, batch_size, *x_train.shape[1:])\n",
"y_train = y_train.reshape(y_train.shape[0] // batch_size, batch_size, *y_train.shape[1:])\n",
"\n",
"for e in range(epochs):\n",
" for x,y in zip(*[x_train, y_train]):\n",
" optim.zero_grad()\n",
" _loss = loss(model, x, y, margin=margin)\n",
" _loss.backward()\n",
" optim.step()\n",
" #print(_loss.item())\n",
" \n",
" plot(fig, model, x_test, y_test)\n",
" video.append(figtoimage(fig))\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"import numpy as np\n",
"import cv2\n",
"\n",
"file = \"./video.mp4\"\n",
"fps = 24\n",
"#video must be CV format (NHWC)\n",
"fourcc = cv2.VideoWriter_fourcc(*'mp4v') #ehhh.... platform specific?\n",
"writer = cv2.VideoWriter(file, fourcc, fps, (video[0].shape[1], video[0].shape[0]), True)\n",
"for frame in video:\n",
" writer.write(frame)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "PhD",
"language": "python",
"name": "phd"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment