Last active
November 9, 2019 15:41
-
-
Save tomginsberg/172733c91e545648a8321d1e149047a2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import tensor | |
from torch.nn.parameter import Parameter | |
import unittest | |
from unittest import TestCase | |
torch.manual_seed(42) | |
''' | |
Convolutions for augmented images | |
Given a regular (n x n) image n=2 | |
[[x1 x2] | |
[x3 x4]] | |
The augmented image turns each pixel value in a vector of length l | |
i.e l=3 | |
[[[x11 x12 x13] [x21 x22 x23]] | |
[[x31 x32 x33] [x41 x42 x43]]] | |
The dimensions of this image are (n x n x l) | |
If we have multiple channels, the dimension is (channels, n, n, l) | |
i.e channels = 2 | |
[ [[[x11 x12 x13] [x21 x22 x23]] | |
[[x31 x32 x33] [x41 x42 x43]]], | |
[[[y1 x12 x13] [y21 x22 y23]] | |
[[y1 x12 x13] [y21 y22 y23]]] ] | |
Convolution is defined exactly as convolution is in PyTorch, | |
except pixel's undergo linear combinations as weighted elementwise sums | |
i.e | |
2([x11 x12 x13]) - 3([x21 x22 x23]) = [2(x11)-3(x21) 2(x12)-3(x22) 2(x13)-3(x23)] | |
The only other significant change is that bias must only be added to the first element of a pixel vector | |
i.e | |
[x11 x12 x13] + b = [x11+b x12 x13] | |
All other details regarding convolution are identical to PyTorch docs | |
Notes: | |
-Kernels are always square | |
-Each channel of an augmented image is (n x n x l) i.e outer dimensions are square | |
-Groups is always equal to one | |
-The batch size is always one. This seems unusual, but the big problem were trying to solve here | |
is not training but proving properties about a network. The network will have additional parameters besides | |
weights and biases that it will try and optimize for a certain sample individually. | |
Below is a very naive and slow convolution for augmented images | |
Takes ~ 1.5s for a forward pass through Conv2d(1, 32, kernel_size=4, stride=2, padding=1) | |
for a augemnted MNIST image with 758 errors (1,28,28,785) on 2.5 GHz Intel Core i7 | |
An ideal implementation should be optimized for 14 cores | |
''' | |
class Conv2D(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): | |
super(Conv2D, self).__init__() | |
# These are ones only for testing | |
self.weight = Parameter(torch.ones(out_channels, in_channels, kernel_size, kernel_size), requires_grad=False) | |
self.bias = Parameter(torch.ones(out_channels), requires_grad=False) | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.padding = padding | |
@staticmethod | |
def val_or_zero_2d(x, i, j): | |
# implicitly zero pad an image | |
# x is (n x n x l), where l is the dimension of the error vector | |
n = x.shape[0] | |
if not ((0 <= i < n) and (0 <= j < n)): | |
return torch.zeros_like(x[0, 0]) | |
return x[i, j] | |
def forward(self, img): | |
# img is (num_channels x n x n x num_errors) | |
# outer dimension of each image channel (assumes n x n) | |
n = img.shape[-2] | |
# kernel is square | |
k = self.kernel_size | |
# padding is the same in both H and W | |
p = self.padding | |
# <5> stack all output channels | |
return torch.stack( | |
# <4> take all input channels, stack and sum | |
[torch.stack( | |
# <3> take all rows and stack | |
[torch.stack( | |
# <2> take inner products across a row and stack | |
[torch.stack( | |
# <1> take inner product with kernel at top left corner pixel (l, m) | |
[torch.stack( | |
[kernel[i, j] * self.val_or_zero_2d(img_channel, l + i, m + j) for i in range(k) for | |
j in range(k)]).sum(0).index_add(0, tensor([0]), bias) | |
# <1> | |
for m in range(-p, n - k + p + 1, self.stride)]) | |
# <2> | |
for l in range(-p, n - k + p + 1, self.stride)]) | |
# <3> | |
for img_channel, kernel in zip(img, kernel_group)]).sum(0) | |
# <4> | |
for kernel_group, bias in zip(self.weight, self.bias)]) | |
# <5> | |
class TestConv2D(TestCase): | |
def test1(self): | |
conv = Conv2D(1, 1, 2) | |
img = tensor([[[[0.8635, 0.1, 0., 0., 0., 0., 0., 0., 0., 0.], | |
[0.2933, 0., 0.1, 0., 0., 0., 0., 0., 0., 0.], | |
[0.0534, 0., 0., 0.1, 0., 0., 0., 0., 0., 0.]], | |
[[0.5245, 0., 0., 0., 0.1, 0., 0., 0., 0., 0.], | |
[0.8011, 0., 0., 0., 0., 0.1, 0., 0., 0., 0.], | |
[0.2956, 0., 0., 0., 0., 0., 0.1, 0., 0., 0.]], | |
[[0.7472, 0., 0., 0., 0., 0., 0., 0.1, 0., 0.], | |
[0.1142, 0., 0., 0., 0., 0., 0., 0., 0.1, 0.], | |
[0.6793, 0., 0., 0., 0., 0., 0., 0., 0., 0.1]]]]) | |
expected = tensor([[[[3.4824, 0.1, 0.1, 0, 0.1, 0.1, 0, 0, 0, 0], | |
[2.4434, 0, 0.1, 0.1, 0, 0.1, 0.1, 0, 0, 0]], | |
[[3.187, 0, 0, 0, 0.1, 0.1, 0, 0.1, 0.1, 0], | |
[2.8902, 0, 0, 0, 0, 0.1, 0.1, 0, 0.1, 0.1]]]]) | |
self.assertAlmostEqual(0, torch.sum(expected - conv(img)).numpy(), 3) | |
def test2(self): | |
conv = Conv2D(2, 1, 2, stride=2, padding=1) | |
img = tensor([[[[0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936], | |
[0.9408, 0.1332, 0.9346, 0.5936, 0.8694, 0.5677, 0.7411, 0.4294], | |
[0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317], | |
[0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753], | |
[0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423]], | |
[[0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895], | |
[0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071], | |
[0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278], | |
[0.6532, 0.3958, 0.9147, 0.2036, 0.2018, 0.2018, 0.9497, 0.6666], | |
[0.9811, 0.0874, 0.0041, 0.1088, 0.1637, 0.7025, 0.6790, 0.9155]], | |
[[0.2418, 0.1591, 0.7653, 0.2979, 0.8035, 0.3813, 0.7860, 0.1115], | |
[0.2477, 0.6524, 0.6057, 0.3725, 0.7980, 0.8399, 0.1374, 0.2331], | |
[0.9578, 0.3313, 0.3227, 0.0162, 0.2137, 0.6249, 0.4340, 0.1371], | |
[0.5117, 0.1585, 0.0758, 0.2247, 0.0624, 0.1816, 0.9998, 0.5944], | |
[0.6541, 0.0337, 0.1716, 0.3336, 0.5782, 0.0600, 0.2846, 0.2007]], | |
[[0.5014, 0.3139, 0.4654, 0.1612, 0.1568, 0.2083, 0.3289, 0.1054], | |
[0.9192, 0.4008, 0.9302, 0.6558, 0.0766, 0.8460, 0.3624, 0.3083], | |
[0.0850, 0.0029, 0.6431, 0.3908, 0.6947, 0.0897, 0.8712, 0.1330], | |
[0.4137, 0.6044, 0.7581, 0.9037, 0.9555, 0.1035, 0.6258, 0.2849], | |
[0.4452, 0.1258, 0.9554, 0.1330, 0.7672, 0.6757, 0.6625, 0.2297]], | |
[[0.9545, 0.6099, 0.5643, 0.0594, 0.7099, 0.4250, 0.2709, 0.9295], | |
[0.6115, 0.2234, 0.2469, 0.4761, 0.7792, 0.3722, 0.2147, 0.3288], | |
[0.1265, 0.6783, 0.8870, 0.0293, 0.6161, 0.7583, 0.5907, 0.3219], | |
[0.7610, 0.7628, 0.6870, 0.4121, 0.3676, 0.5535, 0.4117, 0.3510], | |
[0.8196, 0.9297, 0.4505, 0.3881, 0.5073, 0.4701, 0.6202, 0.6401]]], | |
[[[0.0459, 0.3155, 0.9211, 0.6948, 0.4751, 0.1985, 0.1941, 0.0521], | |
[0.3370, 0.6689, 0.8188, 0.7308, 0.0580, 0.1993, 0.4211, 0.9837], | |
[0.5723, 0.3705, 0.7069, 0.3096, 0.1764, 0.8649, 0.2726, 0.3998], | |
[0.0026, 0.8346, 0.8788, 0.6822, 0.1514, 0.0065, 0.0939, 0.8729], | |
[0.7401, 0.9208, 0.7619, 0.6265, 0.4951, 0.1197, 0.0716, 0.0323]], | |
[[0.7047, 0.2545, 0.3994, 0.2122, 0.4089, 0.1481, 0.1733, 0.6659], | |
[0.3514, 0.8087, 0.3396, 0.1332, 0.4118, 0.2576, 0.3470, 0.0240], | |
[0.7797, 0.1519, 0.7513, 0.7269, 0.8572, 0.1165, 0.8596, 0.2636], | |
[0.6855, 0.9696, 0.4295, 0.4961, 0.3849, 0.0825, 0.7400, 0.0036], | |
[0.8104, 0.8741, 0.9729, 0.3821, 0.0892, 0.6124, 0.7762, 0.0023]], | |
[[0.3865, 0.2003, 0.4563, 0.2539, 0.2956, 0.3413, 0.0248, 0.9103], | |
[0.9192, 0.4216, 0.4431, 0.2959, 0.0485, 0.0134, 0.6858, 0.2255], | |
[0.1786, 0.4610, 0.3335, 0.3382, 0.5161, 0.3939, 0.3278, 0.2606], | |
[0.0931, 0.9193, 0.2999, 0.6325, 0.3265, 0.5406, 0.9662, 0.7304], | |
[0.0667, 0.6985, 0.9746, 0.6315, 0.8352, 0.9929, 0.4234, 0.6038]], | |
[[0.1525, 0.3970, 0.8703, 0.7563, 0.1836, 0.0991, 0.1583, 0.0066], | |
[0.1142, 0.3764, 0.8374, 0.5837, 0.1197, 0.0989, 0.7487, 0.1281], | |
[0.4384, 0.7399, 0.2686, 0.4455, 0.4565, 0.3817, 0.2465, 0.0543], | |
[0.0958, 0.2323, 0.9829, 0.2585, 0.1642, 0.6212, 0.6378, 0.7740], | |
[0.8801, 0.7784, 0.0042, 0.5443, 0.8029, 0.4538, 0.2054, 0.9767]], | |
[[0.3130, 0.2153, 0.0492, 0.5223, 0.7216, 0.6107, 0.5989, 0.1208], | |
[0.0331, 0.5088, 0.9559, 0.7885, 0.2089, 0.4351, 0.1314, 0.2588], | |
[0.5905, 0.7723, 0.9142, 0.0409, 0.8343, 0.1474, 0.6872, 0.9231], | |
[0.5070, 0.9549, 0.0740, 0.3090, 0.7916, 0.3911, 0.3976, 0.2916], | |
[0.8447, 0.7453, 0.6602, 0.2190, 0.0941, 0.5541, 0.6481, 0.2691]]]]) | |
expected = tensor([[[[2.9282, 1.2305, 1.3040, 1.6541, 0.8655, 0.7994, 0.4507, 0.8457], | |
[4.7355, 1.7465, 2.7269, 2.2614, 1.3734, 2.0733, 1.7317, 2.6446], | |
[3.7340, 2.6081, 2.3371, 2.3171, 1.7716, 1.0364, 1.6718, 1.3228]], | |
[[3.9673, 0.9783, 2.3314, 1.7104, 2.2970, 1.1521, 1.7727, 2.2772], | |
[6.8464, 3.5134, 3.6922, 2.3344, 3.4933, 3.3152, 4.0898, 2.1788], | |
[6.4558, 4.1369, 3.8431, 3.0129, 2.6419, 3.3743, 5.8189, 3.7173]], | |
[[3.9214, 1.5361, 1.9492, 1.4992, 1.7719, 1.3431, 1.3570, 1.1623], | |
[4.9184, 3.7028, 5.6833, 3.4106, 3.7860, 3.1293, 3.8528, 2.4563], | |
[6.7671, 5.1336, 4.5723, 3.1677, 4.4504, 3.8230, 4.2091, 3.8171]]]]) | |
self.assertAlmostEqual(0, torch.sum(expected - conv(img)).numpy(), 3) | |
def test3(self): | |
conv = Conv2D(2, 2, 3, stride=2, padding=0) | |
img = torch.randn(2, 4, 4, 5) | |
out = conv(img) | |
self.assertEqual(torch.Size([2, 1, 1, 5]), out.shape) | |
def test4(self): | |
conv = Conv2D(2, 2, 3, stride=2, padding=0) | |
img = tensor([[[[1.9269e+00, 1.4873e+00, 9.0072e-01, -2.1055e+00, 6.7842e-01], | |
[-1.2345e+00, -4.3067e-02, -1.6047e+00, -7.5214e-01, 1.6487e+00], | |
[-3.9248e-01, -1.4036e+00, -7.2788e-01, -5.5943e-01, -7.6884e-01], | |
[7.6245e-01, 1.6423e+00, -1.5960e-01, -4.9740e-01, 4.3959e-01]], | |
[[-7.5813e-01, 1.0783e+00, 8.0080e-01, 1.6806e+00, 1.2791e+00], | |
[1.2964e+00, 6.1047e-01, 1.3347e+00, -2.3162e-01, 4.1759e-02], | |
[-2.5158e-01, 8.5986e-01, -1.3847e+00, -8.7124e-01, -2.2337e-01], | |
[1.7174e+00, 3.1888e-01, -4.2452e-01, 3.0572e-01, -7.7459e-01]], | |
[[-1.5576e+00, 9.9564e-01, -8.7979e-01, -6.0114e-01, -1.2742e+00], | |
[2.1228e+00, -1.2347e+00, -4.8791e-01, -9.1382e-01, -6.5814e-01], | |
[7.8024e-02, 5.2581e-01, -4.8799e-01, 1.1914e+00, -8.1401e-01], | |
[-7.3599e-01, -1.4032e+00, 3.6004e-02, -6.3477e-02, 6.7561e-01]], | |
[[-9.7807e-02, 1.8446e+00, -1.1845e+00, 1.3835e+00, 1.4451e+00], | |
[8.5641e-01, 2.2181e+00, 5.2317e-01, 3.4665e-01, -1.9733e-01], | |
[-1.0546e+00, 1.2780e+00, -1.7219e-01, 5.2379e-01, 5.6622e-02], | |
[4.2630e-01, 5.7501e-01, -6.4172e-01, -2.2064e+00, -7.5080e-01]]], | |
[[[1.0868e-02, -3.3874e-01, -1.3407e+00, -5.8537e-01, 5.3619e-01], | |
[5.2462e-01, 1.1412e+00, 5.1644e-02, 7.4395e-01, -4.8158e-01], | |
[-1.0495e+00, 6.0390e-01, -1.7223e+00, -8.2777e-01, 1.3347e+00], | |
[4.8354e-01, -2.5095e+00, 4.8800e-01, 7.8459e-01, 2.8647e-02]], | |
[[6.4076e-01, 5.8325e-01, 1.0669e+00, -4.5015e-01, -1.8527e-01], | |
[7.5276e-01, 4.0476e-01, 1.7847e-01, 2.6491e-01, 1.2732e+00], | |
[-1.3109e-03, -3.0360e-01, -1.4570e+00, -1.0234e-01, -5.9915e-01], | |
[4.7706e-01, 7.2618e-01, 9.1152e-02, -3.8907e-01, 5.2792e-01]], | |
[[-1.2685e-02, 2.4084e-01, 1.3254e-01, 7.6424e-01, 1.0950e+00], | |
[3.3989e-01, 7.1997e-01, 4.1141e-01, 1.9312e+00, 1.0119e+00], | |
[-1.4364e+00, -1.1299e+00, -1.3603e-01, 1.6354e+00, 6.5474e-01], | |
[5.7600e-01, 1.1415e+00, 1.8565e-02, -1.8058e+00, 9.2543e-01]], | |
[[-3.7534e-01, 1.0331e+00, -6.8665e-01, 6.3681e-01, -9.7267e-01], | |
[9.5846e-01, 1.6192e+00, 1.4506e+00, 2.6948e-01, -2.1038e-01], | |
[-7.3280e-01, 1.0430e-01, 3.4875e-01, 9.6759e-01, -4.6569e-01], | |
[1.6048e+00, -2.4801e+00, -4.1754e-01, -1.1955e+00, 8.1234e-01]]]]) | |
expected = tensor([[[[2.9989, 4.7977, -5.3517, 0.2111, 4.5492]]], | |
[[[2.9989, 4.7977, -5.3517, 0.2111, 4.5492]]]]) | |
self.assertAlmostEqual(0, torch.sum(expected - conv(img)).numpy(), 3) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment