Created
December 13, 2019 12:27
-
-
Save anirudhshenoy/089a70deed944d0ca7ab0b6a5eb5a7f1 to your computer and use it in GitHub Desktop.
Code for the Blog at:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
from time import time | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
from skimage.util.shape import view_as_windows | |
torch.manual_seed(42) | |
def conv_2d(kernel, bias, x): | |
kernel_shape = kernel.shape[0] | |
output_shape = x.shape[0] - kernel_shape + 1 | |
result = np.zeros((output_shape, output_shape)) | |
for row in range(x.shape[0] - 1): | |
for col in range(x.shape[1] - 1): | |
window = x[row: row + kernel_shape, col: col + kernel_shape] | |
result[row, col] = np.sum(np.multiply(kernel, window)) | |
return result + bias | |
def memory_strided_im2col(x, kernel): | |
output_shape = (x.shape[0] - kernel.shape[0]) + 1 | |
return view_as_windows(x, kernel.shape).reshape(kernel.shape[0]*2, output_shape*output_shape) | |
def naive_im2col(x, kernel): | |
kernel_shape = kernel.shape[0] | |
rows = [] | |
# Assuming Padding = 0, stride = 1 | |
for row in range(x.shape[0] - 1): | |
for col in range(x.shape[1] - 1): | |
window = x[row: row + kernel_shape, col: col + kernel_shape] | |
rows.append(window.flatten()) | |
return np.transpose(np.array(rows)) | |
if __name__ == "__main__": | |
naive_time_log = [] | |
torch_time_log = [] | |
strided_time_log = [] | |
naive_im2col_time_log = [] | |
MAX_INPUT_SIZE = 300 | |
NUM_RUNS = 20 | |
input_size_list = list(range(3, MAX_INPUT_SIZE, 5)) | |
for input_size in tqdm(input_size_list): | |
torch_time = np.zeros(NUM_RUNS) | |
naive_time = np.zeros(NUM_RUNS) | |
im2col_time = np.zeros(NUM_RUNS) | |
strided_time = np.zeros(NUM_RUNS) | |
for run in range(NUM_RUNS): | |
conv = nn.Conv2d(1, 1, 2) | |
ip = torch.randint(low=0, high=10, size=( | |
1, 1, input_size, input_size), dtype=torch.float32) | |
conv.weight = nn.Parameter(torch.randint( | |
low=0, high=10, size=(1, 1, 2, 2), dtype=torch.float32)) | |
ip_np = ip.numpy().reshape(-1, ip.shape[-1]) | |
kernel_np = conv.weight.detach().squeeze().numpy() | |
bias_np = conv.bias.detach().squeeze().numpy() | |
start = time() | |
naive_conv = conv_2d(kernel_np, bias_np, ip_np) | |
naive_time[run] = time() - start | |
start = time() | |
np.dot(kernel_np.flatten(), naive_im2col(ip_np, kernel_np)) + bias_np | |
im2col_time[run]= time() - start | |
start = time() | |
torch_conv = conv(ip) | |
torch_time[run] = time() - start | |
start = time() | |
np.dot(kernel_np.flatten(), memory_strided_im2col(ip_np, kernel_np)) + bias_np | |
strided_time[run]= time() - start | |
naive_time_log.append(naive_time.mean()) | |
torch_time_log.append(torch_time.mean()) | |
strided_time_log.append(strided_time.mean()) | |
naive_im2col_time_log.append(im2col_time.mean()) | |
plt.plot(input_size_list, naive_time_log, | |
label='Naive Conv 2D', color='red') | |
plt.plot(input_size_list, torch_time_log, | |
label='PyTorch Conv 2D', color='blue') | |
plt.plot(input_size_list, naive_im2col_time_log, | |
label='Im2Col Conv 2D', color='green') | |
plt.plot(input_size_list, strided_time_log, | |
label='Mem Strided Im2Col Conv 2D', color='purple') | |
plt.xlabel('Size of Input (n x n)') | |
plt.ylabel('Execution Time (secs) - Log Scale') | |
plt.yscale("log") | |
plt.legend() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
When input shape is higher than 3 - then output from your convolution is not the same as output from pytorch convolution. Try to print line 77 and 73. They should be same but they are not. I know output from numpy convolution is vector and output from pytorch convolution is tensor, simply even if you reshape numpy convolution product to (1, 1, Input_size, input_size) the numbers in matrix are not the same. But the naive_im2col is working fine.