Skip to content

Instantly share code, notes, and snippets.

@tomginsberg
Last active November 9, 2019 15:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tomginsberg/172733c91e545648a8321d1e149047a2 to your computer and use it in GitHub Desktop.
Save tomginsberg/172733c91e545648a8321d1e149047a2 to your computer and use it in GitHub Desktop.
import torch
from torch import tensor
from torch.nn.parameter import Parameter
import unittest
from unittest import TestCase
torch.manual_seed(42)
'''
Convolutions for augmented images
Given a regular (n x n) image n=2
[[x1 x2]
[x3 x4]]
The augmented image turns each pixel value in a vector of length l
i.e l=3
[[[x11 x12 x13] [x21 x22 x23]]
[[x31 x32 x33] [x41 x42 x43]]]
The dimensions of this image are (n x n x l)
If we have multiple channels, the dimension is (channels, n, n, l)
i.e channels = 2
[ [[[x11 x12 x13] [x21 x22 x23]]
[[x31 x32 x33] [x41 x42 x43]]],
[[[y1 x12 x13] [y21 x22 y23]]
[[y1 x12 x13] [y21 y22 y23]]] ]
Convolution is defined exactly as convolution is in PyTorch,
except pixel's undergo linear combinations as weighted elementwise sums
i.e
2([x11 x12 x13]) - 3([x21 x22 x23]) = [2(x11)-3(x21) 2(x12)-3(x22) 2(x13)-3(x23)]
The only other significant change is that bias must only be added to the first element of a pixel vector
i.e
[x11 x12 x13] + b = [x11+b x12 x13]
All other details regarding convolution are identical to PyTorch docs
Notes:
-Kernels are always square
-Each channel of an augmented image is (n x n x l) i.e outer dimensions are square
-Groups is always equal to one
-The batch size is always one. This seems unusual, but the big problem were trying to solve here
is not training but proving properties about a network. The network will have additional parameters besides
weights and biases that it will try and optimize for a certain sample individually.
Below is a very naive and slow convolution for augmented images
Takes ~ 1.5s for a forward pass through Conv2d(1, 32, kernel_size=4, stride=2, padding=1)
for a augemnted MNIST image with 758 errors (1,28,28,785) on 2.5 GHz Intel Core i7
An ideal implementation should be optimized for 14 cores
'''
class Conv2D(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super(Conv2D, self).__init__()
# These are ones only for testing
self.weight = Parameter(torch.ones(out_channels, in_channels, kernel_size, kernel_size), requires_grad=False)
self.bias = Parameter(torch.ones(out_channels), requires_grad=False)
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
@staticmethod
def val_or_zero_2d(x, i, j):
# implicitly zero pad an image
# x is (n x n x l), where l is the dimension of the error vector
n = x.shape[0]
if not ((0 <= i < n) and (0 <= j < n)):
return torch.zeros_like(x[0, 0])
return x[i, j]
def forward(self, img):
# img is (num_channels x n x n x num_errors)
# outer dimension of each image channel (assumes n x n)
n = img.shape[-2]
# kernel is square
k = self.kernel_size
# padding is the same in both H and W
p = self.padding
# <5> stack all output channels
return torch.stack(
# <4> take all input channels, stack and sum
[torch.stack(
# <3> take all rows and stack
[torch.stack(
# <2> take inner products across a row and stack
[torch.stack(
# <1> take inner product with kernel at top left corner pixel (l, m)
[torch.stack(
[kernel[i, j] * self.val_or_zero_2d(img_channel, l + i, m + j) for i in range(k) for
j in range(k)]).sum(0).index_add(0, tensor([0]), bias)
# <1>
for m in range(-p, n - k + p + 1, self.stride)])
# <2>
for l in range(-p, n - k + p + 1, self.stride)])
# <3>
for img_channel, kernel in zip(img, kernel_group)]).sum(0)
# <4>
for kernel_group, bias in zip(self.weight, self.bias)])
# <5>
class TestConv2D(TestCase):
def test1(self):
conv = Conv2D(1, 1, 2)
img = tensor([[[[0.8635, 0.1, 0., 0., 0., 0., 0., 0., 0., 0.],
[0.2933, 0., 0.1, 0., 0., 0., 0., 0., 0., 0.],
[0.0534, 0., 0., 0.1, 0., 0., 0., 0., 0., 0.]],
[[0.5245, 0., 0., 0., 0.1, 0., 0., 0., 0., 0.],
[0.8011, 0., 0., 0., 0., 0.1, 0., 0., 0., 0.],
[0.2956, 0., 0., 0., 0., 0., 0.1, 0., 0., 0.]],
[[0.7472, 0., 0., 0., 0., 0., 0., 0.1, 0., 0.],
[0.1142, 0., 0., 0., 0., 0., 0., 0., 0.1, 0.],
[0.6793, 0., 0., 0., 0., 0., 0., 0., 0., 0.1]]]])
expected = tensor([[[[3.4824, 0.1, 0.1, 0, 0.1, 0.1, 0, 0, 0, 0],
[2.4434, 0, 0.1, 0.1, 0, 0.1, 0.1, 0, 0, 0]],
[[3.187, 0, 0, 0, 0.1, 0.1, 0, 0.1, 0.1, 0],
[2.8902, 0, 0, 0, 0, 0.1, 0.1, 0, 0.1, 0.1]]]])
self.assertAlmostEqual(0, torch.sum(expected - conv(img)).numpy(), 3)
def test2(self):
conv = Conv2D(2, 1, 2, stride=2, padding=1)
img = tensor([[[[0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936, 0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317],
[0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423]],
[[0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895],
[0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071],
[0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278],
[0.6532, 0.3958, 0.9147, 0.2036, 0.2018, 0.2018, 0.9497, 0.6666],
[0.9811, 0.0874, 0.0041, 0.1088, 0.1637, 0.7025, 0.6790, 0.9155]],
[[0.2418, 0.1591, 0.7653, 0.2979, 0.8035, 0.3813, 0.7860, 0.1115],
[0.2477, 0.6524, 0.6057, 0.3725, 0.7980, 0.8399, 0.1374, 0.2331],
[0.9578, 0.3313, 0.3227, 0.0162, 0.2137, 0.6249, 0.4340, 0.1371],
[0.5117, 0.1585, 0.0758, 0.2247, 0.0624, 0.1816, 0.9998, 0.5944],
[0.6541, 0.0337, 0.1716, 0.3336, 0.5782, 0.0600, 0.2846, 0.2007]],
[[0.5014, 0.3139, 0.4654, 0.1612, 0.1568, 0.2083, 0.3289, 0.1054],
[0.9192, 0.4008, 0.9302, 0.6558, 0.0766, 0.8460, 0.3624, 0.3083],
[0.0850, 0.0029, 0.6431, 0.3908, 0.6947, 0.0897, 0.8712, 0.1330],
[0.4137, 0.6044, 0.7581, 0.9037, 0.9555, 0.1035, 0.6258, 0.2849],
[0.4452, 0.1258, 0.9554, 0.1330, 0.7672, 0.6757, 0.6625, 0.2297]],
[[0.9545, 0.6099, 0.5643, 0.0594, 0.7099, 0.4250, 0.2709, 0.9295],
[0.6115, 0.2234, 0.2469, 0.4761, 0.7792, 0.3722, 0.2147, 0.3288],
[0.1265, 0.6783, 0.8870, 0.0293, 0.6161, 0.7583, 0.5907, 0.3219],
[0.7610, 0.7628, 0.6870, 0.4121, 0.3676, 0.5535, 0.4117, 0.3510],
[0.8196, 0.9297, 0.4505, 0.3881, 0.5073, 0.4701, 0.6202, 0.6401]]],
[[[0.0459, 0.3155, 0.9211, 0.6948, 0.4751, 0.1985, 0.1941, 0.0521],
[0.3370, 0.6689, 0.8188, 0.7308, 0.0580, 0.1993, 0.4211, 0.9837],
[0.5723, 0.3705, 0.7069, 0.3096, 0.1764, 0.8649, 0.2726, 0.3998],
[0.0026, 0.8346, 0.8788, 0.6822, 0.1514, 0.0065, 0.0939, 0.8729],
[0.7401, 0.9208, 0.7619, 0.6265, 0.4951, 0.1197, 0.0716, 0.0323]],
[[0.7047, 0.2545, 0.3994, 0.2122, 0.4089, 0.1481, 0.1733, 0.6659],
[0.3514, 0.8087, 0.3396, 0.1332, 0.4118, 0.2576, 0.3470, 0.0240],
[0.7797, 0.1519, 0.7513, 0.7269, 0.8572, 0.1165, 0.8596, 0.2636],
[0.6855, 0.9696, 0.4295, 0.4961, 0.3849, 0.0825, 0.7400, 0.0036],
[0.8104, 0.8741, 0.9729, 0.3821, 0.0892, 0.6124, 0.7762, 0.0023]],
[[0.3865, 0.2003, 0.4563, 0.2539, 0.2956, 0.3413, 0.0248, 0.9103],
[0.9192, 0.4216, 0.4431, 0.2959, 0.0485, 0.0134, 0.6858, 0.2255],
[0.1786, 0.4610, 0.3335, 0.3382, 0.5161, 0.3939, 0.3278, 0.2606],
[0.0931, 0.9193, 0.2999, 0.6325, 0.3265, 0.5406, 0.9662, 0.7304],
[0.0667, 0.6985, 0.9746, 0.6315, 0.8352, 0.9929, 0.4234, 0.6038]],
[[0.1525, 0.3970, 0.8703, 0.7563, 0.1836, 0.0991, 0.1583, 0.0066],
[0.1142, 0.3764, 0.8374, 0.5837, 0.1197, 0.0989, 0.7487, 0.1281],
[0.4384, 0.7399, 0.2686, 0.4455, 0.4565, 0.3817, 0.2465, 0.0543],
[0.0958, 0.2323, 0.9829, 0.2585, 0.1642, 0.6212, 0.6378, 0.7740],
[0.8801, 0.7784, 0.0042, 0.5443, 0.8029, 0.4538, 0.2054, 0.9767]],
[[0.3130, 0.2153, 0.0492, 0.5223, 0.7216, 0.6107, 0.5989, 0.1208],
[0.0331, 0.5088, 0.9559, 0.7885, 0.2089, 0.4351, 0.1314, 0.2588],
[0.5905, 0.7723, 0.9142, 0.0409, 0.8343, 0.1474, 0.6872, 0.9231],
[0.5070, 0.9549, 0.0740, 0.3090, 0.7916, 0.3911, 0.3976, 0.2916],
[0.8447, 0.7453, 0.6602, 0.2190, 0.0941, 0.5541, 0.6481, 0.2691]]]])
expected = tensor([[[[2.9282, 1.2305, 1.3040, 1.6541, 0.8655, 0.7994, 0.4507, 0.8457],
[4.7355, 1.7465, 2.7269, 2.2614, 1.3734, 2.0733, 1.7317, 2.6446],
[3.7340, 2.6081, 2.3371, 2.3171, 1.7716, 1.0364, 1.6718, 1.3228]],
[[3.9673, 0.9783, 2.3314, 1.7104, 2.2970, 1.1521, 1.7727, 2.2772],
[6.8464, 3.5134, 3.6922, 2.3344, 3.4933, 3.3152, 4.0898, 2.1788],
[6.4558, 4.1369, 3.8431, 3.0129, 2.6419, 3.3743, 5.8189, 3.7173]],
[[3.9214, 1.5361, 1.9492, 1.4992, 1.7719, 1.3431, 1.3570, 1.1623],
[4.9184, 3.7028, 5.6833, 3.4106, 3.7860, 3.1293, 3.8528, 2.4563],
[6.7671, 5.1336, 4.5723, 3.1677, 4.4504, 3.8230, 4.2091, 3.8171]]]])
self.assertAlmostEqual(0, torch.sum(expected - conv(img)).numpy(), 3)
def test3(self):
conv = Conv2D(2, 2, 3, stride=2, padding=0)
img = torch.randn(2, 4, 4, 5)
out = conv(img)
self.assertEqual(torch.Size([2, 1, 1, 5]), out.shape)
def test4(self):
conv = Conv2D(2, 2, 3, stride=2, padding=0)
img = tensor([[[[1.9269e+00, 1.4873e+00, 9.0072e-01, -2.1055e+00, 6.7842e-01],
[-1.2345e+00, -4.3067e-02, -1.6047e+00, -7.5214e-01, 1.6487e+00],
[-3.9248e-01, -1.4036e+00, -7.2788e-01, -5.5943e-01, -7.6884e-01],
[7.6245e-01, 1.6423e+00, -1.5960e-01, -4.9740e-01, 4.3959e-01]],
[[-7.5813e-01, 1.0783e+00, 8.0080e-01, 1.6806e+00, 1.2791e+00],
[1.2964e+00, 6.1047e-01, 1.3347e+00, -2.3162e-01, 4.1759e-02],
[-2.5158e-01, 8.5986e-01, -1.3847e+00, -8.7124e-01, -2.2337e-01],
[1.7174e+00, 3.1888e-01, -4.2452e-01, 3.0572e-01, -7.7459e-01]],
[[-1.5576e+00, 9.9564e-01, -8.7979e-01, -6.0114e-01, -1.2742e+00],
[2.1228e+00, -1.2347e+00, -4.8791e-01, -9.1382e-01, -6.5814e-01],
[7.8024e-02, 5.2581e-01, -4.8799e-01, 1.1914e+00, -8.1401e-01],
[-7.3599e-01, -1.4032e+00, 3.6004e-02, -6.3477e-02, 6.7561e-01]],
[[-9.7807e-02, 1.8446e+00, -1.1845e+00, 1.3835e+00, 1.4451e+00],
[8.5641e-01, 2.2181e+00, 5.2317e-01, 3.4665e-01, -1.9733e-01],
[-1.0546e+00, 1.2780e+00, -1.7219e-01, 5.2379e-01, 5.6622e-02],
[4.2630e-01, 5.7501e-01, -6.4172e-01, -2.2064e+00, -7.5080e-01]]],
[[[1.0868e-02, -3.3874e-01, -1.3407e+00, -5.8537e-01, 5.3619e-01],
[5.2462e-01, 1.1412e+00, 5.1644e-02, 7.4395e-01, -4.8158e-01],
[-1.0495e+00, 6.0390e-01, -1.7223e+00, -8.2777e-01, 1.3347e+00],
[4.8354e-01, -2.5095e+00, 4.8800e-01, 7.8459e-01, 2.8647e-02]],
[[6.4076e-01, 5.8325e-01, 1.0669e+00, -4.5015e-01, -1.8527e-01],
[7.5276e-01, 4.0476e-01, 1.7847e-01, 2.6491e-01, 1.2732e+00],
[-1.3109e-03, -3.0360e-01, -1.4570e+00, -1.0234e-01, -5.9915e-01],
[4.7706e-01, 7.2618e-01, 9.1152e-02, -3.8907e-01, 5.2792e-01]],
[[-1.2685e-02, 2.4084e-01, 1.3254e-01, 7.6424e-01, 1.0950e+00],
[3.3989e-01, 7.1997e-01, 4.1141e-01, 1.9312e+00, 1.0119e+00],
[-1.4364e+00, -1.1299e+00, -1.3603e-01, 1.6354e+00, 6.5474e-01],
[5.7600e-01, 1.1415e+00, 1.8565e-02, -1.8058e+00, 9.2543e-01]],
[[-3.7534e-01, 1.0331e+00, -6.8665e-01, 6.3681e-01, -9.7267e-01],
[9.5846e-01, 1.6192e+00, 1.4506e+00, 2.6948e-01, -2.1038e-01],
[-7.3280e-01, 1.0430e-01, 3.4875e-01, 9.6759e-01, -4.6569e-01],
[1.6048e+00, -2.4801e+00, -4.1754e-01, -1.1955e+00, 8.1234e-01]]]])
expected = tensor([[[[2.9989, 4.7977, -5.3517, 0.2111, 4.5492]]],
[[[2.9989, 4.7977, -5.3517, 0.2111, 4.5492]]]])
self.assertAlmostEqual(0, torch.sum(expected - conv(img)).numpy(), 3)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment