Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
import torch
from torch import tensor
from torch.nn.parameter import Parameter
import unittest
from unittest import TestCase
torch.manual_seed(42)
'''
Convolutions for augmented images
Given a regular (n x n) image n=2
[[x1 x2]
[x3 x4]]
The augmented image turns each pixel value in a vector of length l
i.e l=3
[[[x11 x12 x13] [x21 x22 x23]]
[[x31 x32 x33] [x41 x42 x43]]]
The dimensions of this image are (n x n x l)
If we have multiple channels, the dimension is (channels, n, n, l)
i.e channels = 2
[ [[[x11 x12 x13] [x21 x22 x23]]
[[x31 x32 x33] [x41 x42 x43]]],
[[[y1 x12 x13] [y21 x22 y23]]
[[y1 x12 x13] [y21 y22 y23]]] ]
Convolution is defined exactly as convolution is in PyTorch,
except pixel's undergo linear combinations as weighted elementwise sums
i.e
2([x11 x12 x13]) - 3([x21 x22 x23]) = [2(x11)-3(x21) 2(x12)-3(x22) 2(x13)-3(x23)]
The only other significant change is that bias must only be added to the first element of a pixel vector
i.e
[x11 x12 x13] + b = [x11+b x12 x13]
All other details regarding convolution are identical to PyTorch docs
Notes:
-Kernels are always square
-Each channel of an augmented image is (n x n x l) i.e outer dimensions are square
-Groups is always equal to one
-The batch size is always one. This seems unusual, but the big problem were trying to solve here
is not training but proving properties about a network. The network will have additional parameters besides
weights and biases that it will try and optimize for a certain sample individually.
Below is a very naive and slow convolution for augmented images
Takes ~ 1.5s for a forward pass through Conv2d(1, 32, kernel_size=4, stride=2, padding=1)
for a augemnted MNIST image with 758 errors (1,28,28,785) on 2.5 GHz Intel Core i7
An ideal implementation should be optimized for 14 cores
'''
class Conv2D(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super(Conv2D, self).__init__()
# These are ones only for testing
self.weight = Parameter(torch.ones(out_channels, in_channels, kernel_size, kernel_size), requires_grad=False)
self.bias = Parameter(torch.ones(out_channels), requires_grad=False)
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
@staticmethod
def val_or_zero_2d(x, i, j):
# implicitly zero pad an image
# x is (n x n x l), where l is the dimension of the error vector
n = x.shape[0]
if not ((0 <= i < n) and (0 <= j < n)):
return torch.zeros_like(x[0, 0])
return x[i, j]
def forward(self, img):
# img is (num_channels x n x n x num_errors)
# outer dimension of each image channel (assumes n x n)
n = img.shape[-2]
# kernel is square
k = self.kernel_size
# padding is the same in both H and W
p = self.padding
# <5> stack all output channels
return torch.stack(
# <4> take all input channels, stack and sum
[torch.stack(
# <3> take all rows and stack
[torch.stack(
# <2> take inner products across a row and stack
[torch.stack(
# <1> take inner product with kernel at top left corner pixel (l, m)
[torch.stack(
[kernel[i, j] * self.val_or_zero_2d(img_channel, l + i, m + j) for i in range(k) for
j in range(k)]).sum(0).index_add(0, tensor([0]), bias)
# <1>
for m in range(-p, n - k + p + 1, self.stride)])
# <2>
for l in range(-p, n - k + p + 1, self.stride)])
# <3>
for img_channel, kernel in zip(img, kernel_group)]).sum(0)
# <4>
for kernel_group, bias in zip(self.weight, self.bias)])
# <5>
class TestConv2D(TestCase):
def test1(self):
conv = Conv2D(1, 1, 2)
img = tensor([[[[0.8635, 0.1, 0., 0., 0., 0., 0., 0., 0., 0.],
[0.2933, 0., 0.1, 0., 0., 0., 0., 0., 0., 0.],
[0.0534, 0., 0., 0.1, 0., 0., 0., 0., 0., 0.]],
[[0.5245, 0., 0., 0., 0.1, 0., 0., 0., 0., 0.],
[0.8011, 0., 0., 0., 0., 0.1, 0., 0., 0., 0.],
[0.2956, 0., 0., 0., 0., 0., 0.1, 0., 0., 0.]],
[[0.7472, 0., 0., 0., 0., 0., 0., 0.1, 0., 0.],
[0.1142, 0., 0., 0., 0., 0., 0., 0., 0.1, 0.],
[0.6793, 0., 0., 0., 0., 0., 0., 0., 0., 0.1]]]])
expected = tensor([[[[3.4824, 0.1, 0.1, 0, 0.1, 0.1, 0, 0, 0, 0],
[2.4434, 0, 0.1, 0.1, 0, 0.1, 0.1, 0, 0, 0]],
[[3.187, 0, 0, 0, 0.1, 0.1, 0, 0.1, 0.1, 0],
[2.8902, 0, 0, 0, 0, 0.1, 0.1, 0, 0.1, 0.1]]]])
self.assertAlmostEqual(0, torch.sum(expected - conv(img)).numpy(), 3)
def test2(self):
conv = Conv2D(2, 1, 2, stride=2, padding=1)
img = tensor([[[[0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936, 0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317],
[0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423]],
[[0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895],
[0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071],
[0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278],
[0.6532, 0.3958, 0.9147, 0.2036, 0.2018, 0.2018, 0.9497, 0.6666],
[0.9811, 0.0874, 0.0041, 0.1088, 0.1637, 0.7025, 0.6790, 0.9155]],
[[0.2418, 0.1591, 0.7653, 0.2979, 0.8035, 0.3813, 0.7860, 0.1115],
[0.2477, 0.6524, 0.6057, 0.3725, 0.7980, 0.8399, 0.1374, 0.2331],
[0.9578, 0.3313, 0.3227, 0.0162, 0.2137, 0.6249, 0.4340, 0.1371],
[0.5117, 0.1585, 0.0758, 0.2247, 0.0624, 0.1816, 0.9998, 0.5944],
[0.6541, 0.0337, 0.1716, 0.3336, 0.5782, 0.0600, 0.2846, 0.2007]],
[[0.5014, 0.3139, 0.4654, 0.1612, 0.1568, 0.2083, 0.3289, 0.1054],
[0.9192, 0.4008, 0.9302, 0.6558, 0.0766, 0.8460, 0.3624, 0.3083],
[0.0850, 0.0029, 0.6431, 0.3908, 0.6947, 0.0897, 0.8712, 0.1330],
[0.4137, 0.6044, 0.7581, 0.9037, 0.9555, 0.1035, 0.6258, 0.2849],
[0.4452, 0.1258, 0.9554, 0.1330, 0.7672, 0.6757, 0.6625, 0.2297]],
[[0.9545, 0.6099, 0.5643, 0.0594, 0.7099, 0.4250, 0.2709, 0.9295],
[0.6115, 0.2234, 0.2469, 0.4761, 0.7792, 0.3722, 0.2147, 0.3288],
[0.1265, 0.6783, 0.8870, 0.0293, 0.6161, 0.7583, 0.5907, 0.3219],
[0.7610, 0.7628, 0.6870, 0.4121, 0.3676, 0.5535, 0.4117, 0.3510],
[0.8196, 0.9297, 0.4505, 0.3881, 0.5073, 0.4701, 0.6202, 0.6401]]],
[[[0.0459, 0.3155, 0.9211, 0.6948, 0.4751, 0.1985, 0.1941, 0.0521],
[0.3370, 0.6689, 0.8188, 0.7308, 0.0580, 0.1993, 0.4211, 0.9837],
[0.5723, 0.3705, 0.7069, 0.3096, 0.1764, 0.8649, 0.2726, 0.3998],
[0.0026, 0.8346, 0.8788, 0.6822, 0.1514, 0.0065, 0.0939, 0.8729],
[0.7401, 0.9208, 0.7619, 0.6265, 0.4951, 0.1197, 0.0716, 0.0323]],
[[0.7047, 0.2545, 0.3994, 0.2122, 0.4089, 0.1481, 0.1733, 0.6659],
[0.3514, 0.8087, 0.3396, 0.1332, 0.4118, 0.2576, 0.3470, 0.0240],
[0.7797, 0.1519, 0.7513, 0.7269, 0.8572, 0.1165, 0.8596, 0.2636],
[0.6855, 0.9696, 0.4295, 0.4961, 0.3849, 0.0825, 0.7400, 0.0036],
[0.8104, 0.8741, 0.9729, 0.3821, 0.0892, 0.6124, 0.7762, 0.0023]],
[[0.3865, 0.2003, 0.4563, 0.2539, 0.2956, 0.3413, 0.0248, 0.9103],
[0.9192, 0.4216, 0.4431, 0.2959, 0.0485, 0.0134, 0.6858, 0.2255],
[0.1786, 0.4610, 0.3335, 0.3382, 0.5161, 0.3939, 0.3278, 0.2606],
[0.0931, 0.9193, 0.2999, 0.6325, 0.3265, 0.5406, 0.9662, 0.7304],
[0.0667, 0.6985, 0.9746, 0.6315, 0.8352, 0.9929, 0.4234, 0.6038]],
[[0.1525, 0.3970, 0.8703, 0.7563, 0.1836, 0.0991, 0.1583, 0.0066],
[0.1142, 0.3764, 0.8374, 0.5837, 0.1197, 0.0989, 0.7487, 0.1281],
[0.4384, 0.7399, 0.2686, 0.4455, 0.4565, 0.3817, 0.2465, 0.0543],
[0.0958, 0.2323, 0.9829, 0.2585, 0.1642, 0.6212, 0.6378, 0.7740],
[0.8801, 0.7784, 0.0042, 0.5443, 0.8029, 0.4538, 0.2054, 0.9767]],
[[0.3130, 0.2153, 0.0492, 0.5223, 0.7216, 0.6107, 0.5989, 0.1208],
[0.0331, 0.5088, 0.9559, 0.7885, 0.2089, 0.4351, 0.1314, 0.2588],
[0.5905, 0.7723, 0.9142, 0.0409, 0.8343, 0.1474, 0.6872, 0.9231],
[0.5070, 0.9549, 0.0740, 0.3090, 0.7916, 0.3911, 0.3976, 0.2916],
[0.8447, 0.7453, 0.6602, 0.2190, 0.0941, 0.5541, 0.6481, 0.2691]]]])
expected = tensor([[[[2.9282, 1.2305, 1.3040, 1.6541, 0.8655, 0.7994, 0.4507, 0.8457],
[4.7355, 1.7465, 2.7269, 2.2614, 1.3734, 2.0733, 1.7317, 2.6446],
[3.7340, 2.6081, 2.3371, 2.3171, 1.7716, 1.0364, 1.6718, 1.3228]],
[[3.9673, 0.9783, 2.3314, 1.7104, 2.2970, 1.1521, 1.7727, 2.2772],
[6.8464, 3.5134, 3.6922, 2.3344, 3.4933, 3.3152, 4.0898, 2.1788],
[6.4558, 4.1369, 3.8431, 3.0129, 2.6419, 3.3743, 5.8189, 3.7173]],
[[3.9214, 1.5361, 1.9492, 1.4992, 1.7719, 1.3431, 1.3570, 1.1623],
[4.9184, 3.7028, 5.6833, 3.4106, 3.7860, 3.1293, 3.8528, 2.4563],
[6.7671, 5.1336, 4.5723, 3.1677, 4.4504, 3.8230, 4.2091, 3.8171]]]])
self.assertAlmostEqual(0, torch.sum(expected - conv(img)).numpy(), 3)
def test3(self):
conv = Conv2D(2, 2, 3, stride=2, padding=0)
img = torch.randn(2, 4, 4, 5)
out = conv(img)
self.assertEqual(torch.Size([2, 1, 1, 5]), out.shape)
def test4(self):
conv = Conv2D(2, 2, 3, stride=2, padding=0)
img = tensor([[[[1.9269e+00, 1.4873e+00, 9.0072e-01, -2.1055e+00, 6.7842e-01],
[-1.2345e+00, -4.3067e-02, -1.6047e+00, -7.5214e-01, 1.6487e+00],
[-3.9248e-01, -1.4036e+00, -7.2788e-01, -5.5943e-01, -7.6884e-01],
[7.6245e-01, 1.6423e+00, -1.5960e-01, -4.9740e-01, 4.3959e-01]],
[[-7.5813e-01, 1.0783e+00, 8.0080e-01, 1.6806e+00, 1.2791e+00],
[1.2964e+00, 6.1047e-01, 1.3347e+00, -2.3162e-01, 4.1759e-02],
[-2.5158e-01, 8.5986e-01, -1.3847e+00, -8.7124e-01, -2.2337e-01],
[1.7174e+00, 3.1888e-01, -4.2452e-01, 3.0572e-01, -7.7459e-01]],
[[-1.5576e+00, 9.9564e-01, -8.7979e-01, -6.0114e-01, -1.2742e+00],
[2.1228e+00, -1.2347e+00, -4.8791e-01, -9.1382e-01, -6.5814e-01],
[7.8024e-02, 5.2581e-01, -4.8799e-01, 1.1914e+00, -8.1401e-01],
[-7.3599e-01, -1.4032e+00, 3.6004e-02, -6.3477e-02, 6.7561e-01]],
[[-9.7807e-02, 1.8446e+00, -1.1845e+00, 1.3835e+00, 1.4451e+00],
[8.5641e-01, 2.2181e+00, 5.2317e-01, 3.4665e-01, -1.9733e-01],
[-1.0546e+00, 1.2780e+00, -1.7219e-01, 5.2379e-01, 5.6622e-02],
[4.2630e-01, 5.7501e-01, -6.4172e-01, -2.2064e+00, -7.5080e-01]]],
[[[1.0868e-02, -3.3874e-01, -1.3407e+00, -5.8537e-01, 5.3619e-01],
[5.2462e-01, 1.1412e+00, 5.1644e-02, 7.4395e-01, -4.8158e-01],
[-1.0495e+00, 6.0390e-01, -1.7223e+00, -8.2777e-01, 1.3347e+00],
[4.8354e-01, -2.5095e+00, 4.8800e-01, 7.8459e-01, 2.8647e-02]],
[[6.4076e-01, 5.8325e-01, 1.0669e+00, -4.5015e-01, -1.8527e-01],
[7.5276e-01, 4.0476e-01, 1.7847e-01, 2.6491e-01, 1.2732e+00],
[-1.3109e-03, -3.0360e-01, -1.4570e+00, -1.0234e-01, -5.9915e-01],
[4.7706e-01, 7.2618e-01, 9.1152e-02, -3.8907e-01, 5.2792e-01]],
[[-1.2685e-02, 2.4084e-01, 1.3254e-01, 7.6424e-01, 1.0950e+00],
[3.3989e-01, 7.1997e-01, 4.1141e-01, 1.9312e+00, 1.0119e+00],
[-1.4364e+00, -1.1299e+00, -1.3603e-01, 1.6354e+00, 6.5474e-01],
[5.7600e-01, 1.1415e+00, 1.8565e-02, -1.8058e+00, 9.2543e-01]],
[[-3.7534e-01, 1.0331e+00, -6.8665e-01, 6.3681e-01, -9.7267e-01],
[9.5846e-01, 1.6192e+00, 1.4506e+00, 2.6948e-01, -2.1038e-01],
[-7.3280e-01, 1.0430e-01, 3.4875e-01, 9.6759e-01, -4.6569e-01],
[1.6048e+00, -2.4801e+00, -4.1754e-01, -1.1955e+00, 8.1234e-01]]]])
expected = tensor([[[[2.9989, 4.7977, -5.3517, 0.2111, 4.5492]]],
[[[2.9989, 4.7977, -5.3517, 0.2111, 4.5492]]]])
self.assertAlmostEqual(0, torch.sum(expected - conv(img)).numpy(), 3)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.